- [2406.16838] From Decoding to Meta-Generation: Inference-time Algorithms for Large Language Models
- GitHub - shreyansh26/LLM-Sampling: A collection of various LLM sampling methods implemented in pure Pytorch
- Test Time Compute, LLM Reasoning, Inference Time Scaling
- Text generation strategies
- Maxime Labonne - Decoding Strategies in Large Language Models
[[2024 NeurIPS#[NeurIPS Tutorial Beyond Decoding Meta-Generation Algorithms for Large Language Models](https //neurips.cc/virtual/2024/tutorial/99522)]]
- [2402.10200] Chain-of-Thought Reasoning Without Prompting
- Chain of Thought Empowers Transformers to Solve Inherently Serial Problems
Greedy
- pick top result at each steps
last_token_logits = logits[:, -1, :] # batch, time, token_index => batch, token_index
next_token_ids = torch.argmax(last_token_logits, dim=-1).unsqueeze(-1) # batch, 1
tokens = torch.cat([tokens, next_token_ids], dim=1)
Sampling with Temperature
last_token_logits = logits[:, -1, :] # batch, time, token_index => batch, token_index
last_token_logits = last_token_logits / temperature
probs = F.softmax(last_token_logits, dim=-1)
next_token_ids = torch.multinomial(probs, num_samples=1)
tokens = torch.cat([tokens, next_token_ids], dim=1)
Top K Sampling
- sample from top K
last_token_logits = logits[:, -1, :] / temperature # Get top k
values, indices = torch.topk(last_token_logits, k=k, dim=-1) # (batch_size, k)
probs = F.softmax(values, dim=-1)
next_token_positions = torch.multinomial(probs, num_samples=1)
next_token_id = torch.gather(indices, 1, next_token_positions)
tokens = torch.cat([tokens, next_token_id], dim=1)
Top P Sampling / Nucleus Sampling
- sample from the top tokens that add up to P
last_token_logits = logits[:, -1, :] / temperature
sorted_logits, sorted_indices = torch.sort(last_token_logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above p
cutoff = (cum_probs > p).float().cumsum(dim=-1) # cutoff > 0 means we've exceeded p
# We can mask out these tokens by setting logits to -inf
sorted_logits[cutoff > 0] = float('-inf') # Re-normalize
probs = F.softmax(sorted_logits, dim=-1)
next_token_positions = torch.multinomial(probs, num_samples=1)
next_token_id = torch.gather(sorted_indices, 1, next_token_positions)
tokens = torch.cat([tokens, next_token_id], dim=1)
Beam Search
Speculative
Structured
Structured Generation with LLMs