File size: 1,326 Bytes
587862b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from fastapi import FastAPI, Request, Response
from PIL import Image
import io, torch
from transformers import AutoImageProcessor, AutoModelForImageClassification

MODEL_ID = "prithivMLmods/Trash-Net"
PT_MAP = {
    "plastic":"plastico", "paper":"papel", "glass":"vidro",
    "metal":"metal", "cardboard":"papel", "trash":"nao_identificado"
}

processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
model.eval()

app = FastAPI()

def predict_bytes(img_bytes: bytes) -> str:
    img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    inputs = processor(images=img, return_tensors="pt")
    with torch.no_grad():
        logits = model(**inputs).logits
    idx = int(logits.softmax(-1).argmax(-1))
    label_en = model.config.id2label[idx].lower()
    return PT_MAP.get(label_en, "nao_identificado")

@app.get("/health")
def health():
    return {"ok": True}

@app.post("/predict")
async def predict(request: Request):
    try:
        img_bytes = await request.body()
        if not img_bytes:
            return Response("nao_identificado", media_type="text/plain")
        label = predict_bytes(img_bytes)
        return Response(label, media_type="text/plain")
    except Exception:
        return Response("nao_identificado", media_type="text/plain")