How the KV Cache Works in HuggingFace Transformers
Every token a Transformer generates forces it to re-read the entire conversation so far. The KV cache is what makes this tractable — it stores the key and value projections of every past token so they aren’t recomputed from scratch. HuggingFace models manage this through a Cache abstraction, and the default implementation, DynamicCache, is a clean example of the full mechanism: how the cache is created, how position IDs coordinate RoPE rotations and causal masking, how each attention layer reads and updates it, and how it’s threaded through the generation loop.
This post walks through the KV cache system end-to-end, using Qwen3-0.6B as a concrete reference.
1. DynamicCache
DynamicCache is the default KV cache and the central state object threaded through every HuggingFace generation loop. It is created once at the start and passed through every model.forward() call. All 28 decoder layers share the same DynamicCache instance; each layer accesses its own slot via cache[layer_idx], which returns a DynamicLayer.
1.1 Cache creation
When use_cache=True and no past_key_values is passed, Qwen3Model.forward constructs a fresh cache:
1 | |
At this point the cache is empty. Per-layer storage will be lazily initialised to:
| What | Shape |
|---|---|
| Keys | [batch, 8, total_seq_len, 64] |
| Values | [batch, 8, total_seq_len, 64] |
For the full 28-layer model at sequence length S (bf16):
1 | |
1.2 position_ids
position_ids is a tensor of shape [batch_size, seq_len] that assigns an absolute integer position to every token.
Creation in Qwen3Model.forward
If the caller doesn’t supply position_ids (the common case), they are computed from the cache state:
1 | |
- Prefill (first call, no cache):
past_seen_tokens = 0→position_ids = [[0, 1, 2, ..., prompt_len-1]] - Decode step 1:
past_seen_tokens = prompt_len→position_ids = [[prompt_len]] - Decode step N:
past_seen_tokens = prompt_len + N - 1→position_ids = [[prompt_len + N - 1]]
Before the layer loop
After position_ids is created, Qwen3Model.forward computes two things once, before entering the 28-layer loop. Both are reused unchanged by every attention layer.
RoPE position embeddings (consumes position_ids):
1 | |
Inside Qwen3RotaryEmbedding.forward:
1 | |
Each position p produces a rotation angle θ_p = p / base^(2i/d) for head-dimension pair 2i and 2i+1. The result (cos, sin) has shape [batch, seq_len, head_dim].
Causal mask (consumes past_key_values and position_ids):
1 | |
create_causal_mask uses past_key_values (for kv_length and kv_offset) together with position_ids to build an additive attention mask (0 for allowed, −∞ for forbidden). Each query position can only attend to key positions ≤ itself, and the mask spans the full cached key length, not just the current input slice.
1.3 Per-layer attention flow
Inside each of the 28 attention layers, the flow proceeds in four steps:
Project & normalize:
1
2
3query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)Q, K, V are projected, reshaped to
[batch, num_heads, seq_len, head_dim], and Q and K are layer-normalised viaQwen3RMSNormon the head dimension (a Qwen3-specific detail).Apply RoPE (consumes
position_embeddingsfrom §1.2):1
2cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)Update cache (consumes
past_key_values):1
2if past_key_values is not None:
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)This calls
DynamicLayer.update():1
2
3
4
5
6def update(self, key_states, value_states, *args, **kwargs):
if not self.is_initialized:
self.lazy_initialization(key_states, value_states)
self.keys = torch.cat([self.keys, key_states], dim=-2)
self.values = torch.cat([self.values, value_states], dim=-2)
return self.keys, self.valuesThe new KV tensors are concatenated along the sequence-length dimension (
dim=-2). The returnedkey_statesandvalue_statesnow contain all historical tokens plus the current token.Run attention (consumes the causal mask from §1.2 and the full cached K/V from step 3):
Attention operates over the concatenated tensors, masked by the pre-computed causal mask (viaALL_ATTENTION_FUNCTIONS, which dispatches to Flash Attention, SDPA, or eager).
1.4 Returning the cache
Qwen3Model.forward returns:
1 | |
Qwen3ForCausalLM.forward projects through lm_head and returns:
1 | |
2. The Cost of torch.cat
DynamicCache is correct and safe, but its core mechanism — torch.cat on every decode step — has a performance cost worth understanding.
Each DynamicLayer.update() does:
1 | |
This allocates a brand-new tensor, copies the entire old buffer into it, then frees the old one. At sequence length 4096 with bf16:
1 | |
That’s ~225 MB allocated, copied, and freed on every decode step. Additionally, because the tensor shape and memory address change on every step, torch.compile cannot statically trace the decode loop — no CUDA graphs, and the Python interpreter runs between every token.
3. StaticCache — a Pre-Allocated Alternative
For use cases that need torch.compile and CUDA graphs, HuggingFace provides StaticCache, a pre-allocated, zero-copy alternative.
Instead of growing dynamically, StaticCache allocates one fixed-size slab per layer upfront and writes into it in-place with index_copy_:
1 | |
Because tensor pointers never change, StaticCache supports torch.compile(fullgraph=True) and CUDA graph capture. The trade-off is that max_cache_len must be known upfront — there is no auto-grow.
Usage is a simple change:
1 | |