Learning Gaussian Multi-Index Models with Gradient Flow: Time Complexity and Directional Convergence

Berfin Simsek, Amire Bendjeddou, Daniel Hsu
Proceedings of The 28th International Conference on Artificial Intelligence and Statistics, PMLR 258:4204-4212, 2025.

Abstract

This work focuses on the gradient flow dynamics of a neural network model that uses correlation loss to approximate a multi-index function on high-dimensional standard Gaussian data. Specifically, the multi-index function we consider is a sum of neurons $f^*(x) = \sum_{j=1}^k \sigma^*(v_j^T x)$ where $v_1, ..., v_k$ are unit vectors, and $\sigma^*$ lacks the first and second Hermite polynomials in its Hermite expansion. It is known that, for the single-index case ($k=1$), overcoming the search phase requires polynomial time complexity. We first generalize this result to multi-index functions characterized by vectors in arbitrary directions. After the search phase, it is not clear whether the network neurons converge to the index vectors, or get stuck at a sub-optimal solution. When the index vectors are orthogonal, we give a complete characterization of the fixed points and prove that neurons converge to the nearest index vectors. Therefore, using $n \asymp k \log k$ neurons ensures finding the full set of index vectors with gradient flow with high probability over random initialization. When $v_i^T v_j = \beta \geq 0$ for all $i \neq j$, we prove the existence of a sharp threshold $\beta_c = c/(c+k)$ at which the fixed point that computes the average of the index vectors transitions from a saddle point to a minimum. Numerical simulations show that using a correlation loss and a mild overparameterization suffices to learn all of the index vectors when they are nearly orthogonal, however, the correlation loss fails when the dot product between the index vectors exceeds a certain threshold.

Cite this Paper


BibTeX
@InProceedings{pmlr-v258-simsek25a, title = {Learning Gaussian Multi-Index Models with Gradient Flow: Time Complexity and Directional Convergence}, author = {Simsek, Berfin and Bendjeddou, Amire and Hsu, Daniel}, booktitle = {Proceedings of The 28th International Conference on Artificial Intelligence and Statistics}, pages = {4204--4212}, year = {2025}, editor = {Li, Yingzhen and Mandt, Stephan and Agrawal, Shipra and Khan, Emtiyaz}, volume = {258}, series = {Proceedings of Machine Learning Research}, month = {03--05 May}, publisher = {PMLR}, pdf = {https://raw.githubusercontent.com/mlresearch/v258/main/assets/simsek25a/simsek25a.pdf}, url = {https://proceedings.mlr.press/v258/simsek25a.html}, abstract = {This work focuses on the gradient flow dynamics of a neural network model that uses correlation loss to approximate a multi-index function on high-dimensional standard Gaussian data. Specifically, the multi-index function we consider is a sum of neurons $f^*(x) = \sum_{j=1}^k \sigma^*(v_j^T x)$ where $v_1, ..., v_k$ are unit vectors, and $\sigma^*$ lacks the first and second Hermite polynomials in its Hermite expansion. It is known that, for the single-index case ($k=1$), overcoming the search phase requires polynomial time complexity. We first generalize this result to multi-index functions characterized by vectors in arbitrary directions. After the search phase, it is not clear whether the network neurons converge to the index vectors, or get stuck at a sub-optimal solution. When the index vectors are orthogonal, we give a complete characterization of the fixed points and prove that neurons converge to the nearest index vectors. Therefore, using $n \asymp k \log k$ neurons ensures finding the full set of index vectors with gradient flow with high probability over random initialization. When $v_i^T v_j = \beta \geq 0$ for all $i \neq j$, we prove the existence of a sharp threshold $\beta_c = c/(c+k)$ at which the fixed point that computes the average of the index vectors transitions from a saddle point to a minimum. Numerical simulations show that using a correlation loss and a mild overparameterization suffices to learn all of the index vectors when they are nearly orthogonal, however, the correlation loss fails when the dot product between the index vectors exceeds a certain threshold.} }
Endnote
%0 Conference Paper %T Learning Gaussian Multi-Index Models with Gradient Flow: Time Complexity and Directional Convergence %A Berfin Simsek %A Amire Bendjeddou %A Daniel Hsu %B Proceedings of The 28th International Conference on Artificial Intelligence and Statistics %C Proceedings of Machine Learning Research %D 2025 %E Yingzhen Li %E Stephan Mandt %E Shipra Agrawal %E Emtiyaz Khan %F pmlr-v258-simsek25a %I PMLR %P 4204--4212 %U https://proceedings.mlr.press/v258/simsek25a.html %V 258 %X This work focuses on the gradient flow dynamics of a neural network model that uses correlation loss to approximate a multi-index function on high-dimensional standard Gaussian data. Specifically, the multi-index function we consider is a sum of neurons $f^*(x) = \sum_{j=1}^k \sigma^*(v_j^T x)$ where $v_1, ..., v_k$ are unit vectors, and $\sigma^*$ lacks the first and second Hermite polynomials in its Hermite expansion. It is known that, for the single-index case ($k=1$), overcoming the search phase requires polynomial time complexity. We first generalize this result to multi-index functions characterized by vectors in arbitrary directions. After the search phase, it is not clear whether the network neurons converge to the index vectors, or get stuck at a sub-optimal solution. When the index vectors are orthogonal, we give a complete characterization of the fixed points and prove that neurons converge to the nearest index vectors. Therefore, using $n \asymp k \log k$ neurons ensures finding the full set of index vectors with gradient flow with high probability over random initialization. When $v_i^T v_j = \beta \geq 0$ for all $i \neq j$, we prove the existence of a sharp threshold $\beta_c = c/(c+k)$ at which the fixed point that computes the average of the index vectors transitions from a saddle point to a minimum. Numerical simulations show that using a correlation loss and a mild overparameterization suffices to learn all of the index vectors when they are nearly orthogonal, however, the correlation loss fails when the dot product between the index vectors exceeds a certain threshold.
APA
Simsek, B., Bendjeddou, A. & Hsu, D.. (2025). Learning Gaussian Multi-Index Models with Gradient Flow: Time Complexity and Directional Convergence. Proceedings of The 28th International Conference on Artificial Intelligence and Statistics, in Proceedings of Machine Learning Research 258:4204-4212 Available from https://proceedings.mlr.press/v258/simsek25a.html.

Related Material