FastAPI — rate limit middleware
There are many cases where we would want to have a rate limiting mechanism in web app (i.e defend against over using an expensive/external/external & expensive service).
So why middle?
Well it is just a convenient place to place something like a rate limiter since it capture the request and can block it before CPU heavy tasks are done.
as illustrated, a middleware simply allow to intercept the requests internally.
Getting started
for start, i am using limits package to maintain the rate limiting window, it is a simple package that allows to multiple storage options and implements multiple algorithms, and supports both sync and async implementations.
# limits_wrapper.py
from limits import RateLimitItem, RateLimitItemPerMinute, storage, strategies
REDIS_URL: str = "redis://localhost:6379/0"
storage = storage.RedisStorage(REDIS_URL)
throttler = strategies.MovingWindowRateLimiter(storage)
"""
This component is used as a wrapper for `limits` so we won't use its api directly in the throttler class.
"""
def hit(key: str, rate_per_minute: int, cost: int = 1) -> bool:
"""
Hits the throttler and returns `true` if a request can be passed and `false` if it needs to be blocked
:param key: the key that identifies the client that needs to be throttled
:param rate_per_minute: the number of request per minute to allow
:param cost: the cost of the request in the time window.
:return: returns `true` if a request can be passed and `false` if it needs to be blocked
"""
item = rate_limit_item_for(rate_per_minute=rate_per_minute)
is_hit = throttler.hit(item, key, cost=cost)
return is_hit
def rate_limit_item_for(rate_per_minute: int) -> RateLimitItem:
"""
Returns the rate of requests for a specific model
:param rate_per_minute: the number of request per minute to allow
:return: `RateLimitItem` object initiated with a rate limit that matched the model
"""
return RateLimitItemPerMinute(rate_per_minute)
This section wraps the usage of the limits library and exposes its functionality through our module, which will make it easier easier to manage the touch points with the library in case we want to replace it.
Rate limiting logic
Now we need to write the class that will manage the rate limiting with the library
from starlette.requests import Request
from starlette.exceptions import HTTPException
from starlette.status import HTTP_429_TOO_MANY_REQUESTS
from .limits_wrapper import hit
class RateLimitMiddleware:
async def __call__(self, request: Request):
key = "some-key"
rate = 100
if not hit(key=key, rate_per_minute=rate):
raise HTTPException(status_code=HTTP_429_TOO_MANY_REQUESTS, detail="request limit reached")
Adding flexibility
But this logic is hard coded and does not add much we can easily improve it by allowing to pass functions that will allow to customize the usage of the class
class RateLimitMiddleware:
def __init__(
self,
identifier: Callable[[Request], Awaitable[str]],
callback: Callable[[Request], Awaitable[Any]],
rate_provider: Callable[[Request], Awaitable[int]]
):
self.identifier = identifier
self.callback = callback
self.rate_provider = rate_provider
async def __call__(self, request: Request):
callback = self.callback
identifier = self.identifier
rate_provider = self.rate_provider
key = await identifier(request)
rate = await rate_provider(request)
if not hit(key=key, rate_per_minute=rate):
return await callback(request)
by introducing these function parameter we can provide custom logic to will generate what we need.
Adding sensible defaults
We can even include default implementations of them if it makes sense for our use case, for example we want to enforce rate limiting for each ip as a default, or also return an HTTP 429 by default when the rate limit has been reached.
async def identifier(request: Request) -> str:
ip = request.client.host
return ip
async def _default_callback(request: Request):
raise HTTPException(status_code=HTTP_429_TOO_MANY_REQUESTS, detail="request limit reached")
class RateLimitMiddleware:
def __init__(
self,
identifier: Callable[[Request], Awaitable[str]] = identifier,
callback: Callable[[Request], Awaitable[Any]],
rate_provider: Callable[[Request], Awaitable[int]]
):
self.identifier = identifier
self.callback = callback
self.rate_provider = rate_provider
async def __call__(self, request: Request):
callback = self.callback
identifier = self.identifier
rate_provider = self.rate_provider
key = await identifier(request)
rate = await rate_provider(request)
if not hit(key=key, rate_per_minute=rate):
return await callback(request)
Using the Rate limiter
If we want to set our rate limit to 1000 requests per minute, and use the default callback and identifier:
async def rate_provider(request: Request) -> str:
return 1000
rate_limit = RateLimitMiddleware(rate_provider=rate_provider)
Integrating with FastAPI
To integrate it with FastAPI I prefer to add it as a dependency since it means it can be applied per route/router/app, which gives a lot of flexibility, unlike when we use it as a starlette middleware which forces us to you us it on an app wide level.
from fastapi import FastAPI, Depends
from app.middlewares.rate_limiter import RateLimitMiddleware
app = FastAPI()
@app.get("/", dependencies=[Depends(RateLimitMiddleware())])
async def root():
return {"message": "OK!"}
adding the following dependency to a route/router will enforce the rate limiting on it.