PySpur

DeepSeek's Multi-Head Latent Attention and Other KV Cache Tricks

DeepSeek's Multi-Head Latent Attention and Other KV Cache Tricks

Overview:

  1. Introduction: We'll explore how Key-Value (KV) caches make language models like ChatGPT and DeepSeek faster at generating text, by making a clever trade-off between memory usage and computation time.
  2. MLA and other Tricks: We'll then look at 11 recent research papers, including DeepSeek's Multi-head Latent Attention (MLA), that build upon this basic idea to make LLM inference even more time-efficient.

Understanding the Problem: Why Text Generation is Slow

Let's start with a simple analogy. Imagine you're writing a story, and for each new word you write, you need to re-read the entire story so far to maintain consistency. The longer your story gets, the more time you spend re-reading. This is exactly what large language models face during text generation.

The Basic Building Block: Self-Attention

At the heart of modern language models is a mechanism called self-attention. For a sequence of nn tokens (think of tokens as roughly corresponding to words), each token needs to "look at" or "attend to" all other tokens to understand the context.

This looking-at-everything process has a computational cost that grows with the sequence length:

  • For nn tokens, each token needs to look at all nn tokens
  • This means the cost is proportional to n×n=n2n \times n = n^2
  • In mathematical notation, we write this as O(n2)O(n^2) complexity

The Real Problem: Generating Text One Token at a Time

When a language model generates text, it does so one token at a time, and this is where things get computationally expensive:

  1. First token: Look at 1 token (cost: O(12)O(1^2))
  2. Second token: Look at 2 tokens (cost: O(22)O(2^2))
  3. Third token: Look at 3 tokens (cost: O(32)O(3^2))
  4. And so on until the nn-th token: Look at nn tokens (cost: O(n2)O(n^2))

If we add up all these costs for generating a sequence of length nn, we get:

O(12+22+32++n2)O(n3)O(1^2 + 2^2 + 3^2 + \dots + n^2) \approx O(n^3)

This O(n3)O(n^3) cost means that as your text gets longer, the generation time grows extremely quickly. For example, generating a sequence twice as long takes roughly eight times as long! Clearly, we need a better approach.


The Solution: Key-Value (KV) Cache

The key insight behind KV caching is that we're doing a lot of redundant work. When generating each new token, we're recomputing things for all previous tokens that we've already processed before. Let's see how we can fix this.

What is a Key-Value Cache?

Think of a KV cache like a smart notepad where we write down important information about each token the first time we see it. For each token, we compute and store two things:

  1. A key (kk): Think of this as an addressing mechanism - it helps determine how relevant this token is to future tokens
  2. A value (vv): Think of this as the actual information that gets used when this token is found to be relevant

Mathematically, we compute these as:

  • Key: k=xWKk = x W_K (where xx is the token and WKW_K is a learned transformation)
  • Value: v=xWVv = x W_V (where WVW_V is another learned transformation)

When generating a new token, we use its query (computed similarly to keys) to find relevant information in our cache by comparing it with all stored keys. The matching values are then used to help generate the token.

How the KV Cache Makes Things Faster

With a KV cache, the process becomes much more efficient:

  1. When we see a new token, we only need to compute its key and value once
  2. For all future tokens, we can just look up these pre-computed values from our cache
  3. This means each new token only needs to do a small amount of new work, instead of redoing all previous computations

The trade-off is clear:

  • We use more memory to store all the keys and values. For a model with:
    • LL layers
    • HH attention heads
    • Sequence length nn
    • Key/value dimension dkd_k The total memory cost is L×H×n×dk×2L \times H \times n \times d_k \times 2 values (the factor of 2 accounts for both keys and values). This grows linearly with sequence length (O(n)O(n)), but the constant factors can be substantial for large models.
  • But in return, we reduce the computation cost from O(n3)O(n^3) to O(n2)O(n^2)

To understand why it's O(n2)O(n^2), let's look at the cost at each step:

  1. Step 1: Process 1 token → cost O(1)O(1)
  2. Step 2: Process 1 new token + look at 1 cached token → cost O(2)O(2)
  3. Step 3: Process 1 new token + look at 2 cached tokens → cost O(3)O(3)
  4. And so on...

Adding these up:

O(1+2+3++n)=O(n2)O(1 + 2 + 3 + \dots + n) = O(n^2)

