GQA is a way of computing attention where K and V vectors re-used across groups of query matrices, which reduces the amount of KV cache (and therefore GPU memory) required during decode. For example, “GQA-8” means there are eight KV groups; in a model with 96 attention heads, this would mean each of those eight KV groups is shared by 12 heads’ query matrices.

This greatly reduces the amount of KV cache required. For a model with 96 attention heads and a head dimension of 128,

  • Standard MHA requires num_heads * head_dim * 2 values per token per layer. This amounts to 24,576 floating points per token per layer.
  • GQA-8 requires head_groups * head_dim * 2 values per token per layer. This amounts to 2,048 floating points per token per layer.

This is a 12x reduction in required KV cache, since num_heads / head_groups = 12.

GQA was used by Llama-3.1.