KV Cache: Optimización de Latencia y Memoria en Transformers Autoregresivos

Análisis técnico del mecanismo de Key-Value (KV) Cache en la inferencia de Large Language Models (LLMs). Se detalla la reducción de FLOPs redundantes durante la decodificación autoregresiva, el impacto en el consumo de VRAM y las implicaciones de ancho de banda.

Contexto técnico

La generación de texto en modelos basados en la arquitectura Transformer (como GPT-4, Llama 2 o Mistral) es un proceso autoregresivo. El modelo predice el siguiente token $x_{t}$ basándose en la secuencia de entrada completa $x_{1}, ..., x_{t-1}$.

En una implementación ingenua (stateless), para generar el token $x_{t}$, el modelo debe procesar nuevamente todos los tokens previos para calcular sus representaciones internas. Esto implica una redundancia computacional masiva: las proyecciones de Keys y Values para los tokens $x_{1}$ a $x_{t-1}$ ya fueron calculadas en pasos anteriores y permanecen invariantes (en arquitecturas causal decoder-only).

El KV Cache aborda este problema almacenando en memoria (VRAM) los tensores de Key y Value de los tokens pasados, transformando la complejidad de la atención de $O(N^2)$ a $O(N)$ por paso de generación, a costa de un incremento lineal en el uso de memoria. Esta técnica es el estándar de facto en frameworks de inferencia como vLLM, TGI y TensorRT-LLM.

El siguiente diagrama ilustra el flujo de datos durante un paso de generación con KV Cache:


Fundamentos matemáticos

El núcleo del mecanismo de atención escalada (Scaled Dot-Product Attention) para una sola cabeza se define como:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Donde $Q$ (Query), $K$ (Key) y $V$ (Value) son proyecciones de la entrada.

En el paso de tiempo $t$, queremos calcular la salida para el nuevo token $x_t$.
Sin caché, calculamos $Q_t, K_{1:t}, V_{1:t}$ re-procesando toda la secuencia.

Con KV Cache, aprovechamos que $K_{1:t-1}$ y $V_{1:t-1}$ ya fueron calculados. Solo necesitamos calcular la proyección para el token actual $x_t$:

$$q_t = x_t W_Q, \quad k_t = x_t W_K, \quad v_t = x_t W_V$$

Luego, concatenamos este nuevo $k_t$ y $v_t$ con la caché existente:

$$K_{cache}^{(t)} = [K_{cache}^{(t-1)}; k_t]$$$$V_{cache}^{(t)} = [V_{cache}^{(t-1)}; v_t]$$

La atención se calcula entonces usando solo el query actual $q_t$ contra toda la historia almacenada:

$$\text{Attention}(q_t, K_{cache}^{(t)}, V_{cache}^{(t)}) = \text{softmax}\left(\frac{q_t (K_{cache}^{(t)})^T}{\sqrt{d_k}}\right) V_{cache}^{(t)}$$

Esto elimina la necesidad de realizar multiplicaciones de matrices para obtener $K$ y $V$ de los tokens anteriores, reduciendo las operaciones de punto flotante (FLOPs) drásticamente.

La siguiente comparativa visualiza la reducción de cómputo que logra el KV Cache:


Implementación práctica

A continuación se presenta una implementación simplificada en Python/PyTorch ilustrando la lógica de actualización del caché durante un bucle de generación.

import torch
import torch.nn.functional as F

def scaled_dot_product_attention_with_cache(
    query,          # [Batch, 1, Dim] - Solo el token actual
    key,            # [Batch, 1, Dim] - Solo el token actual
    value,          # [Batch, 1, Dim] - Solo el token actual
    kv_cache=None,  # Tupla (past_keys, past_values)
    scale=None
):
    # Si existe caché, concatenamos a lo largo de la dimensión de secuencia (dim=1)
    if kv_cache is not None:
        past_key, past_value = kv_cache
        key = torch.cat([past_key, key], dim=1)
        value = torch.cat([past_value, value], dim=1)
    
    # Actualizamos el caché para el siguiente paso
    current_cache = (key, value)
    
    # Cálculo de atención estándar
    # query: [B, 1, D], key.transpose: [B, D, Seq_len]
    scores = torch.matmul(query, key.transpose(-2, -1))
    if scale:
        scores = scores / scale
        
    attn_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, value)
    
    return output, current_cache

# Simulación de bucle de decodificación
def generate_step(model_layer, input_token, cache):
    # Proyecciones lineales (simplificadas)
    q = model_layer.W_q(input_token)
    k = model_layer.W_k(input_token)
    v = model_layer.W_v(input_token)
    
    # Atención con gestión de caché
    attn_out, new_cache = scaled_dot_product_attention_with_cache(
        q, k, v, kv_cache=cache, scale=model_layer.scale
    )
    
    return attn_out, new_cache


Análisis de comportamiento

El uso de KV Cache altera fundamentalmente el perfil de recursos de la inferencia:

  1. Shift de Compute-Bound a Memory-Bound:
    Sin caché, la operación está dominada por la multiplicación de matrices (Compute-Bound) al recalcular todo. Con caché, el cómputo disminuye, pero la necesidad de leer grandes tensores de KV desde la memoria VRAM en cada paso convierte el proceso en Memory Bandwidth Bound. La velocidad de generación (tokens/segundo) depende más del ancho de banda de la memoria (GB/s) que de los TFLOPS de la GPU.
  2. Consumo de Memoria VRAM:
    El tamaño del caché crece linealmente con la longitud de la secuencia y el tamaño del batch. La fórmula aproximada para el uso de memoria (en bytes) por token es:

$$\text{Mem}_{KV} = 2 \times \text{Batch} \times \text{Layers} \times \text{Heads} \times \text{Head\_Dim} \times \text{Precision (bytes)}$$

Para un modelo como Llama-2-70B (GQA, float16) con contexto largo, el KV Cache puede ocupar decenas de gigabytes, compitiendo por espacio con los pesos del modelo.
3. Latencia (Time to First Token vs. Inter-token Latency):
KV Cache no mejora el "Time to First Token" (fase de prefill), ya que ahí se procesa el prompt completo en paralelo. Su impacto es crítico en la latencia inter-token (fase de decode), donde reduce el tiempo de generación en órdenes de magnitud para secuencias largas.

El uso de KV Cache altera el perfil de recursos de la inferencia:


Limitaciones y casos donde no conviene usarlo

  • Contexto Extenso y OOM (Out of Memory): En secuencias muy largas (e.g., 128k tokens), el KV Cache puede crecer hasta superar la capacidad de la VRAM, forzando la reducción del batch size a 1 o causando errores de memoria.
  • Fragmentación de Memoria: Los tensores de caché crecen dinámicamente. En implementaciones ingenuas, esto causa fragmentación externa en la VRAM, desperdiciando espacio reservado pero no contiguo (problema abordado por técnicas avanzadas como PagedAttention).
  • Beam Search: Algoritmos que mantienen múltiples hipótesis (beams) requieren mantener múltiples cachés o gestionar bifurcaciones complejas, multiplicando el consumo de memoria por el factor del beam width.
  • Entrenamiento: Durante el entrenamiento no se utiliza KV Cache, ya que se emplea Teacher Forcing y paralelización completa de la atención (causal masking), no generación secuencial.