Spaces:
Running
on
Zero
Running
on
Zero
| """Rate limiting implementation for Gradio application. | |
| Implements token bucket algorithm with multiple strategies: | |
| - ThreadSafeTokenBucket: In-memory, thread-safe rate limiting | |
| - AsyncTokenBucket: Async-compatible token bucket | |
| - HybridRateLimiter: Redis primary with memory fallback | |
| - TieredRateLimiter: Multi-tier rate limiting (anonymous, authenticated, premium) | |
| - GradioRateLimitMiddleware: Gradio-specific middleware for extracting client info | |
| """ | |
| import time | |
| import asyncio | |
| import logging | |
| import hashlib | |
| from typing import Optional, Dict, Any, Tuple, TYPE_CHECKING | |
| from dataclasses import dataclass | |
| from threading import Lock | |
| from enum import Enum | |
| if TYPE_CHECKING: | |
| import gradio as gr | |
| logger = logging.getLogger(__name__) | |
| class UserTier(str, Enum): | |
| """User tier enumeration for tiered rate limiting.""" | |
| ANONYMOUS = "anonymous" | |
| AUTHENTICATED = "authenticated" | |
| PREMIUM = "premium" | |
| class RateLimitInfo: | |
| """Rate limit information for user feedback. | |
| Attributes: | |
| allowed: Whether the request is allowed | |
| remaining: Remaining requests in current window | |
| reset_time: Unix timestamp when the limit resets | |
| retry_after: Seconds until next request is allowed | |
| """ | |
| allowed: bool | |
| remaining: int | |
| reset_time: float | |
| retry_after: Optional[float] = None | |
| class ThreadSafeTokenBucket: | |
| """Thread-safe token bucket rate limiter using in-memory storage. | |
| Implements the token bucket algorithm with configurable capacity and refill rate. | |
| Thread-safe for concurrent access. | |
| Attributes: | |
| capacity: Maximum number of tokens in the bucket | |
| refill_rate: Number of tokens added per second | |
| tokens: Current number of available tokens | |
| last_refill: Timestamp of last refill | |
| """ | |
| def __init__(self, capacity: int, refill_rate: float): | |
| """Initialise token bucket. | |
| Args: | |
| capacity: Maximum bucket capacity (max requests) | |
| refill_rate: Tokens added per second (requests per second) | |
| """ | |
| self.capacity = capacity | |
| self.refill_rate = refill_rate | |
| self.tokens = float(capacity) | |
| self.last_refill = time.time() | |
| self._lock = Lock() | |
| def _refill(self) -> None: | |
| """Refill tokens based on elapsed time since last refill.""" | |
| now = time.time() | |
| elapsed = now - self.last_refill | |
| # Add tokens based on elapsed time and refill rate | |
| tokens_to_add = elapsed * self.refill_rate | |
| self.tokens = min(self.capacity, self.tokens + tokens_to_add) | |
| self.last_refill = now | |
| def consume(self, tokens: int = 1) -> RateLimitInfo: | |
| """Attempt to consume tokens from the bucket. | |
| Args: | |
| tokens: Number of tokens to consume | |
| Returns: | |
| RateLimitInfo with consumption result and metadata | |
| """ | |
| with self._lock: | |
| self._refill() | |
| if self.tokens >= tokens: | |
| self.tokens -= tokens | |
| remaining = int(self.tokens) | |
| # Calculate reset time (when bucket will be full again) | |
| tokens_needed = self.capacity - self.tokens | |
| reset_time = time.time() + (tokens_needed / self.refill_rate) | |
| return RateLimitInfo( | |
| allowed=True, | |
| remaining=remaining, | |
| reset_time=reset_time, | |
| retry_after=None | |
| ) | |
| else: | |
| # Calculate when enough tokens will be available | |
| tokens_needed = tokens - self.tokens | |
| retry_after = tokens_needed / self.refill_rate | |
| reset_time = time.time() + retry_after | |
| return RateLimitInfo( | |
| allowed=False, | |
| remaining=0, | |
| reset_time=reset_time, | |
| retry_after=retry_after | |
| ) | |
| def reset(self) -> None: | |
| """Reset the bucket to full capacity.""" | |
| with self._lock: | |
| self.tokens = float(self.capacity) | |
| self.last_refill = time.time() | |
| class AsyncTokenBucket: | |
| """Async-compatible token bucket rate limiter. | |
| Uses asyncio locks for async/await compatibility. | |
| Suitable for async Gradio handlers and coroutines. | |
| Attributes: | |
| capacity: Maximum number of tokens in the bucket | |
| refill_rate: Number of tokens added per second | |
| tokens: Current number of available tokens | |
| last_refill: Timestamp of last refill | |
| """ | |
| def __init__(self, capacity: int, refill_rate: float): | |
| """Initialise async token bucket. | |
| Args: | |
| capacity: Maximum bucket capacity (max requests) | |
| refill_rate: Tokens added per second (requests per second) | |
| """ | |
| self.capacity = capacity | |
| self.refill_rate = refill_rate | |
| self.tokens = float(capacity) | |
| self.last_refill = time.time() | |
| self._lock = asyncio.Lock() | |
| def _refill(self) -> None: | |
| """Refill tokens based on elapsed time since last refill.""" | |
| now = time.time() | |
| elapsed = now - self.last_refill | |
| tokens_to_add = elapsed * self.refill_rate | |
| self.tokens = min(self.capacity, self.tokens + tokens_to_add) | |
| self.last_refill = now | |
| async def consume(self, tokens: int = 1) -> RateLimitInfo: | |
| """Attempt to consume tokens from the bucket (async). | |
| Args: | |
| tokens: Number of tokens to consume | |
| Returns: | |
| RateLimitInfo with consumption result and metadata | |
| """ | |
| async with self._lock: | |
| self._refill() | |
| if self.tokens >= tokens: | |
| self.tokens -= tokens | |
| remaining = int(self.tokens) | |
| tokens_needed = self.capacity - self.tokens | |
| reset_time = time.time() + (tokens_needed / self.refill_rate) | |
| return RateLimitInfo( | |
| allowed=True, | |
| remaining=remaining, | |
| reset_time=reset_time, | |
| retry_after=None | |
| ) | |
| else: | |
| tokens_needed = tokens - self.tokens | |
| retry_after = tokens_needed / self.refill_rate | |
| reset_time = time.time() + retry_after | |
| return RateLimitInfo( | |
| allowed=False, | |
| remaining=0, | |
| reset_time=reset_time, | |
| retry_after=retry_after | |
| ) | |
| async def reset(self) -> None: | |
| """Reset the bucket to full capacity (async).""" | |
| async with self._lock: | |
| self.tokens = float(self.capacity) | |
| self.last_refill = time.time() | |
| class HybridRateLimiter: | |
| """Hybrid rate limiter with Redis primary and in-memory fallback. | |
| Uses Redis for distributed rate limiting when available, gracefully | |
| degrades to in-memory token buckets when Redis is unavailable. | |
| Implements atomic Redis operations using Lua scripts for consistency. | |
| Attributes: | |
| capacity: Maximum bucket capacity | |
| refill_rate: Tokens per second | |
| redis_client: Optional Redis client | |
| memory_buckets: In-memory fallback buckets | |
| use_redis: Whether Redis is available | |
| """ | |
| # Lua script for atomic token bucket operations in Redis | |
| REDIS_CONSUME_SCRIPT = """ | |
| local key = KEYS[1] | |
| local capacity = tonumber(ARGV[1]) | |
| local refill_rate = tonumber(ARGV[2]) | |
| local tokens_requested = tonumber(ARGV[3]) | |
| local now = tonumber(ARGV[4]) | |
| -- Get current state or initialise | |
| local data = redis.call('HMGET', key, 'tokens', 'last_refill') | |
| local tokens = tonumber(data[1]) or capacity | |
| local last_refill = tonumber(data[2]) or now | |
| -- Refill tokens based on elapsed time | |
| local elapsed = now - last_refill | |
| local tokens_to_add = elapsed * refill_rate | |
| tokens = math.min(capacity, tokens + tokens_to_add) | |
| -- Try to consume | |
| if tokens >= tokens_requested then | |
| tokens = tokens - tokens_requested | |
| redis.call('HMSET', key, 'tokens', tokens, 'last_refill', now) | |
| redis.call('EXPIRE', key, math.ceil(capacity / refill_rate) * 2) | |
| local tokens_needed = capacity - tokens | |
| local reset_time = now + (tokens_needed / refill_rate) | |
| return {1, math.floor(tokens), reset_time, -1} | |
| else | |
| -- Not enough tokens | |
| local tokens_needed = tokens_requested - tokens | |
| local retry_after = tokens_needed / refill_rate | |
| local reset_time = now + retry_after | |
| redis.call('HMSET', key, 'tokens', tokens, 'last_refill', now) | |
| redis.call('EXPIRE', key, math.ceil(capacity / refill_rate) * 2) | |
| return {0, 0, reset_time, retry_after} | |
| end | |
| """ | |
| def __init__( | |
| self, | |
| capacity: int, | |
| refill_rate: float, | |
| redis_url: Optional[str] = None, | |
| key_prefix: str = "ratelimit" | |
| ): | |
| """Initialise hybrid rate limiter. | |
| Args: | |
| capacity: Maximum bucket capacity | |
| refill_rate: Tokens per second | |
| redis_url: Optional Redis connection URL | |
| key_prefix: Prefix for Redis keys | |
| """ | |
| self.capacity = capacity | |
| self.refill_rate = refill_rate | |
| self.key_prefix = key_prefix | |
| # Memory fallback | |
| self.memory_buckets: Dict[str, ThreadSafeTokenBucket] = {} | |
| self._memory_lock = Lock() | |
| # Try to initialise Redis | |
| self.redis_client = None | |
| self.redis_script_sha = None | |
| self.use_redis = False | |
| if redis_url: | |
| try: | |
| import redis | |
| self.redis_client = redis.from_url( | |
| redis_url, | |
| decode_responses=True, | |
| socket_connect_timeout=2, | |
| socket_timeout=2 | |
| ) | |
| # Test connection | |
| self.redis_client.ping() | |
| # Load Lua script | |
| self.redis_script_sha = self.redis_client.script_load( | |
| self.REDIS_CONSUME_SCRIPT | |
| ) | |
| self.use_redis = True | |
| logger.info("Redis rate limiter initialised successfully") | |
| except Exception as e: | |
| logger.warning( | |
| f"Redis initialisation failed, using in-memory fallback: {e}" | |
| ) | |
| self.redis_client = None | |
| self.use_redis = False | |
| def _get_memory_bucket(self, identifier: str) -> ThreadSafeTokenBucket: | |
| """Get or create in-memory bucket for identifier. | |
| Args: | |
| identifier: Client identifier (IP, user ID, etc.) | |
| Returns: | |
| ThreadSafeTokenBucket instance | |
| """ | |
| with self._memory_lock: | |
| if identifier not in self.memory_buckets: | |
| self.memory_buckets[identifier] = ThreadSafeTokenBucket( | |
| self.capacity, | |
| self.refill_rate | |
| ) | |
| return self.memory_buckets[identifier] | |
| def _consume_redis(self, identifier: str, tokens: int = 1) -> RateLimitInfo: | |
| """Consume tokens using Redis. | |
| Args: | |
| identifier: Client identifier | |
| tokens: Number of tokens to consume | |
| Returns: | |
| RateLimitInfo with result | |
| """ | |
| try: | |
| key = f"{self.key_prefix}:{identifier}" | |
| now = time.time() | |
| # Execute Lua script | |
| result = self.redis_client.evalsha( | |
| self.redis_script_sha, | |
| 1, | |
| key, | |
| self.capacity, | |
| self.refill_rate, | |
| tokens, | |
| now | |
| ) | |
| allowed = bool(result[0]) | |
| remaining = int(result[1]) | |
| reset_time = float(result[2]) | |
| retry_after = float(result[3]) if result[3] != -1 else None | |
| return RateLimitInfo( | |
| allowed=allowed, | |
| remaining=remaining, | |
| reset_time=reset_time, | |
| retry_after=retry_after | |
| ) | |
| except Exception as e: | |
| logger.error(f"Redis consume error, falling back to memory: {e}") | |
| # Fall back to memory on error | |
| self.use_redis = False | |
| return self._get_memory_bucket(identifier).consume(tokens) | |
| def consume(self, identifier: str, tokens: int = 1) -> RateLimitInfo: | |
| """Consume tokens for the given identifier. | |
| Args: | |
| identifier: Client identifier (IP, user ID, etc.) | |
| tokens: Number of tokens to consume | |
| Returns: | |
| RateLimitInfo with consumption result | |
| """ | |
| if self.use_redis and self.redis_client: | |
| return self._consume_redis(identifier, tokens) | |
| else: | |
| return self._get_memory_bucket(identifier).consume(tokens) | |
| def reset(self, identifier: str) -> None: | |
| """Reset rate limit for identifier. | |
| Args: | |
| identifier: Client identifier to reset | |
| """ | |
| if self.use_redis and self.redis_client: | |
| try: | |
| key = f"{self.key_prefix}:{identifier}" | |
| self.redis_client.delete(key) | |
| except Exception as e: | |
| logger.error(f"Redis reset error: {e}") | |
| # Also clear from memory | |
| with self._memory_lock: | |
| if identifier in self.memory_buckets: | |
| del self.memory_buckets[identifier] | |
| class TieredRateLimiter: | |
| """Multi-tier rate limiter with different limits per user tier. | |
| Supports different rate limits for: | |
| - Anonymous users (IP-based) | |
| - Authenticated users (user ID-based) | |
| - Premium users (higher limits) | |
| Attributes: | |
| tier_limits: Dictionary mapping tiers to (capacity, refill_rate) | |
| limiters: Dictionary of HybridRateLimiters per tier | |
| """ | |
| def __init__( | |
| self, | |
| tier_limits: Dict[UserTier, Tuple[int, float]], | |
| redis_url: Optional[str] = None | |
| ): | |
| """Initialise tiered rate limiter. | |
| Args: | |
| tier_limits: Dict mapping UserTier to (capacity, refill_rate) | |
| redis_url: Optional Redis connection URL | |
| """ | |
| self.tier_limits = tier_limits | |
| # Create limiter for each tier | |
| self.limiters: Dict[UserTier, HybridRateLimiter] = {} | |
| for tier, (capacity, refill_rate) in tier_limits.items(): | |
| self.limiters[tier] = HybridRateLimiter( | |
| capacity=capacity, | |
| refill_rate=refill_rate, | |
| redis_url=redis_url, | |
| key_prefix=f"ratelimit:{tier.value}" | |
| ) | |
| logger.info(f"Tiered rate limiter initialised with {len(tier_limits)} tiers") | |
| def consume( | |
| self, | |
| identifier: str, | |
| tier: UserTier = UserTier.ANONYMOUS, | |
| tokens: int = 1 | |
| ) -> RateLimitInfo: | |
| """Consume tokens for identifier at specified tier. | |
| Args: | |
| identifier: Client identifier (IP or user ID) | |
| tier: User tier for rate limit lookup | |
| tokens: Number of tokens to consume | |
| Returns: | |
| RateLimitInfo with consumption result | |
| """ | |
| if tier not in self.limiters: | |
| logger.warning(f"Unknown tier {tier}, using ANONYMOUS") | |
| tier = UserTier.ANONYMOUS | |
| return self.limiters[tier].consume(identifier, tokens) | |
| def reset(self, identifier: str, tier: UserTier = UserTier.ANONYMOUS) -> None: | |
| """Reset rate limit for identifier at tier. | |
| Args: | |
| identifier: Client identifier | |
| tier: User tier | |
| """ | |
| if tier in self.limiters: | |
| self.limiters[tier].reset(identifier) | |
| class GradioRateLimitMiddleware: | |
| """Gradio-specific rate limiting middleware. | |
| Extracts client information from gr.Request and applies rate limiting. | |
| Designed to wrap Gradio handler functions and enforce rate limits. | |
| Attributes: | |
| limiter: TieredRateLimiter instance | |
| get_user_tier: Optional function to determine user tier from request | |
| """ | |
| def __init__( | |
| self, | |
| limiter: TieredRateLimiter, | |
| get_user_tier=None | |
| ): | |
| """Initialise Gradio rate limit middleware. | |
| Args: | |
| limiter: TieredRateLimiter instance | |
| get_user_tier: Optional callable to extract user tier from request | |
| Signature: (gr.Request) -> Tuple[str, UserTier] | |
| Returns (identifier, tier) | |
| """ | |
| self.limiter = limiter | |
| self.get_user_tier = get_user_tier or self._default_get_user_tier | |
| def _default_get_user_tier(request: Optional["gr.Request"]) -> Tuple[str, UserTier]: | |
| """Default user tier extraction from Gradio request. | |
| Extracts client IP address and returns anonymous tier. | |
| Args: | |
| request: Gradio request object | |
| Returns: | |
| Tuple of (identifier, tier) | |
| """ | |
| if not request: | |
| # No request context, use default identifier | |
| return "default", UserTier.ANONYMOUS | |
| # Extract client IP from request | |
| # Gradio stores client info in request.client | |
| client_ip = "unknown" | |
| try: | |
| if hasattr(request, "client") and request.client: | |
| if hasattr(request.client, "host"): | |
| client_ip = request.client.host | |
| elif isinstance(request.client, str): | |
| client_ip = request.client | |
| # Also check headers for forwarded IPs (behind proxy) | |
| if hasattr(request, "headers"): | |
| forwarded = request.headers.get("X-Forwarded-For") | |
| if forwarded: | |
| # Take first IP if multiple | |
| client_ip = forwarded.split(",")[0].strip() | |
| except Exception as e: | |
| logger.warning(f"Error extracting client IP: {e}") | |
| # Hash IP for privacy (optional, can use raw IP) | |
| identifier = hashlib.sha256(client_ip.encode()).hexdigest()[:16] | |
| return identifier, UserTier.ANONYMOUS | |
| def check_rate_limit( | |
| self, | |
| request: Optional["gr.Request"] = None, | |
| tokens: int = 1, | |
| session_state: Optional[dict] = None | |
| ) -> RateLimitInfo: | |
| """Check rate limit for request. | |
| Args: | |
| request: Gradio request object | |
| tokens: Number of tokens to consume | |
| session_state: Optional session state dict for authenticated users | |
| Returns: | |
| RateLimitInfo with result | |
| """ | |
| # If custom get_user_tier accepts session_state, pass it | |
| try: | |
| identifier, tier = self.get_user_tier(request, session_state) | |
| except TypeError: | |
| # Fallback for get_user_tier functions that don't accept session_state | |
| identifier, tier = self.get_user_tier(request) | |
| return self.limiter.consume(identifier, tier, tokens) | |
| def enforce( | |
| self, | |
| request: Optional["gr.Request"] = None, | |
| tokens: int = 1, | |
| error_message: Optional[str] = None, | |
| session_state: Optional[dict] = None | |
| ) -> None: | |
| """Enforce rate limit, raise gr.Error if exceeded. | |
| Args: | |
| request: Gradio request object | |
| tokens: Number of tokens to consume | |
| error_message: Custom error message (optional) | |
| session_state: Optional session state dict for authenticated users | |
| Raises: | |
| gr.Error: If rate limit exceeded | |
| """ | |
| import gradio as gr | |
| info = self.check_rate_limit(request, tokens, session_state) | |
| if not info.allowed: | |
| retry_seconds = int(info.retry_after) if info.retry_after else 60 | |
| if error_message: | |
| message = error_message | |
| else: | |
| message = ( | |
| f"Rate limit exceeded. Please try again in {retry_seconds} seconds. " | |
| f"This helps us maintain service quality for all users." | |
| ) | |
| logger.warning( | |
| f"Rate limit exceeded for request (retry after {retry_seconds}s)" | |
| ) | |
| raise gr.Error(message) | |