Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request, Response | |
| from PIL import Image | |
| import io, torch | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| MODEL_ID = "prithivMLmods/Trash-Net" | |
| PT_MAP = { | |
| "plastic":"plastico", "paper":"papel", "glass":"vidro", | |
| "metal":"metal", "cardboard":"papel", "trash":"nao_identificado" | |
| } | |
| processor = AutoImageProcessor.from_pretrained(MODEL_ID) | |
| model = AutoModelForImageClassification.from_pretrained(MODEL_ID) | |
| model.eval() | |
| app = FastAPI() | |
| def predict_bytes(img_bytes: bytes) -> str: | |
| img = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
| inputs = processor(images=img, return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| idx = int(logits.softmax(-1).argmax(-1)) | |
| label_en = model.config.id2label[idx].lower() | |
| return PT_MAP.get(label_en, "nao_identificado") | |
| def health(): | |
| return {"ok": True} | |
| async def predict(request: Request): | |
| try: | |
| img_bytes = await request.body() | |
| if not img_bytes: | |
| return Response("nao_identificado", media_type="text/plain") | |
| label = predict_bytes(img_bytes) | |
| return Response(label, media_type="text/plain") | |
| except Exception: | |
| return Response("nao_identificado", media_type="text/plain") | |