From 02913c738597dd01fdfeb678d099166eeda382a7 Mon Sep 17 00:00:00 2001 From: Stefano Date: Wed, 28 May 2025 21:52:39 -0500 Subject: [PATCH] Linted and formatted all files --- .vscode/settings.json | 3 + alembic/env.py | 12 +- app.py | 16 +- app/auth/cognito.py | 39 +-- app/auth/dependencies.py | 12 +- app/auth/mock_auth.py | 18 +- app/iptv/createEpg.py | 106 ++++--- app/iptv/createPlaylist.py | 102 +++--- app/main.py | 19 +- app/models/__init__.py | 19 +- app/models/auth.py | 14 +- app/models/db.py | 54 +++- app/models/schemas.py | 35 +- app/routers/auth.py | 12 +- app/routers/channels.py | 179 ++++++----- app/routers/playlist.py | 12 +- app/routers/priorities.py | 53 ++-- app/utils/auth.py | 8 +- app/utils/check_streams.py | 85 ++--- app/utils/constants.py | 2 +- app/utils/database.py | 35 +- infrastructure/stack.py | 236 +++++++------- pyproject.toml | 13 +- tests/auth/test_cognito.py | 98 +++--- tests/auth/test_dependencies.py | 97 +++--- tests/auth/test_mock_auth.py | 13 +- tests/routers/test_auth.py | 70 ++-- tests/routers/test_channels.py | 544 +++++++++++++++++++++++--------- tests/test_main.py | 20 +- tests/utils/db_mocks.py | 62 +++- tests/utils/test_database.py | 42 +-- 31 files changed, 1264 insertions(+), 766 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 01047c9..9373c89 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -39,6 +39,7 @@ "dflogo", "dmlogo", "dotenv", + "EXTINF", "fastapi", "filterwarnings", "fiorinis", @@ -47,6 +48,7 @@ "gitea", "iptv", "isort", + "KHTML", "lclogo", "LETSENCRYPT", "nohup", @@ -76,6 +78,7 @@ "testpaths", "uflogo", "umlogo", + "usefixtures", "uvicorn", "venv", "wrongpass" diff --git a/alembic/env.py b/alembic/env.py index b329a4b..8426f76 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -1,12 +1,10 @@ -import os from logging.config import fileConfig -from sqlalchemy import engine_from_config -from sqlalchemy import pool +from sqlalchemy import engine_from_config, pool from alembic import context -from app.utils.database import get_db_credentials from app.models.db import Base +from app.utils.database import get_db_credentials # this is the Alembic Config object, which provides # access to the values within the .ini file in use. @@ -22,7 +20,7 @@ target_metadata = Base.metadata # Override sqlalchemy.url with dynamic credentials if not context.is_offline_mode(): - config.set_main_option('sqlalchemy.url', get_db_credentials()) + config.set_main_option("sqlalchemy.url", get_db_credentials()) # other values from the config, defined by the needs of env.py, # can be acquired: @@ -68,9 +66,7 @@ def run_migrations_online() -> None: ) with connectable.connect() as connection: - context.configure( - connection=connection, target_metadata=target_metadata - ) + context.configure(connection=connection, target_metadata=target_metadata) with context.begin_transaction(): context.run_migrations() diff --git a/app.py b/app.py index d967164..495aab6 100644 --- a/app.py +++ b/app.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 import os + import aws_cdk as cdk + from infrastructure.stack import IptvUpdaterStack app = cdk.App() @@ -19,21 +21,25 @@ required_vars = { "DOMAIN_NAME": domain_name, "SSH_PUBLIC_KEY": ssh_public_key, "REPO_URL": repo_url, - "LETSENCRYPT_EMAIL": letsencrypt_email + "LETSENCRYPT_EMAIL": letsencrypt_email, } # Check for missing required variables missing_vars = [k for k, v in required_vars.items() if not v] if missing_vars: - raise ValueError(f"Missing required environment variables: {', '.join(missing_vars)}") + raise ValueError( + f"Missing required environment variables: {', '.join(missing_vars)}" + ) -IptvUpdaterStack(app, "IptvUpdaterStack", +IptvUpdaterStack( + app, + "IptvUpdaterStack", freedns_user=freedns_user, freedns_password=freedns_password, domain_name=domain_name, ssh_public_key=ssh_public_key, repo_url=repo_url, - letsencrypt_email=letsencrypt_email + letsencrypt_email=letsencrypt_email, ) -app.synth() \ No newline at end of file +app.synth() diff --git a/app/auth/cognito.py b/app/auth/cognito.py index 243dd0e..e251d1e 100644 --- a/app/auth/cognito.py +++ b/app/auth/cognito.py @@ -1,9 +1,14 @@ import boto3 from fastapi import HTTPException, status + from app.models.auth import CognitoUser from app.utils.auth import calculate_secret_hash -from app.utils.constants import (AWS_REGION, COGNITO_CLIENT_ID, - COGNITO_CLIENT_SECRET, USER_ROLE_ATTRIBUTE) +from app.utils.constants import ( + AWS_REGION, + COGNITO_CLIENT_ID, + COGNITO_CLIENT_SECRET, + USER_ROLE_ATTRIBUTE, +) cognito_client = boto3.client("cognito-idp", region_name=AWS_REGION) @@ -12,43 +17,41 @@ def initiate_auth(username: str, password: str) -> dict: """ Initiate AUTH flow with Cognito using USER_PASSWORD_AUTH. """ - auth_params = { - "USERNAME": username, - "PASSWORD": password - } + auth_params = {"USERNAME": username, "PASSWORD": password} # If a client secret is required, add SECRET_HASH if COGNITO_CLIENT_SECRET: auth_params["SECRET_HASH"] = calculate_secret_hash( - username, COGNITO_CLIENT_ID, COGNITO_CLIENT_SECRET) + username, COGNITO_CLIENT_ID, COGNITO_CLIENT_SECRET + ) try: response = cognito_client.initiate_auth( AuthFlow="USER_PASSWORD_AUTH", AuthParameters=auth_params, - ClientId=COGNITO_CLIENT_ID + ClientId=COGNITO_CLIENT_ID, ) return response["AuthenticationResult"] except cognito_client.exceptions.NotAuthorizedException: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid username or password" + detail="Invalid username or password", ) except cognito_client.exceptions.UserNotFoundException: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found" + status_code=status.HTTP_404_NOT_FOUND, detail="User not found" ) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"An error occurred during authentication: {str(e)}" + detail=f"An error occurred during authentication: {str(e)}", ) def get_user_from_token(access_token: str) -> CognitoUser: """ - Verify the token by calling GetUser in Cognito and retrieve user attributes including roles. + Verify the token by calling GetUser in Cognito and + retrieve user attributes including roles. """ try: user_response = cognito_client.get_user(AccessToken=access_token) @@ -59,23 +62,21 @@ def get_user_from_token(access_token: str) -> CognitoUser: for attr in attributes: if attr["Name"] == USER_ROLE_ATTRIBUTE: # Assume roles are stored as a comma-separated string - user_roles = [r.strip() - for r in attr["Value"].split(",") if r.strip()] + user_roles = [r.strip() for r in attr["Value"].split(",") if r.strip()] break return CognitoUser(username=username, roles=user_roles) except cognito_client.exceptions.NotAuthorizedException: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid or expired token." + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token." ) except cognito_client.exceptions.UserNotFoundException: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="User not found or invalid token." + detail="User not found or invalid token.", ) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Token verification failed: {str(e)}" + detail=f"Token verification failed: {str(e)}", ) diff --git a/app/auth/dependencies.py b/app/auth/dependencies.py index 5726772..debd019 100644 --- a/app/auth/dependencies.py +++ b/app/auth/dependencies.py @@ -1,6 +1,6 @@ +import os from functools import wraps from typing import Callable -import os from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer @@ -13,10 +13,8 @@ if os.getenv("MOCK_AUTH", "").lower() == "true": else: from app.auth.cognito import get_user_from_token -oauth2_scheme = OAuth2PasswordBearer( - tokenUrl="signin", - scheme_name="Bearer" -) +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="signin", scheme_name="Bearer") + def get_current_user(token: str = Depends(oauth2_scheme)) -> CognitoUser: """ @@ -40,7 +38,9 @@ def require_roles(*required_roles: str) -> Callable: if not needed_roles.issubset(user_roles): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="You do not have the required roles to access this endpoint.", + detail=( + "You do not have the required roles to access this endpoint." + ), ) return endpoint(*args, user=user, **kwargs) diff --git a/app/auth/mock_auth.py b/app/auth/mock_auth.py index a7a1598..fc6dff6 100644 --- a/app/auth/mock_auth.py +++ b/app/auth/mock_auth.py @@ -1,12 +1,9 @@ from fastapi import HTTPException, status + from app.models.auth import CognitoUser -MOCK_USERS = { - "testuser": { - "username": "testuser", - "roles": ["admin"] - } -} +MOCK_USERS = {"testuser": {"username": "testuser", "roles": ["admin"]}} + def mock_get_user_from_token(token: str) -> CognitoUser: """ @@ -17,16 +14,13 @@ def mock_get_user_from_token(token: str) -> CognitoUser: return CognitoUser(**MOCK_USERS["testuser"]) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid mock token - use 'testuser'" + detail="Invalid mock token - use 'testuser'", ) + def mock_initiate_auth(username: str, password: str) -> dict: """ Mock version of initiate_auth for local testing Accepts any username/password and returns a mock token """ - return { - "AccessToken": "testuser", - "ExpiresIn": 3600, - "TokenType": "Bearer" - } \ No newline at end of file + return {"AccessToken": "testuser", "ExpiresIn": 3600, "TokenType": "Bearer"} diff --git a/app/iptv/createEpg.py b/app/iptv/createEpg.py index d51c88d..0cf9e48 100644 --- a/app/iptv/createEpg.py +++ b/app/iptv/createEpg.py @@ -1,39 +1,59 @@ -import os -import re +import argparse import gzip import json +import os +import re import xml.etree.ElementTree as ET + import requests -import argparse -from utils.constants import IPTV_SERVER_ADMIN_PASSWORD, IPTV_SERVER_ADMIN_USER, IPTV_SERVER_URL +from utils.constants import ( + IPTV_SERVER_ADMIN_PASSWORD, + IPTV_SERVER_ADMIN_USER, + IPTV_SERVER_URL, +) + def parse_arguments(): - parser = argparse.ArgumentParser(description='EPG Grabber') - parser.add_argument('--playlist', - default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'playlist.m3u8'), - help='Path to playlist file') - parser.add_argument('--output', - default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'epg.xml'), - help='Path to output EPG XML file') - parser.add_argument('--epg-sources', - default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'epg_sources.json'), - help='Path to EPG sources JSON configuration file') - parser.add_argument('--save-as-gz', - action='store_true', - default=True, - help='Save an additional gzipped version of the EPG file') + parser = argparse.ArgumentParser(description="EPG Grabber") + parser.add_argument( + "--playlist", + default=os.path.join( + os.path.dirname(os.path.dirname(__file__)), "playlist.m3u8" + ), + help="Path to playlist file", + ) + parser.add_argument( + "--output", + default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "epg.xml"), + help="Path to output EPG XML file", + ) + parser.add_argument( + "--epg-sources", + default=os.path.join( + os.path.dirname(os.path.dirname(__file__)), "epg_sources.json" + ), + help="Path to EPG sources JSON configuration file", + ) + parser.add_argument( + "--save-as-gz", + action="store_true", + default=True, + help="Save an additional gzipped version of the EPG file", + ) return parser.parse_args() + def load_epg_sources(config_path): """Load EPG sources from JSON configuration file""" try: - with open(config_path, 'r', encoding='utf-8') as f: + with open(config_path, encoding="utf-8") as f: config = json.load(f) - return config.get('epg_sources', []) + return config.get("epg_sources", []) except (FileNotFoundError, json.JSONDecodeError) as e: print(f"Error loading EPG sources: {e}") return [] - + + def get_tvg_ids(playlist_path): """ Extracts unique tvg-id values from an M3U playlist file. @@ -51,26 +71,27 @@ def get_tvg_ids(playlist_path): # and ends with a double quote. tvg_id_pattern = re.compile(r'tvg-id="([^"]*)"') - with open(playlist_path, 'r', encoding='utf-8') as file: + with open(playlist_path, encoding="utf-8") as file: for line in file: - if line.startswith('#EXTINF'): + if line.startswith("#EXTINF"): # Search for the tvg-id pattern in the line match = tvg_id_pattern.search(line) if match: # Extract the captured group (the value inside the quotes) tvg_id = match.group(1) - if tvg_id: # Ensure the extracted id is not empty + if tvg_id: # Ensure the extracted id is not empty unique_tvg_ids.add(tvg_id) return list(unique_tvg_ids) + def fetch_and_extract_xml(url): response = requests.get(url) if response.status_code != 200: print(f"Failed to fetch {url}") return None - if url.endswith('.gz'): + if url.endswith(".gz"): try: decompressed_data = gzip.decompress(response.content) return ET.fromstring(decompressed_data) @@ -84,44 +105,48 @@ def fetch_and_extract_xml(url): print(f"Failed to parse XML from {url}: {e}") return None + def filter_and_build_epg(urls, tvg_ids, output_file, save_as_gz=True): - root = ET.Element('tv') + root = ET.Element("tv") for url in urls: epg_data = fetch_and_extract_xml(url) if epg_data is None: continue - for channel in epg_data.findall('channel'): - tvg_id = channel.get('id') + for channel in epg_data.findall("channel"): + tvg_id = channel.get("id") if tvg_id in tvg_ids: root.append(channel) - for programme in epg_data.findall('programme'): - tvg_id = programme.get('channel') + for programme in epg_data.findall("programme"): + tvg_id = programme.get("channel") if tvg_id in tvg_ids: root.append(programme) tree = ET.ElementTree(root) - tree.write(output_file, encoding='utf-8', xml_declaration=True) + tree.write(output_file, encoding="utf-8", xml_declaration=True) print(f"New EPG saved to {output_file}") if save_as_gz: - output_file_gz = output_file + '.gz' - with gzip.open(output_file_gz, 'wb') as f: - tree.write(f, encoding='utf-8', xml_declaration=True) + output_file_gz = output_file + ".gz" + with gzip.open(output_file_gz, "wb") as f: + tree.write(f, encoding="utf-8", xml_declaration=True) print(f"New EPG saved to {output_file_gz}") + def upload_epg(file_path): """Uploads gzipped EPG file to IPTV server using HTTP Basic Auth""" try: - with open(file_path, 'rb') as f: + with open(file_path, "rb") as f: response = requests.post( - IPTV_SERVER_URL + '/admin/epg', - auth=requests.auth.HTTPBasicAuth(IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD), - files={'file': (os.path.basename(file_path), f)} + IPTV_SERVER_URL + "/admin/epg", + auth=requests.auth.HTTPBasicAuth( + IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD + ), + files={"file": (os.path.basename(file_path), f)}, ) - + if response.status_code == 200: print("EPG successfully uploaded to server") else: @@ -129,6 +154,7 @@ def upload_epg(file_path): except Exception as e: print(f"Upload error: {str(e)}") + if __name__ == "__main__": args = parse_arguments() playlist_file = args.playlist @@ -144,4 +170,4 @@ if __name__ == "__main__": filter_and_build_epg(urls, tvg_ids, output_file, args.save_as_gz) if args.save_as_gz: - upload_epg(output_file + '.gz') + upload_epg(output_file + ".gz") diff --git a/app/iptv/createPlaylist.py b/app/iptv/createPlaylist.py index 96956b6..e834830 100644 --- a/app/iptv/createPlaylist.py +++ b/app/iptv/createPlaylist.py @@ -1,26 +1,45 @@ -import os import argparse import json import logging -import requests -from pathlib import Path +import os from datetime import datetime + +import requests from utils.check_streams import StreamValidator -from utils.constants import EPG_URL, IPTV_SERVER_ADMIN_PASSWORD, IPTV_SERVER_ADMIN_USER, IPTV_SERVER_URL +from utils.constants import ( + EPG_URL, + IPTV_SERVER_ADMIN_PASSWORD, + IPTV_SERVER_ADMIN_USER, + IPTV_SERVER_URL, +) + def parse_arguments(): - parser = argparse.ArgumentParser(description='IPTV playlist generator') - parser.add_argument('--output', - default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'playlist.m3u8'), - help='Path to output playlist file') - parser.add_argument('--channels', - default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'channels.json'), - help='Path to channels definition JSON file') - parser.add_argument('--dead-channels-log', - default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'dead_channels.log'), - help='Path to log file to store a list of dead channels') + parser = argparse.ArgumentParser(description="IPTV playlist generator") + parser.add_argument( + "--output", + default=os.path.join( + os.path.dirname(os.path.dirname(__file__)), "playlist.m3u8" + ), + help="Path to output playlist file", + ) + parser.add_argument( + "--channels", + default=os.path.join( + os.path.dirname(os.path.dirname(__file__)), "channels.json" + ), + help="Path to channels definition JSON file", + ) + parser.add_argument( + "--dead-channels-log", + default=os.path.join( + os.path.dirname(os.path.dirname(__file__)), "dead_channels.log" + ), + help="Path to log file to store a list of dead channels", + ) return parser.parse_args() + def find_working_stream(validator, urls): """Test all URLs and return the first working one""" for url in urls: @@ -29,48 +48,55 @@ def find_working_stream(validator, urls): return url return None + def create_playlist(channels_file, output_file): # Read channels from JSON file - with open(channels_file, 'r', encoding='utf-8') as f: + with open(channels_file, encoding="utf-8") as f: channels = json.load(f) # Initialize validator validator = StreamValidator(timeout=45) - + # Prepare M3U8 header m3u8_content = f'#EXTM3U url-tvg="{EPG_URL}"\n' - + for channel in channels: - if 'urls' in channel: # Check if channel has URLs + if "urls" in channel: # Check if channel has URLs # Find first working stream - working_url = find_working_stream(validator, channel['urls']) - + working_url = find_working_stream(validator, channel["urls"]) + if working_url: # Add channel to playlist m3u8_content += f'#EXTINF:-1 tvg-id="{channel.get("tvg-id", "")}" ' m3u8_content += f'tvg-name="{channel.get("tvg-name", "")}" ' m3u8_content += f'tvg-logo="{channel.get("tvg-logo", "")}" ' m3u8_content += f'group-title="{channel.get("group-title", "")}", ' - m3u8_content += f'{channel.get("name", "")}\n' - m3u8_content += f'{working_url}\n' + m3u8_content += f"{channel.get('name', '')}\n" + m3u8_content += f"{working_url}\n" else: # Log dead channel - logging.info(f'Dead channel: {channel.get("name", "Unknown")} - No working streams found') + logging.info( + f"Dead channel: {channel.get('name', 'Unknown')} - " + "No working streams found" + ) # Write playlist file - with open(output_file, 'w', encoding='utf-8') as f: + with open(output_file, "w", encoding="utf-8") as f: f.write(m3u8_content) + def upload_playlist(file_path): """Uploads playlist file to IPTV server using HTTP Basic Auth""" try: - with open(file_path, 'rb') as f: + with open(file_path, "rb") as f: response = requests.post( - IPTV_SERVER_URL + '/admin/playlist', - auth=requests.auth.HTTPBasicAuth(IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD), - files={'file': (os.path.basename(file_path), f)} + IPTV_SERVER_URL + "/admin/playlist", + auth=requests.auth.HTTPBasicAuth( + IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD + ), + files={"file": (os.path.basename(file_path), f)}, ) - + if response.status_code == 200: print("Playlist successfully uploaded to server") else: @@ -78,6 +104,7 @@ def upload_playlist(file_path): except Exception as e: print(f"Upload error: {str(e)}") + def main(): args = parse_arguments() channels_file = args.channels @@ -85,24 +112,25 @@ def main(): dead_channels_log_file = args.dead_channels_log # Clear previous log file - with open(dead_channels_log_file, 'w') as f: - f.write(f'Log created on {datetime.now()}\n') + with open(dead_channels_log_file, "w") as f: + f.write(f"Log created on {datetime.now()}\n") # Configure logging logging.basicConfig( filename=dead_channels_log_file, level=logging.INFO, - format='%(asctime)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' + format="%(asctime)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", ) - + # Create playlist create_playlist(channels_file, output_file) - #upload playlist to server + # upload playlist to server upload_playlist(output_file) - + print("Playlist creation completed!") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/app/main.py b/app/main.py index a4aa9a8..205660c 100644 --- a/app/main.py +++ b/app/main.py @@ -1,17 +1,18 @@ - -from fastapi.concurrency import asynccontextmanager -from app.routers import channels, auth, playlist, priorities from fastapi import FastAPI +from fastapi.concurrency import asynccontextmanager from fastapi.openapi.utils import get_openapi +from app.routers import auth, channels, playlist, priorities from app.utils.database import init_db + @asynccontextmanager async def lifespan(app: FastAPI): # Initialize database tables on startup init_db() yield + app = FastAPI( lifespan=lifespan, title="IPTV Updater API", @@ -19,6 +20,7 @@ app = FastAPI( version="1.0.0", ) + def custom_openapi(): if app.openapi_schema: return app.openapi_schema @@ -40,11 +42,7 @@ def custom_openapi(): # Add security scheme component openapi_schema["components"]["securitySchemes"] = { - "Bearer": { - "type": "http", - "scheme": "bearer", - "bearerFormat": "JWT" - } + "Bearer": {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"} } # Add global security requirement @@ -56,14 +54,17 @@ def custom_openapi(): app.openapi_schema = openapi_schema return app.openapi_schema + app.openapi = custom_openapi + @app.get("/") async def root(): return {"message": "IPTV Updater API"} + # Include routers app.include_router(auth.router) app.include_router(channels.router) app.include_router(playlist.router) -app.include_router(priorities.router) \ No newline at end of file +app.include_router(priorities.router) diff --git a/app/models/__init__.py b/app/models/__init__.py index 4e7dcad..29b82a9 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -1,4 +1,19 @@ from .db import Base, ChannelDB, ChannelURL -from .schemas import ChannelCreate, ChannelUpdate, ChannelResponse, ChannelURLCreate, ChannelURLResponse +from .schemas import ( + ChannelCreate, + ChannelResponse, + ChannelUpdate, + ChannelURLCreate, + ChannelURLResponse, +) -__all__ = ["Base", "ChannelDB", "ChannelCreate", "ChannelUpdate", "ChannelResponse", "ChannelURL", "ChannelURLCreate", "ChannelURLResponse"] \ No newline at end of file +__all__ = [ + "Base", + "ChannelDB", + "ChannelCreate", + "ChannelUpdate", + "ChannelResponse", + "ChannelURL", + "ChannelURLCreate", + "ChannelURLResponse", +] diff --git a/app/models/auth.py b/app/models/auth.py index f9ddabc..6bd2c4d 100644 --- a/app/models/auth.py +++ b/app/models/auth.py @@ -1,20 +1,26 @@ -from typing import List, Optional +from typing import Optional + from pydantic import BaseModel, Field + class SigninRequest(BaseModel): """Request model for the signin endpoint.""" + username: str = Field(..., description="The user's username") password: str = Field(..., description="The user's password") + class TokenResponse(BaseModel): """Response model for successful authentication.""" + access_token: str = Field(..., description="Access JWT token from Cognito") id_token: str = Field(..., description="ID JWT token from Cognito") - refresh_token: Optional[str] = Field( - None, description="Refresh token from Cognito") + refresh_token: Optional[str] = Field(None, description="Refresh token from Cognito") token_type: str = Field(..., description="Type of the token returned") + class CognitoUser(BaseModel): """Model representing the user returned from token verification.""" + username: str - roles: List[str] \ No newline at end of file + roles: list[str] diff --git a/app/models/db.py b/app/models/db.py index 46369c6..adf02ce 100644 --- a/app/models/db.py +++ b/app/models/db.py @@ -1,21 +1,33 @@ -from datetime import datetime, timezone import uuid -from sqlalchemy import Column, String, JSON, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer +from datetime import datetime, timezone + +from sqlalchemy import ( + Boolean, + Column, + DateTime, + ForeignKey, + Integer, + String, + UniqueConstraint, +) from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import declarative_base -from sqlalchemy.orm import relationship +from sqlalchemy.orm import declarative_base, relationship Base = declarative_base() + class Priority(Base): """SQLAlchemy model for channel URL priorities""" + __tablename__ = "priorities" - + id = Column(Integer, primary_key=True) description = Column(String, nullable=False) + class ChannelDB(Base): """SQLAlchemy model for IPTV channels""" + __tablename__ = "channels" id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) @@ -25,27 +37,43 @@ class ChannelDB(Base): tvg_name = Column(String) __table_args__ = ( - UniqueConstraint('group_title', 'name', name='uix_group_title_name'), + UniqueConstraint("group_title", "name", name="uix_group_title_name"), ) tvg_logo = Column(String) created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) - updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) - + updated_at = Column( + DateTime, + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + ) + # Relationship with ChannelURL - urls = relationship("ChannelURL", back_populates="channel", cascade="all, delete-orphan") + urls = relationship( + "ChannelURL", back_populates="channel", cascade="all, delete-orphan" + ) + class ChannelURL(Base): """SQLAlchemy model for channel URLs""" + __tablename__ = "channels_urls" id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - channel_id = Column(UUID(as_uuid=True), ForeignKey('channels.id', ondelete='CASCADE'), nullable=False) + channel_id = Column( + UUID(as_uuid=True), + ForeignKey("channels.id", ondelete="CASCADE"), + nullable=False, + ) url = Column(String, nullable=False) in_use = Column(Boolean, default=False, nullable=False) - priority_id = Column(Integer, ForeignKey('priorities.id'), nullable=False) + priority_id = Column(Integer, ForeignKey("priorities.id"), nullable=False) created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) - updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) + updated_at = Column( + DateTime, + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + ) # Relationships channel = relationship("ChannelDB", back_populates="urls") - priority = relationship("Priority") \ No newline at end of file + priority = relationship("Priority") diff --git a/app/models/schemas.py b/app/models/schemas.py index 534ed68..22e122c 100644 --- a/app/models/schemas.py +++ b/app/models/schemas.py @@ -1,30 +1,43 @@ from datetime import datetime -from typing import List, Optional +from typing import Optional from uuid import UUID -from pydantic import BaseModel, Field, ConfigDict + +from pydantic import BaseModel, ConfigDict, Field + class PriorityBase(BaseModel): """Base Pydantic model for priorities""" + id: int description: str model_config = ConfigDict(from_attributes=True) + class PriorityCreate(PriorityBase): """Pydantic model for creating priorities""" + pass + class PriorityResponse(PriorityBase): """Pydantic model for priority responses""" + pass + class ChannelURLCreate(BaseModel): """Pydantic model for creating channel URLs""" + url: str - priority_id: int = Field(default=100, ge=100, le=300) # Default to High, validate range + priority_id: int = Field( + default=100, ge=100, le=300 + ) # Default to High, validate range + class ChannelURLBase(ChannelURLCreate): """Base Pydantic model for channel URL responses""" + id: UUID in_use: bool created_at: datetime @@ -33,43 +46,53 @@ class ChannelURLBase(ChannelURLCreate): model_config = ConfigDict(from_attributes=True) + class ChannelURLResponse(ChannelURLBase): """Pydantic model for channel URL responses""" + pass + class ChannelCreate(BaseModel): """Pydantic model for creating channels""" - urls: List[ChannelURLCreate] # List of URL objects with priority + + urls: list[ChannelURLCreate] # List of URL objects with priority name: str group_title: str tvg_id: str tvg_logo: str tvg_name: str + class ChannelURLUpdate(BaseModel): """Pydantic model for updating channel URLs""" + url: Optional[str] = None in_use: Optional[bool] = None priority_id: Optional[int] = Field(default=None, ge=100, le=300) + class ChannelUpdate(BaseModel): """Pydantic model for updating channels (all fields optional)""" + name: Optional[str] = Field(None, min_length=1) group_title: Optional[str] = Field(None, min_length=1) tvg_id: Optional[str] = Field(None, min_length=1) tvg_logo: Optional[str] = None tvg_name: Optional[str] = Field(None, min_length=1) + class ChannelResponse(BaseModel): """Pydantic model for channel responses""" + id: UUID name: str group_title: str tvg_id: str tvg_logo: str tvg_name: str - urls: List[ChannelURLResponse] # List of URL objects without channel_id + urls: list[ChannelURLResponse] # List of URL objects without channel_id created_at: datetime updated_at: datetime - model_config = ConfigDict(from_attributes=True) \ No newline at end of file + model_config = ConfigDict(from_attributes=True) diff --git a/app/routers/auth.py b/app/routers/auth.py index f6e2f21..de8d45e 100644 --- a/app/routers/auth.py +++ b/app/routers/auth.py @@ -1,16 +1,16 @@ from fastapi import APIRouter + from app.auth.cognito import initiate_auth from app.models.auth import SigninRequest, TokenResponse -router = APIRouter( - prefix="/auth", - tags=["authentication"] -) +router = APIRouter(prefix="/auth", tags=["authentication"]) + @router.post("/signin", response_model=TokenResponse, summary="Signin Endpoint") def signin(credentials: SigninRequest): """ - Sign-in endpoint to authenticate the user with AWS Cognito using username and password. + Sign-in endpoint to authenticate the user with AWS Cognito + using username and password. On success, returns JWT tokens (access_token, id_token, refresh_token). """ auth_result = initiate_auth(credentials.username, credentials.password) @@ -19,4 +19,4 @@ def signin(credentials: SigninRequest): id_token=auth_result["IdToken"], refresh_token=auth_result.get("RefreshToken"), token_type="Bearer", - ) \ No newline at end of file + ) diff --git a/app/routers/channels.py b/app/routers/channels.py index a592402..4c2118a 100644 --- a/app/routers/channels.py +++ b/app/routers/channels.py @@ -1,52 +1,54 @@ -from fastapi import APIRouter, Depends, HTTPException, status -from sqlalchemy.orm import Session -from typing import List from uuid import UUID -from sqlalchemy import and_ +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy import and_ +from sqlalchemy.orm import Session + +from app.auth.dependencies import get_current_user, require_roles from app.models import ( - ChannelDB, - ChannelURL, ChannelCreate, - ChannelUpdate, + ChannelDB, ChannelResponse, + ChannelUpdate, + ChannelURL, ChannelURLCreate, ChannelURLResponse, ) +from app.models.auth import CognitoUser from app.models.schemas import ChannelURLUpdate from app.utils.database import get_db -from app.auth.dependencies import get_current_user, require_roles -from app.models.auth import CognitoUser -router = APIRouter( - prefix="/channels", - tags=["channels"] -) +router = APIRouter(prefix="/channels", tags=["channels"]) + @router.post("/", response_model=ChannelResponse, status_code=status.HTTP_201_CREATED) @require_roles("admin") def create_channel( channel: ChannelCreate, db: Session = Depends(get_db), - user: CognitoUser = Depends(get_current_user) + user: CognitoUser = Depends(get_current_user), ): """Create a new channel""" # Check for duplicate channel (same group_title + name) - existing_channel = db.query(ChannelDB).filter( - and_( - ChannelDB.group_title == channel.group_title, - ChannelDB.name == channel.name + existing_channel = ( + db.query(ChannelDB) + .filter( + and_( + ChannelDB.group_title == channel.group_title, + ChannelDB.name == channel.name, + ) ) - ).first() - + .first() + ) + if existing_channel: raise HTTPException( status_code=status.HTTP_409_CONFLICT, - detail="Channel with same group_title and name already exists" + detail="Channel with same group_title and name already exists", ) # Create channel without URLs first - channel_data = channel.model_dump(exclude={'urls'}) + channel_data = channel.model_dump(exclude={"urls"}) urls = channel.urls db_channel = ChannelDB(**channel_data) db.add(db_channel) @@ -59,130 +61,142 @@ def create_channel( channel_id=db_channel.id, url=url.url, priority_id=url.priority_id, - in_use=False + in_use=False, ) db.add(db_url) - + db.commit() db.refresh(db_channel) return db_channel + @router.get("/{channel_id}", response_model=ChannelResponse) -def get_channel( - channel_id: UUID, - db: Session = Depends(get_db) -): +def get_channel(channel_id: UUID, db: Session = Depends(get_db)): """Get a channel by id""" channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first() if not channel: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Channel not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found" ) return channel + @router.put("/{channel_id}", response_model=ChannelResponse) @require_roles("admin") def update_channel( channel_id: UUID, channel: ChannelUpdate, db: Session = Depends(get_db), - user: CognitoUser = Depends(get_current_user) + user: CognitoUser = Depends(get_current_user), ): """Update a channel""" db_channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first() if not db_channel: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Channel not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found" ) # Only check for duplicates if name or group_title are being updated if channel.name is not None or channel.group_title is not None: name = channel.name if channel.name is not None else db_channel.name - group_title = channel.group_title if channel.group_title is not None else db_channel.group_title - - existing_channel = db.query(ChannelDB).filter( - and_( - ChannelDB.group_title == group_title, - ChannelDB.name == name, - ChannelDB.id != channel_id + group_title = ( + channel.group_title + if channel.group_title is not None + else db_channel.group_title + ) + + existing_channel = ( + db.query(ChannelDB) + .filter( + and_( + ChannelDB.group_title == group_title, + ChannelDB.name == name, + ChannelDB.id != channel_id, + ) ) - ).first() - + .first() + ) + if existing_channel: raise HTTPException( status_code=status.HTTP_409_CONFLICT, - detail="Channel with same group_title and name already exists" + detail="Channel with same group_title and name already exists", ) - + # Update only provided fields update_data = channel.model_dump(exclude_unset=True) for key, value in update_data.items(): setattr(db_channel, key, value) - + db.commit() db.refresh(db_channel) return db_channel + @router.delete("/{channel_id}", status_code=status.HTTP_204_NO_CONTENT) @require_roles("admin") def delete_channel( channel_id: UUID, db: Session = Depends(get_db), - user: CognitoUser = Depends(get_current_user) + user: CognitoUser = Depends(get_current_user), ): """Delete a channel""" channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first() if not channel: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Channel not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found" ) db.delete(channel) db.commit() return None -@router.get("/", response_model=List[ChannelResponse]) + +@router.get("/", response_model=list[ChannelResponse]) @require_roles("admin") def list_channels( skip: int = 0, limit: int = 100, db: Session = Depends(get_db), - user: CognitoUser = Depends(get_current_user) + user: CognitoUser = Depends(get_current_user), ): """List all channels with pagination""" return db.query(ChannelDB).offset(skip).limit(limit).all() + # URL Management Endpoints -@router.post("/{channel_id}/urls", response_model=ChannelURLResponse, status_code=status.HTTP_201_CREATED) + +@router.post( + "/{channel_id}/urls", + response_model=ChannelURLResponse, + status_code=status.HTTP_201_CREATED, +) @require_roles("admin") def add_channel_url( channel_id: UUID, url: ChannelURLCreate, db: Session = Depends(get_db), - user: CognitoUser = Depends(get_current_user) + user: CognitoUser = Depends(get_current_user), ): """Add a new URL to a channel""" channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first() if not channel: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Channel not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found" ) db_url = ChannelURL( channel_id=channel_id, url=url.url, priority_id=url.priority_id, - in_use=False # Default to not in use + in_use=False, # Default to not in use ) db.add(db_url) db.commit() db.refresh(db_url) return db_url + @router.put("/{channel_id}/urls/{url_id}", response_model=ChannelURLResponse) @require_roles("admin") def update_channel_url( @@ -190,72 +204,69 @@ def update_channel_url( url_id: UUID, url_update: ChannelURLUpdate, db: Session = Depends(get_db), - user: CognitoUser = Depends(get_current_user) + user: CognitoUser = Depends(get_current_user), ): """Update a channel URL (url, in_use, or priority_id)""" - db_url = db.query(ChannelURL).filter( - and_( - ChannelURL.id == url_id, - ChannelURL.channel_id == channel_id - ) - ).first() - + db_url = ( + db.query(ChannelURL) + .filter(and_(ChannelURL.id == url_id, ChannelURL.channel_id == channel_id)) + .first() + ) + if not db_url: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="URL not found" + status_code=status.HTTP_404_NOT_FOUND, detail="URL not found" ) - + if url_update.url is not None: db_url.url = url_update.url if url_update.in_use is not None: db_url.in_use = url_update.in_use if url_update.priority_id is not None: db_url.priority_id = url_update.priority_id - + db.commit() db.refresh(db_url) return db_url + @router.delete("/{channel_id}/urls/{url_id}", status_code=status.HTTP_204_NO_CONTENT) @require_roles("admin") def delete_channel_url( channel_id: UUID, url_id: UUID, db: Session = Depends(get_db), - user: CognitoUser = Depends(get_current_user) + user: CognitoUser = Depends(get_current_user), ): """Delete a URL from a channel""" - url = db.query(ChannelURL).filter( - and_( - ChannelURL.id == url_id, - ChannelURL.channel_id == channel_id - ) - ).first() - + url = ( + db.query(ChannelURL) + .filter(and_(ChannelURL.id == url_id, ChannelURL.channel_id == channel_id)) + .first() + ) + if not url: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="URL not found" + status_code=status.HTTP_404_NOT_FOUND, detail="URL not found" ) - + db.delete(url) db.commit() return None -@router.get("/{channel_id}/urls", response_model=List[ChannelURLResponse]) + +@router.get("/{channel_id}/urls", response_model=list[ChannelURLResponse]) @require_roles("admin") def list_channel_urls( channel_id: UUID, db: Session = Depends(get_db), - user: CognitoUser = Depends(get_current_user) + user: CognitoUser = Depends(get_current_user), ): """List all URLs for a channel""" channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first() if not channel: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Channel not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found" ) - - return db.query(ChannelURL).filter(ChannelURL.channel_id == channel_id).all() \ No newline at end of file + + return db.query(ChannelURL).filter(ChannelURL.channel_id == channel_id).all() diff --git a/app/routers/playlist.py b/app/routers/playlist.py index e5e4c09..c38e80a 100644 --- a/app/routers/playlist.py +++ b/app/routers/playlist.py @@ -1,17 +1,15 @@ from fastapi import APIRouter, Depends + from app.auth.dependencies import get_current_user from app.models.auth import CognitoUser -router = APIRouter( - prefix="/playlist", - tags=["playlist"] -) +router = APIRouter(prefix="/playlist", tags=["playlist"]) -@router.get("/protected", - summary="Protected endpoint for authenticated users") + +@router.get("/protected", summary="Protected endpoint for authenticated users") async def protected_route(user: CognitoUser = Depends(get_current_user)): """ Protected endpoint that requires authentication for all users. If the user is authenticated, returns success message. """ - return {"message": f"Hello {user.username}, you have access to support resources!"} \ No newline at end of file + return {"message": f"Hello {user.username}, you have access to support resources!"} diff --git a/app/routers/priorities.py b/app/routers/priorities.py index 2d863ed..1e43f82 100644 --- a/app/routers/priorities.py +++ b/app/routers/priorities.py @@ -1,25 +1,22 @@ from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy import delete, select from sqlalchemy.orm import Session -from sqlalchemy import select, delete -from typing import List +from app.auth.dependencies import get_current_user, require_roles +from app.models.auth import CognitoUser from app.models.db import Priority from app.models.schemas import PriorityCreate, PriorityResponse from app.utils.database import get_db -from app.auth.dependencies import get_current_user, require_roles -from app.models.auth import CognitoUser -router = APIRouter( - prefix="/priorities", - tags=["priorities"] -) +router = APIRouter(prefix="/priorities", tags=["priorities"]) + @router.post("/", response_model=PriorityResponse, status_code=status.HTTP_201_CREATED) @require_roles("admin") def create_priority( priority: PriorityCreate, db: Session = Depends(get_db), - user: CognitoUser = Depends(get_current_user) + user: CognitoUser = Depends(get_current_user), ): """Create a new priority""" # Check if priority with this ID already exists @@ -27,71 +24,69 @@ def create_priority( if existing: raise HTTPException( status_code=status.HTTP_409_CONFLICT, - detail=f"Priority with ID {priority.id} already exists" + detail=f"Priority with ID {priority.id} already exists", ) - + db_priority = Priority(**priority.model_dump()) db.add(db_priority) db.commit() db.refresh(db_priority) return db_priority -@router.get("/", response_model=List[PriorityResponse]) + +@router.get("/", response_model=list[PriorityResponse]) @require_roles("admin") def list_priorities( - db: Session = Depends(get_db), - user: CognitoUser = Depends(get_current_user) + db: Session = Depends(get_db), user: CognitoUser = Depends(get_current_user) ): """List all priorities""" return db.query(Priority).all() + @router.get("/{priority_id}", response_model=PriorityResponse) @require_roles("admin") def get_priority( priority_id: int, db: Session = Depends(get_db), - user: CognitoUser = Depends(get_current_user) + user: CognitoUser = Depends(get_current_user), ): """Get a priority by id""" priority = db.get(Priority, priority_id) if not priority: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Priority not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Priority not found" ) return priority + @router.delete("/{priority_id}", status_code=status.HTTP_204_NO_CONTENT) @require_roles("admin") def delete_priority( priority_id: int, db: Session = Depends(get_db), - user: CognitoUser = Depends(get_current_user) + user: CognitoUser = Depends(get_current_user), ): """Delete a priority (if not in use)""" from app.models.db import ChannelURL - + # Check if priority exists priority = db.get(Priority, priority_id) if not priority: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Priority not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Priority not found" ) - + # Check if priority is in use in_use = db.scalar( - select(ChannelURL) - .where(ChannelURL.priority_id == priority_id) - .limit(1) + select(ChannelURL).where(ChannelURL.priority_id == priority_id).limit(1) ) - + if in_use: raise HTTPException( status_code=status.HTTP_409_CONFLICT, - detail="Cannot delete priority that is in use by channel URLs" + detail="Cannot delete priority that is in use by channel URLs", ) - + db.execute(delete(Priority).where(Priority.id == priority_id)) db.commit() - return None \ No newline at end of file + return None diff --git a/app/utils/auth.py b/app/utils/auth.py index e592867..516dd50 100644 --- a/app/utils/auth.py +++ b/app/utils/auth.py @@ -2,11 +2,13 @@ import base64 import hashlib import hmac + def calculate_secret_hash(username: str, client_id: str, client_secret: str) -> str: """ Calculate the Cognito SECRET_HASH using HMAC SHA256 for secret-enabled clients. """ msg = username + client_id - dig = hmac.new(client_secret.encode('utf-8'), - msg.encode('utf-8'), hashlib.sha256).digest() - return base64.b64encode(dig).decode() \ No newline at end of file + dig = hmac.new( + client_secret.encode("utf-8"), msg.encode("utf-8"), hashlib.sha256 + ).digest() + return base64.b64encode(dig).decode() diff --git a/app/utils/check_streams.py b/app/utils/check_streams.py index 026e9c9..55b3ca4 100644 --- a/app/utils/check_streams.py +++ b/app/utils/check_streams.py @@ -1,41 +1,50 @@ -import os import argparse -import requests import logging -from requests.exceptions import RequestException, Timeout, ConnectionError, HTTPError +import os + +import requests +from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout + class StreamValidator: def __init__(self, timeout=10, user_agent=None): self.timeout = timeout self.session = requests.Session() - self.session.headers.update({ - 'User-Agent': user_agent or 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36' - }) - + self.session.headers.update( + { + "User-Agent": user_agent + or ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/120.0.0.0 Safari/537.36" + ) + } + ) + def validate_stream(self, url): """Validate a media stream URL with multiple fallback checks""" try: - headers = {'Range': 'bytes=0-1024'} + headers = {"Range": "bytes=0-1024"} with self.session.get( url, headers=headers, timeout=self.timeout, stream=True, - allow_redirects=True + allow_redirects=True, ) as response: if response.status_code not in [200, 206]: return False, f"Invalid status code: {response.status_code}" - - content_type = response.headers.get('Content-Type', '') + + content_type = response.headers.get("Content-Type", "") if not self._is_valid_content_type(content_type): return False, f"Invalid content type: {content_type}" - + try: next(response.iter_content(chunk_size=1024)) return True, "Stream is valid" except (ConnectionError, Timeout): return False, "Connection failed during content read" - + except HTTPError as e: return False, f"HTTP Error: {str(e)}" except ConnectionError as e: @@ -49,10 +58,13 @@ class StreamValidator: def _is_valid_content_type(self, content_type): valid_types = [ - 'video/mp2t', 'application/vnd.apple.mpegurl', - 'application/dash+xml', 'video/mp4', - 'video/webm', 'application/octet-stream', - 'application/x-mpegURL' + "video/mp2t", + "application/vnd.apple.mpegurl", + "application/dash+xml", + "video/mp4", + "video/webm", + "application/octet-stream", + "application/x-mpegURL", ] return any(ct in content_type for ct in valid_types) @@ -60,45 +72,43 @@ class StreamValidator: """Extract stream URLs from M3U playlist file""" urls = [] try: - with open(file_path, 'r') as f: + with open(file_path) as f: for line in f: line = line.strip() - if line and not line.startswith('#'): + if line and not line.startswith("#"): urls.append(line) except Exception as e: logging.error(f"Error reading playlist file: {str(e)}") raise return urls + def main(): parser = argparse.ArgumentParser( - description='Validate streaming URLs from command line arguments or playlist files', - formatter_class=argparse.ArgumentDefaultsHelpFormatter + description=( + "Validate streaming URLs from command line arguments or playlist files" + ), + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( - 'sources', - nargs='+', - help='List of URLs or file paths containing stream URLs' + "sources", nargs="+", help="List of URLs or file paths containing stream URLs" ) parser.add_argument( - '--timeout', - type=int, - default=20, - help='Timeout in seconds for stream checks' + "--timeout", type=int, default=20, help="Timeout in seconds for stream checks" ) parser.add_argument( - '--output', - default='deadstreams.txt', - help='Output file name for inactive streams' + "--output", + default="deadstreams.txt", + help="Output file name for inactive streams", ) - + args = parser.parse_args() # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s', - handlers=[logging.FileHandler('stream_check.log'), logging.StreamHandler()] + format="%(asctime)s - %(levelname)s - %(message)s", + handlers=[logging.FileHandler("stream_check.log"), logging.StreamHandler()], ) validator = StreamValidator(timeout=args.timeout) @@ -127,9 +137,10 @@ def main(): # Save dead streams to file if dead_streams: - with open(args.output, 'w') as f: - f.write('\n'.join(dead_streams)) + with open(args.output, "w") as f: + f.write("\n".join(dead_streams)) logging.info(f"Found {len(dead_streams)} dead streams. Saved to {args.output}.") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/app/utils/constants.py b/app/utils/constants.py index 8bc9aab..b72fc4c 100644 --- a/app/utils/constants.py +++ b/app/utils/constants.py @@ -21,4 +21,4 @@ IPTV_SERVER_ADMIN_USER = os.getenv("IPTV_SERVER_ADMIN_USER", "admin") IPTV_SERVER_ADMIN_PASSWORD = os.getenv("IPTV_SERVER_ADMIN_PASSWORD", "adminpassword") # URL for the EPG XML file to place in the playlist's header -EPG_URL = os.getenv("EPG_URL", "https://example.com/epg.xml.gz") \ No newline at end of file +EPG_URL = os.getenv("EPG_URL", "https://example.com/epg.xml.gz") diff --git a/app/utils/database.py b/app/utils/database.py index 273dd36..be56764 100644 --- a/app/utils/database.py +++ b/app/utils/database.py @@ -1,11 +1,13 @@ import os + import boto3 -from app.models import Base -from .constants import AWS_REGION from sqlalchemy import create_engine -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker -from functools import lru_cache + +from app.models import Base + +from .constants import AWS_REGION + def get_db_credentials(): """Fetch and cache DB credentials from environment or SSM Parameter Store""" @@ -14,29 +16,40 @@ def get_db_credentials(): f"postgresql://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}" f"@{os.getenv('DB_HOST')}/{os.getenv('DB_NAME')}" ) - - ssm = boto3.client('ssm', region_name=AWS_REGION) + + ssm = boto3.client("ssm", region_name=AWS_REGION) try: - host = ssm.get_parameter(Name='/iptv-updater/DB_HOST', WithDecryption=True)['Parameter']['Value'] - user = ssm.get_parameter(Name='/iptv-updater/DB_USER', WithDecryption=True)['Parameter']['Value'] - password = ssm.get_parameter(Name='/iptv-updater/DB_PASSWORD', WithDecryption=True)['Parameter']['Value'] - dbname = ssm.get_parameter(Name='/iptv-updater/DB_NAME', WithDecryption=True)['Parameter']['Value'] + host = ssm.get_parameter(Name="/iptv-updater/DB_HOST", WithDecryption=True)[ + "Parameter" + ]["Value"] + user = ssm.get_parameter(Name="/iptv-updater/DB_USER", WithDecryption=True)[ + "Parameter" + ]["Value"] + password = ssm.get_parameter( + Name="/iptv-updater/DB_PASSWORD", WithDecryption=True + )["Parameter"]["Value"] + dbname = ssm.get_parameter(Name="/iptv-updater/DB_NAME", WithDecryption=True)[ + "Parameter" + ]["Value"] return f"postgresql://{user}:{password}@{host}/{dbname}" except Exception as e: raise RuntimeError(f"Failed to fetch DB credentials from SSM: {str(e)}") + # Initialize engine and session maker engine = create_engine(get_db_credentials()) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + def init_db(): """Initialize database by creating all tables""" Base.metadata.create_all(bind=engine) + def get_db(): """Dependency for getting database session""" db = SessionLocal() try: yield db finally: - db.close() \ No newline at end of file + db.close() diff --git a/infrastructure/stack.py b/infrastructure/stack.py index 1d63f58..e4679d2 100644 --- a/infrastructure/stack.py +++ b/infrastructure/stack.py @@ -1,80 +1,69 @@ import os -from aws_cdk import ( - Duration, - RemovalPolicy, - Stack, - aws_ec2 as ec2, - aws_iam as iam, - aws_cognito as cognito, - aws_rds as rds, - aws_ssm as ssm, - CfnOutput -) + +from aws_cdk import CfnOutput, Duration, RemovalPolicy, Stack +from aws_cdk import aws_cognito as cognito +from aws_cdk import aws_ec2 as ec2 +from aws_cdk import aws_iam as iam +from aws_cdk import aws_rds as rds +from aws_cdk import aws_ssm as ssm from constructs import Construct + class IptvUpdaterStack(Stack): def __init__( - self, - scope: Construct, - construct_id: str, - freedns_user: str, - freedns_password: str, - domain_name: str, - ssh_public_key: str, - repo_url: str, - letsencrypt_email: str, - **kwargs - ) -> None: + self, + scope: Construct, + construct_id: str, + freedns_user: str, + freedns_password: str, + domain_name: str, + ssh_public_key: str, + repo_url: str, + letsencrypt_email: str, + **kwargs, + ) -> None: super().__init__(scope, construct_id, **kwargs) # Create VPC - vpc = ec2.Vpc(self, "IptvUpdaterVPC", + vpc = ec2.Vpc( + self, + "IptvUpdaterVPC", max_azs=2, # Need at least 2 AZs for RDS subnet group nat_gateways=0, # No NAT Gateway to stay in free tier subnet_configuration=[ ec2.SubnetConfiguration( - name="public", - subnet_type=ec2.SubnetType.PUBLIC, - cidr_mask=24 + name="public", subnet_type=ec2.SubnetType.PUBLIC, cidr_mask=24 ), ec2.SubnetConfiguration( name="private", subnet_type=ec2.SubnetType.PRIVATE_ISOLATED, - cidr_mask=24 - ) - ] + cidr_mask=24, + ), + ], ) # Security Group security_group = ec2.SecurityGroup( - self, "IptvUpdaterSG", - vpc=vpc, - allow_all_outbound=True + self, "IptvUpdaterSG", vpc=vpc, allow_all_outbound=True ) security_group.add_ingress_rule( - ec2.Peer.any_ipv4(), - ec2.Port.tcp(443), - "Allow HTTPS traffic" - ) - - security_group.add_ingress_rule( - ec2.Peer.any_ipv4(), - ec2.Port.tcp(80), - "Allow HTTP traffic" + ec2.Peer.any_ipv4(), ec2.Port.tcp(443), "Allow HTTPS traffic" ) security_group.add_ingress_rule( - ec2.Peer.any_ipv4(), - ec2.Port.tcp(22), - "Allow SSH traffic" + ec2.Peer.any_ipv4(), ec2.Port.tcp(80), "Allow HTTP traffic" + ) + + security_group.add_ingress_rule( + ec2.Peer.any_ipv4(), ec2.Port.tcp(22), "Allow SSH traffic" ) # Allow PostgreSQL port for tunneling restricted to developer IP security_group.add_ingress_rule( ec2.Peer.ipv4("47.189.88.48/32"), # Developer IP ec2.Port.tcp(5432), - "Allow PostgreSQL traffic for tunneling" + "Allow PostgreSQL traffic for tunneling", ) # Key pair for IPTV Updater instance @@ -82,13 +71,14 @@ class IptvUpdaterStack(Stack): self, "IptvUpdaterKeyPair", key_pair_name="iptv-updater-key", - public_key_material=ssh_public_key + public_key_material=ssh_public_key, ) # Create IAM role for EC2 role = iam.Role( - self, "IptvUpdaterRole", - assumed_by=iam.ServicePrincipal("ec2.amazonaws.com") + self, + "IptvUpdaterRole", + assumed_by=iam.ServicePrincipal("ec2.amazonaws.com"), ) # Add SSM managed policy @@ -99,37 +89,36 @@ class IptvUpdaterStack(Stack): ) # Add EC2 describe permissions - role.add_to_policy(iam.PolicyStatement( - actions=["ec2:DescribeInstances"], - resources=["*"] - )) + role.add_to_policy( + iam.PolicyStatement(actions=["ec2:DescribeInstances"], resources=["*"]) + ) # Add SSM SendCommand permissions - role.add_to_policy(iam.PolicyStatement( - actions=["ssm:SendCommand"], - resources=[ - f"arn:aws:ec2:{self.region}:{self.account}:instance/*", # Allow on all EC2 instances - f"arn:aws:ssm:{self.region}:{self.account}:document/AWS-RunShellScript" # Required for the RunShellScript document - ] - )) + role.add_to_policy( + iam.PolicyStatement( + actions=["ssm:SendCommand"], + resources=[ + # Allow on all EC2 instances + f"arn:aws:ec2:{self.region}:{self.account}:instance/*", + # Required for the RunShellScript document + f"arn:aws:ssm:{self.region}:{self.account}:document/AWS-RunShellScript", + ], + ) + ) # Add Cognito permissions to instance role role.add_managed_policy( - iam.ManagedPolicy.from_aws_managed_policy_name( - "AmazonCognitoReadOnly" - ) + iam.ManagedPolicy.from_aws_managed_policy_name("AmazonCognitoReadOnly") ) # EC2 Instance instance = ec2.Instance( - self, "IptvUpdaterInstance", + self, + "IptvUpdaterInstance", vpc=vpc, - vpc_subnets=ec2.SubnetSelection( - subnet_type=ec2.SubnetType.PUBLIC - ), + vpc_subnets=ec2.SubnetSelection(subnet_type=ec2.SubnetType.PUBLIC), instance_type=ec2.InstanceType.of( - ec2.InstanceClass.T2, - ec2.InstanceSize.MICRO + ec2.InstanceClass.T2, ec2.InstanceSize.MICRO ), machine_image=ec2.AmazonLinuxImage( generation=ec2.AmazonLinuxGeneration.AMAZON_LINUX_2023 @@ -138,7 +127,7 @@ class IptvUpdaterStack(Stack): key_pair=key_pair, role=role, # Option: 1: Enable auto-assign public IP (free tier compatible) - associate_public_ip_address=True + associate_public_ip_address=True, ) # Option: 2: Create Elastic IP (not free tier compatible) @@ -150,7 +139,8 @@ class IptvUpdaterStack(Stack): # Add Cognito User Pool user_pool = cognito.UserPool( - self, "IptvUpdaterUserPool", + self, + "IptvUpdaterUserPool", user_pool_name="iptv-updater-users", self_sign_up_enabled=False, # Only admins can create users password_policy=cognito.PasswordPolicy( @@ -158,37 +148,33 @@ class IptvUpdaterStack(Stack): require_lowercase=True, require_digits=True, require_symbols=True, - require_uppercase=True + require_uppercase=True, ), account_recovery=cognito.AccountRecovery.EMAIL_ONLY, - removal_policy=RemovalPolicy.DESTROY + removal_policy=RemovalPolicy.DESTROY, ) # Add App Client with the correct callback URL - client = user_pool.add_client("IptvUpdaterClient", + client = user_pool.add_client( + "IptvUpdaterClient", access_token_validity=Duration.minutes(60), id_token_validity=Duration.minutes(60), refresh_token_validity=Duration.days(1), - auth_flows=cognito.AuthFlow( - user_password=True - ), + auth_flows=cognito.AuthFlow(user_password=True), o_auth=cognito.OAuthSettings( - flows=cognito.OAuthFlows( - implicit_code_grant=True - ) + flows=cognito.OAuthFlows(implicit_code_grant=True) ), prevent_user_existence_errors=True, generate_secret=True, - enable_token_revocation=True + enable_token_revocation=True, ) # Add domain for hosted UI - domain = user_pool.add_domain("IptvUpdaterDomain", - cognito_domain=cognito.CognitoDomainOptions( - domain_prefix="iptv-updater" - ) + domain = user_pool.add_domain( + "IptvUpdaterDomain", + cognito_domain=cognito.CognitoDomainOptions(domain_prefix="iptv-updater"), ) - + # Read the userdata script with proper path resolution script_dir = os.path.dirname(os.path.abspath(__file__)) userdata_path = os.path.join(script_dir, "userdata.sh") @@ -196,46 +182,56 @@ class IptvUpdaterStack(Stack): # Creates a userdata object for Linux hosts userdata = ec2.UserData.for_linux() - + # Add environment variables for acme.sh from parameters userdata.add_commands( f'export FREEDNS_User="{freedns_user}"', f'export FREEDNS_Password="{freedns_password}"', f'export DOMAIN_NAME="{domain_name}"', f'export REPO_URL="{repo_url}"', - f'export LETSENCRYPT_EMAIL="{letsencrypt_email}"' + f'export LETSENCRYPT_EMAIL="{letsencrypt_email}"', ) - + # Adds one or more commands to the userdata object. userdata.add_commands( - f'echo "COGNITO_USER_POOL_ID={user_pool.user_pool_id}" >> /etc/environment', - f'echo "COGNITO_CLIENT_ID={client.user_pool_client_id}" >> /etc/environment', - f'echo "COGNITO_CLIENT_SECRET={client.user_pool_client_secret.to_string()}" >> /etc/environment', - f'echo "DOMAIN_NAME={domain_name}" >> /etc/environment' + ( + f'echo "COGNITO_USER_POOL_ID=' + f'{user_pool.user_pool_id}" >> /etc/environment' + ), + ( + f'echo "COGNITO_CLIENT_ID=' + f'{client.user_pool_client_id}" >> /etc/environment' + ), + ( + f'echo "COGNITO_CLIENT_SECRET=' + f'{client.user_pool_client_secret.to_string()}" >> /etc/environment' + ), + f'echo "DOMAIN_NAME={domain_name}" >> /etc/environment', ) - userdata.add_commands(str(userdata_file, 'utf-8')) + userdata.add_commands(str(userdata_file, "utf-8")) # Create RDS Security Group rds_sg = ec2.SecurityGroup( - self, "RdsSecurityGroup", + self, + "RdsSecurityGroup", vpc=vpc, - description="Security group for RDS PostgreSQL" + description="Security group for RDS PostgreSQL", ) rds_sg.add_ingress_rule( security_group, ec2.Port.tcp(5432), - "Allow PostgreSQL access from EC2 instance" + "Allow PostgreSQL access from EC2 instance", ) # Create RDS PostgreSQL instance (free tier compatible - db.t3.micro) db = rds.DatabaseInstance( - self, "IptvUpdaterDB", + self, + "IptvUpdaterDB", engine=rds.DatabaseInstanceEngine.postgres( version=rds.PostgresEngineVersion.VER_13 ), instance_type=ec2.InstanceType.of( - ec2.InstanceClass.T3, - ec2.InstanceSize.MICRO + ec2.InstanceClass.T3, ec2.InstanceSize.MICRO ), vpc=vpc, vpc_subnets=ec2.SubnetSelection( @@ -247,39 +243,43 @@ class IptvUpdaterStack(Stack): database_name="iptv_updater", removal_policy=RemovalPolicy.DESTROY, deletion_protection=False, - publicly_accessible=False # Avoid public IPv4 charges + publicly_accessible=False, # Avoid public IPv4 charges ) # Add RDS permissions to instance role role.add_managed_policy( - iam.ManagedPolicy.from_aws_managed_policy_name( - "AmazonRDSFullAccess" - ) + iam.ManagedPolicy.from_aws_managed_policy_name("AmazonRDSFullAccess") ) # Store DB connection info in SSM Parameter Store - ssm.StringParameter(self, "DBHostParam", + ssm.StringParameter( + self, + "DBHostParam", parameter_name="/iptv-updater/DB_HOST", - string_value=db.db_instance_endpoint_address + string_value=db.db_instance_endpoint_address, ) - ssm.StringParameter(self, "DBNameParam", + ssm.StringParameter( + self, + "DBNameParam", parameter_name="/iptv-updater/DB_NAME", - string_value="iptv_updater" + string_value="iptv_updater", ) - ssm.StringParameter(self, "DBUserParam", + ssm.StringParameter( + self, + "DBUserParam", parameter_name="/iptv-updater/DB_USER", - string_value=db.secret.secret_value_from_json("username").to_string() + string_value=db.secret.secret_value_from_json("username").to_string(), ) - ssm.StringParameter(self, "DBPassParam", + ssm.StringParameter( + self, + "DBPassParam", parameter_name="/iptv-updater/DB_PASSWORD", - string_value=db.secret.secret_value_from_json("password").to_string() + string_value=db.secret.secret_value_from_json("password").to_string(), ) # Add SSM read permissions to instance role role.add_managed_policy( - iam.ManagedPolicy.from_aws_managed_policy_name( - "AmazonSSMReadOnlyAccess" - ) + iam.ManagedPolicy.from_aws_managed_policy_name("AmazonSSMReadOnlyAccess") ) # Update instance with userdata @@ -293,6 +293,8 @@ class IptvUpdaterStack(Stack): # CfnOutput(self, "InstancePublicIP", value=eip.attr_public_ip) CfnOutput(self, "UserPoolId", value=user_pool.user_pool_id) CfnOutput(self, "UserPoolClientId", value=client.user_pool_client_id) - CfnOutput(self, "CognitoDomainUrl", - value=f"https://{domain.domain_name}.auth.{self.region}.amazoncognito.com" - ) \ No newline at end of file + CfnOutput( + self, + "CognitoDomainUrl", + value=f"https://{domain.domain_name}.auth.{self.region}.amazoncognito.com", + ) diff --git a/pyproject.toml b/pyproject.toml index a07ce25..817ff43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,10 @@ [tool.ruff] line-length = 88 +exclude = [ + "alembic/versions/*.py", # Auto-generated Alembic migration files +] + +[tool.ruff.lint] select = [ "E", # pycodestyle errors "F", # pyflakes @@ -9,7 +14,13 @@ select = [ ] ignore = [] -[tool.ruff.isort] +[tool.ruff.lint.per-file-ignores] +"tests/**/*.py" = [ + "F811", # redefinition of unused name + "F401", # unused import +] + +[tool.ruff.lint.isort] known-first-party = ["app"] [tool.ruff.format] diff --git a/tests/auth/test_cognito.py b/tests/auth/test_cognito.py index 01ef416..310179a 100644 --- a/tests/auth/test_cognito.py +++ b/tests/auth/test_cognito.py @@ -1,5 +1,6 @@ +from unittest.mock import MagicMock, patch + import pytest -from unittest.mock import patch, MagicMock from fastapi import HTTPException, status # Test constants @@ -7,12 +8,15 @@ TEST_CLIENT_ID = "test_client_id" TEST_CLIENT_SECRET = "test_client_secret" # Patch constants before importing the module -with patch("app.utils.constants.COGNITO_CLIENT_ID", TEST_CLIENT_ID), \ - patch("app.utils.constants.COGNITO_CLIENT_SECRET", TEST_CLIENT_SECRET): - from app.auth.cognito import initiate_auth, get_user_from_token +with ( + patch("app.utils.constants.COGNITO_CLIENT_ID", TEST_CLIENT_ID), + patch("app.utils.constants.COGNITO_CLIENT_SECRET", TEST_CLIENT_SECRET), +): + from app.auth.cognito import get_user_from_token, initiate_auth from app.models.auth import CognitoUser from app.utils.constants import USER_ROLE_ATTRIBUTE + @pytest.fixture(autouse=True) def mock_cognito_client(): with patch("app.auth.cognito.cognito_client") as mock_client: @@ -26,13 +30,14 @@ def mock_cognito_client(): ) yield mock_client + def test_initiate_auth_success(mock_cognito_client): # Mock successful authentication response mock_cognito_client.initiate_auth.return_value = { "AuthenticationResult": { "AccessToken": "mock_access_token", "IdToken": "mock_id_token", - "RefreshToken": "mock_refresh_token" + "RefreshToken": "mock_refresh_token", } } @@ -40,104 +45,125 @@ def test_initiate_auth_success(mock_cognito_client): assert result == { "AccessToken": "mock_access_token", "IdToken": "mock_id_token", - "RefreshToken": "mock_refresh_token" + "RefreshToken": "mock_refresh_token", } + def test_initiate_auth_with_secret_hash(mock_cognito_client): - with patch("app.auth.cognito.calculate_secret_hash", return_value="mocked_secret_hash") as mock_hash: + with patch( + "app.auth.cognito.calculate_secret_hash", return_value="mocked_secret_hash" + ) as mock_hash: mock_cognito_client.initiate_auth.return_value = { "AuthenticationResult": {"AccessToken": "token"} } - - result = initiate_auth("test_user", "test_pass") - + + initiate_auth("test_user", "test_pass") + # Verify calculate_secret_hash was called - mock_hash.assert_called_once_with("test_user", TEST_CLIENT_ID, TEST_CLIENT_SECRET) - + mock_hash.assert_called_once_with( + "test_user", TEST_CLIENT_ID, TEST_CLIENT_SECRET + ) + # Verify SECRET_HASH was included in auth params call_args = mock_cognito_client.initiate_auth.call_args[1] assert "SECRET_HASH" in call_args["AuthParameters"] assert call_args["AuthParameters"]["SECRET_HASH"] == "mocked_secret_hash" + def test_initiate_auth_not_authorized(mock_cognito_client): - mock_cognito_client.initiate_auth.side_effect = mock_cognito_client.exceptions.NotAuthorizedException() - + mock_cognito_client.initiate_auth.side_effect = ( + mock_cognito_client.exceptions.NotAuthorizedException() + ) + with pytest.raises(HTTPException) as exc_info: initiate_auth("invalid_user", "wrong_pass") - + assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED assert exc_info.value.detail == "Invalid username or password" + def test_initiate_auth_user_not_found(mock_cognito_client): - mock_cognito_client.initiate_auth.side_effect = mock_cognito_client.exceptions.UserNotFoundException() - + mock_cognito_client.initiate_auth.side_effect = ( + mock_cognito_client.exceptions.UserNotFoundException() + ) + with pytest.raises(HTTPException) as exc_info: initiate_auth("nonexistent_user", "any_pass") - + assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND assert exc_info.value.detail == "User not found" + def test_initiate_auth_generic_error(mock_cognito_client): mock_cognito_client.initiate_auth.side_effect = Exception("Some error") - + with pytest.raises(HTTPException) as exc_info: initiate_auth("test_user", "test_pass") - + assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert "An error occurred during authentication" in exc_info.value.detail + def test_get_user_from_token_success(mock_cognito_client): mock_response = { "Username": "test_user", "UserAttributes": [ {"Name": "sub", "Value": "123"}, - {"Name": USER_ROLE_ATTRIBUTE, "Value": "admin,user"} - ] + {"Name": USER_ROLE_ATTRIBUTE, "Value": "admin,user"}, + ], } mock_cognito_client.get_user.return_value = mock_response - + result = get_user_from_token("valid_token") - + assert isinstance(result, CognitoUser) assert result.username == "test_user" assert set(result.roles) == {"admin", "user"} + def test_get_user_from_token_no_roles(mock_cognito_client): mock_response = { "Username": "test_user", - "UserAttributes": [{"Name": "sub", "Value": "123"}] + "UserAttributes": [{"Name": "sub", "Value": "123"}], } mock_cognito_client.get_user.return_value = mock_response - + result = get_user_from_token("valid_token") - + assert isinstance(result, CognitoUser) assert result.username == "test_user" assert result.roles == [] + def test_get_user_from_token_invalid_token(mock_cognito_client): - mock_cognito_client.get_user.side_effect = mock_cognito_client.exceptions.NotAuthorizedException() - + mock_cognito_client.get_user.side_effect = ( + mock_cognito_client.exceptions.NotAuthorizedException() + ) + with pytest.raises(HTTPException) as exc_info: get_user_from_token("invalid_token") - + assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED assert exc_info.value.detail == "Invalid or expired token." + def test_get_user_from_token_user_not_found(mock_cognito_client): - mock_cognito_client.get_user.side_effect = mock_cognito_client.exceptions.UserNotFoundException() - + mock_cognito_client.get_user.side_effect = ( + mock_cognito_client.exceptions.UserNotFoundException() + ) + with pytest.raises(HTTPException) as exc_info: get_user_from_token("token_for_nonexistent_user") - + assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED assert exc_info.value.detail == "User not found or invalid token." + def test_get_user_from_token_generic_error(mock_cognito_client): mock_cognito_client.get_user.side_effect = Exception("Some error") - + with pytest.raises(HTTPException) as exc_info: get_user_from_token("test_token") - + assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR - assert "Token verification failed" in exc_info.value.detail \ No newline at end of file + assert "Token verification failed" in exc_info.value.detail diff --git a/tests/auth/test_dependencies.py b/tests/auth/test_dependencies.py index 48085b3..5d73400 100644 --- a/tests/auth/test_dependencies.py +++ b/tests/auth/test_dependencies.py @@ -1,9 +1,11 @@ -import os -import pytest import importlib +import os + +import pytest +from fastapi import Depends, HTTPException, Request from fastapi.security import OAuth2PasswordBearer -from fastapi import HTTPException, Depends, Request -from app.auth.dependencies import get_current_user, require_roles, oauth2_scheme + +from app.auth.dependencies import get_current_user, oauth2_scheme, require_roles from app.models.auth import CognitoUser # Mock user for testing @@ -11,24 +13,30 @@ TEST_USER = CognitoUser( username="testuser", email="test@example.com", roles=["admin", "user"], - groups=["test_group"] + groups=["test_group"], ) + # Mock the underlying get_user_from_token function def mock_get_user_from_token(token: str) -> CognitoUser: if token == "valid_token": return TEST_USER raise HTTPException(status_code=401, detail="Invalid token") + # Mock endpoint for testing the require_roles decorator @require_roles("admin") async def mock_protected_endpoint(user: CognitoUser = Depends(get_current_user)): return {"message": "Success", "user": user.username} + # Patch the get_user_from_token function for testing @pytest.fixture(autouse=True) def mock_auth(monkeypatch): - monkeypatch.setattr("app.auth.dependencies.get_user_from_token", mock_get_user_from_token) + monkeypatch.setattr( + "app.auth.dependencies.get_user_from_token", mock_get_user_from_token + ) + # Test get_current_user dependency def test_get_current_user_success(): @@ -37,54 +45,53 @@ def test_get_current_user_success(): assert user.username == "testuser" assert user.roles == ["admin", "user"] + def test_get_current_user_invalid_token(): with pytest.raises(HTTPException) as exc: get_current_user("invalid_token") assert exc.value.status_code == 401 + # Test require_roles decorator @pytest.mark.asyncio async def test_require_roles_success(): # Create test user with required role user = CognitoUser( - username="testuser", - email="test@example.com", - roles=["admin"], - groups=[] + username="testuser", email="test@example.com", roles=["admin"], groups=[] ) - + result = await mock_protected_endpoint(user=user) assert result == {"message": "Success", "user": "testuser"} + @pytest.mark.asyncio async def test_require_roles_missing_role(): # Create test user without required role user = CognitoUser( - username="testuser", - email="test@example.com", - roles=["user"], - groups=[] + username="testuser", email="test@example.com", roles=["user"], groups=[] ) - + with pytest.raises(HTTPException) as exc: await mock_protected_endpoint(user=user) assert exc.value.status_code == 403 - assert exc.value.detail == "You do not have the required roles to access this endpoint." + assert ( + exc.value.detail + == "You do not have the required roles to access this endpoint." + ) + @pytest.mark.asyncio async def test_require_roles_no_roles(): # Create test user with no roles user = CognitoUser( - username="testuser", - email="test@example.com", - roles=[], - groups=[] + username="testuser", email="test@example.com", roles=[], groups=[] ) - + with pytest.raises(HTTPException) as exc: await mock_protected_endpoint(user=user) assert exc.value.status_code == 403 + @pytest.mark.asyncio async def test_require_roles_multiple_roles(): # Test requiring multiple roles @@ -97,7 +104,7 @@ async def test_require_roles_multiple_roles(): username="testuser", email="test@example.com", roles=["admin", "super_user", "user"], - groups=[] + groups=[], ) result = await mock_multi_role_endpoint(user=user_with_roles) assert result == {"message": "Success"} @@ -107,56 +114,62 @@ async def test_require_roles_multiple_roles(): username="testuser", email="test@example.com", roles=["admin", "user"], - groups=[] + groups=[], ) with pytest.raises(HTTPException) as exc: await mock_multi_role_endpoint(user=user_missing_role) assert exc.value.status_code == 403 + @pytest.mark.asyncio async def test_oauth2_scheme_configuration(): # Verify that we have a properly configured OAuth2PasswordBearer instance assert isinstance(oauth2_scheme, OAuth2PasswordBearer) - + # Create a mock request with no Authorization header - mock_request = Request(scope={ - 'type': 'http', - 'headers': [], - 'method': 'GET', - 'scheme': 'http', - 'path': '/', - 'query_string': b'', - 'client': ('127.0.0.1', 8000) - }) - + mock_request = Request( + scope={ + "type": "http", + "headers": [], + "method": "GET", + "scheme": "http", + "path": "/", + "query_string": b"", + "client": ("127.0.0.1", 8000), + } + ) + # Test that the scheme raises 401 when no token is provided with pytest.raises(HTTPException) as exc: await oauth2_scheme(mock_request) assert exc.value.status_code == 401 assert exc.value.detail == "Not authenticated" + def test_mock_auth_import(monkeypatch): # Save original env var value original_value = os.environ.get("MOCK_AUTH") - + try: # Set MOCK_AUTH to true monkeypatch.setenv("MOCK_AUTH", "true") - + # Reload the dependencies module to trigger the import condition import app.auth.dependencies + importlib.reload(app.auth.dependencies) - + # Verify that mock_get_user_from_token was imported from app.auth.dependencies import get_user_from_token - assert get_user_from_token.__module__ == 'app.auth.mock_auth' - + + assert get_user_from_token.__module__ == "app.auth.mock_auth" + finally: # Restore original env var if original_value is None: monkeypatch.delenv("MOCK_AUTH", raising=False) else: monkeypatch.setenv("MOCK_AUTH", original_value) - + # Reload again to restore original state - importlib.reload(app.auth.dependencies) \ No newline at end of file + importlib.reload(app.auth.dependencies) diff --git a/tests/auth/test_mock_auth.py b/tests/auth/test_mock_auth.py index a1bfb27..d937582 100644 --- a/tests/auth/test_mock_auth.py +++ b/tests/auth/test_mock_auth.py @@ -1,8 +1,10 @@ import pytest from fastapi import HTTPException + from app.auth.mock_auth import mock_get_user_from_token, mock_initiate_auth from app.models.auth import CognitoUser + def test_mock_get_user_from_token_success(): """Test successful token validation returns expected user""" user = mock_get_user_from_token("testuser") @@ -10,27 +12,30 @@ def test_mock_get_user_from_token_success(): assert user.username == "testuser" assert user.roles == ["admin"] + def test_mock_get_user_from_token_invalid(): """Test invalid token raises expected exception""" with pytest.raises(HTTPException) as exc_info: mock_get_user_from_token("invalid_token") - + assert exc_info.value.status_code == 401 assert exc_info.value.detail == "Invalid mock token - use 'testuser'" + def test_mock_initiate_auth(): """Test mock authentication returns expected token response""" result = mock_initiate_auth("any_user", "any_password") - + assert isinstance(result, dict) assert result["AccessToken"] == "testuser" assert result["ExpiresIn"] == 3600 assert result["TokenType"] == "Bearer" + def test_mock_initiate_auth_different_credentials(): """Test mock authentication works with any credentials""" result1 = mock_initiate_auth("user1", "pass1") result2 = mock_initiate_auth("user2", "pass2") - + # Should return same mock token regardless of credentials - assert result1 == result2 \ No newline at end of file + assert result1 == result2 diff --git a/tests/routers/test_auth.py b/tests/routers/test_auth.py index a88513b..230bf18 100644 --- a/tests/routers/test_auth.py +++ b/tests/routers/test_auth.py @@ -1,34 +1,35 @@ from unittest.mock import patch + import pytest -from fastapi.testclient import TestClient from fastapi import HTTPException, status +from fastapi.testclient import TestClient + from app.main import app client = TestClient(app) + @pytest.fixture def mock_successful_auth(): return { "AccessToken": "mock_access_token", "IdToken": "mock_id_token", - "RefreshToken": "mock_refresh_token" + "RefreshToken": "mock_refresh_token", } + @pytest.fixture def mock_successful_auth_no_refresh(): - return { - "AccessToken": "mock_access_token", - "IdToken": "mock_id_token" - } + return {"AccessToken": "mock_access_token", "IdToken": "mock_id_token"} + def test_signin_success(mock_successful_auth): """Test successful signin with all tokens""" - with patch('app.routers.auth.initiate_auth', return_value=mock_successful_auth): + with patch("app.routers.auth.initiate_auth", return_value=mock_successful_auth): response = client.post( - "/auth/signin", - json={"username": "testuser", "password": "testpass"} + "/auth/signin", json={"username": "testuser", "password": "testpass"} ) - + assert response.status_code == 200 data = response.json() assert data["access_token"] == "mock_access_token" @@ -36,14 +37,16 @@ def test_signin_success(mock_successful_auth): assert data["refresh_token"] == "mock_refresh_token" assert data["token_type"] == "Bearer" + def test_signin_success_no_refresh(mock_successful_auth_no_refresh): """Test successful signin without refresh token""" - with patch('app.routers.auth.initiate_auth', return_value=mock_successful_auth_no_refresh): + with patch( + "app.routers.auth.initiate_auth", return_value=mock_successful_auth_no_refresh + ): response = client.post( - "/auth/signin", - json={"username": "testuser", "password": "testpass"} + "/auth/signin", json={"username": "testuser", "password": "testpass"} ) - + assert response.status_code == 200 data = response.json() assert data["access_token"] == "mock_access_token" @@ -51,57 +54,48 @@ def test_signin_success_no_refresh(mock_successful_auth_no_refresh): assert data["refresh_token"] is None assert data["token_type"] == "Bearer" + def test_signin_invalid_input(): """Test signin with invalid input format""" # Missing password - response = client.post( - "/auth/signin", - json={"username": "testuser"} - ) + response = client.post("/auth/signin", json={"username": "testuser"}) assert response.status_code == 422 # Missing username - response = client.post( - "/auth/signin", - json={"password": "testpass"} - ) + response = client.post("/auth/signin", json={"password": "testpass"}) assert response.status_code == 422 # Empty payload - response = client.post( - "/auth/signin", - json={} - ) + response = client.post("/auth/signin", json={}) assert response.status_code == 422 + def test_signin_auth_failure(): """Test signin with authentication failure""" - with patch('app.routers.auth.initiate_auth') as mock_auth: + with patch("app.routers.auth.initiate_auth") as mock_auth: mock_auth.side_effect = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid username or password" + detail="Invalid username or password", ) response = client.post( - "/auth/signin", - json={"username": "testuser", "password": "wrongpass"} + "/auth/signin", json={"username": "testuser", "password": "wrongpass"} ) - + assert response.status_code == 401 data = response.json() assert data["detail"] == "Invalid username or password" + def test_signin_user_not_found(): """Test signin with non-existent user""" - with patch('app.routers.auth.initiate_auth') as mock_auth: + with patch("app.routers.auth.initiate_auth") as mock_auth: mock_auth.side_effect = HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found" + status_code=status.HTTP_404_NOT_FOUND, detail="User not found" ) response = client.post( - "/auth/signin", - json={"username": "nonexistent", "password": "testpass"} + "/auth/signin", json={"username": "nonexistent", "password": "testpass"} ) - + assert response.status_code == 404 data = response.json() - assert data["detail"] == "User not found" \ No newline at end of file + assert data["detail"] == "User not found" diff --git a/tests/routers/test_channels.py b/tests/routers/test_channels.py index 6ac11a9..1c78125 100644 --- a/tests/routers/test_channels.py +++ b/tests/routers/test_channels.py @@ -1,29 +1,31 @@ -import pytest import uuid -from fastapi.testclient import TestClient + +import pytest from fastapi import FastAPI, status +from fastapi.testclient import TestClient from sqlalchemy import String from sqlalchemy.orm import Session from app.auth.dependencies import get_current_user -from app.routers.channels import router as channels_router from app.models.auth import CognitoUser +from app.routers.channels import router as channels_router from app.utils.database import get_db # Import mocks from db_mocks from tests.utils.db_mocks import ( MockBase, - engine_mock, - session_mock as TestingSessionLocal, - mock_get_db, MockChannelDB, MockChannelURL, - MockPriority + MockPriority, + engine_mock, + mock_get_db, ) +from tests.utils.db_mocks import session_mock as TestingSessionLocal # Create a FastAPI instance for testing app = FastAPI() + # Mock current user def mock_get_current_user_admin(): return CognitoUser( @@ -31,24 +33,27 @@ def mock_get_current_user_admin(): email="testadmin@example.com", roles=["admin"], user_status="CONFIRMED", - enabled=True + enabled=True, ) + def mock_get_current_user_non_admin(): return CognitoUser( username="testuser", email="testuser@example.com", - roles=["user"], # Or any role other than admin + roles=["user"], # Or any role other than admin user_status="CONFIRMED", - enabled=True + enabled=True, ) + # Override dependencies app.dependency_overrides[get_db] = mock_get_db app.include_router(channels_router) client = TestClient(app) + @pytest.fixture(scope="function") def db_session(): # Create tables for each test function @@ -61,6 +66,7 @@ def db_session(): # Drop tables after each test function MockBase.metadata.drop_all(bind=engine_mock) + @pytest.fixture(scope="function") def admin_user_client(db_session: Session): """Yields a TestClient configured with an admin user.""" @@ -71,6 +77,7 @@ def admin_user_client(db_session: Session): with TestClient(test_app) as test_client: yield test_client + @pytest.fixture(scope="function") def non_admin_user_client(db_session: Session): """Yields a TestClient configured with a non-admin user.""" @@ -81,8 +88,10 @@ def non_admin_user_client(db_session: Session): with TestClient(test_app) as test_client: yield test_client + # --- Test Cases For Channel Creation --- + def test_create_channel_success(db_session: Session, admin_user_client: TestClient): # Setup a priority priority1 = MockPriority(id=100, description="High") @@ -95,11 +104,11 @@ def test_create_channel_success(db_session: Session, admin_user_client: TestClie "group_title": "Test Group", "tvg_name": "TestChannel1", "tvg_logo": "logo.png", - "urls": [ - {"url": "http://stream1.com/test", "priority_id": 100} - ] + "urls": [{"url": "http://stream1.com/test", "priority_id": 100}], } - response = admin_user_client.post("/channels/", json=channel_data) # No headers needed now + response = admin_user_client.post( + "/channels/", json=channel_data + ) # No headers needed now assert response.status_code == status.HTTP_201_CREATED data = response.json() assert data["name"] == "Test Channel 1" @@ -110,18 +119,25 @@ def test_create_channel_success(db_session: Session, admin_user_client: TestClie assert data["urls"][0]["priority_id"] == 100 # Verify in DB - db_channel = db_session.query(MockChannelDB).filter(MockChannelDB.name == "Test Channel 1").first() + db_channel = ( + db_session.query(MockChannelDB) + .filter(MockChannelDB.name == "Test Channel 1") + .first() + ) assert db_channel is not None assert db_channel.group_title == "Test Group" - + # Query URLs using exact string comparison - db_urls = db_session.query(MockChannelURL).filter( - MockChannelURL.channel_id.cast(String()) == db_channel.id - ).all() - + db_urls = ( + db_session.query(MockChannelURL) + .filter(MockChannelURL.channel_id.cast(String()) == db_channel.id) + .all() + ) + assert len(db_urls) == 1 assert db_urls[0].url == "http://stream1.com/test" + def test_create_channel_duplicate(db_session: Session, admin_user_client: TestClient): # Setup a priority priority1 = MockPriority(id=100, description="High") @@ -135,7 +151,7 @@ def test_create_channel_duplicate(db_session: Session, admin_user_client: TestCl "group_title": "Duplicate Group", "tvg_name": "DuplicateChannelName", "tvg_logo": "duplicate_logo.png", - "urls": [{"url": "http://stream_dup.com/test", "priority_id": 100}] + "urls": [{"url": "http://stream_dup.com/test", "priority_id": 100}], } response1 = admin_user_client.post("/channels/", json=initial_channel_data) assert response1.status_code == status.HTTP_201_CREATED @@ -145,26 +161,31 @@ def test_create_channel_duplicate(db_session: Session, admin_user_client: TestCl assert response2.status_code == status.HTTP_409_CONFLICT assert "already exists" in response2.json()["detail"] -def test_create_channel_forbidden_for_non_admin(db_session: Session, non_admin_user_client: TestClient): + +def test_create_channel_forbidden_for_non_admin( + db_session: Session, non_admin_user_client: TestClient +): # Setup a priority priority1 = MockPriority(id=100, description="High") db_session.add(priority1) db_session.commit() - + channel_data = { "tvg_id": "channel_forbidden.tv", "name": "Forbidden Channel", "group_title": "Forbidden Group", "tvg_name": "ForbiddenChannelName", "tvg_logo": "forbidden_logo.png", - "urls": [{"url": "http://stream_forbidden.com/test", "priority_id": 100}] + "urls": [{"url": "http://stream_forbidden.com/test", "priority_id": 100}], } response = non_admin_user_client.post("/channels/", json=channel_data) assert response.status_code == status.HTTP_403_FORBIDDEN assert "required roles" in response.json()["detail"] + # --- Test Cases For Get Channel --- + def test_get_channel_success(db_session: Session, admin_user_client: TestClient): # Setup a priority priority1 = MockPriority(id=100, description="High") @@ -178,13 +199,15 @@ def test_get_channel_success(db_session: Session, admin_user_client: TestClient) "group_title": "Get Group", "tvg_name": "GetMeChannelName", "tvg_logo": "get_me_logo.png", - "urls": [{"url": "http://get_me.com/stream", "priority_id": 100}] + "urls": [{"url": "http://get_me.com/stream", "priority_id": 100}], } create_response = admin_user_client.post("/channels/", json=channel_data_create) assert create_response.status_code == status.HTTP_201_CREATED created_channel_id = create_response.json()["id"] - app.dependency_overrides[get_current_user] = mock_get_current_user_admin # Or a generic authenticated user + app.dependency_overrides[get_current_user] = ( + mock_get_current_user_admin # Or a generic authenticated user + ) get_response = admin_user_client.get(f"/channels/{created_channel_id}") assert get_response.status_code == status.HTTP_200_OK data = get_response.json() @@ -203,8 +226,10 @@ def test_get_channel_not_found(db_session: Session, admin_user_client: TestClien assert "Channel not found" in response.json()["detail"] app.dependency_overrides.pop(get_current_user, None) + # --- Test Cases For Update Channel --- + def test_update_channel_success(db_session: Session, admin_user_client: TestClient): # Setup priority and create initial channel priority1 = MockPriority(id=100, description="High") @@ -217,17 +242,16 @@ def test_update_channel_success(db_session: Session, admin_user_client: TestClie "group_title": "Update Group", "tvg_name": "UpdateMeChannelName", "tvg_logo": "update_me_logo.png", - "urls": [{"url": "http://update_me.com/stream", "priority_id": 100}] + "urls": [{"url": "http://update_me.com/stream", "priority_id": 100}], } create_response = admin_user_client.post("/channels/", json=initial_channel_data) assert create_response.status_code == status.HTTP_201_CREATED created_channel_id = create_response.json()["id"] - update_data = { - "name": "Updated Channel Name", - "tvg_logo": "new_logo.png" - } - response = admin_user_client.put(f"/channels/{created_channel_id}", json=update_data) + update_data = {"name": "Updated Channel Name", "tvg_logo": "new_logo.png"} + response = admin_user_client.put( + f"/channels/{created_channel_id}", json=update_data + ) assert response.status_code == status.HTTP_200_OK data = response.json() assert data["id"] == created_channel_id @@ -236,11 +260,16 @@ def test_update_channel_success(db_session: Session, admin_user_client: TestClie assert data["tvg_logo"] == "new_logo.png" # Verify in DB - db_channel = db_session.query(MockChannelDB).filter(MockChannelDB.id.cast(String()) == uuid.UUID(created_channel_id)).first() + db_channel = ( + db_session.query(MockChannelDB) + .filter(MockChannelDB.id.cast(String()) == uuid.UUID(created_channel_id)) + .first() + ) assert db_channel is not None assert db_channel.name == "Updated Channel Name" assert db_channel.tvg_logo == "new_logo.png" + def test_update_channel_conflict(db_session: Session, admin_user_client: TestClient): # Setup priority priority1 = MockPriority(id=100, description="High") @@ -249,27 +278,36 @@ def test_update_channel_conflict(db_session: Session, admin_user_client: TestCli # Create channel 1 channel1_data = { - "tvg_id": "c1.tv", "name": "Channel One", "group_title": "Group A", - "tvg_name": "C1Name", "tvg_logo": "c1logo.png", - "urls": [{"url": "http://c1.com", "priority_id": 100}] + "tvg_id": "c1.tv", + "name": "Channel One", + "group_title": "Group A", + "tvg_name": "C1Name", + "tvg_logo": "c1logo.png", + "urls": [{"url": "http://c1.com", "priority_id": 100}], } admin_user_client.post("/channels/", json=channel1_data) # Create channel 2 channel2_data = { - "tvg_id": "c2.tv", "name": "Channel Two", "group_title": "Group B", - "tvg_name": "C2Name", "tvg_logo": "c2logo.png", - "urls": [{"url": "http://c2.com", "priority_id": 100}] + "tvg_id": "c2.tv", + "name": "Channel Two", + "group_title": "Group B", + "tvg_name": "C2Name", + "tvg_logo": "c2logo.png", + "urls": [{"url": "http://c2.com", "priority_id": 100}], } response_c2 = admin_user_client.post("/channels/", json=channel2_data) channel2_id = response_c2.json()["id"] # Attempt to update channel 2 to conflict with channel 1 update_conflict_data = {"name": "Channel One", "group_title": "Group A"} - response = admin_user_client.put(f"/channels/{channel2_id}", json=update_conflict_data) + response = admin_user_client.put( + f"/channels/{channel2_id}", json=update_conflict_data + ) assert response.status_code == status.HTTP_409_CONFLICT assert "already exists" in response.json()["detail"] + def test_update_channel_not_found(db_session: Session, admin_user_client: TestClient): random_uuid = uuid.uuid4() update_data = {"name": "Non Existent Update"} @@ -277,25 +315,38 @@ def test_update_channel_not_found(db_session: Session, admin_user_client: TestCl assert response.status_code == status.HTTP_404_NOT_FOUND assert "Channel not found" in response.json()["detail"] -def test_update_channel_forbidden_for_non_admin(db_session: Session, non_admin_user_client: TestClient, admin_user_client: TestClient): + +def test_update_channel_forbidden_for_non_admin( + db_session: Session, + non_admin_user_client: TestClient, + admin_user_client: TestClient, +): # Setup priority and create initial channel with admin priority1 = MockPriority(id=100, description="High") db_session.add(priority1) db_session.commit() initial_channel_data = { - "tvg_id": "update_forbidden.tv", "name": "Update Forbidden", "group_title": "Forbidden Update Group", - "tvg_name": "UFName", "tvg_logo": "uflogo.png", - "urls": [{"url": "http://update_forbidden.com", "priority_id": 100}] + "tvg_id": "update_forbidden.tv", + "name": "Update Forbidden", + "group_title": "Forbidden Update Group", + "tvg_name": "UFName", + "tvg_logo": "uflogo.png", + "urls": [{"url": "http://update_forbidden.com", "priority_id": 100}], } create_response = admin_user_client.post("/channels/", json=initial_channel_data) created_channel_id = create_response.json()["id"] update_data = {"name": "Attempted Update"} - response = non_admin_user_client.put(f"/channels/{created_channel_id}", json=update_data) + response = non_admin_user_client.put( + f"/channels/{created_channel_id}", json=update_data + ) assert response.status_code == status.HTTP_403_FORBIDDEN assert "required roles" in response.json()["detail"] + + # --- Test Cases For Delete Channel --- + def test_delete_channel_success(db_session: Session, admin_user_client: TestClient): # Setup priority and create initial channel priority1 = MockPriority(id=100, description="High") @@ -306,26 +357,41 @@ def test_delete_channel_success(db_session: Session, admin_user_client: TestClie "tvg_id": "delete_me.tv", "name": "Delete Me Channel", "group_title": "Delete Group", - "tvg_name": "DMName", "tvg_logo": "dmlogo.png", - "urls": [{"url": "http://delete_me.com/stream", "priority_id": 100}] + "tvg_name": "DMName", + "tvg_logo": "dmlogo.png", + "urls": [{"url": "http://delete_me.com/stream", "priority_id": 100}], } create_response = admin_user_client.post("/channels/", json=initial_channel_data) assert create_response.status_code == status.HTTP_201_CREATED created_channel_id = create_response.json()["id"] # Verify it exists before delete - db_channel_before_delete = db_session.query(MockChannelDB).filter(MockChannelDB.id.cast(String()) == uuid.UUID(created_channel_id)).first() + db_channel_before_delete = ( + db_session.query(MockChannelDB) + .filter(MockChannelDB.id.cast(String()) == uuid.UUID(created_channel_id)) + .first() + ) assert db_channel_before_delete is not None delete_response = admin_user_client.delete(f"/channels/{created_channel_id}") assert delete_response.status_code == status.HTTP_204_NO_CONTENT # Verify it's gone from DB - db_channel_after_delete = db_session.query(MockChannelDB).filter(MockChannelDB.id.cast(String()) == uuid.UUID(created_channel_id)).first() + db_channel_after_delete = ( + db_session.query(MockChannelDB) + .filter(MockChannelDB.id.cast(String()) == uuid.UUID(created_channel_id)) + .first() + ) assert db_channel_after_delete is None # Also verify associated URLs are deleted (due to CASCADE in mock model) - db_urls_after_delete = db_session.query(MockChannelURL).filter(MockChannelURL.channel_id.cast(String()) == uuid.UUID(created_channel_id)).all() + db_urls_after_delete = ( + db_session.query(MockChannelURL) + .filter( + MockChannelURL.channel_id.cast(String()) == uuid.UUID(created_channel_id) + ) + .all() + ) assert len(db_urls_after_delete) == 0 @@ -335,15 +401,23 @@ def test_delete_channel_not_found(db_session: Session, admin_user_client: TestCl assert response.status_code == status.HTTP_404_NOT_FOUND assert "Channel not found" in response.json()["detail"] -def test_delete_channel_forbidden_for_non_admin(db_session: Session, non_admin_user_client: TestClient, admin_user_client: TestClient): + +def test_delete_channel_forbidden_for_non_admin( + db_session: Session, + non_admin_user_client: TestClient, + admin_user_client: TestClient, +): # Setup priority and create initial channel with admin priority1 = MockPriority(id=100, description="High") db_session.add(priority1) db_session.commit() initial_channel_data = { - "tvg_id": "delete_forbidden.tv", "name": "Delete Forbidden", "group_title": "Forbidden Delete Group", - "tvg_name": "DFName", "tvg_logo": "dflogo.png", - "urls": [{"url": "http://delete_forbidden.com", "priority_id": 100}] + "tvg_id": "delete_forbidden.tv", + "name": "Delete Forbidden", + "group_title": "Forbidden Delete Group", + "tvg_name": "DFName", + "tvg_logo": "dflogo.png", + "urls": [{"url": "http://delete_forbidden.com", "priority_id": 100}], } create_response = admin_user_client.post("/channels/", json=initial_channel_data) created_channel_id = create_response.json()["id"] @@ -353,16 +427,26 @@ def test_delete_channel_forbidden_for_non_admin(db_session: Session, non_admin_u assert "required roles" in response.json()["detail"] # Ensure channel was not deleted - db_channel_not_deleted = db_session.query(MockChannelDB).filter(MockChannelDB.id.cast(String()) == uuid.UUID(created_channel_id)).first() + db_channel_not_deleted = ( + db_session.query(MockChannelDB) + .filter(MockChannelDB.id.cast(String()) == uuid.UUID(created_channel_id)) + .first() + ) assert db_channel_not_deleted is not None + + # --- Test Cases For List Channels --- + def test_list_channels_empty(db_session: Session, admin_user_client: TestClient): response = admin_user_client.get("/channels/") assert response.status_code == status.HTTP_200_OK assert response.json() == [] -def test_list_channels_with_data_and_pagination(db_session: Session, admin_user_client: TestClient): + +def test_list_channels_with_data_and_pagination( + db_session: Session, admin_user_client: TestClient +): # Setup priority priority1 = MockPriority(id=100, description="High") db_session.add(priority1) @@ -371,9 +455,12 @@ def test_list_channels_with_data_and_pagination(db_session: Session, admin_user_ # Create some channels for i in range(5): channel_data = { - "tvg_id": f"list_c{i}.tv", "name": f"List Channel {i}", "group_title": "List Group", - "tvg_name": f"LCName{i}", "tvg_logo": f"lclogo{i}.png", - "urls": [{"url": f"http://list_c{i}.com", "priority_id": 100}] + "tvg_id": f"list_c{i}.tv", + "name": f"List Channel {i}", + "group_title": "List Group", + "tvg_name": f"LCName{i}", + "tvg_logo": f"lclogo{i}.png", + "urls": [{"url": f"http://list_c{i}.com", "priority_id": 100}], } admin_user_client.post("/channels/", json=channel_data) @@ -406,12 +493,17 @@ def test_list_channels_with_data_and_pagination(db_session: Session, admin_user_ assert response_skip_beyond.json() == [] -def test_list_channels_forbidden_for_non_admin(db_session: Session, non_admin_user_client: TestClient): +def test_list_channels_forbidden_for_non_admin( + db_session: Session, non_admin_user_client: TestClient +): response = non_admin_user_client.get("/channels/") assert response.status_code == status.HTTP_403_FORBIDDEN assert "required roles" in response.json()["detail"] + + # --- Test Cases For Add Channel URL --- + def test_add_channel_url_success(db_session: Session, admin_user_client: TestClient): # Setup priority and create a channel priority1 = MockPriority(id=100, description="High") @@ -420,25 +512,33 @@ def test_add_channel_url_success(db_session: Session, admin_user_client: TestCli db_session.commit() channel_data = { - "tvg_id": "channel_for_url.tv", "name": "Channel For URL", "group_title": "URL Group", - "tvg_name": "CFUName", "tvg_logo": "cfulogo.png", - "urls": [{"url": "http://initial.com/stream", "priority_id": 100}] + "tvg_id": "channel_for_url.tv", + "name": "Channel For URL", + "group_title": "URL Group", + "tvg_name": "CFUName", + "tvg_logo": "cfulogo.png", + "urls": [{"url": "http://initial.com/stream", "priority_id": 100}], } create_response = admin_user_client.post("/channels/", json=channel_data) assert create_response.status_code == status.HTTP_201_CREATED created_channel_id = create_response.json()["id"] url_data = {"url": "http://new_stream.com/live", "priority_id": 200} - response = admin_user_client.post(f"/channels/{created_channel_id}/urls", json=url_data) + response = admin_user_client.post( + f"/channels/{created_channel_id}/urls", json=url_data + ) assert response.status_code == status.HTTP_201_CREATED data = response.json() assert data["url"] == "http://new_stream.com/live" assert data["priority_id"] == 200 - # assert data["channel_id"] == created_channel_id # ChannelURLResponse does not include channel_id - assert data["in_use"] is False # Default + assert data["in_use"] is False # Default # Verify in DB - db_url = db_session.query(MockChannelURL).filter(MockChannelURL.id.cast(String()) == uuid.UUID(data["id"])).first() + db_url = ( + db_session.query(MockChannelURL) + .filter(MockChannelURL.id.cast(String()) == uuid.UUID(data["id"])) + .first() + ) assert db_url is not None assert db_url.url == "http://new_stream.com/live" assert db_url.priority_id == 200 @@ -446,10 +546,19 @@ def test_add_channel_url_success(db_session: Session, admin_user_client: TestCli # Check the channel now has two URLs # Re-fetch channel to get updated URLs list - db_session.expire_all() # Expire to ensure fresh data from DB if ChannelResponse is not dynamic - - # Let's verify by querying the database directly for the count of URLs for the channel - url_count = db_session.query(MockChannelURL).filter(MockChannelURL.channel_id.cast(String()) == uuid.UUID(created_channel_id)).count() + + # Expire to ensure fresh data from DB if ChannelResponse is not dynamic + db_session.expire_all() + + # Let's verify by querying the database directly + # for the count of URLs for the channel + url_count = ( + db_session.query(MockChannelURL) + .filter( + MockChannelURL.channel_id.cast(String()) == uuid.UUID(created_channel_id) + ) + .count() + ) assert url_count == 2 # And also check the response from get_channel @@ -459,7 +568,9 @@ def test_add_channel_url_success(db_session: Session, admin_user_client: TestCli assert len(channel_details["urls"]) == 2 -def test_add_channel_url_channel_not_found(db_session: Session, admin_user_client: TestClient): +def test_add_channel_url_channel_not_found( + db_session: Session, admin_user_client: TestClient +): # Setup priority priority1 = MockPriority(id=100, description="High") db_session.add(priority1) @@ -467,55 +578,75 @@ def test_add_channel_url_channel_not_found(db_session: Session, admin_user_clien random_channel_uuid = uuid.uuid4() url_data = {"url": "http://stream_no_channel.com", "priority_id": 100} - response = admin_user_client.post(f"/channels/{random_channel_uuid}/urls", json=url_data) + response = admin_user_client.post( + f"/channels/{random_channel_uuid}/urls", json=url_data + ) assert response.status_code == status.HTTP_404_NOT_FOUND assert "Channel not found" in response.json()["detail"] -def test_add_channel_url_forbidden_for_non_admin(db_session: Session, non_admin_user_client: TestClient, admin_user_client: TestClient): + +def test_add_channel_url_forbidden_for_non_admin( + db_session: Session, + non_admin_user_client: TestClient, + admin_user_client: TestClient, +): # Setup priority and create a channel with admin priority1 = MockPriority(id=100, description="High") db_session.add(priority1) db_session.commit() channel_data = { - "tvg_id": "url_forbidden.tv", "name": "URL Forbidden", "group_title": "URL Forbidden Group", - "tvg_name": "UFName2", "tvg_logo": "uflogo2.png", - "urls": [{"url": "http://url_forbidden.com", "priority_id": 100}] + "tvg_id": "url_forbidden.tv", + "name": "URL Forbidden", + "group_title": "URL Forbidden Group", + "tvg_name": "UFName2", + "tvg_logo": "uflogo2.png", + "urls": [{"url": "http://url_forbidden.com", "priority_id": 100}], } create_response = admin_user_client.post("/channels/", json=channel_data) created_channel_id = create_response.json()["id"] url_data = {"url": "http://new_stream_forbidden.com", "priority_id": 100} - response = non_admin_user_client.post(f"/channels/{created_channel_id}/urls", json=url_data) + response = non_admin_user_client.post( + f"/channels/{created_channel_id}/urls", json=url_data + ) assert response.status_code == status.HTTP_403_FORBIDDEN assert "required roles" in response.json()["detail"] + # --- Test Cases For Update Channel URL --- + def test_update_channel_url_success(db_session: Session, admin_user_client: TestClient): # Setup priorities and create a channel with a URL priority1 = MockPriority(id=100, description="High") priority2 = MockPriority(id=200, description="Medium") - priority3 = MockPriority(id=300, description="Low") # New priority for update, Use valid priority ID + priority3 = MockPriority( + id=300, description="Low" + ) # New priority for update, Use valid priority ID db_session.add_all([priority1, priority2, priority3]) db_session.commit() channel_data = { - "tvg_id": "ch_update_url.tv", "name": "Channel Update URL", "group_title": "URL Update Group", - "tvg_name": "CUUName", "tvg_logo": "cuulogo.png", - "urls": [{"url": "http://original_url.com/stream", "priority_id": 100}] + "tvg_id": "ch_update_url.tv", + "name": "Channel Update URL", + "group_title": "URL Update Group", + "tvg_name": "CUUName", + "tvg_logo": "cuulogo.png", + "urls": [{"url": "http://original_url.com/stream", "priority_id": 100}], } create_ch_response = admin_user_client.post("/channels/", json=channel_data) created_channel_id = create_ch_response.json()["id"] # Get the ID of the initially created URL initial_url_id = create_ch_response.json()["urls"][0]["id"] - update_url_data = { "url": "http://updated_url.com/live", "priority_id": 300, - "in_use": True + "in_use": True, } - response = admin_user_client.put(f"/channels/{created_channel_id}/urls/{initial_url_id}", json=update_url_data) + response = admin_user_client.put( + f"/channels/{created_channel_id}/urls/{initial_url_id}", json=update_url_data + ) assert response.status_code == status.HTTP_200_OK data = response.json() assert data["id"] == initial_url_id @@ -524,22 +655,32 @@ def test_update_channel_url_success(db_session: Session, admin_user_client: Test assert data["in_use"] is True # Verify in DB - db_url = db_session.query(MockChannelURL).filter(MockChannelURL.id.cast(String()) == uuid.UUID(initial_url_id)).first() + db_url = ( + db_session.query(MockChannelURL) + .filter(MockChannelURL.id.cast(String()) == uuid.UUID(initial_url_id)) + .first() + ) assert db_url is not None assert db_url.url == "http://updated_url.com/live" assert db_url.priority_id == 300 assert db_url.in_use is True -def test_update_channel_url_partial_success(db_session: Session, admin_user_client: TestClient): + +def test_update_channel_url_partial_success( + db_session: Session, admin_user_client: TestClient +): # Setup priorities and create a channel with a URL priority1 = MockPriority(id=100, description="High") db_session.add_all([priority1]) db_session.commit() channel_data = { - "tvg_id": "ch_partial_update_url.tv", "name": "Channel Partial Update URL", "group_title": "URL Partial Update Group", - "tvg_name": "CPUName", "tvg_logo": "cpulogo.png", - "urls": [{"url": "http://partial_original.com/stream", "priority_id": 100}] + "tvg_id": "ch_partial_update_url.tv", + "name": "Channel Partial Update URL", + "group_title": "URL Partial Update Group", + "tvg_name": "CPUName", + "tvg_logo": "cpulogo.png", + "urls": [{"url": "http://partial_original.com/stream", "priority_id": 100}], } create_ch_response = admin_user_client.post("/channels/", json=channel_data) created_channel_id = create_ch_response.json()["id"] @@ -547,7 +688,9 @@ def test_update_channel_url_partial_success(db_session: Session, admin_user_clie # Update only 'in_use' update_url_data = {"in_use": True} - response = admin_user_client.put(f"/channels/{created_channel_id}/urls/{initial_url_id}", json=update_url_data) + response = admin_user_client.put( + f"/channels/{created_channel_id}/urls/{initial_url_id}", json=update_url_data + ) assert response.status_code == status.HTTP_200_OK data = response.json() assert data["id"] == initial_url_id @@ -556,76 +699,118 @@ def test_update_channel_url_partial_success(db_session: Session, admin_user_clie assert data["in_use"] is True # Verify in DB - db_url = db_session.query(MockChannelURL).filter(MockChannelURL.id.cast(String()) == uuid.UUID(initial_url_id)).first() + db_url = ( + db_session.query(MockChannelURL) + .filter(MockChannelURL.id.cast(String()) == uuid.UUID(initial_url_id)) + .first() + ) assert db_url is not None assert db_url.in_use is True assert db_url.url == "http://partial_original.com/stream" assert db_url.priority_id == 100 -def test_update_channel_url_url_not_found(db_session: Session, admin_user_client: TestClient): +def test_update_channel_url_url_not_found( + db_session: Session, admin_user_client: TestClient +): # Setup priority and create a channel priority1 = MockPriority(id=100, description="High") db_session.add(priority1) db_session.commit() channel_data = { - "tvg_id": "ch_url_not_found.tv", "name": "Channel URL Not Found", "group_title": "URL Not Found Group", - "tvg_name": "CUNFName", "tvg_logo": "cunflogo.png", - "urls": [] + "tvg_id": "ch_url_not_found.tv", + "name": "Channel URL Not Found", + "group_title": "URL Not Found Group", + "tvg_name": "CUNFName", + "tvg_logo": "cunflogo.png", + "urls": [], } create_ch_response = admin_user_client.post("/channels/", json=channel_data) created_channel_id = create_ch_response.json()["id"] - + random_url_uuid = uuid.uuid4() update_data = {"url": "http://does_not_matter.com"} - response = admin_user_client.put(f"/channels/{created_channel_id}/urls/{random_url_uuid}", json=update_data) + response = admin_user_client.put( + f"/channels/{created_channel_id}/urls/{random_url_uuid}", json=update_data + ) assert response.status_code == status.HTTP_404_NOT_FOUND assert "URL not found" in response.json()["detail"] -def test_update_channel_url_channel_id_mismatch_is_url_not_found(db_session: Session, admin_user_client: TestClient): - # This tests if a URL ID exists but is not associated with the given channel_id in the path + +def test_update_channel_url_channel_id_mismatch_is_url_not_found( + db_session: Session, admin_user_client: TestClient +): + # This tests if a URL ID exists but is not associated + # with the given channel_id in the path priority1 = MockPriority(id=100, description="High") db_session.add(priority1) db_session.commit() # Create channel 1 with a URL - ch1_data = {"tvg_id": "ch1_url_mismatch.tv", "name": "CH1 URL Mismatch", "group_title": "G1", "tvg_name":"C1UMName", "tvg_logo":"c1umlogo.png", "urls": [{"url":"http://ch1.url", "priority_id":100}]} + ch1_data = { + "tvg_id": "ch1_url_mismatch.tv", + "name": "CH1 URL Mismatch", + "group_title": "G1", + "tvg_name": "C1UMName", + "tvg_logo": "c1umlogo.png", + "urls": [{"url": "http://ch1.url", "priority_id": 100}], + } ch1_resp = admin_user_client.post("/channels/", json=ch1_data) url_id_from_ch1 = ch1_resp.json()["urls"][0]["id"] # Create channel 2 - ch2_data = {"tvg_id": "ch2_url_mismatch.tv", "name": "CH2 URL Mismatch", "group_title": "G2", "tvg_name":"C2UMName", "tvg_logo":"c2umlogo.png", "urls": []} # priority_id not needed here + ch2_data = { + "tvg_id": "ch2_url_mismatch.tv", + "name": "CH2 URL Mismatch", + "group_title": "G2", + "tvg_name": "C2UMName", + "tvg_logo": "c2umlogo.png", + "urls": [], + } # priority_id not needed here ch2_resp = admin_user_client.post("/channels/", json=ch2_data) ch2_id = ch2_resp.json()["id"] # Try to update URL from CH1 using CH2's ID in path update_data = {"url": "http://mismatch_update.com"} - response = admin_user_client.put(f"/channels/{ch2_id}/urls/{url_id_from_ch1}", json=update_data) + response = admin_user_client.put( + f"/channels/{ch2_id}/urls/{url_id_from_ch1}", json=update_data + ) assert response.status_code == status.HTTP_404_NOT_FOUND assert "URL not found" in response.json()["detail"] -def test_update_channel_url_forbidden_for_non_admin(db_session: Session, non_admin_user_client: TestClient, admin_user_client: TestClient): +def test_update_channel_url_forbidden_for_non_admin( + db_session: Session, + non_admin_user_client: TestClient, + admin_user_client: TestClient, +): # Setup priority and create channel with URL using admin priority1 = MockPriority(id=100, description="High") db_session.add(priority1) db_session.commit() channel_data = { - "tvg_id": "ch_update_url_forbidden.tv", "name": "Channel Update URL Forbidden", "group_title": "URL Update Forbidden Group", - "tvg_name": "CUFName", "tvg_logo": "cuflgo.png", - "urls": [{"url": "http://original_forbidden.com/stream", "priority_id": 100}] + "tvg_id": "ch_update_url_forbidden.tv", + "name": "Channel Update URL Forbidden", + "group_title": "URL Update Forbidden Group", + "tvg_name": "CUFName", + "tvg_logo": "cuflgo.png", + "urls": [{"url": "http://original_forbidden.com/stream", "priority_id": 100}], } create_ch_response = admin_user_client.post("/channels/", json=channel_data) created_channel_id = create_ch_response.json()["id"] initial_url_id = create_ch_response.json()["urls"][0]["id"] update_url_data = {"url": "http://attempted_update_forbidden.com"} - response = non_admin_user_client.put(f"/channels/{created_channel_id}/urls/{initial_url_id}", json=update_url_data) + response = non_admin_user_client.put( + f"/channels/{created_channel_id}/urls/{initial_url_id}", json=update_url_data + ) assert response.status_code == status.HTTP_403_FORBIDDEN assert "required roles" in response.json()["detail"] + # --- Test Cases For Delete Channel URL --- + def test_delete_channel_url_success(db_session: Session, admin_user_client: TestClient): # Setup priority and create a channel with a URL priority1 = MockPriority(id=100, description="High") @@ -633,23 +818,36 @@ def test_delete_channel_url_success(db_session: Session, admin_user_client: Test db_session.commit() channel_data = { - "tvg_id": "ch_delete_url.tv", "name": "Channel Delete URL", "group_title": "URL Delete Group", - "tvg_name": "CDUName", "tvg_logo": "cdulogo.png", - "urls": [{"url": "http://delete_this_url.com/stream", "priority_id": 100}] + "tvg_id": "ch_delete_url.tv", + "name": "Channel Delete URL", + "group_title": "URL Delete Group", + "tvg_name": "CDUName", + "tvg_logo": "cdulogo.png", + "urls": [{"url": "http://delete_this_url.com/stream", "priority_id": 100}], } create_ch_response = admin_user_client.post("/channels/", json=channel_data) created_channel_id = create_ch_response.json()["id"] url_to_delete_id = create_ch_response.json()["urls"][0]["id"] # Verify URL exists before delete - db_url_before = db_session.query(MockChannelURL).filter(MockChannelURL.id.cast(String()) == uuid.UUID(url_to_delete_id)).first() + db_url_before = ( + db_session.query(MockChannelURL) + .filter(MockChannelURL.id.cast(String()) == uuid.UUID(url_to_delete_id)) + .first() + ) assert db_url_before is not None - delete_response = admin_user_client.delete(f"/channels/{created_channel_id}/urls/{url_to_delete_id}") + delete_response = admin_user_client.delete( + f"/channels/{created_channel_id}/urls/{url_to_delete_id}" + ) assert delete_response.status_code == status.HTTP_204_NO_CONTENT # Verify URL is gone from DB - db_url_after = db_session.query(MockChannelURL).filter(MockChannelURL.id.cast(String()) == uuid.UUID(url_to_delete_id)).first() + db_url_after = ( + db_session.query(MockChannelURL) + .filter(MockChannelURL.id.cast(String()) == uuid.UUID(url_to_delete_id)) + .first() + ) assert db_url_after is None # Verify channel still exists and has no URLs @@ -658,25 +856,35 @@ def test_delete_channel_url_success(db_session: Session, admin_user_client: Test assert len(channel_response.json()["urls"]) == 0 -def test_delete_channel_url_url_not_found(db_session: Session, admin_user_client: TestClient): +def test_delete_channel_url_url_not_found( + db_session: Session, admin_user_client: TestClient +): # Setup priority and create a channel priority1 = MockPriority(id=100, description="High") db_session.add(priority1) db_session.commit() channel_data = { - "tvg_id": "ch_del_url_not_found.tv", "name": "Channel Del URL Not Found", "group_title": "URL Del Not Found Group", - "tvg_name": "CDUNFName", "tvg_logo": "cdunflogo.png", - "urls": [] + "tvg_id": "ch_del_url_not_found.tv", + "name": "Channel Del URL Not Found", + "group_title": "URL Del Not Found Group", + "tvg_name": "CDUNFName", + "tvg_logo": "cdunflogo.png", + "urls": [], } create_ch_response = admin_user_client.post("/channels/", json=channel_data) created_channel_id = create_ch_response.json()["id"] - + random_url_uuid = uuid.uuid4() - response = admin_user_client.delete(f"/channels/{created_channel_id}/urls/{random_url_uuid}") + response = admin_user_client.delete( + f"/channels/{created_channel_id}/urls/{random_url_uuid}" + ) assert response.status_code == status.HTTP_404_NOT_FOUND assert "URL not found" in response.json()["detail"] -def test_delete_channel_url_channel_id_mismatch_is_url_not_found(db_session: Session, admin_user_client: TestClient): + +def test_delete_channel_url_channel_id_mismatch_is_url_not_found( + db_session: Session, admin_user_client: TestClient +): priority1 = MockPriority(id=100, description="High") db_session.add(priority1) db_session.commit() @@ -688,7 +896,8 @@ def test_delete_channel_url_channel_id_mismatch_is_url_not_found(db_session: Ses "tvg_name": "CH1 Del URL Mismatch", "tvg_logo": "ch1delogo.png", "group_title": "G1Del", - "urls": [{"url":"http://ch1del.url", "priority_id":100}]} + "urls": [{"url": "http://ch1del.url", "priority_id": 100}], + } ch1_resp = admin_user_client.post("/channels/", json=ch1_data) print(ch1_resp.json()) url_id_from_ch1 = ch1_resp.json()["urls"][0]["id"] @@ -700,7 +909,7 @@ def test_delete_channel_url_channel_id_mismatch_is_url_not_found(db_session: Ses "tvg_name": "CH2 Del URL Mismatch", "tvg_logo": "ch2delogo.png", "group_title": "G2Del", - "urls": [] + "urls": [], } ch2_resp = admin_user_client.post("/channels/", json=ch2_data) ch2_id = ch2_resp.json()["id"] @@ -711,34 +920,55 @@ def test_delete_channel_url_channel_id_mismatch_is_url_not_found(db_session: Ses assert "URL not found" in response.json()["detail"] # Ensure the original URL on CH1 was not deleted - db_url_ch1 = db_session.query(MockChannelURL).filter(MockChannelURL.id.cast(String()) == uuid.UUID(url_id_from_ch1)).first() + db_url_ch1 = ( + db_session.query(MockChannelURL) + .filter(MockChannelURL.id.cast(String()) == uuid.UUID(url_id_from_ch1)) + .first() + ) assert db_url_ch1 is not None -def test_delete_channel_url_forbidden_for_non_admin(db_session: Session, non_admin_user_client: TestClient, admin_user_client: TestClient): +def test_delete_channel_url_forbidden_for_non_admin( + db_session: Session, + non_admin_user_client: TestClient, + admin_user_client: TestClient, +): # Setup priority and create channel with URL using admin priority1 = MockPriority(id=100, description="High") db_session.add(priority1) db_session.commit() channel_data = { - "tvg_id": "ch_del_url_forbidden.tv", "name": "Channel Del URL Forbidden", "group_title": "URL Del Forbidden Group", - "tvg_name": "CDUFName", "tvg_logo": "cduflogo.png", - "urls": [{"url": "http://original_del_forbidden.com/stream", "priority_id": 100}] + "tvg_id": "ch_del_url_forbidden.tv", + "name": "Channel Del URL Forbidden", + "group_title": "URL Del Forbidden Group", + "tvg_name": "CDUFName", + "tvg_logo": "cduflogo.png", + "urls": [ + {"url": "http://original_del_forbidden.com/stream", "priority_id": 100} + ], } create_ch_response = admin_user_client.post("/channels/", json=channel_data) created_channel_id = create_ch_response.json()["id"] initial_url_id = create_ch_response.json()["urls"][0]["id"] - response = non_admin_user_client.delete(f"/channels/{created_channel_id}/urls/{initial_url_id}") + response = non_admin_user_client.delete( + f"/channels/{created_channel_id}/urls/{initial_url_id}" + ) assert response.status_code == status.HTTP_403_FORBIDDEN assert "required roles" in response.json()["detail"] # Ensure URL was not deleted - db_url_not_deleted = db_session.query(MockChannelURL).filter(MockChannelURL.id.cast(String()) == uuid.UUID(initial_url_id)).first() + db_url_not_deleted = ( + db_session.query(MockChannelURL) + .filter(MockChannelURL.id.cast(String()) == uuid.UUID(initial_url_id)) + .first() + ) assert db_url_not_deleted is not None + # --- Test Cases For List Channel URLs --- + def test_list_channel_urls_success(db_session: Session, admin_user_client: TestClient): # Setup priorities and create a channel with multiple URLs priority1 = MockPriority(id=100, description="High") @@ -747,39 +977,48 @@ def test_list_channel_urls_success(db_session: Session, admin_user_client: TestC db_session.commit() channel_data = { - "tvg_id": "ch_list_urls.tv", "name": "Channel List URLs", "group_title": "URL List Group", - "tvg_name": "CLUName", "tvg_logo": "clulogo.png", + "tvg_id": "ch_list_urls.tv", + "name": "Channel List URLs", + "group_title": "URL List Group", + "tvg_name": "CLUName", + "tvg_logo": "clulogo.png", "urls": [ {"url": "http://list_url1.com/stream", "priority_id": 100}, - {"url": "http://list_url2.com/live", "priority_id": 200} - ] + {"url": "http://list_url2.com/live", "priority_id": 200}, + ], } create_ch_response = admin_user_client.post("/channels/", json=channel_data) created_channel_id = create_ch_response.json()["id"] - - # URLs are added during channel creation, let's get their IDs for assertion if needed + + # URLs are added during channel creation, + # let's get their IDs for assertion if needed # For now, we'll just check the count and content based on what was provided. response = admin_user_client.get(f"/channels/{created_channel_id}/urls") assert response.status_code == status.HTTP_200_OK data = response.json() assert len(data) == 2 - - # Check if the URLs returned match what we expect (order might not be guaranteed by default) + + # Check if the URLs returned match what we expect + # (order might not be guaranteed by default) returned_urls_set = {(item["url"], item["priority_id"]) for item in data} expected_urls_set = { ("http://list_url1.com/stream", 100), - ("http://list_url2.com/live", 200) + ("http://list_url2.com/live", 200), } assert returned_urls_set == expected_urls_set + def test_list_channel_urls_empty(db_session: Session, admin_user_client: TestClient): # Create a channel with no URLs initially # No need to set up MockPriority if no URLs with priority_id are being created. channel_data = { - "tvg_id": "ch_list_empty_urls.tv", "name": "Channel List Empty URLs", "group_title": "URL List Empty Group", - "tvg_name": "CLEUName", "tvg_logo": "cleulogo.png", - "urls": [] + "tvg_id": "ch_list_empty_urls.tv", + "name": "Channel List Empty URLs", + "group_title": "URL List Empty Group", + "tvg_name": "CLEUName", + "tvg_logo": "cleulogo.png", + "urls": [], } create_ch_response = admin_user_client.post("/channels/", json=channel_data) created_channel_id = create_ch_response.json()["id"] @@ -788,21 +1027,32 @@ def test_list_channel_urls_empty(db_session: Session, admin_user_client: TestCli assert response.status_code == status.HTTP_200_OK assert response.json() == [] -def test_list_channel_urls_channel_not_found(db_session: Session, admin_user_client: TestClient): + +def test_list_channel_urls_channel_not_found( + db_session: Session, admin_user_client: TestClient +): random_channel_uuid = uuid.uuid4() response = admin_user_client.get(f"/channels/{random_channel_uuid}/urls") assert response.status_code == status.HTTP_404_NOT_FOUND assert "Channel not found" in response.json()["detail"] -def test_list_channel_urls_forbidden_for_non_admin(db_session: Session, non_admin_user_client: TestClient, admin_user_client: TestClient): + +def test_list_channel_urls_forbidden_for_non_admin( + db_session: Session, + non_admin_user_client: TestClient, + admin_user_client: TestClient, +): # Setup priority and create channel with admin priority1 = MockPriority(id=100, description="High") db_session.add(priority1) db_session.commit() channel_data = { - "tvg_id": "ch_list_url_forbidden.tv", "name": "Channel List URL Forbidden", "group_title": "URL List Forbidden Group", - "tvg_name": "CLUFName", "tvg_logo": "cluflogo.png", - "urls": [{"url": "http://list_url_forbidden.com", "priority_id": 100}] + "tvg_id": "ch_list_url_forbidden.tv", + "name": "Channel List URL Forbidden", + "group_title": "URL List Forbidden Group", + "tvg_name": "CLUFName", + "tvg_logo": "cluflogo.png", + "urls": [{"url": "http://list_url_forbidden.com", "priority_id": 100}], } create_ch_response = admin_user_client.post("/channels/", json=channel_data) created_channel_id = create_ch_response.json()["id"] diff --git a/tests/test_main.py b/tests/test_main.py index 609be9e..ff5ba5d 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,19 +1,24 @@ +from unittest.mock import patch + import pytest from fastapi.testclient import TestClient + from app.main import app, lifespan -from unittest.mock import patch, MagicMock + @pytest.fixture def client(): """Test client for FastAPI app""" return TestClient(app) + def test_root_endpoint(client): """Test root endpoint returns expected message""" response = client.get("/") assert response.status_code == 200 assert response.json() == {"message": "IPTV Updater API"} + def test_openapi_schema_generation(client): """Test OpenAPI schema is properly generated""" # First call - generate schema @@ -23,7 +28,7 @@ def test_openapi_schema_generation(client): assert schema["openapi"] == "3.1.0" assert "securitySchemes" in schema["components"] assert "Bearer" in schema["components"]["securitySchemes"] - + # Test empty components initialization with patch("app.main.get_openapi", return_value={"info": {}}): # Clear cached schema @@ -35,26 +40,28 @@ def test_openapi_schema_generation(client): assert "components" in schema assert "schemas" in schema["components"] + def test_openapi_schema_caching(mocker): """Test OpenAPI schema caching behavior""" # Clear any existing schema app.openapi_schema = None - + # Mock get_openapi to return test schema mock_schema = {"test": "schema"} mocker.patch("app.main.get_openapi", return_value=mock_schema) - + # First call - should call get_openapi schema = app.openapi() assert schema == mock_schema assert app.openapi_schema == mock_schema - + # Second call - should return cached schema with patch("app.main.get_openapi") as mock_get_openapi: schema = app.openapi() assert schema == mock_schema mock_get_openapi.assert_not_called() + @pytest.mark.asyncio async def test_lifespan_init_db(mocker): """Test lifespan manager initializes database""" @@ -63,6 +70,7 @@ async def test_lifespan_init_db(mocker): pass # Just enter/exit context mock_init_db.assert_called_once() + def test_router_inclusion(): """Test all routers are properly included""" route_paths = {route.path for route in app.routes} @@ -70,4 +78,4 @@ def test_router_inclusion(): assert any(path.startswith("/auth") for path in route_paths) assert any(path.startswith("/channels") for path in route_paths) assert any(path.startswith("/playlist") for path in route_paths) - assert any(path.startswith("/priorities") for path in route_paths) \ No newline at end of file + assert any(path.startswith("/priorities") for path in route_paths) diff --git a/tests/utils/db_mocks.py b/tests/utils/db_mocks.py index 39de62c..4427282 100644 --- a/tests/utils/db_mocks.py +++ b/tests/utils/db_mocks.py @@ -1,17 +1,30 @@ -import os import uuid from datetime import datetime, timezone -from unittest.mock import patch, MagicMock -from sqlalchemy.pool import StaticPool -from sqlalchemy.orm import Session, sessionmaker, declarative_base -from sqlalchemy import create_engine, TypeDecorator, TEXT, Column, String, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer +from unittest.mock import MagicMock, patch + import pytest +from sqlalchemy import ( + TEXT, + Boolean, + Column, + DateTime, + ForeignKey, + Integer, + String, + TypeDecorator, + UniqueConstraint, + create_engine, +) +from sqlalchemy.orm import declarative_base, sessionmaker +from sqlalchemy.pool import StaticPool # Create a mock-specific Base class for testing MockBase = declarative_base() + class SQLiteUUID(TypeDecorator): """Enables UUID support for SQLite.""" + impl = TEXT cache_ok = True @@ -25,12 +38,14 @@ class SQLiteUUID(TypeDecorator): return value return uuid.UUID(value) + # Model classes for testing - prefix with Mock to avoid pytest collection class MockPriority(MockBase): __tablename__ = "priorities" id = Column(Integer, primary_key=True) description = Column(String, nullable=False) + class MockChannelDB(MockBase): __tablename__ = "channels" id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4) @@ -39,32 +54,45 @@ class MockChannelDB(MockBase): group_title = Column(String, nullable=False) tvg_name = Column(String) __table_args__ = ( - UniqueConstraint('group_title', 'name', name='uix_group_title_name'), + UniqueConstraint("group_title", "name", name="uix_group_title_name"), ) tvg_logo = Column(String) created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) - updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) + updated_at = Column( + DateTime, + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + ) + class MockChannelURL(MockBase): __tablename__ = "channels_urls" id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4) - channel_id = Column(SQLiteUUID(), ForeignKey('channels.id', ondelete='CASCADE'), nullable=False) + channel_id = Column( + SQLiteUUID(), ForeignKey("channels.id", ondelete="CASCADE"), nullable=False + ) url = Column(String, nullable=False) in_use = Column(Boolean, default=False, nullable=False) - priority_id = Column(Integer, ForeignKey('priorities.id'), nullable=False) + priority_id = Column(Integer, ForeignKey("priorities.id"), nullable=False) created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) - updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) + updated_at = Column( + DateTime, + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + ) + # Create test engine engine_mock = create_engine( "sqlite:///:memory:", connect_args={"check_same_thread": False}, - poolclass=StaticPool + poolclass=StaticPool, ) # Create test session session_mock = sessionmaker(autocommit=False, autoflush=False, bind=engine_mock) + # Mock the actual database functions def mock_get_db(): db = session_mock() @@ -73,6 +101,7 @@ def mock_get_db(): finally: db.close() + @pytest.fixture(autouse=True) def mock_env(monkeypatch): """Fixture for mocking environment variables""" @@ -82,14 +111,13 @@ def mock_env(monkeypatch): monkeypatch.setenv("DB_HOST", "localhost") monkeypatch.setenv("DB_NAME", "testdb") monkeypatch.setenv("AWS_REGION", "us-east-1") - + + @pytest.fixture def mock_ssm(): """Fixture for mocking boto3 SSM client""" - with patch('boto3.client') as mock_client: + with patch("boto3.client") as mock_client: mock_ssm = MagicMock() mock_client.return_value = mock_ssm - mock_ssm.get_parameter.return_value = { - 'Parameter': {'Value': 'mocked_value'} - } - yield mock_ssm \ No newline at end of file + mock_ssm.get_parameter.return_value = {"Parameter": {"Value": "mocked_value"}} + yield mock_ssm diff --git a/tests/utils/test_database.py b/tests/utils/test_database.py index 999ab70..a5197b6 100644 --- a/tests/utils/test_database.py +++ b/tests/utils/test_database.py @@ -1,43 +1,45 @@ import os + import pytest -from unittest.mock import patch from sqlalchemy.orm import Session -from app.utils.database import get_db_credentials, get_db -from tests.utils.db_mocks import ( - session_mock, - mock_get_db, - mock_env, - mock_ssm -) + +from app.utils.database import get_db, get_db_credentials +from tests.utils.db_mocks import mock_env, mock_ssm, session_mock + def test_get_db_credentials_env(mock_env): """Test getting DB credentials from environment variables""" conn_str = get_db_credentials() assert conn_str == "postgresql://testuser:testpass@localhost/testdb" + def test_get_db_credentials_ssm(mock_ssm): """Test getting DB credentials from SSM""" os.environ.pop("MOCK_AUTH", None) conn_str = get_db_credentials() - assert "postgresql://mocked_value:mocked_value@mocked_value/mocked_value" in conn_str + expected_conn = "postgresql://mocked_value:mocked_value@mocked_value/mocked_value" + assert expected_conn in conn_str mock_ssm.get_parameter.assert_called() + def test_get_db_credentials_ssm_exception(mock_ssm): """Test SSM credential fetching failure raises RuntimeError""" os.environ.pop("MOCK_AUTH", None) mock_ssm.get_parameter.side_effect = Exception("SSM timeout") - + with pytest.raises(RuntimeError) as excinfo: get_db_credentials() - + assert "Failed to fetch DB credentials from SSM: SSM timeout" in str(excinfo.value) + def test_session_creation(): """Test database session creation""" session = session_mock() assert isinstance(session, Session) session.close() + def test_get_db_generator(): """Test get_db dependency generator""" db_gen = get_db() @@ -48,18 +50,20 @@ def test_get_db_generator(): except StopIteration: pass + def test_init_db(mocker, mock_env): """Test database initialization creates tables""" - mock_create_all = mocker.patch('app.models.Base.metadata.create_all') - + mock_create_all = mocker.patch("app.models.Base.metadata.create_all") + # Mock get_db_credentials to return SQLite test connection mocker.patch( - 'app.utils.database.get_db_credentials', - return_value="sqlite:///:memory:" + "app.utils.database.get_db_credentials", + return_value="sqlite:///:memory:", ) - - from app.utils.database import init_db, engine + + from app.utils.database import engine, init_db + init_db() - + # Verify create_all was called with the engine - mock_create_all.assert_called_once_with(bind=engine) \ No newline at end of file + mock_create_all.assert_called_once_with(bind=engine)