This is a dramatic improvement over O(n3)O(n^3)! While we still have to do the fundamental work of looking at all previous tokens (O(n2)O(n^2)), we avoid the costly recomputation at each step.


The Memory Challenge: Why We Need Better Solutions

While KV cache is a powerful optimization, it comes with a significant memory cost. Let's look at a concrete example using a modern large language model like Llama3 70B with:

  • L=80L = 80 layers
  • H=64H = 64 attention heads
  • B=8B = 8 batch size of 8 sequences
  • dk=128d_k = 128 key/value dimension
  • 16-bit precision

The memory required for a batch of 8 sequences of 1000 tokens each would be:

L×H×B×n×dk×2×2 bytes=80×64×8×1000×128×2×2 bytes=20.97GBL \times H \times B \times n \times d_k \times 2 \times 2 \text{ bytes} = 80 \times 64 \times 8 \times 1000 \times 128 \times 2 \times 2 \text{ bytes} = 20.97\text{GB}

This substantial memory usage creates several challenges:

  1. Scales linearly with sequence length
  2. Multiplies with batch size for parallel processing
  3. Limits the maximum context length we can handle
  4. Constrains deployment on memory-limited devices

These challenges have sparked a wave of innovation in the research community, leading to various techniques for optimizing KV cache usage. Let's explore these cutting-edge solutions.

Can we improve over naive KV caches?

The following papers represent key innovations in KV cache optimization. We'll explore them through three main approaches: token selection, post-hoc compression techniques, and architectural redesigns.

Token Selection and Pruning Approaches

1) Heavy-Hitter Oracle (H2O)

H2O introduces the concept of identifying and preserving important tokens in the KV cache:

  • Heavy-Hitter Tokens: H2O identifies tokens with the highest accumulated attention scores during generation, following a power-law distribution. These tokens are critical for model functionality and are prioritized in the cache.
  • Dynamic Submodular Eviction: The method frames cache management as an optimization problem with a submodular objective function F(S)F(S) that quantifies the importance of a token set SS: F(S)=iSAiF(S) = \sum_{i \in S} A_{i} where AiA_i is the accumulated attention score for token ii. The cache StS_t is updated by: St=argmaxSSt1{i},SkF(S)S_t = \text{argmax}_{S \subseteq S_{t-1} \cup \{i\}, |S| \leq k} \, F(S) ensuring that at most one token is evicted per step. This greedy algorithm is computationally efficient and guarantees near-optimal performance under submodular constraints.
  • Results: Achieves 5× reduction in KV cache size with negligible accuracy loss and up to 29× throughput improvement.

2) StreamLLM

  • The authors observe the phenomenon of Attention Sinks: Initial tokens that act as natural attention anchors during decoding
    • Without these attention sink tokens, the performance of naive window attention drops
  • Based on that observation, they introduce a Rolling Cache for recent context with retained initial tokens, enabling infinite-length sequence processing.
  • They show that these sink tokens can also be trained; serving as dedicated attention anchors, reducing reliance on multiple initial tokens.

3) Value-Aware Token Pruning (VATP)

VATP extends H2O's token importance concept by considering both attention patterns and value vector properties:

  • Importance Scoring: Combines attention scores with value vector information: Ikt=Sktvk1,Skt=kjtaj,kI_k^t = S_k^t \cdot \|v_k\|_1, \quad S_k^t = \sum_{k \leq j \leq t} a_{j,k} where SktS_k^t is the accumulated attention score and vk1\|v_k\|_1 is the value vector's L1 norm.
  • Token Pruning: Tokens are ranked by IktI_k^t, and those with the lowest scores are pruned, while attention sink tokens (e.g., start or newline tokens) are preserved to prevent performance degradation.
  • Performance and Efficiency:
    • Outperforms baselines like H2O and Scissorhands in 12–14 out of 16 LongBench tasks.
    • Achieves effective 50% compression with minimal performance loss.
    • Introduces negligible computational overhead and is compatible with FlashAttention when integrated with Scissorhands.

Post-hoc Compression Techniques

These methods compress or optimize the KV cache while preserving the standard transformer architecture.

4) Adaptive KV Compression (FastGen)

