From cc3872cf74358fdf1d8dfd76f09a0f188dfc86cd Mon Sep 17 00:00:00 2001 From: Brian Wiborg Date: Thu, 2 Oct 2025 21:49:48 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Add=20support=20for=20middleware=20?= =?UTF-8?q?in=20apps?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/ohmyapi/builtin/auth/__init__.py | 2 +- src/ohmyapi/core/runtime.py | 37 +++++++++++++++++++++++++--- src/ohmyapi/middleware/cors.py | 26 +++++++++++++++++++ 3 files changed, 61 insertions(+), 4 deletions(-) create mode 100644 src/ohmyapi/middleware/cors.py diff --git a/src/ohmyapi/builtin/auth/__init__.py b/src/ohmyapi/builtin/auth/__init__.py index 8c2daf9..3867c30 100644 --- a/src/ohmyapi/builtin/auth/__init__.py +++ b/src/ohmyapi/builtin/auth/__init__.py @@ -1 +1 @@ -from . import models, permissions, routes +from . import middlewares, models, permissions, routes diff --git a/src/ohmyapi/core/runtime.py b/src/ohmyapi/core/runtime.py index 423f3c5..010c67d 100644 --- a/src/ohmyapi/core/runtime.py +++ b/src/ohmyapi/core/runtime.py @@ -7,7 +7,7 @@ import sys from http import HTTPStatus from pathlib import Path from types import ModuleType -from typing import Any, Dict, Generator, List, Optional, Type +from typing import Any, Dict, Generator, List, Optional, Tuple, Type import click from aerich import Command as AerichCommand @@ -88,10 +88,13 @@ class Project: def configure_app(self, app: FastAPI) -> FastAPI: """ - Attach project routes and event handlers to given FastAPI instance. + Attach project middlewares and routes and event handlers to given + FastAPI instance. """ - # Attach project routes. + # Attach project middlewares and routes. for app_name, app_def in self._apps.items(): + for middleware, kwargs in app_def.middlewares: + app.add_middleware(middleware, **kwargs) for router in app_def.routers: app.include_router(router) @@ -229,11 +232,15 @@ class App: # Reference to this app's routes modules. self._routers: Dict[str, ModuleType] = {} + # Reference to this apps middlewares. + self._middlewares: List[Tuple[Any, Dict[str, Any]]] = [] + # Import the app, so its __init__.py runs. mod: ModuleType = importlib.import_module(name) self.__load_models(f"{self.name}.models") self.__load_routes(f"{self.name}.routes") + self.__load_middlewares(f"{self.name}.middlewares") def __repr__(self): return json.dumps(self.dict(), indent=2) @@ -319,6 +326,18 @@ class App: # Walk the walk. walk(mod_name) + def __load_middlewares(self, mod_name): + try: + mod = importlib.import_module(mod_name) + except ModuleNotFoundError: + print(f"no middlewares detected: {mod_name}") + return + + getter = getattr(mod, "get", None) + if getter is not None: + for middleware in getter(): + self._middlewares.append(middleware) + def __serialize_route(self, route): """ Convert APIRoute to JSON-serializable dict. @@ -332,6 +351,12 @@ class App: def __serialize_router(self): return [self.__serialize_route(route) for route in self.routes] + def __serialize_middleware(self): + out = [] + for m in self.middlewares: + out.append((m[0].__name__, m[1])) + return out + @property def models(self) -> List[ModuleType]: """ @@ -363,6 +388,11 @@ class App: out.extend(r.routes) return out + @property + def middlewares(self): + """Returns the list of this app's middlewares.""" + return self._middlewares + def dict(self) -> Dict[str, Any]: """ Convenience method for serializing the runtime data. @@ -371,5 +401,6 @@ class App: "models": [ f"{self.name}.{m.__name__}" for m in self.models[f"{self.name}.models"] ], + "middlewares": self.__serialize_middleware(), "routes": self.__serialize_router(), } diff --git a/src/ohmyapi/middleware/cors.py b/src/ohmyapi/middleware/cors.py new file mode 100644 index 0000000..49852cd --- /dev/null +++ b/src/ohmyapi/middleware/cors.py @@ -0,0 +1,26 @@ +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from typing import Any, Dict, List, Tuple + +import settings + +DEFAULT_ORIGINS = ["http://localhost", "http://localhost:8000"] +DEFAULT_CREDENTIALS = False +DEFAULT_METHODS = ["*"] +DEFAULT_HEADERS = ["*"] + +CORS_CONFIG: Dict[str, Any] = getattr(settings, "MIDDLEWARE_CORS", {}) + +if not isinstance(CORS_CONFIG, dict): + raise ValueError("MIDDLEWARE_CORS must be of type dict") + +middleware = [ + (CORSMiddleware, { + "allow_origins": CORS_CONFIG.get("ALLOW_ORIGINS", DEFAULT_ORIGINS), + "allow_credentials": CORS_CONFIG.get("ALLOW_CREDENTIALS", DEFAULT_CREDENTIALS), + "allow_methods": CORS_CONFIG.get("ALLOW_METHODS", DEFAULT_METHODS), + "allow_headers": CORS_CONFIG.get("ALLOW_HEADERS", DEFAULT_HEADERS), + }), +] +