Diffusion World Model: Future Modeling Beyond Step-by-Step Rollout for Offline Reinforcement Learning

Zihan Ding,Amy Zhang,Yuandong Tian,Qinqing Zheng
2024-10-16
Abstract:We introduce Diffusion World Model (DWM), a conditional diffusion model capable of predicting multistep future states and rewards concurrently. As opposed to traditional one-step dynamics models, DWM offers long-horizon predictions in a single forward pass, eliminating the need for recursive queries. We integrate DWM into model-based value estimation, where the short-term return is simulated by future trajectories sampled from DWM. In the context of offline reinforcement learning, DWM can be viewed as a conservative value regularization through generative modeling. Alternatively, it can be seen as a data source that enables offline Q-learning with synthetic data. Our experiments on the D4RL dataset confirm the robustness of DWM to long-horizon simulation. In terms of absolute performance, DWM significantly surpasses one-step dynamics models with a $44\%$ performance gain, and is comparable to or slightly surpassing their model-free counterparts.
Machine Learning,Artificial Intelligence
What problem does this paper attempt to address?
The problem that this paper attempts to solve is in offline reinforcement learning, the error accumulation problem caused by recursive query when the traditional one - step dynamics model makes multi - step predictions. Specifically, traditional methods predict states and rewards for multiple future time steps by recursively calling the one - step model, which will cause the error to accumulate rapidly as the prediction time range increases, thus affecting the prediction accuracy. This error accumulation not only reduces the prediction performance of the model but also limits the application effect of model - based reinforcement learning methods in long - time - domain planning. To overcome this challenge, the paper introduces a new conditional diffusion model - the Diffusion World Model (DWM). DWM can simultaneously predict states and rewards for multiple future steps in a single forward pass without the need for recursive query. This method effectively reduces the error accumulation problem in long - time - domain prediction by jointly predicting multiple future steps. In addition, DWM can also be used as a conservative value regularization method to enhance the performance of offline reinforcement learning through generative modeling, or as a data source to enable offline Q - learning to use synthetic data. In summary, the main goal of this paper is to improve the accuracy and robustness of long - time - domain prediction in offline reinforcement learning by proposing DWM, thereby promoting the application of model - based reinforcement learning methods in practical problems.