BECAUSE: Bilinear Causal Representation for Generalizable Offline Model-based Reinforcement Learning

Haohong Lin,Wenhao Ding,Jian Chen,Laixi Shi,Jiacheng Zhu,Bo Li,Ding Zhao
2024-07-16
Abstract:Offline model-based reinforcement learning (MBRL) enhances data efficiency by utilizing pre-collected datasets to learn models and policies, especially in scenarios where exploration is costly or infeasible. Nevertheless, its performance often suffers from the objective mismatch between model and policy learning, resulting in inferior performance despite accurate model predictions. This paper first identifies the primary source of this mismatch comes from the underlying confounders present in offline data for MBRL. Subsequently, we introduce \textbf{B}ilin\textbf{E}ar \textbf{CAUS}al r\textbf{E}presentation~(BECAUSE), an algorithm to capture causal representation for both states and actions to reduce the influence of the distribution shift, thus mitigating the objective mismatch problem. Comprehensive evaluations on 18 tasks that vary in data quality and environment context demonstrate the superior performance of BECAUSE over existing offline RL algorithms. We show the generalizability and robustness of BECAUSE under fewer samples or larger numbers of confounders. Additionally, we offer theoretical analysis of BECAUSE to prove its error bound and sample efficiency when integrating causal representation into offline MBRL.
Machine Learning,Artificial Intelligence
What problem does this paper attempt to address?
The problem that this paper attempts to solve is the objective mismatch problem of models and policy learning goals in offline reinforcement learning (RL) due to distribution shift. Specifically, the paper points out that in model - based reinforcement learning (MBRL) based on offline models, due to the potential confounders in the data set, its performance may be poor even if the model predicts accurately. These problems mainly stem from distribution shift in two aspects: 1. Distribution shift between the offline sub - optimal behavior policy and the online optimal policy. 2. Distribution shift between the data collection environment and the online test environment. To solve these problems, the paper proposes the **BilinEar CAUSal rEpresentation (BECAUSE)** algorithm, which reduces the impact of distribution shift by capturing the causal representations of states and actions, thereby alleviating the objective mismatch problem. The core of the BECAUSE algorithm lies in: - **Causal representation learning**: By learning the causal representations of states and actions, reduce the spurious correlations caused by potential confounders. - **Uncertainty quantification**: Use the learned causal representations to quantify the uncertainty of state transitions, so as to adopt a more conservative strategy in the planning process and avoid entering out - of - distribution states. Through these methods, BECAUSE aims to improve the generalization ability and robustness of offline MBRL algorithms. The paper verifies the effectiveness of BECAUSE through extensive experiments on 18 different tasks, demonstrating its superior performance in different data qualities and environmental contexts.