|
|
""" |
|
|
SafeGem: Vision-Language Model with Visual Guard Module |
|
|
|
|
|
This implementation extends Gemma3ForConditionalGeneration with image safety classification |
|
|
capabilities using a pooling-based approach for safety feature extraction. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from typing import Optional, Tuple, List, Union |
|
|
from dataclasses import dataclass |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
from transformers import Gemma3ForConditionalGeneration |
|
|
from transformers.utils import logging |
|
|
|
|
|
from .configuration_safegem import SafeGemConfig |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
local_rank = None |
|
|
|
|
|
|
|
|
def rank0_print(*args): |
|
|
if local_rank == 0 or local_rank == '0' or local_rank is None: |
|
|
print(*args) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class SafeGemOutput(CausalLMOutputWithPast): |
|
|
""" |
|
|
Output class for SafeGem with safety classification results. |
|
|
""" |
|
|
loss: Optional[torch.FloatTensor] = None |
|
|
logits: Optional[torch.FloatTensor] = None |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None |
|
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
image_hidden_states: Optional[torch.FloatTensor] = None |
|
|
img_safety_logits: Optional[torch.FloatTensor] = None |
|
|
img_safety_probs: Optional[torch.FloatTensor] = None |
|
|
|
|
|
|
|
|
class SafetyMLP(nn.Module): |
|
|
""" |
|
|
Multi-layer perceptron for safety classification (Visual Guard Module). |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_size: int, |
|
|
hidden_size: int, |
|
|
output_size: int, |
|
|
num_hidden_layers: int = 1 |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
layers = [] |
|
|
|
|
|
|
|
|
layers.append(nn.Linear(input_size, hidden_size)) |
|
|
layers.append(nn.GELU()) |
|
|
layers.append(nn.Dropout(0.1)) |
|
|
|
|
|
|
|
|
for _ in range(num_hidden_layers - 1): |
|
|
layers.append(nn.Linear(hidden_size, hidden_size)) |
|
|
layers.append(nn.GELU()) |
|
|
layers.append(nn.Dropout(0.1)) |
|
|
|
|
|
|
|
|
layers.append(nn.Linear(hidden_size, output_size)) |
|
|
|
|
|
self.mlp = nn.Sequential(*layers) |
|
|
|
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, module): |
|
|
if isinstance(module, nn.Linear): |
|
|
torch.nn.init.xavier_uniform_(module.weight) |
|
|
if module.bias is not None: |
|
|
torch.nn.init.constant_(module.bias, 0) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.mlp(x) |
|
|
|
|
|
|
|
|
class SafeGemForConditionalGeneration(Gemma3ForConditionalGeneration): |
|
|
""" |
|
|
SafeGem model with Visual Guard Module for image safety classification. |
|
|
|
|
|
This model extends Gemma3ForConditionalGeneration with: |
|
|
1. Visual Guard Module (VGM) - a safety classification head |
|
|
2. Pooling-based safety feature extraction from image tokens |
|
|
3. Simultaneous text generation and safety classification |
|
|
|
|
|
Key design principles: |
|
|
- Minimal modification to base Gemma3 forward pass |
|
|
- Extract safety features from visual tokens using mean pooling |
|
|
- Non-invasive architecture that maintains full base model capabilities |
|
|
""" |
|
|
|
|
|
config_class = SafeGemConfig |
|
|
|
|
|
def __init__(self, config: SafeGemConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
num_safety_categories = getattr(config, 'num_safety_categories', None) |
|
|
if num_safety_categories and num_safety_categories > 0: |
|
|
hidden_size = config.text_config.hidden_size |
|
|
safety_head_hidden_scale = getattr(config, 'safety_head_hidden_scale', 1.0) |
|
|
safety_hidden_size = int(hidden_size * safety_head_hidden_scale) |
|
|
safety_num_hidden_layers = getattr(config, 'safety_num_hidden_layers', 1) |
|
|
|
|
|
rank0_print(f"🔧 [INIT] Initializing Visual Guard Module: {hidden_size} -> {safety_hidden_size} -> {num_safety_categories}") |
|
|
|
|
|
self.img_safety_head = SafetyMLP( |
|
|
input_size=hidden_size, |
|
|
hidden_size=safety_hidden_size, |
|
|
output_size=num_safety_categories, |
|
|
num_hidden_layers=safety_num_hidden_layers |
|
|
) |
|
|
else: |
|
|
rank0_print(f"🔧 [INIT] No safety configuration found, Visual Guard Module not initialized") |
|
|
self.img_safety_head = None |
|
|
|
|
|
def _extract_image_features_pooling( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
image_hidden_states: Optional[torch.Tensor] = None |
|
|
) -> Optional[torch.Tensor]: |
|
|
""" |
|
|
Extract image features using pooling over visual tokens. |
|
|
|
|
|
Args: |
|
|
hidden_states: [batch_size, seq_len, hidden_size] |
|
|
attention_mask: [batch_size, seq_len] |
|
|
input_ids: [batch_size, seq_len] |
|
|
image_hidden_states: [batch_size, num_images, num_patches, hidden_size] |
|
|
|
|
|
Returns: |
|
|
image_features: [batch_size, hidden_size] or None |
|
|
""" |
|
|
|
|
|
if image_hidden_states is not None: |
|
|
|
|
|
if len(image_hidden_states.shape) == 3: |
|
|
|
|
|
batch_size, num_patches, hidden_size = image_hidden_states.shape |
|
|
|
|
|
pooled_features = image_hidden_states.mean(dim=1) |
|
|
return pooled_features |
|
|
elif len(image_hidden_states.shape) == 4: |
|
|
|
|
|
batch_size, num_images, num_patches, hidden_size = image_hidden_states.shape |
|
|
|
|
|
pooled_per_image = image_hidden_states.mean(dim=2) |
|
|
|
|
|
pooled_features = pooled_per_image.mean(dim=1) |
|
|
rank0_print(f"🔧 [POOL] 4D pooled features shape: {pooled_features.shape}") |
|
|
return pooled_features |
|
|
else: |
|
|
rank0_print(f"🔧 [POOL] Unexpected image_hidden_states shape: {image_hidden_states.shape}") |
|
|
return None |
|
|
|
|
|
|
|
|
if input_ids is None: |
|
|
rank0_print("🔧 [POOL] No input_ids available for image token detection") |
|
|
return None |
|
|
|
|
|
rank0_print("🔧 [POOL] No image_hidden_states available, cannot extract image features") |
|
|
return None |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
do_safety: bool = True, |
|
|
safety_labels: Optional[torch.LongTensor] = None, |
|
|
**kwargs |
|
|
) -> Union[Tuple, SafeGemOutput]: |
|
|
""" |
|
|
Forward pass with optional safety classification. |
|
|
|
|
|
Args: |
|
|
do_safety: Whether to perform safety classification (default: True) |
|
|
All other args: Same as Gemma3ForConditionalGeneration |
|
|
|
|
|
Returns: |
|
|
SafeGemOutput with optional safety classification results |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
if do_safety and self.img_safety_head is not None and past_key_values is None: |
|
|
output_hidden_states = True |
|
|
return_dict = True |
|
|
|
|
|
|
|
|
outputs = super().forward( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
labels=labels, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
pixel_values=pixel_values, |
|
|
return_dict=True, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
if outputs.logits is not None: |
|
|
nan_count = torch.isnan(outputs.logits).sum() |
|
|
inf_count = torch.isinf(outputs.logits).sum() |
|
|
|
|
|
if nan_count > 0 or inf_count > 0: |
|
|
if past_key_values is None: |
|
|
print(f"[CRITICAL] Found NaN or Inf in logits! NaN count: {nan_count}, Inf count: {inf_count}") |
|
|
|
|
|
replacement_values = torch.randn_like(outputs.logits) * 0.001 |
|
|
outputs.logits = torch.where( |
|
|
torch.isnan(outputs.logits) | torch.isinf(outputs.logits), |
|
|
replacement_values, |
|
|
outputs.logits |
|
|
) |
|
|
|
|
|
|
|
|
if len(outputs.logits.shape) == 4 and outputs.logits.shape[1] == 1: |
|
|
outputs.logits = outputs.logits.squeeze(1) |
|
|
|
|
|
|
|
|
img_safety_logits = None |
|
|
img_safety_probs = None |
|
|
|
|
|
|
|
|
is_generation = past_key_values is not None |
|
|
has_images = pixel_values is not None |
|
|
|
|
|
should_do_safety = ( |
|
|
do_safety and |
|
|
self.img_safety_head is not None and |
|
|
(outputs.hidden_states is not None or outputs.image_hidden_states is not None) and |
|
|
has_images and |
|
|
not is_generation |
|
|
) |
|
|
|
|
|
if should_do_safety: |
|
|
|
|
|
image_features = self._extract_image_features_pooling( |
|
|
hidden_states=outputs.hidden_states[-1] if outputs.hidden_states else None, |
|
|
attention_mask=attention_mask, |
|
|
input_ids=input_ids, |
|
|
image_hidden_states=outputs.image_hidden_states |
|
|
) |
|
|
|
|
|
if image_features is not None: |
|
|
|
|
|
img_safety_logits = self.img_safety_head(image_features) |
|
|
img_safety_probs = torch.softmax(img_safety_logits, dim=-1) |
|
|
else: |
|
|
rank0_print("🔧 [SafeGem] ❌ Image features extraction failed") |
|
|
|
|
|
|
|
|
if return_dict is False: |
|
|
output = (outputs.loss, outputs.logits, outputs.past_key_values, |
|
|
outputs.hidden_states, outputs.attentions) |
|
|
if img_safety_logits is not None: |
|
|
output += (img_safety_logits, img_safety_probs) |
|
|
return output |
|
|
else: |
|
|
|
|
|
if is_generation or past_key_values is not None: |
|
|
return outputs |
|
|
else: |
|
|
|
|
|
return SafeGemOutput( |
|
|
loss=outputs.loss, |
|
|
logits=outputs.logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
image_hidden_states=outputs.image_hidden_states, |
|
|
img_safety_logits=img_safety_logits, |
|
|
img_safety_probs=img_safety_probs |
|
|
) |
|
|
|