Abrupt Learning in Transformers – A Case Study in Matrix Completion

by Pulkit Gopalani (PhD Candidate)

Pulkit Gopalani

This post is based on Abrupt Learning in Transformers: A Case Study on Matrix Completion, presented at NeurIPS 2024.

Have you ever wondered how a Transformer works mechanistically, and how it learns to solve various tasks?

Introduction

Transformer-based models are the key architectural innovation behind large language models (LLMs) such as ChatGPT.

However, a fundamental understanding of how these models perform so well on a diverse array of tasks is limited. While training transformer models on specific tasks, the model sometimes gets stuck at a “partial” (or memorizing) solution with low accuracy for a long time before abruptly converging to a solution that generalizes well for the given task. This intriguing behavior during model training has recently been studied in the finite dataset setup as “grokking.”

Studying such phenomena associated with training transformers can help us understand how LLMs learn various capabilities from training on web-scale data. 

To understand abrupt learning in a controlled manner, we formulated the classical low-rank matrix completion problem as a masked language modeling (MLM) task, and showed that it is possible to train a BERT, a well-known type of transformer to solve this task with low error. Low rank matrix completion is an estimation problem, where given a low-rank matrix (i.e., sum of product of a few smaller matrices) with some entries missing, the aim is to fill in the missing entries as accurately as possible, while satisfying the low-rank constraints. In masked language modeling (MLM), some tokens in the input sequence (usually a sentence in some natural language) are masked out (i.e., hidden from the model) and the model is required to predict the correct token for those masked entries.

Hence, the two problems are similar in the sense that we want to complete missing elements of the input, subject to the input satisfying some constraints (linguistic in the case of MLM, low-rank in the case of low-rank matrix completion).

Fig 1: Low rank matrix completion is equivalent to masked language modeling

Importantly, the loss curve for online training (i.e., new training data sampled at every step) showed an early plateau for an extended time, followed by a sudden drop to near-optimal values, despite no changes in the training procedure or hyper-parameters. To understand this sudden drop in loss, we investigated the model’s predictions and attention heads before and after the drop in loss. We found that the model shifts from simply copying the masked input to accurately predicting missing entries in the matrix, and the attention heads transition to interpretable patterns relevant to the task. 

Fig 2: Change in the algorithm used by transformer after the abrupt drop in loss

Results 

• Pre–transition: Copying the Input Matrix – Before the transition, the model is simply copying the input matrix both at observed entries, and predicting close to 0 for missing entries. The attention maps at this stage do not correspond to an interpretable structure. We also find that these attention heads contribute little to the model output: if we replace these heads with an artificial attention head that attends equally to every element of the input, there is very little improvement or decline in model performance with respect to the mean squared error. 

• Post–Transition: Computing Missing Entries – After the transition, the model accurately completes the missing entries, while still copying the observed entries. The attention maps at this stage clearly demonstrate that the model ‘attends’ to relevant tokens in the input, and that the attention layers are crucial for accurate matrix completion. Interestingly, the post–transition model can outperform classical methods for matrix completion (e.g., nuclear norm minimization) with respect to mean squared error (MSE), suggesting that it does not simply recover this algorithm.

Fig 3: Attention heads for a 4-layer, 8-head transformer model trained on low-rank matrix completion (7×7 matrices with rank-2). The highlighted attention head attends to the row of the input element – since the input is the flattened matrix, the block structure corresponds to the model attending to the elements in the same row as the query element.

Fig 4: Attention heads for a 4-layer, 8-head transformer model trained on low-rank matrix completion (7×7 matrices with rank-2). The highlighted attention head attends to the column of the input element – similar to Fig 3, since the input is a flattened matrix, the diagonals correspond to the model attending to the columns of the query element.

Understanding how and why such attention maps appear suddenly during training is an important direction for future work. This would potentially involve understanding the loss landscape of transformers on such tasks, from the perspective of “circuits” formed before and after the sudden drop in loss. 

About the author

Pulkit is a second-year Ph.D. candidate in CSE advised by Prof. Wei Hu. His research focuses on developing principled and scientific understanding of deep learning, especially transformer-based models. He is interested in understanding abrupt improvement in model performance during transformer training, and various LLM capabilities like Chain-of-Thought computation.

Website: Pulkit Gopalani

Editors: Trenton Chang, Vaibhav Balloli, Aurelia Bunescu