Linted and formatted all files
This commit is contained in:
3
.vscode/settings.json
vendored
3
.vscode/settings.json
vendored
@@ -39,6 +39,7 @@
|
||||
"dflogo",
|
||||
"dmlogo",
|
||||
"dotenv",
|
||||
"EXTINF",
|
||||
"fastapi",
|
||||
"filterwarnings",
|
||||
"fiorinis",
|
||||
@@ -47,6 +48,7 @@
|
||||
"gitea",
|
||||
"iptv",
|
||||
"isort",
|
||||
"KHTML",
|
||||
"lclogo",
|
||||
"LETSENCRYPT",
|
||||
"nohup",
|
||||
@@ -76,6 +78,7 @@
|
||||
"testpaths",
|
||||
"uflogo",
|
||||
"umlogo",
|
||||
"usefixtures",
|
||||
"uvicorn",
|
||||
"venv",
|
||||
"wrongpass"
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
from alembic import context
|
||||
from app.utils.database import get_db_credentials
|
||||
from app.models.db import Base
|
||||
from app.utils.database import get_db_credentials
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
@@ -22,7 +20,7 @@ target_metadata = Base.metadata
|
||||
|
||||
# Override sqlalchemy.url with dynamic credentials
|
||||
if not context.is_offline_mode():
|
||||
config.set_main_option('sqlalchemy.url', get_db_credentials())
|
||||
config.set_main_option("sqlalchemy.url", get_db_credentials())
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
@@ -68,9 +66,7 @@ def run_migrations_online() -> None:
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata
|
||||
)
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
16
app.py
16
app.py
@@ -1,6 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
|
||||
import aws_cdk as cdk
|
||||
|
||||
from infrastructure.stack import IptvUpdaterStack
|
||||
|
||||
app = cdk.App()
|
||||
@@ -19,21 +21,25 @@ required_vars = {
|
||||
"DOMAIN_NAME": domain_name,
|
||||
"SSH_PUBLIC_KEY": ssh_public_key,
|
||||
"REPO_URL": repo_url,
|
||||
"LETSENCRYPT_EMAIL": letsencrypt_email
|
||||
"LETSENCRYPT_EMAIL": letsencrypt_email,
|
||||
}
|
||||
|
||||
# Check for missing required variables
|
||||
missing_vars = [k for k, v in required_vars.items() if not v]
|
||||
if missing_vars:
|
||||
raise ValueError(f"Missing required environment variables: {', '.join(missing_vars)}")
|
||||
raise ValueError(
|
||||
f"Missing required environment variables: {', '.join(missing_vars)}"
|
||||
)
|
||||
|
||||
IptvUpdaterStack(app, "IptvUpdaterStack",
|
||||
IptvUpdaterStack(
|
||||
app,
|
||||
"IptvUpdaterStack",
|
||||
freedns_user=freedns_user,
|
||||
freedns_password=freedns_password,
|
||||
domain_name=domain_name,
|
||||
ssh_public_key=ssh_public_key,
|
||||
repo_url=repo_url,
|
||||
letsencrypt_email=letsencrypt_email
|
||||
letsencrypt_email=letsencrypt_email,
|
||||
)
|
||||
|
||||
app.synth()
|
||||
app.synth()
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
import boto3
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.models.auth import CognitoUser
|
||||
from app.utils.auth import calculate_secret_hash
|
||||
from app.utils.constants import (AWS_REGION, COGNITO_CLIENT_ID,
|
||||
COGNITO_CLIENT_SECRET, USER_ROLE_ATTRIBUTE)
|
||||
from app.utils.constants import (
|
||||
AWS_REGION,
|
||||
COGNITO_CLIENT_ID,
|
||||
COGNITO_CLIENT_SECRET,
|
||||
USER_ROLE_ATTRIBUTE,
|
||||
)
|
||||
|
||||
cognito_client = boto3.client("cognito-idp", region_name=AWS_REGION)
|
||||
|
||||
@@ -12,43 +17,41 @@ def initiate_auth(username: str, password: str) -> dict:
|
||||
"""
|
||||
Initiate AUTH flow with Cognito using USER_PASSWORD_AUTH.
|
||||
"""
|
||||
auth_params = {
|
||||
"USERNAME": username,
|
||||
"PASSWORD": password
|
||||
}
|
||||
auth_params = {"USERNAME": username, "PASSWORD": password}
|
||||
|
||||
# If a client secret is required, add SECRET_HASH
|
||||
if COGNITO_CLIENT_SECRET:
|
||||
auth_params["SECRET_HASH"] = calculate_secret_hash(
|
||||
username, COGNITO_CLIENT_ID, COGNITO_CLIENT_SECRET)
|
||||
username, COGNITO_CLIENT_ID, COGNITO_CLIENT_SECRET
|
||||
)
|
||||
|
||||
try:
|
||||
response = cognito_client.initiate_auth(
|
||||
AuthFlow="USER_PASSWORD_AUTH",
|
||||
AuthParameters=auth_params,
|
||||
ClientId=COGNITO_CLIENT_ID
|
||||
ClientId=COGNITO_CLIENT_ID,
|
||||
)
|
||||
return response["AuthenticationResult"]
|
||||
except cognito_client.exceptions.NotAuthorizedException:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password"
|
||||
detail="Invalid username or password",
|
||||
)
|
||||
except cognito_client.exceptions.UserNotFoundException:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"An error occurred during authentication: {str(e)}"
|
||||
detail=f"An error occurred during authentication: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
def get_user_from_token(access_token: str) -> CognitoUser:
|
||||
"""
|
||||
Verify the token by calling GetUser in Cognito and retrieve user attributes including roles.
|
||||
Verify the token by calling GetUser in Cognito and
|
||||
retrieve user attributes including roles.
|
||||
"""
|
||||
try:
|
||||
user_response = cognito_client.get_user(AccessToken=access_token)
|
||||
@@ -59,23 +62,21 @@ def get_user_from_token(access_token: str) -> CognitoUser:
|
||||
for attr in attributes:
|
||||
if attr["Name"] == USER_ROLE_ATTRIBUTE:
|
||||
# Assume roles are stored as a comma-separated string
|
||||
user_roles = [r.strip()
|
||||
for r in attr["Value"].split(",") if r.strip()]
|
||||
user_roles = [r.strip() for r in attr["Value"].split(",") if r.strip()]
|
||||
break
|
||||
|
||||
return CognitoUser(username=username, roles=user_roles)
|
||||
except cognito_client.exceptions.NotAuthorizedException:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token."
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token."
|
||||
)
|
||||
except cognito_client.exceptions.UserNotFoundException:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found or invalid token."
|
||||
detail="User not found or invalid token.",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Token verification failed: {str(e)}"
|
||||
detail=f"Token verification failed: {str(e)}",
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
import os
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
@@ -13,10 +13,8 @@ if os.getenv("MOCK_AUTH", "").lower() == "true":
|
||||
else:
|
||||
from app.auth.cognito import get_user_from_token
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(
|
||||
tokenUrl="signin",
|
||||
scheme_name="Bearer"
|
||||
)
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="signin", scheme_name="Bearer")
|
||||
|
||||
|
||||
def get_current_user(token: str = Depends(oauth2_scheme)) -> CognitoUser:
|
||||
"""
|
||||
@@ -40,7 +38,9 @@ def require_roles(*required_roles: str) -> Callable:
|
||||
if not needed_roles.issubset(user_roles):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You do not have the required roles to access this endpoint.",
|
||||
detail=(
|
||||
"You do not have the required roles to access this endpoint."
|
||||
),
|
||||
)
|
||||
return endpoint(*args, user=user, **kwargs)
|
||||
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.models.auth import CognitoUser
|
||||
|
||||
MOCK_USERS = {
|
||||
"testuser": {
|
||||
"username": "testuser",
|
||||
"roles": ["admin"]
|
||||
}
|
||||
}
|
||||
MOCK_USERS = {"testuser": {"username": "testuser", "roles": ["admin"]}}
|
||||
|
||||
|
||||
def mock_get_user_from_token(token: str) -> CognitoUser:
|
||||
"""
|
||||
@@ -17,16 +14,13 @@ def mock_get_user_from_token(token: str) -> CognitoUser:
|
||||
return CognitoUser(**MOCK_USERS["testuser"])
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid mock token - use 'testuser'"
|
||||
detail="Invalid mock token - use 'testuser'",
|
||||
)
|
||||
|
||||
|
||||
def mock_initiate_auth(username: str, password: str) -> dict:
|
||||
"""
|
||||
Mock version of initiate_auth for local testing
|
||||
Accepts any username/password and returns a mock token
|
||||
"""
|
||||
return {
|
||||
"AccessToken": "testuser",
|
||||
"ExpiresIn": 3600,
|
||||
"TokenType": "Bearer"
|
||||
}
|
||||
return {"AccessToken": "testuser", "ExpiresIn": 3600, "TokenType": "Bearer"}
|
||||
|
||||
@@ -1,39 +1,59 @@
|
||||
import os
|
||||
import re
|
||||
import argparse
|
||||
import gzip
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
import requests
|
||||
import argparse
|
||||
from utils.constants import IPTV_SERVER_ADMIN_PASSWORD, IPTV_SERVER_ADMIN_USER, IPTV_SERVER_URL
|
||||
from utils.constants import (
|
||||
IPTV_SERVER_ADMIN_PASSWORD,
|
||||
IPTV_SERVER_ADMIN_USER,
|
||||
IPTV_SERVER_URL,
|
||||
)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description='EPG Grabber')
|
||||
parser.add_argument('--playlist',
|
||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'playlist.m3u8'),
|
||||
help='Path to playlist file')
|
||||
parser.add_argument('--output',
|
||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'epg.xml'),
|
||||
help='Path to output EPG XML file')
|
||||
parser.add_argument('--epg-sources',
|
||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'epg_sources.json'),
|
||||
help='Path to EPG sources JSON configuration file')
|
||||
parser.add_argument('--save-as-gz',
|
||||
action='store_true',
|
||||
default=True,
|
||||
help='Save an additional gzipped version of the EPG file')
|
||||
parser = argparse.ArgumentParser(description="EPG Grabber")
|
||||
parser.add_argument(
|
||||
"--playlist",
|
||||
default=os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "playlist.m3u8"
|
||||
),
|
||||
help="Path to playlist file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "epg.xml"),
|
||||
help="Path to output EPG XML file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epg-sources",
|
||||
default=os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "epg_sources.json"
|
||||
),
|
||||
help="Path to EPG sources JSON configuration file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-as-gz",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Save an additional gzipped version of the EPG file",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_epg_sources(config_path):
|
||||
"""Load EPG sources from JSON configuration file"""
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
return config.get('epg_sources', [])
|
||||
return config.get("epg_sources", [])
|
||||
except (FileNotFoundError, json.JSONDecodeError) as e:
|
||||
print(f"Error loading EPG sources: {e}")
|
||||
return []
|
||||
|
||||
|
||||
|
||||
def get_tvg_ids(playlist_path):
|
||||
"""
|
||||
Extracts unique tvg-id values from an M3U playlist file.
|
||||
@@ -51,26 +71,27 @@ def get_tvg_ids(playlist_path):
|
||||
# and ends with a double quote.
|
||||
tvg_id_pattern = re.compile(r'tvg-id="([^"]*)"')
|
||||
|
||||
with open(playlist_path, 'r', encoding='utf-8') as file:
|
||||
with open(playlist_path, encoding="utf-8") as file:
|
||||
for line in file:
|
||||
if line.startswith('#EXTINF'):
|
||||
if line.startswith("#EXTINF"):
|
||||
# Search for the tvg-id pattern in the line
|
||||
match = tvg_id_pattern.search(line)
|
||||
if match:
|
||||
# Extract the captured group (the value inside the quotes)
|
||||
tvg_id = match.group(1)
|
||||
if tvg_id: # Ensure the extracted id is not empty
|
||||
if tvg_id: # Ensure the extracted id is not empty
|
||||
unique_tvg_ids.add(tvg_id)
|
||||
|
||||
return list(unique_tvg_ids)
|
||||
|
||||
|
||||
def fetch_and_extract_xml(url):
|
||||
response = requests.get(url)
|
||||
if response.status_code != 200:
|
||||
print(f"Failed to fetch {url}")
|
||||
return None
|
||||
|
||||
if url.endswith('.gz'):
|
||||
if url.endswith(".gz"):
|
||||
try:
|
||||
decompressed_data = gzip.decompress(response.content)
|
||||
return ET.fromstring(decompressed_data)
|
||||
@@ -84,44 +105,48 @@ def fetch_and_extract_xml(url):
|
||||
print(f"Failed to parse XML from {url}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def filter_and_build_epg(urls, tvg_ids, output_file, save_as_gz=True):
|
||||
root = ET.Element('tv')
|
||||
root = ET.Element("tv")
|
||||
|
||||
for url in urls:
|
||||
epg_data = fetch_and_extract_xml(url)
|
||||
if epg_data is None:
|
||||
continue
|
||||
|
||||
for channel in epg_data.findall('channel'):
|
||||
tvg_id = channel.get('id')
|
||||
for channel in epg_data.findall("channel"):
|
||||
tvg_id = channel.get("id")
|
||||
if tvg_id in tvg_ids:
|
||||
root.append(channel)
|
||||
|
||||
for programme in epg_data.findall('programme'):
|
||||
tvg_id = programme.get('channel')
|
||||
for programme in epg_data.findall("programme"):
|
||||
tvg_id = programme.get("channel")
|
||||
if tvg_id in tvg_ids:
|
||||
root.append(programme)
|
||||
|
||||
tree = ET.ElementTree(root)
|
||||
tree.write(output_file, encoding='utf-8', xml_declaration=True)
|
||||
tree.write(output_file, encoding="utf-8", xml_declaration=True)
|
||||
print(f"New EPG saved to {output_file}")
|
||||
|
||||
if save_as_gz:
|
||||
output_file_gz = output_file + '.gz'
|
||||
with gzip.open(output_file_gz, 'wb') as f:
|
||||
tree.write(f, encoding='utf-8', xml_declaration=True)
|
||||
output_file_gz = output_file + ".gz"
|
||||
with gzip.open(output_file_gz, "wb") as f:
|
||||
tree.write(f, encoding="utf-8", xml_declaration=True)
|
||||
print(f"New EPG saved to {output_file_gz}")
|
||||
|
||||
|
||||
def upload_epg(file_path):
|
||||
"""Uploads gzipped EPG file to IPTV server using HTTP Basic Auth"""
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
with open(file_path, "rb") as f:
|
||||
response = requests.post(
|
||||
IPTV_SERVER_URL + '/admin/epg',
|
||||
auth=requests.auth.HTTPBasicAuth(IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD),
|
||||
files={'file': (os.path.basename(file_path), f)}
|
||||
IPTV_SERVER_URL + "/admin/epg",
|
||||
auth=requests.auth.HTTPBasicAuth(
|
||||
IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD
|
||||
),
|
||||
files={"file": (os.path.basename(file_path), f)},
|
||||
)
|
||||
|
||||
|
||||
if response.status_code == 200:
|
||||
print("EPG successfully uploaded to server")
|
||||
else:
|
||||
@@ -129,6 +154,7 @@ def upload_epg(file_path):
|
||||
except Exception as e:
|
||||
print(f"Upload error: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments()
|
||||
playlist_file = args.playlist
|
||||
@@ -144,4 +170,4 @@ if __name__ == "__main__":
|
||||
filter_and_build_epg(urls, tvg_ids, output_file, args.save_as_gz)
|
||||
|
||||
if args.save_as_gz:
|
||||
upload_epg(output_file + '.gz')
|
||||
upload_epg(output_file + ".gz")
|
||||
|
||||
@@ -1,26 +1,45 @@
|
||||
import os
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import requests
|
||||
from pathlib import Path
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
from utils.check_streams import StreamValidator
|
||||
from utils.constants import EPG_URL, IPTV_SERVER_ADMIN_PASSWORD, IPTV_SERVER_ADMIN_USER, IPTV_SERVER_URL
|
||||
from utils.constants import (
|
||||
EPG_URL,
|
||||
IPTV_SERVER_ADMIN_PASSWORD,
|
||||
IPTV_SERVER_ADMIN_USER,
|
||||
IPTV_SERVER_URL,
|
||||
)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description='IPTV playlist generator')
|
||||
parser.add_argument('--output',
|
||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'playlist.m3u8'),
|
||||
help='Path to output playlist file')
|
||||
parser.add_argument('--channels',
|
||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'channels.json'),
|
||||
help='Path to channels definition JSON file')
|
||||
parser.add_argument('--dead-channels-log',
|
||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'dead_channels.log'),
|
||||
help='Path to log file to store a list of dead channels')
|
||||
parser = argparse.ArgumentParser(description="IPTV playlist generator")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default=os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "playlist.m3u8"
|
||||
),
|
||||
help="Path to output playlist file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--channels",
|
||||
default=os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "channels.json"
|
||||
),
|
||||
help="Path to channels definition JSON file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dead-channels-log",
|
||||
default=os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "dead_channels.log"
|
||||
),
|
||||
help="Path to log file to store a list of dead channels",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def find_working_stream(validator, urls):
|
||||
"""Test all URLs and return the first working one"""
|
||||
for url in urls:
|
||||
@@ -29,48 +48,55 @@ def find_working_stream(validator, urls):
|
||||
return url
|
||||
return None
|
||||
|
||||
|
||||
def create_playlist(channels_file, output_file):
|
||||
# Read channels from JSON file
|
||||
with open(channels_file, 'r', encoding='utf-8') as f:
|
||||
with open(channels_file, encoding="utf-8") as f:
|
||||
channels = json.load(f)
|
||||
|
||||
# Initialize validator
|
||||
validator = StreamValidator(timeout=45)
|
||||
|
||||
|
||||
# Prepare M3U8 header
|
||||
m3u8_content = f'#EXTM3U url-tvg="{EPG_URL}"\n'
|
||||
|
||||
|
||||
for channel in channels:
|
||||
if 'urls' in channel: # Check if channel has URLs
|
||||
if "urls" in channel: # Check if channel has URLs
|
||||
# Find first working stream
|
||||
working_url = find_working_stream(validator, channel['urls'])
|
||||
|
||||
working_url = find_working_stream(validator, channel["urls"])
|
||||
|
||||
if working_url:
|
||||
# Add channel to playlist
|
||||
m3u8_content += f'#EXTINF:-1 tvg-id="{channel.get("tvg-id", "")}" '
|
||||
m3u8_content += f'tvg-name="{channel.get("tvg-name", "")}" '
|
||||
m3u8_content += f'tvg-logo="{channel.get("tvg-logo", "")}" '
|
||||
m3u8_content += f'group-title="{channel.get("group-title", "")}", '
|
||||
m3u8_content += f'{channel.get("name", "")}\n'
|
||||
m3u8_content += f'{working_url}\n'
|
||||
m3u8_content += f"{channel.get('name', '')}\n"
|
||||
m3u8_content += f"{working_url}\n"
|
||||
else:
|
||||
# Log dead channel
|
||||
logging.info(f'Dead channel: {channel.get("name", "Unknown")} - No working streams found')
|
||||
logging.info(
|
||||
f"Dead channel: {channel.get('name', 'Unknown')} - "
|
||||
"No working streams found"
|
||||
)
|
||||
|
||||
# Write playlist file
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
f.write(m3u8_content)
|
||||
|
||||
|
||||
def upload_playlist(file_path):
|
||||
"""Uploads playlist file to IPTV server using HTTP Basic Auth"""
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
with open(file_path, "rb") as f:
|
||||
response = requests.post(
|
||||
IPTV_SERVER_URL + '/admin/playlist',
|
||||
auth=requests.auth.HTTPBasicAuth(IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD),
|
||||
files={'file': (os.path.basename(file_path), f)}
|
||||
IPTV_SERVER_URL + "/admin/playlist",
|
||||
auth=requests.auth.HTTPBasicAuth(
|
||||
IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD
|
||||
),
|
||||
files={"file": (os.path.basename(file_path), f)},
|
||||
)
|
||||
|
||||
|
||||
if response.status_code == 200:
|
||||
print("Playlist successfully uploaded to server")
|
||||
else:
|
||||
@@ -78,6 +104,7 @@ def upload_playlist(file_path):
|
||||
except Exception as e:
|
||||
print(f"Upload error: {str(e)}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
channels_file = args.channels
|
||||
@@ -85,24 +112,25 @@ def main():
|
||||
dead_channels_log_file = args.dead_channels_log
|
||||
|
||||
# Clear previous log file
|
||||
with open(dead_channels_log_file, 'w') as f:
|
||||
f.write(f'Log created on {datetime.now()}\n')
|
||||
with open(dead_channels_log_file, "w") as f:
|
||||
f.write(f"Log created on {datetime.now()}\n")
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
filename=dead_channels_log_file,
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
format="%(asctime)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
|
||||
# Create playlist
|
||||
create_playlist(channels_file, output_file)
|
||||
|
||||
#upload playlist to server
|
||||
# upload playlist to server
|
||||
upload_playlist(output_file)
|
||||
|
||||
|
||||
print("Playlist creation completed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
19
app/main.py
19
app/main.py
@@ -1,17 +1,18 @@
|
||||
|
||||
from fastapi.concurrency import asynccontextmanager
|
||||
from app.routers import channels, auth, playlist, priorities
|
||||
from fastapi import FastAPI
|
||||
from fastapi.concurrency import asynccontextmanager
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from app.routers import auth, channels, playlist, priorities
|
||||
from app.utils.database import init_db
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Initialize database tables on startup
|
||||
init_db()
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
title="IPTV Updater API",
|
||||
@@ -19,6 +20,7 @@ app = FastAPI(
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
@@ -40,11 +42,7 @@ def custom_openapi():
|
||||
|
||||
# Add security scheme component
|
||||
openapi_schema["components"]["securitySchemes"] = {
|
||||
"Bearer": {
|
||||
"type": "http",
|
||||
"scheme": "bearer",
|
||||
"bearerFormat": "JWT"
|
||||
}
|
||||
"Bearer": {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}
|
||||
}
|
||||
|
||||
# Add global security requirement
|
||||
@@ -56,14 +54,17 @@ def custom_openapi():
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
|
||||
app.openapi = custom_openapi
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "IPTV Updater API"}
|
||||
|
||||
|
||||
# Include routers
|
||||
app.include_router(auth.router)
|
||||
app.include_router(channels.router)
|
||||
app.include_router(playlist.router)
|
||||
app.include_router(priorities.router)
|
||||
app.include_router(priorities.router)
|
||||
|
||||
@@ -1,4 +1,19 @@
|
||||
from .db import Base, ChannelDB, ChannelURL
|
||||
from .schemas import ChannelCreate, ChannelUpdate, ChannelResponse, ChannelURLCreate, ChannelURLResponse
|
||||
from .schemas import (
|
||||
ChannelCreate,
|
||||
ChannelResponse,
|
||||
ChannelUpdate,
|
||||
ChannelURLCreate,
|
||||
ChannelURLResponse,
|
||||
)
|
||||
|
||||
__all__ = ["Base", "ChannelDB", "ChannelCreate", "ChannelUpdate", "ChannelResponse", "ChannelURL", "ChannelURLCreate", "ChannelURLResponse"]
|
||||
__all__ = [
|
||||
"Base",
|
||||
"ChannelDB",
|
||||
"ChannelCreate",
|
||||
"ChannelUpdate",
|
||||
"ChannelResponse",
|
||||
"ChannelURL",
|
||||
"ChannelURLCreate",
|
||||
"ChannelURLResponse",
|
||||
]
|
||||
|
||||
@@ -1,20 +1,26 @@
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SigninRequest(BaseModel):
|
||||
"""Request model for the signin endpoint."""
|
||||
|
||||
username: str = Field(..., description="The user's username")
|
||||
password: str = Field(..., description="The user's password")
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""Response model for successful authentication."""
|
||||
|
||||
access_token: str = Field(..., description="Access JWT token from Cognito")
|
||||
id_token: str = Field(..., description="ID JWT token from Cognito")
|
||||
refresh_token: Optional[str] = Field(
|
||||
None, description="Refresh token from Cognito")
|
||||
refresh_token: Optional[str] = Field(None, description="Refresh token from Cognito")
|
||||
token_type: str = Field(..., description="Type of the token returned")
|
||||
|
||||
|
||||
class CognitoUser(BaseModel):
|
||||
"""Model representing the user returned from token verification."""
|
||||
|
||||
username: str
|
||||
roles: List[str]
|
||||
roles: list[str]
|
||||
|
||||
@@ -1,21 +1,33 @@
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, JSON, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm import declarative_base, relationship
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class Priority(Base):
|
||||
"""SQLAlchemy model for channel URL priorities"""
|
||||
|
||||
__tablename__ = "priorities"
|
||||
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
description = Column(String, nullable=False)
|
||||
|
||||
|
||||
class ChannelDB(Base):
|
||||
"""SQLAlchemy model for IPTV channels"""
|
||||
|
||||
__tablename__ = "channels"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
@@ -25,27 +37,43 @@ class ChannelDB(Base):
|
||||
tvg_name = Column(String)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint('group_title', 'name', name='uix_group_title_name'),
|
||||
UniqueConstraint("group_title", "name", name="uix_group_title_name"),
|
||||
)
|
||||
tvg_logo = Column(String)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Relationship with ChannelURL
|
||||
urls = relationship("ChannelURL", back_populates="channel", cascade="all, delete-orphan")
|
||||
urls = relationship(
|
||||
"ChannelURL", back_populates="channel", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class ChannelURL(Base):
|
||||
"""SQLAlchemy model for channel URLs"""
|
||||
|
||||
__tablename__ = "channels_urls"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
channel_id = Column(UUID(as_uuid=True), ForeignKey('channels.id', ondelete='CASCADE'), nullable=False)
|
||||
channel_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("channels.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
url = Column(String, nullable=False)
|
||||
in_use = Column(Boolean, default=False, nullable=False)
|
||||
priority_id = Column(Integer, ForeignKey('priorities.id'), nullable=False)
|
||||
priority_id = Column(Integer, ForeignKey("priorities.id"), nullable=False)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Relationships
|
||||
channel = relationship("ChannelDB", back_populates="urls")
|
||||
priority = relationship("Priority")
|
||||
priority = relationship("Priority")
|
||||
|
||||
@@ -1,30 +1,43 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class PriorityBase(BaseModel):
|
||||
"""Base Pydantic model for priorities"""
|
||||
|
||||
id: int
|
||||
description: str
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class PriorityCreate(PriorityBase):
|
||||
"""Pydantic model for creating priorities"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PriorityResponse(PriorityBase):
|
||||
"""Pydantic model for priority responses"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ChannelURLCreate(BaseModel):
|
||||
"""Pydantic model for creating channel URLs"""
|
||||
|
||||
url: str
|
||||
priority_id: int = Field(default=100, ge=100, le=300) # Default to High, validate range
|
||||
priority_id: int = Field(
|
||||
default=100, ge=100, le=300
|
||||
) # Default to High, validate range
|
||||
|
||||
|
||||
class ChannelURLBase(ChannelURLCreate):
|
||||
"""Base Pydantic model for channel URL responses"""
|
||||
|
||||
id: UUID
|
||||
in_use: bool
|
||||
created_at: datetime
|
||||
@@ -33,43 +46,53 @@ class ChannelURLBase(ChannelURLCreate):
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ChannelURLResponse(ChannelURLBase):
|
||||
"""Pydantic model for channel URL responses"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ChannelCreate(BaseModel):
|
||||
"""Pydantic model for creating channels"""
|
||||
urls: List[ChannelURLCreate] # List of URL objects with priority
|
||||
|
||||
urls: list[ChannelURLCreate] # List of URL objects with priority
|
||||
name: str
|
||||
group_title: str
|
||||
tvg_id: str
|
||||
tvg_logo: str
|
||||
tvg_name: str
|
||||
|
||||
|
||||
class ChannelURLUpdate(BaseModel):
|
||||
"""Pydantic model for updating channel URLs"""
|
||||
|
||||
url: Optional[str] = None
|
||||
in_use: Optional[bool] = None
|
||||
priority_id: Optional[int] = Field(default=None, ge=100, le=300)
|
||||
|
||||
|
||||
class ChannelUpdate(BaseModel):
|
||||
"""Pydantic model for updating channels (all fields optional)"""
|
||||
|
||||
name: Optional[str] = Field(None, min_length=1)
|
||||
group_title: Optional[str] = Field(None, min_length=1)
|
||||
tvg_id: Optional[str] = Field(None, min_length=1)
|
||||
tvg_logo: Optional[str] = None
|
||||
tvg_name: Optional[str] = Field(None, min_length=1)
|
||||
|
||||
|
||||
class ChannelResponse(BaseModel):
|
||||
"""Pydantic model for channel responses"""
|
||||
|
||||
id: UUID
|
||||
name: str
|
||||
group_title: str
|
||||
tvg_id: str
|
||||
tvg_logo: str
|
||||
tvg_name: str
|
||||
urls: List[ChannelURLResponse] # List of URL objects without channel_id
|
||||
urls: list[ChannelURLResponse] # List of URL objects without channel_id
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.auth.cognito import initiate_auth
|
||||
from app.models.auth import SigninRequest, TokenResponse
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/auth",
|
||||
tags=["authentication"]
|
||||
)
|
||||
router = APIRouter(prefix="/auth", tags=["authentication"])
|
||||
|
||||
|
||||
@router.post("/signin", response_model=TokenResponse, summary="Signin Endpoint")
|
||||
def signin(credentials: SigninRequest):
|
||||
"""
|
||||
Sign-in endpoint to authenticate the user with AWS Cognito using username and password.
|
||||
Sign-in endpoint to authenticate the user with AWS Cognito
|
||||
using username and password.
|
||||
On success, returns JWT tokens (access_token, id_token, refresh_token).
|
||||
"""
|
||||
auth_result = initiate_auth(credentials.username, credentials.password)
|
||||
@@ -19,4 +19,4 @@ def signin(credentials: SigninRequest):
|
||||
id_token=auth_result["IdToken"],
|
||||
refresh_token=auth_result.get("RefreshToken"),
|
||||
token_type="Bearer",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,52 +1,54 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
from sqlalchemy import and_
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.auth.dependencies import get_current_user, require_roles
|
||||
from app.models import (
|
||||
ChannelDB,
|
||||
ChannelURL,
|
||||
ChannelCreate,
|
||||
ChannelUpdate,
|
||||
ChannelDB,
|
||||
ChannelResponse,
|
||||
ChannelUpdate,
|
||||
ChannelURL,
|
||||
ChannelURLCreate,
|
||||
ChannelURLResponse,
|
||||
)
|
||||
from app.models.auth import CognitoUser
|
||||
from app.models.schemas import ChannelURLUpdate
|
||||
from app.utils.database import get_db
|
||||
from app.auth.dependencies import get_current_user, require_roles
|
||||
from app.models.auth import CognitoUser
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/channels",
|
||||
tags=["channels"]
|
||||
)
|
||||
router = APIRouter(prefix="/channels", tags=["channels"])
|
||||
|
||||
|
||||
@router.post("/", response_model=ChannelResponse, status_code=status.HTTP_201_CREATED)
|
||||
@require_roles("admin")
|
||||
def create_channel(
|
||||
channel: ChannelCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user)
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Create a new channel"""
|
||||
# Check for duplicate channel (same group_title + name)
|
||||
existing_channel = db.query(ChannelDB).filter(
|
||||
and_(
|
||||
ChannelDB.group_title == channel.group_title,
|
||||
ChannelDB.name == channel.name
|
||||
existing_channel = (
|
||||
db.query(ChannelDB)
|
||||
.filter(
|
||||
and_(
|
||||
ChannelDB.group_title == channel.group_title,
|
||||
ChannelDB.name == channel.name,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Channel with same group_title and name already exists"
|
||||
detail="Channel with same group_title and name already exists",
|
||||
)
|
||||
|
||||
# Create channel without URLs first
|
||||
channel_data = channel.model_dump(exclude={'urls'})
|
||||
channel_data = channel.model_dump(exclude={"urls"})
|
||||
urls = channel.urls
|
||||
db_channel = ChannelDB(**channel_data)
|
||||
db.add(db_channel)
|
||||
@@ -59,130 +61,142 @@ def create_channel(
|
||||
channel_id=db_channel.id,
|
||||
url=url.url,
|
||||
priority_id=url.priority_id,
|
||||
in_use=False
|
||||
in_use=False,
|
||||
)
|
||||
db.add(db_url)
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_channel)
|
||||
return db_channel
|
||||
|
||||
|
||||
@router.get("/{channel_id}", response_model=ChannelResponse)
|
||||
def get_channel(
|
||||
channel_id: UUID,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
def get_channel(channel_id: UUID, db: Session = Depends(get_db)):
|
||||
"""Get a channel by id"""
|
||||
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Channel not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||
)
|
||||
return channel
|
||||
|
||||
|
||||
@router.put("/{channel_id}", response_model=ChannelResponse)
|
||||
@require_roles("admin")
|
||||
def update_channel(
|
||||
channel_id: UUID,
|
||||
channel: ChannelUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user)
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Update a channel"""
|
||||
db_channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
|
||||
if not db_channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Channel not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||
)
|
||||
|
||||
# Only check for duplicates if name or group_title are being updated
|
||||
if channel.name is not None or channel.group_title is not None:
|
||||
name = channel.name if channel.name is not None else db_channel.name
|
||||
group_title = channel.group_title if channel.group_title is not None else db_channel.group_title
|
||||
|
||||
existing_channel = db.query(ChannelDB).filter(
|
||||
and_(
|
||||
ChannelDB.group_title == group_title,
|
||||
ChannelDB.name == name,
|
||||
ChannelDB.id != channel_id
|
||||
group_title = (
|
||||
channel.group_title
|
||||
if channel.group_title is not None
|
||||
else db_channel.group_title
|
||||
)
|
||||
|
||||
existing_channel = (
|
||||
db.query(ChannelDB)
|
||||
.filter(
|
||||
and_(
|
||||
ChannelDB.group_title == group_title,
|
||||
ChannelDB.name == name,
|
||||
ChannelDB.id != channel_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Channel with same group_title and name already exists"
|
||||
detail="Channel with same group_title and name already exists",
|
||||
)
|
||||
|
||||
|
||||
# Update only provided fields
|
||||
update_data = channel.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_channel, key, value)
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_channel)
|
||||
return db_channel
|
||||
|
||||
|
||||
@router.delete("/{channel_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
@require_roles("admin")
|
||||
def delete_channel(
|
||||
channel_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user)
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Delete a channel"""
|
||||
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Channel not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||
)
|
||||
db.delete(channel)
|
||||
db.commit()
|
||||
return None
|
||||
|
||||
@router.get("/", response_model=List[ChannelResponse])
|
||||
|
||||
@router.get("/", response_model=list[ChannelResponse])
|
||||
@require_roles("admin")
|
||||
def list_channels(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user)
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""List all channels with pagination"""
|
||||
return db.query(ChannelDB).offset(skip).limit(limit).all()
|
||||
|
||||
|
||||
# URL Management Endpoints
|
||||
|
||||
@router.post("/{channel_id}/urls", response_model=ChannelURLResponse, status_code=status.HTTP_201_CREATED)
|
||||
|
||||
@router.post(
|
||||
"/{channel_id}/urls",
|
||||
response_model=ChannelURLResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
@require_roles("admin")
|
||||
def add_channel_url(
|
||||
channel_id: UUID,
|
||||
url: ChannelURLCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user)
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Add a new URL to a channel"""
|
||||
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Channel not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||
)
|
||||
|
||||
db_url = ChannelURL(
|
||||
channel_id=channel_id,
|
||||
url=url.url,
|
||||
priority_id=url.priority_id,
|
||||
in_use=False # Default to not in use
|
||||
in_use=False, # Default to not in use
|
||||
)
|
||||
db.add(db_url)
|
||||
db.commit()
|
||||
db.refresh(db_url)
|
||||
return db_url
|
||||
|
||||
|
||||
@router.put("/{channel_id}/urls/{url_id}", response_model=ChannelURLResponse)
|
||||
@require_roles("admin")
|
||||
def update_channel_url(
|
||||
@@ -190,72 +204,69 @@ def update_channel_url(
|
||||
url_id: UUID,
|
||||
url_update: ChannelURLUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user)
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Update a channel URL (url, in_use, or priority_id)"""
|
||||
db_url = db.query(ChannelURL).filter(
|
||||
and_(
|
||||
ChannelURL.id == url_id,
|
||||
ChannelURL.channel_id == channel_id
|
||||
)
|
||||
).first()
|
||||
|
||||
db_url = (
|
||||
db.query(ChannelURL)
|
||||
.filter(and_(ChannelURL.id == url_id, ChannelURL.channel_id == channel_id))
|
||||
.first()
|
||||
)
|
||||
|
||||
if not db_url:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="URL not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="URL not found"
|
||||
)
|
||||
|
||||
|
||||
if url_update.url is not None:
|
||||
db_url.url = url_update.url
|
||||
if url_update.in_use is not None:
|
||||
db_url.in_use = url_update.in_use
|
||||
if url_update.priority_id is not None:
|
||||
db_url.priority_id = url_update.priority_id
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_url)
|
||||
return db_url
|
||||
|
||||
|
||||
@router.delete("/{channel_id}/urls/{url_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
@require_roles("admin")
|
||||
def delete_channel_url(
|
||||
channel_id: UUID,
|
||||
url_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user)
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Delete a URL from a channel"""
|
||||
url = db.query(ChannelURL).filter(
|
||||
and_(
|
||||
ChannelURL.id == url_id,
|
||||
ChannelURL.channel_id == channel_id
|
||||
)
|
||||
).first()
|
||||
|
||||
url = (
|
||||
db.query(ChannelURL)
|
||||
.filter(and_(ChannelURL.id == url_id, ChannelURL.channel_id == channel_id))
|
||||
.first()
|
||||
)
|
||||
|
||||
if not url:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="URL not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="URL not found"
|
||||
)
|
||||
|
||||
|
||||
db.delete(url)
|
||||
db.commit()
|
||||
return None
|
||||
|
||||
@router.get("/{channel_id}/urls", response_model=List[ChannelURLResponse])
|
||||
|
||||
@router.get("/{channel_id}/urls", response_model=list[ChannelURLResponse])
|
||||
@require_roles("admin")
|
||||
def list_channel_urls(
|
||||
channel_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user)
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""List all URLs for a channel"""
|
||||
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Channel not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||
)
|
||||
|
||||
return db.query(ChannelURL).filter(ChannelURL.channel_id == channel_id).all()
|
||||
|
||||
return db.query(ChannelURL).filter(ChannelURL.channel_id == channel_id).all()
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.auth.dependencies import get_current_user
|
||||
from app.models.auth import CognitoUser
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/playlist",
|
||||
tags=["playlist"]
|
||||
)
|
||||
router = APIRouter(prefix="/playlist", tags=["playlist"])
|
||||
|
||||
@router.get("/protected",
|
||||
summary="Protected endpoint for authenticated users")
|
||||
|
||||
@router.get("/protected", summary="Protected endpoint for authenticated users")
|
||||
async def protected_route(user: CognitoUser = Depends(get_current_user)):
|
||||
"""
|
||||
Protected endpoint that requires authentication for all users.
|
||||
If the user is authenticated, returns success message.
|
||||
"""
|
||||
return {"message": f"Hello {user.username}, you have access to support resources!"}
|
||||
return {"message": f"Hello {user.username}, you have access to support resources!"}
|
||||
|
||||
@@ -1,25 +1,22 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, delete
|
||||
from typing import List
|
||||
|
||||
from app.auth.dependencies import get_current_user, require_roles
|
||||
from app.models.auth import CognitoUser
|
||||
from app.models.db import Priority
|
||||
from app.models.schemas import PriorityCreate, PriorityResponse
|
||||
from app.utils.database import get_db
|
||||
from app.auth.dependencies import get_current_user, require_roles
|
||||
from app.models.auth import CognitoUser
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/priorities",
|
||||
tags=["priorities"]
|
||||
)
|
||||
router = APIRouter(prefix="/priorities", tags=["priorities"])
|
||||
|
||||
|
||||
@router.post("/", response_model=PriorityResponse, status_code=status.HTTP_201_CREATED)
|
||||
@require_roles("admin")
|
||||
def create_priority(
|
||||
priority: PriorityCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user)
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Create a new priority"""
|
||||
# Check if priority with this ID already exists
|
||||
@@ -27,71 +24,69 @@ def create_priority(
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Priority with ID {priority.id} already exists"
|
||||
detail=f"Priority with ID {priority.id} already exists",
|
||||
)
|
||||
|
||||
|
||||
db_priority = Priority(**priority.model_dump())
|
||||
db.add(db_priority)
|
||||
db.commit()
|
||||
db.refresh(db_priority)
|
||||
return db_priority
|
||||
|
||||
@router.get("/", response_model=List[PriorityResponse])
|
||||
|
||||
@router.get("/", response_model=list[PriorityResponse])
|
||||
@require_roles("admin")
|
||||
def list_priorities(
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user)
|
||||
db: Session = Depends(get_db), user: CognitoUser = Depends(get_current_user)
|
||||
):
|
||||
"""List all priorities"""
|
||||
return db.query(Priority).all()
|
||||
|
||||
|
||||
@router.get("/{priority_id}", response_model=PriorityResponse)
|
||||
@require_roles("admin")
|
||||
def get_priority(
|
||||
priority_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user)
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Get a priority by id"""
|
||||
priority = db.get(Priority, priority_id)
|
||||
if not priority:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Priority not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Priority not found"
|
||||
)
|
||||
return priority
|
||||
|
||||
|
||||
@router.delete("/{priority_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
@require_roles("admin")
|
||||
def delete_priority(
|
||||
priority_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user)
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Delete a priority (if not in use)"""
|
||||
from app.models.db import ChannelURL
|
||||
|
||||
|
||||
# Check if priority exists
|
||||
priority = db.get(Priority, priority_id)
|
||||
if not priority:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Priority not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Priority not found"
|
||||
)
|
||||
|
||||
|
||||
# Check if priority is in use
|
||||
in_use = db.scalar(
|
||||
select(ChannelURL)
|
||||
.where(ChannelURL.priority_id == priority_id)
|
||||
.limit(1)
|
||||
select(ChannelURL).where(ChannelURL.priority_id == priority_id).limit(1)
|
||||
)
|
||||
|
||||
|
||||
if in_use:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Cannot delete priority that is in use by channel URLs"
|
||||
detail="Cannot delete priority that is in use by channel URLs",
|
||||
)
|
||||
|
||||
|
||||
db.execute(delete(Priority).where(Priority.id == priority_id))
|
||||
db.commit()
|
||||
return None
|
||||
return None
|
||||
|
||||
@@ -2,11 +2,13 @@ import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
|
||||
|
||||
def calculate_secret_hash(username: str, client_id: str, client_secret: str) -> str:
|
||||
"""
|
||||
Calculate the Cognito SECRET_HASH using HMAC SHA256 for secret-enabled clients.
|
||||
"""
|
||||
msg = username + client_id
|
||||
dig = hmac.new(client_secret.encode('utf-8'),
|
||||
msg.encode('utf-8'), hashlib.sha256).digest()
|
||||
return base64.b64encode(dig).decode()
|
||||
dig = hmac.new(
|
||||
client_secret.encode("utf-8"), msg.encode("utf-8"), hashlib.sha256
|
||||
).digest()
|
||||
return base64.b64encode(dig).decode()
|
||||
|
||||
@@ -1,41 +1,50 @@
|
||||
import os
|
||||
import argparse
|
||||
import requests
|
||||
import logging
|
||||
from requests.exceptions import RequestException, Timeout, ConnectionError, HTTPError
|
||||
import os
|
||||
|
||||
import requests
|
||||
from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout
|
||||
|
||||
|
||||
class StreamValidator:
|
||||
def __init__(self, timeout=10, user_agent=None):
|
||||
self.timeout = timeout
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({
|
||||
'User-Agent': user_agent or 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36'
|
||||
})
|
||||
|
||||
self.session.headers.update(
|
||||
{
|
||||
"User-Agent": user_agent
|
||||
or (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/120.0.0.0 Safari/537.36"
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
def validate_stream(self, url):
|
||||
"""Validate a media stream URL with multiple fallback checks"""
|
||||
try:
|
||||
headers = {'Range': 'bytes=0-1024'}
|
||||
headers = {"Range": "bytes=0-1024"}
|
||||
with self.session.get(
|
||||
url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
stream=True,
|
||||
allow_redirects=True
|
||||
allow_redirects=True,
|
||||
) as response:
|
||||
if response.status_code not in [200, 206]:
|
||||
return False, f"Invalid status code: {response.status_code}"
|
||||
|
||||
content_type = response.headers.get('Content-Type', '')
|
||||
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
if not self._is_valid_content_type(content_type):
|
||||
return False, f"Invalid content type: {content_type}"
|
||||
|
||||
|
||||
try:
|
||||
next(response.iter_content(chunk_size=1024))
|
||||
return True, "Stream is valid"
|
||||
except (ConnectionError, Timeout):
|
||||
return False, "Connection failed during content read"
|
||||
|
||||
|
||||
except HTTPError as e:
|
||||
return False, f"HTTP Error: {str(e)}"
|
||||
except ConnectionError as e:
|
||||
@@ -49,10 +58,13 @@ class StreamValidator:
|
||||
|
||||
def _is_valid_content_type(self, content_type):
|
||||
valid_types = [
|
||||
'video/mp2t', 'application/vnd.apple.mpegurl',
|
||||
'application/dash+xml', 'video/mp4',
|
||||
'video/webm', 'application/octet-stream',
|
||||
'application/x-mpegURL'
|
||||
"video/mp2t",
|
||||
"application/vnd.apple.mpegurl",
|
||||
"application/dash+xml",
|
||||
"video/mp4",
|
||||
"video/webm",
|
||||
"application/octet-stream",
|
||||
"application/x-mpegURL",
|
||||
]
|
||||
return any(ct in content_type for ct in valid_types)
|
||||
|
||||
@@ -60,45 +72,43 @@ class StreamValidator:
|
||||
"""Extract stream URLs from M3U playlist file"""
|
||||
urls = []
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
with open(file_path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
if line and not line.startswith("#"):
|
||||
urls.append(line)
|
||||
except Exception as e:
|
||||
logging.error(f"Error reading playlist file: {str(e)}")
|
||||
raise
|
||||
return urls
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Validate streaming URLs from command line arguments or playlist files',
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
description=(
|
||||
"Validate streaming URLs from command line arguments or playlist files"
|
||||
),
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
'sources',
|
||||
nargs='+',
|
||||
help='List of URLs or file paths containing stream URLs'
|
||||
"sources", nargs="+", help="List of URLs or file paths containing stream URLs"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--timeout',
|
||||
type=int,
|
||||
default=20,
|
||||
help='Timeout in seconds for stream checks'
|
||||
"--timeout", type=int, default=20, help="Timeout in seconds for stream checks"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output',
|
||||
default='deadstreams.txt',
|
||||
help='Output file name for inactive streams'
|
||||
"--output",
|
||||
default="deadstreams.txt",
|
||||
help="Output file name for inactive streams",
|
||||
)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[logging.FileHandler('stream_check.log'), logging.StreamHandler()]
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
handlers=[logging.FileHandler("stream_check.log"), logging.StreamHandler()],
|
||||
)
|
||||
|
||||
validator = StreamValidator(timeout=args.timeout)
|
||||
@@ -127,9 +137,10 @@ def main():
|
||||
|
||||
# Save dead streams to file
|
||||
if dead_streams:
|
||||
with open(args.output, 'w') as f:
|
||||
f.write('\n'.join(dead_streams))
|
||||
with open(args.output, "w") as f:
|
||||
f.write("\n".join(dead_streams))
|
||||
logging.info(f"Found {len(dead_streams)} dead streams. Saved to {args.output}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@@ -21,4 +21,4 @@ IPTV_SERVER_ADMIN_USER = os.getenv("IPTV_SERVER_ADMIN_USER", "admin")
|
||||
IPTV_SERVER_ADMIN_PASSWORD = os.getenv("IPTV_SERVER_ADMIN_PASSWORD", "adminpassword")
|
||||
|
||||
# URL for the EPG XML file to place in the playlist's header
|
||||
EPG_URL = os.getenv("EPG_URL", "https://example.com/epg.xml.gz")
|
||||
EPG_URL = os.getenv("EPG_URL", "https://example.com/epg.xml.gz")
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import os
|
||||
|
||||
import boto3
|
||||
from app.models import Base
|
||||
from .constants import AWS_REGION
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from functools import lru_cache
|
||||
|
||||
from app.models import Base
|
||||
|
||||
from .constants import AWS_REGION
|
||||
|
||||
|
||||
def get_db_credentials():
|
||||
"""Fetch and cache DB credentials from environment or SSM Parameter Store"""
|
||||
@@ -14,29 +16,40 @@ def get_db_credentials():
|
||||
f"postgresql://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}"
|
||||
f"@{os.getenv('DB_HOST')}/{os.getenv('DB_NAME')}"
|
||||
)
|
||||
|
||||
ssm = boto3.client('ssm', region_name=AWS_REGION)
|
||||
|
||||
ssm = boto3.client("ssm", region_name=AWS_REGION)
|
||||
try:
|
||||
host = ssm.get_parameter(Name='/iptv-updater/DB_HOST', WithDecryption=True)['Parameter']['Value']
|
||||
user = ssm.get_parameter(Name='/iptv-updater/DB_USER', WithDecryption=True)['Parameter']['Value']
|
||||
password = ssm.get_parameter(Name='/iptv-updater/DB_PASSWORD', WithDecryption=True)['Parameter']['Value']
|
||||
dbname = ssm.get_parameter(Name='/iptv-updater/DB_NAME', WithDecryption=True)['Parameter']['Value']
|
||||
host = ssm.get_parameter(Name="/iptv-updater/DB_HOST", WithDecryption=True)[
|
||||
"Parameter"
|
||||
]["Value"]
|
||||
user = ssm.get_parameter(Name="/iptv-updater/DB_USER", WithDecryption=True)[
|
||||
"Parameter"
|
||||
]["Value"]
|
||||
password = ssm.get_parameter(
|
||||
Name="/iptv-updater/DB_PASSWORD", WithDecryption=True
|
||||
)["Parameter"]["Value"]
|
||||
dbname = ssm.get_parameter(Name="/iptv-updater/DB_NAME", WithDecryption=True)[
|
||||
"Parameter"
|
||||
]["Value"]
|
||||
return f"postgresql://{user}:{password}@{host}/{dbname}"
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to fetch DB credentials from SSM: {str(e)}")
|
||||
|
||||
|
||||
# Initialize engine and session maker
|
||||
engine = create_engine(get_db_credentials())
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
def init_db():
|
||||
"""Initialize database by creating all tables"""
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
def get_db():
|
||||
"""Dependency for getting database session"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
db.close()
|
||||
|
||||
@@ -1,80 +1,69 @@
|
||||
import os
|
||||
from aws_cdk import (
|
||||
Duration,
|
||||
RemovalPolicy,
|
||||
Stack,
|
||||
aws_ec2 as ec2,
|
||||
aws_iam as iam,
|
||||
aws_cognito as cognito,
|
||||
aws_rds as rds,
|
||||
aws_ssm as ssm,
|
||||
CfnOutput
|
||||
)
|
||||
|
||||
from aws_cdk import CfnOutput, Duration, RemovalPolicy, Stack
|
||||
from aws_cdk import aws_cognito as cognito
|
||||
from aws_cdk import aws_ec2 as ec2
|
||||
from aws_cdk import aws_iam as iam
|
||||
from aws_cdk import aws_rds as rds
|
||||
from aws_cdk import aws_ssm as ssm
|
||||
from constructs import Construct
|
||||
|
||||
|
||||
class IptvUpdaterStack(Stack):
|
||||
def __init__(
|
||||
self,
|
||||
scope: Construct,
|
||||
construct_id: str,
|
||||
freedns_user: str,
|
||||
freedns_password: str,
|
||||
domain_name: str,
|
||||
ssh_public_key: str,
|
||||
repo_url: str,
|
||||
letsencrypt_email: str,
|
||||
**kwargs
|
||||
) -> None:
|
||||
self,
|
||||
scope: Construct,
|
||||
construct_id: str,
|
||||
freedns_user: str,
|
||||
freedns_password: str,
|
||||
domain_name: str,
|
||||
ssh_public_key: str,
|
||||
repo_url: str,
|
||||
letsencrypt_email: str,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(scope, construct_id, **kwargs)
|
||||
|
||||
# Create VPC
|
||||
vpc = ec2.Vpc(self, "IptvUpdaterVPC",
|
||||
vpc = ec2.Vpc(
|
||||
self,
|
||||
"IptvUpdaterVPC",
|
||||
max_azs=2, # Need at least 2 AZs for RDS subnet group
|
||||
nat_gateways=0, # No NAT Gateway to stay in free tier
|
||||
subnet_configuration=[
|
||||
ec2.SubnetConfiguration(
|
||||
name="public",
|
||||
subnet_type=ec2.SubnetType.PUBLIC,
|
||||
cidr_mask=24
|
||||
name="public", subnet_type=ec2.SubnetType.PUBLIC, cidr_mask=24
|
||||
),
|
||||
ec2.SubnetConfiguration(
|
||||
name="private",
|
||||
subnet_type=ec2.SubnetType.PRIVATE_ISOLATED,
|
||||
cidr_mask=24
|
||||
)
|
||||
]
|
||||
cidr_mask=24,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Security Group
|
||||
security_group = ec2.SecurityGroup(
|
||||
self, "IptvUpdaterSG",
|
||||
vpc=vpc,
|
||||
allow_all_outbound=True
|
||||
self, "IptvUpdaterSG", vpc=vpc, allow_all_outbound=True
|
||||
)
|
||||
|
||||
security_group.add_ingress_rule(
|
||||
ec2.Peer.any_ipv4(),
|
||||
ec2.Port.tcp(443),
|
||||
"Allow HTTPS traffic"
|
||||
)
|
||||
|
||||
security_group.add_ingress_rule(
|
||||
ec2.Peer.any_ipv4(),
|
||||
ec2.Port.tcp(80),
|
||||
"Allow HTTP traffic"
|
||||
ec2.Peer.any_ipv4(), ec2.Port.tcp(443), "Allow HTTPS traffic"
|
||||
)
|
||||
|
||||
security_group.add_ingress_rule(
|
||||
ec2.Peer.any_ipv4(),
|
||||
ec2.Port.tcp(22),
|
||||
"Allow SSH traffic"
|
||||
ec2.Peer.any_ipv4(), ec2.Port.tcp(80), "Allow HTTP traffic"
|
||||
)
|
||||
|
||||
security_group.add_ingress_rule(
|
||||
ec2.Peer.any_ipv4(), ec2.Port.tcp(22), "Allow SSH traffic"
|
||||
)
|
||||
|
||||
# Allow PostgreSQL port for tunneling restricted to developer IP
|
||||
security_group.add_ingress_rule(
|
||||
ec2.Peer.ipv4("47.189.88.48/32"), # Developer IP
|
||||
ec2.Port.tcp(5432),
|
||||
"Allow PostgreSQL traffic for tunneling"
|
||||
"Allow PostgreSQL traffic for tunneling",
|
||||
)
|
||||
|
||||
# Key pair for IPTV Updater instance
|
||||
@@ -82,13 +71,14 @@ class IptvUpdaterStack(Stack):
|
||||
self,
|
||||
"IptvUpdaterKeyPair",
|
||||
key_pair_name="iptv-updater-key",
|
||||
public_key_material=ssh_public_key
|
||||
public_key_material=ssh_public_key,
|
||||
)
|
||||
|
||||
# Create IAM role for EC2
|
||||
role = iam.Role(
|
||||
self, "IptvUpdaterRole",
|
||||
assumed_by=iam.ServicePrincipal("ec2.amazonaws.com")
|
||||
self,
|
||||
"IptvUpdaterRole",
|
||||
assumed_by=iam.ServicePrincipal("ec2.amazonaws.com"),
|
||||
)
|
||||
|
||||
# Add SSM managed policy
|
||||
@@ -99,37 +89,36 @@ class IptvUpdaterStack(Stack):
|
||||
)
|
||||
|
||||
# Add EC2 describe permissions
|
||||
role.add_to_policy(iam.PolicyStatement(
|
||||
actions=["ec2:DescribeInstances"],
|
||||
resources=["*"]
|
||||
))
|
||||
role.add_to_policy(
|
||||
iam.PolicyStatement(actions=["ec2:DescribeInstances"], resources=["*"])
|
||||
)
|
||||
|
||||
# Add SSM SendCommand permissions
|
||||
role.add_to_policy(iam.PolicyStatement(
|
||||
actions=["ssm:SendCommand"],
|
||||
resources=[
|
||||
f"arn:aws:ec2:{self.region}:{self.account}:instance/*", # Allow on all EC2 instances
|
||||
f"arn:aws:ssm:{self.region}:{self.account}:document/AWS-RunShellScript" # Required for the RunShellScript document
|
||||
]
|
||||
))
|
||||
role.add_to_policy(
|
||||
iam.PolicyStatement(
|
||||
actions=["ssm:SendCommand"],
|
||||
resources=[
|
||||
# Allow on all EC2 instances
|
||||
f"arn:aws:ec2:{self.region}:{self.account}:instance/*",
|
||||
# Required for the RunShellScript document
|
||||
f"arn:aws:ssm:{self.region}:{self.account}:document/AWS-RunShellScript",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
# Add Cognito permissions to instance role
|
||||
role.add_managed_policy(
|
||||
iam.ManagedPolicy.from_aws_managed_policy_name(
|
||||
"AmazonCognitoReadOnly"
|
||||
)
|
||||
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonCognitoReadOnly")
|
||||
)
|
||||
|
||||
# EC2 Instance
|
||||
instance = ec2.Instance(
|
||||
self, "IptvUpdaterInstance",
|
||||
self,
|
||||
"IptvUpdaterInstance",
|
||||
vpc=vpc,
|
||||
vpc_subnets=ec2.SubnetSelection(
|
||||
subnet_type=ec2.SubnetType.PUBLIC
|
||||
),
|
||||
vpc_subnets=ec2.SubnetSelection(subnet_type=ec2.SubnetType.PUBLIC),
|
||||
instance_type=ec2.InstanceType.of(
|
||||
ec2.InstanceClass.T2,
|
||||
ec2.InstanceSize.MICRO
|
||||
ec2.InstanceClass.T2, ec2.InstanceSize.MICRO
|
||||
),
|
||||
machine_image=ec2.AmazonLinuxImage(
|
||||
generation=ec2.AmazonLinuxGeneration.AMAZON_LINUX_2023
|
||||
@@ -138,7 +127,7 @@ class IptvUpdaterStack(Stack):
|
||||
key_pair=key_pair,
|
||||
role=role,
|
||||
# Option: 1: Enable auto-assign public IP (free tier compatible)
|
||||
associate_public_ip_address=True
|
||||
associate_public_ip_address=True,
|
||||
)
|
||||
|
||||
# Option: 2: Create Elastic IP (not free tier compatible)
|
||||
@@ -150,7 +139,8 @@ class IptvUpdaterStack(Stack):
|
||||
|
||||
# Add Cognito User Pool
|
||||
user_pool = cognito.UserPool(
|
||||
self, "IptvUpdaterUserPool",
|
||||
self,
|
||||
"IptvUpdaterUserPool",
|
||||
user_pool_name="iptv-updater-users",
|
||||
self_sign_up_enabled=False, # Only admins can create users
|
||||
password_policy=cognito.PasswordPolicy(
|
||||
@@ -158,37 +148,33 @@ class IptvUpdaterStack(Stack):
|
||||
require_lowercase=True,
|
||||
require_digits=True,
|
||||
require_symbols=True,
|
||||
require_uppercase=True
|
||||
require_uppercase=True,
|
||||
),
|
||||
account_recovery=cognito.AccountRecovery.EMAIL_ONLY,
|
||||
removal_policy=RemovalPolicy.DESTROY
|
||||
removal_policy=RemovalPolicy.DESTROY,
|
||||
)
|
||||
|
||||
# Add App Client with the correct callback URL
|
||||
client = user_pool.add_client("IptvUpdaterClient",
|
||||
client = user_pool.add_client(
|
||||
"IptvUpdaterClient",
|
||||
access_token_validity=Duration.minutes(60),
|
||||
id_token_validity=Duration.minutes(60),
|
||||
refresh_token_validity=Duration.days(1),
|
||||
auth_flows=cognito.AuthFlow(
|
||||
user_password=True
|
||||
),
|
||||
auth_flows=cognito.AuthFlow(user_password=True),
|
||||
o_auth=cognito.OAuthSettings(
|
||||
flows=cognito.OAuthFlows(
|
||||
implicit_code_grant=True
|
||||
)
|
||||
flows=cognito.OAuthFlows(implicit_code_grant=True)
|
||||
),
|
||||
prevent_user_existence_errors=True,
|
||||
generate_secret=True,
|
||||
enable_token_revocation=True
|
||||
enable_token_revocation=True,
|
||||
)
|
||||
|
||||
# Add domain for hosted UI
|
||||
domain = user_pool.add_domain("IptvUpdaterDomain",
|
||||
cognito_domain=cognito.CognitoDomainOptions(
|
||||
domain_prefix="iptv-updater"
|
||||
)
|
||||
domain = user_pool.add_domain(
|
||||
"IptvUpdaterDomain",
|
||||
cognito_domain=cognito.CognitoDomainOptions(domain_prefix="iptv-updater"),
|
||||
)
|
||||
|
||||
|
||||
# Read the userdata script with proper path resolution
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
userdata_path = os.path.join(script_dir, "userdata.sh")
|
||||
@@ -196,46 +182,56 @@ class IptvUpdaterStack(Stack):
|
||||
|
||||
# Creates a userdata object for Linux hosts
|
||||
userdata = ec2.UserData.for_linux()
|
||||
|
||||
|
||||
# Add environment variables for acme.sh from parameters
|
||||
userdata.add_commands(
|
||||
f'export FREEDNS_User="{freedns_user}"',
|
||||
f'export FREEDNS_Password="{freedns_password}"',
|
||||
f'export DOMAIN_NAME="{domain_name}"',
|
||||
f'export REPO_URL="{repo_url}"',
|
||||
f'export LETSENCRYPT_EMAIL="{letsencrypt_email}"'
|
||||
f'export LETSENCRYPT_EMAIL="{letsencrypt_email}"',
|
||||
)
|
||||
|
||||
|
||||
# Adds one or more commands to the userdata object.
|
||||
userdata.add_commands(
|
||||
f'echo "COGNITO_USER_POOL_ID={user_pool.user_pool_id}" >> /etc/environment',
|
||||
f'echo "COGNITO_CLIENT_ID={client.user_pool_client_id}" >> /etc/environment',
|
||||
f'echo "COGNITO_CLIENT_SECRET={client.user_pool_client_secret.to_string()}" >> /etc/environment',
|
||||
f'echo "DOMAIN_NAME={domain_name}" >> /etc/environment'
|
||||
(
|
||||
f'echo "COGNITO_USER_POOL_ID='
|
||||
f'{user_pool.user_pool_id}" >> /etc/environment'
|
||||
),
|
||||
(
|
||||
f'echo "COGNITO_CLIENT_ID='
|
||||
f'{client.user_pool_client_id}" >> /etc/environment'
|
||||
),
|
||||
(
|
||||
f'echo "COGNITO_CLIENT_SECRET='
|
||||
f'{client.user_pool_client_secret.to_string()}" >> /etc/environment'
|
||||
),
|
||||
f'echo "DOMAIN_NAME={domain_name}" >> /etc/environment',
|
||||
)
|
||||
userdata.add_commands(str(userdata_file, 'utf-8'))
|
||||
userdata.add_commands(str(userdata_file, "utf-8"))
|
||||
|
||||
# Create RDS Security Group
|
||||
rds_sg = ec2.SecurityGroup(
|
||||
self, "RdsSecurityGroup",
|
||||
self,
|
||||
"RdsSecurityGroup",
|
||||
vpc=vpc,
|
||||
description="Security group for RDS PostgreSQL"
|
||||
description="Security group for RDS PostgreSQL",
|
||||
)
|
||||
rds_sg.add_ingress_rule(
|
||||
security_group,
|
||||
ec2.Port.tcp(5432),
|
||||
"Allow PostgreSQL access from EC2 instance"
|
||||
"Allow PostgreSQL access from EC2 instance",
|
||||
)
|
||||
|
||||
# Create RDS PostgreSQL instance (free tier compatible - db.t3.micro)
|
||||
db = rds.DatabaseInstance(
|
||||
self, "IptvUpdaterDB",
|
||||
self,
|
||||
"IptvUpdaterDB",
|
||||
engine=rds.DatabaseInstanceEngine.postgres(
|
||||
version=rds.PostgresEngineVersion.VER_13
|
||||
),
|
||||
instance_type=ec2.InstanceType.of(
|
||||
ec2.InstanceClass.T3,
|
||||
ec2.InstanceSize.MICRO
|
||||
ec2.InstanceClass.T3, ec2.InstanceSize.MICRO
|
||||
),
|
||||
vpc=vpc,
|
||||
vpc_subnets=ec2.SubnetSelection(
|
||||
@@ -247,39 +243,43 @@ class IptvUpdaterStack(Stack):
|
||||
database_name="iptv_updater",
|
||||
removal_policy=RemovalPolicy.DESTROY,
|
||||
deletion_protection=False,
|
||||
publicly_accessible=False # Avoid public IPv4 charges
|
||||
publicly_accessible=False, # Avoid public IPv4 charges
|
||||
)
|
||||
|
||||
# Add RDS permissions to instance role
|
||||
role.add_managed_policy(
|
||||
iam.ManagedPolicy.from_aws_managed_policy_name(
|
||||
"AmazonRDSFullAccess"
|
||||
)
|
||||
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonRDSFullAccess")
|
||||
)
|
||||
|
||||
# Store DB connection info in SSM Parameter Store
|
||||
ssm.StringParameter(self, "DBHostParam",
|
||||
ssm.StringParameter(
|
||||
self,
|
||||
"DBHostParam",
|
||||
parameter_name="/iptv-updater/DB_HOST",
|
||||
string_value=db.db_instance_endpoint_address
|
||||
string_value=db.db_instance_endpoint_address,
|
||||
)
|
||||
ssm.StringParameter(self, "DBNameParam",
|
||||
ssm.StringParameter(
|
||||
self,
|
||||
"DBNameParam",
|
||||
parameter_name="/iptv-updater/DB_NAME",
|
||||
string_value="iptv_updater"
|
||||
string_value="iptv_updater",
|
||||
)
|
||||
ssm.StringParameter(self, "DBUserParam",
|
||||
ssm.StringParameter(
|
||||
self,
|
||||
"DBUserParam",
|
||||
parameter_name="/iptv-updater/DB_USER",
|
||||
string_value=db.secret.secret_value_from_json("username").to_string()
|
||||
string_value=db.secret.secret_value_from_json("username").to_string(),
|
||||
)
|
||||
ssm.StringParameter(self, "DBPassParam",
|
||||
ssm.StringParameter(
|
||||
self,
|
||||
"DBPassParam",
|
||||
parameter_name="/iptv-updater/DB_PASSWORD",
|
||||
string_value=db.secret.secret_value_from_json("password").to_string()
|
||||
string_value=db.secret.secret_value_from_json("password").to_string(),
|
||||
)
|
||||
|
||||
# Add SSM read permissions to instance role
|
||||
role.add_managed_policy(
|
||||
iam.ManagedPolicy.from_aws_managed_policy_name(
|
||||
"AmazonSSMReadOnlyAccess"
|
||||
)
|
||||
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonSSMReadOnlyAccess")
|
||||
)
|
||||
|
||||
# Update instance with userdata
|
||||
@@ -293,6 +293,8 @@ class IptvUpdaterStack(Stack):
|
||||
# CfnOutput(self, "InstancePublicIP", value=eip.attr_public_ip)
|
||||
CfnOutput(self, "UserPoolId", value=user_pool.user_pool_id)
|
||||
CfnOutput(self, "UserPoolClientId", value=client.user_pool_client_id)
|
||||
CfnOutput(self, "CognitoDomainUrl",
|
||||
value=f"https://{domain.domain_name}.auth.{self.region}.amazoncognito.com"
|
||||
)
|
||||
CfnOutput(
|
||||
self,
|
||||
"CognitoDomainUrl",
|
||||
value=f"https://{domain.domain_name}.auth.{self.region}.amazoncognito.com",
|
||||
)
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
[tool.ruff]
|
||||
line-length = 88
|
||||
exclude = [
|
||||
"alembic/versions/*.py", # Auto-generated Alembic migration files
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle errors
|
||||
"F", # pyflakes
|
||||
@@ -9,7 +14,13 @@ select = [
|
||||
]
|
||||
ignore = []
|
||||
|
||||
[tool.ruff.isort]
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/**/*.py" = [
|
||||
"F811", # redefinition of unused name
|
||||
"F401", # unused import
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["app"]
|
||||
|
||||
[tool.ruff.format]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
# Test constants
|
||||
@@ -7,12 +8,15 @@ TEST_CLIENT_ID = "test_client_id"
|
||||
TEST_CLIENT_SECRET = "test_client_secret"
|
||||
|
||||
# Patch constants before importing the module
|
||||
with patch("app.utils.constants.COGNITO_CLIENT_ID", TEST_CLIENT_ID), \
|
||||
patch("app.utils.constants.COGNITO_CLIENT_SECRET", TEST_CLIENT_SECRET):
|
||||
from app.auth.cognito import initiate_auth, get_user_from_token
|
||||
with (
|
||||
patch("app.utils.constants.COGNITO_CLIENT_ID", TEST_CLIENT_ID),
|
||||
patch("app.utils.constants.COGNITO_CLIENT_SECRET", TEST_CLIENT_SECRET),
|
||||
):
|
||||
from app.auth.cognito import get_user_from_token, initiate_auth
|
||||
from app.models.auth import CognitoUser
|
||||
from app.utils.constants import USER_ROLE_ATTRIBUTE
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_cognito_client():
|
||||
with patch("app.auth.cognito.cognito_client") as mock_client:
|
||||
@@ -26,13 +30,14 @@ def mock_cognito_client():
|
||||
)
|
||||
yield mock_client
|
||||
|
||||
|
||||
def test_initiate_auth_success(mock_cognito_client):
|
||||
# Mock successful authentication response
|
||||
mock_cognito_client.initiate_auth.return_value = {
|
||||
"AuthenticationResult": {
|
||||
"AccessToken": "mock_access_token",
|
||||
"IdToken": "mock_id_token",
|
||||
"RefreshToken": "mock_refresh_token"
|
||||
"RefreshToken": "mock_refresh_token",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,104 +45,125 @@ def test_initiate_auth_success(mock_cognito_client):
|
||||
assert result == {
|
||||
"AccessToken": "mock_access_token",
|
||||
"IdToken": "mock_id_token",
|
||||
"RefreshToken": "mock_refresh_token"
|
||||
"RefreshToken": "mock_refresh_token",
|
||||
}
|
||||
|
||||
|
||||
def test_initiate_auth_with_secret_hash(mock_cognito_client):
|
||||
with patch("app.auth.cognito.calculate_secret_hash", return_value="mocked_secret_hash") as mock_hash:
|
||||
with patch(
|
||||
"app.auth.cognito.calculate_secret_hash", return_value="mocked_secret_hash"
|
||||
) as mock_hash:
|
||||
mock_cognito_client.initiate_auth.return_value = {
|
||||
"AuthenticationResult": {"AccessToken": "token"}
|
||||
}
|
||||
|
||||
result = initiate_auth("test_user", "test_pass")
|
||||
|
||||
|
||||
initiate_auth("test_user", "test_pass")
|
||||
|
||||
# Verify calculate_secret_hash was called
|
||||
mock_hash.assert_called_once_with("test_user", TEST_CLIENT_ID, TEST_CLIENT_SECRET)
|
||||
|
||||
mock_hash.assert_called_once_with(
|
||||
"test_user", TEST_CLIENT_ID, TEST_CLIENT_SECRET
|
||||
)
|
||||
|
||||
# Verify SECRET_HASH was included in auth params
|
||||
call_args = mock_cognito_client.initiate_auth.call_args[1]
|
||||
assert "SECRET_HASH" in call_args["AuthParameters"]
|
||||
assert call_args["AuthParameters"]["SECRET_HASH"] == "mocked_secret_hash"
|
||||
|
||||
|
||||
def test_initiate_auth_not_authorized(mock_cognito_client):
|
||||
mock_cognito_client.initiate_auth.side_effect = mock_cognito_client.exceptions.NotAuthorizedException()
|
||||
|
||||
mock_cognito_client.initiate_auth.side_effect = (
|
||||
mock_cognito_client.exceptions.NotAuthorizedException()
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
initiate_auth("invalid_user", "wrong_pass")
|
||||
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid username or password"
|
||||
|
||||
|
||||
def test_initiate_auth_user_not_found(mock_cognito_client):
|
||||
mock_cognito_client.initiate_auth.side_effect = mock_cognito_client.exceptions.UserNotFoundException()
|
||||
|
||||
mock_cognito_client.initiate_auth.side_effect = (
|
||||
mock_cognito_client.exceptions.UserNotFoundException()
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
initiate_auth("nonexistent_user", "any_pass")
|
||||
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert exc_info.value.detail == "User not found"
|
||||
|
||||
|
||||
def test_initiate_auth_generic_error(mock_cognito_client):
|
||||
mock_cognito_client.initiate_auth.side_effect = Exception("Some error")
|
||||
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
initiate_auth("test_user", "test_pass")
|
||||
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert "An error occurred during authentication" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_get_user_from_token_success(mock_cognito_client):
|
||||
mock_response = {
|
||||
"Username": "test_user",
|
||||
"UserAttributes": [
|
||||
{"Name": "sub", "Value": "123"},
|
||||
{"Name": USER_ROLE_ATTRIBUTE, "Value": "admin,user"}
|
||||
]
|
||||
{"Name": USER_ROLE_ATTRIBUTE, "Value": "admin,user"},
|
||||
],
|
||||
}
|
||||
mock_cognito_client.get_user.return_value = mock_response
|
||||
|
||||
|
||||
result = get_user_from_token("valid_token")
|
||||
|
||||
|
||||
assert isinstance(result, CognitoUser)
|
||||
assert result.username == "test_user"
|
||||
assert set(result.roles) == {"admin", "user"}
|
||||
|
||||
|
||||
def test_get_user_from_token_no_roles(mock_cognito_client):
|
||||
mock_response = {
|
||||
"Username": "test_user",
|
||||
"UserAttributes": [{"Name": "sub", "Value": "123"}]
|
||||
"UserAttributes": [{"Name": "sub", "Value": "123"}],
|
||||
}
|
||||
mock_cognito_client.get_user.return_value = mock_response
|
||||
|
||||
|
||||
result = get_user_from_token("valid_token")
|
||||
|
||||
|
||||
assert isinstance(result, CognitoUser)
|
||||
assert result.username == "test_user"
|
||||
assert result.roles == []
|
||||
|
||||
|
||||
def test_get_user_from_token_invalid_token(mock_cognito_client):
|
||||
mock_cognito_client.get_user.side_effect = mock_cognito_client.exceptions.NotAuthorizedException()
|
||||
|
||||
mock_cognito_client.get_user.side_effect = (
|
||||
mock_cognito_client.exceptions.NotAuthorizedException()
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_user_from_token("invalid_token")
|
||||
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid or expired token."
|
||||
|
||||
|
||||
def test_get_user_from_token_user_not_found(mock_cognito_client):
|
||||
mock_cognito_client.get_user.side_effect = mock_cognito_client.exceptions.UserNotFoundException()
|
||||
|
||||
mock_cognito_client.get_user.side_effect = (
|
||||
mock_cognito_client.exceptions.UserNotFoundException()
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_user_from_token("token_for_nonexistent_user")
|
||||
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "User not found or invalid token."
|
||||
|
||||
|
||||
def test_get_user_from_token_generic_error(mock_cognito_client):
|
||||
mock_cognito_client.get_user.side_effect = Exception("Some error")
|
||||
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_user_from_token("test_token")
|
||||
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert "Token verification failed" in exc_info.value.detail
|
||||
assert "Token verification failed" in exc_info.value.detail
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import os
|
||||
import pytest
|
||||
import importlib
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi import HTTPException, Depends, Request
|
||||
from app.auth.dependencies import get_current_user, require_roles, oauth2_scheme
|
||||
|
||||
from app.auth.dependencies import get_current_user, oauth2_scheme, require_roles
|
||||
from app.models.auth import CognitoUser
|
||||
|
||||
# Mock user for testing
|
||||
@@ -11,24 +13,30 @@ TEST_USER = CognitoUser(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
roles=["admin", "user"],
|
||||
groups=["test_group"]
|
||||
groups=["test_group"],
|
||||
)
|
||||
|
||||
|
||||
# Mock the underlying get_user_from_token function
|
||||
def mock_get_user_from_token(token: str) -> CognitoUser:
|
||||
if token == "valid_token":
|
||||
return TEST_USER
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
|
||||
# Mock endpoint for testing the require_roles decorator
|
||||
@require_roles("admin")
|
||||
async def mock_protected_endpoint(user: CognitoUser = Depends(get_current_user)):
|
||||
return {"message": "Success", "user": user.username}
|
||||
|
||||
|
||||
# Patch the get_user_from_token function for testing
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth(monkeypatch):
|
||||
monkeypatch.setattr("app.auth.dependencies.get_user_from_token", mock_get_user_from_token)
|
||||
monkeypatch.setattr(
|
||||
"app.auth.dependencies.get_user_from_token", mock_get_user_from_token
|
||||
)
|
||||
|
||||
|
||||
# Test get_current_user dependency
|
||||
def test_get_current_user_success():
|
||||
@@ -37,54 +45,53 @@ def test_get_current_user_success():
|
||||
assert user.username == "testuser"
|
||||
assert user.roles == ["admin", "user"]
|
||||
|
||||
|
||||
def test_get_current_user_invalid_token():
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
get_current_user("invalid_token")
|
||||
assert exc.value.status_code == 401
|
||||
|
||||
|
||||
# Test require_roles decorator
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_roles_success():
|
||||
# Create test user with required role
|
||||
user = CognitoUser(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
roles=["admin"],
|
||||
groups=[]
|
||||
username="testuser", email="test@example.com", roles=["admin"], groups=[]
|
||||
)
|
||||
|
||||
|
||||
result = await mock_protected_endpoint(user=user)
|
||||
assert result == {"message": "Success", "user": "testuser"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_roles_missing_role():
|
||||
# Create test user without required role
|
||||
user = CognitoUser(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
roles=["user"],
|
||||
groups=[]
|
||||
username="testuser", email="test@example.com", roles=["user"], groups=[]
|
||||
)
|
||||
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await mock_protected_endpoint(user=user)
|
||||
assert exc.value.status_code == 403
|
||||
assert exc.value.detail == "You do not have the required roles to access this endpoint."
|
||||
assert (
|
||||
exc.value.detail
|
||||
== "You do not have the required roles to access this endpoint."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_roles_no_roles():
|
||||
# Create test user with no roles
|
||||
user = CognitoUser(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
roles=[],
|
||||
groups=[]
|
||||
username="testuser", email="test@example.com", roles=[], groups=[]
|
||||
)
|
||||
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await mock_protected_endpoint(user=user)
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_roles_multiple_roles():
|
||||
# Test requiring multiple roles
|
||||
@@ -97,7 +104,7 @@ async def test_require_roles_multiple_roles():
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
roles=["admin", "super_user", "user"],
|
||||
groups=[]
|
||||
groups=[],
|
||||
)
|
||||
result = await mock_multi_role_endpoint(user=user_with_roles)
|
||||
assert result == {"message": "Success"}
|
||||
@@ -107,56 +114,62 @@ async def test_require_roles_multiple_roles():
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
roles=["admin", "user"],
|
||||
groups=[]
|
||||
groups=[],
|
||||
)
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await mock_multi_role_endpoint(user=user_missing_role)
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth2_scheme_configuration():
|
||||
# Verify that we have a properly configured OAuth2PasswordBearer instance
|
||||
assert isinstance(oauth2_scheme, OAuth2PasswordBearer)
|
||||
|
||||
|
||||
# Create a mock request with no Authorization header
|
||||
mock_request = Request(scope={
|
||||
'type': 'http',
|
||||
'headers': [],
|
||||
'method': 'GET',
|
||||
'scheme': 'http',
|
||||
'path': '/',
|
||||
'query_string': b'',
|
||||
'client': ('127.0.0.1', 8000)
|
||||
})
|
||||
|
||||
mock_request = Request(
|
||||
scope={
|
||||
"type": "http",
|
||||
"headers": [],
|
||||
"method": "GET",
|
||||
"scheme": "http",
|
||||
"path": "/",
|
||||
"query_string": b"",
|
||||
"client": ("127.0.0.1", 8000),
|
||||
}
|
||||
)
|
||||
|
||||
# Test that the scheme raises 401 when no token is provided
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await oauth2_scheme(mock_request)
|
||||
assert exc.value.status_code == 401
|
||||
assert exc.value.detail == "Not authenticated"
|
||||
|
||||
|
||||
def test_mock_auth_import(monkeypatch):
|
||||
# Save original env var value
|
||||
original_value = os.environ.get("MOCK_AUTH")
|
||||
|
||||
|
||||
try:
|
||||
# Set MOCK_AUTH to true
|
||||
monkeypatch.setenv("MOCK_AUTH", "true")
|
||||
|
||||
|
||||
# Reload the dependencies module to trigger the import condition
|
||||
import app.auth.dependencies
|
||||
|
||||
importlib.reload(app.auth.dependencies)
|
||||
|
||||
|
||||
# Verify that mock_get_user_from_token was imported
|
||||
from app.auth.dependencies import get_user_from_token
|
||||
assert get_user_from_token.__module__ == 'app.auth.mock_auth'
|
||||
|
||||
|
||||
assert get_user_from_token.__module__ == "app.auth.mock_auth"
|
||||
|
||||
finally:
|
||||
# Restore original env var
|
||||
if original_value is None:
|
||||
monkeypatch.delenv("MOCK_AUTH", raising=False)
|
||||
else:
|
||||
monkeypatch.setenv("MOCK_AUTH", original_value)
|
||||
|
||||
|
||||
# Reload again to restore original state
|
||||
importlib.reload(app.auth.dependencies)
|
||||
importlib.reload(app.auth.dependencies)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.auth.mock_auth import mock_get_user_from_token, mock_initiate_auth
|
||||
from app.models.auth import CognitoUser
|
||||
|
||||
|
||||
def test_mock_get_user_from_token_success():
|
||||
"""Test successful token validation returns expected user"""
|
||||
user = mock_get_user_from_token("testuser")
|
||||
@@ -10,27 +12,30 @@ def test_mock_get_user_from_token_success():
|
||||
assert user.username == "testuser"
|
||||
assert user.roles == ["admin"]
|
||||
|
||||
|
||||
def test_mock_get_user_from_token_invalid():
|
||||
"""Test invalid token raises expected exception"""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
mock_get_user_from_token("invalid_token")
|
||||
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail == "Invalid mock token - use 'testuser'"
|
||||
|
||||
|
||||
def test_mock_initiate_auth():
|
||||
"""Test mock authentication returns expected token response"""
|
||||
result = mock_initiate_auth("any_user", "any_password")
|
||||
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["AccessToken"] == "testuser"
|
||||
assert result["ExpiresIn"] == 3600
|
||||
assert result["TokenType"] == "Bearer"
|
||||
|
||||
|
||||
def test_mock_initiate_auth_different_credentials():
|
||||
"""Test mock authentication works with any credentials"""
|
||||
result1 = mock_initiate_auth("user1", "pass1")
|
||||
result2 = mock_initiate_auth("user2", "pass2")
|
||||
|
||||
|
||||
# Should return same mock token regardless of credentials
|
||||
assert result1 == result2
|
||||
assert result1 == result2
|
||||
|
||||
@@ -1,34 +1,35 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import HTTPException, status
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.main import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_successful_auth():
|
||||
return {
|
||||
"AccessToken": "mock_access_token",
|
||||
"IdToken": "mock_id_token",
|
||||
"RefreshToken": "mock_refresh_token"
|
||||
"RefreshToken": "mock_refresh_token",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_successful_auth_no_refresh():
|
||||
return {
|
||||
"AccessToken": "mock_access_token",
|
||||
"IdToken": "mock_id_token"
|
||||
}
|
||||
return {"AccessToken": "mock_access_token", "IdToken": "mock_id_token"}
|
||||
|
||||
|
||||
def test_signin_success(mock_successful_auth):
|
||||
"""Test successful signin with all tokens"""
|
||||
with patch('app.routers.auth.initiate_auth', return_value=mock_successful_auth):
|
||||
with patch("app.routers.auth.initiate_auth", return_value=mock_successful_auth):
|
||||
response = client.post(
|
||||
"/auth/signin",
|
||||
json={"username": "testuser", "password": "testpass"}
|
||||
"/auth/signin", json={"username": "testuser", "password": "testpass"}
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["access_token"] == "mock_access_token"
|
||||
@@ -36,14 +37,16 @@ def test_signin_success(mock_successful_auth):
|
||||
assert data["refresh_token"] == "mock_refresh_token"
|
||||
assert data["token_type"] == "Bearer"
|
||||
|
||||
|
||||
def test_signin_success_no_refresh(mock_successful_auth_no_refresh):
|
||||
"""Test successful signin without refresh token"""
|
||||
with patch('app.routers.auth.initiate_auth', return_value=mock_successful_auth_no_refresh):
|
||||
with patch(
|
||||
"app.routers.auth.initiate_auth", return_value=mock_successful_auth_no_refresh
|
||||
):
|
||||
response = client.post(
|
||||
"/auth/signin",
|
||||
json={"username": "testuser", "password": "testpass"}
|
||||
"/auth/signin", json={"username": "testuser", "password": "testpass"}
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["access_token"] == "mock_access_token"
|
||||
@@ -51,57 +54,48 @@ def test_signin_success_no_refresh(mock_successful_auth_no_refresh):
|
||||
assert data["refresh_token"] is None
|
||||
assert data["token_type"] == "Bearer"
|
||||
|
||||
|
||||
def test_signin_invalid_input():
|
||||
"""Test signin with invalid input format"""
|
||||
# Missing password
|
||||
response = client.post(
|
||||
"/auth/signin",
|
||||
json={"username": "testuser"}
|
||||
)
|
||||
response = client.post("/auth/signin", json={"username": "testuser"})
|
||||
assert response.status_code == 422
|
||||
|
||||
# Missing username
|
||||
response = client.post(
|
||||
"/auth/signin",
|
||||
json={"password": "testpass"}
|
||||
)
|
||||
response = client.post("/auth/signin", json={"password": "testpass"})
|
||||
assert response.status_code == 422
|
||||
|
||||
# Empty payload
|
||||
response = client.post(
|
||||
"/auth/signin",
|
||||
json={}
|
||||
)
|
||||
response = client.post("/auth/signin", json={})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_signin_auth_failure():
|
||||
"""Test signin with authentication failure"""
|
||||
with patch('app.routers.auth.initiate_auth') as mock_auth:
|
||||
with patch("app.routers.auth.initiate_auth") as mock_auth:
|
||||
mock_auth.side_effect = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password"
|
||||
detail="Invalid username or password",
|
||||
)
|
||||
response = client.post(
|
||||
"/auth/signin",
|
||||
json={"username": "testuser", "password": "wrongpass"}
|
||||
"/auth/signin", json={"username": "testuser", "password": "wrongpass"}
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 401
|
||||
data = response.json()
|
||||
assert data["detail"] == "Invalid username or password"
|
||||
|
||||
|
||||
def test_signin_user_not_found():
|
||||
"""Test signin with non-existent user"""
|
||||
with patch('app.routers.auth.initiate_auth') as mock_auth:
|
||||
with patch("app.routers.auth.initiate_auth") as mock_auth:
|
||||
mock_auth.side_effect = HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
response = client.post(
|
||||
"/auth/signin",
|
||||
json={"username": "nonexistent", "password": "testpass"}
|
||||
"/auth/signin", json={"username": "nonexistent", "password": "testpass"}
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert data["detail"] == "User not found"
|
||||
assert data["detail"] == "User not found"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,19 +1,24 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.main import app, lifespan
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Test client for FastAPI app"""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_root_endpoint(client):
|
||||
"""Test root endpoint returns expected message"""
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "IPTV Updater API"}
|
||||
|
||||
|
||||
def test_openapi_schema_generation(client):
|
||||
"""Test OpenAPI schema is properly generated"""
|
||||
# First call - generate schema
|
||||
@@ -23,7 +28,7 @@ def test_openapi_schema_generation(client):
|
||||
assert schema["openapi"] == "3.1.0"
|
||||
assert "securitySchemes" in schema["components"]
|
||||
assert "Bearer" in schema["components"]["securitySchemes"]
|
||||
|
||||
|
||||
# Test empty components initialization
|
||||
with patch("app.main.get_openapi", return_value={"info": {}}):
|
||||
# Clear cached schema
|
||||
@@ -35,26 +40,28 @@ def test_openapi_schema_generation(client):
|
||||
assert "components" in schema
|
||||
assert "schemas" in schema["components"]
|
||||
|
||||
|
||||
def test_openapi_schema_caching(mocker):
|
||||
"""Test OpenAPI schema caching behavior"""
|
||||
# Clear any existing schema
|
||||
app.openapi_schema = None
|
||||
|
||||
|
||||
# Mock get_openapi to return test schema
|
||||
mock_schema = {"test": "schema"}
|
||||
mocker.patch("app.main.get_openapi", return_value=mock_schema)
|
||||
|
||||
|
||||
# First call - should call get_openapi
|
||||
schema = app.openapi()
|
||||
assert schema == mock_schema
|
||||
assert app.openapi_schema == mock_schema
|
||||
|
||||
|
||||
# Second call - should return cached schema
|
||||
with patch("app.main.get_openapi") as mock_get_openapi:
|
||||
schema = app.openapi()
|
||||
assert schema == mock_schema
|
||||
mock_get_openapi.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan_init_db(mocker):
|
||||
"""Test lifespan manager initializes database"""
|
||||
@@ -63,6 +70,7 @@ async def test_lifespan_init_db(mocker):
|
||||
pass # Just enter/exit context
|
||||
mock_init_db.assert_called_once()
|
||||
|
||||
|
||||
def test_router_inclusion():
|
||||
"""Test all routers are properly included"""
|
||||
route_paths = {route.path for route in app.routes}
|
||||
@@ -70,4 +78,4 @@ def test_router_inclusion():
|
||||
assert any(path.startswith("/auth") for path in route_paths)
|
||||
assert any(path.startswith("/channels") for path in route_paths)
|
||||
assert any(path.startswith("/playlist") for path in route_paths)
|
||||
assert any(path.startswith("/priorities") for path in route_paths)
|
||||
assert any(path.startswith("/priorities") for path in route_paths)
|
||||
|
||||
@@ -1,17 +1,30 @@
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import patch, MagicMock
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlalchemy.orm import Session, sessionmaker, declarative_base
|
||||
from sqlalchemy import create_engine, TypeDecorator, TEXT, Column, String, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import (
|
||||
TEXT,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
TypeDecorator,
|
||||
UniqueConstraint,
|
||||
create_engine,
|
||||
)
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
# Create a mock-specific Base class for testing
|
||||
MockBase = declarative_base()
|
||||
|
||||
|
||||
class SQLiteUUID(TypeDecorator):
|
||||
"""Enables UUID support for SQLite."""
|
||||
|
||||
impl = TEXT
|
||||
cache_ok = True
|
||||
|
||||
@@ -25,12 +38,14 @@ class SQLiteUUID(TypeDecorator):
|
||||
return value
|
||||
return uuid.UUID(value)
|
||||
|
||||
|
||||
# Model classes for testing - prefix with Mock to avoid pytest collection
|
||||
class MockPriority(MockBase):
|
||||
__tablename__ = "priorities"
|
||||
id = Column(Integer, primary_key=True)
|
||||
description = Column(String, nullable=False)
|
||||
|
||||
|
||||
class MockChannelDB(MockBase):
|
||||
__tablename__ = "channels"
|
||||
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
|
||||
@@ -39,32 +54,45 @@ class MockChannelDB(MockBase):
|
||||
group_title = Column(String, nullable=False)
|
||||
tvg_name = Column(String)
|
||||
__table_args__ = (
|
||||
UniqueConstraint('group_title', 'name', name='uix_group_title_name'),
|
||||
UniqueConstraint("group_title", "name", name="uix_group_title_name"),
|
||||
)
|
||||
tvg_logo = Column(String)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
class MockChannelURL(MockBase):
|
||||
__tablename__ = "channels_urls"
|
||||
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
|
||||
channel_id = Column(SQLiteUUID(), ForeignKey('channels.id', ondelete='CASCADE'), nullable=False)
|
||||
channel_id = Column(
|
||||
SQLiteUUID(), ForeignKey("channels.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
url = Column(String, nullable=False)
|
||||
in_use = Column(Boolean, default=False, nullable=False)
|
||||
priority_id = Column(Integer, ForeignKey('priorities.id'), nullable=False)
|
||||
priority_id = Column(Integer, ForeignKey("priorities.id"), nullable=False)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
# Create test engine
|
||||
engine_mock = create_engine(
|
||||
"sqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
|
||||
# Create test session
|
||||
session_mock = sessionmaker(autocommit=False, autoflush=False, bind=engine_mock)
|
||||
|
||||
|
||||
# Mock the actual database functions
|
||||
def mock_get_db():
|
||||
db = session_mock()
|
||||
@@ -73,6 +101,7 @@ def mock_get_db():
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_env(monkeypatch):
|
||||
"""Fixture for mocking environment variables"""
|
||||
@@ -82,14 +111,13 @@ def mock_env(monkeypatch):
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
monkeypatch.setenv("DB_NAME", "testdb")
|
||||
monkeypatch.setenv("AWS_REGION", "us-east-1")
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ssm():
|
||||
"""Fixture for mocking boto3 SSM client"""
|
||||
with patch('boto3.client') as mock_client:
|
||||
with patch("boto3.client") as mock_client:
|
||||
mock_ssm = MagicMock()
|
||||
mock_client.return_value = mock_ssm
|
||||
mock_ssm.get_parameter.return_value = {
|
||||
'Parameter': {'Value': 'mocked_value'}
|
||||
}
|
||||
yield mock_ssm
|
||||
mock_ssm.get_parameter.return_value = {"Parameter": {"Value": "mocked_value"}}
|
||||
yield mock_ssm
|
||||
|
||||
@@ -1,43 +1,45 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from sqlalchemy.orm import Session
|
||||
from app.utils.database import get_db_credentials, get_db
|
||||
from tests.utils.db_mocks import (
|
||||
session_mock,
|
||||
mock_get_db,
|
||||
mock_env,
|
||||
mock_ssm
|
||||
)
|
||||
|
||||
from app.utils.database import get_db, get_db_credentials
|
||||
from tests.utils.db_mocks import mock_env, mock_ssm, session_mock
|
||||
|
||||
|
||||
def test_get_db_credentials_env(mock_env):
|
||||
"""Test getting DB credentials from environment variables"""
|
||||
conn_str = get_db_credentials()
|
||||
assert conn_str == "postgresql://testuser:testpass@localhost/testdb"
|
||||
|
||||
|
||||
def test_get_db_credentials_ssm(mock_ssm):
|
||||
"""Test getting DB credentials from SSM"""
|
||||
os.environ.pop("MOCK_AUTH", None)
|
||||
conn_str = get_db_credentials()
|
||||
assert "postgresql://mocked_value:mocked_value@mocked_value/mocked_value" in conn_str
|
||||
expected_conn = "postgresql://mocked_value:mocked_value@mocked_value/mocked_value"
|
||||
assert expected_conn in conn_str
|
||||
mock_ssm.get_parameter.assert_called()
|
||||
|
||||
|
||||
def test_get_db_credentials_ssm_exception(mock_ssm):
|
||||
"""Test SSM credential fetching failure raises RuntimeError"""
|
||||
os.environ.pop("MOCK_AUTH", None)
|
||||
mock_ssm.get_parameter.side_effect = Exception("SSM timeout")
|
||||
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
get_db_credentials()
|
||||
|
||||
|
||||
assert "Failed to fetch DB credentials from SSM: SSM timeout" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_session_creation():
|
||||
"""Test database session creation"""
|
||||
session = session_mock()
|
||||
assert isinstance(session, Session)
|
||||
session.close()
|
||||
|
||||
|
||||
def test_get_db_generator():
|
||||
"""Test get_db dependency generator"""
|
||||
db_gen = get_db()
|
||||
@@ -48,18 +50,20 @@ def test_get_db_generator():
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
|
||||
def test_init_db(mocker, mock_env):
|
||||
"""Test database initialization creates tables"""
|
||||
mock_create_all = mocker.patch('app.models.Base.metadata.create_all')
|
||||
|
||||
mock_create_all = mocker.patch("app.models.Base.metadata.create_all")
|
||||
|
||||
# Mock get_db_credentials to return SQLite test connection
|
||||
mocker.patch(
|
||||
'app.utils.database.get_db_credentials',
|
||||
return_value="sqlite:///:memory:"
|
||||
"app.utils.database.get_db_credentials",
|
||||
return_value="sqlite:///:memory:",
|
||||
)
|
||||
|
||||
from app.utils.database import init_db, engine
|
||||
|
||||
from app.utils.database import engine, init_db
|
||||
|
||||
init_db()
|
||||
|
||||
|
||||
# Verify create_all was called with the engine
|
||||
mock_create_all.assert_called_once_with(bind=engine)
|
||||
mock_create_all.assert_called_once_with(bind=engine)
|
||||
|
||||
Reference in New Issue
Block a user