What actually runs when you call an inference server
You learned attention as . That formula is correct and beautiful. It is also almost useless for understanding what a production serving engine is actually doing when a request lands. There are four things the textbook version hides:
- The shapes. Real tensors carry a batch dimension, a heads dimension, a sequence dimension, and a head-dim dimension. They reshape, transpose, and reshape again several times per layer. Getting these shapes in your head is half the battle.
- The prefill vs decode split. The same model does two completely different things when you feed it a prompt (one big matmul over the whole prompt) versus when it generates the next token (a single new query against an enormous cached history).
- The KV cache — not as a concept, but as a literal giant GPU tensor whose shape depends on . Decode is fundamentally “append one row to this tensor and attend against the whole thing.”
- The batching. Eight users hit your API at once, each at a different decode step, each with a different prompt length. A single kernel call has to serve all of them. Serving engines achieve this with a trick that most papers never explain: flatten the batch dimension entirely and track sequence boundaries with a cumulative-length tensor.
This lesson walks through all four, using a specific reference: Llama 7B. Hidden dimension , heads , head dimension , 32 layers.
One sentence, end to end — the prefill path
Let's take a specific prompt and trace it through a single attention layer with the shapes of every intermediate tensor. The prompt is “The cat sat on the mat”. Six tokens. That's our .
Click through each phase on the right. The shape annotation under each tensor tells you what PyTorch prints when you call .shape on it. Keep an eye on the heads dimension — this is where the multi-head aspect becomes real.
Reading the shapes
Eight phases. Follow the dimensions as they shift. A few things worth noticing:
- At phase 3 (split-heads) the hidden dimension 4096 gets reshaped into . The heads axis moves to position 1after the transpose, which puts each head's data in contiguous memory for the attention kernel. That transpose is not cosmetic — it's what lets the matmul run efficiently.
- At phase 4 (scores) you get the full attention tensor. Notice: there's one 6×6 attention matrix per head, so 32 of them in parallel. This is also where causal masking is applied.
- At phase 7 (concat) the heads dimension is merged back into the hidden dimension. This undoes the reshape from phase 3 and leaves you with again — the same shape as the input.
- By the end, one prefill layer has produced a output. That output becomes the input to the next layer. Llama 7B has 32 layers, so the whole prefill is 32 repetitions of what you just saw.
The sleight of hand — decoding one token
Now you've finished the prefill. The model has seen all six tokens of “The cat sat on the mat” and is about to generate token #7. What does the next forward pass look like?
You might think: same as before, but with seven tokens instead of six. Run the whole attention operation on a sequence of length 7. That would work but it would be ridiculous — the first six tokens haven't changed, so their K and V vectors haven't changed either, so recomputing them is pure waste. So we don't.
Instead, we cache the K and V of the first six tokens during prefill. Then, when decoding, we only compute the Q, K, V for the new token. Q is used once to produce attention scores; the new K and V are appendedto the cache so they're available for the next decode step. This is the KV cache, finally introduced at the tensor level.
Why this single fact changes everything about serving
Look at what happened. During decode:
- Q has shape — one query vector per head, not six.
- K and V are retrieved from the cache as , then the new K and V vectors for position 7 are appended to make .
- QK⊤ has shape . This is not a full 7×7 attention matrix — it is a single row, because there is only one query. Softmax happens over 7 keys.
- The final output for this step is just — one new token vector. You project to logits, sample, append the sampled token to the sequence, and the next step begins.
Notice the brackets. means “concatenate along the sequence axis”. The cache keeps growing. The query shrinks to length 1. The row-of-the-matrix the decode step touches grows linearly, but the math is shaped identically to prefill.
Four users, one kernel — batched decode
Production serving is rarely one user at a time. Your API has 8, 32, 128 concurrent requests. Each one is at a different decode step — some have just started with a long prompt, some are mid-generation with a few hundred tokens cached, some are on the last token before they finish. How does a single GPU kernel call handle this?
The answer from modern engines (vLLM, SGLang, TGI, TensorRT-LLM) is elegant and a little mind-bending: throw away the batch dimension. Pack all the sequences end-to-end into a single flat tensor, and track where each sequence starts with a small cumulative-length vector. The attention kernel reads the lengths and knows which tokens can attend to which.
The cu_seqlens trick
That cu_seqlenstensor you saw in the diagram is the whole ballgame. It's a cumulative sum of sequence lengths. For a batch of three sequences with lengths [3, 4, 3]:
sequences = [seq_a, seq_b, seq_c]
lengths = [3, 4, 3]
cu_seqlens = [0, 3, 7, 10] # cumulative, with a leading 0
# packed tensor:
# positions [0..3) = seq_a
# positions [3..7) = seq_b
# positions [7..10) = seq_c
flash_attn_varlen_func(
q=q_packed, # shape [10, heads, dim] — one flat axis
k=k_packed,
v=v_packed,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=4,
max_seqlen_k=4,
)Inside the kernel, each thread block figures out which sequence its token belongs to by looking up cu_seqlens, and restricts its attention to keys within the same sequence boundary. No padding, no wasted FLOPs, no separate tensor per sequence. A single 1D packed tensor, one fused kernel call, and all the sequences attend without cross-contamination.
flash_attn_varlen_func. Together with PagedAttention, this is what makes continuous batching in vLLM computationally free relative to fixed-batch inference.Why prefill is compute-bound but decode is memory-bound, definitively
Now you can read the numbers end-to-end. Here's the roofline story told with real shapes:
Q has shape .
K has shape .
is — a full 6×6 matrix of attention scores.
FLOPs: per head, ×32 heads, ×32 layers ≈ 9.4 MFLOPs for attention alone.
Same 14 GB weights streamed, but used for 6 tokens.
Arithmetic intensity: ≈ 6 FLOPs per weight byte — six times higher than decode, much closer to the ridge point.
Status: compute-bound.
Q has shape — one query.
K has shape from the cache.
is — one row.
FLOPs per token: where is total params (~14 GFLOPs).
14 GB of weights streamed through HBM per decoded token.
Arithmetic intensity: ≈ 1 FLOP per weight byte — 295× below the H100 ridge point.
Status: memory-bound. Compute cores sit idle.
Batching changes the decode story: if you process 32 sequences at once, the same weight read serves 32 queries, so the effective arithmetic intensity jumps from 1 to 32. At batch 32, you're back in the compute-bound region. This is why continuous batching matters so dramatically — it turns decode from the worst-case roofline workload into something that runs near peak.