File size: 4,344 Bytes
a09224a
ab0b22d
 
 
a09224a
ab0b22d
80b70e3
a09224a
ab0b22d
a09224a
ab0b22d
80b70e3
ab0b22d
 
 
 
 
 
 
 
 
 
 
 
 
80b70e3
ab0b22d
 
 
80b70e3
ab0b22d
80b70e3
ab0b22d
80b70e3
 
 
ab0b22d
 
 
 
 
 
80b70e3
ab0b22d
80b70e3
ab0b22d
 
 
80b70e3
ab0b22d
80b70e3
 
a09224a
80b70e3
ab0b22d
80b70e3
ab0b22d
 
 
 
 
 
 
 
 
a09224a
ab0b22d
 
 
 
 
 
 
 
80b70e3
 
ab0b22d
 
80b70e3
 
ab0b22d
 
 
 
 
 
80b70e3
 
 
ab0b22d
 
 
 
 
 
80b70e3
ab0b22d
 
 
 
 
 
 
80b70e3
ab0b22d
 
 
80b70e3
 
 
 
 
 
a09224a
80b70e3
ab0b22d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import torch
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

# Model and Tokenizer Configuration 
MODEL_REPO_ID = "LGAI-EXAONE/EXAONE-4.0-1.2B"

print("βœ… Starting application...")

# Load the model with bfloat16 to save memory
try:
    print(f"πŸ”„ Loading tokenizer from '{MODEL_REPO_ID}'...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO_ID)
    print("βœ… Tokenizer loaded successfully.")

    print(f"πŸ”„ Loading model '{MODEL_REPO_ID}' with torch_dtype=torch.bfloat16...")
    
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_REPO_ID,
        torch_dtype=torch.bfloat16,
        device_map="auto" 
    )
    print("βœ… Model loaded successfully.")

except Exception as e:
    print(f"❌ Error loading model or tokenizer: {e}")
    # Exit if model fails to load, as the app is unusable.
    raise

# Streaming Chat Function
def user_input_handler(user_message, history):
    """Handles user input by appending it to the history."""
    return "", history + [[user_message, None]]

def bot_stream(history):
    """
    Generates the bot's response using a streaming approach.
    This function runs the model in a separate thread to avoid blocking the UI.
    """
    print(f"πŸ“ History received: {history}")
    # The last message is the user's prompt.
    user_message = history[-1][0]
    history[-1][1] = "" # Initialize the bot's response field.

    # Format the conversation history into the model's expected chat format.
    messages = []
    for human, assistant in history[:-1]: # All but the last interaction
        messages.append({"role": "user", "content": human})
        if assistant: # Assistant message might be None
            messages.append({"role": "assistant", "content": assistant})
    messages.append({"role": "user", "content": user_message})

    try:
        # Apply the chat template to format the prompt.
        prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    except Exception as e:
        print(f"⚠️ Warning: Could not apply chat template. Using basic formatting. Error: {e}")
        # Fallback for models without a registered chat template
        prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) + "\nassistant:"

    print("➑️ Generated Prompt for Model:\n" + prompt)

    # Tokenize the formatted prompt.
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    # Use TextIteratorStreamer for non-blocking, token-by-token generation.
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    # Set up the generation parameters in a dictionary.
    generation_kwargs = dict(
        inputs,
        streamer=streamer,
        max_new_tokens=512,
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        eos_token_id=tokenizer.eos_token_id,
    )

    # Run the generation in a separate thread to avoid blocking the Gradio UI.
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    # Yield each new token to the Gradio chat interface as it's generated.
    for token in streamer:
        history[-1][1] += token
        yield history

# Gradio User Interface
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), css="footer {display: none !important}") as demo:
    gr.Markdown("## πŸ€– EXAONE-4.0-1.2B")
    gr.Markdown("This demo runs the standard `LGAI-EXAONE/EXAONE-4.0-1.2B` model using the `transformers` library.")

    chatbot = gr.Chatbot(label="Chat History", height=600, bubble_full_width=False)
    with gr.Row():
        msg = gr.Textbox(
            placeholder="Type your message here...",
            label="Your Message",
            scale=8,
            autofocus=True,
        )
        send_btn = gr.Button("Send", scale=1, variant="primary")

    clear_btn = gr.ClearButton([msg, chatbot], value="πŸ—‘οΈ Clear Chat")

    # Event Handlers
    msg.submit(user_input_handler, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot_stream, chatbot, chatbot
    )
    send_btn.click(user_input_handler, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot_stream, chatbot, chatbot
    )

demo.queue()
demo.launch(debug=True)