Diffusion-Based Neural Network Weights Generation

Bedionita Soro,Bruno Andreis,Hayeon Lee,Wonyong Jeong,Song Chong,Frank Hutter,Sung Ju Hwang
2024-10-25
Abstract:Transfer learning has gained significant attention in recent deep learning research due to its ability to accelerate convergence and enhance performance on new tasks. However, its success is often contingent on the similarity between source and target data, and training on numerous datasets can be costly, leading to blind selection of pretrained models with limited insight into their effectiveness. To address these challenges, we introduce D2NWG, a diffusion-based neural network weights generation technique that efficiently produces high-performing weights for transfer learning, conditioned on the target dataset. Our method extends generative hyper-representation learning to recast the latent diffusion paradigm for neural network weights generation, learning the weight distributions of models pretrained on various datasets. This allows for automatic generation of weights that generalize well across both seen and unseen tasks, outperforming state-of-the-art meta-learning methods and pretrained models. Moreover, our approach is scalable to large architectures such as large language models (LLMs), overcoming the limitations of current parameter generation techniques that rely on task-specific model collections or access to original training data. By modeling the parameter distribution of LLMs, D2NWG enables task-specific parameter generation without requiring additional fine-tuning or large collections of model variants. Extensive experiments show that our method consistently enhances the performance of diverse base models, regardless of their size or complexity, positioning it as a robust solution for scalable transfer learning.
Machine Learning,Artificial Intelligence
What problem does this paper attempt to address?
The main problem this paper attempts to address is the limitations of existing transfer learning methods when dealing with new tasks, specifically including: 1. **Dataset Similarity Requirement**: The success of existing transfer learning often relies on the similarity between the source dataset and the target dataset. When there is a significant difference between the two, the transfer effect drops significantly. 2. **Blindness in Model Selection**: Training models on multiple datasets is costly, leading to a lack of in-depth understanding of the model's effectiveness when selecting pre-trained models, often requiring blind selection. 3. **Lack of Adaptability and Generalization Ability**: Existing parameter generation techniques are usually limited to specific tasks or require access to the original training data, making it difficult to achieve good generalization performance on unseen tasks. To address these challenges, the paper introduces a neural network weight generation technique based on diffusion models (D2NWG), which can efficiently generate high-performance weights suitable for transfer learning, and these weights can be conditionally generated according to the target dataset. D2NWG extends generative hyper-representation learning by applying the latent diffusion paradigm to neural network weight generation, learning the weight distribution of pre-trained models on different datasets. This enables the method to automatically generate weights that perform well on both seen and unseen tasks, surpassing current meta-learning methods and pre-trained models. Additionally, D2NWG has the following features: - **Scalability**: Suitable for large architectures, such as large language models (LLMs), overcoming the limitations of existing parameter generation techniques. - **No Additional Fine-Tuning Required**: Capable of generating task-specific parameters without additional fine-tuning, improving the model's adaptability and efficiency. - **Wide Applicability**: Demonstrates excellent performance not only in visual tasks such as image classification but also shows strong potential in natural language processing tasks. Through these improvements, D2NWG aims to provide a more powerful and scalable solution for adaptive learning across multiple tasks.