[[2024 NeurIPS#[NeurIPS Tutorial Beyond Decoding Meta-Generation Algorithms for Large Language Models](https //neurips.cc/virtual/2024/tutorial/99522)]]
pick top result at each steps
last_token_logits = logits[:, - 1 , :] # batch, time, token_index => batch, token_index
next_token_ids = torch.argmax(last_token_logits, dim =- 1 ).unsqueeze( - 1 ) # batch, 1
tokens = torch.cat([tokens, next_token_ids], dim = 1 )
Sampling with Temperature
last_token_logits = logits[:, - 1 , :] # batch, time, token_index => batch, token_index
last_token_logits = last_token_logits / temperature
probs = F.softmax(last_token_logits, dim =- 1 )
next_token_ids = torch.multinomial(probs, num_samples = 1 )
tokens = torch.cat([tokens, next_token_ids], dim = 1 )
Top K Sampling
last_token_logits = logits[:, - 1 , :] / temperature # Get top k
values, indices = torch.topk(last_token_logits, k = k, dim =- 1 ) # (batch_size, k)
probs = F.softmax(values, dim =- 1 )
next_token_positions = torch.multinomial(probs, num_samples = 1 )
next_token_id = torch.gather(indices, 1 , next_token_positions)
tokens = torch.cat([tokens, next_token_id], dim = 1 )
Top P Sampling / Nucleus Sampling
sample from the top tokens that add up to P
last_token_logits = logits[:, - 1 , :] / temperature
sorted_logits, sorted_indices = torch.sort(last_token_logits, descending = True )
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim =- 1 ), dim =- 1 ) # Remove tokens with cumulative probability above p
cutoff = (cum_probs > p).float().cumsum( dim =- 1 ) # cutoff > 0 means we've exceeded p
# We can mask out these tokens by setting logits to -inf
sorted_logits[cutoff > 0 ] = float ( '-inf' ) # Re-normalize
probs = F.softmax(sorted_logits, dim =- 1 )
next_token_positions = torch.multinomial(probs, num_samples = 1 )
next_token_id = torch.gather(sorted_indices, 1 , next_token_positions)
tokens = torch.cat([tokens, next_token_id], dim = 1 )
Beam Search
Structured Generation with LLMs