Commit
·
0da6031
1
Parent(s):
8f67839
eliminate hf helpers
Browse files- README.md +20 -56
- custom_generate/generate.py +185 -100
README.md
CHANGED
|
@@ -2,13 +2,14 @@
|
|
| 2 |
library_name: transformers
|
| 3 |
tags:
|
| 4 |
- custom_generate
|
| 5 |
-
-
|
|
|
|
| 6 |
---
|
| 7 |
|
| 8 |
-
#
|
| 9 |
|
| 10 |
## Description
|
| 11 |
-
A clean, hackable implementation of ancestral sampling
|
| 12 |
|
| 13 |
The implementation supports both sampling and greedy decoding modes, with optional temperature scaling and top-k/top-p filtering.
|
| 14 |
|
|
@@ -18,19 +19,23 @@ The implementation supports both sampling and greedy decoding modes, with option
|
|
| 18 |
## Model compatibility
|
| 19 |
Most transformer LLM/VLM models trained for causal language modeling.
|
| 20 |
|
| 21 |
-
##
|
| 22 |
- `temperature` (float): Sampling temperature (default: 1.0, higher = more random)
|
| 23 |
- `top_k` (int): Only consider top-k most probable tokens (default: None)
|
| 24 |
- `top_p` (float): Only consider tokens with cumulative probability <= top_p (default: None)
|
| 25 |
- `do_sample` (bool): Whether to use sampling (True, default) or greedy decoding (False)
|
| 26 |
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
When `return_dict_in_generate=True`, returns a dictionary with:
|
| 29 |
- `sequences`: Generated token IDs
|
| 30 |
- `scores`: Log probabilities of sampled tokens (with temperature/sampling modifications)
|
| 31 |
-
- `
|
| 32 |
-
|
| 33 |
-
- `lens`: Final sequence lengths
|
| 34 |
|
| 35 |
## Example usage
|
| 36 |
|
|
@@ -43,30 +48,30 @@ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", devic
|
|
| 43 |
inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)
|
| 44 |
|
| 45 |
# Basic sampling
|
| 46 |
-
gen_out = model.generate(**inputs, custom_generate="manueldeprada/
|
| 47 |
|
| 48 |
# With temperature
|
| 49 |
-
gen_out = model.generate(**inputs, custom_generate="manueldeprada/
|
| 50 |
|
| 51 |
# With top-k
|
| 52 |
-
gen_out = model.generate(**inputs, custom_generate="manueldeprada/
|
| 53 |
|
| 54 |
# With top-p (nucleus sampling)
|
| 55 |
-
gen_out = model.generate(**inputs, custom_generate="manueldeprada/
|
| 56 |
|
| 57 |
# Greedy decoding (no sampling)
|
| 58 |
-
gen_out = model.generate(**inputs, custom_generate="manueldeprada/
|
| 59 |
|
| 60 |
# Get detailed output with probabilities
|
| 61 |
gen_out = model.generate(
|
| 62 |
**inputs,
|
| 63 |
-
custom_generate="manueldeprada/
|
| 64 |
return_dict_in_generate=True,
|
| 65 |
trust_remote_code=True
|
| 66 |
)
|
| 67 |
print(f"Generated text: {tokenizer.batch_decode(gen_out['sequences'], skip_special_tokens=True)}")
|
| 68 |
print(f"Sampling scores: {gen_out['scores']}")
|
| 69 |
-
print(f"Model log probabilities: {gen_out['
|
| 70 |
```
|
| 71 |
|
| 72 |
## Algorithm
|
|
@@ -82,47 +87,6 @@ print(f"Model log probabilities: {gen_out['logps']}")
|
|
| 82 |
- Update KV cache and track sequence completion
|
| 83 |
3. Return generated sequences and probability information
|
| 84 |
|
| 85 |
-
## Helper Functions for Custom Generation
|
| 86 |
-
|
| 87 |
-
The implementation provides two key helper functions that you can use to build your own generation strategies:
|
| 88 |
-
|
| 89 |
-
### `init_gen(model_kwargs, model, max_new_tokens, bos_token_id)`
|
| 90 |
-
Initializes the generation process and prepares the KV cache:
|
| 91 |
-
- Sets up input sequences and model inputs
|
| 92 |
-
- Prepares the KV cache for generation
|
| 93 |
-
- Returns updated `model_kwargs` and `input_ids`
|
| 94 |
|
| 95 |
-
### `ps_next(model, model_kwargs, input_ids)`
|
| 96 |
-
Gets the next token logits and updates the KV cache:
|
| 97 |
-
- Runs the model forward pass
|
| 98 |
-
- Extracts logits for the last token
|
| 99 |
-
- Updates the KV cache
|
| 100 |
-
- Returns updated `model_kwargs` and `logits`
|
| 101 |
|
| 102 |
-
### Example: Custom Generation Loop
|
| 103 |
-
|
| 104 |
-
```py
|
| 105 |
-
from ancestral_sampling.generate import init_gen, ps_next
|
| 106 |
-
|
| 107 |
-
def custom_generation(model, model_kwargs, max_new_tokens=20, temperature=1.0):
|
| 108 |
-
# Initialize generation
|
| 109 |
-
model_kwargs, input_ids = init_gen(model_kwargs, model, max_new_tokens, bos_token_id)
|
| 110 |
-
|
| 111 |
-
for i in range(max_new_tokens):
|
| 112 |
-
# Get next token logits
|
| 113 |
-
model_kwargs, logits = ps_next(model, model_kwargs, input_ids)
|
| 114 |
-
|
| 115 |
-
# Your custom logic here
|
| 116 |
-
probs = (logits / temperature).softmax(-1)
|
| 117 |
-
next_token = torch.multinomial(probs, 1)
|
| 118 |
-
|
| 119 |
-
# Append token and continue
|
| 120 |
-
input_ids = torch.cat([input_ids, next_token], dim=-1)
|
| 121 |
-
|
| 122 |
-
# Add your stopping conditions
|
| 123 |
-
if next_token.item() == eos_token_id:
|
| 124 |
-
break
|
| 125 |
-
|
| 126 |
-
return input_ids
|
| 127 |
-
```
|
| 128 |
|
|
|
|
| 2 |
library_name: transformers
|
| 3 |
tags:
|
| 4 |
- custom_generate
|
| 5 |
+
- sampling
|
| 6 |
+
- kvcache
|
| 7 |
---
|
| 8 |
|
| 9 |
+
# Sampling with KV Cache
|
| 10 |
|
| 11 |
## Description
|
| 12 |
+
A clean, hackable implementation of sampling (also called ancestral sampling or multinomial sampling) with full KV cache support. This is a simplified alternative to the complex generation mixin in transformers, designed for readability and ease of modification while maintaining full performance.
|
| 13 |
|
| 14 |
The implementation supports both sampling and greedy decoding modes, with optional temperature scaling and top-k/top-p filtering.
|
| 15 |
|
|
|
|
| 19 |
## Model compatibility
|
| 20 |
Most transformer LLM/VLM models trained for causal language modeling.
|
| 21 |
|
| 22 |
+
## Relevant Arguments
|
| 23 |
- `temperature` (float): Sampling temperature (default: 1.0, higher = more random)
|
| 24 |
- `top_k` (int): Only consider top-k most probable tokens (default: None)
|
| 25 |
- `top_p` (float): Only consider tokens with cumulative probability <= top_p (default: None)
|
| 26 |
- `do_sample` (bool): Whether to use sampling (True, default) or greedy decoding (False)
|
| 27 |
|
| 28 |
+
### Logits Processing Order
|
| 29 |
+
Logits processors are applied in sequence: `temperature → softmax → top_k → top_p` (same as HuggingFace's `LogitProcessor` system). Temperature scaling occurs before top-p filtering, affecting the probability distribution that top-p operates on.
|
| 30 |
+
|
| 31 |
+
For example, with `temperature=1.0`, `top_p=0.9` might include tokens A, B, C. With `temperature=0.5`, probability mass is much more concentrated, so `top_p=0.9` might only include token A.
|
| 32 |
+
|
| 33 |
+
## Outputs
|
| 34 |
When `return_dict_in_generate=True`, returns a dictionary with:
|
| 35 |
- `sequences`: Generated token IDs
|
| 36 |
- `scores`: Log probabilities of sampled tokens (with temperature/sampling modifications)
|
| 37 |
+
- `logprobs`: Original model log probabilities (T=1, no modifications)
|
| 38 |
+
Otherwise, returns a tensor of generated token IDs.
|
|
|
|
| 39 |
|
| 40 |
## Example usage
|
| 41 |
|
|
|
|
| 48 |
inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)
|
| 49 |
|
| 50 |
# Basic sampling
|
| 51 |
+
gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache", trust_remote_code=True)
|
| 52 |
|
| 53 |
# With temperature
|
| 54 |
+
gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache", temperature=0.8, trust_remote_code=True)
|
| 55 |
|
| 56 |
# With top-k
|
| 57 |
+
gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache", top_k=50, trust_remote_code=True)
|
| 58 |
|
| 59 |
# With top-p (nucleus sampling)
|
| 60 |
+
gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache", top_p=0.9, trust_remote_code=True)
|
| 61 |
|
| 62 |
# Greedy decoding (no sampling)
|
| 63 |
+
gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache", do_sample=False, trust_remote_code=True)
|
| 64 |
|
| 65 |
# Get detailed output with probabilities
|
| 66 |
gen_out = model.generate(
|
| 67 |
**inputs,
|
| 68 |
+
custom_generate="manueldeprada/sampling_with_kvcache",
|
| 69 |
return_dict_in_generate=True,
|
| 70 |
trust_remote_code=True
|
| 71 |
)
|
| 72 |
print(f"Generated text: {tokenizer.batch_decode(gen_out['sequences'], skip_special_tokens=True)}")
|
| 73 |
print(f"Sampling scores: {gen_out['scores']}")
|
| 74 |
+
print(f"Model log probabilities: {gen_out['logprobs']}")
|
| 75 |
```
|
| 76 |
|
| 77 |
## Algorithm
|
|
|
|
| 87 |
- Update KV cache and track sequence completion
|
| 88 |
3. Return generated sequences and probability information
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
custom_generate/generate.py
CHANGED
|
@@ -1,87 +1,157 @@
|
|
| 1 |
import torch
|
| 2 |
-
from transformers import
|
|
|
|
|
|
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"""
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
Args:
|
| 10 |
model: The language model
|
| 11 |
model_kwargs: Model keyword arguments including KV cache
|
| 12 |
input_ids: Current input token IDs
|
| 13 |
-
|
| 14 |
-
|
| 15 |
Returns:
|
| 16 |
-
Updated model_kwargs,
|
| 17 |
"""
|
| 18 |
-
model_inputs =
|
| 19 |
with torch.no_grad():
|
| 20 |
outputs = model(**model_inputs, return_dict=True)
|
| 21 |
-
|
| 22 |
logits = outputs.logits[:, -1].detach()
|
| 23 |
-
model_kwargs =
|
| 24 |
-
outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
|
| 25 |
-
)
|
| 26 |
del outputs
|
| 27 |
return model_kwargs, logits
|
| 28 |
|
|
|
|
| 29 |
def init_gen(model_kwargs, model, max_new_tokens, bos_token_id):
|
| 30 |
"""
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
| 33 |
Args:
|
| 34 |
model_kwargs: Model keyword arguments
|
| 35 |
model: The language model
|
| 36 |
max_new_tokens: Maximum number of new tokens to generate
|
| 37 |
-
|
|
|
|
| 38 |
Returns:
|
| 39 |
Model keyword arguments and input token IDs
|
| 40 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
)
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
return model_kwargs, input_ids
|
| 55 |
|
| 56 |
-
def
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
top_p = model.generation_config.top_p
|
| 67 |
-
if top_p < 1.0:
|
| 68 |
-
sorted_probs, sorted_indices = torch.sort(ps, descending=True)
|
| 69 |
-
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
| 70 |
-
|
| 71 |
-
# Remove tokens with cumulative probability above the threshold
|
| 72 |
-
sorted_indices_to_remove = cumulative_probs > top_p
|
| 73 |
-
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 74 |
-
sorted_indices_to_remove[..., 0] = 0
|
| 75 |
-
|
| 76 |
-
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 77 |
-
ps[indices_to_remove] = 0.0
|
| 78 |
-
ps = ps / ps.sum(dim=-1, keepdim=True)
|
| 79 |
-
return ps
|
| 80 |
-
|
| 81 |
-
def ancestral_sampling(model_kwargs, model, eos_token_ids, pad_token_id, bos_token_id, do_sample=True, max_new_tokens=20, T=1.0):
|
| 82 |
"""
|
| 83 |
-
|
| 84 |
-
|
| 85 |
Args:
|
| 86 |
prompts: List of input prompts
|
| 87 |
model: The language model
|
|
@@ -90,55 +160,64 @@ def ancestral_sampling(model_kwargs, model, eos_token_ids, pad_token_id, bos_tok
|
|
| 90 |
pad_token_id: Padding token ID
|
| 91 |
bos_token_id: Beginning-of-sequence token ID
|
| 92 |
max_new_tokens: Maximum number of new tokens to generate
|
| 93 |
-
|
| 94 |
Returns:
|
| 95 |
Generated sequences, log probabilities, and metadata
|
| 96 |
"""
|
| 97 |
# Initialize the generation process and prepare the KV cache
|
| 98 |
-
model_kwargs, input_ids = init_gen(
|
| 99 |
-
|
| 100 |
-
|
|
|
|
| 101 |
|
| 102 |
# Keeps track of which sequences are finished and their lengths
|
| 103 |
active_seqs = input_ids.new_ones((batch_size, 1), dtype=torch.bool)
|
| 104 |
-
lens = torch.full((batch_size,), max_prompts_len, dtype=torch.long, device=input_ids.device)
|
| 105 |
# Modified log probabilities of the sequences
|
| 106 |
scores = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype)
|
| 107 |
-
# Unfiltered sequence log probabilities (
|
| 108 |
-
|
| 109 |
|
| 110 |
for i in range(max_new_tokens):
|
| 111 |
# Get the next token probabilities and update the KV cache
|
| 112 |
-
model_kwargs, logits =
|
| 113 |
-
|
|
|
|
|
|
|
| 114 |
model_ps = logits.softmax(-1)
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
ps =
|
| 118 |
-
|
|
|
|
|
|
|
| 119 |
# Sample the next token and gather the log probabilities
|
| 120 |
-
if do_sample:
|
| 121 |
-
next_token_ids =
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
input_ids = torch.cat([input_ids, next_token_ids], dim=-1)
|
| 128 |
-
scores[:, i] = (
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
lens += active_seqs.squeeze(-1).long()
|
| 132 |
active_seqs &= ~torch.isin(next_token_ids, eos_token_ids)
|
| 133 |
if active_seqs.sum() == 0:
|
| 134 |
-
break
|
| 135 |
-
return input_ids.detach().cpu(), scores[
|
|
|
|
| 136 |
|
| 137 |
def generate(model, **kwargs):
|
| 138 |
"""
|
| 139 |
-
|
| 140 |
Simple implementation with proper KV caching support.
|
| 141 |
-
|
| 142 |
Args:
|
| 143 |
model: The language model
|
| 144 |
model_kwargs: Model keyword arguments from the tokenizer
|
|
@@ -147,29 +226,38 @@ def generate(model, **kwargs):
|
|
| 147 |
top_k: Only consider top-k most probable tokens
|
| 148 |
top_p: Only consider tokens with cumulative probability <= top_p
|
| 149 |
**kwargs: Additional arguments
|
| 150 |
-
|
| 151 |
Returns:
|
| 152 |
Generated token IDs
|
| 153 |
"""
|
| 154 |
generation_config = model.generation_config
|
| 155 |
-
max_new_tokens = kwargs.get(
|
| 156 |
max_new_tokens = 512 if max_new_tokens is None else max_new_tokens
|
| 157 |
-
do_sample = kwargs.get(
|
| 158 |
-
eos_token_ids = kwargs.get(
|
| 159 |
if eos_token_ids is None:
|
| 160 |
-
raise ValueError(
|
|
|
|
|
|
|
| 161 |
eos_token_ids = torch.as_tensor(eos_token_ids, device=model.device)
|
| 162 |
if eos_token_ids is not None and eos_token_ids.ndim == 0:
|
| 163 |
eos_token_ids = eos_token_ids.unsqueeze(0)
|
| 164 |
-
|
| 165 |
-
pad_token_id = kwargs.get(
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
if bos_token_id is None:
|
| 168 |
-
raise ValueError(
|
| 169 |
-
|
| 170 |
-
|
|
|
|
|
|
|
| 171 |
|
| 172 |
-
generated_ids, scores,
|
| 173 |
model_kwargs=kwargs,
|
| 174 |
model=model,
|
| 175 |
eos_token_ids=eos_token_ids,
|
|
@@ -177,17 +265,14 @@ def generate(model, **kwargs):
|
|
| 177 |
bos_token_id=bos_token_id,
|
| 178 |
do_sample=do_sample,
|
| 179 |
max_new_tokens=max_new_tokens,
|
| 180 |
-
|
| 181 |
)
|
| 182 |
|
| 183 |
if return_dict:
|
| 184 |
return {
|
| 185 |
"sequences": generated_ids,
|
| 186 |
"scores": scores,
|
| 187 |
-
"
|
| 188 |
-
"prompt_lens": prompt_lens,
|
| 189 |
-
"lens": lens,
|
| 190 |
}
|
| 191 |
else:
|
| 192 |
return generated_ids
|
| 193 |
-
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from transformers import Cache, DynamicCache
|
| 3 |
+
from transformers.generation.utils import ModelOutput
|
| 4 |
+
from typing import Optional, Any
|
| 5 |
|
| 6 |
+
def prepare_inputs_for_generation(
|
| 7 |
+
input_ids: torch.LongTensor,
|
| 8 |
+
past_key_values: Optional[Cache] = None,
|
| 9 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 10 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 11 |
+
**kwargs,
|
| 12 |
+
):
|
| 13 |
+
input_ids = input_ids[:, cache_position].clone(memory_format=torch.contiguous_format)
|
| 14 |
+
cur_len = input_ids.shape[1]
|
| 15 |
+
model_inputs = {"cache_position": cache_position,
|
| 16 |
+
"past_key_values": past_key_values,
|
| 17 |
+
"input_ids": input_ids,
|
| 18 |
+
"inputs_embeds": None,
|
| 19 |
+
"attention_mask": attention_mask,
|
| 20 |
+
}
|
| 21 |
+
if attention_mask is not None and kwargs.get("position_ids") is None:
|
| 22 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 23 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 24 |
+
kwargs["position_ids"] = position_ids
|
| 25 |
+
if past_key_values is not None:
|
| 26 |
+
for name in ("position_ids", "token_type_ids"):
|
| 27 |
+
if name in kwargs:
|
| 28 |
+
kwargs[name] = kwargs[name][:, -cur_len:].clone(memory_format=torch.contiguous_format)
|
| 29 |
+
model_inputs.update({k: v for k, v in kwargs.items() if k not in model_inputs})
|
| 30 |
+
return model_inputs
|
| 31 |
|
| 32 |
+
def update_model_kwargs_for_generation(
|
| 33 |
+
outputs: ModelOutput,
|
| 34 |
+
model_kwargs: dict[str, Any],
|
| 35 |
+
num_new_tokens: int = 1,
|
| 36 |
+
) -> dict[str, Any]:
|
| 37 |
+
model_kwargs["past_key_values"] = getattr(outputs, "past_key_values")
|
| 38 |
+
if "token_type_ids" in model_kwargs:
|
| 39 |
+
token_type_ids = model_kwargs["token_type_ids"]
|
| 40 |
+
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
|
| 41 |
+
if "attention_mask" in model_kwargs:
|
| 42 |
+
attention_mask = model_kwargs["attention_mask"]
|
| 43 |
+
model_kwargs["attention_mask"] = torch.cat(
|
| 44 |
+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
| 45 |
+
)
|
| 46 |
+
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
|
| 47 |
+
return model_kwargs
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def next_logits_with_cache_update(model, model_kwargs, input_ids):
|
| 51 |
"""
|
| 52 |
+
Gets the next token logits and updates the KV cache:
|
| 53 |
+
- Runs the model forward pass
|
| 54 |
+
- Extracts logits for the last token
|
| 55 |
+
- Updates the KV cache
|
| 56 |
+
- Returns updated `model_kwargs` and `logits`
|
| 57 |
+
|
| 58 |
Args:
|
| 59 |
model: The language model
|
| 60 |
model_kwargs: Model keyword arguments including KV cache
|
| 61 |
input_ids: Current input token IDs
|
| 62 |
+
|
|
|
|
| 63 |
Returns:
|
| 64 |
+
Updated model_kwargs, logits for the next token
|
| 65 |
"""
|
| 66 |
+
model_inputs = prepare_inputs_for_generation(input_ids, **model_kwargs)
|
| 67 |
with torch.no_grad():
|
| 68 |
outputs = model(**model_inputs, return_dict=True)
|
| 69 |
+
|
| 70 |
logits = outputs.logits[:, -1].detach()
|
| 71 |
+
model_kwargs = update_model_kwargs_for_generation(outputs, model_kwargs)
|
|
|
|
|
|
|
| 72 |
del outputs
|
| 73 |
return model_kwargs, logits
|
| 74 |
|
| 75 |
+
|
| 76 |
def init_gen(model_kwargs, model, max_new_tokens, bos_token_id):
|
| 77 |
"""
|
| 78 |
+
Initializes the generation process and prepares the KV cache:
|
| 79 |
+
- Sets up input sequences and model inputs
|
| 80 |
+
- Prepares the KV cache for generation
|
| 81 |
+
- Returns updated `model_kwargs` and `input_ids`
|
| 82 |
+
|
| 83 |
Args:
|
| 84 |
model_kwargs: Model keyword arguments
|
| 85 |
model: The language model
|
| 86 |
max_new_tokens: Maximum number of new tokens to generate
|
| 87 |
+
bos_token_id: Beginning-of-sequence token ID
|
| 88 |
+
|
| 89 |
Returns:
|
| 90 |
Model keyword arguments and input token IDs
|
| 91 |
"""
|
| 92 |
+
input_ids = model_kwargs.pop("input_ids")
|
| 93 |
+
model_kwargs["past_key_values"] = DynamicCache() if model_kwargs.get("past_key_values") is None else model_kwargs["past_key_values"]
|
| 94 |
+
assert isinstance(model_kwargs["past_key_values"], Cache), "past_key_values must be a Cache object"
|
| 95 |
+
cache_position = torch.ones(input_ids.shape[1], dtype=torch.int64, device=input_ids.device).cumsum(0) - 1
|
| 96 |
+
cache_position = cache_position[model_kwargs["past_key_values"].get_seq_length() :]
|
| 97 |
+
model_kwargs["cache_position"] = cache_position
|
| 98 |
+
return model_kwargs, input_ids
|
| 99 |
|
| 100 |
+
|
| 101 |
+
def _apply_top_k(ps, model):
|
| 102 |
+
"""Apply top-k filtering to probabilities."""
|
| 103 |
+
if not hasattr(model, "generation_config") or not hasattr(
|
| 104 |
+
model.generation_config, "top_k"
|
| 105 |
+
):
|
| 106 |
+
return ps
|
| 107 |
+
|
| 108 |
+
top_k = model.generation_config.top_k
|
| 109 |
+
if top_k is None or top_k >= ps.size(-1):
|
| 110 |
+
return ps
|
| 111 |
+
|
| 112 |
+
indices_to_remove = ps < torch.topk(ps, top_k)[0][..., -1, None]
|
| 113 |
+
ps[indices_to_remove] = 0.0
|
| 114 |
+
return ps / ps.sum(dim=-1, keepdim=True)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _apply_top_p(ps, model):
|
| 118 |
+
"""Apply top-p (nucleus) filtering to probabilities."""
|
| 119 |
+
if not hasattr(model, "generation_config") or not hasattr(
|
| 120 |
+
model.generation_config, "top_p"
|
| 121 |
+
):
|
| 122 |
+
return ps
|
| 123 |
+
|
| 124 |
+
top_p = model.generation_config.top_p
|
| 125 |
+
if top_p is None or top_p >= 1.0:
|
| 126 |
+
return ps
|
| 127 |
+
|
| 128 |
+
sorted_probs, sorted_indices = torch.sort(ps, descending=True)
|
| 129 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
| 130 |
+
|
| 131 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 132 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 133 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 134 |
+
|
| 135 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
| 136 |
+
1, sorted_indices, sorted_indices_to_remove
|
| 137 |
)
|
| 138 |
+
ps[indices_to_remove] = 0.0
|
| 139 |
+
return ps / ps.sum(dim=-1, keepdim=True)
|
| 140 |
+
|
|
|
|
| 141 |
|
| 142 |
+
def sampling_with_kvcache(
|
| 143 |
+
model_kwargs,
|
| 144 |
+
model,
|
| 145 |
+
eos_token_ids,
|
| 146 |
+
pad_token_id,
|
| 147 |
+
bos_token_id,
|
| 148 |
+
do_sample=True,
|
| 149 |
+
max_new_tokens=20,
|
| 150 |
+
temperature=1.0,
|
| 151 |
+
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
"""
|
| 153 |
+
Sampling implementation with proper KV caching.
|
| 154 |
+
|
| 155 |
Args:
|
| 156 |
prompts: List of input prompts
|
| 157 |
model: The language model
|
|
|
|
| 160 |
pad_token_id: Padding token ID
|
| 161 |
bos_token_id: Beginning-of-sequence token ID
|
| 162 |
max_new_tokens: Maximum number of new tokens to generate
|
| 163 |
+
|
| 164 |
Returns:
|
| 165 |
Generated sequences, log probabilities, and metadata
|
| 166 |
"""
|
| 167 |
# Initialize the generation process and prepare the KV cache
|
| 168 |
+
model_kwargs, input_ids = init_gen(
|
| 169 |
+
model_kwargs, model, max_new_tokens, bos_token_id
|
| 170 |
+
)
|
| 171 |
+
batch_size, _ = input_ids.shape
|
| 172 |
|
| 173 |
# Keeps track of which sequences are finished and their lengths
|
| 174 |
active_seqs = input_ids.new_ones((batch_size, 1), dtype=torch.bool)
|
|
|
|
| 175 |
# Modified log probabilities of the sequences
|
| 176 |
scores = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype)
|
| 177 |
+
# Unfiltered sequence log probabilities (temperature=1, no sampling processors applied)
|
| 178 |
+
logprobs = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype)
|
| 179 |
|
| 180 |
for i in range(max_new_tokens):
|
| 181 |
# Get the next token probabilities and update the KV cache
|
| 182 |
+
model_kwargs, logits = next_logits_with_cache_update(
|
| 183 |
+
model, model_kwargs, input_ids
|
| 184 |
+
)
|
| 185 |
+
# Store original model probabilities (temperature=1, no sampling processors applied)
|
| 186 |
model_ps = logits.softmax(-1)
|
| 187 |
+
|
| 188 |
+
# Logit processors (temperature, top-k, top-p). We can chain these!
|
| 189 |
+
ps = (logits / temperature).softmax(-1)
|
| 190 |
+
ps = _apply_top_k(ps, model)
|
| 191 |
+
ps = _apply_top_p(ps, model)
|
| 192 |
+
|
| 193 |
# Sample the next token and gather the log probabilities
|
| 194 |
+
if do_sample: # Sampling
|
| 195 |
+
next_token_ids = (
|
| 196 |
+
torch.multinomial(ps, 1) * active_seqs + pad_token_id * ~active_seqs
|
| 197 |
+
)
|
| 198 |
+
else: # Greedy decoding
|
| 199 |
+
next_token_ids = (
|
| 200 |
+
torch.argmax(ps, dim=-1).unsqueeze(-1) * active_seqs
|
| 201 |
+
+ pad_token_id * ~active_seqs
|
| 202 |
+
)
|
| 203 |
+
next_token_logprobs = ps.gather(-1, next_token_ids).log()
|
| 204 |
+
next_token_model_logprobs = model_ps.gather(-1, next_token_ids).log()
|
| 205 |
+
|
| 206 |
input_ids = torch.cat([input_ids, next_token_ids], dim=-1)
|
| 207 |
+
scores[:, i] = (next_token_logprobs * active_seqs).squeeze()
|
| 208 |
+
logprobs[:, i] = (next_token_model_logprobs * active_seqs).squeeze()
|
| 209 |
+
|
|
|
|
| 210 |
active_seqs &= ~torch.isin(next_token_ids, eos_token_ids)
|
| 211 |
if active_seqs.sum() == 0:
|
| 212 |
+
break
|
| 213 |
+
return input_ids.detach().cpu(), scores[:, : i + 1], logprobs[:, : i + 1]
|
| 214 |
+
|
| 215 |
|
| 216 |
def generate(model, **kwargs):
|
| 217 |
"""
|
| 218 |
+
Sampling strategy - multinomial sampling with temperature and optional top-k/top-p filtering.
|
| 219 |
Simple implementation with proper KV caching support.
|
| 220 |
+
|
| 221 |
Args:
|
| 222 |
model: The language model
|
| 223 |
model_kwargs: Model keyword arguments from the tokenizer
|
|
|
|
| 226 |
top_k: Only consider top-k most probable tokens
|
| 227 |
top_p: Only consider tokens with cumulative probability <= top_p
|
| 228 |
**kwargs: Additional arguments
|
| 229 |
+
|
| 230 |
Returns:
|
| 231 |
Generated token IDs
|
| 232 |
"""
|
| 233 |
generation_config = model.generation_config
|
| 234 |
+
max_new_tokens = kwargs.get("max_new_tokens", generation_config.max_new_tokens)
|
| 235 |
max_new_tokens = 512 if max_new_tokens is None else max_new_tokens
|
| 236 |
+
do_sample = kwargs.get("do_sample", True)
|
| 237 |
+
eos_token_ids = kwargs.get("eos_token_ids", generation_config.eos_token_id)
|
| 238 |
if eos_token_ids is None:
|
| 239 |
+
raise ValueError(
|
| 240 |
+
"Model generation config does not have an EOS token id. You must provide it to generate() with the eos_token_ids argument."
|
| 241 |
+
)
|
| 242 |
eos_token_ids = torch.as_tensor(eos_token_ids, device=model.device)
|
| 243 |
if eos_token_ids is not None and eos_token_ids.ndim == 0:
|
| 244 |
eos_token_ids = eos_token_ids.unsqueeze(0)
|
| 245 |
+
|
| 246 |
+
pad_token_id = kwargs.get(
|
| 247 |
+
"pad_token_id",
|
| 248 |
+
generation_config.pad_token_id
|
| 249 |
+
if generation_config.pad_token_id is not None
|
| 250 |
+
else eos_token_ids[0],
|
| 251 |
+
)
|
| 252 |
+
bos_token_id = kwargs.get("bos_token_id", generation_config.bos_token_id)
|
| 253 |
if bos_token_id is None:
|
| 254 |
+
raise ValueError(
|
| 255 |
+
"Model generation config does not have a BOS token id. You must provide it to generate() with the bos_token_id argument."
|
| 256 |
+
)
|
| 257 |
+
temperature = kwargs.get("temperature", 1.0)
|
| 258 |
+
return_dict = kwargs.get("return_dict_in_generate", False)
|
| 259 |
|
| 260 |
+
generated_ids, scores, logprobs = sampling_with_kvcache(
|
| 261 |
model_kwargs=kwargs,
|
| 262 |
model=model,
|
| 263 |
eos_token_ids=eos_token_ids,
|
|
|
|
| 265 |
bos_token_id=bos_token_id,
|
| 266 |
do_sample=do_sample,
|
| 267 |
max_new_tokens=max_new_tokens,
|
| 268 |
+
temperature=temperature,
|
| 269 |
)
|
| 270 |
|
| 271 |
if return_dict:
|
| 272 |
return {
|
| 273 |
"sequences": generated_ids,
|
| 274 |
"scores": scores,
|
| 275 |
+
"logprobs": logprobs,
|
|
|
|
|
|
|
| 276 |
}
|
| 277 |
else:
|
| 278 |
return generated_ids
|
|
|