Overview
When a Transformer (especially in decoder-only architectures like GPT-style models) processes a long context, it constructs a KV cache that grows with the sequence length. Over very long sequences or with large models, the memory consumed by this cache can become too large for a single GPU. Distributing these tensors across multiple GPUs helps keep memory usage per GPU within acceptable limits.
We’ll cover:
- Basic concepts of storing KV caches.
- How to split the KV cache across devices.
- How to handle indexing and gather/synchronize for attention operations.
- A toy code example using PyTorch.
At a high level, you have two main strategies for distributing a KV cache:
- Splitting by batch or microbatch dimension.
- Splitting by hidden dimensions (e.g., heads).
The details can vary depending on whether you’re doing data parallel, model parallel, or pipeline parallel training. This tutorial focuses on a conceptual example in PyTorch that demonstrates the logic of distributing the cache, rather than an integration with a specific framework (like DeepSpeed or Megatron-LM). However, the fundamental ideas are similar.
1. KV Cache Basics
Recall that for each attention head, the model stores:
In a decoder-only Transformer, each new token can attend to all previously generated tokens, so seq_len
keeps increasing as tokens are generated. Hence, K
and V
get larger.
If your model is loaded onto multiple GPUs (say n_gpus
devices), we want to place fragments of K
and V
on each device to spread out memory usage.
2. Splitting the KV Cache
2.1 By batch dimension
A straightforward approach is to distribute by the batch dimension (e.g., if you have multiple sequences in a batch, each GPU holds the KV cache for its subset of the sequences).
- Pros: Simple indexing. Each GPU can handle separate slices of the batch.
- Cons: If you are only generating a single long sequence (or a small batch), this doesn’t necessarily help.
2.2 By hidden dimension (heads)
Another approach is to chunk the heads across GPUs (model parallel style). For example, if you have 12 heads and 2 GPUs, GPU0 can store heads [0..5]
while GPU1 stores heads [6..11]
.
- Pros: You can distribute the memory for a single sequence, which is handy when generating long sequences one at a time.
- Cons: You need to aggregate the attention outputs from all heads, so some cross-device communication is required.
2.3 By sequence length
Less common in practice, but can be done if you keep partial sequences on each device. However, this approach requires careful attention indexing across devices and can be tricky in practice.
3. Handling Indexing and Gather Operations
If you are splitting by the batch dimension (for multi-batch generation):
- Forward pass: Only the tokens in the sub-batch on each GPU attend to the part of the KV cache stored on the same GPU.
- No immediate gather needed: Each sub-batch only needs its chunk of K and V.
- Final results: Combine or gather final predictions across GPUs.
If you are splitting by heads (model parallel style):
- Storing: You maintain a chunk of heads
[head_start:head_end]
on each GPU. - Forward pass: Each GPU performs attention for its subset of heads using its own K, V.
- Attention combination: You need an all-reduce or gather step to combine the partial attentions from all GPUs into a final output.
- For self-attention, the aggregated result is the sum (weighted by softmax) across all heads.
- For multi-head attention, you typically concatenate the heads along the head dimension, then project. Make sure to properly reduce or gather them.
4. A Toy Example in PyTorch
Below is a simplistic example demonstrating:
- A minimal transformer block with attention.
- Splitting the KV cache by heads across 2 GPUs (
cuda:0
andcuda:1
). - Forward pass that collects partial attention outputs.
Note: This code is illustrative and omits many real-world complexities (e.g., layer norms, multi-layer stack, etc.). It should give you a blueprint for how you might adapt your own code.
Explanation of Key Steps:
-
Splitting Heads We define a
split_size = num_heads // 2
(for 2 GPUs). Realistically, you would computesplit_size
for however many GPUs you have or if you have more heads than GPUs, or vice versa, adapt accordingly. -
Moving Tensors to the Correct GPU
Q_gpu0
,K_gpu0
,V_gpu0
go tocuda:0
.Q_gpu1
,K_gpu1
,V_gpu1
go tocuda:1
.
-
Appending to the KV Cache
- We keep the existing KV cache in
kv_cache['key'][gpu_index]
andkv_cache['value'][gpu_index]
. - For each new token (or set of tokens), we append along the sequence dimension.
- We keep the existing KV cache in
-
Attention Computation
- Each GPU computes attention scores only for the heads it stores.
- Then we bring those partial context vectors back onto one device (in this example,
cuda:0
) and concatenate.
-
Combining Results
- We transpose back and reshape so that we have
(batch_size, seq_len, d_model)
. - Finally, we apply the output projection (
self.out_proj
).
- We transpose back and reshape so that we have
5. Considerations for Production
-
Performance vs. Complexity:
- Splitting across multiple GPUs helps memory, but cross-device communication can become a bottleneck.
- In production setups, you typically use existing distributed frameworks (e.g., DeepSpeed ZeRO stage 3, Megatron-LM model parallel, or PyTorch’s DistributedDataParallel with sharded tensors) rather than hand-implementing everything.
-
Synchronization:
- Be mindful of calls like
torch.cuda.synchronize()
or collecting outputs to the CPU. They can degrade performance if used naively. - Asynchronous GPU operations can help.
- Be mindful of calls like
-
Batch vs. Single-Sequence:
- For multi-batch scenarios, it might be simpler to just store each slice of the batch on a different GPU.
- For single-sequence generation, splitting across heads (or sequence length) is often how large language models implement tensor parallelism.
-
Indexing:
- If you skip or remove tokens, or reorder sequences across GPUs, your code must carefully maintain the correct index references.
-
Integration with Hugging Face:
- If using
transformers
, you can manually manipulate thepast_key_values
arguments or adapt internal classes by hooking into their generation loop. - Or rely on the built-in pipeline parallel and sharded support in libraries like Accelerate, Deepspeed, or FullyShardedDataParallel in PyTorch.
- If using
Conclusion
Splitting the KV cache across multiple GPUs allows you to handle longer context lengths and reduce per-GPU memory load. The key points are:
- Decide how to split: by batch dimension or by heads (model parallel).
- Implement consistent indexing: each GPU’s cache must stay in sync with the correct token positions.
- Aggregate results: attention outputs need to be combined across all GPUs.
The code sample above demonstrates the essential logic in a toy scenario. In a real setup, you’d integrate this approach into a distributed training/inference framework that handles the nitty-gritty details of parallelism and communication for you.