BrianIsaac's picture
fix: implement fixed window rate limiting and resolve analysis history persistence
84b71cf
raw
history blame
20.3 kB
"""Rate limiting implementation for Gradio application.
Implements token bucket algorithm with multiple strategies:
- ThreadSafeTokenBucket: In-memory, thread-safe rate limiting
- AsyncTokenBucket: Async-compatible token bucket
- HybridRateLimiter: Redis primary with memory fallback
- TieredRateLimiter: Multi-tier rate limiting (anonymous, authenticated, premium)
- GradioRateLimitMiddleware: Gradio-specific middleware for extracting client info
"""
import time
import asyncio
import logging
import hashlib
from typing import Optional, Dict, Any, Tuple, TYPE_CHECKING
from dataclasses import dataclass
from threading import Lock
from enum import Enum
if TYPE_CHECKING:
import gradio as gr
logger = logging.getLogger(__name__)
class UserTier(str, Enum):
"""User tier enumeration for tiered rate limiting."""
ANONYMOUS = "anonymous"
AUTHENTICATED = "authenticated"
PREMIUM = "premium"
@dataclass
class RateLimitInfo:
"""Rate limit information for user feedback.
Attributes:
allowed: Whether the request is allowed
remaining: Remaining requests in current window
reset_time: Unix timestamp when the limit resets
retry_after: Seconds until next request is allowed
"""
allowed: bool
remaining: int
reset_time: float
retry_after: Optional[float] = None
class ThreadSafeTokenBucket:
"""Thread-safe token bucket rate limiter using in-memory storage.
Implements the token bucket algorithm with configurable capacity and refill rate.
Thread-safe for concurrent access.
Attributes:
capacity: Maximum number of tokens in the bucket
refill_rate: Number of tokens added per second
tokens: Current number of available tokens
last_refill: Timestamp of last refill
"""
def __init__(self, capacity: int, refill_rate: float):
"""Initialise token bucket.
Args:
capacity: Maximum bucket capacity (max requests)
refill_rate: Tokens added per second (requests per second)
"""
self.capacity = capacity
self.refill_rate = refill_rate
self.tokens = float(capacity)
self.last_refill = time.time()
self._lock = Lock()
def _refill(self) -> None:
"""Refill tokens based on elapsed time since last refill."""
now = time.time()
elapsed = now - self.last_refill
# Add tokens based on elapsed time and refill rate
tokens_to_add = elapsed * self.refill_rate
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
self.last_refill = now
def consume(self, tokens: int = 1) -> RateLimitInfo:
"""Attempt to consume tokens from the bucket.
Args:
tokens: Number of tokens to consume
Returns:
RateLimitInfo with consumption result and metadata
"""
with self._lock:
self._refill()
if self.tokens >= tokens:
self.tokens -= tokens
remaining = int(self.tokens)
# Calculate reset time (when bucket will be full again)
tokens_needed = self.capacity - self.tokens
reset_time = time.time() + (tokens_needed / self.refill_rate)
return RateLimitInfo(
allowed=True,
remaining=remaining,
reset_time=reset_time,
retry_after=None
)
else:
# Calculate when enough tokens will be available
tokens_needed = tokens - self.tokens
retry_after = tokens_needed / self.refill_rate
reset_time = time.time() + retry_after
return RateLimitInfo(
allowed=False,
remaining=0,
reset_time=reset_time,
retry_after=retry_after
)
def reset(self) -> None:
"""Reset the bucket to full capacity."""
with self._lock:
self.tokens = float(self.capacity)
self.last_refill = time.time()
class AsyncTokenBucket:
"""Async-compatible token bucket rate limiter.
Uses asyncio locks for async/await compatibility.
Suitable for async Gradio handlers and coroutines.
Attributes:
capacity: Maximum number of tokens in the bucket
refill_rate: Number of tokens added per second
tokens: Current number of available tokens
last_refill: Timestamp of last refill
"""
def __init__(self, capacity: int, refill_rate: float):
"""Initialise async token bucket.
Args:
capacity: Maximum bucket capacity (max requests)
refill_rate: Tokens added per second (requests per second)
"""
self.capacity = capacity
self.refill_rate = refill_rate
self.tokens = float(capacity)
self.last_refill = time.time()
self._lock = asyncio.Lock()
def _refill(self) -> None:
"""Refill tokens based on elapsed time since last refill."""
now = time.time()
elapsed = now - self.last_refill
tokens_to_add = elapsed * self.refill_rate
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
self.last_refill = now
async def consume(self, tokens: int = 1) -> RateLimitInfo:
"""Attempt to consume tokens from the bucket (async).
Args:
tokens: Number of tokens to consume
Returns:
RateLimitInfo with consumption result and metadata
"""
async with self._lock:
self._refill()
if self.tokens >= tokens:
self.tokens -= tokens
remaining = int(self.tokens)
tokens_needed = self.capacity - self.tokens
reset_time = time.time() + (tokens_needed / self.refill_rate)
return RateLimitInfo(
allowed=True,
remaining=remaining,
reset_time=reset_time,
retry_after=None
)
else:
tokens_needed = tokens - self.tokens
retry_after = tokens_needed / self.refill_rate
reset_time = time.time() + retry_after
return RateLimitInfo(
allowed=False,
remaining=0,
reset_time=reset_time,
retry_after=retry_after
)
async def reset(self) -> None:
"""Reset the bucket to full capacity (async)."""
async with self._lock:
self.tokens = float(self.capacity)
self.last_refill = time.time()
class HybridRateLimiter:
"""Hybrid rate limiter with Redis primary and in-memory fallback.
Uses Redis for distributed rate limiting when available, gracefully
degrades to in-memory token buckets when Redis is unavailable.
Implements atomic Redis operations using Lua scripts for consistency.
Attributes:
capacity: Maximum bucket capacity
refill_rate: Tokens per second
redis_client: Optional Redis client
memory_buckets: In-memory fallback buckets
use_redis: Whether Redis is available
"""
# Lua script for atomic token bucket operations in Redis
REDIS_CONSUME_SCRIPT = """
local key = KEYS[1]
local capacity = tonumber(ARGV[1])
local refill_rate = tonumber(ARGV[2])
local tokens_requested = tonumber(ARGV[3])
local now = tonumber(ARGV[4])
-- Get current state or initialise
local data = redis.call('HMGET', key, 'tokens', 'last_refill')
local tokens = tonumber(data[1]) or capacity
local last_refill = tonumber(data[2]) or now
-- Refill tokens based on elapsed time
local elapsed = now - last_refill
local tokens_to_add = elapsed * refill_rate
tokens = math.min(capacity, tokens + tokens_to_add)
-- Try to consume
if tokens >= tokens_requested then
tokens = tokens - tokens_requested
redis.call('HMSET', key, 'tokens', tokens, 'last_refill', now)
redis.call('EXPIRE', key, math.ceil(capacity / refill_rate) * 2)
local tokens_needed = capacity - tokens
local reset_time = now + (tokens_needed / refill_rate)
return {1, math.floor(tokens), reset_time, -1}
else
-- Not enough tokens
local tokens_needed = tokens_requested - tokens
local retry_after = tokens_needed / refill_rate
local reset_time = now + retry_after
redis.call('HMSET', key, 'tokens', tokens, 'last_refill', now)
redis.call('EXPIRE', key, math.ceil(capacity / refill_rate) * 2)
return {0, 0, reset_time, retry_after}
end
"""
def __init__(
self,
capacity: int,
refill_rate: float,
redis_url: Optional[str] = None,
key_prefix: str = "ratelimit"
):
"""Initialise hybrid rate limiter.
Args:
capacity: Maximum bucket capacity
refill_rate: Tokens per second
redis_url: Optional Redis connection URL
key_prefix: Prefix for Redis keys
"""
self.capacity = capacity
self.refill_rate = refill_rate
self.key_prefix = key_prefix
# Memory fallback
self.memory_buckets: Dict[str, ThreadSafeTokenBucket] = {}
self._memory_lock = Lock()
# Try to initialise Redis
self.redis_client = None
self.redis_script_sha = None
self.use_redis = False
if redis_url:
try:
import redis
self.redis_client = redis.from_url(
redis_url,
decode_responses=True,
socket_connect_timeout=2,
socket_timeout=2
)
# Test connection
self.redis_client.ping()
# Load Lua script
self.redis_script_sha = self.redis_client.script_load(
self.REDIS_CONSUME_SCRIPT
)
self.use_redis = True
logger.info("Redis rate limiter initialised successfully")
except Exception as e:
logger.warning(
f"Redis initialisation failed, using in-memory fallback: {e}"
)
self.redis_client = None
self.use_redis = False
def _get_memory_bucket(self, identifier: str) -> ThreadSafeTokenBucket:
"""Get or create in-memory bucket for identifier.
Args:
identifier: Client identifier (IP, user ID, etc.)
Returns:
ThreadSafeTokenBucket instance
"""
with self._memory_lock:
if identifier not in self.memory_buckets:
self.memory_buckets[identifier] = ThreadSafeTokenBucket(
self.capacity,
self.refill_rate
)
return self.memory_buckets[identifier]
def _consume_redis(self, identifier: str, tokens: int = 1) -> RateLimitInfo:
"""Consume tokens using Redis.
Args:
identifier: Client identifier
tokens: Number of tokens to consume
Returns:
RateLimitInfo with result
"""
try:
key = f"{self.key_prefix}:{identifier}"
now = time.time()
# Execute Lua script
result = self.redis_client.evalsha(
self.redis_script_sha,
1,
key,
self.capacity,
self.refill_rate,
tokens,
now
)
allowed = bool(result[0])
remaining = int(result[1])
reset_time = float(result[2])
retry_after = float(result[3]) if result[3] != -1 else None
return RateLimitInfo(
allowed=allowed,
remaining=remaining,
reset_time=reset_time,
retry_after=retry_after
)
except Exception as e:
logger.error(f"Redis consume error, falling back to memory: {e}")
# Fall back to memory on error
self.use_redis = False
return self._get_memory_bucket(identifier).consume(tokens)
def consume(self, identifier: str, tokens: int = 1) -> RateLimitInfo:
"""Consume tokens for the given identifier.
Args:
identifier: Client identifier (IP, user ID, etc.)
tokens: Number of tokens to consume
Returns:
RateLimitInfo with consumption result
"""
if self.use_redis and self.redis_client:
return self._consume_redis(identifier, tokens)
else:
return self._get_memory_bucket(identifier).consume(tokens)
def reset(self, identifier: str) -> None:
"""Reset rate limit for identifier.
Args:
identifier: Client identifier to reset
"""
if self.use_redis and self.redis_client:
try:
key = f"{self.key_prefix}:{identifier}"
self.redis_client.delete(key)
except Exception as e:
logger.error(f"Redis reset error: {e}")
# Also clear from memory
with self._memory_lock:
if identifier in self.memory_buckets:
del self.memory_buckets[identifier]
class TieredRateLimiter:
"""Multi-tier rate limiter with different limits per user tier.
Supports different rate limits for:
- Anonymous users (IP-based)
- Authenticated users (user ID-based)
- Premium users (higher limits)
Attributes:
tier_limits: Dictionary mapping tiers to (capacity, refill_rate)
limiters: Dictionary of HybridRateLimiters per tier
"""
def __init__(
self,
tier_limits: Dict[UserTier, Tuple[int, float]],
redis_url: Optional[str] = None
):
"""Initialise tiered rate limiter.
Args:
tier_limits: Dict mapping UserTier to (capacity, refill_rate)
redis_url: Optional Redis connection URL
"""
self.tier_limits = tier_limits
# Create limiter for each tier
self.limiters: Dict[UserTier, HybridRateLimiter] = {}
for tier, (capacity, refill_rate) in tier_limits.items():
self.limiters[tier] = HybridRateLimiter(
capacity=capacity,
refill_rate=refill_rate,
redis_url=redis_url,
key_prefix=f"ratelimit:{tier.value}"
)
logger.info(f"Tiered rate limiter initialised with {len(tier_limits)} tiers")
def consume(
self,
identifier: str,
tier: UserTier = UserTier.ANONYMOUS,
tokens: int = 1
) -> RateLimitInfo:
"""Consume tokens for identifier at specified tier.
Args:
identifier: Client identifier (IP or user ID)
tier: User tier for rate limit lookup
tokens: Number of tokens to consume
Returns:
RateLimitInfo with consumption result
"""
if tier not in self.limiters:
logger.warning(f"Unknown tier {tier}, using ANONYMOUS")
tier = UserTier.ANONYMOUS
return self.limiters[tier].consume(identifier, tokens)
def reset(self, identifier: str, tier: UserTier = UserTier.ANONYMOUS) -> None:
"""Reset rate limit for identifier at tier.
Args:
identifier: Client identifier
tier: User tier
"""
if tier in self.limiters:
self.limiters[tier].reset(identifier)
class GradioRateLimitMiddleware:
"""Gradio-specific rate limiting middleware.
Extracts client information from gr.Request and applies rate limiting.
Designed to wrap Gradio handler functions and enforce rate limits.
Attributes:
limiter: TieredRateLimiter instance
get_user_tier: Optional function to determine user tier from request
"""
def __init__(
self,
limiter: TieredRateLimiter,
get_user_tier=None
):
"""Initialise Gradio rate limit middleware.
Args:
limiter: TieredRateLimiter instance
get_user_tier: Optional callable to extract user tier from request
Signature: (gr.Request) -> Tuple[str, UserTier]
Returns (identifier, tier)
"""
self.limiter = limiter
self.get_user_tier = get_user_tier or self._default_get_user_tier
@staticmethod
def _default_get_user_tier(request: Optional["gr.Request"]) -> Tuple[str, UserTier]:
"""Default user tier extraction from Gradio request.
Extracts client IP address and returns anonymous tier.
Args:
request: Gradio request object
Returns:
Tuple of (identifier, tier)
"""
if not request:
# No request context, use default identifier
return "default", UserTier.ANONYMOUS
# Extract client IP from request
# Gradio stores client info in request.client
client_ip = "unknown"
try:
if hasattr(request, "client") and request.client:
if hasattr(request.client, "host"):
client_ip = request.client.host
elif isinstance(request.client, str):
client_ip = request.client
# Also check headers for forwarded IPs (behind proxy)
if hasattr(request, "headers"):
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
# Take first IP if multiple
client_ip = forwarded.split(",")[0].strip()
except Exception as e:
logger.warning(f"Error extracting client IP: {e}")
# Hash IP for privacy (optional, can use raw IP)
identifier = hashlib.sha256(client_ip.encode()).hexdigest()[:16]
return identifier, UserTier.ANONYMOUS
def check_rate_limit(
self,
request: Optional["gr.Request"] = None,
tokens: int = 1,
session_state: Optional[dict] = None
) -> RateLimitInfo:
"""Check rate limit for request.
Args:
request: Gradio request object
tokens: Number of tokens to consume
session_state: Optional session state dict for authenticated users
Returns:
RateLimitInfo with result
"""
# If custom get_user_tier accepts session_state, pass it
try:
identifier, tier = self.get_user_tier(request, session_state)
except TypeError:
# Fallback for get_user_tier functions that don't accept session_state
identifier, tier = self.get_user_tier(request)
return self.limiter.consume(identifier, tier, tokens)
def enforce(
self,
request: Optional["gr.Request"] = None,
tokens: int = 1,
error_message: Optional[str] = None,
session_state: Optional[dict] = None
) -> None:
"""Enforce rate limit, raise gr.Error if exceeded.
Args:
request: Gradio request object
tokens: Number of tokens to consume
error_message: Custom error message (optional)
session_state: Optional session state dict for authenticated users
Raises:
gr.Error: If rate limit exceeded
"""
import gradio as gr
info = self.check_rate_limit(request, tokens, session_state)
if not info.allowed:
retry_seconds = int(info.retry_after) if info.retry_after else 60
if error_message:
message = error_message
else:
message = (
f"Rate limit exceeded. Please try again in {retry_seconds} seconds. "
f"This helps us maintain service quality for all users."
)
logger.warning(
f"Rate limit exceeded for request (retry after {retry_seconds}s)"
)
raise gr.Error(message)