Spaces:
Sleeping
Sleeping
File size: 3,996 Bytes
c3a6888 1b4701d 0ca0c79 929f4df 587862b c3a6888 ac41ebb c3a6888 929f4df c3a6888 ac41ebb 587862b 929f4df 65f698e 7b7ccff c3a6888 587862b 1b4701d 0ca0c79 3d60c5e 0ca0c79 587862b 1b4701d 0ca0c79 c3a6888 929f4df 587862b 0ca0c79 1b4701d 0ca0c79 929f4df 0ca0c79 929f4df 0ca0c79 587862b 7b7ccff 587862b 1b4701d 7b7ccff 0ca0c79 929f4df 7b7ccff 587862b 929f4df 7b7ccff 929f4df c3a6888 929f4df 7b7ccff 929f4df c3a6888 929f4df 0ca0c79 929f4df 7b7ccff 929f4df 0ca0c79 929f4df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
# app.py
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image, UnidentifiedImageError
import io, torch, base64, traceback, random
from transformers import AutoImageProcessor, AutoModelForImageClassification
# ========= CONFIG =========
MODEL_ID = "prithivMLmods/Trash-Net"
# Mapa EN -> PT (apenas 4 classes desejadas)
MAP_PT = {
"glass": "vidro",
"metal": "metal",
"paper": "papel",
"plastic": "plastico",
}
ALLOWED = ["plastico", "papel", "vidro", "metal"] # ordem fixa p/ random
# ========= OTIMIZAÇÕES (CPU do Space) =========
torch.set_grad_enabled(False)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
# Evita imagens gigantes
Image.MAX_IMAGE_PIXELS = 25_000_000
# ========= CARREGAMENTO =========
processor = AutoImageProcessor.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
model.eval()
app = FastAPI()
# CORS (opcional)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"],
)
def _force_allowed(label_en: str | None) -> str:
"""Converte label EN para PT se mapeado; caso contrário, escolhe aleatoriamente uma das 4."""
if label_en:
pt = MAP_PT.get(label_en.strip().lower())
if pt in ALLOWED:
return pt
# fallback forçado
return random.choice(ALLOWED)
def _predict_image_bytes(img_bytes: bytes) -> str:
with Image.open(io.BytesIO(img_bytes)) as img:
img = img.convert("RGB")
img = img.resize((256, 256)) # tradeoff bom para CPU do Space
with torch.inference_mode():
inputs = processor(images=img, return_tensors="pt")
logits = model(**inputs).logits
idx = int(logits.argmax(-1))
label_en = model.config.id2label[idx]
return _force_allowed(label_en)
# ========= ROTAS =========
@app.get("/")
def root():
return {"ok": True, "message": "TrashNet classifier up", "model": MODEL_ID}
@app.get("/health")
def health():
return {"ok": True, "model": MODEL_ID}
@app.post("/predict")
async def predict(request: Request):
"""
Aceita:
- application/octet-stream (raw JPEG no corpo)
- image/jpeg (raw JPEG no corpo)
- application/json {"image_b64": "..."} (dataURL ou base64 puro)
Retorna SEMPRE: texto puro — 'vidro' | 'papel' | 'plastico' | 'metal'
(nunca 'nao_identificado')
"""
try:
ctype = (request.headers.get("content-type") or "").lower()
img_bytes: bytes = b""
if "application/octet-stream" in ctype or "image/jpeg" in ctype:
img_bytes = await request.body()
else:
# fallback: JSON base64
data = await request.json()
b64 = (data.get("image_b64") or "")
if "," in b64: # dataURL
b64 = b64.split(",", 1)[1]
img_bytes = base64.b64decode(b64) if b64 else b""
# Se veio vazio, ainda assim devolve um dos 4
if not img_bytes:
return Response(random.choice(ALLOWED), media_type="text/plain")
label = _predict_image_bytes(img_bytes)
# Por garantia, força para uma das 4
if label not in ALLOWED:
label = random.choice(ALLOWED)
return Response(label, media_type="text/plain")
except UnidentifiedImageError:
return Response(random.choice(ALLOWED), media_type="text/plain")
except Exception:
traceback.print_exc()
return Response(random.choice(ALLOWED), media_type="text/plain")
# ========= WARM-UP =========
@app.on_event("startup")
def _warmup():
try:
dummy = Image.new("RGB", (256, 256), (127, 127, 127))
with torch.inference_mode():
inputs = processor(images=dummy, return_tensors="pt")
_ = model(**inputs).logits
print("[startup] warm-up ok")
except Exception:
traceback.print_exc()
print("[startup] warm-up falhou (seguindo sem)") |