A Statistical Framework for Data-dependent Retrieval-Augmented Models

Soumya Basu, Ankit Singh Rawat, Manzil Zaheer
Proceedings of the 41st International Conference on Machine Learning, PMLR 235:3197-3223, 2024.

Abstract

Modern ML systems increasingly augment input instances with additional relevant information to enhance final prediction. Despite growing interest in such retrieval-augmented models, their fundamental properties and training are not well understood. We propose a statistical framework to study such models with two components: 1) a retriever to identify the relevant information out of a large corpus via a data-dependent metric; and 2) a predictor that consumes the input instances along with the retrieved information to make the final predictions. We present a principled method for end-to-end training of both components and draw connections with various training approaches in the literature. Furthermore, we establish excess risk bounds for retrieval-augmented models while delineating the contributions of both retriever and predictor towards the model performance.We validate the utility of our proposed training methods along with the key takeaways from our statistical analysis on open domain question answering task where retrieval augmentation is important.

Cite this Paper


BibTeX
@InProceedings{pmlr-v235-basu24a, title = {A Statistical Framework for Data-dependent Retrieval-Augmented Models}, author = {Basu, Soumya and Rawat, Ankit Singh and Zaheer, Manzil}, booktitle = {Proceedings of the 41st International Conference on Machine Learning}, pages = {3197--3223}, year = {2024}, editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix}, volume = {235}, series = {Proceedings of Machine Learning Research}, month = {21--27 Jul}, publisher = {PMLR}, pdf = {https://raw.githubusercontent.com/mlresearch/v235/main/assets/basu24a/basu24a.pdf}, url = {https://proceedings.mlr.press/v235/basu24a.html}, abstract = {Modern ML systems increasingly augment input instances with additional relevant information to enhance final prediction. Despite growing interest in such retrieval-augmented models, their fundamental properties and training are not well understood. We propose a statistical framework to study such models with two components: 1) a retriever to identify the relevant information out of a large corpus via a data-dependent metric; and 2) a predictor that consumes the input instances along with the retrieved information to make the final predictions. We present a principled method for end-to-end training of both components and draw connections with various training approaches in the literature. Furthermore, we establish excess risk bounds for retrieval-augmented models while delineating the contributions of both retriever and predictor towards the model performance.We validate the utility of our proposed training methods along with the key takeaways from our statistical analysis on open domain question answering task where retrieval augmentation is important.} }
Endnote
%0 Conference Paper %T A Statistical Framework for Data-dependent Retrieval-Augmented Models %A Soumya Basu %A Ankit Singh Rawat %A Manzil Zaheer %B Proceedings of the 41st International Conference on Machine Learning %C Proceedings of Machine Learning Research %D 2024 %E Ruslan Salakhutdinov %E Zico Kolter %E Katherine Heller %E Adrian Weller %E Nuria Oliver %E Jonathan Scarlett %E Felix Berkenkamp %F pmlr-v235-basu24a %I PMLR %P 3197--3223 %U https://proceedings.mlr.press/v235/basu24a.html %V 235 %X Modern ML systems increasingly augment input instances with additional relevant information to enhance final prediction. Despite growing interest in such retrieval-augmented models, their fundamental properties and training are not well understood. We propose a statistical framework to study such models with two components: 1) a retriever to identify the relevant information out of a large corpus via a data-dependent metric; and 2) a predictor that consumes the input instances along with the retrieved information to make the final predictions. We present a principled method for end-to-end training of both components and draw connections with various training approaches in the literature. Furthermore, we establish excess risk bounds for retrieval-augmented models while delineating the contributions of both retriever and predictor towards the model performance.We validate the utility of our proposed training methods along with the key takeaways from our statistical analysis on open domain question answering task where retrieval augmentation is important.
APA
Basu, S., Rawat, A.S. & Zaheer, M.. (2024). A Statistical Framework for Data-dependent Retrieval-Augmented Models. Proceedings of the 41st International Conference on Machine Learning, in Proceedings of Machine Learning Research 235:3197-3223 Available from https://proceedings.mlr.press/v235/basu24a.html.

Related Material