import os, io, base64, traceback from typing import Any, List, Optional import requests import torch from PIL import Image from fastapi import FastAPI, HTTPException, Header from fastapi.responses import JSONResponse from pydantic import BaseModel from transformers import AutoProcessor, AutoModelForVision2Seq MODEL_ID = "Qwen/Qwen3-VL-4B-Instruct" DEVICE = "cpu" DTYPE = torch.float32 # Image & text input limits MAX_IMAGES = 5 # Maximum number of images per request MAX_TOTAL_IMAGE_BYTES = 5_000_000 # 5 MB total image size limit MAX_INPUT_TOKENS = 4096 # Max tokens to pass into the model MAX_NEW_TOKENS = 256 # Max tokens to generate in output processor = AutoProcessor.from_pretrained( MODEL_ID, trust_remote_code=True, ) model = AutoModelForVision2Seq.from_pretrained( MODEL_ID, trust_remote_code=True, torch_dtype=DTYPE, device_map=DEVICE, ).to(DEVICE) model.eval() BACKEND_KEY = os.getenv("BACKEND_KEY") def require_auth(authorization: Optional[str]): if not BACKEND_KEY: return if not authorization or authorization.split()[-1] != BACKEND_KEY: raise HTTPException(status_code=401, detail="Unauthorized") class ChatMessage(BaseModel): role: str content: Any # string or list of parts class ChatRequest(BaseModel): model: Optional[str] = "qwen2-vl-2b-instruct" messages: List[ChatMessage] max_tokens: Optional[int] = MAX_NEW_TOKENS temperature: Optional[float] = 0.0 stream: Optional[bool] = False app = FastAPI() @app.get("/health") def health(): return {"ok": True, "model": MODEL_ID, "device": DEVICE} def _pil_from_data_url(url: str) -> Image.Image: head, b64 = url.split(",", 1) raw = base64.b64decode(b64) if len(raw) > MAX_TOTAL_IMAGE_BYTES: raise ValueError(f"image too large: {len(raw)} bytes > {MAX_TOTAL_IMAGE_BYTES}") return Image.open(io.BytesIO(raw)).convert("RGB") def _pil_from_http(url: str) -> Image.Image: r = requests.get(url, timeout=15) r.raise_for_status() raw = r.content if len(raw) > MAX_TOTAL_IMAGE_BYTES: raise ValueError(f"image too large: {len(raw)} bytes > {MAX_TOTAL_IMAGE_BYTES}") return Image.open(io.BytesIO(raw)).convert("RGB") def _openai_to_qwen(messages: List[dict]) -> tuple[list[dict], list[Image.Image]]: chat, imgs = [], [] total_img_bytes = 0 for m in messages: role = m.get("role", "user") contents = m.get("content", []) if isinstance(contents, str): contents = [{"type":"text","text": contents}] norm = [] for part in contents: if isinstance(part, dict) and part.get("type") == "text": norm.append({"type":"text","text": part["text"]}) elif isinstance(part, dict) and part.get("type") == "image_url": url = part["image_url"]["url"] if url.startswith("data:image/"): head, b64 = url.split(",", 1) raw = base64.b64decode(b64) total_img_bytes += len(raw) img = Image.open(io.BytesIO(raw)).convert("RGB") elif url.startswith("http"): r = requests.get(url, timeout=15) r.raise_for_status() raw = r.content total_img_bytes += len(raw) img = Image.open(io.BytesIO(raw)).convert("RGB") else: img = Image.open(url[7:] if url.startswith("file://") else url).convert("RGB") img.thumbnail((1024, 1024)) # aggressive downscale for CPU imgs.append(img) norm.append({"type":"image","image": img}) if norm: chat.append({"role": role, "content": norm}) if len(imgs) > MAX_IMAGES: raise ValueError(f"too many images: {len(imgs)} > {MAX_IMAGES}") if total_img_bytes > MAX_TOTAL_IMAGE_BYTES: raise ValueError(f"total image bytes {total_img_bytes} > {MAX_TOTAL_IMAGE_BYTES}") return chat, imgs @app.post("/v1/chat/completions") def chat_completions(req: ChatRequest, authorization: Optional[str] = Header(None)): try: require_auth(authorization) msgs = [m.dict() for m in req.messages] qwen_msgs, imgs = _openai_to_qwen(msgs) # Build chat text text = processor.apply_chat_template( qwen_msgs, tokenize=False, add_generation_prompt=True ) # Tokenize with truncation inputs = processor( text=[text], images=imgs if imgs else None, padding=True, truncation=True, max_length=MAX_INPUT_TOKENS, return_tensors="pt", ) inputs = {k: v.to(DEVICE) for k, v in inputs.items()} max_new = int(req.max_tokens or MAX_NEW_TOKENS) max_new = min(max_new, MAX_NEW_TOKENS) temp = float(req.temperature or 0.0) with torch.inference_mode(): out = model.generate( **inputs, max_new_tokens=max_new, temperature=temp, do_sample=False, ) trimmed = out[:, inputs["input_ids"].shape[1]:] text_out = processor.batch_decode( trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0].strip() return { "id": "chatcmpl-qwen2vl", "object": "chat.completion", "model": req.model, "choices": [{ "index": 0, "finish_reason": "stop", "message": {"role": "assistant", "content": text_out} }], } except HTTPException: raise except Exception as e: # Log server-side traceback for debugging in Space Logs traceback.print_exc() return JSONResponse( status_code=500, content={ "error": { "message": f"{type(e).__name__}: {str(e)}", "type": "server_error", "code": "internal_server_error", } }, )