Compare commits

...

2 Commits

Author SHA1 Message Date
6d506122d9 Add pre-commit commands to install script
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m7s
2025-05-28 22:05:13 -05:00
02913c7385 Linted and formatted all files 2025-05-28 21:52:39 -05:00
32 changed files with 1268 additions and 766 deletions

View File

@@ -39,6 +39,7 @@
"dflogo", "dflogo",
"dmlogo", "dmlogo",
"dotenv", "dotenv",
"EXTINF",
"fastapi", "fastapi",
"filterwarnings", "filterwarnings",
"fiorinis", "fiorinis",
@@ -47,6 +48,7 @@
"gitea", "gitea",
"iptv", "iptv",
"isort", "isort",
"KHTML",
"lclogo", "lclogo",
"LETSENCRYPT", "LETSENCRYPT",
"nohup", "nohup",
@@ -76,6 +78,7 @@
"testpaths", "testpaths",
"uflogo", "uflogo",
"umlogo", "umlogo",
"usefixtures",
"uvicorn", "uvicorn",
"venv", "venv",
"wrongpass" "wrongpass"

View File

@@ -1,12 +1,10 @@
import os
from logging.config import fileConfig from logging.config import fileConfig
from sqlalchemy import engine_from_config from sqlalchemy import engine_from_config, pool
from sqlalchemy import pool
from alembic import context from alembic import context
from app.utils.database import get_db_credentials
from app.models.db import Base from app.models.db import Base
from app.utils.database import get_db_credentials
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.
@@ -22,7 +20,7 @@ target_metadata = Base.metadata
# Override sqlalchemy.url with dynamic credentials # Override sqlalchemy.url with dynamic credentials
if not context.is_offline_mode(): if not context.is_offline_mode():
config.set_main_option('sqlalchemy.url', get_db_credentials()) config.set_main_option("sqlalchemy.url", get_db_credentials())
# other values from the config, defined by the needs of env.py, # other values from the config, defined by the needs of env.py,
# can be acquired: # can be acquired:
@@ -68,9 +66,7 @@ def run_migrations_online() -> None:
) )
with connectable.connect() as connection: with connectable.connect() as connection:
context.configure( context.configure(connection=connection, target_metadata=target_metadata)
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction(): with context.begin_transaction():
context.run_migrations() context.run_migrations()

14
app.py
View File

@@ -1,6 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import os import os
import aws_cdk as cdk import aws_cdk as cdk
from infrastructure.stack import IptvUpdaterStack from infrastructure.stack import IptvUpdaterStack
app = cdk.App() app = cdk.App()
@@ -19,21 +21,25 @@ required_vars = {
"DOMAIN_NAME": domain_name, "DOMAIN_NAME": domain_name,
"SSH_PUBLIC_KEY": ssh_public_key, "SSH_PUBLIC_KEY": ssh_public_key,
"REPO_URL": repo_url, "REPO_URL": repo_url,
"LETSENCRYPT_EMAIL": letsencrypt_email "LETSENCRYPT_EMAIL": letsencrypt_email,
} }
# Check for missing required variables # Check for missing required variables
missing_vars = [k for k, v in required_vars.items() if not v] missing_vars = [k for k, v in required_vars.items() if not v]
if missing_vars: if missing_vars:
raise ValueError(f"Missing required environment variables: {', '.join(missing_vars)}") raise ValueError(
f"Missing required environment variables: {', '.join(missing_vars)}"
)
IptvUpdaterStack(app, "IptvUpdaterStack", IptvUpdaterStack(
app,
"IptvUpdaterStack",
freedns_user=freedns_user, freedns_user=freedns_user,
freedns_password=freedns_password, freedns_password=freedns_password,
domain_name=domain_name, domain_name=domain_name,
ssh_public_key=ssh_public_key, ssh_public_key=ssh_public_key,
repo_url=repo_url, repo_url=repo_url,
letsencrypt_email=letsencrypt_email letsencrypt_email=letsencrypt_email,
) )
app.synth() app.synth()

View File

@@ -1,9 +1,14 @@
import boto3 import boto3
from fastapi import HTTPException, status from fastapi import HTTPException, status
from app.models.auth import CognitoUser from app.models.auth import CognitoUser
from app.utils.auth import calculate_secret_hash from app.utils.auth import calculate_secret_hash
from app.utils.constants import (AWS_REGION, COGNITO_CLIENT_ID, from app.utils.constants import (
COGNITO_CLIENT_SECRET, USER_ROLE_ATTRIBUTE) AWS_REGION,
COGNITO_CLIENT_ID,
COGNITO_CLIENT_SECRET,
USER_ROLE_ATTRIBUTE,
)
cognito_client = boto3.client("cognito-idp", region_name=AWS_REGION) cognito_client = boto3.client("cognito-idp", region_name=AWS_REGION)
@@ -12,43 +17,41 @@ def initiate_auth(username: str, password: str) -> dict:
""" """
Initiate AUTH flow with Cognito using USER_PASSWORD_AUTH. Initiate AUTH flow with Cognito using USER_PASSWORD_AUTH.
""" """
auth_params = { auth_params = {"USERNAME": username, "PASSWORD": password}
"USERNAME": username,
"PASSWORD": password
}
# If a client secret is required, add SECRET_HASH # If a client secret is required, add SECRET_HASH
if COGNITO_CLIENT_SECRET: if COGNITO_CLIENT_SECRET:
auth_params["SECRET_HASH"] = calculate_secret_hash( auth_params["SECRET_HASH"] = calculate_secret_hash(
username, COGNITO_CLIENT_ID, COGNITO_CLIENT_SECRET) username, COGNITO_CLIENT_ID, COGNITO_CLIENT_SECRET
)
try: try:
response = cognito_client.initiate_auth( response = cognito_client.initiate_auth(
AuthFlow="USER_PASSWORD_AUTH", AuthFlow="USER_PASSWORD_AUTH",
AuthParameters=auth_params, AuthParameters=auth_params,
ClientId=COGNITO_CLIENT_ID ClientId=COGNITO_CLIENT_ID,
) )
return response["AuthenticationResult"] return response["AuthenticationResult"]
except cognito_client.exceptions.NotAuthorizedException: except cognito_client.exceptions.NotAuthorizedException:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid username or password" detail="Invalid username or password",
) )
except cognito_client.exceptions.UserNotFoundException: except cognito_client.exceptions.UserNotFoundException:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
detail="User not found"
) )
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"An error occurred during authentication: {str(e)}" detail=f"An error occurred during authentication: {str(e)}",
) )
def get_user_from_token(access_token: str) -> CognitoUser: def get_user_from_token(access_token: str) -> CognitoUser:
""" """
Verify the token by calling GetUser in Cognito and retrieve user attributes including roles. Verify the token by calling GetUser in Cognito and
retrieve user attributes including roles.
""" """
try: try:
user_response = cognito_client.get_user(AccessToken=access_token) user_response = cognito_client.get_user(AccessToken=access_token)
@@ -59,23 +62,21 @@ def get_user_from_token(access_token: str) -> CognitoUser:
for attr in attributes: for attr in attributes:
if attr["Name"] == USER_ROLE_ATTRIBUTE: if attr["Name"] == USER_ROLE_ATTRIBUTE:
# Assume roles are stored as a comma-separated string # Assume roles are stored as a comma-separated string
user_roles = [r.strip() user_roles = [r.strip() for r in attr["Value"].split(",") if r.strip()]
for r in attr["Value"].split(",") if r.strip()]
break break
return CognitoUser(username=username, roles=user_roles) return CognitoUser(username=username, roles=user_roles)
except cognito_client.exceptions.NotAuthorizedException: except cognito_client.exceptions.NotAuthorizedException:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token."
detail="Invalid or expired token."
) )
except cognito_client.exceptions.UserNotFoundException: except cognito_client.exceptions.UserNotFoundException:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or invalid token." detail="User not found or invalid token.",
) )
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Token verification failed: {str(e)}" detail=f"Token verification failed: {str(e)}",
) )

View File

@@ -1,6 +1,6 @@
import os
from functools import wraps from functools import wraps
from typing import Callable from typing import Callable
import os
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
@@ -13,10 +13,8 @@ if os.getenv("MOCK_AUTH", "").lower() == "true":
else: else:
from app.auth.cognito import get_user_from_token from app.auth.cognito import get_user_from_token
oauth2_scheme = OAuth2PasswordBearer( oauth2_scheme = OAuth2PasswordBearer(tokenUrl="signin", scheme_name="Bearer")
tokenUrl="signin",
scheme_name="Bearer"
)
def get_current_user(token: str = Depends(oauth2_scheme)) -> CognitoUser: def get_current_user(token: str = Depends(oauth2_scheme)) -> CognitoUser:
""" """
@@ -40,7 +38,9 @@ def require_roles(*required_roles: str) -> Callable:
if not needed_roles.issubset(user_roles): if not needed_roles.issubset(user_roles):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="You do not have the required roles to access this endpoint.", detail=(
"You do not have the required roles to access this endpoint."
),
) )
return endpoint(*args, user=user, **kwargs) return endpoint(*args, user=user, **kwargs)

View File

@@ -1,12 +1,9 @@
from fastapi import HTTPException, status from fastapi import HTTPException, status
from app.models.auth import CognitoUser from app.models.auth import CognitoUser
MOCK_USERS = { MOCK_USERS = {"testuser": {"username": "testuser", "roles": ["admin"]}}
"testuser": {
"username": "testuser",
"roles": ["admin"]
}
}
def mock_get_user_from_token(token: str) -> CognitoUser: def mock_get_user_from_token(token: str) -> CognitoUser:
""" """
@@ -17,16 +14,13 @@ def mock_get_user_from_token(token: str) -> CognitoUser:
return CognitoUser(**MOCK_USERS["testuser"]) return CognitoUser(**MOCK_USERS["testuser"])
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid mock token - use 'testuser'" detail="Invalid mock token - use 'testuser'",
) )
def mock_initiate_auth(username: str, password: str) -> dict: def mock_initiate_auth(username: str, password: str) -> dict:
""" """
Mock version of initiate_auth for local testing Mock version of initiate_auth for local testing
Accepts any username/password and returns a mock token Accepts any username/password and returns a mock token
""" """
return { return {"AccessToken": "testuser", "ExpiresIn": 3600, "TokenType": "Bearer"}
"AccessToken": "testuser",
"ExpiresIn": 3600,
"TokenType": "Bearer"
}

View File

