[edit]
Causal Invariance-aware Augmentation for Brain Graph Contrastive Learning
Proceedings of the 42nd International Conference on Machine Learning, PMLR 267:73221-73240, 2025.
Abstract
Deep models are increasingly used to analyze brain graphs for the diagnosis and understanding of brain diseases. However, due to the multi-site data aggregation and individual differences, brain graph datasets exhibit widespread distribution shifts, which impair the model’s generalization ability to the test set, thereby limiting the performance of existing methods. To address these issues, we propose a Causally Invariance-aware Augmentation for brain Graph Contrastive Learning, called CIA-GCL. This method first generates a brain graph by extracting node features based on the topological structure. Then, a learnable brain invariant subgraph is identified based on a causal decoupling approach to capture the maximum label-related invariant information with invariant learning. Around this invariant subgraph, we design a novel invariance-aware augmentation strategy to generate meaningful augmented samples for graph contrast learning. Finally, the extracted invariant subgraph is utilized for brain disease classification, effectively mitigating distribution shifts while also identifying critical local graph structures, enhancing the model’s interpretability. Experiments on three real-world brain disease datasets demonstrate that our method achieves state-of-the-art performance, effectively generalizes to multi-site brain datasets, and provides certain interpretability.