Linted and formatted all files
This commit is contained in:
@@ -1,9 +1,14 @@
|
||||
import boto3
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.models.auth import CognitoUser
|
||||
from app.utils.auth import calculate_secret_hash
|
||||
from app.utils.constants import (AWS_REGION, COGNITO_CLIENT_ID,
|
||||
COGNITO_CLIENT_SECRET, USER_ROLE_ATTRIBUTE)
|
||||
from app.utils.constants import (
|
||||
AWS_REGION,
|
||||
COGNITO_CLIENT_ID,
|
||||
COGNITO_CLIENT_SECRET,
|
||||
USER_ROLE_ATTRIBUTE,
|
||||
)
|
||||
|
||||
cognito_client = boto3.client("cognito-idp", region_name=AWS_REGION)
|
||||
|
||||
@@ -12,43 +17,41 @@ def initiate_auth(username: str, password: str) -> dict:
|
||||
"""
|
||||
Initiate AUTH flow with Cognito using USER_PASSWORD_AUTH.
|
||||
"""
|
||||
auth_params = {
|
||||
"USERNAME": username,
|
||||
"PASSWORD": password
|
||||
}
|
||||
auth_params = {"USERNAME": username, "PASSWORD": password}
|
||||
|
||||
# If a client secret is required, add SECRET_HASH
|
||||
if COGNITO_CLIENT_SECRET:
|
||||
auth_params["SECRET_HASH"] = calculate_secret_hash(
|
||||
username, COGNITO_CLIENT_ID, COGNITO_CLIENT_SECRET)
|
||||
username, COGNITO_CLIENT_ID, COGNITO_CLIENT_SECRET
|
||||
)
|
||||
|
||||
try:
|
||||
response = cognito_client.initiate_auth(
|
||||
AuthFlow="USER_PASSWORD_AUTH",
|
||||
AuthParameters=auth_params,
|
||||
ClientId=COGNITO_CLIENT_ID
|
||||
ClientId=COGNITO_CLIENT_ID,
|
||||
)
|
||||
return response["AuthenticationResult"]
|
||||
except cognito_client.exceptions.NotAuthorizedException:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password"
|
||||
detail="Invalid username or password",
|
||||
)
|
||||
except cognito_client.exceptions.UserNotFoundException:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"An error occurred during authentication: {str(e)}"
|
||||
detail=f"An error occurred during authentication: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
def get_user_from_token(access_token: str) -> CognitoUser:
|
||||
"""
|
||||
Verify the token by calling GetUser in Cognito and retrieve user attributes including roles.
|
||||
Verify the token by calling GetUser in Cognito and
|
||||
retrieve user attributes including roles.
|
||||
"""
|
||||
try:
|
||||
user_response = cognito_client.get_user(AccessToken=access_token)
|
||||
@@ -59,23 +62,21 @@ def get_user_from_token(access_token: str) -> CognitoUser:
|
||||
for attr in attributes:
|
||||
if attr["Name"] == USER_ROLE_ATTRIBUTE:
|
||||
# Assume roles are stored as a comma-separated string
|
||||
user_roles = [r.strip()
|
||||
for r in attr["Value"].split(",") if r.strip()]
|
||||
user_roles = [r.strip() for r in attr["Value"].split(",") if r.strip()]
|
||||
break
|
||||
|
||||
return CognitoUser(username=username, roles=user_roles)
|
||||
except cognito_client.exceptions.NotAuthorizedException:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token."
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token."
|
||||
)
|
||||
except cognito_client.exceptions.UserNotFoundException:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found or invalid token."
|
||||
detail="User not found or invalid token.",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Token verification failed: {str(e)}"
|
||||
detail=f"Token verification failed: {str(e)}",
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
import os
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
@@ -13,10 +13,8 @@ if os.getenv("MOCK_AUTH", "").lower() == "true":
|
||||
else:
|
||||
from app.auth.cognito import get_user_from_token
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(
|
||||
tokenUrl="signin",
|
||||
scheme_name="Bearer"
|
||||
)
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="signin", scheme_name="Bearer")
|
||||
|
||||
|
||||
def get_current_user(token: str = Depends(oauth2_scheme)) -> CognitoUser:
|
||||
"""
|
||||
@@ -40,7 +38,9 @@ def require_roles(*required_roles: str) -> Callable:
|
||||
if not needed_roles.issubset(user_roles):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You do not have the required roles to access this endpoint.",
|
||||
detail=(
|
||||
"You do not have the required roles to access this endpoint."
|
||||
),
|
||||
)
|
||||
return endpoint(*args, user=user, **kwargs)
|
||||
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.models.auth import CognitoUser
|
||||
|
||||
MOCK_USERS = {
|
||||
"testuser": {
|
||||
"username": "testuser",
|
||||
"roles": ["admin"]
|
||||
}
|
||||
}
|
||||
MOCK_USERS = {"testuser": {"username": "testuser", "roles": ["admin"]}}
|
||||
|
||||
|
||||
def mock_get_user_from_token(token: str) -> CognitoUser:
|
||||
"""
|
||||
@@ -17,16 +14,13 @@ def mock_get_user_from_token(token: str) -> CognitoUser:
|
||||
return CognitoUser(**MOCK_USERS["testuser"])
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid mock token - use 'testuser'"
|
||||
detail="Invalid mock token - use 'testuser'",
|
||||
)
|
||||
|
||||
|
||||
def mock_initiate_auth(username: str, password: str) -> dict:
|
||||
"""
|
||||
Mock version of initiate_auth for local testing
|
||||
Accepts any username/password and returns a mock token
|
||||
"""
|
||||
return {
|
||||
"AccessToken": "testuser",
|
||||
"ExpiresIn": 3600,
|
||||
"TokenType": "Bearer"
|
||||
}
|
||||
return {"AccessToken": "testuser", "ExpiresIn": 3600, "TokenType": "Bearer"}
|
||||
|
||||
Reference in New Issue
Block a user