Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Simple MCP client example with OAuth authentication support. | |
| This client connects to an MCP server using streamable HTTP transport with OAuth. | |
| """ | |
| import asyncio | |
| import os | |
| import threading | |
| import time | |
| import webbrowser | |
| from datetime import timedelta | |
| from http.server import BaseHTTPRequestHandler, HTTPServer | |
| from typing import Any | |
| from urllib.parse import parse_qs, urlparse | |
| from mcp.client.auth import OAuthClientProvider, TokenStorage | |
| from mcp.client.session import ClientSession | |
| from mcp.client.sse import sse_client | |
| from mcp.client.streamable_http import streamablehttp_client | |
| from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken | |
| class InMemoryTokenStorage(TokenStorage): | |
| """Simple in-memory token storage implementation.""" | |
| def __init__(self): | |
| self._tokens: OAuthToken | None = None | |
| self._client_info: OAuthClientInformationFull | None = None | |
| async def get_tokens(self) -> OAuthToken | None: | |
| return self._tokens | |
| async def set_tokens(self, tokens: OAuthToken) -> None: | |
| self._tokens = tokens | |
| async def get_client_info(self) -> OAuthClientInformationFull | None: | |
| return self._client_info | |
| async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: | |
| self._client_info = client_info | |
| class CallbackHandler(BaseHTTPRequestHandler): | |
| """Simple HTTP handler to capture OAuth callback.""" | |
| def __init__(self, request, client_address, server, callback_data): | |
| """Initialize with callback data storage.""" | |
| self.callback_data = callback_data | |
| super().__init__(request, client_address, server) | |
| def do_GET(self): | |
| """Handle GET request from OAuth redirect.""" | |
| parsed = urlparse(self.path) | |
| query_params = parse_qs(parsed.query) | |
| if "code" in query_params: | |
| self.callback_data["authorization_code"] = query_params["code"][0] | |
| self.callback_data["state"] = query_params.get("state", [None])[0] | |
| self.send_response(200) | |
| self.send_header("Content-type", "text/html") | |
| self.end_headers() | |
| self.wfile.write(b""" | |
| <html> | |
| <body> | |
| <h1>Authorization Successful!</h1> | |
| <p>You can close this window and return to the terminal.</p> | |
| <script>setTimeout(() => window.close(), 2000);</script> | |
| </body> | |
| </html> | |
| """) | |
| elif "error" in query_params: | |
| self.callback_data["error"] = query_params["error"][0] | |
| self.send_response(400) | |
| self.send_header("Content-type", "text/html") | |
| self.end_headers() | |
| self.wfile.write( | |
| f""" | |
| <html> | |
| <body> | |
| <h1>Authorization Failed</h1> | |
| <p>Error: {query_params["error"][0]}</p> | |
| <p>You can close this window and return to the terminal.</p> | |
| </body> | |
| </html> | |
| """.encode() | |
| ) | |
| else: | |
| self.send_response(404) | |
| self.end_headers() | |
| def log_message(self, format, *args): | |
| """Suppress default logging.""" | |
| pass | |
| class CallbackServer: | |
| """Simple server to handle OAuth callbacks.""" | |
| def __init__(self, port=3000): | |
| self.port = port | |
| self.server = None | |
| self.thread = None | |
| self.callback_data = {"authorization_code": None, "state": None, "error": None} | |
| def _create_handler_with_data(self): | |
| """Create a handler class with access to callback data.""" | |
| callback_data = self.callback_data | |
| class DataCallbackHandler(CallbackHandler): | |
| def __init__(self, request, client_address, server): | |
| super().__init__(request, client_address, server, callback_data) | |
| return DataCallbackHandler | |
| def start(self): | |
| """Start the callback server in a background thread.""" | |
| handler_class = self._create_handler_with_data() | |
| self.server = HTTPServer(("0.0.0.0", self.port), handler_class) | |
| self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) | |
| self.thread.start() | |
| print(f"π₯οΈ Started callback server on http://0.0.0.0:{self.port}") | |
| def stop(self): | |
| """Stop the callback server.""" | |
| if self.server: | |
| self.server.shutdown() | |
| self.server.server_close() | |
| if self.thread: | |
| self.thread.join(timeout=1) | |
| def wait_for_callback(self, timeout=300): | |
| """Wait for OAuth callback with timeout.""" | |
| start_time = time.time() | |
| while time.time() - start_time < timeout: | |
| if self.callback_data["authorization_code"]: | |
| return self.callback_data["authorization_code"] | |
| elif self.callback_data["error"]: | |
| raise Exception(f"OAuth error: {self.callback_data['error']}") | |
| time.sleep(0.1) | |
| raise Exception("Timeout waiting for OAuth callback") | |
| def get_state(self): | |
| """Get the received state parameter.""" | |
| return self.callback_data["state"] | |
| class SimpleAuthClient: | |
| """Simple MCP client with auth support.""" | |
| def __init__(self, server_url: str, transport_type: str = "streamable_http"): | |
| self.server_url = server_url | |
| self.transport_type = transport_type | |
| self.session: ClientSession | None = None | |
| async def connect(self): | |
| """Connect to the MCP server.""" | |
| print(f"π Attempting to connect to {self.server_url}...") | |
| try: | |
| callback_server = CallbackServer(port=3030) | |
| callback_server.start() | |
| async def callback_handler() -> tuple[str, str | None]: | |
| """Wait for OAuth callback and return auth code and state.""" | |
| print("β³ Waiting for authorization callback...") | |
| try: | |
| auth_code = callback_server.wait_for_callback(timeout=300) | |
| return auth_code, callback_server.get_state() | |
| finally: | |
| callback_server.stop() | |
| client_metadata_dict = { | |
| "client_name": "Simple Auth Client", | |
| "redirect_uris": ["http://localhost:3030/callback"], | |
| "grant_types": ["authorization_code", "refresh_token"], | |
| "response_types": ["code"], | |
| "token_endpoint_auth_method": "client_secret_post", | |
| } | |
| async def _default_redirect_handler(authorization_url: str) -> None: | |
| """Default redirect handler that opens the URL in a browser.""" | |
| print(f"Opening browser for authorization: {authorization_url}") | |
| webbrowser.open(authorization_url) | |
| # Create OAuth authentication handler using the new interface | |
| oauth_auth = OAuthClientProvider( | |
| server_url=self.server_url.replace("/mcp", ""), | |
| client_metadata=OAuthClientMetadata.model_validate( | |
| client_metadata_dict | |
| ), | |
| storage=InMemoryTokenStorage(), | |
| redirect_handler=_default_redirect_handler, | |
| callback_handler=callback_handler, | |
| ) | |
| # Create transport with auth handler based on transport type | |
| if self.transport_type == "sse": | |
| print("π‘ Opening SSE transport connection with auth...") | |
| async with sse_client( | |
| url=self.server_url, | |
| auth=oauth_auth, | |
| timeout=60, | |
| ) as (read_stream, write_stream): | |
| await self._run_session(read_stream, write_stream, None) | |
| else: | |
| print("π‘ Opening StreamableHTTP transport connection with auth...") | |
| async with streamablehttp_client( | |
| url=self.server_url, | |
| auth=oauth_auth, | |
| timeout=timedelta(seconds=60), | |
| ) as (read_stream, write_stream, get_session_id): | |
| await self._run_session(read_stream, write_stream, get_session_id) | |
| except Exception as e: | |
| print(f"β Failed to connect: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| async def _run_session(self, read_stream, write_stream, get_session_id): | |
| """Run the MCP session with the given streams.""" | |
| print("π€ Initializing MCP session...") | |
| async with ClientSession(read_stream, write_stream) as session: | |
| self.session = session | |
| print("β‘ Starting session initialization...") | |
| await session.initialize() | |
| print("β¨ Session initialization complete!") | |
| print(f"\nβ Connected to MCP server at {self.server_url}") | |
| if get_session_id: | |
| session_id = get_session_id() | |
| if session_id: | |
| print(f"Session ID: {session_id}") | |
| # Run interactive loop | |
| await self.interactive_loop() | |
| async def list_tools(self): | |
| """List available tools from the server.""" | |
| if not self.session: | |
| print("β Not connected to server") | |
| return | |
| try: | |
| result = await self.session.list_tools() | |
| if hasattr(result, "tools") and result.tools: | |
| print("\nπ Available tools:") | |
| for i, tool in enumerate(result.tools, 1): | |
| print(f"{i}. {tool.name}") | |
| if tool.description: | |
| print(f" Description: {tool.description}") | |
| print() | |
| else: | |
| print("No tools available") | |
| except Exception as e: | |
| print(f"β Failed to list tools: {e}") | |
| async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = None): | |
| """Call a specific tool.""" | |
| if not self.session: | |
| print("β Not connected to server") | |
| return | |
| try: | |
| result = await self.session.call_tool(tool_name, arguments or {}) | |
| print(f"\nπ§ Tool '{tool_name}' result:") | |
| if hasattr(result, "content"): | |
| for content in result.content: | |
| if content.type == "text": | |
| print(content.text) | |
| else: | |
| print(content) | |
| else: | |
| print(result) | |
| except Exception as e: | |
| print(f"β Failed to call tool '{tool_name}': {e}") | |
| async def interactive_loop(self): | |
| """Run interactive command loop.""" | |
| print("\nπ― Interactive MCP Client") | |
| print("Commands:") | |
| print(" list - List available tools") | |
| print(" call <tool_name> [args] - Call a tool") | |
| print(" quit - Exit the client") | |
| print() | |
| while True: | |
| try: | |
| command = input("mcp> ").strip() | |
| if not command: | |
| continue | |
| if command == "quit": | |
| break | |
| elif command == "list": | |
| await self.list_tools() | |
| elif command.startswith("call "): | |
| parts = command.split(maxsplit=2) | |
| tool_name = parts[1] if len(parts) > 1 else "" | |
| if not tool_name: | |
| print("β Please specify a tool name") | |
| continue | |
| # Parse arguments (simple JSON-like format) | |
| arguments = {} | |
| if len(parts) > 2: | |
| import json | |
| try: | |
| arguments = json.loads(parts[2]) | |
| except json.JSONDecodeError: | |
| print("β Invalid arguments format (expected JSON)") | |
| continue | |
| await self.call_tool(tool_name, arguments) | |
| else: | |
| print( | |
| "β Unknown command. Try 'list', 'call <tool_name>', or 'quit'" | |
| ) | |
| except KeyboardInterrupt: | |
| print("\n\nπ Goodbye!") | |
| break | |
| except EOFError: | |
| break | |
| async def main(): | |
| """Main entry point.""" | |
| # Default server URL - can be overridden with environment variable | |
| # Most MCP streamable HTTP servers use /mcp as the endpoint | |
| host_server = os.getenv("MCP_SERVER_HOST", "applemuncy-rs.hf.space") | |
| server_url = os.getenv("MCP_SERVER_PORT", 443) | |
| transport_type = os.getenv("MCP_TRANSPORT_TYPE", "streamable_http") | |
| server_url = ( | |
| f"https://{host_server}/mcp" # :{server_url}/mcp" | |
| if transport_type == "streamable_http" | |
| else f"http://localhost:{server_url}/sse" | |
| ) | |
| print("π Simple MCP Auth Client") | |
| print(f"Connecting to: {host_server}") | |
| print(f"Connecting to: {server_url}") | |
| print(f"Transport type: {transport_type}") | |
| # Start connection flow - OAuth will be handled automatically | |
| client = SimpleAuthClient(server_url, transport_type) | |
| await client.connect() | |
| def cli(): | |
| """CLI entry point for uv script.""" | |
| asyncio.run(main()) | |
| if __name__ == "__main__": | |
| cli() | |