[edit]
A Fast Optimization View: Reformulating Single Layer Attention in LLM Based on Tensor and SVM Trick, and Solving It in Matrix Multiplication Time
Proceedings of the Forty-first Conference on Uncertainty in Artificial Intelligence, PMLR 286:1381-1452, 2025.
Abstract
Large language models (LLMs) have played a pivotal role in revolutionizing various facets of our daily existence. Solving attention regression is a fundamental task in optimizing LLMs. In this work, we focus on providing a provable guarantee for the one-layer attention network objective function: given input matrices of a layer, $A_1, A_2, A_3 \in \mathbb{R}^{n \times d}$, our goal is to minimize the loss function: \begin{align*} L(X,Y) = \sum_{j_0=1}^n \sum_{i_0=1}^d ( ⟨⟨\exp( \mathsf{A}_{j_0} x ) , {\bf 1}_n ⟩^{-1} \cdot \exp( \mathsf{A}_{j_0} x ), A_{3} Y_{\*,i_0} ⟩- b_{j_0,i_0} )^2, \end{align*} where $\mathsf{A}_{j_0} \in \mathbb{R}^{n \times d^2}$ is the $j_0$-th block of the Kronecker product of $A_1$ and $A_2$. The matrix $B \in \mathbb{R}^{n \times d}$ represents the output of a layer, and $b_{j_0,i_0} \in \mathbb{R}$ is the $(j_0, i_0)$-th entry of $B$. $Y_{*,i_0} \in \mathbb{R}^d$ is the $i_0$-th column vector of $Y$, and $x \in \mathbb{R}^{d^2}$ is the vectorization of $X$. In self-attention, $Q, K, V \in \mathbb{R}^{d \times d}$ represent the weight matrices for the query, key, and value, respectively. Unlike prior works that relied on simplified and single-variable versions of the attention computation problem, our multivariate loss function analyzes a complete and unsimplified attention layer, treating all these weights as variables, where $X = QK^\top \in \mathbb{R}^{d \times d}$ and $Y = V \in \mathbb{R}^{d \times d}$. We propose an iterative greedy algorithm to train a neural network using the loss function $L(X,Y)$, achieving an error bound of $\epsilon \in (0, 0.1)$. The algorithm runs in $O( ({\cal T}_{\mathrm{mat}}(n,n,d) + {\cal T}_{\mathrm{mat}}(n,d,d) + d^{2\omega}) \log(1/\epsilon) )$ time, where ${\cal T}_{\mathrm{mat}}(a,b,c)$ denotes the time complexity of multiplying an $a \times b$ matrix with a $b \times c$ matrix, and $\omega \approx 2.37$ is the exponent of matrix multiplication.