moe

TLDR

  1. Better models for less compute (larger capacity models under same compute budget)
  2. Best for batched inference (server or offline)1
  3. Balanced Routing is important, otherwise get stragglers, memory balance issues and dead experts
  4. Most models use MoE layers in FFNs (token level feed forward networks)

Routing

Problem: Load Balancing

If the token assignments are not balanced across experts you end up with stragglers and out of memory issues as some experts hold and process way more inputs than others.

Dead Experts will not get gradients and will stop getting routed to

Auxiliary Load Balancing Loss

Penalize routing function for producing imbalanced weights

Capped Expert Capacity

Set a max capacity for each expert C, any inputs to the expert below C get dropped or assigned to a different expert.

Thanks to residual connections the dropped inputs are not completely dropped from the network

NOTE: Expert Capacity leaks information in causal models since it uses batch level statistics

Token Choice

Each token sent to K experts

Expert Choice

Each expert picks N tokens

leaks information for decoder models

Gating

Layer wise

sum(gate_i(x) * expert_i(x))

Noisy Top-K Gating

No longer used in practice

gate(x) = softmax(keep_top_k(H(x), k))
H(x)_i = (x * W_g)_i + StandardNormal() * softplus((x * W_noise)_i)


# keep_top_k  set all values except top k to -inf (masking for softmax)
# StandardNormal() * softplus((x * W_noise)_i) randomizes the gating

Top-K Gating

gate(x) = softmax(keep_top_k(H(x), k))
H(x)_i = (x * W_g)_i


# keep_top_k  set all values except top k to 0

ReMoE

Distributed Setting

Expert Parallelism (EP)

puts experts on different GPUs and sends routed tokens to the GPU that the expert is on

Can be combined with Data Parallelism, by replicating non expert layers on all devices and keeping experts split up:

Inference

Decouples processing speed from parameter count - allows us to use really large capacity models without having to do as much compute at inference time

Adds some overhead vs dense model of the same size due to communication (about 20% slower)

Reduces compute to memory ratio (less compute with same amount of memory compared to dense models), meaning it’s more likely to be memory bound than dense models.

Need large batch size at inference to balance compute and memory

More efficient for batch processing and other compute intense tasks like prefill, speculative decoding

Larger KV Cache since more FFN parameters

Training

Tend to perform better for the same training budget

Number of experts tends to increase quality (8-64 tends to be the best tradeoff)

Tweaks

Mixture of Depth

use expert routing and only update the top K tokens (forwarding the remaining tokens). use routing weight to modulate outputs of the tokens (to make it differentiable)

Granularity

[2402.07871] Scaling Laws for Fine-Grained Mixture of Experts

Shared Experts

Have a few experts that process all tokens, allowing the routed ones to specialize more

Batch Prioritized Routing

Drop tokens with small gating weights

Multi Head MoE

MoE LLMs / Transformers

FFN MoE

Usually MoE layer for feed forward module.

GShard

Mixtral 8x7B

Replaced FFN layers with MoEs, 8 experts, with 2 tokens routed to each expert. Leads to ~14B active parameters per token

Deepseek v2

OLMoE

Attention MoE

Not usually done with Attention layers because attention requires access to all tokens.

Mixture-of-Attention - routes queries Switch Transformer - successor to GShard, but claims attention MoE was too unstable SwitchHead - routes values

MoEs in Vision

Sparse MoE

Soft MoE