♻️ Refactor core.runtime
- rewrite how apps are loaded into scope - rewrite how apps are collected for Tortoise and Aerich - rewrite how routes are collected for FastAPI - support packages for models and routes with arbitrary nesting - no need to expose models and routes in __init__.py - OhMyAPI will recursively iterate through all submodules
This commit is contained in:
parent
642359bdeb
commit
2232726e7c
4 changed files with 107 additions and 82 deletions
|
|
@ -11,9 +11,7 @@ from typing import List
|
|||
router = APIRouter(prefix="/tournament")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/", tags=["tournament"], response_model=List[models.Tournament.Schema()]
|
||||
)
|
||||
@router.get("/", tags=["tournament"], response_model=List[models.Tournament.Schema()])
|
||||
async def list():
|
||||
"""List all tournaments."""
|
||||
return await models.Tournament.Schema().from_queryset(models.Tournament.all())
|
||||
|
|
@ -30,9 +28,7 @@ async def post(tournament: models.Tournament.Schema(readonly=True)):
|
|||
@router.get("/{id}", tags=["tournament"], response_model=models.Tournament.Schema())
|
||||
async def get(id: str):
|
||||
"""Get tournament by id."""
|
||||
return await models.Tournament.Schema().from_queryset(
|
||||
models.Tournament.get(id=id)
|
||||
)
|
||||
return await models.Tournament.Schema().from_queryset(models.Tournament.get(id=id))
|
||||
|
||||
|
||||
@router.put(
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
# ohmyapi/core/runtime.py
|
||||
import copy
|
||||
import importlib
|
||||
import importlib.util
|
||||
import json
|
||||
import pkgutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
from types import ModuleType
|
||||
from typing import Any, Dict, Generator, List, Optional, Type
|
||||
|
||||
import click
|
||||
from aerich import Command as AerichCommand
|
||||
|
|
@ -107,7 +107,7 @@ class Project:
|
|||
}
|
||||
|
||||
for app_name, app in self._apps.items():
|
||||
modules = list(dict.fromkeys(app.model_modules))
|
||||
modules = list(app.models.keys())
|
||||
if modules:
|
||||
config["apps"][app_name] = {
|
||||
"models": modules,
|
||||
|
|
@ -123,8 +123,9 @@ class Project:
|
|||
raise RuntimeError(f"App '{app_label}' is not registered")
|
||||
|
||||
# Get a fresh copy of the config (without aerich.models anywhere)
|
||||
tortoise_cfg = copy.deepcopy(self.build_tortoise_config(db_url=db_url))
|
||||
tortoise_cfg = self.build_tortoise_config(db_url=db_url)
|
||||
|
||||
tortoise_cfg["apps"] = {app_label: tortoise_cfg["apps"][app_label]}
|
||||
# Append aerich.models to the models list of the target app only
|
||||
tortoise_cfg["apps"][app_label]["models"].append("aerich.models")
|
||||
|
||||
|
|
@ -199,33 +200,17 @@ class App:
|
|||
self.project = project
|
||||
self.name = name
|
||||
|
||||
# The list of module paths (e.g. "ohmyapi_auth.models") for Tortoise and Aerich
|
||||
self.model_modules: List[str] = []
|
||||
# Reference to this app's models modules.
|
||||
self._models: Dict[str, ModuleType] = {}
|
||||
|
||||
# The APIRouter
|
||||
self.router: APIRouter = APIRouter()
|
||||
# Reference to this app's routes modules.
|
||||
self._routers: Dict[str, ModuleType] = {}
|
||||
|
||||
# Import the app, so its __init__.py runs.
|
||||
importlib.import_module(self.name)
|
||||
mod: ModuleType = importlib.import_module(name)
|
||||
|
||||
# Load the models
|
||||
try:
|
||||
models_mod = importlib.import_module(f"{self.name}.models")
|
||||
self.model_modules.append(f"{self.name}.models")
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
# Locate the APIRouter
|
||||
try:
|
||||
routes_mod = importlib.import_module(f"{self.name}.routes")
|
||||
for attr_name in dir(routes_mod):
|
||||
if attr_name.startswith("__"):
|
||||
continue
|
||||
attr = getattr(routes_mod, attr_name)
|
||||
if isinstance(attr, APIRouter):
|
||||
self.router.include_router(attr)
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
self.__load_models(f"{self.name}.models")
|
||||
self.__load_routes(f"{self.name}.routes")
|
||||
|
||||
def __repr__(self):
|
||||
return json.dumps(self.dict(), indent=2)
|
||||
|
|
@ -233,7 +218,72 @@ class App:
|
|||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def _serialize_route(self, route):
|
||||
def __load_models(self, mod_name: str):
|
||||
try:
|
||||
importlib.import_module(mod_name)
|
||||
except ModuleNotFoundError:
|
||||
print(f"no models detected: {mod_name}")
|
||||
return
|
||||
|
||||
visited: set[str] = set()
|
||||
out: Dict[str, ModuleType] = {}
|
||||
|
||||
def walk(mod_name: str):
|
||||
mod = importlib.import_module(mod_name)
|
||||
if mod_name in visited:
|
||||
return
|
||||
visited.add(mod_name)
|
||||
|
||||
for name, value in vars(mod).copy().items():
|
||||
if (
|
||||
isinstance(value, type)
|
||||
and issubclass(value, Model)
|
||||
and not name == Model.__name__
|
||||
):
|
||||
out[mod_name] = out.get(mod_name, []) + [value]
|
||||
|
||||
# if it's a package, recurse into submodules
|
||||
if hasattr(mod, "__path__"):
|
||||
for _, subname, _ in pkgutil.iter_modules(
|
||||
mod.__path__, mod.__name__ + "."
|
||||
):
|
||||
walk(subname)
|
||||
|
||||
walk(mod_name)
|
||||
self._models = out
|
||||
|
||||
def __load_routes(self, mod_name: str):
|
||||
try:
|
||||
importlib.import_module(mod_name)
|
||||
except ModuleNotFound:
|
||||
print(f"no routes detected: {mod_name}")
|
||||
return
|
||||
|
||||
visited: set[str] = set()
|
||||
out: Dict[str, ModuleType] = {}
|
||||
|
||||
def walk(mod_name: str):
|
||||
mod = importlib.import_module(mod_name)
|
||||
if mod.__name__ in visited:
|
||||
return
|
||||
visited.add(mod.__name__)
|
||||
|
||||
for name, value in vars(mod).copy().items():
|
||||
if isinstance(value, APIRouter) and not name == APIRouter.__name__:
|
||||
out[mod_name] = out.get(mod_name, []) + [value]
|
||||
|
||||
# if it's a package, recurse into submodules
|
||||
if hasattr(mod, "__path__"):
|
||||
for _, subname, _ in pkgutil.iter_modules(
|
||||
mod.__path__, mod.__name__ + "."
|
||||
):
|
||||
submod = importlib.import_module(subname)
|
||||
walk(submod)
|
||||
|
||||
walk(mod_name)
|
||||
self._routers = out
|
||||
|
||||
def __serialize_route(self, route):
|
||||
"""Convert APIRoute to JSON-serializable dict."""
|
||||
return {
|
||||
"path": route.path,
|
||||
|
|
@ -248,27 +298,31 @@ class App:
|
|||
"tags": getattr(route, "tags", None),
|
||||
}
|
||||
|
||||
def _serialize_router(self):
|
||||
return [self._serialize_route(route) for route in self.routes]
|
||||
def __serialize_router(self):
|
||||
return [self.__serialize_route(route) for route in self.routes]
|
||||
|
||||
def dict(self) -> Dict[str, Any]:
|
||||
@property
|
||||
def models(self) -> List[ModuleType]:
|
||||
out = []
|
||||
for module in self._models:
|
||||
for model in self._models[module]:
|
||||
out.append(model)
|
||||
return {
|
||||
"models": [m.__name__ for m in self.models],
|
||||
"routes": self._serialize_router(),
|
||||
module: out,
|
||||
}
|
||||
|
||||
@property
|
||||
def models(self) -> Generator[Model, None, None]:
|
||||
for mod in self.model_modules:
|
||||
models_mod = importlib.import_module(mod)
|
||||
for obj in models_mod.__dict__.values():
|
||||
if (
|
||||
isinstance(obj, type)
|
||||
and getattr(obj, "_meta", None) is not None
|
||||
and obj.__name__ != "Model"
|
||||
):
|
||||
yield obj
|
||||
|
||||
@property
|
||||
def routes(self):
|
||||
return self.router.routes
|
||||
router = APIRouter()
|
||||
for routes_mod in self._routers:
|
||||
for r in self._routers[routes_mod]:
|
||||
router.include_router(r)
|
||||
return router.routes
|
||||
|
||||
def dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"models": [
|
||||
f"{self.name}.{m.__name__}" for m in self.models[f"{self.name}.models"]
|
||||
],
|
||||
"routes": self.__serialize_router(),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,10 +30,12 @@ UUID.__get_pydantic_core_schema__ = classmethod(__uuid_schema_monkey_patch)
|
|||
|
||||
|
||||
class ModelMeta(type(TortoiseModel)):
|
||||
def __new__(cls, name, bases, attrs):
|
||||
new_cls = super().__new__(cls, name, bases, attrs)
|
||||
def __new__(mcls, name, bases, attrs):
|
||||
# Grab the Schema class for further processing.
|
||||
schema_opts = attrs.get("Schema", None)
|
||||
|
||||
schema_opts = getattr(new_cls, "Schema", None)
|
||||
# Let Tortoise's Metaclass do it's thing.
|
||||
new_cls = super().__new__(mcls, name, bases, attrs)
|
||||
|
||||
class BoundSchema:
|
||||
def __call__(self, readonly: bool = False):
|
||||
|
|
|
|||
|
|
@ -1,29 +1,2 @@
|
|||
import importlib
|
||||
import inspect
|
||||
import pkgutil
|
||||
import pathlib
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from http import HTTPStatus
|
||||
from typing import Generator
|
||||
|
||||
|
||||
def package_routers(
|
||||
package_name: str,
|
||||
package_path: str | pathlib.Path) -> Generator[APIRouter, None, None]:
|
||||
"""
|
||||
Discover all APIRouter instances in submodules of the given package.
|
||||
"""
|
||||
if isinstance(package_path, str):
|
||||
package_path = pathlib.Path(package_path).parent
|
||||
|
||||
for module_info in pkgutil.iter_modules([str(package_path)]):
|
||||
if module_info.name.startswith("_"):
|
||||
continue # skip private modules like __init__.py
|
||||
|
||||
module_fqname = f"{package_name}.{module_info.name}"
|
||||
module = importlib.import_module(module_fqname)
|
||||
|
||||
for _, obj in inspect.getmembers(module):
|
||||
if isinstance(obj, APIRouter):
|
||||
yield obj
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue