♻️ Refactor ohmyapi_auth

- improved type-safety
- created and defined response_models
This commit is contained in:
Brian Wiborg 2025-09-28 19:23:22 +02:00
parent 64d6ca369f
commit 90f257ae38
No known key found for this signature in database
2 changed files with 69 additions and 43 deletions

View file

@ -26,9 +26,64 @@ REFRESH_TOKEN_EXPIRE_SECONDS = getattr(
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")
def create_token(data: dict, expires_in: int) -> str: class ClaimsUser(BaseModel):
to_encode = data.copy() username: str
to_encode.update({"exp": int(time.time()) + expires_in}) email: str
is_admin: bool
is_staff: bool
class Claims(BaseModel):
type: str
sub: str
user: ClaimsUser
roles: List[str]
exp: str
class AccessToken(BaseModel):
token_type: str
access_token: str
class RefreshToken(AccessToken):
refresh_token: str
class LoginRequest(BaseModel):
username: str
password: str
class TokenType(str, Enum):
"""
Helper for indicating the token type when generating claims.
"""
access = "access"
refresh = "refresh"
def claims(
token_type: TokenType, user: User, groups: List[Group] = []
) -> Claims:
return Claims(
type=token_type,
sub=str(user.id),
user=ClaimsUser(
username=user.username,
email=user.email,
is_admin=user.is_admin,
is_staff=user.is_staff,
),
roles=[g.name for g in groups],
exp="",
)
def create_token(claims: Claims, expires_in: int) -> str:
to_encode = claims.model_dump()
to_encode['exp'] = int(time.time()) + expires_in
token = jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM) token = jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM)
if isinstance(token, bytes): if isinstance(token, bytes):
token = token.decode("utf-8") token = token.decode("utf-8")
@ -48,29 +103,6 @@ def decode_token(token: str) -> Dict:
) )
class TokenType(str, Enum):
"""
Helper for indicating the token type when generating claims.
"""
access = "access"
refresh = "refresh"
def claims(
token_type: TokenType, user: User, groups: List[Group] = []
) -> Dict[str, Any]:
return {
"type": token_type,
"sub": str(user.id),
"user": {
"username": user.username,
"email": user.email,
},
"roles": [g.name for g in groups],
}
async def get_token(token: str = Depends(oauth2_scheme)) -> Dict: async def get_token(token: str = Depends(oauth2_scheme)) -> Dict:
"""Dependency: token introspection""" """Dependency: token introspection"""
payload = decode_token(token) payload = decode_token(token)
@ -127,12 +159,7 @@ async def require_group(
return current_user return current_user
class LoginRequest(BaseModel): @router.post("/login", response_model=RefreshToken)
username: str
password: str
@router.post("/login")
async def login(form_data: LoginRequest = Body(...)): async def login(form_data: LoginRequest = Body(...)):
"""Login with username & password, returns access and refresh tokens.""" """Login with username & password, returns access and refresh tokens."""
user = await User.authenticate(form_data.username, form_data.password) user = await User.authenticate(form_data.username, form_data.password)
@ -148,14 +175,14 @@ async def login(form_data: LoginRequest = Body(...)):
claims(TokenType.refresh, user), REFRESH_TOKEN_EXPIRE_SECONDS claims(TokenType.refresh, user), REFRESH_TOKEN_EXPIRE_SECONDS
) )
return { return RefreshToken(
"access_token": access_token, token_type="bearer",
"refresh_token": refresh_token, access_token=access_token,
"token_type": "bearer", refresh_token=refresh_token,
} )
@router.post("/refresh") @router.post("/refresh", response_model=AccessToken)
async def refresh_token(refresh_token: str): async def refresh_token(refresh_token: str):
"""Exchange refresh token for new access token.""" """Exchange refresh token for new access token."""
payload = decode_token(refresh_token) payload = decode_token(refresh_token)
@ -174,15 +201,15 @@ async def refresh_token(refresh_token: str):
new_access = create_token( new_access = create_token(
claims(TokenType.access, user), ACCESS_TOKEN_EXPIRE_SECONDS claims(TokenType.access, user), ACCESS_TOKEN_EXPIRE_SECONDS
) )
return {"access_token": new_access, "token_type": "bearer"} return AccessToken(token_type="bearer", access_token=access_token)
@router.get("/introspect") @router.get("/introspect", response_model=Dict[str, Any])
async def introspect(token: Dict = Depends(get_token)): async def introspect(token: Dict = Depends(get_token)):
return token return token
@router.get("/me") @router.get("/me", response_model=User.Schema.model)
async def me(user: User = Depends(get_current_user)): async def me(user: User = Depends(get_current_user)):
"""Return the currently authenticated user.""" """Return the currently authenticated user."""
return User.Schema.one.from_orm(user) return await User.Schema.model.from_tortoise_orm(user)

View file

@ -1,3 +1,2 @@
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from http import HTTPStatus from http import HTTPStatus