Spaces:
Running
Running
| import gradio as gr | |
| import openai | |
| import os | |
| import sys | |
| import torch | |
| # # Add the parent directory to the path to import from final_model | |
| # sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'final_model')) | |
| from safetensors.torch import load_file | |
| from lionguard2 import LionGuard2 | |
| from utils import get_embeddings | |
| # Set up OpenAI client | |
| client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| # Load LionGuard2 model | |
| model = LionGuard2() | |
| model.eval() | |
| # Load model weights | |
| model_path = os.path.join(os.path.dirname(__file__), '..', 'final_model', 'LionGuard2.safetensors') | |
| state_dict = load_file(model_path) | |
| model.load_state_dict(state_dict) | |
| def lionguard_2(message, threshold=0.5): | |
| """ | |
| LionGuard 2 function that uses the actual model to determine if content is unsafe. | |
| Args: | |
| message: The text message to check | |
| threshold: Probability threshold for flagging content as unsafe (default: 0.5) | |
| Returns: | |
| bool: True if content is flagged as unsafe, False otherwise | |
| """ | |
| try: | |
| # Get embeddings for the message | |
| embeddings = get_embeddings([message]) | |
| # Get predictions from the model | |
| results = model.predict(embeddings) | |
| # Check the binary classification result (overall safety) | |
| binary_prob = results['binary'][0] # First (and only) message's binary probability | |
| # Flag as unsafe if probability exceeds threshold | |
| return binary_prob > threshold | |
| except Exception as e: | |
| print(f"Error in LionGuard 2: {e}") | |
| # In case of error, default to not flagging to avoid blocking legitimate content | |
| return False | |
| def get_openai_response(message, system_prompt="You are a helpful assistant."): | |
| """Get response from OpenAI API""" | |
| try: | |
| response = client.chat.completions.create( | |
| model="gpt-4.1-nano", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": message} | |
| ], | |
| max_tokens=500, | |
| temperature=0, | |
| seed=42, | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| return f"Error: {str(e)}. Please check your OpenAI API key." | |
| def openai_moderation(message): | |
| """ | |
| OpenAI moderation function that uses OpenAI's built-in moderation API. | |
| Args: | |
| message: The text message to check | |
| Returns: | |
| bool: True if content is flagged as unsafe, False otherwise | |
| """ | |
| try: | |
| response = client.moderations.create(input=message) | |
| return response.results[0].flagged | |
| except Exception as e: | |
| print(f"Error in OpenAI moderation: {e}") | |
| # In case of error, default to not flagging | |
| return False | |
| def process_message(message, history_no_mod, history_openai, history_lg): | |
| """Process message for all three chatbots""" | |
| if not message.strip(): | |
| return history_no_mod, history_openai, history_lg, "" | |
| # Process for gpt-4.1-nano (no moderation) | |
| no_mod_response = get_openai_response(message) | |
| history_no_mod.append({"role": "user", "content": message}) | |
| history_no_mod.append({"role": "assistant", "content": no_mod_response}) | |
| # Process for gpt-4.1-nano with OpenAI moderation | |
| openai_flagged = openai_moderation(message) | |
| history_openai.append({"role": "user", "content": message}) | |
| if openai_flagged: | |
| openai_response = "π« This message has been flagged by OpenAI moderation" | |
| history_openai.append({"role": "assistant", "content": openai_response}) | |
| else: | |
| openai_response = get_openai_response( | |
| message, | |
| ) | |
| history_openai.append({"role": "assistant", "content": openai_response}) | |
| # Process for gpt-4.1-nano with LionGuard 2 | |
| lg_flagged = lionguard_2(message) | |
| history_lg.append({"role": "user", "content": message}) | |
| if lg_flagged: | |
| lg_response = "π« This message has been flagged by LionGuard 2" | |
| history_lg.append({"role": "assistant", "content": lg_response}) | |
| else: | |
| lg_response = get_openai_response( | |
| message, | |
| ) | |
| history_lg.append({"role": "assistant", "content": lg_response}) | |
| return history_no_mod, history_openai, history_lg, "" | |
| def clear_all_chats(): | |
| """Clear all chat histories""" | |
| return [], [], [] | |
| # Create the Gradio interface | |
| with gr.Blocks(title="LionGuard 2", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# EMNLP 2025 System Demonstration: LionGuard 2 π¦") | |
| gr.Markdown("**LionGuard 2 is a content moderator localised to Singapore - use it to detect unsafe LLM inputs and outputs**") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("## π΅ No Moderation") | |
| chatbot_no_mod = gr.Chatbot( | |
| height=800, | |
| label="No Moderation", | |
| show_label=False, | |
| bubble_full_width=False, | |
| type='messages' | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("## π OpenAI Moderation") | |
| chatbot_openai = gr.Chatbot( | |
| height=800, | |
| label="OpenAI Moderation", | |
| show_label=False, | |
| bubble_full_width=False, | |
| type='messages' | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("## π‘οΈ LionGuard 2") | |
| chatbot_lg = gr.Chatbot( | |
| height=800, | |
| label="LionGuard 2", | |
| show_label=False, | |
| bubble_full_width=False, | |
| type='messages' | |
| ) | |
| # Single input for all chatbots | |
| gr.Markdown("### π¬ Send Message to All Models") | |
| with gr.Row(): | |
| message_input = gr.Textbox( | |
| placeholder="Type your message to compare responses...", | |
| show_label=False, | |
| scale=4 | |
| ) | |
| send_btn = gr.Button("Send", variant="primary", scale=1) | |
| # Control buttons | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear All Chats", variant="stop") | |
| # Event handlers | |
| send_btn.click( | |
| process_message, | |
| inputs=[message_input, chatbot_no_mod, chatbot_openai, chatbot_lg], | |
| outputs=[chatbot_no_mod, chatbot_openai, chatbot_lg, message_input] | |
| ) | |
| message_input.submit( | |
| process_message, | |
| inputs=[message_input, chatbot_no_mod, chatbot_openai, chatbot_lg], | |
| outputs=[chatbot_no_mod, chatbot_openai, chatbot_lg, message_input] | |
| ) | |
| # Clear button | |
| clear_btn.click( | |
| clear_all_chats, | |
| outputs=[chatbot_no_mod, chatbot_openai, chatbot_lg] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch(share=True, debug=True) |