From 2232726e7cf27ef482dc88528a400dc8d65917e0 Mon Sep 17 00:00:00 2001 From: Brian Wiborg Date: Wed, 1 Oct 2025 20:43:56 +0200 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Refactor=20core.runtime?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - rewrite how apps are loaded into scope - rewrite how apps are collected for Tortoise and Aerich - rewrite how routes are collected for FastAPI - support packages for models and routes with arbitrary nesting - no need to expose models and routes in __init__.py - OhMyAPI will recursively iterate through all submodules --- src/ohmyapi/builtin/demo/routes.py | 8 +- src/ohmyapi/core/runtime.py | 146 ++++++++++++++++++++--------- src/ohmyapi/db/model/model.py | 8 +- src/ohmyapi/router.py | 27 ------ 4 files changed, 107 insertions(+), 82 deletions(-) diff --git a/src/ohmyapi/builtin/demo/routes.py b/src/ohmyapi/builtin/demo/routes.py index 0022039..24bdc83 100644 --- a/src/ohmyapi/builtin/demo/routes.py +++ b/src/ohmyapi/builtin/demo/routes.py @@ -11,9 +11,7 @@ from typing import List router = APIRouter(prefix="/tournament") -@router.get( - "/", tags=["tournament"], response_model=List[models.Tournament.Schema()] -) +@router.get("/", tags=["tournament"], response_model=List[models.Tournament.Schema()]) async def list(): """List all tournaments.""" return await models.Tournament.Schema().from_queryset(models.Tournament.all()) @@ -30,9 +28,7 @@ async def post(tournament: models.Tournament.Schema(readonly=True)): @router.get("/{id}", tags=["tournament"], response_model=models.Tournament.Schema()) async def get(id: str): """Get tournament by id.""" - return await models.Tournament.Schema().from_queryset( - models.Tournament.get(id=id) - ) + return await models.Tournament.Schema().from_queryset(models.Tournament.get(id=id)) @router.put( diff --git a/src/ohmyapi/core/runtime.py b/src/ohmyapi/core/runtime.py index 7b0c340..103853c 100644 --- a/src/ohmyapi/core/runtime.py +++ b/src/ohmyapi/core/runtime.py @@ -1,12 +1,12 @@ # ohmyapi/core/runtime.py -import copy import importlib import importlib.util import json import pkgutil import sys from pathlib import Path -from typing import Any, Dict, Generator, List, Optional +from types import ModuleType +from typing import Any, Dict, Generator, List, Optional, Type import click from aerich import Command as AerichCommand @@ -107,7 +107,7 @@ class Project: } for app_name, app in self._apps.items(): - modules = list(dict.fromkeys(app.model_modules)) + modules = list(app.models.keys()) if modules: config["apps"][app_name] = { "models": modules, @@ -123,8 +123,9 @@ class Project: raise RuntimeError(f"App '{app_label}' is not registered") # Get a fresh copy of the config (without aerich.models anywhere) - tortoise_cfg = copy.deepcopy(self.build_tortoise_config(db_url=db_url)) + tortoise_cfg = self.build_tortoise_config(db_url=db_url) + tortoise_cfg["apps"] = {app_label: tortoise_cfg["apps"][app_label]} # Append aerich.models to the models list of the target app only tortoise_cfg["apps"][app_label]["models"].append("aerich.models") @@ -199,33 +200,17 @@ class App: self.project = project self.name = name - # The list of module paths (e.g. "ohmyapi_auth.models") for Tortoise and Aerich - self.model_modules: List[str] = [] + # Reference to this app's models modules. + self._models: Dict[str, ModuleType] = {} - # The APIRouter - self.router: APIRouter = APIRouter() + # Reference to this app's routes modules. + self._routers: Dict[str, ModuleType] = {} # Import the app, so its __init__.py runs. - importlib.import_module(self.name) + mod: ModuleType = importlib.import_module(name) - # Load the models - try: - models_mod = importlib.import_module(f"{self.name}.models") - self.model_modules.append(f"{self.name}.models") - except ModuleNotFoundError: - pass - - # Locate the APIRouter - try: - routes_mod = importlib.import_module(f"{self.name}.routes") - for attr_name in dir(routes_mod): - if attr_name.startswith("__"): - continue - attr = getattr(routes_mod, attr_name) - if isinstance(attr, APIRouter): - self.router.include_router(attr) - except ModuleNotFoundError: - pass + self.__load_models(f"{self.name}.models") + self.__load_routes(f"{self.name}.routes") def __repr__(self): return json.dumps(self.dict(), indent=2) @@ -233,7 +218,72 @@ class App: def __str__(self): return self.__repr__() - def _serialize_route(self, route): + def __load_models(self, mod_name: str): + try: + importlib.import_module(mod_name) + except ModuleNotFoundError: + print(f"no models detected: {mod_name}") + return + + visited: set[str] = set() + out: Dict[str, ModuleType] = {} + + def walk(mod_name: str): + mod = importlib.import_module(mod_name) + if mod_name in visited: + return + visited.add(mod_name) + + for name, value in vars(mod).copy().items(): + if ( + isinstance(value, type) + and issubclass(value, Model) + and not name == Model.__name__ + ): + out[mod_name] = out.get(mod_name, []) + [value] + + # if it's a package, recurse into submodules + if hasattr(mod, "__path__"): + for _, subname, _ in pkgutil.iter_modules( + mod.__path__, mod.__name__ + "." + ): + walk(subname) + + walk(mod_name) + self._models = out + + def __load_routes(self, mod_name: str): + try: + importlib.import_module(mod_name) + except ModuleNotFound: + print(f"no routes detected: {mod_name}") + return + + visited: set[str] = set() + out: Dict[str, ModuleType] = {} + + def walk(mod_name: str): + mod = importlib.import_module(mod_name) + if mod.__name__ in visited: + return + visited.add(mod.__name__) + + for name, value in vars(mod).copy().items(): + if isinstance(value, APIRouter) and not name == APIRouter.__name__: + out[mod_name] = out.get(mod_name, []) + [value] + + # if it's a package, recurse into submodules + if hasattr(mod, "__path__"): + for _, subname, _ in pkgutil.iter_modules( + mod.__path__, mod.__name__ + "." + ): + submod = importlib.import_module(subname) + walk(submod) + + walk(mod_name) + self._routers = out + + def __serialize_route(self, route): """Convert APIRoute to JSON-serializable dict.""" return { "path": route.path, @@ -248,27 +298,31 @@ class App: "tags": getattr(route, "tags", None), } - def _serialize_router(self): - return [self._serialize_route(route) for route in self.routes] + def __serialize_router(self): + return [self.__serialize_route(route) for route in self.routes] - def dict(self) -> Dict[str, Any]: + @property + def models(self) -> List[ModuleType]: + out = [] + for module in self._models: + for model in self._models[module]: + out.append(model) return { - "models": [m.__name__ for m in self.models], - "routes": self._serialize_router(), + module: out, } - @property - def models(self) -> Generator[Model, None, None]: - for mod in self.model_modules: - models_mod = importlib.import_module(mod) - for obj in models_mod.__dict__.values(): - if ( - isinstance(obj, type) - and getattr(obj, "_meta", None) is not None - and obj.__name__ != "Model" - ): - yield obj - @property def routes(self): - return self.router.routes + router = APIRouter() + for routes_mod in self._routers: + for r in self._routers[routes_mod]: + router.include_router(r) + return router.routes + + def dict(self) -> Dict[str, Any]: + return { + "models": [ + f"{self.name}.{m.__name__}" for m in self.models[f"{self.name}.models"] + ], + "routes": self.__serialize_router(), + } diff --git a/src/ohmyapi/db/model/model.py b/src/ohmyapi/db/model/model.py index 04e5d9c..20359d6 100644 --- a/src/ohmyapi/db/model/model.py +++ b/src/ohmyapi/db/model/model.py @@ -30,10 +30,12 @@ UUID.__get_pydantic_core_schema__ = classmethod(__uuid_schema_monkey_patch) class ModelMeta(type(TortoiseModel)): - def __new__(cls, name, bases, attrs): - new_cls = super().__new__(cls, name, bases, attrs) + def __new__(mcls, name, bases, attrs): + # Grab the Schema class for further processing. + schema_opts = attrs.get("Schema", None) - schema_opts = getattr(new_cls, "Schema", None) + # Let Tortoise's Metaclass do it's thing. + new_cls = super().__new__(mcls, name, bases, attrs) class BoundSchema: def __call__(self, readonly: bool = False): diff --git a/src/ohmyapi/router.py b/src/ohmyapi/router.py index 52116e1..ed96e12 100644 --- a/src/ohmyapi/router.py +++ b/src/ohmyapi/router.py @@ -1,29 +1,2 @@ -import importlib -import inspect -import pkgutil -import pathlib - from fastapi import APIRouter, Depends, HTTPException from http import HTTPStatus -from typing import Generator - - -def package_routers( - package_name: str, - package_path: str | pathlib.Path) -> Generator[APIRouter, None, None]: - """ - Discover all APIRouter instances in submodules of the given package. - """ - if isinstance(package_path, str): - package_path = pathlib.Path(package_path).parent - - for module_info in pkgutil.iter_modules([str(package_path)]): - if module_info.name.startswith("_"): - continue # skip private modules like __init__.py - - module_fqname = f"{package_name}.{module_info.name}" - module = importlib.import_module(module_fqname) - - for _, obj in inspect.getmembers(module): - if isinstance(obj, APIRouter): - yield obj