gabrielchua commited on
Commit
25776aa
·
verified ·
1 Parent(s): 1e12913

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +103 -1
utils.py CHANGED
@@ -4,11 +4,14 @@ utils.py
4
 
5
  # Standard imports
6
  import os
7
- from typing import List
8
 
9
  # Third party imports
10
  import numpy as np
 
11
  from openai import OpenAI
 
 
12
 
13
  client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
14
 
@@ -42,3 +45,102 @@ def get_embeddings(
42
  # Extract embeddings from response
43
  embeddings = np.array([data.embedding for data in response.data])
44
  return embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  # Standard imports
6
  import os
7
+ from typing import List, Tuple
8
 
9
  # Third party imports
10
  import numpy as np
11
+ from google import genai
12
  from openai import OpenAI
13
+ from sentence_transformers import SentenceTransformer
14
+ from transformers import AutoModel
15
 
16
  client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
17
 
 
45
  # Extract embeddings from response
46
  embeddings = np.array([data.embedding for data in response.data])
47
  return embeddings
48
+
49
+
50
+ MODEL_CONFIGS = {
51
+ "lionguard-2": {
52
+ "label": "LionGuard 2",
53
+ "repo_id": "govtech/lionguard-2",
54
+ "embedding_strategy": "openai",
55
+ "embedding_model": "text-embedding-3-large",
56
+ },
57
+ "lionguard-2-lite": {
58
+ "label": "LionGuard 2 Lite",
59
+ "repo_id": "govtech/lionguard-2-lite",
60
+ "embedding_strategy": "sentence_transformer",
61
+ "embedding_model": "google/embeddinggemma-300m",
62
+ },
63
+ "lionguard-2.1": {
64
+ "label": "LionGuard 2.1",
65
+ "repo_id": "govtech/lionguard-2.1",
66
+ "embedding_strategy": "gemini",
67
+ "embedding_model": "gemini-embedding-001",
68
+ },
69
+ }
70
+
71
+ DEFAULT_MODEL_KEY = "lionguard-2.1"
72
+ MODEL_CACHE = {}
73
+ EMBEDDING_MODEL_CACHE = {}
74
+ current_model_choice = DEFAULT_MODEL_KEY
75
+ GEMINI_CLIENT = None
76
+
77
+
78
+ def resolve_model_key(model_key: str = None) -> str:
79
+ key = model_key or current_model_choice
80
+ if key not in MODEL_CONFIGS:
81
+ raise ValueError(f"Unknown model selection: {key}")
82
+ return key
83
+
84
+
85
+ def load_model_instance(model_key: str):
86
+ key = resolve_model_key(model_key)
87
+ if key not in MODEL_CACHE:
88
+ repo_id = MODEL_CONFIGS[key]["repo_id"]
89
+ MODEL_CACHE[key] = AutoModel.from_pretrained(repo_id, trust_remote_code=True)
90
+ return MODEL_CACHE[key]
91
+
92
+
93
+ def get_sentence_transformer(model_name: str):
94
+ if model_name not in EMBEDDING_MODEL_CACHE:
95
+ EMBEDDING_MODEL_CACHE[model_name] = SentenceTransformer(model_name)
96
+ return EMBEDDING_MODEL_CACHE[model_name]
97
+
98
+
99
+ def get_gemini_client():
100
+ global GEMINI_CLIENT
101
+ if GEMINI_CLIENT is None:
102
+ api_key = os.getenv("GEMINI_API_KEY")
103
+ if not api_key:
104
+ raise EnvironmentError(
105
+ "GEMINI_API_KEY environment variable is required for LionGuard 2.1."
106
+ )
107
+ GEMINI_CLIENT = genai.Client(api_key=api_key)
108
+ return GEMINI_CLIENT
109
+
110
+
111
+ def get_model_embeddings(model_key: str, texts: List[str]) -> np.ndarray:
112
+ key = resolve_model_key(model_key)
113
+ config = MODEL_CONFIGS[key]
114
+ strategy = config["embedding_strategy"]
115
+ model_name = config.get("embedding_model")
116
+
117
+ if strategy == "openai":
118
+ return get_embeddings(texts, model=model_name)
119
+ if strategy == "sentence_transformer":
120
+ embedder = get_sentence_transformer(model_name)
121
+ formatted_texts = [f"task: classification | query: {text}" for text in texts]
122
+ embeddings = embedder.encode(formatted_texts)
123
+ return np.array(embeddings)
124
+ if strategy == "gemini":
125
+ client = get_gemini_client()
126
+ result = client.models.embed_content(model=model_name, contents=texts)
127
+ return np.array([embedding.values for embedding in result.embeddings])
128
+
129
+ raise ValueError(f"Unsupported embedding strategy: {strategy}")
130
+
131
+
132
+ def predict_with_model(texts: List[str], model_key: str = None) -> Tuple[dict, str]:
133
+ key = resolve_model_key(model_key)
134
+ embeddings = get_model_embeddings(key, texts)
135
+ model = load_model_instance(key)
136
+ return model.predict(embeddings), key
137
+
138
+
139
+ def set_active_model(model_key: str) -> str:
140
+ if model_key not in MODEL_CONFIGS:
141
+ return f"⚠️ Unknown model {model_key}"
142
+ global current_model_choice
143
+ current_model_choice = model_key
144
+ load_model_instance(model_key)
145
+ label = MODEL_CONFIGS[model_key]["label"]
146
+ return f"🦁 Using {label} ({model_key})"