Neural-Path/Notes
25 min

Mixture of Experts

Scaling a dense neural network means every parameter activates for every input. A 70B-parameter model runs all 70B parameters on every token. Mixture of Experts (MoE) breaks this constraint: the model is divided into many specialized sub-networks ("experts"), and a learned router activates only a small subset per token. This is how GPT-4, Mixtral, and Gemini achieve model capacity far beyond what the inference cost would suggest.

Theory

14 connections · 25% of experts active per token · 75% compute saved
ParisisthecapitalofFrance.funcentitynoundetverbgeoadjbound
func
entity
noun
det
verb
geo
adj
bound

hover a token to trace its routing · "Paris" and "France" share the same experts

A dense model runs every parameter on every token. Mixture of Experts breaks this: the model contains many specialized sub-networks ("experts"), and a router sends each token to only a few of them. The diagram above shows this routing — each token activates a different subset of experts. The total parameter count is large, but the active parameter count per token is small. This is how models achieve the capacity of a 200B+ parameter model at the inference cost of a 20B one.

The MoE Layer

A standard Transformer replaces its dense Feed-Forward Network (FFN) sublayer with an MoE layer containing EE expert networks {f1,f2,,fE}\{f_1, f_2, \ldots, f_E\}. For each token representation xRdx \in \mathbb{R}^d, a router selects KK experts and weights their outputs:

MoE(x)=i=1Kgi(x)fi(x)\text{MoE}(x) = \sum_{i=1}^{K} g_i(x) \cdot f_i(x)

where gi(x)g_i(x) is the gating weight for the ii-th selected expert. Only the top-KK gates are nonzero (typically K=2K = 2).

Routing Mechanisms

Top-KK Softmax Gating (the standard): compute a score for each expert, take the top KK, normalize:

h(x)=Softmax(Wgx)REh(x) = \text{Softmax}(W_g \cdot x) \in \mathbb{R}^E

T(x)=top-K indices of h(x)\mathcal{T}(x) = \text{top-}K \text{ indices of } h(x)

