Spaces:
Sleeping
Sleeping
File size: 3,855 Bytes
c3a6888 1b4701d 0ca0c79 1b4701d 587862b c3a6888 ac41ebb c3a6888 7b7ccff c3a6888 ac41ebb 587862b 7b7ccff 65f698e 7b7ccff c3a6888 587862b 1b4701d 0ca0c79 3d60c5e 1b4701d 0ca0c79 587862b 1b4701d 0ca0c79 c3a6888 0ca0c79 1b4701d 587862b 0ca0c79 1b4701d 0ca0c79 1b4701d 0ca0c79 587862b 7b7ccff 587862b 1b4701d 7b7ccff 0ca0c79 7b7ccff 587862b ac41ebb 1b4701d 7b7ccff 1b4701d 3d60c5e c3a6888 1b4701d c3a6888 0ca0c79 1b4701d 0ca0c79 65f698e c3a6888 587862b 7b7ccff 0ca0c79 7b7ccff c3a6888 7b7ccff 0ca0c79 7b7ccff 0ca0c79 7b7ccff 0ca0c79 1b4701d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
# app.py
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image, UnidentifiedImageError
import io, torch, base64, traceback
from transformers import AutoImageProcessor, AutoModelForImageClassification
# ========= CONFIG =========
MODEL_ID = "prithivMLmods/Trash-Net"
# Mantemos só estas 4 classes em PT-BR; o resto vira "nao_identificado"
MAP_PT = {
"glass": "vidro",
"metal": "metal",
"paper": "papel",
"plastic": "plastico",
}
ALLOWED = set(MAP_PT.values())
# ========= OTIMIZAÇÕES (CPU do Space) =========
torch.set_grad_enabled(False)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
# Evita imagens gigantes
Image.MAX_IMAGE_PIXELS = 25_000_000
# ========= CARREGAMENTO =========
# Se aparecer aviso "use_fast=True mas torchvision não disponível",
# é só um warning; pode trocar para use_fast=False se quiser ocultar.
processor = AutoImageProcessor.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
model.eval()
app = FastAPI()
# CORS (opcional)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"],
)
def _to_label_pt(label_en: str) -> str:
return MAP_PT.get((label_en or "").strip().lower(), "nao_identificado")
def _predict_image_bytes(img_bytes: bytes) -> str:
with Image.open(io.BytesIO(img_bytes)) as img:
img = img.convert("RGB")
img = img.resize((256, 256)) # tradeoff bom para CPU do Space
with torch.inference_mode():
inputs = processor(images=img, return_tensors="pt")
logits = model(**inputs).logits
idx = int(logits.argmax(-1)) # softmax não precisa para argmax
label_en = model.config.id2label[idx]
return _to_label_pt(label_en)
# ========= ROTAS =========
@app.get("/")
def root():
return {"ok": True, "message": "TrashNet classifier up", "model": MODEL_ID}
@app.get("/health")
def health():
return {"ok": True, "model": MODEL_ID}
@app.post("/predict")
async def predict(request: Request):
"""
Aceita:
- application/octet-stream (raw JPEG no corpo)
- image/jpeg (raw JPEG no corpo)
- application/json {"image_b64": "..."} (dataURL ou base64 puro)
Retorna: texto puro — 'vidro' | 'papel' | 'plastico' | 'metal' | 'nao_identificado'
"""
try:
ctype = (request.headers.get("content-type") or "").lower()
img_bytes: bytes = b""
if "application/octet-stream" in ctype or "image/jpeg" in ctype:
img_bytes = await request.body()
else:
# fallback: JSON base64
data = await request.json()
b64 = (data.get("image_b64") or "")
if "," in b64: # dataURL
b64 = b64.split(",", 1)[1]
img_bytes = base64.b64decode(b64) if b64 else b""
if not img_bytes:
return Response("nao_identificado", media_type="text/plain")
label = _predict_image_bytes(img_bytes)
if label not in ALLOWED:
label = "nao_identificado"
return Response(label, media_type="text/plain")
except UnidentifiedImageError:
return Response("nao_identificado", media_type="text/plain")
except Exception:
traceback.print_exc()
return Response("nao_identificado", media_type="text/plain")
# ========= WARM-UP =========
@app.on_event("startup")
def _warmup():
try:
dummy = Image.new("RGB", (256, 256), (127, 127, 127))
with torch.inference_mode():
inputs = processor(images=dummy, return_tensors="pt")
_ = model(**inputs).logits
print("[startup] warm-up ok")
except Exception:
traceback.print_exc()
print("[startup] warm-up falhou (seguindo sem)")
|