Ti-Tai Wang commited on
Commit
619b90a
·
1 Parent(s): be5a948

update to 4.56.0.dev0-08152025

Browse files
Files changed (1) hide show
  1. onnx/modeling_gemma3.py +209 -104
onnx/modeling_gemma3.py CHANGED
@@ -27,20 +27,21 @@ from typing import Optional, Union
27
  import torch
28
  import torch.nn as nn
29
 
30
- from transformers.activations import ACT2FN
31
- from transformers.cache_utils import Cache, DynamicCache
32
- from transformers.configuration_utils import PretrainedConfig
33
- from transformers.generation import GenerationMixin
34
- from transformers.masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
35
- from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
36
- from transformers.modeling_layers import GradientCheckpointingLayer
37
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
38
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
39
- from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
40
- from transformers.processing_utils import Unpack
41
- from transformers.utils import ModelOutput, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
42
- from transformers.utils.deprecation import deprecate_kwarg
43
- from transformers import AutoModel
 
44
  from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig
45
 
46
 
@@ -55,7 +56,7 @@ logger = logging.get_logger(__name__)
55
  )
56
  class Gemma3ModelOutputWithPast(BaseModelOutputWithPast):
57
  r"""
58
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
59
  Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
60
  `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
61
 
@@ -81,7 +82,7 @@ class Gemma3CausalLMOutputWithPast(ModelOutput):
81
  Language modeling loss (for next-token prediction).
82
  logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
83
  Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
84
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
85
  Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
86
  `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
87
 
@@ -150,10 +151,12 @@ class Gemma3RMSNorm(nn.Module):
150
 
151
 
152
  class Gemma3RotaryEmbedding(nn.Module):
 
 
153
  def __init__(self, config: Gemma3TextConfig, device=None):
154
  super().__init__()
155
  # BC: "rope_type" was originally "type"
156
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
157
  self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
158
  else:
159
  self.rope_type = "default"
@@ -296,12 +299,13 @@ class Gemma3Attention(nn.Module):
296
  self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
297
  self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
298
 
 
299
  def forward(
300
  self,
301
  hidden_states: torch.Tensor,
302
  position_embeddings: torch.Tensor,
303
  attention_mask: Optional[torch.Tensor],
304
- past_key_value: Optional[Cache] = None,
305
  cache_position: Optional[torch.LongTensor] = None,
306
  **kwargs: Unpack[FlashAttentionKwargs],
307
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
@@ -318,10 +322,10 @@ class Gemma3Attention(nn.Module):
318
  cos, sin = position_embeddings
319
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
320
 
321
- if past_key_value is not None:
322
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
323
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
324
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
325
 
326
  attention_interface: Callable = eager_attention_forward
327
  if self.config._attn_implementation != "eager":
@@ -358,7 +362,7 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer):
358
  self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
359
  self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
360
 
361
- @deprecate_kwarg("last_cache_position", version="4.53.0")
362
  def forward(
363
  self,
364
  hidden_states: torch.Tensor,
@@ -366,7 +370,7 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer):
366
  position_embeddings_local: torch.Tensor,
367
  attention_mask: Optional[torch.Tensor] = None,
368
  position_ids: Optional[torch.LongTensor] = None,
369
- past_key_value: Optional[Cache] = None,
370
  output_attentions: Optional[bool] = False,
371
  use_cache: Optional[bool] = False,
372
  cache_position: Optional[torch.LongTensor] = None,
@@ -387,7 +391,7 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer):
387
  position_embeddings=position_embeddings,
388
  attention_mask=attention_mask,
389
  position_ids=position_ids,
390
- past_key_value=past_key_value,
391
  output_attentions=output_attentions,
392
  use_cache=use_cache,
393
  cache_position=cache_position,
@@ -412,7 +416,7 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer):
412
 
413
  @auto_docstring
414
  class Gemma3PreTrainedModel(PreTrainedModel):
415
- config_class = Gemma3Config
416
  base_model_prefix = ""
417
  supports_gradient_checkpointing = True
418
  _no_split_modules = [
@@ -422,35 +426,26 @@ class Gemma3PreTrainedModel(PreTrainedModel):
422
  "SiglipMultiheadAttentionPoolingHead",
423
  ]
424
  _skip_keys_device_placement = ["past_key_values"]
425
- _supports_flash_attn_3 = True
426
- _supports_flash_attn_2 = True
427
  _supports_sdpa = True
428
  _supports_flex_attn = True
429
- _supports_cache_class = True
430
- _supports_quantized_cache = True
431
- _supports_static_cache = True
432
  _supports_attention_backend = True
 
 
 
 
433
 
434
  def _init_weights(self, module):
435
- std = self.config.initializer_range
436
-
437
- if isinstance(module, (nn.Linear, nn.Conv2d)):
438
- module.weight.data.normal_(mean=0.0, std=std)
439
- if module.bias is not None:
440
- module.bias.data.zero_()
441
- elif isinstance(module, nn.Embedding):
442
- module.weight.data.normal_(mean=0.0, std=std)
443
- if module.padding_idx is not None:
444
- module.weight.data[module.padding_idx].zero_()
445
- elif isinstance(module, Gemma3RMSNorm):
446
- module.weight.data.fill_(1.0)
447
- elif isinstance(module, Gemma3MultiModalProjector):
448
  module.mm_input_projection_weight.data.zero_()
449
 
450
 
451
  @auto_docstring
452
  class Gemma3TextModel(Gemma3PreTrainedModel):
453
- config_class = Gemma3TextConfig
454
 
455
  def __init__(self, config: Gemma3TextConfig):
456
  super().__init__(config)
@@ -478,13 +473,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
478
  # Initialize weights and apply final processing
479
  self.post_init()
480
 
481
- def get_input_embeddings(self):
482
- return self.embed_tokens
483
-
484
- def set_input_embeddings(self, value):
485
- self.embed_tokens = value
486
-
487
- @can_return_tuple
488
  @auto_docstring
489
  def forward(
490
  self,
@@ -497,7 +486,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
497
  output_attentions: Optional[bool] = None,
498
  output_hidden_states: Optional[bool] = None,
499
  cache_position: Optional[torch.LongTensor] = None,
500
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
501
  ) -> BaseModelOutputWithPast:
502
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
503
  output_hidden_states = (
@@ -518,7 +507,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
518
  inputs_embeds = self.embed_tokens(input_ids)
519
 
520
  if use_cache and past_key_values is None and not self.training:
521
- past_key_values = DynamicCache()
522
 
523
  if cache_position is None:
524
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
@@ -569,11 +558,11 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
569
  position_embeddings_local=position_embeddings_local,
570
  attention_mask=causal_mask_mapping[decoder_layer.attention_type],
571
  position_ids=position_ids,
572
- past_key_value=past_key_values,
573
  output_attentions=output_attentions,
574
  use_cache=use_cache,
575
  cache_position=cache_position,
576
- **flash_attn_kwargs,
577
  )
578
 
579
  hidden_states = layer_outputs[0]
@@ -599,7 +588,7 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
599
  _tied_weights_keys = ["lm_head.weight"]
600
  _tp_plan = {"lm_head": "colwise_rep"}
601
  _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
602
- config_class = Gemma3TextConfig
603
  base_model_prefix = "language_model"
604
 
605
  def __init__(self, config: Gemma3TextConfig):
@@ -611,18 +600,6 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
611
  # Initialize weights and apply final processing
612
  self.post_init()
613
 
614
- def get_input_embeddings(self):
615
- return self.model.embed_tokens
616
-
617
- def set_input_embeddings(self, value):
618
- self.model.embed_tokens = value
619
-
620
- def get_output_embeddings(self):
621
- return self.lm_head
622
-
623
- def set_output_embeddings(self, new_embeddings):
624
- self.lm_head = new_embeddings
625
-
626
  def set_decoder(self, decoder):
627
  self.model = decoder
628
 
@@ -644,14 +621,9 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
644
  output_hidden_states: Optional[bool] = None,
645
  cache_position: Optional[torch.LongTensor] = None,
646
  logits_to_keep: Union[int, torch.Tensor] = 0,
647
- **loss_kwargs,
648
  ) -> CausalLMOutputWithPast:
649
  r"""
