diff --git a/app/cabletv/utils/auth.py b/app/cabletv/utils/auth.py index 3fa4306..bf7826f 100644 --- a/app/cabletv/utils/auth.py +++ b/app/cabletv/utils/auth.py @@ -1,52 +1,52 @@ -import os -import boto3 -import requests -import jwt from fastapi import Depends, HTTPException, status, Request from fastapi.security import OAuth2AuthorizationCodeBearer -from fastapi.responses import RedirectResponse +from fastapi.security.utils import get_authorization_scheme_param +import os +import jwt REGION = "us-east-2" USER_POOL_ID = os.getenv("COGNITO_USER_POOL_ID") CLIENT_ID = os.getenv("COGNITO_CLIENT_ID") DOMAIN = f"https://iptv-updater.auth.{REGION}.amazoncognito.com" -# Remove the hardcoded REDIRECT_URI, we'll make it dynamic based on the request -class DynamicOAuth2(OAuth2AuthorizationCodeBearer): - async def __call__(self, request: Request): - self.redirect_uri = str(request.base_url) + "auth/callback" - return await super().__call__(request) +class CustomOAuth2(OAuth2AuthorizationCodeBearer): + async def __call__(self, request: Request) -> str: + # Check if this is a browser request + is_browser = "text/html" in request.headers.get("accept", "") + + authorization = request.cookies.get("token") + if not authorization: + authorization = request.headers.get("Authorization") + + scheme, param = get_authorization_scheme_param(authorization) + + if not authorization or scheme.lower() != "bearer": + if is_browser: + redirect_uri = str(request.base_url)[:-1] + "/auth/callback" # Remove trailing slash + raise HTTPException( + status_code=302, + headers={ + "Location": f"{DOMAIN}/login?client_id={CLIENT_ID}" + f"&response_type=code" + f"&scope=openid+email+profile" + f"&redirect_uri={redirect_uri}" + } + ) + else: + raise HTTPException( + status_code=401, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) + + return param -oauth2_scheme = DynamicOAuth2( +oauth2_scheme = CustomOAuth2( authorizationUrl=f"{DOMAIN}/oauth2/authorize", - tokenUrl=f"{DOMAIN}/oauth2/token" + tokenUrl=f"{DOMAIN}/oauth2/token", ) -def exchange_code_for_token(code: str, redirect_uri: str): - token_url = f"{DOMAIN}/oauth2/token" - data = { - 'grant_type': 'authorization_code', - 'client_id': CLIENT_ID, - 'code': code, - 'redirect_uri': redirect_uri - } - - response = requests.post(token_url, data=data) - if response.status_code == 200: - return response.json() - print(f"Token exchange failed: {response.text}") - raise HTTPException(status_code=400, detail="Failed to exchange code for token") - async def get_current_user(request: Request, token: str = Depends(oauth2_scheme)): - if not token: - redirect_uri = str(request.base_url) + "auth/callback" - return RedirectResponse( - f"{DOMAIN}/login?client_id={CLIENT_ID}" - f"&response_type=code" - f"&scope=openid+email+profile" - f"&redirect_uri={redirect_uri}" - ) - try: decoded = jwt.decode( token, diff --git a/app/main.py b/app/main.py index 84dcde9..ccce3bf 100644 --- a/app/main.py +++ b/app/main.py @@ -1,17 +1,25 @@ -from fastapi import FastAPI, Depends, HTTPException, Request +from fastapi import FastAPI, Depends, HTTPException, Request, Response from fastapi.responses import RedirectResponse, JSONResponse from app.cabletv.utils.auth import get_current_user, exchange_code_for_token +from fastapi.middleware.cors import CORSMiddleware +from starlette.middleware.sessions import SessionMiddleware app = FastAPI() -@app.get("/") -async def root(): - return {"message": "IPTV Updater API"} +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Add session middleware +app.add_middleware(SessionMiddleware, secret_key="your-secret-key") @app.get("/protected") async def protected_route(request: Request, user = Depends(get_current_user)): - if isinstance(user, RedirectResponse): - return user return {"message": "Protected content", "user": user['Username']} @app.get("/auth/callback") @@ -20,11 +28,17 @@ async def auth_callback(request: Request, code: str): redirect_uri = str(request.base_url) tokens = exchange_code_for_token(code, redirect_uri) - response = JSONResponse(content={ - "message": "Authentication successful", - "id_token": tokens["id_token"] - }) + # For browser requests, redirect to protected page + is_browser = "text/html" in request.headers.get("accept", "") + if is_browser: + response = RedirectResponse(url="/protected") + else: + response = JSONResponse(content={ + "message": "Authentication successful", + "id_token": tokens["id_token"] + }) + # Set the token cookie response.set_cookie( key="token", value=tokens["id_token"],