Training Diffusion Models with Reinforcement Learning

Kevin Black,Michael Janner,Yilun Du,Ilya Kostrikov,Sergey Levine
2024-01-05
Abstract:Diffusion models are a class of flexible generative models trained with an approximation to the log-likelihood objective. However, most use cases of diffusion models are not concerned with likelihoods, but instead with downstream objectives such as human-perceived image quality or drug effectiveness. In this paper, we investigate reinforcement learning methods for directly optimizing diffusion models for such objectives. We describe how posing denoising as a multi-step decision-making problem enables a class of policy gradient algorithms, which we refer to as denoising diffusion policy optimization (DDPO), that are more effective than alternative reward-weighted likelihood approaches. Empirically, DDPO is able to adapt text-to-image diffusion models to objectives that are difficult to express via prompting, such as image compressibility, and those derived from human feedback, such as aesthetic quality. Finally, we show that DDPO can improve prompt-image alignment using feedback from a vision-language model without the need for additional data collection or human annotation. The project's website can be found at <a class="link-external link-http" href="http://rl-diffusion.github.io" rel="external noopener nofollow">this http URL</a> .
Machine Learning,Artificial Intelligence,Computer Vision and Pattern Recognition
What problem does this paper attempt to address?
### Main Purpose and Research Questions This paper explores how to use reinforcement learning methods to directly optimize diffusion models so that they can better meet downstream task objectives, such as human-perceived image quality or drug efficacy, rather than focusing solely on likelihood estimation. ### Main Contributions and Methods - **Proposed DDPO Algorithm**: The authors proposed a policy gradient algorithm—Denoising Diffusion Policy Optimization (DDPO), which can effectively optimize diffusion models to adapt to various downstream tasks. By treating the denoising process as a multi-step decision problem, DDPO can directly optimize the model based on a black-box reward function. - **Comparison with Different Methods**: The paper compares DDPO with reward-weighted likelihood regression methods and demonstrates DDPO's superiority in various tasks. DDPO can effectively adapt to goals that are difficult to specify through prompts, such as image compressibility and aesthetic quality, and can use feedback from visual language models (VLM) to improve prompt-image consistency without additional data collection or manual annotation. - **Experimental Validation**: A series of experiments validated the effectiveness of DDPO, including optimizing image compressibility, incompressibility, and aesthetic quality. The experiments also showed that DDPO combined with VLM could automatically improve the consistency between pre-trained model-generated images and text prompts without additional manual labels. Furthermore, the study showed that these optimization effects could generalize to unseen prompts. ### Application Scenarios Examples - **Compressibility and Incompressibility**: By optimizing file size as a reward function, DDPO can generate images that are easy or difficult to compress. - **Aesthetic Quality**: Using the LAION aesthetic predictor as a reward function to optimize the aesthetic quality of images. - **Prompt Consistency Assisted by Visual Language Models**: Utilizing visual language models (such as LLaVA) to describe generated images and calculating the similarity between the description and the original prompt as a reward to optimize prompt-image consistency. ### Conclusion and Future Work This paper proposes a new framework for directly optimizing diffusion models using reinforcement learning, particularly the DDPO algorithm, which can effectively address a series of challenging downstream tasks. Future research directions may include expanding the scope of visual language models and enriching prompt distributions, as well as studying the issue of over-optimization to ensure the model maintains practicality and generalization ability.