| import contextlib |
| import logging |
| import os |
| import sys |
| import ast |
| import json |
| from threading import Thread |
| import time |
| from traceback import print_exception |
| from typing import List |
| from pydantic import BaseModel, Field |
|
|
| import uvicorn |
| from fastapi import Depends, FastAPI, Header, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.requests import Request |
| from fastapi.responses import JSONResponse |
| from sse_starlette import EventSourceResponse |
| from starlette.responses import PlainTextResponse |
|
|
| from openai_server.log import logger |
|
|
| sys.path.append('openai_server') |
|
|
|
|
| |
| |
| |
|
|
|
|
| class Generation(BaseModel): |
| |
| |
| top_k: int | None = 1 |
| repetition_penalty: float | None = 1 |
| min_p: float | None = 0.0 |
| max_time: float | None = 360 |
|
|
|
|
| class Params(BaseModel): |
| |
| user: str | None = Field(default=None, description="Track user") |
| model: str | None = Field(default=None, description="Choose model") |
| best_of: int | None = Field(default=1, description="Unused") |
| frequency_penalty: float | None = 0.0 |
| max_tokens: int | None = 256 |
| n: int | None = Field(default=1, description="Unused") |
| presence_penalty: float | None = 0.0 |
| stop: str | List[str] | None = None |
| stop_token_ids: List[int] | None = None |
| stream: bool | None = False |
| temperature: float | None = 0.3 |
| top_p: float | None = 1.0 |
| seed: int | None = 1234 |
|
|
|
|
| class CompletionParams(Params): |
| prompt: str | List[str] |
| logit_bias: dict | None = None |
| logprobs: int | None = None |
|
|
|
|
| class TextRequest(Generation, CompletionParams): |
| pass |
|
|
|
|
| class TextResponse(BaseModel): |
| id: str |
| choices: List[dict] |
| created: int = int(time.time()) |
| model: str |
| object: str = "text_completion" |
| usage: dict |
|
|
|
|
| class ChatParams(Params): |
| messages: List[dict] |
| tools: list | None = Field(default=None, description="WIP") |
| tool_choice: str | None = Field(default=None, description="WIP") |
|
|
|
|
| class ChatRequest(Generation, ChatParams): |
| |
| pass |
|
|
|
|
| class ChatResponse(BaseModel): |
| id: str |
| choices: List[dict] |
| created: int = int(time.time()) |
| model: str |
| object: str = "chat.completion" |
| usage: dict |
|
|
|
|
| class ModelInfoResponse(BaseModel): |
| model_name: str |
|
|
|
|
| class ModelListResponse(BaseModel): |
| model_names: List[str] |
|
|
|
|
| def verify_api_key(authorization: str = Header(None)) -> None: |
| server_api_key = os.getenv('H2OGPT_OPENAI_API_KEY', 'EMPTY') |
| if server_api_key == 'EMPTY': |
| |
| return |
| if server_api_key and (authorization is None or authorization != f"Bearer {server_api_key}"): |
| raise HTTPException(status_code=401, detail="Unauthorized") |
|
|
|
|
| app = FastAPI() |
| check_key = [Depends(verify_api_key)] |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"] |
| ) |
|
|
|
|
| |
|
|
|
|
| class InvalidRequestError(Exception): |
| pass |
|
|
|
|
| @app.exception_handler(Exception) |
| async def validation_exception_handler(request, exc): |
| print_exception(exc) |
| exc2 = InvalidRequestError(str(exc)) |
| return PlainTextResponse(str(exc2), status_code=400) |
|
|
|
|
| @app.options("/", dependencies=check_key) |
| async def options_route(): |
| return JSONResponse(content="OK") |
|
|
|
|
| @app.post('/v1/completions', response_model=TextResponse, dependencies=check_key) |
| async def openai_completions(request: Request, request_data: TextRequest): |
| if request_data.stream: |
| async def generator(): |
| from openai_server.backend import stream_completions |
| response = stream_completions(dict(request_data)) |
| for resp in response: |
| disconnected = await request.is_disconnected() |
| if disconnected: |
| break |
|
|
| yield {"data": json.dumps(resp)} |
|
|
| return EventSourceResponse(generator()) |
|
|
| else: |
| from openai_server.backend import completions |
| response = completions(dict(request_data)) |
| return JSONResponse(response) |
|
|
|
|
| @app.post('/v1/chat/completions', response_model=ChatResponse, dependencies=check_key) |
| async def openai_chat_completions(request: Request, request_data: ChatRequest): |
| if request_data.stream: |
| from openai_server.backend import stream_chat_completions |
|
|
| async def generator(): |
| response = stream_chat_completions(dict(request_data)) |
| for resp in response: |
| disconnected = await request.is_disconnected() |
| if disconnected: |
| break |
|
|
| yield {"data": json.dumps(resp)} |
|
|
| return EventSourceResponse(generator()) |
| else: |
| from openai_server.backend import chat_completions |
| response = chat_completions(dict(request_data)) |
| return JSONResponse(response) |
|
|
|
|
| |
| @app.get("/v1/models", dependencies=check_key) |
| @app.get("/v1/models/{model}", dependencies=check_key) |
| @app.get("/v1/models/{repo}/{model}", dependencies=check_key) |
| async def handle_models(request: Request): |
| path = request.url.path |
| model_name = path[len('/v1/models/'):] |
|
|
| from openai_server.backend import gradio_client |
| model_dict = ast.literal_eval(gradio_client.predict(api_name='/model_names')) |
| base_models = [x['base_model'] for x in model_dict] |
|
|
| if not model_name: |
| response = { |
| "object": "list", |
| "data": base_models, |
| } |
| else: |
| model_index = base_models.index(model_name) |
| if model_index >= 0: |
| response = model_dict[model_index] |
| else: |
| response = dict(model_name='INVALID') |
|
|
| return JSONResponse(response) |
|
|
|
|
| @app.get("/v1/internal/model/info", response_model=ModelInfoResponse, dependencies=check_key) |
| async def handle_model_info(): |
| from openai_server.backend import get_model_info |
| return JSONResponse(content=get_model_info()) |
|
|
|
|
| @app.get("/v1/internal/model/list", response_model=ModelListResponse, dependencies=check_key) |
| async def handle_list_models(): |
| from openai_server.backend import get_model_list |
| return JSONResponse(content=get_model_list()) |
|
|
|
|
| def run_server(host='0.0.0.0', |
| port=5000, |
| ssl_certfile=None, |
| ssl_keyfile=None, |
| gradio_prefix=None, |
| gradio_host=None, |
| gradio_port=None, |
| h2ogpt_key=None, |
| ): |
| os.environ['GRADIO_PREFIX'] = gradio_prefix or 'http' |
| os.environ['GRADIO_SERVER_HOST'] = gradio_host or 'localhost' |
| os.environ['GRADIO_SERVER_PORT'] = gradio_port or '7860' |
| os.environ['GRADIO_H2OGPT_H2OGPT_KEY'] = h2ogpt_key or '' |
| |
| |
| server_api_key = os.getenv('H2OGPT_OPENAI_API_KEY', os.environ['GRADIO_H2OGPT_H2OGPT_KEY']) or 'EMPTY' |
| os.environ['H2OGPT_OPENAI_API_KEY'] = server_api_key |
|
|
| port = int(os.getenv('H2OGPT_OPENAI_PORT', port)) |
| ssl_certfile = os.getenv('H2OGPT_OPENAI_CERT_PATH', ssl_certfile) |
| ssl_keyfile = os.getenv('H2OGPT_OPENAI_KEY_PATH', ssl_keyfile) |
|
|
| prefix = 'https' if ssl_keyfile and ssl_certfile else 'http' |
| logger.info(f'OpenAI API URL: {prefix}://{host}:{port}') |
| logger.info(f'OpenAI API key: {server_api_key}') |
|
|
| logging.getLogger("uvicorn.error").propagate = False |
| uvicorn.run(app, host=host, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile) |
|
|
|
|
| def run(wait=True, **kwargs): |
| if wait: |
| run_server(**kwargs) |
| else: |
| Thread(target=run_server, kwargs=kwargs, daemon=True).start() |
|
|