Bypass Back-propagation: Optimization-based Structural Pruning for Large Language Models via Policy Gradient

Yuan Gao,Zujing Liu,Weizhong Zhang,Bo Du,Gui-Song Xia
2024-10-21
Abstract:In contrast to moderate-size neural network pruning, structural weight pruning on the Large-Language Models (LLMs) imposes a novel challenge on the efficiency of the pruning algorithms, due to the heavy computation/memory demands of the LLMs. Recent efficient LLM pruning methods typically operate at the post-training phase without the expensive weight finetuning, however, their pruning criteria often rely on heuristically hand-crafted metrics, potentially leading to suboptimal performance. We instead propose a novel optimization-based structural pruning that learns the pruning masks in a probabilistic space directly by optimizing the loss of the pruned model. To preserve the efficiency, our method eliminates the back-propagation through the LLM per se during the optimization, requiring only the forward pass of the LLM. We achieve this by learning an underlying Bernoulli distribution to sample binary pruning masks, where we decouple the Bernoulli parameters from the LLM loss, thus facilitating an efficient optimization via a policy gradient estimator without back-propagation. As a result, our method is able to 1) operate at structural granularities of channels, heads, and layers, 2) support global and heterogeneous pruning (i.e., our method automatically determines different redundancy for different layers), and 3) optionally initialize with a metric-based method (for our Bernoulli distributions). Extensive experiments on LLaMA, LLaMA-2, LLaMA-3, Vicuna, and Mistral using the C4 and WikiText2 datasets demonstrate that our method operates for 2.7 hours with around 35GB memory for the 13B models on a single A100 GPU, and our pruned models outperform the state-of-the-arts w.r.t. both perplexity and the majority of various zero-shot tasks. Codes will be released.
Machine Learning,Computation and Language
What problem does this paper attempt to address?
### What problems does this paper attempt to solve? This paper aims to solve the efficiency and performance challenges encountered in the structured pruning process of large - language models (LLMs). Specifically, this research focuses on the following key issues: 1. **High computational and memory requirements**: - Due to their large number of parameters and complex architectures, LLMs lead to very high computational and memory requirements for pruning algorithms. Existing optimization methods are inefficient when dealing with LLMs and are difficult to meet the needs of practical applications. 2. **Limitations of existing pruning methods**: - **Metric - based pruning methods**: These methods rely on manually - designed metrics to decide which parts need to be pruned, which may lead to sub - optimal performance, especially in the case of high pruning rates. - **Difficulties in global and heterogeneous pruning**: Different layers of LLMs may have different degrees of redundancy, and metric - based methods are difficult to implement global and heterogeneous pruning strategies because the manually - designed metrics for different layers vary greatly and it is difficult to set appropriate pruning thresholds. 3. **Avoid back - propagation to improve efficiency**: - Existing optimization methods usually rely on back - propagation to update weights, which is very time - consuming and memory - intensive for LLMs. Therefore, how to achieve efficient pruning without using back - propagation is an important research direction. To solve the above problems, the author proposes a new optimization - driven structured pruning method, which achieves efficient pruning in the following ways: - **Learning binary masks in probability space**: Model the pruning problem as the problem of sampling binary masks from the Bernoulli distribution, thereby decoupling the relationship between the masks and the LLM loss, making it possible to perform optimization without back - propagation through the policy gradient estimator. - **Support pruning at multiple granularities**: This method can perform pruning at different granularities such as channels, heads, and layers, and also supports global and heterogeneous pruning. - **Initialization flexibility**: It can be initialized using any metric - based method or randomly, ensuring the flexibility and applicability of the method. Through these improvements, this method is not only significantly superior to existing methods in terms of computational and memory efficiency, but also experimental results on multiple benchmark datasets show that it is also superior to the current state - of - the - art pruning methods in terms of perplexity and zero - shot task performance.