650
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
651
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
652
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
653
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
654
-
655
  Example:
656
 
657
  ```python
@@ -689,7 +661,7 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
689
  output_attentions=output_attentions,
690
  output_hidden_states=output_hidden_states,
691
  cache_position=cache_position,
692
- **loss_kwargs,
693
  )
694
 
695
  hidden_states = outputs.last_hidden_state
@@ -703,7 +675,7 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
703
 
704
  loss = None
705
  if labels is not None:
706
- loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
707
 
708
  return CausalLMOutputWithPast(
709
  loss=loss,
@@ -750,7 +722,11 @@ class Gemma3MultiModalProjector(nn.Module):
750
  return projected_vision_outputs.type_as(vision_outputs)
751
 
752
 
753
- def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_per_image: int) -> Optional[Callable]:
 
 
 
 
754
  """
755
  This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
756
  not start and end indices.
@@ -760,10 +736,18 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_
760
  return None
761
 
762
  def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
763
- # If the difference is less than image size, both are part of the same image block
764
- same_image_block = torch.abs(kv_idx - q_idx) <= tokens_per_image
765
  # If it's 1 for both query and key/value, we are in an image block
766
- is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids[batch_idx, kv_idx] == 1)
 
 
 
 
 
 
 
 
 
 
767
 
768
  # This is bidirectional attention whenever we are dealing with image tokens
769
  return is_image_block & same_image_block
@@ -850,6 +834,7 @@ class Gemma3Model(Gemma3PreTrainedModel):
850
 
851
  def image_features_is_none(inputs_embeds, image_features=None):
852
  return inputs_embeds
 
853
  def image_features_is_not_none(inputs_embeds, image_features=None):
854
  # input_ids: [batch_size, seq_len]
855
  # input_embeds: [batch_size, seq_len, 2560 (hidden_size)]
@@ -867,9 +852,33 @@ class Gemma3Model(Gemma3PreTrainedModel):
867
  image_features_is_not_none,
868
  (inputs_embeds, image_features,)
869
  )
870
-
871
  return inputs_embeds
872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
873
  @can_return_tuple
874
  @auto_docstring
875
  def forward(
@@ -945,23 +954,10 @@ class Gemma3Model(Gemma3PreTrainedModel):
945
  # Merge text and images
946
  if pixel_values is not None:
947
  image_features = self.get_image_features(pixel_values)
948
-
949
- if input_ids is None:
950
- special_image_mask = inputs_embeds == self.get_input_embeddings()(
951
- torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
952
- )
953
- else:
954
- special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
955
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
956
-
957
- if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
958
- image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
959
- raise ValueError(
960
- f"Number of images does not match number of special image tokens in the input text. "
961
- f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
962
- "tokens from image embeddings."
963
- )
964
  image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
 
 
 
965
  inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
966
 
967
  # It may already have been prepared by e.g. `generate`
@@ -977,8 +973,17 @@ class Gemma3Model(Gemma3PreTrainedModel):
977
  }
978
  if token_type_ids is not None and inputs_embeds.shape[1] != 1:
979
  # We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
 
 
 
 
 
 
 
 
 
980
  mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
981
- token_type_ids.to(cache_position.device), self.config.mm_tokens_per_image
982
  )
983
 
984
  # Create the masks
@@ -1035,12 +1040,6 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
1035
  def set_input_embeddings(self, value):
1036
  self.model.set_input_embeddings(value)
1037
 
1038
- def get_output_embeddings(self):
1039
- return self.lm_head
1040
-
1041
- def set_output_embeddings(self, new_embeddings):
1042
- self.lm_head = new_embeddings
1043
-
1044
  def set_decoder(self, decoder):
1045
  self.model.set_decoder(decoder)
1046
 
@@ -1219,6 +1218,10 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
1219
  **kwargs,
1220
  )
1221
 
 
 
 
 
1222
  # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
1223
  # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
1224
  if cache_position[0] == 0:
@@ -1249,17 +1252,119 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
1249
  # Add the token type ids mask for generate as well
1250
  if token_type_ids is not None and input_embeds.shape[1] != 1:
1251
  # We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
 
 
 
 
 
 
 
1252
  mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
1253
- token_type_ids.to(cache_position.device), config.mm_tokens_per_image
1254
  )
1255
 
1256
  return create_masks_for_generate(**mask_kwargs)
1257
 
1258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1259
  __all__ = [
1260
  "Gemma3PreTrainedModel",
1261
  "Gemma3TextModel",
1262
  "Gemma3ForCausalLM",
1263
  "Gemma3ForConditionalGeneration",
1264
  "Gemma3Model",
 
1265
  ]
 
27
  import torch
28
  import torch.nn as nn
29
 
30
+ from ...activations import ACT2FN
31
+ from ...cache_utils import Cache, DynamicCache
32
+ from ...configuration_utils import PretrainedConfig
33
+ from ...generation import GenerationMixin
34
+ from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
35
+ from ...modeling_flash_attention_utils import FlashAttentionKwargs
36
+ from ...modeling_layers import GradientCheckpointingLayer
37
+ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
38
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
39
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
40
+ from ...processing_utils import Unpack
41
+ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging
42
+ from ...utils.deprecation import deprecate_kwarg
43
+ from ...utils.generic import check_model_inputs
44
+ from ..auto import AutoModel
45
  from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig
46
 
47
 
 
56
  )
57
  class Gemma3ModelOutputWithPast(BaseModelOutputWithPast):
58
  r"""
59
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
60
  Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
61
  `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