@@ -1,39 +1,59 @@
import os import argparse
import re
import gzip import gzip
import json import json
import os
import re
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import requests import requests
import argparse from utils.constants import (
from utils.constants import IPTV_SERVER_ADMIN_PASSWORD, IPTV_SERVER_ADMIN_USER, IPTV_SERVER_URL IPTV_SERVER_ADMIN_PASSWORD,
IPTV_SERVER_ADMIN_USER,
IPTV_SERVER_URL,
)
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser(description='EPG Grabber') parser = argparse.ArgumentParser(description="EPG Grabber")
parser.add_argument('--playlist', parser.add_argument(
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'playlist.m3u8'), "--playlist",
help='Path to playlist file') default=os.path.join(
parser.add_argument('--output', os.path.dirname(os.path.dirname(__file__)), "playlist.m3u8"
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'epg.xml'), ),
help='Path to output EPG XML file') help="Path to playlist file",
parser.add_argument('--epg-sources', )
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'epg_sources.json'), parser.add_argument(
help='Path to EPG sources JSON configuration file') "--output",
parser.add_argument('--save-as-gz', default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "epg.xml"),
action='store_true', help="Path to output EPG XML file",
)
parser.add_argument(
"--epg-sources",
default=os.path.join(
os.path.dirname(os.path.dirname(__file__)), "epg_sources.json"
),
help="Path to EPG sources JSON configuration file",
)
parser.add_argument(
"--save-as-gz",
action="store_true",
default=True, default=True,
help='Save an additional gzipped version of the EPG file') help="Save an additional gzipped version of the EPG file",
)
return parser.parse_args() return parser.parse_args()
def load_epg_sources(config_path): def load_epg_sources(config_path):
"""Load EPG sources from JSON configuration file""" """Load EPG sources from JSON configuration file"""
try: try:
with open(config_path, 'r', encoding='utf-8') as f: with open(config_path, encoding="utf-8") as f:
config = json.load(f) config = json.load(f)
return config.get('epg_sources', []) return config.get("epg_sources", [])
except (FileNotFoundError, json.JSONDecodeError) as e: except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Error loading EPG sources: {e}") print(f"Error loading EPG sources: {e}")
return [] return []
def get_tvg_ids(playlist_path): def get_tvg_ids(playlist_path):
""" """
Extracts unique tvg-id values from an M3U playlist file. Extracts unique tvg-id values from an M3U playlist file.
@@ -51,9 +71,9 @@ def get_tvg_ids(playlist_path):
# and ends with a double quote. # and ends with a double quote.
tvg_id_pattern = re.compile(r'tvg-id="([^"]*)"') tvg_id_pattern = re.compile(r'tvg-id="([^"]*)"')
with open(playlist_path, 'r', encoding='utf-8') as file: with open(playlist_path, encoding="utf-8") as file:
for line in file: for line in file:
if line.startswith('#EXTINF'): if line.startswith("#EXTINF"):
# Search for the tvg-id pattern in the line # Search for the tvg-id pattern in the line
match = tvg_id_pattern.search(line) match = tvg_id_pattern.search(line)
if match: if match:
@@ -64,13 +84,14 @@ def get_tvg_ids(playlist_path):
return list(unique_tvg_ids) return list(unique_tvg_ids)
def fetch_and_extract_xml(url): def fetch_and_extract_xml(url):
response = requests.get(url) response = requests.get(url)
if response.status_code != 200: if response.status_code != 200:
print(f"Failed to fetch {url}") print(f"Failed to fetch {url}")
return None return None
if url.endswith('.gz'): if url.endswith(".gz"):
try: try:
decompressed_data = gzip.decompress(response.content) decompressed_data = gzip.decompress(response.content)
return ET.fromstring(decompressed_data) return ET.fromstring(decompressed_data)
@@ -84,42 +105,46 @@ def fetch_and_extract_xml(url):
print(f"Failed to parse XML from {url}: {e}") print(f"Failed to parse XML from {url}: {e}")
return None return None
def filter_and_build_epg(urls, tvg_ids, output_file, save_as_gz=True): def filter_and_build_epg(urls, tvg_ids, output_file, save_as_gz=True):
root = ET.Element('tv') root = ET.Element("tv")
for url in urls: for url in urls:
epg_data = fetch_and_extract_xml(url) epg_data = fetch_and_extract_xml(url)
if epg_data is None: if epg_data is None:
continue continue
for channel in epg_data.findall('channel'): for channel in epg_data.findall("channel"):
tvg_id = channel.get('id') tvg_id = channel.get("id")
if tvg_id in tvg_ids: if tvg_id in tvg_ids:
root.append(channel) root.append(channel)
for programme in epg_data.findall('programme'): for programme in epg_data.findall("programme"):
tvg_id = programme.get('channel') tvg_id = programme.get("channel")
if tvg_id in tvg_ids: if tvg_id in tvg_ids:
root.append(programme) root.append(programme)
tree = ET.ElementTree(root) tree = ET.ElementTree(root)
tree.write(output_file, encoding='utf-8', xml_declaration=True) tree.write(output_file, encoding="utf-8", xml_declaration=True)
print(f"New EPG saved to {output_file}") print(f"New EPG saved to {output_file}")
if save_as_gz: if save_as_gz:
output_file_gz = output_file + '.gz' output_file_gz = output_file + ".gz"
with gzip.open(output_file_gz, 'wb') as f: with gzip.open(output_file_gz, "wb") as f:
tree.write(f, encoding='utf-8', xml_declaration=True) tree.write(f, encoding="utf-8", xml_declaration=True)
print(f"New EPG saved to {output_file_gz}") print(f"New EPG saved to {output_file_gz}")
def upload_epg(file_path): def upload_epg(file_path):
"""Uploads gzipped EPG file to IPTV server using HTTP Basic Auth""" """Uploads gzipped EPG file to IPTV server using HTTP Basic Auth"""
try: try:
with open(file_path, 'rb') as f: with open(file_path, "rb") as f:
response = requests.post( response = requests.post(
IPTV_SERVER_URL + '/admin/epg', IPTV_SERVER_URL + "/admin/epg",
auth=requests.auth.HTTPBasicAuth(IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD), auth=requests.auth.HTTPBasicAuth(
files={'file': (os.path.basename(file_path), f)} IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD
),
files={"file": (os.path.basename(file_path), f)},
) )
if response.status_code == 200: if response.status_code == 200:
@@ -129,6 +154,7 @@ def upload_epg(file_path):
except Exception as e: except Exception as e:
print(f"Upload error: {str(e)}") print(f"Upload error: {str(e)}")
if __name__ == "__main__": if __name__ == "__main__":
args = parse_arguments() args = parse_arguments()
playlist_file = args.playlist playlist_file = args.playlist
@@ -144,4 +170,4 @@ if __name__ == "__main__":
filter_and_build_epg(urls, tvg_ids, output_file, args.save_as_gz) filter_and_build_epg(urls, tvg_ids, output_file, args.save_as_gz)
if args.save_as_gz: if args.save_as_gz:
upload_epg(output_file + '.gz') upload_epg(output_file + ".gz")

View File

