Mixture of Experts: Sparse Activation y Estrategias de Routing en LLMs

Análisis de la arquitectura Mixture of Experts (MoE), enfocado en el desacople del coste de inferencia respecto al conteo total de parámetros mediante "conditional computation".

El escalado de modelos densos (Dense Transformers) presenta una relación lineal directa entre el rendimiento y el coste computacional por token: para aumentar la capacidad del modelo, se debe incrementar el número de parámetros, lo que obliga a activar la totalidad de los pesos en cada paso de la inferencia. Esto genera un cuello de botella en latencia y FLOPs, limitando la viabilidad de modelos superiores a los 100B parámetros en entornos de producción sensibles al tiempo de respuesta.

Mixture of Experts (MoE) reintroduce el concepto de computación condicional. En lugar de una red monolítica, el modelo se compone de múltiples sub-redes ("expertos") especializadas. Para cada token de entrada, solo se activa un subconjunto de estos expertos.

En el estado del arte, arquitecturas como Mixtral 8x7B (basada en Mistral) y las especulaciones sobre GPT-4 han validado el uso de MoE sparse para obtener el rendimiento de un modelo denso masivo (e.g., Llama 2 70B) con una fracción del coste de inferencia, aunque con requerimientos de VRAM significativos debido a la necesidad de cargar todos los expertos en memoria.


Fundamentos matemáticos

La arquitectura MoE reemplaza las capas Feed-Forward (FFN) densas tradicionales por una capa MoE. Esta capa consta de un conjunto de $N$ redes expertas ${E_1, E_2, ..., E_N}$ y una red de compuerta o Gating Network $G$.

Para una entrada $x$, la salida $y$ de la capa MoE es la suma ponderada de las salidas de los expertos activados:

$$y = \sum_{i=1}^{N} G(x)_i E_i(x)$$

Donde $G(x)$ es un vector disperso (sparse vector) determinado por la función de routing. Generalmente, se utiliza un mecanismo de Top-k Gating, donde solo los $k$ expertos con mayor puntuación procesan el token. Si definimos $H(x) = x \cdot W_g$ como los logits del router (donde $W_g$ son los pesos entrenables):

$$G(x) = \text{Softmax}(\text{TopK}(H(x), k))$$

En esta formulación, $\text{TopK}$ mantiene los valores de los $k$ índices mayores y establece el resto en $-\infty$ antes del Softmax, asegurando que la contribución de los expertos no seleccionados sea estrictamente cero.

El siguiente diagrama ilustra cómo un token atraviesa la capa MoE: el Gating Network evalúa todos los expertos pero solo activa los top-k (k=2 en este ejemplo), combinando sus salidas mediante suma ponderada.

Load Balancing Loss

Un problema crítico es el "colapso de expertos", donde el Gating Network converge a usar siempre los mismos pocos expertos, desperdiciando la capacidad del resto. Para mitigar esto, se introduce una pérdida auxiliar durante el entrenamiento:

$$L_{aux} = N \sum_{i=1}^{N} f_i \cdot P_i$$

Donde $f_i$ es la fracción de tokens asignados al experto $i$, y $P_i$ es la probabilidad promedio de que el router seleccione al experto $i$. Esta pérdida penaliza la distribución desigual de la carga.

La visualización muestra el problema de colapso de expertos (izquierda), donde el router concentra el tráfico en pocos expertos mientras ignora al resto, versus el estado balanceado (derecha) logrado mediante la pérdida auxiliar $L_{aux}$.


Implementación práctica

A continuación se presenta una implementación simplificada en PyTorch de una capa MoE con Top-2 Gating, similar a la arquitectura utilizada en Mixtral.

import torch
import torch.nn as nn
import torch.nn.functional as F

