♻️ Refactor core.runtime
- rewrite how apps are loaded into scope - rewrite how apps are collected for Tortoise and Aerich - rewrite how routes are collected for FastAPI - support packages for models and routes with arbitrary nesting - no need to expose models and routes in __init__.py - OhMyAPI will recursively iterate through all submodules
This commit is contained in:
parent
642359bdeb
commit
2232726e7c
4 changed files with 107 additions and 82 deletions
|
|
@ -11,9 +11,7 @@ from typing import List
|
||||||
router = APIRouter(prefix="/tournament")
|
router = APIRouter(prefix="/tournament")
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get("/", tags=["tournament"], response_model=List[models.Tournament.Schema()])
|
||||||
"/", tags=["tournament"], response_model=List[models.Tournament.Schema()]
|
|
||||||
)
|
|
||||||
async def list():
|
async def list():
|
||||||
"""List all tournaments."""
|
"""List all tournaments."""
|
||||||
return await models.Tournament.Schema().from_queryset(models.Tournament.all())
|
return await models.Tournament.Schema().from_queryset(models.Tournament.all())
|
||||||
|
|
@ -30,9 +28,7 @@ async def post(tournament: models.Tournament.Schema(readonly=True)):
|
||||||
@router.get("/{id}", tags=["tournament"], response_model=models.Tournament.Schema())
|
@router.get("/{id}", tags=["tournament"], response_model=models.Tournament.Schema())
|
||||||
async def get(id: str):
|
async def get(id: str):
|
||||||
"""Get tournament by id."""
|
"""Get tournament by id."""
|
||||||
return await models.Tournament.Schema().from_queryset(
|
return await models.Tournament.Schema().from_queryset(models.Tournament.get(id=id))
|
||||||
models.Tournament.get(id=id)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put(
|
@router.put(
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,12 @@
|
||||||
# ohmyapi/core/runtime.py
|
# ohmyapi/core/runtime.py
|
||||||
import copy
|
|
||||||
import importlib
|
import importlib
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import json
|
import json
|
||||||
import pkgutil
|
import pkgutil
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Generator, List, Optional
|
from types import ModuleType
|
||||||
|
from typing import Any, Dict, Generator, List, Optional, Type
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from aerich import Command as AerichCommand
|
from aerich import Command as AerichCommand
|
||||||
|
|
@ -107,7 +107,7 @@ class Project:
|
||||||
}
|
}
|
||||||
|
|
||||||
for app_name, app in self._apps.items():
|
for app_name, app in self._apps.items():
|
||||||
modules = list(dict.fromkeys(app.model_modules))
|
modules = list(app.models.keys())
|
||||||
if modules:
|
if modules:
|
||||||
config["apps"][app_name] = {
|
config["apps"][app_name] = {
|
||||||
"models": modules,
|
"models": modules,
|
||||||
|
|
@ -123,8 +123,9 @@ class Project:
|
||||||
raise RuntimeError(f"App '{app_label}' is not registered")
|
raise RuntimeError(f"App '{app_label}' is not registered")
|
||||||
|
|
||||||
# Get a fresh copy of the config (without aerich.models anywhere)
|
# Get a fresh copy of the config (without aerich.models anywhere)
|
||||||
tortoise_cfg = copy.deepcopy(self.build_tortoise_config(db_url=db_url))
|
tortoise_cfg = self.build_tortoise_config(db_url=db_url)
|
||||||
|
|
||||||
|
tortoise_cfg["apps"] = {app_label: tortoise_cfg["apps"][app_label]}
|
||||||
# Append aerich.models to the models list of the target app only
|
# Append aerich.models to the models list of the target app only
|
||||||
tortoise_cfg["apps"][app_label]["models"].append("aerich.models")
|
tortoise_cfg["apps"][app_label]["models"].append("aerich.models")
|
||||||
|
|
||||||
|
|
@ -199,33 +200,17 @@ class App:
|
||||||
self.project = project
|
self.project = project
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
# The list of module paths (e.g. "ohmyapi_auth.models") for Tortoise and Aerich
|
# Reference to this app's models modules.
|
||||||
self.model_modules: List[str] = []
|
self._models: Dict[str, ModuleType] = {}
|
||||||
|
|
||||||
# The APIRouter
|
# Reference to this app's routes modules.
|
||||||
self.router: APIRouter = APIRouter()
|
self._routers: Dict[str, ModuleType] = {}
|
||||||
|
|
||||||
# Import the app, so its __init__.py runs.
|
# Import the app, so its __init__.py runs.
|
||||||
importlib.import_module(self.name)
|
mod: ModuleType = importlib.import_module(name)
|
||||||
|
|
||||||
# Load the models
|
self.__load_models(f"{self.name}.models")
|
||||||
try:
|
self.__load_routes(f"{self.name}.routes")
|
||||||
models_mod = importlib.import_module(f"{self.name}.models")
|
|
||||||
self.model_modules.append(f"{self.name}.models")
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Locate the APIRouter
|
|
||||||
try:
|
|
||||||
routes_mod = importlib.import_module(f"{self.name}.routes")
|
|
||||||
for attr_name in dir(routes_mod):
|
|
||||||
if attr_name.startswith("__"):
|
|
||||||
continue
|
|
||||||
attr = getattr(routes_mod, attr_name)
|
|
||||||
if isinstance(attr, APIRouter):
|
|
||||||
self.router.include_router(attr)
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return json.dumps(self.dict(), indent=2)
|
return json.dumps(self.dict(), indent=2)
|
||||||
|
|
@ -233,7 +218,72 @@ class App:
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.__repr__()
|
return self.__repr__()
|
||||||
|
|
||||||
def _serialize_route(self, route):
|
def __load_models(self, mod_name: str):
|
||||||
|
try:
|
||||||
|
importlib.import_module(mod_name)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
print(f"no models detected: {mod_name}")
|
||||||
|
return
|
||||||
|
|
||||||
|
visited: set[str] = set()
|
||||||
|
out: Dict[str, ModuleType] = {}
|
||||||
|
|
||||||
|
def walk(mod_name: str):
|
||||||
|
mod = importlib.import_module(mod_name)
|
||||||
|
if mod_name in visited:
|
||||||
|
return
|
||||||
|
visited.add(mod_name)
|
||||||
|
|
||||||
|
for name, value in vars(mod).copy().items():
|
||||||
|
if (
|
||||||
|
isinstance(value, type)
|
||||||
|
and issubclass(value, Model)
|
||||||
|
and not name == Model.__name__
|
||||||
|
):
|
||||||
|
out[mod_name] = out.get(mod_name, []) + [value]
|
||||||
|
|
||||||
|
# if it's a package, recurse into submodules
|
||||||
|
if hasattr(mod, "__path__"):
|
||||||
|
for _, subname, _ in pkgutil.iter_modules(
|
||||||
|
mod.__path__, mod.__name__ + "."
|
||||||
|
):
|
||||||
|
walk(subname)
|
||||||
|
|
||||||
|
walk(mod_name)
|
||||||
|
self._models = out
|
||||||
|
|
||||||
|
def __load_routes(self, mod_name: str):
|
||||||
|
try:
|
||||||
|
importlib.import_module(mod_name)
|
||||||
|
except ModuleNotFound:
|
||||||
|
print(f"no routes detected: {mod_name}")
|
||||||
|
return
|
||||||
|
|
||||||
|
visited: set[str] = set()
|
||||||
|
out: Dict[str, ModuleType] = {}
|
||||||
|
|
||||||
|
def walk(mod_name: str):
|
||||||
|
mod = importlib.import_module(mod_name)
|
||||||
|
if mod.__name__ in visited:
|
||||||
|
return
|
||||||
|
visited.add(mod.__name__)
|
||||||
|
|
||||||
|
for name, value in vars(mod).copy().items():
|
||||||
|
if isinstance(value, APIRouter) and not name == APIRouter.__name__:
|
||||||
|
out[mod_name] = out.get(mod_name, []) + [value]
|
||||||
|
|
||||||
|
# if it's a package, recurse into submodules
|
||||||
|
if hasattr(mod, "__path__"):
|
||||||
|
for _, subname, _ in pkgutil.iter_modules(
|
||||||
|
mod.__path__, mod.__name__ + "."
|
||||||
|
):
|
||||||
|
submod = importlib.import_module(subname)
|
||||||
|
walk(submod)
|
||||||
|
|
||||||
|
walk(mod_name)
|
||||||
|
self._routers = out
|
||||||
|
|
||||||
|
def __serialize_route(self, route):
|
||||||
"""Convert APIRoute to JSON-serializable dict."""
|
"""Convert APIRoute to JSON-serializable dict."""
|
||||||
return {
|
return {
|
||||||
"path": route.path,
|
"path": route.path,
|
||||||
|
|
@ -248,27 +298,31 @@ class App:
|
||||||
"tags": getattr(route, "tags", None),
|
"tags": getattr(route, "tags", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _serialize_router(self):
|
def __serialize_router(self):
|
||||||
return [self._serialize_route(route) for route in self.routes]
|
return [self.__serialize_route(route) for route in self.routes]
|
||||||
|
|
||||||
def dict(self) -> Dict[str, Any]:
|
@property
|
||||||
|
def models(self) -> List[ModuleType]:
|
||||||
|
out = []
|
||||||
|
for module in self._models:
|
||||||
|
for model in self._models[module]:
|
||||||
|
out.append(model)
|
||||||
return {
|
return {
|
||||||
"models": [m.__name__ for m in self.models],
|
module: out,
|
||||||
"routes": self._serialize_router(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
|
||||||
def models(self) -> Generator[Model, None, None]:
|
|
||||||
for mod in self.model_modules:
|
|
||||||
models_mod = importlib.import_module(mod)
|
|
||||||
for obj in models_mod.__dict__.values():
|
|
||||||
if (
|
|
||||||
isinstance(obj, type)
|
|
||||||
and getattr(obj, "_meta", None) is not None
|
|
||||||
and obj.__name__ != "Model"
|
|
||||||
):
|
|
||||||
yield obj
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def routes(self):
|
def routes(self):
|
||||||
return self.router.routes
|
router = APIRouter()
|
||||||
|
for routes_mod in self._routers:
|
||||||
|
for r in self._routers[routes_mod]:
|
||||||
|
router.include_router(r)
|
||||||
|
return router.routes
|
||||||
|
|
||||||
|
def dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"models": [
|
||||||
|
f"{self.name}.{m.__name__}" for m in self.models[f"{self.name}.models"]
|
||||||
|
],
|
||||||
|
"routes": self.__serialize_router(),
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -30,10 +30,12 @@ UUID.__get_pydantic_core_schema__ = classmethod(__uuid_schema_monkey_patch)
|
||||||
|
|
||||||
|
|
||||||
class ModelMeta(type(TortoiseModel)):
|
class ModelMeta(type(TortoiseModel)):
|
||||||
def __new__(cls, name, bases, attrs):
|
def __new__(mcls, name, bases, attrs):
|
||||||
new_cls = super().__new__(cls, name, bases, attrs)
|
# Grab the Schema class for further processing.
|
||||||
|
schema_opts = attrs.get("Schema", None)
|
||||||
|
|
||||||
schema_opts = getattr(new_cls, "Schema", None)
|
# Let Tortoise's Metaclass do it's thing.
|
||||||
|
new_cls = super().__new__(mcls, name, bases, attrs)
|
||||||
|
|
||||||
class BoundSchema:
|
class BoundSchema:
|
||||||
def __call__(self, readonly: bool = False):
|
def __call__(self, readonly: bool = False):
|
||||||
|
|
|
||||||
|
|
@ -1,29 +1,2 @@
|
||||||
import importlib
|
|
||||||
import inspect
|
|
||||||
import pkgutil
|
|
||||||
import pathlib
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Generator
|
|
||||||
|
|
||||||
|
|
||||||
def package_routers(
|
|
||||||
package_name: str,
|
|
||||||
package_path: str | pathlib.Path) -> Generator[APIRouter, None, None]:
|
|
||||||
"""
|
|
||||||
Discover all APIRouter instances in submodules of the given package.
|
|
||||||
"""
|
|
||||||
if isinstance(package_path, str):
|
|
||||||
package_path = pathlib.Path(package_path).parent
|
|
||||||
|
|
||||||
for module_info in pkgutil.iter_modules([str(package_path)]):
|
|
||||||
if module_info.name.startswith("_"):
|
|
||||||
continue # skip private modules like __init__.py
|
|
||||||
|
|
||||||
module_fqname = f"{package_name}.{module_info.name}"
|
|
||||||
module = importlib.import_module(module_fqname)
|
|
||||||
|
|
||||||
for _, obj in inspect.getmembers(module):
|
|
||||||
if isinstance(obj, APIRouter):
|
|
||||||
yield obj
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue