gabrielchua commited on
Commit
2b87e64
Β·
1 Parent(s): ec71cb2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -45
app.py CHANGED
@@ -7,6 +7,7 @@ import json
7
  import os
8
  import sys
9
  import uuid
 
10
  from datetime import datetime
11
 
12
  # Third party imports
@@ -37,7 +38,9 @@ CATEGORIES = {
37
  }
38
 
39
  # --- OpenAI Setup ---
 
40
  client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
 
41
 
42
  # --- Model Loading ---
43
  def load_lionguard2():
@@ -204,10 +207,12 @@ def vote_thumbs_down(text_id):
204
  return '<div style="color: #fca5a5; font-weight:700;">πŸ“ Thanks for the feedback!</div>'
205
  return '<div>Voting not available or analysis not yet run.</div>'
206
 
207
- # --- Guardrail Comparison logic ---
208
- def get_openai_response(message, system_prompt="You are a helpful assistant."):
 
 
209
  try:
210
- response = client.chat.completions.create(
211
  model="gpt-4.1-nano",
212
  messages=[
213
  {"role": "system", "content": system_prompt},
@@ -221,15 +226,17 @@ def get_openai_response(message, system_prompt="You are a helpful assistant."):
221
  except Exception as e:
222
  return f"Error: {str(e)}. Please check your OpenAI API key."
223
 
224
- def openai_moderation(message):
 
225
  try:
226
- response = client.moderations.create(input=message)
227
  return response.results[0].flagged
228
  except Exception as e:
229
  print(f"Error in OpenAI moderation: {e}")
230
  return False
231
 
232
- def lionguard_2(message, threshold=0.5):
 
233
  try:
234
  embeddings = get_embeddings([message])
235
  results = model.predict(embeddings)
@@ -239,68 +246,109 @@ def lionguard_2(message, threshold=0.5):
239
  print(f"Error in LionGuard 2: {e}")
240
  return False, 0.0
241
 
242
- def process_message(message, history_no_mod, history_openai, history_lg):
243
- if not message.strip():
244
- return history_no_mod, history_openai, history_lg, ""
245
- no_mod_response = get_openai_response(message)
246
  history_no_mod.append({"role": "user", "content": message})
247
  history_no_mod.append({"role": "assistant", "content": no_mod_response})
 
248
 
249
- openai_flagged = openai_moderation(message)
 
 
250
  history_openai.append({"role": "user", "content": message})
251
  if openai_flagged:
252
  openai_response = "🚫 This message has been flagged by OpenAI moderation"
253
  history_openai.append({"role": "assistant", "content": openai_response})
254
  else:
255
- openai_response = get_openai_response(message)
256
  history_openai.append({"role": "assistant", "content": openai_response})
257
-
258
- lg_flagged, lg_score = lionguard_2(message)
 
 
 
 
 
 
259
  history_lg.append({"role": "user", "content": message})
260
  if lg_flagged:
261
  lg_response = "🚫 This message has been flagged by LionGuard 2"
262
  history_lg.append({"role": "assistant", "content": lg_response})
263
  else:
264
- lg_response = get_openai_response(message)
265
  history_lg.append({"role": "assistant", "content": lg_response})
 
266
 
267
- # --- Logging for chatbot worksheet ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  if GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
269
  try:
270
- embeddings = get_embeddings([message])
271
- results = model.predict(embeddings)
272
- now = datetime.now().isoformat()
273
- text_id = str(uuid.uuid4())
274
- row = {
275
- "datetime": now,
276
- "text_id": text_id,
277
- "text": message,
278
- "binary_score": results.get("binary", [None])[0],
279
- "hateful_l1_score": results.get(CATEGORIES['hateful'][0], [None])[0],
280
- "hateful_l2_score": results.get(CATEGORIES['hateful'][1], [None])[0],
281
- "insults_score": results.get(CATEGORIES['insults'][0], [None])[0],
282
- "sexual_l1_score": results.get(CATEGORIES['sexual'][0], [None])[0],
283
- "sexual_l2_score": results.get(CATEGORIES['sexual'][1], [None])[0],
284
- "physical_violence_score": results.get(CATEGORIES['physical_violence'][0], [None])[0],
285
- "self_harm_l1_score": results.get(CATEGORIES['self_harm'][0], [None])[0],
286
- "self_harm_l2_score": results.get(CATEGORIES['self_harm'][1], [None])[0],
287
- "aom_l1_score": results.get(CATEGORIES['all_other_misconduct'][0], [None])[0],
288
- "aom_l2_score": results.get(CATEGORIES['all_other_misconduct'][1], [None])[0],
289
- "openai_score": None
290
- }
291
- try:
292
- openai_result = client.moderations.create(input=message)
293
- # Using the "hate" category score as a demonstration. You may customize as needed.
294
- row["openai_score"] = float(openai_result.results[0].category_scores.get("hate", 0.0))
295
- except Exception:
296
- row["openai_score"] = None
297
-
298
- log_chatbot_data(row)
299
  except Exception as e:
300
  print(f"Chatbot logging failed: {e}")
301
 
302
  return history_no_mod, history_openai, history_lg, ""
303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  def clear_all_chats():
305
  return [], [], []
306
 
@@ -362,6 +410,7 @@ with gr.Blocks(title="LionGuard 2 Demo", theme=gr.themes.Soft()) as demo:
362
 
363
  with gr.Tab("Guardrail Comparison"):
364
  gr.HTML(DISCLAIMER)
 
365
  with gr.Row():
366
  with gr.Column(scale=1):
367
  gr.Markdown("#### πŸ”΅ No Moderation")
 
7
  import os
8
  import sys
9
  import uuid
10
+ import asyncio
11
  from datetime import datetime
12
 
13
  # Third party imports
 
38
  }
39
 
40
  # --- OpenAI Setup ---
41
+ # Create both sync and async clients
42
  client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
43
+ async_client = openai.AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
44
 
45
  # --- Model Loading ---
46
  def load_lionguard2():
 
207
  return '<div style="color: #fca5a5; font-weight:700;">πŸ“ Thanks for the feedback!</div>'
208
  return '<div>Voting not available or analysis not yet run.</div>'
209
 
210
+ # --- Guardrail Comparison logic (ASYNC VERSION) ---
211
+
212
+ async def get_openai_response_async(message, system_prompt="You are a helpful assistant."):
213
+ """Async version of OpenAI API call"""
214
  try:
215
+ response = await async_client.chat.completions.create(
216
  model="gpt-4.1-nano",
217
  messages=[
218
  {"role": "system", "content": system_prompt},
 
226
  except Exception as e:
227
  return f"Error: {str(e)}. Please check your OpenAI API key."
228
 
229
+ async def openai_moderation_async(message):
230
+ """Async version of OpenAI moderation"""
231
  try:
232
+ response = await async_client.moderations.create(input=message)
233
  return response.results[0].flagged
234
  except Exception as e:
235
  print(f"Error in OpenAI moderation: {e}")
236
  return False
237
 
238
+ def lionguard_2_sync(message, threshold=0.5):
239
+ """LionGuard remains sync as it's using a local model"""
240
  try:
241
  embeddings = get_embeddings([message])
242
  results = model.predict(embeddings)
 
246
  print(f"Error in LionGuard 2: {e}")
247
  return False, 0.0
248
 
249
+ async def process_no_moderation(message, history_no_mod):
250
+ """Process message without moderation"""
251
+ no_mod_response = await get_openai_response_async(message)
 
252
  history_no_mod.append({"role": "user", "content": message})
253
  history_no_mod.append({"role": "assistant", "content": no_mod_response})
254
+ return history_no_mod
255
 
256
+ async def process_openai_moderation(message, history_openai):
257
+ """Process message with OpenAI moderation"""
258
+ openai_flagged = await openai_moderation_async(message)
259
  history_openai.append({"role": "user", "content": message})
260
  if openai_flagged:
261
  openai_response = "🚫 This message has been flagged by OpenAI moderation"
262
  history_openai.append({"role": "assistant", "content": openai_response})
263
  else:
264
+ openai_response = await get_openai_response_async(message)
265
  history_openai.append({"role": "assistant", "content": openai_response})
266
+ return history_openai
267
+
268
+ async def process_lionguard(message, history_lg):
269
+ """Process message with LionGuard 2"""
270
+ # Run LionGuard sync check in thread pool to not block
271
+ loop = asyncio.get_event_loop()
272
+ lg_flagged, lg_score = await loop.run_in_executor(None, lionguard_2_sync, message, 0.5)
273
+
274
  history_lg.append({"role": "user", "content": message})
275
  if lg_flagged:
276
  lg_response = "🚫 This message has been flagged by LionGuard 2"
277
  history_lg.append({"role": "assistant", "content": lg_response})
278
  else:
279
+ lg_response = await get_openai_response_async(message)
280
  history_lg.append({"role": "assistant", "content": lg_response})
281
+ return history_lg, lg_score
282
 
283
+ async def process_message_async(message, history_no_mod, history_openai, history_lg):
284
+ """Process message concurrently across all three guardrails"""
285
+ if not message.strip():
286
+ return history_no_mod, history_openai, history_lg, ""
287
+
288
+ # Run all three processes concurrently using asyncio.gather
289
+ results = await asyncio.gather(
290
+ process_no_moderation(message, history_no_mod),
291
+ process_openai_moderation(message, history_openai),
292
+ process_lionguard(message, history_lg),
293
+ return_exceptions=True # Continue even if one fails
294
+ )
295
+
296
+ # Unpack results
297
+ history_no_mod = results[0] if not isinstance(results[0], Exception) else history_no_mod
298
+ history_openai = results[1] if not isinstance(results[1], Exception) else history_openai
299
+ history_lg_result = results[2] if not isinstance(results[2], Exception) else (history_lg, 0.0)
300
+ history_lg = history_lg_result[0]
301
+ lg_score = history_lg_result[1] if isinstance(history_lg_result, tuple) else 0.0
302
+
303
+ # --- Logging for chatbot worksheet (runs in background) ---
304
  if GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
305
  try:
306
+ loop = asyncio.get_event_loop()
307
+ # Run logging in thread pool so it doesn't block
308
+ loop.run_in_executor(None, _log_chatbot_sync, message, lg_score)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  except Exception as e:
310
  print(f"Chatbot logging failed: {e}")
311
 
312
  return history_no_mod, history_openai, history_lg, ""
313
 
314
+ def _log_chatbot_sync(message, lg_score):
315
+ """Sync helper for logging - runs in thread pool"""
316
+ try:
317
+ embeddings = get_embeddings([message])
318
+ results = model.predict(embeddings)
319
+ now = datetime.now().isoformat()
320
+ text_id = str(uuid.uuid4())
321
+ row = {
322
+ "datetime": now,
323
+ "text_id": text_id,
324
+ "text": message,
325
+ "binary_score": results.get("binary", [None])[0],
326
+ "hateful_l1_score": results.get(CATEGORIES['hateful'][0], [None])[0],
327
+ "hateful_l2_score": results.get(CATEGORIES['hateful'][1], [None])[0],
328
+ "insults_score": results.get(CATEGORIES['insults'][0], [None])[0],
329
+ "sexual_l1_score": results.get(CATEGORIES['sexual'][0], [None])[0],
330
+ "sexual_l2_score": results.get(CATEGORIES['sexual'][1], [None])[0],
331
+ "physical_violence_score": results.get(CATEGORIES['physical_violence'][0], [None])[0],
332
+ "self_harm_l1_score": results.get(CATEGORIES['self_harm'][0], [None])[0],
333
+ "self_harm_l2_score": results.get(CATEGORIES['self_harm'][1], [None])[0],
334
+ "aom_l1_score": results.get(CATEGORIES['all_other_misconduct'][0], [None])[0],
335
+ "aom_l2_score": results.get(CATEGORIES['all_other_misconduct'][1], [None])[0],
336
+ "openai_score": None
337
+ }
338
+ try:
339
+ openai_result = client.moderations.create(input=message)
340
+ row["openai_score"] = float(openai_result.results[0].category_scores.get("hate", 0.0))
341
+ except Exception:
342
+ row["openai_score"] = None
343
+
344
+ log_chatbot_data(row)
345
+ except Exception as e:
346
+ print(f"Error in sync logging: {e}")
347
+
348
+ def process_message(message, history_no_mod, history_openai, history_lg):
349
+ """Wrapper function for Gradio (converts async to sync)"""
350
+ return asyncio.run(process_message_async(message, history_no_mod, history_openai, history_lg))
351
+
352
  def clear_all_chats():
353
  return [], [], []
354
 
 
410
 
411
  with gr.Tab("Guardrail Comparison"):
412
  gr.HTML(DISCLAIMER)
413
+ gr.HTML('<div style="background: #34d399; color: #1e293b; border-radius: 8px; padding: 12px; margin-bottom: 12px; font-size: 14px; font-weight:600;">⚑ Concurrent Processing: All 3 guardrails run in parallel for faster responses!</div>')
414
  with gr.Row():
415
  with gr.Column(scale=1):
416
  gr.Markdown("#### πŸ”΅ No Moderation")