@@ -1,26 +1,45 @@
import os
import argparse import argparse
import json import json
import logging import logging
import requests import os
from pathlib import Path
from datetime import datetime from datetime import datetime
import requests
from utils.check_streams import StreamValidator from utils.check_streams import StreamValidator
from utils.constants import EPG_URL, IPTV_SERVER_ADMIN_PASSWORD, IPTV_SERVER_ADMIN_USER, IPTV_SERVER_URL from utils.constants import (
EPG_URL,
IPTV_SERVER_ADMIN_PASSWORD,
IPTV_SERVER_ADMIN_USER,
IPTV_SERVER_URL,
)
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser(description='IPTV playlist generator') parser = argparse.ArgumentParser(description="IPTV playlist generator")
parser.add_argument('--output', parser.add_argument(
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'playlist.m3u8'), "--output",
help='Path to output playlist file') default=os.path.join(
parser.add_argument('--channels', os.path.dirname(os.path.dirname(__file__)), "playlist.m3u8"
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'channels.json'), ),
help='Path to channels definition JSON file') help="Path to output playlist file",
parser.add_argument('--dead-channels-log', )
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'dead_channels.log'), parser.add_argument(
help='Path to log file to store a list of dead channels') "--channels",
default=os.path.join(
os.path.dirname(os.path.dirname(__file__)), "channels.json"
),
help="Path to channels definition JSON file",
)
parser.add_argument(
"--dead-channels-log",
default=os.path.join(
os.path.dirname(os.path.dirname(__file__)), "dead_channels.log"
),
help="Path to log file to store a list of dead channels",
)
return parser.parse_args() return parser.parse_args()
def find_working_stream(validator, urls): def find_working_stream(validator, urls):
"""Test all URLs and return the first working one""" """Test all URLs and return the first working one"""
for url in urls: for url in urls:
@@ -29,9 +48,10 @@ def find_working_stream(validator, urls):
return url return url
return None return None
def create_playlist(channels_file, output_file): def create_playlist(channels_file, output_file):
# Read channels from JSON file # Read channels from JSON file
with open(channels_file, 'r', encoding='utf-8') as f: with open(channels_file, encoding="utf-8") as f:
channels = json.load(f) channels = json.load(f)
# Initialize validator # Initialize validator
@@ -41,9 +61,9 @@ def create_playlist(channels_file, output_file):
m3u8_content = f'#EXTM3U url-tvg="{EPG_URL}"\n' m3u8_content = f'#EXTM3U url-tvg="{EPG_URL}"\n'
for channel in channels: for channel in channels:
if 'urls' in channel: # Check if channel has URLs if "urls" in channel: # Check if channel has URLs
# Find first working stream # Find first working stream
working_url = find_working_stream(validator, channel['urls']) working_url = find_working_stream(validator, channel["urls"])
if working_url: if working_url:
# Add channel to playlist # Add channel to playlist
@@ -51,24 +71,30 @@ def create_playlist(channels_file, output_file):
m3u8_content += f'tvg-name="{channel.get("tvg-name", "")}" ' m3u8_content += f'tvg-name="{channel.get("tvg-name", "")}" '
m3u8_content += f'tvg-logo="{channel.get("tvg-logo", "")}" ' m3u8_content += f'tvg-logo="{channel.get("tvg-logo", "")}" '
m3u8_content += f'group-title="{channel.get("group-title", "")}", ' m3u8_content += f'group-title="{channel.get("group-title", "")}", '
m3u8_content += f'{channel.get("name", "")}\n' m3u8_content += f"{channel.get('name', '')}\n"
m3u8_content += f'{working_url}\n' m3u8_content += f"{working_url}\n"
else: else:
# Log dead channel # Log dead channel
logging.info(f'Dead channel: {channel.get("name", "Unknown")} - No working streams found') logging.info(
f"Dead channel: {channel.get('name', 'Unknown')} - "
"No working streams found"
)
# Write playlist file # Write playlist file
with open(output_file, 'w', encoding='utf-8') as f: with open(output_file, "w", encoding="utf-8") as f:
f.write(m3u8_content) f.write(m3u8_content)
def upload_playlist(file_path): def upload_playlist(file_path):
"""Uploads playlist file to IPTV server using HTTP Basic Auth""" """Uploads playlist file to IPTV server using HTTP Basic Auth"""
try: try:
with open(file_path, 'rb') as f: with open(file_path, "rb") as f:
response = requests.post( response = requests.post(
IPTV_SERVER_URL + '/admin/playlist', IPTV_SERVER_URL + "/admin/playlist",
auth=requests.auth.HTTPBasicAuth(IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD), auth=requests.auth.HTTPBasicAuth(
files={'file': (os.path.basename(file_path), f)} IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD
),
files={"file": (os.path.basename(file_path), f)},
) )
if response.status_code == 200: if response.status_code == 200:
@@ -78,6 +104,7 @@ def upload_playlist(file_path):
except Exception as e: except Exception as e:
print(f"Upload error: {str(e)}") print(f"Upload error: {str(e)}")
def main(): def main():
args = parse_arguments() args = parse_arguments()
channels_file = args.channels channels_file = args.channels
@@ -85,24 +112,25 @@ def main():
dead_channels_log_file = args.dead_channels_log dead_channels_log_file = args.dead_channels_log
# Clear previous log file # Clear previous log file
with open(dead_channels_log_file, 'w') as f: with open(dead_channels_log_file, "w") as f:
f.write(f'Log created on {datetime.now()}\n') f.write(f"Log created on {datetime.now()}\n")
# Configure logging # Configure logging
logging.basicConfig( logging.basicConfig(
filename=dead_channels_log_file, filename=dead_channels_log_file,
level=logging.INFO, level=logging.INFO,
format='%(asctime)s - %(message)s', format="%(asctime)s - %(message)s",
datefmt='%Y-%m-%d %H:%M:%S' datefmt="%Y-%m-%d %H:%M:%S",
) )
# Create playlist # Create playlist
create_playlist(channels_file, output_file) create_playlist(channels_file, output_file)
#upload playlist to server # upload playlist to server
upload_playlist(output_file) upload_playlist(output_file)
print("Playlist creation completed!") print("Playlist creation completed!")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -1,17 +1,18 @@
from fastapi.concurrency import asynccontextmanager
from app.routers import channels, auth, playlist, priorities
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.concurrency import asynccontextmanager
from fastapi.openapi.utils import get_openapi from fastapi.openapi.utils import get_openapi
from app.routers import auth, channels, playlist, priorities
from app.utils.database import init_db from app.utils.database import init_db
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
# Initialize database tables on startup # Initialize database tables on startup
init_db() init_db()
yield yield
app = FastAPI( app = FastAPI(
lifespan=lifespan, lifespan=lifespan,
title="IPTV Updater API", title="IPTV Updater API",
@@ -19,6 +20,7 @@ app = FastAPI(
version="1.0.0", version="1.0.0",
) )
def custom_openapi(): def custom_openapi():
if app.openapi_schema: if app.openapi_schema:
return app.openapi_schema return app.openapi_schema
@@ -40,11 +42,7 @@ def custom_openapi():
# Add security scheme component # Add security scheme component
openapi_schema["components"]["securitySchemes"] = { openapi_schema["components"]["securitySchemes"] = {
"Bearer": { "Bearer": {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}
"type": "http",
"scheme": "bearer",
"bearerFormat": "JWT"
}
} }
# Add global security requirement # Add global security requirement
@@ -56,12 +54,15 @@ def custom_openapi():
app.openapi_schema = openapi_schema app.openapi_schema = openapi_schema
return app.openapi_schema return app.openapi_schema
app.openapi = custom_openapi app.openapi = custom_openapi
@app.get("/") @app.get("/")
async def root(): async def root():
return {"message": "IPTV Updater API"} return {"message": "IPTV Updater API"}
# Include routers # Include routers
app.include_router(auth.router) app.include_router(auth.router)
app.include_router(channels.router) app.include_router(channels.router)

View File

@@ -1,4 +1,19 @@
from .db import Base, ChannelDB, ChannelURL from .db import Base, ChannelDB, ChannelURL
from .schemas import ChannelCreate, ChannelUpdate, ChannelResponse, ChannelURLCreate, ChannelURLResponse from .schemas import (
ChannelCreate,
ChannelResponse,
ChannelUpdate,
ChannelURLCreate,
ChannelURLResponse,
)
__all__ = ["Base", "ChannelDB", "ChannelCreate", "ChannelUpdate", "ChannelResponse", "ChannelURL", "ChannelURLCreate", "ChannelURLResponse"] __all__ = [
"Base",
"ChannelDB",
"ChannelCreate",
"ChannelUpdate",
"ChannelResponse",
"ChannelURL",
"ChannelURLCreate",
"ChannelURLResponse",
]

View File

@@ -1,20 +1,26 @@
from typing import List, Optional from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class SigninRequest(BaseModel): class SigninRequest(BaseModel):
"""Request model for the signin endpoint.""" """Request model for the signin endpoint."""
username: str = Field(..., description="The user's username") username: str = Field(..., description="The user's username")
password: str = Field(..., description="The user's password") password: str = Field(..., description="The user's password")
class TokenResponse(BaseModel): class TokenResponse(BaseModel):
"""Response model for successful authentication.""" """Response model for successful authentication."""
access_token: str = Field(..., description="Access JWT token from Cognito") access_token: str = Field(..., description="Access JWT token from Cognito")
id_token: str = Field(..., description="ID JWT token from Cognito") id_token: str = Field(..., description="ID JWT token from Cognito")
refresh_token: Optional[str] = Field( refresh_token: Optional[str] = Field(None, description="Refresh token from Cognito")
None, description="Refresh token from Cognito")
token_type: str = Field(..., description="Type of the token returned") token_type: str = Field(..., description="Type of the token returned")
class CognitoUser(BaseModel): class CognitoUser(BaseModel):
"""Model representing the user returned from token verification.""" """Model representing the user returned from token verification."""
username: str username: str
roles: List[str] roles: list[str]

View File

@@ -1,21 +1,33 @@
from datetime import datetime, timezone
import uuid import uuid
from sqlalchemy import Column, String, JSON, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer from datetime import datetime, timezone
from sqlalchemy import (
Boolean,
Column,
DateTime,
ForeignKey,
Integer,
String,
UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import declarative_base from sqlalchemy.orm import declarative_base, relationship
from sqlalchemy.orm import relationship
Base = declarative_base() Base = declarative_base()
class Priority(Base): class Priority(Base):
"""SQLAlchemy model for channel URL priorities""" """SQLAlchemy model for channel URL priorities"""
__tablename__ = "priorities" __tablename__ = "priorities"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
description = Column(String, nullable=False) description = Column(String, nullable=False)
class ChannelDB(Base): class ChannelDB(Base):
"""SQLAlchemy model for IPTV channels""" """SQLAlchemy model for IPTV channels"""
__tablename__ = "channels" __tablename__ = "channels"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
@@ -25,26 +37,42 @@ class ChannelDB(Base):
tvg_name = Column(String) tvg_name = Column(String)
__table_args__ = ( __table_args__ = (
UniqueConstraint('group_title', 'name', name='uix_group_title_name'), UniqueConstraint("group_title", "name", name="uix_group_title_name"),
) )
tvg_logo = Column(String) tvg_logo = Column(String)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) updated_at = Column(
DateTime,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
# Relationship with ChannelURL # Relationship with ChannelURL
urls = relationship("ChannelURL", back_populates="channel", cascade="all, delete-orphan") urls = relationship(
"ChannelURL", back_populates="channel", cascade="all, delete-orphan"
)
class ChannelURL(Base): class ChannelURL(Base):
"""SQLAlchemy model for channel URLs""" """SQLAlchemy model for channel URLs"""
__tablename__ = "channels_urls" __tablename__ = "channels_urls"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
channel_id = Column(UUID(as_uuid=True), ForeignKey('channels.id', ondelete='CASCADE'), nullable=False) channel_id = Column(
UUID(as_uuid=True),
ForeignKey("channels.id", ondelete="CASCADE"),
nullable=False,
)
url = Column(String, nullable=False) url = Column(String, nullable=False)
in_use = Column(Boolean, default=False, nullable=False) in_use = Column(Boolean, default=False, nullable=False)
priority_id = Column(Integer, ForeignKey('priorities.id'), nullable=False) priority_id = Column(Integer, ForeignKey("priorities.id"), nullable=False)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) updated_at = Column(
DateTime,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
# Relationships # Relationships
channel = relationship("ChannelDB", back_populates="urls") channel = relationship("ChannelDB", back_populates="urls")

View File

@@ -1,30 +1,43 @@
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import Optional
from uuid import UUID from uuid import UUID
from pydantic import BaseModel, Field, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
class PriorityBase(BaseModel): class PriorityBase(BaseModel):
"""Base Pydantic model for priorities""" """Base Pydantic model for priorities"""
id: int id: int
description: str description: str
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
class PriorityCreate(PriorityBase): class PriorityCreate(PriorityBase):
"""Pydantic model for creating priorities""" """Pydantic model for creating priorities"""
pass pass
class PriorityResponse(PriorityBase): class PriorityResponse(PriorityBase):
"""Pydantic model for priority responses""" """Pydantic model for priority responses"""
pass pass
class ChannelURLCreate(BaseModel): class ChannelURLCreate(BaseModel):
"""Pydantic model for creating channel URLs""" """Pydantic model for creating channel URLs"""
url: str url: str
priority_id: int = Field(default=100, ge=100, le=300) # Default to High, validate range priority_id: int = Field(
default=100, ge=100, le=300
) # Default to High, validate range
class ChannelURLBase(ChannelURLCreate): class ChannelURLBase(ChannelURLCreate):
"""Base Pydantic model for channel URL responses""" """Base Pydantic model for channel URL responses"""
id: UUID id: UUID
in_use: bool in_use: bool
created_at: datetime created_at: datetime
@@ -33,42 +46,52 @@ class ChannelURLBase(ChannelURLCreate):
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
class ChannelURLResponse(ChannelURLBase): class ChannelURLResponse(ChannelURLBase):
"""Pydantic model for channel URL responses""" """Pydantic model for channel URL responses"""
pass pass
class ChannelCreate(BaseModel): class ChannelCreate(BaseModel):
"""Pydantic model for creating channels""" """Pydantic model for creating channels"""
urls: List[ChannelURLCreate] # List of URL objects with priority
urls: list[ChannelURLCreate] # List of URL objects with priority
name: str name: str
group_title: str group_title: str
tvg_id: str tvg_id: str
tvg_logo: str tvg_logo: str
tvg_name: str tvg_name: str
class ChannelURLUpdate(BaseModel): class ChannelURLUpdate(BaseModel):
"""Pydantic model for updating channel URLs""" """Pydantic model for updating channel URLs"""
url: Optional[str] = None url: Optional[str] = None
in_use: Optional[bool] = None in_use: Optional[bool] = None
priority_id: Optional[int] = Field(default=None, ge=100, le=300) priority_id: Optional[int] = Field(default=None, ge=100, le=300)
class ChannelUpdate(BaseModel): class ChannelUpdate(BaseModel):
"""Pydantic model for updating channels (all fields optional)""" """Pydantic model for updating channels (all fields optional)"""
name: Optional[str] = Field(None, min_length=1) name: Optional[str] = Field(None, min_length=1)
group_title: Optional[str] = Field(None, min_length=1) group_title: Optional[str] = Field(None, min_length=1)
tvg_id: Optional[str] = Field(None, min_length=1) tvg_id: Optional[str] = Field(None, min_length=1)
tvg_logo: Optional[str] = None tvg_logo: Optional[str] = None
tvg_name: Optional[str] = Field(None, min_length=1) tvg_name: Optional[str] = Field(None, min_length=1)
class ChannelResponse(BaseModel): class ChannelResponse(BaseModel):
"""Pydantic model for channel responses""" """Pydantic model for channel responses"""
id: UUID id: UUID
name: str name: str
group_title: str group_title: str
tvg_id: str tvg_id: str
tvg_logo: str tvg_logo: str
tvg_name: str tvg_name: str
urls: List[ChannelURLResponse] # List of URL objects without channel_id urls: list[ChannelURLResponse] # List of URL objects without channel_id
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime

View File

@@ -1,16 +1,16 @@
from fastapi import APIRouter from fastapi import APIRouter
from app.auth.cognito import initiate_auth from app.auth.cognito import initiate_auth
from app.models.auth import SigninRequest, TokenResponse from app.models.auth import SigninRequest, TokenResponse
router = APIRouter( router = APIRouter(prefix="/auth", tags=["authentication"])
prefix="/auth",
tags=["authentication"]
)
@router.post("/signin", response_model=TokenResponse, summary="Signin Endpoint") @router.post("/signin", response_model=TokenResponse, summary="Signin Endpoint")
def signin(credentials: SigninRequest): def signin(credentials: SigninRequest):
""" """
Sign-in endpoint to authenticate the user with AWS Cognito using username and password. Sign-in endpoint to authenticate the user with AWS Cognito
using username and password.
On success, returns JWT tokens (access_token, id_token, refresh_token). On success, returns JWT tokens (access_token, id_token, refresh_token).
""" """
auth_result = initiate_auth(credentials.username, credentials.password) auth_result = initiate_auth(credentials.username, credentials.password)

View File

@@ -1,52 +1,54 @@
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from typing import List
from uuid import UUID from uuid import UUID
from sqlalchemy import and_
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import and_
from sqlalchemy.orm import Session
from app.auth.dependencies import get_current_user, require_roles
from app.models import ( from app.models import (
ChannelDB,
ChannelURL,
ChannelCreate, ChannelCreate,
ChannelUpdate, ChannelDB,
ChannelResponse, ChannelResponse,
ChannelUpdate,
ChannelURL,
ChannelURLCreate, ChannelURLCreate,
ChannelURLResponse, ChannelURLResponse,
) )
from app.models.auth import CognitoUser
from app.models.schemas import ChannelURLUpdate from app.models.schemas import ChannelURLUpdate
from app.utils.database import get_db from app.utils.database import get_db
from app.auth.dependencies import get_current_user, require_roles
from app.models.auth import CognitoUser
router = APIRouter( router = APIRouter(prefix="/channels", tags=["channels"])
prefix="/channels",
tags=["channels"]
)
@router.post("/", response_model=ChannelResponse, status_code=status.HTTP_201_CREATED) @router.post("/", response_model=ChannelResponse, status_code=status.HTTP_201_CREATED)
@require_roles("admin") @require_roles("admin")
def create_channel( def create_channel(
channel: ChannelCreate, channel: ChannelCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user) user: CognitoUser = Depends(get_current_user),
): ):
"""Create a new channel""" """Create a new channel"""
# Check for duplicate channel (same group_title + name) # Check for duplicate channel (same group_title + name)
existing_channel = db.query(ChannelDB).filter( existing_channel = (
db.query(ChannelDB)
.filter(
and_( and_(
ChannelDB.group_title == channel.group_title, ChannelDB.group_title == channel.group_title,
ChannelDB.name == channel.name ChannelDB.name == channel.name,
)
)
.first()
) )
).first()
if existing_channel: if existing_channel:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_409_CONFLICT, status_code=status.HTTP_409_CONFLICT,
detail="Channel with same group_title and name already exists" detail="Channel with same group_title and name already exists",
) )
# Create channel without URLs first # Create channel without URLs first
channel_data = channel.model_dump(exclude={'urls'}) channel_data = channel.model_dump(exclude={"urls"})
urls = channel.urls urls = channel.urls
db_channel = ChannelDB(**channel_data) db_channel = ChannelDB(**channel_data)
db.add(db_channel) db.add(db_channel)
@@ -59,7 +61,7 @@ def create_channel(
channel_id=db_channel.id, channel_id=db_channel.id,
url=url.url, url=url.url,
priority_id=url.priority_id, priority_id=url.priority_id,
in_use=False in_use=False,
) )
db.add(db_url) db.add(db_url)
@@ -67,53 +69,58 @@ def create_channel(
db.refresh(db_channel) db.refresh(db_channel)
return db_channel return db_channel
@router.get("/{channel_id}", response_model=ChannelResponse) @router.get("/{channel_id}", response_model=ChannelResponse)
def get_channel( def get_channel(channel_id: UUID, db: Session = Depends(get_db)):
channel_id: UUID,
db: Session = Depends(get_db)
):
"""Get a channel by id""" """Get a channel by id"""
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first() channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not channel: if not channel:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
detail="Channel not found"
) )
return channel return channel
@router.put("/{channel_id}", response_model=ChannelResponse) @router.put("/{channel_id}", response_model=ChannelResponse)
@require_roles("admin") @require_roles("admin")
def update_channel( def update_channel(
channel_id: UUID, channel_id: UUID,
channel: ChannelUpdate, channel: ChannelUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user) user: CognitoUser = Depends(get_current_user),
): ):
"""Update a channel""" """Update a channel"""
db_channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first() db_channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not db_channel: if not db_channel:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
detail="Channel not found"
) )
# Only check for duplicates if name or group_title are being updated # Only check for duplicates if name or group_title are being updated
if channel.name is not None or channel.group_title is not None: if channel.name is not None or channel.group_title is not None:
name = channel.name if channel.name is not None else db_channel.name name = channel.name if channel.name is not None else db_channel.name
group_title = channel.group_title if channel.group_title is not None else db_channel.group_title group_title = (
channel.group_title
if channel.group_title is not None
else db_channel.group_title
)
existing_channel = db.query(ChannelDB).filter( existing_channel = (
db.query(ChannelDB)
.filter(
and_( and_(
ChannelDB.group_title == group_title, ChannelDB.group_title == group_title,
ChannelDB.name == name, ChannelDB.name == name,
ChannelDB.id != channel_id ChannelDB.id != channel_id,
)
)
.first()
) )
).first()
if existing_channel: if existing_channel:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_409_CONFLICT, status_code=status.HTTP_409_CONFLICT,
detail="Channel with same group_title and name already exists" detail="Channel with same group_title and name already exists",
) )
# Update only provided fields # Update only provided fields
@@ -125,64 +132,71 @@ def update_channel(
db.refresh(db_channel) db.refresh(db_channel)
return db_channel return db_channel
@router.delete("/{channel_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{channel_id}", status_code=status.HTTP_204_NO_CONTENT)
@require_roles("admin") @require_roles("admin")
def delete_channel( def delete_channel(
channel_id: UUID, channel_id: UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user) user: CognitoUser = Depends(get_current_user),
): ):
"""Delete a channel""" """Delete a channel"""
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first() channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not channel: if not channel:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
detail="Channel not found"
) )
db.delete(channel) db.delete(channel)
db.commit() db.commit()
return None return None
@router.get("/", response_model=List[ChannelResponse])
@router.get("/", response_model=list[ChannelResponse])
@require_roles("admin") @require_roles("admin")
def list_channels( def list_channels(
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user) user: CognitoUser = Depends(get_current_user),
): ):
"""List all channels with pagination""" """List all channels with pagination"""
return db.query(ChannelDB).offset(skip).limit(limit).all() return db.query(ChannelDB).offset(skip).limit(limit).all()
# URL Management Endpoints # URL Management Endpoints
@router.post("/{channel_id}/urls", response_model=ChannelURLResponse, status_code=status.HTTP_201_CREATED)
@router.post(
"/{channel_id}/urls",
response_model=ChannelURLResponse,
status_code=status.HTTP_201_CREATED,
)
@require_roles("admin") @require_roles("admin")
def add_channel_url( def add_channel_url(
channel_id: UUID, channel_id: UUID,
url: ChannelURLCreate, url: ChannelURLCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user) user: CognitoUser = Depends(get_current_user),
): ):
"""Add a new URL to a channel""" """Add a new URL to a channel"""
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first() channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not channel: if not channel:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
detail="Channel not found"
) )
db_url = ChannelURL( db_url = ChannelURL(
channel_id=channel_id, channel_id=channel_id,
url=url.url, url=url.url,
priority_id=url.priority_id, priority_id=url.priority_id,
in_use=False # Default to not in use in_use=False, # Default to not in use
) )
db.add(db_url) db.add(db_url)
db.commit() db.commit()
db.refresh(db_url) db.refresh(db_url)
return db_url return db_url
@router.put("/{channel_id}/urls/{url_id}", response_model=ChannelURLResponse) @router.put("/{channel_id}/urls/{url_id}", response_model=ChannelURLResponse)
@require_roles("admin") @require_roles("admin")
def update_channel_url( def update_channel_url(
@@ -190,20 +204,18 @@ def update_channel_url(
url_id: UUID, url_id: UUID,
url_update: ChannelURLUpdate, url_update: ChannelURLUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user) user: CognitoUser = Depends(get_current_user),
): ):
"""Update a channel URL (url, in_use, or priority_id)""" """Update a channel URL (url, in_use, or priority_id)"""
db_url = db.query(ChannelURL).filter( db_url = (
and_( db.query(ChannelURL)
ChannelURL.id == url_id, .filter(and_(ChannelURL.id == url_id, ChannelURL.channel_id == channel_id))
ChannelURL.channel_id == channel_id .first()
) )
).first()
if not db_url: if not db_url:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="URL not found"
detail="URL not found"
) )
if url_update.url is not None: if url_update.url is not None:
@@ -217,45 +229,44 @@ def update_channel_url(
db.refresh(db_url) db.refresh(db_url)
return db_url return db_url
@router.delete("/{channel_id}/urls/{url_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{channel_id}/urls/{url_id}", status_code=status.HTTP_204_NO_CONTENT)
@require_roles("admin") @require_roles("admin")
def delete_channel_url( def delete_channel_url(
channel_id: UUID, channel_id: UUID,
url_id: UUID, url_id: UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user) user: CognitoUser = Depends(get_current_user),
): ):
"""Delete a URL from a channel""" """Delete a URL from a channel"""
url = db.query(ChannelURL).filter( url = (
and_( db.query(ChannelURL)
ChannelURL.id == url_id, .filter(and_(ChannelURL.id == url_id, ChannelURL.channel_id == channel_id))
ChannelURL.channel_id == channel_id .first()
) )
).first()
if not url: if not url:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="URL not found"
detail="URL not found"
) )
db.delete(url) db.delete(url)
db.commit() db.commit()
return None return None
@router.get("/{channel_id}/urls", response_model=List[ChannelURLResponse])
@router.get("/{channel_id}/urls", response_model=list[ChannelURLResponse])
@require_roles("admin") @require_roles("admin")
def list_channel_urls( def list_channel_urls(
channel_id: UUID, channel_id: UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user) user: CognitoUser = Depends(get_current_user),
): ):
"""List all URLs for a channel""" """List all URLs for a channel"""
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first() channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not channel: if not channel:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
detail="Channel not found"
) )
return db.query(ChannelURL).filter(ChannelURL.channel_id == channel_id).all() return db.query(ChannelURL).filter(ChannelURL.channel_id == channel_id).all()

