trashnet-server / app.py
froidhj's picture
Create app.py
587862b verified
raw
history blame
1.33 kB
from fastapi import FastAPI, Request, Response
from PIL import Image
import io, torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
MODEL_ID = "prithivMLmods/Trash-Net"
PT_MAP = {
"plastic":"plastico", "paper":"papel", "glass":"vidro",
"metal":"metal", "cardboard":"papel", "trash":"nao_identificado"
}
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
model.eval()
app = FastAPI()
def predict_bytes(img_bytes: bytes) -> str:
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
inputs = processor(images=img, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
idx = int(logits.softmax(-1).argmax(-1))
label_en = model.config.id2label[idx].lower()
return PT_MAP.get(label_en, "nao_identificado")
@app.get("/health")
def health():
return {"ok": True}
@app.post("/predict")
async def predict(request: Request):
try:
img_bytes = await request.body()
if not img_bytes:
return Response("nao_identificado", media_type="text/plain")
label = predict_bytes(img_bytes)
return Response(label, media_type="text/plain")
except Exception:
return Response("nao_identificado", media_type="text/plain")