Latent Dataset Distillation with Diffusion Models

Brian B. Moser,Federico Raue,Sebastian Palacio,Stanislav Frolov,Andreas Dengel
2024-07-11
Abstract:Machine learning traditionally relies on increasingly larger datasets. Yet, such datasets pose major storage challenges and usually contain non-influential samples, which could be ignored during training without negatively impacting the training quality. In response, the idea of distilling a dataset into a condensed set of synthetic samples, i.e., a distilled dataset, emerged. One key aspect is the selected architecture, usually ConvNet, for linking the original and synthetic datasets. However, the final accuracy is lower if the employed model architecture differs from that used during distillation. Another challenge is the generation of high-resolution images (128x128 and higher). To address both challenges, this paper proposes Latent Dataset Distillation with Diffusion Models (LD3M) that combine diffusion in latent space with dataset distillation. Our novel diffusion process is tailored for this task and significantly improves the gradient flow for distillation. By adjusting the number of diffusion steps, LD3M also offers a convenient way of controlling the trade-off between distillation speed and dataset quality. Overall, LD3M consistently outperforms state-of-the-art methods by up to 4.8 p.p. and 4.2 p.p. for 1 and 10 images per class, respectively, and on several ImageNet subsets and high resolutions (128x128 and 256x256).
Computer Vision and Pattern Recognition,Artificial Intelligence,Machine Learning
What problem does this paper attempt to address?
The paper aims to address the following issues: 1. **Dataset Compression**: Machine learning typically relies on increasingly larger datasets, which poses significant storage challenges, and these datasets often contain a large number of samples that have no impact on training. Therefore, researchers have proposed methods to compress the original dataset into a set of synthetic samples (i.e., distilled datasets). 2. **Cross-Architecture Generalization**: Existing methods experience performance degradation when using different architectures. If the dataset used for distillation does not match the model architecture used in practice, the final accuracy will be affected. 3. **High-Resolution Image Generation**: Existing methods struggle to generate high-resolution images (such as 128x128 and higher). Directly optimizing pixel values can lead to the generated dataset overfitting to the network used during the distillation process. To address these issues, the paper proposes a new method—**Latent Dataset Distillation with Diffusion Models (LD3M)**. This method improves gradient flow by processing diffusion models in the latent space, thereby enhancing dataset quality and better controlling the trade-off between distillation speed and dataset quality. Experimental results show that LD3M significantly outperforms existing state-of-the-art methods across multiple benchmarks.