Ti-Tai Wang commited on
Commit ·
619b90a
1
Parent(s): be5a948
update to 4.56.0.dev0-08152025
Browse files- onnx/modeling_gemma3.py +209 -104
onnx/modeling_gemma3.py
CHANGED
|
@@ -27,20 +27,21 @@ from typing import Optional, Union
|
|
| 27 |
import torch
|
| 28 |
import torch.nn as nn
|
| 29 |
|
| 30 |
-
from
|
| 31 |
-
from
|
| 32 |
-
from
|
| 33 |
-
from
|
| 34 |
-
from
|
| 35 |
-
from
|
| 36 |
-
from
|
| 37 |
-
from
|
| 38 |
-
from
|
| 39 |
-
from
|
| 40 |
-
from
|
| 41 |
-
from
|
| 42 |
-
from
|
| 43 |
-
from
|
|
|
|
| 44 |
from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig
|
| 45 |
|
| 46 |
|
|
@@ -55,7 +56,7 @@ logger = logging.get_logger(__name__)
|
|
| 55 |
)
|
| 56 |
class Gemma3ModelOutputWithPast(BaseModelOutputWithPast):
|
| 57 |
r"""
|
| 58 |
-
past_key_values (`
|
| 59 |
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
| 60 |
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
| 61 |
|
|
@@ -81,7 +82,7 @@ class Gemma3CausalLMOutputWithPast(ModelOutput):
|
|
| 81 |
Language modeling loss (for next-token prediction).
|
| 82 |
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
|
| 83 |
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 84 |
-
past_key_values (`
|
| 85 |
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
| 86 |
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
| 87 |
|
|
@@ -150,10 +151,12 @@ class Gemma3RMSNorm(nn.Module):
|
|
| 150 |
|
| 151 |
|
| 152 |
class Gemma3RotaryEmbedding(nn.Module):
|
|
|
|
|
|
|
| 153 |
def __init__(self, config: Gemma3TextConfig, device=None):
|
| 154 |
super().__init__()
|
| 155 |
# BC: "rope_type" was originally "type"
|
| 156 |
-
if hasattr(config, "rope_scaling") and config.rope_scaling
|
| 157 |
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 158 |
else:
|
| 159 |
self.rope_type = "default"
|
|
@@ -296,12 +299,13 @@ class Gemma3Attention(nn.Module):
|
|
| 296 |
self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
|
| 297 |
self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
|
| 298 |
|
|
|
|
| 299 |
def forward(
|
| 300 |
self,
|
| 301 |
hidden_states: torch.Tensor,
|
| 302 |
position_embeddings: torch.Tensor,
|
| 303 |
attention_mask: Optional[torch.Tensor],
|
| 304 |
-
|
| 305 |
cache_position: Optional[torch.LongTensor] = None,
|
| 306 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 307 |
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
@@ -318,10 +322,10 @@ class Gemma3Attention(nn.Module):
|
|
| 318 |
cos, sin = position_embeddings
|
| 319 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 320 |
|
| 321 |
-
if
|
| 322 |
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 323 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 324 |
-
key_states, value_states =
|
| 325 |
|
| 326 |
attention_interface: Callable = eager_attention_forward
|
| 327 |
if self.config._attn_implementation != "eager":
|
|
@@ -358,7 +362,7 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer):
|
|
| 358 |
self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
| 359 |
self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
| 360 |
|
| 361 |
-
@deprecate_kwarg("
|
| 362 |
def forward(
|
| 363 |
self,
|
| 364 |
hidden_states: torch.Tensor,
|
|
@@ -366,7 +370,7 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer):
|
|
| 366 |
position_embeddings_local: torch.Tensor,
|
| 367 |
attention_mask: Optional[torch.Tensor] = None,
|
| 368 |
position_ids: Optional[torch.LongTensor] = None,
|
| 369 |
-
|
| 370 |
output_attentions: Optional[bool] = False,
|
| 371 |
use_cache: Optional[bool] = False,
|
| 372 |
cache_position: Optional[torch.LongTensor] = None,
|
|
@@ -387,7 +391,7 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer):
|
|
| 387 |
position_embeddings=position_embeddings,
|
| 388 |
attention_mask=attention_mask,
|
| 389 |
position_ids=position_ids,
|
| 390 |
-
|
| 391 |
output_attentions=output_attentions,
|
| 392 |
use_cache=use_cache,
|
| 393 |
cache_position=cache_position,
|
|
@@ -412,7 +416,7 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer):
|
|
| 412 |
|
| 413 |
@auto_docstring
|
| 414 |
class Gemma3PreTrainedModel(PreTrainedModel):
|
| 415 |
-
|
| 416 |
base_model_prefix = ""
|
| 417 |
supports_gradient_checkpointing = True
|
| 418 |
_no_split_modules = [
|
|
@@ -422,35 +426,26 @@ class Gemma3PreTrainedModel(PreTrainedModel):
|
|
| 422 |
"SiglipMultiheadAttentionPoolingHead",
|
| 423 |
]
|
| 424 |
_skip_keys_device_placement = ["past_key_values"]
|
| 425 |
-
|
| 426 |
-
_supports_flash_attn_2 = True
|
| 427 |
_supports_sdpa = True
|
| 428 |
_supports_flex_attn = True
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
_supports_static_cache = True
|
| 432 |
_supports_attention_backend = True
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
|
| 434 |
def _init_weights(self, module):
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 438 |
-
module.weight.data.normal_(mean=0.0, std=std)
|
| 439 |
-
if module.bias is not None:
|
| 440 |
-
module.bias.data.zero_()
|
| 441 |
-
elif isinstance(module, nn.Embedding):
|
| 442 |
-
module.weight.data.normal_(mean=0.0, std=std)
|
| 443 |
-
if module.padding_idx is not None:
|
| 444 |
-
module.weight.data[module.padding_idx].zero_()
|
| 445 |
-
elif isinstance(module, Gemma3RMSNorm):
|
| 446 |
-
module.weight.data.fill_(1.0)
|
| 447 |
-
elif isinstance(module, Gemma3MultiModalProjector):
|
| 448 |
module.mm_input_projection_weight.data.zero_()
|
| 449 |
|
| 450 |
|
| 451 |
@auto_docstring
|
| 452 |
class Gemma3TextModel(Gemma3PreTrainedModel):
|
| 453 |
-
|
| 454 |
|
| 455 |
def __init__(self, config: Gemma3TextConfig):
|
| 456 |
super().__init__(config)
|
|
@@ -478,13 +473,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
|
|
| 478 |
# Initialize weights and apply final processing
|
| 479 |
self.post_init()
|
| 480 |
|
| 481 |
-
|
| 482 |
-
return self.embed_tokens
|
| 483 |
-
|
| 484 |
-
def set_input_embeddings(self, value):
|
| 485 |
-
self.embed_tokens = value
|
| 486 |
-
|
| 487 |
-
@can_return_tuple
|
| 488 |
@auto_docstring
|
| 489 |
def forward(
|
| 490 |
self,
|
|
@@ -497,7 +486,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
|
|
| 497 |
output_attentions: Optional[bool] = None,
|
| 498 |
output_hidden_states: Optional[bool] = None,
|
| 499 |
cache_position: Optional[torch.LongTensor] = None,
|
| 500 |
-
**
|
| 501 |
) -> BaseModelOutputWithPast:
|
| 502 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 503 |
output_hidden_states = (
|
|
@@ -518,7 +507,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
|
|
| 518 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 519 |
|
| 520 |
if use_cache and past_key_values is None and not self.training:
|
| 521 |
-
past_key_values = DynamicCache()
|
| 522 |
|
| 523 |
if cache_position is None:
|
| 524 |
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
@@ -569,11 +558,11 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
|
|
| 569 |
position_embeddings_local=position_embeddings_local,
|
| 570 |
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
| 571 |
position_ids=position_ids,
|
| 572 |
-
|
| 573 |
output_attentions=output_attentions,
|
| 574 |
use_cache=use_cache,
|
| 575 |
cache_position=cache_position,
|
| 576 |
-
**
|
| 577 |
)
|
| 578 |
|
| 579 |
hidden_states = layer_outputs[0]
|
|
@@ -599,7 +588,7 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
|
|
| 599 |
_tied_weights_keys = ["lm_head.weight"]
|
| 600 |
_tp_plan = {"lm_head": "colwise_rep"}
|
| 601 |
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 602 |
-
|
| 603 |
base_model_prefix = "language_model"
|
| 604 |
|
| 605 |
def __init__(self, config: Gemma3TextConfig):
|
|
@@ -611,18 +600,6 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
|
|
| 611 |
# Initialize weights and apply final processing
|
| 612 |
self.post_init()
|
| 613 |
|
| 614 |
-
def get_input_embeddings(self):
|
| 615 |
-
return self.model.embed_tokens
|
| 616 |
-
|
| 617 |
-
def set_input_embeddings(self, value):
|
| 618 |
-
self.model.embed_tokens = value
|
| 619 |
-
|
| 620 |
-
def get_output_embeddings(self):
|
| 621 |
-
return self.lm_head
|
| 622 |
-
|
| 623 |
-
def set_output_embeddings(self, new_embeddings):
|
| 624 |
-
self.lm_head = new_embeddings
|
| 625 |
-
|
| 626 |
def set_decoder(self, decoder):
|
| 627 |
self.model = decoder
|
| 628 |
|
|
@@ -644,14 +621,9 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
|
|
| 644 |
output_hidden_states: Optional[bool] = None,
|
| 645 |
cache_position: Optional[torch.LongTensor] = None,
|
| 646 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 647 |
-
**
|
| 648 |
) -> CausalLMOutputWithPast:
|
| 649 |
r"""
|
| 650 |
-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 651 |
-
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 652 |
-
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 653 |
-
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 654 |
-
|
| 655 |
Example:
|
| 656 |
|
| 657 |
```python
|
|
@@ -689,7 +661,7 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
|
|
| 689 |
output_attentions=output_attentions,
|
| 690 |
output_hidden_states=output_hidden_states,
|
| 691 |
cache_position=cache_position,
|
| 692 |
-
**
|
| 693 |
)
|
| 694 |
|
| 695 |
hidden_states = outputs.last_hidden_state
|
|
@@ -703,7 +675,7 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
|
|
| 703 |
|
| 704 |
loss = None
|
| 705 |
if labels is not None:
|
| 706 |
-
loss = self.loss_function(logits, labels, self.vocab_size, **
|
| 707 |
|
| 708 |
return CausalLMOutputWithPast(
|
| 709 |
loss=loss,
|
|
@@ -750,7 +722,11 @@ class Gemma3MultiModalProjector(nn.Module):
|
|
| 750 |
return projected_vision_outputs.type_as(vision_outputs)
|
| 751 |
|
| 752 |
|
| 753 |
-
def token_type_ids_mask_function(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 754 |
"""
|
| 755 |
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
|
| 756 |
not start and end indices.
|
|
@@ -760,10 +736,18 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_
|
|
| 760 |
return None
|
| 761 |
|
| 762 |
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
| 763 |
-
# If the difference is less than image size, both are part of the same image block
|
| 764 |
-
same_image_block = torch.abs(kv_idx - q_idx) <= tokens_per_image
|
| 765 |
# If it's 1 for both query and key/value, we are in an image block
|
| 766 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 767 |
|
| 768 |
# This is bidirectional attention whenever we are dealing with image tokens
|
| 769 |
return is_image_block & same_image_block
|
|
@@ -850,6 +834,7 @@ class Gemma3Model(Gemma3PreTrainedModel):
|
|
| 850 |
|
| 851 |
def image_features_is_none(inputs_embeds, image_features=None):
|
| 852 |
return inputs_embeds
|
|
|
|
| 853 |
def image_features_is_not_none(inputs_embeds, image_features=None):
|
| 854 |
# input_ids: [batch_size, seq_len]
|
| 855 |
# input_embeds: [batch_size, seq_len, 2560 (hidden_size)]
|
|
@@ -867,9 +852,33 @@ class Gemma3Model(Gemma3PreTrainedModel):
|
|
| 867 |
image_features_is_not_none,
|
| 868 |
(inputs_embeds, image_features,)
|
| 869 |
)
|
| 870 |
-
|
| 871 |
return inputs_embeds
|
| 872 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 873 |
@can_return_tuple
|
| 874 |
@auto_docstring
|
| 875 |
def forward(
|
|
@@ -945,23 +954,10 @@ class Gemma3Model(Gemma3PreTrainedModel):
|
|
| 945 |
# Merge text and images
|
| 946 |
if pixel_values is not None:
|
| 947 |
image_features = self.get_image_features(pixel_values)
|
| 948 |
-
|
| 949 |
-
if input_ids is None:
|
| 950 |
-
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
| 951 |
-
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
| 952 |
-
)
|
| 953 |
-
else:
|
| 954 |
-
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
| 955 |
-
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
| 956 |
-
|
| 957 |
-
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
| 958 |
-
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
| 959 |
-
raise ValueError(
|
| 960 |
-
f"Number of images does not match number of special image tokens in the input text. "
|
| 961 |
-
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
| 962 |
-
"tokens from image embeddings."
|
| 963 |
-
)
|
| 964 |
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
|
|
|
|
|
|
|
|
| 965 |
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
| 966 |
|
| 967 |
# It may already have been prepared by e.g. `generate`
|
|
@@ -977,8 +973,17 @@ class Gemma3Model(Gemma3PreTrainedModel):
|
|
| 977 |
}
|
| 978 |
if token_type_ids is not None and inputs_embeds.shape[1] != 1:
|
| 979 |
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 980 |
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
|
| 981 |
-
token_type_ids.to(cache_position.device), self.config.mm_tokens_per_image
|
| 982 |
)
|
| 983 |
|
| 984 |
# Create the masks
|
|
@@ -1035,12 +1040,6 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
|
| 1035 |
def set_input_embeddings(self, value):
|
| 1036 |
self.model.set_input_embeddings(value)
|
| 1037 |
|
| 1038 |
-
def get_output_embeddings(self):
|
| 1039 |
-
return self.lm_head
|
| 1040 |
-
|
| 1041 |
-
def set_output_embeddings(self, new_embeddings):
|
| 1042 |
-
self.lm_head = new_embeddings
|
| 1043 |
-
|
| 1044 |
def set_decoder(self, decoder):
|
| 1045 |
self.model.set_decoder(decoder)
|
| 1046 |
|
|
@@ -1219,6 +1218,10 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
|
| 1219 |
**kwargs,
|
| 1220 |
)
|
| 1221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1222 |
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
| 1223 |
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
|
| 1224 |
if cache_position[0] == 0:
|
|
@@ -1249,17 +1252,119 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
|
| 1249 |
# Add the token type ids mask for generate as well
|
| 1250 |
if token_type_ids is not None and input_embeds.shape[1] != 1:
|
| 1251 |
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1252 |
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
|
| 1253 |
-
token_type_ids.to(cache_position.device), config.mm_tokens_per_image
|
| 1254 |
)
|
| 1255 |
|
| 1256 |
return create_masks_for_generate(**mask_kwargs)
|
| 1257 |
|
| 1258 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1259 |
__all__ = [
|
| 1260 |
"Gemma3PreTrainedModel",
|
| 1261 |
"Gemma3TextModel",
|
| 1262 |
"Gemma3ForCausalLM",
|
| 1263 |
"Gemma3ForConditionalGeneration",
|
| 1264 |
"Gemma3Model",
|
|
|
|
| 1265 |
]
|
|
|
|
| 27 |
import torch
|
| 28 |
import torch.nn as nn
|
| 29 |
|
| 30 |
+
from ...activations import ACT2FN
|
| 31 |
+
from ...cache_utils import Cache, DynamicCache
|
| 32 |
+
from ...configuration_utils import PretrainedConfig
|
| 33 |
+
from ...generation import GenerationMixin
|
| 34 |
+
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
|
| 35 |
+
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
| 36 |
+
from ...modeling_layers import GradientCheckpointingLayer
|
| 37 |
+
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
| 38 |
+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 39 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 40 |
+
from ...processing_utils import Unpack
|
| 41 |
+
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging
|
| 42 |
+
from ...utils.deprecation import deprecate_kwarg
|
| 43 |
+
from ...utils.generic import check_model_inputs
|
| 44 |
+
from ..auto import AutoModel
|
| 45 |
from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig
|
| 46 |
|
| 47 |
|
|
|
|
| 56 |
)
|
| 57 |
class Gemma3ModelOutputWithPast(BaseModelOutputWithPast):
|
| 58 |
r"""
|
| 59 |
+
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 60 |
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
| 61 |
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
| 62 |
|
|
|
|
| 82 |
Language modeling loss (for next-token prediction).
|
| 83 |
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
|
| 84 |
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 85 |
+
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 86 |
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
| 87 |
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
| 88 |
|
|
|
|
| 151 |
|
| 152 |
|
| 153 |
class Gemma3RotaryEmbedding(nn.Module):
|
| 154 |
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 155 |
+
|
| 156 |
def __init__(self, config: Gemma3TextConfig, device=None):
|
| 157 |
super().__init__()
|
| 158 |
# BC: "rope_type" was originally "type"
|
| 159 |
+
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
|
| 160 |
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 161 |
else:
|
| 162 |
self.rope_type = "default"
|
|
|
|
| 299 |
self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
|
| 300 |
self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
|
| 301 |
|
| 302 |
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 303 |
def forward(
|
| 304 |
self,
|
| 305 |
hidden_states: torch.Tensor,
|
| 306 |
position_embeddings: torch.Tensor,
|
| 307 |
attention_mask: Optional[torch.Tensor],
|
| 308 |
+
past_key_values: Optional[Cache] = None,
|
| 309 |
cache_position: Optional[torch.LongTensor] = None,
|
| 310 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 311 |
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
|
|
| 322 |
cos, sin = position_embeddings
|
| 323 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 324 |
|
| 325 |
+
if past_key_values is not None:
|
| 326 |
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 327 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 328 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 329 |
|
| 330 |
attention_interface: Callable = eager_attention_forward
|
| 331 |
if self.config._attn_implementation != "eager":
|
|
|
|
| 362 |
self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
| 363 |
self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
| 364 |
|
| 365 |
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 366 |
def forward(
|
| 367 |
self,
|
| 368 |
hidden_states: torch.Tensor,
|
|
|
|
| 370 |
position_embeddings_local: torch.Tensor,
|
| 371 |
attention_mask: Optional[torch.Tensor] = None,
|
| 372 |
position_ids: Optional[torch.LongTensor] = None,
|
| 373 |
+
past_key_values: Optional[Cache] = None,
|
| 374 |
output_attentions: Optional[bool] = False,
|
| 375 |
use_cache: Optional[bool] = False,
|
| 376 |
cache_position: Optional[torch.LongTensor] = None,
|
|
|
|
| 391 |
position_embeddings=position_embeddings,
|
| 392 |
attention_mask=attention_mask,
|
| 393 |
position_ids=position_ids,
|
| 394 |
+
past_key_values=past_key_values,
|
| 395 |
output_attentions=output_attentions,
|
| 396 |
use_cache=use_cache,
|
| 397 |
cache_position=cache_position,
|
|
|
|
| 416 |
|
| 417 |
@auto_docstring
|
| 418 |
class Gemma3PreTrainedModel(PreTrainedModel):
|
| 419 |
+
config: Gemma3Config
|
| 420 |
base_model_prefix = ""
|
| 421 |
supports_gradient_checkpointing = True
|
| 422 |
_no_split_modules = [
|
|
|
|
| 426 |
"SiglipMultiheadAttentionPoolingHead",
|
| 427 |
]
|
| 428 |
_skip_keys_device_placement = ["past_key_values"]
|
| 429 |
+
_supports_flash_attn = True
|
|
|
|
| 430 |
_supports_sdpa = True
|
| 431 |
_supports_flex_attn = True
|
| 432 |
+
|
| 433 |
+
_can_compile_fullgraph = True
|
|
|
|
| 434 |
_supports_attention_backend = True
|
| 435 |
+
_can_record_outputs = {
|
| 436 |
+
"hidden_states": Gemma3DecoderLayer,
|
| 437 |
+
"attentions": Gemma3Attention,
|
| 438 |
+
}
|
| 439 |
|
| 440 |
def _init_weights(self, module):
|
| 441 |
+
super()._init_weights(module)
|
| 442 |
+
if isinstance(module, Gemma3MultiModalProjector):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
module.mm_input_projection_weight.data.zero_()
|
| 444 |
|
| 445 |
|
| 446 |
@auto_docstring
|
| 447 |
class Gemma3TextModel(Gemma3PreTrainedModel):
|
| 448 |
+
config: Gemma3TextConfig
|
| 449 |
|
| 450 |
def __init__(self, config: Gemma3TextConfig):
|
| 451 |
super().__init__(config)
|
|
|
|
| 473 |
# Initialize weights and apply final processing
|
| 474 |
self.post_init()
|
| 475 |
|
| 476 |
+
@check_model_inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
@auto_docstring
|
| 478 |
def forward(
|
| 479 |
self,
|
|
|
|
| 486 |
output_attentions: Optional[bool] = None,
|
| 487 |
output_hidden_states: Optional[bool] = None,
|
| 488 |
cache_position: Optional[torch.LongTensor] = None,
|
| 489 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 490 |
) -> BaseModelOutputWithPast:
|
| 491 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 492 |
output_hidden_states = (
|
|
|
|
| 507 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 508 |
|
| 509 |
if use_cache and past_key_values is None and not self.training:
|
| 510 |
+
past_key_values = DynamicCache(config=self.config)
|
| 511 |
|
| 512 |
if cache_position is None:
|
| 513 |
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
|
|
| 558 |
position_embeddings_local=position_embeddings_local,
|
| 559 |
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
| 560 |
position_ids=position_ids,
|
| 561 |
+
past_key_values=past_key_values,
|
| 562 |
output_attentions=output_attentions,
|
| 563 |
use_cache=use_cache,
|
| 564 |
cache_position=cache_position,
|
| 565 |
+
**kwargs,
|
| 566 |
)
|
| 567 |
|
| 568 |
hidden_states = layer_outputs[0]
|
|
|
|
| 588 |
_tied_weights_keys = ["lm_head.weight"]
|
| 589 |
_tp_plan = {"lm_head": "colwise_rep"}
|
| 590 |
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 591 |
+
config: Gemma3TextConfig
|
| 592 |
base_model_prefix = "language_model"
|
| 593 |
|
| 594 |
def __init__(self, config: Gemma3TextConfig):
|
|
|
|
| 600 |
# Initialize weights and apply final processing
|
| 601 |
self.post_init()
|
| 602 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
def set_decoder(self, decoder):
|
| 604 |
self.model = decoder
|
| 605 |
|
|
|
|
| 621 |
output_hidden_states: Optional[bool] = None,
|
| 622 |
cache_position: Optional[torch.LongTensor] = None,
|
| 623 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 624 |
+
**kwargs,
|
| 625 |
) -> CausalLMOutputWithPast:
|
| 626 |
r"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 627 |
Example:
|
| 628 |
|
| 629 |
```python
|
|
|
|
| 661 |
output_attentions=output_attentions,
|
| 662 |
output_hidden_states=output_hidden_states,
|
| 663 |
cache_position=cache_position,
|
| 664 |
+
**kwargs,
|
| 665 |
)
|
| 666 |
|
| 667 |
hidden_states = outputs.last_hidden_state
|
|
|
|
| 675 |
|
| 676 |
loss = None
|
| 677 |
if labels is not None:
|
| 678 |
+
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
| 679 |
|
| 680 |
return CausalLMOutputWithPast(
|
| 681 |
loss=loss,
|
|
|
|
| 722 |
return projected_vision_outputs.type_as(vision_outputs)
|
| 723 |
|
| 724 |
|
| 725 |
+
def token_type_ids_mask_function(
|
| 726 |
+
token_type_ids: Optional[torch.Tensor],
|
| 727 |
+
image_group_ids: Optional[torch.Tensor],
|
| 728 |
+
tokens_per_image: int,
|
| 729 |
+
) -> Optional[Callable]:
|
| 730 |
"""
|
| 731 |
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
|
| 732 |
not start and end indices.
|
|
|
|
| 736 |
return None
|
| 737 |
|
| 738 |
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
|
|
|
|
|
|
| 739 |
# If it's 1 for both query and key/value, we are in an image block
|
| 740 |
+
# NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
|
| 741 |
+
# Since vmap doesn't support `if statement` we workaround it with `torch.where`
|
| 742 |
+
safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
|
| 743 |
+
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx]
|
| 744 |
+
token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)
|
| 745 |
+
|
| 746 |
+
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_idx]
|
| 747 |
+
image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1)
|
| 748 |
+
|
| 749 |
+
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1)
|
| 750 |
+
same_image_block = image_group_ids[batch_idx, q_idx] == image_group_ids_at_kv_idx
|
| 751 |
|
| 752 |
# This is bidirectional attention whenever we are dealing with image tokens
|
| 753 |
return is_image_block & same_image_block
|
|
|
|
| 834 |
|
| 835 |
def image_features_is_none(inputs_embeds, image_features=None):
|
| 836 |
return inputs_embeds
|
| 837 |
+
|
| 838 |
def image_features_is_not_none(inputs_embeds, image_features=None):
|
| 839 |
# input_ids: [batch_size, seq_len]
|
| 840 |
# input_embeds: [batch_size, seq_len, 2560 (hidden_size)]
|
|
|
|
| 852 |
image_features_is_not_none,
|
| 853 |
(inputs_embeds, image_features,)
|
| 854 |
)
|
| 855 |
+
|
| 856 |
return inputs_embeds
|
| 857 |
|
| 858 |
+
def get_placeholder_mask(
|
| 859 |
+
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
|
| 860 |
+
):
|
| 861 |
+
"""
|
| 862 |
+
Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
|
| 863 |
+
equal to the length of multimodal features. If the lengths are different, an error is raised.
|
| 864 |
+
"""
|
| 865 |
+
if input_ids is None:
|
| 866 |
+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
| 867 |
+
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
| 868 |
+
)
|
| 869 |
+
special_image_mask = special_image_mask.all(-1)
|
| 870 |
+
else:
|
| 871 |
+
special_image_mask = input_ids == self.config.image_token_id
|
| 872 |
+
|
| 873 |
+
n_image_tokens = special_image_mask.sum()
|
| 874 |
+
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
| 875 |
+
n_image_features = image_features.shape[0] * image_features.shape[1]
|
| 876 |
+
if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
| 877 |
+
raise ValueError(
|
| 878 |
+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
| 879 |
+
)
|
| 880 |
+
return special_image_mask
|
| 881 |
+
|
| 882 |
@can_return_tuple
|
| 883 |
@auto_docstring
|
| 884 |
def forward(
|
|
|
|
| 954 |
# Merge text and images
|
| 955 |
if pixel_values is not None:
|
| 956 |
image_features = self.get_image_features(pixel_values)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 957 |
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 958 |
+
special_image_mask = self.get_placeholder_mask(
|
| 959 |
+
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
|
| 960 |
+
)
|
| 961 |
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
| 962 |
|
| 963 |
# It may already have been prepared by e.g. `generate`
|
|
|
|
| 973 |
}
|
| 974 |
if token_type_ids is not None and inputs_embeds.shape[1] != 1:
|
| 975 |
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
|
| 976 |
+
|
| 977 |
+
# First find where a new image block starts: 1 if image and previous not image
|
| 978 |
+
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
|
| 979 |
+
is_image = (token_type_ids == 1).to(cache_position.device)
|
| 980 |
+
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
|
| 981 |
+
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
|
| 982 |
+
image_group_ids = torch.where(
|
| 983 |
+
is_image, image_group_ids, torch.full_like(token_type_ids, -1, device=is_image.device)
|
| 984 |
+
)
|
| 985 |
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
|
| 986 |
+
token_type_ids.to(cache_position.device), image_group_ids, self.config.mm_tokens_per_image
|
| 987 |
)
|
| 988 |
|
| 989 |
# Create the masks
|
|
|
|
| 1040 |
def set_input_embeddings(self, value):
|
| 1041 |
self.model.set_input_embeddings(value)
|
| 1042 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1043 |
def set_decoder(self, decoder):
|
| 1044 |
self.model.set_decoder(decoder)
|
| 1045 |
|
|
|
|
| 1218 |
**kwargs,
|
| 1219 |
)
|
| 1220 |
|
| 1221 |
+
# position_ids in Gemma3 are 1-indexed
|
| 1222 |
+
if model_inputs.get("position_ids") is not None:
|
| 1223 |
+
model_inputs["position_ids"] += 1
|
| 1224 |
+
|
| 1225 |
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
| 1226 |
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
|
| 1227 |
if cache_position[0] == 0:
|
|
|
|
| 1252 |
# Add the token type ids mask for generate as well
|
| 1253 |
if token_type_ids is not None and input_embeds.shape[1] != 1:
|
| 1254 |
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
|
| 1255 |
+
|
| 1256 |
+
# First find where a new image block starts: 1 if image and previous not image
|
| 1257 |
+
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
|
| 1258 |
+
is_image = (token_type_ids == 1).to(cache_position.device)
|
| 1259 |
+
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
|
| 1260 |
+
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
|
| 1261 |
+
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
|
| 1262 |
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
|
| 1263 |
+
token_type_ids.to(cache_position.device), image_group_ids, config.mm_tokens_per_image
|
| 1264 |
)
|
| 1265 |
|
| 1266 |
return create_masks_for_generate(**mask_kwargs)
|
| 1267 |
|
| 1268 |
|
| 1269 |
+
class Gemma3ForSequenceClassification(Gemma3PreTrainedModel):
|
| 1270 |
+
_checkpoint_conversion_mapping = {
|
| 1271 |
+
"^language_model.model": "model.language_model",
|
| 1272 |
+
"^vision_tower": "model.vision_tower",
|
| 1273 |
+
"^multi_modal_projector": "model.multi_modal_projector",
|
| 1274 |
+
}
|
| 1275 |
+
|
| 1276 |
+
def __init__(self, config):
|
| 1277 |
+
super().__init__(config)
|
| 1278 |
+
self.num_labels = config.num_labels
|
| 1279 |
+
self.model = Gemma3Model(config)
|
| 1280 |
+
self.score = nn.Linear(config.text_config.hidden_size, self.num_labels, bias=False)
|
| 1281 |
+
|
| 1282 |
+
# Initialize weights and apply final processing
|
| 1283 |
+
self.post_init()
|
| 1284 |
+
|
| 1285 |
+
def get_input_embeddings(self):
|
| 1286 |
+
return self.model.get_input_embeddings()
|
| 1287 |
+
|
| 1288 |
+
def set_input_embeddings(self, value):
|
| 1289 |
+
self.model.set_input_embeddings(value)
|
| 1290 |
+
|
| 1291 |
+
@can_return_tuple
|
| 1292 |
+
@auto_docstring
|
| 1293 |
+
def forward(
|
| 1294 |
+
self,
|
| 1295 |
+
input_ids: torch.LongTensor = None,
|
| 1296 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1297 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1298 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1299 |
+
past_key_values: Optional[Cache] = None,
|
| 1300 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1301 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1302 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1303 |
+
use_cache: Optional[bool] = None,
|
| 1304 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 1305 |
+
) -> SequenceClassifierOutputWithPast:
|
| 1306 |
+
r"""
|
| 1307 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1308 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1309 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1310 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1311 |
+
"""
|
| 1312 |
+
|
| 1313 |
+
transformer_outputs = self.model(
|
| 1314 |
+
input_ids,
|
| 1315 |
+
attention_mask=attention_mask,
|
| 1316 |
+
pixel_values=pixel_values,
|
| 1317 |
+
position_ids=position_ids,
|
| 1318 |
+
past_key_values=past_key_values,
|
| 1319 |
+
inputs_embeds=inputs_embeds,
|
| 1320 |
+
token_type_ids=token_type_ids,
|
| 1321 |
+
use_cache=use_cache,
|
| 1322 |
+
**kwargs,
|
| 1323 |
+
)
|
| 1324 |
+
hidden_states = transformer_outputs.last_hidden_state
|
| 1325 |
+
logits = self.score(hidden_states)
|
| 1326 |
+
|
| 1327 |
+
if input_ids is not None:
|
| 1328 |
+
batch_size = input_ids.shape[0]
|
| 1329 |
+
else:
|
| 1330 |
+
batch_size = inputs_embeds.shape[0]
|
| 1331 |
+
|
| 1332 |
+
if self.config.text_config.pad_token_id is None and batch_size != 1:
|
| 1333 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
| 1334 |
+
if self.config.text_config.pad_token_id is None:
|
| 1335 |
+
last_non_pad_token = -1
|
| 1336 |
+
elif input_ids is not None:
|
| 1337 |
+
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
| 1338 |
+
non_pad_mask = (input_ids != self.config.text_config.pad_token_id).to(logits.device, torch.int32)
|
| 1339 |
+
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
|
| 1340 |
+
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
| 1341 |
+
else:
|
| 1342 |
+
last_non_pad_token = -1
|
| 1343 |
+
logger.warning_once(
|
| 1344 |
+
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
| 1345 |
+
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
| 1346 |
+
)
|
| 1347 |
+
|
| 1348 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
| 1349 |
+
|
| 1350 |
+
loss = None
|
| 1351 |
+
if labels is not None:
|
| 1352 |
+
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
| 1353 |
+
|
| 1354 |
+
return SequenceClassifierOutputWithPast(
|
| 1355 |
+
loss=loss,
|
| 1356 |
+
logits=pooled_logits,
|
| 1357 |
+
past_key_values=transformer_outputs.past_key_values,
|
| 1358 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 1359 |
+
attentions=transformer_outputs.attentions,
|
| 1360 |
+
)
|
| 1361 |
+
|
| 1362 |
+
|
| 1363 |
__all__ = [
|
| 1364 |
"Gemma3PreTrainedModel",
|
| 1365 |
"Gemma3TextModel",
|
| 1366 |
"Gemma3ForCausalLM",
|
| 1367 |
"Gemma3ForConditionalGeneration",
|
| 1368 |
"Gemma3Model",
|
| 1369 |
+
"Gemma3ForSequenceClassification",
|
| 1370 |
]
|