Linted and formatted all files

This commit is contained in:
2025-05-28 21:52:39 -05:00
parent e46f13930d
commit 02913c7385
31 changed files with 1264 additions and 766 deletions

View File

@@ -1,9 +1,14 @@
import boto3
from fastapi import HTTPException, status
from app.models.auth import CognitoUser
from app.utils.auth import calculate_secret_hash
from app.utils.constants import (AWS_REGION, COGNITO_CLIENT_ID,
COGNITO_CLIENT_SECRET, USER_ROLE_ATTRIBUTE)
from app.utils.constants import (
AWS_REGION,
COGNITO_CLIENT_ID,
COGNITO_CLIENT_SECRET,
USER_ROLE_ATTRIBUTE,
)
cognito_client = boto3.client("cognito-idp", region_name=AWS_REGION)
@@ -12,43 +17,41 @@ def initiate_auth(username: str, password: str) -> dict:
"""
Initiate AUTH flow with Cognito using USER_PASSWORD_AUTH.
"""
auth_params = {
"USERNAME": username,
"PASSWORD": password
}
auth_params = {"USERNAME": username, "PASSWORD": password}
# If a client secret is required, add SECRET_HASH
if COGNITO_CLIENT_SECRET:
auth_params["SECRET_HASH"] = calculate_secret_hash(
username, COGNITO_CLIENT_ID, COGNITO_CLIENT_SECRET)
username, COGNITO_CLIENT_ID, COGNITO_CLIENT_SECRET
)
try:
response = cognito_client.initiate_auth(
AuthFlow="USER_PASSWORD_AUTH",
AuthParameters=auth_params,
ClientId=COGNITO_CLIENT_ID
ClientId=COGNITO_CLIENT_ID,
)
return response["AuthenticationResult"]
except cognito_client.exceptions.NotAuthorizedException:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid username or password"
detail="Invalid username or password",
)
except cognito_client.exceptions.UserNotFoundException:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"An error occurred during authentication: {str(e)}"
detail=f"An error occurred during authentication: {str(e)}",
)
def get_user_from_token(access_token: str) -> CognitoUser:
"""
Verify the token by calling GetUser in Cognito and retrieve user attributes including roles.
Verify the token by calling GetUser in Cognito and
retrieve user attributes including roles.
"""
try:
user_response = cognito_client.get_user(AccessToken=access_token)
@@ -59,23 +62,21 @@ def get_user_from_token(access_token: str) -> CognitoUser:
for attr in attributes:
if attr["Name"] == USER_ROLE_ATTRIBUTE:
# Assume roles are stored as a comma-separated string
user_roles = [r.strip()
for r in attr["Value"].split(",") if r.strip()]
user_roles = [r.strip() for r in attr["Value"].split(",") if r.strip()]
break
return CognitoUser(username=username, roles=user_roles)
except cognito_client.exceptions.NotAuthorizedException:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token."
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token."
)
except cognito_client.exceptions.UserNotFoundException:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or invalid token."
detail="User not found or invalid token.",
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Token verification failed: {str(e)}"
detail=f"Token verification failed: {str(e)}",
)

View File

@@ -1,6 +1,6 @@
import os
from functools import wraps
from typing import Callable
import os
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
@@ -13,10 +13,8 @@ if os.getenv("MOCK_AUTH", "").lower() == "true":
else:
from app.auth.cognito import get_user_from_token
oauth2_scheme = OAuth2PasswordBearer(
tokenUrl="signin",
scheme_name="Bearer"
)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="signin", scheme_name="Bearer")
def get_current_user(token: str = Depends(oauth2_scheme)) -> CognitoUser:
"""
@@ -40,7 +38,9 @@ def require_roles(*required_roles: str) -> Callable:
if not needed_roles.issubset(user_roles):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You do not have the required roles to access this endpoint.",
detail=(
"You do not have the required roles to access this endpoint."
),
)
return endpoint(*args, user=user, **kwargs)

View File

@@ -1,12 +1,9 @@
from fastapi import HTTPException, status
from app.models.auth import CognitoUser
MOCK_USERS = {
"testuser": {
"username": "testuser",
"roles": ["admin"]
}
}
MOCK_USERS = {"testuser": {"username": "testuser", "roles": ["admin"]}}
def mock_get_user_from_token(token: str) -> CognitoUser:
"""
@@ -17,16 +14,13 @@ def mock_get_user_from_token(token: str) -> CognitoUser:
return CognitoUser(**MOCK_USERS["testuser"])
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid mock token - use 'testuser'"
detail="Invalid mock token - use 'testuser'",
)
def mock_initiate_auth(username: str, password: str) -> dict:
"""
Mock version of initiate_auth for local testing
Accepts any username/password and returns a mock token
"""
return {
"AccessToken": "testuser",
"ExpiresIn": 3600,
"TokenType": "Bearer"
}
return {"AccessToken": "testuser", "ExpiresIn": 3600, "TokenType": "Bearer"}

View File