View File

@@ -1,14 +1,12 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from app.auth.dependencies import get_current_user from app.auth.dependencies import get_current_user
from app.models.auth import CognitoUser from app.models.auth import CognitoUser
router = APIRouter( router = APIRouter(prefix="/playlist", tags=["playlist"])
prefix="/playlist",
tags=["playlist"]
)
@router.get("/protected",
summary="Protected endpoint for authenticated users") @router.get("/protected", summary="Protected endpoint for authenticated users")
async def protected_route(user: CognitoUser = Depends(get_current_user)): async def protected_route(user: CognitoUser = Depends(get_current_user)):
""" """
Protected endpoint that requires authentication for all users. Protected endpoint that requires authentication for all users.

View File

@@ -1,25 +1,22 @@
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import delete, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy import select, delete
from typing import List
from app.auth.dependencies import get_current_user, require_roles
from app.models.auth import CognitoUser
from app.models.db import Priority from app.models.db import Priority
from app.models.schemas import PriorityCreate, PriorityResponse from app.models.schemas import PriorityCreate, PriorityResponse
from app.utils.database import get_db from app.utils.database import get_db
from app.auth.dependencies import get_current_user, require_roles
from app.models.auth import CognitoUser
router = APIRouter( router = APIRouter(prefix="/priorities", tags=["priorities"])
prefix="/priorities",
tags=["priorities"]
)
@router.post("/", response_model=PriorityResponse, status_code=status.HTTP_201_CREATED) @router.post("/", response_model=PriorityResponse, status_code=status.HTTP_201_CREATED)
@require_roles("admin") @require_roles("admin")
def create_priority( def create_priority(
priority: PriorityCreate, priority: PriorityCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user) user: CognitoUser = Depends(get_current_user),
): ):
"""Create a new priority""" """Create a new priority"""
# Check if priority with this ID already exists # Check if priority with this ID already exists
@@ -27,7 +24,7 @@ def create_priority(
if existing: if existing:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_409_CONFLICT, status_code=status.HTTP_409_CONFLICT,
detail=f"Priority with ID {priority.id} already exists" detail=f"Priority with ID {priority.id} already exists",
) )
db_priority = Priority(**priority.model_dump()) db_priority = Priority(**priority.model_dump())
@@ -36,37 +33,38 @@ def create_priority(
db.refresh(db_priority) db.refresh(db_priority)
return db_priority return db_priority
@router.get("/", response_model=List[PriorityResponse])
@router.get("/", response_model=list[PriorityResponse])
@require_roles("admin") @require_roles("admin")
def list_priorities( def list_priorities(
db: Session = Depends(get_db), db: Session = Depends(get_db), user: CognitoUser = Depends(get_current_user)
user: CognitoUser = Depends(get_current_user)
): ):
"""List all priorities""" """List all priorities"""
return db.query(Priority).all() return db.query(Priority).all()
@router.get("/{priority_id}", response_model=PriorityResponse) @router.get("/{priority_id}", response_model=PriorityResponse)
@require_roles("admin") @require_roles("admin")
def get_priority( def get_priority(
priority_id: int, priority_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user) user: CognitoUser = Depends(get_current_user),
): ):
"""Get a priority by id""" """Get a priority by id"""
priority = db.get(Priority, priority_id) priority = db.get(Priority, priority_id)
if not priority: if not priority:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="Priority not found"
detail="Priority not found"
) )
return priority return priority
@router.delete("/{priority_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{priority_id}", status_code=status.HTTP_204_NO_CONTENT)
@require_roles("admin") @require_roles("admin")
def delete_priority( def delete_priority(
priority_id: int, priority_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user) user: CognitoUser = Depends(get_current_user),
): ):
"""Delete a priority (if not in use)""" """Delete a priority (if not in use)"""
from app.models.db import ChannelURL from app.models.db import ChannelURL
@@ -75,21 +73,18 @@ def delete_priority(
priority = db.get(Priority, priority_id) priority = db.get(Priority, priority_id)
if not priority: if not priority:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="Priority not found"
detail="Priority not found"
) )
# Check if priority is in use # Check if priority is in use
in_use = db.scalar( in_use = db.scalar(
select(ChannelURL) select(ChannelURL).where(ChannelURL.priority_id == priority_id).limit(1)
.where(ChannelURL.priority_id == priority_id)
.limit(1)
) )
if in_use: if in_use:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_409_CONFLICT, status_code=status.HTTP_409_CONFLICT,
detail="Cannot delete priority that is in use by channel URLs" detail="Cannot delete priority that is in use by channel URLs",
) )
db.execute(delete(Priority).where(Priority.id == priority_id)) db.execute(delete(Priority).where(Priority.id == priority_id))

View File

@@ -2,11 +2,13 @@ import base64
import hashlib import hashlib
import hmac import hmac
def calculate_secret_hash(username: str, client_id: str, client_secret: str) -> str: def calculate_secret_hash(username: str, client_id: str, client_secret: str) -> str:
""" """
Calculate the Cognito SECRET_HASH using HMAC SHA256 for secret-enabled clients. Calculate the Cognito SECRET_HASH using HMAC SHA256 for secret-enabled clients.
""" """
msg = username + client_id msg = username + client_id
dig = hmac.new(client_secret.encode('utf-8'), dig = hmac.new(
msg.encode('utf-8'), hashlib.sha256).digest() client_secret.encode("utf-8"), msg.encode("utf-8"), hashlib.sha256
).digest()
return base64.b64encode(dig).decode() return base64.b64encode(dig).decode()

View File

@@ -1,32 +1,41 @@
import os
import argparse import argparse
import requests
import logging import logging
from requests.exceptions import RequestException, Timeout, ConnectionError, HTTPError import os
import requests
from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout
class StreamValidator: class StreamValidator:
def __init__(self, timeout=10, user_agent=None): def __init__(self, timeout=10, user_agent=None):
self.timeout = timeout self.timeout = timeout
self.session = requests.Session() self.session = requests.Session()
self.session.headers.update({ self.session.headers.update(
'User-Agent': user_agent or 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36' {
}) "User-Agent": user_agent
or (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/120.0.0.0 Safari/537.36"
)
}
)
def validate_stream(self, url): def validate_stream(self, url):
"""Validate a media stream URL with multiple fallback checks""" """Validate a media stream URL with multiple fallback checks"""
try: try:
headers = {'Range': 'bytes=0-1024'} headers = {"Range": "bytes=0-1024"}
with self.session.get( with self.session.get(
url, url,
headers=headers, headers=headers,
timeout=self.timeout, timeout=self.timeout,
stream=True, stream=True,
allow_redirects=True allow_redirects=True,
) as response: ) as response:
if response.status_code not in [200, 206]: if response.status_code not in [200, 206]:
return False, f"Invalid status code: {response.status_code}" return False, f"Invalid status code: {response.status_code}"
content_type = response.headers.get('Content-Type', '') content_type = response.headers.get("Content-Type", "")
if not self._is_valid_content_type(content_type): if not self._is_valid_content_type(content_type):
return False, f"Invalid content type: {content_type}" return False, f"Invalid content type: {content_type}"
@@ -49,10 +58,13 @@ class StreamValidator:
def _is_valid_content_type(self, content_type): def _is_valid_content_type(self, content_type):
valid_types = [ valid_types = [
'video/mp2t', 'application/vnd.apple.mpegurl', "video/mp2t",
'application/dash+xml', 'video/mp4', "application/vnd.apple.mpegurl",
'video/webm', 'application/octet-stream', "application/dash+xml",
'application/x-mpegURL' "video/mp4",
"video/webm",
"application/octet-stream",
"application/x-mpegURL",
] ]
return any(ct in content_type for ct in valid_types) return any(ct in content_type for ct in valid_types)
@@ -60,36 +72,34 @@ class StreamValidator:
"""Extract stream URLs from M3U playlist file""" """Extract stream URLs from M3U playlist file"""
urls = [] urls = []
try: try:
with open(file_path, 'r') as f: with open(file_path) as f:
for line in f: for line in f:
line = line.strip() line = line.strip()
if line and not line.startswith('#'): if line and not line.startswith("#"):
urls.append(line) urls.append(line)
except Exception as e: except Exception as e:
logging.error(f"Error reading playlist file: {str(e)}") logging.error(f"Error reading playlist file: {str(e)}")
raise raise
return urls return urls
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Validate streaming URLs from command line arguments or playlist files', description=(
formatter_class=argparse.ArgumentDefaultsHelpFormatter "Validate streaming URLs from command line arguments or playlist files"
),
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
) )
parser.add_argument( parser.add_argument(
'sources', "sources", nargs="+", help="List of URLs or file paths containing stream URLs"
nargs='+',
help='List of URLs or file paths containing stream URLs'
) )
parser.add_argument( parser.add_argument(
'--timeout', "--timeout", type=int, default=20, help="Timeout in seconds for stream checks"
type=int,
default=20,
help='Timeout in seconds for stream checks'
) )
parser.add_argument( parser.add_argument(
'--output', "--output",
default='deadstreams.txt', default="deadstreams.txt",
help='Output file name for inactive streams' help="Output file name for inactive streams",
) )
args = parser.parse_args() args = parser.parse_args()
@@ -97,8 +107,8 @@ def main():
# Configure logging # Configure logging
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s', format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[logging.FileHandler('stream_check.log'), logging.StreamHandler()] handlers=[logging.FileHandler("stream_check.log"), logging.StreamHandler()],
) )
validator = StreamValidator(timeout=args.timeout) validator = StreamValidator(timeout=args.timeout)
@@ -127,9 +137,10 @@ def main():
# Save dead streams to file # Save dead streams to file
if dead_streams: if dead_streams:
with open(args.output, 'w') as f: with open(args.output, "w") as f:
f.write('\n'.join(dead_streams)) f.write("\n".join(dead_streams))
logging.info(f"Found {len(dead_streams)} dead streams. Saved to {args.output}.") logging.info(f"Found {len(dead_streams)} dead streams. Saved to {args.output}.")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -1,11 +1,13 @@
import os import os
import boto3 import boto3
from app.models import Base
from .constants import AWS_REGION
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from functools import lru_cache
from app.models import Base
from .constants import AWS_REGION
def get_db_credentials(): def get_db_credentials():
"""Fetch and cache DB credentials from environment or SSM Parameter Store""" """Fetch and cache DB credentials from environment or SSM Parameter Store"""
@@ -15,24 +17,35 @@ def get_db_credentials():
f"@{os.getenv('DB_HOST')}/{os.getenv('DB_NAME')}" f"@{os.getenv('DB_HOST')}/{os.getenv('DB_NAME')}"
) )
ssm = boto3.client('ssm', region_name=AWS_REGION) ssm = boto3.client("ssm", region_name=AWS_REGION)
try: try:
host = ssm.get_parameter(Name='/iptv-updater/DB_HOST', WithDecryption=True)['Parameter']['Value'] host = ssm.get_parameter(Name="/iptv-updater/DB_HOST", WithDecryption=True)[
user = ssm.get_parameter(Name='/iptv-updater/DB_USER', WithDecryption=True)['Parameter']['Value'] "Parameter"
password = ssm.get_parameter(Name='/iptv-updater/DB_PASSWORD', WithDecryption=True)['Parameter']['Value'] ]["Value"]
dbname = ssm.get_parameter(Name='/iptv-updater/DB_NAME', WithDecryption=True)['Parameter']['Value'] user = ssm.get_parameter(Name="/iptv-updater/DB_USER", WithDecryption=True)[
"Parameter"
]["Value"]
password = ssm.get_parameter(
Name="/iptv-updater/DB_PASSWORD", WithDecryption=True
)["Parameter"]["Value"]
dbname = ssm.get_parameter(Name="/iptv-updater/DB_NAME", WithDecryption=True)[
"Parameter"
]["Value"]
return f"postgresql://{user}:{password}@{host}/{dbname}" return f"postgresql://{user}:{password}@{host}/{dbname}"
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to fetch DB credentials from SSM: {str(e)}") raise RuntimeError(f"Failed to fetch DB credentials from SSM: {str(e)}")
# Initialize engine and session maker # Initialize engine and session maker
engine = create_engine(get_db_credentials()) engine = create_engine(get_db_credentials())
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def init_db(): def init_db():
"""Initialize database by creating all tables""" """Initialize database by creating all tables"""
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
def get_db(): def get_db():
"""Dependency for getting database session""" """Dependency for getting database session"""
db = SessionLocal() db = SessionLocal()