FastGen introduces adaptive compression based on attention patterns observed at run-time:

  • Attention Profiling: during prompt encoding, FastGen identifies attention patterns and selects compression policies CC^* that minimize memory cost while preserving attention recovery: C=argminCCCacheMemoryCost(C)s.t.Asoftmax(QKCT)1T.C^* = \arg\min_{C \in \mathcal{C}} \text{CacheMemoryCost}(C) \quad \text{s.t.} \quad \|A - \text{softmax}(QK_C^T)\| \leq 1 - T.
  • Adaptive Compression Policies:
    • Compression strategies include:
      • Special Tokens (CspecialC_{\text{special}}): Retain only special tokens.
      • Locality (ClocalC_{\text{local}}): Evict tokens beyond a relative distance rlr_l.
      • Frequency (CfrequentC_{\text{frequent}}): Keep tokens with high cumulative attention scores (rfr_f).
      • Hybrid Policies combine strategies, starting with CspecialC_{\text{special}}, and applies them adaptively to each head: C={Cspecial,Cspecial+Cpunct,,Cfull}.\mathcal{C} = \{C_{\text{special}}, C_{\text{special}} + C_{\text{punct}}, \ldots, C_{\text{full}}\}.
  1. Token Generation:
    • During decoding, pre-selected compression policies manage the KV cache efficiently: KCi,VCi=f(K,V,Ci).K_{C_i}, V_{C_i} = f(K, V, C_i).

5) Dynamic Memory Compression (DMC)

DMC introduces adaptive token merging:

  • Decision Mechanism: At time tt, predicts merge decisions αt\alpha_t and weights ωt\omega_t: αt=sigmoid(kt[0]),ωt=sigmoid(qt[0]).\alpha_t = \lfloor \text{sigmoid}(k_t[0]) \rceil, \quad \omega_t = \text{sigmoid}(q_t[0]).
  • Weighted Merging: When αt=1\alpha_t = 1, merges current and previous entries: k=ωtkt+zt1kt1ωt+zt1,v=ωtvt+zt1vt1ωt+zt1,k' = \frac{\omega_t k_t + z_{t-1} k_{t-1}}{\omega_t + z_{t-1}}, \quad v' = \frac{\omega_t v_t + z_{t-1} v_{t-1}}{\omega_t + z_{t-1}}, where zt=zt1+ωtz_t = z_{t-1} + \omega_t accumulates importance weights.
  • Training:
    • Uses a Gumbel-Sigmoid relaxation for αt\alpha_t to allow end-to-end training with gradient descent: αtGumbel-Sigmoid(kt[0],τ),\alpha_t \sim \text{Gumbel-Sigmoid}(k_t[0], \tau), where τ\tau is a temperature parameter.
    • Optimizes a combined objective: L=LLM+λmax(0,nCRtαt),\mathcal{L} = \mathcal{L}_{\text{LM}} + \lambda \max\left(0, \frac{n}{\text{CR}} - \sum_{t} \alpha_t \right), where LLM\mathcal{L}_{\text{LM}} is the language modeling loss, and the second term encourages the model to match a target compression ratio (CR).
  • Results: Up to 8× compression with maintained performance.

6) L2L_2 Norm-Based Compression

This paper presents a surprising observation: A clear correlation between the L2L_2 norm and the attention scores over cached KV pairs, where a low L2L_2 norm of a key embedding usually leads to a high attention score during decoding. Consequently, they introduce a simple but effective compression objective:

  • Norm-Based Selection: For a set of cached keys K={k1,k2,,kn}K = \{k_1, k_2, \dots, k_n\}, computes and sorts key norms: ki2=j=1dki,j2\|k_i\|_2 = \sqrt{\sum_{j=1}^d k_{i,j}^2}
  • Sorting and Selection: To compress the KV cache, sort all keys by their L2 norm values: Ksorted=Sort({k12,k22,,kn2})K_{\text{sorted}} = \text{Sort}\big(\{\|k_1\|_2, \|k_2\|_2, \dots, \|k_n\|_2\}\big) Retain the top-mm keys with lowest norms, where m=cnm = \lfloor c \cdot n \rfloor and cc is the compression ratio.
  • Compressed Cache: The compressed key-value cache consists of: Kcompressed={kiki2Ksorted[1:m]},Vcompressed={vikiKcompressed}K_{\text{compressed}} = \{k_i \mid \|k_i\|_2 \in K_{\text{sorted}}[1:m]\}, \quad V_{\text{compressed}} = \{v_i \mid k_i \in K_{\text{compressed}}\}
  • Due to its simplicity, this approach maintains compatibility with FlashAttention.

Architectural Redesigns

These approaches change the Transformers architecture to handle KV caches more efficiently, often incorporating compression directly into the architecture.

