PySpur

Splitting KV Cache To Multiple GPUs

Splitting KV Cache To Multiple GPUs

Overview

When a Transformer (especially in decoder-only architectures like GPT-style models) processes a long context, it constructs a KV cache that grows with the sequence length. Over very long sequences or with large models, the memory consumed by this cache can become too large for a single GPU. Distributing these tensors across multiple GPUs helps keep memory usage per GPU within acceptable limits.

We’ll cover:

  1. Basic concepts of storing KV caches.
  2. How to split the KV cache across devices.
  3. How to handle indexing and gather/synchronize for attention operations.
  4. A toy code example using PyTorch.

At a high level, you have two main strategies for distributing a KV cache:

  1. Splitting by batch or microbatch dimension.
  2. Splitting by hidden dimensions (e.g., heads).

The details can vary depending on whether you’re doing data parallel, model parallel, or pipeline parallel training. This tutorial focuses on a conceptual example in PyTorch that demonstrates the logic of distributing the cache, rather than an integration with a specific framework (like DeepSpeed or Megatron-LM). However, the fundamental ideas are similar.


1. KV Cache Basics

Recall that for each attention head, the model stores:

K: (batch_size, num_heads, seq_len, head_dim)
V: (batch_size, num_heads, seq_len, head_dim)

In a decoder-only Transformer, each new token can attend to all previously generated tokens, so seq_len keeps increasing as tokens are generated. Hence, K and V get larger.

If your model is loaded onto multiple GPUs (say n_gpus devices), we want to place fragments of K and V on each device to spread out memory usage.


2. Splitting the KV Cache

2.1 By batch dimension

A straightforward approach is to distribute by the batch dimension (e.g., if you have multiple sequences in a batch, each GPU holds the KV cache for its subset of the sequences).

  • Pros: Simple indexing. Each GPU can handle separate slices of the batch.
  • Cons: If you are only generating a single long sequence (or a small batch), this doesn’t necessarily help.

2.2 By hidden dimension (heads)

Another approach is to chunk the heads across GPUs (model parallel style). For example, if you have 12 heads and 2 GPUs, GPU0 can store heads [0..5] while GPU1 stores heads [6..11].

  • Pros: You can distribute the memory for a single sequence, which is handy when generating long sequences one at a time.
  • Cons: You need to aggregate the attention outputs from all heads, so some cross-device communication is required.

2.3 By sequence length

Less common in practice, but can be done if you keep partial sequences on each device. However, this approach requires careful attention indexing across devices and can be tricky in practice.


3. Handling Indexing and Gather Operations

If you are splitting by the batch dimension (for multi-batch generation):

  1. Forward pass: Only the tokens in the sub-batch on each GPU attend to the part of the KV cache stored on the same GPU.
  2. No immediate gather needed: Each sub-batch only needs its chunk of K and V.
  3. Final results: Combine or gather final predictions across GPUs.

If you are splitting by heads (model parallel style):

  1. Storing: You maintain a chunk of heads [head_start:head_end] on each GPU.
  2. Forward pass: Each GPU performs attention for its subset of heads using its own K, V.
  3. Attention combination: You need an all-reduce or gather step to combine the partial attentions from all GPUs into a final output.
    • For self-attention, the aggregated result is the sum (weighted by softmax) across all heads.
    • For multi-head attention, you typically concatenate the heads along the head dimension, then project. Make sure to properly reduce or gather them.

4. A Toy Example in PyTorch

Below is a simplistic example demonstrating:

  • A minimal transformer block with attention.
  • Splitting the KV cache by heads across 2 GPUs (cuda:0 and cuda:1).
  • Forward pass that collects partial attention outputs.

Note: This code is illustrative and omits many real-world complexities (e.g., layer norms, multi-layer stack, etc.). It should give you a blueprint for how you might adapt your own code.

import torch
import torch.nn as nn
 