62
 
 
82
  Language modeling loss (for next-token prediction).
83
  logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
84
  Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
85
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
86
  Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
87
  `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
88
 
 
151
 
152
 
153
  class Gemma3RotaryEmbedding(nn.Module):
154
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
155
+
156
  def __init__(self, config: Gemma3TextConfig, device=None):
157
  super().__init__()
158
  # BC: "rope_type" was originally "type"
159
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
160
  self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
161
  else:
162
  self.rope_type = "default"
 
299
  self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
300
  self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
301
 
302
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
303
  def forward(
304
  self,
305
  hidden_states: torch.Tensor,
306
  position_embeddings: torch.Tensor,
307
  attention_mask: Optional[torch.Tensor],
308
+ past_key_values: Optional[Cache] = None,
309
  cache_position: Optional[torch.LongTensor] = None,
310
  **kwargs: Unpack[FlashAttentionKwargs],
311
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
 
322
  cos, sin = position_embeddings
323
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
324
 
325
+ if past_key_values is not None:
326
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
327
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
328
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
329
 
330
  attention_interface: Callable = eager_attention_forward
331
  if self.config._attn_implementation != "eager":
 
362
  self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
363
  self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
364
 
365
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
366
  def forward(
367
  self,
368
  hidden_states: torch.Tensor,
 
370
  position_embeddings_local: torch.Tensor,
371
  attention_mask: Optional[torch.Tensor] = None,
372
  position_ids: Optional[torch.LongTensor] = None,
373
+ past_key_values: Optional[Cache] = None,
374
  output_attentions: Optional[bool] = False,
375
  use_cache: Optional[bool] = False,
376
  cache_position: Optional[torch.LongTensor] = None,
 
391
  position_embeddings=position_embeddings,
392
  attention_mask=attention_mask,
393
  position_ids=position_ids,
394
+ past_key_values=past_key_values,
395
  output_attentions=output_attentions,
396
  use_cache=use_cache,
397
  cache_position=cache_position,
 
416
 
417
  @auto_docstring
418
  class Gemma3PreTrainedModel(PreTrainedModel):
419
+ config: Gemma3Config
420
  base_model_prefix = ""
421
  supports_gradient_checkpointing = True
422
  _no_split_modules = [
 
426
  "SiglipMultiheadAttentionPoolingHead",
427
  ]
428
  _skip_keys_device_placement = ["past_key_values"]
429
+ _supports_flash_attn = True
 
430
  _supports_sdpa = True
431
  _supports_flex_attn = True
432
+
433
+ _can_compile_fullgraph = True
 
434
  _supports_attention_backend = True
435
+ _can_record_outputs = {
436
+ "hidden_states": Gemma3DecoderLayer,
437
+ "attentions": Gemma3Attention,
438
+ }
439
 
440
  def _init_weights(self, module):
441
+ super()._init_weights(module)
442
+ if isinstance(module, Gemma3MultiModalProjector):
 
 
 
 
 
 
 
 
 
 
 
443
  module.mm_input_projection_weight.data.zero_()
444
 
445
 
446
  @auto_docstring
447
  class Gemma3TextModel(Gemma3PreTrainedModel):
448
+ config: Gemma3TextConfig
449
 
450
  def __init__(self, config: Gemma3TextConfig):
451
  super().__init__(config)
 
473
  # Initialize weights and apply final processing
474
  self.post_init()
475
 
476
+ @check_model_inputs
 
 
 
 
 
 
477
  @auto_docstring
478
  def forward(
479
  self,
 
486
  output_attentions: Optional[bool] = None,
487
  output_hidden_states: Optional[bool] = None,
488
  cache_position: Optional[torch.LongTensor] = None,
489
+ **kwargs: Unpack[TransformersKwargs],
490
  ) -> BaseModelOutputWithPast:
491
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
492
  output_hidden_states = (
 
507
  inputs_embeds = self.embed_tokens(input_ids)
508
 
509
  if use_cache and past_key_values is None and not self.training:
510
+ past_key_values = DynamicCache(config=self.config)
511
 
512
  if cache_position is None:
513
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
 
558
  position_embeddings_local=position_embeddings_local,
559
  attention_mask=causal_mask_mapping[decoder_layer.attention_type],
560
  position_ids=position_ids,
561
+ past_key_values=past_key_values,
562
  output_attentions=output_attentions,
563
  use_cache=use_cache,
564
  cache_position=cache_position,
565
+ **kwargs,
566
  )
567
 
568
  hidden_states = layer_outputs[0]
 
588
  _tied_weights_keys = ["lm_head.weight"]
589
  _tp_plan = {"lm_head": "colwise_rep"}
590
  _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
591
+ config: Gemma3TextConfig
592
  base_model_prefix = "language_model"
593
 
594
  def __init__(self, config: Gemma3TextConfig):
 
600
  # Initialize weights and apply final processing
601
  self.post_init()
602
 
 
 
 
 
 
 
 
 
 
 
 
 
603
  def set_decoder(self, decoder):
604
  self.model = decoder
605
 
 
621
  output_hidden_states: Optional[bool] = None,
622
  cache_position: Optional[torch.LongTensor] = None,
623
  logits_to_keep: Union[int, torch.Tensor] = 0,
624
+ **kwargs,
625
  ) -> CausalLMOutputWithPast:
626
  r"""
 
 
 
 
 
627
  Example:
628
 
629
  ```python
 
