Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, Request, Response | |
| from pydantic import BaseModel | |
| from typing import List, Optional | |
| from llama_cpp import Llama | |
| from fastapi.responses import PlainTextResponse, JSONResponse | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| import logging | |
| import json | |
| import os | |
| import time | |
| import uuid | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("api_logger") | |
| class LoggingMiddleware(BaseHTTPMiddleware): | |
| async def dispatch(self, request: Request, call_next): | |
| # Read request body (must be buffered manually) | |
| body = await request.body() | |
| logger.info(f"REQUEST: {request.method} {request.url}\nBody: {body.decode('utf-8')}") | |
| # Rebuild the request with body for downstream handlers | |
| request = Request(request.scope, receive=lambda: {'type': 'http.request', 'body': body}) | |
| # Process the response | |
| response = await call_next(request) | |
| response_body = b"" | |
| async for chunk in response.body_iterator: | |
| response_body += chunk | |
| # Log response body and status code | |
| logger.info(f"RESPONSE: Status {response.status_code}\nBody: {response_body.decode('utf-8')}") | |
| # Rebuild response to preserve original functionality | |
| return Response( | |
| content=response_body, | |
| status_code=response.status_code, | |
| headers=dict(response.headers), | |
| media_type=response.media_type | |
| ) | |
| # FastAPI app with middleware | |
| app = FastAPI() | |
| app.add_middleware(LoggingMiddleware) | |
| llm = None | |
| # Models | |
| class Message(BaseModel): | |
| role: str | |
| content: str | |
| class ChatRequest(BaseModel): | |
| model: str | |
| messages: List[Message] | |
| temperature: Optional[float] = 0.7 | |
| max_tokens: Optional[int] = 256 | |
| class GenerateRequest(BaseModel): | |
| model: str | |
| prompt: str | |
| max_tokens: Optional[int] = 256 | |
| temperature: Optional[float] = 0.7 | |
| class ModelInfo(BaseModel): | |
| id: str | |
| object: str | |
| type: str | |
| publisher: str | |
| arch: str | |
| compatibility_type: str | |
| quantization: str | |
| state: str | |
| max_context_length: int | |
| AVAILABLE_MODELS = [ | |
| ModelInfo( | |
| id="codellama-7b-instruct", | |
| object="model", | |
| type="llm", | |
| publisher="lmstudio-community", | |
| arch="llama", | |
| compatibility_type="gguf", | |
| quantization="Q4_K_M", | |
| state="loaded", | |
| max_context_length=32768 | |
| ) | |
| ] | |
| def load_model(): | |
| global llm | |
| model_path_file = "/tmp/model_path.txt" | |
| if not os.path.exists(model_path_file): | |
| raise RuntimeError(f"Model path file not found: {model_path_file}") | |
| with open(model_path_file, "r") as f: | |
| model_path = f.read().strip() | |
| if not os.path.exists(model_path): | |
| raise RuntimeError(f"Model not found at path: {model_path}") | |
| llm = Llama(model_path=model_path) | |
| async def root(): | |
| return "Ollama is running" | |
| async def health_check(): | |
| return {"status": "ok"} | |
| async def api_tags(): | |
| return JSONResponse(content={ | |
| "data": [ | |
| { | |
| "name": "codellama-7b-instruct", | |
| "modified_at": "2025-06-01T00:00:00Z", # Replace with actual last modified ISO8601 UTC | |
| "size": 8000000000, # Replace with actual model size in bytes | |
| "digest": "sha256:placeholderdigestcodellama7b", # Replace with actual sha256 digest | |
| "details": { | |
| "format": "gguf", | |
| "family": "codellama", | |
| "families": ["codellama"] | |
| } | |
| } | |
| ] | |
| }) | |
| async def list_models(): | |
| # Return available models info | |
| return [model.dict() for model in AVAILABLE_MODELS] | |
| async def api_models(): | |
| return {"data": [model.dict() for model in AVAILABLE_MODELS]} | |
| async def get_model(model_id: str): | |
| for model in AVAILABLE_MODELS: | |
| if model.id == model_id: | |
| return model.dict() | |
| raise HTTPException(status_code=404, detail="Model not found") | |
| async def chat(req: ChatRequest): | |
| global llm | |
| if llm is None: | |
| return {"error": "Model not initialized."} | |
| # Validate model - simple check | |
| if req.model not in [m.id for m in AVAILABLE_MODELS]: | |
| raise HTTPException(status_code=400, detail="Unsupported model") | |
| # Construct prompt from messages | |
| prompt = "" | |
| for m in req.messages: | |
| prompt += f"{m.role}: {m.content}\n" | |
| prompt += "assistant:" | |
| output = llm( | |
| prompt, | |
| max_tokens=req.max_tokens, | |
| temperature=req.temperature, | |
| stop=["user:", "assistant:"] | |
| ) | |
| text = output.get("choices", [{}])[0].get("text", "").strip() | |
| response = { | |
| "id": str(uuid.uuid4()), | |
| "model": req.model, | |
| "choices": [ | |
| { | |
| "message": {"role": "assistant", "content": text}, | |
| "finish_reason": "stop" | |
| } | |
| ] | |
| } | |
| return response | |
| async def api_generate(req: GenerateRequest): | |
| global llm | |
| if llm is None: | |
| raise HTTPException(status_code=503, detail="Model not initialized") | |
| if req.model not in [m.id for m in AVAILABLE_MODELS]: | |
| raise HTTPException(status_code=400, detail="Unsupported model") | |
| output = llm( | |
| req.prompt, | |
| max_tokens=req.max_tokens, | |
| temperature=req.temperature, | |
| stop=["\n\n"] # Or any stop sequence you want | |
| ) | |
| text = output.get("choices", [{}])[0].get("text", "").strip() | |
| return { | |
| "id": str(uuid.uuid4()), | |
| "model": req.model, | |
| "choices": [ | |
| { | |
| "text": text, | |
| "index": 0, | |
| "finish_reason": "stop" | |
| } | |
| ] | |
| } |