Overview:
- Introduction: We'll explore how Key-Value (KV) caches make language models like ChatGPT and DeepSeek faster at generating text, by making a clever trade-off between memory usage and computation time.
- MLA and other Tricks: We'll then look at 11 recent research papers, including DeepSeek's Multi-head Latent Attention (MLA), that build upon this basic idea to make LLM inference even more time-efficient.
Understanding the Problem: Why Text Generation is Slow
Let's start with a simple analogy. Imagine you're writing a story, and for each new word you write, you need to re-read the entire story so far to maintain consistency. The longer your story gets, the more time you spend re-reading. This is exactly what large language models face during text generation.
The Basic Building Block: Self-Attention
At the heart of modern language models is a mechanism called self-attention. For a sequence of tokens (think of tokens as roughly corresponding to words), each token needs to "look at" or "attend to" all other tokens to understand the context.
This looking-at-everything process has a computational cost that grows with the sequence length:
- For tokens, each token needs to look at all tokens
- This means the cost is proportional to
- In mathematical notation, we write this as complexity
The Real Problem: Generating Text One Token at a Time
When a language model generates text, it does so one token at a time, and this is where things get computationally expensive:
- First token: Look at 1 token (cost: )
- Second token: Look at 2 tokens (cost: )
- Third token: Look at 3 tokens (cost: )
- And so on until the -th token: Look at tokens (cost: )
If we add up all these costs for generating a sequence of length , we get:
This cost means that as your text gets longer, the generation time grows extremely quickly. For example, generating a sequence twice as long takes roughly eight times as long! Clearly, we need a better approach.
The Solution: Key-Value (KV) Cache
The key insight behind KV caching is that we're doing a lot of redundant work. When generating each new token, we're recomputing things for all previous tokens that we've already processed before. Let's see how we can fix this.
What is a Key-Value Cache?
Think of a KV cache like a smart notepad where we write down important information about each token the first time we see it. For each token, we compute and store two things:
- A key (): Think of this as an addressing mechanism - it helps determine how relevant this token is to future tokens
- A value (): Think of this as the actual information that gets used when this token is found to be relevant
Mathematically, we compute these as:
- Key: (where is the token and is a learned transformation)
- Value: (where is another learned transformation)
When generating a new token, we use its query (computed similarly to keys) to find relevant information in our cache by comparing it with all stored keys. The matching values are then used to help generate the token.
How the KV Cache Makes Things Faster
With a KV cache, the process becomes much more efficient:
- When we see a new token, we only need to compute its key and value once
- For all future tokens, we can just look up these pre-computed values from our cache
- This means each new token only needs to do a small amount of new work, instead of redoing all previous computations
The trade-off is clear:
- We use more memory to store all the keys and values. For a model with:
- layers
- attention heads
- Sequence length
- Key/value dimension The total memory cost is values (the factor of 2 accounts for both keys and values). This grows linearly with sequence length (), but the constant factors can be substantial for large models.
- But in return, we reduce the computation cost from to
To understand why it's , let's look at the cost at each step:
- Step 1: Process 1 token → cost
- Step 2: Process 1 new token + look at 1 cached token → cost
- Step 3: Process 1 new token + look at 2 cached tokens → cost
- And so on...
Adding these up:
This is a dramatic improvement over ! While we still have to do the fundamental work of looking at all previous tokens (), we avoid the costly recomputation at each step.
The Memory Challenge: Why We Need Better Solutions
While KV cache is a powerful optimization, it comes with a significant memory cost. Let's look at a concrete example using a modern large language model like Llama3 70B with:
- layers
- attention heads
- batch size of 8 sequences
- key/value dimension
- 16-bit precision
The memory required for a batch of 8 sequences of 1000 tokens each would be:
This substantial memory usage creates several challenges:
- Scales linearly with sequence length
- Multiplies with batch size for parallel processing
- Limits the maximum context length we can handle
- Constrains deployment on memory-limited devices
These challenges have sparked a wave of innovation in the research community, leading to various techniques for optimizing KV cache usage. Let's explore these cutting-edge solutions.
Can we improve over naive KV caches?
The following papers represent key innovations in KV cache optimization. We'll explore them through three main approaches: token selection, post-hoc compression techniques, and architectural redesigns.
Token Selection and Pruning Approaches
1) Heavy-Hitter Oracle (H2O)
H2O introduces the concept of identifying and preserving important tokens in the KV cache:
- Heavy-Hitter Tokens: H2O identifies tokens with the highest accumulated attention scores during generation, following a power-law distribution. These tokens are critical for model functionality and are prioritized in the cache.
- Dynamic Submodular Eviction: The method frames cache management as an optimization problem with a submodular objective function that quantifies the importance of a token set : where is the accumulated attention score for token . The cache is updated by: ensuring that at most one token is evicted per step. This greedy algorithm is computationally efficient and guarantees near-optimal performance under submodular constraints.
- Results: Achieves 5× reduction in KV cache size with negligible accuracy loss and up to 29× throughput improvement.
2) StreamLLM
- The authors observe the phenomenon of Attention Sinks: Initial tokens that act as natural attention anchors during decoding
- Without these attention sink tokens, the performance of naive window attention drops
- Based on that observation, they introduce a Rolling Cache for recent context with retained initial tokens, enabling infinite-length sequence processing.
- They show that these sink tokens can also be trained; serving as dedicated attention anchors, reducing reliance on multiple initial tokens.
3) Value-Aware Token Pruning (VATP)
VATP extends H2O's token importance concept by considering both attention patterns and value vector properties:
- Importance Scoring: Combines attention scores with value vector information: where is the accumulated attention score and is the value vector's L1 norm.
- Token Pruning: Tokens are ranked by , and those with the lowest scores are pruned, while attention sink tokens (e.g., start or newline tokens) are preserved to prevent performance degradation.
- Performance and Efficiency:
- Outperforms baselines like H2O and Scissorhands in 12–14 out of 16 LongBench tasks.
- Achieves effective 50% compression with minimal performance loss.
- Introduces negligible computational overhead and is compatible with FlashAttention when integrated with Scissorhands.
Post-hoc Compression Techniques
These methods compress or optimize the KV cache while preserving the standard transformer architecture.
4) Adaptive KV Compression (FastGen)
FastGen introduces adaptive compression based on attention patterns observed at run-time:
- Attention Profiling: during prompt encoding, FastGen identifies attention patterns and selects compression policies that minimize memory cost while preserving attention recovery:
- Adaptive Compression Policies:
- Compression strategies include:
- Special Tokens (): Retain only special tokens.
- Locality (): Evict tokens beyond a relative distance .
- Frequency (): Keep tokens with high cumulative attention scores ().
- Hybrid Policies combine strategies, starting with , and applies them adaptively to each head:
- Compression strategies include:
- Token Generation:
- During decoding, pre-selected compression policies manage the KV cache efficiently:
5) Dynamic Memory Compression (DMC)
DMC introduces adaptive token merging:
- Decision Mechanism: At time , predicts merge decisions and weights :
- Weighted Merging: When , merges current and previous entries: where accumulates importance weights.
- Training:
- Uses a Gumbel-Sigmoid relaxation for to allow end-to-end training with gradient descent: where is a temperature parameter.
- Optimizes a combined objective: where is the language modeling loss, and the second term encourages the model to match a target compression ratio (CR).
- Results: Up to 8× compression with maintained performance.
6) Norm-Based Compression
This paper presents a surprising observation: A clear correlation between the norm and the attention scores over cached KV pairs, where a low norm of a key embedding usually leads to a high attention score during decoding. Consequently, they introduce a simple but effective compression objective:
- Norm-Based Selection: For a set of cached keys , computes and sorts key norms:
- Sorting and Selection: To compress the KV cache, sort all keys by their L2 norm values: Retain the top- keys with lowest norms, where and is the compression ratio.
- Compressed Cache: The compressed key-value cache consists of:
- Due to its simplicity, this approach maintains compatibility with FlashAttention.
Architectural Redesigns
These approaches change the Transformers architecture to handle KV caches more efficiently, often incorporating compression directly into the architecture.
7) Multi-Query Attention (MQA)
- Key Idea: MQA reduces the KV cache size by sharing a single key-value head across all query heads, replacing the traditional Multi-Head Attention (MHA): where and are the shared key and value projections.
- Benefits: Reduces the KV cache size by a factor of (the number of attention heads), significantly lowering memory bandwidth overhead.
- Trade-Off: While MQA is faster, it often suffers from quality degradation, especially in tasks requiring diverse attention patterns.
8) Group-Query Attention (GQA)
- Key Idea: GQA interpolates between full multi-head attention and MQA to offering a scalable trade-off between inference speed and model quality. It divides query heads into groups, where each group shares a single key-value head:
- GQA-1: Equivalent to MQA ().
- GQA-: Equivalent to MHA ().
- Uptraining: GQA can be introduced to existing pre-trained models through fine-tuning:
- First, convert MHA checkpoints to GQA by mean pooling key and value heads into groups
- Then fine-tune ("uptrain") the model briefly to adapt to the new attention pattern
- This adaptation process requires only 5% of the original pre-training compute, making it very efficient
- The resulting model maintains quality while gaining GQA's memory benefits
9) Multi-head Latent Attention (MLA)
DeepSeek's Multi-Head Latent Attention (MLA) takes a novel approach to reducing KV cache overhead. While MQA and GQA achieve this through head-sharing, MLA instead employs a low-rank latent compression technique that maintains the benefits of multiple attention heads.
- MLA reduces KV cache size by compressing keys and values into low-dimensional latent vectors before reconstruction.
- It down-project key-value embeddings into a compressed latent space: where is the down-projection matrix, and , are up-projection matrices for keys and values.
- It retains per-head flexibility through compressed representations, unlike MQA's complete head sharing.
- It introduces Rotary Positional Embeddings (RoPE) for decoupling position-aware keys: This reduces KV cache storage further by caching only compressed latent vectors and positional keys .
10) SnapKV
- SnapKV introduces an Observation Window: Uses end-of-prompt tokens to identify attention patterns: where represents the attention weights, and is determined by the compression rate.
- Compression: Clusters features around the selected positions using a pooling layer to preserve context completeness.
11) You Only Cache Once (YOCO)
YOCO modifies the transformer architecture for caching:
- Global Cache: Uses a decoder-decoder design with a single shared KV cache.
- Complexity Reduction: Reduces memory from to , where is sequence length and is the number of layers.
- Efficient Attention: The self-decoder employs sliding-window attention or gated retention, enabling constant memory usage (, where is a small window size).
Conclusion
Key-Value caching techniques are central to scaling and optimizing Transformer-based models for real-world use. Innovations like dynamic eviction, compression, and structured approximations continue to push the boundaries on what is possible in long-context or resource-constrained scenarios. KV caching remains a lively research area, offering both theoretical insights and practical improvements.
PS: This blog post is mostly AI-generated using a PySpur workflow with minor human edits.