Discrete Key-Value Bottleneck

Frederik Träuble, Anirudh Goyal, Nasim Rahaman, Michael Curtis Mozer, Kenji Kawaguchi, Yoshua Bengio, Bernhard Schölkopf
Proceedings of the 40th International Conference on Machine Learning, PMLR 202:34431-34455, 2023.

Abstract

Deep neural networks perform well on classification tasks where data streams are i.i.d. and labeled data is abundant. Challenges emerge with non-stationary training data streams such as continual learning. One powerful approach that has addressed this challenge involves pre-training of large encoders on volumes of readily available data, followed by task-specific tuning. Given a new task, however, updating the weights of these encoders is challenging as a large number of weights needs to be fine-tuned, and as a result, they forget information about the previous tasks. In the present work, we propose a model architecture to address this issue, building upon a discrete bottleneck containing pairs of separate and learnable key-value codes. Our paradigm will be to encode; process the representation via a discrete bottleneck; and decode. Here, the input is fed to the pre-trained encoder, the output of the encoder is used to select the nearest keys, and the corresponding values are fed to the decoder to solve the current task. The model can only fetch and re-use a sparse number of these key-value pairs during inference, enabling localized and context-dependent model updates. We theoretically investigate the ability of the discrete key-value bottleneck to minimize the effect of learning under distribution shifts and show that it reduces the complexity of the hypothesis class. We empirically verify the proposed method under challenging class-incremental learning scenarios and show that the proposed model — without any task boundaries — reduces catastrophic forgetting across a wide variety of pre-trained models, outperforming relevant baselines on this task.

Cite this Paper


BibTeX
@InProceedings{pmlr-v202-trauble23a, title = {Discrete Key-Value Bottleneck}, author = {Tr\"{a}uble, Frederik and Goyal, Anirudh and Rahaman, Nasim and Mozer, Michael Curtis and Kawaguchi, Kenji and Bengio, Yoshua and Sch\"{o}lkopf, Bernhard}, booktitle = {Proceedings of the 40th International Conference on Machine Learning}, pages = {34431--34455}, year = {2023}, editor = {Krause, Andreas and Brunskill, Emma and Cho, Kyunghyun and Engelhardt, Barbara and Sabato, Sivan and Scarlett, Jonathan}, volume = {202}, series = {Proceedings of Machine Learning Research}, month = {23--29 Jul}, publisher = {PMLR}, pdf = {https://proceedings.mlr.press/v202/trauble23a/trauble23a.pdf}, url = {https://proceedings.mlr.press/v202/trauble23a.html}, abstract = {Deep neural networks perform well on classification tasks where data streams are i.i.d. and labeled data is abundant. Challenges emerge with non-stationary training data streams such as continual learning. One powerful approach that has addressed this challenge involves pre-training of large encoders on volumes of readily available data, followed by task-specific tuning. Given a new task, however, updating the weights of these encoders is challenging as a large number of weights needs to be fine-tuned, and as a result, they forget information about the previous tasks. In the present work, we propose a model architecture to address this issue, building upon a discrete bottleneck containing pairs of separate and learnable key-value codes. Our paradigm will be to encode; process the representation via a discrete bottleneck; and decode. Here, the input is fed to the pre-trained encoder, the output of the encoder is used to select the nearest keys, and the corresponding values are fed to the decoder to solve the current task. The model can only fetch and re-use a sparse number of these key-value pairs during inference, enabling localized and context-dependent model updates. We theoretically investigate the ability of the discrete key-value bottleneck to minimize the effect of learning under distribution shifts and show that it reduces the complexity of the hypothesis class. We empirically verify the proposed method under challenging class-incremental learning scenarios and show that the proposed model — without any task boundaries — reduces catastrophic forgetting across a wide variety of pre-trained models, outperforming relevant baselines on this task.} }
Endnote
%0 Conference Paper %T Discrete Key-Value Bottleneck %A Frederik Träuble %A Anirudh Goyal %A Nasim Rahaman %A Michael Curtis Mozer %A Kenji Kawaguchi %A Yoshua Bengio %A Bernhard Schölkopf %B Proceedings of the 40th International Conference on Machine Learning %C Proceedings of Machine Learning Research %D 2023 %E Andreas Krause %E Emma Brunskill %E Kyunghyun Cho %E Barbara Engelhardt %E Sivan Sabato %E Jonathan Scarlett %F pmlr-v202-trauble23a %I PMLR %P 34431--34455 %U https://proceedings.mlr.press/v202/trauble23a.html %V 202 %X Deep neural networks perform well on classification tasks where data streams are i.i.d. and labeled data is abundant. Challenges emerge with non-stationary training data streams such as continual learning. One powerful approach that has addressed this challenge involves pre-training of large encoders on volumes of readily available data, followed by task-specific tuning. Given a new task, however, updating the weights of these encoders is challenging as a large number of weights needs to be fine-tuned, and as a result, they forget information about the previous tasks. In the present work, we propose a model architecture to address this issue, building upon a discrete bottleneck containing pairs of separate and learnable key-value codes. Our paradigm will be to encode; process the representation via a discrete bottleneck; and decode. Here, the input is fed to the pre-trained encoder, the output of the encoder is used to select the nearest keys, and the corresponding values are fed to the decoder to solve the current task. The model can only fetch and re-use a sparse number of these key-value pairs during inference, enabling localized and context-dependent model updates. We theoretically investigate the ability of the discrete key-value bottleneck to minimize the effect of learning under distribution shifts and show that it reduces the complexity of the hypothesis class. We empirically verify the proposed method under challenging class-incremental learning scenarios and show that the proposed model — without any task boundaries — reduces catastrophic forgetting across a wide variety of pre-trained models, outperforming relevant baselines on this task.
APA
Träuble, F., Goyal, A., Rahaman, N., Mozer, M.C., Kawaguchi, K., Bengio, Y. & Schölkopf, B.. (2023). Discrete Key-Value Bottleneck. Proceedings of the 40th International Conference on Machine Learning, in Proceedings of Machine Learning Research 202:34431-34455 Available from https://proceedings.mlr.press/v202/trauble23a.html.

Related Material