7) Multi-Query Attention (MQA)

  • Key Idea: MQA reduces the KV cache size by sharing a single key-value head across all query heads, replacing the traditional Multi-Head Attention (MHA): K=XWK,V=XWV,K = XW_K, \quad V = XW_V, where KK and VV are the shared key and value projections.
  • Benefits: Reduces the KV cache size by a factor of HH (the number of attention heads), significantly lowering memory bandwidth overhead.
  • Trade-Off: While MQA is faster, it often suffers from quality degradation, especially in tasks requiring diverse attention patterns.

8) Group-Query Attention (GQA)

  • Key Idea: GQA interpolates between full multi-head attention and MQA to offering a scalable trade-off between inference speed and model quality. It divides query heads into GG groups, where each group shares a single key-value head: Kgroup=1GhGKh,Vgroup=1GhGVhK_{\text{group}} = \frac{1}{|G|} \sum_{h \in G} K_h, \quad V_{\text{group}} = \frac{1}{|G|} \sum_{h \in G} V_h
    • GQA-1: Equivalent to MQA (G=1G = 1 ).
    • GQA-HH : Equivalent to MHA (G=HG = H ).
  • Uptraining: GQA can be introduced to existing pre-trained models through fine-tuning:
    • First, convert MHA checkpoints to GQA by mean pooling key and value heads into groups
    • Then fine-tune ("uptrain") the model briefly to adapt to the new attention pattern
    • This adaptation process requires only 5% of the original pre-training compute, making it very efficient
    • The resulting model maintains quality while gaining GQA's memory benefits

9) Multi-head Latent Attention (MLA)

DeepSeek's Multi-Head Latent Attention (MLA) takes a novel approach to reducing KV cache overhead. While MQA and GQA achieve this through head-sharing, MLA instead employs a low-rank latent compression technique that maintains the benefits of multiple attention heads.

  • MLA reduces KV cache size by compressing keys and values into low-dimensional latent vectors before reconstruction.
  • It down-project key-value embeddings into a compressed latent space: cKV,t=WDKVht,kC=WUKcKV,t,vC=WUVcKV,tc_{\text{KV}, t} = W_{\text{DKV}} h_t, \quad k_C = W_{\text{UK}} c_{\text{KV}, t}, \quad v_C = W_{\text{UV}} c_{\text{KV}, t} where WDKVW_{\text{DKV}} is the down-projection matrix, and WUKW_{\text{UK}}, WUVW_{\text{UV}} are up-projection matrices for keys and values.
  • It retains per-head flexibility through compressed representations, unlike MQA's complete head sharing.
  • It introduces Rotary Positional Embeddings (RoPE) for decoupling position-aware keys: kR=RoPE(WKRht),kt=[kC;kR]k_R = \text{RoPE}(W_{KR} h_t), \quad k_t = [k_C; k_R] This reduces KV cache storage further by caching only compressed latent vectors cKVc_{\text{KV}} and positional keys kRk_R.

10) SnapKV

  • SnapKV introduces an Observation Window: Uses end-of-prompt tokens to identify attention patterns: C=i=0LobsWobs[:,i,:],I=Topk(C,k)C = \sum_{i=0}^{L_{\text{obs}}} W_{\text{obs}}[:, i, :], \quad I = \text{Top}_k(C, k) where WobsW_{\text{obs}} represents the attention weights, and kk is determined by the compression rate.
  • Compression: Clusters features around the selected positions using a pooling layer to preserve context completeness.

11) You Only Cache Once (YOCO)

YOCO modifies the transformer architecture for caching:

  • Global Cache: Uses a decoder-decoder design with a single shared KV cache.
  • Complexity Reduction: Reduces memory from O(N×L)O(N \times L) to O(N+L)O(N + L), where NN is sequence length and LL is the number of layers.
  • Efficient Attention: The self-decoder employs sliding-window attention or gated retention, enabling constant memory usage (O(C)O(C), where CC is a small window size).

Conclusion

Key-Value caching techniques are central to scaling and optimizing Transformer-based models for real-world use. Innovations like dynamic eviction, compression, and structured approximations continue to push the boundaries on what is possible in long-context or resource-constrained scenarios. KV caching remains a lively research area, offering both theoretical insights and practical improvements.

PS: This blog post is mostly AI-generated using a PySpur workflow with minor human edits.

Ready to get started?

Support us by starring our repository.