llm

[[2024 NeurIPS#[NeurIPS Tutorial Beyond Decoding Meta-Generation Algorithms for Large Language Models](https //neurips.cc/virtual/2024/tutorial/99522)]]

Greedy

  • pick top result at each steps
last_token_logits = logits[:, -1, :] # batch, time, token_index => batch, token_index
next_token_ids = torch.argmax(last_token_logits, dim=-1).unsqueeze(-1) # batch, 1
 
tokens = torch.cat([tokens, next_token_ids], dim=1)

Sampling with Temperature

last_token_logits = logits[:, -1, :] # batch, time, token_index => batch, token_index
 
last_token_logits = last_token_logits / temperature
probs = F.softmax(last_token_logits, dim=-1)
 
next_token_ids = torch.multinomial(probs, num_samples=1)
 
tokens = torch.cat([tokens, next_token_ids], dim=1)

Top K Sampling

  • sample from top K
last_token_logits = logits[:, -1, :] / temperature # Get top k 
values, indices = torch.topk(last_token_logits, k=k, dim=-1) # (batch_size, k) 
 
probs = F.softmax(values, dim=-1)
next_token_positions = torch.multinomial(probs, num_samples=1)
next_token_id = torch.gather(indices, 1, next_token_positions) 
tokens = torch.cat([tokens, next_token_id], dim=1)

Top P Sampling / Nucleus Sampling

  • sample from the top tokens that add up to P
last_token_logits = logits[:, -1, :] / temperature 
sorted_logits, sorted_indices = torch.sort(last_token_logits, descending=True) 
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above p 
cutoff = (cum_probs > p).float().cumsum(dim=-1) # cutoff > 0 means we've exceeded p 
# We can mask out these tokens by setting logits to -inf 
sorted_logits[cutoff > 0] = float('-inf') # Re-normalize 
probs = F.softmax(sorted_logits, dim=-1) 
next_token_positions = torch.multinomial(probs, num_samples=1) 
next_token_id = torch.gather(sorted_indices, 1, next_token_positions) 
tokens = torch.cat([tokens, next_token_id], dim=1)

Beam Search

Speculative

Structured

Structured Generation with LLMs

MCTS