176 lines
5.4 KiB
Python
176 lines
5.4 KiB
Python
from fastapi.security import OAuth2PasswordBearer
|
|
import uvicorn
|
|
from fastapi import FastAPI, Depends, HTTPException, status
|
|
from fastapi.responses import RedirectResponse
|
|
from sqlalchemy.orm import Session
|
|
from typing import List
|
|
|
|
from app.auth.cognito import initiate_auth
|
|
from app.auth.dependencies import get_current_user, require_roles
|
|
from app.models.auth import CognitoUser, SigninRequest, TokenResponse
|
|
from app.models import ChannelDB, ChannelCreate, ChannelResponse
|
|
from app.utils.database import get_db
|
|
|
|
from fastapi import FastAPI, Depends, Security
|
|
from fastapi.security import OAuth2PasswordBearer
|
|
from fastapi.openapi.utils import get_openapi
|
|
|
|
app = FastAPI(
|
|
title="IPTV Updater API",
|
|
description="API for IPTV Updater service",
|
|
version="1.0.0",
|
|
)
|
|
|
|
def custom_openapi():
|
|
if app.openapi_schema:
|
|
return app.openapi_schema
|
|
|
|
openapi_schema = get_openapi(
|
|
title=app.title,
|
|
version=app.version,
|
|
description=app.description,
|
|
routes=app.routes,
|
|
)
|
|
|
|
# Ensure components object exists
|
|
if "components" not in openapi_schema:
|
|
openapi_schema["components"] = {}
|
|
|
|
# Add schemas if they don't exist
|
|
if "schemas" not in openapi_schema["components"]:
|
|
openapi_schema["components"]["schemas"] = {}
|
|
|
|
# Add security scheme component
|
|
openapi_schema["components"]["securitySchemes"] = {
|
|
"Bearer": {
|
|
"type": "http",
|
|
"scheme": "bearer",
|
|
"bearerFormat": "JWT"
|
|
}
|
|
}
|
|
|
|
# Add global security requirement
|
|
openapi_schema["security"] = [{"Bearer": []}]
|
|
|
|
# Set OpenAPI version explicitly
|
|
openapi_schema["openapi"] = "3.1.0"
|
|
|
|
app.openapi_schema = openapi_schema
|
|
return app.openapi_schema
|
|
|
|
app.openapi = custom_openapi
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
return {"message": "IPTV Updater API"}
|
|
|
|
@app.post("/signin", response_model=TokenResponse, summary="Signin Endpoint")
|
|
def signin(credentials: SigninRequest):
|
|
"""
|
|
Sign-in endpoint to authenticate the user with AWS Cognito using username and password.
|
|
On success, returns JWT tokens (access_token, id_token, refresh_token).
|
|
"""
|
|
auth_result = initiate_auth(credentials.username, credentials.password)
|
|
return TokenResponse(
|
|
access_token=auth_result["AccessToken"],
|
|
id_token=auth_result["IdToken"],
|
|
refresh_token=auth_result.get("RefreshToken"),
|
|
token_type="Bearer",
|
|
)
|
|
|
|
@app.get("/protected",
|
|
summary="Protected endpoint for authenticated users")
|
|
async def protected_route(user: CognitoUser = Depends(get_current_user)):
|
|
"""
|
|
Protected endpoint that requires for all authenticated users.
|
|
If the user is authenticates, returns success message.
|
|
"""
|
|
return {"message": f"Hello {user.username}, you have access to support resources!"}
|
|
|
|
@app.get("/protected_admin", summary="Protected endpoint for Admin role")
|
|
@require_roles("admin")
|
|
def protected_admin_endpoint(user: CognitoUser = Depends(get_current_user)):
|
|
"""
|
|
Protected endpoint that requires the 'admin' role.
|
|
If the user has 'admin' role, returns success message.
|
|
"""
|
|
return {"message": f"Hello {user.username}, you have admin privileges!"}
|
|
|
|
# Channel CRUD Endpoints
|
|
@app.post("/channels", response_model=ChannelResponse, status_code=status.HTTP_201_CREATED)
|
|
@require_roles("admin")
|
|
def create_channel(
|
|
channel: ChannelCreate,
|
|
db: Session = Depends(get_db),
|
|
user: CognitoUser = Depends(get_current_user)
|
|
):
|
|
"""Create a new channel"""
|
|
db_channel = ChannelDB(**channel.model_dump())
|
|
db.add(db_channel)
|
|
db.commit()
|
|
db.refresh(db_channel)
|
|
return db_channel
|
|
|
|
@app.get("/channels/{tvg_id}", response_model=ChannelResponse)
|
|
def get_channel(
|
|
tvg_id: str,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Get a channel by tvg_id"""
|
|
channel = db.query(ChannelDB).filter(ChannelDB.tvg_id == tvg_id).first()
|
|
if not channel:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Channel not found"
|
|
)
|
|
return channel
|
|
|
|
@app.put("/channels/{tvg_id}", response_model=ChannelResponse)
|
|
@require_roles("admin")
|
|
def update_channel(
|
|
tvg_id: str,
|
|
channel: ChannelCreate,
|
|
db: Session = Depends(get_db),
|
|
user: CognitoUser = Depends(get_current_user)
|
|
):
|
|
"""Update a channel"""
|
|
db_channel = db.query(ChannelDB).filter(ChannelDB.tvg_id == tvg_id).first()
|
|
if not db_channel:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Channel not found"
|
|
)
|
|
|
|
for key, value in channel.model_dump().items():
|
|
setattr(db_channel, key, value)
|
|
|
|
db.commit()
|
|
db.refresh(db_channel)
|
|
return db_channel
|
|
|
|
@app.delete("/channels/{tvg_id}", status_code=status.HTTP_204_NO_CONTENT)
|
|
@require_roles("admin")
|
|
def delete_channel(
|
|
tvg_id: str,
|
|
db: Session = Depends(get_db),
|
|
user: CognitoUser = Depends(get_current_user)
|
|
):
|
|
"""Delete a channel"""
|
|
channel = db.query(ChannelDB).filter(ChannelDB.tvg_id == tvg_id).first()
|
|
if not channel:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Channel not found"
|
|
)
|
|
db.delete(channel)
|
|
db.commit()
|
|
return None
|
|
|
|
@app.get("/channels", response_model=List[ChannelResponse])
|
|
def list_channels(
|
|
skip: int = 0,
|
|
limit: int = 100,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""List all channels with pagination"""
|
|
return db.query(ChannelDB).offset(skip).limit(limit).all() |