♻️ Refactor ohmyapi_auth
- improved type-safety - created and defined response_models
This commit is contained in:
parent
64d6ca369f
commit
90f257ae38
2 changed files with 69 additions and 43 deletions
|
|
@ -26,9 +26,64 @@ REFRESH_TOKEN_EXPIRE_SECONDS = getattr(
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")
|
||||||
|
|
||||||
|
|
||||||
def create_token(data: dict, expires_in: int) -> str:
|
class ClaimsUser(BaseModel):
|
||||||
to_encode = data.copy()
|
username: str
|
||||||
to_encode.update({"exp": int(time.time()) + expires_in})
|
email: str
|
||||||
|
is_admin: bool
|
||||||
|
is_staff: bool
|
||||||
|
|
||||||
|
|
||||||
|
class Claims(BaseModel):
|
||||||
|
type: str
|
||||||
|
sub: str
|
||||||
|
user: ClaimsUser
|
||||||
|
roles: List[str]
|
||||||
|
exp: str
|
||||||
|
|
||||||
|
|
||||||
|
class AccessToken(BaseModel):
|
||||||
|
token_type: str
|
||||||
|
access_token: str
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshToken(AccessToken):
|
||||||
|
refresh_token: str
|
||||||
|
|
||||||
|
|
||||||
|
class LoginRequest(BaseModel):
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class TokenType(str, Enum):
|
||||||
|
"""
|
||||||
|
Helper for indicating the token type when generating claims.
|
||||||
|
"""
|
||||||
|
|
||||||
|
access = "access"
|
||||||
|
refresh = "refresh"
|
||||||
|
|
||||||
|
|
||||||
|
def claims(
|
||||||
|
token_type: TokenType, user: User, groups: List[Group] = []
|
||||||
|
) -> Claims:
|
||||||
|
return Claims(
|
||||||
|
type=token_type,
|
||||||
|
sub=str(user.id),
|
||||||
|
user=ClaimsUser(
|
||||||
|
username=user.username,
|
||||||
|
email=user.email,
|
||||||
|
is_admin=user.is_admin,
|
||||||
|
is_staff=user.is_staff,
|
||||||
|
),
|
||||||
|
roles=[g.name for g in groups],
|
||||||
|
exp="",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_token(claims: Claims, expires_in: int) -> str:
|
||||||
|
to_encode = claims.model_dump()
|
||||||
|
to_encode['exp'] = int(time.time()) + expires_in
|
||||||
token = jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM)
|
token = jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM)
|
||||||
if isinstance(token, bytes):
|
if isinstance(token, bytes):
|
||||||
token = token.decode("utf-8")
|
token = token.decode("utf-8")
|
||||||
|
|
@ -48,29 +103,6 @@ def decode_token(token: str) -> Dict:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TokenType(str, Enum):
|
|
||||||
"""
|
|
||||||
Helper for indicating the token type when generating claims.
|
|
||||||
"""
|
|
||||||
|
|
||||||
access = "access"
|
|
||||||
refresh = "refresh"
|
|
||||||
|
|
||||||
|
|
||||||
def claims(
|
|
||||||
token_type: TokenType, user: User, groups: List[Group] = []
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": token_type,
|
|
||||||
"sub": str(user.id),
|
|
||||||
"user": {
|
|
||||||
"username": user.username,
|
|
||||||
"email": user.email,
|
|
||||||
},
|
|
||||||
"roles": [g.name for g in groups],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def get_token(token: str = Depends(oauth2_scheme)) -> Dict:
|
async def get_token(token: str = Depends(oauth2_scheme)) -> Dict:
|
||||||
"""Dependency: token introspection"""
|
"""Dependency: token introspection"""
|
||||||
payload = decode_token(token)
|
payload = decode_token(token)
|
||||||
|
|
@ -127,12 +159,7 @@ async def require_group(
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
class LoginRequest(BaseModel):
|
@router.post("/login", response_model=RefreshToken)
|
||||||
username: str
|
|
||||||
password: str
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/login")
|
|
||||||
async def login(form_data: LoginRequest = Body(...)):
|
async def login(form_data: LoginRequest = Body(...)):
|
||||||
"""Login with username & password, returns access and refresh tokens."""
|
"""Login with username & password, returns access and refresh tokens."""
|
||||||
user = await User.authenticate(form_data.username, form_data.password)
|
user = await User.authenticate(form_data.username, form_data.password)
|
||||||
|
|
@ -148,14 +175,14 @@ async def login(form_data: LoginRequest = Body(...)):
|
||||||
claims(TokenType.refresh, user), REFRESH_TOKEN_EXPIRE_SECONDS
|
claims(TokenType.refresh, user), REFRESH_TOKEN_EXPIRE_SECONDS
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return RefreshToken(
|
||||||
"access_token": access_token,
|
token_type="bearer",
|
||||||
"refresh_token": refresh_token,
|
access_token=access_token,
|
||||||
"token_type": "bearer",
|
refresh_token=refresh_token,
|
||||||
}
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/refresh")
|
@router.post("/refresh", response_model=AccessToken)
|
||||||
async def refresh_token(refresh_token: str):
|
async def refresh_token(refresh_token: str):
|
||||||
"""Exchange refresh token for new access token."""
|
"""Exchange refresh token for new access token."""
|
||||||
payload = decode_token(refresh_token)
|
payload = decode_token(refresh_token)
|
||||||
|
|
@ -174,15 +201,15 @@ async def refresh_token(refresh_token: str):
|
||||||
new_access = create_token(
|
new_access = create_token(
|
||||||
claims(TokenType.access, user), ACCESS_TOKEN_EXPIRE_SECONDS
|
claims(TokenType.access, user), ACCESS_TOKEN_EXPIRE_SECONDS
|
||||||
)
|
)
|
||||||
return {"access_token": new_access, "token_type": "bearer"}
|
return AccessToken(token_type="bearer", access_token=access_token)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/introspect")
|
@router.get("/introspect", response_model=Dict[str, Any])
|
||||||
async def introspect(token: Dict = Depends(get_token)):
|
async def introspect(token: Dict = Depends(get_token)):
|
||||||
return token
|
return token
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me")
|
@router.get("/me", response_model=User.Schema.model)
|
||||||
async def me(user: User = Depends(get_current_user)):
|
async def me(user: User = Depends(get_current_user)):
|
||||||
"""Return the currently authenticated user."""
|
"""Return the currently authenticated user."""
|
||||||
return User.Schema.one.from_orm(user)
|
return await User.Schema.model.from_tortoise_orm(user)
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,2 @@
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue