♻️ Refactor ohmyapi_auth
- improved type-safety - created and defined response_models
This commit is contained in:
parent
64d6ca369f
commit
90f257ae38
2 changed files with 69 additions and 43 deletions
|
|
@ -26,9 +26,64 @@ REFRESH_TOKEN_EXPIRE_SECONDS = getattr(
|
|||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")
|
||||
|
||||
|
||||
def create_token(data: dict, expires_in: int) -> str:
|
||||
to_encode = data.copy()
|
||||
to_encode.update({"exp": int(time.time()) + expires_in})
|
||||
class ClaimsUser(BaseModel):
|
||||
username: str
|
||||
email: str
|
||||
is_admin: bool
|
||||
is_staff: bool
|
||||
|
||||
|
||||
class Claims(BaseModel):
|
||||
type: str
|
||||
sub: str
|
||||
user: ClaimsUser
|
||||
roles: List[str]
|
||||
exp: str
|
||||
|
||||
|
||||
class AccessToken(BaseModel):
|
||||
token_type: str
|
||||
access_token: str
|
||||
|
||||
|
||||
class RefreshToken(AccessToken):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class TokenType(str, Enum):
|
||||
"""
|
||||
Helper for indicating the token type when generating claims.
|
||||
"""
|
||||
|
||||
access = "access"
|
||||
refresh = "refresh"
|
||||
|
||||
|
||||
def claims(
|
||||
token_type: TokenType, user: User, groups: List[Group] = []
|
||||
) -> Claims:
|
||||
return Claims(
|
||||
type=token_type,
|
||||
sub=str(user.id),
|
||||
user=ClaimsUser(
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
is_admin=user.is_admin,
|
||||
is_staff=user.is_staff,
|
||||
),
|
||||
roles=[g.name for g in groups],
|
||||
exp="",
|
||||
)
|
||||
|
||||
|
||||
def create_token(claims: Claims, expires_in: int) -> str:
|
||||
to_encode = claims.model_dump()
|
||||
to_encode['exp'] = int(time.time()) + expires_in
|
||||
token = jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM)
|
||||
if isinstance(token, bytes):
|
||||
token = token.decode("utf-8")
|
||||
|
|
@ -48,29 +103,6 @@ def decode_token(token: str) -> Dict:
|
|||
)
|
||||
|
||||
|
||||
class TokenType(str, Enum):
|
||||
"""
|
||||
Helper for indicating the token type when generating claims.
|
||||
"""
|
||||
|
||||
access = "access"
|
||||
refresh = "refresh"
|
||||
|
||||
|
||||
def claims(
|
||||
token_type: TokenType, user: User, groups: List[Group] = []
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": token_type,
|
||||
"sub": str(user.id),
|
||||
"user": {
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
},
|
||||
"roles": [g.name for g in groups],
|
||||
}
|
||||
|
||||
|
||||
async def get_token(token: str = Depends(oauth2_scheme)) -> Dict:
|
||||
"""Dependency: token introspection"""
|
||||
payload = decode_token(token)
|
||||
|
|
@ -127,12 +159,7 @@ async def require_group(
|
|||
return current_user
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
@router.post("/login", response_model=RefreshToken)
|
||||
async def login(form_data: LoginRequest = Body(...)):
|
||||
"""Login with username & password, returns access and refresh tokens."""
|
||||
user = await User.authenticate(form_data.username, form_data.password)
|
||||
|
|
@ -148,14 +175,14 @@ async def login(form_data: LoginRequest = Body(...)):
|
|||
claims(TokenType.refresh, user), REFRESH_TOKEN_EXPIRE_SECONDS
|
||||
)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_type": "bearer",
|
||||
}
|
||||
return RefreshToken(
|
||||
token_type="bearer",
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
@router.post("/refresh", response_model=AccessToken)
|
||||
async def refresh_token(refresh_token: str):
|
||||
"""Exchange refresh token for new access token."""
|
||||
payload = decode_token(refresh_token)
|
||||
|
|
@ -174,15 +201,15 @@ async def refresh_token(refresh_token: str):
|
|||
new_access = create_token(
|
||||
claims(TokenType.access, user), ACCESS_TOKEN_EXPIRE_SECONDS
|
||||
)
|
||||
return {"access_token": new_access, "token_type": "bearer"}
|
||||
return AccessToken(token_type="bearer", access_token=access_token)
|
||||
|
||||
|
||||
@router.get("/introspect")
|
||||
@router.get("/introspect", response_model=Dict[str, Any])
|
||||
async def introspect(token: Dict = Depends(get_token)):
|
||||
return token
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
@router.get("/me", response_model=User.Schema.model)
|
||||
async def me(user: User = Depends(get_current_user)):
|
||||
"""Return the currently authenticated user."""
|
||||
return User.Schema.one.from_orm(user)
|
||||
return await User.Schema.model.from_tortoise_orm(user)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,2 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from http import HTTPStatus
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue