TLDR
- Better models for less compute (larger capacity models under same compute budget)
- Best for batched inference (server or offline)1
- Balanced Routing is important, otherwise get stragglers, memory balance issues and dead experts
- Most models use MoE layers in FFNs (token level feed forward networks)
- Mixture of Experts Explained
- [M2L 2024] Mixture of Experts - Diego de Las Casas (Mistral AI)- YouTube - Great lecture from Mistal team, goes into details, most of the notes here are from that talk
Routing
Problem: Load Balancing
If the token assignments are not balanced across experts you end up with stragglers and out of memory issues as some experts hold and process way more inputs than others.
Dead Experts will not get gradients and will stop getting routed to
Auxiliary Load Balancing Loss
Penalize routing function for producing imbalanced weights
Capped Expert Capacity
Set a max capacity for each expert C, any inputs to the expert below C get dropped or assigned to a different expert.
Thanks to residual connections the dropped inputs are not completely dropped from the network
NOTE: Expert Capacity leaks information in causal models since it uses batch level statistics
Token Choice
Each token sent to K experts
Expert Choice
Each expert picks N tokens
leaks information for decoder models
Gating
Layer wise
sum(gate_i(x) * expert_i(x))
Noisy Top-K Gating
No longer used in practice
gate(x) = softmax(keep_top_k(H(x), k))
H(x)_i = (x * W_g)_i + StandardNormal() * softplus((x * W_noise)_i)
# keep_top_k set all values except top k to -inf (masking for softmax)
# StandardNormal() * softplus((x * W_noise)_i) randomizes the gating
Top-K Gating
gate(x) = softmax(keep_top_k(H(x), k))
H(x)_i = (x * W_g)_i
# keep_top_k set all values except top k to 0
Distributed Setting
Expert Parallelism (EP)
puts experts on different GPUs and sends routed tokens to the GPU that the expert is on
Can be combined with Data Parallelism, by replicating non expert layers on all devices and keeping experts split up:
Inference
Decouples processing speed from parameter count - allows us to use really large capacity models without having to do as much compute at inference time
Adds some overhead vs dense model of the same size due to communication (about 20% slower)
Reduces compute to memory ratio (less compute with same amount of memory compared to dense models), meaning it’s more likely to be memory bound than dense models.
Need large batch size at inference to balance compute and memory
More efficient for batch processing and other compute intense tasks like prefill, speculative decoding
Larger KV Cache since more FFN parameters
Training
Tend to perform better for the same training budget
Number of experts tends to increase quality (8-64 tends to be the best tradeoff)
Tweaks
Mixture of Depth
use expert routing and only update the top K tokens (forwarding the remaining tokens). use routing weight to modulate outputs of the tokens (to make it differentiable)
Granularity
[2402.07871] Scaling Laws for Fine-Grained Mixture of Experts
Shared Experts
Have a few experts that process all tokens, allowing the routed ones to specialize more
Batch Prioritized Routing
Drop tokens with small gating weights
MoE LLMs / Transformers
FFN MoE
Usually MoE layer for feed forward module.
GShard
Mixtral 8x7B
Replaced FFN layers with MoEs, 8 experts, with 2 tokens routed to each expert. Leads to ~14B active parameters per token
Deepseek v2
OLMoE
Attention MoE
Not usually done with Attention layers because attention requires access to all tokens.
Mixture-of-Attention - routes queries Switch Transformer - successor to GShard, but claims attention MoE was too unstable SwitchHead - routes values