# ----------------------------
# 1. Minimal Multi-Head Self-Attention Module
# ----------------------------
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
 
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
 
        # Simple linear projections
        self.query_proj = nn.Linear(d_model, d_model)
        self.key_proj   = nn.Linear(d_model, d_model)
        self.value_proj = nn.Linear(d_model, d_model)
 
        # Output projection
        self.out_proj   = nn.Linear(d_model, d_model)
 
    def forward(self, x, kv_cache, time_step):
        """
        x: (batch_size, seq_len, d_model)
        kv_cache: dict with 'key' and 'value'
                  Each is a list [gpu0_tensor, gpu1_tensor], splitted by heads.
        time_step: the index of the newly generated token (for appending to cache)
        """
 
        bsz, seq_len, _ = x.size()
 
        # Project queries
        Q = self.query_proj(x)  # (bsz, seq_len, d_model)
 
        # Reshape Q: (bsz, seq_len, num_heads, head_dim)
        Q = Q.view(bsz, seq_len, self.num_heads, self.head_dim)
        # Transpose to: (bsz, num_heads, seq_len, head_dim)
        Q = Q.transpose(1, 2)
 
        # We'll split the Q across heads for each GPU as well
        # This example: if num_heads == 8, GPU0 = heads 0..3, GPU1 = heads 4..7
        # Hardcode for 2 GPUs for simplicity
        split_size = self.num_heads // 2
 
        Q_gpu0 = Q[:, :split_size].contiguous().to('cuda:0')
        Q_gpu1 = Q[:, split_size:].contiguous().to('cuda:1')
 
        # We'll do the same for keys and values. We need to build the new K, V
        # for the current time_step. We'll append them to kv_cache accordingly.
 
        K_new = self.key_proj(x)  # (bsz, seq_len, d_model)
        V_new = self.value_proj(x)
        K_new = K_new.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V_new = V_new.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
 
        # Split by heads
        K_gpu0_new = K_new[:, :split_size].contiguous().to('cuda:0')
        V_gpu0_new = V_new[:, :split_size].contiguous().to('cuda:0')
 
        K_gpu1_new = K_new[:, split_size:].contiguous().to('cuda:1')
        V_gpu1_new = V_new[:, split_size:].contiguous().to('cuda:1')
 
        # ---------------
        # Update KV Cache
        # ---------------
        # Suppose each entry in kv_cache['key'] and kv_cache['value'] is shaped:
        #   (bsz, num_heads_chunk, total_seq_len, head_dim)
        # We just append along seq_len dimension at time_step.
 
        # GPU0
        if kv_cache['key'][0] is None:
            kv_cache['key'][0] = K_gpu0_new
            kv_cache['value'][0] = V_gpu0_new
        else:
            # Concat along seq_len dimension
            kv_cache['key'][0] = torch.cat([kv_cache['key'][0], K_gpu0_new], dim=2)
            kv_cache['value'][0] = torch.cat([kv_cache['value'][0], V_gpu0_new], dim=2)
 
        # GPU1
        if kv_cache['key'][1] is None:
            kv_cache['key'][1] = K_gpu1_new
            kv_cache['value'][1] = V_gpu1_new
        else:
            kv_cache['key'][1] = torch.cat([kv_cache['key'][1], K_gpu1_new], dim=2)
            kv_cache['value'][1] = torch.cat([kv_cache['value'][1], V_gpu1_new], dim=2)
 
        # ---------------
        # Compute Attention on each GPU separately
        # ---------------
        # We'll define a helper function
        def attend(Q_sub, K_sub, V_sub):
            # Q_sub: (bsz, chunk_heads, seq_len, head_dim)
            # K_sub: (bsz, chunk_heads, total_seq_len_so_far, head_dim)
            # V_sub: (bsz, chunk_heads, total_seq_len_so_far, head_dim)
 
            attn_scores = torch.matmul(
                Q_sub,
                K_sub.transpose(-2, -1)
            ) / (self.head_dim ** 0.5)  # (bsz, chunk_heads, seq_len, total_seq_len_so_far)
 
            attn_weights = torch.softmax(attn_scores, dim=-1)
            context = torch.matmul(attn_weights, V_sub)  # (bsz, chunk_heads, seq_len, head_dim)
            return context
 
        # GPU0 attention
        context_gpu0 = attend(
            Q_gpu0,
            kv_cache['key'][0],
            kv_cache['value'][0]
        )
        # GPU1 attention
        context_gpu1 = attend(
            Q_gpu1,
            kv_cache['key'][1],
            kv_cache['value'][1]
        )
 
        # Move contexts back to CPU or a single GPU to combine them
        context_gpu0 = context_gpu0.to('cuda:0')
        context_gpu1 = context_gpu1.to('cuda:0')  # gather everything on GPU0 for final
        # Now we can combine along heads dimension
        context_combined = torch.cat([context_gpu0, context_gpu1], dim=1)
 
        # (bsz, num_heads, seq_len, head_dim) -> (bsz, seq_len, num_heads, head_dim)
        context_combined = context_combined.transpose(1, 2).contiguous()
        # Merge heads
        context_combined = context_combined.view(bsz, seq_len, self.d_model)
 
        # Final projection
        out = self.out_proj(context_combined)
 
        return out
 