class MoELayer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_experts, top_k=2):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k

        # Gating Network (Router)
        self.gate = nn.Linear(input_dim, num_experts)
        
        # Expertos: Lista de redes Feed-Forward independientes
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.SiLU(),
                nn.Linear(hidden_dim, input_dim)
            ) for _ in range(num_experts)
        ])

    def forward(self, x):
        # x shape: (batch_size, seq_len, input_dim)
        batch_size, seq_len, _ = x.shape
        x_flat = x.view(-1, x.size(-1)) # Flatten para procesar tokens individualmente
        
        # Calcular logits del router
        router_logits = self.gate(x_flat) # (total_tokens, num_experts)
        
        # Seleccionar los top-k expertos y sus pesos
        routing_weights, selected_experts = torch.topk(router_logits, self.top_k, dim=-1)
        
        # Normalizar pesos (Softmax sobre los top-k)
        routing_weights = F.softmax(routing_weights, dim=-1)
        
        # Inicializar tensor de salida
        final_output = torch.zeros_like(x_flat)
        
        # Procesamiento (Iteración ingenua para claridad, optimizable con scatter/gather)
        for i in range(self.top_k):
            expert_idx = selected_experts[:, i]
            weight = routing_weights[:, i].unsqueeze(1)
            
            # En una implementación real vectorizada, se usarían máscaras
            # para enviar lotes de tokens a cada experto en paralelo.
            for j in range(self.num_experts):
                mask = (expert_idx == j)
                if mask.any():
                    tokens_for_expert = x_flat[mask]
                    expert_out = self.experts[j](tokens_for_expert)
                    final_output[mask] += weight[mask] * expert_out
        
        return final_output.view(batch_size, seq_len, -1)

# Ejemplo de instanciación
# 8 expertos, activando 2 por token (Configuración tipo Mixtral)
moe_layer = MoELayer(input_dim=4096, hidden_dim=14336, num_experts=8, top_k=2)


Análisis de comportamiento

Al desplegar modelos MoE en producción, se observan comportamientos distintivos respecto a modelos densos:

  • Disociación Memoria-Cómputo: Un modelo como Mixtral 8x7B tiene ~47B de parámetros totales, pero solo utiliza ~13B parámetros activos por token (2 expertos activos). Esto resulta en una velocidad de inferencia (tokens/segundo) comparable a un modelo de 12B-14B, no a uno de 47B.
  • Ancho de Banda de Memoria (Memory Bandwidth Bound): Aunque los FLOPs son bajos, el modelo debe cargar en VRAM los parámetros de todos los expertos. En inferencia con batch size bajo (e.g., 1 usuario), el cuello de botella se traslada de la capacidad de cómputo a la velocidad de lectura de la VRAM, ya que para cada token se deben acceder a diferentes regiones de memoria de forma no contigua.
  • Estabilidad en Entrenamiento: Los MoE son notoriamente inestables en etapas tempranas. Sin un auxiliary loss bien calibrado, es común ver divergencias o que el router ignore a ciertos expertos permanentemente.

Comparativas o referencias técnicas

Comparativa técnica de arquitecturas relevantes a Enero 2024:

Métrica Llama 2 70B (Denso) Mixtral 8x7B (MoE)
Parámetros Totales ~70B ~46.7B
Parámetros Activos/Token ~70B ~12.9B
Ventana de Contexto 4k tokens 32k tokens
Requisitos VRAM (FP16) ~140 GB ~90 GB
Inferencia (Tokens/s) Base ~4x - 6x vs Llama 2 70B
Rendimiento (MMLU) ~68.9% ~70.6%

Mixtral demuestra que es posible superar el rendimiento de modelos densos con una fracción del coste computacional activo, aunque la huella de memoria sigue siendo alta.


Limitaciones y casos donde no conviene usarlo

El uso de MoE no es una solución universal y presenta desventajas técnicas específicas:

  1. Despliegue en Edge Devices: Debido a que todos los parámetros deben residir en memoria (o intercambiarse rápidamente), los MoE son inviables para dispositivos con RAM limitada, incluso si su coste de cómputo es bajo. Un modelo denso pequeño (e.g., 7B) es preferible en estos entornos.
  2. Complejidad de Infraestructura: El entrenamiento distribuido de MoE requiere estrategias avanzadas de paralelismo (Expert Parallelism), donde diferentes expertos se alojan en diferentes GPUs. Esto aumenta la sobrecarga de comunicación (all-to-all communication) entre nodos.
  3. Fine-tuning (LoRA): Aunque es posible realizar fine-tuning con técnicas como LoRA, la dispersión de los adaptadores en múltiples expertos complica la convergencia y requiere ajustes específicos en los hiperparámetros de optimización (learning rate, rank) respecto a modelos densos.
  4. Latencia en Batch Size alto: En escenarios de alto throughput, si los tokens de un mismo batch activan expertos muy diversos, se pierde la eficiencia del caché y la paralelización matricial, degradando el rendimiento.