FastAPI — rate limit middleware

Snir Orlanczyk
3 min readOct 4, 2023

--

There are many cases where we would want to have a rate limiting mechanism in web app (i.e defend against over using an expensive/external/external & expensive service).

So why middle?

Well it is just a convenient place to place something like a rate limiter since it capture the request and can block it before CPU heavy tasks are done.

as illustrated, a middleware simply allow to intercept the requests internally.

Getting started

for start, i am using limits package to maintain the rate limiting window, it is a simple package that allows to multiple storage options and implements multiple algorithms, and supports both sync and async implementations.

# limits_wrapper.py

from limits import RateLimitItem, RateLimitItemPerMinute, storage, strategies

REDIS_URL: str = "redis://localhost:6379/0"
storage = storage.RedisStorage(REDIS_URL)
throttler = strategies.MovingWindowRateLimiter(storage)

"""
This component is used as a wrapper for `limits` so we won't use its api directly in the throttler class.
"""


def hit(key: str, rate_per_minute: int, cost: int = 1) -> bool:
"""
Hits the throttler and returns `true` if a request can be passed and `false` if it needs to be blocked
:param key: the key that identifies the client that needs to be throttled
:param rate_per_minute: the number of request per minute to allow
:param cost: the cost of the request in the time window.
:return: returns `true` if a request can be passed and `false` if it needs to be blocked
"""
item = rate_limit_item_for(rate_per_minute=rate_per_minute)
is_hit = throttler.hit(item, key, cost=cost)
return is_hit


def rate_limit_item_for(rate_per_minute: int) -> RateLimitItem:
"""
Returns the rate of requests for a specific model

:param rate_per_minute: the number of request per minute to allow
:return: `RateLimitItem` object initiated with a rate limit that matched the model
"""
return RateLimitItemPerMinute(rate_per_minute)

This section wraps the usage of the limits library and exposes its functionality through our module, which will make it easier easier to manage the touch points with the library in case we want to replace it.

Rate limiting logic

Now we need to write the class that will manage the rate limiting with the library

from starlette.requests import Request
from starlette.exceptions import HTTPException
from starlette.status import HTTP_429_TOO_MANY_REQUESTS

from .limits_wrapper import hit

class RateLimitMiddleware:
async def __call__(self, request: Request):
key = "some-key"
rate = 100

if not hit(key=key, rate_per_minute=rate):
raise HTTPException(status_code=HTTP_429_TOO_MANY_REQUESTS, detail="request limit reached")

Adding flexibility

But this logic is hard coded and does not add much we can easily improve it by allowing to pass functions that will allow to customize the usage of the class

class RateLimitMiddleware:
def __init__(
self,
identifier: Callable[[Request], Awaitable[str]],
callback: Callable[[Request], Awaitable[Any]],
rate_provider: Callable[[Request], Awaitable[int]]
):
self.identifier = identifier
self.callback = callback
self.rate_provider = rate_provider

async def __call__(self, request: Request):
callback = self.callback
identifier = self.identifier
rate_provider = self.rate_provider

key = await identifier(request)
rate = await rate_provider(request)

if not hit(key=key, rate_per_minute=rate):
return await callback(request)

by introducing these function parameter we can provide custom logic to will generate what we need.

Adding sensible defaults

We can even include default implementations of them if it makes sense for our use case, for example we want to enforce rate limiting for each ip as a default, or also return an HTTP 429 by default when the rate limit has been reached.

async def identifier(request: Request) -> str:
ip = request.client.host
return ip

async def _default_callback(request: Request):
raise HTTPException(status_code=HTTP_429_TOO_MANY_REQUESTS, detail="request limit reached")

class RateLimitMiddleware:
def __init__(
self,
identifier: Callable[[Request], Awaitable[str]] = identifier,
callback: Callable[[Request], Awaitable[Any]],
rate_provider: Callable[[Request], Awaitable[int]]
):
self.identifier = identifier
self.callback = callback
self.rate_provider = rate_provider

async def __call__(self, request: Request):
callback = self.callback
identifier = self.identifier
rate_provider = self.rate_provider

key = await identifier(request)
rate = await rate_provider(request)

if not hit(key=key, rate_per_minute=rate):
return await callback(request)

Using the Rate limiter

If we want to set our rate limit to 1000 requests per minute, and use the default callback and identifier:

async def rate_provider(request: Request) -> str:
return 1000


rate_limit = RateLimitMiddleware(rate_provider=rate_provider)

Integrating with FastAPI

To integrate it with FastAPI I prefer to add it as a dependency since it means it can be applied per route/router/app, which gives a lot of flexibility, unlike when we use it as a starlette middleware which forces us to you us it on an app wide level.

from fastapi import FastAPI, Depends
from app.middlewares.rate_limiter import RateLimitMiddleware

app = FastAPI()


@app.get("/", dependencies=[Depends(RateLimitMiddleware())])
async def root():
return {"message": "OK!"}

adding the following dependency to a route/router will enforce the rate limiting on it.

--

--

Snir Orlanczyk

iOS developer by day, iOS developer by night (one does not simply stop developing an app)