from fastapi.security import OAuth2PasswordBearer import uvicorn from fastapi import FastAPI, Depends, HTTPException, status from fastapi.responses import RedirectResponse from sqlalchemy.orm import Session from typing import List from app.auth.cognito import initiate_auth from app.auth.dependencies import get_current_user, require_roles from app.models.auth import CognitoUser, SigninRequest, TokenResponse from app.models import ChannelDB, ChannelCreate, ChannelResponse from app.utils.database import get_db from fastapi import FastAPI, Depends, Security from fastapi.security import OAuth2PasswordBearer from fastapi.openapi.utils import get_openapi app = FastAPI( title="IPTV Updater API", description="API for IPTV Updater service", version="1.0.0", ) def custom_openapi(): if app.openapi_schema: return app.openapi_schema openapi_schema = get_openapi( title=app.title, version=app.version, description=app.description, routes=app.routes, ) # Ensure components object exists if "components" not in openapi_schema: openapi_schema["components"] = {} # Add schemas if they don't exist if "schemas" not in openapi_schema["components"]: openapi_schema["components"]["schemas"] = {} # Add security scheme component openapi_schema["components"]["securitySchemes"] = { "Bearer": { "type": "http", "scheme": "bearer", "bearerFormat": "JWT" } } # Add global security requirement openapi_schema["security"] = [{"Bearer": []}] # Set OpenAPI version explicitly openapi_schema["openapi"] = "3.1.0" app.openapi_schema = openapi_schema return app.openapi_schema app.openapi = custom_openapi @app.get("/") async def root(): return {"message": "IPTV Updater API"} @app.post("/signin", response_model=TokenResponse, summary="Signin Endpoint") def signin(credentials: SigninRequest): """ Sign-in endpoint to authenticate the user with AWS Cognito using username and password. On success, returns JWT tokens (access_token, id_token, refresh_token). """ auth_result = initiate_auth(credentials.username, credentials.password) return TokenResponse( access_token=auth_result["AccessToken"], id_token=auth_result["IdToken"], refresh_token=auth_result.get("RefreshToken"), token_type="Bearer", ) @app.get("/protected", summary="Protected endpoint for authenticated users") async def protected_route(user: CognitoUser = Depends(get_current_user)): """ Protected endpoint that requires for all authenticated users. If the user is authenticates, returns success message. """ return {"message": f"Hello {user.username}, you have access to support resources!"} @app.get("/protected_admin", summary="Protected endpoint for Admin role") @require_roles("admin") def protected_admin_endpoint(user: CognitoUser = Depends(get_current_user)): """ Protected endpoint that requires the 'admin' role. If the user has 'admin' role, returns success message. """ return {"message": f"Hello {user.username}, you have admin privileges!"} # Channel CRUD Endpoints @app.post("/channels", response_model=ChannelResponse, status_code=status.HTTP_201_CREATED) @require_roles("admin") def create_channel( channel: ChannelCreate, db: Session = Depends(get_db), user: CognitoUser = Depends(get_current_user) ): """Create a new channel""" db_channel = ChannelDB(**channel.model_dump()) db.add(db_channel) db.commit() db.refresh(db_channel) return db_channel @app.get("/channels/{tvg_id}", response_model=ChannelResponse) def get_channel( tvg_id: str, db: Session = Depends(get_db) ): """Get a channel by tvg_id""" channel = db.query(ChannelDB).filter(ChannelDB.tvg_id == tvg_id).first() if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found" ) return channel @app.put("/channels/{tvg_id}", response_model=ChannelResponse) @require_roles("admin") def update_channel( tvg_id: str, channel: ChannelCreate, db: Session = Depends(get_db), user: CognitoUser = Depends(get_current_user) ): """Update a channel""" db_channel = db.query(ChannelDB).filter(ChannelDB.tvg_id == tvg_id).first() if not db_channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found" ) for key, value in channel.model_dump().items(): setattr(db_channel, key, value) db.commit() db.refresh(db_channel) return db_channel @app.delete("/channels/{tvg_id}", status_code=status.HTTP_204_NO_CONTENT) @require_roles("admin") def delete_channel( tvg_id: str, db: Session = Depends(get_db), user: CognitoUser = Depends(get_current_user) ): """Delete a channel""" channel = db.query(ChannelDB).filter(ChannelDB.tvg_id == tvg_id).first() if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found" ) db.delete(channel) db.commit() return None @app.get("/channels", response_model=List[ChannelResponse]) def list_channels( skip: int = 0, limit: int = 100, db: Session = Depends(get_db) ): """List all channels with pagination""" return db.query(ChannelDB).offset(skip).limit(limit).all()