# ----------------------------
# 2. Example usage
# ----------------------------
def main():
    # Config
    d_model = 8
    num_heads = 4  # Must be divisible by number of GPUs you want to split across
    seq_len = 1
    batch_size = 2
 
    # Dummy tokens
    x = torch.randn(batch_size, seq_len, d_model)
 
    # Initialize attention module
    attn = MultiHeadSelfAttention(d_model, num_heads)
 
    # Move base module to CPU or rank0 GPU – but we handle KV on multiple GPUs
    attn = attn.to('cuda:0')
 
    # Initialize KV cache structure
    # We'll store [kv_gpu0, kv_gpu1]
    # Start with None; they will be set after first pass
    kv_cache = {
        'key':   [None, None],
        'value': [None, None]
    }
 
    # Simulate generating tokens one by one
    # time_step increments each time a new token is generated
    time_steps = 5
    for t in range(time_steps):
        # We'll do a forward pass for the new token(s) at time t
        out = attn(x, kv_cache, time_step=t)
        print(f"Time step: {t}, out shape = {out.shape}")
 
        # For the sake of demonstration, next token embedding is random
        x = torch.randn(batch_size, seq_len, d_model)
 
    print("Final KV cache shapes:")
    for i, (k, v) in enumerate(zip(kv_cache['key'], kv_cache['value'])):
        if k is not None:
            print(f"GPU {i} => K shape = {k.shape}, V shape = {v.shape}")
 
if __name__ == "__main__":
    main()

Explanation of Key Steps:

  1. Splitting Heads We define a split_size = num_heads // 2 (for 2 GPUs). Realistically, you would compute split_size for however many GPUs you have or if you have more heads than GPUs, or vice versa, adapt accordingly.

  2. Moving Tensors to the Correct GPU

    • Q_gpu0, K_gpu0, V_gpu0 go to cuda:0.
    • Q_gpu1, K_gpu1, V_gpu1 go to cuda:1.
  3. Appending to the KV Cache

    • We keep the existing KV cache in kv_cache['key'][gpu_index] and kv_cache['value'][gpu_index].
    • For each new token (or set of tokens), we append along the sequence dimension.
  4. Attention Computation

    • Each GPU computes attention scores only for the heads it stores.
    • Then we bring those partial context vectors back onto one device (in this example, cuda:0) and concatenate.
  5. Combining Results

    • We transpose back and reshape so that we have (batch_size, seq_len, d_model).
    • Finally, we apply the output projection (self.out_proj).

5. Considerations for Production

  1. Performance vs. Complexity:

    • Splitting across multiple GPUs helps memory, but cross-device communication can become a bottleneck.
    • In production setups, you typically use existing distributed frameworks (e.g., DeepSpeed ZeRO stage 3, Megatron-LM model parallel, or PyTorch’s DistributedDataParallel with sharded tensors) rather than hand-implementing everything.
  2. Synchronization:

    • Be mindful of calls like torch.cuda.synchronize() or collecting outputs to the CPU. They can degrade performance if used naively.
    • Asynchronous GPU operations can help.
  3. Batch vs. Single-Sequence:

    • For multi-batch scenarios, it might be simpler to just store each slice of the batch on a different GPU.
    • For single-sequence generation, splitting across heads (or sequence length) is often how large language models implement tensor parallelism.
  4. Indexing:

    • If you skip or remove tokens, or reorder sequences across GPUs, your code must carefully maintain the correct index references.
  5. Integration with Hugging Face:

    • If using transformers, you can manually manipulate the past_key_values arguments or adapt internal classes by hooking into their generation loop.
    • Or rely on the built-in pipeline parallel and sharded support in libraries like Accelerate, Deepspeed, or FullyShardedDataParallel in PyTorch.

Conclusion

Splitting the KV cache across multiple GPUs allows you to handle longer context lengths and reduce per-GPU memory load. The key points are:

  • Decide how to split: by batch dimension or by heads (model parallel).
  • Implement consistent indexing: each GPU’s cache must stay in sync with the correct token positions.
  • Aggregate results: attention outputs need to be combined across all GPUs.

The code sample above demonstrates the essential logic in a toy scenario. In a real setup, you’d integrate this approach into a distributed training/inference framework that handles the nitty-gritty details of parallelism and communication for you.

Ready to get started?

Support us by starring our repository.