Add support for middleware in apps

This commit is contained in:
Brian Wiborg 2025-10-02 21:49:48 +02:00
parent 5c632cbe8f
commit cc3872cf74
No known key found for this signature in database
3 changed files with 61 additions and 4 deletions

View file

@ -1 +1 @@
from . import models, permissions, routes from . import middlewares, models, permissions, routes

View file

@ -7,7 +7,7 @@ import sys
from http import HTTPStatus from http import HTTPStatus
from pathlib import Path from pathlib import Path
from types import ModuleType from types import ModuleType
from typing import Any, Dict, Generator, List, Optional, Type from typing import Any, Dict, Generator, List, Optional, Tuple, Type
import click import click
from aerich import Command as AerichCommand from aerich import Command as AerichCommand
@ -88,10 +88,13 @@ class Project:
def configure_app(self, app: FastAPI) -> FastAPI: def configure_app(self, app: FastAPI) -> FastAPI:
""" """
Attach project routes and event handlers to given FastAPI instance. Attach project middlewares and routes and event handlers to given
FastAPI instance.
""" """
# Attach project routes. # Attach project middlewares and routes.
for app_name, app_def in self._apps.items(): for app_name, app_def in self._apps.items():
for middleware, kwargs in app_def.middlewares:
app.add_middleware(middleware, **kwargs)
for router in app_def.routers: for router in app_def.routers:
app.include_router(router) app.include_router(router)
@ -229,11 +232,15 @@ class App:
# Reference to this app's routes modules. # Reference to this app's routes modules.
self._routers: Dict[str, ModuleType] = {} self._routers: Dict[str, ModuleType] = {}
# Reference to this apps middlewares.
self._middlewares: List[Tuple[Any, Dict[str, Any]]] = []
# Import the app, so its __init__.py runs. # Import the app, so its __init__.py runs.
mod: ModuleType = importlib.import_module(name) mod: ModuleType = importlib.import_module(name)
self.__load_models(f"{self.name}.models") self.__load_models(f"{self.name}.models")
self.__load_routes(f"{self.name}.routes") self.__load_routes(f"{self.name}.routes")
self.__load_middlewares(f"{self.name}.middlewares")
def __repr__(self): def __repr__(self):
return json.dumps(self.dict(), indent=2) return json.dumps(self.dict(), indent=2)
@ -319,6 +326,18 @@ class App:
# Walk the walk. # Walk the walk.
walk(mod_name) walk(mod_name)
def __load_middlewares(self, mod_name):
try:
mod = importlib.import_module(mod_name)
except ModuleNotFoundError:
print(f"no middlewares detected: {mod_name}")
return
getter = getattr(mod, "get", None)
if getter is not None:
for middleware in getter():
self._middlewares.append(middleware)
def __serialize_route(self, route): def __serialize_route(self, route):
""" """
Convert APIRoute to JSON-serializable dict. Convert APIRoute to JSON-serializable dict.
@ -332,6 +351,12 @@ class App:
def __serialize_router(self): def __serialize_router(self):
return [self.__serialize_route(route) for route in self.routes] return [self.__serialize_route(route) for route in self.routes]
def __serialize_middleware(self):
out = []
for m in self.middlewares:
out.append((m[0].__name__, m[1]))
return out
@property @property
def models(self) -> List[ModuleType]: def models(self) -> List[ModuleType]:
""" """
@ -363,6 +388,11 @@ class App:
out.extend(r.routes) out.extend(r.routes)
return out return out
@property
def middlewares(self):
"""Returns the list of this app's middlewares."""
return self._middlewares
def dict(self) -> Dict[str, Any]: def dict(self) -> Dict[str, Any]:
""" """
Convenience method for serializing the runtime data. Convenience method for serializing the runtime data.
@ -371,5 +401,6 @@ class App:
"models": [ "models": [
f"{self.name}.{m.__name__}" for m in self.models[f"{self.name}.models"] f"{self.name}.{m.__name__}" for m in self.models[f"{self.name}.models"]
], ],
"middlewares": self.__serialize_middleware(),
"routes": self.__serialize_router(), "routes": self.__serialize_router(),
} }

View file

@ -0,0 +1,26 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from typing import Any, Dict, List, Tuple
import settings
DEFAULT_ORIGINS = ["http://localhost", "http://localhost:8000"]
DEFAULT_CREDENTIALS = False
DEFAULT_METHODS = ["*"]
DEFAULT_HEADERS = ["*"]
CORS_CONFIG: Dict[str, Any] = getattr(settings, "MIDDLEWARE_CORS", {})
if not isinstance(CORS_CONFIG, dict):
raise ValueError("MIDDLEWARE_CORS must be of type dict")
middleware = [
(CORSMiddleware, {
"allow_origins": CORS_CONFIG.get("ALLOW_ORIGINS", DEFAULT_ORIGINS),
"allow_credentials": CORS_CONFIG.get("ALLOW_CREDENTIALS", DEFAULT_CREDENTIALS),
"allow_methods": CORS_CONFIG.get("ALLOW_METHODS", DEFAULT_METHODS),
"allow_headers": CORS_CONFIG.get("ALLOW_HEADERS", DEFAULT_HEADERS),
}),
]