SafeGem-27B / modeling_safegem.py
ywlee88's picture
Initial release of SafeGem-27B with Visual Guard Module
c0ae4dc verified
"""
SafeGem: Vision-Language Model with Visual Guard Module
This implementation extends Gemma3ForConditionalGeneration with image safety classification
capabilities using a pooling-based approach for safety feature extraction.
"""
import torch
import torch.nn as nn
from typing import Optional, Tuple, List, Union
from dataclasses import dataclass
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import Gemma3ForConditionalGeneration
from transformers.utils import logging
from .configuration_safegem import SafeGemConfig
logger = logging.get_logger(__name__)
local_rank = None
def rank0_print(*args):
if local_rank == 0 or local_rank == '0' or local_rank is None:
print(*args)
@dataclass
class SafeGemOutput(CausalLMOutputWithPast):
"""
Output class for SafeGem with safety classification results.
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None
img_safety_logits: Optional[torch.FloatTensor] = None
img_safety_probs: Optional[torch.FloatTensor] = None
class SafetyMLP(nn.Module):
"""
Multi-layer perceptron for safety classification (Visual Guard Module).
"""
def __init__(
self,
input_size: int,
hidden_size: int,
output_size: int,
num_hidden_layers: int = 1
):
super().__init__()
layers = []
# First layer
layers.append(nn.Linear(input_size, hidden_size))
layers.append(nn.GELU())
layers.append(nn.Dropout(0.1))
# Additional hidden layers
for _ in range(num_hidden_layers - 1):
layers.append(nn.Linear(hidden_size, hidden_size))
layers.append(nn.GELU())
layers.append(nn.Dropout(0.1))
# Output layer
layers.append(nn.Linear(hidden_size, output_size))
self.mlp = nn.Sequential(*layers)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
torch.nn.init.constant_(module.bias, 0)
def forward(self, x):
return self.mlp(x)
class SafeGemForConditionalGeneration(Gemma3ForConditionalGeneration):
"""
SafeGem model with Visual Guard Module for image safety classification.
This model extends Gemma3ForConditionalGeneration with:
1. Visual Guard Module (VGM) - a safety classification head
2. Pooling-based safety feature extraction from image tokens
3. Simultaneous text generation and safety classification
Key design principles:
- Minimal modification to base Gemma3 forward pass
- Extract safety features from visual tokens using mean pooling
- Non-invasive architecture that maintains full base model capabilities
"""
config_class = SafeGemConfig
def __init__(self, config: SafeGemConfig):
super().__init__(config)
# Add safety head (Visual Guard Module) if safety configuration is present
num_safety_categories = getattr(config, 'num_safety_categories', None)
if num_safety_categories and num_safety_categories > 0:
hidden_size = config.text_config.hidden_size
safety_head_hidden_scale = getattr(config, 'safety_head_hidden_scale', 1.0)
safety_hidden_size = int(hidden_size * safety_head_hidden_scale)
safety_num_hidden_layers = getattr(config, 'safety_num_hidden_layers', 1)
rank0_print(f"🔧 [INIT] Initializing Visual Guard Module: {hidden_size} -> {safety_hidden_size} -> {num_safety_categories}")
self.img_safety_head = SafetyMLP(
input_size=hidden_size,
hidden_size=safety_hidden_size,
output_size=num_safety_categories,
num_hidden_layers=safety_num_hidden_layers
)
else:
rank0_print(f"🔧 [INIT] No safety configuration found, Visual Guard Module not initialized")
self.img_safety_head = None
def _extract_image_features_pooling(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
input_ids: Optional[torch.Tensor] = None,
image_hidden_states: Optional[torch.Tensor] = None
) -> Optional[torch.Tensor]:
"""
Extract image features using pooling over visual tokens.
Args:
hidden_states: [batch_size, seq_len, hidden_size]
attention_mask: [batch_size, seq_len]
input_ids: [batch_size, seq_len]
image_hidden_states: [batch_size, num_images, num_patches, hidden_size]
Returns:
image_features: [batch_size, hidden_size] or None
"""
# First try to use image_hidden_states if available (from vision tower)
if image_hidden_states is not None:
# Handle different shapes of image_hidden_states
if len(image_hidden_states.shape) == 3:
# [batch_size, num_patches, hidden_size]
batch_size, num_patches, hidden_size = image_hidden_states.shape
# Mean over patches: [batch_size, hidden_size]
pooled_features = image_hidden_states.mean(dim=1)
return pooled_features
elif len(image_hidden_states.shape) == 4:
# [batch_size, num_images, num_patches, hidden_size]
batch_size, num_images, num_patches, hidden_size = image_hidden_states.shape
# Mean over patches: [batch_size, num_images, hidden_size]
pooled_per_image = image_hidden_states.mean(dim=2)
# Mean over images: [batch_size, hidden_size]
pooled_features = pooled_per_image.mean(dim=1)
rank0_print(f"🔧 [POOL] 4D pooled features shape: {pooled_features.shape}")
return pooled_features
else:
rank0_print(f"🔧 [POOL] Unexpected image_hidden_states shape: {image_hidden_states.shape}")
return None
# Fallback: return None if no image_hidden_states
if input_ids is None:
rank0_print("🔧 [POOL] No input_ids available for image token detection")
return None
rank0_print("🔧 [POOL] No image_hidden_states available, cannot extract image features")
return None
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
pixel_values: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
do_safety: bool = True, # Default to True for training, can be overridden for generation
safety_labels: Optional[torch.LongTensor] = None,
**kwargs
) -> Union[Tuple, SafeGemOutput]:
"""
Forward pass with optional safety classification.
Args:
do_safety: Whether to perform safety classification (default: True)
All other args: Same as Gemma3ForConditionalGeneration
Returns:
SafeGemOutput with optional safety classification results
"""
# Force output_hidden_states if we need safety classification
# BUT only during initial forward pass, not during generation
if do_safety and self.img_safety_head is not None and past_key_values is None:
output_hidden_states = True
return_dict = True
# Standard Gemma3 forward pass - NO MODIFICATIONS
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
pixel_values=pixel_values,
return_dict=True,
**kwargs
)
# Fix NaN/Inf in logits if present
if outputs.logits is not None:
nan_count = torch.isnan(outputs.logits).sum()
inf_count = torch.isinf(outputs.logits).sum()
if nan_count > 0 or inf_count > 0:
if past_key_values is None:
print(f"[CRITICAL] Found NaN or Inf in logits! NaN count: {nan_count}, Inf count: {inf_count}")
replacement_values = torch.randn_like(outputs.logits) * 0.001
outputs.logits = torch.where(
torch.isnan(outputs.logits) | torch.isinf(outputs.logits),
replacement_values,
outputs.logits
)
# Fix logits shape if needed
if len(outputs.logits.shape) == 4 and outputs.logits.shape[1] == 1:
outputs.logits = outputs.logits.squeeze(1)
# Initialize safety outputs
img_safety_logits = None
img_safety_probs = None
# Check if we should perform safety classification
is_generation = past_key_values is not None
has_images = pixel_values is not None
should_do_safety = (
do_safety and
self.img_safety_head is not None and
(outputs.hidden_states is not None or outputs.image_hidden_states is not None) and
has_images and
not is_generation
)
if should_do_safety:
# Extract image features
image_features = self._extract_image_features_pooling(
hidden_states=outputs.hidden_states[-1] if outputs.hidden_states else None,
attention_mask=attention_mask,
input_ids=input_ids,
image_hidden_states=outputs.image_hidden_states
)
if image_features is not None:
# Run through Visual Guard Module
img_safety_logits = self.img_safety_head(image_features)
img_safety_probs = torch.softmax(img_safety_logits, dim=-1)
else:
rank0_print("🔧 [SafeGem] ❌ Image features extraction failed")
# Return results
if return_dict is False:
output = (outputs.loss, outputs.logits, outputs.past_key_values,
outputs.hidden_states, outputs.attentions)
if img_safety_logits is not None:
output += (img_safety_logits, img_safety_probs)
return output
else:
# During generation, return standard output
if is_generation or past_key_values is not None:
return outputs
else:
# During training/inference, return custom output with safety info
return SafeGemOutput(
loss=outputs.loss,
logits=outputs.logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
img_safety_logits=img_safety_logits,
img_safety_probs=img_safety_probs
)