View File

@@ -1,17 +1,14 @@
import os import os
from aws_cdk import (
Duration, from aws_cdk import CfnOutput, Duration, RemovalPolicy, Stack
RemovalPolicy, from aws_cdk import aws_cognito as cognito
Stack, from aws_cdk import aws_ec2 as ec2
aws_ec2 as ec2, from aws_cdk import aws_iam as iam
aws_iam as iam, from aws_cdk import aws_rds as rds
aws_cognito as cognito, from aws_cdk import aws_ssm as ssm
aws_rds as rds,
aws_ssm as ssm,
CfnOutput
)
from constructs import Construct from constructs import Construct
class IptvUpdaterStack(Stack): class IptvUpdaterStack(Stack):
def __init__( def __init__(
self, self,
@@ -23,58 +20,50 @@ class IptvUpdaterStack(Stack):
ssh_public_key: str, ssh_public_key: str,
repo_url: str, repo_url: str,
letsencrypt_email: str, letsencrypt_email: str,
**kwargs **kwargs,
) -> None: ) -> None:
super().__init__(scope, construct_id, **kwargs) super().__init__(scope, construct_id, **kwargs)
# Create VPC # Create VPC
vpc = ec2.Vpc(self, "IptvUpdaterVPC", vpc = ec2.Vpc(
self,
"IptvUpdaterVPC",
max_azs=2, # Need at least 2 AZs for RDS subnet group max_azs=2, # Need at least 2 AZs for RDS subnet group
nat_gateways=0, # No NAT Gateway to stay in free tier nat_gateways=0, # No NAT Gateway to stay in free tier
subnet_configuration=[ subnet_configuration=[
ec2.SubnetConfiguration( ec2.SubnetConfiguration(
name="public", name="public", subnet_type=ec2.SubnetType.PUBLIC, cidr_mask=24
subnet_type=ec2.SubnetType.PUBLIC,
cidr_mask=24
), ),
ec2.SubnetConfiguration( ec2.SubnetConfiguration(
name="private", name="private",
subnet_type=ec2.SubnetType.PRIVATE_ISOLATED, subnet_type=ec2.SubnetType.PRIVATE_ISOLATED,
cidr_mask=24 cidr_mask=24,
) ),
] ],
) )
# Security Group # Security Group
security_group = ec2.SecurityGroup( security_group = ec2.SecurityGroup(
self, "IptvUpdaterSG", self, "IptvUpdaterSG", vpc=vpc, allow_all_outbound=True
vpc=vpc,
allow_all_outbound=True
) )
security_group.add_ingress_rule( security_group.add_ingress_rule(
ec2.Peer.any_ipv4(), ec2.Peer.any_ipv4(), ec2.Port.tcp(443), "Allow HTTPS traffic"
ec2.Port.tcp(443),
"Allow HTTPS traffic"
) )
security_group.add_ingress_rule( security_group.add_ingress_rule(
ec2.Peer.any_ipv4(), ec2.Peer.any_ipv4(), ec2.Port.tcp(80), "Allow HTTP traffic"
ec2.Port.tcp(80),
"Allow HTTP traffic"
) )
security_group.add_ingress_rule( security_group.add_ingress_rule(
ec2.Peer.any_ipv4(), ec2.Peer.any_ipv4(), ec2.Port.tcp(22), "Allow SSH traffic"
ec2.Port.tcp(22),
"Allow SSH traffic"
) )
# Allow PostgreSQL port for tunneling restricted to developer IP # Allow PostgreSQL port for tunneling restricted to developer IP
security_group.add_ingress_rule( security_group.add_ingress_rule(
ec2.Peer.ipv4("47.189.88.48/32"), # Developer IP ec2.Peer.ipv4("47.189.88.48/32"), # Developer IP
ec2.Port.tcp(5432), ec2.Port.tcp(5432),
"Allow PostgreSQL traffic for tunneling" "Allow PostgreSQL traffic for tunneling",
) )
# Key pair for IPTV Updater instance # Key pair for IPTV Updater instance
@@ -82,13 +71,14 @@ class IptvUpdaterStack(Stack):
self, self,
"IptvUpdaterKeyPair", "IptvUpdaterKeyPair",
key_pair_name="iptv-updater-key", key_pair_name="iptv-updater-key",
public_key_material=ssh_public_key public_key_material=ssh_public_key,
) )
# Create IAM role for EC2 # Create IAM role for EC2
role = iam.Role( role = iam.Role(
self, "IptvUpdaterRole", self,
assumed_by=iam.ServicePrincipal("ec2.amazonaws.com") "IptvUpdaterRole",
assumed_by=iam.ServicePrincipal("ec2.amazonaws.com"),
) )
# Add SSM managed policy # Add SSM managed policy
@@ -99,37 +89,36 @@ class IptvUpdaterStack(Stack):
) )
# Add EC2 describe permissions # Add EC2 describe permissions
role.add_to_policy(iam.PolicyStatement( role.add_to_policy(
actions=["ec2:DescribeInstances"], iam.PolicyStatement(actions=["ec2:DescribeInstances"], resources=["*"])
resources=["*"] )
))
# Add SSM SendCommand permissions # Add SSM SendCommand permissions
role.add_to_policy(iam.PolicyStatement( role.add_to_policy(
iam.PolicyStatement(
actions=["ssm:SendCommand"], actions=["ssm:SendCommand"],
resources=[ resources=[
f"arn:aws:ec2:{self.region}:{self.account}:instance/*", # Allow on all EC2 instances # Allow on all EC2 instances
f"arn:aws:ssm:{self.region}:{self.account}:document/AWS-RunShellScript" # Required for the RunShellScript document f"arn:aws:ec2:{self.region}:{self.account}:instance/*",
] # Required for the RunShellScript document
)) f"arn:aws:ssm:{self.region}:{self.account}:document/AWS-RunShellScript",
],
)
)
# Add Cognito permissions to instance role # Add Cognito permissions to instance role
role.add_managed_policy( role.add_managed_policy(
iam.ManagedPolicy.from_aws_managed_policy_name( iam.ManagedPolicy.from_aws_managed_policy_name("AmazonCognitoReadOnly")
"AmazonCognitoReadOnly"
)
) )
# EC2 Instance # EC2 Instance
instance = ec2.Instance( instance = ec2.Instance(
self, "IptvUpdaterInstance", self,
"IptvUpdaterInstance",
vpc=vpc, vpc=vpc,
vpc_subnets=ec2.SubnetSelection( vpc_subnets=ec2.SubnetSelection(subnet_type=ec2.SubnetType.PUBLIC),
subnet_type=ec2.SubnetType.PUBLIC
),
instance_type=ec2.InstanceType.of( instance_type=ec2.InstanceType.of(
ec2.InstanceClass.T2, ec2.InstanceClass.T2, ec2.InstanceSize.MICRO
ec2.InstanceSize.MICRO
), ),
machine_image=ec2.AmazonLinuxImage( machine_image=ec2.AmazonLinuxImage(
generation=ec2.AmazonLinuxGeneration.AMAZON_LINUX_2023 generation=ec2.AmazonLinuxGeneration.AMAZON_LINUX_2023
@@ -138,7 +127,7 @@ class IptvUpdaterStack(Stack):
key_pair=key_pair, key_pair=key_pair,
role=role, role=role,
# Option: 1: Enable auto-assign public IP (free tier compatible) # Option: 1: Enable auto-assign public IP (free tier compatible)
associate_public_ip_address=True associate_public_ip_address=True,
) )
# Option: 2: Create Elastic IP (not free tier compatible) # Option: 2: Create Elastic IP (not free tier compatible)
@@ -150,7 +139,8 @@ class IptvUpdaterStack(Stack):
# Add Cognito User Pool # Add Cognito User Pool
user_pool = cognito.UserPool( user_pool = cognito.UserPool(
self, "IptvUpdaterUserPool", self,
"IptvUpdaterUserPool",
user_pool_name="iptv-updater-users", user_pool_name="iptv-updater-users",
self_sign_up_enabled=False, # Only admins can create users self_sign_up_enabled=False, # Only admins can create users
password_policy=cognito.PasswordPolicy( password_policy=cognito.PasswordPolicy(
@@ -158,35 +148,31 @@ class IptvUpdaterStack(Stack):
require_lowercase=True, require_lowercase=True,
require_digits=True, require_digits=True,
require_symbols=True, require_symbols=True,
require_uppercase=True require_uppercase=True,
), ),
account_recovery=cognito.AccountRecovery.EMAIL_ONLY, account_recovery=cognito.AccountRecovery.EMAIL_ONLY,
removal_policy=RemovalPolicy.DESTROY removal_policy=RemovalPolicy.DESTROY,
) )
# Add App Client with the correct callback URL # Add App Client with the correct callback URL
client = user_pool.add_client("IptvUpdaterClient", client = user_pool.add_client(
"IptvUpdaterClient",
access_token_validity=Duration.minutes(60), access_token_validity=Duration.minutes(60),
id_token_validity=Duration.minutes(60), id_token_validity=Duration.minutes(60),
refresh_token_validity=Duration.days(1), refresh_token_validity=Duration.days(1),
auth_flows=cognito.AuthFlow( auth_flows=cognito.AuthFlow(user_password=True),
user_password=True
),
o_auth=cognito.OAuthSettings( o_auth=cognito.OAuthSettings(
flows=cognito.OAuthFlows( flows=cognito.OAuthFlows(implicit_code_grant=True)
implicit_code_grant=True
)
), ),
prevent_user_existence_errors=True, prevent_user_existence_errors=True,
generate_secret=True, generate_secret=True,
enable_token_revocation=True enable_token_revocation=True,
) )
# Add domain for hosted UI # Add domain for hosted UI
domain = user_pool.add_domain("IptvUpdaterDomain", domain = user_pool.add_domain(
cognito_domain=cognito.CognitoDomainOptions( "IptvUpdaterDomain",
domain_prefix="iptv-updater" cognito_domain=cognito.CognitoDomainOptions(domain_prefix="iptv-updater"),
)
) )
# Read the userdata script with proper path resolution # Read the userdata script with proper path resolution
@@ -203,39 +189,49 @@ class IptvUpdaterStack(Stack):
f'export FREEDNS_Password="{freedns_password}"', f'export FREEDNS_Password="{freedns_password}"',
f'export DOMAIN_NAME="{domain_name}"', f'export DOMAIN_NAME="{domain_name}"',
f'export REPO_URL="{repo_url}"', f'export REPO_URL="{repo_url}"',
f'export LETSENCRYPT_EMAIL="{letsencrypt_email}"' f'export LETSENCRYPT_EMAIL="{letsencrypt_email}"',
) )
# Adds one or more commands to the userdata object. # Adds one or more commands to the userdata object.
userdata.add_commands( userdata.add_commands(
f'echo "COGNITO_USER_POOL_ID={user_pool.user_pool_id}" >> /etc/environment', (
f'echo "COGNITO_CLIENT_ID={client.user_pool_client_id}" >> /etc/environment', f'echo "COGNITO_USER_POOL_ID='
f'echo "COGNITO_CLIENT_SECRET={client.user_pool_client_secret.to_string()}" >> /etc/environment', f'{user_pool.user_pool_id}" >> /etc/environment'
f'echo "DOMAIN_NAME={domain_name}" >> /etc/environment' ),
(
f'echo "COGNITO_CLIENT_ID='
f'{client.user_pool_client_id}" >> /etc/environment'
),
(
f'echo "COGNITO_CLIENT_SECRET='
f'{client.user_pool_client_secret.to_string()}" >> /etc/environment'
),
f'echo "DOMAIN_NAME={domain_name}" >> /etc/environment',
) )
userdata.add_commands(str(userdata_file, 'utf-8')) userdata.add_commands(str(userdata_file, "utf-8"))
# Create RDS Security Group # Create RDS Security Group
rds_sg = ec2.SecurityGroup( rds_sg = ec2.SecurityGroup(
self, "RdsSecurityGroup", self,
"RdsSecurityGroup",
vpc=vpc, vpc=vpc,
description="Security group for RDS PostgreSQL" description="Security group for RDS PostgreSQL",
) )
rds_sg.add_ingress_rule( rds_sg.add_ingress_rule(
security_group, security_group,
ec2.Port.tcp(5432), ec2.Port.tcp(5432),
"Allow PostgreSQL access from EC2 instance" "Allow PostgreSQL access from EC2 instance",
) )
# Create RDS PostgreSQL instance (free tier compatible - db.t3.micro) # Create RDS PostgreSQL instance (free tier compatible - db.t3.micro)
db = rds.DatabaseInstance( db = rds.DatabaseInstance(
self, "IptvUpdaterDB", self,
"IptvUpdaterDB",
engine=rds.DatabaseInstanceEngine.postgres( engine=rds.DatabaseInstanceEngine.postgres(
version=rds.PostgresEngineVersion.VER_13 version=rds.PostgresEngineVersion.VER_13
), ),
instance_type=ec2.InstanceType.of( instance_type=ec2.InstanceType.of(
ec2.InstanceClass.T3, ec2.InstanceClass.T3, ec2.InstanceSize.MICRO
ec2.InstanceSize.MICRO
), ),
vpc=vpc, vpc=vpc,
vpc_subnets=ec2.SubnetSelection( vpc_subnets=ec2.SubnetSelection(
@@ -247,39 +243,43 @@ class IptvUpdaterStack(Stack):
database_name="iptv_updater", database_name="iptv_updater",
removal_policy=RemovalPolicy.DESTROY, removal_policy=RemovalPolicy.DESTROY,
deletion_protection=False, deletion_protection=False,
publicly_accessible=False # Avoid public IPv4 charges publicly_accessible=False, # Avoid public IPv4 charges
) )
# Add RDS permissions to instance role # Add RDS permissions to instance role
role.add_managed_policy( role.add_managed_policy(
iam.ManagedPolicy.from_aws_managed_policy_name( iam.ManagedPolicy.from_aws_managed_policy_name("AmazonRDSFullAccess")
"AmazonRDSFullAccess"
)
) )
# Store DB connection info in SSM Parameter Store # Store DB connection info in SSM Parameter Store
ssm.StringParameter(self, "DBHostParam", ssm.StringParameter(
self,
"DBHostParam",
parameter_name="/iptv-updater/DB_HOST", parameter_name="/iptv-updater/DB_HOST",
string_value=db.db_instance_endpoint_address string_value=db.db_instance_endpoint_address,
) )
ssm.StringParameter(self, "DBNameParam", ssm.StringParameter(
self,
"DBNameParam",
parameter_name="/iptv-updater/DB_NAME", parameter_name="/iptv-updater/DB_NAME",
string_value="iptv_updater" string_value="iptv_updater",
) )
ssm.StringParameter(self, "DBUserParam", ssm.StringParameter(
self,
"DBUserParam",
parameter_name="/iptv-updater/DB_USER", parameter_name="/iptv-updater/DB_USER",
string_value=db.secret.secret_value_from_json("username").to_string() string_value=db.secret.secret_value_from_json("username").to_string(),
) )
ssm.StringParameter(self, "DBPassParam", ssm.StringParameter(
self,
"DBPassParam",
parameter_name="/iptv-updater/DB_PASSWORD", parameter_name="/iptv-updater/DB_PASSWORD",
string_value=db.secret.secret_value_from_json("password").to_string() string_value=db.secret.secret_value_from_json("password").to_string(),
) )
# Add SSM read permissions to instance role # Add SSM read permissions to instance role
role.add_managed_policy( role.add_managed_policy(
iam.ManagedPolicy.from_aws_managed_policy_name( iam.ManagedPolicy.from_aws_managed_policy_name("AmazonSSMReadOnlyAccess")
"AmazonSSMReadOnlyAccess"
)
) )
# Update instance with userdata # Update instance with userdata
@@ -293,6 +293,8 @@ class IptvUpdaterStack(Stack):
# CfnOutput(self, "InstancePublicIP", value=eip.attr_public_ip) # CfnOutput(self, "InstancePublicIP", value=eip.attr_public_ip)
CfnOutput(self, "UserPoolId", value=user_pool.user_pool_id) CfnOutput(self, "UserPoolId", value=user_pool.user_pool_id)
CfnOutput(self, "UserPoolClientId", value=client.user_pool_client_id) CfnOutput(self, "UserPoolClientId", value=client.user_pool_client_id)
CfnOutput(self, "CognitoDomainUrl", CfnOutput(
value=f"https://{domain.domain_name}.auth.{self.region}.amazoncognito.com" self,
"CognitoDomainUrl",
value=f"https://{domain.domain_name}.auth.{self.region}.amazoncognito.com",
) )

View File

@@ -1,5 +1,10 @@
[tool.ruff] [tool.ruff]
line-length = 88 line-length = 88
exclude = [
"alembic/versions/*.py", # Auto-generated Alembic migration files
]
[tool.ruff.lint]
select = [ select = [
"E", # pycodestyle errors "E", # pycodestyle errors
"F", # pyflakes "F", # pyflakes
@@ -9,7 +14,13 @@ select = [
] ]
ignore = [] ignore = []
[tool.ruff.isort] [tool.ruff.lint.per-file-ignores]
"tests/**/*.py" = [
"F811", # redefinition of unused name
"F401", # unused import
]
[tool.ruff.lint.isort]
known-first-party = ["app"] known-first-party = ["app"]
[tool.ruff.format] [tool.ruff.format]

View File

@@ -4,6 +4,10 @@
npm install -g aws-cdk npm install -g aws-cdk
python3 -m pip install -r requirements.txt python3 -m pip install -r requirements.txt
# Install and configure pre-commit hooks
pre-commit install
pre-commit install-hooks
# Initialize and run database migrations # Initialize and run database migrations
alembic upgrade head alembic upgrade head

View File

@@ -1,5 +1,6 @@
from unittest.mock import MagicMock, patch
import pytest import pytest
from unittest.mock import patch, MagicMock
from fastapi import HTTPException, status from fastapi import HTTPException, status
# Test constants # Test constants
@@ -7,12 +8,15 @@ TEST_CLIENT_ID = "test_client_id"
TEST_CLIENT_SECRET = "test_client_secret" TEST_CLIENT_SECRET = "test_client_secret"
# Patch constants before importing the module # Patch constants before importing the module
with patch("app.utils.constants.COGNITO_CLIENT_ID", TEST_CLIENT_ID), \ with (
patch("app.utils.constants.COGNITO_CLIENT_SECRET", TEST_CLIENT_SECRET): patch("app.utils.constants.COGNITO_CLIENT_ID", TEST_CLIENT_ID),
from app.auth.cognito import initiate_auth, get_user_from_token patch("app.utils.constants.COGNITO_CLIENT_SECRET", TEST_CLIENT_SECRET),
):
from app.auth.cognito import get_user_from_token, initiate_auth
from app.models.auth import CognitoUser from app.models.auth import CognitoUser
from app.utils.constants import USER_ROLE_ATTRIBUTE from app.utils.constants import USER_ROLE_ATTRIBUTE
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_cognito_client(): def mock_cognito_client():
with patch("app.auth.cognito.cognito_client") as mock_client: with patch("app.auth.cognito.cognito_client") as mock_client:
@@ -26,13 +30,14 @@ def mock_cognito_client():
) )
yield mock_client yield mock_client
def test_initiate_auth_success(mock_cognito_client): def test_initiate_auth_success(mock_cognito_client):
# Mock successful authentication response # Mock successful authentication response
mock_cognito_client.initiate_auth.return_value = { mock_cognito_client.initiate_auth.return_value = {
"AuthenticationResult": { "AuthenticationResult": {
"AccessToken": "mock_access_token", "AccessToken": "mock_access_token",
"IdToken": "mock_id_token", "IdToken": "mock_id_token",
"RefreshToken": "mock_refresh_token" "RefreshToken": "mock_refresh_token",
} }
} }
@@ -40,27 +45,35 @@ def test_initiate_auth_success(mock_cognito_client):
assert result == { assert result == {
"AccessToken": "mock_access_token", "AccessToken": "mock_access_token",
"IdToken": "mock_id_token", "IdToken": "mock_id_token",
"RefreshToken": "mock_refresh_token" "RefreshToken": "mock_refresh_token",
} }
def test_initiate_auth_with_secret_hash(mock_cognito_client): def test_initiate_auth_with_secret_hash(mock_cognito_client):
with patch("app.auth.cognito.calculate_secret_hash", return_value="mocked_secret_hash") as mock_hash: with patch(
"app.auth.cognito.calculate_secret_hash", return_value="mocked_secret_hash"
) as mock_hash:
mock_cognito_client.initiate_auth.return_value = { mock_cognito_client.initiate_auth.return_value = {
"AuthenticationResult": {"AccessToken": "token"} "AuthenticationResult": {"AccessToken": "token"}
} }
result = initiate_auth("test_user", "test_pass") initiate_auth("test_user", "test_pass")
# Verify calculate_secret_hash was called # Verify calculate_secret_hash was called
mock_hash.assert_called_once_with("test_user", TEST_CLIENT_ID, TEST_CLIENT_SECRET) mock_hash.assert_called_once_with(
"test_user", TEST_CLIENT_ID, TEST_CLIENT_SECRET
)
# Verify SECRET_HASH was included in auth params # Verify SECRET_HASH was included in auth params
call_args = mock_cognito_client.initiate_auth.call_args[1] call_args = mock_cognito_client.initiate_auth.call_args[1]
assert "SECRET_HASH" in call_args["AuthParameters"] assert "SECRET_HASH" in call_args["AuthParameters"]
assert call_args["AuthParameters"]["SECRET_HASH"] == "mocked_secret_hash" assert call_args["AuthParameters"]["SECRET_HASH"] == "mocked_secret_hash"
def test_initiate_auth_not_authorized(mock_cognito_client): def test_initiate_auth_not_authorized(mock_cognito_client):
mock_cognito_client.initiate_auth.side_effect = mock_cognito_client.exceptions.NotAuthorizedException() mock_cognito_client.initiate_auth.side_effect = (
mock_cognito_client.exceptions.NotAuthorizedException()
)
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
initiate_auth("invalid_user", "wrong_pass") initiate_auth("invalid_user", "wrong_pass")
@@ -68,8 +81,11 @@ def test_initiate_auth_not_authorized(mock_cognito_client):
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert exc_info.value.detail == "Invalid username or password" assert exc_info.value.detail == "Invalid username or password"
def test_initiate_auth_user_not_found(mock_cognito_client): def test_initiate_auth_user_not_found(mock_cognito_client):
mock_cognito_client.initiate_auth.side_effect = mock_cognito_client.exceptions.UserNotFoundException() mock_cognito_client.initiate_auth.side_effect = (
mock_cognito_client.exceptions.UserNotFoundException()
)
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
initiate_auth("nonexistent_user", "any_pass") initiate_auth("nonexistent_user", "any_pass")
@@ -77,6 +93,7 @@ def test_initiate_auth_user_not_found(mock_cognito_client):
assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND
assert exc_info.value.detail == "User not found" assert exc_info.value.detail == "User not found"
def test_initiate_auth_generic_error(mock_cognito_client): def test_initiate_auth_generic_error(mock_cognito_client):
mock_cognito_client.initiate_auth.side_effect = Exception("Some error") mock_cognito_client.initiate_auth.side_effect = Exception("Some error")
@@ -86,13 +103,14 @@ def test_initiate_auth_generic_error(mock_cognito_client):
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert "An error occurred during authentication" in exc_info.value.detail assert "An error occurred during authentication" in exc_info.value.detail
def test_get_user_from_token_success(mock_cognito_client): def test_get_user_from_token_success(mock_cognito_client):
mock_response = { mock_response = {
"Username": "test_user", "Username": "test_user",
"UserAttributes": [ "UserAttributes": [
{"Name": "sub", "Value": "123"}, {"Name": "sub", "Value": "123"},
{"Name": USER_ROLE_ATTRIBUTE, "Value": "admin,user"} {"Name": USER_ROLE_ATTRIBUTE, "Value": "admin,user"},
] ],
} }
mock_cognito_client.get_user.return_value = mock_response mock_cognito_client.get_user.return_value = mock_response
@@ -102,10 +120,11 @@ def test_get_user_from_token_success(mock_cognito_client):
assert result.username == "test_user" assert result.username == "test_user"
assert set(result.roles) == {"admin", "user"} assert set(result.roles) == {"admin", "user"}
def test_get_user_from_token_no_roles(mock_cognito_client): def test_get_user_from_token_no_roles(mock_cognito_client):
mock_response = { mock_response = {
"Username": "test_user", "Username": "test_user",
"UserAttributes": [{"Name": "sub", "Value": "123"}] "UserAttributes": [{"Name": "sub", "Value": "123"}],
} }
mock_cognito_client.get_user.return_value = mock_response mock_cognito_client.get_user.return_value = mock_response
@@ -115,8 +134,11 @@ def test_get_user_from_token_no_roles(mock_cognito_client):
assert result.username == "test_user" assert result.username == "test_user"
assert result.roles == [] assert result.roles == []
def test_get_user_from_token_invalid_token(mock_cognito_client): def test_get_user_from_token_invalid_token(mock_cognito_client):
mock_cognito_client.get_user.side_effect = mock_cognito_client.exceptions.NotAuthorizedException() mock_cognito_client.get_user.side_effect = (
mock_cognito_client.exceptions.NotAuthorizedException()
)
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
get_user_from_token("invalid_token") get_user_from_token("invalid_token")
@@ -124,8 +146,11 @@ def test_get_user_from_token_invalid_token(mock_cognito_client):
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert exc_info.value.detail == "Invalid or expired token." assert exc_info.value.detail == "Invalid or expired token."
def test_get_user_from_token_user_not_found(mock_cognito_client): def test_get_user_from_token_user_not_found(mock_cognito_client):
mock_cognito_client.get_user.side_effect = mock_cognito_client.exceptions.UserNotFoundException() mock_cognito_client.get_user.side_effect = (
mock_cognito_client.exceptions.UserNotFoundException()
)
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
get_user_from_token("token_for_nonexistent_user") get_user_from_token("token_for_nonexistent_user")
@@ -133,6 +158,7 @@ def test_get_user_from_token_user_not_found(mock_cognito_client):
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert exc_info.value.detail == "User not found or invalid token." assert exc_info.value.detail == "User not found or invalid token."
def test_get_user_from_token_generic_error(mock_cognito_client): def test_get_user_from_token_generic_error(mock_cognito_client):
mock_cognito_client.get_user.side_effect = Exception("Some error") mock_cognito_client.get_user.side_effect = Exception("Some error")

View File

@@ -1,9 +1,11 @@
import os
import pytest
import importlib import importlib
import os
import pytest
from fastapi import Depends, HTTPException, Request
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from fastapi import HTTPException, Depends, Request
from app.auth.dependencies import get_current_user, require_roles, oauth2_scheme from app.auth.dependencies import get_current_user, oauth2_scheme, require_roles
from app.models.auth import CognitoUser from app.models.auth import CognitoUser
# Mock user for testing # Mock user for testing
@@ -11,24 +13,30 @@ TEST_USER = CognitoUser(
username="testuser", username="testuser",
email="test@example.com", email="test@example.com",
roles=["admin", "user"], roles=["admin", "user"],
groups=["test_group"] groups=["test_group"],
) )
# Mock the underlying get_user_from_token function # Mock the underlying get_user_from_token function
def mock_get_user_from_token(token: str) -> CognitoUser: def mock_get_user_from_token(token: str) -> CognitoUser:
if token == "valid_token": if token == "valid_token":
return TEST_USER return TEST_USER
raise HTTPException(status_code=401, detail="Invalid token") raise HTTPException(status_code=401, detail="Invalid token")
# Mock endpoint for testing the require_roles decorator # Mock endpoint for testing the require_roles decorator
@require_roles("admin") @require_roles("admin")
async def mock_protected_endpoint(user: CognitoUser = Depends(get_current_user)): async def mock_protected_endpoint(user: CognitoUser = Depends(get_current_user)):
return {"message": "Success", "user": user.username} return {"message": "Success", "user": user.username}
# Patch the get_user_from_token function for testing # Patch the get_user_from_token function for testing
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_auth(monkeypatch): def mock_auth(monkeypatch):
monkeypatch.setattr("app.auth.dependencies.get_user_from_token", mock_get_user_from_token) monkeypatch.setattr(
"app.auth.dependencies.get_user_from_token", mock_get_user_from_token
)
# Test get_current_user dependency # Test get_current_user dependency
def test_get_current_user_success(): def test_get_current_user_success():
@@ -37,54 +45,53 @@ def test_get_current_user_success():
assert user.username == "testuser" assert user.username == "testuser"
assert user.roles == ["admin", "user"] assert user.roles == ["admin", "user"]
def test_get_current_user_invalid_token(): def test_get_current_user_invalid_token():
with pytest.raises(HTTPException) as exc: with pytest.raises(HTTPException) as exc:
get_current_user("invalid_token") get_current_user("invalid_token")
assert exc.value.status_code == 401 assert exc.value.status_code == 401
# Test require_roles decorator # Test require_roles decorator
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_require_roles_success(): async def test_require_roles_success():
# Create test user with required role # Create test user with required role
user = CognitoUser( user = CognitoUser(
username="testuser", username="testuser", email="test@example.com", roles=["admin"], groups=[]
email="test@example.com",
roles=["admin"],
groups=[]
) )
result = await mock_protected_endpoint(user=user) result = await mock_protected_endpoint(user=user)
assert result == {"message": "Success", "user": "testuser"} assert result == {"message": "Success", "user": "testuser"}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_require_roles_missing_role(): async def test_require_roles_missing_role():
# Create test user without required role # Create test user without required role
user = CognitoUser( user = CognitoUser(
username="testuser", username="testuser", email="test@example.com", roles=["user"], groups=[]
email="test@example.com",
roles=["user"],
groups=[]
) )
with pytest.raises(HTTPException) as exc: with pytest.raises(HTTPException) as exc:
await mock_protected_endpoint(user=user) await mock_protected_endpoint(user=user)
assert exc.value.status_code == 403 assert exc.value.status_code == 403
assert exc.value.detail == "You do not have the required roles to access this endpoint." assert (
exc.value.detail
== "You do not have the required roles to access this endpoint."
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_require_roles_no_roles(): async def test_require_roles_no_roles():
# Create test user with no roles # Create test user with no roles
user = CognitoUser( user = CognitoUser(
username="testuser", username="testuser", email="test@example.com", roles=[], groups=[]
email="test@example.com",
roles=[],
groups=[]
) )
with pytest.raises(HTTPException) as exc: with pytest.raises(HTTPException) as exc:
await mock_protected_endpoint(user=user) await mock_protected_endpoint(user=user)
assert exc.value.status_code == 403 assert exc.value.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_require_roles_multiple_roles(): async def test_require_roles_multiple_roles():
# Test requiring multiple roles # Test requiring multiple roles
@@ -97,7 +104,7 @@ async def test_require_roles_multiple_roles():
username="testuser", username="testuser",
email="test@example.com", email="test@example.com",
roles=["admin", "super_user", "user"], roles=["admin", "super_user", "user"],
groups=[] groups=[],
) )
result = await mock_multi_role_endpoint(user=user_with_roles) result = await mock_multi_role_endpoint(user=user_with_roles)
assert result == {"message": "Success"} assert result == {"message": "Success"}
@@ -107,27 +114,30 @@ async def test_require_roles_multiple_roles():
username="testuser", username="testuser",
email="test@example.com", email="test@example.com",
roles=["admin", "user"], roles=["admin", "user"],
groups=[] groups=[],
) )
with pytest.raises(HTTPException) as exc: with pytest.raises(HTTPException) as exc:
await mock_multi_role_endpoint(user=user_missing_role) await mock_multi_role_endpoint(user=user_missing_role)
assert exc.value.status_code == 403 assert exc.value.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_oauth2_scheme_configuration(): async def test_oauth2_scheme_configuration():
# Verify that we have a properly configured OAuth2PasswordBearer instance # Verify that we have a properly configured OAuth2PasswordBearer instance
assert isinstance(oauth2_scheme, OAuth2PasswordBearer) assert isinstance(oauth2_scheme, OAuth2PasswordBearer)
# Create a mock request with no Authorization header # Create a mock request with no Authorization header
mock_request = Request(scope={ mock_request = Request(
'type': 'http', scope={
'headers': [], "type": "http",
'method': 'GET', "headers": [],
'scheme': 'http', "method": "GET",
'path': '/', "scheme": "http",
'query_string': b'', "path": "/",
'client': ('127.0.0.1', 8000) "query_string": b"",
}) "client": ("127.0.0.1", 8000),
}
)
# Test that the scheme raises 401 when no token is provided # Test that the scheme raises 401 when no token is provided
with pytest.raises(HTTPException) as exc: with pytest.raises(HTTPException) as exc:
@@ -135,6 +145,7 @@ async def test_oauth2_scheme_configuration():
assert exc.value.status_code == 401 assert exc.value.status_code == 401
assert exc.value.detail == "Not authenticated" assert exc.value.detail == "Not authenticated"
def test_mock_auth_import(monkeypatch): def test_mock_auth_import(monkeypatch):
# Save original env var value # Save original env var value
original_value = os.environ.get("MOCK_AUTH") original_value = os.environ.get("MOCK_AUTH")
@@ -145,11 +156,13 @@ def test_mock_auth_import(monkeypatch):
# Reload the dependencies module to trigger the import condition # Reload the dependencies module to trigger the import condition
import app.auth.dependencies import app.auth.dependencies
importlib.reload(app.auth.dependencies) importlib.reload(app.auth.dependencies)
# Verify that mock_get_user_from_token was imported # Verify that mock_get_user_from_token was imported
from app.auth.dependencies import get_user_from_token from app.auth.dependencies import get_user_from_token
assert get_user_from_token.__module__ == 'app.auth.mock_auth'
assert get_user_from_token.__module__ == "app.auth.mock_auth"
finally: finally:
# Restore original env var # Restore original env var

