Linted and formatted all files

This commit is contained in:
2025-05-28 21:52:39 -05:00
parent e46f13930d
commit 02913c7385
31 changed files with 1264 additions and 766 deletions

View File

@@ -39,6 +39,7 @@
"dflogo",
"dmlogo",
"dotenv",
"EXTINF",
"fastapi",
"filterwarnings",
"fiorinis",
@@ -47,6 +48,7 @@
"gitea",
"iptv",
"isort",
"KHTML",
"lclogo",
"LETSENCRYPT",
"nohup",
@@ -76,6 +78,7 @@
"testpaths",
"uflogo",
"umlogo",
"usefixtures",
"uvicorn",
"venv",
"wrongpass"

View File

@@ -1,12 +1,10 @@
import os
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from sqlalchemy import engine_from_config, pool
from alembic import context
from app.utils.database import get_db_credentials
from app.models.db import Base
from app.utils.database import get_db_credentials
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
@@ -22,7 +20,7 @@ target_metadata = Base.metadata
# Override sqlalchemy.url with dynamic credentials
if not context.is_offline_mode():
config.set_main_option('sqlalchemy.url', get_db_credentials())
config.set_main_option("sqlalchemy.url", get_db_credentials())
# other values from the config, defined by the needs of env.py,
# can be acquired:
@@ -68,9 +66,7 @@ def run_migrations_online() -> None:
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()

16
app.py
View File

@@ -1,6 +1,8 @@
#!/usr/bin/env python3
import os
import aws_cdk as cdk
from infrastructure.stack import IptvUpdaterStack
app = cdk.App()
@@ -19,21 +21,25 @@ required_vars = {
"DOMAIN_NAME": domain_name,
"SSH_PUBLIC_KEY": ssh_public_key,
"REPO_URL": repo_url,
"LETSENCRYPT_EMAIL": letsencrypt_email
"LETSENCRYPT_EMAIL": letsencrypt_email,
}
# Check for missing required variables
missing_vars = [k for k, v in required_vars.items() if not v]
if missing_vars:
raise ValueError(f"Missing required environment variables: {', '.join(missing_vars)}")
raise ValueError(
f"Missing required environment variables: {', '.join(missing_vars)}"
)
IptvUpdaterStack(app, "IptvUpdaterStack",
IptvUpdaterStack(
app,
"IptvUpdaterStack",
freedns_user=freedns_user,
freedns_password=freedns_password,
domain_name=domain_name,
ssh_public_key=ssh_public_key,
repo_url=repo_url,
letsencrypt_email=letsencrypt_email
letsencrypt_email=letsencrypt_email,
)
app.synth()
app.synth()

View File

@@ -1,9 +1,14 @@
import boto3
from fastapi import HTTPException, status
from app.models.auth import CognitoUser
from app.utils.auth import calculate_secret_hash
from app.utils.constants import (AWS_REGION, COGNITO_CLIENT_ID,
COGNITO_CLIENT_SECRET, USER_ROLE_ATTRIBUTE)
from app.utils.constants import (
AWS_REGION,
COGNITO_CLIENT_ID,
COGNITO_CLIENT_SECRET,
USER_ROLE_ATTRIBUTE,
)
cognito_client = boto3.client("cognito-idp", region_name=AWS_REGION)
@@ -12,43 +17,41 @@ def initiate_auth(username: str, password: str) -> dict:
"""
Initiate AUTH flow with Cognito using USER_PASSWORD_AUTH.
"""
auth_params = {
"USERNAME": username,
"PASSWORD": password
}
auth_params = {"USERNAME": username, "PASSWORD": password}
# If a client secret is required, add SECRET_HASH
if COGNITO_CLIENT_SECRET:
auth_params["SECRET_HASH"] = calculate_secret_hash(
username, COGNITO_CLIENT_ID, COGNITO_CLIENT_SECRET)
username, COGNITO_CLIENT_ID, COGNITO_CLIENT_SECRET
)
try:
response = cognito_client.initiate_auth(
AuthFlow="USER_PASSWORD_AUTH",
AuthParameters=auth_params,
ClientId=COGNITO_CLIENT_ID
ClientId=COGNITO_CLIENT_ID,
)
return response["AuthenticationResult"]
except cognito_client.exceptions.NotAuthorizedException:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid username or password"
detail="Invalid username or password",
)
except cognito_client.exceptions.UserNotFoundException:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"An error occurred during authentication: {str(e)}"
detail=f"An error occurred during authentication: {str(e)}",
)
def get_user_from_token(access_token: str) -> CognitoUser:
"""
Verify the token by calling GetUser in Cognito and retrieve user attributes including roles.
Verify the token by calling GetUser in Cognito and
retrieve user attributes including roles.
"""
try:
user_response = cognito_client.get_user(AccessToken=access_token)
@@ -59,23 +62,21 @@ def get_user_from_token(access_token: str) -> CognitoUser:
for attr in attributes:
if attr["Name"] == USER_ROLE_ATTRIBUTE:
# Assume roles are stored as a comma-separated string
user_roles = [r.strip()
for r in attr["Value"].split(",") if r.strip()]
user_roles = [r.strip() for r in attr["Value"].split(",") if r.strip()]
break
return CognitoUser(username=username, roles=user_roles)
except cognito_client.exceptions.NotAuthorizedException:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token."
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token."
)
except cognito_client.exceptions.UserNotFoundException:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or invalid token."
detail="User not found or invalid token.",
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Token verification failed: {str(e)}"
detail=f"Token verification failed: {str(e)}",
)

View File

@@ -1,6 +1,6 @@
import os
from functools import wraps
from typing import Callable
import os
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
@@ -13,10 +13,8 @@ if os.getenv("MOCK_AUTH", "").lower() == "true":
else:
from app.auth.cognito import get_user_from_token
oauth2_scheme = OAuth2PasswordBearer(
tokenUrl="signin",
scheme_name="Bearer"
)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="signin", scheme_name="Bearer")
def get_current_user(token: str = Depends(oauth2_scheme)) -> CognitoUser:
"""
@@ -40,7 +38,9 @@ def require_roles(*required_roles: str) -> Callable:
if not needed_roles.issubset(user_roles):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You do not have the required roles to access this endpoint.",
detail=(
"You do not have the required roles to access this endpoint."
),
)
return endpoint(*args, user=user, **kwargs)

View File

@@ -1,12 +1,9 @@
from fastapi import HTTPException, status
from app.models.auth import CognitoUser
MOCK_USERS = {
"testuser": {
"username": "testuser",
"roles": ["admin"]
}
}
MOCK_USERS = {"testuser": {"username": "testuser", "roles": ["admin"]}}
def mock_get_user_from_token(token: str) -> CognitoUser:
"""
@@ -17,16 +14,13 @@ def mock_get_user_from_token(token: str) -> CognitoUser:
return CognitoUser(**MOCK_USERS["testuser"])
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid mock token - use 'testuser'"
detail="Invalid mock token - use 'testuser'",
)
def mock_initiate_auth(username: str, password: str) -> dict:
"""
Mock version of initiate_auth for local testing
Accepts any username/password and returns a mock token
"""
return {
"AccessToken": "testuser",
"ExpiresIn": 3600,
"TokenType": "Bearer"
}
return {"AccessToken": "testuser", "ExpiresIn": 3600, "TokenType": "Bearer"}

View File

@@ -1,39 +1,59 @@
import os
import re
import argparse
import gzip
import json
import os
import re
import xml.etree.ElementTree as ET
import requests
import argparse
from utils.constants import IPTV_SERVER_ADMIN_PASSWORD, IPTV_SERVER_ADMIN_USER, IPTV_SERVER_URL
from utils.constants import (
IPTV_SERVER_ADMIN_PASSWORD,
IPTV_SERVER_ADMIN_USER,
IPTV_SERVER_URL,
)
def parse_arguments():
parser = argparse.ArgumentParser(description='EPG Grabber')
parser.add_argument('--playlist',
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'playlist.m3u8'),
help='Path to playlist file')
parser.add_argument('--output',
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'epg.xml'),
help='Path to output EPG XML file')
parser.add_argument('--epg-sources',
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'epg_sources.json'),
help='Path to EPG sources JSON configuration file')
parser.add_argument('--save-as-gz',
action='store_true',
default=True,
help='Save an additional gzipped version of the EPG file')
parser = argparse.ArgumentParser(description="EPG Grabber")
parser.add_argument(
"--playlist",
default=os.path.join(
os.path.dirname(os.path.dirname(__file__)), "playlist.m3u8"
),
help="Path to playlist file",
)
parser.add_argument(
"--output",
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "epg.xml"),
help="Path to output EPG XML file",
)
parser.add_argument(
"--epg-sources",
default=os.path.join(
os.path.dirname(os.path.dirname(__file__)), "epg_sources.json"
),
help="Path to EPG sources JSON configuration file",
)
parser.add_argument(
"--save-as-gz",
action="store_true",
default=True,
help="Save an additional gzipped version of the EPG file",
)
return parser.parse_args()
def load_epg_sources(config_path):
"""Load EPG sources from JSON configuration file"""
try:
with open(config_path, 'r', encoding='utf-8') as f:
with open(config_path, encoding="utf-8") as f:
config = json.load(f)
return config.get('epg_sources', [])
return config.get("epg_sources", [])
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Error loading EPG sources: {e}")
return []
def get_tvg_ids(playlist_path):
"""
Extracts unique tvg-id values from an M3U playlist file.
@@ -51,26 +71,27 @@ def get_tvg_ids(playlist_path):
# and ends with a double quote.
tvg_id_pattern = re.compile(r'tvg-id="([^"]*)"')
with open(playlist_path, 'r', encoding='utf-8') as file:
with open(playlist_path, encoding="utf-8") as file:
for line in file:
if line.startswith('#EXTINF'):
if line.startswith("#EXTINF"):
# Search for the tvg-id pattern in the line
match = tvg_id_pattern.search(line)
if match:
# Extract the captured group (the value inside the quotes)
tvg_id = match.group(1)
if tvg_id: # Ensure the extracted id is not empty
if tvg_id: # Ensure the extracted id is not empty
unique_tvg_ids.add(tvg_id)
return list(unique_tvg_ids)
def fetch_and_extract_xml(url):
response = requests.get(url)
if response.status_code != 200:
print(f"Failed to fetch {url}")
return None
if url.endswith('.gz'):
if url.endswith(".gz"):
try:
decompressed_data = gzip.decompress(response.content)
return ET.fromstring(decompressed_data)
@@ -84,44 +105,48 @@ def fetch_and_extract_xml(url):
print(f"Failed to parse XML from {url}: {e}")
return None
def filter_and_build_epg(urls, tvg_ids, output_file, save_as_gz=True):
root = ET.Element('tv')
root = ET.Element("tv")
for url in urls:
epg_data = fetch_and_extract_xml(url)
if epg_data is None:
continue
for channel in epg_data.findall('channel'):
tvg_id = channel.get('id')
for channel in epg_data.findall("channel"):
tvg_id = channel.get("id")
if tvg_id in tvg_ids:
root.append(channel)
for programme in epg_data.findall('programme'):
tvg_id = programme.get('channel')
for programme in epg_data.findall("programme"):
tvg_id = programme.get("channel")
if tvg_id in tvg_ids:
root.append(programme)
tree = ET.ElementTree(root)
tree.write(output_file, encoding='utf-8', xml_declaration=True)
tree.write(output_file, encoding="utf-8", xml_declaration=True)
print(f"New EPG saved to {output_file}")
if save_as_gz:
output_file_gz = output_file + '.gz'
with gzip.open(output_file_gz, 'wb') as f:
tree.write(f, encoding='utf-8', xml_declaration=True)
output_file_gz = output_file + ".gz"
with gzip.open(output_file_gz, "wb") as f:
tree.write(f, encoding="utf-8", xml_declaration=True)
print(f"New EPG saved to {output_file_gz}")
def upload_epg(file_path):
"""Uploads gzipped EPG file to IPTV server using HTTP Basic Auth"""
try:
with open(file_path, 'rb') as f:
with open(file_path, "rb") as f:
response = requests.post(
IPTV_SERVER_URL + '/admin/epg',
auth=requests.auth.HTTPBasicAuth(IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD),
files={'file': (os.path.basename(file_path), f)}
IPTV_SERVER_URL + "/admin/epg",
auth=requests.auth.HTTPBasicAuth(
IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD
),
files={"file": (os.path.basename(file_path), f)},
)
if response.status_code == 200:
print("EPG successfully uploaded to server")
else:
@@ -129,6 +154,7 @@ def upload_epg(file_path):
except Exception as e:
print(f"Upload error: {str(e)}")
if __name__ == "__main__":
args = parse_arguments()
playlist_file = args.playlist
@@ -144,4 +170,4 @@ if __name__ == "__main__":
filter_and_build_epg(urls, tvg_ids, output_file, args.save_as_gz)
if args.save_as_gz:
upload_epg(output_file + '.gz')
upload_epg(output_file + ".gz")

View File

@@ -1,26 +1,45 @@
import os
import argparse
import json
import logging
import requests
from pathlib import Path
import os
from datetime import datetime
import requests
from utils.check_streams import StreamValidator
from utils.constants import EPG_URL, IPTV_SERVER_ADMIN_PASSWORD, IPTV_SERVER_ADMIN_USER, IPTV_SERVER_URL
from utils.constants import (
EPG_URL,
IPTV_SERVER_ADMIN_PASSWORD,
IPTV_SERVER_ADMIN_USER,
IPTV_SERVER_URL,
)
def parse_arguments():
parser = argparse.ArgumentParser(description='IPTV playlist generator')
parser.add_argument('--output',
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'playlist.m3u8'),
help='Path to output playlist file')
parser.add_argument('--channels',
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'channels.json'),
help='Path to channels definition JSON file')
parser.add_argument('--dead-channels-log',
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'dead_channels.log'),
help='Path to log file to store a list of dead channels')
parser = argparse.ArgumentParser(description="IPTV playlist generator")
parser.add_argument(
"--output",
default=os.path.join(
os.path.dirname(os.path.dirname(__file__)), "playlist.m3u8"
),
help="Path to output playlist file",
)
parser.add_argument(
"--channels",
default=os.path.join(
os.path.dirname(os.path.dirname(__file__)), "channels.json"
),
help="Path to channels definition JSON file",
)
parser.add_argument(
"--dead-channels-log",
default=os.path.join(
os.path.dirname(os.path.dirname(__file__)), "dead_channels.log"
),
help="Path to log file to store a list of dead channels",
)
return parser.parse_args()
def find_working_stream(validator, urls):
"""Test all URLs and return the first working one"""
for url in urls:
@@ -29,48 +48,55 @@ def find_working_stream(validator, urls):
return url
return None
def create_playlist(channels_file, output_file):
# Read channels from JSON file
with open(channels_file, 'r', encoding='utf-8') as f:
with open(channels_file, encoding="utf-8") as f:
channels = json.load(f)
# Initialize validator
validator = StreamValidator(timeout=45)
# Prepare M3U8 header
m3u8_content = f'#EXTM3U url-tvg="{EPG_URL}"\n'
for channel in channels:
if 'urls' in channel: # Check if channel has URLs
if "urls" in channel: # Check if channel has URLs
# Find first working stream
working_url = find_working_stream(validator, channel['urls'])
working_url = find_working_stream(validator, channel["urls"])
if working_url:
# Add channel to playlist
m3u8_content += f'#EXTINF:-1 tvg-id="{channel.get("tvg-id", "")}" '
m3u8_content += f'tvg-name="{channel.get("tvg-name", "")}" '
m3u8_content += f'tvg-logo="{channel.get("tvg-logo", "")}" '
m3u8_content += f'group-title="{channel.get("group-title", "")}", '
m3u8_content += f'{channel.get("name", "")}\n'
m3u8_content += f'{working_url}\n'
m3u8_content += f"{channel.get('name', '')}\n"
m3u8_content += f"{working_url}\n"
else:
# Log dead channel
logging.info(f'Dead channel: {channel.get("name", "Unknown")} - No working streams found')
logging.info(
f"Dead channel: {channel.get('name', 'Unknown')} - "
"No working streams found"
)
# Write playlist file
with open(output_file, 'w', encoding='utf-8') as f:
with open(output_file, "w", encoding="utf-8") as f:
f.write(m3u8_content)
def upload_playlist(file_path):
"""Uploads playlist file to IPTV server using HTTP Basic Auth"""
try:
with open(file_path, 'rb') as f:
with open(file_path, "rb") as f:
response = requests.post(
IPTV_SERVER_URL + '/admin/playlist',
auth=requests.auth.HTTPBasicAuth(IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD),
files={'file': (os.path.basename(file_path), f)}
IPTV_SERVER_URL + "/admin/playlist",
auth=requests.auth.HTTPBasicAuth(
IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD
),
files={"file": (os.path.basename(file_path), f)},
)
if response.status_code == 200:
print("Playlist successfully uploaded to server")
else:
@@ -78,6 +104,7 @@ def upload_playlist(file_path):
except Exception as e:
print(f"Upload error: {str(e)}")
def main():
args = parse_arguments()
channels_file = args.channels
@@ -85,24 +112,25 @@ def main():
dead_channels_log_file = args.dead_channels_log
# Clear previous log file
with open(dead_channels_log_file, 'w') as f:
f.write(f'Log created on {datetime.now()}\n')
with open(dead_channels_log_file, "w") as f:
f.write(f"Log created on {datetime.now()}\n")
# Configure logging
logging.basicConfig(
filename=dead_channels_log_file,
level=logging.INFO,
format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
format="%(asctime)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# Create playlist
create_playlist(channels_file, output_file)
#upload playlist to server
# upload playlist to server
upload_playlist(output_file)
print("Playlist creation completed!")
if __name__ == "__main__":
main()
main()

View File

@@ -1,17 +1,18 @@
from fastapi.concurrency import asynccontextmanager
from app.routers import channels, auth, playlist, priorities
from fastapi import FastAPI
from fastapi.concurrency import asynccontextmanager
from fastapi.openapi.utils import get_openapi
from app.routers import auth, channels, playlist, priorities
from app.utils.database import init_db
@asynccontextmanager
async def lifespan(app: FastAPI):
# Initialize database tables on startup
init_db()
yield
app = FastAPI(
lifespan=lifespan,
title="IPTV Updater API",
@@ -19,6 +20,7 @@ app = FastAPI(
version="1.0.0",
)
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
@@ -40,11 +42,7 @@ def custom_openapi():
# Add security scheme component
openapi_schema["components"]["securitySchemes"] = {
"Bearer": {
"type": "http",
"scheme": "bearer",
"bearerFormat": "JWT"
}
"Bearer": {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}
}
# Add global security requirement
@@ -56,14 +54,17 @@ def custom_openapi():
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi
@app.get("/")
async def root():
return {"message": "IPTV Updater API"}
# Include routers
app.include_router(auth.router)
app.include_router(channels.router)
app.include_router(playlist.router)
app.include_router(priorities.router)
app.include_router(priorities.router)

View File

@@ -1,4 +1,19 @@
from .db import Base, ChannelDB, ChannelURL
from .schemas import ChannelCreate, ChannelUpdate, ChannelResponse, ChannelURLCreate, ChannelURLResponse
from .schemas import (
ChannelCreate,
ChannelResponse,
ChannelUpdate,
ChannelURLCreate,
ChannelURLResponse,
)
__all__ = ["Base", "ChannelDB", "ChannelCreate", "ChannelUpdate", "ChannelResponse", "ChannelURL", "ChannelURLCreate", "ChannelURLResponse"]
__all__ = [
"Base",
"ChannelDB",
"ChannelCreate",
"ChannelUpdate",
"ChannelResponse",
"ChannelURL",
"ChannelURLCreate",
"ChannelURLResponse",
]

View File

@@ -1,20 +1,26 @@
from typing import List, Optional
from typing import Optional
from pydantic import BaseModel, Field
class SigninRequest(BaseModel):
"""Request model for the signin endpoint."""
username: str = Field(..., description="The user's username")
password: str = Field(..., description="The user's password")
class TokenResponse(BaseModel):
"""Response model for successful authentication."""
access_token: str = Field(..., description="Access JWT token from Cognito")
id_token: str = Field(..., description="ID JWT token from Cognito")
refresh_token: Optional[str] = Field(
None, description="Refresh token from Cognito")
refresh_token: Optional[str] = Field(None, description="Refresh token from Cognito")
token_type: str = Field(..., description="Type of the token returned")
class CognitoUser(BaseModel):
"""Model representing the user returned from token verification."""
username: str
roles: List[str]
roles: list[str]

View File

@@ -1,21 +1,33 @@
from datetime import datetime, timezone
import uuid
from sqlalchemy import Column, String, JSON, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer
from datetime import datetime, timezone
from sqlalchemy import (
Boolean,
Column,
DateTime,
ForeignKey,
Integer,
String,
UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import relationship
from sqlalchemy.orm import declarative_base, relationship
Base = declarative_base()
class Priority(Base):
"""SQLAlchemy model for channel URL priorities"""
__tablename__ = "priorities"
id = Column(Integer, primary_key=True)
description = Column(String, nullable=False)
class ChannelDB(Base):
"""SQLAlchemy model for IPTV channels"""
__tablename__ = "channels"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
@@ -25,27 +37,43 @@ class ChannelDB(Base):
tvg_name = Column(String)
__table_args__ = (
UniqueConstraint('group_title', 'name', name='uix_group_title_name'),
UniqueConstraint("group_title", "name", name="uix_group_title_name"),
)
tvg_logo = Column(String)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
updated_at = Column(
DateTime,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
# Relationship with ChannelURL
urls = relationship("ChannelURL", back_populates="channel", cascade="all, delete-orphan")
urls = relationship(
"ChannelURL", back_populates="channel", cascade="all, delete-orphan"
)
class ChannelURL(Base):
"""SQLAlchemy model for channel URLs"""
__tablename__ = "channels_urls"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
channel_id = Column(UUID(as_uuid=True), ForeignKey('channels.id', ondelete='CASCADE'), nullable=False)
channel_id = Column(
UUID(as_uuid=True),
ForeignKey("channels.id", ondelete="CASCADE"),
nullable=False,
)
url = Column(String, nullable=False)
in_use = Column(Boolean, default=False, nullable=False)
priority_id = Column(Integer, ForeignKey('priorities.id'), nullable=False)
priority_id = Column(Integer, ForeignKey("priorities.id"), nullable=False)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
updated_at = Column(
DateTime,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
# Relationships
channel = relationship("ChannelDB", back_populates="urls")
priority = relationship("Priority")
priority = relationship("Priority")

View File

@@ -1,30 +1,43 @@
from datetime import datetime
from typing import List, Optional
from typing import Optional
from uuid import UUID
from pydantic import BaseModel, Field, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
class PriorityBase(BaseModel):
"""Base Pydantic model for priorities"""
id: int
description: str
model_config = ConfigDict(from_attributes=True)
class PriorityCreate(PriorityBase):
"""Pydantic model for creating priorities"""
pass
class PriorityResponse(PriorityBase):
"""Pydantic model for priority responses"""
pass
class ChannelURLCreate(BaseModel):
"""Pydantic model for creating channel URLs"""
url: str
priority_id: int = Field(default=100, ge=100, le=300) # Default to High, validate range
priority_id: int = Field(
default=100, ge=100, le=300
) # Default to High, validate range
class ChannelURLBase(ChannelURLCreate):
"""Base Pydantic model for channel URL responses"""
id: UUID
in_use: bool
created_at: datetime
@@ -33,43 +46,53 @@ class ChannelURLBase(ChannelURLCreate):
model_config = ConfigDict(from_attributes=True)
class ChannelURLResponse(ChannelURLBase):
"""Pydantic model for channel URL responses"""
pass
class ChannelCreate(BaseModel):
"""Pydantic model for creating channels"""
urls: List[ChannelURLCreate] # List of URL objects with priority
urls: list[ChannelURLCreate] # List of URL objects with priority
name: str
group_title: str
tvg_id: str
tvg_logo: str
tvg_name: str
class ChannelURLUpdate(BaseModel):
"""Pydantic model for updating channel URLs"""
url: Optional[str] = None
in_use: Optional[bool] = None
priority_id: Optional[int] = Field(default=None, ge=100, le=300)
class ChannelUpdate(BaseModel):
"""Pydantic model for updating channels (all fields optional)"""
name: Optional[str] = Field(None, min_length=1)
group_title: Optional[str] = Field(None, min_length=1)
tvg_id: Optional[str] = Field(None, min_length=1)
tvg_logo: Optional[str] = None
tvg_name: Optional[str] = Field(None, min_length=1)
class ChannelResponse(BaseModel):
"""Pydantic model for channel responses"""
id: UUID
name: str
group_title: str
tvg_id: str
tvg_logo: str
tvg_name: str
urls: List[ChannelURLResponse] # List of URL objects without channel_id
urls: list[ChannelURLResponse] # List of URL objects without channel_id
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
model_config = ConfigDict(from_attributes=True)

View File

@@ -1,16 +1,16 @@
from fastapi import APIRouter
from app.auth.cognito import initiate_auth
from app.models.auth import SigninRequest, TokenResponse
router = APIRouter(
prefix="/auth",
tags=["authentication"]
)
router = APIRouter(prefix="/auth", tags=["authentication"])
@router.post("/signin", response_model=TokenResponse, summary="Signin Endpoint")
def signin(credentials: SigninRequest):
"""
Sign-in endpoint to authenticate the user with AWS Cognito using username and password.
Sign-in endpoint to authenticate the user with AWS Cognito
using username and password.
On success, returns JWT tokens (access_token, id_token, refresh_token).
"""
auth_result = initiate_auth(credentials.username, credentials.password)
@@ -19,4 +19,4 @@ def signin(credentials: SigninRequest):
id_token=auth_result["IdToken"],
refresh_token=auth_result.get("RefreshToken"),
token_type="Bearer",
)
)

View File

@@ -1,52 +1,54 @@
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from typing import List
from uuid import UUID
from sqlalchemy import and_
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import and_
from sqlalchemy.orm import Session
from app.auth.dependencies import get_current_user, require_roles
from app.models import (
ChannelDB,
ChannelURL,
ChannelCreate,
ChannelUpdate,
ChannelDB,
ChannelResponse,
ChannelUpdate,
ChannelURL,
ChannelURLCreate,
ChannelURLResponse,
)
from app.models.auth import CognitoUser
from app.models.schemas import ChannelURLUpdate
from app.utils.database import get_db
from app.auth.dependencies import get_current_user, require_roles
from app.models.auth import CognitoUser
router = APIRouter(
prefix="/channels",
tags=["channels"]
)
router = APIRouter(prefix="/channels", tags=["channels"])
@router.post("/", response_model=ChannelResponse, status_code=status.HTTP_201_CREATED)
@require_roles("admin")
def create_channel(
channel: ChannelCreate,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user)
user: CognitoUser = Depends(get_current_user),
):
"""Create a new channel"""
# Check for duplicate channel (same group_title + name)
existing_channel = db.query(ChannelDB).filter(
and_(
ChannelDB.group_title == channel.group_title,
ChannelDB.name == channel.name
existing_channel = (
db.query(ChannelDB)
.filter(
and_(
ChannelDB.group_title == channel.group_title,
ChannelDB.name == channel.name,
)
)
).first()
.first()
)
if existing_channel:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Channel with same group_title and name already exists"
detail="Channel with same group_title and name already exists",
)
# Create channel without URLs first
channel_data = channel.model_dump(exclude={'urls'})
channel_data = channel.model_dump(exclude={"urls"})
urls = channel.urls
db_channel = ChannelDB(**channel_data)
db.add(db_channel)
@@ -59,130 +61,142 @@ def create_channel(
channel_id=db_channel.id,
url=url.url,
priority_id=url.priority_id,
in_use=False
in_use=False,
)
db.add(db_url)
db.commit()
db.refresh(db_channel)
return db_channel
@router.get("/{channel_id}", response_model=ChannelResponse)
def get_channel(
channel_id: UUID,
db: Session = Depends(get_db)
):
def get_channel(channel_id: UUID, db: Session = Depends(get_db)):
"""Get a channel by id"""
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Channel not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
)
return channel
@router.put("/{channel_id}", response_model=ChannelResponse)
@require_roles("admin")
def update_channel(
channel_id: UUID,
channel: ChannelUpdate,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user)
user: CognitoUser = Depends(get_current_user),
):
"""Update a channel"""
db_channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not db_channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Channel not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
)
# Only check for duplicates if name or group_title are being updated
if channel.name is not None or channel.group_title is not None:
name = channel.name if channel.name is not None else db_channel.name
group_title = channel.group_title if channel.group_title is not None else db_channel.group_title
existing_channel = db.query(ChannelDB).filter(
and_(
ChannelDB.group_title == group_title,
ChannelDB.name == name,
ChannelDB.id != channel_id
group_title = (
channel.group_title
if channel.group_title is not None
else db_channel.group_title
)
existing_channel = (
db.query(ChannelDB)
.filter(
and_(
ChannelDB.group_title == group_title,
ChannelDB.name == name,
ChannelDB.id != channel_id,
)
)
).first()
.first()
)
if existing_channel:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Channel with same group_title and name already exists"
detail="Channel with same group_title and name already exists",
)
# Update only provided fields
update_data = channel.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_channel, key, value)
db.commit()
db.refresh(db_channel)
return db_channel
@router.delete("/{channel_id}", status_code=status.HTTP_204_NO_CONTENT)
@require_roles("admin")
def delete_channel(
channel_id: UUID,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user)
user: CognitoUser = Depends(get_current_user),
):
"""Delete a channel"""
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Channel not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
)
db.delete(channel)
db.commit()
return None
@router.get("/", response_model=List[ChannelResponse])
@router.get("/", response_model=list[ChannelResponse])
@require_roles("admin")
def list_channels(
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user)
user: CognitoUser = Depends(get_current_user),
):
"""List all channels with pagination"""
return db.query(ChannelDB).offset(skip).limit(limit).all()
# URL Management Endpoints
@router.post("/{channel_id}/urls", response_model=ChannelURLResponse, status_code=status.HTTP_201_CREATED)
@router.post(
"/{channel_id}/urls",
response_model=ChannelURLResponse,
status_code=status.HTTP_201_CREATED,
)
@require_roles("admin")
def add_channel_url(
channel_id: UUID,
url: ChannelURLCreate,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user)
user: CognitoUser = Depends(get_current_user),
):
"""Add a new URL to a channel"""
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Channel not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
)
db_url = ChannelURL(
channel_id=channel_id,
url=url.url,
priority_id=url.priority_id,
in_use=False # Default to not in use
in_use=False, # Default to not in use
)
db.add(db_url)
db.commit()
db.refresh(db_url)
return db_url
@router.put("/{channel_id}/urls/{url_id}", response_model=ChannelURLResponse)
@require_roles("admin")
def update_channel_url(
@@ -190,72 +204,69 @@ def update_channel_url(
url_id: UUID,
url_update: ChannelURLUpdate,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user)
user: CognitoUser = Depends(get_current_user),
):
"""Update a channel URL (url, in_use, or priority_id)"""
db_url = db.query(ChannelURL).filter(
and_(
ChannelURL.id == url_id,
ChannelURL.channel_id == channel_id
)
).first()
db_url = (
db.query(ChannelURL)
.filter(and_(ChannelURL.id == url_id, ChannelURL.channel_id == channel_id))
.first()
)
if not db_url:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="URL not found"
status_code=status.HTTP_404_NOT_FOUND, detail="URL not found"
)
if url_update.url is not None:
db_url.url = url_update.url
if url_update.in_use is not None:
db_url.in_use = url_update.in_use
if url_update.priority_id is not None:
db_url.priority_id = url_update.priority_id
db.commit()
db.refresh(db_url)
return db_url
@router.delete("/{channel_id}/urls/{url_id}", status_code=status.HTTP_204_NO_CONTENT)
@require_roles("admin")
def delete_channel_url(
channel_id: UUID,
url_id: UUID,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user)
user: CognitoUser = Depends(get_current_user),
):
"""Delete a URL from a channel"""
url = db.query(ChannelURL).filter(
and_(
ChannelURL.id == url_id,
ChannelURL.channel_id == channel_id
)
).first()
url = (
db.query(ChannelURL)
.filter(and_(ChannelURL.id == url_id, ChannelURL.channel_id == channel_id))
.first()
)
if not url:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="URL not found"
status_code=status.HTTP_404_NOT_FOUND, detail="URL not found"
)
db.delete(url)
db.commit()
return None
@router.get("/{channel_id}/urls", response_model=List[ChannelURLResponse])
@router.get("/{channel_id}/urls", response_model=list[ChannelURLResponse])
@require_roles("admin")
def list_channel_urls(
channel_id: UUID,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user)
user: CognitoUser = Depends(get_current_user),
):
"""List all URLs for a channel"""
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Channel not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
)
return db.query(ChannelURL).filter(ChannelURL.channel_id == channel_id).all()
return db.query(ChannelURL).filter(ChannelURL.channel_id == channel_id).all()

View File

@@ -1,17 +1,15 @@
from fastapi import APIRouter, Depends
from app.auth.dependencies import get_current_user
from app.models.auth import CognitoUser
router = APIRouter(
prefix="/playlist",
tags=["playlist"]
)
router = APIRouter(prefix="/playlist", tags=["playlist"])
@router.get("/protected",
summary="Protected endpoint for authenticated users")
@router.get("/protected", summary="Protected endpoint for authenticated users")
async def protected_route(user: CognitoUser = Depends(get_current_user)):
"""
Protected endpoint that requires authentication for all users.
If the user is authenticated, returns success message.
"""
return {"message": f"Hello {user.username}, you have access to support resources!"}
return {"message": f"Hello {user.username}, you have access to support resources!"}

View File

@@ -1,25 +1,22 @@
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from sqlalchemy import select, delete
from typing import List
from app.auth.dependencies import get_current_user, require_roles
from app.models.auth import CognitoUser
from app.models.db import Priority
from app.models.schemas import PriorityCreate, PriorityResponse
from app.utils.database import get_db
from app.auth.dependencies import get_current_user, require_roles
from app.models.auth import CognitoUser
router = APIRouter(
prefix="/priorities",
tags=["priorities"]
)
router = APIRouter(prefix="/priorities", tags=["priorities"])
@router.post("/", response_model=PriorityResponse, status_code=status.HTTP_201_CREATED)
@require_roles("admin")
def create_priority(
priority: PriorityCreate,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user)
user: CognitoUser = Depends(get_current_user),
):
"""Create a new priority"""
# Check if priority with this ID already exists
@@ -27,71 +24,69 @@ def create_priority(
if existing:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Priority with ID {priority.id} already exists"
detail=f"Priority with ID {priority.id} already exists",
)
db_priority = Priority(**priority.model_dump())
db.add(db_priority)
db.commit()
db.refresh(db_priority)
return db_priority
@router.get("/", response_model=List[PriorityResponse])
@router.get("/", response_model=list[PriorityResponse])
@require_roles("admin")
def list_priorities(
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user)
db: Session = Depends(get_db), user: CognitoUser = Depends(get_current_user)
):
"""List all priorities"""
return db.query(Priority).all()
@router.get("/{priority_id}", response_model=PriorityResponse)
@require_roles("admin")
def get_priority(
priority_id: int,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user)
user: CognitoUser = Depends(get_current_user),
):
"""Get a priority by id"""
priority = db.get(Priority, priority_id)
if not priority:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Priority not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Priority not found"
)
return priority
@router.delete("/{priority_id}", status_code=status.HTTP_204_NO_CONTENT)
@require_roles("admin")
def delete_priority(
priority_id: int,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user)
user: CognitoUser = Depends(get_current_user),
):
"""Delete a priority (if not in use)"""
from app.models.db import ChannelURL
# Check if priority exists
priority = db.get(Priority, priority_id)
if not priority:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Priority not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Priority not found"
)
# Check if priority is in use
in_use = db.scalar(
select(ChannelURL)
.where(ChannelURL.priority_id == priority_id)
.limit(1)
select(ChannelURL).where(ChannelURL.priority_id == priority_id).limit(1)
)
if in_use:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Cannot delete priority that is in use by channel URLs"
detail="Cannot delete priority that is in use by channel URLs",
)
db.execute(delete(Priority).where(Priority.id == priority_id))
db.commit()
return None
return None

View File

@@ -2,11 +2,13 @@ import base64
import hashlib
import hmac
def calculate_secret_hash(username: str, client_id: str, client_secret: str) -> str:
"""
Calculate the Cognito SECRET_HASH using HMAC SHA256 for secret-enabled clients.
"""
msg = username + client_id
dig = hmac.new(client_secret.encode('utf-8'),
msg.encode('utf-8'), hashlib.sha256).digest()
return base64.b64encode(dig).decode()
dig = hmac.new(
client_secret.encode("utf-8"), msg.encode("utf-8"), hashlib.sha256
).digest()
return base64.b64encode(dig).decode()

View File

@@ -1,41 +1,50 @@
import os
import argparse
import requests
import logging
from requests.exceptions import RequestException, Timeout, ConnectionError, HTTPError
import os
import requests
from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout
class StreamValidator:
def __init__(self, timeout=10, user_agent=None):
self.timeout = timeout
self.session = requests.Session()
self.session.headers.update({
'User-Agent': user_agent or 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36'
})
self.session.headers.update(
{
"User-Agent": user_agent
or (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/120.0.0.0 Safari/537.36"
)
}
)
def validate_stream(self, url):
"""Validate a media stream URL with multiple fallback checks"""
try:
headers = {'Range': 'bytes=0-1024'}
headers = {"Range": "bytes=0-1024"}
with self.session.get(
url,
headers=headers,
timeout=self.timeout,
stream=True,
allow_redirects=True
allow_redirects=True,
) as response:
if response.status_code not in [200, 206]:
return False, f"Invalid status code: {response.status_code}"
content_type = response.headers.get('Content-Type', '')
content_type = response.headers.get("Content-Type", "")
if not self._is_valid_content_type(content_type):
return False, f"Invalid content type: {content_type}"
try:
next(response.iter_content(chunk_size=1024))
return True, "Stream is valid"
except (ConnectionError, Timeout):
return False, "Connection failed during content read"
except HTTPError as e:
return False, f"HTTP Error: {str(e)}"
except ConnectionError as e:
@@ -49,10 +58,13 @@ class StreamValidator:
def _is_valid_content_type(self, content_type):
valid_types = [
'video/mp2t', 'application/vnd.apple.mpegurl',
'application/dash+xml', 'video/mp4',
'video/webm', 'application/octet-stream',
'application/x-mpegURL'
"video/mp2t",
"application/vnd.apple.mpegurl",
"application/dash+xml",
"video/mp4",
"video/webm",
"application/octet-stream",
"application/x-mpegURL",
]
return any(ct in content_type for ct in valid_types)
@@ -60,45 +72,43 @@ class StreamValidator:
"""Extract stream URLs from M3U playlist file"""
urls = []
try:
with open(file_path, 'r') as f:
with open(file_path) as f:
for line in f:
line = line.strip()
if line and not line.startswith('#'):
if line and not line.startswith("#"):
urls.append(line)
except Exception as e:
logging.error(f"Error reading playlist file: {str(e)}")
raise
return urls
def main():
parser = argparse.ArgumentParser(
description='Validate streaming URLs from command line arguments or playlist files',
formatter_class=argparse.ArgumentDefaultsHelpFormatter
description=(
"Validate streaming URLs from command line arguments or playlist files"
),
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
'sources',
nargs='+',
help='List of URLs or file paths containing stream URLs'
"sources", nargs="+", help="List of URLs or file paths containing stream URLs"
)
parser.add_argument(
'--timeout',
type=int,
default=20,
help='Timeout in seconds for stream checks'
"--timeout", type=int, default=20, help="Timeout in seconds for stream checks"
)
parser.add_argument(
'--output',
default='deadstreams.txt',
help='Output file name for inactive streams'
"--output",
default="deadstreams.txt",
help="Output file name for inactive streams",
)
args = parser.parse_args()
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.FileHandler('stream_check.log'), logging.StreamHandler()]
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[logging.FileHandler("stream_check.log"), logging.StreamHandler()],
)
validator = StreamValidator(timeout=args.timeout)
@@ -127,9 +137,10 @@ def main():
# Save dead streams to file
if dead_streams:
with open(args.output, 'w') as f:
f.write('\n'.join(dead_streams))
with open(args.output, "w") as f:
f.write("\n".join(dead_streams))
logging.info(f"Found {len(dead_streams)} dead streams. Saved to {args.output}.")
if __name__ == "__main__":
main()
main()

View File

@@ -21,4 +21,4 @@ IPTV_SERVER_ADMIN_USER = os.getenv("IPTV_SERVER_ADMIN_USER", "admin")
IPTV_SERVER_ADMIN_PASSWORD = os.getenv("IPTV_SERVER_ADMIN_PASSWORD", "adminpassword")
# URL for the EPG XML file to place in the playlist's header
EPG_URL = os.getenv("EPG_URL", "https://example.com/epg.xml.gz")
EPG_URL = os.getenv("EPG_URL", "https://example.com/epg.xml.gz")

View File

@@ -1,11 +1,13 @@
import os
import boto3
from app.models import Base
from .constants import AWS_REGION
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from functools import lru_cache
from app.models import Base
from .constants import AWS_REGION
def get_db_credentials():
"""Fetch and cache DB credentials from environment or SSM Parameter Store"""
@@ -14,29 +16,40 @@ def get_db_credentials():
f"postgresql://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}"
f"@{os.getenv('DB_HOST')}/{os.getenv('DB_NAME')}"
)
ssm = boto3.client('ssm', region_name=AWS_REGION)
ssm = boto3.client("ssm", region_name=AWS_REGION)
try:
host = ssm.get_parameter(Name='/iptv-updater/DB_HOST', WithDecryption=True)['Parameter']['Value']
user = ssm.get_parameter(Name='/iptv-updater/DB_USER', WithDecryption=True)['Parameter']['Value']
password = ssm.get_parameter(Name='/iptv-updater/DB_PASSWORD', WithDecryption=True)['Parameter']['Value']
dbname = ssm.get_parameter(Name='/iptv-updater/DB_NAME', WithDecryption=True)['Parameter']['Value']
host = ssm.get_parameter(Name="/iptv-updater/DB_HOST", WithDecryption=True)[
"Parameter"
]["Value"]
user = ssm.get_parameter(Name="/iptv-updater/DB_USER", WithDecryption=True)[
"Parameter"
]["Value"]
password = ssm.get_parameter(
Name="/iptv-updater/DB_PASSWORD", WithDecryption=True
)["Parameter"]["Value"]
dbname = ssm.get_parameter(Name="/iptv-updater/DB_NAME", WithDecryption=True)[
"Parameter"
]["Value"]
return f"postgresql://{user}:{password}@{host}/{dbname}"
except Exception as e:
raise RuntimeError(f"Failed to fetch DB credentials from SSM: {str(e)}")
# Initialize engine and session maker
engine = create_engine(get_db_credentials())
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def init_db():
"""Initialize database by creating all tables"""
Base.metadata.create_all(bind=engine)
def get_db():
"""Dependency for getting database session"""
db = SessionLocal()
try:
yield db
finally:
db.close()
db.close()

View File

@@ -1,80 +1,69 @@
import os
from aws_cdk import (
Duration,
RemovalPolicy,
Stack,
aws_ec2 as ec2,
aws_iam as iam,
aws_cognito as cognito,
aws_rds as rds,
aws_ssm as ssm,
CfnOutput
)
from aws_cdk import CfnOutput, Duration, RemovalPolicy, Stack
from aws_cdk import aws_cognito as cognito
from aws_cdk import aws_ec2 as ec2
from aws_cdk import aws_iam as iam
from aws_cdk import aws_rds as rds
from aws_cdk import aws_ssm as ssm
from constructs import Construct
class IptvUpdaterStack(Stack):
def __init__(
self,
scope: Construct,
construct_id: str,
freedns_user: str,
freedns_password: str,
domain_name: str,
ssh_public_key: str,
repo_url: str,
letsencrypt_email: str,
**kwargs
) -> None:
self,
scope: Construct,
construct_id: str,
freedns_user: str,
freedns_password: str,
domain_name: str,
ssh_public_key: str,
repo_url: str,
letsencrypt_email: str,
**kwargs,
) -> None:
super().__init__(scope, construct_id, **kwargs)
# Create VPC
vpc = ec2.Vpc(self, "IptvUpdaterVPC",
vpc = ec2.Vpc(
self,
"IptvUpdaterVPC",
max_azs=2, # Need at least 2 AZs for RDS subnet group
nat_gateways=0, # No NAT Gateway to stay in free tier
subnet_configuration=[
ec2.SubnetConfiguration(
name="public",
subnet_type=ec2.SubnetType.PUBLIC,
cidr_mask=24
name="public", subnet_type=ec2.SubnetType.PUBLIC, cidr_mask=24
),
ec2.SubnetConfiguration(
name="private",
subnet_type=ec2.SubnetType.PRIVATE_ISOLATED,
cidr_mask=24
)
]
cidr_mask=24,
),
],
)
# Security Group
security_group = ec2.SecurityGroup(
self, "IptvUpdaterSG",
vpc=vpc,
allow_all_outbound=True
self, "IptvUpdaterSG", vpc=vpc, allow_all_outbound=True
)
security_group.add_ingress_rule(
ec2.Peer.any_ipv4(),
ec2.Port.tcp(443),
"Allow HTTPS traffic"
)
security_group.add_ingress_rule(
ec2.Peer.any_ipv4(),
ec2.Port.tcp(80),
"Allow HTTP traffic"
ec2.Peer.any_ipv4(), ec2.Port.tcp(443), "Allow HTTPS traffic"
)
security_group.add_ingress_rule(
ec2.Peer.any_ipv4(),
ec2.Port.tcp(22),
"Allow SSH traffic"
ec2.Peer.any_ipv4(), ec2.Port.tcp(80), "Allow HTTP traffic"
)
security_group.add_ingress_rule(
ec2.Peer.any_ipv4(), ec2.Port.tcp(22), "Allow SSH traffic"
)
# Allow PostgreSQL port for tunneling restricted to developer IP
security_group.add_ingress_rule(
ec2.Peer.ipv4("47.189.88.48/32"), # Developer IP
ec2.Port.tcp(5432),
"Allow PostgreSQL traffic for tunneling"
"Allow PostgreSQL traffic for tunneling",
)
# Key pair for IPTV Updater instance
@@ -82,13 +71,14 @@ class IptvUpdaterStack(Stack):
self,
"IptvUpdaterKeyPair",
key_pair_name="iptv-updater-key",
public_key_material=ssh_public_key
public_key_material=ssh_public_key,
)
# Create IAM role for EC2
role = iam.Role(
self, "IptvUpdaterRole",
assumed_by=iam.ServicePrincipal("ec2.amazonaws.com")
self,
"IptvUpdaterRole",
assumed_by=iam.ServicePrincipal("ec2.amazonaws.com"),
)
# Add SSM managed policy
@@ -99,37 +89,36 @@ class IptvUpdaterStack(Stack):
)
# Add EC2 describe permissions
role.add_to_policy(iam.PolicyStatement(
actions=["ec2:DescribeInstances"],
resources=["*"]
))
role.add_to_policy(
iam.PolicyStatement(actions=["ec2:DescribeInstances"], resources=["*"])
)
# Add SSM SendCommand permissions
role.add_to_policy(iam.PolicyStatement(
actions=["ssm:SendCommand"],
resources=[
f"arn:aws:ec2:{self.region}:{self.account}:instance/*", # Allow on all EC2 instances
f"arn:aws:ssm:{self.region}:{self.account}:document/AWS-RunShellScript" # Required for the RunShellScript document
]
))
role.add_to_policy(
iam.PolicyStatement(
actions=["ssm:SendCommand"],
resources=[
# Allow on all EC2 instances
f"arn:aws:ec2:{self.region}:{self.account}:instance/*",
# Required for the RunShellScript document
f"arn:aws:ssm:{self.region}:{self.account}:document/AWS-RunShellScript",
],
)
)
# Add Cognito permissions to instance role
role.add_managed_policy(
iam.ManagedPolicy.from_aws_managed_policy_name(
"AmazonCognitoReadOnly"
)
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonCognitoReadOnly")
)
# EC2 Instance
instance = ec2.Instance(
self, "IptvUpdaterInstance",
self,
"IptvUpdaterInstance",
vpc=vpc,
vpc_subnets=ec2.SubnetSelection(
subnet_type=ec2.SubnetType.PUBLIC
),
vpc_subnets=ec2.SubnetSelection(subnet_type=ec2.SubnetType.PUBLIC),
instance_type=ec2.InstanceType.of(
ec2.InstanceClass.T2,
ec2.InstanceSize.MICRO
ec2.InstanceClass.T2, ec2.InstanceSize.MICRO
),
machine_image=ec2.AmazonLinuxImage(
generation=ec2.AmazonLinuxGeneration.AMAZON_LINUX_2023
@@ -138,7 +127,7 @@ class IptvUpdaterStack(Stack):
key_pair=key_pair,
role=role,
# Option: 1: Enable auto-assign public IP (free tier compatible)
associate_public_ip_address=True
associate_public_ip_address=True,
)
# Option: 2: Create Elastic IP (not free tier compatible)
@@ -150,7 +139,8 @@ class IptvUpdaterStack(Stack):
# Add Cognito User Pool
user_pool = cognito.UserPool(
self, "IptvUpdaterUserPool",
self,
"IptvUpdaterUserPool",
user_pool_name="iptv-updater-users",
self_sign_up_enabled=False, # Only admins can create users
password_policy=cognito.PasswordPolicy(
@@ -158,37 +148,33 @@ class IptvUpdaterStack(Stack):
require_lowercase=True,
require_digits=True,
require_symbols=True,
require_uppercase=True
require_uppercase=True,
),
account_recovery=cognito.AccountRecovery.EMAIL_ONLY,
removal_policy=RemovalPolicy.DESTROY
removal_policy=RemovalPolicy.DESTROY,
)
# Add App Client with the correct callback URL
client = user_pool.add_client("IptvUpdaterClient",
client = user_pool.add_client(
"IptvUpdaterClient",
access_token_validity=Duration.minutes(60),
id_token_validity=Duration.minutes(60),
refresh_token_validity=Duration.days(1),
auth_flows=cognito.AuthFlow(
user_password=True
),
auth_flows=cognito.AuthFlow(user_password=True),
o_auth=cognito.OAuthSettings(
flows=cognito.OAuthFlows(
implicit_code_grant=True
)
flows=cognito.OAuthFlows(implicit_code_grant=True)
),
prevent_user_existence_errors=True,
generate_secret=True,
enable_token_revocation=True
enable_token_revocation=True,
)
# Add domain for hosted UI
domain = user_pool.add_domain("IptvUpdaterDomain",
cognito_domain=cognito.CognitoDomainOptions(
domain_prefix="iptv-updater"
)
domain = user_pool.add_domain(
"IptvUpdaterDomain",
cognito_domain=cognito.CognitoDomainOptions(domain_prefix="iptv-updater"),
)
# Read the userdata script with proper path resolution
script_dir = os.path.dirname(os.path.abspath(__file__))
userdata_path = os.path.join(script_dir, "userdata.sh")
@@ -196,46 +182,56 @@ class IptvUpdaterStack(Stack):
# Creates a userdata object for Linux hosts
userdata = ec2.UserData.for_linux()
# Add environment variables for acme.sh from parameters
userdata.add_commands(
f'export FREEDNS_User="{freedns_user}"',
f'export FREEDNS_Password="{freedns_password}"',
f'export DOMAIN_NAME="{domain_name}"',
f'export REPO_URL="{repo_url}"',
f'export LETSENCRYPT_EMAIL="{letsencrypt_email}"'
f'export LETSENCRYPT_EMAIL="{letsencrypt_email}"',
)
# Adds one or more commands to the userdata object.
userdata.add_commands(
f'echo "COGNITO_USER_POOL_ID={user_pool.user_pool_id}" >> /etc/environment',
f'echo "COGNITO_CLIENT_ID={client.user_pool_client_id}" >> /etc/environment',
f'echo "COGNITO_CLIENT_SECRET={client.user_pool_client_secret.to_string()}" >> /etc/environment',
f'echo "DOMAIN_NAME={domain_name}" >> /etc/environment'
(
f'echo "COGNITO_USER_POOL_ID='
f'{user_pool.user_pool_id}" >> /etc/environment'
),
(
f'echo "COGNITO_CLIENT_ID='
f'{client.user_pool_client_id}" >> /etc/environment'
),
(
f'echo "COGNITO_CLIENT_SECRET='
f'{client.user_pool_client_secret.to_string()}" >> /etc/environment'
),
f'echo "DOMAIN_NAME={domain_name}" >> /etc/environment',
)
userdata.add_commands(str(userdata_file, 'utf-8'))
userdata.add_commands(str(userdata_file, "utf-8"))
# Create RDS Security Group
rds_sg = ec2.SecurityGroup(
self, "RdsSecurityGroup",
self,
"RdsSecurityGroup",
vpc=vpc,
description="Security group for RDS PostgreSQL"
description="Security group for RDS PostgreSQL",
)
rds_sg.add_ingress_rule(
security_group,
ec2.Port.tcp(5432),
"Allow PostgreSQL access from EC2 instance"
"Allow PostgreSQL access from EC2 instance",
)
# Create RDS PostgreSQL instance (free tier compatible - db.t3.micro)
db = rds.DatabaseInstance(
self, "IptvUpdaterDB",
self,
"IptvUpdaterDB",
engine=rds.DatabaseInstanceEngine.postgres(
version=rds.PostgresEngineVersion.VER_13
),
instance_type=ec2.InstanceType.of(
ec2.InstanceClass.T3,
ec2.InstanceSize.MICRO
ec2.InstanceClass.T3, ec2.InstanceSize.MICRO
),
vpc=vpc,
vpc_subnets=ec2.SubnetSelection(
@@ -247,39 +243,43 @@ class IptvUpdaterStack(Stack):
database_name="iptv_updater",
removal_policy=RemovalPolicy.DESTROY,
deletion_protection=False,
publicly_accessible=False # Avoid public IPv4 charges
publicly_accessible=False, # Avoid public IPv4 charges
)
# Add RDS permissions to instance role
role.add_managed_policy(
iam.ManagedPolicy.from_aws_managed_policy_name(
"AmazonRDSFullAccess"
)
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonRDSFullAccess")
)
# Store DB connection info in SSM Parameter Store
ssm.StringParameter(self, "DBHostParam",
ssm.StringParameter(
self,
"DBHostParam",
parameter_name="/iptv-updater/DB_HOST",
string_value=db.db_instance_endpoint_address
string_value=db.db_instance_endpoint_address,
)
ssm.StringParameter(self, "DBNameParam",
ssm.StringParameter(
self,
"DBNameParam",
parameter_name="/iptv-updater/DB_NAME",
string_value="iptv_updater"
string_value="iptv_updater",
)
ssm.StringParameter(self, "DBUserParam",
ssm.StringParameter(
self,
"DBUserParam",
parameter_name="/iptv-updater/DB_USER",
string_value=db.secret.secret_value_from_json("username").to_string()
string_value=db.secret.secret_value_from_json("username").to_string(),
)
ssm.StringParameter(self, "DBPassParam",
ssm.StringParameter(
self,
"DBPassParam",
parameter_name="/iptv-updater/DB_PASSWORD",
string_value=db.secret.secret_value_from_json("password").to_string()
string_value=db.secret.secret_value_from_json("password").to_string(),
)
# Add SSM read permissions to instance role
role.add_managed_policy(
iam.ManagedPolicy.from_aws_managed_policy_name(
"AmazonSSMReadOnlyAccess"
)
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonSSMReadOnlyAccess")
)
# Update instance with userdata
@@ -293,6 +293,8 @@ class IptvUpdaterStack(Stack):
# CfnOutput(self, "InstancePublicIP", value=eip.attr_public_ip)
CfnOutput(self, "UserPoolId", value=user_pool.user_pool_id)
CfnOutput(self, "UserPoolClientId", value=client.user_pool_client_id)
CfnOutput(self, "CognitoDomainUrl",
value=f"https://{domain.domain_name}.auth.{self.region}.amazoncognito.com"
)
CfnOutput(
self,
"CognitoDomainUrl",
value=f"https://{domain.domain_name}.auth.{self.region}.amazoncognito.com",
)

View File

@@ -1,5 +1,10 @@
[tool.ruff]
line-length = 88
exclude = [
"alembic/versions/*.py", # Auto-generated Alembic migration files
]
[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"F", # pyflakes
@@ -9,7 +14,13 @@ select = [
]
ignore = []
[tool.ruff.isort]
[tool.ruff.lint.per-file-ignores]
"tests/**/*.py" = [
"F811", # redefinition of unused name
"F401", # unused import
]
[tool.ruff.lint.isort]
known-first-party = ["app"]
[tool.ruff.format]

View File

@@ -1,5 +1,6 @@
from unittest.mock import MagicMock, patch
import pytest
from unittest.mock import patch, MagicMock
from fastapi import HTTPException, status
# Test constants
@@ -7,12 +8,15 @@ TEST_CLIENT_ID = "test_client_id"
TEST_CLIENT_SECRET = "test_client_secret"
# Patch constants before importing the module
with patch("app.utils.constants.COGNITO_CLIENT_ID", TEST_CLIENT_ID), \
patch("app.utils.constants.COGNITO_CLIENT_SECRET", TEST_CLIENT_SECRET):
from app.auth.cognito import initiate_auth, get_user_from_token
with (
patch("app.utils.constants.COGNITO_CLIENT_ID", TEST_CLIENT_ID),
patch("app.utils.constants.COGNITO_CLIENT_SECRET", TEST_CLIENT_SECRET),
):
from app.auth.cognito import get_user_from_token, initiate_auth
from app.models.auth import CognitoUser
from app.utils.constants import USER_ROLE_ATTRIBUTE
@pytest.fixture(autouse=True)
def mock_cognito_client():
with patch("app.auth.cognito.cognito_client") as mock_client:
@@ -26,13 +30,14 @@ def mock_cognito_client():
)
yield mock_client
def test_initiate_auth_success(mock_cognito_client):
# Mock successful authentication response
mock_cognito_client.initiate_auth.return_value = {
"AuthenticationResult": {
"AccessToken": "mock_access_token",
"IdToken": "mock_id_token",
"RefreshToken": "mock_refresh_token"
"RefreshToken": "mock_refresh_token",
}
}
@@ -40,104 +45,125 @@ def test_initiate_auth_success(mock_cognito_client):
assert result == {
"AccessToken": "mock_access_token",
"IdToken": "mock_id_token",
"RefreshToken": "mock_refresh_token"
"RefreshToken": "mock_refresh_token",
}
def test_initiate_auth_with_secret_hash(mock_cognito_client):
with patch("app.auth.cognito.calculate_secret_hash", return_value="mocked_secret_hash") as mock_hash:
with patch(
"app.auth.cognito.calculate_secret_hash", return_value="mocked_secret_hash"
) as mock_hash:
mock_cognito_client.initiate_auth.return_value = {
"AuthenticationResult": {"AccessToken": "token"}
}
result = initiate_auth("test_user", "test_pass")
initiate_auth("test_user", "test_pass")
# Verify calculate_secret_hash was called
mock_hash.assert_called_once_with("test_user", TEST_CLIENT_ID, TEST_CLIENT_SECRET)
mock_hash.assert_called_once_with(
"test_user", TEST_CLIENT_ID, TEST_CLIENT_SECRET
)
# Verify SECRET_HASH was included in auth params
call_args = mock_cognito_client.initiate_auth.call_args[1]
assert "SECRET_HASH" in call_args["AuthParameters"]
assert call_args["AuthParameters"]["SECRET_HASH"] == "mocked_secret_hash"
def test_initiate_auth_not_authorized(mock_cognito_client):
mock_cognito_client.initiate_auth.side_effect = mock_cognito_client.exceptions.NotAuthorizedException()
mock_cognito_client.initiate_auth.side_effect = (
mock_cognito_client.exceptions.NotAuthorizedException()
)
with pytest.raises(HTTPException) as exc_info:
initiate_auth("invalid_user", "wrong_pass")
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert exc_info.value.detail == "Invalid username or password"
def test_initiate_auth_user_not_found(mock_cognito_client):
mock_cognito_client.initiate_auth.side_effect = mock_cognito_client.exceptions.UserNotFoundException()
mock_cognito_client.initiate_auth.side_effect = (
mock_cognito_client.exceptions.UserNotFoundException()
)
with pytest.raises(HTTPException) as exc_info:
initiate_auth("nonexistent_user", "any_pass")
assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND
assert exc_info.value.detail == "User not found"
def test_initiate_auth_generic_error(mock_cognito_client):
mock_cognito_client.initiate_auth.side_effect = Exception("Some error")
with pytest.raises(HTTPException) as exc_info:
initiate_auth("test_user", "test_pass")
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert "An error occurred during authentication" in exc_info.value.detail
def test_get_user_from_token_success(mock_cognito_client):
mock_response = {
"Username": "test_user",
"UserAttributes": [
{"Name": "sub", "Value": "123"},
{"Name": USER_ROLE_ATTRIBUTE, "Value": "admin,user"}
]
{"Name": USER_ROLE_ATTRIBUTE, "Value": "admin,user"},
],
}
mock_cognito_client.get_user.return_value = mock_response
result = get_user_from_token("valid_token")
assert isinstance(result, CognitoUser)
assert result.username == "test_user"
assert set(result.roles) == {"admin", "user"}
def test_get_user_from_token_no_roles(mock_cognito_client):
mock_response = {
"Username": "test_user",
"UserAttributes": [{"Name": "sub", "Value": "123"}]
"UserAttributes": [{"Name": "sub", "Value": "123"}],
}
mock_cognito_client.get_user.return_value = mock_response
result = get_user_from_token("valid_token")
assert isinstance(result, CognitoUser)
assert result.username == "test_user"
assert result.roles == []
def test_get_user_from_token_invalid_token(mock_cognito_client):
mock_cognito_client.get_user.side_effect = mock_cognito_client.exceptions.NotAuthorizedException()
mock_cognito_client.get_user.side_effect = (
mock_cognito_client.exceptions.NotAuthorizedException()
)
with pytest.raises(HTTPException) as exc_info:
get_user_from_token("invalid_token")
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert exc_info.value.detail == "Invalid or expired token."
def test_get_user_from_token_user_not_found(mock_cognito_client):
mock_cognito_client.get_user.side_effect = mock_cognito_client.exceptions.UserNotFoundException()
mock_cognito_client.get_user.side_effect = (
mock_cognito_client.exceptions.UserNotFoundException()
)
with pytest.raises(HTTPException) as exc_info:
get_user_from_token("token_for_nonexistent_user")
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert exc_info.value.detail == "User not found or invalid token."
def test_get_user_from_token_generic_error(mock_cognito_client):
mock_cognito_client.get_user.side_effect = Exception("Some error")
with pytest.raises(HTTPException) as exc_info:
get_user_from_token("test_token")
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert "Token verification failed" in exc_info.value.detail
assert "Token verification failed" in exc_info.value.detail

View File

@@ -1,9 +1,11 @@
import os
import pytest
import importlib
import os
import pytest
from fastapi import Depends, HTTPException, Request
from fastapi.security import OAuth2PasswordBearer
from fastapi import HTTPException, Depends, Request
from app.auth.dependencies import get_current_user, require_roles, oauth2_scheme
from app.auth.dependencies import get_current_user, oauth2_scheme, require_roles
from app.models.auth import CognitoUser
# Mock user for testing
@@ -11,24 +13,30 @@ TEST_USER = CognitoUser(
username="testuser",
email="test@example.com",
roles=["admin", "user"],
groups=["test_group"]
groups=["test_group"],
)
# Mock the underlying get_user_from_token function
def mock_get_user_from_token(token: str) -> CognitoUser:
if token == "valid_token":
return TEST_USER
raise HTTPException(status_code=401, detail="Invalid token")
# Mock endpoint for testing the require_roles decorator
@require_roles("admin")
async def mock_protected_endpoint(user: CognitoUser = Depends(get_current_user)):
return {"message": "Success", "user": user.username}
# Patch the get_user_from_token function for testing
@pytest.fixture(autouse=True)
def mock_auth(monkeypatch):
monkeypatch.setattr("app.auth.dependencies.get_user_from_token", mock_get_user_from_token)
monkeypatch.setattr(
"app.auth.dependencies.get_user_from_token", mock_get_user_from_token
)
# Test get_current_user dependency
def test_get_current_user_success():
@@ -37,54 +45,53 @@ def test_get_current_user_success():
assert user.username == "testuser"
assert user.roles == ["admin", "user"]
def test_get_current_user_invalid_token():
with pytest.raises(HTTPException) as exc:
get_current_user("invalid_token")
assert exc.value.status_code == 401
# Test require_roles decorator
@pytest.mark.asyncio
async def test_require_roles_success():
# Create test user with required role
user = CognitoUser(
username="testuser",
email="test@example.com",
roles=["admin"],
groups=[]
username="testuser", email="test@example.com", roles=["admin"], groups=[]
)
result = await mock_protected_endpoint(user=user)
assert result == {"message": "Success", "user": "testuser"}
@pytest.mark.asyncio
async def test_require_roles_missing_role():
# Create test user without required role
user = CognitoUser(
username="testuser",
email="test@example.com",
roles=["user"],
groups=[]
username="testuser", email="test@example.com", roles=["user"], groups=[]
)
with pytest.raises(HTTPException) as exc:
await mock_protected_endpoint(user=user)
assert exc.value.status_code == 403
assert exc.value.detail == "You do not have the required roles to access this endpoint."
assert (
exc.value.detail
== "You do not have the required roles to access this endpoint."
)
@pytest.mark.asyncio
async def test_require_roles_no_roles():
# Create test user with no roles
user = CognitoUser(
username="testuser",
email="test@example.com",
roles=[],
groups=[]
username="testuser", email="test@example.com", roles=[], groups=[]
)
with pytest.raises(HTTPException) as exc:
await mock_protected_endpoint(user=user)
assert exc.value.status_code == 403
@pytest.mark.asyncio
async def test_require_roles_multiple_roles():
# Test requiring multiple roles
@@ -97,7 +104,7 @@ async def test_require_roles_multiple_roles():
username="testuser",
email="test@example.com",
roles=["admin", "super_user", "user"],
groups=[]
groups=[],
)
result = await mock_multi_role_endpoint(user=user_with_roles)
assert result == {"message": "Success"}
@@ -107,56 +114,62 @@ async def test_require_roles_multiple_roles():
username="testuser",
email="test@example.com",
roles=["admin", "user"],
groups=[]
groups=[],
)
with pytest.raises(HTTPException) as exc:
await mock_multi_role_endpoint(user=user_missing_role)
assert exc.value.status_code == 403
@pytest.mark.asyncio
async def test_oauth2_scheme_configuration():
# Verify that we have a properly configured OAuth2PasswordBearer instance
assert isinstance(oauth2_scheme, OAuth2PasswordBearer)
# Create a mock request with no Authorization header
mock_request = Request(scope={
'type': 'http',
'headers': [],
'method': 'GET',
'scheme': 'http',
'path': '/',
'query_string': b'',
'client': ('127.0.0.1', 8000)
})
mock_request = Request(
scope={
"type": "http",
"headers": [],
"method": "GET",
"scheme": "http",
"path": "/",
"query_string": b"",
"client": ("127.0.0.1", 8000),
}
)
# Test that the scheme raises 401 when no token is provided
with pytest.raises(HTTPException) as exc:
await oauth2_scheme(mock_request)
assert exc.value.status_code == 401
assert exc.value.detail == "Not authenticated"
def test_mock_auth_import(monkeypatch):
# Save original env var value
original_value = os.environ.get("MOCK_AUTH")
try:
# Set MOCK_AUTH to true
monkeypatch.setenv("MOCK_AUTH", "true")
# Reload the dependencies module to trigger the import condition
import app.auth.dependencies
importlib.reload(app.auth.dependencies)
# Verify that mock_get_user_from_token was imported
from app.auth.dependencies import get_user_from_token
assert get_user_from_token.__module__ == 'app.auth.mock_auth'
assert get_user_from_token.__module__ == "app.auth.mock_auth"
finally:
# Restore original env var
if original_value is None:
monkeypatch.delenv("MOCK_AUTH", raising=False)
else:
monkeypatch.setenv("MOCK_AUTH", original_value)
# Reload again to restore original state
importlib.reload(app.auth.dependencies)
importlib.reload(app.auth.dependencies)

View File

@@ -1,8 +1,10 @@
import pytest
from fastapi import HTTPException
from app.auth.mock_auth import mock_get_user_from_token, mock_initiate_auth
from app.models.auth import CognitoUser
def test_mock_get_user_from_token_success():
"""Test successful token validation returns expected user"""
user = mock_get_user_from_token("testuser")
@@ -10,27 +12,30 @@ def test_mock_get_user_from_token_success():
assert user.username == "testuser"
assert user.roles == ["admin"]
def test_mock_get_user_from_token_invalid():
"""Test invalid token raises expected exception"""
with pytest.raises(HTTPException) as exc_info:
mock_get_user_from_token("invalid_token")
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Invalid mock token - use 'testuser'"
def test_mock_initiate_auth():
"""Test mock authentication returns expected token response"""
result = mock_initiate_auth("any_user", "any_password")
assert isinstance(result, dict)
assert result["AccessToken"] == "testuser"
assert result["ExpiresIn"] == 3600
assert result["TokenType"] == "Bearer"
def test_mock_initiate_auth_different_credentials():
"""Test mock authentication works with any credentials"""
result1 = mock_initiate_auth("user1", "pass1")
result2 = mock_initiate_auth("user2", "pass2")
# Should return same mock token regardless of credentials
assert result1 == result2
assert result1 == result2

View File

@@ -1,34 +1,35 @@
from unittest.mock import patch
import pytest
from fastapi.testclient import TestClient
from fastapi import HTTPException, status
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
@pytest.fixture
def mock_successful_auth():
return {
"AccessToken": "mock_access_token",
"IdToken": "mock_id_token",
"RefreshToken": "mock_refresh_token"
"RefreshToken": "mock_refresh_token",
}
@pytest.fixture
def mock_successful_auth_no_refresh():
return {
"AccessToken": "mock_access_token",
"IdToken": "mock_id_token"
}
return {"AccessToken": "mock_access_token", "IdToken": "mock_id_token"}
def test_signin_success(mock_successful_auth):
"""Test successful signin with all tokens"""
with patch('app.routers.auth.initiate_auth', return_value=mock_successful_auth):
with patch("app.routers.auth.initiate_auth", return_value=mock_successful_auth):
response = client.post(
"/auth/signin",
json={"username": "testuser", "password": "testpass"}
"/auth/signin", json={"username": "testuser", "password": "testpass"}
)
assert response.status_code == 200
data = response.json()
assert data["access_token"] == "mock_access_token"
@@ -36,14 +37,16 @@ def test_signin_success(mock_successful_auth):
assert data["refresh_token"] == "mock_refresh_token"
assert data["token_type"] == "Bearer"
def test_signin_success_no_refresh(mock_successful_auth_no_refresh):
"""Test successful signin without refresh token"""
with patch('app.routers.auth.initiate_auth', return_value=mock_successful_auth_no_refresh):
with patch(
"app.routers.auth.initiate_auth", return_value=mock_successful_auth_no_refresh
):
response = client.post(
"/auth/signin",
json={"username": "testuser", "password": "testpass"}
"/auth/signin", json={"username": "testuser", "password": "testpass"}
)
assert response.status_code == 200
data = response.json()
assert data["access_token"] == "mock_access_token"
@@ -51,57 +54,48 @@ def test_signin_success_no_refresh(mock_successful_auth_no_refresh):
assert data["refresh_token"] is None
assert data["token_type"] == "Bearer"
def test_signin_invalid_input():
"""Test signin with invalid input format"""
# Missing password
response = client.post(
"/auth/signin",
json={"username": "testuser"}
)
response = client.post("/auth/signin", json={"username": "testuser"})
assert response.status_code == 422
# Missing username
response = client.post(
"/auth/signin",
json={"password": "testpass"}
)
response = client.post("/auth/signin", json={"password": "testpass"})
assert response.status_code == 422
# Empty payload
response = client.post(
"/auth/signin",
json={}
)
response = client.post("/auth/signin", json={})
assert response.status_code == 422
def test_signin_auth_failure():
"""Test signin with authentication failure"""
with patch('app.routers.auth.initiate_auth') as mock_auth:
with patch("app.routers.auth.initiate_auth") as mock_auth:
mock_auth.side_effect = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid username or password"
detail="Invalid username or password",
)
response = client.post(
"/auth/signin",
json={"username": "testuser", "password": "wrongpass"}
"/auth/signin", json={"username": "testuser", "password": "wrongpass"}
)
assert response.status_code == 401
data = response.json()
assert data["detail"] == "Invalid username or password"
def test_signin_user_not_found():
"""Test signin with non-existent user"""
with patch('app.routers.auth.initiate_auth') as mock_auth:
with patch("app.routers.auth.initiate_auth") as mock_auth:
mock_auth.side_effect = HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
response = client.post(
"/auth/signin",
json={"username": "nonexistent", "password": "testpass"}
"/auth/signin", json={"username": "nonexistent", "password": "testpass"}
)
assert response.status_code == 404
data = response.json()
assert data["detail"] == "User not found"
assert data["detail"] == "User not found"

File diff suppressed because it is too large Load Diff

View File

@@ -1,19 +1,24 @@
from unittest.mock import patch
import pytest
from fastapi.testclient import TestClient
from app.main import app, lifespan
from unittest.mock import patch, MagicMock
@pytest.fixture
def client():
"""Test client for FastAPI app"""
return TestClient(app)
def test_root_endpoint(client):
"""Test root endpoint returns expected message"""
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"message": "IPTV Updater API"}
def test_openapi_schema_generation(client):
"""Test OpenAPI schema is properly generated"""
# First call - generate schema
@@ -23,7 +28,7 @@ def test_openapi_schema_generation(client):
assert schema["openapi"] == "3.1.0"
assert "securitySchemes" in schema["components"]
assert "Bearer" in schema["components"]["securitySchemes"]
# Test empty components initialization
with patch("app.main.get_openapi", return_value={"info": {}}):
# Clear cached schema
@@ -35,26 +40,28 @@ def test_openapi_schema_generation(client):
assert "components" in schema
assert "schemas" in schema["components"]
def test_openapi_schema_caching(mocker):
"""Test OpenAPI schema caching behavior"""
# Clear any existing schema
app.openapi_schema = None
# Mock get_openapi to return test schema
mock_schema = {"test": "schema"}
mocker.patch("app.main.get_openapi", return_value=mock_schema)
# First call - should call get_openapi
schema = app.openapi()
assert schema == mock_schema
assert app.openapi_schema == mock_schema
# Second call - should return cached schema
with patch("app.main.get_openapi") as mock_get_openapi:
schema = app.openapi()
assert schema == mock_schema
mock_get_openapi.assert_not_called()
@pytest.mark.asyncio
async def test_lifespan_init_db(mocker):
"""Test lifespan manager initializes database"""
@@ -63,6 +70,7 @@ async def test_lifespan_init_db(mocker):
pass # Just enter/exit context
mock_init_db.assert_called_once()
def test_router_inclusion():
"""Test all routers are properly included"""
route_paths = {route.path for route in app.routes}
@@ -70,4 +78,4 @@ def test_router_inclusion():
assert any(path.startswith("/auth") for path in route_paths)
assert any(path.startswith("/channels") for path in route_paths)
assert any(path.startswith("/playlist") for path in route_paths)
assert any(path.startswith("/priorities") for path in route_paths)
assert any(path.startswith("/priorities") for path in route_paths)

View File

@@ -1,17 +1,30 @@
import os
import uuid
from datetime import datetime, timezone
from unittest.mock import patch, MagicMock
from sqlalchemy.pool import StaticPool
from sqlalchemy.orm import Session, sessionmaker, declarative_base
from sqlalchemy import create_engine, TypeDecorator, TEXT, Column, String, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer
from unittest.mock import MagicMock, patch
import pytest
from sqlalchemy import (
TEXT,
Boolean,
Column,
DateTime,
ForeignKey,
Integer,
String,
TypeDecorator,
UniqueConstraint,
create_engine,
)
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.pool import StaticPool
# Create a mock-specific Base class for testing
MockBase = declarative_base()
class SQLiteUUID(TypeDecorator):
"""Enables UUID support for SQLite."""
impl = TEXT
cache_ok = True
@@ -25,12 +38,14 @@ class SQLiteUUID(TypeDecorator):
return value
return uuid.UUID(value)
# Model classes for testing - prefix with Mock to avoid pytest collection
class MockPriority(MockBase):
__tablename__ = "priorities"
id = Column(Integer, primary_key=True)
description = Column(String, nullable=False)
class MockChannelDB(MockBase):
__tablename__ = "channels"
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
@@ -39,32 +54,45 @@ class MockChannelDB(MockBase):
group_title = Column(String, nullable=False)
tvg_name = Column(String)
__table_args__ = (
UniqueConstraint('group_title', 'name', name='uix_group_title_name'),
UniqueConstraint("group_title", "name", name="uix_group_title_name"),
)
tvg_logo = Column(String)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
updated_at = Column(
DateTime,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
class MockChannelURL(MockBase):
__tablename__ = "channels_urls"
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
channel_id = Column(SQLiteUUID(), ForeignKey('channels.id', ondelete='CASCADE'), nullable=False)
channel_id = Column(
SQLiteUUID(), ForeignKey("channels.id", ondelete="CASCADE"), nullable=False
)
url = Column(String, nullable=False)
in_use = Column(Boolean, default=False, nullable=False)
priority_id = Column(Integer, ForeignKey('priorities.id'), nullable=False)
priority_id = Column(Integer, ForeignKey("priorities.id"), nullable=False)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
updated_at = Column(
DateTime,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
# Create test engine
engine_mock = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool
poolclass=StaticPool,
)
# Create test session
session_mock = sessionmaker(autocommit=False, autoflush=False, bind=engine_mock)
# Mock the actual database functions
def mock_get_db():
db = session_mock()
@@ -73,6 +101,7 @@ def mock_get_db():
finally:
db.close()
@pytest.fixture(autouse=True)
def mock_env(monkeypatch):
"""Fixture for mocking environment variables"""
@@ -82,14 +111,13 @@ def mock_env(monkeypatch):
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_NAME", "testdb")
monkeypatch.setenv("AWS_REGION", "us-east-1")
@pytest.fixture
def mock_ssm():
"""Fixture for mocking boto3 SSM client"""
with patch('boto3.client') as mock_client:
with patch("boto3.client") as mock_client:
mock_ssm = MagicMock()
mock_client.return_value = mock_ssm
mock_ssm.get_parameter.return_value = {
'Parameter': {'Value': 'mocked_value'}
}
yield mock_ssm
mock_ssm.get_parameter.return_value = {"Parameter": {"Value": "mocked_value"}}
yield mock_ssm

View File

@@ -1,43 +1,45 @@
import os
import pytest
from unittest.mock import patch
from sqlalchemy.orm import Session
from app.utils.database import get_db_credentials, get_db
from tests.utils.db_mocks import (
session_mock,
mock_get_db,
mock_env,
mock_ssm
)
from app.utils.database import get_db, get_db_credentials
from tests.utils.db_mocks import mock_env, mock_ssm, session_mock
def test_get_db_credentials_env(mock_env):
"""Test getting DB credentials from environment variables"""
conn_str = get_db_credentials()
assert conn_str == "postgresql://testuser:testpass@localhost/testdb"
def test_get_db_credentials_ssm(mock_ssm):
"""Test getting DB credentials from SSM"""
os.environ.pop("MOCK_AUTH", None)
conn_str = get_db_credentials()
assert "postgresql://mocked_value:mocked_value@mocked_value/mocked_value" in conn_str
expected_conn = "postgresql://mocked_value:mocked_value@mocked_value/mocked_value"
assert expected_conn in conn_str
mock_ssm.get_parameter.assert_called()
def test_get_db_credentials_ssm_exception(mock_ssm):
"""Test SSM credential fetching failure raises RuntimeError"""
os.environ.pop("MOCK_AUTH", None)
mock_ssm.get_parameter.side_effect = Exception("SSM timeout")
with pytest.raises(RuntimeError) as excinfo:
get_db_credentials()
assert "Failed to fetch DB credentials from SSM: SSM timeout" in str(excinfo.value)
def test_session_creation():
"""Test database session creation"""
session = session_mock()
assert isinstance(session, Session)
session.close()
def test_get_db_generator():
"""Test get_db dependency generator"""
db_gen = get_db()
@@ -48,18 +50,20 @@ def test_get_db_generator():
except StopIteration:
pass
def test_init_db(mocker, mock_env):
"""Test database initialization creates tables"""
mock_create_all = mocker.patch('app.models.Base.metadata.create_all')
mock_create_all = mocker.patch("app.models.Base.metadata.create_all")
# Mock get_db_credentials to return SQLite test connection
mocker.patch(
'app.utils.database.get_db_credentials',
return_value="sqlite:///:memory:"
"app.utils.database.get_db_credentials",
return_value="sqlite:///:memory:",
)
from app.utils.database import init_db, engine
from app.utils.database import engine, init_db
init_db()
# Verify create_all was called with the engine
mock_create_all.assert_called_once_with(bind=engine)
mock_create_all.assert_called_once_with(bind=engine)