Learning Decomposed Representation for Counterfactual Inference

Anpeng Wu,Kun Kuang,Junkun Yuan,Bo Li,Runze Wu,Qiang Zhu,Yueting Zhuang,Fei Wu
DOI: https://doi.org/10.48550/arXiv.2006.07040
2021-10-12
Abstract:The fundamental problem in treatment effect estimation from observational data is confounder identification and balancing. Most of the previous methods realized confounder balancing by treating all observed pre-treatment variables as confounders, ignoring further identifying confounders and non-confounders. In general, not all the observed pre-treatment variables are confounders that refer to the common causes of the treatment and the outcome, some variables only contribute to the treatment and some only contribute to the outcome. Balancing those non-confounders, including instrumental variables and adjustment variables, would generate additional bias for treatment effect estimation. By modeling the different causal relations among observed pre-treatment variables, treatment and outcome, we propose a synergistic learning framework to 1) identify confounders by learning decomposed representations of both confounders and non-confounders, 2) balance confounder with sample re-weighting technique, and simultaneously 3) estimate the treatment effect in observational studies via counterfactual inference. Empirical results on synthetic and real-world datasets demonstrate that the proposed method can precisely decompose confounders and achieve a more precise estimation of treatment effect than baselines.
Methodology,Machine Learning
What problem does this paper attempt to address?
The problem that this paper attempts to solve is how to accurately identify and balance confounders when making causal inferences in observational data, in order to improve the accuracy of treatment effect estimation. Specifically, the paper points out that most previous methods perform balancing by treating all observed pre - treatment variables as confounders, ignoring further distinction between confounders and non - confounders (such as instrumental variables and adjustment variables). This practice can lead to additional bias. Therefore, this paper proposes a collaborative learning framework, aiming at: 1. Identifying confounders by learning the decomposed representations of confounders and non - confounders. 2. Balancing confounders using sample re - weighting techniques. 3. Simultaneously estimating treatment effects in observational studies through counterfactual reasoning. To achieve these goals, the authors propose a new algorithm named Decomposed Representations for CounterFactual Regression (DeR - CFR). This algorithm is implemented through the following steps: - **Decomposed Representation Network**: Build three decomposed representation networks to learn the representations of latent factors \(I(X)\), \(C(X)\) and \(A(X)\) respectively. - **Decomposition and Balance Regularizers**: Design three regularizers for: - Decomposing the adjustment factor \(A\), ensuring that \(A\perp T\) and \(A\) can predict \(Y\) as accurately as possible. - Decomposing the instrumental factor \(I\), ensuring that \(I\perp Y|T\) and \(I\) can predict \(T\) as accurately as possible. - Balancing the confounder \(C\) by minimizing the distribution differences of \(C\) in different treatment groups. - **Outcome Regression Network**: Build two regression networks for predicting the potential outcomes of each treatment group respectively. Through the above methods, the DeR - CFR algorithm can more accurately identify and balance confounders in observational data, thereby improving the accuracy of treatment effect estimation. Experimental results show that this method can more precisely decompose confounders on both synthetic and real - world datasets, and is superior to baseline methods in terms of treatment effect estimation.