Keypoint-based Progressive Chain-of-Thought Distillation for LLMs
Kaituo Feng,Changsheng Li,Xiaolu Zhang,JUN ZHOU,Ye Yuan,Guoren Wang
2024-01-01
Abstract:Chain-of-thought distillation is a powerful technique for transferringreasoning abilities from large language models (LLMs) to smaller studentmodels. Previous methods typically require the student to mimic thestep-by-step rationale produced by LLMs, often facing the following challenges:(i) Tokens within a rationale vary in significance, and treating them equallymay fail to accurately mimic keypoint tokens, leading to reasoning errors. (ii)They usually distill knowledge by consistently predicting all the steps in arationale, which falls short in distinguishing the learning order of stepgeneration. This diverges from the human cognitive progression of starting witheasy tasks and advancing to harder ones, resulting in sub-optimal outcomes. Tothis end, we propose a unified framework, called KPOD, to address these issues.Specifically, we propose a token weighting module utilizing mask learning toencourage accurate mimicry of keypoint tokens by the student duringdistillation. Besides, we develop an in-rationale progressive distillationstrategy, starting with training the student to generate the final reasoningsteps and gradually extending to cover the entire rationale. To accomplishthis, a weighted token generation loss is proposed to assess step reasoningdifficulty, and a value function is devised to schedule the progressivedistillation by considering both step difficulty and question diversity.Extensive experiments on four reasoning benchmarks illustrate our KPODoutperforms previous methods by a large margin.
What problem does this paper attempt to address?