Mecanismo de Self-Attention: Formulación Matemática y Arquitectura Transformer
Análisis del mecanismo de Scaled Dot-Product Attention propuesto en "Attention Is All You Need". Se detalla el cálculo de matrices Key-Query-Value, la eliminación de la recurrencia para paralelización masiva y el impacto en la complejidad computacional respecto a la longitud de la secuencia.
El procesamiento de secuencias (Seq2Seq) ha estado dominado históricamente por arquitecturas recurrentes (RNN, LSTM, GRU). Estos modelos operan bajo una restricción secuencial inherente: el estado oculto $h_t$ es función de $h_{t-1}$ y la entrada $x_t$. Esta dependencia temporal impide la paralelización del entrenamiento, convirtiendo a la longitud de la secuencia en un cuello de botella computacional.
El mecanismo de Self-Attention, introducido centralmente en la arquitectura Transformer (Vaswani et al., 2017), propone prescindir totalmente de la recurrencia y las convoluciones. El objetivo es modelar dependencias globales entre entradas y salidas independientemente de su distancia en la secuencia, reduciendo el path length de la propagación de la señal a $O(1)$.
En el ecosistema actual, esta arquitectura está desplazando a las LSTM con atención (Bahdanau) en tareas de traducción automática (NMT) y constituency parsing, estableciendo un nuevo estado del arte en datasets como WMT 2014.
Fundamentos matemáticos
El núcleo del modelo es el cálculo de la atención como una función de consulta (Query) sobre un conjunto de pares clave-valor (Key-Value).
Dada una entrada $X$, se proyecta linealmente en tres matrices: Queries ($Q$), Keys ($K$) y Values ($V$). La función de atención, denominada Scaled Dot-Product Attention, se define formalmente como:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$
Donde:
- $Q, K \in \mathbb{R}^{n \times d_k}$ y $V \in \mathbb{R}^{n \times d_v}$.
- $n$ es la longitud de la secuencia.
- $\sqrt{d_k}$ es el factor de escalado.
El siguiente diagrama ilustra el flujo de operaciones del mecanismo Scaled Dot-Product Attention:
Justificación del escalado:
El término $\frac{1}{\sqrt{d_k}}$ es crítico. Para valores grandes de $d_k$, el producto punto $QK^T$ crece en magnitud, empujando a la función Softmax hacia regiones donde los gradientes son extremadamente pequeños (vanishing gradients). El escalado normaliza la varianza del producto punto.
Multi-Head Attention:
Para permitir que el modelo atienda conjuntamente a información de diferentes subespacios de representación, se computan $h$ atenciones en paralelo:
$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W^O$$$$\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$$
Donde $W^O$ es una matriz de proyección final. Esto permite al modelo capturar distintos tipos de relaciones (ej. sintácticas y semánticas) simultáneamente.
La arquitectura Multi-Head Attention permite procesar información en paralelo a través de múltiples subespacios:
Implementación práctica
A continuación, se presenta una implementación vectorizada del mecanismo de atención escalada utilizando PyTorch. Se asume que las proyecciones lineales $W_Q, W_K, W_V$ ya han sido aplicadas antes de llamar a esta función.
import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(query, key, value, mask=None, dropout=None):
"""
Calcula la atención escalada por producto punto.
Args:
query: Tensor de forma [batch_size, heads, len_q, d_k]
key: Tensor de forma [batch_size, heads, len_k, d_k]
value: Tensor de forma [batch_size, heads, len_v, d_v]
mask: Máscara opcional para ocultar posiciones (ej. padding o look-ahead)
"""
d_k = query.size(-1)
# 1. Producto punto entre Query y Key traspuesta
# scores shape: [batch_size, heads, len_q, len_k]
scores = torch.matmul(query, key.transpose(-2, -1))
# 2. Escalado para estabilidad de gradientes
scores = scores / math.sqrt(d_k)
# 3. Aplicación de máscara (si existe)
# Se reemplazan valores por -infinito antes del softmax
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 4. Softmax para obtener probabilidades de atención
p_attn = F.softmax(scores, dim=-1)
# Aplicación de dropout a los pesos de atención (regularización)
if dropout is not None:
p_attn = dropout(p_attn)
# 5. Suma ponderada de los Values
return torch.matmul(p_attn, value), p_attn
Notas de implementación:
- La operación
matmules eficiente en GPU. - La máscara es fundamental en el decodificador para preservar la propiedad autoregresiva (evitar que la posición $t$ atienda a $t+1$).
Análisis de comportamiento
Al desplegar este mecanismo en entrenamiento, se observan comportamientos distintivos frente a las RNN:
- Paralelización: A diferencia de una RNN que requiere $O(n)$ pasos secuenciales, el mecanismo de atención procesa toda la secuencia en $O(1)$ pasos secuenciales (aunque con un coste total de operaciones mayor por capa). Esto satura mejor las GPUs modernas (NVIDIA Pascal/Volta).
- Captura de dependencias: La distancia máxima del camino (path length) entre cualquier par de posiciones en la entrada y salida es 1. Esto facilita el flujo de gradientes, permitiendo aprender dependencias de muy largo alcance que LSTM suele olvidar.
- Positional Encoding: Al ser el mecanismo invariante a la permutación (no hay noción intrínseca de orden), es obligatorio inyectar información posicional. El uso de funciones sinusoidales:
$$PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{model}})$$
Ha demostrado ser estable y permite al modelo extrapolar a longitudes de secuencia mayores que las vistas en entrenamiento.
Comparativas o referencias técnicas
Comparando la capa de Self-Attention contra capas recurrentes y convolucionales en términos de complejidad por capa y operaciones secuenciales:
| Tipo de Capa | Complejidad por Capa | Operaciones Secuenciales | Path Length Máximo |
|---|---|---|---|
| Self-Attention | $O(n^2 \cdot d)$ | $O(1)$ | $O(1)$ |
| Recurrent (RNN) | $O(n \cdot d^2)$ | $O(n)$ | $O(n)$ |
| Convolutional | $O(k \cdot n \cdot d^2)$ | $O(1)$ | $O(\log_k(n))$ |
- $n$: longitud de secuencia
- $d$: dimensión de representación
- $k$: tamaño del kernel
En tareas de traducción (WMT 2014 English-to-German), el Transformer base alcanza 27.3 BLEU, superando a los mejores modelos de ensamble previos, con un coste de entrenamiento (FLOPs) significativamente menor ($3.2 \times 10^{18}$ vs $>10^{19}$ para arquitecturas recurrentes como GNMT).
La siguiente visualización compara las características clave de cada tipo de capa, destacando las ventajas del Self-Attention en paralelización y path length:
Limitaciones y casos donde no conviene usarlo
A pesar de su eficiencia en entrenamiento, el mecanismo presenta limitaciones estructurales:
- Complejidad Cuadrática: El coste computacional y de memoria de la matriz de atención es $O(n^2)$. Para secuencias muy largas (ej. documentos completos, genómica), el consumo de VRAM se vuelve prohibitivo rápidamente comparado con la linealidad $O(n)$ de las RNN.
- Inferencia: Durante la generación de texto (decoding), el modelo sigue siendo autoregresivo. A diferencia del entrenamiento, la inferencia no es completamente paralela y el recalculo de las claves/valores para pasos anteriores (o el uso de caché KV) introduce latencia y consumo de memoria.
- Data Hunger: La falta de sesgos inductivos fuertes (como la localidad en CNNs o la temporalidad en RNNs) implica que el modelo requiere grandes volúmenes de datos para aprender patrones que otras arquitecturas asumen por diseño. En datasets pequeños, el riesgo de overfitting es alto si no se regulariza agresivamente.