Microscale
0
Act IIInside the Machine
lesson attention-production · 14 min · 65 xp

Attention in production

Prefill vs decode, KV cache as a tensor, one kernel call serving many users

What actually runs when you call an inference server

You learned attention as softmax(QK/dk)V\text{softmax}(QK^\top/\sqrt{d_k})V. That formula is correct and beautiful. It is also almost useless for understanding what a production serving engine is actually doing when a request lands. There are four things the textbook version hides:

  1. The shapes. Real tensors carry a batch dimension, a heads dimension, a sequence dimension, and a head-dim dimension. They reshape, transpose, and reshape again several times per layer. Getting these shapes in your head is half the battle.
  2. The prefill vs decode split. The same model does two completely different things when you feed it a prompt (one big matmul over the whole prompt) versus when it generates the next token (a single new query against an enormous cached history).
  3. The KV cache — not as a concept, but as a literal giant GPU tensor whose shape depends on [layers,heads,Lmax,dh][\text{layers},\text{heads},L_\text{max},d_h]. Decode is fundamentally “append one row to this tensor and attend against the whole thing.”
  4. The batching. Eight users hit your API at once, each at a different decode step, each with a different prompt length. A single kernel call has to serve all of them. Serving engines achieve this with a trick that most papers never explain: flatten the batch dimension entirely and track sequence boundaries with a cumulative-length tensor.

This lesson walks through all four, using a specific reference: Llama 7B. Hidden dimension d=4096d = 4096, heads h=32h = 32, head dimension dh=128d_h = 128, 32 layers.

One sentence, end to end — the prefill path

Let's take a specific prompt and trace it through a single attention layer with the shapes of every intermediate tensor. The prompt is “The cat sat on the mat”. Six tokens. That's our L=6L = 6.

Click through each phase on the right. The shape annotation under each tensor tells you what PyTorch prints when you call .shape on it. Keep an eye on the heads dimension — this is where the multi-head aspect becomes real.

prefill — one layer of Llama 7B, L = 6 tokens, batch = 1
x
[1, 6, 4096]
Input hidden states for the 6 tokens

Reading the shapes

Eight phases. Follow the dimensions as they shift. A few things worth noticing:

  • At phase 3 (split-heads) the hidden dimension 4096 gets reshaped into h×dh=32×128h \times d_h = 32 \times 128. The heads axis moves to position 1after the transpose, which puts each head's data in contiguous memory for the attention kernel. That transpose is not cosmetic — it's what lets the matmul run efficiently.
  • At phase 4 (scores) you get the full [1,32,6,6][1, 32, 6, 6]attention tensor. Notice: there's one 6×6 attention matrix per head, so 32 of them in parallel. This is also where causal masking is applied.
  • At phase 7 (concat) the heads dimension is merged back into the hidden dimension. This undoes the reshape from phase 3 and leaves you with [1,6,4096][1, 6, 4096] again — the same shape as the input.
  • By the end, one prefill layer has produced a [1,6,4096][1, 6, 4096] output. That output becomes the input to the next layer. Llama 7B has 32 layers, so the whole prefill is 32 repetitions of what you just saw.

The sleight of hand — decoding one token

Now you've finished the prefill. The model has seen all six tokens of “The cat sat on the mat” and is about to generate token #7. What does the next forward pass look like?

You might think: same as before, but with seven tokens instead of six. Run the whole attention operation on a sequence of length 7. That would work but it would be ridiculous — the first six tokens haven't changed, so their K and V vectors haven't changed either, so recomputing them is pure waste. So we don't.

Instead, we cache the K and V of the first six tokens during prefill. Then, when decoding, we only compute the Q, K, V for the new token. Q is used once to produce attention scores; the new K and V are appendedto the cache so they're available for the next decode step. This is the KV cache, finally introduced at the tensor level.

decode step — generating token #7 with 6 tokens already in the KV cache
new token only
x_new
[1, 1, 4096]
after projections
Q_new
[1, 32, 1, 128]
K_new
[1, 32, 1, 128]
V_new
[1, 32, 1, 128]
+
kv cache (from prefill)
K_cache, V_cache
[1, 32, 6, 128]
after appending new token
K = concat(K_cache, K_new)
[1, 32, 7, 128]
V = concat(V_cache, V_new)
[1, 32, 7, 128]
attention — one row of the matrix
Q_new[1,32,1,128]·K^T[1,32,128,7]=scores[1,32,1,7]
softmax(scores/√128)[1,32,1,7]·V[1,32,7,128]=output[1,32,1,128]
concat heads + W_O=final output[1, 1, 4096]
The key insight: Q has sequence length 1, K and V have sequence length 7. The matrix QKQK^\top is 1×71 \times 7 — one row, seven columns. It is a single row of what would be the full attention matrix if this were the very last token of a 7-token prefill. You compute only this row. The cost is linear in L, not quadratic.

Why this single fact changes everything about serving

Look at what happened. During decode:

  • Q has shape [1,32,1,128][1, 32, 1, 128] one query vector per head, not six.
  • K and V are retrieved from the cache as [1,32,6,128][1, 32, 6, 128], then the new K and V vectors for position 7 are appended to make [1,32,7,128][1, 32, 7, 128].
  • QK has shape [1,32,1,7][1, 32, 1, 7]. This is not a full 7×7 attention matrix — it is a single row, because there is only one query. Softmax happens over 7 keys.
  • The final output for this step is just [1,1,4096][1, 1, 4096] — one new token vector. You project to logits, sample, append the sampled token to the sequence, and the next step begins.
