Scalable Adaptive Computation for Iterative Generation

Allan Jabri, David J. Fleet, Ting Chen
Proceedings of the 40th International Conference on Machine Learning, PMLR 202:14569-14589, 2023.

Abstract

Natural data is redundant yet predominant architectures tile computation uniformly across their input and output space. We propose the Recurrent Interface Network (RIN), an attention-based architecture that decouples its core computation from the dimensionality of the data, enabling adaptive computation for more scalable generation of high-dimensional data. RINs focus the bulk of computation (i.e. global self-attention) on a set of latent tokens, using cross-attention to read and write (i.e. route) information between latent and data tokens. Stacking RIN blocks allows bottom-up (data to latent) and top-down (latent to data) feedback, leading to deeper and more expressive routing. While this routing introduces challenges, this is less problematic in recurrent computation settings where the task (and routing problem) changes gradually, such as iterative generation with diffusion models. We show how to leverage recurrence by conditioning the latent tokens at each forward pass of the reverse diffusion process with those from prior computation, i.e. latent self-conditioning. RINs yield state-of-the-art pixel diffusion models for image and video generation, scaling to1024×1024 images without cascades or guidance, while being domain-agnostic and up to 10× more efficient than 2D and 3D U-Nets.

Cite this Paper


BibTeX
@InProceedings{pmlr-v202-jabri23a, title = {Scalable Adaptive Computation for Iterative Generation}, author = {Jabri, Allan and Fleet, David J. and Chen, Ting}, booktitle = {Proceedings of the 40th International Conference on Machine Learning}, pages = {14569--14589}, year = {2023}, editor = {Krause, Andreas and Brunskill, Emma and Cho, Kyunghyun and Engelhardt, Barbara and Sabato, Sivan and Scarlett, Jonathan}, volume = {202}, series = {Proceedings of Machine Learning Research}, month = {23--29 Jul}, publisher = {PMLR}, pdf = {https://proceedings.mlr.press/v202/jabri23a/jabri23a.pdf}, url = {https://proceedings.mlr.press/v202/jabri23a.html}, abstract = {Natural data is redundant yet predominant architectures tile computation uniformly across their input and output space. We propose the Recurrent Interface Network (RIN), an attention-based architecture that decouples its core computation from the dimensionality of the data, enabling adaptive computation for more scalable generation of high-dimensional data. RINs focus the bulk of computation (i.e. global self-attention) on a set of latent tokens, using cross-attention to read and write (i.e. route) information between latent and data tokens. Stacking RIN blocks allows bottom-up (data to latent) and top-down (latent to data) feedback, leading to deeper and more expressive routing. While this routing introduces challenges, this is less problematic in recurrent computation settings where the task (and routing problem) changes gradually, such as iterative generation with diffusion models. We show how to leverage recurrence by conditioning the latent tokens at each forward pass of the reverse diffusion process with those from prior computation, i.e. latent self-conditioning. RINs yield state-of-the-art pixel diffusion models for image and video generation, scaling to1024×1024 images without cascades or guidance, while being domain-agnostic and up to 10× more efficient than 2D and 3D U-Nets.} }
Endnote
%0 Conference Paper %T Scalable Adaptive Computation for Iterative Generation %A Allan Jabri %A David J. Fleet %A Ting Chen %B Proceedings of the 40th International Conference on Machine Learning %C Proceedings of Machine Learning Research %D 2023 %E Andreas Krause %E Emma Brunskill %E Kyunghyun Cho %E Barbara Engelhardt %E Sivan Sabato %E Jonathan Scarlett %F pmlr-v202-jabri23a %I PMLR %P 14569--14589 %U https://proceedings.mlr.press/v202/jabri23a.html %V 202 %X Natural data is redundant yet predominant architectures tile computation uniformly across their input and output space. We propose the Recurrent Interface Network (RIN), an attention-based architecture that decouples its core computation from the dimensionality of the data, enabling adaptive computation for more scalable generation of high-dimensional data. RINs focus the bulk of computation (i.e. global self-attention) on a set of latent tokens, using cross-attention to read and write (i.e. route) information between latent and data tokens. Stacking RIN blocks allows bottom-up (data to latent) and top-down (latent to data) feedback, leading to deeper and more expressive routing. While this routing introduces challenges, this is less problematic in recurrent computation settings where the task (and routing problem) changes gradually, such as iterative generation with diffusion models. We show how to leverage recurrence by conditioning the latent tokens at each forward pass of the reverse diffusion process with those from prior computation, i.e. latent self-conditioning. RINs yield state-of-the-art pixel diffusion models for image and video generation, scaling to1024×1024 images without cascades or guidance, while being domain-agnostic and up to 10× more efficient than 2D and 3D U-Nets.
APA
Jabri, A., Fleet, D.J. & Chen, T.. (2023). Scalable Adaptive Computation for Iterative Generation. Proceedings of the 40th International Conference on Machine Learning, in Proceedings of Machine Learning Research 202:14569-14589 Available from https://proceedings.mlr.press/v202/jabri23a.html.

Related Material