Spaces:
Running
Running
| # -*- encoding: utf-8 -*- | |
| # File: app.py | |
| # Description: None | |
| from copy import deepcopy | |
| from typing import Dict, List | |
| from PIL import Image | |
| import io | |
| import subprocess | |
| import requests | |
| import json | |
| import base64 | |
| import gradio as gr | |
| import librosa | |
| IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp") | |
| VIDEO_EXTENSIONS = (".mp4", ".mkv", ".mov", ".avi", ".flv", ".wmv", ".webm", ".m4v") | |
| AUDIO_EXTENSIONS = (".mp3", ".wav", "flac", ".m4a", ".wma") | |
| DEFAULT_SAMPLING_PARAMS = { | |
| "top_p": 0.8, | |
| "top_k": 100, | |
| "temperature": 0.7, | |
| "do_sample": True, | |
| "num_beams": 1, | |
| "repetition_penalty": 1.2, | |
| } | |
| MAX_NEW_TOKENS = 1024 | |
| def load_image_to_base64(image_path): | |
| """Load image and convert to base64 string""" | |
| with Image.open(image_path) as img: | |
| if img.mode != 'RGB': | |
| img = img.convert('RGB') | |
| img_byte_arr = io.BytesIO() | |
| img.save(img_byte_arr, format='PNG') | |
| img_byte_arr = img_byte_arr.getvalue() | |
| return base64.b64encode(img_byte_arr).decode('utf-8') | |
| def wav_to_bytes_with_ffmpeg(wav_file_path): | |
| process = subprocess.Popen( | |
| ['ffmpeg', '-i', wav_file_path, '-f', 'wav', '-'], | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE | |
| ) | |
| out, _ = process.communicate() | |
| return base64.b64encode(out).decode('utf-8') | |
| def parse_sse_response(response): | |
| for line in response.iter_lines(): | |
| if line: | |
| line = line.decode('utf-8') | |
| if line.startswith('data: '): | |
| data = line[6:] # Remove 'data: ' prefix | |
| if data == '[DONE]': | |
| break | |
| try: | |
| json_data = json.loads(data) | |
| yield json_data['text'] | |
| except json.JSONDecodeError: | |
| raise gr.Error(f"Failed to parse JSON: {data}") | |
| def history2messages(history: List[Dict]) -> List[Dict]: | |
| """ | |
| Transform gradio history to chat messages. | |
| """ | |
| messages = [] | |
| cur_message = dict() | |
| for item in history: | |
| if item["role"] == "assistant": | |
| if len(cur_message) > 0: | |
| messages.append(deepcopy(cur_message)) | |
| cur_message = dict() | |
| messages.append(deepcopy(item)) | |
| continue | |
| print(f"item:{item}") | |
| print(f"cur_message:{cur_message}") | |
| if "role" not in cur_message: | |
| cur_message["role"] = "user" | |
| if "content" not in cur_message: | |
| cur_message["content"] = dict() | |
| if "metadata" not in item or item["metadata"] is None: | |
| item["metadata"] = {"title": None} | |
| if item["metadata"]["title"] is None or item["metadata"]["title"] == "text": | |
| cur_message["content"]["text"] = item["content"] | |
| elif item["metadata"]["title"] == "image": | |
| cur_message["content"]["image"] = load_image_to_base64(item["content"][0]) | |
| elif item["metadata"]["title"] == "audio": | |
| cur_message["content"]["audio"] = wav_to_bytes_with_ffmpeg(item["content"][0]) | |
| print(f"cur_message:{cur_message}") | |
| if len(cur_message) > 0: | |
| messages.append(cur_message) | |
| return messages | |
| def check_messages(history, message, audio): | |
| has_text = message["text"] and message["text"].strip() | |
| has_files = len(message["files"]) > 0 | |
| has_audio = audio is not None | |
| if not (has_text or has_files or has_audio): | |
| raise gr.Error("请输入文字或上传音频/图片后再发送。") | |
| audios = [] | |
| images = [] | |
| for file_msg in message["files"]: | |
| if file_msg.endswith(AUDIO_EXTENSIONS) or file_msg.endswith(VIDEO_EXTENSIONS): | |
| duration = librosa.get_duration(filename=file_msg) | |
| if duration > 30: | |
| raise gr.Error("音频时长不能超过30秒。") | |
| if duration == 0: | |
| raise gr.Error("音频时长不能为0秒。") | |
| audios.append(file_msg) | |
| elif file_msg.endswith(IMAGE_EXTENSIONS): | |
| images.append(file_msg) | |
| else: | |
| filename = file_msg.split("/")[-1] | |
| raise gr.Error(f"Unsupported file type: {filename}. It should be an image or audio file.") | |
| if len(audios) > 1: | |
| raise gr.Error("Please upload only one audio file.") | |
| if len(images) > 1: | |
| raise gr.Error("Please upload only one image file.") | |
| if audio is not None: | |
| if len(audios) > 0: | |
| raise gr.Error("Please upload only one audio file or record audio.") | |
| audios.append(audio) | |
| # Append the message to the history | |
| for image in images: | |
| history.append({"role": "user", "content": (image,), "metadata": {"title": "image"}}) | |
| for audio in audios: | |
| history.append({"role": "user", "content": (audio,), "metadata": {"title": "audio"}}) | |
| if message["text"]: | |
| history.append({"role": "user", "content": message["text"], "metadata": {"title": "text"}}) | |
| return history, gr.MultimodalTextbox(value=None, interactive=False), None | |
| def bot( | |
| history: list, | |
| top_p: float, | |
| top_k: int, | |
| temperature: float, | |
| repetition_penalty: float, | |
| max_new_tokens: int = MAX_NEW_TOKENS, | |
| regenerate: bool = False, | |
| ): | |
| if history and regenerate: | |
| history = history[:-1] | |
| if not history: | |
| return history | |
| msgs = history2messages(history) | |
| API_URL = "http://8.152.0.142:80/v1/chat" | |
| payload = { | |
| "messages": msgs, | |
| "sampling_params": { | |
| "top_p": top_p, | |
| "top_k": top_k, | |
| "temperature": temperature, | |
| "repetition_penalty": repetition_penalty, | |
| "max_new_tokens": max_new_tokens, | |
| "num_beams": 3, | |
| } | |
| } | |
| response = requests.get( | |
| API_URL, | |
| json=payload, | |
| headers={'Accept': 'text/event-stream'}, | |
| stream=True | |
| ) | |
| response_text = "" | |
| for text in parse_sse_response(response): | |
| print(f"text: {text}") | |
| response_text += text | |
| yield history + [{"role": "assistant", "content": response_text, "metadata": {"title": "text"}}] | |
| return response_text | |
| def change_state(state): | |
| return gr.update(visible=not state), not state | |
| def reset_user_input(): | |
| return gr.update(value="") | |
| if __name__ == "__main__": | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| f""" | |
| # 🪐 Chat with <a href="https://github.com/infinigence/Infini-Megrez-Omni">Megrez-3B-Omni</a> | |
| """ | |
| ) | |
| chatbot = gr.Chatbot(elem_id="chatbot", bubble_full_width=False, type="messages", height='48vh') | |
| sampling_params_group_hidden_state = gr.State(False) | |
| with gr.Row(equal_height=True): | |
| chat_input = gr.MultimodalTextbox( | |
| file_count="multiple", | |
| placeholder="Enter your prompt or upload image/audio here, then press ENTER...", | |
| show_label=False, | |
| scale=8, | |
| file_types=["image", "audio"], | |
| interactive=True, | |
| # stop_btn=True, | |
| ) | |
| with gr.Row(equal_height=True): | |
| audio_input = gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="filepath", | |
| scale=1, | |
| max_length=30 | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1, min_width=150): | |
| with gr.Row(equal_height=True): | |
| regenerate_btn = gr.Button("Regenerate", variant="primary") | |
| clear_btn = gr.ClearButton( | |
| [chat_input, audio_input, chatbot], | |
| ) | |
| with gr.Row(): | |
| sampling_params_toggle_btn = gr.Button("Sampling Parameters") | |
| with gr.Group(visible=False) as sampling_params_group: | |
| with gr.Row(): | |
| temperature = gr.Slider( | |
| minimum=0, maximum=1.2, value=DEFAULT_SAMPLING_PARAMS["temperature"], label="Temperature" | |
| ) | |
| repetition_penalty = gr.Slider( | |
| minimum=0, | |
| maximum=2, | |
| value=DEFAULT_SAMPLING_PARAMS["repetition_penalty"], | |
| label="Repetition Penalty", | |
| ) | |
| with gr.Row(): | |
| top_p = gr.Slider(minimum=0, maximum=1, value=DEFAULT_SAMPLING_PARAMS["top_p"], label="Top-p") | |
| top_k = gr.Slider(minimum=0, maximum=1000, value=DEFAULT_SAMPLING_PARAMS["top_k"], label="Top-k") | |
| with gr.Row(): | |
| max_new_tokens = gr.Slider( | |
| minimum=1, | |
| maximum=MAX_NEW_TOKENS, | |
| value=MAX_NEW_TOKENS, | |
| label="Max New Tokens", | |
| interactive=True, | |
| ) | |
| sampling_params_toggle_btn.click( | |
| change_state, | |
| sampling_params_group_hidden_state, | |
| [sampling_params_group, sampling_params_group_hidden_state], | |
| ) | |
| chat_msg = chat_input.submit( | |
| check_messages, | |
| [chatbot, chat_input, audio_input], | |
| [chatbot, chat_input, audio_input], | |
| ) | |
| bot_msg = chat_msg.then( | |
| bot, | |
| inputs=[chatbot, top_p, top_k, temperature, repetition_penalty, max_new_tokens], | |
| outputs=chatbot, | |
| api_name="bot_response", | |
| ) | |
| bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input]) | |
| regenerate_btn.click( | |
| bot, | |
| inputs=[chatbot, top_p, top_k, temperature, repetition_penalty, max_new_tokens, gr.State(True)], | |
| outputs=chatbot, | |
| ) | |
| demo.launch() |