Linted and formatted all files
This commit is contained in:
@@ -1,9 +1,14 @@
|
||||
import boto3
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.models.auth import CognitoUser
|
||||
from app.utils.auth import calculate_secret_hash
|
||||
from app.utils.constants import (AWS_REGION, COGNITO_CLIENT_ID,
|
||||
COGNITO_CLIENT_SECRET, USER_ROLE_ATTRIBUTE)
|
||||
from app.utils.constants import (
|
||||
AWS_REGION,
|
||||
COGNITO_CLIENT_ID,
|
||||
COGNITO_CLIENT_SECRET,
|
||||
USER_ROLE_ATTRIBUTE,
|
||||
)
|
||||
|
||||
cognito_client = boto3.client("cognito-idp", region_name=AWS_REGION)
|
||||
|
||||
@@ -12,43 +17,41 @@ def initiate_auth(username: str, password: str) -> dict:
|
||||
"""
|
||||
Initiate AUTH flow with Cognito using USER_PASSWORD_AUTH.
|
||||
"""
|
||||
auth_params = {
|
||||
"USERNAME": username,
|
||||
"PASSWORD": password
|
||||
}
|
||||
auth_params = {"USERNAME": username, "PASSWORD": password}
|
||||
|
||||
# If a client secret is required, add SECRET_HASH
|
||||
if COGNITO_CLIENT_SECRET:
|
||||
auth_params["SECRET_HASH"] = calculate_secret_hash(
|
||||
username, COGNITO_CLIENT_ID, COGNITO_CLIENT_SECRET)
|
||||
username, COGNITO_CLIENT_ID, COGNITO_CLIENT_SECRET
|
||||
)
|
||||
|
||||
try:
|
||||
response = cognito_client.initiate_auth(
|
||||
AuthFlow="USER_PASSWORD_AUTH",
|
||||
AuthParameters=auth_params,
|
||||
ClientId=COGNITO_CLIENT_ID
|
||||
ClientId=COGNITO_CLIENT_ID,
|
||||
)
|
||||
return response["AuthenticationResult"]
|
||||
except cognito_client.exceptions.NotAuthorizedException:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password"
|
||||
detail="Invalid username or password",
|
||||
)
|
||||
except cognito_client.exceptions.UserNotFoundException:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"An error occurred during authentication: {str(e)}"
|
||||
detail=f"An error occurred during authentication: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
def get_user_from_token(access_token: str) -> CognitoUser:
|
||||
"""
|
||||
Verify the token by calling GetUser in Cognito and retrieve user attributes including roles.
|
||||
Verify the token by calling GetUser in Cognito and
|
||||
retrieve user attributes including roles.
|
||||
"""
|
||||
try:
|
||||
user_response = cognito_client.get_user(AccessToken=access_token)
|
||||
@@ -59,23 +62,21 @@ def get_user_from_token(access_token: str) -> CognitoUser:
|
||||
for attr in attributes:
|
||||
if attr["Name"] == USER_ROLE_ATTRIBUTE:
|
||||
# Assume roles are stored as a comma-separated string
|
||||
user_roles = [r.strip()
|
||||
for r in attr["Value"].split(",") if r.strip()]
|
||||
user_roles = [r.strip() for r in attr["Value"].split(",") if r.strip()]
|
||||
break
|
||||
|
||||
return CognitoUser(username=username, roles=user_roles)
|
||||
except cognito_client.exceptions.NotAuthorizedException:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token."
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token."
|
||||
)
|
||||
except cognito_client.exceptions.UserNotFoundException:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found or invalid token."
|
||||
detail="User not found or invalid token.",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Token verification failed: {str(e)}"
|
||||
detail=f"Token verification failed: {str(e)}",
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
import os
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
@@ -13,10 +13,8 @@ if os.getenv("MOCK_AUTH", "").lower() == "true":
|
||||
else:
|
||||
from app.auth.cognito import get_user_from_token
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(
|
||||
tokenUrl="signin",
|
||||
scheme_name="Bearer"
|
||||
)
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="signin", scheme_name="Bearer")
|
||||
|
||||
|
||||
def get_current_user(token: str = Depends(oauth2_scheme)) -> CognitoUser:
|
||||
"""
|
||||
@@ -40,7 +38,9 @@ def require_roles(*required_roles: str) -> Callable:
|
||||
if not needed_roles.issubset(user_roles):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You do not have the required roles to access this endpoint.",
|
||||
detail=(
|
||||
"You do not have the required roles to access this endpoint."
|
||||
),
|
||||
)
|
||||
return endpoint(*args, user=user, **kwargs)
|
||||
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.models.auth import CognitoUser
|
||||
|
||||
MOCK_USERS = {
|
||||
"testuser": {
|
||||
"username": "testuser",
|
||||
"roles": ["admin"]
|
||||
}
|
||||
}
|
||||
MOCK_USERS = {"testuser": {"username": "testuser", "roles": ["admin"]}}
|
||||
|
||||
|
||||
def mock_get_user_from_token(token: str) -> CognitoUser:
|
||||
"""
|
||||
@@ -17,16 +14,13 @@ def mock_get_user_from_token(token: str) -> CognitoUser:
|
||||
return CognitoUser(**MOCK_USERS["testuser"])
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid mock token - use 'testuser'"
|
||||
detail="Invalid mock token - use 'testuser'",
|
||||
)
|
||||
|
||||
|
||||
def mock_initiate_auth(username: str, password: str) -> dict:
|
||||
"""
|
||||
Mock version of initiate_auth for local testing
|
||||
Accepts any username/password and returns a mock token
|
||||
"""
|
||||
return {
|
||||
"AccessToken": "testuser",
|
||||
"ExpiresIn": 3600,
|
||||
"TokenType": "Bearer"
|
||||
}
|
||||
return {"AccessToken": "testuser", "ExpiresIn": 3600, "TokenType": "Bearer"}
|
||||
|
||||
@@ -1,39 +1,59 @@
|
||||
import os
|
||||
import re
|
||||
import argparse
|
||||
import gzip
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
import requests
|
||||
import argparse
|
||||
from utils.constants import IPTV_SERVER_ADMIN_PASSWORD, IPTV_SERVER_ADMIN_USER, IPTV_SERVER_URL
|
||||
from utils.constants import (
|
||||
IPTV_SERVER_ADMIN_PASSWORD,
|
||||
IPTV_SERVER_ADMIN_USER,
|
||||
IPTV_SERVER_URL,
|
||||
)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description='EPG Grabber')
|
||||
parser.add_argument('--playlist',
|
||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'playlist.m3u8'),
|
||||
help='Path to playlist file')
|
||||
parser.add_argument('--output',
|
||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'epg.xml'),
|
||||
help='Path to output EPG XML file')
|
||||
parser.add_argument('--epg-sources',
|
||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'epg_sources.json'),
|
||||
help='Path to EPG sources JSON configuration file')
|
||||
parser.add_argument('--save-as-gz',
|
||||
action='store_true',
|
||||
default=True,
|
||||
help='Save an additional gzipped version of the EPG file')
|
||||
parser = argparse.ArgumentParser(description="EPG Grabber")
|
||||
parser.add_argument(
|
||||
"--playlist",
|
||||
default=os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "playlist.m3u8"
|
||||
),
|
||||
help="Path to playlist file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "epg.xml"),
|
||||
help="Path to output EPG XML file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epg-sources",
|
||||
default=os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "epg_sources.json"
|
||||
),
|
||||
help="Path to EPG sources JSON configuration file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-as-gz",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Save an additional gzipped version of the EPG file",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_epg_sources(config_path):
|
||||
"""Load EPG sources from JSON configuration file"""
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
return config.get('epg_sources', [])
|
||||
return config.get("epg_sources", [])
|
||||
except (FileNotFoundError, json.JSONDecodeError) as e:
|
||||
print(f"Error loading EPG sources: {e}")
|
||||
return []
|
||||
|
||||
|
||||
|
||||
def get_tvg_ids(playlist_path):
|
||||
"""
|
||||
Extracts unique tvg-id values from an M3U playlist file.
|
||||
@@ -51,26 +71,27 @@ def get_tvg_ids(playlist_path):
|
||||
# and ends with a double quote.
|
||||
tvg_id_pattern = re.compile(r'tvg-id="([^"]*)"')
|
||||
|
||||
with open(playlist_path, 'r', encoding='utf-8') as file:
|
||||
with open(playlist_path, encoding="utf-8") as file:
|
||||
for line in file:
|
||||
if line.startswith('#EXTINF'):
|
||||
if line.startswith("#EXTINF"):
|
||||
# Search for the tvg-id pattern in the line
|
||||
match = tvg_id_pattern.search(line)
|
||||
if match:
|
||||
# Extract the captured group (the value inside the quotes)
|
||||
tvg_id = match.group(1)
|
||||
if tvg_id: # Ensure the extracted id is not empty
|
||||
if tvg_id: # Ensure the extracted id is not empty
|
||||
unique_tvg_ids.add(tvg_id)
|
||||
|
||||
return list(unique_tvg_ids)
|
||||
|
||||
|
||||
def fetch_and_extract_xml(url):
|
||||
response = requests.get(url)
|
||||
if response.status_code != 200:
|
||||
print(f"Failed to fetch {url}")
|
||||
return None
|
||||
|
||||
if url.endswith('.gz'):
|
||||
if url.endswith(".gz"):
|
||||
try:
|
||||
decompressed_data = gzip.decompress(response.content)
|
||||
return ET.fromstring(decompressed_data)
|
||||
@@ -84,44 +105,48 @@ def fetch_and_extract_xml(url):
|
||||
print(f"Failed to parse XML from {url}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def filter_and_build_epg(urls, tvg_ids, output_file, save_as_gz=True):
|
||||
root = ET.Element('tv')
|
||||
root = ET.Element("tv")
|
||||
|
||||
for url in urls:
|
||||
epg_data = fetch_and_extract_xml(url)
|
||||
if epg_data is None:
|
||||
continue
|
||||
|
||||
for channel in epg_data.findall('channel'):
|
||||
tvg_id = channel.get('id')
|
||||
for channel in epg_data.findall("channel"):
|
||||
tvg_id = channel.get("id")
|
||||
if tvg_id in tvg_ids:
|
||||
root.append(channel)
|
||||
|
||||
for programme in epg_data.findall('programme'):
|
||||
tvg_id = programme.get('channel')
|
||||
for programme in epg_data.findall("programme"):
|
||||
tvg_id = programme.get("channel")
|
||||
if tvg_id in tvg_ids:
|
||||
root.append(programme)
|
||||
|
||||
tree = ET.ElementTree(root)
|
||||
tree.write(output_file, encoding='utf-8', xml_declaration=True)
|
||||
tree.write(output_file, encoding="utf-8", xml_declaration=True)
|
||||
print(f"New EPG saved to {output_file}")
|
||||
|
||||
if save_as_gz:
|
||||
output_file_gz = output_file + '.gz'
|
||||
with gzip.open(output_file_gz, 'wb') as f:
|
||||
tree.write(f, encoding='utf-8', xml_declaration=True)
|
||||
output_file_gz = output_file + ".gz"
|
||||
with gzip.open(output_file_gz, "wb") as f:
|
||||
tree.write(f, encoding="utf-8", xml_declaration=True)
|
||||
print(f"New EPG saved to {output_file_gz}")
|
||||
|
||||
|
||||
def upload_epg(file_path):
|
||||
"""Uploads gzipped EPG file to IPTV server using HTTP Basic Auth"""
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
with open(file_path, "rb") as f:
|
||||
response = requests.post(
|
||||
IPTV_SERVER_URL + '/admin/epg',
|
||||
auth=requests.auth.HTTPBasicAuth(IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD),
|
||||
files={'file': (os.path.basename(file_path), f)}
|
||||
IPTV_SERVER_URL + "/admin/epg",
|
||||
auth=requests.auth.HTTPBasicAuth(
|
||||
IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD
|
||||
),
|
||||
files={"file": (os.path.basename(file_path), f)},
|
||||
)
|
||||
|
||||
|
||||
if response.status_code == 200:
|
||||
print("EPG successfully uploaded to server")
|
||||
else:
|
||||
@@ -129,6 +154,7 @@ def upload_epg(file_path):
|
||||
except Exception as e:
|
||||
print(f"Upload error: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments()
|
||||
playlist_file = args.playlist
|
||||
@@ -144,4 +170,4 @@ if __name__ == "__main__":
|
||||
filter_and_build_epg(urls, tvg_ids, output_file, args.save_as_gz)
|
||||
|
||||
if args.save_as_gz:
|
||||
upload_epg(output_file + '.gz')
|
||||
upload_epg(output_file + ".gz")
|
||||
|
||||
@@ -1,26 +1,45 @@
|
||||
import os
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import requests
|
||||
from pathlib import Path
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
from utils.check_streams import StreamValidator
|
||||
from utils.constants import EPG_URL, IPTV_SERVER_ADMIN_PASSWORD, IPTV_SERVER_ADMIN_USER, IPTV_SERVER_URL
|
||||
from utils.constants import (
|
||||
EPG_URL,
|
||||
IPTV_SERVER_ADMIN_PASSWORD,
|
||||
IPTV_SERVER_ADMIN_USER,
|
||||
IPTV_SERVER_URL,
|
||||
)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description='IPTV playlist generator')
|
||||
parser.add_argument('--output',
|
||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'playlist.m3u8'),
|
||||
help='Path to output playlist file')
|
||||
parser.add_argument('--channels',
|
||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'channels.json'),
|
||||
help='Path to channels definition JSON file')
|
||||
parser.add_argument('--dead-channels-log',
|
||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'dead_channels.log'),
|
||||
help='Path to log file to store a list of dead channels')
|
||||
parser = argparse.ArgumentParser(description="IPTV playlist generator")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default=os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "playlist.m3u8"
|
||||
),
|
||||
help="Path to output playlist file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--channels",
|
||||
default=os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "channels.json"
|
||||
),
|
||||
help="Path to channels definition JSON file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dead-channels-log",
|
||||
default=os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "dead_channels.log"
|
||||
),
|
||||
help="Path to log file to store a list of dead channels",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def find_working_stream(validator, urls):
|
||||
"""Test all URLs and return the first working one"""
|
||||
for url in urls:
|
||||
@@ -29,48 +48,55 @@ def find_working_stream(validator, urls):
|
||||
return url
|
||||
return None
|
||||
|
||||
|
||||
def create_playlist(channels_file, output_file):
|
||||
# Read channels from JSON file
|
||||
with open(channels_file, 'r', encoding='utf-8') as f:
|
||||
with open(channels_file, encoding="utf-8") as f:
|
||||
channels = json.load(f)
|
||||
|
||||
# Initialize validator
|
||||
validator = StreamValidator(timeout=45)
|
||||
|
||||
|
||||
# Prepare M3U8 header
|
||||
m3u8_content = f'#EXTM3U url-tvg="{EPG_URL}"\n'
|
||||
|
||||
|
||||
for channel in channels:
|
||||
if 'urls' in channel: # Check if channel has URLs
|
||||
if "urls" in channel: # Check if channel has URLs
|
||||
# Find first working stream
|
||||
working_url = find_working_stream(validator, channel['urls'])
|
||||
|
||||
working_url = find_working_stream(validator, channel["urls"])
|
||||
|
||||
if working_url:
|
||||
# Add channel to playlist
|
||||
m3u8_content += f'#EXTINF:-1 tvg-id="{channel.get("tvg-id", "")}" '
|
||||
m3u8_content += f'tvg-name="{channel.get("tvg-name", "")}" '
|
||||
m3u8_content += f'tvg-logo="{channel.get("tvg-logo", "")}" '
|
||||
m3u8_content += f'group-title="{channel.get("group-title", "")}", '
|
||||
m3u8_content += f'{channel.get("name", "")}\n'
|
||||
m3u8_content += f'{working_url}\n'
|
||||
m3u8_content += f"{channel.get('name', '')}\n"
|
||||
m3u8_content += f"{working_url}\n"
|
||||
else:
|
||||
# Log dead channel
|
||||
logging.info(f'Dead channel: {channel.get("name", "Unknown")} - No working streams found')
|
||||
logging.info(
|
||||
f"Dead channel: {channel.get('name', 'Unknown')} - "
|
||||
"No working streams found"
|
||||
)
|
||||
|
||||
# Write playlist file
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
f.write(m3u8_content)
|
||||
|
||||
|
||||
def upload_playlist(file_path):
|
||||
"""Uploads playlist file to IPTV server using HTTP Basic Auth"""
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
with open(file_path, "rb") as f:
|
||||
response = requests.post(
|
||||
IPTV_SERVER_URL + '/admin/playlist',
|
||||
auth=requests.auth.HTTPBasicAuth(IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD),
|
||||
files={'file': (os.path.basename(file_path), f)}
|
||||
IPTV_SERVER_URL + "/admin/playlist",
|
||||
auth=requests.auth.HTTPBasicAuth(
|
||||
IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD
|
||||
),
|
||||
files={"file": (os.path.basename(file_path), f)},
|
||||
)
|
||||
|
||||
|
||||
if response.status_code == 200:
|
||||
print("Playlist successfully uploaded to server")
|
||||
else:
|
||||
@@ -78,6 +104,7 @@ def upload_playlist(file_path):
|
||||
except Exception as e:
|
||||
print(f"Upload error: {str(e)}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
channels_file = args.channels
|
||||
@@ -85,24 +112,25 @@ def main():
|
||||
dead_channels_log_file = args.dead_channels_log
|
||||
|
||||
# Clear previous log file
|
||||
with open(dead_channels_log_file, 'w') as f:
|
||||
f.write(f'Log created on {datetime.now()}\n')
|
||||
with open(dead_channels_log_file, "w") as f:
|
||||
f.write(f"Log created on {datetime.now()}\n")
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
filename=dead_channels_log_file,
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
format="%(asctime)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
|
||||
# Create playlist
|
||||
create_playlist(channels_file, output_file)
|
||||
|
||||
#upload playlist to server
|
||||
# upload playlist to server
|
||||
upload_playlist(output_file)
|
||||
|
||||
|
||||
print("Playlist creation completed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
19
app/main.py
19
app/main.py
@@ -1,17 +1,18 @@
|
||||
|
||||
from fastapi.concurrency import asynccontextmanager
|
||||
from app.routers import channels, auth, playlist, priorities
|
||||
from fastapi import FastAPI
|
||||
from fastapi.concurrency import asynccontextmanager
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from app.routers import auth, channels, playlist, priorities
|
||||
from app.utils.database import init_db
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Initialize database tables on startup
|
||||
init_db()
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
title="IPTV Updater API",
|
||||
@@ -19,6 +20,7 @@ app = FastAPI(
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
@@ -40,11 +42,7 @@ def custom_openapi():
|
||||
|
||||
# Add security scheme component
|
||||
openapi_schema["components"]["securitySchemes"] = {
|
||||
"Bearer": {
|
||||
"type": "http",
|
||||
"scheme": "bearer",
|
||||
"bearerFormat": "JWT"
|
||||
}
|
||||
"Bearer": {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}
|
||||
}
|
||||
|
||||
# Add global security requirement
|
||||
@@ -56,14 +54,17 @@ def custom_openapi():
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
|
||||
app.openapi = custom_openapi
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "IPTV Updater API"}
|
||||
|
||||
|
||||
# Include routers
|
||||
app.include_router(auth.router)
|
||||
app.include_router(channels.router)
|
||||
app.include_router(playlist.router)
|
||||
app.include_router(priorities.router)
|
||||
app.include_router(priorities.router)
|
||||
|
||||
@@ -1,4 +1,19 @@
|
||||
from .db import Base, ChannelDB, ChannelURL
|
||||
from .schemas import ChannelCreate, ChannelUpdate, ChannelResponse, ChannelURLCreate, ChannelURLResponse
|
||||
from .schemas import (
|
||||
ChannelCreate,
|
||||
ChannelResponse,
|
||||
ChannelUpdate,
|
||||
ChannelURLCreate,
|
||||
ChannelURLResponse,
|
||||
)
|
||||
|
||||
__all__ = ["Base", "ChannelDB", "ChannelCreate", "ChannelUpdate", "ChannelResponse", "ChannelURL", "ChannelURLCreate", "ChannelURLResponse"]
|
||||
__all__ = [
|
||||
"Base",
|
||||
"ChannelDB",
|
||||
"ChannelCreate",
|
||||
"ChannelUpdate",
|
||||
"ChannelResponse",
|
||||
"ChannelURL",
|
||||
"ChannelURLCreate",
|
||||
"ChannelURLResponse",
|
||||
]
|
||||
|
||||
@@ -1,20 +1,26 @@
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SigninRequest(BaseModel):
|
||||
"""Request model for the signin endpoint."""
|
||||
|
||||
username: str = Field(..., description="The user's username")
|
||||
password: str = Field(..., description="The user's password")
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""Response model for successful authentication."""
|
||||
|
||||
access_token: str = Field(..., description="Access JWT token from Cognito")
|
||||
id_token: str = Field(..., description="ID JWT token from Cognito")
|
||||
refresh_token: Optional[str] = Field(
|
||||
None, description="Refresh token from Cognito")
|
||||
refresh_token: Optional[str] = Field(None, description="Refresh token from Cognito")
|
||||
token_type: str = Field(..., description="Type of the token returned")
|
||||
|
||||
|
||||
class CognitoUser(BaseModel):
|
||||
"""Model representing the user returned from token verification."""
|
||||
|
||||
username: str
|
||||
roles: List[str]
|
||||
roles: list[str]
|
||||
|
||||
@@ -1,21 +1,33 @@
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, JSON, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm import declarative_base, relationship
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class Priority(Base):
|
||||
"""SQLAlchemy model for channel URL priorities"""
|
||||
|
||||
__tablename__ = "priorities"
|
||||
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
description = Column(String, nullable=False)
|
||||
|
||||
|
||||
class ChannelDB(Base):
|
||||
"""SQLAlchemy model for IPTV channels"""
|
||||
|
||||
__tablename__ = "channels"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
@@ -25,27 +37,43 @@ class ChannelDB(Base):
|
||||
tvg_name = Column(String)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint('group_title', 'name', name='uix_group_title_name'),
|
||||
UniqueConstraint("group_title", "name", name="uix_group_title_name"),
|
||||
)
|
||||
tvg_logo = Column(String)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Relationship with ChannelURL
|
||||
urls = relationship("ChannelURL", back_populates="channel", cascade="all, delete-orphan")
|
||||
urls = relationship(
|
||||
"ChannelURL", back_populates="channel", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class ChannelURL(Base):
|
||||
"""SQLAlchemy model for channel URLs"""
|
||||
|
||||
__tablename__ = "channels_urls"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
channel_id = Column(UUID(as_uuid=True), ForeignKey('channels.id', ondelete='CASCADE'), nullable=False)
|
||||
channel_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("channels.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
url = Column(String, nullable=False)
|
||||
in_use = Column(Boolean, default=False, nullable=False)
|
||||
priority_id = Column(Integer, ForeignKey('priorities.id'), nullable=False)
|
||||
priority_id = Column(Integer, ForeignKey("priorities.id"), nullable=False)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Relationships
|
||||
channel = relationship("ChannelDB", back_populates="urls")
|
||||
priority = relationship("Priority")
|
||||
priority = relationship("Priority")
|
||||
|
||||
@@ -1,30 +1,43 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class PriorityBase(BaseModel):
|
||||
"""Base Pydantic model for priorities"""
|
||||
|
||||
id: int
|
||||
description: str
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class PriorityCreate(PriorityBase):
|
||||
"""Pydantic model for creating priorities"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PriorityResponse(PriorityBase):
|
||||
"""Pydantic model for priority responses"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ChannelURLCreate(BaseModel):
|
||||
"""Pydantic model for creating channel URLs"""
|
||||
|
||||
url: str
|
||||
priority_id: int = Field(default=100, ge=100, le=300) # Default to High, validate range
|
||||
priority_id: int = Field(
|
||||
default=100, ge=100, le=300
|
||||
) # Default to High, validate range
|
||||
|
||||
|
||||
class ChannelURLBase(ChannelURLCreate):
|
||||
"""Base Pydantic model for channel URL responses"""
|
||||
|
||||
id: UUID
|
||||
in_use: bool
|
||||
created_at: datetime
|
||||
@@ -33,43 +46,53 @@ class ChannelURLBase(ChannelURLCreate):
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ChannelURLResponse(ChannelURLBase):
|
||||
"""Pydantic model for channel URL responses"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ChannelCreate(BaseModel):
|
||||
"""Pydantic model for creating channels"""
|
||||
urls: List[ChannelURLCreate] # List of URL objects with priority
|
||||
|
||||
urls: list[ChannelURLCreate] # List of URL objects with priority
|
||||
name: str
|
||||
group_title: str
|
||||
tvg_id: str
|
||||
tvg_logo: str
|
||||
tvg_name: str
|
||||
|
||||
|
||||
class ChannelURLUpdate(BaseModel):
|
||||
"""Pydantic model for updating channel URLs"""
|
||||
|
||||
url: Optional[str] = None
|
||||
in_use: Optional[bool] = None
|
||||
priority_id: Optional[int] = Field(default=None, ge=100, le=300)
|
||||
|
||||
|
||||
class ChannelUpdate(BaseModel):
|
||||
"""Pydantic model for updating channels (all fields optional)"""
|
||||
|
||||
name: Optional[str] = Field(None, min_length=1)
|
||||
group_title: Optional[str] = Field(None, min_length=1)
|
||||
tvg_id: Optional[str] = Field(None, min_length=1)
|
||||
tvg_logo: Optional[str] = None
|
||||
tvg_name: Optional[str] = Field(None, min_length=1)
|
||||
|
||||
|
||||
class ChannelResponse(BaseModel):
|
||||
"""Pydantic model for channel responses"""
|
||||
|
||||
id: UUID
|
||||
name: str
|
||||
group_title: str
|
||||
tvg_id: str
|
||||
tvg_logo: str
|
||||
tvg_name: str
|
||||
urls: List[ChannelURLResponse] # List of URL objects without channel_id
|
||||
urls: list[ChannelURLResponse] # List of URL objects without channel_id
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.auth.cognito import initiate_auth
|
||||
from app.models.auth import SigninRequest, TokenResponse
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/auth",
|
||||
tags=["authentication"]
|
||||
)
|
||||
router = APIRouter(prefix="/auth", tags=["authentication"])
|
||||
|
||||
|
||||
@router.post("/signin", response_model=TokenResponse, summary="Signin Endpoint")
|
||||
def signin(credentials: SigninRequest):
|
||||
"""
|
||||
Sign-in endpoint to authenticate the user with AWS Cognito using username and password.
|
||||
Sign-in endpoint to authenticate the user with AWS Cognito
|
||||
using username and password.
|
||||
On success, returns JWT tokens (access_token, id_token, refresh_token).
|
||||
"""
|
||||
auth_result = initiate_auth(credentials.username, credentials.password)
|
||||
@@ -19,4 +19,4 @@ def signin(credentials: SigninRequest):
|
||||
id_token=auth_result["IdToken"],
|
||||
refresh_token=auth_result.get("RefreshToken"),
|
||||
token_type="Bearer",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,52 +1,54 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
from sqlalchemy import and_
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.auth.dependencies import get_current_user, require_roles
|
||||
from app.models import (
|
||||
ChannelDB,
|
||||
ChannelURL,
|
||||
ChannelCreate,
|
||||
ChannelUpdate,
|
||||
ChannelDB,
|
||||
ChannelResponse,
|
||||
ChannelUpdate,
|
||||
ChannelURL,
|
||||
ChannelURLCreate,
|
||||
ChannelURLResponse,
|
||||
)
|
||||
from app.models.auth import CognitoUser
|
||||
from app.models.schemas import ChannelURLUpdate
|
||||
from app.utils.database import get_db
|
||||
from app.auth.dependencies import get_current_user, require_roles
|
||||
from app.models.auth import CognitoUser
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/channels",
|
||||
tags=["channels"]
|
||||
)
|
||||
router = APIRouter(prefix="/channels", tags=["channels"])
|
||||
|
||||
|
||||
@router.post("/", response_model=ChannelResponse, status_code=status.HTTP_201_CREATED)
|
||||
@require_roles("admin")
|
||||
def create_channel(
|
||||
channel: ChannelCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user)
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Create a new channel"""
|
||||
# Check for duplicate channel (same group_title + name)
|
||||
existing_channel = db.query(ChannelDB).filter(
|
||||
and_(
|
||||
ChannelDB.group_title == channel.group_title,
|
||||
ChannelDB.name == channel.name
|
||||
existing_channel = (
|
||||
db.query(ChannelDB)
|
||||
.filter(
|
||||
and_(
|
||||
ChannelDB.group_title == channel.group_title,
|
||||
ChannelDB.name == channel.name,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Channel with same group_title and name already exists"
|
||||
detail="Channel with same group_title and name already exists",
|
||||
)
|
||||
|
||||
# Create channel without URLs first
|
||||
channel_data = channel.model_dump(exclude={'urls'})
|
||||
channel_data = channel.model_dump(exclude={"urls"})
|
||||
urls = channel.urls
|
||||
db_channel = ChannelDB(**channel_data)
|
||||
db.add(db_channel)
|
||||
@@ -59,130 +61,142 @@ def create_channel(
|
||||
channel_id=db_channel.id,
|
||||
url=url.url,
|
||||
priority_id=url.priority_id,
|
||||
in_use=False
|
||||
in_use=False,
|
||||
)
|
||||
db.add(db_url)
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_channel)
|
||||
return db_channel
|
||||
|
||||
|
||||
@router.get("/{channel_id}", response_model=ChannelResponse)
|
||||
def get_channel(
|
||||
channel_id: UUID,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
def get_channel(channel_id: UUID, db: Session = Depends(get_db)):
|
||||
"""Get a channel by id"""
|
||||
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Channel not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||
)
|
||||
return channel
|
||||
|
||||
|
||||
@router.put("/{channel_id}", response_model=ChannelResponse)
|
||||
@require_roles("admin")
|
||||
def update_channel(
|
||||
channel_id: UUID,
|
||||
channel: ChannelUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user)
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Update a channel"""
|
||||
db_channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
|
||||
if not db_channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Channel not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||
)
|
||||
|
||||
# Only check for duplicates if name or group_title are being updated
|
||||
if channel.name is not None or channel.group_title is not None:
|
||||
name = channel.name if channel.name is not None else db_channel.name
|
||||
group_title = channel.group_title if channel.group_title is not None else db_channel.group_title
|
||||
|
||||
existing_channel = db.query(ChannelDB).filter(
|
||||
and_(
|
||||
ChannelDB.group_title == group_title,
|
||||
ChannelDB.name == name,
|
||||
ChannelDB.id != channel_id
|
||||
group_title = (
|
||||
channel.group_title
|
||||
if channel.group_title is not None
|
||||
else db_channel.group_title
|
||||
)
|
||||
|
||||
existing_channel = (
|
||||
db.query(ChannelDB)
|
||||
.filter(
|
||||
and_(
|
||||
ChannelDB.group_title == group_title,
|
||||
ChannelDB.name == name,
|
||||
ChannelDB.id != channel_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Channel with same group_title and name already exists"
|
||||
detail="Channel with same group_title and name already exists",
|
||||
)
|
||||
|
||||
|
||||
# Update only provided fields
|
||||
update_data = channel.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_channel, key, value)
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_channel)
|
||||
return db_channel
|
||||
|
||||
|
||||
@router.delete("/{channel_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
@require_roles("admin")
|
||||
def delete_channel(
|
||||
channel_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user)
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Delete a channel"""
|
||||
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Channel not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||
)
|
||||
db.delete(channel)
|
||||
db.commit()
|
||||
return None
|
||||
|
||||
@router.get("/", response_model=List[ChannelResponse])
|
||||
|
||||
@router.get("/", response_model=list[ChannelResponse])
|
||||
@require_roles("admin")
|
||||
def list_channels(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user)
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""List all channels with pagination"""
|
||||
return db.query(ChannelDB).offset(skip).limit(limit).all()
|
||||
|
||||
|
||||
# URL Management Endpoints
|
||||
|
||||
@router.post("/{channel_id}/urls", response_model=ChannelURLResponse, status_code=status.HTTP_201_CREATED)
|
||||
|
||||
@router.post(
|
||||
"/{channel_id}/urls",
|
||||
response_model=ChannelURLResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
@require_roles("admin")
|
||||
def add_channel_url(
|
||||
channel_id: UUID,
|
||||
url: ChannelURLCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user)
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Add a new URL to a channel"""
|
||||
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Channel not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||
)
|
||||
|
||||
db_url = ChannelURL(
|
||||
channel_id=channel_id,
|
||||
url=url.url,
|
||||
priority_id=url.priority_id,
|
||||
in_use=False # Default to not in use
|
||||
in_use=False, # Default to not in use
|
||||
)
|
||||
db.add(db_url)
|
||||
db.commit()
|
||||
db.refresh(db_url)
|
||||
return db_url
|
||||
|
||||
|
||||
@router.put("/{channel_id}/urls/{url_id}", response_model=ChannelURLResponse)
|
||||
@require_roles("admin")
|
||||
def update_channel_url(
|
||||
@@ -190,72 +204,69 @@ def update_channel_url(
|
||||
url_id: UUID,
|
||||
url_update: ChannelURLUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user)
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Update a channel URL (url, in_use, or priority_id)"""
|
||||
db_url = db.query(ChannelURL).filter(
|
||||
and_(
|
||||
ChannelURL.id == url_id,
|
||||
ChannelURL.channel_id == channel_id
|
||||
)
|
||||
).first()
|
||||
|
||||
db_url = (
|
||||
db.query(ChannelURL)
|
||||
.filter(and_(ChannelURL.id == url_id, ChannelURL.channel_id == channel_id))
|
||||
.first()
|
||||
)
|
||||
|
||||
if not db_url:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="URL not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="URL not found"
|
||||
)
|
||||
|
||||
|
||||
if url_update.url is not None:
|
||||
db_url.url = url_update.url
|
||||
if url_update.in_use is not None:
|
||||
db_url.in_use = url_update.in_use
|
||||
if url_update.priority_id is not None:
|
||||
db_url.priority_id = url_update.priority_id
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_url)
|
||||
return db_url
|
||||
|
||||
|
||||
@router.delete("/{channel_id}/urls/{url_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
@require_roles("admin")
|
||||
def delete_channel_url(
|
||||
channel_id: UUID,
|
||||
url_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user)
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Delete a URL from a channel"""
|
||||
url = db.query(ChannelURL).filter(
|
||||
and_(
|
||||
ChannelURL.id == url_id,
|
||||
ChannelURL.channel_id == channel_id
|
||||
)
|
||||
).first()
|
||||
|
||||
url = (
|
||||
db.query(ChannelURL)
|
||||
.filter(and_(ChannelURL.id == url_id, ChannelURL.channel_id == channel_id))
|
||||
.first()
|
||||
)
|
||||
|
||||
if not url:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="URL not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="URL not found"
|
||||
)
|
||||
|
||||
|
||||
db.delete(url)
|
||||
db.commit()
|
||||
return None
|
||||
|
||||
@router.get("/{channel_id}/urls", response_model=List[ChannelURLResponse])
|
||||
|
||||
@router.get("/{channel_id}/urls", response_model=list[ChannelURLResponse])
|
||||
@require_roles("admin")
|
||||
def list_channel_urls(
|
||||
channel_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user)
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""List all URLs for a channel"""
|
||||
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Channel not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||
)
|
||||
|
||||
return db.query(ChannelURL).filter(ChannelURL.channel_id == channel_id).all()
|
||||
|
||||
return db.query(ChannelURL).filter(ChannelURL.channel_id == channel_id).all()
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.auth.dependencies import get_current_user
|
||||
from app.models.auth import CognitoUser
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/playlist",
|
||||
tags=["playlist"]
|
||||
)
|
||||
router = APIRouter(prefix="/playlist", tags=["playlist"])
|
||||
|
||||
@router.get("/protected",
|
||||
summary="Protected endpoint for authenticated users")
|
||||
|
||||
@router.get("/protected", summary="Protected endpoint for authenticated users")
|
||||
async def protected_route(user: CognitoUser = Depends(get_current_user)):
|
||||
"""
|
||||
Protected endpoint that requires authentication for all users.
|
||||
If the user is authenticated, returns success message.
|
||||
"""
|
||||
return {"message": f"Hello {user.username}, you have access to support resources!"}
|
||||
return {"message": f"Hello {user.username}, you have access to support resources!"}
|
||||
|
||||
@@ -1,25 +1,22 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, delete
|
||||
from typing import List
|
||||
|
||||
from app.auth.dependencies import get_current_user, require_roles
|
||||
from app.models.auth import CognitoUser
|
||||
from app.models.db import Priority
|
||||
from app.models.schemas import PriorityCreate, PriorityResponse
|
||||
from app.utils.database import get_db
|
||||
from app.auth.dependencies import get_current_user, require_roles
|
||||
from app.models.auth import CognitoUser
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/priorities",
|
||||
tags=["priorities"]
|
||||
)
|
||||
router = APIRouter(prefix="/priorities", tags=["priorities"])
|
||||
|
||||
|
||||
@router.post("/", response_model=PriorityResponse, status_code=status.HTTP_201_CREATED)
|
||||
@require_roles("admin")
|
||||
def create_priority(
|
||||
priority: PriorityCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user)
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Create a new priority"""
|
||||
# Check if priority with this ID already exists
|
||||
@@ -27,71 +24,69 @@ def create_priority(
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Priority with ID {priority.id} already exists"
|
||||
detail=f"Priority with ID {priority.id} already exists",
|
||||
)
|
||||
|
||||
|
||||
db_priority = Priority(**priority.model_dump())
|
||||
db.add(db_priority)
|
||||
db.commit()
|
||||
db.refresh(db_priority)
|
||||
return db_priority
|
||||
|
||||
@router.get("/", response_model=List[PriorityResponse])
|
||||
|
||||
@router.get("/", response_model=list[PriorityResponse])
|
||||
@require_roles("admin")
|
||||
def list_priorities(
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user)
|
||||
db: Session = Depends(get_db), user: CognitoUser = Depends(get_current_user)
|
||||
):
|
||||
"""List all priorities"""
|
||||
return db.query(Priority).all()
|
||||
|
||||
|
||||
@router.get("/{priority_id}", response_model=PriorityResponse)
|
||||
@require_roles("admin")
|
||||
def get_priority(
|
||||
priority_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user)
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Get a priority by id"""
|
||||
priority = db.get(Priority, priority_id)
|
||||
if not priority:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Priority not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Priority not found"
|
||||
)
|
||||
return priority
|
||||
|
||||
|
||||
@router.delete("/{priority_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
@require_roles("admin")
|
||||
def delete_priority(
|
||||
priority_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user)
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Delete a priority (if not in use)"""
|
||||
from app.models.db import ChannelURL
|
||||
|
||||
|
||||
# Check if priority exists
|
||||
priority = db.get(Priority, priority_id)
|
||||
if not priority:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Priority not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Priority not found"
|
||||
)
|
||||
|
||||
|
||||
# Check if priority is in use
|
||||
in_use = db.scalar(
|
||||
select(ChannelURL)
|
||||
.where(ChannelURL.priority_id == priority_id)
|
||||
.limit(1)
|
||||
select(ChannelURL).where(ChannelURL.priority_id == priority_id).limit(1)
|
||||
)
|
||||
|
||||
|
||||
if in_use:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Cannot delete priority that is in use by channel URLs"
|
||||
detail="Cannot delete priority that is in use by channel URLs",
|
||||
)
|
||||
|
||||
|
||||
db.execute(delete(Priority).where(Priority.id == priority_id))
|
||||
db.commit()
|
||||
return None
|
||||
return None
|
||||
|
||||
@@ -2,11 +2,13 @@ import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
|
||||
|
||||
def calculate_secret_hash(username: str, client_id: str, client_secret: str) -> str:
|
||||
"""
|
||||
Calculate the Cognito SECRET_HASH using HMAC SHA256 for secret-enabled clients.
|
||||
"""
|
||||
msg = username + client_id
|
||||
dig = hmac.new(client_secret.encode('utf-8'),
|
||||
msg.encode('utf-8'), hashlib.sha256).digest()
|
||||
return base64.b64encode(dig).decode()
|
||||
dig = hmac.new(
|
||||
client_secret.encode("utf-8"), msg.encode("utf-8"), hashlib.sha256
|
||||
).digest()
|
||||
return base64.b64encode(dig).decode()
|
||||
|
||||
@@ -1,41 +1,50 @@
|
||||
import os
|
||||
import argparse
|
||||
import requests
|
||||
import logging
|
||||
from requests.exceptions import RequestException, Timeout, ConnectionError, HTTPError
|
||||
import os
|
||||
|
||||
import requests
|
||||
from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout
|
||||
|
||||
|
||||
class StreamValidator:
|
||||
def __init__(self, timeout=10, user_agent=None):
|
||||
self.timeout = timeout
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({
|
||||
'User-Agent': user_agent or 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36'
|
||||
})
|
||||
|
||||
self.session.headers.update(
|
||||
{
|
||||
"User-Agent": user_agent
|
||||
or (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/120.0.0.0 Safari/537.36"
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
def validate_stream(self, url):
|
||||
"""Validate a media stream URL with multiple fallback checks"""
|
||||
try:
|
||||
headers = {'Range': 'bytes=0-1024'}
|
||||
headers = {"Range": "bytes=0-1024"}
|
||||
with self.session.get(
|
||||
url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
stream=True,
|
||||
allow_redirects=True
|
||||
allow_redirects=True,
|
||||
) as response:
|
||||
if response.status_code not in [200, 206]:
|
||||
return False, f"Invalid status code: {response.status_code}"
|
||||
|
||||
content_type = response.headers.get('Content-Type', '')
|
||||
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
if not self._is_valid_content_type(content_type):
|
||||
return False, f"Invalid content type: {content_type}"
|
||||
|
||||
|
||||
try:
|
||||
next(response.iter_content(chunk_size=1024))
|
||||
return True, "Stream is valid"
|
||||
except (ConnectionError, Timeout):
|
||||
return False, "Connection failed during content read"
|
||||
|
||||
|
||||
except HTTPError as e:
|
||||
return False, f"HTTP Error: {str(e)}"
|
||||
except ConnectionError as e:
|
||||
@@ -49,10 +58,13 @@ class StreamValidator:
|
||||
|
||||
def _is_valid_content_type(self, content_type):
|
||||
valid_types = [
|
||||
'video/mp2t', 'application/vnd.apple.mpegurl',
|
||||
'application/dash+xml', 'video/mp4',
|
||||
'video/webm', 'application/octet-stream',
|
||||
'application/x-mpegURL'
|
||||
"video/mp2t",
|
||||
"application/vnd.apple.mpegurl",
|
||||
"application/dash+xml",
|
||||
"video/mp4",
|
||||
"video/webm",
|
||||
"application/octet-stream",
|
||||
"application/x-mpegURL",
|
||||
]
|
||||
return any(ct in content_type for ct in valid_types)
|
||||
|
||||
@@ -60,45 +72,43 @@ class StreamValidator:
|
||||
"""Extract stream URLs from M3U playlist file"""
|
||||
urls = []
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
with open(file_path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
if line and not line.startswith("#"):
|
||||
urls.append(line)
|
||||
except Exception as e:
|
||||
logging.error(f"Error reading playlist file: {str(e)}")
|
||||
raise
|
||||
return urls
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Validate streaming URLs from command line arguments or playlist files',
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
description=(
|
||||
"Validate streaming URLs from command line arguments or playlist files"
|
||||
),
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
'sources',
|
||||
nargs='+',
|
||||
help='List of URLs or file paths containing stream URLs'
|
||||
"sources", nargs="+", help="List of URLs or file paths containing stream URLs"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--timeout',
|
||||
type=int,
|
||||
default=20,
|
||||
help='Timeout in seconds for stream checks'
|
||||
"--timeout", type=int, default=20, help="Timeout in seconds for stream checks"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output',
|
||||
default='deadstreams.txt',
|
||||
help='Output file name for inactive streams'
|
||||
"--output",
|
||||
default="deadstreams.txt",
|
||||
help="Output file name for inactive streams",
|
||||
)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[logging.FileHandler('stream_check.log'), logging.StreamHandler()]
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
handlers=[logging.FileHandler("stream_check.log"), logging.StreamHandler()],
|
||||
)
|
||||
|
||||
validator = StreamValidator(timeout=args.timeout)
|
||||
@@ -127,9 +137,10 @@ def main():
|
||||
|
||||
# Save dead streams to file
|
||||
if dead_streams:
|
||||
with open(args.output, 'w') as f:
|
||||
f.write('\n'.join(dead_streams))
|
||||
with open(args.output, "w") as f:
|
||||
f.write("\n".join(dead_streams))
|
||||
logging.info(f"Found {len(dead_streams)} dead streams. Saved to {args.output}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@@ -21,4 +21,4 @@ IPTV_SERVER_ADMIN_USER = os.getenv("IPTV_SERVER_ADMIN_USER", "admin")
|
||||
IPTV_SERVER_ADMIN_PASSWORD = os.getenv("IPTV_SERVER_ADMIN_PASSWORD", "adminpassword")
|
||||
|
||||
# URL for the EPG XML file to place in the playlist's header
|
||||
EPG_URL = os.getenv("EPG_URL", "https://example.com/epg.xml.gz")
|
||||
EPG_URL = os.getenv("EPG_URL", "https://example.com/epg.xml.gz")
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import os
|
||||
|
||||
import boto3
|
||||
from app.models import Base
|
||||
from .constants import AWS_REGION
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from functools import lru_cache
|
||||
|
||||
from app.models import Base
|
||||
|
||||
from .constants import AWS_REGION
|
||||
|
||||
|
||||
def get_db_credentials():
|
||||
"""Fetch and cache DB credentials from environment or SSM Parameter Store"""
|
||||
@@ -14,29 +16,40 @@ def get_db_credentials():
|
||||
f"postgresql://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}"
|
||||
f"@{os.getenv('DB_HOST')}/{os.getenv('DB_NAME')}"
|
||||
)
|
||||
|
||||
ssm = boto3.client('ssm', region_name=AWS_REGION)
|
||||
|
||||
ssm = boto3.client("ssm", region_name=AWS_REGION)
|
||||
try:
|
||||
host = ssm.get_parameter(Name='/iptv-updater/DB_HOST', WithDecryption=True)['Parameter']['Value']
|
||||
user = ssm.get_parameter(Name='/iptv-updater/DB_USER', WithDecryption=True)['Parameter']['Value']
|
||||
password = ssm.get_parameter(Name='/iptv-updater/DB_PASSWORD', WithDecryption=True)['Parameter']['Value']
|
||||
dbname = ssm.get_parameter(Name='/iptv-updater/DB_NAME', WithDecryption=True)['Parameter']['Value']
|
||||
host = ssm.get_parameter(Name="/iptv-updater/DB_HOST", WithDecryption=True)[
|
||||
"Parameter"
|
||||
]["Value"]
|
||||
user = ssm.get_parameter(Name="/iptv-updater/DB_USER", WithDecryption=True)[
|
||||
"Parameter"
|
||||
]["Value"]
|
||||
password = ssm.get_parameter(
|
||||
Name="/iptv-updater/DB_PASSWORD", WithDecryption=True
|
||||
)["Parameter"]["Value"]
|
||||
dbname = ssm.get_parameter(Name="/iptv-updater/DB_NAME", WithDecryption=True)[
|
||||
"Parameter"
|
||||
]["Value"]
|
||||
return f"postgresql://{user}:{password}@{host}/{dbname}"
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to fetch DB credentials from SSM: {str(e)}")
|
||||
|
||||
|
||||
# Initialize engine and session maker
|
||||
engine = create_engine(get_db_credentials())
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
def init_db():
|
||||
"""Initialize database by creating all tables"""
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
def get_db():
|
||||
"""Dependency for getting database session"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
db.close()
|
||||
|
||||
Reference in New Issue
Block a user