💄 Introduce black & isort

This commit is contained in:
Brian Wiborg 2025-09-28 15:41:01 +02:00
parent 9becfc857d
commit 6a90e4a44a
No known key found for this signature in database
14 changed files with 311 additions and 89 deletions

136
poetry.lock generated
View file

@ -163,6 +163,52 @@ files = [
[package.dependencies] [package.dependencies]
colorama = {version = "*", markers = "platform_system == \"Windows\""} colorama = {version = "*", markers = "platform_system == \"Windows\""}
[[package]]
name = "black"
version = "25.9.0"
description = "The uncompromising code formatter."
optional = false
python-versions = ">=3.9"
groups = ["dev"]
files = [
{file = "black-25.9.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ce41ed2614b706fd55fd0b4a6909d06b5bab344ffbfadc6ef34ae50adba3d4f7"},
{file = "black-25.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2ab0ce111ef026790e9b13bd216fa7bc48edd934ffc4cbf78808b235793cbc92"},
{file = "black-25.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f96b6726d690c96c60ba682955199f8c39abc1ae0c3a494a9c62c0184049a713"},
{file = "black-25.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:d119957b37cc641596063cd7db2656c5be3752ac17877017b2ffcdb9dfc4d2b1"},
{file = "black-25.9.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:456386fe87bad41b806d53c062e2974615825c7a52159cde7ccaeb0695fa28fa"},
{file = "black-25.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a16b14a44c1af60a210d8da28e108e13e75a284bf21a9afa6b4571f96ab8bb9d"},
{file = "black-25.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:aaf319612536d502fdd0e88ce52d8f1352b2c0a955cc2798f79eeca9d3af0608"},
{file = "black-25.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:c0372a93e16b3954208417bfe448e09b0de5cc721d521866cd9e0acac3c04a1f"},
{file = "black-25.9.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:1b9dc70c21ef8b43248f1d86aedd2aaf75ae110b958a7909ad8463c4aa0880b0"},
{file = "black-25.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8e46eecf65a095fa62e53245ae2795c90bdecabd53b50c448d0a8bcd0d2e74c4"},
{file = "black-25.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9101ee58ddc2442199a25cb648d46ba22cd580b00ca4b44234a324e3ec7a0f7e"},
{file = "black-25.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:77e7060a00c5ec4b3367c55f39cf9b06e68965a4f2e61cecacd6d0d9b7ec945a"},
{file = "black-25.9.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0172a012f725b792c358d57fe7b6b6e8e67375dd157f64fa7a3097b3ed3e2175"},
{file = "black-25.9.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3bec74ee60f8dfef564b573a96b8930f7b6a538e846123d5ad77ba14a8d7a64f"},
{file = "black-25.9.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b756fc75871cb1bcac5499552d771822fd9db5a2bb8db2a7247936ca48f39831"},
{file = "black-25.9.0-cp313-cp313-win_amd64.whl", hash = "sha256:846d58e3ce7879ec1ffe816bb9df6d006cd9590515ed5d17db14e17666b2b357"},
{file = "black-25.9.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ef69351df3c84485a8beb6f7b8f9721e2009e20ef80a8d619e2d1788b7816d47"},
{file = "black-25.9.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e3c1f4cd5e93842774d9ee4ef6cd8d17790e65f44f7cdbaab5f2cf8ccf22a823"},
{file = "black-25.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:154b06d618233fe468236ba1f0e40823d4eb08b26f5e9261526fde34916b9140"},
{file = "black-25.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:e593466de7b998374ea2585a471ba90553283fb9beefcfa430d84a2651ed5933"},
{file = "black-25.9.0-py3-none-any.whl", hash = "sha256:474b34c1342cdc157d307b56c4c65bce916480c4a8f6551fdc6bf9b486a7c4ae"},
{file = "black-25.9.0.tar.gz", hash = "sha256:0474bca9a0dd1b51791fcc507a4e02078a1c63f6d4e4ae5544b9848c7adfb619"},
]
[package.dependencies]
click = ">=8.0.0"
mypy-extensions = ">=0.4.3"
packaging = ">=22.0"
pathspec = ">=0.9.0"
platformdirs = ">=2"
pytokens = ">=0.1.10"
[package.extras]
colorama = ["colorama (>=0.4.3)"]
d = ["aiohttp (>=3.10)"]
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
uvloop = ["uvloop (>=0.15.2)"]
[[package]] [[package]]
name = "certifi" name = "certifi"
version = "2025.8.3" version = "2025.8.3"
@ -367,7 +413,7 @@ version = "8.3.0"
description = "Composable command line interface toolkit" description = "Composable command line interface toolkit"
optional = false optional = false
python-versions = ">=3.10" python-versions = ">=3.10"
groups = ["main"] groups = ["main", "dev"]
files = [ files = [
{file = "click-8.3.0-py3-none-any.whl", hash = "sha256:9b9f285302c6e3064f4330c05f05b81945b2a39544279343e6e7c5f27a9baddc"}, {file = "click-8.3.0-py3-none-any.whl", hash = "sha256:9b9f285302c6e3064f4330c05f05b81945b2a39544279343e6e7c5f27a9baddc"},
{file = "click-8.3.0.tar.gz", hash = "sha256:e7b8232224eba16f4ebe410c25ced9f7875cb5f3263ffc93cc3e8da705e229c4"}, {file = "click-8.3.0.tar.gz", hash = "sha256:e7b8232224eba16f4ebe410c25ced9f7875cb5f3263ffc93cc3e8da705e229c4"},
@ -383,11 +429,11 @@ description = "Cross-platform colored terminal text."
optional = false optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
groups = ["main", "dev"] groups = ["main", "dev"]
markers = "platform_system == \"Windows\" or sys_platform == \"win32\""
files = [ files = [
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
] ]
markers = {main = "platform_system == \"Windows\" or sys_platform == \"win32\"", dev = "sys_platform == \"win32\""}
[[package]] [[package]]
name = "crypto" name = "crypto"
@ -559,6 +605,22 @@ files = [
{file = "iso8601-2.1.0.tar.gz", hash = "sha256:6b1d3829ee8921c4301998c909f7829fa9ed3cbdac0d3b16af2d743aed1ba8df"}, {file = "iso8601-2.1.0.tar.gz", hash = "sha256:6b1d3829ee8921c4301998c909f7829fa9ed3cbdac0d3b16af2d743aed1ba8df"},
] ]
[[package]]
name = "isort"
version = "6.0.1"
description = "A Python utility / library to sort Python imports."
optional = false
python-versions = ">=3.9.0"
groups = ["dev"]
files = [
{file = "isort-6.0.1-py3-none-any.whl", hash = "sha256:2dc5d7f65c9678d94c88dfc29161a320eec67328bc97aad576874cb4be1e9615"},
{file = "isort-6.0.1.tar.gz", hash = "sha256:1cb5df28dfbc742e490c5e41bad6da41b805b0a8be7bc93cd0fb2a8a890ac450"},
]
[package.extras]
colors = ["colorama"]
plugins = ["setuptools"]
[[package]] [[package]]
name = "jedi" name = "jedi"
version = "0.19.2" version = "0.19.2"
@ -719,6 +781,18 @@ files = [
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
] ]
[[package]]
name = "mypy-extensions"
version = "1.1.0"
description = "Type system extensions for programs checked with the mypy type checker."
optional = false
python-versions = ">=3.8"
groups = ["dev"]
files = [
{file = "mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505"},
{file = "mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558"},
]
[[package]] [[package]]
name = "naked" name = "naked"
version = "0.1.32" version = "0.1.32"
@ -735,6 +809,18 @@ files = [
pyyaml = "*" pyyaml = "*"
requests = "*" requests = "*"
[[package]]
name = "packaging"
version = "25.0"
description = "Core utilities for Python packages"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
files = [
{file = "packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484"},
{file = "packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f"},
]
[[package]] [[package]]
name = "parso" name = "parso"
version = "0.8.5" version = "0.8.5"
@ -769,6 +855,18 @@ bcrypt = ["bcrypt (>=3.1.0)"]
build-docs = ["cloud-sptheme (>=1.10.1)", "sphinx (>=1.6)", "sphinxcontrib-fulltoc (>=1.2.0)"] build-docs = ["cloud-sptheme (>=1.10.1)", "sphinx (>=1.6)", "sphinxcontrib-fulltoc (>=1.2.0)"]
totp = ["cryptography"] totp = ["cryptography"]
[[package]]
name = "pathspec"
version = "0.12.1"
description = "Utility library for gitignore style pattern matching of file paths."
optional = false
python-versions = ">=3.8"
groups = ["dev"]
files = [
{file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"},
{file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"},
]
[[package]] [[package]]
name = "pexpect" name = "pexpect"
version = "4.9.0" version = "4.9.0"
@ -785,6 +883,23 @@ files = [
[package.dependencies] [package.dependencies]
ptyprocess = ">=0.5" ptyprocess = ">=0.5"
[[package]]
name = "platformdirs"
version = "4.4.0"
description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`."
optional = false
python-versions = ">=3.9"
groups = ["dev"]
files = [
{file = "platformdirs-4.4.0-py3-none-any.whl", hash = "sha256:abd01743f24e5287cd7a5db3752faf1a2d65353f38ec26d98e25a6db65958c85"},
{file = "platformdirs-4.4.0.tar.gz", hash = "sha256:ca753cf4d81dc309bc67b0ea38fd15dc97bc30ce419a7f58d13eb3bf14c4febf"},
]
[package.extras]
docs = ["furo (>=2024.8.6)", "proselint (>=0.14)", "sphinx (>=8.1.3)", "sphinx-autodoc-typehints (>=3)"]
test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=8.3.4)", "pytest-cov (>=6)", "pytest-mock (>=3.14)"]
type = ["mypy (>=1.14.1)"]
[[package]] [[package]]
name = "prompt-toolkit" name = "prompt-toolkit"
version = "3.0.52" version = "3.0.52"
@ -1033,6 +1148,21 @@ files = [
{file = "python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13"}, {file = "python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13"},
] ]
[[package]]
name = "pytokens"
version = "0.1.10"
description = "A Fast, spec compliant Python 3.12+ tokenizer that runs on older Pythons."
optional = false
python-versions = ">=3.8"
groups = ["dev"]
files = [
{file = "pytokens-0.1.10-py3-none-any.whl", hash = "sha256:db7b72284e480e69fb085d9f251f66b3d2df8b7166059261258ff35f50fb711b"},
{file = "pytokens-0.1.10.tar.gz", hash = "sha256:c9a4bfa0be1d26aebce03e6884ba454e842f186a59ea43a6d3b25af58223c044"},
]
[package.extras]
dev = ["black", "build", "mypy", "pytest", "pytest-cov", "setuptools", "tox", "twine", "wheel"]
[[package]] [[package]]
name = "pytz" name = "pytz"
version = "2025.2" version = "2025.2"
@ -1378,4 +1508,4 @@ auth = ["argon2-cffi", "crypto", "passlib", "pyjwt", "python-multipart"]
[metadata] [metadata]
lock-version = "2.1" lock-version = "2.1"
python-versions = ">=3.13" python-versions = ">=3.13"
content-hash = "16ae1b48820c723ca784e71a454fb4b686c94fc9f01fa81b086df5dcaf512074" content-hash = "145508f708df01d84d998947a87b95cfc269e197eb8bc7467e9748a3b8e210e5"

View file

@ -27,6 +27,8 @@ dependencies = [
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
ipython = ">=9.5.0,<10.0.0" ipython = ">=9.5.0,<10.0.0"
black = "^25.9.0"
isort = "^6.0.1"
[project.optional-dependencies] [project.optional-dependencies]
auth = ["passlib", "pyjwt", "crypto", "argon2-cffi", "python-multipart"] auth = ["passlib", "pyjwt", "crypto", "argon2-cffi", "python-multipart"]
@ -36,3 +38,23 @@ packages = [ { include = "ohmyapi", from = "src" } ]
[project.scripts] [project.scripts]
ohmyapi = "ohmyapi.cli:app" ohmyapi = "ohmyapi.cli:app"
[tool.black]
line-length = 88
target-version = ['py39', 'py310', 'py311', 'py312', 'py313']
include = '\.pyi?$'
exclude = '''
/(
\.git
| \.venv
| build
| dist
)/
'''
[tool.isort]
profile = "black" # makes imports compatible with black
line_length = 88
multi_line_output = 3
include_trailing_comma = true

View file

@ -1,2 +1 @@
from . import db from . import db

View file

@ -1,4 +1 @@
from . import models from . import models, permissions, routes
from . import routes
from . import permissions

View file

@ -1,11 +1,12 @@
from ohmyapi.router import HTTPException
from ohmyapi.db import Model, field, pre_save, pre_delete
from functools import wraps from functools import wraps
from typing import Optional, List from typing import List, Optional
from uuid import UUID
from passlib.context import CryptContext from passlib.context import CryptContext
from tortoise.contrib.pydantic import pydantic_queryset_creator from tortoise.contrib.pydantic import pydantic_queryset_creator
from uuid import UUID
from ohmyapi.db import Model, field, pre_delete, pre_save
from ohmyapi.router import HTTPException
pwd_context = CryptContext(schemes=["argon2"], deprecated="auto") pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")
@ -22,10 +23,12 @@ class User(Model):
password_hash: str = field.CharField(max_length=128) password_hash: str = field.CharField(max_length=128)
is_admin: bool = field.BooleanField(default=False) is_admin: bool = field.BooleanField(default=False)
is_staff: bool = field.BooleanField(default=False) is_staff: bool = field.BooleanField(default=False)
groups: field.ManyToManyRelation[Group] = field.ManyToManyField("ohmyapi_auth.Group", related_name="users", through='user_groups') groups: field.ManyToManyRelation[Group] = field.ManyToManyField(
"ohmyapi_auth.Group", related_name="users", through="user_groups"
)
class Schema: class Schema:
exclude = 'password_hash', exclude = ("password_hash",)
def set_password(self, raw_password: str) -> None: def set_password(self, raw_password: str) -> None:
"""Hash and store the password.""" """Hash and store the password."""

View file

@ -1,8 +1,8 @@
from .routes import ( from .routes import (
get_token,
get_current_user, get_current_user,
require_authenticated, get_token,
require_admin, require_admin,
require_staff, require_authenticated,
require_group, require_group,
require_staff,
) )

View file

@ -3,13 +3,12 @@ from enum import Enum
from typing import Any, Dict, List from typing import Any, Dict, List
import jwt import jwt
import settings
from fastapi import APIRouter, Body, Depends, Header, HTTPException, status from fastapi import APIRouter, Body, Depends, Header, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel from pydantic import BaseModel
from ohmyapi.builtin.auth.models import User, Group from ohmyapi.builtin.auth.models import Group, User
import settings
# Router # Router
router = APIRouter(prefix="/auth", tags=["auth"]) router = APIRouter(prefix="/auth", tags=["auth"])
@ -17,8 +16,12 @@ router = APIRouter(prefix="/auth", tags=["auth"])
# Secrets & config (should come from settings/env in real projects) # Secrets & config (should come from settings/env in real projects)
JWT_SECRET = getattr(settings, "JWT_SECRET", "changeme") JWT_SECRET = getattr(settings, "JWT_SECRET", "changeme")
JWT_ALGORITHM = getattr(settings, "JWT_ALGORITHM", "HS256") JWT_ALGORITHM = getattr(settings, "JWT_ALGORITHM", "HS256")
ACCESS_TOKEN_EXPIRE_SECONDS = getattr(settings, "JWT_ACCESS_TOKEN_EXPIRE_SECONDS", 15 * 60) ACCESS_TOKEN_EXPIRE_SECONDS = getattr(
REFRESH_TOKEN_EXPIRE_SECONDS = getattr(settings, "JWT_REFRESH_TOKEN_EXPIRE_SECONDS", 7 * 24 * 60 * 60) settings, "JWT_ACCESS_TOKEN_EXPIRE_SECONDS", 15 * 60
)
REFRESH_TOKEN_EXPIRE_SECONDS = getattr(
settings, "JWT_REFRESH_TOKEN_EXPIRE_SECONDS", 7 * 24 * 60 * 60
)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")
@ -36,30 +39,38 @@ def decode_token(token: str) -> Dict:
try: try:
return jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM]) return jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
except jwt.ExpiredSignatureError: except jwt.ExpiredSignatureError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired") raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
)
except jwt.InvalidTokenError: except jwt.InvalidTokenError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
)
class TokenType(str, Enum): class TokenType(str, Enum):
""" """
Helper for indicating the token type when generating claims. Helper for indicating the token type when generating claims.
""" """
access = "access" access = "access"
refresh = "refresh" refresh = "refresh"
def claims(token_type: TokenType, user: User, groups: List[Group] = []) -> Dict[str, Any]: def claims(
token_type: TokenType, user: User, groups: List[Group] = []
) -> Dict[str, Any]:
return { return {
'type': token_type, "type": token_type,
'sub': str(user.id), "sub": str(user.id),
'user': { "user": {
'username': user.username, "username": user.username,
'email': user.email, "email": user.email,
}, },
'roles': [g.name for g in groups] "roles": [g.name for g in groups],
} }
async def get_token(token: str = Depends(oauth2_scheme)) -> Dict: async def get_token(token: str = Depends(oauth2_scheme)) -> Dict:
"""Dependency: token introspection""" """Dependency: token introspection"""
payload = decode_token(token) payload = decode_token(token)
@ -71,11 +82,15 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
payload = decode_token(token) payload = decode_token(token)
user_id = payload.get("sub") user_id = payload.get("sub")
if user_id is None: if user_id is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload") raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload"
)
user = await User.filter(id=user_id).first() user = await User.filter(id=user_id).first()
if not user: if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found") raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found"
)
return user return user
@ -101,15 +116,13 @@ async def require_staff(current_user: User = Depends(get_current_user)) -> User:
async def require_group( async def require_group(
group_name: str, group_name: str, current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user)
) -> User: ) -> User:
"""Ensure the current user belongs to the given group.""" """Ensure the current user belongs to the given group."""
user_groups = await current_user.groups.all() user_groups = await current_user.groups.all()
if not any(g.name == group_name for g in user_groups): if not any(g.name == group_name for g in user_groups):
raise HTTPException( raise HTTPException(
status_code=403, status_code=403, detail=f"User must belong to group '{group_name}'"
detail=f"User must belong to group '{group_name}'"
) )
return current_user return current_user
@ -124,15 +137,21 @@ async def login(form_data: LoginRequest = Body(...)):
"""Login with username & password, returns access and refresh tokens.""" """Login with username & password, returns access and refresh tokens."""
user = await User.authenticate(form_data.username, form_data.password) user = await User.authenticate(form_data.username, form_data.password)
if not user: if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials") raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials"
)
access_token = create_token(claims(TokenType.access, user), ACCESS_TOKEN_EXPIRE_SECONDS) access_token = create_token(
refresh_token = create_token(claims(TokenType.refresh, user), REFRESH_TOKEN_EXPIRE_SECONDS) claims(TokenType.access, user), ACCESS_TOKEN_EXPIRE_SECONDS
)
refresh_token = create_token(
claims(TokenType.refresh, user), REFRESH_TOKEN_EXPIRE_SECONDS
)
return { return {
"access_token": access_token, "access_token": access_token,
"refresh_token": refresh_token, "refresh_token": refresh_token,
"token_type": "bearer" "token_type": "bearer",
} }
@ -141,14 +160,20 @@ async def refresh_token(refresh_token: str):
"""Exchange refresh token for new access token.""" """Exchange refresh token for new access token."""
payload = decode_token(refresh_token) payload = decode_token(refresh_token)
if payload.get("type") != "refresh": if payload.get("type") != "refresh":
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token") raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
)
user_id = payload.get("sub") user_id = payload.get("sub")
user = await User.filter(id=user_id).first() user = await User.filter(id=user_id).first()
if not user: if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found") raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found"
)
new_access = create_token(claims(TokenType.access, user), ACCESS_TOKEN_EXPIRE_SECONDS) new_access = create_token(
claims(TokenType.access, user), ACCESS_TOKEN_EXPIRE_SECONDS
)
return {"access_token": new_access, "token_type": "bearer"} return {"access_token": new_access, "token_type": "bearer"}
@ -161,4 +186,3 @@ async def introspect(token: Dict = Depends(get_token)):
async def me(user: User = Depends(get_current_user)): async def me(user: User = Depends(get_current_user)):
"""Return the currently authenticated user.""" """Return the currently authenticated user."""
return User.Schema.one.from_orm(user) return User.Schema.one.from_orm(user)

View file

@ -2,14 +2,17 @@ import asyncio
import atexit import atexit
import importlib import importlib
import sys import sys
from getpass import getpass
from pathlib import Path
import typer import typer
import uvicorn import uvicorn
from getpass import getpass from ohmyapi.core import runtime, scaffolding
from ohmyapi.core import scaffolding, runtime
from pathlib import Path
app = typer.Typer(help="OhMyAPI — Django-flavored FastAPI scaffolding with tightly integrated TortoiseORM.") app = typer.Typer(
help="OhMyAPI — Django-flavored FastAPI scaffolding with tightly integrated TortoiseORM."
)
@app.command() @app.command()
@ -78,6 +81,7 @@ def shell(root: str = "."):
start_ipython(argv=[], user_ns=shell_vars, config=c) start_ipython(argv=[], user_ns=shell_vars, config=c)
except ImportError: except ImportError:
import code import code
code.interact(local=shell_vars, banner=banner) code.interact(local=shell_vars, banner=banner)
finally: finally:
loop.run_until_complete(cleanup()) loop.run_until_complete(cleanup())
@ -120,11 +124,15 @@ def createsuperuser(root: str = "."):
project_path = Path(root).resolve() project_path = Path(root).resolve()
project = runtime.Project(project_path) project = runtime.Project(project_path)
if not project.is_app_installed("ohmyapi_auth"): if not project.is_app_installed("ohmyapi_auth"):
print("Auth app not installed! Please add 'ohmyapi_auth' to your INSTALLED_APPS.") print(
"Auth app not installed! Please add 'ohmyapi_auth' to your INSTALLED_APPS."
)
return return
import asyncio import asyncio
import ohmyapi_auth import ohmyapi_auth
email = input("E-Mail: ") email = input("E-Mail: ")
username = input("Username: ") username = input("Username: ")
password1, password2 = "foo", "bar" password1, password2 = "foo", "bar"
@ -133,9 +141,10 @@ def createsuperuser(root: str = "."):
password2 = getpass("Repeat Password: ") password2 = getpass("Repeat Password: ")
if password1 != password2: if password1 != password2:
print("Passwords didn't match!") print("Passwords didn't match!")
user = ohmyapi_auth.models.User(email=email, username=username, is_staff=True, is_admin=True) user = ohmyapi_auth.models.User(
email=email, username=username, is_staff=True, is_admin=True
)
user.set_password(password1) user.set_password(password1)
asyncio.run(project.init_orm()) asyncio.run(project.init_orm())
asyncio.run(user.save()) asyncio.run(user.save())
asyncio.run(project.close_orm()) asyncio.run(project.close_orm())

View file

@ -11,8 +11,9 @@ from typing import Any, Dict, Generator, List, Optional
import click import click
from aerich import Command as AerichCommand from aerich import Command as AerichCommand
from aerich.exceptions import NotInitedError from aerich.exceptions import NotInitedError
from fastapi import APIRouter, FastAPI
from tortoise import Tortoise from tortoise import Tortoise
from fastapi import FastAPI, APIRouter
from ohmyapi.db.model import Model from ohmyapi.db.model import Model
@ -44,7 +45,9 @@ class Project:
orig = importlib.import_module(full) orig = importlib.import_module(full)
sys.modules[alias] = orig sys.modules[alias] = orig
try: try:
sys.modules[f"{alias}.models"] = importlib.import_module(f"{full}.models") sys.modules[f"{alias}.models"] = importlib.import_module(
f"{full}.models"
)
except ModuleNotFoundError: except ModuleNotFoundError:
pass pass
@ -52,7 +55,9 @@ class Project:
try: try:
self.settings = importlib.import_module("settings") self.settings = importlib.import_module("settings")
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to import project settings from {self.project_path}") from e raise RuntimeError(
f"Failed to import project settings from {self.project_path}"
) from e
# Load installed apps # Load installed apps
for app_name in getattr(self.settings, "INSTALLED_APPS", []): for app_name in getattr(self.settings, "INSTALLED_APPS", []):
@ -104,11 +109,16 @@ class Project:
for app_name, app in self._apps.items(): for app_name, app in self._apps.items():
modules = list(dict.fromkeys(app.model_modules)) modules = list(dict.fromkeys(app.model_modules))
if modules: if modules:
config["apps"][app_name] = {"models": modules, "default_connection": "default"} config["apps"][app_name] = {
"models": modules,
"default_connection": "default",
}
return config return config
def build_aerich_command(self, app_label: str, db_url: Optional[str] = None) -> AerichCommand: def build_aerich_command(
self, app_label: str, db_url: Optional[str] = None
) -> AerichCommand:
# Resolve label to flat_label # Resolve label to flat_label
if app_label in self._apps: if app_label in self._apps:
flat_label = app_label flat_label = app_label
@ -129,7 +139,7 @@ class Project:
return AerichCommand( return AerichCommand(
tortoise_config=tortoise_cfg, tortoise_config=tortoise_cfg,
app=flat_label, app=flat_label,
location=str(self.migrations_dir) location=str(self.migrations_dir),
) )
# --- ORM lifecycle --- # --- ORM lifecycle ---
@ -144,7 +154,9 @@ class Project:
await Tortoise.close_connections() await Tortoise.close_connections()
# --- Migration helpers --- # --- Migration helpers ---
async def makemigrations(self, app_label: str, name: str = "auto", db_url: Optional[str] = None) -> None: async def makemigrations(
self, app_label: str, name: str = "auto", db_url: Optional[str] = None
) -> None:
cmd = self.build_aerich_command(app_label, db_url=db_url) cmd = self.build_aerich_command(app_label, db_url=db_url)
async with cmd as c: async with cmd as c:
await c.init() await c.init()
@ -158,7 +170,9 @@ class Project:
await c.init_db(safe=True) await c.init_db(safe=True)
await c.migrate(name=name) await c.migrate(name=name)
async def migrate(self, app_label: Optional[str] = None, db_url: Optional[str] = None) -> None: async def migrate(
self, app_label: Optional[str] = None, db_url: Optional[str] = None
) -> None:
labels: List[str] labels: List[str]
if app_label: if app_label:
if app_label in self._apps: if app_label in self._apps:
@ -231,8 +245,11 @@ class App:
"name": route.name, "name": route.name,
"methods": list(route.methods), "methods": list(route.methods),
"endpoint": route.endpoint.__name__, # just the function name "endpoint": route.endpoint.__name__, # just the function name
"response_model": getattr(route, "response_model", None).__name__ "response_model": (
if getattr(route, "response_model", None) else None, getattr(route, "response_model", None).__name__
if getattr(route, "response_model", None)
else None
),
"tags": getattr(route, "tags", None), "tags": getattr(route, "tags", None),
} }
@ -241,8 +258,8 @@ class App:
def dict(self) -> Dict[str, Any]: def dict(self) -> Dict[str, Any]:
return { return {
'models': [m.__name__ for m in self.models], "models": [m.__name__ for m in self.models],
'routes': self._serialize_router(), "routes": self._serialize_router(),
} }
@property @property
@ -250,10 +267,13 @@ class App:
for mod in self.model_modules: for mod in self.model_modules:
models_mod = importlib.import_module(mod) models_mod = importlib.import_module(mod)
for obj in models_mod.__dict__.values(): for obj in models_mod.__dict__.values():
if isinstance(obj, type) and getattr(obj, "_meta", None) is not None and obj.__name__ != 'Model': if (
isinstance(obj, type)
and getattr(obj, "_meta", None) is not None
and obj.__name__ != "Model"
):
yield obj yield obj
@property @property
def routes(self): def routes(self):
return self.router.routes return self.router.routes

View file

@ -1,4 +1,5 @@
from pathlib import Path from pathlib import Path
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
# Base templates directory # Base templates directory
@ -8,14 +9,21 @@ env = Environment(loader=FileSystemLoader(str(TEMPLATE_DIR)))
def render_template_file(template_path: Path, context: dict, output_path: Path): def render_template_file(template_path: Path, context: dict, output_path: Path):
"""Render a single Jinja2 template file to disk.""" """Render a single Jinja2 template file to disk."""
template = env.get_template(str(template_path.relative_to(TEMPLATE_DIR)).replace("\\", "/")) template = env.get_template(
str(template_path.relative_to(TEMPLATE_DIR)).replace("\\", "/")
)
content = template.render(**context) content = template.render(**context)
output_path.parent.mkdir(exist_ok=True) output_path.parent.mkdir(exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f: with open(output_path, "w", encoding="utf-8") as f:
f.write(content) f.write(content)
def render_template_dir(template_subdir: str, target_dir: Path, context: dict, subdir_name: str | None = None): def render_template_dir(
template_subdir: str,
target_dir: Path,
context: dict,
subdir_name: str | None = None,
):
""" """
Recursively render all *.j2 templates from TEMPLATE_DIR/template_subdir into target_dir. Recursively render all *.j2 templates from TEMPLATE_DIR/template_subdir into target_dir.
If subdir_name is given, files are placed inside target_dir/subdir_name. If subdir_name is given, files are placed inside target_dir/subdir_name.
@ -23,14 +31,18 @@ def render_template_dir(template_subdir: str, target_dir: Path, context: dict, s
template_dir = TEMPLATE_DIR / template_subdir template_dir = TEMPLATE_DIR / template_subdir
for root, _, files in template_dir.walk(): for root, _, files in template_dir.walk():
root_path = Path(root) root_path = Path(root)
rel_root = root_path.relative_to(template_dir) # path relative to template_subdir rel_root = root_path.relative_to(
template_dir
) # path relative to template_subdir
for f in files: for f in files:
if not f.endswith(".j2"): if not f.endswith(".j2"):
continue continue
template_rel_path = rel_root / f template_rel_path = rel_root / f
output_rel_path = Path(*template_rel_path.parts).with_suffix("") # remove .j2 output_rel_path = Path(*template_rel_path.parts).with_suffix(
""
) # remove .j2
# optionally wrap in subdir_name # optionally wrap in subdir_name
if subdir_name: if subdir_name:
@ -54,7 +66,11 @@ def startapp(name: str, project: str):
"""Create a new app inside a project: templates go into <project_dir>/<name>/""" """Create a new app inside a project: templates go into <project_dir>/<name>/"""
target_dir = Path(project) target_dir = Path(project)
target_dir.mkdir(exist_ok=True) target_dir.mkdir(exist_ok=True)
render_template_dir("app", target_dir, {"project_name": target_dir.resolve().name, "app_name": name}, subdir_name=name) render_template_dir(
"app",
target_dir,
{"project_name": target_dir.resolve().name, "app_name": name},
subdir_name=name,
)
print(f"✅ App '{name}' created in project '{target_dir}' successfully.") print(f"✅ App '{name}' created in project '{target_dir}' successfully.")
print(f"🔧 Remember to add '{name}' to your INSTALLED_APPS!") print(f"🔧 Remember to add '{name}' to your INSTALLED_APPS!")

View file

@ -1,10 +1,10 @@
from .model import Model, field
from tortoise.manager import Manager from tortoise.manager import Manager
from tortoise.queryset import QuerySet from tortoise.queryset import QuerySet
from tortoise.signals import ( from tortoise.signals import (
pre_delete,
post_delete, post_delete,
pre_save,
post_save, post_save,
pre_delete,
pre_save,
) )
from .model import Model, field

View file

@ -1,2 +1 @@
from tortoise.exceptions import * from tortoise.exceptions import *

View file

@ -1,22 +1,27 @@
from pydantic_core import core_schema
from pydantic import GetCoreSchemaHandler
from tortoise import fields as field
from tortoise.models import Model as TortoiseModel
from tortoise.contrib.pydantic import pydantic_model_creator, pydantic_queryset_creator
from uuid import UUID from uuid import UUID
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema
from tortoise import fields as field
from tortoise.contrib.pydantic import pydantic_model_creator, pydantic_queryset_creator
from tortoise.models import Model as TortoiseModel
def __uuid_schema_monkey_patch(cls, source_type, handler): def __uuid_schema_monkey_patch(cls, source_type, handler):
# Always treat UUID as string schema # Always treat UUID as string schema
return core_schema.no_info_after_validator_function( return core_schema.no_info_after_validator_function(
# Accept UUID or str, always return UUID internally # Accept UUID or str, always return UUID internally
lambda v: v if isinstance(v, UUID) else UUID(str(v)), lambda v: v if isinstance(v, UUID) else UUID(str(v)),
core_schema.union_schema([ core_schema.union_schema(
core_schema.str_schema(), [
core_schema.is_instance_schema(UUID), core_schema.str_schema(),
]), core_schema.is_instance_schema(UUID),
]
),
# But when serializing, always str() # But when serializing, always str()
serialization=core_schema.plain_serializer_function_ser_schema(str, when_used="always"), serialization=core_schema.plain_serializer_function_ser_schema(
str, when_used="always"
),
) )
@ -64,4 +69,3 @@ class Model(TortoiseModel, metaclass=ModelMeta):
class Schema: class Schema:
include = None include = None
exclude = None exclude = None

View file

@ -1,2 +1 @@
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException