Improving Prototypical Visual Explanations with Reward Reweighing, Reselection, and Retraining

Aaron Jiaxun Li, Robin Netzorg, Zhihan Cheng, Zhuoqin Zhang, Bin Yu
Proceedings of the 41st International Conference on Machine Learning, PMLR 235:28466-28479, 2024.

Abstract

In recent years, work has gone into developing deep interpretable methods for image classification that clearly attributes a model’s output to specific features of the data. One such of these methods is the Prototypical Part Network (ProtoPNet), which attempts to classify images based on meaningful parts of the input. While this architecture is able to produce visually interpretable classifications, it often learns to classify based on parts of the image that are not semantically meaningful. To address this problem, we propose the Reward Reweighing, Reselecting, and Retraining (R3) post-processing framework, which performs three additional corrective updates to a pretrained ProtoPNet in an offline and efficient manner. The first two steps involve learning a reward model based on collected human feedback and then aligning the prototypes with human preferences. The final step is retraining, which realigns the base features and the classifier layer of the original model with the updated prototypes. We find that our R3 framework consistently improves both the interpretability and the predictive accuracy of ProtoPNet and its variants.

Cite this Paper


BibTeX
@InProceedings{pmlr-v235-li24ba, title = {Improving Prototypical Visual Explanations with Reward Reweighing, Reselection, and Retraining}, author = {Li, Aaron Jiaxun and Netzorg, Robin and Cheng, Zhihan and Zhang, Zhuoqin and Yu, Bin}, booktitle = {Proceedings of the 41st International Conference on Machine Learning}, pages = {28466--28479}, year = {2024}, editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix}, volume = {235}, series = {Proceedings of Machine Learning Research}, month = {21--27 Jul}, publisher = {PMLR}, pdf = {https://raw.githubusercontent.com/mlresearch/v235/main/assets/li24ba/li24ba.pdf}, url = {https://proceedings.mlr.press/v235/li24ba.html}, abstract = {In recent years, work has gone into developing deep interpretable methods for image classification that clearly attributes a model’s output to specific features of the data. One such of these methods is the Prototypical Part Network (ProtoPNet), which attempts to classify images based on meaningful parts of the input. While this architecture is able to produce visually interpretable classifications, it often learns to classify based on parts of the image that are not semantically meaningful. To address this problem, we propose the Reward Reweighing, Reselecting, and Retraining (R3) post-processing framework, which performs three additional corrective updates to a pretrained ProtoPNet in an offline and efficient manner. The first two steps involve learning a reward model based on collected human feedback and then aligning the prototypes with human preferences. The final step is retraining, which realigns the base features and the classifier layer of the original model with the updated prototypes. We find that our R3 framework consistently improves both the interpretability and the predictive accuracy of ProtoPNet and its variants.} }
Endnote
%0 Conference Paper %T Improving Prototypical Visual Explanations with Reward Reweighing, Reselection, and Retraining %A Aaron Jiaxun Li %A Robin Netzorg %A Zhihan Cheng %A Zhuoqin Zhang %A Bin Yu %B Proceedings of the 41st International Conference on Machine Learning %C Proceedings of Machine Learning Research %D 2024 %E Ruslan Salakhutdinov %E Zico Kolter %E Katherine Heller %E Adrian Weller %E Nuria Oliver %E Jonathan Scarlett %E Felix Berkenkamp %F pmlr-v235-li24ba %I PMLR %P 28466--28479 %U https://proceedings.mlr.press/v235/li24ba.html %V 235 %X In recent years, work has gone into developing deep interpretable methods for image classification that clearly attributes a model’s output to specific features of the data. One such of these methods is the Prototypical Part Network (ProtoPNet), which attempts to classify images based on meaningful parts of the input. While this architecture is able to produce visually interpretable classifications, it often learns to classify based on parts of the image that are not semantically meaningful. To address this problem, we propose the Reward Reweighing, Reselecting, and Retraining (R3) post-processing framework, which performs three additional corrective updates to a pretrained ProtoPNet in an offline and efficient manner. The first two steps involve learning a reward model based on collected human feedback and then aligning the prototypes with human preferences. The final step is retraining, which realigns the base features and the classifier layer of the original model with the updated prototypes. We find that our R3 framework consistently improves both the interpretability and the predictive accuracy of ProtoPNet and its variants.
APA
Li, A.J., Netzorg, R., Cheng, Z., Zhang, Z. & Yu, B.. (2024). Improving Prototypical Visual Explanations with Reward Reweighing, Reselection, and Retraining. Proceedings of the 41st International Conference on Machine Learning, in Proceedings of Machine Learning Research 235:28466-28479 Available from https://proceedings.mlr.press/v235/li24ba.html.

Related Material