mashrur950 commited on
Commit
a2d87ed
Β·
1 Parent(s): 62ca1b3

fix authetication

Browse files
Files changed (2) hide show
  1. app.py +80 -11
  2. server.py +102 -5
app.py CHANGED
@@ -99,11 +99,54 @@ if __name__ == "__main__":
99
  from starlette.requests import Request
100
  from database.api_keys import generate_api_key as db_generate_api_key
101
 
102
- # NOTE: Custom middleware for FastMCP has known limitations (GitHub issue #817)
103
- # Headers/state set in middleware are NOT accessible inside FastMCP tools.
104
- # For local development, use SKIP_AUTH=true with ENV=development in .env
105
- # For production, API key validation happens via extract_api_key_from_request() in server.py
106
- logger.info("[Auth] Using FastMCP built-in request handling for authentication")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  @mcp.custom_route("/", methods=["GET"])
109
  async def landing_page(request):
@@ -430,12 +473,38 @@ if __name__ == "__main__":
430
  logger.info("[OK] API key generation page added at /generate-key")
431
  logger.info("[OK] MCP SSE endpoint available at /sse")
432
 
433
- # Run MCP server with SSE transport (includes both /sse and custom routes)
434
- mcp.run(
435
- transport="sse",
436
- host=HF_SPACE_HOST,
437
- port=HF_SPACE_PORT
438
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
 
440
  except Exception as e:
441
  logger.error(f"Failed to start server: {e}")
 
99
  from starlette.requests import Request
100
  from database.api_keys import generate_api_key as db_generate_api_key
101
 
102
+ # Import session auth store from server.py
103
+ from server import store_session_api_key, _session_auth_store
104
+
105
+ # =====================================================================
106
+ # API KEY CAPTURE MIDDLEWARE (Starlette)
107
+ # This simple middleware captures the API key from SSE connections
108
+ # and stores it for later lookup by session_id during tool calls.
109
+ #
110
+ # Flow:
111
+ # 1. SSE connects: /sse?api_key=xxx β†’ Store api_key temporarily
112
+ # 2. Messages come: /messages/?session_id=yyy β†’ Link session to api_key
113
+ # 3. FastMCP middleware retrieves api_key from session store
114
+ # =====================================================================
115
+ from starlette.middleware.base import BaseHTTPMiddleware
116
+
117
+ class ApiKeyCaptureMiddleware(BaseHTTPMiddleware):
118
+ """Captures API key from SSE and links it to session_id"""
119
+
120
+ async def dispatch(self, request: Request, call_next):
121
+ api_key = request.query_params.get('api_key')
122
+ session_id = request.query_params.get('session_id')
123
+ path = request.url.path
124
+
125
+ # SSE connection with api_key - store for later linking
126
+ if api_key and path == '/sse':
127
+ _session_auth_store['_pending_api_key'] = api_key
128
+ logger.info(f"[ApiKeyCapture] SSE with api_key {api_key[:15]}...")
129
+
130
+ # Messages request - link session to api_key
131
+ if session_id and path.startswith('/messages'):
132
+ # Check if session already has api_key
133
+ if session_id not in _session_auth_store:
134
+ # Get from URL if present
135
+ if api_key:
136
+ store_session_api_key(session_id, api_key)
137
+ logger.info(f"[ApiKeyCapture] Session {session_id[:12]}... linked via URL")
138
+ # Otherwise use pending key from SSE connection
139
+ elif '_pending_api_key' in _session_auth_store:
140
+ pending = _session_auth_store.pop('_pending_api_key')
141
+ store_session_api_key(session_id, pending)
142
+ logger.info(f"[ApiKeyCapture] Session {session_id[:12]}... linked via pending")
143
+
144
+ return await call_next(request)
145
+
146
+ # Add the capture middleware - must be added before mcp.run()
147
+ # We'll add it via mcp.settings or directly to the ASGI app
148
+ logger.info("[Auth] ApiKeyCaptureMiddleware configured")
149
+ logger.info("[Auth] FastMCP ApiKeyAuthMiddleware handles authentication in tools")
150
 
151
  @mcp.custom_route("/", methods=["GET"])
152
  async def landing_page(request):
 
473
  logger.info("[OK] API key generation page added at /generate-key")
474
  logger.info("[OK] MCP SSE endpoint available at /sse")
475
 
476
+ # Try to use SSE app with custom middleware, fallback to standard mcp.run()
477
+ try:
478
+ from starlette.middleware import Middleware
479
+ import uvicorn
480
+
481
+ # Create middleware list
482
+ custom_middleware = [
483
+ Middleware(ApiKeyCaptureMiddleware)
484
+ ]
485
+
486
+ # Get SSE app with custom middleware
487
+ sse_app = mcp.sse_app(custom_middleware=custom_middleware)
488
+ logger.info("[Auth] Running with ApiKeyCaptureMiddleware via sse_app()")
489
+
490
+ # Run with uvicorn
491
+ uvicorn.run(
492
+ sse_app,
493
+ host=HF_SPACE_HOST,
494
+ port=HF_SPACE_PORT,
495
+ log_level="info"
496
+ )
497
+ except (AttributeError, TypeError) as e:
498
+ # Fallback: FastMCP version doesn't support custom_middleware
499
+ logger.warning(f"[Auth] custom_middleware not supported: {e}")
500
+ logger.warning("[Auth] Falling back to standard mcp.run()")
501
+
502
+ # Run MCP server with SSE transport (includes both /sse and custom routes)
503
+ mcp.run(
504
+ transport="sse",
505
+ host=HF_SPACE_HOST,
506
+ port=HF_SPACE_PORT
507
+ )
508
 
509
  except Exception as e:
510
  logger.error(f"Failed to start server: {e}")
server.py CHANGED
@@ -115,6 +115,81 @@ mcp = FastMCP(
115
  version="1.0.0"
116
  )
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  # Initialize shared services
119
  logger.info("Initializing FleetMind MCP Server...")
120
  geocoding_service = GeocodingService()
@@ -131,18 +206,40 @@ except Exception as e:
131
  # AUTHENTICATION - API KEY SYSTEM
132
  # ============================================================================
133
 
134
- def get_authenticated_user():
135
  """
136
- Get authenticated user by validating API key from request context.
 
 
 
 
 
 
 
137
 
138
  Returns:
139
  User info dict with user_id, email, scopes, name or None if not authenticated
140
  """
141
  try:
142
- # METHOD 1: Extract API key from current HTTP request
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  api_key = extract_api_key_from_request()
144
 
145
- # METHOD 2: Fallback to environment variable for testing
146
  if not api_key:
147
  api_key = os.getenv("FLEETMIND_API_KEY")
148
  if api_key:
@@ -159,7 +256,7 @@ def get_authenticated_user():
159
  logger.warning(f"❌ Invalid API key: {api_key[:10]}...")
160
  return None
161
 
162
- # METHOD 3: Development bypass mode (local testing only)
163
  # SECURITY: Only allow SKIP_AUTH in development environments
164
  # Check both ENV and ENVIRONMENT variables for compatibility
165
  env = os.getenv("ENV") or os.getenv("ENVIRONMENT", "production")
 
115
  version="1.0.0"
116
  )
