Disentangled Counterfactual Recurrent Networks for Treatment Effect Inference over Time

Jeroen Berrevoets,Alicia Curth,Ioana Bica,Eoin McKinney,Mihaela van der Schaar
DOI: https://doi.org/10.48550/arXiv.2112.03811
2021-12-08
Abstract:Choosing the best treatment-plan for each individual patient requires accurate forecasts of their outcome trajectories as a function of the treatment, over time. While large observational data sets constitute rich sources of information to learn from, they also contain biases as treatments are rarely assigned randomly in practice. To provide accurate and unbiased forecasts, we introduce the Disentangled Counterfactual Recurrent Network (DCRN), a novel sequence-to-sequence architecture that estimates treatment outcomes over time by learning representations of patient histories that are disentangled into three separate latent factors: a treatment factor, influencing only treatment selection; an outcome factor, influencing only the outcome; and a confounding factor, influencing both. With an architecture that is completely inspired by the causal structure of treatment influence over time, we advance forecast accuracy and disease understanding, as our architecture allows for practitioners to infer which patient features influence which part in a patient's trajectory, contrasting other approaches in this domain. We demonstrate that DCRN outperforms current state-of-the-art methods in forecasting treatment responses, on both real and simulated data.
Machine Learning
What problem does this paper attempt to address?