Tom Aarsen
commited on
Commit
·
d713204
1
Parent(s):
9ae8623
Integrate with transformers, sentence transformers
Browse files- README.md +1 -1
- config.json +7 -1
- modeling_zeranker.py +128 -206
- tokenizer_config.json +4 -1
README.md
CHANGED
|
@@ -41,8 +41,8 @@ query_documents = [
|
|
| 41 |
]
|
| 42 |
|
| 43 |
scores = model.predict(query_documents)
|
| 44 |
-
|
| 45 |
print(scores)
|
|
|
|
| 46 |
```
|
| 47 |
|
| 48 |
The model can also be inferenced using ZeroEntropy's [/models/rerank](https://docs.zeroentropy.dev/api-reference/models/rerank) endpoint, and on [AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-o7avk66msiukc).
|
|
|
|
| 41 |
]
|
| 42 |
|
| 43 |
scores = model.predict(query_documents)
|
|
|
|
| 44 |
print(scores)
|
| 45 |
+
# [0.7531883 0.28894895]
|
| 46 |
```
|
| 47 |
|
| 48 |
The model can also be inferenced using ZeroEntropy's [/models/rerank](https://docs.zeroentropy.dev/api-reference/models/rerank) endpoint, and on [AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-o7avk66msiukc).
|
config.json
CHANGED
|
@@ -1,9 +1,13 @@
|
|
| 1 |
{
|
| 2 |
"architectures": [
|
| 3 |
-
"
|
| 4 |
],
|
| 5 |
"attention_bias": false,
|
| 6 |
"attention_dropout": 0.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
"bos_token_id": 151643,
|
| 8 |
"dtype": "bfloat16",
|
| 9 |
"eos_token_id": 151645,
|
|
@@ -56,6 +60,8 @@
|
|
| 56 |
"num_attention_heads": 32,
|
| 57 |
"num_hidden_layers": 36,
|
| 58 |
"num_key_value_heads": 8,
|
|
|
|
|
|
|
| 59 |
"rms_norm_eps": 1e-06,
|
| 60 |
"rope_scaling": null,
|
| 61 |
"rope_theta": 1000000,
|
|
|
|
| 1 |
{
|
| 2 |
"architectures": [
|
| 3 |
+
"ZeroEntropyForSequenceClassification"
|
| 4 |
],
|
| 5 |
"attention_bias": false,
|
| 6 |
"attention_dropout": 0.0,
|
| 7 |
+
"auto_map": {
|
| 8 |
+
"AutoConfig": "modeling_zeranker.ZeroEntropyConfig",
|
| 9 |
+
"AutoModelForSequenceClassification": "modeling_zeranker.ZeroEntropyForSequenceClassification"
|
| 10 |
+
},
|
| 11 |
"bos_token_id": 151643,
|
| 12 |
"dtype": "bfloat16",
|
| 13 |
"eos_token_id": 151645,
|
|
|
|
| 60 |
"num_attention_heads": 32,
|
| 61 |
"num_hidden_layers": 36,
|
| 62 |
"num_key_value_heads": 8,
|
| 63 |
+
"num_labels": 1,
|
| 64 |
+
"pad_token_id": 151643,
|
| 65 |
"rms_norm_eps": 1e-06,
|
| 66 |
"rope_scaling": null,
|
| 67 |
"rope_theta": 1000000,
|
modeling_zeranker.py
CHANGED
|
@@ -1,216 +1,138 @@
|
|
| 1 |
-
from
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
| 6 |
|
|
|
|
| 7 |
|
|
|
|
| 8 |
import torch
|
| 9 |
-
from transformers
|
| 10 |
-
|
| 11 |
-
from transformers.models.auto.configuration_auto import AutoConfig
|
| 12 |
-
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
|
| 13 |
-
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
| 14 |
-
from transformers.models.gemma3.modeling_gemma3 import (
|
| 15 |
-
Gemma3ForCausalLM,
|
| 16 |
-
Gemma3ForConditionalGeneration,
|
| 17 |
-
)
|
| 18 |
-
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
| 19 |
-
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
|
| 20 |
-
from transformers.tokenization_utils_base import BatchEncoding
|
| 21 |
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
AutoTokenizer,
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
apply_softmax: Any = None,
|
| 120 |
-
convert_to_numpy: Any = None,
|
| 121 |
-
convert_to_tensor: Any = None,
|
| 122 |
-
) -> list[float]:
|
| 123 |
-
if query_documents is None:
|
| 124 |
-
if sentences is None:
|
| 125 |
-
raise ValueError("query_documents or sentences must be provided")
|
| 126 |
-
query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
|
| 127 |
-
|
| 128 |
-
if not hasattr(self, "inner_model"):
|
| 129 |
-
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
| 130 |
-
self.inner_model.gradient_checkpointing_enable()
|
| 131 |
-
self.inner_model.eval()
|
| 132 |
-
self.inner_yes_token_id = self.inner_tokenizer.encode(
|
| 133 |
-
"Yes", add_special_tokens=False
|
| 134 |
-
)[0]
|
| 135 |
-
|
| 136 |
-
model = self.inner_model
|
| 137 |
-
tokenizer = self.inner_tokenizer
|
| 138 |
-
|
| 139 |
-
query_documents = [
|
| 140 |
-
(query[:2_000], document[:10_000]) for query, document in query_documents
|
| 141 |
-
]
|
| 142 |
-
# Sort
|
| 143 |
-
permutation = list(range(len(query_documents)))
|
| 144 |
-
permutation.sort(
|
| 145 |
-
key=lambda i: -len(query_documents[i][0]) - len(query_documents[i][1])
|
| 146 |
-
)
|
| 147 |
-
query_documents = [query_documents[i] for i in permutation]
|
| 148 |
-
|
| 149 |
-
# Extract document batches from this line of datapoints
|
| 150 |
-
max_length = 0
|
| 151 |
-
batches: list[list[tuple[str, str]]] = []
|
| 152 |
-
for query, document in query_documents:
|
| 153 |
-
if (
|
| 154 |
-
len(batches) == 0
|
| 155 |
-
or (len(batches[-1]) + 1) * max(max_length, len(query) + len(document))
|
| 156 |
-
> PER_DEVICE_BATCH_SIZE_TOKENS
|
| 157 |
-
):
|
| 158 |
-
batches.append([])
|
| 159 |
-
max_length = 0
|
| 160 |
-
|
| 161 |
-
batches[-1].append((query, document))
|
| 162 |
-
max_length = max(max_length, 20 + len(query) + len(document))
|
| 163 |
-
|
| 164 |
-
# Inference all of the document batches
|
| 165 |
-
all_logits: list[float] = []
|
| 166 |
-
for batch in batches:
|
| 167 |
-
batch_inputs = format_pointwise_datapoints(
|
| 168 |
-
tokenizer,
|
| 169 |
-
batch,
|
| 170 |
)
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
outputs = model(**batch_inputs, use_cache=False)
|
| 181 |
|
| 182 |
-
# Extract the logits
|
| 183 |
-
logits = cast(torch.Tensor, outputs.logits)
|
| 184 |
-
attention_mask = cast(torch.Tensor, batch_inputs.attention_mask)
|
| 185 |
last_positions = attention_mask.sum(dim=1) - 1
|
| 186 |
-
|
| 187 |
batch_size = logits.shape[0]
|
| 188 |
-
batch_indices = torch.arange(batch_size, device=
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
yes_logits =
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
scores = [score for _, score in sorted(zip(permutation, scores, strict=True))]
|
| 201 |
-
|
| 202 |
-
return scores
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
def to_device(self: _CE, new_device: torch.device) -> None:
|
| 206 |
-
global global_device
|
| 207 |
-
global_device = new_device
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
_CE.predict = predict
|
| 211 |
-
|
| 212 |
-
from transformers import Qwen3Config
|
| 213 |
-
|
| 214 |
-
ZEConfig = Qwen3Config
|
| 215 |
-
|
| 216 |
-
_CE.to = to_device
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
from transformers.modeling_outputs import (
|
| 3 |
+
BaseModelOutputWithPast,
|
| 4 |
+
CausalLMOutputWithPast,
|
| 5 |
+
SequenceClassifierOutputWithPast,
|
| 6 |
+
)
|
| 7 |
+
from transformers.utils import auto_docstring
|
| 8 |
+
from transformers.utils.generic import TransformersKwargs, can_return_tuple
|
| 9 |
|
| 10 |
+
from typing import Optional, Union
|
| 11 |
|
| 12 |
+
from transformers.processing_utils import Unpack
|
| 13 |
import torch
|
| 14 |
+
from transformers import Cache, Qwen3Config
|
| 15 |
+
from transformers.models.qwen3.modeling_qwen3 import Qwen3PreTrainedModel, Qwen3Model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
| 17 |
|
| 18 |
+
from transformers.utils import logging
|
| 19 |
+
|
| 20 |
+
logger = logging.get_logger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ZeroEntropyTokenizer(PreTrainedTokenizerFast):
|
| 24 |
+
def __init__(self, **kwargs):
|
| 25 |
+
super().__init__(**kwargs)
|
| 26 |
+
|
| 27 |
+
def __call__(self, pairs, *args, **kwargs):
|
| 28 |
+
input_texts: list[str] = []
|
| 29 |
+
for query, document in pairs:
|
| 30 |
+
messages = [
|
| 31 |
+
{"role": "system", "content": query.strip()},
|
| 32 |
+
{"role": "user", "content": document.strip()},
|
| 33 |
+
]
|
| 34 |
+
input_text = self.apply_chat_template(
|
| 35 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 36 |
+
)
|
| 37 |
+
assert isinstance(input_text, str)
|
| 38 |
+
input_texts.append(input_text)
|
| 39 |
+
|
| 40 |
+
batch_inputs = super().__call__(input_texts, *args, **kwargs)
|
| 41 |
+
return batch_inputs
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ZeroEntropyConfig(Qwen3Config):
|
| 45 |
+
model_type = "zeroentropy"
|
| 46 |
+
|
| 47 |
+
def __init__(self, yes_token_id: int = 9454, **kwargs):
|
| 48 |
+
super().__init__(**kwargs)
|
| 49 |
+
self.yes_token_id = yes_token_id
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class ZeroEntropyForSequenceClassification(Qwen3PreTrainedModel):
|
| 53 |
+
config: ZeroEntropyConfig
|
| 54 |
+
|
| 55 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 56 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 57 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 58 |
+
|
| 59 |
+
def __init__(self, config):
|
| 60 |
+
super().__init__(config)
|
| 61 |
+
self.model = Qwen3Model(config)
|
| 62 |
+
self.vocab_size = config.vocab_size
|
| 63 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 64 |
+
|
| 65 |
+
# Initialize weights and apply final processing
|
| 66 |
+
self.post_init()
|
| 67 |
+
|
| 68 |
+
@can_return_tuple
|
| 69 |
+
@auto_docstring
|
| 70 |
+
def forward(
|
| 71 |
+
self,
|
| 72 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 73 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 74 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 75 |
+
past_key_values: Optional[Cache] = None,
|
| 76 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 77 |
+
labels: Optional[torch.LongTensor] = None,
|
| 78 |
+
use_cache: Optional[bool] = None,
|
| 79 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 80 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 81 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 82 |
+
) -> CausalLMOutputWithPast:
|
| 83 |
+
r"""
|
| 84 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 85 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 86 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 87 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 88 |
+
|
| 89 |
+
Example:
|
| 90 |
+
|
| 91 |
+
```python
|
| 92 |
+
>>> from transformers import AutoTokenizer, Qwen3ForCausalLM
|
| 93 |
+
|
| 94 |
+
>>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
|
| 95 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
|
| 96 |
+
|
| 97 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 98 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 99 |
+
|
| 100 |
+
>>> # Generate
|
| 101 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 102 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 103 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 104 |
+
```"""
|
| 105 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 106 |
+
input_ids=input_ids,
|
| 107 |
+
attention_mask=attention_mask,
|
| 108 |
+
position_ids=position_ids,
|
| 109 |
+
past_key_values=past_key_values,
|
| 110 |
+
inputs_embeds=inputs_embeds,
|
| 111 |
+
use_cache=use_cache,
|
| 112 |
+
cache_position=cache_position,
|
| 113 |
+
**kwargs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
)
|
| 115 |
|
| 116 |
+
hidden_states = outputs.last_hidden_state
|
| 117 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 118 |
+
slice_indices = (
|
| 119 |
+
slice(-logits_to_keep, None)
|
| 120 |
+
if isinstance(logits_to_keep, int)
|
| 121 |
+
else logits_to_keep
|
| 122 |
+
)
|
| 123 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
|
|
| 124 |
|
|
|
|
|
|
|
|
|
|
| 125 |
last_positions = attention_mask.sum(dim=1) - 1
|
|
|
|
| 126 |
batch_size = logits.shape[0]
|
| 127 |
+
batch_indices = torch.arange(batch_size, device=logits.device)
|
| 128 |
+
yes_logits = logits[batch_indices, last_positions, self.config.yes_token_id]
|
| 129 |
+
yes_logits = yes_logits / 5.0
|
| 130 |
+
yes_logits = yes_logits.unsqueeze(-1)
|
| 131 |
+
|
| 132 |
+
return SequenceClassifierOutputWithPast(
|
| 133 |
+
loss=None,
|
| 134 |
+
logits=yes_logits,
|
| 135 |
+
past_key_values=outputs.past_key_values,
|
| 136 |
+
hidden_states=outputs.hidden_states,
|
| 137 |
+
attentions=outputs.attentions,
|
| 138 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer_config.json
CHANGED
|
@@ -226,6 +226,9 @@
|
|
| 226 |
"<|image_pad|>",
|
| 227 |
"<|video_pad|>"
|
| 228 |
],
|
|
|
|
|
|
|
|
|
|
| 229 |
"bos_token": null,
|
| 230 |
"clean_up_tokenization_spaces": false,
|
| 231 |
"eos_token": "<|im_end|>",
|
|
@@ -235,6 +238,6 @@
|
|
| 235 |
"pad_token": "<|endoftext|>",
|
| 236 |
"padding_side": "right",
|
| 237 |
"split_special_tokens": false,
|
| 238 |
-
"tokenizer_class": "
|
| 239 |
"unk_token": null
|
| 240 |
}
|
|
|
|
| 226 |
"<|image_pad|>",
|
| 227 |
"<|video_pad|>"
|
| 228 |
],
|
| 229 |
+
"auto_map": {
|
| 230 |
+
"AutoTokenizer": [null, "modeling_zeranker.ZeroEntropyTokenizer"]
|
| 231 |
+
},
|
| 232 |
"bos_token": null,
|
| 233 |
"clean_up_tokenization_spaces": false,
|
| 234 |
"eos_token": "<|im_end|>",
|
|
|
|
| 238 |
"pad_token": "<|endoftext|>",
|
| 239 |
"padding_side": "right",
|
| 240 |
"split_special_tokens": false,
|
| 241 |
+
"tokenizer_class": "ZeroEntropyTokenizer",
|
| 242 |
"unk_token": null
|
| 243 |
}
|