Spaces:
Running
feat: infrastructure sprint — vLLM/Modal, Helm, Terraform (#8)
Browse files* feat: add SelfHostedProvider for OpenAI-compatible endpoints (vLLM, TGI, Ollama)
- SelfHostedProvider targets any /v1/chat/completions endpoint via httpx
- Config schema extended: provider.selfhosted.{base_url, model_name, api_key, timeout_seconds}
- Env var fallback: MODAL_VLLM_URL, SELFHOSTED_MODEL, MODAL_AUTH_TOKEN
- Lazy tool-calling detection via startup probe; prompt-based fallback for
models that don't support function calling
- True streaming via httpx.AsyncClient.stream() + aiter_lines()
- 25 tests covering factory, complete, tool detection, prompt fallback,
retry/timeout, env vars, streaming, format_tools
- YAML configs for local (docker-compose) and Modal (serverless) deployments
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* fix: address review findings on SelfHostedProvider
- selfhosted_local.yaml: vLLM port changed from 8000 to 8001 to avoid
collision with FastAPI app (which also serves on 8000)
- docker-compose.vllm.yml: host port updated to 8001 accordingly
- _detect_tool_calling(): distinguish transient failures (timeout, 5xx)
returning None from definitive unsupported (400) returning False.
Transient results are NOT cached, so detection retries on next call.
- stream_complete(): apply same tool-calling detection/fallback logic
as complete() — prevents 400 on unsupported endpoints in streaming mode
- Added tests: transient failure returns None, 5xx returns None,
transient retries on next call, YAML-from-disk loading (3 files),
port collision regression test. 31 tests total.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* feat: add Modal vLLM deployment and benchmark runner
- modal/common.py: shared constants (model, GPU type, cost tracking)
- modal/serve_vllm.py: Modal app deploying vLLM as OpenAI-compatible
endpoint on A10G GPU with /v1/chat/completions and /health
- modal/run_benchmark.py: runs 27-question eval against all provider
configs and generates docs/provider_comparison.md
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* fix: address round-2 review findings
Provider:
- _parse_tool_calls_from_text: validate arguments is dict, not arbitrary
JSON (e.g. "oops") — prevents Pydantic ValidationError
- selfhosted_local.yaml: remove hardcoded base_url to avoid port collision
with app (8000) and Docker service-name mismatch; falls back to
MODAL_VLLM_URL env var, then default http://localhost:8001/v1
- Default fallback URL changed from :8000 to :8001
Docker Compose:
- AGENT_BENCH_ENV=selfhosted_local now works: config has empty base_url,
so MODAL_VLLM_URL=http://vllm:8000/v1 takes effect in-container
Modal:
- serve_vllm.py: rewritten to use vLLM CLI subprocess + ASGI proxy,
avoiding unstable Python API imports
- run_benchmark.py: fixed list-vs-dict crash — evaluate.py returns
list[EvalResult]; added aggregate() to compute P@5, R@5, citation
accuracy, latency p50, cost from per-question results; fixed field
names to match EvalResult schema; fixed docstring (removed invalid
`modal run` claim)
Tests:
- Added test_fallback_handles_non_dict_arguments (malformed args → {})
- Updated YAML-from-disk tests for empty base_url
- 32 selfhosted tests, 201 total
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* fix: Modal proxy signature, status forwarding, benchmark --base-url
- serve_vllm.py: annotate request as fastapi.Request (was parsed as
query param, returning 422 on every call); forward upstream status
code and headers via JSONResponse instead of bare resp.json()
- run_benchmark.py: --base-url is now optional; only required when
selfhosted_modal is in the provider set. --only openai works without it.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* fix: Modal proxy readiness, streaming errors, GPU cost reporting
serve_vllm.py:
- Add readiness loop: poll vLLM /health before accepting proxied
requests (180s timeout, 2s interval). Prevents cold-start failures.
- Streaming branch: check resp.status_code before streaming; return
upstream 4xx/5xx as non-streaming Response with real status code.
- Import ordering fix for ruff.
run_benchmark.py:
- Self-hosted cost: derive from GPU-seconds (latency_ms * A10G rate)
instead of token pricing (which is $0.00 for self-hosted models).
Uses MODAL_A10G_COST_PER_SEC from common.py.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* feat: add Helm chart for K8s deployment with dev/prod values
- Chart with Deployment, Service, HPA, ConfigMap, Secret templates
- values-dev.yaml: 1 replica, no HPA, reduced resources
- values-prod.yaml: 3 replicas, HPA 2-8 pods at 70% CPU
- Container port 7860 (matching Dockerfile), Service maps to 8000
- Probes target /health endpoint
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* feat: add Terraform GKE modules for API cluster (CPU-only, GCP)
- Root module wires networking + GKE modules
- Networking: VPC, subnet with pod/service CIDRs, firewall rules
- GKE: cluster with managed node pool (e2-standard-4, 2 nodes)
- No GPU nodes — inference runs on Modal (external)
- terraform.tfvars gitignored; example provided
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* docs: add infra documentation, Makefile targets, and architecture updates
- Makefile: modal-deploy, modal-stop, vllm-up, benchmark-all, k8s-dev,
k8s-prod, tf-plan, tf-validate targets
- DECISIONS.md: 7 new entries (vLLM, Modal, split topology, Helm,
CPU HPA, env var fallback, lazy tool detection)
- README.md: self-hosted/K8s/Terraform sections, provider tree in
architecture diagram, updated test count (201) and skills list
- docs/k8s-local-setup.md: minikube walkthrough
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* fix: streaming proxy lifetime, Helm provider routing, README flow
serve_vllm.py:
- Fix stream lifetime: use client.send(req, stream=True) instead of
async with client.stream() to avoid closing the upstream connection
before FastAPI iterates the generator. Generator owns cleanup via
try/finally with upstream.aclose().
Helm:
- configmap.yaml: AGENT_BENCH_ENV now conditional on provider.type
(selfhosted → selfhosted_modal, openai → default, anthropic → anthropic)
instead of hardcoded selfhosted_modal.
- Makefile k8s-dev/k8s-prod: require MODAL_VLLM_URL and pass it via
--set provider.selfhosted.modalEndpoint to prevent deploying pods
with empty endpoint URLs.
README:
- Benchmark flow: document required OPENAI_API_KEY and ANTHROPIC_API_KEY
for make benchmark-all; show --only selfhosted_modal alternative.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* fix: proxy streaming detection, provider comparison docs consistency
serve_vllm.py:
- Detect streaming by checking request body for "stream": true, not
Accept header. httpx sends Accept: */* by default, so the header
check missed real streaming requests and fell through to resp.json()
on SSE bodies (JSONDecodeError).
docs/provider_comparison.md:
- Updated to 3-provider format (openai, anthropic, selfhosted_modal)
matching the benchmark runner output. Self-hosted row marked TBD
with instructions to run make benchmark-all.
README.md:
- Provider comparison table updated to 3 columns with TBD self-hosted.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* fix: proxy non-JSON responses, docs table consistency
serve_vllm.py:
- Non-streaming branch: try resp.json(), fall back to raw Response
for endpoints that return empty/non-JSON bodies (e.g. vLLM /health
returns empty 200). Prevents JSONDecodeError → 500.
docs/provider_comparison.md:
- Table shape now matches generator output: 6 columns, no Model column.
README.md:
- Description updated from "2 providers" to "3 providers".
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* fix: update stale test counts in README (169 → 201)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* fix: Modal vLLM deployment and self-hosted benchmark
- Pin vllm==0.6.6.post1, transformers==4.47.0, huggingface_hub<1.0 to
fix tqdm DisabledTqdm, tokenizer, and head_dim incompatibilities
- Add HF secret for authenticated model downloads
- Increase context window to 8192 and ready timeout to 600s
- Add proxy error handling with traceback surfacing
- Sanitize messages for non-tool-calling models: merge multiple system
messages, convert tool-role to user, merge consecutive same-role
messages (required by Mistral chat template)
- Abort benchmark on first provider failure instead of continuing
- Tune selfhosted config: max_iterations=1, top_k=3 for 7B context
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* docs: update README with self-hosted benchmark results
Replace TBD placeholders with actual Mistral-7B benchmark data from
Modal vLLM deployment. Add citation accuracy and latency columns to
provider comparison table.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* fix: sync common.py context window and complete provider comparison
- Sync VLLM_MAX_MODEL_LEN to 8192 in common.py (was 4096, diverged
from serve_vllm.py during debugging)
- Add OpenAI and Anthropic data to provider_comparison.md alongside
self-hosted results for complete 3-provider comparison with analysis
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* feat: add provider health_check and Prometheus metrics endpoint
- Add health_check() to LLMProvider interface with default True
- OpenAI: probe via models.retrieve()
- Anthropic: probe via minimal messages.create()
- SelfHosted: probe via GET /models (vLLM endpoint)
- /health now calls provider.health_check() instead of just checking
the provider object exists — pods report degraded when inference is
actually unreachable
- Add /metrics/prometheus endpoint with text exposition format
(counter/gauge types) for Prometheus adapter + K8s HPA custom metrics
- Existing JSON /metrics endpoint unchanged
Co-Authored-By: Claude Opus 4.6 (1M context) <nore
- .gitignore +4 -0
- DECISIONS.md +49 -0
- Makefile +33 -1
- README.md +59 -12
- agent_bench/core/config.py +8 -0
- agent_bench/core/provider.py +413 -0
- agent_bench/serving/routes.py +32 -4
- configs/selfhosted_local.yaml +58 -0
- configs/selfhosted_modal.yaml +56 -0
- docker/docker-compose.vllm.yml +50 -0
- docs/k8s-local-setup.md +40 -0
- docs/provider_comparison.md +60 -39
- k8s/helm/agent-bench/Chart.yaml +6 -0
- k8s/helm/agent-bench/templates/_helpers.tpl +35 -0
- k8s/helm/agent-bench/templates/configmap.yaml +15 -0
- k8s/helm/agent-bench/templates/deployment.yaml +45 -0
- k8s/helm/agent-bench/templates/hpa.yaml +22 -0
- k8s/helm/agent-bench/templates/secret.yaml +12 -0
- k8s/helm/agent-bench/templates/service.yaml +15 -0
- k8s/helm/agent-bench/values-dev.yaml +12 -0
- k8s/helm/agent-bench/values-prod.yaml +15 -0
- k8s/helm/agent-bench/values.yaml +43 -0
- modal/common.py +11 -0
- modal/run_benchmark.py +182 -0
- modal/serve_vllm.py +187 -0
- pyproject.toml +3 -0
- terraform/main.tf +32 -0
- terraform/modules/gke/main.tf +38 -0
- terraform/modules/gke/outputs.tf +8 -0
- terraform/modules/gke/variables.tf +29 -0
- terraform/modules/networking/main.tf +67 -0
- terraform/modules/networking/variables.tf +11 -0
- terraform/outputs.tf +15 -0
- terraform/terraform.tfvars.example +6 -0
- terraform/variables.tf +16 -0
- tests/test_selfhosted_provider.py +689 -0
- tests/test_serving.py +95 -0
|
@@ -17,3 +17,7 @@ venv/
|
|
| 17 |
.worktrees/
|
| 18 |
*.db
|
| 19 |
docs/DESIGN.md
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
.worktrees/
|
| 18 |
*.db
|
| 19 |
docs/DESIGN.md
|
| 20 |
+
terraform.tfvars
|
| 21 |
+
.terraform/
|
| 22 |
+
*.tfstate
|
| 23 |
+
*.tfstate.backup
|
|
@@ -232,3 +232,52 @@ The deduplicated `sources` list in `AgentResponse` is for the API
|
|
| 232 |
response. The `ranked_sources` list preserves rank order with
|
| 233 |
duplicates for evaluation metrics. P@5 and R@5 need the raw
|
| 234 |
retrieval ranking, not the post-processed answer metadata.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
response. The `ranked_sources` list preserves rank order with
|
| 233 |
duplicates for evaluation metrics. P@5 and R@5 need the raw
|
| 234 |
retrieval ranking, not the post-processed answer metadata.
|
| 235 |
+
|
| 236 |
+
## Why vLLM over TGI / llama.cpp
|
| 237 |
+
|
| 238 |
+
vLLM has the widest model support, best throughput via PagedAttention, and a native
|
| 239 |
+
OpenAI-compatible server (`/v1/chat/completions`). TGI is a valid alternative; llama.cpp
|
| 240 |
+
targets different use cases (edge/CPU inference). This is a deliberate choice, not
|
| 241 |
+
ignorance of alternatives.
|
| 242 |
+
|
| 243 |
+
## Why Modal for GPU inference
|
| 244 |
+
|
| 245 |
+
Serverless GPU eliminates idle cost and GPU node management. A10G at ~$1.30/hr costs
|
| 246 |
+
~$0.50 per full 27-question benchmark run. The Docker Compose path (`docker-compose.vllm.yml`)
|
| 247 |
+
is retained for users who have local GPUs or prefer persistent serving.
|
| 248 |
+
|
| 249 |
+
## Why split topology (K8s API + Modal GPU)
|
| 250 |
+
|
| 251 |
+
The API layer (retrieval, orchestration, tool routing) is CPU-bound and benefits from
|
| 252 |
+
horizontal scaling via K8s HPA. The LLM inference layer is GPU-bound and benefits from
|
| 253 |
+
serverless elasticity — Modal scales to zero when idle, scales up on demand with no node
|
| 254 |
+
provisioning. Co-locating both in K8s would require GPU node pools with idle cost,
|
| 255 |
+
node autoscaler latency, and NVIDIA device plugin management. This mirrors a common
|
| 256 |
+
production pattern.
|
| 257 |
+
|
| 258 |
+
## Why Helm only, not Kustomize + Helm
|
| 259 |
+
|
| 260 |
+
Showing two K8s deployment methods for the same app adds complexity without demonstrating
|
| 261 |
+
distinct skills. Helm with `values-dev.yaml` / `values-prod.yaml` covers
|
| 262 |
+
environment-specific configuration cleanly.
|
| 263 |
+
|
| 264 |
+
## Why CPU-based HPA, not custom metrics
|
| 265 |
+
|
| 266 |
+
CPU utilization works without a Prometheus adapter or custom metrics server. A production
|
| 267 |
+
improvement would use the Prometheus adapter to scale on p95 latency from the `/metrics`
|
| 268 |
+
endpoint — this requires bridging the JSON metrics to Prometheus exposition format.
|
| 269 |
+
Documented as a follow-up.
|
| 270 |
+
|
| 271 |
+
## Why env var fallback in SelfHostedProvider
|
| 272 |
+
|
| 273 |
+
Follows the same pattern as OpenAIProvider reading `OPENAI_API_KEY`. The YAML config
|
| 274 |
+
provides defaults; env vars override at runtime. No config loader changes needed.
|
| 275 |
+
|
| 276 |
+
## Why lazy tool-call detection, not metadata check
|
| 277 |
+
|
| 278 |
+
Checking `/v1/models` metadata for tool-calling support is unreliable — model metadata
|
| 279 |
+
doesn't consistently report this capability. Instead, the provider sends one tool-calling
|
| 280 |
+
request on first `complete()` call with tools and checks if the response contains
|
| 281 |
+
`tool_calls`. The result is cached as `self._supports_tool_calling`. Transient failures
|
| 282 |
+
(timeout, 5xx) return `None` and retry on the next call rather than permanently
|
| 283 |
+
downgrading to prompt-based fallback.
|
|
@@ -1,6 +1,6 @@
|
|
| 1 |
PYTHON ?= /usr/local/opt/python@3.11/bin/python3.11
|
| 2 |
|
| 3 |
-
.PHONY: install test lint serve ingest evaluate-fast evaluate-full benchmark evaluate-langchain docker
|
| 4 |
|
| 5 |
install:
|
| 6 |
$(PYTHON) -m pip install -e ".[dev]"
|
|
@@ -33,3 +33,35 @@ evaluate-langchain:
|
|
| 33 |
|
| 34 |
docker:
|
| 35 |
docker-compose -f docker/docker-compose.yaml up --build
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
PYTHON ?= /usr/local/opt/python@3.11/bin/python3.11
|
| 2 |
|
| 3 |
+
.PHONY: install test lint serve ingest evaluate-fast evaluate-full benchmark evaluate-langchain docker modal-deploy modal-stop vllm-up benchmark-all k8s-dev k8s-prod tf-plan tf-validate
|
| 4 |
|
| 5 |
install:
|
| 6 |
$(PYTHON) -m pip install -e ".[dev]"
|
|
|
|
| 33 |
|
| 34 |
docker:
|
| 35 |
docker-compose -f docker/docker-compose.yaml up --build
|
| 36 |
+
|
| 37 |
+
## --- Infrastructure ---
|
| 38 |
+
|
| 39 |
+
modal-deploy: ## Deploy vLLM on Modal (prints endpoint URL)
|
| 40 |
+
@command -v modal >/dev/null 2>&1 || { echo "Error: modal CLI not found. Run: pip install -e '.[modal]' && modal setup"; exit 1; }
|
| 41 |
+
modal deploy modal/serve_vllm.py
|
| 42 |
+
|
| 43 |
+
modal-stop: ## Stop Modal deployment
|
| 44 |
+
@command -v modal >/dev/null 2>&1 || { echo "Error: modal CLI not found. Run: pip install -e '.[modal]' && modal setup"; exit 1; }
|
| 45 |
+
modal app stop agent-bench-vllm
|
| 46 |
+
|
| 47 |
+
vllm-up: ## Start local vLLM via Docker Compose (requires NVIDIA GPU)
|
| 48 |
+
docker compose -f docker/docker-compose.vllm.yml up --build
|
| 49 |
+
|
| 50 |
+
benchmark-all: ## Run provider comparison (requires Modal deployment + API keys)
|
| 51 |
+
$(PYTHON) modal/run_benchmark.py --base-url $(MODAL_VLLM_URL)
|
| 52 |
+
|
| 53 |
+
k8s-dev: ## Deploy to minikube (dev values, set MODAL_VLLM_URL first)
|
| 54 |
+
@test -n "$(MODAL_VLLM_URL)" || (echo "Error: MODAL_VLLM_URL is not set" && exit 1)
|
| 55 |
+
helm install agent-bench k8s/helm/agent-bench/ -f k8s/helm/agent-bench/values-dev.yaml \
|
| 56 |
+
--set provider.selfhosted.modalEndpoint=$(MODAL_VLLM_URL)
|
| 57 |
+
|
| 58 |
+
k8s-prod: ## Deploy via Helm (prod values, set MODAL_VLLM_URL first)
|
| 59 |
+
@test -n "$(MODAL_VLLM_URL)" || (echo "Error: MODAL_VLLM_URL is not set" && exit 1)
|
| 60 |
+
helm install agent-bench k8s/helm/agent-bench/ -f k8s/helm/agent-bench/values-prod.yaml \
|
| 61 |
+
--set provider.selfhosted.modalEndpoint=$(MODAL_VLLM_URL)
|
| 62 |
+
|
| 63 |
+
tf-plan: ## Run terraform plan (no apply)
|
| 64 |
+
cd terraform && terraform plan
|
| 65 |
+
|
| 66 |
+
tf-validate: ## Validate terraform syntax
|
| 67 |
+
cd terraform && terraform validate
|
|
@@ -2,9 +2,9 @@
|
|
| 2 |
|
| 3 |

|
| 4 |
|
| 5 |
-
Agentic knowledge retrieval system with evaluation benchmark. Custom orchestration pipeline + LangChain baseline, evaluated on the same 27-question golden dataset across
|
| 6 |
|
| 7 |
-
`
|
| 8 |
|
| 9 |
## Benchmark Results
|
| 10 |
|
|
@@ -30,12 +30,16 @@ Full analysis: [comparison report](results/comparison_custom_vs_langchain.md)
|
|
| 30 |
|
| 31 |
### Provider Comparison (Custom Pipeline)
|
| 32 |
|
| 33 |
-
| Metric | OpenAI gpt-4o-mini | Anthropic claude-haiku |
|
| 34 |
-
|--------|-------------------|----------------------|
|
| 35 |
-
| Retrieval P@5 | 0.70 | **0.74** |
|
| 36 |
-
| Retrieval R@5 | 0.83 | **0.84** |
|
| 37 |
-
| Keyword Hit Rate | 0.89 | **0.92** |
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
[Full benchmark report](docs/benchmark_report.md) | [Provider comparison](docs/provider_comparison.md) | [Design decisions](DECISIONS.md)
|
| 41 |
|
|
@@ -78,6 +82,40 @@ curl -X POST http://localhost:8000/ask \
|
|
| 78 |
OPENAI_API_KEY=sk-... docker-compose -f docker/docker-compose.yaml up --build
|
| 79 |
```
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
## Architecture
|
| 82 |
|
| 83 |
```mermaid
|
|
@@ -90,13 +128,21 @@ flowchart LR
|
|
| 90 |
Reg --> Calc[calculator]
|
| 91 |
Search --> Store[Hybrid Store<br/>FAISS + BM25 + RRF]
|
| 92 |
LLM -->|no tool_calls| Resp[AskResponse<br/>answer + sources + metadata]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
```
|
| 94 |
|
| 95 |
## Skills Demonstrated
|
| 96 |
|
| 97 |
- **Agent design & evaluation**: Built two independent orchestration approaches (custom tool-calling loop + LangChain AgentExecutor) and evaluated both on identical metrics to quantify framework tradeoffs
|
| 98 |
- **Retrieval engineering**: Hybrid FAISS + BM25 with Reciprocal Rank Fusion, cross-encoder reranking, evaluated across 27 questions with P@5, R@5, citation accuracy
|
| 99 |
-
- **
|
|
|
|
|
|
|
| 100 |
|
| 101 |
<details><summary>API Reference</summary>
|
| 102 |
|
|
@@ -107,7 +153,8 @@ flowchart LR
|
|
| 107 |
| `/ask` | POST | Ask a question, get answer with sources |
|
| 108 |
| `/ask/stream` | POST | SSE streaming (sources → chunks → done) |
|
| 109 |
| `/health` | GET | Store stats, provider status, uptime |
|
| 110 |
-
| `/metrics` | GET | Request count, latency p50/p95, cost |
|
|
|
|
| 111 |
|
| 112 |
### POST /ask
|
| 113 |
|
|
@@ -156,7 +203,7 @@ The golden dataset contains 27 hand-crafted questions:
|
|
| 156 |
## Testing
|
| 157 |
|
| 158 |
```bash
|
| 159 |
-
make test #
|
| 160 |
make lint # ruff + mypy
|
| 161 |
```
|
| 162 |
|
|
@@ -179,7 +226,7 @@ See [DECISIONS.md](DECISIONS.md) for rationale on building from primitives, RRF
|
|
| 179 |
| Conversation memory | Stateless | SQLite sessions | State management |
|
| 180 |
| Cloud deployment | None | HF Spaces (Docker) | Docker → production |
|
| 181 |
| CI/CD | None | GitHub Actions | Automated quality gates |
|
| 182 |
-
| Tests | 97 |
|
| 183 |
|
| 184 |
See [DECISIONS.md](DECISIONS.md) for the reasoning behind each design choice.
|
| 185 |
|
|
|
|
| 2 |
|
| 3 |

|
| 4 |
|
| 5 |
+
Agentic knowledge retrieval system with evaluation benchmark. Custom orchestration pipeline + LangChain baseline, evaluated on the same 27-question golden dataset across 3 providers (OpenAI, Anthropic, self-hosted vLLM on Modal). Zero hallucinated citations in all API configurations.
|
| 6 |
|
| 7 |
+
`205 tests` · `3 providers` · `LangChain comparison` · `K8s + Terraform` · `CI`
|
| 8 |
|
| 9 |
## Benchmark Results
|
| 10 |
|
|
|
|
| 30 |
|
| 31 |
### Provider Comparison (Custom Pipeline)
|
| 32 |
|
| 33 |
+
| Metric | OpenAI gpt-4o-mini | Anthropic claude-haiku | Self-hosted Mistral-7B |
|
| 34 |
+
|--------|-------------------|----------------------|----------------------|
|
| 35 |
+
| Retrieval P@5 | 0.70 | **0.74** | 0.05 |
|
| 36 |
+
| Retrieval R@5 | 0.83 | **0.84** | 0.05 |
|
| 37 |
+
| Keyword Hit Rate | 0.89 | **0.92** | 0.61 |
|
| 38 |
+
| Citation Acc | **1.00** | **1.00** | 0.14 |
|
| 39 |
+
| Latency p50 | 4,690 ms | 5,120 ms | 6,709 ms |
|
| 40 |
+
| Cost per query | **$0.0004** | $0.0007 | $0.0031 |
|
| 41 |
+
|
| 42 |
+
API providers are directly comparable (same config). The self-hosted row uses `max_iterations=1` and `top_k=3` (vs 3/5 for API) to fit Mistral-7B's 8K context window — not an apples-to-apples comparison, but reflects realistic 7B operating constraints. See [provider comparison](docs/provider_comparison.md) for full analysis.
|
| 43 |
|
| 44 |
[Full benchmark report](docs/benchmark_report.md) | [Provider comparison](docs/provider_comparison.md) | [Design decisions](DECISIONS.md)
|
| 45 |
|
|
|
|
| 82 |
OPENAI_API_KEY=sk-... docker-compose -f docker/docker-compose.yaml up --build
|
| 83 |
```
|
| 84 |
|
| 85 |
+
### Self-Hosted LLM via Modal (no local GPU needed)
|
| 86 |
+
|
| 87 |
+
```bash
|
| 88 |
+
pip install -e ".[modal]" # Install Modal SDK
|
| 89 |
+
modal setup # Authenticate with Modal
|
| 90 |
+
modal secret create huggingface-secret HF_TOKEN=hf_... # HF token for model download
|
| 91 |
+
make modal-deploy # Deploy vLLM on Modal A10G
|
| 92 |
+
export MODAL_VLLM_URL=https://your--agent-bench-vllm-serve.modal.run/v1
|
| 93 |
+
AGENT_BENCH_ENV=selfhosted_modal make serve # Serve with self-hosted provider
|
| 94 |
+
|
| 95 |
+
# Run provider comparison (requires all provider API keys)
|
| 96 |
+
export OPENAI_API_KEY=sk-...
|
| 97 |
+
export ANTHROPIC_API_KEY=sk-ant-...
|
| 98 |
+
make benchmark-all
|
| 99 |
+
|
| 100 |
+
# Or run only the self-hosted provider
|
| 101 |
+
python modal/run_benchmark.py --base-url $MODAL_VLLM_URL --only selfhosted_modal
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
### Self-Hosted LLM via Docker Compose (requires local NVIDIA GPU)
|
| 105 |
+
|
| 106 |
+
```bash
|
| 107 |
+
docker compose -f docker/docker-compose.vllm.yml up --build
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
### Kubernetes (Helm)
|
| 111 |
+
|
| 112 |
+
```bash
|
| 113 |
+
make k8s-dev # Dev: 1 replica, no HPA
|
| 114 |
+
make k8s-prod # Prod: 3 replicas, HPA 2-8 pods
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
See [docs/k8s-local-setup.md](docs/k8s-local-setup.md) for minikube walkthrough.
|
| 118 |
+
|
| 119 |
## Architecture
|
| 120 |
|
| 121 |
```mermaid
|
|
|
|
| 128 |
Reg --> Calc[calculator]
|
| 129 |
Search --> Store[Hybrid Store<br/>FAISS + BM25 + RRF]
|
| 130 |
LLM -->|no tool_calls| Resp[AskResponse<br/>answer + sources + metadata]
|
| 131 |
+
|
| 132 |
+
subgraph Providers
|
| 133 |
+
LLM --- OpenAI[OpenAI<br/>gpt-4o-mini]
|
| 134 |
+
LLM --- Anthropic[Anthropic<br/>claude-haiku]
|
| 135 |
+
LLM --- SelfHosted[SelfHosted<br/>vLLM / TGI / Ollama]
|
| 136 |
+
end
|
| 137 |
```
|
| 138 |
|
| 139 |
## Skills Demonstrated
|
| 140 |
|
| 141 |
- **Agent design & evaluation**: Built two independent orchestration approaches (custom tool-calling loop + LangChain AgentExecutor) and evaluated both on identical metrics to quantify framework tradeoffs
|
| 142 |
- **Retrieval engineering**: Hybrid FAISS + BM25 with Reciprocal Rank Fusion, cross-encoder reranking, evaluated across 27 questions with P@5, R@5, citation accuracy
|
| 143 |
+
- **Infrastructure:** Kubernetes (Helm), Terraform (GCP/GKE), self-hosted LLM serving (vLLM on Modal + Docker Compose)
|
| 144 |
+
- **MLOps:** Provider comparison benchmark (API vs self-hosted, real measured data)
|
| 145 |
+
- **Production engineering**: FastAPI, Docker, CI/CD, structured logging, rate limiting, SSE streaming, conversation sessions, 205 deterministic tests with mock providers
|
| 146 |
|
| 147 |
<details><summary>API Reference</summary>
|
| 148 |
|
|
|
|
| 153 |
| `/ask` | POST | Ask a question, get answer with sources |
|
| 154 |
| `/ask/stream` | POST | SSE streaming (sources → chunks → done) |
|
| 155 |
| `/health` | GET | Store stats, provider status, uptime |
|
| 156 |
+
| `/metrics` | GET | Request count, latency p50/p95, cost (JSON) |
|
| 157 |
+
| `/metrics/prometheus` | GET | Prometheus text exposition format |
|
| 158 |
|
| 159 |
### POST /ask
|
| 160 |
|
|
|
|
| 203 |
## Testing
|
| 204 |
|
| 205 |
```bash
|
| 206 |
+
make test # 205 deterministic tests, no API keys needed
|
| 207 |
make lint # ruff + mypy
|
| 208 |
```
|
| 209 |
|
|
|
|
| 226 |
| Conversation memory | Stateless | SQLite sessions | State management |
|
| 227 |
| Cloud deployment | None | HF Spaces (Docker) | Docker → production |
|
| 228 |
| CI/CD | None | GitHub Actions | Automated quality gates |
|
| 229 |
+
| Tests | 97 | 205 | Comprehensive coverage |
|
| 230 |
|
| 231 |
See [DECISIONS.md](DECISIONS.md) for the reasoning behind each design choice.
|
| 232 |
|
|
@@ -21,9 +21,17 @@ class ModelPricing(BaseModel):
|
|
| 21 |
output_cost_per_mtok: float
|
| 22 |
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
class ProviderConfig(BaseModel):
|
| 25 |
default: str = "openai"
|
| 26 |
models: dict[str, ModelPricing] = {}
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
class ChunkingConfig(BaseModel):
|
|
|
|
| 21 |
output_cost_per_mtok: float
|
| 22 |
|
| 23 |
|
| 24 |
+
class SelfHostedConfig(BaseModel):
|
| 25 |
+
base_url: str = ""
|
| 26 |
+
model_name: str = "mistralai/Mistral-7B-Instruct-v0.3"
|
| 27 |
+
api_key: str = ""
|
| 28 |
+
timeout_seconds: float = 120.0
|
| 29 |
+
|
| 30 |
+
|
| 31 |
class ProviderConfig(BaseModel):
|
| 32 |
default: str = "openai"
|
| 33 |
models: dict[str, ModelPricing] = {}
|
| 34 |
+
selfhosted: SelfHostedConfig = SelfHostedConfig()
|
| 35 |
|
| 36 |
|
| 37 |
class ChunkingConfig(BaseModel):
|
|
@@ -102,6 +102,15 @@ class LLMProvider(ABC):
|
|
| 102 |
@abstractmethod
|
| 103 |
def format_tools(self, tools: list[ToolDefinition]) -> list[dict]: ...
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
# --- Implementations ---
|
| 107 |
|
|
@@ -327,6 +336,13 @@ class OpenAIProvider(LLMProvider):
|
|
| 327 |
if chunk.choices and chunk.choices[0].delta.content:
|
| 328 |
yield chunk.choices[0].delta.content
|
| 329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
def format_tools(self, tools: list[ToolDefinition]) -> list[dict]:
|
| 331 |
return format_tools_openai(tools)
|
| 332 |
|
|
@@ -560,10 +576,405 @@ class AnthropicProvider(LLMProvider):
|
|
| 560 |
f"Anthropic timed out: {e}"
|
| 561 |
) from e
|
| 562 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 563 |
def format_tools(self, tools: list[ToolDefinition]) -> list[dict]:
|
| 564 |
return format_tools_anthropic(tools)
|
| 565 |
|
| 566 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 567 |
def create_provider(config: AppConfig | None = None) -> LLMProvider:
|
| 568 |
"""Factory: create provider based on config."""
|
| 569 |
if config is None:
|
|
@@ -573,6 +984,8 @@ def create_provider(config: AppConfig | None = None) -> LLMProvider:
|
|
| 573 |
return OpenAIProvider(config)
|
| 574 |
elif name == "anthropic":
|
| 575 |
return AnthropicProvider(config)
|
|
|
|
|
|
|
| 576 |
elif name == "mock":
|
| 577 |
return MockProvider()
|
| 578 |
else:
|
|
|
|
| 102 |
@abstractmethod
|
| 103 |
def format_tools(self, tools: list[ToolDefinition]) -> list[dict]: ...
|
| 104 |
|
| 105 |
+
async def health_check(self) -> bool:
|
| 106 |
+
"""Check if the upstream provider is reachable.
|
| 107 |
+
|
| 108 |
+
Returns True if the provider can serve requests, False otherwise.
|
| 109 |
+
Default implementation returns True (assume healthy). Providers
|
| 110 |
+
should override this to perform a real connectivity check.
|
| 111 |
+
"""
|
| 112 |
+
return True
|
| 113 |
+
|
| 114 |
|
| 115 |
# --- Implementations ---
|
| 116 |
|
|
|
|
| 336 |
if chunk.choices and chunk.choices[0].delta.content:
|
| 337 |
yield chunk.choices[0].delta.content
|
| 338 |
|
| 339 |
+
async def health_check(self) -> bool:
|
| 340 |
+
try:
|
| 341 |
+
await self.client.models.retrieve(self.model)
|
| 342 |
+
return True
|
| 343 |
+
except Exception:
|
| 344 |
+
return False
|
| 345 |
+
|
| 346 |
def format_tools(self, tools: list[ToolDefinition]) -> list[dict]:
|
| 347 |
return format_tools_openai(tools)
|
| 348 |
|
|
|
|
| 576 |
f"Anthropic timed out: {e}"
|
| 577 |
) from e
|
| 578 |
|
| 579 |
+
async def health_check(self) -> bool:
|
| 580 |
+
try:
|
| 581 |
+
await self.client.models.retrieve(model_id=self.model)
|
| 582 |
+
return True
|
| 583 |
+
except Exception:
|
| 584 |
+
return False
|
| 585 |
+
|
| 586 |
def format_tools(self, tools: list[ToolDefinition]) -> list[dict]:
|
| 587 |
return format_tools_anthropic(tools)
|
| 588 |
|
| 589 |
|
| 590 |
+
class SelfHostedProvider(LLMProvider):
|
| 591 |
+
"""Provider targeting any OpenAI-compatible endpoint (vLLM, TGI, Ollama).
|
| 592 |
+
|
| 593 |
+
Reads settings from config (provider.selfhosted.*) with env var fallback:
|
| 594 |
+
MODAL_VLLM_URL -> base_url
|
| 595 |
+
SELFHOSTED_MODEL -> model_name
|
| 596 |
+
MODAL_AUTH_TOKEN -> api_key
|
| 597 |
+
|
| 598 |
+
Tool-calling support is detected lazily on the first complete() call
|
| 599 |
+
with tools. If the endpoint returns a 400 or the model ignores tools,
|
| 600 |
+
subsequent calls fall back to prompt-based tool selection.
|
| 601 |
+
"""
|
| 602 |
+
|
| 603 |
+
def __init__(self, config: AppConfig | None = None) -> None:
|
| 604 |
+
import os
|
| 605 |
+
|
| 606 |
+
import httpx as _httpx
|
| 607 |
+
|
| 608 |
+
self.config = config or load_config()
|
| 609 |
+
sh = self.config.provider.selfhosted
|
| 610 |
+
self.base_url = (
|
| 611 |
+
sh.base_url
|
| 612 |
+
or os.environ.get("MODAL_VLLM_URL", "http://localhost:8001/v1")
|
| 613 |
+
)
|
| 614 |
+
self.model = (
|
| 615 |
+
sh.model_name
|
| 616 |
+
if sh.model_name != "mistralai/Mistral-7B-Instruct-v0.3"
|
| 617 |
+
else os.environ.get("SELFHOSTED_MODEL", sh.model_name)
|
| 618 |
+
)
|
| 619 |
+
api_key = sh.api_key or os.environ.get("MODAL_AUTH_TOKEN", "")
|
| 620 |
+
self._supports_tool_calling: bool | None = None # detected lazily
|
| 621 |
+
|
| 622 |
+
model_pricing = self.config.provider.models.get(self.model)
|
| 623 |
+
self._input_cost = model_pricing.input_cost_per_mtok if model_pricing else 0.0
|
| 624 |
+
self._output_cost = model_pricing.output_cost_per_mtok if model_pricing else 0.0
|
| 625 |
+
|
| 626 |
+
self.client = _httpx.AsyncClient(
|
| 627 |
+
base_url=self.base_url,
|
| 628 |
+
timeout=sh.timeout_seconds,
|
| 629 |
+
follow_redirects=True,
|
| 630 |
+
headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
async def _detect_tool_calling(self) -> bool | None:
|
| 634 |
+
"""Probe the endpoint for OpenAI-format tool-calling support.
|
| 635 |
+
|
| 636 |
+
Returns:
|
| 637 |
+
True — model responded with tool_calls (definitive: cache it)
|
| 638 |
+
False — endpoint returned 400 (definitive: cache it)
|
| 639 |
+
None — transient failure (timeout, 5xx, connection error); do NOT cache
|
| 640 |
+
"""
|
| 641 |
+
test_tool = {
|
| 642 |
+
"type": "function",
|
| 643 |
+
"function": {
|
| 644 |
+
"name": "test_probe",
|
| 645 |
+
"description": "Probe for tool support",
|
| 646 |
+
"parameters": {
|
| 647 |
+
"type": "object",
|
| 648 |
+
"properties": {"x": {"type": "string"}},
|
| 649 |
+
},
|
| 650 |
+
},
|
| 651 |
+
}
|
| 652 |
+
try:
|
| 653 |
+
resp = await self.client.post(
|
| 654 |
+
"/chat/completions",
|
| 655 |
+
json={
|
| 656 |
+
"model": self.model,
|
| 657 |
+
"messages": [
|
| 658 |
+
{"role": "user", "content": "Call the test_probe tool with x='hello'"}
|
| 659 |
+
],
|
| 660 |
+
"tools": [test_tool],
|
| 661 |
+
"tool_choice": "auto",
|
| 662 |
+
"max_tokens": 50,
|
| 663 |
+
},
|
| 664 |
+
)
|
| 665 |
+
if resp.status_code == 400:
|
| 666 |
+
log.info("selfhosted_tool_detect", result="unsupported (400)")
|
| 667 |
+
return False
|
| 668 |
+
if resp.status_code >= 500:
|
| 669 |
+
log.warning("selfhosted_tool_detect", result="transient (5xx)")
|
| 670 |
+
return None
|
| 671 |
+
resp.raise_for_status()
|
| 672 |
+
data = resp.json()
|
| 673 |
+
has_tools = bool(
|
| 674 |
+
data["choices"][0]["message"].get("tool_calls")
|
| 675 |
+
)
|
| 676 |
+
log.info("selfhosted_tool_detect", result="supported" if has_tools else "unsupported")
|
| 677 |
+
return has_tools
|
| 678 |
+
except Exception:
|
| 679 |
+
log.warning("selfhosted_tool_detect", result="transient (error)")
|
| 680 |
+
return None
|
| 681 |
+
|
| 682 |
+
@staticmethod
|
| 683 |
+
def _sanitize_messages(messages: list[dict]) -> list[dict]:
|
| 684 |
+
"""Convert tool-role messages and merge consecutive same-role messages.
|
| 685 |
+
|
| 686 |
+
Many models (e.g. Mistral) require strictly alternating user/assistant
|
| 687 |
+
messages. Tool results are converted to user messages and consecutive
|
| 688 |
+
same-role messages are merged.
|
| 689 |
+
"""
|
| 690 |
+
sanitized: list[dict] = []
|
| 691 |
+
for m in messages:
|
| 692 |
+
if m["role"] == "tool":
|
| 693 |
+
role = "user"
|
| 694 |
+
content = f"[Tool result]: {m['content']}"
|
| 695 |
+
elif m["role"] == "assistant" and "tool_calls" in m:
|
| 696 |
+
role = "assistant"
|
| 697 |
+
content = m.get("content") or ""
|
| 698 |
+
else:
|
| 699 |
+
role = m["role"]
|
| 700 |
+
content = m.get("content") or ""
|
| 701 |
+
|
| 702 |
+
# Merge consecutive same-role messages
|
| 703 |
+
if sanitized and sanitized[-1]["role"] == role and role != "system":
|
| 704 |
+
sanitized[-1]["content"] += "\n\n" + content
|
| 705 |
+
else:
|
| 706 |
+
sanitized.append({"role": role, "content": content})
|
| 707 |
+
|
| 708 |
+
# Merge consecutive same-role messages that resulted from dropping empty ones
|
| 709 |
+
merged: list[dict] = []
|
| 710 |
+
for m in sanitized:
|
| 711 |
+
if not m["content"].strip() and m["role"] != "system":
|
| 712 |
+
continue # drop empty messages
|
| 713 |
+
if merged and merged[-1]["role"] == m["role"] and m["role"] != "system":
|
| 714 |
+
merged[-1]["content"] += "\n\n" + m["content"]
|
| 715 |
+
else:
|
| 716 |
+
merged.append(m)
|
| 717 |
+
return merged
|
| 718 |
+
|
| 719 |
+
@staticmethod
|
| 720 |
+
def _tools_as_prompt(tools: list[ToolDefinition]) -> str:
|
| 721 |
+
"""Format tools as system prompt text for prompt-based fallback."""
|
| 722 |
+
lines = ["You have access to the following tools:", ""]
|
| 723 |
+
for t in tools:
|
| 724 |
+
lines.append(f"- {t.name}: {t.description}")
|
| 725 |
+
lines.append(f" Parameters: {json.dumps(t.parameters)}")
|
| 726 |
+
lines.extend([
|
| 727 |
+
"",
|
| 728 |
+
"To use a tool, respond with ONLY this JSON (no other text):",
|
| 729 |
+
'{"tool_calls": [{"name": "tool_name", "arguments": {"param": "value"}}]}',
|
| 730 |
+
"",
|
| 731 |
+
"If you don't need a tool, respond normally with text.",
|
| 732 |
+
])
|
| 733 |
+
return "\n".join(lines)
|
| 734 |
+
|
| 735 |
+
@staticmethod
|
| 736 |
+
def _parse_tool_calls_from_text(text: str) -> list[ToolCall]:
|
| 737 |
+
"""Parse tool calls from model text output (prompt-based fallback)."""
|
| 738 |
+
import uuid
|
| 739 |
+
|
| 740 |
+
try:
|
| 741 |
+
data = json.loads(text.strip())
|
| 742 |
+
if isinstance(data, dict) and "tool_calls" in data:
|
| 743 |
+
calls = []
|
| 744 |
+
for tc in data["tool_calls"]:
|
| 745 |
+
raw_args = tc.get("arguments", {})
|
| 746 |
+
if not isinstance(raw_args, dict):
|
| 747 |
+
raw_args = {}
|
| 748 |
+
calls.append(
|
| 749 |
+
ToolCall(
|
| 750 |
+
id=f"call_{uuid.uuid4().hex[:8]}",
|
| 751 |
+
name=tc["name"],
|
| 752 |
+
arguments=raw_args,
|
| 753 |
+
)
|
| 754 |
+
)
|
| 755 |
+
return calls
|
| 756 |
+
except (json.JSONDecodeError, KeyError, TypeError):
|
| 757 |
+
pass
|
| 758 |
+
return []
|
| 759 |
+
|
| 760 |
+
async def complete(
|
| 761 |
+
self,
|
| 762 |
+
messages: list[Message],
|
| 763 |
+
tools: list[ToolDefinition] | None = None,
|
| 764 |
+
temperature: float = 0.0,
|
| 765 |
+
max_tokens: int = 1024,
|
| 766 |
+
) -> CompletionResponse:
|
| 767 |
+
import httpx as _httpx
|
| 768 |
+
|
| 769 |
+
# Lazy tool-calling detection on first call with tools
|
| 770 |
+
if tools and self._supports_tool_calling is None:
|
| 771 |
+
result = await self._detect_tool_calling()
|
| 772 |
+
if result is not None:
|
| 773 |
+
self._supports_tool_calling = result
|
| 774 |
+
# If None (transient), leave as None so next call retries
|
| 775 |
+
|
| 776 |
+
formatted_messages = format_messages_openai(messages)
|
| 777 |
+
|
| 778 |
+
# Use native tools only when detection confirmed support.
|
| 779 |
+
# When detection is None (transient failure), fall back to prompt-based
|
| 780 |
+
# rather than risk a 400 with native tools on an unsupported endpoint.
|
| 781 |
+
use_native_tools = tools and self._supports_tool_calling is True
|
| 782 |
+
if tools and not use_native_tools:
|
| 783 |
+
tool_prompt = self._tools_as_prompt(tools)
|
| 784 |
+
# Merge tool instructions into existing system message (some models
|
| 785 |
+
# like Mistral reject multiple system messages in their chat template)
|
| 786 |
+
if formatted_messages and formatted_messages[0]["role"] == "system":
|
| 787 |
+
formatted_messages[0]["content"] = (
|
| 788 |
+
tool_prompt + "\n\n" + formatted_messages[0]["content"]
|
| 789 |
+
)
|
| 790 |
+
else:
|
| 791 |
+
formatted_messages = [
|
| 792 |
+
{"role": "system", "content": tool_prompt},
|
| 793 |
+
*formatted_messages,
|
| 794 |
+
]
|
| 795 |
+
# Always sanitize for self-hosted: messages may contain tool/tool_calls
|
| 796 |
+
# from earlier iterations even when current call has tools=None
|
| 797 |
+
formatted_messages = self._sanitize_messages(formatted_messages)
|
| 798 |
+
|
| 799 |
+
payload: dict = {
|
| 800 |
+
"model": self.model,
|
| 801 |
+
"messages": formatted_messages,
|
| 802 |
+
"temperature": temperature,
|
| 803 |
+
"max_tokens": max_tokens,
|
| 804 |
+
}
|
| 805 |
+
if use_native_tools and tools:
|
| 806 |
+
payload["tools"] = self.format_tools(tools)
|
| 807 |
+
payload["tool_choice"] = "auto"
|
| 808 |
+
|
| 809 |
+
retry_cfg = self.config.retry
|
| 810 |
+
start = time.perf_counter()
|
| 811 |
+
|
| 812 |
+
for attempt in range(retry_cfg.max_retries + 1):
|
| 813 |
+
try:
|
| 814 |
+
resp = await self.client.post("/chat/completions", json=payload)
|
| 815 |
+
if resp.status_code == 429:
|
| 816 |
+
if attempt == retry_cfg.max_retries:
|
| 817 |
+
raise ProviderRateLimitError(
|
| 818 |
+
f"Rate limited after {retry_cfg.max_retries} retries"
|
| 819 |
+
)
|
| 820 |
+
wait = min(
|
| 821 |
+
retry_cfg.base_delay * (2 ** attempt), retry_cfg.max_delay
|
| 822 |
+
)
|
| 823 |
+
log.warning(
|
| 824 |
+
"selfhosted_retry",
|
| 825 |
+
attempt=attempt + 1,
|
| 826 |
+
wait_seconds=wait,
|
| 827 |
+
)
|
| 828 |
+
await asyncio.sleep(wait)
|
| 829 |
+
continue
|
| 830 |
+
if resp.status_code >= 400:
|
| 831 |
+
log.error("selfhosted_error", status=resp.status_code, body=resp.text[:500])
|
| 832 |
+
resp.raise_for_status()
|
| 833 |
+
break
|
| 834 |
+
except _httpx.TimeoutException as e:
|
| 835 |
+
raise ProviderTimeoutError(f"Self-hosted timed out: {e}") from e
|
| 836 |
+
|
| 837 |
+
latency_ms = (time.perf_counter() - start) * 1000
|
| 838 |
+
data = resp.json()
|
| 839 |
+
|
| 840 |
+
choice = data["choices"][0]
|
| 841 |
+
content = choice["message"].get("content") or ""
|
| 842 |
+
tool_calls: list[ToolCall] = []
|
| 843 |
+
|
| 844 |
+
if choice["message"].get("tool_calls"):
|
| 845 |
+
# Native tool calling response
|
| 846 |
+
for tc in choice["message"]["tool_calls"]:
|
| 847 |
+
try:
|
| 848 |
+
args = json.loads(tc["function"]["arguments"])
|
| 849 |
+
except (json.JSONDecodeError, KeyError):
|
| 850 |
+
args = {}
|
| 851 |
+
tool_calls.append(
|
| 852 |
+
ToolCall(
|
| 853 |
+
id=tc["id"],
|
| 854 |
+
name=tc["function"]["name"],
|
| 855 |
+
arguments=args,
|
| 856 |
+
)
|
| 857 |
+
)
|
| 858 |
+
elif tools and not self._supports_tool_calling and content:
|
| 859 |
+
# Prompt-based fallback: parse tool calls from text
|
| 860 |
+
tool_calls = self._parse_tool_calls_from_text(content)
|
| 861 |
+
if tool_calls:
|
| 862 |
+
content = "" # tool call replaces text content
|
| 863 |
+
|
| 864 |
+
usage_data = data.get("usage", {})
|
| 865 |
+
input_tokens = usage_data.get("prompt_tokens", 0)
|
| 866 |
+
output_tokens = usage_data.get("completion_tokens", 0)
|
| 867 |
+
cost = (
|
| 868 |
+
input_tokens * self._input_cost + output_tokens * self._output_cost
|
| 869 |
+
) / 1_000_000
|
| 870 |
+
|
| 871 |
+
return CompletionResponse(
|
| 872 |
+
content=content,
|
| 873 |
+
tool_calls=tool_calls,
|
| 874 |
+
usage=TokenUsage(
|
| 875 |
+
input_tokens=input_tokens,
|
| 876 |
+
output_tokens=output_tokens,
|
| 877 |
+
estimated_cost_usd=cost,
|
| 878 |
+
),
|
| 879 |
+
provider="selfhosted",
|
| 880 |
+
model=self.model,
|
| 881 |
+
latency_ms=latency_ms,
|
| 882 |
+
)
|
| 883 |
+
|
| 884 |
+
async def stream_complete(
|
| 885 |
+
self,
|
| 886 |
+
messages: list[Message],
|
| 887 |
+
tools: list[ToolDefinition] | None = None,
|
| 888 |
+
temperature: float = 0.0,
|
| 889 |
+
max_tokens: int = 1024,
|
| 890 |
+
) -> AsyncIterator[str]:
|
| 891 |
+
import httpx as _httpx
|
| 892 |
+
|
| 893 |
+
# Same tool-calling detection/fallback as complete()
|
| 894 |
+
if tools and self._supports_tool_calling is None:
|
| 895 |
+
result = await self._detect_tool_calling()
|
| 896 |
+
if result is not None:
|
| 897 |
+
self._supports_tool_calling = result
|
| 898 |
+
|
| 899 |
+
formatted_messages = format_messages_openai(messages)
|
| 900 |
+
use_native_tools = tools and self._supports_tool_calling is True
|
| 901 |
+
if tools and not use_native_tools:
|
| 902 |
+
tool_prompt = self._tools_as_prompt(tools)
|
| 903 |
+
if formatted_messages and formatted_messages[0]["role"] == "system":
|
| 904 |
+
formatted_messages[0]["content"] = (
|
| 905 |
+
tool_prompt + "\n\n" + formatted_messages[0]["content"]
|
| 906 |
+
)
|
| 907 |
+
else:
|
| 908 |
+
formatted_messages = [
|
| 909 |
+
{"role": "system", "content": tool_prompt},
|
| 910 |
+
*formatted_messages,
|
| 911 |
+
]
|
| 912 |
+
formatted_messages = self._sanitize_messages(formatted_messages)
|
| 913 |
+
|
| 914 |
+
payload: dict = {
|
| 915 |
+
"model": self.model,
|
| 916 |
+
"messages": formatted_messages,
|
| 917 |
+
"temperature": temperature,
|
| 918 |
+
"max_tokens": max_tokens,
|
| 919 |
+
"stream": True,
|
| 920 |
+
}
|
| 921 |
+
if use_native_tools and tools:
|
| 922 |
+
payload["tools"] = self.format_tools(tools)
|
| 923 |
+
payload["tool_choice"] = "auto"
|
| 924 |
+
|
| 925 |
+
retry_cfg = self.config.retry
|
| 926 |
+
for attempt in range(retry_cfg.max_retries + 1):
|
| 927 |
+
try:
|
| 928 |
+
async with self.client.stream(
|
| 929 |
+
"POST", "/chat/completions", json=payload
|
| 930 |
+
) as resp:
|
| 931 |
+
if resp.status_code == 429:
|
| 932 |
+
if attempt == retry_cfg.max_retries:
|
| 933 |
+
raise ProviderRateLimitError(
|
| 934 |
+
f"Rate limited after {retry_cfg.max_retries} retries"
|
| 935 |
+
)
|
| 936 |
+
wait = min(
|
| 937 |
+
retry_cfg.base_delay * (2 ** attempt),
|
| 938 |
+
retry_cfg.max_delay,
|
| 939 |
+
)
|
| 940 |
+
log.warning(
|
| 941 |
+
"selfhosted_stream_retry",
|
| 942 |
+
attempt=attempt + 1,
|
| 943 |
+
wait_seconds=wait,
|
| 944 |
+
)
|
| 945 |
+
await asyncio.sleep(wait)
|
| 946 |
+
continue
|
| 947 |
+
resp.raise_for_status()
|
| 948 |
+
|
| 949 |
+
async for line in resp.aiter_lines():
|
| 950 |
+
line = line.strip()
|
| 951 |
+
if not line or not line.startswith("data: "):
|
| 952 |
+
continue
|
| 953 |
+
data_str = line[len("data: "):]
|
| 954 |
+
if data_str == "[DONE]":
|
| 955 |
+
return
|
| 956 |
+
try:
|
| 957 |
+
chunk_data = json.loads(data_str)
|
| 958 |
+
delta = chunk_data["choices"][0].get("delta", {})
|
| 959 |
+
if delta.get("content"):
|
| 960 |
+
yield delta["content"]
|
| 961 |
+
except (json.JSONDecodeError, KeyError, IndexError):
|
| 962 |
+
continue
|
| 963 |
+
return # success — exit retry loop
|
| 964 |
+
except _httpx.TimeoutException as e:
|
| 965 |
+
raise ProviderTimeoutError(f"Self-hosted timed out: {e}") from e
|
| 966 |
+
|
| 967 |
+
async def health_check(self) -> bool:
|
| 968 |
+
try:
|
| 969 |
+
resp = await self.client.get("/models", timeout=5.0)
|
| 970 |
+
return resp.status_code == 200
|
| 971 |
+
except Exception:
|
| 972 |
+
return False
|
| 973 |
+
|
| 974 |
+
def format_tools(self, tools: list[ToolDefinition]) -> list[dict]:
|
| 975 |
+
return format_tools_openai(tools)
|
| 976 |
+
|
| 977 |
+
|
| 978 |
def create_provider(config: AppConfig | None = None) -> LLMProvider:
|
| 979 |
"""Factory: create provider based on config."""
|
| 980 |
if config is None:
|
|
|
|
| 984 |
return OpenAIProvider(config)
|
| 985 |
elif name == "anthropic":
|
| 986 |
return AnthropicProvider(config)
|
| 987 |
+
elif name == "selfhosted":
|
| 988 |
+
return SelfHostedProvider(config)
|
| 989 |
elif name == "mock":
|
| 990 |
return MockProvider()
|
| 991 |
else:
|
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""API routes: /ask, /ask/stream, /health, /metrics."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
@@ -178,10 +178,10 @@ async def health(request: Request) -> HealthResponse:
|
|
| 178 |
store = request.app.state.store
|
| 179 |
start_time: float = request.app.state.start_time
|
| 180 |
|
| 181 |
-
provider_available =
|
| 182 |
try:
|
| 183 |
-
|
| 184 |
-
|
| 185 |
except Exception:
|
| 186 |
provider_available = False
|
| 187 |
|
|
@@ -205,3 +205,31 @@ async def metrics(request: Request) -> MetricsResponse:
|
|
| 205 |
errors_total=m.errors_total,
|
| 206 |
avg_cost_per_query_usd=m.avg_cost,
|
| 207 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""API routes: /ask, /ask/stream, /health, /metrics, /metrics/prometheus."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 178 |
store = request.app.state.store
|
| 179 |
start_time: float = request.app.state.start_time
|
| 180 |
|
| 181 |
+
provider_available = False
|
| 182 |
try:
|
| 183 |
+
provider = request.app.state.orchestrator.provider
|
| 184 |
+
provider_available = await provider.health_check()
|
| 185 |
except Exception:
|
| 186 |
provider_available = False
|
| 187 |
|
|
|
|
| 205 |
errors_total=m.errors_total,
|
| 206 |
avg_cost_per_query_usd=m.avg_cost,
|
| 207 |
)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@router.get("/metrics/prometheus")
|
| 211 |
+
async def metrics_prometheus(request: Request) -> Response:
|
| 212 |
+
"""Prometheus text exposition format for K8s HPA custom metrics."""
|
| 213 |
+
m: MetricsCollector = request.app.state.metrics
|
| 214 |
+
lines = [
|
| 215 |
+
"# HELP agent_bench_requests_total Total requests served.",
|
| 216 |
+
"# TYPE agent_bench_requests_total counter",
|
| 217 |
+
f"agent_bench_requests_total {m.requests_total}",
|
| 218 |
+
"# HELP agent_bench_errors_total Total error responses.",
|
| 219 |
+
"# TYPE agent_bench_errors_total counter",
|
| 220 |
+
f"agent_bench_errors_total {m.errors_total}",
|
| 221 |
+
"# HELP agent_bench_latency_p50_ms 50th percentile latency in ms.",
|
| 222 |
+
"# TYPE agent_bench_latency_p50_ms gauge",
|
| 223 |
+
f"agent_bench_latency_p50_ms {m.percentile(50):.1f}",
|
| 224 |
+
"# HELP agent_bench_latency_p95_ms 95th percentile latency in ms.",
|
| 225 |
+
"# TYPE agent_bench_latency_p95_ms gauge",
|
| 226 |
+
f"agent_bench_latency_p95_ms {m.percentile(95):.1f}",
|
| 227 |
+
"# HELP agent_bench_avg_cost_usd Average cost per query in USD.",
|
| 228 |
+
"# TYPE agent_bench_avg_cost_usd gauge",
|
| 229 |
+
f"agent_bench_avg_cost_usd {m.avg_cost:.6f}",
|
| 230 |
+
"",
|
| 231 |
+
]
|
| 232 |
+
return Response(
|
| 233 |
+
content="\n".join(lines),
|
| 234 |
+
media_type="text/plain; version=0.0.4; charset=utf-8",
|
| 235 |
+
)
|
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
agent:
|
| 2 |
+
max_iterations: 3
|
| 3 |
+
temperature: 0.0
|
| 4 |
+
|
| 5 |
+
provider:
|
| 6 |
+
default: selfhosted
|
| 7 |
+
selfhosted:
|
| 8 |
+
# base_url left empty: falls back to MODAL_VLLM_URL env var,
|
| 9 |
+
# then to http://localhost:8001/v1 (default for local vLLM via docker-compose.vllm.yml).
|
| 10 |
+
# In Docker Compose, MODAL_VLLM_URL is set to http://vllm:8000/v1.
|
| 11 |
+
model_name: mistralai/Mistral-7B-Instruct-v0.3
|
| 12 |
+
timeout_seconds: 120
|
| 13 |
+
models:
|
| 14 |
+
mistralai/Mistral-7B-Instruct-v0.3:
|
| 15 |
+
input_cost_per_mtok: 0.0
|
| 16 |
+
output_cost_per_mtok: 0.0
|
| 17 |
+
gpt-4o-mini:
|
| 18 |
+
input_cost_per_mtok: 0.15
|
| 19 |
+
output_cost_per_mtok: 0.60
|
| 20 |
+
|
| 21 |
+
rag:
|
| 22 |
+
chunking:
|
| 23 |
+
strategy: recursive
|
| 24 |
+
chunk_size: 512
|
| 25 |
+
chunk_overlap: 64
|
| 26 |
+
retrieval:
|
| 27 |
+
strategy: hybrid
|
| 28 |
+
rrf_k: 60
|
| 29 |
+
candidates_per_system: 10
|
| 30 |
+
top_k: 5
|
| 31 |
+
reranker:
|
| 32 |
+
enabled: true
|
| 33 |
+
model_name: cross-encoder/ms-marco-MiniLM-L-6-v2
|
| 34 |
+
top_k: 5
|
| 35 |
+
refusal_threshold: 0.02
|
| 36 |
+
store_path: .cache/store
|
| 37 |
+
|
| 38 |
+
embedding:
|
| 39 |
+
model: all-MiniLM-L6-v2
|
| 40 |
+
cache_dir: .cache/embeddings
|
| 41 |
+
|
| 42 |
+
retry:
|
| 43 |
+
max_retries: 3
|
| 44 |
+
base_delay: 1.0
|
| 45 |
+
max_delay: 8.0
|
| 46 |
+
|
| 47 |
+
memory:
|
| 48 |
+
enabled: false
|
| 49 |
+
|
| 50 |
+
serving:
|
| 51 |
+
host: 0.0.0.0
|
| 52 |
+
port: 8000
|
| 53 |
+
request_timeout_seconds: 120
|
| 54 |
+
rate_limit_rpm: 10
|
| 55 |
+
|
| 56 |
+
evaluation:
|
| 57 |
+
judge_provider: openai
|
| 58 |
+
golden_dataset: agent_bench/evaluation/datasets/tech_docs_golden.json
|
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
agent:
|
| 2 |
+
max_iterations: 1
|
| 3 |
+
temperature: 0.0
|
| 4 |
+
|
| 5 |
+
provider:
|
| 6 |
+
default: selfhosted
|
| 7 |
+
selfhosted:
|
| 8 |
+
# base_url and api_key read from MODAL_VLLM_URL / MODAL_AUTH_TOKEN env vars
|
| 9 |
+
model_name: mistralai/Mistral-7B-Instruct-v0.3
|
| 10 |
+
timeout_seconds: 300
|
| 11 |
+
models:
|
| 12 |
+
mistralai/Mistral-7B-Instruct-v0.3:
|
| 13 |
+
input_cost_per_mtok: 0.0
|
| 14 |
+
output_cost_per_mtok: 0.0
|
| 15 |
+
gpt-4o-mini:
|
| 16 |
+
input_cost_per_mtok: 0.15
|
| 17 |
+
output_cost_per_mtok: 0.60
|
| 18 |
+
|
| 19 |
+
rag:
|
| 20 |
+
chunking:
|
| 21 |
+
strategy: recursive
|
| 22 |
+
chunk_size: 512
|
| 23 |
+
chunk_overlap: 64
|
| 24 |
+
retrieval:
|
| 25 |
+
strategy: hybrid
|
| 26 |
+
rrf_k: 60
|
| 27 |
+
candidates_per_system: 10
|
| 28 |
+
top_k: 3
|
| 29 |
+
reranker:
|
| 30 |
+
enabled: true
|
| 31 |
+
model_name: cross-encoder/ms-marco-MiniLM-L-6-v2
|
| 32 |
+
top_k: 3
|
| 33 |
+
refusal_threshold: 0.02
|
| 34 |
+
store_path: .cache/store
|
| 35 |
+
|
| 36 |
+
embedding:
|
| 37 |
+
model: all-MiniLM-L6-v2
|
| 38 |
+
cache_dir: .cache/embeddings
|
| 39 |
+
|
| 40 |
+
retry:
|
| 41 |
+
max_retries: 3
|
| 42 |
+
base_delay: 1.0
|
| 43 |
+
max_delay: 8.0
|
| 44 |
+
|
| 45 |
+
memory:
|
| 46 |
+
enabled: false
|
| 47 |
+
|
| 48 |
+
serving:
|
| 49 |
+
host: 0.0.0.0
|
| 50 |
+
port: 8000
|
| 51 |
+
request_timeout_seconds: 120
|
| 52 |
+
rate_limit_rpm: 10
|
| 53 |
+
|
| 54 |
+
evaluation:
|
| 55 |
+
judge_provider: openai
|
| 56 |
+
golden_dataset: agent_bench/evaluation/datasets/tech_docs_golden.json
|
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Local GPU serving via vLLM + agent-bench API.
|
| 2 |
+
# Requires: nvidia-container-toolkit
|
| 3 |
+
# See modal/serve_vllm.py for serverless alternative.
|
| 4 |
+
#
|
| 5 |
+
# Usage:
|
| 6 |
+
# docker compose -f docker/docker-compose.vllm.yml up --build
|
| 7 |
+
|
| 8 |
+
services:
|
| 9 |
+
vllm:
|
| 10 |
+
image: vllm/vllm-openai:latest
|
| 11 |
+
command:
|
| 12 |
+
- --model=mistralai/Mistral-7B-Instruct-v0.3
|
| 13 |
+
- --max-model-len=4096
|
| 14 |
+
- --dtype=half
|
| 15 |
+
- --gpu-memory-utilization=0.85
|
| 16 |
+
- --host=0.0.0.0
|
| 17 |
+
- --port=8000
|
| 18 |
+
ports:
|
| 19 |
+
- "8001:8000"
|
| 20 |
+
deploy:
|
| 21 |
+
resources:
|
| 22 |
+
reservations:
|
| 23 |
+
devices:
|
| 24 |
+
- driver: nvidia
|
| 25 |
+
count: 1
|
| 26 |
+
capabilities: [gpu]
|
| 27 |
+
volumes:
|
| 28 |
+
- vllm-cache:/root/.cache/huggingface
|
| 29 |
+
healthcheck:
|
| 30 |
+
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
| 31 |
+
interval: 30s
|
| 32 |
+
timeout: 10s
|
| 33 |
+
retries: 5
|
| 34 |
+
start_period: 120s
|
| 35 |
+
|
| 36 |
+
app:
|
| 37 |
+
build:
|
| 38 |
+
context: ..
|
| 39 |
+
dockerfile: docker/Dockerfile
|
| 40 |
+
environment:
|
| 41 |
+
- MODAL_VLLM_URL=http://vllm:8000/v1
|
| 42 |
+
- AGENT_BENCH_ENV=selfhosted_local
|
| 43 |
+
depends_on:
|
| 44 |
+
vllm:
|
| 45 |
+
condition: service_healthy
|
| 46 |
+
ports:
|
| 47 |
+
- "8080:7860"
|
| 48 |
+
|
| 49 |
+
volumes:
|
| 50 |
+
vllm-cache:
|
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Kubernetes Local Setup (minikube)
|
| 2 |
+
|
| 3 |
+
## Prerequisites
|
| 4 |
+
|
| 5 |
+
- [minikube](https://minikube.sigs.k8s.io/docs/start/)
|
| 6 |
+
- [Helm](https://helm.sh/docs/intro/install/)
|
| 7 |
+
- Docker
|
| 8 |
+
|
| 9 |
+
## Deploy
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
# Start minikube
|
| 13 |
+
minikube start --cpus=4 --memory=8192
|
| 14 |
+
|
| 15 |
+
# Build image inside minikube's Docker daemon
|
| 16 |
+
eval $(minikube docker-env)
|
| 17 |
+
docker build -t agent-bench:latest -f docker/Dockerfile .
|
| 18 |
+
|
| 19 |
+
# Deploy with dev values
|
| 20 |
+
helm install agent-bench k8s/helm/agent-bench/ \
|
| 21 |
+
-f k8s/helm/agent-bench/values-dev.yaml \
|
| 22 |
+
--set provider.selfhosted.modalEndpoint=$MODAL_VLLM_URL
|
| 23 |
+
|
| 24 |
+
# Verify
|
| 25 |
+
kubectl get pods
|
| 26 |
+
kubectl port-forward svc/agent-bench 8080:8000
|
| 27 |
+
|
| 28 |
+
# Test
|
| 29 |
+
curl http://localhost:8080/health
|
| 30 |
+
curl -X POST http://localhost:8080/ask \
|
| 31 |
+
-H "Content-Type: application/json" \
|
| 32 |
+
-d '{"question": "How do I define a path parameter in FastAPI?"}'
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
## Teardown
|
| 36 |
+
|
| 37 |
+
```bash
|
| 38 |
+
helm uninstall agent-bench
|
| 39 |
+
minikube stop
|
| 40 |
+
```
|
|
@@ -1,64 +1,85 @@
|
|
| 1 |
-
# Provider Comparison
|
| 2 |
|
| 3 |
Evaluated on the same 27-question golden dataset over 16 FastAPI documentation files.
|
| 4 |
-
|
| 5 |
-
|
| 6 |
|
| 7 |
-
**The
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
##
|
| 10 |
|
| 11 |
-
| Provider | Model |
|
| 12 |
-
|----------|-------|---------|--------------------------------------|
|
| 13 |
-
| OpenAI | gpt-4o-mini |
|
| 14 |
-
| Anthropic | claude-haiku-4-5 |
|
|
|
|
| 15 |
|
| 16 |
-
##
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
| Keyword Hit Rate | 0.89 | **0.92** | +0.03 |
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
|
| 28 |
-
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
##
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
- **Citation format**: Both providers follow the `[source: filename.md]` citation format
|
| 45 |
-
specified in the system prompt.
|
| 46 |
-
- **Answer quality**: Haiku tends to produce more structured answers (numbered lists,
|
| 47 |
-
code examples) while gpt-4o-mini is more concise. Both are accurate.
|
| 48 |
|
| 49 |
## How to Reproduce
|
| 50 |
|
| 51 |
```bash
|
| 52 |
-
# OpenAI evaluation
|
| 53 |
OPENAI_API_KEY=sk-... python scripts/evaluate.py --mode deterministic
|
| 54 |
|
| 55 |
# Anthropic evaluation
|
| 56 |
ANTHROPIC_API_KEY=sk-ant-... python scripts/evaluate.py --config configs/anthropic.yaml --mode deterministic
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
```
|
| 58 |
|
| 59 |
## Takeaway
|
| 60 |
|
| 61 |
-
The provider abstraction works as designed — switching
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Provider Comparison: API vs Self-Hosted
|
| 2 |
|
| 3 |
Evaluated on the same 27-question golden dataset over 16 FastAPI documentation files.
|
| 4 |
+
All providers use hybrid retrieval (FAISS + BM25 + RRF), cross-encoder reranking,
|
| 5 |
+
grounded refusal threshold, and identical system prompt.
|
| 6 |
|
| 7 |
+
**Note:** The self-hosted config differs from API configs in two ways to accommodate
|
| 8 |
+
the 7B model's smaller context window (8192 tokens) and weaker instruction following:
|
| 9 |
+
`max_iterations=1` (vs 3) and `top_k=3` (vs 5). This means the self-hosted row is
|
| 10 |
+
**not a controlled comparison** — it reflects realistic operating constraints for a
|
| 11 |
+
7B model, not an apples-to-apples provider swap. The API providers are directly
|
| 12 |
+
comparable to each other.
|
| 13 |
|
| 14 |
+
## Results
|
| 15 |
|
| 16 |
+
| Provider | Model | Iterations | top_k | P@5 | R@5 | Citation Acc | Latency p50 (ms) | Cost/query |
|
| 17 |
+
|----------|-------|-----------|-------|-----|-----|--------------|-------------------|------------|
|
| 18 |
+
| OpenAI (API) | gpt-4o-mini | 3 | 5 | 0.70 | 0.83 | 1.00 | 4,690 | $0.0004 |
|
| 19 |
+
| Anthropic (API) | claude-haiku-4-5 | 3 | 5 | 0.74 | 0.84 | 1.00 | 5,120 | $0.0007 |
|
| 20 |
+
| Self-hosted (Modal) | Mistral-7B-Instruct-v0.3 | 1 | 3 | 0.05 | 0.05 | 0.14 | 6,709 | $0.0031 |
|
| 21 |
|
| 22 |
+
## Analysis
|
| 23 |
|
| 24 |
+
**Retrieval quality:** API models (gpt-4o-mini, claude-haiku) generate substantially better
|
| 25 |
+
search queries than Mistral-7B, reflected in P@5 (0.70-0.74 vs 0.05). The 7B model struggles
|
| 26 |
+
with prompt-based tool calling — it often produces malformed JSON or calls tools with
|
| 27 |
+
poor queries, degrading retrieval quality.
|
|
|
|
| 28 |
|
| 29 |
+
**Citation accuracy:** Both API providers achieve 1.00 citation accuracy (zero hallucinated
|
| 30 |
+
citations). Mistral-7B manages 0.14, frequently omitting or fabricating source references.
|
| 31 |
+
This is a known limitation of smaller models on instruction-following tasks.
|
| 32 |
|
| 33 |
+
**Latency:** Self-hosted latency (6,709ms p50) is higher than API providers due to the
|
| 34 |
+
proxy overhead and smaller model generating more tokens before reaching a final answer.
|
| 35 |
+
Cold start adds ~90s on first request (model download + GPU load).
|
| 36 |
|
| 37 |
+
**Cost:** Self-hosted cost ($0.0031/query) is computed from GPU-seconds
|
| 38 |
+
(latency x Modal A10G rate of $0.000361/sec). This is higher per-query than API providers
|
| 39 |
+
at low volume, but the cost model is fundamentally different — GPU cost scales with
|
| 40 |
+
compute time, not token count.
|
| 41 |
|
| 42 |
+
**Tool calling:** Mistral-7B does not support native OpenAI-format tool calling in vLLM
|
| 43 |
+
0.6.6. The provider falls back to prompt-based tool selection (injecting tool descriptions
|
| 44 |
+
into the system prompt and parsing JSON from the model's text output). This works but is
|
| 45 |
+
unreliable — a legitimate benchmark finding, not a failure.
|
| 46 |
|
| 47 |
+
## Infrastructure
|
| 48 |
|
| 49 |
+
| Config | Cold start | Warm latency p50 | GPU | Infra |
|
| 50 |
+
|--------|-----------|-------------------|-----|-------|
|
| 51 |
+
| OpenAI | N/A | 4,690 ms | N/A | Managed API |
|
| 52 |
+
| Anthropic | N/A | 5,120 ms | N/A | Managed API |
|
| 53 |
+
| Self-hosted (Modal) | ~90s | 6,709 ms | A10G (24GB) | Serverless GPU |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
## How to Reproduce
|
| 56 |
|
| 57 |
```bash
|
| 58 |
+
# OpenAI evaluation
|
| 59 |
OPENAI_API_KEY=sk-... python scripts/evaluate.py --mode deterministic
|
| 60 |
|
| 61 |
# Anthropic evaluation
|
| 62 |
ANTHROPIC_API_KEY=sk-ant-... python scripts/evaluate.py --config configs/anthropic.yaml --mode deterministic
|
| 63 |
+
|
| 64 |
+
# Self-hosted evaluation (requires Modal deployment + HF secret)
|
| 65 |
+
pip install -e ".[modal]"
|
| 66 |
+
modal secret create huggingface-secret HF_TOKEN=hf_...
|
| 67 |
+
modal deploy modal/serve_vllm.py
|
| 68 |
+
export MODAL_VLLM_URL=https://your--agent-bench-vllm-serve.modal.run/v1
|
| 69 |
+
python scripts/evaluate.py --config configs/selfhosted_modal.yaml --mode deterministic
|
| 70 |
+
|
| 71 |
+
# All providers at once
|
| 72 |
+
make benchmark-all
|
| 73 |
```
|
| 74 |
|
| 75 |
## Takeaway
|
| 76 |
|
| 77 |
+
The provider abstraction works as designed — switching providers is a single config change.
|
| 78 |
+
API models dominate on quality metrics, but the self-hosted path demonstrates end-to-end
|
| 79 |
+
inference serving: vLLM on Modal (serverless A10G), OpenAI-compatible endpoint, identical
|
| 80 |
+
evaluation harness. The quality gap is expected for a 7B model on RAG tasks and would
|
| 81 |
+
narrow with larger self-hosted models (e.g., Mixtral-8x7B, Llama-3-70B).
|
| 82 |
+
|
| 83 |
+
---
|
| 84 |
+
|
| 85 |
+
Generated by `modal/run_benchmark.py`
|
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
apiVersion: v2
|
| 2 |
+
name: agent-bench
|
| 3 |
+
description: Agentic RAG system with self-hosted LLM support
|
| 4 |
+
type: application
|
| 5 |
+
version: 0.1.0
|
| 6 |
+
appVersion: "0.1.0"
|
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{{/*
|
| 2 |
+
Expand the name of the chart.
|
| 3 |
+
*/}}
|
| 4 |
+
{{- define "agent-bench.name" -}}
|
| 5 |
+
{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }}
|
| 6 |
+
{{- end }}
|
| 7 |
+
|
| 8 |
+
{{/*
|
| 9 |
+
Create a default fully qualified app name.
|
| 10 |
+
*/}}
|
| 11 |
+
{{- define "agent-bench.fullname" -}}
|
| 12 |
+
{{- $name := default .Chart.Name .Values.nameOverride }}
|
| 13 |
+
{{- if .Values.fullnameOverride }}
|
| 14 |
+
{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }}
|
| 15 |
+
{{- else }}
|
| 16 |
+
{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }}
|
| 17 |
+
{{- end }}
|
| 18 |
+
{{- end }}
|
| 19 |
+
|
| 20 |
+
{{/*
|
| 21 |
+
Common labels
|
| 22 |
+
*/}}
|
| 23 |
+
{{- define "agent-bench.labels" -}}
|
| 24 |
+
helm.sh/chart: {{ .Chart.Name }}-{{ .Chart.Version }}
|
| 25 |
+
{{ include "agent-bench.selectorLabels" . }}
|
| 26 |
+
app.kubernetes.io/managed-by: {{ .Release.Service }}
|
| 27 |
+
{{- end }}
|
| 28 |
+
|
| 29 |
+
{{/*
|
| 30 |
+
Selector labels
|
| 31 |
+
*/}}
|
| 32 |
+
{{- define "agent-bench.selectorLabels" -}}
|
| 33 |
+
app.kubernetes.io/name: {{ include "agent-bench.name" . }}
|
| 34 |
+
app.kubernetes.io/instance: {{ .Release.Name }}
|
| 35 |
+
{{- end }}
|
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
apiVersion: v1
|
| 2 |
+
kind: ConfigMap
|
| 3 |
+
metadata:
|
| 4 |
+
name: {{ include "agent-bench.fullname" . }}-config
|
| 5 |
+
labels:
|
| 6 |
+
{{- include "agent-bench.labels" . | nindent 4 }}
|
| 7 |
+
data:
|
| 8 |
+
{{- if eq .Values.provider.type "selfhosted" }}
|
| 9 |
+
AGENT_BENCH_ENV: "selfhosted_modal"
|
| 10 |
+
SELFHOSTED_MODEL: {{ .Values.provider.selfhosted.model | quote }}
|
| 11 |
+
{{- else if eq .Values.provider.type "openai" }}
|
| 12 |
+
AGENT_BENCH_ENV: "default"
|
| 13 |
+
{{- else if eq .Values.provider.type "anthropic" }}
|
| 14 |
+
AGENT_BENCH_ENV: "anthropic"
|
| 15 |
+
{{- end }}
|
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
apiVersion: apps/v1
|
| 2 |
+
kind: Deployment
|
| 3 |
+
metadata:
|
| 4 |
+
name: {{ include "agent-bench.fullname" . }}
|
| 5 |
+
labels:
|
| 6 |
+
{{- include "agent-bench.labels" . | nindent 4 }}
|
| 7 |
+
spec:
|
| 8 |
+
{{- if not .Values.autoscaling.enabled }}
|
| 9 |
+
replicas: {{ .Values.replicaCount }}
|
| 10 |
+
{{- end }}
|
| 11 |
+
selector:
|
| 12 |
+
matchLabels:
|
| 13 |
+
{{- include "agent-bench.selectorLabels" . | nindent 6 }}
|
| 14 |
+
template:
|
| 15 |
+
metadata:
|
| 16 |
+
labels:
|
| 17 |
+
{{- include "agent-bench.selectorLabels" . | nindent 8 }}
|
| 18 |
+
spec:
|
| 19 |
+
containers:
|
| 20 |
+
- name: api
|
| 21 |
+
image: "{{ .Values.image.repository }}:{{ .Values.image.tag }}"
|
| 22 |
+
imagePullPolicy: {{ .Values.image.pullPolicy }}
|
| 23 |
+
ports:
|
| 24 |
+
- name: http
|
| 25 |
+
containerPort: 7860
|
| 26 |
+
protocol: TCP
|
| 27 |
+
envFrom:
|
| 28 |
+
- configMapRef:
|
| 29 |
+
name: {{ include "agent-bench.fullname" . }}-config
|
| 30 |
+
- secretRef:
|
| 31 |
+
name: {{ include "agent-bench.fullname" . }}-secrets
|
| 32 |
+
livenessProbe:
|
| 33 |
+
httpGet:
|
| 34 |
+
path: {{ .Values.probes.liveness.path }}
|
| 35 |
+
port: 7860
|
| 36 |
+
initialDelaySeconds: {{ .Values.probes.liveness.initialDelaySeconds }}
|
| 37 |
+
periodSeconds: {{ .Values.probes.liveness.periodSeconds }}
|
| 38 |
+
readinessProbe:
|
| 39 |
+
httpGet:
|
| 40 |
+
path: {{ .Values.probes.readiness.path }}
|
| 41 |
+
port: 7860
|
| 42 |
+
initialDelaySeconds: {{ .Values.probes.readiness.initialDelaySeconds }}
|
| 43 |
+
periodSeconds: {{ .Values.probes.readiness.periodSeconds }}
|
| 44 |
+
resources:
|
| 45 |
+
{{- toYaml .Values.resources | nindent 12 }}
|
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{{- if .Values.autoscaling.enabled }}
|
| 2 |
+
apiVersion: autoscaling/v2
|
| 3 |
+
kind: HorizontalPodAutoscaler
|
| 4 |
+
metadata:
|
| 5 |
+
name: {{ include "agent-bench.fullname" . }}
|
| 6 |
+
labels:
|
| 7 |
+
{{- include "agent-bench.labels" . | nindent 4 }}
|
| 8 |
+
spec:
|
| 9 |
+
scaleTargetRef:
|
| 10 |
+
apiVersion: apps/v1
|
| 11 |
+
kind: Deployment
|
| 12 |
+
name: {{ include "agent-bench.fullname" . }}
|
| 13 |
+
minReplicas: {{ .Values.autoscaling.minReplicas }}
|
| 14 |
+
maxReplicas: {{ .Values.autoscaling.maxReplicas }}
|
| 15 |
+
metrics:
|
| 16 |
+
- type: Resource
|
| 17 |
+
resource:
|
| 18 |
+
name: cpu
|
| 19 |
+
target:
|
| 20 |
+
type: Utilization
|
| 21 |
+
averageUtilization: {{ .Values.autoscaling.targetCPUUtilization }}
|
| 22 |
+
{{- end }}
|
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
apiVersion: v1
|
| 2 |
+
kind: Secret
|
| 3 |
+
metadata:
|
| 4 |
+
name: {{ include "agent-bench.fullname" . }}-secrets
|
| 5 |
+
labels:
|
| 6 |
+
{{- include "agent-bench.labels" . | nindent 4 }}
|
| 7 |
+
type: Opaque
|
| 8 |
+
stringData:
|
| 9 |
+
MODAL_VLLM_URL: {{ .Values.provider.selfhosted.modalEndpoint | quote }}
|
| 10 |
+
MODAL_AUTH_TOKEN: {{ .Values.provider.selfhosted.modalAuthToken | quote }}
|
| 11 |
+
OPENAI_API_KEY: {{ .Values.provider.openaiApiKey | quote }}
|
| 12 |
+
ANTHROPIC_API_KEY: {{ .Values.provider.anthropicApiKey | quote }}
|
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
apiVersion: v1
|
| 2 |
+
kind: Service
|
| 3 |
+
metadata:
|
| 4 |
+
name: {{ include "agent-bench.fullname" . }}
|
| 5 |
+
labels:
|
| 6 |
+
{{- include "agent-bench.labels" . | nindent 4 }}
|
| 7 |
+
spec:
|
| 8 |
+
type: {{ .Values.service.type }}
|
| 9 |
+
ports:
|
| 10 |
+
- port: {{ .Values.service.port }}
|
| 11 |
+
targetPort: 7860
|
| 12 |
+
protocol: TCP
|
| 13 |
+
name: http
|
| 14 |
+
selector:
|
| 15 |
+
{{- include "agent-bench.selectorLabels" . | nindent 4 }}
|
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
replicaCount: 1
|
| 2 |
+
|
| 3 |
+
autoscaling:
|
| 4 |
+
enabled: false
|
| 5 |
+
|
| 6 |
+
resources:
|
| 7 |
+
requests:
|
| 8 |
+
cpu: 250m
|
| 9 |
+
memory: 512Mi
|
| 10 |
+
limits:
|
| 11 |
+
cpu: 1000m
|
| 12 |
+
memory: 2Gi
|
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
replicaCount: 3
|
| 2 |
+
|
| 3 |
+
autoscaling:
|
| 4 |
+
enabled: true
|
| 5 |
+
minReplicas: 2
|
| 6 |
+
maxReplicas: 8
|
| 7 |
+
targetCPUUtilization: 70
|
| 8 |
+
|
| 9 |
+
resources:
|
| 10 |
+
requests:
|
| 11 |
+
cpu: 500m
|
| 12 |
+
memory: 1Gi
|
| 13 |
+
limits:
|
| 14 |
+
cpu: 2000m
|
| 15 |
+
memory: 4Gi
|
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
replicaCount: 2
|
| 2 |
+
|
| 3 |
+
image:
|
| 4 |
+
repository: agent-bench
|
| 5 |
+
tag: latest
|
| 6 |
+
pullPolicy: IfNotPresent
|
| 7 |
+
|
| 8 |
+
service:
|
| 9 |
+
type: ClusterIP
|
| 10 |
+
port: 8000
|
| 11 |
+
|
| 12 |
+
provider:
|
| 13 |
+
type: selfhosted
|
| 14 |
+
selfhosted:
|
| 15 |
+
model: mistralai/Mistral-7B-Instruct-v0.3
|
| 16 |
+
modalEndpoint: ""
|
| 17 |
+
modalAuthToken: ""
|
| 18 |
+
openaiApiKey: ""
|
| 19 |
+
anthropicApiKey: ""
|
| 20 |
+
|
| 21 |
+
autoscaling:
|
| 22 |
+
enabled: true
|
| 23 |
+
minReplicas: 2
|
| 24 |
+
maxReplicas: 8
|
| 25 |
+
targetCPUUtilization: 70
|
| 26 |
+
|
| 27 |
+
resources:
|
| 28 |
+
requests:
|
| 29 |
+
cpu: 500m
|
| 30 |
+
memory: 1Gi
|
| 31 |
+
limits:
|
| 32 |
+
cpu: 2000m
|
| 33 |
+
memory: 4Gi
|
| 34 |
+
|
| 35 |
+
probes:
|
| 36 |
+
liveness:
|
| 37 |
+
path: /health
|
| 38 |
+
initialDelaySeconds: 10
|
| 39 |
+
periodSeconds: 30
|
| 40 |
+
readiness:
|
| 41 |
+
path: /health
|
| 42 |
+
initialDelaySeconds: 5
|
| 43 |
+
periodSeconds: 10
|
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared constants for Modal deployments."""
|
| 2 |
+
|
| 3 |
+
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
|
| 4 |
+
GPU_TYPE = "a10g"
|
| 5 |
+
VLLM_MAX_MODEL_LEN = 8192
|
| 6 |
+
VLLM_DTYPE = "half"
|
| 7 |
+
VLLM_GPU_MEMORY_UTILIZATION = 0.85
|
| 8 |
+
|
| 9 |
+
# Cost tracking (for provider comparison report)
|
| 10 |
+
# Modal A10G: ~$0.000361/sec (~$1.30/hr)
|
| 11 |
+
MODAL_A10G_COST_PER_SEC = 0.000361
|
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Run the 27-question benchmark against all provider configurations.
|
| 2 |
+
|
| 3 |
+
Usage:
|
| 4 |
+
# Run against a deployed Modal endpoint
|
| 5 |
+
python modal/run_benchmark.py --base-url https://...modal.run/v1
|
| 6 |
+
|
| 7 |
+
# Optionally restrict to specific providers
|
| 8 |
+
python modal/run_benchmark.py --base-url https://...modal.run/v1 --only selfhosted_modal
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
import statistics
|
| 17 |
+
import subprocess
|
| 18 |
+
import sys
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def run_eval(config_path: str, env: dict[str, str]) -> list[dict] | None:
|
| 25 |
+
"""Run scripts/evaluate.py and return the list of EvalResult dicts."""
|
| 26 |
+
output_path = f".cache/eval_{Path(config_path).stem}.json"
|
| 27 |
+
result = subprocess.run(
|
| 28 |
+
[
|
| 29 |
+
sys.executable,
|
| 30 |
+
"scripts/evaluate.py",
|
| 31 |
+
"--config",
|
| 32 |
+
config_path,
|
| 33 |
+
"--mode",
|
| 34 |
+
"deterministic",
|
| 35 |
+
"--output",
|
| 36 |
+
output_path,
|
| 37 |
+
],
|
| 38 |
+
capture_output=True,
|
| 39 |
+
text=True,
|
| 40 |
+
env=env,
|
| 41 |
+
cwd=str(PROJECT_ROOT),
|
| 42 |
+
)
|
| 43 |
+
if result.returncode != 0:
|
| 44 |
+
print(f"FAILED: {config_path}\n{result.stderr}", file=sys.stderr)
|
| 45 |
+
return None
|
| 46 |
+
output_file = PROJECT_ROOT / output_path
|
| 47 |
+
if not output_file.exists():
|
| 48 |
+
print(f"FAILED: output not created: {output_path}", file=sys.stderr)
|
| 49 |
+
return None
|
| 50 |
+
with open(output_file) as f:
|
| 51 |
+
data = json.load(f)
|
| 52 |
+
if not isinstance(data, list):
|
| 53 |
+
print(f"FAILED: expected list, got {type(data).__name__}", file=sys.stderr)
|
| 54 |
+
return None
|
| 55 |
+
return data
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def aggregate(results: list[dict], provider_name: str = "") -> dict:
|
| 59 |
+
"""Compute aggregate metrics from a list of EvalResult dicts.
|
| 60 |
+
|
| 61 |
+
For selfhosted providers, cost is computed from GPU-seconds (latency *
|
| 62 |
+
MODAL_A10G_COST_PER_SEC) rather than token pricing, which is zero.
|
| 63 |
+
"""
|
| 64 |
+
from common import MODAL_A10G_COST_PER_SEC
|
| 65 |
+
|
| 66 |
+
positive = [r for r in results if r.get("category") != "out_of_scope"]
|
| 67 |
+
if not positive:
|
| 68 |
+
return {}
|
| 69 |
+
|
| 70 |
+
# For self-hosted, derive cost from GPU time; for API providers, use token cost
|
| 71 |
+
is_selfhosted = "selfhosted" in provider_name
|
| 72 |
+
if is_selfhosted:
|
| 73 |
+
avg_cost = statistics.mean(
|
| 74 |
+
(r["latency_ms"] / 1000.0) * MODAL_A10G_COST_PER_SEC
|
| 75 |
+
for r in positive
|
| 76 |
+
)
|
| 77 |
+
else:
|
| 78 |
+
avg_cost = statistics.mean(
|
| 79 |
+
r.get("tokens_used", {}).get("estimated_cost_usd", 0.0)
|
| 80 |
+
for r in positive
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
return {
|
| 84 |
+
"retrieval_precision": statistics.mean(
|
| 85 |
+
r["retrieval_precision"] for r in positive
|
| 86 |
+
),
|
| 87 |
+
"retrieval_recall": statistics.mean(
|
| 88 |
+
r["retrieval_recall"] for r in positive
|
| 89 |
+
),
|
| 90 |
+
"citation_accuracy": statistics.mean(
|
| 91 |
+
r["citation_accuracy"] for r in positive
|
| 92 |
+
),
|
| 93 |
+
"latency_p50_ms": statistics.median(
|
| 94 |
+
r["latency_ms"] for r in positive
|
| 95 |
+
),
|
| 96 |
+
"avg_cost_usd": avg_cost,
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def generate_report(
|
| 101 |
+
all_results: dict[str, list[dict] | None], output_path: str
|
| 102 |
+
) -> None:
|
| 103 |
+
"""Generate docs/provider_comparison.md from benchmark results."""
|
| 104 |
+
lines = [
|
| 105 |
+
"# Provider Comparison: API vs Self-Hosted",
|
| 106 |
+
"",
|
| 107 |
+
"Benchmark: 27-question golden dataset "
|
| 108 |
+
"(19 retrieval, 3 calculation, 5 out-of-scope).",
|
| 109 |
+
"",
|
| 110 |
+
"| Provider | P@5 | R@5 | Citation Acc | Latency p50 (ms) | Cost/query |",
|
| 111 |
+
"|----------|-----|-----|--------------|-------------------|------------|",
|
| 112 |
+
]
|
| 113 |
+
for name, results in all_results.items():
|
| 114 |
+
if results is None:
|
| 115 |
+
lines.append(f"| {name} | ERROR | - | - | - | - |")
|
| 116 |
+
continue
|
| 117 |
+
agg = aggregate(results, provider_name=name)
|
| 118 |
+
if not agg:
|
| 119 |
+
lines.append(f"| {name} | NO DATA | - | - | - | - |")
|
| 120 |
+
continue
|
| 121 |
+
lines.append(
|
| 122 |
+
f"| {name} "
|
| 123 |
+
f"| {agg['retrieval_precision']:.2f} "
|
| 124 |
+
f"| {agg['retrieval_recall']:.2f} "
|
| 125 |
+
f"| {agg['citation_accuracy']:.2f} "
|
| 126 |
+
f"| {agg['latency_p50_ms']:.0f} "
|
| 127 |
+
f"| ${agg['avg_cost_usd']:.4f} |"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
lines.extend(["", "---", "", "Generated by `modal/run_benchmark.py`"])
|
| 131 |
+
|
| 132 |
+
out = PROJECT_ROOT / output_path
|
| 133 |
+
out.parent.mkdir(parents=True, exist_ok=True)
|
| 134 |
+
out.write_text("\n".join(lines))
|
| 135 |
+
print(f"Report written to {output_path}")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def main() -> None:
|
| 139 |
+
parser = argparse.ArgumentParser(description="Run provider comparison benchmark")
|
| 140 |
+
parser.add_argument(
|
| 141 |
+
"--base-url",
|
| 142 |
+
help="Modal vLLM endpoint URL (required when running selfhosted_modal)",
|
| 143 |
+
)
|
| 144 |
+
parser.add_argument(
|
| 145 |
+
"--only",
|
| 146 |
+
help="Run only this provider (e.g., selfhosted_modal, openai, anthropic)",
|
| 147 |
+
)
|
| 148 |
+
args = parser.parse_args()
|
| 149 |
+
|
| 150 |
+
configs = [
|
| 151 |
+
("openai", "configs/default.yaml"),
|
| 152 |
+
("anthropic", "configs/anthropic.yaml"),
|
| 153 |
+
("selfhosted_modal", "configs/selfhosted_modal.yaml"),
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
if args.only:
|
| 157 |
+
configs = [(n, p) for n, p in configs if n == args.only]
|
| 158 |
+
if not configs:
|
| 159 |
+
parser.error(f"Unknown provider: {args.only}")
|
| 160 |
+
|
| 161 |
+
needs_base_url = any(n == "selfhosted_modal" for n, _ in configs)
|
| 162 |
+
if needs_base_url and not args.base_url:
|
| 163 |
+
parser.error("--base-url is required when running selfhosted_modal")
|
| 164 |
+
|
| 165 |
+
all_results: dict[str, list[dict] | None] = {}
|
| 166 |
+
for name, config_path in configs:
|
| 167 |
+
print(f"\n--- Running: {name} ({config_path}) ---")
|
| 168 |
+
env = os.environ.copy()
|
| 169 |
+
if name == "selfhosted_modal" and args.base_url:
|
| 170 |
+
env["MODAL_VLLM_URL"] = args.base_url
|
| 171 |
+
results = run_eval(config_path, env)
|
| 172 |
+
if results is None:
|
| 173 |
+
print(f"\nABORTING: {name} failed, stopping benchmark run.",
|
| 174 |
+
file=sys.stderr)
|
| 175 |
+
sys.exit(1)
|
| 176 |
+
all_results[name] = results
|
| 177 |
+
|
| 178 |
+
generate_report(all_results, "docs/provider_comparison.md")
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
if __name__ == "__main__":
|
| 182 |
+
main()
|
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Deploy vLLM on Modal as an OpenAI-compatible endpoint.
|
| 2 |
+
|
| 3 |
+
Usage:
|
| 4 |
+
modal deploy modal/serve_vllm.py # Deploy (stays running, prints URL)
|
| 5 |
+
modal serve modal/serve_vllm.py # Dev mode (auto-redeploys on change)
|
| 6 |
+
|
| 7 |
+
The printed URL is the MODAL_VLLM_URL for SelfHostedProvider:
|
| 8 |
+
export MODAL_VLLM_URL=https://<your-workspace>--agent-bench-vllm-serve.modal.run/v1
|
| 9 |
+
|
| 10 |
+
Note: The vLLM server integration pattern changes between vLLM releases.
|
| 11 |
+
If deployment fails, check Modal's vLLM example for the current API:
|
| 12 |
+
https://modal.com/docs/examples/vllm_inference
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import modal
|
| 16 |
+
|
| 17 |
+
# Inlined from common.py — Modal containers don't auto-include sibling modules
|
| 18 |
+
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
|
| 19 |
+
VLLM_MAX_MODEL_LEN = 8192
|
| 20 |
+
VLLM_DTYPE = "half"
|
| 21 |
+
VLLM_GPU_MEMORY_UTILIZATION = 0.85
|
| 22 |
+
|
| 23 |
+
MODELS_DIR = "/models"
|
| 24 |
+
VLLM_PORT = 8000
|
| 25 |
+
VLLM_READY_TIMEOUT = 600 # seconds to wait for vLLM to become ready (download + load)
|
| 26 |
+
|
| 27 |
+
vllm_image = (
|
| 28 |
+
modal.Image.debian_slim(python_version="3.11")
|
| 29 |
+
.pip_install(
|
| 30 |
+
"vllm==0.6.6.post1",
|
| 31 |
+
"transformers==4.47.0",
|
| 32 |
+
"huggingface_hub[hf_transfer]<1.0",
|
| 33 |
+
"httpx",
|
| 34 |
+
)
|
| 35 |
+
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
app = modal.App("agent-bench-vllm")
|
| 39 |
+
model_volume = modal.Volume.from_name("vllm-model-cache", create_if_missing=True)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@app.function(
|
| 43 |
+
image=vllm_image,
|
| 44 |
+
gpu="a10g",
|
| 45 |
+
scaledown_window=600,
|
| 46 |
+
timeout=900,
|
| 47 |
+
volumes={MODELS_DIR: model_volume},
|
| 48 |
+
secrets=[modal.Secret.from_name("huggingface-secret")],
|
| 49 |
+
)
|
| 50 |
+
@modal.asgi_app()
|
| 51 |
+
def serve():
|
| 52 |
+
"""Serve vLLM with OpenAI-compatible API.
|
| 53 |
+
|
| 54 |
+
Exposes /v1/chat/completions and /health.
|
| 55 |
+
Waits for the vLLM subprocess to be ready before accepting requests.
|
| 56 |
+
"""
|
| 57 |
+
import subprocess
|
| 58 |
+
import time
|
| 59 |
+
|
| 60 |
+
import httpx
|
| 61 |
+
from fastapi import FastAPI, Request
|
| 62 |
+
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
| 63 |
+
|
| 64 |
+
vllm_process = subprocess.Popen(
|
| 65 |
+
[
|
| 66 |
+
"python", "-m", "vllm.entrypoints.openai.api_server",
|
| 67 |
+
"--model", MODEL_NAME,
|
| 68 |
+
"--download-dir", MODELS_DIR,
|
| 69 |
+
"--dtype", VLLM_DTYPE,
|
| 70 |
+
"--max-model-len", str(VLLM_MAX_MODEL_LEN),
|
| 71 |
+
"--gpu-memory-utilization", str(VLLM_GPU_MEMORY_UTILIZATION),
|
| 72 |
+
"--host", "0.0.0.0",
|
| 73 |
+
"--port", str(VLLM_PORT),
|
| 74 |
+
],
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Wait for vLLM to be ready before accepting proxied requests
|
| 78 |
+
base = f"http://localhost:{VLLM_PORT}"
|
| 79 |
+
deadline = time.monotonic() + VLLM_READY_TIMEOUT
|
| 80 |
+
while time.monotonic() < deadline:
|
| 81 |
+
try:
|
| 82 |
+
r = httpx.get(f"{base}/health", timeout=2.0)
|
| 83 |
+
if r.status_code == 200:
|
| 84 |
+
break
|
| 85 |
+
except httpx.HTTPError:
|
| 86 |
+
pass
|
| 87 |
+
if vllm_process.poll() is not None:
|
| 88 |
+
raise RuntimeError(
|
| 89 |
+
f"vLLM process exited with code {vllm_process.returncode}"
|
| 90 |
+
)
|
| 91 |
+
time.sleep(2)
|
| 92 |
+
else:
|
| 93 |
+
vllm_process.terminate()
|
| 94 |
+
raise TimeoutError(
|
| 95 |
+
f"vLLM did not become ready within {VLLM_READY_TIMEOUT}s"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
proxy_app = FastAPI()
|
| 99 |
+
client = httpx.AsyncClient(base_url=base, timeout=120.0)
|
| 100 |
+
|
| 101 |
+
@proxy_app.api_route("/{path:path}", methods=["GET", "POST"])
|
| 102 |
+
async def proxy(path: str, request: Request):
|
| 103 |
+
"""Proxy all requests to the vLLM subprocess."""
|
| 104 |
+
import traceback as _tb
|
| 105 |
+
try:
|
| 106 |
+
return await _proxy_inner(path, request)
|
| 107 |
+
except Exception as exc:
|
| 108 |
+
_tb.print_exc()
|
| 109 |
+
return JSONResponse(
|
| 110 |
+
content={"error": str(exc), "type": type(exc).__name__},
|
| 111 |
+
status_code=502,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
async def _proxy_inner(path: str, request: Request):
|
| 115 |
+
url = f"/{path}"
|
| 116 |
+
body = await request.body()
|
| 117 |
+
headers = {
|
| 118 |
+
k: v for k, v in request.headers.items()
|
| 119 |
+
if k.lower() not in ("host", "content-length")
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
# Detect streaming: check body for "stream": true (httpx sends
|
| 123 |
+
# Accept: */*, not text/event-stream, so header check is unreliable)
|
| 124 |
+
is_streaming = False
|
| 125 |
+
if body:
|
| 126 |
+
try:
|
| 127 |
+
import json as _json
|
| 128 |
+
is_streaming = _json.loads(body).get("stream", False)
|
| 129 |
+
except (ValueError, AttributeError):
|
| 130 |
+
pass
|
| 131 |
+
if not is_streaming:
|
| 132 |
+
is_streaming = request.headers.get("accept") == "text/event-stream"
|
| 133 |
+
|
| 134 |
+
if is_streaming:
|
| 135 |
+
req = client.build_request(
|
| 136 |
+
request.method, url, content=body, headers=headers
|
| 137 |
+
)
|
| 138 |
+
upstream = await client.send(req, stream=True)
|
| 139 |
+
|
| 140 |
+
if upstream.status_code != 200:
|
| 141 |
+
error_body = await upstream.aread()
|
| 142 |
+
await upstream.aclose()
|
| 143 |
+
return Response(
|
| 144 |
+
content=error_body,
|
| 145 |
+
status_code=upstream.status_code,
|
| 146 |
+
media_type="application/json",
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
async def stream():
|
| 150 |
+
try:
|
| 151 |
+
async for chunk in upstream.aiter_bytes():
|
| 152 |
+
yield chunk
|
| 153 |
+
finally:
|
| 154 |
+
await upstream.aclose()
|
| 155 |
+
|
| 156 |
+
return StreamingResponse(
|
| 157 |
+
stream(),
|
| 158 |
+
status_code=upstream.status_code,
|
| 159 |
+
media_type="text/event-stream",
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
resp = await client.request(
|
| 163 |
+
request.method, url, content=body, headers=headers
|
| 164 |
+
)
|
| 165 |
+
# Not all endpoints return JSON (e.g. /health returns empty 200)
|
| 166 |
+
try:
|
| 167 |
+
content = resp.json()
|
| 168 |
+
except Exception:
|
| 169 |
+
return Response(
|
| 170 |
+
content=resp.content,
|
| 171 |
+
status_code=resp.status_code,
|
| 172 |
+
media_type=resp.headers.get("content-type", "text/plain"),
|
| 173 |
+
)
|
| 174 |
+
return JSONResponse(
|
| 175 |
+
content=content,
|
| 176 |
+
status_code=resp.status_code,
|
| 177 |
+
headers={
|
| 178 |
+
k: v for k, v in resp.headers.items()
|
| 179 |
+
if k.lower() not in ("content-length", "transfer-encoding")
|
| 180 |
+
},
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
@proxy_app.on_event("shutdown")
|
| 184 |
+
def shutdown():
|
| 185 |
+
vllm_process.terminate()
|
| 186 |
+
|
| 187 |
+
return proxy_app
|
|
@@ -32,6 +32,9 @@ dev = [
|
|
| 32 |
"respx>=0.21.0",
|
| 33 |
"types-PyYAML",
|
| 34 |
]
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
[tool.setuptools.packages.find]
|
| 37 |
include = ["agent_bench*"]
|
|
|
|
| 32 |
"respx>=0.21.0",
|
| 33 |
"types-PyYAML",
|
| 34 |
]
|
| 35 |
+
modal = [
|
| 36 |
+
"modal>=0.66.0",
|
| 37 |
+
]
|
| 38 |
|
| 39 |
[tool.setuptools.packages.find]
|
| 40 |
include = ["agent_bench*"]
|
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
terraform {
|
| 2 |
+
required_version = ">= 1.5"
|
| 3 |
+
required_providers {
|
| 4 |
+
google = {
|
| 5 |
+
source = "hashicorp/google"
|
| 6 |
+
version = "~> 5.0"
|
| 7 |
+
}
|
| 8 |
+
}
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
provider "google" {
|
| 12 |
+
project = var.project_id
|
| 13 |
+
region = var.region
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
module "networking" {
|
| 17 |
+
source = "./modules/networking"
|
| 18 |
+
project_id = var.project_id
|
| 19 |
+
region = var.region
|
| 20 |
+
cluster_name = var.cluster_name
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
module "gke" {
|
| 24 |
+
source = "./modules/gke"
|
| 25 |
+
project_id = var.project_id
|
| 26 |
+
region = var.region
|
| 27 |
+
cluster_name = var.cluster_name
|
| 28 |
+
network = module.networking.network_name
|
| 29 |
+
subnetwork = module.networking.subnetwork_name
|
| 30 |
+
cpu_node_count = 2
|
| 31 |
+
cpu_machine_type = "e2-standard-4"
|
| 32 |
+
}
|
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
resource "google_container_cluster" "primary" {
|
| 2 |
+
name = var.cluster_name
|
| 3 |
+
location = var.region
|
| 4 |
+
project = var.project_id
|
| 5 |
+
|
| 6 |
+
network = var.network
|
| 7 |
+
subnetwork = var.subnetwork
|
| 8 |
+
|
| 9 |
+
# Autopilot disabled — we manage node pools explicitly
|
| 10 |
+
enable_autopilot = false
|
| 11 |
+
|
| 12 |
+
# Remove default node pool (we create our own)
|
| 13 |
+
remove_default_node_pool = true
|
| 14 |
+
initial_node_count = 1
|
| 15 |
+
|
| 16 |
+
ip_allocation_policy {
|
| 17 |
+
cluster_secondary_range_name = "pods"
|
| 18 |
+
services_secondary_range_name = "services"
|
| 19 |
+
}
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
resource "google_container_node_pool" "cpu_pool" {
|
| 23 |
+
name = "${var.cluster_name}-cpu-pool"
|
| 24 |
+
location = var.region
|
| 25 |
+
cluster = google_container_cluster.primary.name
|
| 26 |
+
node_count = var.cpu_node_count
|
| 27 |
+
project = var.project_id
|
| 28 |
+
|
| 29 |
+
node_config {
|
| 30 |
+
machine_type = var.cpu_machine_type
|
| 31 |
+
disk_size_gb = 50
|
| 32 |
+
disk_type = "pd-standard"
|
| 33 |
+
|
| 34 |
+
oauth_scopes = [
|
| 35 |
+
"https://www.googleapis.com/auth/cloud-platform",
|
| 36 |
+
]
|
| 37 |
+
}
|
| 38 |
+
}
|
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
output "cluster_name" {
|
| 2 |
+
value = google_container_cluster.primary.name
|
| 3 |
+
}
|
| 4 |
+
|
| 5 |
+
output "cluster_endpoint" {
|
| 6 |
+
value = google_container_cluster.primary.endpoint
|
| 7 |
+
sensitive = true
|
| 8 |
+
}
|
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
variable "project_id" {
|
| 2 |
+
type = string
|
| 3 |
+
}
|
| 4 |
+
|
| 5 |
+
variable "region" {
|
| 6 |
+
type = string
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
variable "cluster_name" {
|
| 10 |
+
type = string
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
variable "network" {
|
| 14 |
+
type = string
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
variable "subnetwork" {
|
| 18 |
+
type = string
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
variable "cpu_node_count" {
|
| 22 |
+
type = number
|
| 23 |
+
default = 2
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
variable "cpu_machine_type" {
|
| 27 |
+
type = string
|
| 28 |
+
default = "e2-standard-4"
|
| 29 |
+
}
|
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
resource "google_compute_network" "vpc" {
|
| 2 |
+
name = "${var.cluster_name}-vpc"
|
| 3 |
+
auto_create_subnetworks = false
|
| 4 |
+
project = var.project_id
|
| 5 |
+
}
|
| 6 |
+
|
| 7 |
+
resource "google_compute_subnetwork" "subnet" {
|
| 8 |
+
name = "${var.cluster_name}-subnet"
|
| 9 |
+
ip_cidr_range = "10.0.0.0/24"
|
| 10 |
+
region = var.region
|
| 11 |
+
network = google_compute_network.vpc.id
|
| 12 |
+
project = var.project_id
|
| 13 |
+
|
| 14 |
+
secondary_ip_range {
|
| 15 |
+
range_name = "pods"
|
| 16 |
+
ip_cidr_range = "10.1.0.0/16"
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
secondary_ip_range {
|
| 20 |
+
range_name = "services"
|
| 21 |
+
ip_cidr_range = "10.2.0.0/20"
|
| 22 |
+
}
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
resource "google_compute_firewall" "allow_internal" {
|
| 26 |
+
name = "${var.cluster_name}-allow-internal"
|
| 27 |
+
network = google_compute_network.vpc.name
|
| 28 |
+
project = var.project_id
|
| 29 |
+
|
| 30 |
+
allow {
|
| 31 |
+
protocol = "tcp"
|
| 32 |
+
ports = ["0-65535"]
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
allow {
|
| 36 |
+
protocol = "udp"
|
| 37 |
+
ports = ["0-65535"]
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
allow {
|
| 41 |
+
protocol = "icmp"
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
source_ranges = ["10.0.0.0/8"]
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
resource "google_compute_firewall" "allow_health_checks" {
|
| 48 |
+
name = "${var.cluster_name}-allow-health-checks"
|
| 49 |
+
network = google_compute_network.vpc.name
|
| 50 |
+
project = var.project_id
|
| 51 |
+
|
| 52 |
+
allow {
|
| 53 |
+
protocol = "tcp"
|
| 54 |
+
ports = ["80", "443", "8000", "7860"]
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
# GCP health check IP ranges
|
| 58 |
+
source_ranges = ["35.191.0.0/16", "130.211.0.0/22"]
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
output "network_name" {
|
| 62 |
+
value = google_compute_network.vpc.name
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
output "subnetwork_name" {
|
| 66 |
+
value = google_compute_subnetwork.subnet.name
|
| 67 |
+
}
|
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
variable "project_id" {
|
| 2 |
+
type = string
|
| 3 |
+
}
|
| 4 |
+
|
| 5 |
+
variable "region" {
|
| 6 |
+
type = string
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
variable "cluster_name" {
|
| 10 |
+
type = string
|
| 11 |
+
}
|
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
output "cluster_name" {
|
| 2 |
+
description = "GKE cluster name"
|
| 3 |
+
value = module.gke.cluster_name
|
| 4 |
+
}
|
| 5 |
+
|
| 6 |
+
output "cluster_endpoint" {
|
| 7 |
+
description = "GKE cluster endpoint"
|
| 8 |
+
value = module.gke.cluster_endpoint
|
| 9 |
+
sensitive = true
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
output "kubeconfig_command" {
|
| 13 |
+
description = "Command to configure kubectl"
|
| 14 |
+
value = "gcloud container clusters get-credentials ${var.cluster_name} --region ${var.region} --project ${var.project_id}"
|
| 15 |
+
}
|
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copy to terraform.tfvars and fill in values.
|
| 2 |
+
# terraform.tfvars is gitignored.
|
| 3 |
+
|
| 4 |
+
project_id = "your-gcp-project-id"
|
| 5 |
+
region = "europe-west1"
|
| 6 |
+
cluster_name = "agent-bench-cluster"
|
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
variable "project_id" {
|
| 2 |
+
description = "GCP project ID"
|
| 3 |
+
type = string
|
| 4 |
+
}
|
| 5 |
+
|
| 6 |
+
variable "region" {
|
| 7 |
+
description = "GCP region for the cluster"
|
| 8 |
+
type = string
|
| 9 |
+
default = "europe-west1"
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
variable "cluster_name" {
|
| 13 |
+
description = "GKE cluster name"
|
| 14 |
+
type = string
|
| 15 |
+
default = "agent-bench-cluster"
|
| 16 |
+
}
|
|
@@ -0,0 +1,689 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the SelfHostedProvider (OpenAI-compatible endpoint)."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
import httpx
|
| 6 |
+
import pytest
|
| 7 |
+
import respx
|
| 8 |
+
|
| 9 |
+
from agent_bench.core.config import (
|
| 10 |
+
AppConfig,
|
| 11 |
+
ProviderConfig,
|
| 12 |
+
RetryConfig,
|
| 13 |
+
SelfHostedConfig,
|
| 14 |
+
)
|
| 15 |
+
from agent_bench.core.provider import (
|
| 16 |
+
ProviderRateLimitError,
|
| 17 |
+
ProviderTimeoutError,
|
| 18 |
+
SelfHostedProvider,
|
| 19 |
+
create_provider,
|
| 20 |
+
)
|
| 21 |
+
from agent_bench.core.types import Message, Role, ToolDefinition
|
| 22 |
+
|
| 23 |
+
# --- Helpers ---
|
| 24 |
+
|
| 25 |
+
FAKE_URL = "http://fake-vllm:8000/v1"
|
| 26 |
+
|
| 27 |
+
SEARCH_TOOL = ToolDefinition(
|
| 28 |
+
name="search_documents",
|
| 29 |
+
description="Search docs",
|
| 30 |
+
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _ok_response(content="ok", tool_calls=None, prompt_tokens=10, completion_tokens=5):
|
| 35 |
+
"""Build a minimal OpenAI-format chat completion response."""
|
| 36 |
+
message: dict = {"role": "assistant", "content": content}
|
| 37 |
+
if tool_calls:
|
| 38 |
+
message["tool_calls"] = tool_calls
|
| 39 |
+
message["content"] = None
|
| 40 |
+
return {
|
| 41 |
+
"id": "chatcmpl-test",
|
| 42 |
+
"object": "chat.completion",
|
| 43 |
+
"model": "mistralai/Mistral-7B-Instruct-v0.3",
|
| 44 |
+
"choices": [{"index": 0, "message": message, "finish_reason": "stop"}],
|
| 45 |
+
"usage": {
|
| 46 |
+
"prompt_tokens": prompt_tokens,
|
| 47 |
+
"completion_tokens": completion_tokens,
|
| 48 |
+
"total_tokens": prompt_tokens + completion_tokens,
|
| 49 |
+
},
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _probe_response_with_tool_calls():
|
| 54 |
+
"""Response to the tool-calling detection probe — model uses tools."""
|
| 55 |
+
return _ok_response(
|
| 56 |
+
tool_calls=[
|
| 57 |
+
{
|
| 58 |
+
"id": "call_probe",
|
| 59 |
+
"type": "function",
|
| 60 |
+
"function": {
|
| 61 |
+
"name": "test_probe",
|
| 62 |
+
"arguments": json.dumps({"x": "hello"}),
|
| 63 |
+
},
|
| 64 |
+
}
|
| 65 |
+
],
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _probe_response_without_tool_calls():
|
| 70 |
+
"""Response to the tool-calling detection probe — model ignores tools."""
|
| 71 |
+
return _ok_response(content="I cannot use tools.")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# --- Factory ---
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class TestSelfHostedFactory:
|
| 78 |
+
def test_factory_creates_selfhosted_provider(self, monkeypatch):
|
| 79 |
+
"""Factory returns SelfHostedProvider for 'selfhosted' config."""
|
| 80 |
+
monkeypatch.setenv("MODAL_VLLM_URL", FAKE_URL)
|
| 81 |
+
config = AppConfig(provider=ProviderConfig(default="selfhosted"))
|
| 82 |
+
provider = create_provider(config)
|
| 83 |
+
assert isinstance(provider, SelfHostedProvider)
|
| 84 |
+
|
| 85 |
+
def test_factory_raises_for_unknown_provider(self):
|
| 86 |
+
config = AppConfig(provider=ProviderConfig(default="nonexistent"))
|
| 87 |
+
with pytest.raises(ValueError, match="Unknown provider"):
|
| 88 |
+
create_provider(config)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# --- Config-based settings ---
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class TestSelfHostedConfig:
|
| 95 |
+
def test_reads_base_url_from_config(self, monkeypatch):
|
| 96 |
+
"""Config selfhosted.base_url takes precedence over env var."""
|
| 97 |
+
monkeypatch.setenv("MODAL_VLLM_URL", "http://env-url:8000/v1")
|
| 98 |
+
config = AppConfig(
|
| 99 |
+
provider=ProviderConfig(
|
| 100 |
+
default="selfhosted",
|
| 101 |
+
selfhosted=SelfHostedConfig(base_url="http://config-url:8000/v1"),
|
| 102 |
+
)
|
| 103 |
+
)
|
| 104 |
+
provider = SelfHostedProvider(config)
|
| 105 |
+
assert provider.base_url == "http://config-url:8000/v1"
|
| 106 |
+
|
| 107 |
+
def test_falls_back_to_env_when_config_empty(self, monkeypatch):
|
| 108 |
+
"""Empty config falls back to MODAL_VLLM_URL env var."""
|
| 109 |
+
monkeypatch.setenv("MODAL_VLLM_URL", "http://env-url:8000/v1")
|
| 110 |
+
config = AppConfig(provider=ProviderConfig(default="selfhosted"))
|
| 111 |
+
provider = SelfHostedProvider(config)
|
| 112 |
+
assert provider.base_url == "http://env-url:8000/v1"
|
| 113 |
+
|
| 114 |
+
def test_reads_api_key_from_config(self, monkeypatch):
|
| 115 |
+
monkeypatch.delenv("MODAL_AUTH_TOKEN", raising=False)
|
| 116 |
+
config = AppConfig(
|
| 117 |
+
provider=ProviderConfig(
|
| 118 |
+
default="selfhosted",
|
| 119 |
+
selfhosted=SelfHostedConfig(
|
| 120 |
+
base_url=FAKE_URL, api_key="config-key-123"
|
| 121 |
+
),
|
| 122 |
+
)
|
| 123 |
+
)
|
| 124 |
+
provider = SelfHostedProvider(config)
|
| 125 |
+
assert provider.client.headers.get("authorization") == "Bearer config-key-123"
|
| 126 |
+
|
| 127 |
+
def test_timeout_from_config(self, monkeypatch):
|
| 128 |
+
monkeypatch.setenv("MODAL_VLLM_URL", FAKE_URL)
|
| 129 |
+
config = AppConfig(
|
| 130 |
+
provider=ProviderConfig(
|
| 131 |
+
default="selfhosted",
|
| 132 |
+
selfhosted=SelfHostedConfig(timeout_seconds=42.0),
|
| 133 |
+
)
|
| 134 |
+
)
|
| 135 |
+
provider = SelfHostedProvider(config)
|
| 136 |
+
assert provider.client.timeout.read == 42.0
|
| 137 |
+
|
| 138 |
+
def test_config_yaml_selfhosted_block_not_dropped(self):
|
| 139 |
+
"""Pydantic accepts provider.selfhosted fields (regression for issue #3)."""
|
| 140 |
+
raw = {
|
| 141 |
+
"provider": {
|
| 142 |
+
"default": "selfhosted",
|
| 143 |
+
"selfhosted": {
|
| 144 |
+
"base_url": "http://yaml-url:8000/v1",
|
| 145 |
+
"model_name": "meta-llama/Llama-3-8B",
|
| 146 |
+
"api_key": "yaml-key",
|
| 147 |
+
"timeout_seconds": 60.0,
|
| 148 |
+
},
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
config = AppConfig.model_validate(raw)
|
| 152 |
+
assert config.provider.selfhosted.base_url == "http://yaml-url:8000/v1"
|
| 153 |
+
assert config.provider.selfhosted.model_name == "meta-llama/Llama-3-8B"
|
| 154 |
+
assert config.provider.selfhosted.api_key == "yaml-key"
|
| 155 |
+
assert config.provider.selfhosted.timeout_seconds == 60.0
|
| 156 |
+
|
| 157 |
+
def test_loads_selfhosted_local_yaml_from_disk(self):
|
| 158 |
+
"""selfhosted_local.yaml loads from disk with correct selfhosted settings."""
|
| 159 |
+
from pathlib import Path
|
| 160 |
+
|
| 161 |
+
from agent_bench.core.config import load_config
|
| 162 |
+
|
| 163 |
+
yaml_path = Path(__file__).resolve().parent.parent / "configs" / "selfhosted_local.yaml"
|
| 164 |
+
config = load_config(yaml_path)
|
| 165 |
+
assert config.provider.default == "selfhosted"
|
| 166 |
+
assert config.provider.selfhosted.base_url == "" # env var fallback
|
| 167 |
+
assert config.provider.selfhosted.model_name == "mistralai/Mistral-7B-Instruct-v0.3"
|
| 168 |
+
|
| 169 |
+
def test_loads_selfhosted_modal_yaml_from_disk(self):
|
| 170 |
+
"""selfhosted_modal.yaml loads from disk; base_url empty (env var fallback)."""
|
| 171 |
+
from pathlib import Path
|
| 172 |
+
|
| 173 |
+
from agent_bench.core.config import load_config
|
| 174 |
+
|
| 175 |
+
yaml_path = Path(__file__).resolve().parent.parent / "configs" / "selfhosted_modal.yaml"
|
| 176 |
+
config = load_config(yaml_path)
|
| 177 |
+
assert config.provider.default == "selfhosted"
|
| 178 |
+
assert config.provider.selfhosted.base_url == "" # falls back to MODAL_VLLM_URL
|
| 179 |
+
|
| 180 |
+
def test_default_fallback_port_does_not_collide_with_app(self, monkeypatch):
|
| 181 |
+
"""Default vLLM fallback URL must NOT use port 8000 (app's serving port)."""
|
| 182 |
+
monkeypatch.delenv("MODAL_VLLM_URL", raising=False)
|
| 183 |
+
config = AppConfig(provider=ProviderConfig(default="selfhosted"))
|
| 184 |
+
provider = SelfHostedProvider(config)
|
| 185 |
+
assert ":8000" not in provider.base_url
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# --- complete() ---
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class TestSelfHostedComplete:
|
| 192 |
+
@pytest.fixture
|
| 193 |
+
def provider(self, monkeypatch):
|
| 194 |
+
monkeypatch.setenv("MODAL_VLLM_URL", FAKE_URL)
|
| 195 |
+
config = AppConfig(provider=ProviderConfig(default="selfhosted"))
|
| 196 |
+
return SelfHostedProvider(config)
|
| 197 |
+
|
| 198 |
+
@pytest.mark.asyncio
|
| 199 |
+
async def test_complete_parses_response(self, provider):
|
| 200 |
+
"""SelfHostedProvider.complete() parses OpenAI-format response."""
|
| 201 |
+
mock_response = _ok_response(
|
| 202 |
+
content="Path params use curly braces. [source: fastapi.md]",
|
| 203 |
+
prompt_tokens=80,
|
| 204 |
+
completion_tokens=20,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
with respx.mock:
|
| 208 |
+
respx.post(f"{FAKE_URL}/chat/completions").mock(
|
| 209 |
+
return_value=httpx.Response(200, json=mock_response)
|
| 210 |
+
)
|
| 211 |
+
response = await provider.complete(
|
| 212 |
+
[Message(role=Role.USER, content="How do path params work?")]
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
assert response.content == "Path params use curly braces. [source: fastapi.md]"
|
| 216 |
+
assert response.tool_calls == []
|
| 217 |
+
assert response.provider == "selfhosted"
|
| 218 |
+
assert response.model == "mistralai/Mistral-7B-Instruct-v0.3"
|
| 219 |
+
assert response.usage.input_tokens == 80
|
| 220 |
+
assert response.usage.output_tokens == 20
|
| 221 |
+
assert response.latency_ms > 0
|
| 222 |
+
|
| 223 |
+
@pytest.mark.asyncio
|
| 224 |
+
async def test_complete_parses_tool_calls(self, provider):
|
| 225 |
+
"""SelfHostedProvider.complete() parses native tool_calls."""
|
| 226 |
+
# Pre-set tool support to skip detection probe
|
| 227 |
+
provider._supports_tool_calling = True
|
| 228 |
+
|
| 229 |
+
tool_response = _ok_response(
|
| 230 |
+
tool_calls=[
|
| 231 |
+
{
|
| 232 |
+
"id": "call_abc",
|
| 233 |
+
"type": "function",
|
| 234 |
+
"function": {
|
| 235 |
+
"name": "search_documents",
|
| 236 |
+
"arguments": json.dumps({"query": "path params"}),
|
| 237 |
+
},
|
| 238 |
+
}
|
| 239 |
+
],
|
| 240 |
+
prompt_tokens=60,
|
| 241 |
+
completion_tokens=15,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
with respx.mock:
|
| 245 |
+
respx.post(f"{FAKE_URL}/chat/completions").mock(
|
| 246 |
+
return_value=httpx.Response(200, json=tool_response)
|
| 247 |
+
)
|
| 248 |
+
response = await provider.complete(
|
| 249 |
+
[Message(role=Role.USER, content="search for path params")],
|
| 250 |
+
tools=[SEARCH_TOOL],
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
assert len(response.tool_calls) == 1
|
| 254 |
+
assert response.tool_calls[0].id == "call_abc"
|
| 255 |
+
assert response.tool_calls[0].name == "search_documents"
|
| 256 |
+
assert response.tool_calls[0].arguments == {"query": "path params"}
|
| 257 |
+
|
| 258 |
+
@pytest.mark.asyncio
|
| 259 |
+
async def test_complete_handles_malformed_tool_args(self, provider):
|
| 260 |
+
"""Malformed JSON in tool arguments falls back to empty dict."""
|
| 261 |
+
provider._supports_tool_calling = True
|
| 262 |
+
|
| 263 |
+
mock_response = _ok_response(
|
| 264 |
+
tool_calls=[
|
| 265 |
+
{
|
| 266 |
+
"id": "call_bad",
|
| 267 |
+
"type": "function",
|
| 268 |
+
"function": {
|
| 269 |
+
"name": "search_documents",
|
| 270 |
+
"arguments": "not valid json{{{",
|
| 271 |
+
},
|
| 272 |
+
}
|
| 273 |
+
],
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
with respx.mock:
|
| 277 |
+
respx.post(f"{FAKE_URL}/chat/completions").mock(
|
| 278 |
+
return_value=httpx.Response(200, json=mock_response)
|
| 279 |
+
)
|
| 280 |
+
response = await provider.complete(
|
| 281 |
+
[Message(role=Role.USER, content="test")]
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
assert len(response.tool_calls) == 1
|
| 285 |
+
assert response.tool_calls[0].arguments == {}
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
# --- Tool-calling detection ---
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class TestSelfHostedToolDetection:
|
| 292 |
+
@pytest.fixture
|
| 293 |
+
def provider(self, monkeypatch):
|
| 294 |
+
monkeypatch.setenv("MODAL_VLLM_URL", FAKE_URL)
|
| 295 |
+
config = AppConfig(provider=ProviderConfig(default="selfhosted"))
|
| 296 |
+
return SelfHostedProvider(config)
|
| 297 |
+
|
| 298 |
+
@pytest.mark.asyncio
|
| 299 |
+
async def test_detect_tool_calling_supported(self, provider):
|
| 300 |
+
"""Detection probe returns True when model responds with tool_calls."""
|
| 301 |
+
with respx.mock:
|
| 302 |
+
respx.post(f"{FAKE_URL}/chat/completions").mock(
|
| 303 |
+
return_value=httpx.Response(
|
| 304 |
+
200, json=_probe_response_with_tool_calls()
|
| 305 |
+
)
|
| 306 |
+
)
|
| 307 |
+
result = await provider._detect_tool_calling()
|
| 308 |
+
assert result is True
|
| 309 |
+
|
| 310 |
+
@pytest.mark.asyncio
|
| 311 |
+
async def test_detect_tool_calling_unsupported_400(self, provider):
|
| 312 |
+
"""Detection probe returns False on 400 (endpoint rejects tools)."""
|
| 313 |
+
with respx.mock:
|
| 314 |
+
respx.post(f"{FAKE_URL}/chat/completions").mock(
|
| 315 |
+
return_value=httpx.Response(
|
| 316 |
+
400, json={"error": "tools not supported"}
|
| 317 |
+
)
|
| 318 |
+
)
|
| 319 |
+
result = await provider._detect_tool_calling()
|
| 320 |
+
assert result is False
|
| 321 |
+
|
| 322 |
+
@pytest.mark.asyncio
|
| 323 |
+
async def test_detect_tool_calling_unsupported_no_tool_calls(self, provider):
|
| 324 |
+
"""Detection probe returns False when model ignores tools."""
|
| 325 |
+
with respx.mock:
|
| 326 |
+
respx.post(f"{FAKE_URL}/chat/completions").mock(
|
| 327 |
+
return_value=httpx.Response(
|
| 328 |
+
200, json=_probe_response_without_tool_calls()
|
| 329 |
+
)
|
| 330 |
+
)
|
| 331 |
+
result = await provider._detect_tool_calling()
|
| 332 |
+
assert result is False
|
| 333 |
+
|
| 334 |
+
@pytest.mark.asyncio
|
| 335 |
+
async def test_detect_transient_failure_returns_none(self, provider):
|
| 336 |
+
"""Transient failure (timeout, 5xx) returns None, not False."""
|
| 337 |
+
with respx.mock:
|
| 338 |
+
respx.post(f"{FAKE_URL}/chat/completions").mock(
|
| 339 |
+
side_effect=httpx.ReadTimeout("cold start")
|
| 340 |
+
)
|
| 341 |
+
result = await provider._detect_tool_calling()
|
| 342 |
+
assert result is None
|
| 343 |
+
|
| 344 |
+
@pytest.mark.asyncio
|
| 345 |
+
async def test_detect_5xx_returns_none(self, provider):
|
| 346 |
+
"""Server error returns None (transient), not False (definitive)."""
|
| 347 |
+
with respx.mock:
|
| 348 |
+
respx.post(f"{FAKE_URL}/chat/completions").mock(
|
| 349 |
+
return_value=httpx.Response(503, json={"error": "unavailable"})
|
| 350 |
+
)
|
| 351 |
+
result = await provider._detect_tool_calling()
|
| 352 |
+
assert result is None
|
| 353 |
+
|
| 354 |
+
@pytest.mark.asyncio
|
| 355 |
+
async def test_detection_runs_once_then_cached(self, provider):
|
| 356 |
+
"""Detection probe fires on first call with tools, cached thereafter."""
|
| 357 |
+
call_count = 0
|
| 358 |
+
|
| 359 |
+
def side_effect(request):
|
| 360 |
+
nonlocal call_count
|
| 361 |
+
call_count += 1
|
| 362 |
+
body = json.loads(request.content)
|
| 363 |
+
# Detection probe has test_probe tool
|
| 364 |
+
if any(
|
| 365 |
+
t.get("function", {}).get("name") == "test_probe"
|
| 366 |
+
for t in body.get("tools", [])
|
| 367 |
+
):
|
| 368 |
+
return httpx.Response(
|
| 369 |
+
200, json=_probe_response_with_tool_calls()
|
| 370 |
+
)
|
| 371 |
+
# Real request
|
| 372 |
+
return httpx.Response(200, json=_ok_response(
|
| 373 |
+
tool_calls=[{
|
| 374 |
+
"id": "call_real",
|
| 375 |
+
"type": "function",
|
| 376 |
+
"function": {
|
| 377 |
+
"name": "search_documents",
|
| 378 |
+
"arguments": json.dumps({"query": "test"}),
|
| 379 |
+
},
|
| 380 |
+
}],
|
| 381 |
+
))
|
| 382 |
+
|
| 383 |
+
with respx.mock:
|
| 384 |
+
respx.post(f"{FAKE_URL}/chat/completions").mock(
|
| 385 |
+
side_effect=side_effect
|
| 386 |
+
)
|
| 387 |
+
# First call: probe + real = 2 requests
|
| 388 |
+
await provider.complete(
|
| 389 |
+
[Message(role=Role.USER, content="test")],
|
| 390 |
+
tools=[SEARCH_TOOL],
|
| 391 |
+
)
|
| 392 |
+
# Second call: no probe = 1 request
|
| 393 |
+
await provider.complete(
|
| 394 |
+
[Message(role=Role.USER, content="test2")],
|
| 395 |
+
tools=[SEARCH_TOOL],
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
assert call_count == 3 # 1 probe + 2 real
|
| 399 |
+
assert provider._supports_tool_calling is True
|
| 400 |
+
|
| 401 |
+
@pytest.mark.asyncio
|
| 402 |
+
async def test_transient_failure_retries_on_next_call(self, provider):
|
| 403 |
+
"""Transient detection failure leaves _supports_tool_calling as None, retries."""
|
| 404 |
+
call_count = 0
|
| 405 |
+
|
| 406 |
+
def side_effect(request):
|
| 407 |
+
nonlocal call_count
|
| 408 |
+
call_count += 1
|
| 409 |
+
body = json.loads(request.content)
|
| 410 |
+
is_probe = any(
|
| 411 |
+
t.get("function", {}).get("name") == "test_probe"
|
| 412 |
+
for t in body.get("tools", [])
|
| 413 |
+
)
|
| 414 |
+
if is_probe:
|
| 415 |
+
if call_count == 1:
|
| 416 |
+
# First probe: transient failure
|
| 417 |
+
return httpx.Response(503, json={"error": "cold start"})
|
| 418 |
+
# Second probe: success
|
| 419 |
+
return httpx.Response(
|
| 420 |
+
200, json=_probe_response_with_tool_calls()
|
| 421 |
+
)
|
| 422 |
+
# Real request (fallback or native)
|
| 423 |
+
return httpx.Response(200, json=_ok_response())
|
| 424 |
+
|
| 425 |
+
with respx.mock:
|
| 426 |
+
respx.post(f"{FAKE_URL}/chat/completions").mock(
|
| 427 |
+
side_effect=side_effect
|
| 428 |
+
)
|
| 429 |
+
# First call: probe fails (transient) + real (fallback) = 2
|
| 430 |
+
await provider.complete(
|
| 431 |
+
[Message(role=Role.USER, content="test")],
|
| 432 |
+
tools=[SEARCH_TOOL],
|
| 433 |
+
)
|
| 434 |
+
assert provider._supports_tool_calling is None # NOT cached
|
| 435 |
+
|
| 436 |
+
# Second call: probe succeeds + real (native) = 2
|
| 437 |
+
await provider.complete(
|
| 438 |
+
[Message(role=Role.USER, content="test2")],
|
| 439 |
+
tools=[SEARCH_TOOL],
|
| 440 |
+
)
|
| 441 |
+
assert provider._supports_tool_calling is True # NOW cached
|
| 442 |
+
|
| 443 |
+
assert call_count == 4 # 2 probes + 2 real
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
# --- Prompt-based fallback ---
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
class TestSelfHostedPromptFallback:
|
| 450 |
+
@pytest.fixture
|
| 451 |
+
def provider(self, monkeypatch):
|
| 452 |
+
monkeypatch.setenv("MODAL_VLLM_URL", FAKE_URL)
|
| 453 |
+
config = AppConfig(provider=ProviderConfig(default="selfhosted"))
|
| 454 |
+
p = SelfHostedProvider(config)
|
| 455 |
+
p._supports_tool_calling = False # Force fallback mode
|
| 456 |
+
return p
|
| 457 |
+
|
| 458 |
+
@pytest.mark.asyncio
|
| 459 |
+
async def test_fallback_parses_tool_call_from_text(self, provider):
|
| 460 |
+
"""When tool calling is unsupported, parse tool calls from model text."""
|
| 461 |
+
tool_json = json.dumps(
|
| 462 |
+
{"tool_calls": [{"name": "search_documents", "arguments": {"query": "path params"}}]}
|
| 463 |
+
)
|
| 464 |
+
mock_response = _ok_response(content=tool_json)
|
| 465 |
+
|
| 466 |
+
with respx.mock:
|
| 467 |
+
route = respx.post(f"{FAKE_URL}/chat/completions").mock(
|
| 468 |
+
return_value=httpx.Response(200, json=mock_response)
|
| 469 |
+
)
|
| 470 |
+
response = await provider.complete(
|
| 471 |
+
[Message(role=Role.USER, content="search for path params")],
|
| 472 |
+
tools=[SEARCH_TOOL],
|
| 473 |
+
)
|
| 474 |
+
# Verify tools NOT in payload (prompt-based, not native)
|
| 475 |
+
sent_body = json.loads(route.calls[0].request.content)
|
| 476 |
+
assert "tools" not in sent_body
|
| 477 |
+
|
| 478 |
+
assert len(response.tool_calls) == 1
|
| 479 |
+
assert response.tool_calls[0].name == "search_documents"
|
| 480 |
+
assert response.tool_calls[0].arguments == {"query": "path params"}
|
| 481 |
+
assert response.content == "" # tool call replaces content
|
| 482 |
+
|
| 483 |
+
@pytest.mark.asyncio
|
| 484 |
+
async def test_fallback_injects_tool_prompt(self, provider):
|
| 485 |
+
"""When tool calling is unsupported, tool descriptions injected as system prompt."""
|
| 486 |
+
mock_response = _ok_response(content="Just a text answer.")
|
| 487 |
+
|
| 488 |
+
with respx.mock:
|
| 489 |
+
route = respx.post(f"{FAKE_URL}/chat/completions").mock(
|
| 490 |
+
return_value=httpx.Response(200, json=mock_response)
|
| 491 |
+
)
|
| 492 |
+
await provider.complete(
|
| 493 |
+
[Message(role=Role.USER, content="hello")],
|
| 494 |
+
tools=[SEARCH_TOOL],
|
| 495 |
+
)
|
| 496 |
+
sent_body = json.loads(route.calls[0].request.content)
|
| 497 |
+
|
| 498 |
+
# System message should contain tool descriptions
|
| 499 |
+
system_msg = sent_body["messages"][0]
|
| 500 |
+
assert system_msg["role"] == "system"
|
| 501 |
+
assert "search_documents" in system_msg["content"]
|
| 502 |
+
assert "tool_calls" in system_msg["content"]
|
| 503 |
+
|
| 504 |
+
@pytest.mark.asyncio
|
| 505 |
+
async def test_fallback_handles_non_dict_arguments(self, provider):
|
| 506 |
+
"""Non-dict arguments in prompt-based JSON degrades to empty dict, not crash."""
|
| 507 |
+
tool_json = json.dumps(
|
| 508 |
+
{"tool_calls": [{"name": "search_documents", "arguments": "oops"}]}
|
| 509 |
+
)
|
| 510 |
+
mock_response = _ok_response(content=tool_json)
|
| 511 |
+
|
| 512 |
+
with respx.mock:
|
| 513 |
+
respx.post(f"{FAKE_URL}/chat/completions").mock(
|
| 514 |
+
return_value=httpx.Response(200, json=mock_response)
|
| 515 |
+
)
|
| 516 |
+
response = await provider.complete(
|
| 517 |
+
[Message(role=Role.USER, content="test")],
|
| 518 |
+
tools=[SEARCH_TOOL],
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
assert len(response.tool_calls) == 1
|
| 522 |
+
assert response.tool_calls[0].name == "search_documents"
|
| 523 |
+
assert response.tool_calls[0].arguments == {}
|
| 524 |
+
|
| 525 |
+
@pytest.mark.asyncio
|
| 526 |
+
async def test_fallback_returns_text_when_no_tool_json(self, provider):
|
| 527 |
+
"""When model responds with plain text (not JSON), return as content."""
|
| 528 |
+
mock_response = _ok_response(content="I don't know how to use tools.")
|
| 529 |
+
|
| 530 |
+
with respx.mock:
|
| 531 |
+
respx.post(f"{FAKE_URL}/chat/completions").mock(
|
| 532 |
+
return_value=httpx.Response(200, json=mock_response)
|
| 533 |
+
)
|
| 534 |
+
response = await provider.complete(
|
| 535 |
+
[Message(role=Role.USER, content="test")],
|
| 536 |
+
tools=[SEARCH_TOOL],
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
assert response.tool_calls == []
|
| 540 |
+
assert response.content == "I don't know how to use tools."
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
# --- Retry and timeout ---
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
class TestSelfHostedRetryAndTimeout:
|
| 547 |
+
@pytest.fixture
|
| 548 |
+
def provider(self, monkeypatch):
|
| 549 |
+
monkeypatch.setenv("MODAL_VLLM_URL", FAKE_URL)
|
| 550 |
+
config = AppConfig(
|
| 551 |
+
provider=ProviderConfig(default="selfhosted"),
|
| 552 |
+
retry=RetryConfig(max_retries=2, base_delay=0.01, max_delay=0.05),
|
| 553 |
+
)
|
| 554 |
+
return SelfHostedProvider(config)
|
| 555 |
+
|
| 556 |
+
@pytest.mark.asyncio
|
| 557 |
+
async def test_retries_on_429_then_succeeds(self, provider):
|
| 558 |
+
"""Provider retries on 429 and succeeds on next attempt."""
|
| 559 |
+
call_count = 0
|
| 560 |
+
|
| 561 |
+
def side_effect(request):
|
| 562 |
+
nonlocal call_count
|
| 563 |
+
call_count += 1
|
| 564 |
+
if call_count == 1:
|
| 565 |
+
return httpx.Response(429, json={"error": "rate limited"})
|
| 566 |
+
return httpx.Response(200, json=_ok_response())
|
| 567 |
+
|
| 568 |
+
with respx.mock:
|
| 569 |
+
respx.post(f"{FAKE_URL}/chat/completions").mock(
|
| 570 |
+
side_effect=side_effect
|
| 571 |
+
)
|
| 572 |
+
response = await provider.complete(
|
| 573 |
+
[Message(role=Role.USER, content="test")]
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
assert response.content == "ok"
|
| 577 |
+
assert call_count == 2
|
| 578 |
+
|
| 579 |
+
@pytest.mark.asyncio
|
| 580 |
+
async def test_raises_rate_limit_after_exhausting_retries(self, provider):
|
| 581 |
+
"""Provider raises ProviderRateLimitError after all retries exhausted."""
|
| 582 |
+
with respx.mock:
|
| 583 |
+
respx.post(f"{FAKE_URL}/chat/completions").mock(
|
| 584 |
+
return_value=httpx.Response(429, json={"error": "rate limited"})
|
| 585 |
+
)
|
| 586 |
+
with pytest.raises(ProviderRateLimitError, match="Rate limited"):
|
| 587 |
+
await provider.complete(
|
| 588 |
+
[Message(role=Role.USER, content="test")]
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
@pytest.mark.asyncio
|
| 592 |
+
async def test_raises_timeout_error(self, provider):
|
| 593 |
+
"""Provider raises ProviderTimeoutError on httpx timeout."""
|
| 594 |
+
with respx.mock:
|
| 595 |
+
respx.post(f"{FAKE_URL}/chat/completions").mock(
|
| 596 |
+
side_effect=httpx.ReadTimeout("timed out")
|
| 597 |
+
)
|
| 598 |
+
with pytest.raises(ProviderTimeoutError, match="timed out"):
|
| 599 |
+
await provider.complete(
|
| 600 |
+
[Message(role=Role.USER, content="test")]
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
# --- Env var fallback ---
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
class TestSelfHostedEnvVars:
|
| 608 |
+
def test_reads_base_url_from_env(self, monkeypatch):
|
| 609 |
+
monkeypatch.setenv("MODAL_VLLM_URL", "http://my-modal-url:8000/v1")
|
| 610 |
+
config = AppConfig(provider=ProviderConfig(default="selfhosted"))
|
| 611 |
+
provider = SelfHostedProvider(config)
|
| 612 |
+
assert provider.base_url == "http://my-modal-url:8000/v1"
|
| 613 |
+
|
| 614 |
+
def test_reads_auth_token_from_env(self, monkeypatch):
|
| 615 |
+
monkeypatch.setenv("MODAL_VLLM_URL", FAKE_URL)
|
| 616 |
+
monkeypatch.setenv("MODAL_AUTH_TOKEN", "secret-token-123")
|
| 617 |
+
config = AppConfig(provider=ProviderConfig(default="selfhosted"))
|
| 618 |
+
provider = SelfHostedProvider(config)
|
| 619 |
+
assert provider.client.headers.get("authorization") == "Bearer secret-token-123"
|
| 620 |
+
|
| 621 |
+
def test_no_auth_header_when_no_token(self, monkeypatch):
|
| 622 |
+
monkeypatch.setenv("MODAL_VLLM_URL", FAKE_URL)
|
| 623 |
+
monkeypatch.delenv("MODAL_AUTH_TOKEN", raising=False)
|
| 624 |
+
config = AppConfig(provider=ProviderConfig(default="selfhosted"))
|
| 625 |
+
provider = SelfHostedProvider(config)
|
| 626 |
+
assert "authorization" not in {
|
| 627 |
+
k.lower() for k in provider.client.headers.keys()
|
| 628 |
+
}
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
# --- Streaming ---
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
class TestSelfHostedStream:
|
| 635 |
+
@pytest.fixture
|
| 636 |
+
def provider(self, monkeypatch):
|
| 637 |
+
monkeypatch.setenv("MODAL_VLLM_URL", FAKE_URL)
|
| 638 |
+
config = AppConfig(provider=ProviderConfig(default="selfhosted"))
|
| 639 |
+
return SelfHostedProvider(config)
|
| 640 |
+
|
| 641 |
+
@pytest.mark.asyncio
|
| 642 |
+
async def test_stream_yields_content_chunks(self, provider):
|
| 643 |
+
"""stream_complete() yields text chunks from SSE stream."""
|
| 644 |
+
sse_body = (
|
| 645 |
+
'data: {"choices":[{"delta":{"content":"Hello "}}]}\n\n'
|
| 646 |
+
'data: {"choices":[{"delta":{"content":"world"}}]}\n\n'
|
| 647 |
+
"data: [DONE]\n\n"
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
with respx.mock:
|
| 651 |
+
respx.post(f"{FAKE_URL}/chat/completions").mock(
|
| 652 |
+
return_value=httpx.Response(
|
| 653 |
+
200,
|
| 654 |
+
stream=httpx.ByteStream(sse_body.encode()),
|
| 655 |
+
headers={"content-type": "text/event-stream"},
|
| 656 |
+
)
|
| 657 |
+
)
|
| 658 |
+
chunks = []
|
| 659 |
+
async for chunk in provider.stream_complete(
|
| 660 |
+
[Message(role=Role.USER, content="Hi")]
|
| 661 |
+
):
|
| 662 |
+
chunks.append(chunk)
|
| 663 |
+
|
| 664 |
+
assert chunks == ["Hello ", "world"]
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
# --- format_tools ---
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
class TestSelfHostedFormatTools:
|
| 671 |
+
def test_format_tools_uses_openai_schema(self, monkeypatch):
|
| 672 |
+
monkeypatch.setenv("MODAL_VLLM_URL", FAKE_URL)
|
| 673 |
+
config = AppConfig(provider=ProviderConfig(default="selfhosted"))
|
| 674 |
+
provider = SelfHostedProvider(config)
|
| 675 |
+
tools = [
|
| 676 |
+
ToolDefinition(
|
| 677 |
+
name="search_documents",
|
| 678 |
+
description="Search docs",
|
| 679 |
+
parameters={
|
| 680 |
+
"type": "object",
|
| 681 |
+
"properties": {"query": {"type": "string"}},
|
| 682 |
+
"required": ["query"],
|
| 683 |
+
},
|
| 684 |
+
)
|
| 685 |
+
]
|
| 686 |
+
formatted = provider.format_tools(tools)
|
| 687 |
+
assert formatted[0]["type"] == "function"
|
| 688 |
+
assert formatted[0]["function"]["name"] == "search_documents"
|
| 689 |
+
assert formatted[0]["function"]["parameters"]["required"] == ["query"]
|
|
@@ -151,6 +151,101 @@ class TestMetricsEndpoint:
|
|
| 151 |
assert "errors_total" in data
|
| 152 |
assert "avg_cost_per_query_usd" in data
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
class TestMiddleware:
|
| 156 |
@pytest.mark.asyncio
|
|
|
|
| 151 |
assert "errors_total" in data
|
| 152 |
assert "avg_cost_per_query_usd" in data
|
| 153 |
|
| 154 |
+
@pytest.mark.asyncio
|
| 155 |
+
async def test_prometheus_endpoint_returns_text_exposition(self, test_app):
|
| 156 |
+
async with AsyncClient(
|
| 157 |
+
transport=ASGITransport(app=test_app), base_url="http://test"
|
| 158 |
+
) as client:
|
| 159 |
+
response = await client.get("/metrics/prometheus")
|
| 160 |
+
assert response.status_code == 200
|
| 161 |
+
assert "text/plain" in response.headers["content-type"]
|
| 162 |
+
body = response.text
|
| 163 |
+
assert "# TYPE agent_bench_requests_total counter" in body
|
| 164 |
+
assert "agent_bench_requests_total " in body
|
| 165 |
+
assert "# TYPE agent_bench_latency_p95_ms gauge" in body
|
| 166 |
+
assert "agent_bench_latency_p95_ms " in body
|
| 167 |
+
assert "# TYPE agent_bench_errors_total counter" in body
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class TestHealthCheckProbesProvider:
|
| 171 |
+
@pytest.mark.asyncio
|
| 172 |
+
async def test_healthy_when_provider_health_check_passes(self, test_app):
|
| 173 |
+
"""MockProvider.health_check() returns True (default), so status=healthy."""
|
| 174 |
+
async with AsyncClient(
|
| 175 |
+
transport=ASGITransport(app=test_app), base_url="http://test"
|
| 176 |
+
) as client:
|
| 177 |
+
response = await client.get("/health")
|
| 178 |
+
assert response.status_code == 200
|
| 179 |
+
data = response.json()
|
| 180 |
+
assert data["status"] == "healthy"
|
| 181 |
+
assert data["provider_available"] is True
|
| 182 |
+
|
| 183 |
+
@pytest.mark.asyncio
|
| 184 |
+
async def test_degraded_when_provider_health_check_fails(self):
|
| 185 |
+
"""Provider whose health_check() returns False -> status=degraded."""
|
| 186 |
+
from fastapi import FastAPI
|
| 187 |
+
|
| 188 |
+
class UnhealthyProvider(MockProvider):
|
| 189 |
+
async def health_check(self) -> bool:
|
| 190 |
+
return False
|
| 191 |
+
|
| 192 |
+
app = FastAPI()
|
| 193 |
+
registry = ToolRegistry()
|
| 194 |
+
registry.register(FakeSearchTool())
|
| 195 |
+
provider = UnhealthyProvider()
|
| 196 |
+
orchestrator = Orchestrator(provider=provider, registry=registry, max_iterations=1)
|
| 197 |
+
app.state.orchestrator = orchestrator
|
| 198 |
+
app.state.store = HybridStore(dimension=384)
|
| 199 |
+
app.state.config = AppConfig(provider=ProviderConfig(default="mock"))
|
| 200 |
+
app.state.system_prompt = "test"
|
| 201 |
+
app.state.start_time = time.time()
|
| 202 |
+
app.state.metrics = MetricsCollector()
|
| 203 |
+
app.add_middleware(RequestMiddleware)
|
| 204 |
+
from agent_bench.serving.routes import router
|
| 205 |
+
app.include_router(router)
|
| 206 |
+
|
| 207 |
+
async with AsyncClient(
|
| 208 |
+
transport=ASGITransport(app=app), base_url="http://test"
|
| 209 |
+
) as client:
|
| 210 |
+
response = await client.get("/health")
|
| 211 |
+
assert response.status_code == 200
|
| 212 |
+
data = response.json()
|
| 213 |
+
assert data["status"] == "degraded"
|
| 214 |
+
assert data["provider_available"] is False
|
| 215 |
+
|
| 216 |
+
@pytest.mark.asyncio
|
| 217 |
+
async def test_degraded_when_provider_health_check_raises(self):
|
| 218 |
+
"""Provider whose health_check() raises -> status=degraded."""
|
| 219 |
+
from fastapi import FastAPI
|
| 220 |
+
|
| 221 |
+
class CrashingProvider(MockProvider):
|
| 222 |
+
async def health_check(self) -> bool:
|
| 223 |
+
raise ConnectionError("upstream unreachable")
|
| 224 |
+
|
| 225 |
+
app = FastAPI()
|
| 226 |
+
registry = ToolRegistry()
|
| 227 |
+
registry.register(FakeSearchTool())
|
| 228 |
+
provider = CrashingProvider()
|
| 229 |
+
orchestrator = Orchestrator(provider=provider, registry=registry, max_iterations=1)
|
| 230 |
+
app.state.orchestrator = orchestrator
|
| 231 |
+
app.state.store = HybridStore(dimension=384)
|
| 232 |
+
app.state.config = AppConfig(provider=ProviderConfig(default="mock"))
|
| 233 |
+
app.state.system_prompt = "test"
|
| 234 |
+
app.state.start_time = time.time()
|
| 235 |
+
app.state.metrics = MetricsCollector()
|
| 236 |
+
app.add_middleware(RequestMiddleware)
|
| 237 |
+
from agent_bench.serving.routes import router
|
| 238 |
+
app.include_router(router)
|
| 239 |
+
|
| 240 |
+
async with AsyncClient(
|
| 241 |
+
transport=ASGITransport(app=app), base_url="http://test"
|
| 242 |
+
) as client:
|
| 243 |
+
response = await client.get("/health")
|
| 244 |
+
assert response.status_code == 200
|
| 245 |
+
data = response.json()
|
| 246 |
+
assert data["status"] == "degraded"
|
| 247 |
+
assert data["provider_available"] is False
|
| 248 |
+
|
| 249 |
|
| 250 |
class TestMiddleware:
|
| 251 |
@pytest.mark.asyncio
|