661
  output_attentions=output_attentions,
662
  output_hidden_states=output_hidden_states,
663
  cache_position=cache_position,
664
+ **kwargs,
665
  )
666
 
667
  hidden_states = outputs.last_hidden_state
 
675
 
676
  loss = None
677
  if labels is not None:
678
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
679
 
680
  return CausalLMOutputWithPast(
681
  loss=loss,
 
722
  return projected_vision_outputs.type_as(vision_outputs)
723
 
724
 
725
+ def token_type_ids_mask_function(
726
+ token_type_ids: Optional[torch.Tensor],
727
+ image_group_ids: Optional[torch.Tensor],
728
+ tokens_per_image: int,
729
+ ) -> Optional[Callable]:
730
  """
731
  This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
732
  not start and end indices.
 
736
  return None
737
 
738
  def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
 
 
739
  # If it's 1 for both query and key/value, we are in an image block
740
+ # NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
741
+ # Since vmap doesn't support `if statement` we workaround it with `torch.where`
742
+ safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
743
+ token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx]
744
+ token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)
745
+
746
+ image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_idx]
747
+ image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1)
748
+
749
+ is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1)
750
+ same_image_block = image_group_ids[batch_idx, q_idx] == image_group_ids_at_kv_idx
751
 
752
  # This is bidirectional attention whenever we are dealing with image tokens
753
  return is_image_block & same_image_block
 
834
 
835
  def image_features_is_none(inputs_embeds, image_features=None):
836
  return inputs_embeds
837
+
838
  def image_features_is_not_none(inputs_embeds, image_features=None):
839
  # input_ids: [batch_size, seq_len]
840
  # input_embeds: [batch_size, seq_len, 2560 (hidden_size)]
 
852
  image_features_is_not_none,
853
  (inputs_embeds, image_features,)
854
  )
855
+
856
  return inputs_embeds
857
 
858
+ def get_placeholder_mask(
859
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
860
+ ):
861
+ """
862
+ Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
863
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
864
+ """
865
+ if input_ids is None:
866
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
867
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
868
+ )
869
+ special_image_mask = special_image_mask.all(-1)
870
+ else:
871
+ special_image_mask = input_ids == self.config.image_token_id
872
+
873
+ n_image_tokens = special_image_mask.sum()
874
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
875
+ n_image_features = image_features.shape[0] * image_features.shape[1]
876
+ if inputs_embeds[special_image_mask].numel() != image_features.numel():
877
+ raise ValueError(
878
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
879
+ )
880
+ return special_image_mask
881
+
882
  @can_return_tuple
883
  @auto_docstring
884
  def forward(
 
954
  # Merge text and images
955
  if pixel_values is not None:
956
  image_features = self.get_image_features(pixel_values)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
957
  image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
958
+ special_image_mask = self.get_placeholder_mask(
959
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
960
+ )
961
  inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
962
 
963
  # It may already have been prepared by e.g. `generate`
 
973
  }
974
  if token_type_ids is not None and inputs_embeds.shape[1] != 1:
975
  # We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
976
+
977
+ # First find where a new image block starts: 1 if image and previous not image
978
+ # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
979
+ is_image = (token_type_ids == 1).to(cache_position.device)
980
+ new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
981
+ image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
982
+ image_group_ids = torch.where(
983
+ is_image, image_group_ids, torch.full_like(token_type_ids, -1, device=is_image.device)
984
+ )
985
  mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
986
+ token_type_ids.to(cache_position.device), image_group_ids, self.config.mm_tokens_per_image
987
  )
988
 
989
  # Create the masks
 
1040
  def set_input_embeddings(self, value):
1041
  self.model.set_input_embeddings(value)
1042
 
 
 
 
 
 
 
1043
  def set_decoder(self, decoder):
1044
  self.model.set_decoder(decoder)
1045
 
 
1218
  **kwargs,
1219
  )
1220
 
1221
+ # position_ids in Gemma3 are 1-indexed
1222
+ if model_inputs.get("position_ids") is not None:
1223
+ model_inputs["position_ids"] += 1
1224
+
1225
  # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
1226
  # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
1227
  if cache_position[0] == 0:
 
1252
  # Add the token type ids mask for generate as well
1253
  if token_type_ids is not None and input_embeds.shape[1] != 1:
1254
  # We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
1255
+
1256
+ # First find where a new image block starts: 1 if image and previous not image
1257
+ # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
1258
+ is_image = (token_type_ids == 1).to(cache_position.device)
1259
+ new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
1260
+ image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
1261
+ image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
1262
  mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
1263
+ token_type_ids.to(cache_position.device), image_group_ids, config.mm_tokens_per_image
1264
  )
1265
 
1266
  return create_masks_for_generate(**mask_kwargs)
1267
 
1268
 
1269
+ class Gemma3ForSequenceClassification(Gemma3PreTrainedModel):
1270
+ _checkpoint_conversion_mapping = {
1271
+ "^language_model.model": "model.language_model",
1272
+ "^vision_tower": "model.vision_tower",
1273
+ "^multi_modal_projector": "model.multi_modal_projector",
1274
+ }
1275
+
1276
+ def __init__(self, config):
1277
+ super().__init__(config)
1278
+ self.num_labels = config.num_labels
1279
+ self.model = Gemma3Model(config)
1280
+ self.score = nn.Linear(config.text_config.hidden_size, self.num_labels, bias=False)
1281
+
1282
+ # Initialize weights and apply final processing
1283
+ self.post_init()
1284
+
1285
+ def get_input_embeddings(self):
1286
+ return self.model.get_input_embeddings()
1287
+
1288
+ def set_input_embeddings(self, value):
1289
+ self.model.set_input_embeddings(value)
1290
+
1291
+ @can_return_tuple
1292
+ @auto_docstring
1293
+ def forward(
1294
+ self,
1295
+ input_ids: torch.LongTensor = None,
1296
+ pixel_values: Optional[torch.FloatTensor] = None,
1297
+ attention_mask: Optional[torch.Tensor] = None,
1298
+ position_ids: Optional[torch.LongTensor] = None,
1299
+ past_key_values: Optional[Cache] = None,
1300
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1301
+ token_type_ids: Optional[torch.LongTensor] = None,
1302
+ labels: Optional[torch.LongTensor] = None,
1303
+ use_cache: Optional[bool] = None,
1304
+ **kwargs: Unpack[TransformersKwargs],
1305
+ ) -> SequenceClassifierOutputWithPast:
1306
+ r"""
1307
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1308
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1309
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1310
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1311
+ """
1312
+
1313
+ transformer_outputs = self.model(
1314
+ input_ids,
1315
+ attention_mask=attention_mask,
1316
+ pixel_values=pixel_values,
1317
+ position_ids=position_ids,
1318
+ past_key_values=past_key_values,
1319
+ inputs_embeds=inputs_embeds,
1320
+ token_type_ids=token_type_ids,
1321
+ use_cache=use_cache,
1322
+ **kwargs,
1323
+ )
1324
+ hidden_states = transformer_outputs.last_hidden_state
1325
+ logits = self.score(hidden_states)
1326
+
1327
+ if input_ids is not None:
1328
+ batch_size = input_ids.shape[0]
1329
+ else:
1330
+ batch_size = inputs_embeds.shape[0]
1331
+
1332
+ if self.config.text_config.pad_token_id is None and batch_size != 1:
1333
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1334
+ if self.config.text_config.pad_token_id is None:
1335
+ last_non_pad_token = -1
1336
+ elif input_ids is not None:
1337
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
1338
+ non_pad_mask = (input_ids != self.config.text_config.pad_token_id).to(logits.device, torch.int32)
1339
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
1340
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
1341
+ else:
1342
+ last_non_pad_token = -1
1343
+ logger.warning_once(
1344
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1345
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1346
+ )
1347
+
1348
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
1349
+
1350
+ loss = None
1351
+ if labels is not None:
1352
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
1353
+
1354
+ return SequenceClassifierOutputWithPast(
1355
+ loss=loss,
1356
+ logits=pooled_logits,
1357
+ past_key_values=transformer_outputs.past_key_values,
1358
+ hidden_states=transformer_outputs.hidden_states,
1359
+ attentions=transformer_outputs.attentions,
1360
+ )
1361
+
1362
+
1363
  __all__ = [
1364
  "Gemma3PreTrainedModel",
1365
  "Gemma3TextModel",
1366
  "Gemma3ForCausalLM",
1367
  "Gemma3ForConditionalGeneration",
1368
  "Gemma3Model",
1369
+ "Gemma3ForSequenceClassification",
1370
  ]