Progressive Inference: Explaining Decoder-Only Sequence Classification Models Using Intermediate Predictions

Sanjay Kariyappa, Freddy Lecue, Saumitra Mishra, Christopher Pond, Daniele Magazzeni, Manuela Veloso
Proceedings of the 41st International Conference on Machine Learning, PMLR 235:23238-23255, 2024.

Abstract

This paper proposes Progressive inference–a framework to explain the predictions of decoder-only transformer models trained to perform sequence classification tasks. Our work is based on the insight that the classification head of a decoder-only model can be used to make intermediate predictions by evaluating them at different points in the input sequence. Due to the masked attention mechanism used in decoder-only models, these intermediate predictions only depend on the tokens seen before the inference point, allowing us to obtain the model’s prediction on a masked input sub-sequence, with negligible computational overheads. We develop two methods to provide sub-sequence level attributions using this core insight. First, we propose Single Pass-Progressive Inference (SP-PI) to compute attributions by simply taking the difference between intermediate predictions. Second, we exploit a connection with Kernel SHAP to develop Multi Pass-Progressive Inference (MP-PI); this uses intermediate predictions from multiple masked versions of the input to compute higher-quality attributions that approximate SHAP values. We perform studies on several text classification datasets to demonstrate that our proposal provides better explanations compared to prior work, both in the single-pass and multi-pass settings.

Cite this Paper


BibTeX
@InProceedings{pmlr-v235-kariyappa24a, title = {Progressive Inference: Explaining Decoder-Only Sequence Classification Models Using Intermediate Predictions}, author = {Kariyappa, Sanjay and Lecue, Freddy and Mishra, Saumitra and Pond, Christopher and Magazzeni, Daniele and Veloso, Manuela}, booktitle = {Proceedings of the 41st International Conference on Machine Learning}, pages = {23238--23255}, year = {2024}, editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix}, volume = {235}, series = {Proceedings of Machine Learning Research}, month = {21--27 Jul}, publisher = {PMLR}, pdf = {https://raw.githubusercontent.com/mlresearch/v235/main/assets/kariyappa24a/kariyappa24a.pdf}, url = {https://proceedings.mlr.press/v235/kariyappa24a.html}, abstract = {This paper proposes Progressive inference–a framework to explain the predictions of decoder-only transformer models trained to perform sequence classification tasks. Our work is based on the insight that the classification head of a decoder-only model can be used to make intermediate predictions by evaluating them at different points in the input sequence. Due to the masked attention mechanism used in decoder-only models, these intermediate predictions only depend on the tokens seen before the inference point, allowing us to obtain the model’s prediction on a masked input sub-sequence, with negligible computational overheads. We develop two methods to provide sub-sequence level attributions using this core insight. First, we propose Single Pass-Progressive Inference (SP-PI) to compute attributions by simply taking the difference between intermediate predictions. Second, we exploit a connection with Kernel SHAP to develop Multi Pass-Progressive Inference (MP-PI); this uses intermediate predictions from multiple masked versions of the input to compute higher-quality attributions that approximate SHAP values. We perform studies on several text classification datasets to demonstrate that our proposal provides better explanations compared to prior work, both in the single-pass and multi-pass settings.} }
Endnote
%0 Conference Paper %T Progressive Inference: Explaining Decoder-Only Sequence Classification Models Using Intermediate Predictions %A Sanjay Kariyappa %A Freddy Lecue %A Saumitra Mishra %A Christopher Pond %A Daniele Magazzeni %A Manuela Veloso %B Proceedings of the 41st International Conference on Machine Learning %C Proceedings of Machine Learning Research %D 2024 %E Ruslan Salakhutdinov %E Zico Kolter %E Katherine Heller %E Adrian Weller %E Nuria Oliver %E Jonathan Scarlett %E Felix Berkenkamp %F pmlr-v235-kariyappa24a %I PMLR %P 23238--23255 %U https://proceedings.mlr.press/v235/kariyappa24a.html %V 235 %X This paper proposes Progressive inference–a framework to explain the predictions of decoder-only transformer models trained to perform sequence classification tasks. Our work is based on the insight that the classification head of a decoder-only model can be used to make intermediate predictions by evaluating them at different points in the input sequence. Due to the masked attention mechanism used in decoder-only models, these intermediate predictions only depend on the tokens seen before the inference point, allowing us to obtain the model’s prediction on a masked input sub-sequence, with negligible computational overheads. We develop two methods to provide sub-sequence level attributions using this core insight. First, we propose Single Pass-Progressive Inference (SP-PI) to compute attributions by simply taking the difference between intermediate predictions. Second, we exploit a connection with Kernel SHAP to develop Multi Pass-Progressive Inference (MP-PI); this uses intermediate predictions from multiple masked versions of the input to compute higher-quality attributions that approximate SHAP values. We perform studies on several text classification datasets to demonstrate that our proposal provides better explanations compared to prior work, both in the single-pass and multi-pass settings.
APA
Kariyappa, S., Lecue, F., Mishra, S., Pond, C., Magazzeni, D. & Veloso, M.. (2024). Progressive Inference: Explaining Decoder-Only Sequence Classification Models Using Intermediate Predictions. Proceedings of the 41st International Conference on Machine Learning, in Proceedings of Machine Learning Research 235:23238-23255 Available from https://proceedings.mlr.press/v235/kariyappa24a.html.

Related Material