|
|
""" |
|
|
SafeGem Configuration |
|
|
|
|
|
Configuration class for SafeGem models with safety classification capabilities. |
|
|
""" |
|
|
|
|
|
from typing import Optional, List |
|
|
from transformers import Gemma3Config |
|
|
|
|
|
|
|
|
class SafeGemConfig(Gemma3Config): |
|
|
""" |
|
|
Configuration for SafeGem model. |
|
|
|
|
|
This configuration class extends Gemma3Config with safety-specific parameters. |
|
|
""" |
|
|
|
|
|
model_type = "safegem" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
|
|
|
safety_categories: Optional[List[str]] = None, |
|
|
safety_head_hidden_scale: float = 1.0, |
|
|
safety_loss_lambda: float = 1.0, |
|
|
safety_num_hidden_layers: int = 1, |
|
|
num_safety_categories: int = 20, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
self.safety_categories = safety_categories or [ |
|
|
"safe", |
|
|
"gender", |
|
|
"race", |
|
|
"religion", |
|
|
"harassment", |
|
|
"disability_discrimination", |
|
|
"drug_crime", |
|
|
"property_crime", |
|
|
"facial_data", |
|
|
"identity_data", |
|
|
"physical_self_injury", |
|
|
"suicide", |
|
|
"animal_abuse", |
|
|
"obscene_gestures", |
|
|
"physical_altercation", |
|
|
"terrorism", |
|
|
"weapon_related_violence", |
|
|
"sexual_content", |
|
|
"financial_advice", |
|
|
"medical_advice" |
|
|
] |
|
|
|
|
|
self.safety_head_hidden_scale = safety_head_hidden_scale |
|
|
self.safety_loss_lambda = safety_loss_lambda |
|
|
self.safety_num_hidden_layers = safety_num_hidden_layers |
|
|
self.num_safety_categories = num_safety_categories or len(self.safety_categories) |
|
|
|