From 30ccf86c866a31c6e4db55d086ce2da03bfd7cd1 Mon Sep 17 00:00:00 2001 From: Stefano Date: Thu, 15 May 2025 16:52:54 -0500 Subject: [PATCH] Added cognito authentication - Fix 8 --- app/cabletv/utils/auth.py | 68 +++++++++++++++++++++++++++------------ app/main.py | 33 +++++-------------- 2 files changed, 56 insertions(+), 45 deletions(-) diff --git a/app/cabletv/utils/auth.py b/app/cabletv/utils/auth.py index bf7826f..ae1024e 100644 --- a/app/cabletv/utils/auth.py +++ b/app/cabletv/utils/auth.py @@ -1,19 +1,23 @@ +import os +import boto3 +import requests +import jwt from fastapi import Depends, HTTPException, status, Request from fastapi.security import OAuth2AuthorizationCodeBearer from fastapi.security.utils import get_authorization_scheme_param -import os -import jwt +from fastapi.responses import RedirectResponse REGION = "us-east-2" USER_POOL_ID = os.getenv("COGNITO_USER_POOL_ID") CLIENT_ID = os.getenv("COGNITO_CLIENT_ID") DOMAIN = f"https://iptv-updater.auth.{REGION}.amazoncognito.com" -class CustomOAuth2(OAuth2AuthorizationCodeBearer): +class BrowserAwareOAuth2(OAuth2AuthorizationCodeBearer): async def __call__(self, request: Request) -> str: # Check if this is a browser request is_browser = "text/html" in request.headers.get("accept", "") + # Try to get token from cookie first, then header authorization = request.cookies.get("token") if not authorization: authorization = request.headers.get("Authorization") @@ -22,31 +26,55 @@ class CustomOAuth2(OAuth2AuthorizationCodeBearer): if not authorization or scheme.lower() != "bearer": if is_browser: - redirect_uri = str(request.base_url)[:-1] + "/auth/callback" # Remove trailing slash - raise HTTPException( - status_code=302, - headers={ - "Location": f"{DOMAIN}/login?client_id={CLIENT_ID}" - f"&response_type=code" - f"&scope=openid+email+profile" - f"&redirect_uri={redirect_uri}" - } - ) - else: - raise HTTPException( - status_code=401, - detail="Not authenticated", - headers={"WWW-Authenticate": "Bearer"}, + redirect_uri = str(request.base_url) + "auth/callback" + # Return redirect for browser requests + return RedirectResponse( + f"{DOMAIN}/login?client_id={CLIENT_ID}" + f"&response_type=code" + f"&scope=openid+email+profile" + f"&redirect_uri={redirect_uri}", + status_code=302 ) + # Return 401 for API requests + raise HTTPException( + status_code=401, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) return param -oauth2_scheme = CustomOAuth2( +# Update the oauth2_scheme to use our custom class +oauth2_scheme = BrowserAwareOAuth2( authorizationUrl=f"{DOMAIN}/oauth2/authorize", - tokenUrl=f"{DOMAIN}/oauth2/token", + tokenUrl=f"{DOMAIN}/oauth2/token" ) +def exchange_code_for_token(code: str, redirect_uri: str): + token_url = f"{DOMAIN}/oauth2/token" + data = { + 'grant_type': 'authorization_code', + 'client_id': CLIENT_ID, + 'code': code, + 'redirect_uri': redirect_uri + } + + response = requests.post(token_url, data=data) + if response.status_code == 200: + return response.json() + print(f"Token exchange failed: {response.text}") + raise HTTPException(status_code=400, detail="Failed to exchange code for token") + async def get_current_user(request: Request, token: str = Depends(oauth2_scheme)): + if not token: + redirect_uri = str(request.base_url) + "auth/callback" + return RedirectResponse( + f"{DOMAIN}/login?client_id={CLIENT_ID}" + f"&response_type=code" + f"&scope=openid+email+profile" + f"&redirect_uri={redirect_uri}" + ) + try: decoded = jwt.decode( token, diff --git a/app/main.py b/app/main.py index ccce3bf..2951461 100644 --- a/app/main.py +++ b/app/main.py @@ -1,22 +1,12 @@ -from fastapi import FastAPI, Depends, HTTPException, Request, Response -from fastapi.responses import RedirectResponse, JSONResponse +from fastapi import FastAPI, Depends, HTTPException, Request +from fastapi.responses import RedirectResponse from app.cabletv.utils.auth import get_current_user, exchange_code_for_token -from fastapi.middleware.cors import CORSMiddleware -from starlette.middleware.sessions import SessionMiddleware app = FastAPI() -# Add CORS middleware -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -# Add session middleware -app.add_middleware(SessionMiddleware, secret_key="your-secret-key") +@app.get("/") +async def root(): + return {"message": "IPTV Updater API"} @app.get("/protected") async def protected_route(request: Request, user = Depends(get_current_user)): @@ -28,17 +18,10 @@ async def auth_callback(request: Request, code: str): redirect_uri = str(request.base_url) tokens = exchange_code_for_token(code, redirect_uri) - # For browser requests, redirect to protected page - is_browser = "text/html" in request.headers.get("accept", "") - if is_browser: - response = RedirectResponse(url="/protected") - else: - response = JSONResponse(content={ - "message": "Authentication successful", - "id_token": tokens["id_token"] - }) + # Create redirect response to protected route + response = RedirectResponse(url="/protected", status_code=302) - # Set the token cookie + # Set token cookie response.set_cookie( key="token", value=tokens["id_token"],