from fastapi import FastAPI, HTTPException from pydantic import BaseModel from ctransformers import AutoModelForCausalLM import os import shutil from huggingface_hub import hf_hub_download app = FastAPI(title="GPT-OSS-20B API") # Set environment variables os.environ["HF_HOME"] = "/app/cache/huggingface" os.environ["HUGGINGFACE_HUB_CACHE"] = "/app/cache/huggingface/hub" os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # Model ID and local directory MODEL_ID = "unsloth/gpt-oss-20b-GGUF" MODEL_DIR = "/app/gpt-oss-20b" MODEL_FILE = "gpt-oss-20b.Q4_K_M.gguf" # Adjust based on actual filename # Clear cache directory cache_dir = os.environ["HF_HOME"] if os.path.exists(cache_dir): print(f"Clearing cache directory: {cache_dir}") for item in os.listdir(cache_dir): item_path = os.path.join(cache_dir, item) if os.path.isdir(item_path): shutil.rmtree(item_path, ignore_errors=True) else: os.remove(item_path) if os.path.exists(item_path) else None # Create directories os.makedirs(cache_dir, exist_ok=True) os.makedirs(MODEL_DIR, exist_ok=True) # Download model file print("Downloading model file...") try: hf_hub_download( repo_id=MODEL_ID, filename=MODEL_FILE, local_dir=MODEL_DIR, cache_dir=cache_dir ) print("Model file downloaded successfully.") except Exception as e: raise RuntimeError(f"Failed to download model: {str(e)}") # Load model print("Loading model...") try: model = AutoModelForCausalLM.from_pretrained( MODEL_DIR, model_type="gguf", model_file=MODEL_FILE ) print("Model loaded successfully.") except Exception as e: raise RuntimeError(f"Failed to load model: {str(e)}") class ChatRequest(BaseModel): message: str max_tokens: int = 256 temperature: float = 0.7 @app.post("/chat") async def chat_endpoint(request: ChatRequest): try: # Generate response response = model( request.message, max_new_tokens=request.max_tokens, temperature=request.temperature ) return {"response": response} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)