File size: 3,996 Bytes
c3a6888
1b4701d
0ca0c79
 
929f4df
587862b
 
c3a6888
ac41ebb
c3a6888
929f4df
c3a6888
 
 
 
ac41ebb
587862b
929f4df
65f698e
7b7ccff
c3a6888
 
 
587862b
1b4701d
 
0ca0c79
3d60c5e
0ca0c79
587862b
 
 
 
 
1b4701d
0ca0c79
 
 
 
 
c3a6888
929f4df
 
 
 
 
 
 
 
587862b
0ca0c79
 
1b4701d
 
0ca0c79
 
 
 
929f4df
0ca0c79
929f4df
0ca0c79
 
 
 
 
 
587862b
 
7b7ccff
587862b
 
1b4701d
7b7ccff
0ca0c79
 
 
 
 
929f4df
 
7b7ccff
587862b
929f4df
 
7b7ccff
929f4df
 
 
 
 
 
 
 
 
c3a6888
929f4df
 
 
7b7ccff
929f4df
c3a6888
929f4df
 
 
 
 
0ca0c79
 
929f4df
7b7ccff
929f4df
 
0ca0c79
 
 
 
 
 
 
 
 
 
 
 
929f4df
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# app.py
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image, UnidentifiedImageError
import io, torch, base64, traceback, random
from transformers import AutoImageProcessor, AutoModelForImageClassification

# ========= CONFIG =========
MODEL_ID = "prithivMLmods/Trash-Net"

# Mapa EN -> PT (apenas 4 classes desejadas)
MAP_PT = {
    "glass": "vidro",
    "metal": "metal",
    "paper": "papel",
    "plastic": "plastico",
}
ALLOWED = ["plastico", "papel", "vidro", "metal"]  # ordem fixa p/ random

# ========= OTIMIZAÇÕES (CPU do Space) =========
torch.set_grad_enabled(False)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)

# Evita imagens gigantes
Image.MAX_IMAGE_PIXELS = 25_000_000

# ========= CARREGAMENTO =========
processor = AutoImageProcessor.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
model.eval()

app = FastAPI()

# CORS (opcional)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], allow_credentials=True,
    allow_methods=["*"], allow_headers=["*"],
)

def _force_allowed(label_en: str | None) -> str:
    """Converte label EN para PT se mapeado; caso contrário, escolhe aleatoriamente uma das 4."""
    if label_en:
        pt = MAP_PT.get(label_en.strip().lower())
        if pt in ALLOWED:
            return pt
    # fallback forçado
    return random.choice(ALLOWED)

def _predict_image_bytes(img_bytes: bytes) -> str:
    with Image.open(io.BytesIO(img_bytes)) as img:
        img = img.convert("RGB")
        img = img.resize((256, 256))  # tradeoff bom para CPU do Space

    with torch.inference_mode():
        inputs = processor(images=img, return_tensors="pt")
        logits = model(**inputs).logits
        idx = int(logits.argmax(-1))
        label_en = model.config.id2label[idx]
        return _force_allowed(label_en)

# ========= ROTAS =========
@app.get("/")
def root():
    return {"ok": True, "message": "TrashNet classifier up", "model": MODEL_ID}

@app.get("/health")
def health():
    return {"ok": True, "model": MODEL_ID}

@app.post("/predict")
async def predict(request: Request):
    """
    Aceita:
    - application/octet-stream (raw JPEG no corpo)
    - image/jpeg (raw JPEG no corpo)
    - application/json {"image_b64": "..."} (dataURL ou base64 puro)

    Retorna SEMPRE: texto puro — 'vidro' | 'papel' | 'plastico' | 'metal'
    (nunca 'nao_identificado')
    """
    try:
      ctype = (request.headers.get("content-type") or "").lower()
      img_bytes: bytes = b""

      if "application/octet-stream" in ctype or "image/jpeg" in ctype:
          img_bytes = await request.body()
      else:
          # fallback: JSON base64
          data = await request.json()
          b64 = (data.get("image_b64") or "")
          if "," in b64:  # dataURL
              b64 = b64.split(",", 1)[1]
          img_bytes = base64.b64decode(b64) if b64 else b""

      # Se veio vazio, ainda assim devolve um dos 4
      if not img_bytes:
          return Response(random.choice(ALLOWED), media_type="text/plain")

      label = _predict_image_bytes(img_bytes)

      # Por garantia, força para uma das 4
      if label not in ALLOWED:
          label = random.choice(ALLOWED)

      return Response(label, media_type="text/plain")

    except UnidentifiedImageError:
      return Response(random.choice(ALLOWED), media_type="text/plain")
    except Exception:
      traceback.print_exc()
      return Response(random.choice(ALLOWED), media_type="text/plain")

# ========= WARM-UP =========
@app.on_event("startup")
def _warmup():
    try:
        dummy = Image.new("RGB", (256, 256), (127, 127, 127))
        with torch.inference_mode():
            inputs = processor(images=dummy, return_tensors="pt")
            _ = model(**inputs).logits
        print("[startup] warm-up ok")
    except Exception:
        traceback.print_exc()
        print("[startup] warm-up falhou (seguindo sem)")