what decode actually computes, in two lines
Qnew=XnewWQ,Knew=XnewWK,Vnew=XnewWVQ_\text{new} = X_\text{new} W_Q, \quad K_\text{new} = X_\text{new} W_K, \quad V_\text{new} = X_\text{new} W_V
outputt=softmax ⁣ ⁣(Qnew[Kcache;Knew]dh)[Vcache;Vnew]\text{output}_t = \text{softmax}\!\!\left(\frac{Q_\text{new}\,[K_\text{cache}; K_\text{new}]^\top}{\sqrt{d_h}}\right)[V_\text{cache}; V_\text{new}]

Notice the brackets. [Kcache;Knew][K_\text{cache}; K_\text{new}] means “concatenate along the sequence axis”. The cache keeps growing. The query shrinks to length 1. The row-of-the-matrix the decode step touches grows linearly, but the math is shaped identically to prefill.

Four users, one kernel — batched decode

Production serving is rarely one user at a time. Your API has 8, 32, 128 concurrent requests. Each one is at a different decode step — some have just started with a long prompt, some are mid-generation with a few hundred tokens cached, some are on the last token before they finish. How does a single GPU kernel call handle this?

The answer from modern engines (vLLM, SGLang, TGI, TensorRT-LLM) is elegant and a little mind-bending: throw away the batch dimension. Pack all the sequences end-to-end into a single flat tensor, and track where each sequence starts with a small cumulative-length vector. The attention kernel reads the lengths and knows which tokens can attend to which.

continuous batching — 4 concurrent users at different decode positions
each sequence in its own slot
A
Q_new
LA = 12 + 1
B
Q_new
LB = 27 + 1
C
Q_new
LC = 4 + 1
D
Q_new
LD = 33 + 1
Four users, four different KV cache lengths (12, 27, 4, 33). Each is about to get one new token. If we treated them as a normal batch, we'd have to pad the sequence dim to max(27, 33, 12, 4) = 33 and waste most of the compute. Real engines don't pad.

The cu_seqlens trick

That cu_seqlenstensor you saw in the diagram is the whole ballgame. It's a cumulative sum of sequence lengths. For a batch of three sequences with lengths [3, 4, 3]:

python
sequences = [seq_a, seq_b, seq_c]
lengths   = [3, 4, 3]
cu_seqlens = [0, 3, 7, 10]    # cumulative, with a leading 0

# packed tensor:
# positions [0..3) = seq_a
# positions [3..7) = seq_b
# positions [7..10) = seq_c

flash_attn_varlen_func(
    q=q_packed,     # shape [10, heads, dim] — one flat axis
    k=k_packed,
    v=v_packed,
    cu_seqlens_q=cu_seqlens,
    cu_seqlens_k=cu_seqlens,
    max_seqlen_q=4,
    max_seqlen_k=4,
)

Inside the kernel, each thread block figures out which sequence its token belongs to by looking up cu_seqlens, and restricts its attention to keys within the same sequence boundary. No padding, no wasted FLOPs, no separate tensor per sequence. A single 1D packed tensor, one fused kernel call, and all the sequences attend without cross-contamination.

◆ paper
FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
Tri Dao · 2023
arxiv:2307.08691
FA2 added first-class support for variable-length sequences via flash_attn_varlen_func. Together with PagedAttention, this is what makes continuous batching in vLLM computationally free relative to fixed-batch inference.

Why prefill is compute-bound but decode is memory-bound, definitively

Now you can read the numbers end-to-end. Here's the roofline story told with real shapes:

prefill on a 6-token prompt

Q has shape [1,32,6,128][1, 32, 6, 128].

K has shape [1,32,6,128][1, 32, 6, 128].

QKQK^\top is [1,32,6,6][1, 32, 6, 6] — a full 6×6 matrix of attention scores.

FLOPs: 266128=9,200\sim 2 \cdot 6 \cdot 6 \cdot 128 = 9{,}200 per head, ×32 heads, ×32 layers ≈ 9.4 MFLOPs for attention alone.

Same 14 GB weights streamed, but used for 6 tokens.

Arithmetic intensity: ≈ 6 FLOPs per weight byte — six times higher than decode, much closer to the ridge point.

Status: compute-bound.

decode generating one token

Q has shape [1,32,1,128][1, 32, 1, 128]one query.

K has shape [1,32,L,128][1, 32, L, 128] from the cache.

QKQK^\top is [1,32,1,L][1, 32, 1, L] — one row.

FLOPs per token: 2N\sim 2N where NN is total params (~14 GFLOPs).

14 GB of weights streamed through HBM per decoded token.

Arithmetic intensity: ≈ 1 FLOP per weight byte — 295× below the H100 ridge point.

Status: memory-bound. Compute cores sit idle.

Batching changes the decode story: if you process 32 sequences at once, the same weight read serves 32 queries, so the effective arithmetic intensity jumps from 1 to 32. At batch 32, you're back in the compute-bound region. This is why continuous batching matters so dramatically — it turns decode from the worst-case roofline workload into something that runs near peak.

comprehension check
comprehension · 1 / 4

In a prefill forward pass over a 6-token prompt, what shape is the attention-scores tensor QKQK^\top for a single Llama-7B layer?