Improve Student's Reasoning Generalizability through Cascading Decomposed CoTs Distillation

Chengwei Dai,Kun Li,Wei Zhou,Songlin Hu
2024-05-30
Abstract:Large language models (LLMs) exhibit enhanced reasoning at larger scales, driving efforts to distill these capabilities into smaller models via teacher-student learning. Previous works simply fine-tune student models on teachers' generated Chain-of-Thoughts (CoTs) data. Although these methods enhance in-domain (IND) reasoning performance, they struggle to generalize to out-of-domain (OOD) tasks. We believe that the widespread spurious correlations between questions and answers may lead the model to preset a specific answer which restricts the diversity and generalizability of its reasoning process. In this paper, we propose Cascading Decomposed CoTs Distillation (CasCoD) to address these issues by decomposing the traditional single-step learning process into two cascaded learning steps. Specifically, by restructuring the training objectives -- removing the answer from outputs and concatenating the question with the rationale as input -- CasCoD's two-step learning process ensures that students focus on learning rationales without interference from the preset answers, thus improving reasoning generalizability. Extensive experiments demonstrate the effectiveness of CasCoD on both IND and OOD benchmark reasoning datasets. Code can be found at <a class="link-external link-https" href="https://github.com/C-W-D/CasCoD" rel="external noopener nofollow">this https URL</a>.
Computation and Language,Artificial Intelligence
What problem does this paper attempt to address?
The paper aims to address the following issues: 1. **Generalization of reasoning ability on Out-of-Domain (OOD) tasks**: Existing methods fine-tune student models to learn the Chain-of-Thoughts (CoTs) data generated by teacher models. Although they perform well on In-Domain (IND) tasks, their generalization ability on OOD tasks is poor. Researchers believe this is mainly due to spurious correlations between questions and answers, leading models to produce preset answers during reasoning, thereby limiting their diversity and generalization ability. 2. **Improving learning strategies in the reasoning process**: A method called Cascading Decomposed CoTs Distillation (CasCoD) is proposed, which decomposes the traditional single-step learning process into two cascading learning steps—reasoning step and answer step—to mitigate the negative impact of spurious correlations between questions and answers on model training. Specifically, CasCoD removes the answer as an output and combines the question with the reasoning rationale as input, ensuring that the student model focuses on learning the reasoning rationale without being disturbed by preset answers, thereby enhancing reasoning generalization ability. Through extensive experimental validation, CasCoD has demonstrated effective improvements on both IND and OOD tasks across multiple benchmark datasets, proving its robustness and effectiveness under different model scales and training data volumes.