Microscale
0
Act IIInside the Machine
lesson mha-to-gqa · 9 min · 50 xp

From MHA to GQA

Collapse 32 heads into 8 groups — KV cache drops 3×

The KV cache is the real constraint

In the multi-head lesson we discovered that each head has its own learned WQ,WK,WVW^Q, W^K, W^Vprojections. At inference time, during autoregressive generation, we cache the keys and values of every token we've already processed — otherwise every new token would require recomputing attention over the whole sequence from scratch. This is called the KV cache.

KV cache memory, per token
bytes per token  =  2Lhdhb\text{bytes per token} \;=\; 2 \cdot L \cdot h \cdot d_h \cdot b

22 because we store both K and V. LL = number of transformer layers. hh = number of KV heads. dhd_h = per-head dimension. bb = bytes per element (2 for FP16).

For a Llama-2 7B with L=32,h=32,dh=128L=32, h=32, d_h=128, FP16, that's about 0.5 MB per token. At 4096 tokens it's 2 GB of KV cache for one sequence— more than the weights for a 3B model. Serve ten concurrent users at 4k each and your KV cache is 20 GB, before you've even allocated the model.

MMXXVI
historical note
2019 · Noam Shazeer, Google
The first attempt to shrink the KV cache was Multi-Query Attention (MQA): share a single K/V pair across all query heads. That's an h×h\times reduction — dramatic. Shazeer's one-page 2019 paper (“Fast Transformer Decoding: One Write-Head is All You Need”) showed the quality cost was real but acceptable for many tasks. MQA got adopted by PaLM and a handful of Google models, but it was consistently noted to hurt quality on harder benchmarks. For four years the industry wanted something between MHA's quality and MQA's compression.

The fix — divide queries into groups

Grouped-query attention (Ainslie et al. 2023) found the sweet spot. Divide the hh query heads into gg groups, each group sharing one K/V pair:

  • g=hg = h → full MHA (no compression)
  • g=1g = 1 → MQA (maximum compression, quality hit)
  • 1<g<h1 < g < h → GQA (tunable)

Modern SLMs live in the sweet spot with h/g=3h/g = 3 or h/g=4h/g = 4 — Phi-4-mini, Llama 3.2-3B, Qwen3-4B, SmolLM3-3B all sit there.

◆ paper
GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
Ainslie, Lee-Thorp, de Jong, Zemlyanskiy, Lebrón, Sanghai · 2023 · EMNLP 2023
arxiv:2305.13245
The Ainslie paper showed that you can uptrain an existing MHA model into GQA by mean-pooling the K and V projection weights within each new group, then fine-tuning for roughly 5% of the original pretraining compute. The quality loss is minimal and the KV cache savings are full. Llama 2 70B was the first major model to ship with GQA via uptraining.
MQA — one K/V pair total

Maximum compression (h× reduction)

All query heads pool their retrieval direction into one K/V

Quality loss: 1–3 points on SuperGLUE, visible on summarisation

Used by: PaLM, some early Google models

GQA — g K/V pairs

Tunable compression (h/g×)

Query heads in a group share one K/V

Quality loss: < 0.5 points at h/g ≈ 4

Used by: Llama 2 70B+, Phi-4-mini, Qwen3, Gemma 3, SmolLM3 — essentially every 2024+ SLM

8
snap to divisors of 32 (MHA=32, MQA=1)
compression
KV / token
0.13MB
KV @ 4k ctx
512.0MB
32 query heads · grouped into 8 KV heads
QK/VKV1KV2KV3KV4KV5KV6KV7KV8each group of 4 query heads shares 1 K/V pair
baseline (MHA)
0.50 MB/tok · 2048 MB @ 4k
current (GQA, g=8)
0.13 MB/tok · 512 MB @ 4k

Why it works — mechanistically

Empirically, attention heads cluster: many heads learn nearly-identical retrieval patterns (“attend to the previous token”, “attend to the first token”, “attend to the subject”). Forcing every query head to maintain its own K/V pair is redundant — those near-identical patterns don't need independent K/V storage. GQA lets a small set of K/V pairs be shared across similar query directions, preserving the head diversity via the independent WiQW_i^Q projections but consolidating the storage.

Concrete current choices:

  • Phi-4-mini: h=24h = 24, g=8g = 8 (3× compression)
  • Llama 3.2-3B: h=24h = 24, g=8g = 8 (3×)
  • Qwen3-4B: h=32h = 32, g=8g = 8 (4×)
  • SmolLM3-3B: h=16h = 16, g=4g = 4 (4×)
GQA composes perfectly with FP8 KV cache (another 2× compression on top of GQA's 3–4×). Together you can serve ~6–8× more concurrent sessions on the same GPU than naive MHA FP16. This is why Act VIII is so focused on KV cache mechanics — it's the binding constraint in every production serving workload.
comprehension check
comprehension · 1 / 3

For a 32-layer model with 32 KV heads, d_h = 128, FP16, what's the KV cache per token?