Attention is the mathematical operation within a transformer that allows different parts of the input to figure out how important they are to each other.

Info

Understanding how attention fits into transformers is described in Transformers > Attention. This page describes different ways in which attention can be implemented.

Full attention

“Full attention” is the form of attention where every token interacts with every other token. It takes an input (the activations tensor) and creates three derived tensors:

  • Query tensor , which captures what information each token needs from the other tokens
  • Key tensor , which captures what information each token has to offer the other tokens
  • Value tensor , which captures the features that get shared by other tokens when the query and key are similar

All three tensors are simply the activations multiplied by learned weights (which are model parameters). If the activation tensor (aka the “residual stream”) is , then

Where these weight tensors have dimensions where is the hidden dimension of the transformer. This means , , and have the same dimensions as the activations tensor (, where is the transformer’s sequence length).

Roughly, attention works like this:

  1. It first calculates a score between each token and every other token by comparing their projected representations ( vs ). The more similar they are, the higher the score. After this, we have a tensor of shape describing how much every token attends to every other token. These scores are then normalized with softmax.
  2. For each token , the normalized score for every token is multiplied by that token’s value vector . These weighted value vectors are added up, producing a new context vector for token . Doing this for all tokens results in an output tensor of shape .
  3. This output tensor is finally multiplied by one last matrix of learned weights, (shape ), to produce the new activations that feed into the next part of the transformer layer.

Mathematically, attention looks like this:

Because , , and have the dimensions of (remember, is sequence length, or when the context isn’t completed filled, the number of tokens in the context so far), there are two quadratic terms here:

  • multiplies tensors of shape by and outputs a tensor
  • then multiplies tensors of shape and and outputs a tensor

So the computational cost of attention scales as .

Multi-head Attention (MHA)

Multi-head attention splits the attention part of each transformer layer into multiple attention heads, and each head gets a slice of every token’s hidden dimension. Each head decides the importance of every token with respect to the others, but only in its slice of the hidden dimension. Each head does this independently (and in parallel), and only after each head calculates the attention values for its slice is there a global mixing across heads.

A little more precisely, each token vector (of length ) going into the attention block is divided up over the attention heads. For example, if a model has a hidden dimension of features and attention heads, each attention head would be assigned a slice that contains features.

As a result, each head has its own subset of , , and and their associated weights , , and . Each head computes on its subslice of the hidden dimension independently, resulting in attention output tensors. These per-head outputs are all concatenated together, then multiplied with another set of learned weights, , to give you a fully mixed attention output tensor.

Properties

Multi-head attention is generally a good idea because it allows the model to learn multiple attention patterns in parallel. The output projection is also a more compact way of representing the mixing of features than representing the full attention tensor.

If you have too many attention heads, many of them will either zero themselves out or become redundant. At the extreme, the subslice of the feature space that each head gets is so small that it cannot learn useful attention patterns.

Worked Example

As an example, let’s say we have an attention block with hidden dimension and two heads (), and we pass a single token through it. The activations vector for that token might be

It gets divided over the two heads:

After each head processes its , there are two activation outputs :

They get concatenated along the feature axis:

Let’s say is an learned matrix that looks like:

The final attention output is the concatenated output tensor multiplied by this weight matrix:

The resulting attention tensor coming out of multi-head attention would then be:

Grouped-Query Attention (GQA)

GQA was used by Llama-3.1.

Latent attention (MLA)

Latent attention factorizes the matrix into two smaller matrices. One such matrix can get go into the KV cache, and the other holds weights. Both take less memory, and you rehydrate the original full matrix only when needed.

This approach effectively stores a compressed representation (a latent representation) of the and tensors that are shared across the heads in each transformer layer. During a forward pass, each attention head rehydrates its real and using this latent representation on-demand. This rehydration is more computationally expensive than storing and in memory outright, but it saves a ton of memory and allows you to fit a larger model in the same GPU memory footprint.

This method was used by DeepSeek-R1.

Linear attention

Linear attention is a family of techniques that replaces full attention with an approximation that scales as rather than . Recall that full attention does the following:

StepOperationComplexityOutput Shape
1Multiply or
2Multiply previous by

Linear attention does this:

StepOperationComplexityOutput Shape
1Compute summary of multiply by
2Multiply summary of with previous

The summary transformation is a nonlinear feature mapping. Remember that the key , value , and query tensors are really each a bunch of per-token vectors. Each token has its own key vector , value vector , and query vector within , , and , and these vectors all have length . With this in mind,

  1. Linear attention uses a function which is a feature map or kernel map. This turns a key vector into a feature-mapped key vector of length .
  2. For each token , we take the outer product of the feature-mapped key vector and the unmodified value vector . This results in a tensor of shape that encodes how token ‘s key and value interact.
  3. Each token’s matrix is them added together to form a summary matrix of shape which encodes a summary of how tokens’ keys and values interact. However, there is no mixing across tokens yet!
  4. Now for each query vector , calculate its feature-mapped query vector using the same feature map we used to generate .
  5. This feature-mapped query vector is multiplied by the summary matrix (), resulting in an output vector in the same space as the value vector. Its length is still , and it now encodes a blended weight of all keys and values (which were represented in the summary matrix) for each feature of the query.

Stacking all of these output vectors back into a tensor gives you a attention tensor, but you never had to calculate a full matrix-matrix multiplication.

I first become aware of linear attention after reading the Jet-Nemotron paper which selectively replaces full-attention with linear attention in layers where it does not compromise model quality. The corollary is that linear attention does negatively impact model quality.

Ring attention

Ring attention is a method for distributing Full attention across multiple nodes. It uses a specialized version of generalized distributed (e.g., block-cyclic) GEMM have been doing for decades. Because it is specifically designed for attention, it performs a 1D decomposition (partitioning along the sequence length, so that nodes get blocks of token vectors) and incorporates the distributed softmax along with the GEMM.

Ring attention is likely the method by which Google Gemini offers sequence lengths of millions of tokens, since its creator works at DeepMind.