diff --git a/src/ohmyapi/builtin/auth/routes.py b/src/ohmyapi/builtin/auth/routes.py index cdb735e..aa09e44 100644 --- a/src/ohmyapi/builtin/auth/routes.py +++ b/src/ohmyapi/builtin/auth/routes.py @@ -26,9 +26,64 @@ REFRESH_TOKEN_EXPIRE_SECONDS = getattr( oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login") -def create_token(data: dict, expires_in: int) -> str: - to_encode = data.copy() - to_encode.update({"exp": int(time.time()) + expires_in}) +class ClaimsUser(BaseModel): + username: str + email: str + is_admin: bool + is_staff: bool + + +class Claims(BaseModel): + type: str + sub: str + user: ClaimsUser + roles: List[str] + exp: str + + +class AccessToken(BaseModel): + token_type: str + access_token: str + + +class RefreshToken(AccessToken): + refresh_token: str + + +class LoginRequest(BaseModel): + username: str + password: str + + +class TokenType(str, Enum): + """ + Helper for indicating the token type when generating claims. + """ + + access = "access" + refresh = "refresh" + + +def claims( + token_type: TokenType, user: User, groups: List[Group] = [] +) -> Claims: + return Claims( + type=token_type, + sub=str(user.id), + user=ClaimsUser( + username=user.username, + email=user.email, + is_admin=user.is_admin, + is_staff=user.is_staff, + ), + roles=[g.name for g in groups], + exp="", + ) + + +def create_token(claims: Claims, expires_in: int) -> str: + to_encode = claims.model_dump() + to_encode['exp'] = int(time.time()) + expires_in token = jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM) if isinstance(token, bytes): token = token.decode("utf-8") @@ -48,29 +103,6 @@ def decode_token(token: str) -> Dict: ) -class TokenType(str, Enum): - """ - Helper for indicating the token type when generating claims. - """ - - access = "access" - refresh = "refresh" - - -def claims( - token_type: TokenType, user: User, groups: List[Group] = [] -) -> Dict[str, Any]: - return { - "type": token_type, - "sub": str(user.id), - "user": { - "username": user.username, - "email": user.email, - }, - "roles": [g.name for g in groups], - } - - async def get_token(token: str = Depends(oauth2_scheme)) -> Dict: """Dependency: token introspection""" payload = decode_token(token) @@ -127,12 +159,7 @@ async def require_group( return current_user -class LoginRequest(BaseModel): - username: str - password: str - - -@router.post("/login") +@router.post("/login", response_model=RefreshToken) async def login(form_data: LoginRequest = Body(...)): """Login with username & password, returns access and refresh tokens.""" user = await User.authenticate(form_data.username, form_data.password) @@ -148,14 +175,14 @@ async def login(form_data: LoginRequest = Body(...)): claims(TokenType.refresh, user), REFRESH_TOKEN_EXPIRE_SECONDS ) - return { - "access_token": access_token, - "refresh_token": refresh_token, - "token_type": "bearer", - } + return RefreshToken( + token_type="bearer", + access_token=access_token, + refresh_token=refresh_token, + ) -@router.post("/refresh") +@router.post("/refresh", response_model=AccessToken) async def refresh_token(refresh_token: str): """Exchange refresh token for new access token.""" payload = decode_token(refresh_token) @@ -174,15 +201,15 @@ async def refresh_token(refresh_token: str): new_access = create_token( claims(TokenType.access, user), ACCESS_TOKEN_EXPIRE_SECONDS ) - return {"access_token": new_access, "token_type": "bearer"} + return AccessToken(token_type="bearer", access_token=access_token) -@router.get("/introspect") +@router.get("/introspect", response_model=Dict[str, Any]) async def introspect(token: Dict = Depends(get_token)): return token -@router.get("/me") +@router.get("/me", response_model=User.Schema.model) async def me(user: User = Depends(get_current_user)): """Return the currently authenticated user.""" - return User.Schema.one.from_orm(user) + return await User.Schema.model.from_tortoise_orm(user) diff --git a/src/ohmyapi/router.py b/src/ohmyapi/router.py index 3a0c442..ed96e12 100644 --- a/src/ohmyapi/router.py +++ b/src/ohmyapi/router.py @@ -1,3 +1,2 @@ from fastapi import APIRouter, Depends, HTTPException from http import HTTPStatus -