♻️ Refactor ohmyapi_auth

- improved type-safety
- created and defined response_models
This commit is contained in:
Brian Wiborg 2025-09-28 19:23:22 +02:00
parent 64d6ca369f
commit 90f257ae38
No known key found for this signature in database
2 changed files with 69 additions and 43 deletions

View file

@ -26,9 +26,64 @@ REFRESH_TOKEN_EXPIRE_SECONDS = getattr(
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")
def create_token(data: dict, expires_in: int) -> str:
to_encode = data.copy()
to_encode.update({"exp": int(time.time()) + expires_in})
class ClaimsUser(BaseModel):
username: str
email: str
is_admin: bool
is_staff: bool
class Claims(BaseModel):
type: str
sub: str
user: ClaimsUser
roles: List[str]
exp: str
class AccessToken(BaseModel):
token_type: str
access_token: str
class RefreshToken(AccessToken):
refresh_token: str
class LoginRequest(BaseModel):
username: str
password: str
class TokenType(str, Enum):
"""
Helper for indicating the token type when generating claims.
"""
access = "access"
refresh = "refresh"
def claims(
token_type: TokenType, user: User, groups: List[Group] = []
) -> Claims:
return Claims(
type=token_type,
sub=str(user.id),
user=ClaimsUser(
username=user.username,
email=user.email,
is_admin=user.is_admin,
is_staff=user.is_staff,
),
roles=[g.name for g in groups],
exp="",
)
def create_token(claims: Claims, expires_in: int) -> str:
to_encode = claims.model_dump()
to_encode['exp'] = int(time.time()) + expires_in
token = jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM)
if isinstance(token, bytes):
token = token.decode("utf-8")
@ -48,29 +103,6 @@ def decode_token(token: str) -> Dict:
)
class TokenType(str, Enum):
"""
Helper for indicating the token type when generating claims.
"""
access = "access"
refresh = "refresh"
def claims(
token_type: TokenType, user: User, groups: List[Group] = []
) -> Dict[str, Any]:
return {
"type": token_type,
"sub": str(user.id),
"user": {
"username": user.username,
"email": user.email,
},
"roles": [g.name for g in groups],
}
async def get_token(token: str = Depends(oauth2_scheme)) -> Dict:
"""Dependency: token introspection"""
payload = decode_token(token)
@ -127,12 +159,7 @@ async def require_group(
return current_user
class LoginRequest(BaseModel):
username: str
password: str
@router.post("/login")
@router.post("/login", response_model=RefreshToken)
async def login(form_data: LoginRequest = Body(...)):
"""Login with username & password, returns access and refresh tokens."""
user = await User.authenticate(form_data.username, form_data.password)
@ -148,14 +175,14 @@ async def login(form_data: LoginRequest = Body(...)):
claims(TokenType.refresh, user), REFRESH_TOKEN_EXPIRE_SECONDS
)
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
}
return RefreshToken(
token_type="bearer",
access_token=access_token,
refresh_token=refresh_token,
)
@router.post("/refresh")
@router.post("/refresh", response_model=AccessToken)
async def refresh_token(refresh_token: str):
"""Exchange refresh token for new access token."""
payload = decode_token(refresh_token)
@ -174,15 +201,15 @@ async def refresh_token(refresh_token: str):
new_access = create_token(
claims(TokenType.access, user), ACCESS_TOKEN_EXPIRE_SECONDS
)
return {"access_token": new_access, "token_type": "bearer"}
return AccessToken(token_type="bearer", access_token=access_token)
@router.get("/introspect")
@router.get("/introspect", response_model=Dict[str, Any])
async def introspect(token: Dict = Depends(get_token)):
return token
@router.get("/me")
@router.get("/me", response_model=User.Schema.model)
async def me(user: User = Depends(get_current_user)):
"""Return the currently authenticated user."""
return User.Schema.one.from_orm(user)
return await User.Schema.model.from_tortoise_orm(user)

View file

@ -1,3 +1,2 @@
from fastapi import APIRouter, Depends, HTTPException
from http import HTTPStatus