@@ -1,39 +1,59 @@
import os
import re
import argparse
import gzip
import json
import os
import re
import xml.etree.ElementTree as ET
import requests
import argparse
from utils.constants import IPTV_SERVER_ADMIN_PASSWORD, IPTV_SERVER_ADMIN_USER, IPTV_SERVER_URL
from utils.constants import (
IPTV_SERVER_ADMIN_PASSWORD,
IPTV_SERVER_ADMIN_USER,
IPTV_SERVER_URL,
)
def parse_arguments():
parser = argparse.ArgumentParser(description='EPG Grabber')
parser.add_argument('--playlist',
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'playlist.m3u8'),
help='Path to playlist file')
parser.add_argument('--output',
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'epg.xml'),
help='Path to output EPG XML file')
parser.add_argument('--epg-sources',
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'epg_sources.json'),
help='Path to EPG sources JSON configuration file')
parser.add_argument('--save-as-gz',
action='store_true',
default=True,
help='Save an additional gzipped version of the EPG file')
parser = argparse.ArgumentParser(description="EPG Grabber")
parser.add_argument(
"--playlist",
default=os.path.join(
os.path.dirname(os.path.dirname(__file__)), "playlist.m3u8"
),
help="Path to playlist file",
)
parser.add_argument(
"--output",
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "epg.xml"),
help="Path to output EPG XML file",
)
parser.add_argument(
"--epg-sources",
default=os.path.join(
os.path.dirname(os.path.dirname(__file__)), "epg_sources.json"
),
help="Path to EPG sources JSON configuration file",
)
parser.add_argument(
"--save-as-gz",
action="store_true",
default=True,
help="Save an additional gzipped version of the EPG file",
)
return parser.parse_args()
def load_epg_sources(config_path):
"""Load EPG sources from JSON configuration file"""
try:
with open(config_path, 'r', encoding='utf-8') as f:
with open(config_path, encoding="utf-8") as f:
config = json.load(f)
return config.get('epg_sources', [])
return config.get("epg_sources", [])
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Error loading EPG sources: {e}")
return []
def get_tvg_ids(playlist_path):
"""
Extracts unique tvg-id values from an M3U playlist file.
@@ -51,26 +71,27 @@ def get_tvg_ids(playlist_path):
# and ends with a double quote.
tvg_id_pattern = re.compile(r'tvg-id="([^"]*)"')
with open(playlist_path, 'r', encoding='utf-8') as file:
with open(playlist_path, encoding="utf-8") as file:
for line in file:
if line.startswith('#EXTINF'):
if line.startswith("#EXTINF"):
# Search for the tvg-id pattern in the line
match = tvg_id_pattern.search(line)
if match:
# Extract the captured group (the value inside the quotes)
tvg_id = match.group(1)
if tvg_id: # Ensure the extracted id is not empty
if tvg_id: # Ensure the extracted id is not empty
unique_tvg_ids.add(tvg_id)
return list(unique_tvg_ids)
def fetch_and_extract_xml(url):
response = requests.get(url)
if response.status_code != 200:
print(f"Failed to fetch {url}")
return None
if url.endswith('.gz'):
if url.endswith(".gz"):
try:
decompressed_data = gzip.decompress(response.content)
return ET.fromstring(decompressed_data)
@@ -84,44 +105,48 @@ def fetch_and_extract_xml(url):
print(f"Failed to parse XML from {url}: {e}")
return None
def filter_and_build_epg(urls, tvg_ids, output_file, save_as_gz=True):
root = ET.Element('tv')
root = ET.Element("tv")
for url in urls:
epg_data = fetch_and_extract_xml(url)
if epg_data is None:
continue
for channel in epg_data.findall('channel'):
tvg_id = channel.get('id')
for channel in epg_data.findall("channel"):
tvg_id = channel.get("id")
if tvg_id in tvg_ids:
root.append(channel)
for programme in epg_data.findall('programme'):
tvg_id = programme.get('channel')
for programme in epg_data.findall("programme"):
tvg_id = programme.get("channel")
if tvg_id in tvg_ids:
root.append(programme)
tree = ET.ElementTree(root)
tree.write(output_file, encoding='utf-8', xml_declaration=True)
tree.write(output_file, encoding="utf-8", xml_declaration=True)
print(f"New EPG saved to {output_file}")
if save_as_gz:
output_file_gz = output_file + '.gz'
with gzip.open(output_file_gz, 'wb') as f:
tree.write(f, encoding='utf-8', xml_declaration=True)
output_file_gz = output_file + ".gz"
with gzip.open(output_file_gz, "wb") as f:
tree.write(f, encoding="utf-8", xml_declaration=True)
print(f"New EPG saved to {output_file_gz}")
def upload_epg(file_path):
"""Uploads gzipped EPG file to IPTV server using HTTP Basic Auth"""
try:
with open(file_path, 'rb') as f:
with open(file_path, "rb") as f:
response = requests.post(
IPTV_SERVER_URL + '/admin/epg',
auth=requests.auth.HTTPBasicAuth(IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD),
files={'file': (os.path.basename(file_path), f)}
IPTV_SERVER_URL + "/admin/epg",
auth=requests.auth.HTTPBasicAuth(
IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD
),
files={"file": (os.path.basename(file_path), f)},
)
if response.status_code == 200:
print("EPG successfully uploaded to server")
else:
@@ -129,6 +154,7 @@ def upload_epg(file_path):
except Exception as e:
print(f"Upload error: {str(e)}")
if __name__ == "__main__":
args = parse_arguments()
playlist_file = args.playlist
@@ -144,4 +170,4 @@ if __name__ == "__main__":
filter_and_build_epg(urls, tvg_ids, output_file, args.save_as_gz)
if args.save_as_gz:
upload_epg(output_file + '.gz')
upload_epg(output_file + ".gz")

View File

@@ -1,26 +1,45 @@
import os
import argparse
import json
import logging
import requests
from pathlib import Path
import os
from datetime import datetime
import requests
from utils.check_streams import StreamValidator
from utils.constants import EPG_URL, IPTV_SERVER_ADMIN_PASSWORD, IPTV_SERVER_ADMIN_USER, IPTV_SERVER_URL
from utils.constants import (
EPG_URL,
IPTV_SERVER_ADMIN_PASSWORD,
IPTV_SERVER_ADMIN_USER,
IPTV_SERVER_URL,
)
def parse_arguments():
parser = argparse.ArgumentParser(description='IPTV playlist generator')
parser.add_argument('--output',
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'playlist.m3u8'),
help='Path to output playlist file')
parser.add_argument('--channels',
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'channels.json'),
help='Path to channels definition JSON file')
parser.add_argument('--dead-channels-log',
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'dead_channels.log'),
help='Path to log file to store a list of dead channels')
parser = argparse.ArgumentParser(description="IPTV playlist generator")
parser.add_argument(
"--output",
default=os.path.join(
os.path.dirname(os.path.dirname(__file__)), "playlist.m3u8"
),
help="Path to output playlist file",
)
parser.add_argument(
"--channels",
default=os.path.join(
os.path.dirname(os.path.dirname(__file__)), "channels.json"
),
help="Path to channels definition JSON file",
)
parser.add_argument(
"--dead-channels-log",
default=os.path.join(
os.path.dirname(os.path.dirname(__file__)), "dead_channels.log"
),
help="Path to log file to store a list of dead channels",
)
return parser.parse_args()
def find_working_stream(validator, urls):
"""Test all URLs and return the first working one"""
for url in urls:
@@ -29,48 +48,55 @@ def find_working_stream(validator, urls):
return url
return None
def create_playlist(channels_file, output_file):
# Read channels from JSON file
with open(channels_file, 'r', encoding='utf-8') as f:
with open(channels_file, encoding="utf-8") as f:
channels = json.load(f)
# Initialize validator
validator = StreamValidator(timeout=45)
# Prepare M3U8 header
m3u8_content = f'#EXTM3U url-tvg="{EPG_URL}"\n'
for channel in channels:
if 'urls' in channel: # Check if channel has URLs
if "urls" in channel: # Check if channel has URLs
# Find first working stream
working_url = find_working_stream(validator, channel['urls'])
working_url = find_working_stream(validator, channel["urls"])
if working_url:
# Add channel to playlist
m3u8_content += f'#EXTINF:-1 tvg-id="{channel.get("tvg-id", "")}" '
m3u8_content += f'tvg-name="{channel.get("tvg-name", "")}" '
m3u8_content += f'tvg-logo="{channel.get("tvg-logo", "")}" '
m3u8_content += f'group-title="{channel.get("group-title", "")}", '
m3u8_content += f'{channel.get("name", "")}\n'
m3u8_content += f'{working_url}\n'
m3u8_content += f"{channel.get('name', '')}\n"
m3u8_content += f"{working_url}\n"
else:
# Log dead channel
logging.info(f'Dead channel: {channel.get("name", "Unknown")} - No working streams found')
logging.info(
f"Dead channel: {channel.get('name', 'Unknown')} - "
"No working streams found"
)
# Write playlist file
with open(output_file, 'w', encoding='utf-8') as f:
with open(output_file, "w", encoding="utf-8") as f:
f.write(m3u8_content)
def upload_playlist(file_path):
"""Uploads playlist file to IPTV server using HTTP Basic Auth"""
try:
with open(file_path, 'rb') as f:
with open(file_path, "rb") as f:
response = requests.post(
IPTV_SERVER_URL + '/admin/playlist',
auth=requests.auth.HTTPBasicAuth(IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD),
files={'file': (os.path.basename(file_path), f)}
IPTV_SERVER_URL + "/admin/playlist",
auth=requests.auth.HTTPBasicAuth(
IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD
),
files={"file": (os.path.basename(file_path), f)},
)
if response.status_code == 200:
print("Playlist successfully uploaded to server")
else:
@@ -78,6 +104,7 @@ def upload_playlist(file_path):
except Exception as e:
print(f"Upload error: {str(e)}")
def main():
args = parse_arguments()
channels_file = args.channels
@@ -85,24 +112,25 @@ def main():
dead_channels_log_file = args.dead_channels_log
# Clear previous log file
with open(dead_channels_log_file, 'w') as f:
f.write(f'Log created on {datetime.now()}\n')
with open(dead_channels_log_file, "w") as f:
f.write(f"Log created on {datetime.now()}\n")
# Configure logging
logging.basicConfig(
filename=dead_channels_log_file,
level=logging.INFO,
format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
format="%(asctime)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# Create playlist
create_playlist(channels_file, output_file)
#upload playlist to server
# upload playlist to server
upload_playlist(output_file)
print("Playlist creation completed!")
if __name__ == "__main__":
main()
main()

View File

@@ -1,17 +1,18 @@
from fastapi.concurrency import asynccontextmanager
from app.routers import channels, auth, playlist, priorities
from fastapi import FastAPI
from fastapi.concurrency import asynccontextmanager
from fastapi.openapi.utils import get_openapi
from app.routers import auth, channels, playlist, priorities
from app.utils.database import init_db
@asynccontextmanager
async def lifespan(app: FastAPI):
# Initialize database tables on startup
init_db()
yield
app = FastAPI(
lifespan=lifespan,
title="IPTV Updater API",
@@ -19,6 +20,7 @@ app = FastAPI(
version="1.0.0",
)
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
@@ -40,11 +42,7 @@ def custom_openapi():
# Add security scheme component
openapi_schema["components"]["securitySchemes"] = {
"Bearer": {
"type": "http",
"scheme": "bearer",
"bearerFormat": "JWT"
}
"Bearer": {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}
}
# Add global security requirement
@@ -56,14 +54,17 @@ def custom_openapi():
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi
@app.get("/")
async def root():
return {"message": "IPTV Updater API"}
# Include routers
app.include_router(auth.router)
app.include_router(channels.router)
app.include_router(playlist.router)
app.include_router(priorities.router)
app.include_router(priorities.router)

View File

@@ -1,4 +1,19 @@
from .db import Base, ChannelDB, ChannelURL
from .schemas import ChannelCreate, ChannelUpdate, ChannelResponse, ChannelURLCreate, ChannelURLResponse
from .schemas import (
ChannelCreate,
ChannelResponse,
ChannelUpdate,
ChannelURLCreate,
ChannelURLResponse,
)
__all__ = ["Base", "ChannelDB", "ChannelCreate", "ChannelUpdate", "ChannelResponse", "ChannelURL", "ChannelURLCreate", "ChannelURLResponse"]
__all__ = [
"Base",
"ChannelDB",
"ChannelCreate",
"ChannelUpdate",
"ChannelResponse",
"ChannelURL",
"ChannelURLCreate",
"ChannelURLResponse",
]

View File

@@ -1,20 +1,26 @@
from typing import List, Optional
from typing import Optional
from pydantic import BaseModel, Field
class SigninRequest(BaseModel):
"""Request model for the signin endpoint."""
username: str = Field(..., description="The user's username")
password: str = Field(..., description="The user's password")
class TokenResponse(BaseModel):
"""Response model for successful authentication."""
access_token: str = Field(..., description="Access JWT token from Cognito")
id_token: str = Field(..., description="ID JWT token from Cognito")
refresh_token: Optional[str] = Field(
None, description="Refresh token from Cognito")
refresh_token: Optional[str] = Field(None, description="Refresh token from Cognito")
token_type: str = Field(..., description="Type of the token returned")
class CognitoUser(BaseModel):
"""Model representing the user returned from token verification."""
username: str
roles: List[str]
roles: list[str]

View File

@@ -1,21 +1,33 @@
from datetime import datetime, timezone
import uuid
from sqlalchemy import Column, String, JSON, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer
from datetime import datetime, timezone
from sqlalchemy import (
Boolean,
Column,
DateTime,
ForeignKey,
Integer,
String,
UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import relationship
from sqlalchemy.orm import declarative_base, relationship
Base = declarative_base()
class Priority(Base):
"""SQLAlchemy model for channel URL priorities"""
__tablename__ = "priorities"
id = Column(Integer, primary_key=True)
description = Column(String, nullable=False)
class ChannelDB(Base):
"""SQLAlchemy model for IPTV channels"""
__tablename__ = "channels"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
@@ -25,27 +37,43 @@ class ChannelDB(Base):
tvg_name = Column(String)
__table_args__ = (
UniqueConstraint('group_title', 'name', name='uix_group_title_name'),
UniqueConstraint("group_title", "name", name="uix_group_title_name"),
)
tvg_logo = Column(String)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
updated_at = Column(
DateTime,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
# Relationship with ChannelURL
urls = relationship("ChannelURL", back_populates="channel", cascade="all, delete-orphan")
urls = relationship(
"ChannelURL", back_populates="channel", cascade="all, delete-orphan"
)
class ChannelURL(Base):
"""SQLAlchemy model for channel URLs"""
__tablename__ = "channels_urls"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
channel_id = Column(UUID(as_uuid=True), ForeignKey('channels.id', ondelete='CASCADE'), nullable=False)
channel_id = Column(
UUID(as_uuid=True),
ForeignKey("channels.id", ondelete="CASCADE"),
nullable=False,
)
url = Column(String, nullable=False)
in_use = Column(Boolean, default=False, nullable=False)
priority_id = Column(Integer, ForeignKey('priorities.id'), nullable=False)
priority_id = Column(Integer, ForeignKey("priorities.id"), nullable=False)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
updated_at = Column(
DateTime,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
# Relationships
channel = relationship("ChannelDB", back_populates="urls")
priority = relationship("Priority")
priority = relationship("Priority")

View File

@@ -1,30 +1,43 @@
from datetime import datetime
from typing import List, Optional
from typing import Optional
from uuid import UUID
from pydantic import BaseModel, Field, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
class PriorityBase(BaseModel):
"""Base Pydantic model for priorities"""
id: int
description: str
model_config = ConfigDict(from_attributes=True)
class PriorityCreate(PriorityBase):
"""Pydantic model for creating priorities"""
pass
class PriorityResponse(PriorityBase):
"""Pydantic model for priority responses"""
pass
class ChannelURLCreate(BaseModel):
"""Pydantic model for creating channel URLs"""
url: str
priority_id: int = Field(default=100, ge=100, le=300) # Default to High, validate range
priority_id: int = Field(
default=100, ge=100, le=300
) # Default to High, validate range
class ChannelURLBase(ChannelURLCreate):
"""Base Pydantic model for channel URL responses"""
id: UUID
in_use: bool
created_at: datetime
@@ -33,43 +46,53 @@ class ChannelURLBase(ChannelURLCreate):
model_config = ConfigDict(from_attributes=True)
class ChannelURLResponse(ChannelURLBase):
"""Pydantic model for channel URL responses"""
pass
class ChannelCreate(BaseModel):
"""Pydantic model for creating channels"""
urls: List[ChannelURLCreate] # List of URL objects with priority
urls: list[ChannelURLCreate] # List of URL objects with priority
name: str
group_title: str
tvg_id: str
tvg_logo: str
tvg_name: str
class ChannelURLUpdate(BaseModel):
"""Pydantic model for updating channel URLs"""
url: Optional[str] = None
in_use: Optional[bool] = None
priority_id: Optional[int] = Field(default=None, ge=100, le=300)
class ChannelUpdate(BaseModel):
"""Pydantic model for updating channels (all fields optional)"""
name: Optional[str] = Field(None, min_length=1)
group_title: Optional[str] = Field(None, min_length=1)
tvg_id: Optional[str] = Field(None, min_length=1)
tvg_logo: Optional[str] = None
tvg_name: Optional[str] = Field(None, min_length=1)
class ChannelResponse(BaseModel):
"""Pydantic model for channel responses"""
id: UUID
name: str
group_title: str
tvg_id: str
tvg_logo: str
tvg_name: str
urls: List[ChannelURLResponse] # List of URL objects without channel_id
urls: list[ChannelURLResponse] # List of URL objects without channel_id
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
model_config = ConfigDict(from_attributes=True)

View File

@@ -1,16 +1,16 @@
from fastapi import APIRouter
from app.auth.cognito import initiate_auth
from app.models.auth import SigninRequest, TokenResponse
router = APIRouter(
prefix="/auth",
tags=["authentication"]
)
router = APIRouter(prefix="/auth", tags=["authentication"])
@router.post("/signin", response_model=TokenResponse, summary="Signin Endpoint")
def signin(credentials: SigninRequest):
"""
Sign-in endpoint to authenticate the user with AWS Cognito using username and password.
Sign-in endpoint to authenticate the user with AWS Cognito
using username and password.
On success, returns JWT tokens (access_token, id_token, refresh_token).
"""
auth_result = initiate_auth(credentials.username, credentials.password)
@@ -19,4 +19,4 @@ def signin(credentials: SigninRequest):
id_token=auth_result["IdToken"],
refresh_token=auth_result.get("RefreshToken"),
token_type="Bearer",
)
)

View File

@@ -1,52 +1,54 @@
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from typing import List
from uuid import UUID
from sqlalchemy import and_
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import and_
from sqlalchemy.orm import Session
from app.auth.dependencies import get_current_user, require_roles
from app.models import (
ChannelDB,
ChannelURL,
ChannelCreate,
ChannelUpdate,
ChannelDB,
ChannelResponse,
ChannelUpdate,
ChannelURL,
ChannelURLCreate,
ChannelURLResponse,
)
from app.models.auth import CognitoUser
from app.models.schemas import ChannelURLUpdate
from app.utils.database import get_db
from app.auth.dependencies import get_current_user, require_roles
from app.models.auth import CognitoUser
router = APIRouter(
prefix="/channels",
tags=["channels"]
)
router = APIRouter(prefix="/channels", tags=["channels"])
@router.post("/", response_model=ChannelResponse, status_code=status.HTTP_201_CREATED)
@require_roles("admin")
def create_channel(
channel: ChannelCreate,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user)
user: CognitoUser = Depends(get_current_user),
):
"""Create a new channel"""
# Check for duplicate channel (same group_title + name)
existing_channel = db.query(ChannelDB).filter(
and_(
ChannelDB.group_title == channel.group_title,
ChannelDB.name == channel.name
existing_channel = (
db.query(ChannelDB)
.filter(
and_(
ChannelDB.group_title == channel.group_title,
ChannelDB.name == channel.name,
)
)
).first()
.first()
)
if existing_channel:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Channel with same group_title and name already exists"
detail="Channel with same group_title and name already exists",
)
# Create channel without URLs first
channel_data = channel.model_dump(exclude={'urls'})
channel_data = channel.model_dump(exclude={"urls"})
urls = channel.urls
db_channel = ChannelDB(**channel_data)
db.add(db_channel)
@@ -59,130 +61,142 @@ def create_channel(
channel_id=db_channel.id,
url=url.url,
priority_id=url.priority_id,
in_use=False
in_use=False,
)
db.add(db_url)
db.commit()
db.refresh(db_channel)
return db_channel
@router.get("/{channel_id}", response_model=ChannelResponse)
def get_channel(
channel_id: UUID,
db: Session = Depends(get_db)
):
def get_channel(channel_id: UUID, db: Session = Depends(get_db)):
"""Get a channel by id"""
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Channel not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
)
return channel
@router.put("/{channel_id}", response_model=ChannelResponse)
@require_roles("admin")
def update_channel(
channel_id: UUID,
channel: ChannelUpdate,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user)
user: CognitoUser = Depends(get_current_user),
):
"""Update a channel"""
db_channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not db_channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Channel not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
)
# Only check for duplicates if name or group_title are being updated
if channel.name is not None or channel.group_title is not None:
name = channel.name if channel.name is not None else db_channel.name
group_title = channel.group_title if channel.group_title is not None else db_channel.group_title
existing_channel = db.query(ChannelDB).filter(
and_(
ChannelDB.group_title == group_title,
ChannelDB.name == name,
ChannelDB.id != channel_id
group_title = (
channel.group_title
if channel.group_title is not None
else db_channel.group_title
)
existing_channel = (
db.query(ChannelDB)
.filter(
and_(
ChannelDB.group_title == group_title,
ChannelDB.name == name,
ChannelDB.id != channel_id,
)
)
).first()
.first()
)
if existing_channel:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Channel with same group_title and name already exists"
detail="Channel with same group_title and name already exists",
)
# Update only provided fields
update_data = channel.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_channel, key, value)
db.commit()
db.refresh(db_channel)
return db_channel
@router.delete("/{channel_id}", status_code=status.HTTP_204_NO_CONTENT)
@require_roles("admin")
def delete_channel(
channel_id: UUID,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user)
user: CognitoUser = Depends(get_current_user),
):
"""Delete a channel"""
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Channel not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
)
db.delete(channel)
db.commit()
return None
@router.get("/", response_model=List[ChannelResponse])
@router.get("/", response_model=list[ChannelResponse])
@require_roles("admin")
def list_channels(
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user)
user: CognitoUser = Depends(get_current_user),
):
"""List all channels with pagination"""
return db.query(ChannelDB).offset(skip).limit(limit).all()
# URL Management Endpoints
@router.post("/{channel_id}/urls", response_model=ChannelURLResponse, status_code=status.HTTP_201_CREATED)
@router.post(
"/{channel_id}/urls",
response_model=ChannelURLResponse,
status_code=status.HTTP_201_CREATED,
)
@require_roles("admin")
def add_channel_url(
channel_id: UUID,
url: ChannelURLCreate,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user)
user: CognitoUser = Depends(get_current_user),
):
"""Add a new URL to a channel"""
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Channel not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
)
db_url = ChannelURL(
channel_id=channel_id,
url=url.url,
priority_id=url.priority_id,
in_use=False # Default to not in use
in_use=False, # Default to not in use
)
db.add(db_url)
db.commit()
db.refresh(db_url)
return db_url
@router.put("/{channel_id}/urls/{url_id}", response_model=ChannelURLResponse)
@require_roles("admin")
def update_channel_url(
@@ -190,72 +204,69 @@ def update_channel_url(
url_id: UUID,
url_update: ChannelURLUpdate,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user)
user: CognitoUser = Depends(get_current_user),
):
"""Update a channel URL (url, in_use, or priority_id)"""
db_url = db.query(ChannelURL).filter(
and_(
ChannelURL.id == url_id,
ChannelURL.channel_id == channel_id
)
).first()
db_url = (
db.query(ChannelURL)
.filter(and_(ChannelURL.id == url_id, ChannelURL.channel_id == channel_id))
.first()
)
if not db_url:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="URL not found"
status_code=status.HTTP_404_NOT_FOUND, detail="URL not found"
)
if url_update.url is not None:
db_url.url = url_update.url
if url_update.in_use is not None:
db_url.in_use = url_update.in_use
if url_update.priority_id is not None:
db_url.priority_id = url_update.priority_id
db.commit()
db.refresh(db_url)
return db_url
@router.delete("/{channel_id}/urls/{url_id}", status_code=status.HTTP_204_NO_CONTENT)
@require_roles("admin")
def delete_channel_url(
channel_id: UUID,
url_id: UUID,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user)
user: CognitoUser = Depends(get_current_user),
):
"""Delete a URL from a channel"""
url = db.query(ChannelURL).filter(
and_(
ChannelURL.id == url_id,
ChannelURL.channel_id == channel_id
)
).first()
url = (
db.query(ChannelURL)
.filter(and_(ChannelURL.id == url_id, ChannelURL.channel_id == channel_id))
.first()
)
if not url:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="URL not found"
status_code=status.HTTP_404_NOT_FOUND, detail="URL not found"
)
db.delete(url)
db.commit()
return None
@router.get("/{channel_id}/urls", response_model=List[ChannelURLResponse])
@router.get("/{channel_id}/urls", response_model=list[ChannelURLResponse])
@require_roles("admin")
def list_channel_urls(
channel_id: UUID,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user)
user: CognitoUser = Depends(get_current_user),
):
"""List all URLs for a channel"""
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Channel not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
)
return db.query(ChannelURL).filter(ChannelURL.channel_id == channel_id).all()
return db.query(ChannelURL).filter(ChannelURL.channel_id == channel_id).all()

View File

@@ -1,17 +1,15 @@
from fastapi import APIRouter, Depends
from app.auth.dependencies import get_current_user
from app.models.auth import CognitoUser
router = APIRouter(
prefix="/playlist",
tags=["playlist"]
)
router = APIRouter(prefix="/playlist", tags=["playlist"])
@router.get("/protected",
summary="Protected endpoint for authenticated users")
@router.get("/protected", summary="Protected endpoint for authenticated users")
async def protected_route(user: CognitoUser = Depends(get_current_user)):
"""
Protected endpoint that requires authentication for all users.
If the user is authenticated, returns success message.
"""
return {"message": f"Hello {user.username}, you have access to support resources!"}
return {"message": f"Hello {user.username}, you have access to support resources!"}

View File

@@ -1,25 +1,22 @@
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from sqlalchemy import select, delete
from typing import List
from app.auth.dependencies import get_current_user, require_roles
from app.models.auth import CognitoUser
from app.models.db import Priority
from app.models.schemas import PriorityCreate, PriorityResponse
from app.utils.database import get_db
from app.auth.dependencies import get_current_user, require_roles
from app.models.auth import CognitoUser
router = APIRouter(
prefix="/priorities",
tags=["priorities"]
)
router = APIRouter(prefix="/priorities", tags=["priorities"])
@router.post("/", response_model=PriorityResponse, status_code=status.HTTP_201_CREATED)
@require_roles("admin")
def create_priority(
priority: PriorityCreate,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user)
user: CognitoUser = Depends(get_current_user),
):
"""Create a new priority"""
# Check if priority with this ID already exists
@@ -27,71 +24,69 @@ def create_priority(
if existing:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Priority with ID {priority.id} already exists"
detail=f"Priority with ID {priority.id} already exists",
)
db_priority = Priority(**priority.model_dump())
db.add(db_priority)
db.commit()
db.refresh(db_priority)
return db_priority
@router.get("/", response_model=List[PriorityResponse])
@router.get("/", response_model=list[PriorityResponse])
@require_roles("admin")
def list_priorities(
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user)
db: Session = Depends(get_db), user: CognitoUser = Depends(get_current_user)
):
"""List all priorities"""
return db.query(Priority).all()
@router.get("/{priority_id}", response_model=PriorityResponse)
@require_roles("admin")
def get_priority(
priority_id: int,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user)
user: CognitoUser = Depends(get_current_user),
):
"""Get a priority by id"""
priority = db.get(Priority, priority_id)
if not priority:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Priority not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Priority not found"
)
return priority
@router.delete("/{priority_id}", status_code=status.HTTP_204_NO_CONTENT)
@require_roles("admin")
def delete_priority(
priority_id: int,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user)
user: CognitoUser = Depends(get_current_user),
):
"""Delete a priority (if not in use)"""
from app.models.db import ChannelURL
# Check if priority exists
priority = db.get(Priority, priority_id)
if not priority:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Priority not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Priority not found"
)
# Check if priority is in use
in_use = db.scalar(
select(ChannelURL)
.where(ChannelURL.priority_id == priority_id)
.limit(1)
select(ChannelURL).where(ChannelURL.priority_id == priority_id).limit(1)
)
if in_use:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Cannot delete priority that is in use by channel URLs"
detail="Cannot delete priority that is in use by channel URLs",
)
db.execute(delete(Priority).where(Priority.id == priority_id))
db.commit()
return None
return None

View File

@@ -2,11 +2,13 @@ import base64
import hashlib
import hmac
def calculate_secret_hash(username: str, client_id: str, client_secret: str) -> str:
"""
Calculate the Cognito SECRET_HASH using HMAC SHA256 for secret-enabled clients.
"""
msg = username + client_id
dig = hmac.new(client_secret.encode('utf-8'),
msg.encode('utf-8'), hashlib.sha256).digest()
return base64.b64encode(dig).decode()
dig = hmac.new(
client_secret.encode("utf-8"), msg.encode("utf-8"), hashlib.sha256
).digest()
return base64.b64encode(dig).decode()

View File

@@ -1,41 +1,50 @@
import os
import argparse
import requests
import logging
from requests.exceptions import RequestException, Timeout, ConnectionError, HTTPError
import os
import requests
from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout
class StreamValidator:
def __init__(self, timeout=10, user_agent=None):
self.timeout = timeout
self.session = requests.Session()
self.session.headers.update({
'User-Agent': user_agent or 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36'
})
self.session.headers.update(
{
"User-Agent": user_agent
or (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/120.0.0.0 Safari/537.36"
)
}
)
def validate_stream(self, url):
"""Validate a media stream URL with multiple fallback checks"""
try:
headers = {'Range': 'bytes=0-1024'}
headers = {"Range": "bytes=0-1024"}
with self.session.get(
url,
headers=headers,
timeout=self.timeout,
stream=True,
allow_redirects=True
allow_redirects=True,
) as response:
if response.status_code not in [200, 206]:
return False, f"Invalid status code: {response.status_code}"
content_type = response.headers.get('Content-Type', '')
content_type = response.headers.get("Content-Type", "")
if not self._is_valid_content_type(content_type):
return False, f"Invalid content type: {content_type}"
try:
next(response.iter_content(chunk_size=1024))
return True, "Stream is valid"
except (ConnectionError, Timeout):
return False, "Connection failed during content read"
except HTTPError as e:
return False, f"HTTP Error: {str(e)}"
except ConnectionError as e:
@@ -49,10 +58,13 @@ class StreamValidator:
def _is_valid_content_type(self, content_type):
valid_types = [
'video/mp2t', 'application/vnd.apple.mpegurl',
'application/dash+xml', 'video/mp4',
'video/webm', 'application/octet-stream',
'application/x-mpegURL'
"video/mp2t",
"application/vnd.apple.mpegurl",
"application/dash+xml",
"video/mp4",
"video/webm",
"application/octet-stream",
"application/x-mpegURL",
]
return any(ct in content_type for ct in valid_types)
@@ -60,45 +72,43 @@ class StreamValidator:
"""Extract stream URLs from M3U playlist file"""
urls = []
try:
with open(file_path, 'r') as f:
with open(file_path) as f:
for line in f:
line = line.strip()
if line and not line.startswith('#'):
if line and not line.startswith("#"):
urls.append(line)
except Exception as e:
logging.error(f"Error reading playlist file: {str(e)}")
raise
return urls
def main():
parser = argparse.ArgumentParser(
description='Validate streaming URLs from command line arguments or playlist files',
formatter_class=argparse.ArgumentDefaultsHelpFormatter
description=(
"Validate streaming URLs from command line arguments or playlist files"
),
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
'sources',
nargs='+',
help='List of URLs or file paths containing stream URLs'
"sources", nargs="+", help="List of URLs or file paths containing stream URLs"
)
parser.add_argument(
'--timeout',
type=int,
default=20,
help='Timeout in seconds for stream checks'
"--timeout", type=int, default=20, help="Timeout in seconds for stream checks"
)
parser.add_argument(
'--output',
default='deadstreams.txt',
help='Output file name for inactive streams'
"--output",
default="deadstreams.txt",
help="Output file name for inactive streams",
)
args = parser.parse_args()
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.FileHandler('stream_check.log'), logging.StreamHandler()]
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[logging.FileHandler("stream_check.log"), logging.StreamHandler()],
)
validator = StreamValidator(timeout=args.timeout)
@@ -127,9 +137,10 @@ def main():
# Save dead streams to file
if dead_streams:
with open(args.output, 'w') as f:
f.write('\n'.join(dead_streams))
with open(args.output, "w") as f:
f.write("\n".join(dead_streams))
logging.info(f"Found {len(dead_streams)} dead streams. Saved to {args.output}.")
if __name__ == "__main__":
main()
main()

View File

@@ -21,4 +21,4 @@ IPTV_SERVER_ADMIN_USER = os.getenv("IPTV_SERVER_ADMIN_USER", "admin")
IPTV_SERVER_ADMIN_PASSWORD = os.getenv("IPTV_SERVER_ADMIN_PASSWORD", "adminpassword")
# URL for the EPG XML file to place in the playlist's header
EPG_URL = os.getenv("EPG_URL", "https://example.com/epg.xml.gz")
EPG_URL = os.getenv("EPG_URL", "https://example.com/epg.xml.gz")

View File

@@ -1,11 +1,13 @@
import os
import boto3
from app.models import Base
from .constants import AWS_REGION
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from functools import lru_cache
from app.models import Base
from .constants import AWS_REGION
def get_db_credentials():
"""Fetch and cache DB credentials from environment or SSM Parameter Store"""
@@ -14,29 +16,40 @@ def get_db_credentials():
f"postgresql://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}"
f"@{os.getenv('DB_HOST')}/{os.getenv('DB_NAME')}"
)
ssm = boto3.client('ssm', region_name=AWS_REGION)
ssm = boto3.client("ssm", region_name=AWS_REGION)
try:
host = ssm.get_parameter(Name='/iptv-updater/DB_HOST', WithDecryption=True)['Parameter']['Value']
user = ssm.get_parameter(Name='/iptv-updater/DB_USER', WithDecryption=True)['Parameter']['Value']
password = ssm.get_parameter(Name='/iptv-updater/DB_PASSWORD', WithDecryption=True)['Parameter']['Value']
dbname = ssm.get_parameter(Name='/iptv-updater/DB_NAME', WithDecryption=True)['Parameter']['Value']
host = ssm.get_parameter(Name="/iptv-updater/DB_HOST", WithDecryption=True)[
"Parameter"
]["Value"]
user = ssm.get_parameter(Name="/iptv-updater/DB_USER", WithDecryption=True)[
"Parameter"
]["Value"]
password = ssm.get_parameter(
Name="/iptv-updater/DB_PASSWORD", WithDecryption=True
)["Parameter"]["Value"]
dbname = ssm.get_parameter(Name="/iptv-updater/DB_NAME", WithDecryption=True)[
"Parameter"
]["Value"]
return f"postgresql://{user}:{password}@{host}/{dbname}"
except Exception as e:
raise RuntimeError(f"Failed to fetch DB credentials from SSM: {str(e)}")
# Initialize engine and session maker
engine = create_engine(get_db_credentials())
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def init_db():
"""Initialize database by creating all tables"""
Base.metadata.create_all(bind=engine)
def get_db():
"""Dependency for getting database session"""
db = SessionLocal()
try:
yield db
finally:
db.close()
db.close()