- Published on
Survey of Current Modified Transformer Attention Designs
- Authors
- Name
- Nathan Brake
- @njbrake
The Transformer model has dominated the AI landscape for several years. Popular models like Gemini, GPT, Claude, T5, and Llama all use the general design. The core architectural feature of the Transformer is its attention block. In high level terms for NLP, it's the component that gives the ability for each word to understand how it relates to another word in a sentence. One problem with attention, though, is that its original design is resource intensive. It's expensive to train and run inference on models using the original design (often referred to as "vanilla attention"). In order to enable fast generation for increasingly long sequences (state-of-the-art models can now support 100's of thousands of tokens), a popular area for AI research is exploring modifications to improve the efficiency without degrading performance. In this blog, I'll to describe a few recent advancements in attention mechanism modifications.
Original Causal Self-Attention Mechanism
First, a brief overview of vanilla attention. The attention mechanism of the Transformer is described in the ("Attention is All you Need" paper). Several great explanations of attention can be found:
- Visualized: LLM Visualization
- Mathy: Stanford Lecture on Self-Attention by John Hewitt
- Explained with Python Code: Karpathy Building GPT from Scratch, go to timestamp 1:11:38
- Narrative Explanation: The Illustrated Transformer
Here's a high level description of how attention works; the attention mechanism is slightly different in the encoder layers of a transformer vs the decoder layers of a transformer, which I'll annotate:
Basic Concept:
- Each token in a sequence attends to all tokens (if in an encoder layer) or all previous tokens (if in a decoder layer) and itself.
- For decoder layers, this creates a "causal" structure, ensuring that predictions for a given position can depend only on known outputs at earlier positions.
The equation:
- Attention(Q, K, V) = softmax(QK^T / √d_k) V
Computation Steps:
a) For each token, create three vectors, which are created by multiplying the tokens by trainable weights inside the model:
- Query (Q): What the token is looking for
- Key (K): What the token offers to other tokens
- Value (V): The actual information the token holds
b) Calculate attention scores:
- For each token, compute dot products of its query with all keys. (THIS IS THE EXPENSIVE PART)
- Apply a scaling factor (1/√d_k, where d_k is the dimension of the key vectors).
- For decoder layers, To ensure causality, apply a mask to set all attention scores for future tokens to negative infinity.
- Apply a softmax to normalize the scores.
c) Apply the attention:
- Multiply each value vector by its corresponding attention score.
- Sum up these weighted values to get the output for each token.
Complexity Analysis (why it's O(n^2)):
- For a sequence of length n:
- We compute n query vectors, n key vectors, and n value vectors.
- The attention matrix QK^T is n x n.
- We perform a matrix multiplication for every token with every other token.
- This results in O(n^2) time and space complexity.
- For a sequence of length n:
Proposed Attention Modifications
Sliding Window Attention, Dialated Sliding window, and Global attention (Longformer)
The 2020 Longformer paper is a seminal paper in the area of expanding context length support for the Transformer architecture. The sliding window attention mechanism introduced in the Longformer paper is a modification of the standard transformer attention designed to efficiently handle very long sequences. This sliding window is still used in modern models (e.g. gemma and phi models) Here are key details:
Basic Concept:
- Each token attends to a fixed-size window of surrounding tokens, rather than the entire sequence.
- The window "slides" along the sequence, centering on each token.
- The window size is a hyperparameter: in the paper they settle on a window size of 512 tokens.
Implementation:
- The attention mask is modified to only allow attention within the defined window.
- This results in a band-diagonal attention matrix.
Complexity:
- Reduces the self-attention complexity from O(n^2) to O(n*w), where n is the sequence length and w is the window size.
- This linear complexity allows for processing much longer sequences efficiently.
Global Tokens:
- The Longformer combines sliding window attention with global attention for certain tokens.
- This allows the model to capture both local and global context.
- Certain pre-defined tokens (e.g., [CLS] tokens or sentence starts) have global attention, attending to and being attended by all tokens.
Some noteable limitations (addressed by some of the newer papers) is that the global attention tokens are manually selected based on the task, making it a bit of feature engineering.
ETC
The 2020 paper ETC: Encoding Long and Structured Inputs in Transformers was published around the same time as the previously mentioned Longformer publication. Noteably, the ETC design requires a modification to the Transformer architecture, where a new input sequence called a "global sequence" is added.
Two-Tier Structure:
- ETC introduces a two-tier structure: a long sequence and a global sequence.
- The long sequence contains the main input tokens, while the global sequence contains a smaller set of tokens that can attend to and be attended by all tokens in the long sequence.
Global-Local Attention:
- Tokens in the long sequence attend to: a) A local window of surrounding tokens (similar to the Longformer sliding window attention) b) All tokens in the global sequence
- Tokens in the global sequence attend to: a) All other tokens in the global sequence b) All tokens in the long sequence
CPC (Contrastive Predictive Coding) Auxiliary Task:
- To help the model learn long-range dependencies, ETC introduces a CPC auxiliary task.
- This task involves predicting tokens far from the current position, encouraging the model to capture long-range information. Instead of the standard negative log likelihood loss function, they use what is called a "noise contrastive estimation" loss.
The big difference between ETC and Longformer is the way that they store their global attention: Longformer relies on a selected set of tokens, while ETC modifies the input to allow for the embeddings of the global tokens to be learned during both pre-training (through the CPC task) as well as during fine-tuning.
Transient-Global Attention (LongT5)
The 2022 LongT5 paper improves on the ETC paper by introducing the concept of Transient-Global attention. This is a hybrid approach that combines local and global attention patterns, and leverages the existing T5 architecture.
- Similar to ETC and Longformer, LongT5 using a sliding window so that each token attends to all nearby tokens.
- However, the global attention mechanism is created by creating a configurable number of "blocks", where each block contains the average of "k" number of input tokens.
- For example, if k=3, the input sequence is length 6, and the window size "r" is 1, each token would attend to one token on its left and one token on its right in the sequence, and would also attend to the 2 "global tokens" that are created, where each global token is the average of the 3 tokens in its bucket.
These global tokens can be created "on-the-fly" during the forward pass of the model, so that no other modifications need to be made to the T5 architecture. This makes it much simpler to implement the modified attention mechanism into the existing models.
Global-Local Block Attention (Pegasus-X)
Later in 2022, the Pegasus-X paper took a different approach to extending the context length.
- Global-Local attention:
- The encoder uses a combination of local attention and global tokens
- Local attention: Input tokens are divided into non-overlapping blocks, and tokens can only attend to other tokens within the same block. This is slightly different from sliding window attention in that a token belongs to one block and isn't at the center of the block (as it would be in a sliding window)
- Global tokens: A set of learnable global token embeddings are added that can attend to and be attended by all encoder tokens (Similar to ETC).
- Staggered blocks:
- The local attention blocks are staggered across layers.
- Block boundaries are shifted by half a block every other layer.
- This allows information to flow across blocks without increasing computational complexity.
Grouped Query Attention (GQA)
One of the most recent advancements was in the 2023 paper for Grouped Query Attention (GQA). This is a modification that reduces memory usage and improves inference speed:
- Instead of each token having its own query vector, queries are shared among groups of tokens.
- The number of key-value pairs remains the same as in standard attention.
- This significantly reduces the number of parameters and memory required for attention computation.
GQA has been shown to maintain or even improve performance while offering substantial efficiency gains, making it particularly useful for very large models.
What's next
These are just a few advancement in the past handful of years. It's exciting to see how each breakthrough helps to support the next one. Pegasus-X and LongT5 built on top of Longformer and ETC, and the T5 architecture in general has been instrumental in unlocking research like GQA. The big players like OpenAI/Meta/Google/Microsoft are pouring a lot of resources into being the leaders in the space of AI models, and support for fast and efficient models that support long input context lengths is going to be a key differentiator.