Keypoint-based Progressive Chain-of-Thought Distillation for LLMs

Kaituo Feng, Changsheng Li, Xiaolu Zhang, Jun Zhou, Ye Yuan, Guoren Wang
Proceedings of the 41st International Conference on Machine Learning, PMLR 235:13241-13255, 2024.

Abstract

Chain-of-thought distillation is a powerful technique for transferring reasoning abilities from large language models (LLMs) to smaller student models. Previous methods typically require the student to mimic the step-by-step rationale produced by LLMs, often facing the following challenges: (i) Tokens within a rationale vary in significance, and treating them equally may fail to accurately mimic keypoint tokens, leading to reasoning errors. (ii) They usually distill knowledge by consistently predicting all the steps in a rationale, which falls short in distinguishing the learning order of step generation. This diverges from the human cognitive progression of starting with easy tasks and advancing to harder ones, resulting in sub-optimal outcomes. To this end, we propose a unified framework, called KPOD, to address these issues. Specifically, we propose a token weighting module utilizing mask learning to encourage accurate mimicry of keypoint tokens by the student during distillation. Besides, we develop an in-rationale progressive distillation strategy, starting with training the student to generate the final reasoning steps and gradually extending to cover the entire rationale. To accomplish this, a weighted token generation loss is proposed to assess step reasoning difficulty, and a value function is devised to schedule the progressive distillation by considering both step difficulty and question diversity. Extensive experiments on four reasoning benchmarks illustrate our KPOD outperforms previous methods by a large margin.

Cite this Paper


BibTeX
@InProceedings{pmlr-v235-feng24e, title = {Keypoint-based Progressive Chain-of-Thought Distillation for {LLM}s}, author = {Feng, Kaituo and Li, Changsheng and Zhang, Xiaolu and Zhou, Jun and Yuan, Ye and Wang, Guoren}, booktitle = {Proceedings of the 41st International Conference on Machine Learning}, pages = {13241--13255}, year = {2024}, editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix}, volume = {235}, series = {Proceedings of Machine Learning Research}, month = {21--27 Jul}, publisher = {PMLR}, pdf = {https://raw.githubusercontent.com/mlresearch/v235/main/assets/feng24e/feng24e.pdf}, url = {https://proceedings.mlr.press/v235/feng24e.html}, abstract = {Chain-of-thought distillation is a powerful technique for transferring reasoning abilities from large language models (LLMs) to smaller student models. Previous methods typically require the student to mimic the step-by-step rationale produced by LLMs, often facing the following challenges: (i) Tokens within a rationale vary in significance, and treating them equally may fail to accurately mimic keypoint tokens, leading to reasoning errors. (ii) They usually distill knowledge by consistently predicting all the steps in a rationale, which falls short in distinguishing the learning order of step generation. This diverges from the human cognitive progression of starting with easy tasks and advancing to harder ones, resulting in sub-optimal outcomes. To this end, we propose a unified framework, called KPOD, to address these issues. Specifically, we propose a token weighting module utilizing mask learning to encourage accurate mimicry of keypoint tokens by the student during distillation. Besides, we develop an in-rationale progressive distillation strategy, starting with training the student to generate the final reasoning steps and gradually extending to cover the entire rationale. To accomplish this, a weighted token generation loss is proposed to assess step reasoning difficulty, and a value function is devised to schedule the progressive distillation by considering both step difficulty and question diversity. Extensive experiments on four reasoning benchmarks illustrate our KPOD outperforms previous methods by a large margin.} }
Endnote
%0 Conference Paper %T Keypoint-based Progressive Chain-of-Thought Distillation for LLMs %A Kaituo Feng %A Changsheng Li %A Xiaolu Zhang %A Jun Zhou %A Ye Yuan %A Guoren Wang %B Proceedings of the 41st International Conference on Machine Learning %C Proceedings of Machine Learning Research %D 2024 %E Ruslan Salakhutdinov %E Zico Kolter %E Katherine Heller %E Adrian Weller %E Nuria Oliver %E Jonathan Scarlett %E Felix Berkenkamp %F pmlr-v235-feng24e %I PMLR %P 13241--13255 %U https://proceedings.mlr.press/v235/feng24e.html %V 235 %X Chain-of-thought distillation is a powerful technique for transferring reasoning abilities from large language models (LLMs) to smaller student models. Previous methods typically require the student to mimic the step-by-step rationale produced by LLMs, often facing the following challenges: (i) Tokens within a rationale vary in significance, and treating them equally may fail to accurately mimic keypoint tokens, leading to reasoning errors. (ii) They usually distill knowledge by consistently predicting all the steps in a rationale, which falls short in distinguishing the learning order of step generation. This diverges from the human cognitive progression of starting with easy tasks and advancing to harder ones, resulting in sub-optimal outcomes. To this end, we propose a unified framework, called KPOD, to address these issues. Specifically, we propose a token weighting module utilizing mask learning to encourage accurate mimicry of keypoint tokens by the student during distillation. Besides, we develop an in-rationale progressive distillation strategy, starting with training the student to generate the final reasoning steps and gradually extending to cover the entire rationale. To accomplish this, a weighted token generation loss is proposed to assess step reasoning difficulty, and a value function is devised to schedule the progressive distillation by considering both step difficulty and question diversity. Extensive experiments on four reasoning benchmarks illustrate our KPOD outperforms previous methods by a large margin.
APA
Feng, K., Li, C., Zhang, X., Zhou, J., Yuan, Y. & Wang, G.. (2024). Keypoint-based Progressive Chain-of-Thought Distillation for LLMs. Proceedings of the 41st International Conference on Machine Learning, in Proceedings of Machine Learning Research 235:13241-13255 Available from https://proceedings.mlr.press/v235/feng24e.html.

Related Material