GradLLM / config.py
johnbridges's picture
.
db550a4
raw
history blame
7.97 kB
# app/config.py
import json
import os
from functools import lru_cache
from urllib.parse import quote
# pydantic v2-first import, soft-fallback to v1
try:
from pydantic_settings import BaseSettings, SettingsConfigDict # v2
from pydantic import BaseModel, Field
IS_V2 = True
except Exception: # v1 fallback
from pydantic import BaseSettings as _BaseSettings, BaseModel, Field
class BaseSettings(_BaseSettings): # shim name
pass
IS_V2 = False
APPSETTINGS_PATH = os.environ.get("APPSETTINGS_JSON", "appsettings.json")
def _load_json(path: str):
if not path or not os.path.exists(path):
return {}
try:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return {}
def _env_or(raw_val: str | None, env_name: str | None):
"""
If raw_val == ".env" -> fetch from env_name (or the obvious uppercase key).
Otherwise return raw_val.
"""
if raw_val != ".env":
return raw_val
key = env_name or ""
# If an explicit env name isn't passed, try a sensible default
if not key:
return None
return os.environ.get(key)
def _bool(v, default=False):
if v is None:
return default
if isinstance(v, bool):
return v
s = str(v).strip().lower()
if s in {"1", "true", "yes", "y", "on"}:
return True
if s in {"0", "false", "no", "n", "off"}:
return False
return default
class LocalSystemUrl(BaseModel):
ExternalUrl: str | None = None
IPAddress: str | None = None
RabbitHostName: str | None = None
RabbitPort: int | None = None
RabbitInstanceName: str | None = None
RabbitUserName: str | None = None
RabbitPassword: str | None = None # might be ".env"
RabbitVHost: str | None = None
UseTls: bool | None = None
class _SettingsModel(BaseModel):
# Core (env can still override each)
AMQP_URL: str | None = None
# Rabbit defaults (match what you already had)
RABBIT_INSTANCE_NAME: str = "prod"
RABBIT_EXCHANGE_TYPE: str = "topic"
RABBIT_ROUTING_KEY: str = ""
RABBIT_PREFETCH: int = 1
# TLS controls (default to strict=False only because your broker uses custom certs)
RABBIT_TLS_VERIFY: bool = False
RABBIT_TLS_CHECK_HOSTNAME: bool = False
RABBIT_TLS_CA_FILE: str | None = None
SERVICE_ID: str = "gradllm"
USE_TLS: bool = True
EXCHANGE_TYPES: dict[str, str] = Field(default_factory=dict)
# appsettings.json fields that we may use to assemble AMQP_URL
LocalSystemUrl: LocalSystemUrl | None = None
ServiceID: str | None = None
RabbitRoutingKey: str | None = None
RabbitExhangeType: str | None = None # (note the spelling in your JSON)
UseTls: bool | None = None # top-level dup in your JSON
# misc passthroughs if you want them later
RedisUrl: str | None = None
RedisSecret: str | None = None
# pydantic config
if IS_V2:
model_config = SettingsConfigDict(extra="ignore")
else:
class Config:
extra = "ignore"
def _build_amqp_url_from_local(local: LocalSystemUrl, top_level_use_tls: bool | None):
"""
Build amqp(s)://user:pass@host:port/vhost?heartbeat=30
Uses LocalSystemUrl + `.env` indirections.
"""
if not local or not local.RabbitHostName or not local.RabbitUserName:
return None
# determine TLS
use_tls = local.UseTls
if use_tls is None:
use_tls = top_level_use_tls
if use_tls is None:
use_tls = True # default secure
scheme = "amqps" if use_tls else "amqp"
# password indirection
# If RabbitPassword is ".env", read RABBIT_PASSWORD (conventional)
raw_pwd = local.RabbitPassword
pwd = raw_pwd if raw_pwd and raw_pwd != ".env" else os.environ.get("RABBIT_PASSWORD")
# fall back to standard names, just in case
if not pwd:
pwd = os.environ.get("RabbitPassword") or os.environ.get("RABBIT_PASS")
user = local.RabbitUserName
host = local.RabbitHostName
port = local.RabbitPort or (5671 if scheme == "amqps" else 5672)
vhost = local.RabbitVHost or "/"
vhost_enc = quote(vhost, safe="") # encode e.g. "/vhostuser" -> "%2Fvhostuser"
return f"{scheme}://{user}:{pwd}@{host}:{port}/{vhost_enc}?heartbeat=30"
def _merge_env_over_json(j: dict) -> _SettingsModel:
"""
Precedence:
1) Environment variables (HF Secrets)
2) appsettings.json values
3) built defaults
We also synthesize AMQP_URL from LocalSystemUrl if not set via env.
"""
# Start with JSON
model = _SettingsModel(**j)
# Map top-level JSON keys to our fields when used
if model.ServiceID:
model.SERVICE_ID = model.ServiceID
if model.RabbitRoutingKey:
model.RABBIT_ROUTING_KEY = model.RabbitRoutingKey
if model.RabbitExhangeType:
model.RABBIT_EXCHANGE_TYPE = model.RabbitExhangeType
if model.UseTls is not None:
model.USE_TLS = _bool(model.UseTls, model.USE_TLS)
# If AMQP_URL not set, try to build from LocalSystemUrl
if not model.AMQP_URL and model.LocalSystemUrl:
built = _build_amqp_url_from_local(model.LocalSystemUrl, model.UseTls)
if built:
model.AMQP_URL = built
# Now overlay environment variables (HF Secrets)
env = os.environ
# Direct env override of AMQP_URL wins
model.AMQP_URL = env.get("AMQP_URL", model.AMQP_URL)
# Other rabbit knobs
model.RABBIT_INSTANCE_NAME = env.get("RABBIT_INSTANCE_NAME", model.RABBIT_INSTANCE_NAME)
model.RABBIT_EXCHANGE_TYPE = env.get("RABBIT_EXCHANGE_TYPE", model.RABBIT_EXCHANGE_TYPE)
model.RABBIT_ROUTING_KEY = env.get("RABBIT_ROUTING_KEY", model.RABBIT_ROUTING_KEY)
model.RABBIT_PREFETCH = int(env.get("RABBIT_PREFETCH", model.RABBIT_PREFETCH))
# TLS env overrides
if "RABBIT_TLS_VERIFY" in env:
model.RABBIT_TLS_VERIFY = _bool(env["RABBIT_TLS_VERIFY"], model.RABBIT_TLS_VERIFY)
if "RABBIT_TLS_CHECK_HOSTNAME" in env:
model.RABBIT_TLS_CHECK_HOSTNAME = _bool(env["RABBIT_TLS_CHECK_HOSTNAME"], model.RABBIT_TLS_CHECK_HOSTNAME)
model.RABBIT_TLS_CA_FILE = env.get("RABBIT_TLS_CA_FILE", model.RABBIT_TLS_CA_FILE)
# SERVICE_ID can be overridden (you renamed yours to gradllm)
model.SERVICE_ID = env.get("SERVICE_ID", model.SERVICE_ID)
# Optional EXCHANGE_TYPES as JSON string in env
et = env.get("EXCHANGE_TYPES")
if et:
try:
model.EXCHANGE_TYPES = json.loads(et)
except Exception:
pass
# Final sanity: must have AMQP_URL
if not model.AMQP_URL:
raise RuntimeError(
"AMQP_URL is not configured. Set it via:\n"
"- env secret AMQP_URL, or\n"
"- appsettings.json LocalSystemUrl (RabbitHostName/UserName/Password/VHost/UseTls)."
)
return model
class Settings(BaseSettings):
"""
Thin wrapper that exposes the merged model as attributes.
"""
# required
AMQP_URL: str
# rabbit
RABBIT_INSTANCE_NAME: str = "prod"
RABBIT_EXCHANGE_TYPE: str = "topic"
RABBIT_ROUTING_KEY: str = ""
RABBIT_PREFETCH: int = 1
# TLS
RABBIT_TLS_VERIFY: bool = False
RABBIT_TLS_CHECK_HOSTNAME: bool = False
RABBIT_TLS_CA_FILE: str | None = None
SERVICE_ID: str = "gradllm"
USE_TLS: bool = True
EXCHANGE_TYPES: dict[str, str] = Field(default_factory=dict)
if IS_V2:
model_config = SettingsConfigDict(case_sensitive=True)
else:
class Config:
case_sensitive = True
@lru_cache
def get_settings() -> Settings:
# 1) load json
j = _load_json(APPSETTINGS_PATH)
# 2) merge with env precedence + synthesize AMQP_URL if needed
merged = _merge_env_over_json(j)
# 3) project into the public Settings class
data = merged.model_dump() if hasattr(merged, "model_dump") else merged.dict()
return Settings(**data)
settings = get_settings()