Towards Causal Foundation Model: on Duality between Causal Inference and Attention

Jiaqi Zhang,Joel Jennings,Agrin Hilmkil,Nick Pawlowski,Cheng Zhang,Chao Ma
2024-06-04
Abstract:Foundation models have brought changes to the landscape of machine learning, demonstrating sparks of human-level intelligence across a diverse array of tasks. However, a gap persists in complex tasks such as causal inference, primarily due to challenges associated with intricate reasoning steps and high numerical precision requirements. In this work, we take a first step towards building causally-aware foundation models for treatment effect estimations. We propose a novel, theoretically justified method called Causal Inference with Attention (CInA), which utilizes multiple unlabeled datasets to perform self-supervised causal learning, and subsequently enables zero-shot causal inference on unseen tasks with new data. This is based on our theoretical results that demonstrate the primal-dual connection between optimal covariate balancing and self-attention, facilitating zero-shot causal inference through the final layer of a trained transformer-type architecture. We demonstrate empirically that CInA effectively generalizes to out-of-distribution datasets and various real-world datasets, matching or even surpassing traditional per-dataset methodologies. These results provide compelling evidence that our method has the potential to serve as a stepping stone for the development of causal foundation models.
Machine Learning,Artificial Intelligence,Methodology
What problem does this paper attempt to address?
The paper aims to address the problem of causal inference in the field of machine learning, particularly the challenges faced in complex tasks such as treatment effect estimation. Specifically, the goals of the paper can be summarized as follows: 1. **Constructing a Causally Aware Foundation Model**: Current foundation models (such as large language models) have limitations when dealing with tasks that require fine reasoning or high numerical precision, especially in causal inference. Therefore, the authors aim to develop a new type of foundation model that can better understand and execute causal reasoning tasks. 2. **Proposing the CInA Method**: To achieve the above goal, the authors propose a new method called Causal Inference with Attention (CInA). This method utilizes multiple unlabeled datasets for self-supervised learning and can perform zero-shot causal inference on new data. 3. **Theoretical Foundation and Algorithm Design**: The paper establishes a theoretical link between optimal covariate balance and the self-attention mechanism, proving that under certain conditions, the trained self-attention layer can find the optimal covariate balancing weights. Based on these theoretical results, the paper designs a gradient-based practical algorithm suitable for zero-shot causal inference. 4. **Empirical Validation**: Through experiments on synthetic and real-world datasets, the effectiveness and generalization ability of CInA are validated. The experimental results show that CInA not only performs well in zero-shot settings but also, in some cases, outperforms traditional causal inference methods for each dataset and significantly reduces inference time. In summary, the main contribution of this paper is the proposal of a causally aware foundation model framework for treatment effect estimation, which not only has a theoretical basis but also demonstrates its effectiveness and practicality through experiments.