Spaces:
Sleeping
Sleeping
| """Tests for the SelfHostedProvider (OpenAI-compatible endpoint).""" | |
| import json | |
| import httpx | |
| import pytest | |
| import respx | |
| from agent_bench.core.config import ( | |
| AppConfig, | |
| ProviderConfig, | |
| RetryConfig, | |
| SelfHostedConfig, | |
| ) | |
| from agent_bench.core.provider import ( | |
| ProviderRateLimitError, | |
| ProviderTimeoutError, | |
| SelfHostedProvider, | |
| create_provider, | |
| ) | |
| from agent_bench.core.types import Message, Role, ToolDefinition | |
| # --- Helpers --- | |
| FAKE_URL = "http://fake-vllm:8000/v1" | |
| SEARCH_TOOL = ToolDefinition( | |
| name="search_documents", | |
| description="Search docs", | |
| parameters={"type": "object", "properties": {"query": {"type": "string"}}}, | |
| ) | |
| def _ok_response(content="ok", tool_calls=None, prompt_tokens=10, completion_tokens=5): | |
| """Build a minimal OpenAI-format chat completion response.""" | |
| message: dict = {"role": "assistant", "content": content} | |
| if tool_calls: | |
| message["tool_calls"] = tool_calls | |
| message["content"] = None | |
| return { | |
| "id": "chatcmpl-test", | |
| "object": "chat.completion", | |
| "model": "mistralai/Mistral-7B-Instruct-v0.3", | |
| "choices": [{"index": 0, "message": message, "finish_reason": "stop"}], | |
| "usage": { | |
| "prompt_tokens": prompt_tokens, | |
| "completion_tokens": completion_tokens, | |
| "total_tokens": prompt_tokens + completion_tokens, | |
| }, | |
| } | |
| def _probe_response_with_tool_calls(): | |
| """Response to the tool-calling detection probe — model uses tools.""" | |
| return _ok_response( | |
| tool_calls=[ | |
| { | |
| "id": "call_probe", | |
| "type": "function", | |
| "function": { | |
| "name": "test_probe", | |
| "arguments": json.dumps({"x": "hello"}), | |
| }, | |
| } | |
| ], | |
| ) | |
| def _probe_response_without_tool_calls(): | |
| """Response to the tool-calling detection probe — model ignores tools.""" | |
| return _ok_response(content="I cannot use tools.") | |
| # --- Factory --- | |
| class TestSelfHostedFactory: | |
| def test_factory_creates_selfhosted_provider(self, monkeypatch): | |
| """Factory returns SelfHostedProvider for 'selfhosted' config.""" | |
| monkeypatch.setenv("MODAL_VLLM_URL", FAKE_URL) | |
| config = AppConfig(provider=ProviderConfig(default="selfhosted")) | |
| provider = create_provider(config) | |
| assert isinstance(provider, SelfHostedProvider) | |
| def test_factory_raises_for_unknown_provider(self): | |
| config = AppConfig(provider=ProviderConfig(default="nonexistent")) | |
| with pytest.raises(ValueError, match="Unknown provider"): | |
| create_provider(config) | |
| # --- Config-based settings --- | |
| class TestSelfHostedConfig: | |
| def test_reads_base_url_from_config(self, monkeypatch): | |
| """Config selfhosted.base_url takes precedence over env var.""" | |
| monkeypatch.setenv("MODAL_VLLM_URL", "http://env-url:8000/v1") | |
| config = AppConfig( | |
| provider=ProviderConfig( | |
| default="selfhosted", | |
| selfhosted=SelfHostedConfig(base_url="http://config-url:8000/v1"), | |
| ) | |
| ) | |
| provider = SelfHostedProvider(config) | |
| assert provider.base_url == "http://config-url:8000/v1" | |
| def test_falls_back_to_env_when_config_empty(self, monkeypatch): | |
| """Empty config falls back to MODAL_VLLM_URL env var.""" | |
| monkeypatch.setenv("MODAL_VLLM_URL", "http://env-url:8000/v1") | |
| config = AppConfig(provider=ProviderConfig(default="selfhosted")) | |
| provider = SelfHostedProvider(config) | |
| assert provider.base_url == "http://env-url:8000/v1" | |
| def test_reads_api_key_from_config(self, monkeypatch): | |
| monkeypatch.delenv("MODAL_AUTH_TOKEN", raising=False) | |
| config = AppConfig( | |
| provider=ProviderConfig( | |
| default="selfhosted", | |
| selfhosted=SelfHostedConfig( | |
| base_url=FAKE_URL, api_key="config-key-123" | |
| ), | |
| ) | |
| ) | |
| provider = SelfHostedProvider(config) | |
| assert provider.client.headers.get("authorization") == "Bearer config-key-123" | |
| def test_timeout_from_config(self, monkeypatch): | |
| monkeypatch.setenv("MODAL_VLLM_URL", FAKE_URL) | |
| config = AppConfig( | |
| provider=ProviderConfig( | |
| default="selfhosted", | |
| selfhosted=SelfHostedConfig(timeout_seconds=42.0), | |
| ) | |
| ) | |
| provider = SelfHostedProvider(config) | |
| assert provider.client.timeout.read == 42.0 | |
| def test_config_yaml_selfhosted_block_not_dropped(self): | |
| """Pydantic accepts provider.selfhosted fields (regression for issue #3).""" | |
| raw = { | |
| "provider": { | |
| "default": "selfhosted", | |
| "selfhosted": { | |
| "base_url": "http://yaml-url:8000/v1", | |
| "model_name": "meta-llama/Llama-3-8B", | |
| "api_key": "yaml-key", | |
| "timeout_seconds": 60.0, | |
| }, | |
| } | |
| } | |
| config = AppConfig.model_validate(raw) | |
| assert config.provider.selfhosted.base_url == "http://yaml-url:8000/v1" | |
| assert config.provider.selfhosted.model_name == "meta-llama/Llama-3-8B" | |
| assert config.provider.selfhosted.api_key == "yaml-key" | |
| assert config.provider.selfhosted.timeout_seconds == 60.0 | |
| def test_loads_selfhosted_local_yaml_from_disk(self): | |
| """selfhosted_local.yaml loads from disk with correct selfhosted settings.""" | |
| from pathlib import Path | |
| from agent_bench.core.config import load_config | |
| yaml_path = Path(__file__).resolve().parent.parent / "configs" / "selfhosted_local.yaml" | |
| config = load_config(yaml_path) | |
| assert config.provider.default == "selfhosted" | |
| assert config.provider.selfhosted.base_url == "" # env var fallback | |
| assert config.provider.selfhosted.model_name == "mistralai/Mistral-7B-Instruct-v0.3" | |
| def test_loads_selfhosted_modal_yaml_from_disk(self): | |
| """selfhosted_modal.yaml loads from disk; base_url empty (env var fallback).""" | |
| from pathlib import Path | |
| from agent_bench.core.config import load_config | |
| yaml_path = Path(__file__).resolve().parent.parent / "configs" / "selfhosted_modal.yaml" | |
| config = load_config(yaml_path) | |
| assert config.provider.default == "selfhosted" | |
| assert config.provider.selfhosted.base_url == "" # falls back to MODAL_VLLM_URL | |
| def test_default_fallback_port_does_not_collide_with_app(self, monkeypatch): | |
| """Default vLLM fallback URL must NOT use port 8000 (app's serving port).""" | |
| monkeypatch.delenv("MODAL_VLLM_URL", raising=False) | |
| config = AppConfig(provider=ProviderConfig(default="selfhosted")) | |
| provider = SelfHostedProvider(config) | |
| assert ":8000" not in provider.base_url | |
| # --- complete() --- | |
| class TestSelfHostedComplete: | |
| def provider(self, monkeypatch): | |
| monkeypatch.setenv("MODAL_VLLM_URL", FAKE_URL) | |
| config = AppConfig(provider=ProviderConfig(default="selfhosted")) | |
| return SelfHostedProvider(config) | |
| async def test_complete_parses_response(self, provider): | |
| """SelfHostedProvider.complete() parses OpenAI-format response.""" | |
| mock_response = _ok_response( | |
| content="Path params use curly braces. [source: fastapi.md]", | |
| prompt_tokens=80, | |
| completion_tokens=20, | |
| ) | |
| with respx.mock: | |
| respx.post(f"{FAKE_URL}/chat/completions").mock( | |
| return_value=httpx.Response(200, json=mock_response) | |
| ) | |
| response = await provider.complete( | |
| [Message(role=Role.USER, content="How do path params work?")] | |
| ) | |
| assert response.content == "Path params use curly braces. [source: fastapi.md]" | |
| assert response.tool_calls == [] | |
| assert response.provider == "selfhosted" | |
| assert response.model == "mistralai/Mistral-7B-Instruct-v0.3" | |
| assert response.usage.input_tokens == 80 | |
| assert response.usage.output_tokens == 20 | |
| assert response.latency_ms > 0 | |
| async def test_complete_parses_tool_calls(self, provider): | |
| """SelfHostedProvider.complete() parses native tool_calls.""" | |
| # Pre-set tool support to skip detection probe | |
| provider._supports_tool_calling = True | |
| tool_response = _ok_response( | |
| tool_calls=[ | |
| { | |
| "id": "call_abc", | |
| "type": "function", | |
| "function": { | |
| "name": "search_documents", | |
| "arguments": json.dumps({"query": "path params"}), | |
| }, | |
| } | |
| ], | |
| prompt_tokens=60, | |
| completion_tokens=15, | |
| ) | |
| with respx.mock: | |
| respx.post(f"{FAKE_URL}/chat/completions").mock( | |
| return_value=httpx.Response(200, json=tool_response) | |
| ) | |
| response = await provider.complete( | |
| [Message(role=Role.USER, content="search for path params")], | |
| tools=[SEARCH_TOOL], | |
| ) | |
| assert len(response.tool_calls) == 1 | |
| assert response.tool_calls[0].id == "call_abc" | |
| assert response.tool_calls[0].name == "search_documents" | |
| assert response.tool_calls[0].arguments == {"query": "path params"} | |
| async def test_complete_handles_malformed_tool_args(self, provider): | |
| """Malformed JSON in tool arguments falls back to empty dict.""" | |
| provider._supports_tool_calling = True | |
| mock_response = _ok_response( | |
| tool_calls=[ | |
| { | |
| "id": "call_bad", | |
| "type": "function", | |
| "function": { | |
| "name": "search_documents", | |
| "arguments": "not valid json{{{", | |
| }, | |
| } | |
| ], | |
| ) | |
| with respx.mock: | |
| respx.post(f"{FAKE_URL}/chat/completions").mock( | |
| return_value=httpx.Response(200, json=mock_response) | |
| ) | |
| response = await provider.complete( | |
| [Message(role=Role.USER, content="test")] | |
| ) | |
| assert len(response.tool_calls) == 1 | |
| assert response.tool_calls[0].arguments == {} | |
| # --- Tool-calling detection --- | |
| class TestSelfHostedToolDetection: | |
| def provider(self, monkeypatch): | |
| monkeypatch.setenv("MODAL_VLLM_URL", FAKE_URL) | |
| config = AppConfig(provider=ProviderConfig(default="selfhosted")) | |
| return SelfHostedProvider(config) | |
| async def test_detect_tool_calling_supported(self, provider): | |
| """Detection probe returns True when model responds with tool_calls.""" | |
| with respx.mock: | |
| respx.post(f"{FAKE_URL}/chat/completions").mock( | |
| return_value=httpx.Response( | |
| 200, json=_probe_response_with_tool_calls() | |
| ) | |
| ) | |
| result = await provider._detect_tool_calling() | |
| assert result is True | |
| async def test_detect_tool_calling_unsupported_400(self, provider): | |
| """Detection probe returns False on 400 (endpoint rejects tools).""" | |
| with respx.mock: | |
| respx.post(f"{FAKE_URL}/chat/completions").mock( | |
| return_value=httpx.Response( | |
| 400, json={"error": "tools not supported"} | |
| ) | |
| ) | |
| result = await provider._detect_tool_calling() | |
| assert result is False | |
| async def test_detect_tool_calling_unsupported_no_tool_calls(self, provider): | |
| """Detection probe returns False when model ignores tools.""" | |
| with respx.mock: | |
| respx.post(f"{FAKE_URL}/chat/completions").mock( | |
| return_value=httpx.Response( | |
| 200, json=_probe_response_without_tool_calls() | |
| ) | |
| ) | |
| result = await provider._detect_tool_calling() | |
| assert result is False | |
| async def test_detect_transient_failure_returns_none(self, provider): | |
| """Transient failure (timeout, 5xx) returns None, not False.""" | |
| with respx.mock: | |
| respx.post(f"{FAKE_URL}/chat/completions").mock( | |
| side_effect=httpx.ReadTimeout("cold start") | |
| ) | |
| result = await provider._detect_tool_calling() | |
| assert result is None | |
| async def test_detect_5xx_returns_none(self, provider): | |
| """Server error returns None (transient), not False (definitive).""" | |
| with respx.mock: | |
| respx.post(f"{FAKE_URL}/chat/completions").mock( | |
| return_value=httpx.Response(503, json={"error": "unavailable"}) | |
| ) | |
| result = await provider._detect_tool_calling() | |
| assert result is None | |
| async def test_detection_runs_once_then_cached(self, provider): | |
| """Detection probe fires on first call with tools, cached thereafter.""" | |
| call_count = 0 | |
| def side_effect(request): | |
| nonlocal call_count | |
| call_count += 1 | |
| body = json.loads(request.content) | |
| # Detection probe has test_probe tool | |
| if any( | |
| t.get("function", {}).get("name") == "test_probe" | |
| for t in body.get("tools", []) | |
| ): | |
| return httpx.Response( | |
| 200, json=_probe_response_with_tool_calls() | |
| ) | |
| # Real request | |
| return httpx.Response(200, json=_ok_response( | |
| tool_calls=[{ | |
| "id": "call_real", | |
| "type": "function", | |
| "function": { | |
| "name": "search_documents", | |
| "arguments": json.dumps({"query": "test"}), | |
| }, | |
| }], | |
| )) | |
| with respx.mock: | |
| respx.post(f"{FAKE_URL}/chat/completions").mock( | |
| side_effect=side_effect | |
| ) | |
| # First call: probe + real = 2 requests | |
| await provider.complete( | |
| [Message(role=Role.USER, content="test")], | |
| tools=[SEARCH_TOOL], | |
| ) | |
| # Second call: no probe = 1 request | |
| await provider.complete( | |
| [Message(role=Role.USER, content="test2")], | |
| tools=[SEARCH_TOOL], | |
| ) | |
| assert call_count == 3 # 1 probe + 2 real | |
| assert provider._supports_tool_calling is True | |
| async def test_transient_failure_retries_on_next_call(self, provider): | |
| """Transient detection failure leaves _supports_tool_calling as None, retries.""" | |
| call_count = 0 | |
| def side_effect(request): | |
| nonlocal call_count | |
| call_count += 1 | |
| body = json.loads(request.content) | |
| is_probe = any( | |
| t.get("function", {}).get("name") == "test_probe" | |
| for t in body.get("tools", []) | |
| ) | |
| if is_probe: | |
| if call_count == 1: | |
| # First probe: transient failure | |
| return httpx.Response(503, json={"error": "cold start"}) | |
| # Second probe: success | |
| return httpx.Response( | |
| 200, json=_probe_response_with_tool_calls() | |
| ) | |
| # Real request (fallback or native) | |
| return httpx.Response(200, json=_ok_response()) | |
| with respx.mock: | |
| respx.post(f"{FAKE_URL}/chat/completions").mock( | |
| side_effect=side_effect | |
| ) | |
| # First call: probe fails (transient) + real (fallback) = 2 | |
| await provider.complete( | |
| [Message(role=Role.USER, content="test")], | |
| tools=[SEARCH_TOOL], | |
| ) | |
| assert provider._supports_tool_calling is None # NOT cached | |
| # Second call: probe succeeds + real (native) = 2 | |
| await provider.complete( | |
| [Message(role=Role.USER, content="test2")], | |
| tools=[SEARCH_TOOL], | |
| ) | |
| assert provider._supports_tool_calling is True # NOW cached | |
| assert call_count == 4 # 2 probes + 2 real | |
| # --- Prompt-based fallback --- | |
| class TestSelfHostedPromptFallback: | |
| def provider(self, monkeypatch): | |
| monkeypatch.setenv("MODAL_VLLM_URL", FAKE_URL) | |
| config = AppConfig(provider=ProviderConfig(default="selfhosted")) | |
| p = SelfHostedProvider(config) | |
| p._supports_tool_calling = False # Force fallback mode | |
| return p | |
| async def test_fallback_parses_tool_call_from_text(self, provider): | |
| """When tool calling is unsupported, parse tool calls from model text.""" | |
| tool_json = json.dumps( | |
| {"tool_calls": [{"name": "search_documents", "arguments": {"query": "path params"}}]} | |
| ) | |
| mock_response = _ok_response(content=tool_json) | |
| with respx.mock: | |
| route = respx.post(f"{FAKE_URL}/chat/completions").mock( | |
| return_value=httpx.Response(200, json=mock_response) | |
| ) | |
| response = await provider.complete( | |
| [Message(role=Role.USER, content="search for path params")], | |
| tools=[SEARCH_TOOL], | |
| ) | |
| # Verify tools NOT in payload (prompt-based, not native) | |
| sent_body = json.loads(route.calls[0].request.content) | |
| assert "tools" not in sent_body | |
| assert len(response.tool_calls) == 1 | |
| assert response.tool_calls[0].name == "search_documents" | |
| assert response.tool_calls[0].arguments == {"query": "path params"} | |
| assert response.content == "" # tool call replaces content | |
| async def test_fallback_injects_tool_prompt(self, provider): | |
| """When tool calling is unsupported, tool descriptions injected as system prompt.""" | |
| mock_response = _ok_response(content="Just a text answer.") | |
| with respx.mock: | |
| route = respx.post(f"{FAKE_URL}/chat/completions").mock( | |
| return_value=httpx.Response(200, json=mock_response) | |
| ) | |
| await provider.complete( | |
| [Message(role=Role.USER, content="hello")], | |
| tools=[SEARCH_TOOL], | |
| ) | |
| sent_body = json.loads(route.calls[0].request.content) | |
| # System message should contain tool descriptions | |
| system_msg = sent_body["messages"][0] | |
| assert system_msg["role"] == "system" | |
| assert "search_documents" in system_msg["content"] | |
| assert "tool_calls" in system_msg["content"] | |
| async def test_fallback_handles_non_dict_arguments(self, provider): | |
| """Non-dict arguments in prompt-based JSON degrades to empty dict, not crash.""" | |
| tool_json = json.dumps( | |
| {"tool_calls": [{"name": "search_documents", "arguments": "oops"}]} | |
| ) | |
| mock_response = _ok_response(content=tool_json) | |
| with respx.mock: | |
| respx.post(f"{FAKE_URL}/chat/completions").mock( | |
| return_value=httpx.Response(200, json=mock_response) | |
| ) | |
| response = await provider.complete( | |
| [Message(role=Role.USER, content="test")], | |
| tools=[SEARCH_TOOL], | |
| ) | |
| assert len(response.tool_calls) == 1 | |
| assert response.tool_calls[0].name == "search_documents" | |
| assert response.tool_calls[0].arguments == {} | |
| async def test_fallback_returns_text_when_no_tool_json(self, provider): | |
| """When model responds with plain text (not JSON), return as content.""" | |
| mock_response = _ok_response(content="I don't know how to use tools.") | |
| with respx.mock: | |
| respx.post(f"{FAKE_URL}/chat/completions").mock( | |
| return_value=httpx.Response(200, json=mock_response) | |
| ) | |
| response = await provider.complete( | |
| [Message(role=Role.USER, content="test")], | |
| tools=[SEARCH_TOOL], | |
| ) | |
| assert response.tool_calls == [] | |
| assert response.content == "I don't know how to use tools." | |
| # --- Retry and timeout --- | |
| class TestSelfHostedRetryAndTimeout: | |
| def provider(self, monkeypatch): | |
| monkeypatch.setenv("MODAL_VLLM_URL", FAKE_URL) | |
| config = AppConfig( | |
| provider=ProviderConfig(default="selfhosted"), | |
| retry=RetryConfig(max_retries=2, base_delay=0.01, max_delay=0.05), | |
| ) | |
| return SelfHostedProvider(config) | |
| async def test_retries_on_429_then_succeeds(self, provider): | |
| """Provider retries on 429 and succeeds on next attempt.""" | |
| call_count = 0 | |
| def side_effect(request): | |
| nonlocal call_count | |
| call_count += 1 | |
| if call_count == 1: | |
| return httpx.Response(429, json={"error": "rate limited"}) | |
| return httpx.Response(200, json=_ok_response()) | |
| with respx.mock: | |
| respx.post(f"{FAKE_URL}/chat/completions").mock( | |
| side_effect=side_effect | |
| ) | |
| response = await provider.complete( | |
| [Message(role=Role.USER, content="test")] | |
| ) | |
| assert response.content == "ok" | |
| assert call_count == 2 | |
| async def test_raises_rate_limit_after_exhausting_retries(self, provider): | |
| """Provider raises ProviderRateLimitError after all retries exhausted.""" | |
| with respx.mock: | |
| respx.post(f"{FAKE_URL}/chat/completions").mock( | |
| return_value=httpx.Response(429, json={"error": "rate limited"}) | |
| ) | |
| with pytest.raises(ProviderRateLimitError, match="Rate limited"): | |
| await provider.complete( | |
| [Message(role=Role.USER, content="test")] | |
| ) | |
| async def test_raises_timeout_error(self, provider): | |
| """Provider raises ProviderTimeoutError on httpx timeout.""" | |
| with respx.mock: | |
| respx.post(f"{FAKE_URL}/chat/completions").mock( | |
| side_effect=httpx.ReadTimeout("timed out") | |
| ) | |
| with pytest.raises(ProviderTimeoutError, match="timed out"): | |
| await provider.complete( | |
| [Message(role=Role.USER, content="test")] | |
| ) | |
| # --- Env var fallback --- | |
| class TestSelfHostedEnvVars: | |
| def test_reads_base_url_from_env(self, monkeypatch): | |
| monkeypatch.setenv("MODAL_VLLM_URL", "http://my-modal-url:8000/v1") | |
| config = AppConfig(provider=ProviderConfig(default="selfhosted")) | |
| provider = SelfHostedProvider(config) | |
| assert provider.base_url == "http://my-modal-url:8000/v1" | |
| def test_reads_auth_token_from_env(self, monkeypatch): | |
| monkeypatch.setenv("MODAL_VLLM_URL", FAKE_URL) | |
| monkeypatch.setenv("MODAL_AUTH_TOKEN", "secret-token-123") | |
| config = AppConfig(provider=ProviderConfig(default="selfhosted")) | |
| provider = SelfHostedProvider(config) | |
| assert provider.client.headers.get("authorization") == "Bearer secret-token-123" | |
| def test_no_auth_header_when_no_token(self, monkeypatch): | |
| monkeypatch.setenv("MODAL_VLLM_URL", FAKE_URL) | |
| monkeypatch.delenv("MODAL_AUTH_TOKEN", raising=False) | |
| config = AppConfig(provider=ProviderConfig(default="selfhosted")) | |
| provider = SelfHostedProvider(config) | |
| assert "authorization" not in { | |
| k.lower() for k in provider.client.headers.keys() | |
| } | |
| # --- Streaming --- | |
| class TestSelfHostedStream: | |
| def provider(self, monkeypatch): | |
| monkeypatch.setenv("MODAL_VLLM_URL", FAKE_URL) | |
| config = AppConfig(provider=ProviderConfig(default="selfhosted")) | |
| return SelfHostedProvider(config) | |
| async def test_stream_yields_content_chunks(self, provider): | |
| """stream_complete() yields text chunks from SSE stream.""" | |
| sse_body = ( | |
| 'data: {"choices":[{"delta":{"content":"Hello "}}]}\n\n' | |
| 'data: {"choices":[{"delta":{"content":"world"}}]}\n\n' | |
| "data: [DONE]\n\n" | |
| ) | |
| with respx.mock: | |
| respx.post(f"{FAKE_URL}/chat/completions").mock( | |
| return_value=httpx.Response( | |
| 200, | |
| stream=httpx.ByteStream(sse_body.encode()), | |
| headers={"content-type": "text/event-stream"}, | |
| ) | |
| ) | |
| chunks = [] | |
| async for chunk in provider.stream_complete( | |
| [Message(role=Role.USER, content="Hi")] | |
| ): | |
| chunks.append(chunk) | |
| assert chunks == ["Hello ", "world"] | |
| # --- format_tools --- | |
| class TestSelfHostedFormatTools: | |
| def test_format_tools_uses_openai_schema(self, monkeypatch): | |
| monkeypatch.setenv("MODAL_VLLM_URL", FAKE_URL) | |
| config = AppConfig(provider=ProviderConfig(default="selfhosted")) | |
| provider = SelfHostedProvider(config) | |
| tools = [ | |
| ToolDefinition( | |
| name="search_documents", | |
| description="Search docs", | |
| parameters={ | |
| "type": "object", | |
| "properties": {"query": {"type": "string"}}, | |
| "required": ["query"], | |
| }, | |
| ) | |
| ] | |
| formatted = provider.format_tools(tools) | |
| assert formatted[0]["type"] == "function" | |
| assert formatted[0]["function"]["name"] == "search_documents" | |
| assert formatted[0]["function"]["parameters"]["required"] == ["query"] | |