[edit]
Theory-guided Message Passing Neural Network for Probabilistic Inference
Proceedings of The 27th International Conference on Artificial Intelligence and Statistics, PMLR 238:667-675, 2024.
Abstract
Probabilistic inference can be tackled by minimizing a variational free energy through message passing. To improve performance, neural networks are adopted for message computation. Neural message learning is heuristic and requires strong guidance to perform well. In this work, we propose a {\em theory-guided message passing neural network} (TMPNN) for probabilistic inference. Inspired by existing work, we consider a generalized Bethe free energy which allows for a learnable variational assumption. Instead of using a black-box neural network for message computation, we utilize a general message equation and introduce a symbolic message function with semantically meaningful parameters. The analytically derived symbolic message function is seamlessly integrated into the MPNN framework, giving rise to the proposed TMPNN. TMPNN is trained using algorithmic supervision without requiring exact inference results. Leveraging the theory-guided symbolic function, TMPNN offers strengthened theoretical guarantees compared to conventional heuristic neural models. It presents a novel contribution by demonstrating its applicability to both MAP and marginal inference tasks, outperforming SOTAs in both cases. Furthermore, TMPNN provides improved generalizability across various graph structures and exhibits enhanced data efficiency.