letran1110's picture
Update app.py
996d3ac verified
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
import gradio as gr
# Kiểm tra thiết bị
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model và tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained("letran1110/vit5_motor_extractor").to(device)
tokenizer = AutoTokenizer.from_pretrained("letran1110/vit5_motor_extractor")
# Warm-up để lần đầu không bị delay
def warmup():
inputs = tokenizer("test", return_tensors="pt", truncation=True, padding=True, max_length=512)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
_ = model.generate(**inputs, max_length=256)
warmup()
# Hàm dự đoán
def predict(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
output = model.generate(**inputs, max_length=256)
return tokenizer.decode(output[0], skip_special_tokens=True)
# Giao diện Gradio
demo = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=15, placeholder="Nhập mô tả động cơ..."),
outputs="text",
title="Motor Info Extractor",
description="Nhập mô tả động cơ (tiếng Việt hoặc Anh), hệ thống sẽ trích xuất thông tin cấu hình dưới dạng JSON."
)
if __name__ == "__main__":
demo.launch()