Tom Aarsen commited on
Commit
d713204
·
1 Parent(s): 9ae8623

Integrate with transformers, sentence transformers

Browse files
Files changed (4) hide show
  1. README.md +1 -1
  2. config.json +7 -1
  3. modeling_zeranker.py +128 -206
  4. tokenizer_config.json +4 -1
README.md CHANGED
@@ -41,8 +41,8 @@ query_documents = [
41
  ]
42
 
43
  scores = model.predict(query_documents)
44
-
45
  print(scores)
 
46
  ```
47
 
48
  The model can also be inferenced using ZeroEntropy's [/models/rerank](https://docs.zeroentropy.dev/api-reference/models/rerank) endpoint, and on [AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-o7avk66msiukc).
 
41
  ]
42
 
43
  scores = model.predict(query_documents)
 
44
  print(scores)
45
+ # [0.7531883 0.28894895]
46
  ```
47
 
48
  The model can also be inferenced using ZeroEntropy's [/models/rerank](https://docs.zeroentropy.dev/api-reference/models/rerank) endpoint, and on [AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-o7avk66msiukc).
config.json CHANGED
@@ -1,9 +1,13 @@
1
  {
2
  "architectures": [
3
- "Qwen3ForCausalLM"
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
 
 
 
 
7
  "bos_token_id": 151643,
8
  "dtype": "bfloat16",
9
  "eos_token_id": 151645,
@@ -56,6 +60,8 @@
56
  "num_attention_heads": 32,
57
  "num_hidden_layers": 36,
58
  "num_key_value_heads": 8,
 
 
59
  "rms_norm_eps": 1e-06,
60
  "rope_scaling": null,
61
  "rope_theta": 1000000,
 
1
  {
2
  "architectures": [
3
+ "ZeroEntropyForSequenceClassification"
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "modeling_zeranker.ZeroEntropyConfig",
9
+ "AutoModelForSequenceClassification": "modeling_zeranker.ZeroEntropyForSequenceClassification"
10
+ },
11
  "bos_token_id": 151643,
12
  "dtype": "bfloat16",
13
  "eos_token_id": 151645,
 
60
  "num_attention_heads": 32,
61
  "num_hidden_layers": 36,
62
  "num_key_value_heads": 8,
63
+ "num_labels": 1,
64
+ "pad_token_id": 151643,
65
  "rms_norm_eps": 1e-06,
66
  "rope_scaling": null,
67
  "rope_theta": 1000000,
modeling_zeranker.py CHANGED
@@ -1,216 +1,138 @@
1
- from sentence_transformers import CrossEncoder as _CE
2
-
3
- import math
4
- from typing import cast, Any
5
- import types
 
 
 
6
 
 
7
 
 
8
  import torch
9
- from transformers.configuration_utils import PretrainedConfig
10
-
11
- from transformers.models.auto.configuration_auto import AutoConfig
12
- from transformers.models.auto.modeling_auto import AutoModelForCausalLM
13
- from transformers.models.auto.tokenization_auto import AutoTokenizer
14
- from transformers.models.gemma3.modeling_gemma3 import (
15
- Gemma3ForCausalLM,
16
- Gemma3ForConditionalGeneration,
17
- )
18
- from transformers.models.llama.modeling_llama import LlamaForCausalLM
19
- from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
20
- from transformers.tokenization_utils_base import BatchEncoding
21
  from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
22
 
23
- # pyright: reportUnknownMemberType=false
24
- # pyright: reportUnknownVariableType=false
25
-
26
- MODEL_PATH = "zeroentropy/zerank-2"
27
- PER_DEVICE_BATCH_SIZE_TOKENS = 15_000
28
- global_device = (
29
- torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
30
- )
31
-
32
-
33
- def format_pointwise_datapoints(
34
- tokenizer: PreTrainedTokenizerFast,
35
- query_documents: list[tuple[str, str]],
36
- ) -> BatchEncoding:
37
- input_texts: list[str] = []
38
- for query, document in query_documents:
39
- system_prompt = f"""
40
- {query}
41
- """.strip()
42
- user_message = f"""
43
- {document}
44
- """.strip()
45
- messages = [
46
- {"role": "system", "content": system_prompt},
47
- {"role": "user", "content": user_message},
48
- ]
49
- input_text = tokenizer.apply_chat_template(
50
- messages,
51
- tokenize=False,
52
- add_generation_prompt=True,
53
- )
54
- assert isinstance(input_text, str)
55
- input_texts.append(input_text)
56
-
57
- batch_inputs = tokenizer(
58
- input_texts,
59
- padding=True,
60
- return_tensors="pt",
61
- )
62
- return batch_inputs
63
-
64
-
65
- def load_model(
66
- device: torch.device | None = None,
67
- ) -> tuple[
68
- PreTrainedTokenizerFast,
69
- LlamaForCausalLM
70
- | Gemma3ForConditionalGeneration
71
- | Gemma3ForCausalLM
72
- | Qwen3ForCausalLM,
73
- ]:
74
- if device is None:
75
- device = global_device
76
-
77
- config = AutoConfig.from_pretrained(MODEL_PATH)
78
- assert isinstance(config, PretrainedConfig)
79
-
80
- model = AutoModelForCausalLM.from_pretrained(
81
- MODEL_PATH,
82
- torch_dtype="auto",
83
- quantization_config=None,
84
- device_map={"": device},
85
- )
86
- if config.model_type == "llama":
87
- model.config.attn_implementation = "flash_attention_2"
88
- assert isinstance(
89
- model,
90
- LlamaForCausalLM
91
- | Gemma3ForConditionalGeneration
92
- | Gemma3ForCausalLM
93
- | Qwen3ForCausalLM,
94
- )
95
-
96
- tokenizer = cast(
97
- AutoTokenizer,
98
- AutoTokenizer.from_pretrained(
99
- MODEL_PATH,
100
- padding_side="right",
101
- ),
102
- )
103
- assert isinstance(tokenizer, PreTrainedTokenizerFast)
104
-
105
- if tokenizer.pad_token is None:
106
- tokenizer.pad_token = tokenizer.eos_token
107
-
108
- return tokenizer, model
109
-
110
-
111
- def predict(
112
- self,
113
- query_documents: list[tuple[str, str]] | None = None,
114
- *,
115
- sentences: Any = None,
116
- batch_size: Any = None,
117
- show_progress_bar: Any = None,
118
- activation_fn: Any = None,
119
- apply_softmax: Any = None,
120
- convert_to_numpy: Any = None,
121
- convert_to_tensor: Any = None,
122
- ) -> list[float]:
123
- if query_documents is None:
124
- if sentences is None:
125
- raise ValueError("query_documents or sentences must be provided")
126
- query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
127
-
128
- if not hasattr(self, "inner_model"):
129
- self.inner_tokenizer, self.inner_model = load_model(global_device)
130
- self.inner_model.gradient_checkpointing_enable()
131
- self.inner_model.eval()
132
- self.inner_yes_token_id = self.inner_tokenizer.encode(
133
- "Yes", add_special_tokens=False
134
- )[0]
135
-
136
- model = self.inner_model
137
- tokenizer = self.inner_tokenizer
138
-
139
- query_documents = [
140
- (query[:2_000], document[:10_000]) for query, document in query_documents
141
- ]
142
- # Sort
143
- permutation = list(range(len(query_documents)))
144
- permutation.sort(
145
- key=lambda i: -len(query_documents[i][0]) - len(query_documents[i][1])
146
- )
147
- query_documents = [query_documents[i] for i in permutation]
148
-
149
- # Extract document batches from this line of datapoints
150
- max_length = 0
151
- batches: list[list[tuple[str, str]]] = []
152
- for query, document in query_documents:
153
- if (
154
- len(batches) == 0
155
- or (len(batches[-1]) + 1) * max(max_length, len(query) + len(document))
156
- > PER_DEVICE_BATCH_SIZE_TOKENS
157
- ):
158
- batches.append([])
159
- max_length = 0
160
-
161
- batches[-1].append((query, document))
162
- max_length = max(max_length, 20 + len(query) + len(document))
163
-
164
- # Inference all of the document batches
165
- all_logits: list[float] = []
166
- for batch in batches:
167
- batch_inputs = format_pointwise_datapoints(
168
- tokenizer,
169
- batch,
170
  )
171
 
172
- batch_inputs = batch_inputs.to(global_device)
173
-
174
- try:
175
- outputs = model(**batch_inputs, use_cache=False)
176
- except torch.OutOfMemoryError:
177
- print(f"GPU OOM! {torch.cuda.memory_reserved()}")
178
- torch.cuda.empty_cache()
179
- print(f"GPU After OOM Cache Clear: {torch.cuda.memory_reserved()}")
180
- outputs = model(**batch_inputs, use_cache=False)
181
 
182
- # Extract the logits
183
- logits = cast(torch.Tensor, outputs.logits)
184
- attention_mask = cast(torch.Tensor, batch_inputs.attention_mask)
185
  last_positions = attention_mask.sum(dim=1) - 1
186
-
187
  batch_size = logits.shape[0]
188
- batch_indices = torch.arange(batch_size, device=global_device)
189
- last_logits = logits[batch_indices, last_positions]
190
-
191
- yes_logits = last_logits[:, self.inner_yes_token_id]
192
- all_logits.extend([float(logit) / 5.0 for logit in yes_logits])
193
-
194
- def sigmoid(x: float) -> float:
195
- return 1 / (1 + math.exp(-x))
196
-
197
- scores = [sigmoid(logit) for logit in all_logits]
198
-
199
- # Unsort by indices
200
- scores = [score for _, score in sorted(zip(permutation, scores, strict=True))]
201
-
202
- return scores
203
-
204
-
205
- def to_device(self: _CE, new_device: torch.device) -> None:
206
- global global_device
207
- global_device = new_device
208
-
209
-
210
- _CE.predict = predict
211
-
212
- from transformers import Qwen3Config
213
-
214
- ZEConfig = Qwen3Config
215
-
216
- _CE.to = to_device
 
1
+ from torch import nn
2
+ from transformers.modeling_outputs import (
3
+ BaseModelOutputWithPast,
4
+ CausalLMOutputWithPast,
5
+ SequenceClassifierOutputWithPast,
6
+ )
7
+ from transformers.utils import auto_docstring
8
+ from transformers.utils.generic import TransformersKwargs, can_return_tuple
9
 
10
+ from typing import Optional, Union
11
 
12
+ from transformers.processing_utils import Unpack
13
  import torch
14
+ from transformers import Cache, Qwen3Config
15
+ from transformers.models.qwen3.modeling_qwen3 import Qwen3PreTrainedModel, Qwen3Model
 
 
 
 
 
 
 
 
 
 
16
  from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
17
 
18
+ from transformers.utils import logging
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+
23
+ class ZeroEntropyTokenizer(PreTrainedTokenizerFast):
24
+ def __init__(self, **kwargs):
25
+ super().__init__(**kwargs)
26
+
27
+ def __call__(self, pairs, *args, **kwargs):
28
+ input_texts: list[str] = []
29
+ for query, document in pairs:
30
+ messages = [
31
+ {"role": "system", "content": query.strip()},
32
+ {"role": "user", "content": document.strip()},
33
+ ]
34
+ input_text = self.apply_chat_template(
35
+ messages, tokenize=False, add_generation_prompt=True
36
+ )
37
+ assert isinstance(input_text, str)
38
+ input_texts.append(input_text)
39
+
40
+ batch_inputs = super().__call__(input_texts, *args, **kwargs)
41
+ return batch_inputs
42
+
43
+
44
+ class ZeroEntropyConfig(Qwen3Config):
45
+ model_type = "zeroentropy"
46
+
47
+ def __init__(self, yes_token_id: int = 9454, **kwargs):
48
+ super().__init__(**kwargs)
49
+ self.yes_token_id = yes_token_id
50
+
51
+
52
+ class ZeroEntropyForSequenceClassification(Qwen3PreTrainedModel):
53
+ config: ZeroEntropyConfig
54
+
55
+ _tied_weights_keys = ["lm_head.weight"]
56
+ _tp_plan = {"lm_head": "colwise_rep"}
57
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
58
+
59
+ def __init__(self, config):
60
+ super().__init__(config)
61
+ self.model = Qwen3Model(config)
62
+ self.vocab_size = config.vocab_size
63
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
64
+
65
+ # Initialize weights and apply final processing
66
+ self.post_init()
67
+
68
+ @can_return_tuple
69
+ @auto_docstring
70
+ def forward(
71
+ self,
72
+ input_ids: Optional[torch.LongTensor] = None,
73
+ attention_mask: Optional[torch.Tensor] = None,
74
+ position_ids: Optional[torch.LongTensor] = None,
75
+ past_key_values: Optional[Cache] = None,
76
+ inputs_embeds: Optional[torch.FloatTensor] = None,
77
+ labels: Optional[torch.LongTensor] = None,
78
+ use_cache: Optional[bool] = None,
79
+ cache_position: Optional[torch.LongTensor] = None,
80
+ logits_to_keep: Union[int, torch.Tensor] = 0,
81
+ **kwargs: Unpack[TransformersKwargs],
82
+ ) -> CausalLMOutputWithPast:
83
+ r"""
84
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
85
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
86
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
87
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
88
+
89
+ Example:
90
+
91
+ ```python
92
+ >>> from transformers import AutoTokenizer, Qwen3ForCausalLM
93
+
94
+ >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
95
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
96
+
97
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
98
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
99
+
100
+ >>> # Generate
101
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
102
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
103
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
104
+ ```"""
105
+ outputs: BaseModelOutputWithPast = self.model(
106
+ input_ids=input_ids,
107
+ attention_mask=attention_mask,
108
+ position_ids=position_ids,
109
+ past_key_values=past_key_values,
110
+ inputs_embeds=inputs_embeds,
111
+ use_cache=use_cache,
112
+ cache_position=cache_position,
113
+ **kwargs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  )
115
 
116
+ hidden_states = outputs.last_hidden_state
117
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
118
+ slice_indices = (
119
+ slice(-logits_to_keep, None)
120
+ if isinstance(logits_to_keep, int)
121
+ else logits_to_keep
122
+ )
123
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
 
124
 
 
 
 
125
  last_positions = attention_mask.sum(dim=1) - 1
 
126
  batch_size = logits.shape[0]
127
+ batch_indices = torch.arange(batch_size, device=logits.device)
128
+ yes_logits = logits[batch_indices, last_positions, self.config.yes_token_id]
129
+ yes_logits = yes_logits / 5.0
130
+ yes_logits = yes_logits.unsqueeze(-1)
131
+
132
+ return SequenceClassifierOutputWithPast(
133
+ loss=None,
134
+ logits=yes_logits,
135
+ past_key_values=outputs.past_key_values,
136
+ hidden_states=outputs.hidden_states,
137
+ attentions=outputs.attentions,
138
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tokenizer_config.json CHANGED
@@ -226,6 +226,9 @@
226
  "<|image_pad|>",
227
  "<|video_pad|>"
228
  ],
 
 
 
229
  "bos_token": null,
230
  "clean_up_tokenization_spaces": false,
231
  "eos_token": "<|im_end|>",
@@ -235,6 +238,6 @@
235
  "pad_token": "<|endoftext|>",
236
  "padding_side": "right",
237
  "split_special_tokens": false,
238
- "tokenizer_class": "Qwen2Tokenizer",
239
  "unk_token": null
240
  }
 
226
  "<|image_pad|>",
227
  "<|video_pad|>"
228
  ],
229
+ "auto_map": {
230
+ "AutoTokenizer": [null, "modeling_zeranker.ZeroEntropyTokenizer"]
231
+ },
232
  "bos_token": null,
233
  "clean_up_tokenization_spaces": false,
234
  "eos_token": "<|im_end|>",
 
238
  "pad_token": "<|endoftext|>",
239
  "padding_side": "right",
240
  "split_special_tokens": false,
241
+ "tokenizer_class": "ZeroEntropyTokenizer",
242
  "unk_token": null
243
  }