117
 
118
+ # ============================================================================
119
+ # FASTMCP AUTHENTICATION MIDDLEWARE
120
+ # Uses FastMCP's native middleware system to properly pass user context to tools
121
+ # Reference: https://gelembjuk.com/blog/post/authentication-remote-mcp-server-python/
122
+ # ============================================================================
123
+ try:
124
+ from fastmcp.server.middleware import Middleware, MiddlewareContext
125
+ from fastmcp.exceptions import ToolError
126
+
127
+ class ApiKeyAuthMiddleware(Middleware):
128
+ """
129
+ FastMCP middleware for API key authentication.
130
+ Validates API key from request query params and injects user info into context.
131
+ """
132
+
133
+ async def on_call_tool(self, context: MiddlewareContext, call_next):
134
+ """Intercept tool calls to validate authentication"""
135
+ try:
136
+ from fastmcp.server.dependencies import get_http_request
137
+ request = get_http_request()
138
+
139
+ # Get API key from query params
140
+ api_key = request.query_params.get('api_key')
141
+
142
+ # Also check session store if no direct api_key
143
+ if not api_key:
144
+ session_id = request.query_params.get('session_id')
145
+ if session_id:
146
+ api_key = get_api_key_from_session(session_id)
147
+
148
+ if api_key:
149
+ # Validate API key and get user info
150
+ from database.api_keys import verify_api_key
151
+ user_info = verify_api_key(api_key)
152
+
153
+ if user_info:
154
+ # Store user info in context for tools to access
155
+ context.fastmcp_context.set_state("user_id", user_info['user_id'])
156
+ context.fastmcp_context.set_state("user_email", user_info['email'])
157
+ context.fastmcp_context.set_state("user_scopes", user_info.get('scopes', []))
158
+ context.fastmcp_context.set_state("user_name", user_info.get('name', ''))
159
+ logger.info(f"βœ… Auth middleware: User {user_info['email']} authenticated")
160
+ return await call_next(context)
161
+ else:
162
+ logger.warning(f"❌ Auth middleware: Invalid API key {api_key[:10]}...")
163
+
164
+ # Check SKIP_AUTH for development
165
+ env = os.getenv("ENV") or os.getenv("ENVIRONMENT", "production")
166
+ skip_auth = os.getenv("SKIP_AUTH", "false").lower() == "true"
167
+
168
+ if skip_auth and env.lower() != "production":
169
+ # Development mode - use dev user
170
+ context.fastmcp_context.set_state("user_id", "dev-user")
171
+ context.fastmcp_context.set_state("user_email", "[email protected]")
172
+ context.fastmcp_context.set_state("user_scopes", ["admin"])
173
+ context.fastmcp_context.set_state("user_name", "Development User")
174
+ logger.warning(f"⚠️ Auth middleware: SKIP_AUTH enabled, using dev user")
175
+ return await call_next(context)
176
+
177
+ # No valid authentication
178
+ raise ToolError("Authentication required. Please provide a valid API key.")
179
+
180
+ except ImportError:
181
+ # get_http_request not available (stdio transport)
182
+ logger.debug("Auth middleware: No HTTP request available")
183
+ return await call_next(context)
184
+
185
+ # Register the middleware with FastMCP
186
+ mcp.add_middleware(ApiKeyAuthMiddleware())
187
+ logger.info("[Auth] FastMCP ApiKeyAuthMiddleware registered")
188
+
189
+ except ImportError as e:
190
+ logger.warning(f"[Auth] Could not import FastMCP middleware: {e}")
191
+ logger.warning("[Auth] Falling back to per-tool authentication")
192
+
193
  # Initialize shared services