View File

@@ -1,8 +1,10 @@
import pytest import pytest
from fastapi import HTTPException from fastapi import HTTPException
from app.auth.mock_auth import mock_get_user_from_token, mock_initiate_auth from app.auth.mock_auth import mock_get_user_from_token, mock_initiate_auth
from app.models.auth import CognitoUser from app.models.auth import CognitoUser
def test_mock_get_user_from_token_success(): def test_mock_get_user_from_token_success():
"""Test successful token validation returns expected user""" """Test successful token validation returns expected user"""
user = mock_get_user_from_token("testuser") user = mock_get_user_from_token("testuser")
@@ -10,6 +12,7 @@ def test_mock_get_user_from_token_success():
assert user.username == "testuser" assert user.username == "testuser"
assert user.roles == ["admin"] assert user.roles == ["admin"]
def test_mock_get_user_from_token_invalid(): def test_mock_get_user_from_token_invalid():
"""Test invalid token raises expected exception""" """Test invalid token raises expected exception"""
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
@@ -18,6 +21,7 @@ def test_mock_get_user_from_token_invalid():
assert exc_info.value.status_code == 401 assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Invalid mock token - use 'testuser'" assert exc_info.value.detail == "Invalid mock token - use 'testuser'"
def test_mock_initiate_auth(): def test_mock_initiate_auth():
"""Test mock authentication returns expected token response""" """Test mock authentication returns expected token response"""
result = mock_initiate_auth("any_user", "any_password") result = mock_initiate_auth("any_user", "any_password")
@@ -27,6 +31,7 @@ def test_mock_initiate_auth():
assert result["ExpiresIn"] == 3600 assert result["ExpiresIn"] == 3600
assert result["TokenType"] == "Bearer" assert result["TokenType"] == "Bearer"
def test_mock_initiate_auth_different_credentials(): def test_mock_initiate_auth_different_credentials():
"""Test mock authentication works with any credentials""" """Test mock authentication works with any credentials"""
result1 = mock_initiate_auth("user1", "pass1") result1 = mock_initiate_auth("user1", "pass1")

