File size: 3,621 Bytes
5e1305b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import torch


class VibeThinker:
    def __init__(self, model_path):
        self.model_path = model_path
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_path,
            low_cpu_mem_usage=True,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)

    def infer_text(self, messages):
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)

        generation_config = dict(
            max_new_tokens=4096,
            do_sample=True,
            temperature=0.6,
            top_p=0.95,
            top_k=-1
        )
        
        generated_ids = self.model.generate(
            **model_inputs,
            generation_config=GenerationConfig(**generation_config)
        )
        
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]

        response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        return response


# Initialize the model
print("Loading VibeThinker model...")
vibe_model = VibeThinker('WeiboAI/VibeThinker-1.5B')
print("Model loaded successfully!")


def respond(message, history):
    """
    Generate response for the chatbot.
    
    Args:
        message: The user's current message
        history: List of previous conversation messages in [user, assistant] format
    """
    # Convert history to messages format
    messages = []
    for user_msg, assistant_msg in history:
        messages.append({"role": "user", "content": user_msg})
        messages.append({"role": "assistant", "content": assistant_msg})
    
    # Add current message
    messages.append({"role": "user", "content": message})
    
    # Generate response
    response = vibe_model.infer_text(messages)
    
    return response


# Create the Gradio interface
with gr.Blocks(
    theme=gr.themes.Soft(),
    css="""
    .header-link { text-decoration: none; color: inherit; }
    .header-link:hover { text-decoration: underline; }
    """
) as demo:
    gr.Markdown(
        """
        # 💭 VibeThinker Chatbot
        Chat with [WeiboAI/VibeThinker-1.5B](https://huggingface.co/WeiboAI/VibeThinker-1.5B) - a powerful conversational AI model.
        
        <a href="https://huggingface.co/spaces/akhaliq/anycoder" class="header-link">Built with anycoder</a>
        """
    )
    
    chatbot = gr.ChatInterface(
        fn=respond,
        type="messages",
        title="",
        description="Ask me anything! I'm powered by VibeThinker.",
        examples=[
            "What is the meaning of life?",
            "Explain quantum computing in simple terms",
            "Write a short poem about artificial intelligence",
            "How can I improve my productivity?",
        ],
        cache_examples=False,
        retry_btn=None,
        undo_btn=None,
        clear_btn="Clear Chat",
    )
    
    gr.Markdown(
        """
        ### About VibeThinker
        VibeThinker is a 1.5B parameter conversational AI model designed for engaging and thoughtful conversations.
        The model uses temperature sampling (0.6) for balanced creativity and coherence.
        """
    )

if __name__ == "__main__":
    demo.launch()