💄 Introduce black & isort

This commit is contained in:
Brian Wiborg 2025-09-28 15:41:01 +02:00
parent 9becfc857d
commit 6a90e4a44a
No known key found for this signature in database
14 changed files with 311 additions and 89 deletions

136
poetry.lock generated
View file

@ -163,6 +163,52 @@ files = [
[package.dependencies]
colorama = {version = "*", markers = "platform_system == \"Windows\""}
[[package]]
name = "black"
version = "25.9.0"
description = "The uncompromising code formatter."
optional = false
python-versions = ">=3.9"
groups = ["dev"]
files = [
{file = "black-25.9.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ce41ed2614b706fd55fd0b4a6909d06b5bab344ffbfadc6ef34ae50adba3d4f7"},
{file = "black-25.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2ab0ce111ef026790e9b13bd216fa7bc48edd934ffc4cbf78808b235793cbc92"},
{file = "black-25.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f96b6726d690c96c60ba682955199f8c39abc1ae0c3a494a9c62c0184049a713"},
{file = "black-25.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:d119957b37cc641596063cd7db2656c5be3752ac17877017b2ffcdb9dfc4d2b1"},
{file = "black-25.9.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:456386fe87bad41b806d53c062e2974615825c7a52159cde7ccaeb0695fa28fa"},
{file = "black-25.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a16b14a44c1af60a210d8da28e108e13e75a284bf21a9afa6b4571f96ab8bb9d"},
{file = "black-25.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:aaf319612536d502fdd0e88ce52d8f1352b2c0a955cc2798f79eeca9d3af0608"},
{file = "black-25.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:c0372a93e16b3954208417bfe448e09b0de5cc721d521866cd9e0acac3c04a1f"},
{file = "black-25.9.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:1b9dc70c21ef8b43248f1d86aedd2aaf75ae110b958a7909ad8463c4aa0880b0"},
{file = "black-25.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8e46eecf65a095fa62e53245ae2795c90bdecabd53b50c448d0a8bcd0d2e74c4"},
{file = "black-25.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9101ee58ddc2442199a25cb648d46ba22cd580b00ca4b44234a324e3ec7a0f7e"},
{file = "black-25.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:77e7060a00c5ec4b3367c55f39cf9b06e68965a4f2e61cecacd6d0d9b7ec945a"},
{file = "black-25.9.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0172a012f725b792c358d57fe7b6b6e8e67375dd157f64fa7a3097b3ed3e2175"},
{file = "black-25.9.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3bec74ee60f8dfef564b573a96b8930f7b6a538e846123d5ad77ba14a8d7a64f"},
{file = "black-25.9.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b756fc75871cb1bcac5499552d771822fd9db5a2bb8db2a7247936ca48f39831"},
{file = "black-25.9.0-cp313-cp313-win_amd64.whl", hash = "sha256:846d58e3ce7879ec1ffe816bb9df6d006cd9590515ed5d17db14e17666b2b357"},
{file = "black-25.9.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ef69351df3c84485a8beb6f7b8f9721e2009e20ef80a8d619e2d1788b7816d47"},
{file = "black-25.9.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e3c1f4cd5e93842774d9ee4ef6cd8d17790e65f44f7cdbaab5f2cf8ccf22a823"},
{file = "black-25.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:154b06d618233fe468236ba1f0e40823d4eb08b26f5e9261526fde34916b9140"},
{file = "black-25.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:e593466de7b998374ea2585a471ba90553283fb9beefcfa430d84a2651ed5933"},
{file = "black-25.9.0-py3-none-any.whl", hash = "sha256:474b34c1342cdc157d307b56c4c65bce916480c4a8f6551fdc6bf9b486a7c4ae"},
{file = "black-25.9.0.tar.gz", hash = "sha256:0474bca9a0dd1b51791fcc507a4e02078a1c63f6d4e4ae5544b9848c7adfb619"},
]
[package.dependencies]
click = ">=8.0.0"
mypy-extensions = ">=0.4.3"
packaging = ">=22.0"
pathspec = ">=0.9.0"
platformdirs = ">=2"
pytokens = ">=0.1.10"
[package.extras]
colorama = ["colorama (>=0.4.3)"]
d = ["aiohttp (>=3.10)"]
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
uvloop = ["uvloop (>=0.15.2)"]
[[package]]
name = "certifi"
version = "2025.8.3"
@ -367,7 +413,7 @@ version = "8.3.0"
description = "Composable command line interface toolkit"
optional = false
python-versions = ">=3.10"
groups = ["main"]
groups = ["main", "dev"]
files = [
{file = "click-8.3.0-py3-none-any.whl", hash = "sha256:9b9f285302c6e3064f4330c05f05b81945b2a39544279343e6e7c5f27a9baddc"},
{file = "click-8.3.0.tar.gz", hash = "sha256:e7b8232224eba16f4ebe410c25ced9f7875cb5f3263ffc93cc3e8da705e229c4"},
@ -383,11 +429,11 @@ description = "Cross-platform colored terminal text."
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
groups = ["main", "dev"]
markers = "platform_system == \"Windows\" or sys_platform == \"win32\""
files = [
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
]
markers = {main = "platform_system == \"Windows\" or sys_platform == \"win32\"", dev = "sys_platform == \"win32\""}
[[package]]
name = "crypto"
@ -559,6 +605,22 @@ files = [
{file = "iso8601-2.1.0.tar.gz", hash = "sha256:6b1d3829ee8921c4301998c909f7829fa9ed3cbdac0d3b16af2d743aed1ba8df"},
]
[[package]]
name = "isort"
version = "6.0.1"
description = "A Python utility / library to sort Python imports."
optional = false
python-versions = ">=3.9.0"
groups = ["dev"]
files = [
{file = "isort-6.0.1-py3-none-any.whl", hash = "sha256:2dc5d7f65c9678d94c88dfc29161a320eec67328bc97aad576874cb4be1e9615"},
{file = "isort-6.0.1.tar.gz", hash = "sha256:1cb5df28dfbc742e490c5e41bad6da41b805b0a8be7bc93cd0fb2a8a890ac450"},
]
[package.extras]
colors = ["colorama"]
plugins = ["setuptools"]
[[package]]
name = "jedi"
version = "0.19.2"
@ -719,6 +781,18 @@ files = [
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
]
[[package]]
name = "mypy-extensions"
version = "1.1.0"
description = "Type system extensions for programs checked with the mypy type checker."
optional = false
python-versions = ">=3.8"
groups = ["dev"]
files = [
{file = "mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505"},
{file = "mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558"},
]
[[package]]
name = "naked"
version = "0.1.32"
@ -735,6 +809,18 @@ files = [
pyyaml = "*"
requests = "*"
[[package]]
name = "packaging"
version = "25.0"
description = "Core utilities for Python packages"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
files = [
{file = "packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484"},
{file = "packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f"},
]
[[package]]
name = "parso"
version = "0.8.5"
@ -769,6 +855,18 @@ bcrypt = ["bcrypt (>=3.1.0)"]
build-docs = ["cloud-sptheme (>=1.10.1)", "sphinx (>=1.6)", "sphinxcontrib-fulltoc (>=1.2.0)"]
totp = ["cryptography"]
[[package]]
name = "pathspec"
version = "0.12.1"
description = "Utility library for gitignore style pattern matching of file paths."
optional = false
python-versions = ">=3.8"
groups = ["dev"]
files = [
{file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"},
{file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"},
]
[[package]]
name = "pexpect"
version = "4.9.0"
@ -785,6 +883,23 @@ files = [
[package.dependencies]
ptyprocess = ">=0.5"
[[package]]
name = "platformdirs"
version = "4.4.0"
description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`."
optional = false
python-versions = ">=3.9"
groups = ["dev"]
files = [
{file = "platformdirs-4.4.0-py3-none-any.whl", hash = "sha256:abd01743f24e5287cd7a5db3752faf1a2d65353f38ec26d98e25a6db65958c85"},
{file = "platformdirs-4.4.0.tar.gz", hash = "sha256:ca753cf4d81dc309bc67b0ea38fd15dc97bc30ce419a7f58d13eb3bf14c4febf"},
]
[package.extras]
docs = ["furo (>=2024.8.6)", "proselint (>=0.14)", "sphinx (>=8.1.3)", "sphinx-autodoc-typehints (>=3)"]
test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=8.3.4)", "pytest-cov (>=6)", "pytest-mock (>=3.14)"]
type = ["mypy (>=1.14.1)"]
[[package]]
name = "prompt-toolkit"
version = "3.0.52"
@ -1033,6 +1148,21 @@ files = [
{file = "python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13"},
]
[[package]]
name = "pytokens"
version = "0.1.10"
description = "A Fast, spec compliant Python 3.12+ tokenizer that runs on older Pythons."
optional = false
python-versions = ">=3.8"
groups = ["dev"]
files = [
{file = "pytokens-0.1.10-py3-none-any.whl", hash = "sha256:db7b72284e480e69fb085d9f251f66b3d2df8b7166059261258ff35f50fb711b"},
{file = "pytokens-0.1.10.tar.gz", hash = "sha256:c9a4bfa0be1d26aebce03e6884ba454e842f186a59ea43a6d3b25af58223c044"},
]
[package.extras]
dev = ["black", "build", "mypy", "pytest", "pytest-cov", "setuptools", "tox", "twine", "wheel"]
[[package]]
name = "pytz"
version = "2025.2"
@ -1378,4 +1508,4 @@ auth = ["argon2-cffi", "crypto", "passlib", "pyjwt", "python-multipart"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.13"
content-hash = "16ae1b48820c723ca784e71a454fb4b686c94fc9f01fa81b086df5dcaf512074"
content-hash = "145508f708df01d84d998947a87b95cfc269e197eb8bc7467e9748a3b8e210e5"

View file

@ -27,6 +27,8 @@ dependencies = [
[tool.poetry.group.dev.dependencies]
ipython = ">=9.5.0,<10.0.0"
black = "^25.9.0"
isort = "^6.0.1"
[project.optional-dependencies]
auth = ["passlib", "pyjwt", "crypto", "argon2-cffi", "python-multipart"]
@ -36,3 +38,23 @@ packages = [ { include = "ohmyapi", from = "src" } ]
[project.scripts]
ohmyapi = "ohmyapi.cli:app"
[tool.black]
line-length = 88
target-version = ['py39', 'py310', 'py311', 'py312', 'py313']
include = '\.pyi?$'
exclude = '''
/(
\.git
| \.venv
| build
| dist
)/
'''
[tool.isort]
profile = "black" # makes imports compatible with black
line_length = 88
multi_line_output = 3
include_trailing_comma = true

View file

@ -1,2 +1 @@
from . import db

View file

@ -1,4 +1 @@
from . import models
from . import routes
from . import permissions
from . import models, permissions, routes

View file

@ -1,11 +1,12 @@
from ohmyapi.router import HTTPException
from ohmyapi.db import Model, field, pre_save, pre_delete
from functools import wraps
from typing import Optional, List
from typing import List, Optional
from uuid import UUID
from passlib.context import CryptContext
from tortoise.contrib.pydantic import pydantic_queryset_creator
from uuid import UUID
from ohmyapi.db import Model, field, pre_delete, pre_save
from ohmyapi.router import HTTPException
pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")
@ -22,10 +23,12 @@ class User(Model):
password_hash: str = field.CharField(max_length=128)
is_admin: bool = field.BooleanField(default=False)
is_staff: bool = field.BooleanField(default=False)
groups: field.ManyToManyRelation[Group] = field.ManyToManyField("ohmyapi_auth.Group", related_name="users", through='user_groups')
groups: field.ManyToManyRelation[Group] = field.ManyToManyField(
"ohmyapi_auth.Group", related_name="users", through="user_groups"
)
class Schema:
exclude = 'password_hash',
exclude = ("password_hash",)
def set_password(self, raw_password: str) -> None:
"""Hash and store the password."""

View file

@ -1,8 +1,8 @@
from .routes import (
get_token,
get_current_user,
require_authenticated,
get_token,
require_admin,
require_staff,
require_authenticated,
require_group,
require_staff,
)

View file

@ -3,13 +3,12 @@ from enum import Enum
from typing import Any, Dict, List
import jwt
import settings
from fastapi import APIRouter, Body, Depends, Header, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel
from ohmyapi.builtin.auth.models import User, Group
import settings
from ohmyapi.builtin.auth.models import Group, User
# Router
router = APIRouter(prefix="/auth", tags=["auth"])
@ -17,8 +16,12 @@ router = APIRouter(prefix="/auth", tags=["auth"])
# Secrets & config (should come from settings/env in real projects)
JWT_SECRET = getattr(settings, "JWT_SECRET", "changeme")
JWT_ALGORITHM = getattr(settings, "JWT_ALGORITHM", "HS256")
ACCESS_TOKEN_EXPIRE_SECONDS = getattr(settings, "JWT_ACCESS_TOKEN_EXPIRE_SECONDS", 15 * 60)
REFRESH_TOKEN_EXPIRE_SECONDS = getattr(settings, "JWT_REFRESH_TOKEN_EXPIRE_SECONDS", 7 * 24 * 60 * 60)
ACCESS_TOKEN_EXPIRE_SECONDS = getattr(
settings, "JWT_ACCESS_TOKEN_EXPIRE_SECONDS", 15 * 60
)
REFRESH_TOKEN_EXPIRE_SECONDS = getattr(
settings, "JWT_REFRESH_TOKEN_EXPIRE_SECONDS", 7 * 24 * 60 * 60
)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")
@ -36,30 +39,38 @@ def decode_token(token: str) -> Dict:
try:
return jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
)
except jwt.InvalidTokenError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
)
class TokenType(str, Enum):
"""
Helper for indicating the token type when generating claims.
"""
access = "access"
refresh = "refresh"
def claims(token_type: TokenType, user: User, groups: List[Group] = []) -> Dict[str, Any]:
def claims(
token_type: TokenType, user: User, groups: List[Group] = []
) -> Dict[str, Any]:
return {
'type': token_type,
'sub': str(user.id),
'user': {
'username': user.username,
'email': user.email,
"type": token_type,
"sub": str(user.id),
"user": {
"username": user.username,
"email": user.email,
},
'roles': [g.name for g in groups]
"roles": [g.name for g in groups],
}
async def get_token(token: str = Depends(oauth2_scheme)) -> Dict:
"""Dependency: token introspection"""
payload = decode_token(token)
@ -71,11 +82,15 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
payload = decode_token(token)
user_id = payload.get("sub")
if user_id is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload"
)
user = await User.filter(id=user_id).first()
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found"
)
return user
@ -101,15 +116,13 @@ async def require_staff(current_user: User = Depends(get_current_user)) -> User:
async def require_group(
group_name: str,
current_user: User = Depends(get_current_user)
group_name: str, current_user: User = Depends(get_current_user)
) -> User:
"""Ensure the current user belongs to the given group."""
user_groups = await current_user.groups.all()
if not any(g.name == group_name for g in user_groups):
raise HTTPException(
status_code=403,
detail=f"User must belong to group '{group_name}'"
status_code=403, detail=f"User must belong to group '{group_name}'"
)
return current_user
@ -124,15 +137,21 @@ async def login(form_data: LoginRequest = Body(...)):
"""Login with username & password, returns access and refresh tokens."""
user = await User.authenticate(form_data.username, form_data.password)
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials"
)
access_token = create_token(claims(TokenType.access, user), ACCESS_TOKEN_EXPIRE_SECONDS)
refresh_token = create_token(claims(TokenType.refresh, user), REFRESH_TOKEN_EXPIRE_SECONDS)
access_token = create_token(
claims(TokenType.access, user), ACCESS_TOKEN_EXPIRE_SECONDS
)
refresh_token = create_token(
claims(TokenType.refresh, user), REFRESH_TOKEN_EXPIRE_SECONDS
)
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer"
"token_type": "bearer",
}
@ -141,14 +160,20 @@ async def refresh_token(refresh_token: str):
"""Exchange refresh token for new access token."""
payload = decode_token(refresh_token)
if payload.get("type") != "refresh":
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
)
user_id = payload.get("sub")
user = await User.filter(id=user_id).first()
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found"
)
new_access = create_token(claims(TokenType.access, user), ACCESS_TOKEN_EXPIRE_SECONDS)
new_access = create_token(
claims(TokenType.access, user), ACCESS_TOKEN_EXPIRE_SECONDS
)
return {"access_token": new_access, "token_type": "bearer"}
@ -161,4 +186,3 @@ async def introspect(token: Dict = Depends(get_token)):
async def me(user: User = Depends(get_current_user)):
"""Return the currently authenticated user."""
return User.Schema.one.from_orm(user)

View file

@ -2,14 +2,17 @@ import asyncio
import atexit
import importlib
import sys
from getpass import getpass
from pathlib import Path
import typer
import uvicorn
from getpass import getpass
from ohmyapi.core import scaffolding, runtime
from pathlib import Path
from ohmyapi.core import runtime, scaffolding
app = typer.Typer(help="OhMyAPI — Django-flavored FastAPI scaffolding with tightly integrated TortoiseORM.")
app = typer.Typer(
help="OhMyAPI — Django-flavored FastAPI scaffolding with tightly integrated TortoiseORM."
)
@app.command()
@ -78,6 +81,7 @@ def shell(root: str = "."):
start_ipython(argv=[], user_ns=shell_vars, config=c)
except ImportError:
import code
code.interact(local=shell_vars, banner=banner)
finally:
loop.run_until_complete(cleanup())
@ -120,11 +124,15 @@ def createsuperuser(root: str = "."):
project_path = Path(root).resolve()
project = runtime.Project(project_path)
if not project.is_app_installed("ohmyapi_auth"):
print("Auth app not installed! Please add 'ohmyapi_auth' to your INSTALLED_APPS.")
print(
"Auth app not installed! Please add 'ohmyapi_auth' to your INSTALLED_APPS."
)
return
import asyncio
import ohmyapi_auth
email = input("E-Mail: ")
username = input("Username: ")
password1, password2 = "foo", "bar"
@ -133,9 +141,10 @@ def createsuperuser(root: str = "."):
password2 = getpass("Repeat Password: ")
if password1 != password2:
print("Passwords didn't match!")
user = ohmyapi_auth.models.User(email=email, username=username, is_staff=True, is_admin=True)
user = ohmyapi_auth.models.User(
email=email, username=username, is_staff=True, is_admin=True
)
user.set_password(password1)
asyncio.run(project.init_orm())
asyncio.run(user.save())
asyncio.run(project.close_orm())

View file

@ -11,8 +11,9 @@ from typing import Any, Dict, Generator, List, Optional
import click
from aerich import Command as AerichCommand
from aerich.exceptions import NotInitedError
from fastapi import APIRouter, FastAPI
from tortoise import Tortoise
from fastapi import FastAPI, APIRouter
from ohmyapi.db.model import Model
@ -44,7 +45,9 @@ class Project:
orig = importlib.import_module(full)
sys.modules[alias] = orig
try:
sys.modules[f"{alias}.models"] = importlib.import_module(f"{full}.models")
sys.modules[f"{alias}.models"] = importlib.import_module(
f"{full}.models"
)
except ModuleNotFoundError:
pass
@ -52,7 +55,9 @@ class Project:
try:
self.settings = importlib.import_module("settings")
except Exception as e:
raise RuntimeError(f"Failed to import project settings from {self.project_path}") from e
raise RuntimeError(
f"Failed to import project settings from {self.project_path}"
) from e
# Load installed apps
for app_name in getattr(self.settings, "INSTALLED_APPS", []):
@ -104,11 +109,16 @@ class Project:
for app_name, app in self._apps.items():
modules = list(dict.fromkeys(app.model_modules))
if modules:
config["apps"][app_name] = {"models": modules, "default_connection": "default"}
config["apps"][app_name] = {
"models": modules,
"default_connection": "default",
}
return config
def build_aerich_command(self, app_label: str, db_url: Optional[str] = None) -> AerichCommand:
def build_aerich_command(
self, app_label: str, db_url: Optional[str] = None
) -> AerichCommand:
# Resolve label to flat_label
if app_label in self._apps:
flat_label = app_label
@ -129,7 +139,7 @@ class Project:
return AerichCommand(
tortoise_config=tortoise_cfg,
app=flat_label,
location=str(self.migrations_dir)
location=str(self.migrations_dir),
)
# --- ORM lifecycle ---
@ -144,7 +154,9 @@ class Project:
await Tortoise.close_connections()
# --- Migration helpers ---
async def makemigrations(self, app_label: str, name: str = "auto", db_url: Optional[str] = None) -> None:
async def makemigrations(
self, app_label: str, name: str = "auto", db_url: Optional[str] = None
) -> None:
cmd = self.build_aerich_command(app_label, db_url=db_url)
async with cmd as c:
await c.init()
@ -158,7 +170,9 @@ class Project:
await c.init_db(safe=True)
await c.migrate(name=name)
async def migrate(self, app_label: Optional[str] = None, db_url: Optional[str] = None) -> None:
async def migrate(
self, app_label: Optional[str] = None, db_url: Optional[str] = None
) -> None:
labels: List[str]
if app_label:
if app_label in self._apps:
@ -231,8 +245,11 @@ class App:
"name": route.name,
"methods": list(route.methods),
"endpoint": route.endpoint.__name__, # just the function name
"response_model": getattr(route, "response_model", None).__name__
if getattr(route, "response_model", None) else None,
"response_model": (
getattr(route, "response_model", None).__name__
if getattr(route, "response_model", None)
else None
),
"tags": getattr(route, "tags", None),
}
@ -241,8 +258,8 @@ class App:
def dict(self) -> Dict[str, Any]:
return {
'models': [m.__name__ for m in self.models],
'routes': self._serialize_router(),
"models": [m.__name__ for m in self.models],
"routes": self._serialize_router(),
}
@property
@ -250,10 +267,13 @@ class App:
for mod in self.model_modules:
models_mod = importlib.import_module(mod)
for obj in models_mod.__dict__.values():
if isinstance(obj, type) and getattr(obj, "_meta", None) is not None and obj.__name__ != 'Model':
if (
isinstance(obj, type)
and getattr(obj, "_meta", None) is not None
and obj.__name__ != "Model"
):
yield obj
@property
def routes(self):
return self.router.routes

View file

@ -1,4 +1,5 @@
from pathlib import Path
from jinja2 import Environment, FileSystemLoader
# Base templates directory
@ -8,14 +9,21 @@ env = Environment(loader=FileSystemLoader(str(TEMPLATE_DIR)))
def render_template_file(template_path: Path, context: dict, output_path: Path):
"""Render a single Jinja2 template file to disk."""
template = env.get_template(str(template_path.relative_to(TEMPLATE_DIR)).replace("\\", "/"))
template = env.get_template(
str(template_path.relative_to(TEMPLATE_DIR)).replace("\\", "/")
)
content = template.render(**context)
output_path.parent.mkdir(exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
f.write(content)
def render_template_dir(template_subdir: str, target_dir: Path, context: dict, subdir_name: str | None = None):
def render_template_dir(
template_subdir: str,
target_dir: Path,
context: dict,
subdir_name: str | None = None,
):
"""
Recursively render all *.j2 templates from TEMPLATE_DIR/template_subdir into target_dir.
If subdir_name is given, files are placed inside target_dir/subdir_name.
@ -23,14 +31,18 @@ def render_template_dir(template_subdir: str, target_dir: Path, context: dict, s
template_dir = TEMPLATE_DIR / template_subdir
for root, _, files in template_dir.walk():
root_path = Path(root)
rel_root = root_path.relative_to(template_dir) # path relative to template_subdir
rel_root = root_path.relative_to(
template_dir
) # path relative to template_subdir
for f in files:
if not f.endswith(".j2"):
continue
template_rel_path = rel_root / f
output_rel_path = Path(*template_rel_path.parts).with_suffix("") # remove .j2
output_rel_path = Path(*template_rel_path.parts).with_suffix(
""
) # remove .j2
# optionally wrap in subdir_name
if subdir_name:
@ -54,7 +66,11 @@ def startapp(name: str, project: str):
"""Create a new app inside a project: templates go into <project_dir>/<name>/"""
target_dir = Path(project)
target_dir.mkdir(exist_ok=True)
render_template_dir("app", target_dir, {"project_name": target_dir.resolve().name, "app_name": name}, subdir_name=name)
render_template_dir(
"app",
target_dir,
{"project_name": target_dir.resolve().name, "app_name": name},
subdir_name=name,
)
print(f"✅ App '{name}' created in project '{target_dir}' successfully.")
print(f"🔧 Remember to add '{name}' to your INSTALLED_APPS!")

View file

@ -1,10 +1,10 @@
from .model import Model, field
from tortoise.manager import Manager
from tortoise.queryset import QuerySet
from tortoise.signals import (
pre_delete,
post_delete,
pre_save,
post_save,
pre_delete,
pre_save,
)
from .model import Model, field

View file

@ -1,2 +1 @@
from tortoise.exceptions import *

View file

@ -1,22 +1,27 @@
from pydantic_core import core_schema
from pydantic import GetCoreSchemaHandler
from tortoise import fields as field
from tortoise.models import Model as TortoiseModel
from tortoise.contrib.pydantic import pydantic_model_creator, pydantic_queryset_creator
from uuid import UUID
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema
from tortoise import fields as field
from tortoise.contrib.pydantic import pydantic_model_creator, pydantic_queryset_creator
from tortoise.models import Model as TortoiseModel
def __uuid_schema_monkey_patch(cls, source_type, handler):
# Always treat UUID as string schema
return core_schema.no_info_after_validator_function(
# Accept UUID or str, always return UUID internally
lambda v: v if isinstance(v, UUID) else UUID(str(v)),
core_schema.union_schema([
core_schema.str_schema(),
core_schema.is_instance_schema(UUID),
]),
core_schema.union_schema(
[
core_schema.str_schema(),
core_schema.is_instance_schema(UUID),
]
),
# But when serializing, always str()
serialization=core_schema.plain_serializer_function_ser_schema(str, when_used="always"),
serialization=core_schema.plain_serializer_function_ser_schema(
str, when_used="always"
),
)
@ -64,4 +69,3 @@ class Model(TortoiseModel, metaclass=ModelMeta):
class Schema:
include = None
exclude = None

View file

@ -1,2 +1 @@
from fastapi import APIRouter, Depends, HTTPException