Post-Training Quantization en LLMs: Reducción de Precisión de FP16 a INT4
Análisis técnico de estrategias de Post-Training Quantization (PTQ) para LLM, enfocándose en la transición de FP16 a INT4. Se examinan los fundamentos de la cuantización, el impacto en la perplejidad del modelo y la reducción de latencia.
La barrera principal para el despliegue de Large Language Models (LLMs) como LLaMA-65B o Falcon-40B es el requisito de memoria VRAM. Un modelo de 65 mil millones de parámetros en precisión estándar (FP16, 16-bit floating point) requiere aproximadamente 130 GB de memoria solo para almacenar los pesos, excluyendo la sobrecarga del contexto (KV cache) y las activaciones. Esto hace inviable su ejecución en hardware de consumo o servidores single-GPU.
La cuantización (quantization) aborda este problema reduciendo la precisión numérica de los pesos. El estado del arte se ha desplazado del Quantization Aware Training (QAT), costoso e impracticable para modelos fundacionales ya entrenados, hacia el Post-Training Quantization (PTQ). Algoritmos recientes como GPTQ (Generative Pre-trained Transformer Quantization) permiten comprimir pesos a 4 bits con una degradación mínima de la perplejidad, permitiendo la inferencia de modelos masivos en tarjetas como la RTX 3090 o 4090.
Fundamentos matemáticos
La cuantización busca mapear un rango de valores de punto flotante $x_{f}$ (FP16) a un rango entero $x_{q}$ de menor precisión (INT8 o INT4). El enfoque más común es la cuantización uniforme afín (asimétrica).
La operación de cuantización se define como:
$$x_{q} = \text{clamp}\left( \text{round}\left( \frac{x_{f}}{S} + Z \right), q_{min}, q_{max} \right)$$
Donde:
- $S$ (Scale factor) es un escalar positivo de punto flotante.
- $Z$ (Zero-point) es un valor entero que mapea el cero real al dominio cuantizado.
- $[q_{min}, q_{max}]$ define el rango dinámico del tipo de dato objetivo (ej. $[0, 15]$ para INT4 sin signo).
El siguiente diagrama ilustra el proceso de conversión entre precisiones:
La recuperación aproximada del valor (decuantización) se realiza mediante:
$$\hat{x}_{f} = S \cdot (x_{q} - Z)$$
Cálculo de parámetros
Para un tensor de pesos $W$, los parámetros se calculan basándose en sus valores mínimos y máximos observados ($\alpha = \min(W), \beta = \max(W)$):
$$S = \frac{\beta - \alpha}{q_{max} - q_{min}}$$$$Z = \text{round}\left( q_{min} - \frac{\alpha}{S} \right)$$
Optimización del Error (GPTQ)
La cuantización simple (Round-to-Nearest o RTN) introduce un error de cuantización $E = || W - \hat{W} ||^2$. En LLMs, RTN degrada severamente la precisión en 4 bits. GPTQ minimiza el error de salida capa por capa utilizando información de segundo orden (la matriz Hessiana inversa $H^{-1}$):
$$\text{argmin}_{\hat{W}} || WX - \hat{W}X ||^2_2$$
GPTQ actualiza los pesos restantes en una fila basándose en el error introducido por la cuantización de los pesos anteriores, ajustando la compensación mediante la Hessiana inversa.
Implementación práctica
A continuación se ilustra una implementación de referencia en Python para una cuantización simétrica básica (donde $Z=0$), común en implementaciones on-the-fly para reducir la latencia de kernel.
import torch
def quantize_tensor_sym(w, n_bits=4, group_size=128):
"""
Simulación de cuantización simétrica block-wise (simulando GPTQ/GGML logic).
w: Tensor de pesos original (FP16)
n_bits: Bits objetivo (INT4)
group_size: Tamaño del bloque para calcular la escala (granularidad).
"""
# 1. Definir rango máximo entero
max_int = 2**(n_bits - 1) - 1
# 2. Reshape para cuantización por grupos (Block-wise quantization)
# Esto mejora la precisión al tener escalas locales
orig_shape = w.shape
w_reshaped = w.reshape(-1, group_size)
# 3. Calcular escala por bloque: max(abs(w)) / max_int
max_val = w_reshaped.abs().amax(dim=1, keepdim=True)
scale = max_val / max_int
# 4. Cuantizar
# w_q = round(w / scale)
w_q = (w_reshaped / scale).round().clamp(-max_int, max_int)
# 5. Decuantizar (Simulación de lo que ocurre en inferencia)
w_deq = w_q * scale
return w_deq.reshape(orig_shape)
# Ejemplo de uso y medición de error L2
torch.manual_seed(42)
weights = torch.randn(4096, 4096, dtype=torch.float16)
# Cuantización a 4 bits
weights_int4_sim = quantize_tensor_sym(weights, n_bits=4)
# Cálculo de error de reconstrucción
error = torch.linalg.norm(weights - weights_int4_sim) / torch.linalg.norm(weights)
print(f"Error relativo L2 (INT4): {error:.4f}")
En entornos de producción reales (usando librerías como AutoGPTQ o bitsandbytes), la carga del modelo omite la conversión manual y carga directamente los safetensors pre-cuantizados, descomprimiéndolos a FP16 solo en el momento del cálculo matricial dentro del kernel CUDA.
Pipeline de Block-wise Quantization con grupos y escalas locales
Análisis de comportamiento
Estabilidad Numérica y Outliers
La principal anomalía en la cuantización de Transformers son los "outliers" de magnitud extrema en las activaciones (feature emergent descrita por Dettmers et al. en LLM.int8()).
- INT8: Generalmente estable. La pérdida de precisión es despreciable (< 1% degradación en zero-shot).
- INT4: La cuantización RTN (Round-to-Nearest) rompe el modelo. Se requiere GPTQ o AWQ para mantener la coherencia. Los outliers deben manejarse preservándolos en FP16 o utilizando agrupamiento (group-wise quantization) para que no distorsionen la escala $S$ de todo el tensor.
Latencia: Compute-bound vs Memory-bound
La inferencia de LLMs en batch size bajo (1, uso típico de chat) es un proceso memory-bound. La velocidad está limitada por qué tan rápido se pueden mover los pesos de la VRAM a los registros del chip.
- Al reducir los pesos de 16 bits a 4 bits, se reduce el tráfico de memoria en un factor de 4x.
- Aunque existe un overhead computacional para decuantizar (INT4 $\rightarrow$ FP16) antes de la operación matricial, la ganancia en ancho de banda supera el costo de cómputo, resultando en una generación de tokens más rápida.
Consumo de VRAM
Relación directa con el tamaño del modelo:
- FP16: 2 bytes por parámetro.
- INT8: 1 byte por parámetro.
- INT4: 0.5 bytes por parámetro.
Un modelo LLaMA-7B:
- FP16: ~13.5 GB
- INT4: ~3.8 GB (Entra holgadamente en GPUs de 6GB/8GB).
La reducción en consumo de VRAM es directamente proporcional a la reducción de bits por parámetro:
Comparativas
Tabla comparativa de rendimiento en una NVIDIA RTX 3090 (24GB VRAM) sobre el modelo LLaMA-30B.
| Métrica | FP16 (Baseline) | INT8 (bitsandbytes) |
INT4 (GPTQ) |
|---|---|---|---|
| VRAM Ocupada | ~60 GB (OOM)* | ~32 GB (OOM)* | ~17.5 GB |
| Perplejidad (WikiText2) | 4.10 | 4.12 | 4.25 |
| Latencia (token/s) | N/A | N/A | ~18-22 t/s |
*OOM: Out of Memory. El modelo no carga en una sola GPU sin cuantización.
Se observa que INT4 permite ejecutar un modelo que previamente requería múltiples GPUs de grado servidor (A100 80GB o 2x3090 NVLink), con una penalización de perplejidad marginal para tareas generales.
Limitaciones y casos donde no conviene usarlo
- Degradación en Razonamiento Complejo: Aunque la perplejidad se mantenga baja, la cuantización agresiva (INT4 o INT3) afecta desproporcionadamente a la capacidad de razonamiento "paso a paso" (Chain-of-Thought) y a la codificación. Si la precisión lógica es crítica, INT4 puede introducir alucinaciones sutiles.
- Overhead de CPU/Kernel: Si la implementación del kernel de decuantización no está optimizada (ej. implementaciones naive en Python puro), la latencia aumentará en lugar de disminuir. Se depende de kernels Triton o CUDA altamente optimizados.
- No apto para Entrenamiento: La cuantización PTQ descrita es solo para inferencia. Entrenar (Fine-tuning) sobre pesos INT4 requiere técnicas como LoRA (Low-Rank Adaptation) aplicado sobre el modelo base congelado (QLoRA), pero no se puede hacer backpropagation estándar directamente sobre pesos INT4 debido a la naturaleza discreta de la función de redondeo (gradiente cero casi en todas partes).