lionguard-demo / app.py
gabrielchua's picture
Upload 5 files
bc1321c verified
raw
history blame
6.84 kB
import gradio as gr
import openai
import os
import sys
import torch
# # Add the parent directory to the path to import from final_model
# sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'final_model'))
from safetensors.torch import load_file
from lionguard2 import LionGuard2
from utils import get_embeddings
# Set up OpenAI client
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
# Load LionGuard2 model
model = LionGuard2()
model.eval()
# Load model weights
model_path = os.path.join(os.path.dirname(__file__), '..', 'final_model', 'LionGuard2.safetensors')
state_dict = load_file(model_path)
model.load_state_dict(state_dict)
def lionguard_2(message, threshold=0.5):
"""
LionGuard 2 function that uses the actual model to determine if content is unsafe.
Args:
message: The text message to check
threshold: Probability threshold for flagging content as unsafe (default: 0.5)
Returns:
bool: True if content is flagged as unsafe, False otherwise
"""
try:
# Get embeddings for the message
embeddings = get_embeddings([message])
# Get predictions from the model
results = model.predict(embeddings)
# Check the binary classification result (overall safety)
binary_prob = results['binary'][0] # First (and only) message's binary probability
# Flag as unsafe if probability exceeds threshold
return binary_prob > threshold
except Exception as e:
print(f"Error in LionGuard 2: {e}")
# In case of error, default to not flagging to avoid blocking legitimate content
return False
def get_openai_response(message, system_prompt="You are a helpful assistant."):
"""Get response from OpenAI API"""
try:
response = client.chat.completions.create(
model="gpt-4.1-nano",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": message}
],
max_tokens=500,
temperature=0,
seed=42,
)
return response.choices[0].message.content
except Exception as e:
return f"Error: {str(e)}. Please check your OpenAI API key."
def openai_moderation(message):
"""
OpenAI moderation function that uses OpenAI's built-in moderation API.
Args:
message: The text message to check
Returns:
bool: True if content is flagged as unsafe, False otherwise
"""
try:
response = client.moderations.create(input=message)
return response.results[0].flagged
except Exception as e:
print(f"Error in OpenAI moderation: {e}")
# In case of error, default to not flagging
return False
def process_message(message, history_no_mod, history_openai, history_lg):
"""Process message for all three chatbots"""
if not message.strip():
return history_no_mod, history_openai, history_lg, ""
# Process for gpt-4.1-nano (no moderation)
no_mod_response = get_openai_response(message)
history_no_mod.append({"role": "user", "content": message})
history_no_mod.append({"role": "assistant", "content": no_mod_response})
# Process for gpt-4.1-nano with OpenAI moderation
openai_flagged = openai_moderation(message)
history_openai.append({"role": "user", "content": message})
if openai_flagged:
openai_response = "🚫 This message has been flagged by OpenAI moderation"
history_openai.append({"role": "assistant", "content": openai_response})
else:
openai_response = get_openai_response(
message,
)
history_openai.append({"role": "assistant", "content": openai_response})
# Process for gpt-4.1-nano with LionGuard 2
lg_flagged = lionguard_2(message)
history_lg.append({"role": "user", "content": message})
if lg_flagged:
lg_response = "🚫 This message has been flagged by LionGuard 2"
history_lg.append({"role": "assistant", "content": lg_response})
else:
lg_response = get_openai_response(
message,
)
history_lg.append({"role": "assistant", "content": lg_response})
return history_no_mod, history_openai, history_lg, ""
def clear_all_chats():
"""Clear all chat histories"""
return [], [], []
# Create the Gradio interface
with gr.Blocks(title="LionGuard 2", theme=gr.themes.Soft()) as demo:
gr.Markdown("# EMNLP 2025 System Demonstration: LionGuard 2 🦁")
gr.Markdown("**LionGuard 2 is a content moderator localised to Singapore - use it to detect unsafe LLM inputs and outputs**")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## πŸ”΅ No Moderation")
chatbot_no_mod = gr.Chatbot(
height=800,
label="No Moderation",
show_label=False,
bubble_full_width=False,
type='messages'
)
with gr.Column(scale=1):
gr.Markdown("## 🟠 OpenAI Moderation")
chatbot_openai = gr.Chatbot(
height=800,
label="OpenAI Moderation",
show_label=False,
bubble_full_width=False,
type='messages'
)
with gr.Column(scale=1):
gr.Markdown("## πŸ›‘οΈ LionGuard 2")
chatbot_lg = gr.Chatbot(
height=800,
label="LionGuard 2",
show_label=False,
bubble_full_width=False,
type='messages'
)
# Single input for all chatbots
gr.Markdown("### πŸ’¬ Send Message to All Models")
with gr.Row():
message_input = gr.Textbox(
placeholder="Type your message to compare responses...",
show_label=False,
scale=4
)
send_btn = gr.Button("Send", variant="primary", scale=1)
# Control buttons
with gr.Row():
clear_btn = gr.Button("Clear All Chats", variant="stop")
# Event handlers
send_btn.click(
process_message,
inputs=[message_input, chatbot_no_mod, chatbot_openai, chatbot_lg],
outputs=[chatbot_no_mod, chatbot_openai, chatbot_lg, message_input]
)
message_input.submit(
process_message,
inputs=[message_input, chatbot_no_mod, chatbot_openai, chatbot_lg],
outputs=[chatbot_no_mod, chatbot_openai, chatbot_lg, message_input]
)
# Clear button
clear_btn.click(
clear_all_chats,
outputs=[chatbot_no_mod, chatbot_openai, chatbot_lg]
)
# Launch the app
if __name__ == "__main__":
demo.launch(share=True, debug=True)