trashnet-server / app.py
froidhj's picture
Update app.py
7b7ccff verified
raw
history blame
2.42 kB
# app.py
from fastapi import FastAPI, Request, Response
from PIL import Image
import io, os, torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
# ========= CONFIG =========
MODEL_ID = "prithivMLmods/Trash-Net"
# Mantemos só estas 4 classes em PT-BR; o resto vira "nao_identificado"
MAP_PT = {
"glass": "vidro",
"metal": "metal",
"paper": "papel",
"plastic": "plastico",
}
ALLOWED = set(MAP_PT.values())
# ========= OTIMIZAÇÕES (CPU do Space) =========
torch.set_grad_enabled(False)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
# ========= CARREGAMENTO =========
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
model.eval()
app = FastAPI()
def predict_image_bytes(img_bytes: bytes) -> str:
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
# Reduz um pouco para acelerar sem perder muito
img = img.resize((256, 256))
inputs = processor(images=img, return_tensors="pt")
logits = model(**inputs).logits
idx = int(logits.softmax(-1).argmax(-1))
label_en = model.config.id2label[idx].lower()
# Converte apenas se for uma das 4; senão marca como não identificado
return MAP_PT.get(label_en, "nao_identificado")
@app.get("/health")
def health():
return {"ok": True, "model": MODEL_ID}
@app.post("/predict")
async def predict(request: Request):
"""
Espera: bytes JPEG (application/octet-stream)
Retorna: texto puro — 'vidro' | 'papel' | 'plastico' | 'metal' | 'nao_identificado'
"""
try:
ctype = (request.headers.get("content-type") or "").lower()
if "application/octet-stream" in ctype or "image/jpeg" in ctype:
img_bytes = await request.body()
else:
# fallback opcional para JSON base64 (testes manuais)
data = await request.json()
import base64
b64 = (data.get("image_b64") or "").split(",")[-1]
img_bytes = base64.b64decode(b64) if b64 else b""
if not img_bytes:
return Response("nao_identificado", media_type="text/plain")
label = predict_image_bytes(img_bytes)
if label not in ALLOWED:
label = "nao_identificado"
return Response(label, media_type="text/plain")
except Exception:
return Response("nao_identificado", media_type="text/plain")