194
  logger.info("Initializing FleetMind MCP Server...")
195
  geocoding_service = GeocodingService()
 
206
  # AUTHENTICATION - API KEY SYSTEM
207
  # ============================================================================
208
 
209
+ def get_authenticated_user(ctx=None):
210
  """
211
+ Get authenticated user from multiple sources:
212
+ 1. FastMCP Context state (set by middleware) - PREFERRED
213
+ 2. HTTP request API key extraction (fallback)
214
+ 3. Environment variable (fallback for testing)
215
+ 4. SKIP_AUTH bypass (development only)
216
+
217
+ Args:
218
+ ctx: Optional FastMCP Context object from tool function
219
 
220
  Returns:
221
  User info dict with user_id, email, scopes, name or None if not authenticated
222
  """
223
  try:
224
+ # METHOD 1: Get user from FastMCP Context state (set by middleware)
225
+ # This is the preferred method when using FastMCP middleware
226
+ if ctx is not None:
227
+ try:
228
+ user_id = ctx.get_state("user_id")
229
+ if user_id:
230
+ return {
231
+ 'user_id': user_id,
232
+ 'email': ctx.get_state("user_email") or "",
233
+ 'scopes': ctx.get_state("user_scopes") or [],
234
+ 'name': ctx.get_state("user_name") or ""
235
+ }
236
+ except Exception:
237
+ pass # Context state not available, try other methods
238
+
239
+ # METHOD 2: Extract API key from current HTTP request
240
  api_key = extract_api_key_from_request()
241
 
242
+ # METHOD 3: Fallback to environment variable for testing
243
  if not api_key:
244
  api_key = os.getenv("FLEETMIND_API_KEY")
245
  if api_key:
 
256
  logger.warning(f"❌ Invalid API key: {api_key[:10]}...")
257
  return None
258
 
259
+ # METHOD 4: Development bypass mode (local testing only)
260
  # SECURITY: Only allow SKIP_AUTH in development environments
261
  # Check both ENV and ENVIRONMENT variables for compatibility
262
  env = os.getenv("ENV") or os.getenv("ENVIRONMENT", "production")