gi(x)={hi(x)jT(x)hj(x)iT(x)0otherwiseg_i(x) = \begin{cases} \dfrac{h_i(x)}{\sum_{j \in \mathcal{T}(x)} h_j(x)} & i \in \mathcal{T}(x) \\ 0 & \text{otherwise} \end{cases}

The load balancing problem: without additional constraints, a router will quickly learn to route everything to 1–2 "expert" sub-networks, wasting the others. To prevent this, training adds an auxiliary load balancing loss:

Laux=αEi=1Efipi\mathcal{L}_{\text{aux}} = \alpha \cdot E \sum_{i=1}^{E} f_i \cdot p_i

where fif_i = fraction of tokens routed to expert ii, pip_i = mean routing probability assigned to expert ii. Minimizing this encourages uniform routing. α=0.01\alpha = 0.01 is typical.

💡Intuition

Think of each expert as a specialist. A generalist doctor can handle everything, but a hospital with specialists handles complex cases faster. The router acts like a triage system: it reads the "symptom" (token) and sends it to the right specialist(s). The key engineering challenge is keeping all specialists busy — if only one specialist gets all the patients, you haven't gained anything.

Expert Capacity and Token Dropping

In practice, each expert has a fixed capacity — the maximum number of tokens it can process in a batch. If more tokens are routed to an expert than its capacity, the excess tokens are dropped (passed through unchanged via the residual connection).

C=BTEcapacity_factorC = \left\lfloor \frac{B \cdot T}{E} \cdot \text{capacity\_factor} \right\rfloor

where BB = batch size, TT = sequence length. A capacity factor of 1.25 gives experts 25% overhead above uniform routing. Too small: tokens dropped. Too large: memory wasted on padding.

Token dropping is a hardware constraint masquerading as a design choice. GPUs execute synchronously — all experts must process the same number of tokens per batch for efficient parallelism. A variable-length queue per expert would stall the entire batch waiting for overloaded experts to finish. The capacity buffer accepts some token loss (via the residual connection bypass) in exchange for keeping all experts on the same clock. It's an engineering compromise forced by SIMD hardware, not a theoretical preference.

Sparse vs. Dense MoE

Dense FFNSparse MoE
ParametersNNN×EN \times E
Activated params / tokenNNN×K/E\approx N \times K/E
Training computeO(N)O(N)O(N×K/E)O(N \times K/E) per token
MemoryNNN×EN \times E
CommunicationNoneExpert parallelism overhead

With E=8E = 8 experts and K=2K = 2, an MoE model uses 8×8\times more parameters but only 2×2\times the compute per token. This is the core efficiency gain: you get a model that "remembers" 8× more, but pays for only 2/8 of it at inference time.

Walkthrough

Tracing a Token Through a Mixtral MoE Layer

Take the token "Paris" entering the MoE FFN sublayer in layer 12 of Mixtral 8x7B. Here is the exact computation:

Step 1 — Router scores. The router projects the token's 4096-dim representation with W_g (shape 4096 × 8), producing 8 raw scores. After softmax: [0.02, 0.08, 0.31, 0.04, 0.44, 0.06, 0.03, 0.02]. Experts 4 and 2 have the highest scores.

Step 2 — Top-2 selection. Only experts 4 and 2 are selected. Their normalized gating weights: g_4 = 0.44/(0.44+0.31) = 0.587, g_2 = 0.31/(0.44+0.31) = 0.413.

Step 3 — Expert computation. Experts 4 and 2 each independently run a standard FFN (two linear layers with SiLU activation, hidden dim 14336):

out_4 = FFN_4("Paris" representation)   # shape: (4096,)
out_2 = FFN_2("Paris" representation)   # shape: (4096,)

Both computations happen in parallel (on the same GPU or on separate GPUs with expert parallelism).

Step 4 — Combine. The final MoE output is the weighted sum:

moe_out = 0.587 × out_4 + 0.413 × out_2   # shape: (4096,)

This replaces what a standard dense FFN would compute, and is added to the residual stream as usual.

What this achieves. The model has 8 × 14336-dim FFN experts, but for this token it only ran 2. The total parameters engaged for "Paris" are ~13B out of 47B total — the full capacity of a large model at roughly 28% of the inference cost.

Analysis & Evaluation

Where Your Intuition Breaks

Sparse MoE is strictly more efficient than a dense model. At inference, yes — only a fraction of parameters activate per token. During training, no: all experts receive gradients, the load balancing auxiliary loss adds overhead, and expert parallelism requires cross-device communication. A sparse MoE with 64 experts trains meaningfully slower than a dense model with the same active parameter count. MoE is an inference-time efficiency win achieved at training-time cost — the tradeoff is favorable only at scale.

Real-World MoE Architectures

Mixtral 8x7B (Mistral AI, 2024): 8 experts per FFN layer, top-2 routing, K=2K=2. Despite "8×7B" in the name, the activated parameter count per token is ~13B (shared attention params + 2 of 8 expert FFNs). On benchmarks it matches or exceeds Llama 2 70B at roughly half the inference cost.

GPT-4: widely believed to be an MoE architecture based on inference characteristics, though OpenAI has not published architectural details.

Gemini 1.5: Google confirmed use of MoE for achieving 1M token context at practical inference costs.

Grok-1 (xAI, 2024): 314B total parameters, 25% activated per token — publicly released MoE model.

DeepSeek-V2/V3 (2024): fine-grained MoE with many small experts (up to 256) and only 6 activated per token. Demonstrated that smaller, more numerous experts can outperform coarser routing schemes.

Why MoE is Hard to Train

  1. Router collapse: without load balancing, routing degenerates to a few experts. The auxiliary loss helps but doesn't fully solve it.
  2. Expert specialization is slow: experts don't specialize until late in training. The first few thousand steps look like worse-than-dense performance.
  3. Communication overhead: in distributed training, each expert may live on a different GPU. Token routing requires all-to-all communication that dense models don't need. Expert parallelism is a distinct axis of parallelism.
  4. Fine-tuning instability: MoE models are more sensitive to fine-tuning than dense models — the routing distribution can shift dramatically with small datasets, causing expert collapse.

MoE vs. Dense: When to Use Each

ScenarioPrefer
Fixed inference hardwareDense — simpler, no routing overhead
Very large parameter budget, limited computeMoE — 3–5× more params for same FLOPs
Multilingual or multi-domain tasksMoE — experts may specialize by domain
Small models (under 7B)Dense — MoE overhead not worth it
Serving latency criticalDense — no token dropping, simpler batching
Fine-tuning on small datasetsDense — more stable

The industry rule of thumb: MoE starts paying off above ~30B total parameters. Below that, the load balancing and communication overhead costs more than the extra capacity gains.

Most practitioners interact with MoE models through APIs (GPT-4, Gemini) or by loading pre-trained checkpoints (Mixtral, Grok-1). You generally don't implement MoE from scratch.

Loading Mixtral with transformers:

python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
 
model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mixtral-8x7B-Instruct-v0.1",
    device_map="auto",        # spread experts across GPUs automatically
    load_in_4bit=True,        # quantize to fit in consumer hardware
    torch_dtype=torch.float16,
)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
🚀Production

Memory requirements:

  • Mixtral 8x7B fp16: ~93GB — needs 2× A100 80GB or 4× A6000
  • Mixtral 8x7B 4-bit quantized: ~26GB — fits on 2× RTX 3090
  • Use device_map="auto"transformers handles expert placement across GPUs

Fine-tuning MoE:

  • Prefer LoRA — only fine-tunes a small fraction of parameters, less routing instability
  • Use small learning rates (1e-5 or lower) — MoE is more sensitive than dense models
  • Monitor routing entropy during fine-tuning — if it drops sharply, reduce learning rate
  • Avoid fine-tuning on datasets smaller than ~10K examples

When building new architectures:

  • MoE is production-ready in transformers, megablocks, and switch-transformers
  • Start with E=8E=8, K=2K=2, capacity_factor=1.25 — the Mixtral defaults are well-validated
  • Load balancing loss coefficient α=0.01\alpha = 0.01 is the most commonly reported value
  • Expect ~15–20% token dropping in early training; it should decrease as routing stabilizes

Enjoying these notes?

Get new lessons delivered to your inbox. No spam.