View File

@@ -1,32 +1,33 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from fastapi.testclient import TestClient
from fastapi import HTTPException, status from fastapi import HTTPException, status
from fastapi.testclient import TestClient
from app.main import app from app.main import app
client = TestClient(app) client = TestClient(app)
@pytest.fixture @pytest.fixture
def mock_successful_auth(): def mock_successful_auth():
return { return {
"AccessToken": "mock_access_token", "AccessToken": "mock_access_token",
"IdToken": "mock_id_token", "IdToken": "mock_id_token",
"RefreshToken": "mock_refresh_token" "RefreshToken": "mock_refresh_token",
} }
@pytest.fixture @pytest.fixture
def mock_successful_auth_no_refresh(): def mock_successful_auth_no_refresh():
return { return {"AccessToken": "mock_access_token", "IdToken": "mock_id_token"}
"AccessToken": "mock_access_token",
"IdToken": "mock_id_token"
}
def test_signin_success(mock_successful_auth): def test_signin_success(mock_successful_auth):
"""Test successful signin with all tokens""" """Test successful signin with all tokens"""
with patch('app.routers.auth.initiate_auth', return_value=mock_successful_auth): with patch("app.routers.auth.initiate_auth", return_value=mock_successful_auth):
response = client.post( response = client.post(
"/auth/signin", "/auth/signin", json={"username": "testuser", "password": "testpass"}
json={"username": "testuser", "password": "testpass"}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -36,12 +37,14 @@ def test_signin_success(mock_successful_auth):
assert data["refresh_token"] == "mock_refresh_token" assert data["refresh_token"] == "mock_refresh_token"
assert data["token_type"] == "Bearer" assert data["token_type"] == "Bearer"
def test_signin_success_no_refresh(mock_successful_auth_no_refresh): def test_signin_success_no_refresh(mock_successful_auth_no_refresh):
"""Test successful signin without refresh token""" """Test successful signin without refresh token"""
with patch('app.routers.auth.initiate_auth', return_value=mock_successful_auth_no_refresh): with patch(
"app.routers.auth.initiate_auth", return_value=mock_successful_auth_no_refresh
):
response = client.post( response = client.post(
"/auth/signin", "/auth/signin", json={"username": "testuser", "password": "testpass"}
json={"username": "testuser", "password": "testpass"}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -51,55 +54,46 @@ def test_signin_success_no_refresh(mock_successful_auth_no_refresh):
assert data["refresh_token"] is None assert data["refresh_token"] is None
assert data["token_type"] == "Bearer" assert data["token_type"] == "Bearer"
def test_signin_invalid_input(): def test_signin_invalid_input():
"""Test signin with invalid input format""" """Test signin with invalid input format"""
# Missing password # Missing password
response = client.post( response = client.post("/auth/signin", json={"username": "testuser"})
"/auth/signin",
json={"username": "testuser"}
)
assert response.status_code == 422 assert response.status_code == 422
# Missing username # Missing username
response = client.post( response = client.post("/auth/signin", json={"password": "testpass"})
"/auth/signin",
json={"password": "testpass"}
)
assert response.status_code == 422 assert response.status_code == 422
# Empty payload # Empty payload
response = client.post( response = client.post("/auth/signin", json={})
"/auth/signin",
json={}
)
assert response.status_code == 422 assert response.status_code == 422
def test_signin_auth_failure(): def test_signin_auth_failure():
"""Test signin with authentication failure""" """Test signin with authentication failure"""
with patch('app.routers.auth.initiate_auth') as mock_auth: with patch("app.routers.auth.initiate_auth") as mock_auth:
mock_auth.side_effect = HTTPException( mock_auth.side_effect = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid username or password" detail="Invalid username or password",
) )
response = client.post( response = client.post(
"/auth/signin", "/auth/signin", json={"username": "testuser", "password": "wrongpass"}
json={"username": "testuser", "password": "wrongpass"}
) )
assert response.status_code == 401 assert response.status_code == 401
data = response.json() data = response.json()
assert data["detail"] == "Invalid username or password" assert data["detail"] == "Invalid username or password"
def test_signin_user_not_found(): def test_signin_user_not_found():
"""Test signin with non-existent user""" """Test signin with non-existent user"""
with patch('app.routers.auth.initiate_auth') as mock_auth: with patch("app.routers.auth.initiate_auth") as mock_auth:
mock_auth.side_effect = HTTPException( mock_auth.side_effect = HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
detail="User not found"
) )
response = client.post( response = client.post(
"/auth/signin", "/auth/signin", json={"username": "nonexistent", "password": "testpass"}
json={"username": "nonexistent", "password": "testpass"}
) )
assert response.status_code == 404 assert response.status_code == 404

File diff suppressed because it is too large Load Diff

View File

@@ -1,19 +1,24 @@
from unittest.mock import patch
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from app.main import app, lifespan from app.main import app, lifespan
from unittest.mock import patch, MagicMock
@pytest.fixture @pytest.fixture
def client(): def client():
"""Test client for FastAPI app""" """Test client for FastAPI app"""
return TestClient(app) return TestClient(app)
def test_root_endpoint(client): def test_root_endpoint(client):
"""Test root endpoint returns expected message""" """Test root endpoint returns expected message"""
response = client.get("/") response = client.get("/")
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"message": "IPTV Updater API"} assert response.json() == {"message": "IPTV Updater API"}
def test_openapi_schema_generation(client): def test_openapi_schema_generation(client):
"""Test OpenAPI schema is properly generated""" """Test OpenAPI schema is properly generated"""
# First call - generate schema # First call - generate schema
@@ -35,6 +40,7 @@ def test_openapi_schema_generation(client):
assert "components" in schema assert "components" in schema
assert "schemas" in schema["components"] assert "schemas" in schema["components"]
def test_openapi_schema_caching(mocker): def test_openapi_schema_caching(mocker):
"""Test OpenAPI schema caching behavior""" """Test OpenAPI schema caching behavior"""
# Clear any existing schema # Clear any existing schema
@@ -55,6 +61,7 @@ def test_openapi_schema_caching(mocker):
assert schema == mock_schema assert schema == mock_schema
mock_get_openapi.assert_not_called() mock_get_openapi.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_lifespan_init_db(mocker): async def test_lifespan_init_db(mocker):
"""Test lifespan manager initializes database""" """Test lifespan manager initializes database"""
@@ -63,6 +70,7 @@ async def test_lifespan_init_db(mocker):
pass # Just enter/exit context pass # Just enter/exit context
mock_init_db.assert_called_once() mock_init_db.assert_called_once()
def test_router_inclusion(): def test_router_inclusion():
"""Test all routers are properly included""" """Test all routers are properly included"""
route_paths = {route.path for route in app.routes} route_paths = {route.path for route in app.routes}

View File

@@ -1,17 +1,30 @@
import os
import uuid import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from unittest.mock import patch, MagicMock from unittest.mock import MagicMock, patch
from sqlalchemy.pool import StaticPool
from sqlalchemy.orm import Session, sessionmaker, declarative_base
from sqlalchemy import create_engine, TypeDecorator, TEXT, Column, String, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer
import pytest import pytest
from sqlalchemy import (
TEXT,
Boolean,
Column,
DateTime,
ForeignKey,
Integer,
String,
TypeDecorator,
UniqueConstraint,
create_engine,
)
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.pool import StaticPool
# Create a mock-specific Base class for testing # Create a mock-specific Base class for testing
MockBase = declarative_base() MockBase = declarative_base()
class SQLiteUUID(TypeDecorator): class SQLiteUUID(TypeDecorator):
"""Enables UUID support for SQLite.""" """Enables UUID support for SQLite."""
impl = TEXT impl = TEXT
cache_ok = True cache_ok = True
@@ -25,12 +38,14 @@ class SQLiteUUID(TypeDecorator):
return value return value
return uuid.UUID(value) return uuid.UUID(value)
# Model classes for testing - prefix with Mock to avoid pytest collection # Model classes for testing - prefix with Mock to avoid pytest collection
class MockPriority(MockBase): class MockPriority(MockBase):
__tablename__ = "priorities" __tablename__ = "priorities"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
description = Column(String, nullable=False) description = Column(String, nullable=False)
class MockChannelDB(MockBase): class MockChannelDB(MockBase):
__tablename__ = "channels" __tablename__ = "channels"
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4) id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
@@ -39,32 +54,45 @@ class MockChannelDB(MockBase):
group_title = Column(String, nullable=False) group_title = Column(String, nullable=False)
tvg_name = Column(String) tvg_name = Column(String)
__table_args__ = ( __table_args__ = (
UniqueConstraint('group_title', 'name', name='uix_group_title_name'), UniqueConstraint("group_title", "name", name="uix_group_title_name"),
) )
tvg_logo = Column(String) tvg_logo = Column(String)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) updated_at = Column(
DateTime,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
class MockChannelURL(MockBase): class MockChannelURL(MockBase):
__tablename__ = "channels_urls" __tablename__ = "channels_urls"
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4) id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
channel_id = Column(SQLiteUUID(), ForeignKey('channels.id', ondelete='CASCADE'), nullable=False) channel_id = Column(
SQLiteUUID(), ForeignKey("channels.id", ondelete="CASCADE"), nullable=False
)
url = Column(String, nullable=False) url = Column(String, nullable=False)
in_use = Column(Boolean, default=False, nullable=False) in_use = Column(Boolean, default=False, nullable=False)
priority_id = Column(Integer, ForeignKey('priorities.id'), nullable=False) priority_id = Column(Integer, ForeignKey("priorities.id"), nullable=False)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) updated_at = Column(
DateTime,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
# Create test engine # Create test engine
engine_mock = create_engine( engine_mock = create_engine(
"sqlite:///:memory:", "sqlite:///:memory:",
connect_args={"check_same_thread": False}, connect_args={"check_same_thread": False},
poolclass=StaticPool poolclass=StaticPool,
) )
# Create test session # Create test session
session_mock = sessionmaker(autocommit=False, autoflush=False, bind=engine_mock) session_mock = sessionmaker(autocommit=False, autoflush=False, bind=engine_mock)
# Mock the actual database functions # Mock the actual database functions
def mock_get_db(): def mock_get_db():
db = session_mock() db = session_mock()
@@ -73,6 +101,7 @@ def mock_get_db():
finally: finally:
db.close() db.close()
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_env(monkeypatch): def mock_env(monkeypatch):
"""Fixture for mocking environment variables""" """Fixture for mocking environment variables"""
@@ -83,13 +112,12 @@ def mock_env(monkeypatch):
monkeypatch.setenv("DB_NAME", "testdb") monkeypatch.setenv("DB_NAME", "testdb")
monkeypatch.setenv("AWS_REGION", "us-east-1") monkeypatch.setenv("AWS_REGION", "us-east-1")
@pytest.fixture @pytest.fixture
def mock_ssm(): def mock_ssm():
"""Fixture for mocking boto3 SSM client""" """Fixture for mocking boto3 SSM client"""
with patch('boto3.client') as mock_client: with patch("boto3.client") as mock_client:
mock_ssm = MagicMock() mock_ssm = MagicMock()
mock_client.return_value = mock_ssm mock_client.return_value = mock_ssm
mock_ssm.get_parameter.return_value = { mock_ssm.get_parameter.return_value = {"Parameter": {"Value": "mocked_value"}}
'Parameter': {'Value': 'mocked_value'}
}
yield mock_ssm yield mock_ssm

View File

@@ -1,27 +1,27 @@
import os import os
import pytest import pytest
from unittest.mock import patch
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.utils.database import get_db_credentials, get_db
from tests.utils.db_mocks import ( from app.utils.database import get_db, get_db_credentials
session_mock, from tests.utils.db_mocks import mock_env, mock_ssm, session_mock
mock_get_db,
mock_env,
mock_ssm
)
def test_get_db_credentials_env(mock_env): def test_get_db_credentials_env(mock_env):
"""Test getting DB credentials from environment variables""" """Test getting DB credentials from environment variables"""
conn_str = get_db_credentials() conn_str = get_db_credentials()
assert conn_str == "postgresql://testuser:testpass@localhost/testdb" assert conn_str == "postgresql://testuser:testpass@localhost/testdb"
def test_get_db_credentials_ssm(mock_ssm): def test_get_db_credentials_ssm(mock_ssm):
"""Test getting DB credentials from SSM""" """Test getting DB credentials from SSM"""
os.environ.pop("MOCK_AUTH", None) os.environ.pop("MOCK_AUTH", None)
conn_str = get_db_credentials() conn_str = get_db_credentials()
assert "postgresql://mocked_value:mocked_value@mocked_value/mocked_value" in conn_str expected_conn = "postgresql://mocked_value:mocked_value@mocked_value/mocked_value"
assert expected_conn in conn_str
mock_ssm.get_parameter.assert_called() mock_ssm.get_parameter.assert_called()
def test_get_db_credentials_ssm_exception(mock_ssm): def test_get_db_credentials_ssm_exception(mock_ssm):
"""Test SSM credential fetching failure raises RuntimeError""" """Test SSM credential fetching failure raises RuntimeError"""
os.environ.pop("MOCK_AUTH", None) os.environ.pop("MOCK_AUTH", None)
@@ -32,12 +32,14 @@ def test_get_db_credentials_ssm_exception(mock_ssm):
assert "Failed to fetch DB credentials from SSM: SSM timeout" in str(excinfo.value) assert "Failed to fetch DB credentials from SSM: SSM timeout" in str(excinfo.value)
def test_session_creation(): def test_session_creation():
"""Test database session creation""" """Test database session creation"""
session = session_mock() session = session_mock()
assert isinstance(session, Session) assert isinstance(session, Session)
session.close() session.close()
def test_get_db_generator(): def test_get_db_generator():
"""Test get_db dependency generator""" """Test get_db dependency generator"""
db_gen = get_db() db_gen = get_db()
@@ -48,17 +50,19 @@ def test_get_db_generator():
except StopIteration: except StopIteration:
pass pass
def test_init_db(mocker, mock_env): def test_init_db(mocker, mock_env):
"""Test database initialization creates tables""" """Test database initialization creates tables"""
mock_create_all = mocker.patch('app.models.Base.metadata.create_all') mock_create_all = mocker.patch("app.models.Base.metadata.create_all")
# Mock get_db_credentials to return SQLite test connection # Mock get_db_credentials to return SQLite test connection
mocker.patch( mocker.patch(
'app.utils.database.get_db_credentials', "app.utils.database.get_db_credentials",
return_value="sqlite:///:memory:" return_value="sqlite:///:memory:",
) )
from app.utils.database import init_db, engine from app.utils.database import engine, init_db
init_db() init_db()
# Verify create_all was called with the engine # Verify create_all was called with the engine