File size: 3,855 Bytes
c3a6888
1b4701d
0ca0c79
 
1b4701d
587862b
 
c3a6888
ac41ebb
c3a6888
7b7ccff
c3a6888
 
 
 
ac41ebb
587862b
7b7ccff
65f698e
7b7ccff
c3a6888
 
 
587862b
1b4701d
 
0ca0c79
3d60c5e
1b4701d
 
0ca0c79
587862b
 
 
 
 
1b4701d
0ca0c79
 
 
 
 
c3a6888
0ca0c79
1b4701d
587862b
0ca0c79
 
1b4701d
 
0ca0c79
 
 
 
1b4701d
0ca0c79
 
 
 
 
 
 
 
587862b
 
7b7ccff
587862b
 
1b4701d
7b7ccff
0ca0c79
 
 
 
 
7b7ccff
 
587862b
ac41ebb
1b4701d
7b7ccff
1b4701d
3d60c5e
c3a6888
1b4701d
c3a6888
0ca0c79
1b4701d
0ca0c79
65f698e
c3a6888
587862b
7b7ccff
 
0ca0c79
7b7ccff
 
c3a6888
7b7ccff
0ca0c79
 
 
7b7ccff
0ca0c79
7b7ccff
0ca0c79
 
 
 
 
 
 
 
 
 
 
 
1b4701d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# app.py
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image, UnidentifiedImageError
import io, torch, base64, traceback
from transformers import AutoImageProcessor, AutoModelForImageClassification

# ========= CONFIG =========
MODEL_ID = "prithivMLmods/Trash-Net"

# Mantemos só estas 4 classes em PT-BR; o resto vira "nao_identificado"
MAP_PT = {
    "glass": "vidro",
    "metal": "metal",
    "paper": "papel",
    "plastic": "plastico",
}
ALLOWED = set(MAP_PT.values())

# ========= OTIMIZAÇÕES (CPU do Space) =========
torch.set_grad_enabled(False)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)

# Evita imagens gigantes
Image.MAX_IMAGE_PIXELS = 25_000_000

# ========= CARREGAMENTO =========
# Se aparecer aviso "use_fast=True mas torchvision não disponível",
# é só um warning; pode trocar para use_fast=False se quiser ocultar.
processor = AutoImageProcessor.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
model.eval()

app = FastAPI()

# CORS (opcional)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], allow_credentials=True,
    allow_methods=["*"], allow_headers=["*"],
)

def _to_label_pt(label_en: str) -> str:
    return MAP_PT.get((label_en or "").strip().lower(), "nao_identificado")

def _predict_image_bytes(img_bytes: bytes) -> str:
    with Image.open(io.BytesIO(img_bytes)) as img:
        img = img.convert("RGB")
        img = img.resize((256, 256))  # tradeoff bom para CPU do Space

    with torch.inference_mode():
        inputs = processor(images=img, return_tensors="pt")
        logits = model(**inputs).logits
        idx = int(logits.argmax(-1))  # softmax não precisa para argmax
        label_en = model.config.id2label[idx]
        return _to_label_pt(label_en)

# ========= ROTAS =========
@app.get("/")
def root():
    return {"ok": True, "message": "TrashNet classifier up", "model": MODEL_ID}

@app.get("/health")
def health():
    return {"ok": True, "model": MODEL_ID}

@app.post("/predict")
async def predict(request: Request):
    """
    Aceita:
    - application/octet-stream (raw JPEG no corpo)
    - image/jpeg (raw JPEG no corpo)
    - application/json {"image_b64": "..."} (dataURL ou base64 puro)

    Retorna: texto puro — 'vidro' | 'papel' | 'plastico' | 'metal' | 'nao_identificado'
    """
    try:
        ctype = (request.headers.get("content-type") or "").lower()
        img_bytes: bytes = b""

        if "application/octet-stream" in ctype or "image/jpeg" in ctype:
            img_bytes = await request.body()
        else:
            # fallback: JSON base64
            data = await request.json()
            b64 = (data.get("image_b64") or "")
            if "," in b64:  # dataURL
                b64 = b64.split(",", 1)[1]
            img_bytes = base64.b64decode(b64) if b64 else b""

        if not img_bytes:
            return Response("nao_identificado", media_type="text/plain")

        label = _predict_image_bytes(img_bytes)
        if label not in ALLOWED:
            label = "nao_identificado"

        return Response(label, media_type="text/plain")

    except UnidentifiedImageError:
        return Response("nao_identificado", media_type="text/plain")
    except Exception:
        traceback.print_exc()
        return Response("nao_identificado", media_type="text/plain")

# ========= WARM-UP =========
@app.on_event("startup")
def _warmup():
    try:
        dummy = Image.new("RGB", (256, 256), (127, 127, 127))
        with torch.inference_mode():
            inputs = processor(images=dummy, return_tensors="pt")
            _ = model(**inputs).logits
        print("[startup] warm-up ok")
    except Exception:
        traceback.print_exc()
        print("[startup] warm-up falhou (seguindo sem)")