♻️ Refactor core.runtime

- rewrite how apps are loaded into scope
- rewrite how apps are collected for Tortoise and Aerich
- rewrite how routes are collected for FastAPI
- support packages for models and routes with arbitrary nesting
  - no need to expose models and routes in __init__.py
  - OhMyAPI will recursively iterate through all submodules
This commit is contained in:
Brian Wiborg 2025-10-01 20:43:56 +02:00
parent 642359bdeb
commit 2232726e7c
No known key found for this signature in database
4 changed files with 107 additions and 82 deletions

View file

@ -11,9 +11,7 @@ from typing import List
router = APIRouter(prefix="/tournament") router = APIRouter(prefix="/tournament")
@router.get( @router.get("/", tags=["tournament"], response_model=List[models.Tournament.Schema()])
"/", tags=["tournament"], response_model=List[models.Tournament.Schema()]
)
async def list(): async def list():
"""List all tournaments.""" """List all tournaments."""
return await models.Tournament.Schema().from_queryset(models.Tournament.all()) return await models.Tournament.Schema().from_queryset(models.Tournament.all())
@ -30,9 +28,7 @@ async def post(tournament: models.Tournament.Schema(readonly=True)):
@router.get("/{id}", tags=["tournament"], response_model=models.Tournament.Schema()) @router.get("/{id}", tags=["tournament"], response_model=models.Tournament.Schema())
async def get(id: str): async def get(id: str):
"""Get tournament by id.""" """Get tournament by id."""
return await models.Tournament.Schema().from_queryset( return await models.Tournament.Schema().from_queryset(models.Tournament.get(id=id))
models.Tournament.get(id=id)
)
@router.put( @router.put(

View file

@ -1,12 +1,12 @@
# ohmyapi/core/runtime.py # ohmyapi/core/runtime.py
import copy
import importlib import importlib
import importlib.util import importlib.util
import json import json
import pkgutil import pkgutil
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Generator, List, Optional from types import ModuleType
from typing import Any, Dict, Generator, List, Optional, Type
import click import click
from aerich import Command as AerichCommand from aerich import Command as AerichCommand
@ -107,7 +107,7 @@ class Project:
} }
for app_name, app in self._apps.items(): for app_name, app in self._apps.items():
modules = list(dict.fromkeys(app.model_modules)) modules = list(app.models.keys())
if modules: if modules:
config["apps"][app_name] = { config["apps"][app_name] = {
"models": modules, "models": modules,
@ -123,8 +123,9 @@ class Project:
raise RuntimeError(f"App '{app_label}' is not registered") raise RuntimeError(f"App '{app_label}' is not registered")
# Get a fresh copy of the config (without aerich.models anywhere) # Get a fresh copy of the config (without aerich.models anywhere)
tortoise_cfg = copy.deepcopy(self.build_tortoise_config(db_url=db_url)) tortoise_cfg = self.build_tortoise_config(db_url=db_url)
tortoise_cfg["apps"] = {app_label: tortoise_cfg["apps"][app_label]}
# Append aerich.models to the models list of the target app only # Append aerich.models to the models list of the target app only
tortoise_cfg["apps"][app_label]["models"].append("aerich.models") tortoise_cfg["apps"][app_label]["models"].append("aerich.models")
@ -199,33 +200,17 @@ class App:
self.project = project self.project = project
self.name = name self.name = name
# The list of module paths (e.g. "ohmyapi_auth.models") for Tortoise and Aerich # Reference to this app's models modules.
self.model_modules: List[str] = [] self._models: Dict[str, ModuleType] = {}
# The APIRouter # Reference to this app's routes modules.
self.router: APIRouter = APIRouter() self._routers: Dict[str, ModuleType] = {}
# Import the app, so its __init__.py runs. # Import the app, so its __init__.py runs.
importlib.import_module(self.name) mod: ModuleType = importlib.import_module(name)
# Load the models self.__load_models(f"{self.name}.models")
try: self.__load_routes(f"{self.name}.routes")
models_mod = importlib.import_module(f"{self.name}.models")
self.model_modules.append(f"{self.name}.models")
except ModuleNotFoundError:
pass
# Locate the APIRouter
try:
routes_mod = importlib.import_module(f"{self.name}.routes")
for attr_name in dir(routes_mod):
if attr_name.startswith("__"):
continue
attr = getattr(routes_mod, attr_name)
if isinstance(attr, APIRouter):
self.router.include_router(attr)
except ModuleNotFoundError:
pass
def __repr__(self): def __repr__(self):
return json.dumps(self.dict(), indent=2) return json.dumps(self.dict(), indent=2)
@ -233,7 +218,72 @@ class App:
def __str__(self): def __str__(self):
return self.__repr__() return self.__repr__()
def _serialize_route(self, route): def __load_models(self, mod_name: str):
try:
importlib.import_module(mod_name)
except ModuleNotFoundError:
print(f"no models detected: {mod_name}")
return
visited: set[str] = set()
out: Dict[str, ModuleType] = {}
def walk(mod_name: str):
mod = importlib.import_module(mod_name)
if mod_name in visited:
return
visited.add(mod_name)
for name, value in vars(mod).copy().items():
if (
isinstance(value, type)
and issubclass(value, Model)
and not name == Model.__name__
):
out[mod_name] = out.get(mod_name, []) + [value]
# if it's a package, recurse into submodules
if hasattr(mod, "__path__"):
for _, subname, _ in pkgutil.iter_modules(
mod.__path__, mod.__name__ + "."
):
walk(subname)
walk(mod_name)
self._models = out
def __load_routes(self, mod_name: str):
try:
importlib.import_module(mod_name)
except ModuleNotFound:
print(f"no routes detected: {mod_name}")
return
visited: set[str] = set()
out: Dict[str, ModuleType] = {}
def walk(mod_name: str):
mod = importlib.import_module(mod_name)
if mod.__name__ in visited:
return
visited.add(mod.__name__)
for name, value in vars(mod).copy().items():
if isinstance(value, APIRouter) and not name == APIRouter.__name__:
out[mod_name] = out.get(mod_name, []) + [value]
# if it's a package, recurse into submodules
if hasattr(mod, "__path__"):
for _, subname, _ in pkgutil.iter_modules(
mod.__path__, mod.__name__ + "."
):
submod = importlib.import_module(subname)
walk(submod)
walk(mod_name)
self._routers = out
def __serialize_route(self, route):
"""Convert APIRoute to JSON-serializable dict.""" """Convert APIRoute to JSON-serializable dict."""
return { return {
"path": route.path, "path": route.path,
@ -248,27 +298,31 @@ class App:
"tags": getattr(route, "tags", None), "tags": getattr(route, "tags", None),
} }
def _serialize_router(self): def __serialize_router(self):
return [self._serialize_route(route) for route in self.routes] return [self.__serialize_route(route) for route in self.routes]
def dict(self) -> Dict[str, Any]: @property
def models(self) -> List[ModuleType]:
out = []
for module in self._models:
for model in self._models[module]:
out.append(model)
return { return {
"models": [m.__name__ for m in self.models], module: out,
"routes": self._serialize_router(),
} }
@property
def models(self) -> Generator[Model, None, None]:
for mod in self.model_modules:
models_mod = importlib.import_module(mod)
for obj in models_mod.__dict__.values():
if (
isinstance(obj, type)
and getattr(obj, "_meta", None) is not None
and obj.__name__ != "Model"
):
yield obj
@property @property
def routes(self): def routes(self):
return self.router.routes router = APIRouter()
for routes_mod in self._routers:
for r in self._routers[routes_mod]:
router.include_router(r)
return router.routes
def dict(self) -> Dict[str, Any]:
return {
"models": [
f"{self.name}.{m.__name__}" for m in self.models[f"{self.name}.models"]
],
"routes": self.__serialize_router(),
}

View file

@ -30,10 +30,12 @@ UUID.__get_pydantic_core_schema__ = classmethod(__uuid_schema_monkey_patch)
class ModelMeta(type(TortoiseModel)): class ModelMeta(type(TortoiseModel)):
def __new__(cls, name, bases, attrs): def __new__(mcls, name, bases, attrs):
new_cls = super().__new__(cls, name, bases, attrs) # Grab the Schema class for further processing.
schema_opts = attrs.get("Schema", None)
schema_opts = getattr(new_cls, "Schema", None) # Let Tortoise's Metaclass do it's thing.
new_cls = super().__new__(mcls, name, bases, attrs)
class BoundSchema: class BoundSchema:
def __call__(self, readonly: bool = False): def __call__(self, readonly: bool = False):

View file

@ -1,29 +1,2 @@
import importlib
import inspect
import pkgutil
import pathlib
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from http import HTTPStatus from http import HTTPStatus
from typing import Generator
def package_routers(
package_name: str,
package_path: str | pathlib.Path) -> Generator[APIRouter, None, None]:
"""
Discover all APIRouter instances in submodules of the given package.
"""
if isinstance(package_path, str):
package_path = pathlib.Path(package_path).parent
for module_info in pkgutil.iter_modules([str(package_path)]):
if module_info.name.startswith("_"):
continue # skip private modules like __init__.py
module_fqname = f"{package_name}.{module_info.name}"
module = importlib.import_module(module_fqname)
for _, obj in inspect.getmembers(module):
if isinstance(obj, APIRouter):
yield obj