Compare commits

...

72 Commits

Author SHA1 Message Date
a42d4c30a6 Started (incomplete) implementation of stream verification scheduler and endpoints
All checks were successful
AWS Deploy on Push / build (push) Successful in 5m18s
2025-06-17 17:12:39 -05:00
abb467749b Implemented bulk upload by passing a json structure. Added delete all channels, groups and priorities
All checks were successful
AWS Deploy on Push / build (push) Successful in 2m17s
2025-06-12 18:49:20 -05:00
b8ac25e301 Introduced groups and added all related endpoints
All checks were successful
AWS Deploy on Push / build (push) Successful in 7m39s
2025-06-10 23:02:46 -05:00
729eabf27f Updated documentation
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m8s
2025-05-29 18:46:40 -05:00
34c446bcfa Make sure DB credentials are available when running userdata (fix-2)
All checks were successful
AWS Deploy on Push / build (push) Successful in 10m26s
2025-05-29 17:52:53 -05:00
d4cc74ea8c Make sure DB credentials are available when running userdata (fix-1)
Some checks failed
AWS Deploy on Push / build (push) Failing after 39s
2025-05-29 17:48:23 -05:00
21b73b6843 Make sure DB credentials are available when running userdata
Some checks failed
AWS Deploy on Push / build (push) Failing after 41s
2025-05-29 17:43:08 -05:00
e743daf9f7 Moved creation of the instance after database creation
All checks were successful
AWS Deploy on Push / build (push) Successful in 6m51s
2025-05-29 17:16:08 -05:00
b0d98551b8 Fixed install of postgres client on Amazon Linux 2023
All checks were successful
AWS Deploy on Push / build (push) Successful in 7m55s
2025-05-29 16:37:42 -05:00
eaab1ef998 Changed project name to be IPTV Manager Service
All checks were successful
AWS Deploy on Push / build (push) Successful in 8m29s
2025-05-29 16:09:52 -05:00
e25f8c1ecd Run unit test upon committing new code
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m3s
2025-05-28 23:41:12 -05:00
95bf0f9701 Created unit tests for check_streams
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m5s
2025-05-28 23:31:04 -05:00
f7a1c20066 Created unit tests for playlist.py
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m8s
2025-05-28 22:58:29 -05:00
bf6f156fec Created unit tests for priorities.py router
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m6s
2025-05-28 22:31:31 -05:00
7e25ec6755 Test refactoring
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m9s
2025-05-28 22:22:20 -05:00
6d506122d9 Add pre-commit commands to install script
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m7s
2025-05-28 22:05:13 -05:00
02913c7385 Linted and formatted all files 2025-05-28 21:52:39 -05:00
e46f13930d Added ruff linter and formatter
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m5s
2025-05-28 20:50:10 -05:00
903f190ee2 Added more unit tests for routers
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m10s
2025-05-27 23:41:00 -05:00
32af6bbdb5 Added unit tests for all auth classes
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m8s
2025-05-27 18:44:05 -05:00
7ee7a0e644 Added unit tests for cognito.py
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m10s
2025-05-27 18:31:18 -05:00
9474a3ca44 Added unit tests for main.py
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m7s
2025-05-27 18:13:32 -05:00
1ab8599dde Complete coverage for database.py
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m5s
2025-05-27 18:04:56 -05:00
fb5215b92a Cleanup test database unit tests
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m8s
2025-05-27 17:57:28 -05:00
cebbb9c1a8 Added pytest configuration and first 4 unit tests
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m4s
2025-05-27 17:37:05 -05:00
4b1a7e9bea Added alembic upgrade head to workflow script
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m7s
2025-05-26 21:43:44 -05:00
21cc99eff6 Added in_use and priority_id field for channels urls. Added priorities table. Setup sql alchemy migration. Generate first migration.
All checks were successful
AWS Deploy on Push / build (push) Successful in 2m4s
2025-05-26 21:24:41 -05:00
76dc8908de Moved install script to scripts folder
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m8s
2025-05-23 14:50:33 -05:00
07dab76e3b Install Python packages with --ignore-installed to prevent conflicts with RPM packages
All checks were successful
AWS Deploy on Push / build (push) Successful in 7m53s
2025-05-23 14:11:51 -05:00
c21b34f5fe Fixed packages installation to conform with Amazon Linux 2023
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m8s
2025-05-23 13:26:31 -05:00
b1942354d9 Switch to AMAZON Linux 2023
All checks were successful
AWS Deploy on Push / build (push) Successful in 10m24s
2025-05-23 12:57:21 -05:00
3937269bb9 Use EC2 instance public address (free tier compatible)
Some checks failed
AWS Deploy on Push / build (push) Has been cancelled
2025-05-23 12:25:23 -05:00
8c7ed421c9 Using Optional[str] instead of str | None for optional fields
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m7s
2025-05-23 11:49:13 -05:00
c96ee307db Moved channel URLs to channels_urls table. Create CRUD endpoints for new table.
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m8s
2025-05-23 11:36:04 -05:00
f11d533fac Allow PostgreSQL port for tunneling restricted to developer IP
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m35s
2025-05-23 09:30:55 -05:00
99d26a8f53 Remove ingress rule to allow remote access to database since database is now on a private subnet
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m32s
2025-05-23 01:14:44 -05:00
260fcb311b Place ec2 explicitly in public subnet
All checks were successful
AWS Deploy on Push / build (push) Successful in 7m20s
2025-05-22 23:18:05 -05:00
1e82418cad Place rds database in private subnet
Some checks failed
AWS Deploy on Push / build (push) Failing after 7m50s
2025-05-22 22:37:03 -05:00
9c690fe6a6 Fixed name on new ingress rule
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m33s
2025-05-22 11:38:14 -05:00
9e8df169fc Refactored DB schema. Channels table to have uuid as primary key. Modified endpoints accordingly
Some checks failed
AWS Deploy on Push / build (push) Failing after 1m9s
2025-05-22 11:34:04 -05:00
5ee6cb4be4 Moved endpoints to routers
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m12s
2025-05-22 09:34:20 -05:00
c1e3a6ef26 Updated README
All checks were successful
AWS Deploy on Push / build (push) Successful in 7m22s
2025-05-21 16:57:27 -05:00
cb793ef5e1 Add SendCommand permissions
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m40s
2025-05-21 16:29:34 -05:00
be719a6e34 Fixed process of updating app on running instances
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m29s
2025-05-21 16:16:02 -05:00
5767124031 Fixed database credential retrieval
All checks were successful
AWS Deploy on Push / build (push) Successful in 2m41s
2025-05-21 15:05:12 -05:00
c6f7e9cb2b Changed db instance class to t3
All checks were successful
AWS Deploy on Push / build (push) Successful in 9m5s
2025-05-21 14:32:22 -05:00
eeb0f1c844 Need at least 2 AZs for RDS subnet group
Some checks failed
AWS Deploy on Push / build (push) Failing after 1m52s
2025-05-21 14:24:50 -05:00
4cb3811d17 RDS to use public subnet
Some checks failed
AWS Deploy on Push / build (push) Failing after 1m40s
2025-05-21 14:16:48 -05:00
489281f3eb Added PostgreSQL RDS database. Added channels protected endpoints. Added scripts and docker config to run application locally in dev mode.
Some checks failed
AWS Deploy on Push / build (push) Failing after 41s
2025-05-21 14:02:01 -05:00
b947ac67f0 Fixed typo
All checks were successful
AWS Deploy on Push / build (push) Successful in 4m15s
2025-05-20 17:02:59 -05:00
dd2446a01a Updated README
All checks were successful
AWS Deploy on Push / build (push) Successful in 4m16s
2025-05-20 16:41:19 -05:00
639adba7eb Moved repo url and email for letsencrypt to env variables
All checks were successful
AWS Deploy on Push / build (push) Successful in 4m15s
2025-05-20 16:24:31 -05:00
5698e7f26b Moved ssh public key to an environment variable
All checks were successful
AWS Deploy on Push / build (push) Successful in 4m14s
2025-05-20 15:51:34 -05:00
df3fc2f37c Fixed acme.sh installation in userdata.sh - Part 2
All checks were successful
AWS Deploy on Push / build (push) Successful in 4m17s
2025-05-20 15:02:11 -05:00
594ce0c67a Fixed acme.sh installation in userdata.sh
All checks were successful
AWS Deploy on Push / build (push) Successful in 4m20s
2025-05-20 14:48:26 -05:00
37be1f3f91 Fixed replacing of DOMAIN_NAME in userdata.sh
All checks were successful
AWS Deploy on Push / build (push) Successful in 4m18s
2025-05-20 14:11:03 -05:00
732667cf64 Added SSL cert generation and installation. Moved variables to ENV
All checks were successful
AWS Deploy on Push / build (push) Successful in 4m15s
2025-05-20 12:45:55 -05:00
5bc7a72a92 Initial README version representing the current progress
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m11s
2025-05-16 14:37:35 -05:00
a5dfc1b493 Added scripts to create and delete users. Moved all scripts to new scripts folder
All checks were successful
AWS Deploy on Push / build (push) Successful in 4m21s
2025-05-16 14:19:14 -05:00
0b69ffd67c Switch to cognito user/password authentication. Major code refactor - Fix 4
All checks were successful
AWS Deploy on Push / build (push) Successful in 4m21s
2025-05-16 13:39:29 -05:00
127d81adac Switch to cognito user/password authentication. Major code refactor - Fix 3
All checks were successful
AWS Deploy on Push / build (push) Successful in 4m20s
2025-05-16 13:22:10 -05:00
658f7998ef Switch to cognito user/password authentication. Major code refactor - Fix 2
All checks were successful
AWS Deploy on Push / build (push) Successful in 4m25s
2025-05-16 13:00:09 -05:00
c4f19999dc Switch to cognito user/password authentication. Major code refactor - Fix 1
All checks were successful
AWS Deploy on Push / build (push) Successful in 4m11s
2025-05-16 11:11:16 -05:00
c221a8cded Switch to cognito user/password authentication. Major code refactor.
Some checks failed
AWS Deploy on Push / build (push) Failing after 48s
2025-05-16 11:05:54 -05:00
8d1997fa5a Added cognito authentication - Fix 12
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m10s
2025-05-15 17:27:51 -05:00
795a25961f Added cognito authentication - Fix 11
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m11s
2025-05-15 17:23:52 -05:00
d55c383bc4 Added cognito authentication - Fix 10
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m9s
2025-05-15 17:14:33 -05:00
5c17e4b1e9 Added cognito authentication - Fix 9
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m8s
2025-05-15 17:08:01 -05:00
30ccf86c86 Added cognito authentication - Fix 8
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m13s
2025-05-15 16:52:54 -05:00
ae040fc49e Added cognito authentication - Fix 7
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m14s
2025-05-15 16:42:42 -05:00
47befceb17 Added cognito authentication - Fix 6
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m32s
2025-05-15 16:29:57 -05:00
7f282049ac Added cognito authentication - Fix 5
Some checks failed
AWS Deploy on Push / build (push) Failing after 1m0s
2025-05-15 16:24:37 -05:00
80 changed files with 7799 additions and 363 deletions

19
.env.example Normal file
View File

@@ -0,0 +1,19 @@
# Environment variables
# Scheduler configuration
STREAM_VALIDATION_SCHEDULE=0 3 * * * # Daily at 3 AM (cron syntax)
STREAM_VALIDATION_BATCH_SIZE=10 # Number of channels per batch (0=all)
# For use with Docker Compose to run application locally
MOCK_AUTH=true/false
DB_USER=MyDBUser
DB_PASSWORD=MyDBPassword
DB_HOST=MyDBHost
DB_NAME=iptv_manager
FREEDNS_User=MyFreeDNSUsername
FREEDNS_Password=MyFreeDNSPassword
DOMAIN_NAME=mydomain.com
SSH_PUBLIC_KEY="ssh-rsa AAAAB3NzaC1yc2EMYPUBLICKEY7+"
REPO_URL="https://git.example.com/user/repo.git"
LETSENCRYPT_EMAIL="admin@example.com"

View File

@@ -39,6 +39,13 @@ jobs:
- name: Deploy to AWS - name: Deploy to AWS
run: cdk deploy --app="python3 ${PWD}/app.py" --require-approval=never run: cdk deploy --app="python3 ${PWD}/app.py" --require-approval=never
env:
FREEDNS_User: ${{ secrets.FREEDNS_USER }}
FREEDNS_Password: ${{ secrets.FREEDNS_PASSWORD }}
DOMAIN_NAME: ${{ secrets.DOMAIN_NAME }}
SSH_PUBLIC_KEY: ${{ secrets.SSH_PUBLIC_KEY }}
REPO_URL: ${{ secrets.REPO_URL }}
LETSENCRYPT_EMAIL: ${{ secrets.LETSENCRYPT_EMAIL }}
- name: Install AWS CLI - name: Install AWS CLI
run: | run: |
@@ -50,20 +57,23 @@ jobs:
- name: Update application on instance - name: Update application on instance
run: | run: |
INSTANCE_IDS=$(aws ec2 describe-instances \ INSTANCE_IDS=$(aws ec2 describe-instances \
--filters "Name=tag:Name,Values=IptvUpdater/IptvUpdaterInstance" \ --region us-east-2 \
--filters "Name=tag:Name,Values=IptvManagerStack/IptvManagerInstance" \
"Name=instance-state-name,Values=running" \ "Name=instance-state-name,Values=running" \
--query "Reservations[].Instances[].InstanceId" \ --query "Reservations[].Instances[].InstanceId" \
--output text) --output text)
for INSTANCE_ID in $INSTANCE_IDS; do for INSTANCE_ID in $INSTANCE_IDS; do
aws ssm send-command \ aws ssm send-command \
--region us-east-2 \
--instance-ids "$INSTANCE_ID" \ --instance-ids "$INSTANCE_ID" \
--document-name "AWS-RunShellScript" \ --document-name "AWS-RunShellScript" \
--parameters 'commands=[ --parameters 'commands=[
"cd /home/ec2-user/iptv-updater-aws", "cd /home/ec2-user/iptv-manager-service",
"git pull", "git pull",
"pip3 install -r requirements.txt", "pip3 install -r requirements.txt",
"sudo systemctl restart iptv-updater" "alembic upgrade head",
"sudo systemctl restart iptv-manager"
]' ]'
done done

6
.gitignore vendored
View File

@@ -4,10 +4,16 @@ __pycache__
.pytest_cache .pytest_cache
.env .env
.venv .venv
*.pid
*.log
*.egg-info *.egg-info
.coverage .coverage
.roomodes
cdk.out/ cdk.out/
node_modules/ node_modules/
data/
.roo/
.ruru/
# CDK asset staging directory # CDK asset staging directory
.cdk.staging .cdk.staging

16
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,16 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.12
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format
- repo: local
hooks:
- id: pytest-check
name: pytest-check
entry: pytest
language: system
pass_filenames: false
always_run: true

84
.vscode/settings.json vendored
View File

@@ -1,23 +1,105 @@
{ {
"python.terminal.activateEnvironment": true,
"python.terminal.activateEnvInCurrentTerminal": true,
"editor.formatOnSave": true,
"editor.defaultFormatter": "charliermarsh.ruff",
"ruff.importStrategy": "fromEnvironment",
"ruff.path": ["${workspaceFolder}"],
"cSpell.words": [ "cSpell.words": [
"addopts",
"adminpassword",
"altinstall", "altinstall",
"apscheduler",
"asyncio",
"autoflush",
"autoupdate",
"autouse",
"awscli",
"awscliv", "awscliv",
"boto", "boto",
"botocore",
"BURSTABLE",
"cabletv", "cabletv",
"capsys",
"CDUF",
"cduflogo",
"cdulogo",
"CDUNF",
"cdunflogo",
"certbot", "certbot",
"certifi", "certifi",
"cfulogo",
"CLEU",
"cleulogo",
"CLUF",
"cluflogo",
"clulogo",
"cpulogo",
"crond",
"cronie",
"cuflgo",
"CUNF",
"cunflogo",
"cuulogo",
"datname",
"deadstreams",
"delenv",
"delogo",
"devel", "devel",
"dflogo",
"dmlogo",
"dotenv", "dotenv",
"EXTINF",
"EXTM",
"fastapi", "fastapi",
"filterwarnings",
"fiorinis", "fiorinis",
"freedns",
"fullchain",
"gitea", "gitea",
"httpx",
"iptv", "iptv",
"isort",
"KHTML",
"lclogo",
"LETSENCRYPT",
"levelname",
"mpegurl",
"nohup", "nohup",
"nopriority",
"ondelete",
"onupdate",
"passlib", "passlib",
"PGPASSWORD",
"poolclass",
"psql",
"psycopg",
"pycache",
"pycodestyle",
"pyflakes",
"pyjwt", "pyjwt",
"pytest",
"PYTHONDONTWRITEBYTECODE",
"PYTHONUNBUFFERED",
"pyupgrade",
"reloadcmd",
"roomodes",
"ruru",
"sessionmaker",
"sqlalchemy",
"sqliteuuid",
"starlette", "starlette",
"stefano", "stefano",
"testadmin",
"testdb",
"testpass",
"testpaths",
"testuser",
"uflogo",
"umlogo",
"usefixtures",
"uvicorn", "uvicorn",
"venv" "venv",
"wrongpass"
] ]
} }

150
README.md
View File

@@ -1 +1,149 @@
# To do # IPTV Manager Service
A FastAPI-based service for managing IPTV playlists and channel priorities. The application provides secure endpoints for user authentication, channel management, and playlist generation.
## ✨ Features
- **JWT Authentication**: Secure login using AWS Cognito
- **Channel Management**: CRUD operations for IPTV channels
- **Playlist Generation**: Create M3U playlists with channel priorities
- **Stream Monitoring**: Background checks for channel availability
- **Priority Management**: Set channel priorities for playlist ordering
## 🛠️ Technology Stack
- **Backend**: Python 3.11, FastAPI
- **Database**: PostgreSQL (SQLAlchemy ORM)
- **Authentication**: AWS Cognito
- **Infrastructure**: AWS CDK (API Gateway, Lambda, RDS)
- **Testing**: Pytest with 85%+ coverage
- **CI/CD**: Pre-commit hooks, Alembic migrations
## 🚀 Getting Started
### Prerequisites
- Python 3.11+
- Docker
- AWS CLI (for deployment)
### Installation
```bash
# Clone repository
git clone https://github.com/your-repo/iptv-manager-service.git
cd iptv-manager-service
# Setup environment
python -m venv venv
source venv/bin/activate
pip install -r requirements.txt
cp .env.example .env # Update with your values
# Run installation script
./scripts/install.sh
```
### Running Locally
```bash
# Start development environment
./scripts/start_local_dev.sh
# Stop development environment
./scripts/stop_local_dev.sh
```
## ☁️ AWS Deployment
The infrastructure is defined in CDK. Use the provided scripts:
```bash
# Deploy AWS infrastructure
./scripts/deploy.sh
# Destroy AWS infrastructure
./scripts/destroy.sh
# Create Cognito test user
./scripts/create_cognito_user.sh
# Delete Cognito user
./scripts/delete_cognito_user.sh
```
Key AWS components:
- API Gateway
- Lambda functions
- RDS PostgreSQL
- Cognito User Pool
## 🤖 Continuous Integration/Deployment
This project includes a Gitea Actions workflow (`.gitea/workflows/deploy.yml`) for automated deployment to AWS. The workflow is fully compatible with GitHub Actions and can be easily adapted by:
1. Placing the workflow file in the `.github/workflows/` directory
2. Setting up the required secrets in your CI/CD environment:
- `AWS_ACCESS_KEY_ID`
- `AWS_SECRET_ACCESS_KEY`
- `AWS_DEFAULT_REGION`
The workflow automatically deploys the infrastructure and application when changes are pushed to the main branch.
## 📚 API Documentation
Access interactive docs at:
- Swagger UI: `http://localhost:8000/docs`
- ReDoc: `http://localhost:8000/redoc`
### Key Endpoints
| Endpoint | Method | Description |
| ------------- | ------ | --------------------- |
| `/auth/login` | POST | User authentication |
| `/channels` | GET | List all channels |
| `/playlist` | GET | Generate M3U playlist |
| `/priorities` | POST | Set channel priority |
## 🧪 Testing
Run the full test suite:
```bash
pytest
```
Test coverage includes:
- Authentication workflows
- Channel CRUD operations
- Playlist generation logic
- Stream monitoring
- Database operations
## 📂 Project Structure
```txt
iptv-manager-service/
├── app/ # Core application
│ ├── auth/ # Cognito authentication
│ ├── iptv/ # Playlist logic
│ ├── models/ # Database models
│ ├── routers/ # API endpoints
│ ├── utils/ # Helper functions
│ └── main.py # App entry point
├── infrastructure/ # AWS CDK stack
├── docker/ # Docker configs
├── scripts/ # Deployment scripts
├── tests/ # Comprehensive tests
├── alembic/ # Database migrations
├── .gitea/ # Gitea CI/CD workflows
│ └── workflows/
└── ... # Config files
```
## 📝 License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

141
alembic.ini Normal file
View File

@@ -0,0 +1,141 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts.
# this is typically a path given in POSIX (e.g. forward slashes)
# format, relative to the token %(here)s which refers to the location of this
# ini file
script_location = %(here)s/alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory. for multiple paths, the path separator
# is defined by "path_separator" below.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to <script_location>/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "path_separator"
# below.
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
# path_separator; This indicates what character is used to split lists of file
# paths, including version_locations and prepend_sys_path within configparser
# files such as alembic.ini.
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
# to provide os-dependent path splitting.
#
# Note that in order to support legacy alembic.ini files, this default does NOT
# take place if path_separator is not present in alembic.ini. If this
# option is omitted entirely, fallback logic is as follows:
#
# 1. Parsing of the version_locations option falls back to using the legacy
# "version_path_separator" key, which if absent then falls back to the legacy
# behavior of splitting on spaces and/or commas.
# 2. Parsing of the prepend_sys_path option falls back to the legacy
# behavior of splitting on spaces, commas, or colons.
#
# Valid values for path_separator are:
#
# path_separator = :
# path_separator = ;
# path_separator = space
# path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# database URL. This is consumed by the user-maintained env.py script only.
# other means of configuring database URLs may be customized within the env.py
# file.
sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Logging configuration. This is also consumed by the user-maintained
# env.py script only.
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

1
alembic/README Normal file
View File

@@ -0,0 +1 @@
Generic single-database configuration.

79
alembic/env.py Normal file
View File

@@ -0,0 +1,79 @@
from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool
from alembic import context
from app.models.db import Base
from app.utils.database import get_db_credentials
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
target_metadata = Base.metadata
# Override sqlalchemy.url with dynamic credentials
if not context.is_offline_mode():
config.set_main_option("sqlalchemy.url", get_db_credentials())
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

28
alembic/script.py.mako Normal file
View File

@@ -0,0 +1,28 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,110 @@
"""add groups table and migrate group_title data
Revision ID: 0a455608256f
Revises: 95b61a92455a
Create Date: 2025-06-10 09:22:11.820035
"""
import uuid
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '0a455608256f'
down_revision: Union[str, None] = '95b61a92455a'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('groups',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('sort_order', sa.Integer(), nullable=False, server_default='0'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('name')
)
# Create temporary table for group mapping
group_mapping = op.create_table(
'group_mapping',
sa.Column('group_title', sa.String(), nullable=False),
sa.Column('group_id', sa.UUID(), nullable=False)
)
# Get existing group titles and create groups
conn = op.get_bind()
distinct_groups = conn.execute(
sa.text("SELECT DISTINCT group_title FROM channels")
).fetchall()
for group in distinct_groups:
group_title = group[0]
group_id = str(uuid.uuid4())
conn.execute(
sa.text(
"INSERT INTO groups (id, name, sort_order) "
"VALUES (:id, :name, 0)"
).bindparams(id=group_id, name=group_title)
)
conn.execute(
group_mapping.insert().values(
group_title=group_title,
group_id=group_id
)
)
# Add group_id column (nullable first)
op.add_column('channels', sa.Column('group_id', sa.UUID(), nullable=True))
# Update channels with group_ids
conn.execute(
sa.text(
"UPDATE channels c SET group_id = gm.group_id "
"FROM group_mapping gm WHERE c.group_title = gm.group_title"
)
)
# Now make group_id non-nullable and add constraints
op.alter_column('channels', 'group_id', nullable=False)
op.drop_constraint(op.f('uix_group_title_name'), 'channels', type_='unique')
op.create_unique_constraint('uix_group_id_name', 'channels', ['group_id', 'name'])
op.create_foreign_key('fk_channels_group_id', 'channels', 'groups', ['group_id'], ['id'])
# Clean up and drop group_title
op.drop_table('group_mapping')
op.drop_column('channels', 'group_title')
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('channels', sa.Column('group_title', sa.VARCHAR(), autoincrement=False, nullable=True))
# Restore group_title values from groups table
conn = op.get_bind()
conn.execute(
sa.text(
"UPDATE channels c SET group_title = g.name "
"FROM groups g WHERE c.group_id = g.id"
)
)
# Now make group_title non-nullable
op.alter_column('channels', 'group_title', nullable=False)
# Drop constraints and columns
op.drop_constraint('fk_channels_group_id', 'channels', type_='foreignkey')
op.drop_constraint('uix_group_id_name', 'channels', type_='unique')
op.create_unique_constraint(op.f('uix_group_title_name'), 'channels', ['group_title', 'name'])
op.drop_column('channels', 'group_id')
op.drop_table('groups')
# ### end Alembic commands ###

View File

@@ -0,0 +1,79 @@
"""create initial tables
Revision ID: 95b61a92455a
Revises:
Create Date: 2025-05-29 14:42:16.239587
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '95b61a92455a'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('channels',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('tvg_id', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('group_title', sa.String(), nullable=False),
sa.Column('tvg_name', sa.String(), nullable=True),
sa.Column('tvg_logo', sa.String(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('group_title', 'name', name='uix_group_title_name')
)
op.create_table('priorities',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('description', sa.String(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_table('channels_urls',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('channel_id', sa.UUID(), nullable=False),
sa.Column('url', sa.String(), nullable=False),
sa.Column('in_use', sa.Boolean(), nullable=False),
sa.Column('priority_id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['channel_id'], ['channels.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['priority_id'], ['priorities.id'], ),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
# Seed initial priorities
op.bulk_insert(
sa.Table(
'priorities',
sa.MetaData(),
sa.Column('id', sa.Integer),
sa.Column('description', sa.String),
),
[
{'id': 100, 'description': 'High'},
{'id': 200, 'description': 'Medium'},
{'id': 300, 'description': 'Low'},
]
)
def downgrade() -> None:
"""Downgrade schema."""
# Remove seeded priorities
op.execute("DELETE FROM priorities WHERE id IN (100, 200, 300);")
# Drop tables
op.drop_table('channels_urls')
op.drop_table('priorities')
op.drop_table('channels')

42
app.py
View File

@@ -1,7 +1,45 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import os
import aws_cdk as cdk import aws_cdk as cdk
from infrastructure.stack import IptvUpdaterStack
from infrastructure.stack import IptvManagerStack
app = cdk.App() app = cdk.App()
IptvUpdaterStack(app, "IptvUpdater")
# Read environment variables for FreeDNS credentials
freedns_user = os.environ.get("FREEDNS_User")
freedns_password = os.environ.get("FREEDNS_Password")
domain_name = os.environ.get("DOMAIN_NAME")
ssh_public_key = os.environ.get("SSH_PUBLIC_KEY")
repo_url = os.environ.get("REPO_URL")
letsencrypt_email = os.environ.get("LETSENCRYPT_EMAIL")
required_vars = {
"FREEDNS_User": freedns_user,
"FREEDNS_Password": freedns_password,
"DOMAIN_NAME": domain_name,
"SSH_PUBLIC_KEY": ssh_public_key,
"REPO_URL": repo_url,
"LETSENCRYPT_EMAIL": letsencrypt_email,
}
# Check for missing required variables
missing_vars = [k for k, v in required_vars.items() if not v]
if missing_vars:
raise ValueError(
f"Missing required environment variables: {', '.join(missing_vars)}"
)
IptvManagerStack(
app,
"IptvManagerStack",
freedns_user=freedns_user,
freedns_password=freedns_password,
domain_name=domain_name,
ssh_public_key=ssh_public_key,
repo_url=repo_url,
letsencrypt_email=letsencrypt_email,
)
app.synth() app.synth()

82
app/auth/cognito.py Normal file
View File

@@ -0,0 +1,82 @@
import boto3
from fastapi import HTTPException, status
from app.models.auth import CognitoUser
from app.utils.auth import calculate_secret_hash
from app.utils.constants import (
AWS_REGION,
COGNITO_CLIENT_ID,
COGNITO_CLIENT_SECRET,
USER_ROLE_ATTRIBUTE,
)
cognito_client = boto3.client("cognito-idp", region_name=AWS_REGION)
def initiate_auth(username: str, password: str) -> dict:
"""
Initiate AUTH flow with Cognito using USER_PASSWORD_AUTH.
"""
auth_params = {"USERNAME": username, "PASSWORD": password}
# If a client secret is required, add SECRET_HASH
if COGNITO_CLIENT_SECRET:
auth_params["SECRET_HASH"] = calculate_secret_hash(
username, COGNITO_CLIENT_ID, COGNITO_CLIENT_SECRET
)
try:
response = cognito_client.initiate_auth(
AuthFlow="USER_PASSWORD_AUTH",
AuthParameters=auth_params,
ClientId=COGNITO_CLIENT_ID,
)
return response["AuthenticationResult"]
except cognito_client.exceptions.NotAuthorizedException:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid username or password",
)
except cognito_client.exceptions.UserNotFoundException:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"An error occurred during authentication: {str(e)}",
)
def get_user_from_token(access_token: str) -> CognitoUser:
"""
Verify the token by calling GetUser in Cognito and
retrieve user attributes including roles.
"""
try:
user_response = cognito_client.get_user(AccessToken=access_token)
username = user_response.get("Username", "")
attributes = user_response.get("UserAttributes", [])
user_roles = []
for attr in attributes:
if attr["Name"] == USER_ROLE_ATTRIBUTE:
# Assume roles are stored as a comma-separated string
user_roles = [r.strip() for r in attr["Value"].split(",") if r.strip()]
break
return CognitoUser(username=username, roles=user_roles)
except cognito_client.exceptions.NotAuthorizedException:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token."
)
except cognito_client.exceptions.UserNotFoundException:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or invalid token.",
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Token verification failed: {str(e)}",
)

51
app/auth/dependencies.py Normal file
View File

@@ -0,0 +1,51 @@
import os
from functools import wraps
from typing import Callable
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from app.models.auth import CognitoUser
# Use mock auth for local testing if MOCK_AUTH is set
if os.getenv("MOCK_AUTH", "").lower() == "true":
from app.auth.mock_auth import mock_get_user_from_token as get_user_from_token
else:
from app.auth.cognito import get_user_from_token
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="signin", scheme_name="Bearer")
def get_current_user(token: str = Depends(oauth2_scheme)) -> CognitoUser:
"""
Dependency to get the current user from the given token.
This will verify the token with Cognito and return the user's information.
"""
return get_user_from_token(token)
def require_roles(*required_roles: str) -> Callable:
"""
Decorator for role-based access control.
Use on endpoints to enforce that the user possesses all required roles.
"""
def decorator(endpoint: Callable) -> Callable:
@wraps(endpoint)
async def wrapper(
*args, user: CognitoUser = Depends(get_current_user), **kwargs
):
user_roles = set(user.roles or [])
needed_roles = set(required_roles)
if not needed_roles.issubset(user_roles):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=(
"You do not have the required roles to access this endpoint."
),
)
return endpoint(*args, user=user, **kwargs)
return wrapper
return decorator

26
app/auth/mock_auth.py Normal file
View File

@@ -0,0 +1,26 @@
from fastapi import HTTPException, status
from app.models.auth import CognitoUser
MOCK_USERS = {"testuser": {"username": "testuser", "roles": ["admin"]}}
def mock_get_user_from_token(token: str) -> CognitoUser:
"""
Mock version of get_user_from_token for local testing
Accepts 'testuser' as a valid token and returns admin user
"""
if token == "testuser":
return CognitoUser(**MOCK_USERS["testuser"])
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid mock token - use 'testuser'",
)
def mock_initiate_auth(username: str, password: str) -> dict:
"""
Mock version of initiate_auth for local testing
Accepts any username/password and returns a mock token
"""
return {"AccessToken": "testuser", "ExpiresIn": 3600, "TokenType": "Bearer"}

View File

@@ -1,63 +0,0 @@
import os
import boto3
import requests
import jwt
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2AuthorizationCodeBearer
from fastapi.responses import RedirectResponse
REGION = "us-east-2"
USER_POOL_ID = os.getenv("COGNITO_USER_POOL_ID")
CLIENT_ID = os.getenv("COGNITO_CLIENT_ID")
DOMAIN = f"https://iptv-updater.auth.{REGION}.amazoncognito.com"
REDIRECT_URI = f"http://localhost:8000/auth/callback"
oauth2_scheme = OAuth2AuthorizationCodeBearer(
authorizationUrl=f"{DOMAIN}/oauth2/authorize",
tokenUrl=f"{DOMAIN}/oauth2/token"
)
def exchange_code_for_token(code: str):
token_url = f"{DOMAIN}/oauth2/token"
data = {
'grant_type': 'authorization_code',
'client_id': CLIENT_ID,
'code': code,
'redirect_uri': REDIRECT_URI
}
response = requests.post(token_url, data=data)
if response.status_code == 200:
return response.json()
print(f"Token exchange failed: {response.text}") # Add logging
raise HTTPException(status_code=400, detail="Failed to exchange code for token")
async def get_current_user(token: str = Depends(oauth2_scheme)):
if not token:
return RedirectResponse(
f"{DOMAIN}/login?client_id={CLIENT_ID}"
f"&response_type=code"
f"&scope=openid+email+profile" # Added more scopes
f"&redirect_uri={REDIRECT_URI}"
)
try:
# Decode JWT token instead of using get_user
decoded = jwt.decode(
token,
options={"verify_signature": False} # We trust tokens from Cognito
)
return {
"Username": decoded.get("email") or decoded.get("sub"),
"UserAttributes": [
{"Name": k, "Value": v}
for k, v in decoded.items()
]
}
except Exception as e:
print(f"Token verification failed: {str(e)}") # Add logging
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)

View File

@@ -1,39 +1,59 @@
import os import argparse
import re
import gzip import gzip
import json import json
import os
import re
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import requests import requests
import argparse from utils.constants import (
from utils.config import IPTV_SERVER_ADMIN_PASSWORD, IPTV_SERVER_ADMIN_USER, IPTV_SERVER_URL IPTV_SERVER_ADMIN_PASSWORD,
IPTV_SERVER_ADMIN_USER,
IPTV_SERVER_URL,
)
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser(description='EPG Grabber') parser = argparse.ArgumentParser(description="EPG Grabber")
parser.add_argument('--playlist', parser.add_argument(
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'playlist.m3u8'), "--playlist",
help='Path to playlist file') default=os.path.join(
parser.add_argument('--output', os.path.dirname(os.path.dirname(__file__)), "playlist.m3u8"
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'epg.xml'), ),
help='Path to output EPG XML file') help="Path to playlist file",
parser.add_argument('--epg-sources', )
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'epg_sources.json'), parser.add_argument(
help='Path to EPG sources JSON configuration file') "--output",
parser.add_argument('--save-as-gz', default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "epg.xml"),
action='store_true', help="Path to output EPG XML file",
default=True, )
help='Save an additional gzipped version of the EPG file') parser.add_argument(
"--epg-sources",
default=os.path.join(
os.path.dirname(os.path.dirname(__file__)), "epg_sources.json"
),
help="Path to EPG sources JSON configuration file",
)
parser.add_argument(
"--save-as-gz",
action="store_true",
default=True,
help="Save an additional gzipped version of the EPG file",
)
return parser.parse_args() return parser.parse_args()
def load_epg_sources(config_path): def load_epg_sources(config_path):
"""Load EPG sources from JSON configuration file""" """Load EPG sources from JSON configuration file"""
try: try:
with open(config_path, 'r', encoding='utf-8') as f: with open(config_path, encoding="utf-8") as f:
config = json.load(f) config = json.load(f)
return config.get('epg_sources', []) return config.get("epg_sources", [])
except (FileNotFoundError, json.JSONDecodeError) as e: except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Error loading EPG sources: {e}") print(f"Error loading EPG sources: {e}")
return [] return []
def get_tvg_ids(playlist_path): def get_tvg_ids(playlist_path):
""" """
Extracts unique tvg-id values from an M3U playlist file. Extracts unique tvg-id values from an M3U playlist file.
@@ -51,26 +71,27 @@ def get_tvg_ids(playlist_path):
# and ends with a double quote. # and ends with a double quote.
tvg_id_pattern = re.compile(r'tvg-id="([^"]*)"') tvg_id_pattern = re.compile(r'tvg-id="([^"]*)"')
with open(playlist_path, 'r', encoding='utf-8') as file: with open(playlist_path, encoding="utf-8") as file:
for line in file: for line in file:
if line.startswith('#EXTINF'): if line.startswith("#EXTINF"):
# Search for the tvg-id pattern in the line # Search for the tvg-id pattern in the line
match = tvg_id_pattern.search(line) match = tvg_id_pattern.search(line)
if match: if match:
# Extract the captured group (the value inside the quotes) # Extract the captured group (the value inside the quotes)
tvg_id = match.group(1) tvg_id = match.group(1)
if tvg_id: # Ensure the extracted id is not empty if tvg_id: # Ensure the extracted id is not empty
unique_tvg_ids.add(tvg_id) unique_tvg_ids.add(tvg_id)
return list(unique_tvg_ids) return list(unique_tvg_ids)
def fetch_and_extract_xml(url): def fetch_and_extract_xml(url):
response = requests.get(url) response = requests.get(url)
if response.status_code != 200: if response.status_code != 200:
print(f"Failed to fetch {url}") print(f"Failed to fetch {url}")
return None return None
if url.endswith('.gz'): if url.endswith(".gz"):
try: try:
decompressed_data = gzip.decompress(response.content) decompressed_data = gzip.decompress(response.content)
return ET.fromstring(decompressed_data) return ET.fromstring(decompressed_data)
@@ -84,42 +105,46 @@ def fetch_and_extract_xml(url):
print(f"Failed to parse XML from {url}: {e}") print(f"Failed to parse XML from {url}: {e}")
return None return None
def filter_and_build_epg(urls, tvg_ids, output_file, save_as_gz=True): def filter_and_build_epg(urls, tvg_ids, output_file, save_as_gz=True):
root = ET.Element('tv') root = ET.Element("tv")
for url in urls: for url in urls:
epg_data = fetch_and_extract_xml(url) epg_data = fetch_and_extract_xml(url)
if epg_data is None: if epg_data is None:
continue continue
for channel in epg_data.findall('channel'): for channel in epg_data.findall("channel"):
tvg_id = channel.get('id') tvg_id = channel.get("id")
if tvg_id in tvg_ids: if tvg_id in tvg_ids:
root.append(channel) root.append(channel)
for programme in epg_data.findall('programme'): for programme in epg_data.findall("programme"):
tvg_id = programme.get('channel') tvg_id = programme.get("channel")
if tvg_id in tvg_ids: if tvg_id in tvg_ids:
root.append(programme) root.append(programme)
tree = ET.ElementTree(root) tree = ET.ElementTree(root)
tree.write(output_file, encoding='utf-8', xml_declaration=True) tree.write(output_file, encoding="utf-8", xml_declaration=True)
print(f"New EPG saved to {output_file}") print(f"New EPG saved to {output_file}")
if save_as_gz: if save_as_gz:
output_file_gz = output_file + '.gz' output_file_gz = output_file + ".gz"
with gzip.open(output_file_gz, 'wb') as f: with gzip.open(output_file_gz, "wb") as f:
tree.write(f, encoding='utf-8', xml_declaration=True) tree.write(f, encoding="utf-8", xml_declaration=True)
print(f"New EPG saved to {output_file_gz}") print(f"New EPG saved to {output_file_gz}")
def upload_epg(file_path): def upload_epg(file_path):
"""Uploads gzipped EPG file to IPTV server using HTTP Basic Auth""" """Uploads gzipped EPG file to IPTV server using HTTP Basic Auth"""
try: try:
with open(file_path, 'rb') as f: with open(file_path, "rb") as f:
response = requests.post( response = requests.post(
IPTV_SERVER_URL + '/admin/epg', IPTV_SERVER_URL + "/admin/epg",
auth=requests.auth.HTTPBasicAuth(IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD), auth=requests.auth.HTTPBasicAuth(
files={'file': (os.path.basename(file_path), f)} IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD
),
files={"file": (os.path.basename(file_path), f)},
) )
if response.status_code == 200: if response.status_code == 200:
@@ -129,6 +154,7 @@ def upload_epg(file_path):
except Exception as e: except Exception as e:
print(f"Upload error: {str(e)}") print(f"Upload error: {str(e)}")
if __name__ == "__main__": if __name__ == "__main__":
args = parse_arguments() args = parse_arguments()
playlist_file = args.playlist playlist_file = args.playlist
@@ -144,4 +170,4 @@ if __name__ == "__main__":
filter_and_build_epg(urls, tvg_ids, output_file, args.save_as_gz) filter_and_build_epg(urls, tvg_ids, output_file, args.save_as_gz)
if args.save_as_gz: if args.save_as_gz:
upload_epg(output_file + '.gz') upload_epg(output_file + ".gz")

View File

@@ -1,26 +1,45 @@
import os
import argparse import argparse
import json import json
import logging import logging
import requests import os
from pathlib import Path
from datetime import datetime from datetime import datetime
import requests
from utils.check_streams import StreamValidator from utils.check_streams import StreamValidator
from utils.config import EPG_URL, IPTV_SERVER_ADMIN_PASSWORD, IPTV_SERVER_ADMIN_USER, IPTV_SERVER_URL from utils.constants import (
EPG_URL,
IPTV_SERVER_ADMIN_PASSWORD,
IPTV_SERVER_ADMIN_USER,
IPTV_SERVER_URL,
)
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser(description='IPTV playlist generator') parser = argparse.ArgumentParser(description="IPTV playlist generator")
parser.add_argument('--output', parser.add_argument(
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'playlist.m3u8'), "--output",
help='Path to output playlist file') default=os.path.join(
parser.add_argument('--channels', os.path.dirname(os.path.dirname(__file__)), "playlist.m3u8"
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'channels.json'), ),
help='Path to channels definition JSON file') help="Path to output playlist file",
parser.add_argument('--dead-channels-log', )
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'dead_channels.log'), parser.add_argument(
help='Path to log file to store a list of dead channels') "--channels",
default=os.path.join(
os.path.dirname(os.path.dirname(__file__)), "channels.json"
),
help="Path to channels definition JSON file",
)
parser.add_argument(
"--dead-channels-log",
default=os.path.join(
os.path.dirname(os.path.dirname(__file__)), "dead_channels.log"
),
help="Path to log file to store a list of dead channels",
)
return parser.parse_args() return parser.parse_args()
def find_working_stream(validator, urls): def find_working_stream(validator, urls):
"""Test all URLs and return the first working one""" """Test all URLs and return the first working one"""
for url in urls: for url in urls:
@@ -29,9 +48,10 @@ def find_working_stream(validator, urls):
return url return url
return None return None
def create_playlist(channels_file, output_file): def create_playlist(channels_file, output_file):
# Read channels from JSON file # Read channels from JSON file
with open(channels_file, 'r', encoding='utf-8') as f: with open(channels_file, encoding="utf-8") as f:
channels = json.load(f) channels = json.load(f)
# Initialize validator # Initialize validator
@@ -41,9 +61,9 @@ def create_playlist(channels_file, output_file):
m3u8_content = f'#EXTM3U url-tvg="{EPG_URL}"\n' m3u8_content = f'#EXTM3U url-tvg="{EPG_URL}"\n'
for channel in channels: for channel in channels:
if 'urls' in channel: # Check if channel has URLs if "urls" in channel: # Check if channel has URLs
# Find first working stream # Find first working stream
working_url = find_working_stream(validator, channel['urls']) working_url = find_working_stream(validator, channel["urls"])
if working_url: if working_url:
# Add channel to playlist # Add channel to playlist
@@ -51,24 +71,30 @@ def create_playlist(channels_file, output_file):
m3u8_content += f'tvg-name="{channel.get("tvg-name", "")}" ' m3u8_content += f'tvg-name="{channel.get("tvg-name", "")}" '
m3u8_content += f'tvg-logo="{channel.get("tvg-logo", "")}" ' m3u8_content += f'tvg-logo="{channel.get("tvg-logo", "")}" '
m3u8_content += f'group-title="{channel.get("group-title", "")}", ' m3u8_content += f'group-title="{channel.get("group-title", "")}", '
m3u8_content += f'{channel.get("name", "")}\n' m3u8_content += f"{channel.get('name', '')}\n"
m3u8_content += f'{working_url}\n' m3u8_content += f"{working_url}\n"
else: else:
# Log dead channel # Log dead channel
logging.info(f'Dead channel: {channel.get("name", "Unknown")} - No working streams found') logging.info(
f"Dead channel: {channel.get('name', 'Unknown')} - "
"No working streams found"
)
# Write playlist file # Write playlist file
with open(output_file, 'w', encoding='utf-8') as f: with open(output_file, "w", encoding="utf-8") as f:
f.write(m3u8_content) f.write(m3u8_content)
def upload_playlist(file_path): def upload_playlist(file_path):
"""Uploads playlist file to IPTV server using HTTP Basic Auth""" """Uploads playlist file to IPTV server using HTTP Basic Auth"""
try: try:
with open(file_path, 'rb') as f: with open(file_path, "rb") as f:
response = requests.post( response = requests.post(
IPTV_SERVER_URL + '/admin/playlist', IPTV_SERVER_URL + "/admin/playlist",
auth=requests.auth.HTTPBasicAuth(IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD), auth=requests.auth.HTTPBasicAuth(
files={'file': (os.path.basename(file_path), f)} IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD
),
files={"file": (os.path.basename(file_path), f)},
) )
if response.status_code == 200: if response.status_code == 200:
@@ -78,6 +104,7 @@ def upload_playlist(file_path):
except Exception as e: except Exception as e:
print(f"Upload error: {str(e)}") print(f"Upload error: {str(e)}")
def main(): def main():
args = parse_arguments() args = parse_arguments()
channels_file = args.channels channels_file = args.channels
@@ -85,24 +112,25 @@ def main():
dead_channels_log_file = args.dead_channels_log dead_channels_log_file = args.dead_channels_log
# Clear previous log file # Clear previous log file
with open(dead_channels_log_file, 'w') as f: with open(dead_channels_log_file, "w") as f:
f.write(f'Log created on {datetime.now()}\n') f.write(f"Log created on {datetime.now()}\n")
# Configure logging # Configure logging
logging.basicConfig( logging.basicConfig(
filename=dead_channels_log_file, filename=dead_channels_log_file,
level=logging.INFO, level=logging.INFO,
format='%(asctime)s - %(message)s', format="%(asctime)s - %(message)s",
datefmt='%Y-%m-%d %H:%M:%S' datefmt="%Y-%m-%d %H:%M:%S",
) )
# Create playlist # Create playlist
create_playlist(channels_file, output_file) create_playlist(channels_file, output_file)
#upload playlist to server # upload playlist to server
upload_playlist(output_file) upload_playlist(output_file)
print("Playlist creation completed!") print("Playlist creation completed!")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

110
app/iptv/scheduler.py Normal file
View File

@@ -0,0 +1,110 @@
import logging
import os
from typing import Optional
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from fastapi import FastAPI
from sqlalchemy.orm import Session
from app.iptv.stream_manager import StreamManager
from app.models.db import ChannelDB
from app.utils.database import get_db_session
logger = logging.getLogger(__name__)
class StreamScheduler:
"""Scheduler service for periodic stream validation tasks."""
def __init__(self, app: Optional[FastAPI] = None):
"""
Initialize the scheduler with optional FastAPI app integration.
Args:
app: Optional FastAPI app instance for lifecycle integration
"""
self.scheduler = BackgroundScheduler()
self.app = app
self.batch_size = int(os.getenv("STREAM_VALIDATION_BATCH_SIZE", "10"))
self.schedule_time = os.getenv(
"STREAM_VALIDATION_SCHEDULE", "0 3 * * *"
) # Default 3 AM daily
logger.info(f"Scheduler initialized with app: {app is not None}")
def validate_streams_batch(self, db_session: Optional[Session] = None) -> None:
"""
Validate streams and update their status.
When batch_size=0, validates all channels.
Args:
db_session: Optional SQLAlchemy session
"""
db = db_session if db_session else get_db_session()
try:
manager = StreamManager(db)
# Get channels to validate
query = db.query(ChannelDB)
if self.batch_size > 0:
query = query.limit(self.batch_size)
channels = query.all()
for channel in channels:
try:
logger.info(f"Validating streams for channel {channel.id}")
manager.validate_and_select_stream(str(channel.id))
except Exception as e:
logger.error(f"Error validating channel {channel.id}: {str(e)}")
continue
logger.info(f"Completed stream validation of {len(channels)} channels")
finally:
if db_session is None:
db.close()
def start(self) -> None:
"""Start the scheduler and add jobs."""
if not self.scheduler.running:
# Add the scheduled job
self.scheduler.add_job(
self.validate_streams_batch,
trigger=CronTrigger.from_crontab(self.schedule_time),
id="daily_stream_validation",
)
# Start the scheduler
self.scheduler.start()
logger.info(
f"Stream scheduler started with daily validation job. "
f"Running: {self.scheduler.running}"
)
# Register shutdown handler if FastAPI app is provided
if self.app:
logger.info(
f"Registering scheduler with FastAPI "
f"app: {hasattr(self.app, 'state')}"
)
@self.app.on_event("shutdown")
def shutdown_scheduler():
self.shutdown()
def shutdown(self) -> None:
"""Shutdown the scheduler gracefully."""
if self.scheduler.running:
self.scheduler.shutdown()
logger.info("Stream scheduler stopped")
def trigger_manual_validation(self) -> None:
"""Trigger manual validation of streams."""
logger.info("Manually triggering stream validation")
self.validate_streams_batch()
def init_scheduler(app: FastAPI) -> StreamScheduler:
"""Initialize and start the scheduler with FastAPI integration."""
scheduler = StreamScheduler(app)
scheduler.start()
return scheduler

151
app/iptv/stream_manager.py Normal file
View File

@@ -0,0 +1,151 @@
import logging
import random
from typing import Optional
from sqlalchemy.orm import Session
from app.models.db import ChannelURL
from app.utils.check_streams import StreamValidator
from app.utils.database import get_db_session
logger = logging.getLogger(__name__)
class StreamManager:
"""Service for managing and validating channel streams."""
def __init__(self, db_session: Optional[Session] = None):
"""
Initialize StreamManager with optional database session.
Args:
db_session: Optional SQLAlchemy session. If None, will create a new one.
"""
self.db = db_session if db_session else get_db_session()
self.validator = StreamValidator()
def get_streams_for_channel(self, channel_id: str) -> list[ChannelURL]:
"""
Get all streams for a channel ordered by priority (lowest first),
with same-priority streams randomized.
Args:
channel_id: UUID of the channel to get streams for
Returns:
List of ChannelURL objects ordered by priority
"""
try:
# Get all streams for channel ordered by priority
streams = (
self.db.query(ChannelURL)
.filter(ChannelURL.channel_id == channel_id)
.order_by(ChannelURL.priority_id)
.all()
)
# Group streams by priority and randomize same-priority streams
grouped = {}
for stream in streams:
if stream.priority_id not in grouped:
grouped[stream.priority_id] = []
grouped[stream.priority_id].append(stream)
# Randomize same-priority streams and flatten
randomized_streams = []
for priority in sorted(grouped.keys()):
random.shuffle(grouped[priority])
randomized_streams.extend(grouped[priority])
return randomized_streams
except Exception as e:
logger.error(f"Error getting streams for channel {channel_id}: {str(e)}")
raise
def validate_and_select_stream(self, channel_id: str) -> Optional[str]:
"""
Find and validate a working stream for the given channel.
Args:
channel_id: UUID of the channel to find a stream for
Returns:
URL of the first working stream found, or None if none found
"""
try:
streams = self.get_streams_for_channel(channel_id)
if not streams:
logger.warning(f"No streams found for channel {channel_id}")
return None
working_stream = None
for stream in streams:
logger.info(f"Validating stream {stream.url} for channel {channel_id}")
is_valid, _ = self.validator.validate_stream(stream.url)
if is_valid:
working_stream = stream
break
if working_stream:
self._update_stream_status(working_stream, streams)
return working_stream.url
else:
logger.warning(f"No valid streams found for channel {channel_id}")
return None
except Exception as e:
logger.error(f"Error validating streams for channel {channel_id}: {str(e)}")
raise
def _update_stream_status(
self, working_stream: ChannelURL, all_streams: list[ChannelURL]
) -> None:
"""
Update in_use status for streams (True for working stream, False for others).
Args:
working_stream: The stream that was validated as working
all_streams: All streams for the channel
"""
try:
for stream in all_streams:
stream.in_use = stream.id == working_stream.id
self.db.commit()
logger.info(
f"Updated stream status - set in_use=True for {working_stream.url}"
)
except Exception as e:
self.db.rollback()
logger.error(f"Error updating stream status: {str(e)}")
raise
def __del__(self):
"""Close database session when StreamManager is destroyed."""
if hasattr(self, "db"):
self.db.close()
def get_working_stream(
channel_id: str, db_session: Optional[Session] = None
) -> Optional[str]:
"""
Convenience function to get a working stream for a channel.
Args:
channel_id: UUID of the channel to get a stream for
db_session: Optional SQLAlchemy session
Returns:
URL of the first working stream found, or None if none found
"""
manager = StreamManager(db_session)
try:
return manager.validate_and_select_stream(channel_id)
finally:
if db_session is None: # Only close if we created the session
manager.__del__()

View File

@@ -1,42 +1,82 @@
from fastapi import FastAPI, Depends, HTTPException from fastapi import FastAPI
from fastapi.responses import JSONResponse, RedirectResponse from fastapi.concurrency import asynccontextmanager
from app.cabletv.utils.auth import exchange_code_for_token, get_current_user, DOMAIN, CLIENT_ID from fastapi.openapi.utils import get_openapi
from app.iptv.scheduler import StreamScheduler
from app.routers import auth, channels, groups, playlist, priorities, scheduler
from app.utils.database import init_db
@asynccontextmanager
async def lifespan(app: FastAPI):
# Initialize database tables on startup
init_db()
# Initialize and start the stream scheduler
scheduler = StreamScheduler(app)
app.state.scheduler = scheduler # Store scheduler in app state
scheduler.start()
yield
# Shutdown scheduler on app shutdown
scheduler.shutdown()
app = FastAPI(
lifespan=lifespan,
title="IPTV Manager API",
description="API for IPTV Manager service",
version="1.0.0",
)
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
version=app.version,
description=app.description,
routes=app.routes,
)
# Ensure components object exists
if "components" not in openapi_schema:
openapi_schema["components"] = {}
# Add schemas if they don't exist
if "schemas" not in openapi_schema["components"]:
openapi_schema["components"]["schemas"] = {}
# Add security scheme component
openapi_schema["components"]["securitySchemes"] = {
"Bearer": {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}
}
# Add global security requirement
openapi_schema["security"] = [{"Bearer": []}]
# Set OpenAPI version explicitly
openapi_schema["openapi"] = "3.1.0"
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi
app = FastAPI()
@app.get("/") @app.get("/")
async def root(): async def root():
return {"message": "IPTV Updater API"} return {"message": "IPTV Manager API"}
@app.get("/protected")
async def protected_route(user = Depends(get_current_user)):
if isinstance(user, RedirectResponse):
return user
return {"message": "Protected content", "user": user['Username']}
@app.get("/auth/callback") # Include routers
async def auth_callback(code: str): app.include_router(auth.router)
try: app.include_router(channels.router)
tokens = exchange_code_for_token(code) app.include_router(playlist.router)
app.include_router(priorities.router)
# Use id_token instead of access_token app.include_router(groups.router)
response = JSONResponse(content={ app.include_router(scheduler.router)
"message": "Authentication successful",
"id_token": tokens["id_token"] # Changed from access_token
})
# Store id_token in cookie
response.set_cookie(
key="token",
value=tokens["id_token"], # Changed from access_token
httponly=True,
secure=True,
samesite="lax"
)
return response
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Authentication failed: {str(e)}"
)

27
app/models/__init__.py Normal file
View File

@@ -0,0 +1,27 @@
from .db import Base, ChannelDB, ChannelURL, Group, Priority
from .schemas import (
ChannelCreate,
ChannelResponse,
ChannelUpdate,
ChannelURLCreate,
ChannelURLResponse,
GroupCreate,
GroupResponse,
GroupUpdate,
)
__all__ = [
"Base",
"ChannelDB",
"ChannelCreate",
"ChannelUpdate",
"ChannelResponse",
"ChannelURL",
"ChannelURLCreate",
"ChannelURLResponse",
"Group",
"Priority",
"GroupCreate",
"GroupResponse",
"GroupUpdate",
]

26
app/models/auth.py Normal file
View File

@@ -0,0 +1,26 @@
from typing import Optional
from pydantic import BaseModel, Field
class SigninRequest(BaseModel):
"""Request model for the signin endpoint."""
username: str = Field(..., description="The user's username")
password: str = Field(..., description="The user's password")
class TokenResponse(BaseModel):
"""Response model for successful authentication."""
access_token: str = Field(..., description="Access JWT token from Cognito")
id_token: str = Field(..., description="ID JWT token from Cognito")
refresh_token: Optional[str] = Field(None, description="Refresh token from Cognito")
token_type: str = Field(..., description="Type of the token returned")
class CognitoUser(BaseModel):
"""Model representing the user returned from token verification."""
username: str
roles: list[str]

139
app/models/db.py Normal file
View File

@@ -0,0 +1,139 @@
import os
import uuid
from datetime import datetime, timezone
from sqlalchemy import (
TEXT,
Boolean,
Column,
DateTime,
ForeignKey,
Integer,
String,
TypeDecorator,
UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import declarative_base, relationship
# Custom UUID type for SQLite compatibility
class SQLiteUUID(TypeDecorator):
"""Enables UUID support for SQLite with proper comparison handling."""
impl = TEXT
cache_ok = True
def process_bind_param(self, value, dialect):
if value is None:
return value
if isinstance(value, uuid.UUID):
return str(value)
try:
# Validate string format by attempting to create UUID
uuid.UUID(value)
return value
except (ValueError, AttributeError):
raise ValueError(f"Invalid UUID string format: {value}")
def process_result_value(self, value, dialect):
if value is None:
return value
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(value)
def compare_values(self, x, y):
if x is None or y is None:
return x == y
return str(x) == str(y)
# Determine which UUID type to use based on environment
if os.getenv("MOCK_AUTH", "").lower() == "true":
UUID_COLUMN_TYPE = SQLiteUUID()
else:
UUID_COLUMN_TYPE = UUID(as_uuid=True)
Base = declarative_base()
class Priority(Base):
"""SQLAlchemy model for channel URL priorities"""
__tablename__ = "priorities"
id = Column(Integer, primary_key=True)
description = Column(String, nullable=False)
class Group(Base):
"""SQLAlchemy model for channel groups"""
__tablename__ = "groups"
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
name = Column(String, nullable=False, unique=True)
sort_order = Column(Integer, nullable=False, default=0)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(
DateTime,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
# Relationship with Channel
channels = relationship("ChannelDB", back_populates="group")
class ChannelDB(Base):
"""SQLAlchemy model for IPTV channels"""
__tablename__ = "channels"
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
tvg_id = Column(String, nullable=False)
name = Column(String, nullable=False)
group_id = Column(UUID_COLUMN_TYPE, ForeignKey("groups.id"), nullable=False)
tvg_name = Column(String)
__table_args__ = (UniqueConstraint("group_id", "name", name="uix_group_id_name"),)
tvg_logo = Column(String)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(
DateTime,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
# Relationships
urls = relationship(
"ChannelURL", back_populates="channel", cascade="all, delete-orphan"
)
group = relationship("Group", back_populates="channels")
class ChannelURL(Base):
"""SQLAlchemy model for channel URLs"""
__tablename__ = "channels_urls"
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
channel_id = Column(
UUID_COLUMN_TYPE,
ForeignKey("channels.id", ondelete="CASCADE"),
nullable=False,
)
url = Column(String, nullable=False)
in_use = Column(Boolean, default=False, nullable=False)
priority_id = Column(Integer, ForeignKey("priorities.id"), nullable=False)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(
DateTime,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
# Relationships
channel = relationship("ChannelDB", back_populates="urls")
priority = relationship("Priority")

140
app/models/schemas.py Normal file
View File

@@ -0,0 +1,140 @@
from datetime import datetime
from typing import Optional
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
class PriorityBase(BaseModel):
"""Base Pydantic model for priorities"""
id: int
description: str
model_config = ConfigDict(from_attributes=True)
class PriorityCreate(PriorityBase):
"""Pydantic model for creating priorities"""
pass
class PriorityResponse(PriorityBase):
"""Pydantic model for priority responses"""
pass
class ChannelURLCreate(BaseModel):
"""Pydantic model for creating channel URLs"""
url: str
priority_id: int = Field(
default=100, ge=100, le=300
) # Default to High, validate range
class ChannelURLBase(ChannelURLCreate):
"""Base Pydantic model for channel URL responses"""
id: UUID
in_use: bool
created_at: datetime
updated_at: datetime
priority_id: int
model_config = ConfigDict(from_attributes=True)
class ChannelURLResponse(ChannelURLBase):
"""Pydantic model for channel URL responses"""
pass
# New Group Schemas
class GroupCreate(BaseModel):
"""Pydantic model for creating groups"""
name: str
sort_order: int = Field(default=0, ge=0)
class GroupUpdate(BaseModel):
"""Pydantic model for updating groups"""
name: Optional[str] = None
sort_order: Optional[int] = Field(None, ge=0)
class GroupResponse(BaseModel):
"""Pydantic model for group responses"""
id: UUID
name: str
sort_order: int
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
class GroupSortUpdate(BaseModel):
"""Pydantic model for updating a single group's sort order"""
sort_order: int = Field(ge=0)
class GroupBulkSort(BaseModel):
"""Pydantic model for bulk updating group sort orders"""
groups: list[dict] = Field(
description="List of dicts with group_id and new sort_order",
json_schema_extra={"example": [{"group_id": "uuid", "sort_order": 1}]},
)
class ChannelCreate(BaseModel):
"""Pydantic model for creating channels"""
urls: list[ChannelURLCreate] # List of URL objects with priority
name: str
group_id: UUID
tvg_id: str
tvg_logo: str
tvg_name: str
class ChannelURLUpdate(BaseModel):
"""Pydantic model for updating channel URLs"""
url: Optional[str] = None
in_use: Optional[bool] = None
priority_id: Optional[int] = Field(default=None, ge=100, le=300)
class ChannelUpdate(BaseModel):
"""Pydantic model for updating channels (all fields optional)"""
name: Optional[str] = Field(None, min_length=1)
group_id: Optional[UUID] = None
tvg_id: Optional[str] = Field(None, min_length=1)
tvg_logo: Optional[str] = None
tvg_name: Optional[str] = Field(None, min_length=1)
class ChannelResponse(BaseModel):
"""Pydantic model for channel responses"""
id: UUID
name: str
group_id: UUID
tvg_id: str
tvg_logo: str
tvg_name: str
urls: list[ChannelURLResponse] # List of URL objects without channel_id
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)

0
app/routers/__init__.py Normal file
View File

22
app/routers/auth.py Normal file
View File

@@ -0,0 +1,22 @@
from fastapi import APIRouter
from app.auth.cognito import initiate_auth
from app.models.auth import SigninRequest, TokenResponse
router = APIRouter(prefix="/auth", tags=["authentication"])
@router.post("/signin", response_model=TokenResponse, summary="Signin Endpoint")
def signin(credentials: SigninRequest):
"""
Sign-in endpoint to authenticate the user with AWS Cognito
using username and password.
On success, returns JWT tokens (access_token, id_token, refresh_token).
"""
auth_result = initiate_auth(credentials.username, credentials.password)
return TokenResponse(
access_token=auth_result["AccessToken"],
id_token=auth_result["IdToken"],
refresh_token=auth_result.get("RefreshToken"),
token_type="Bearer",
)

508
app/routers/channels.py Normal file
View File

@@ -0,0 +1,508 @@
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import and_
from sqlalchemy.orm import Session
from app.auth.dependencies import get_current_user, require_roles
from app.models import (
ChannelCreate,
ChannelDB,
ChannelResponse,
ChannelUpdate,
ChannelURL,
ChannelURLCreate,
ChannelURLResponse,
Group,
Priority, # Added Priority import
)
from app.models.auth import CognitoUser
from app.models.schemas import ChannelURLUpdate
from app.utils.database import get_db
router = APIRouter(prefix="/channels", tags=["channels"])
@router.post("/", response_model=ChannelResponse, status_code=status.HTTP_201_CREATED)
@require_roles("admin")
def create_channel(
channel: ChannelCreate,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Create a new channel"""
# Check if group exists
group = db.query(Group).filter(Group.id == channel.group_id).first()
if not group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Group not found",
)
# Check for duplicate channel (same group_id + name)
existing_channel = (
db.query(ChannelDB)
.filter(
and_(
ChannelDB.group_id == channel.group_id,
ChannelDB.name == channel.name,
)
)
.first()
)
if existing_channel:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Channel with same group_id and name already exists",
)
# Create channel without URLs first
channel_data = channel.model_dump(exclude={"urls"})
urls = channel.urls
db_channel = ChannelDB(**channel_data)
db.add(db_channel)
db.commit()
db.refresh(db_channel)
# Add URLs with priority
for url in urls:
db_url = ChannelURL(
channel_id=db_channel.id,
url=url.url,
priority_id=url.priority_id,
in_use=False,
)
db.add(db_url)
db.commit()
db.refresh(db_channel)
return db_channel
@router.get("/{channel_id}", response_model=ChannelResponse)
def get_channel(channel_id: UUID, db: Session = Depends(get_db)):
"""Get a channel by id"""
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
)
return channel
@router.put("/{channel_id}", response_model=ChannelResponse)
@require_roles("admin")
def update_channel(
channel_id: UUID,
channel: ChannelUpdate,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Update a channel"""
db_channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not db_channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
)
# Only check for duplicates if name or group_id are being updated
if channel.name is not None or channel.group_id is not None:
name = channel.name if channel.name is not None else db_channel.name
group_id = (
channel.group_id if channel.group_id is not None else db_channel.group_id
)
# Check if new group exists
if channel.group_id is not None:
group = db.query(Group).filter(Group.id == channel.group_id).first()
if not group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Group not found",
)
existing_channel = (
db.query(ChannelDB)
.filter(
and_(
ChannelDB.group_id == group_id,
ChannelDB.name == name,
ChannelDB.id != channel_id,
)
)
.first()
)
if existing_channel:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Channel with same group_id and name already exists",
)
# Update only provided fields
update_data = channel.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_channel, key, value)
db.commit()
db.refresh(db_channel)
return db_channel
@router.delete("/", status_code=status.HTTP_200_OK)
@require_roles("admin")
def delete_channels(
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Delete all channels"""
count = 0
try:
count = db.query(ChannelDB).count()
# First delete all channels
db.query(ChannelDB).delete()
# Then delete any URLs that are now orphaned (no channel references)
db.query(ChannelURL).filter(
~ChannelURL.channel_id.in_(db.query(ChannelDB.id))
).delete(synchronize_session=False)
# Then delete any groups that are now empty
db.query(Group).filter(~Group.id.in_(db.query(ChannelDB.group_id))).delete(
synchronize_session=False
)
db.commit()
except Exception as e:
print(f"Error deleting channels: {e}")
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to delete channels",
)
return {"deleted": count}
@router.delete("/{channel_id}", status_code=status.HTTP_204_NO_CONTENT)
@require_roles("admin")
def delete_channel(
channel_id: UUID,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Delete a channel"""
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
)
db.delete(channel)
db.commit()
return None
@router.get("/", response_model=list[ChannelResponse])
@require_roles("admin")
def list_channels(
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""List all channels with pagination"""
return db.query(ChannelDB).offset(skip).limit(limit).all()
# New endpoint to get channels by group
@router.get("/groups/{group_id}/channels", response_model=list[ChannelResponse])
def get_channels_by_group(
group_id: UUID,
db: Session = Depends(get_db),
):
"""Get all channels for a specific group"""
group = db.query(Group).filter(Group.id == group_id).first()
if not group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
)
return db.query(ChannelDB).filter(ChannelDB.group_id == group_id).all()
# New endpoint to update a channel's group
@router.put("/{channel_id}/group", response_model=ChannelResponse)
@require_roles("admin")
def update_channel_group(
channel_id: UUID,
group_id: UUID,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Update a channel's group"""
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
)
group = db.query(Group).filter(Group.id == group_id).first()
if not group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
)
# Check for duplicate channel name in new group
existing_channel = (
db.query(ChannelDB)
.filter(
and_(
ChannelDB.group_id == group_id,
ChannelDB.name == channel.name,
ChannelDB.id != channel_id,
)
)
.first()
)
if existing_channel:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Channel with same name already exists in target group",
)
channel.group_id = group_id
db.commit()
db.refresh(channel)
return channel
# Bulk Upload and Reset Endpoints
@router.post("/bulk-upload", status_code=status.HTTP_200_OK)
@require_roles("admin")
def bulk_upload_channels(
channels: list[dict],
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Bulk upload channels from JSON array"""
processed = 0
# Fetch all priorities from the database, ordered by id
priorities = db.query(Priority).order_by(Priority.id).all()
priority_map = {i: p.id for i, p in enumerate(priorities)}
# Get the highest priority_id (which corresponds to the lowest priority level)
max_priority_id = None
if priorities:
max_priority_id = db.query(Priority.id).order_by(Priority.id.desc()).first()[0]
for channel_data in channels:
try:
# Get or create group
group_name = channel_data.get("group-title")
if not group_name:
continue
group = db.query(Group).filter(Group.name == group_name).first()
if not group:
group = Group(name=group_name)
db.add(group)
db.flush() # Use flush to make the group available in the session
db.refresh(group)
# Prepare channel data
urls = channel_data.get("urls", [])
if not isinstance(urls, list):
urls = [urls]
# Assign priorities dynamically based on fetched priorities
url_objects = []
for i, url in enumerate(urls): # Process all URLs
priority_id = priority_map.get(i)
if priority_id is None:
# If index is out of bounds,
# assign the highest priority_id (lowest priority)
if max_priority_id is not None:
priority_id = max_priority_id
else:
print(
f"Warning: No priorities defined in database. "
f"Skipping URL {url}"
)
continue
url_objects.append({"url": url, "priority_id": priority_id})
# Create channel object with required fields
channel_obj = ChannelDB(
tvg_id=channel_data.get("tvg-id", ""),
name=channel_data.get("name", ""),
group_id=group.id,
tvg_name=channel_data.get("tvg-name", ""),
tvg_logo=channel_data.get("tvg-logo", ""),
)
# Upsert channel
existing_channel = (
db.query(ChannelDB)
.filter(
and_(
ChannelDB.group_id == group.id,
ChannelDB.name == channel_obj.name,
)
)
.first()
)
if existing_channel:
# Update existing
existing_channel.tvg_id = channel_obj.tvg_id
existing_channel.tvg_name = channel_obj.tvg_name
existing_channel.tvg_logo = channel_obj.tvg_logo
# Clear and recreate URLs
db.query(ChannelURL).filter(
ChannelURL.channel_id == existing_channel.id
).delete()
for url in url_objects:
db_url = ChannelURL(
channel_id=existing_channel.id,
url=url["url"],
priority_id=url["priority_id"],
in_use=False,
)
db.add(db_url)
else:
# Create new
db.add(channel_obj)
db.flush() # Flush to get the new channel's ID
db.refresh(channel_obj)
# Add URLs for new channel
for url in url_objects:
db_url = ChannelURL(
channel_id=channel_obj.id,
url=url["url"],
priority_id=url["priority_id"],
in_use=False,
)
db.add(db_url)
db.commit() # Commit all changes for this channel atomically
processed += 1
except Exception as e:
print(f"Error processing channel: {channel_data.get('name', 'Unknown')}")
print(f"Exception details: {e}")
db.rollback() # Rollback the entire transaction for the failed channel
continue
return {"processed": processed}
# URL Management Endpoints
@router.post(
"/{channel_id}/urls",
response_model=ChannelURLResponse,
status_code=status.HTTP_201_CREATED,
)
@require_roles("admin")
def add_channel_url(
channel_id: UUID,
url: ChannelURLCreate,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Add a new URL to a channel"""
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
)
db_url = ChannelURL(
channel_id=channel_id,
url=url.url,
priority_id=url.priority_id,
in_use=False, # Default to not in use
)
db.add(db_url)
db.commit()
db.refresh(db_url)
return db_url
@router.put("/{channel_id}/urls/{url_id}", response_model=ChannelURLResponse)
@require_roles("admin")
def update_channel_url(
channel_id: UUID,
url_id: UUID,
url_update: ChannelURLUpdate,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Update a channel URL (url, in_use, or priority_id)"""
db_url = (
db.query(ChannelURL)
.filter(and_(ChannelURL.id == url_id, ChannelURL.channel_id == channel_id))
.first()
)
if not db_url:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="URL not found"
)
if url_update.url is not None:
db_url.url = url_update.url
if url_update.in_use is not None:
db_url.in_use = url_update.in_use
if url_update.priority_id is not None:
db_url.priority_id = url_update.priority_id
db.commit()
db.refresh(db_url)
return db_url
@router.delete("/{channel_id}/urls/{url_id}", status_code=status.HTTP_204_NO_CONTENT)
@require_roles("admin")
def delete_channel_url(
channel_id: UUID,
url_id: UUID,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Delete a URL from a channel"""
url = (
db.query(ChannelURL)
.filter(and_(ChannelURL.id == url_id, ChannelURL.channel_id == channel_id))
.first()
)
if not url:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="URL not found"
)
db.delete(url)
db.commit()
return None
@router.get("/{channel_id}/urls", response_model=list[ChannelURLResponse])
@require_roles("admin")
def list_channel_urls(
channel_id: UUID,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""List all URLs for a channel"""
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
)
return db.query(ChannelURL).filter(ChannelURL.channel_id == channel_id).all()

191
app/routers/groups.py Normal file
View File

@@ -0,0 +1,191 @@
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from app.auth.dependencies import get_current_user, require_roles
from app.models import Group
from app.models.auth import CognitoUser
from app.models.schemas import (
GroupBulkSort,
GroupCreate,
GroupResponse,
GroupSortUpdate,
GroupUpdate,
)
from app.utils.database import get_db
router = APIRouter(prefix="/groups", tags=["groups"])
@router.post("/", response_model=GroupResponse, status_code=status.HTTP_201_CREATED)
@require_roles("admin")
def create_group(
group: GroupCreate,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Create a new channel group"""
# Check for duplicate group name
existing_group = db.query(Group).filter(Group.name == group.name).first()
if existing_group:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Group with this name already exists",
)
db_group = Group(**group.model_dump())
db.add(db_group)
db.commit()
db.refresh(db_group)
return db_group
@router.get("/{group_id}", response_model=GroupResponse)
def get_group(group_id: UUID, db: Session = Depends(get_db)):
"""Get a group by id"""
group = db.query(Group).filter(Group.id == group_id).first()
if not group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
)
return group
@router.put("/{group_id}", response_model=GroupResponse)
@require_roles("admin")
def update_group(
group_id: UUID,
group: GroupUpdate,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Update a group's name or sort order"""
db_group = db.query(Group).filter(Group.id == group_id).first()
if not db_group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
)
# Check for duplicate name if name is being updated
if group.name is not None and group.name != db_group.name:
existing_group = db.query(Group).filter(Group.name == group.name).first()
if existing_group:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Group with this name already exists",
)
# Update only provided fields
update_data = group.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_group, key, value)
db.commit()
db.refresh(db_group)
return db_group
@router.delete("/", status_code=status.HTTP_200_OK)
@require_roles("admin")
def delete_groups(
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Delete all groups that have no channels (skip groups with channels)"""
groups = db.query(Group).all()
deleted = 0
skipped = 0
for group in groups:
if not group.channels:
db.delete(group)
deleted += 1
else:
skipped += 1
db.commit()
return {"deleted": deleted, "skipped": skipped}
@router.delete("/{group_id}", status_code=status.HTTP_204_NO_CONTENT)
@require_roles("admin")
def delete_group(
group_id: UUID,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Delete a group (only if it has no channels)"""
group = db.query(Group).filter(Group.id == group_id).first()
if not group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
)
# Check if group has any channels
if group.channels:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot delete group with existing channels",
)
db.delete(group)
db.commit()
return None
@router.get("/", response_model=list[GroupResponse])
def list_groups(db: Session = Depends(get_db)):
"""List all groups sorted by sort_order"""
return db.query(Group).order_by(Group.sort_order).all()
@router.put("/{group_id}/sort", response_model=GroupResponse)
@require_roles("admin")
def update_group_sort_order(
group_id: UUID,
sort_update: GroupSortUpdate,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Update a single group's sort order"""
db_group = db.query(Group).filter(Group.id == group_id).first()
if not db_group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
)
db_group.sort_order = sort_update.sort_order
db.commit()
db.refresh(db_group)
return db_group
@router.post("/reorder", response_model=list[GroupResponse])
@require_roles("admin")
def bulk_update_sort_orders(
bulk_sort: GroupBulkSort,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Bulk update group sort orders"""
groups_to_update = []
for group_data in bulk_sort.groups:
group_id = group_data["group_id"]
sort_order = group_data["sort_order"]
group = db.query(Group).filter(Group.id == str(group_id)).first()
if not group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Group with id {group_id} not found",
)
group.sort_order = sort_order
groups_to_update.append(group)
db.commit()
# Return all groups in their new order
return db.query(Group).order_by(Group.sort_order).all()

156
app/routers/playlist.py Normal file
View File

@@ -0,0 +1,156 @@
import logging
from enum import Enum
from typing import Optional
from uuid import uuid4
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.auth.dependencies import get_current_user
from app.iptv.stream_manager import StreamManager
from app.models.auth import CognitoUser
from app.utils.database import get_db_session
router = APIRouter(prefix="/playlist", tags=["playlist"])
logger = logging.getLogger(__name__)
# In-memory store for validation processes
validation_processes: dict[str, dict] = {}
class ProcessStatus(str, Enum):
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
FAILED = "failed"
class StreamValidationRequest(BaseModel):
"""Request model for stream validation endpoint"""
channel_id: Optional[str] = None
class ValidatedStream(BaseModel):
"""Model for a validated working stream"""
channel_id: str
stream_url: str
class ValidationProcessResponse(BaseModel):
"""Response model for validation process initiation"""
process_id: str
status: ProcessStatus
message: str
class ValidationResultResponse(BaseModel):
"""Response model for validation results"""
process_id: str
status: ProcessStatus
working_streams: Optional[list[ValidatedStream]] = None
error: Optional[str] = None
def run_stream_validation(process_id: str, channel_id: Optional[str], db: Session):
"""Background task to validate streams"""
try:
validation_processes[process_id]["status"] = ProcessStatus.IN_PROGRESS
manager = StreamManager(db)
if channel_id:
stream_url = manager.validate_and_select_stream(channel_id)
if stream_url:
validation_processes[process_id]["result"] = {
"working_streams": [
ValidatedStream(channel_id=channel_id, stream_url=stream_url)
]
}
else:
validation_processes[process_id]["error"] = (
f"No working streams found for channel {channel_id}"
)
else:
# TODO: Implement validation for all channels
validation_processes[process_id]["error"] = (
"Validation of all channels not yet implemented"
)
validation_processes[process_id]["status"] = ProcessStatus.COMPLETED
except Exception as e:
logger.error(f"Error validating streams: {str(e)}")
validation_processes[process_id]["status"] = ProcessStatus.FAILED
validation_processes[process_id]["error"] = str(e)
@router.post(
"/validate-streams",
summary="Start stream validation process",
response_model=ValidationProcessResponse,
status_code=status.HTTP_202_ACCEPTED,
responses={202: {"description": "Validation process started successfully"}},
)
async def start_stream_validation(
request: StreamValidationRequest,
background_tasks: BackgroundTasks,
user: CognitoUser = Depends(get_current_user),
db: Session = Depends(get_db_session),
):
"""
Start asynchronous validation of streams.
- Returns immediately with a process ID
- Use GET /validate-streams/{process_id} to check status
"""
process_id = str(uuid4())
validation_processes[process_id] = {
"status": ProcessStatus.PENDING,
"channel_id": request.channel_id,
}
background_tasks.add_task(run_stream_validation, process_id, request.channel_id, db)
return {
"process_id": process_id,
"status": ProcessStatus.PENDING,
"message": "Validation process started",
}
@router.get(
"/validate-streams/{process_id}",
summary="Check validation process status",
response_model=ValidationResultResponse,
responses={
200: {"description": "Process status and results"},
404: {"description": "Process not found"},
},
)
async def get_validation_status(
process_id: str, user: CognitoUser = Depends(get_current_user)
):
"""
Check status of a stream validation process.
Returns current status and results if completed.
"""
if process_id not in validation_processes:
raise HTTPException(status_code=404, detail="Process not found")
process = validation_processes[process_id]
response = {"process_id": process_id, "status": process["status"]}
if process["status"] == ProcessStatus.COMPLETED:
if "error" in process:
response["error"] = process["error"]
else:
response["working_streams"] = process["result"]["working_streams"]
elif process["status"] == ProcessStatus.FAILED:
response["error"] = process["error"]
return response

120
app/routers/priorities.py Normal file
View File

@@ -0,0 +1,120 @@
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from app.auth.dependencies import get_current_user, require_roles
from app.models.auth import CognitoUser
from app.models.db import Priority
from app.models.schemas import PriorityCreate, PriorityResponse
from app.utils.database import get_db
router = APIRouter(prefix="/priorities", tags=["priorities"])
@router.post("/", response_model=PriorityResponse, status_code=status.HTTP_201_CREATED)
@require_roles("admin")
def create_priority(
priority: PriorityCreate,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Create a new priority"""
# Check if priority with this ID already exists
existing = db.get(Priority, priority.id)
if existing:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Priority with ID {priority.id} already exists",
)
db_priority = Priority(**priority.model_dump())
db.add(db_priority)
db.commit()
db.refresh(db_priority)
return db_priority
@router.get("/", response_model=list[PriorityResponse])
@require_roles("admin")
def list_priorities(
db: Session = Depends(get_db), user: CognitoUser = Depends(get_current_user)
):
"""List all priorities"""
return db.query(Priority).all()
@router.get("/{priority_id}", response_model=PriorityResponse)
@require_roles("admin")
def get_priority(
priority_id: int,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Get a priority by id"""
priority = db.get(Priority, priority_id)
if not priority:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Priority not found"
)
return priority
@router.delete("/", status_code=status.HTTP_200_OK)
@require_roles("admin")
def delete_priorities(
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Delete all priorities not in use by channel URLs"""
from app.models.db import ChannelURL
priorities = db.query(Priority).all()
deleted = 0
skipped = 0
for priority in priorities:
in_use = db.scalar(
select(ChannelURL).where(ChannelURL.priority_id == priority.id).limit(1)
)
if not in_use:
db.delete(priority)
deleted += 1
else:
skipped += 1
db.commit()
return {"deleted": deleted, "skipped": skipped}
@router.delete("/{priority_id}", status_code=status.HTTP_204_NO_CONTENT)
@require_roles("admin")
def delete_priority(
priority_id: int,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Delete a priority (if not in use)"""
from app.models.db import ChannelURL
# Check if priority exists
priority = db.get(Priority, priority_id)
if not priority:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Priority not found"
)
# Check if priority is in use
in_use = db.scalar(
select(ChannelURL).where(ChannelURL.priority_id == priority_id).limit(1)
)
if in_use:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Cannot delete priority that is in use by channel URLs",
)
db.execute(delete(Priority).where(Priority.id == priority_id))
db.commit()
return None

57
app/routers/scheduler.py Normal file
View File

@@ -0,0 +1,57 @@
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from app.auth.dependencies import get_current_user, require_roles
from app.iptv.scheduler import StreamScheduler
from app.models.auth import CognitoUser
from app.utils.database import get_db
router = APIRouter(
prefix="/scheduler",
tags=["scheduler"],
responses={404: {"description": "Not found"}},
)
async def get_scheduler(request: Request) -> StreamScheduler:
"""Get the scheduler instance from the app state."""
if not hasattr(request.app.state.scheduler, "scheduler"):
raise HTTPException(status_code=500, detail="Scheduler not initialized")
return request.app.state.scheduler
@router.get("/health")
@require_roles("admin")
def scheduler_health(
scheduler: StreamScheduler = Depends(get_scheduler),
user: CognitoUser = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""Check scheduler health status (admin only)."""
try:
job = scheduler.scheduler.get_job("daily_stream_validation")
next_run = str(job.next_run_time) if job and job.next_run_time else None
return {
"status": "running" if scheduler.scheduler.running else "stopped",
"next_run": next_run,
}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to check scheduler health: {str(e)}"
)
@router.post("/trigger")
@require_roles("admin")
def trigger_validation(
scheduler: StreamScheduler = Depends(get_scheduler),
user: CognitoUser = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""Manually trigger stream validation (admin only)."""
scheduler.trigger_manual_validation()
return JSONResponse(
status_code=202, content={"message": "Stream validation triggered"}
)

0
app/utils/__init__.py Normal file
View File

14
app/utils/auth.py Normal file
View File

@@ -0,0 +1,14 @@
import base64
import hashlib
import hmac
def calculate_secret_hash(username: str, client_id: str, client_secret: str) -> str:
"""
Calculate the Cognito SECRET_HASH using HMAC SHA256 for secret-enabled clients.
"""
msg = username + client_id
dig = hmac.new(
client_secret.encode("utf-8"), msg.encode("utf-8"), hashlib.sha256
).digest()
return base64.b64encode(dig).decode()

View File

@@ -1,32 +1,41 @@
import os
import argparse import argparse
import requests
import logging import logging
from requests.exceptions import RequestException, Timeout, ConnectionError, HTTPError import os
import requests
from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout
class StreamValidator: class StreamValidator:
def __init__(self, timeout=10, user_agent=None): def __init__(self, timeout=10, user_agent=None):
self.timeout = timeout self.timeout = timeout
self.session = requests.Session() self.session = requests.Session()
self.session.headers.update({ self.session.headers.update(
'User-Agent': user_agent or 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36' {
}) "User-Agent": user_agent
or (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/120.0.0.0 Safari/537.36"
)
}
)
def validate_stream(self, url): def validate_stream(self, url):
"""Validate a media stream URL with multiple fallback checks""" """Validate a media stream URL with multiple fallback checks"""
try: try:
headers = {'Range': 'bytes=0-1024'} headers = {"Range": "bytes=0-1024"}
with self.session.get( with self.session.get(
url, url,
headers=headers, headers=headers,
timeout=self.timeout, timeout=self.timeout,
stream=True, stream=True,
allow_redirects=True allow_redirects=True,
) as response: ) as response:
if response.status_code not in [200, 206]: if response.status_code not in [200, 206]:
return False, f"Invalid status code: {response.status_code}" return False, f"Invalid status code: {response.status_code}"
content_type = response.headers.get('Content-Type', '') content_type = response.headers.get("Content-Type", "")
if not self._is_valid_content_type(content_type): if not self._is_valid_content_type(content_type):
return False, f"Invalid content type: {content_type}" return False, f"Invalid content type: {content_type}"
@@ -49,47 +58,50 @@ class StreamValidator:
def _is_valid_content_type(self, content_type): def _is_valid_content_type(self, content_type):
valid_types = [ valid_types = [
'video/mp2t', 'application/vnd.apple.mpegurl', "video/mp2t",
'application/dash+xml', 'video/mp4', "application/vnd.apple.mpegurl",
'video/webm', 'application/octet-stream', "application/dash+xml",
'application/x-mpegURL' "video/mp4",
"video/webm",
"application/octet-stream",
"application/x-mpegURL",
] ]
if content_type is None:
return False
return any(ct in content_type for ct in valid_types) return any(ct in content_type for ct in valid_types)
def parse_playlist(self, file_path): def parse_playlist(self, file_path):
"""Extract stream URLs from M3U playlist file""" """Extract stream URLs from M3U playlist file"""
urls = [] urls = []
try: try:
with open(file_path, 'r') as f: with open(file_path) as f:
for line in f: for line in f:
line = line.strip() line = line.strip()
if line and not line.startswith('#'): if line and not line.startswith("#"):
urls.append(line) urls.append(line)
except Exception as e: except Exception as e:
logging.error(f"Error reading playlist file: {str(e)}") logging.error(f"Error reading playlist file: {str(e)}")
raise raise
return urls return urls
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Validate streaming URLs from command line arguments or playlist files', description=(
formatter_class=argparse.ArgumentDefaultsHelpFormatter "Validate streaming URLs from command line arguments or playlist files"
),
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
) )
parser.add_argument( parser.add_argument(
'sources', "sources", nargs="+", help="List of URLs or file paths containing stream URLs"
nargs='+',
help='List of URLs or file paths containing stream URLs'
) )
parser.add_argument( parser.add_argument(
'--timeout', "--timeout", type=int, default=20, help="Timeout in seconds for stream checks"
type=int,
default=20,
help='Timeout in seconds for stream checks'
) )
parser.add_argument( parser.add_argument(
'--output', "--output",
default='deadstreams.txt', default="deadstreams.txt",
help='Output file name for inactive streams' help="Output file name for inactive streams",
) )
args = parser.parse_args() args = parser.parse_args()
@@ -97,8 +109,8 @@ def main():
# Configure logging # Configure logging
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s', format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[logging.FileHandler('stream_check.log'), logging.StreamHandler()] handlers=[logging.FileHandler("stream_check.log"), logging.StreamHandler()],
) )
validator = StreamValidator(timeout=args.timeout) validator = StreamValidator(timeout=args.timeout)
@@ -127,9 +139,10 @@ def main():
# Save dead streams to file # Save dead streams to file
if dead_streams: if dead_streams:
with open(args.output, 'w') as f: with open(args.output, "w") as f:
f.write('\n'.join(dead_streams)) f.write("\n".join(dead_streams))
logging.info(f"Found {len(dead_streams)} dead streams. Saved to {args.output}.") logging.info(f"Found {len(dead_streams)} dead streams. Saved to {args.output}.")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -1,12 +1,21 @@
# Utility functions and constants
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
# Load environment variables from a .env file if it exists # Load environment variables from a .env file if it exists
load_dotenv() load_dotenv()
# AWS related constants
AWS_REGION = os.environ.get("AWS_REGION", "us-east-2")
COGNITO_USER_POOL_ID = os.getenv("COGNITO_USER_POOL_ID")
COGNITO_CLIENT_ID = os.getenv("COGNITO_CLIENT_ID")
COGNITO_CLIENT_SECRET = os.environ.get("COGNITO_CLIENT_SECRET", None)
USER_ROLE_ATTRIBUTE = "zoneinfo"
IPTV_SERVER_URL = os.getenv("IPTV_SERVER_URL", "https://iptv.fiorinis.com") IPTV_SERVER_URL = os.getenv("IPTV_SERVER_URL", "https://iptv.fiorinis.com")
# Super iptv-server admin credentials for basic auth # iptv-server super admin credentials for basic auth
# Reads from environment variables IPTV_SERVER_ADMIN_USER and IPTV_SERVER_ADMIN_PASSWORD # Reads from environment variables IPTV_SERVER_ADMIN_USER and IPTV_SERVER_ADMIN_PASSWORD
IPTV_SERVER_ADMIN_USER = os.getenv("IPTV_SERVER_ADMIN_USER", "admin") IPTV_SERVER_ADMIN_USER = os.getenv("IPTV_SERVER_ADMIN_USER", "admin")
IPTV_SERVER_ADMIN_PASSWORD = os.getenv("IPTV_SERVER_ADMIN_PASSWORD", "adminpassword") IPTV_SERVER_ADMIN_PASSWORD = os.getenv("IPTV_SERVER_ADMIN_PASSWORD", "adminpassword")

61
app/utils/database.py Normal file
View File

@@ -0,0 +1,61 @@
import os
import boto3
from requests import Session
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from app.models import Base
from .constants import AWS_REGION
def get_db_credentials():
"""Fetch and cache DB credentials from environment or SSM Parameter Store"""
if os.getenv("MOCK_AUTH", "").lower() == "true":
return (
f"postgresql://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}"
f"@{os.getenv('DB_HOST')}/{os.getenv('DB_NAME')}"
)
ssm = boto3.client("ssm", region_name=AWS_REGION)
try:
host = ssm.get_parameter(Name="/iptv-manager/DB_HOST", WithDecryption=True)[
"Parameter"
]["Value"]
user = ssm.get_parameter(Name="/iptv-manager/DB_USER", WithDecryption=True)[
"Parameter"
]["Value"]
password = ssm.get_parameter(
Name="/iptv-manager/DB_PASSWORD", WithDecryption=True
)["Parameter"]["Value"]
dbname = ssm.get_parameter(Name="/iptv-manager/DB_NAME", WithDecryption=True)[
"Parameter"
]["Value"]
return f"postgresql://{user}:{password}@{host}/{dbname}"
except Exception as e:
raise RuntimeError(f"Failed to fetch DB credentials from SSM: {str(e)}")
# Initialize engine and session maker
engine = create_engine(get_db_credentials())
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def init_db():
"""Initialize database by creating all tables"""
Base.metadata.create_all(bind=engine)
def get_db():
"""Dependency for getting database session"""
db = SessionLocal()
try:
yield db
finally:
db.close()
def get_db_session() -> Session:
"""Get a direct database session (non-generator version)"""
return SessionLocal()

View File

@@ -1,23 +0,0 @@
#!/bin/bash
# Deploy infrastructure
cdk deploy
# Update application on running instances
INSTANCE_IDS=$(aws ec2 describe-instances \
--filters "Name=tag:Name,Values=IptvUpdater/IptvUpdaterInstance" \
"Name=instance-state-name,Values=running" \
--query "Reservations[].Instances[].InstanceId" \
--output text)
for INSTANCE_ID in $INSTANCE_IDS; do
echo "Updating application on instance: $INSTANCE_ID"
aws ssm send-command \
--instance-ids "$INSTANCE_ID" \
--document-name "AWS-RunShellScript" \
--parameters '{"commands":["cd /home/ec2-user/iptv-updater-aws && git pull && pip3 install -r requirements.txt && sudo systemctl restart iptv-updater"]}' \
--no-cli-pager \
--no-paginate
done
echo "Deployment and instance update complete"

View File

@@ -1,4 +0,0 @@
#!/bin/bash
# Destroy infrastructure
cdk destroy

28
docker/Dockerfile Normal file
View File

@@ -0,0 +1,28 @@
# Use official Python image
FROM python:3.9-slim
# Set environment variables
ENV PYTHONDONTWRITEBYTECODE 1
ENV PYTHONUNBUFFERED 1
# Set work directory
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
gcc \
python3-dev \
&& rm -rf /var/lib/apt/lists/*
# Install Python dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY . .
# Expose the port the app runs on
EXPOSE 8000
# Command to run the application
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

View File

@@ -0,0 +1,17 @@
version: '3.8'
services:
postgres:
image: postgres:13
container_name: postgres
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: iptv_manager
ports:
- "5432:5432"
volumes:
- postgres_data:/var/lib/postgresql/data
volumes:
postgres_data:

View File

@@ -0,0 +1,32 @@
version: '3.8'
services:
postgres:
image: postgres:13
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: iptv_manager
ports:
- "5432:5432"
volumes:
- postgres_data:/var/lib/postgresql/data
app:
build:
context: ..
dockerfile: docker/Dockerfile
environment:
DB_USER: postgres
DB_PASSWORD: postgres
DB_HOST: postgres
DB_NAME: iptv_manager
MOCK_AUTH: "true"
ports:
- "8000:8000"
depends_on:
- postgres
command: uvicorn app.main:app --host 0.0.0.0 --port 8000
volumes:
postgres_data:

View File

@@ -1,67 +1,84 @@
import os import os
from aws_cdk import (
Stack, from aws_cdk import CfnOutput, Duration, RemovalPolicy, Stack
aws_ec2 as ec2, from aws_cdk import aws_cognito as cognito
aws_iam as iam, from aws_cdk import aws_ec2 as ec2
aws_cognito as cognito, from aws_cdk import aws_iam as iam
CfnOutput from aws_cdk import aws_rds as rds
) from aws_cdk import aws_ssm as ssm
from constructs import Construct from constructs import Construct
class IptvUpdaterStack(Stack):
def __init__(self, scope: Construct, construct_id: str, **kwargs) -> None: class IptvManagerStack(Stack):
def __init__(
self,
scope: Construct,
construct_id: str,
freedns_user: str,
freedns_password: str,
domain_name: str,
ssh_public_key: str,
repo_url: str,
letsencrypt_email: str,
**kwargs,
) -> None:
super().__init__(scope, construct_id, **kwargs) super().__init__(scope, construct_id, **kwargs)
# Create VPC # Create VPC
vpc = ec2.Vpc(self, "IptvUpdaterVPC", vpc = ec2.Vpc(
max_azs=1, # Use only one AZ for free tier self,
"IptvManagerVPC",
max_azs=2, # Need at least 2 AZs for RDS subnet group
nat_gateways=0, # No NAT Gateway to stay in free tier nat_gateways=0, # No NAT Gateway to stay in free tier
subnet_configuration=[ subnet_configuration=[
ec2.SubnetConfiguration( ec2.SubnetConfiguration(
name="public", name="public", subnet_type=ec2.SubnetType.PUBLIC, cidr_mask=24
subnet_type=ec2.SubnetType.PUBLIC, ),
cidr_mask=24 ec2.SubnetConfiguration(
) name="private",
] subnet_type=ec2.SubnetType.PRIVATE_ISOLATED,
cidr_mask=24,
),
],
) )
# Security Group # Security Group
security_group = ec2.SecurityGroup( security_group = ec2.SecurityGroup(
self, "IptvUpdaterSG", self, "IptvManagerSG", vpc=vpc, allow_all_outbound=True
vpc=vpc,
allow_all_outbound=True
) )
security_group.add_ingress_rule( security_group.add_ingress_rule(
ec2.Peer.any_ipv4(), ec2.Peer.any_ipv4(), ec2.Port.tcp(443), "Allow HTTPS traffic"
ec2.Port.tcp(443),
"Allow HTTPS traffic"
) )
security_group.add_ingress_rule( security_group.add_ingress_rule(
ec2.Peer.any_ipv4(), ec2.Peer.any_ipv4(), ec2.Port.tcp(80), "Allow HTTP traffic"
ec2.Port.tcp(80),
"Allow HTTP traffic"
) )
security_group.add_ingress_rule( security_group.add_ingress_rule(
ec2.Peer.any_ipv4(), ec2.Peer.any_ipv4(), ec2.Port.tcp(22), "Allow SSH traffic"
ec2.Port.tcp(22),
"Allow SSH traffic"
) )
# Key pair for IPTV Updater instance # Allow PostgreSQL port for tunneling restricted to developer IP
security_group.add_ingress_rule(
ec2.Peer.ipv4("47.189.88.48/32"), # Developer IP
ec2.Port.tcp(5432),
"Allow PostgreSQL traffic for tunneling",
)
# Key pair for IPTV Manager instance
key_pair = ec2.KeyPair( key_pair = ec2.KeyPair(
self, self,
"IptvUpdaterKeyPair", "IptvManagerKeyPair",
key_pair_name="iptv-updater-key", key_pair_name="iptv-manager-key",
public_key_material="ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQDZcD20mfOQ/al6VMXWWUcUILYIHyIy5gY9RmC6koicaMYTJ078yMUnYw9DONEIfvxZwceTrtdUWv7vOXjknh1bberF78Pi3vwwFKazdgjbyZnkM9r2N40+PqfOFYf03B53VIA8jdae6gaD3VvYdlwV5dqqW3N+JE/+7sXHLeENnqxeS5jNLoRKDIxHgV4MBnewWNdp77ZEHZFrG3Fj/0lnqq2EfDh+5E/Vhkc/mzkuXu2l5Z1szw2rcjJz4IYUjgnClorBVlmWwwiICrHbyVCYivaaVvpVmZoBy7WW1P8KFAiot4G6C0Klyn4sy2AwCQ7u65TyDtbCR89VuuwW0zLgERAfWMdhn/5HdIudcTScXEsUsADTqxb2x0IXVbs3uhxHCUcc7BVg0S7dpAbmpK+80NlDCH38LFcyrYASRD6/Le2skVePIFt0Tw1OnwPH/QqPG0Y9vLZJl8779pg+kpk+o1MaRnczPq9Sk9zr3dR4Sv82CObnjTeY5LCTvs06JBCtcey/vLk7RQYs43uXZgg746aWIcM0lgHPivh2JpUOAd/Kj1v4aEXTRgHyPPLQ4KKwnALcd4s6+ytXwf4gxumRmPf/7P6HYJQvY8pfVda8jEvu08rvSVMXb09Jq4tzS/JTX2flfEdF0qIyTNHnK/lum27fgP0yq6Pq24IWYXha9w== stefano@MSI" public_key_material=ssh_public_key,
) )
# Create IAM role for EC2 # Create IAM role for EC2
role = iam.Role( role = iam.Role(
self, "IptvUpdaterRole", self,
assumed_by=iam.ServicePrincipal("ec2.amazonaws.com") "IptvManagerRole",
assumed_by=iam.ServicePrincipal("ec2.amazonaws.com"),
) )
# Add SSM managed policy # Add SSM managed policy
@@ -71,70 +88,65 @@ class IptvUpdaterStack(Stack):
) )
) )
# Add Cognito permissions to instance role # Add EC2 describe permissions
role.add_managed_policy( role.add_to_policy(
iam.ManagedPolicy.from_aws_managed_policy_name( iam.PolicyStatement(actions=["ec2:DescribeInstances"], resources=["*"])
"AmazonCognitoReadOnly" )
# Add SSM SendCommand permissions
role.add_to_policy(
iam.PolicyStatement(
actions=["ssm:SendCommand"],
resources=[
# Allow on all EC2 instances
f"arn:aws:ec2:{self.region}:{self.account}:instance/*",
# Required for the RunShellScript document
f"arn:aws:ssm:{self.region}:{self.account}:document/AWS-RunShellScript",
],
) )
) )
# EC2 Instance # Add Cognito permissions to instance role
instance = ec2.Instance( role.add_managed_policy(
self, "IptvUpdaterInstance", iam.ManagedPolicy.from_aws_managed_policy_name("AmazonCognitoReadOnly")
vpc=vpc,
instance_type=ec2.InstanceType.of(
ec2.InstanceClass.T2,
ec2.InstanceSize.MICRO
),
machine_image=ec2.AmazonLinuxImage(
generation=ec2.AmazonLinuxGeneration.AMAZON_LINUX_2
),
security_group=security_group,
key_pair=key_pair,
role=role
)
# Create Elastic IP
eip = ec2.CfnEIP(
self, "IptvUpdaterEIP",
domain="vpc",
instance_id=instance.instance_id
) )
# Add Cognito User Pool # Add Cognito User Pool
user_pool = cognito.UserPool( user_pool = cognito.UserPool(
self, "IptvUpdaterUserPool", self,
user_pool_name="iptv-updater-users", "IptvManagerUserPool",
user_pool_name="iptv-manager-users",
self_sign_up_enabled=False, # Only admins can create users self_sign_up_enabled=False, # Only admins can create users
password_policy=cognito.PasswordPolicy( password_policy=cognito.PasswordPolicy(
min_length=8, min_length=8,
require_lowercase=True, require_lowercase=True,
require_digits=True, require_digits=True,
require_symbols=True, require_symbols=True,
require_uppercase=True require_uppercase=True,
), ),
account_recovery=cognito.AccountRecovery.EMAIL_ONLY account_recovery=cognito.AccountRecovery.EMAIL_ONLY,
removal_policy=RemovalPolicy.DESTROY,
) )
# Add App Client with the correct callback URL # Add App Client with the correct callback URL
client = user_pool.add_client("IptvUpdaterClient", client = user_pool.add_client(
"IptvManagerClient",
access_token_validity=Duration.minutes(60),
id_token_validity=Duration.minutes(60),
refresh_token_validity=Duration.days(1),
auth_flows=cognito.AuthFlow(user_password=True),
o_auth=cognito.OAuthSettings( o_auth=cognito.OAuthSettings(
flows=cognito.OAuthFlows( flows=cognito.OAuthFlows(implicit_code_grant=True)
authorization_code_grant=True ),
), prevent_user_existence_errors=True,
scopes=[cognito.OAuthScope.OPENID], generate_secret=True,
callback_urls=[ enable_token_revocation=True,
"http://localhost:8000/auth/callback", # For local testing
"https://*.amazonaws.com/auth/callback" # Will match EC2 public DNS
]
)
) )
# Add domain for hosted UI # Add domain for hosted UI
domain = user_pool.add_domain("IptvUpdaterDomain", domain = user_pool.add_domain(
cognito_domain=cognito.CognitoDomainOptions( "IptvManagerDomain",
domain_prefix="iptv-updater" cognito_domain=cognito.CognitoDomainOptions(domain_prefix="iptv-manager"),
)
) )
# Read the userdata script with proper path resolution # Read the userdata script with proper path resolution
@@ -144,20 +156,152 @@ class IptvUpdaterStack(Stack):
# Creates a userdata object for Linux hosts # Creates a userdata object for Linux hosts
userdata = ec2.UserData.for_linux() userdata = ec2.UserData.for_linux()
# Add environment variables for acme.sh from parameters
userdata.add_commands(
f'export FREEDNS_User="{freedns_user}"',
f'export FREEDNS_Password="{freedns_password}"',
f'export DOMAIN_NAME="{domain_name}"',
f'export REPO_URL="{repo_url}"',
f'export LETSENCRYPT_EMAIL="{letsencrypt_email}"',
)
# Adds one or more commands to the userdata object. # Adds one or more commands to the userdata object.
userdata.add_commands( userdata.add_commands(
f'echo "COGNITO_USER_POOL_ID={user_pool.user_pool_id}" >> /etc/environment', (
f'echo "COGNITO_CLIENT_ID={client.user_pool_client_id}" >> /etc/environment' f'echo "COGNITO_USER_POOL_ID='
f'{user_pool.user_pool_id}" >> /etc/environment'
),
(
f'echo "COGNITO_CLIENT_ID='
f'{client.user_pool_client_id}" >> /etc/environment'
),
(
f'echo "COGNITO_CLIENT_SECRET='
f'{client.user_pool_client_secret.to_string()}" >> /etc/environment'
),
f'echo "DOMAIN_NAME={domain_name}" >> /etc/environment',
) )
userdata.add_commands(str(userdata_file, 'utf-8')) userdata.add_commands(str(userdata_file, "utf-8"))
# Create RDS Security Group
rds_sg = ec2.SecurityGroup(
self,
"RdsSecurityGroup",
vpc=vpc,
description="Security group for RDS PostgreSQL",
)
rds_sg.add_ingress_rule(
security_group,
ec2.Port.tcp(5432),
"Allow PostgreSQL access from EC2 instance",
)
# Create RDS PostgreSQL instance (free tier compatible - db.t3.micro)
db = rds.DatabaseInstance(
self,
"IptvManagerDB",
engine=rds.DatabaseInstanceEngine.postgres(
version=rds.PostgresEngineVersion.VER_13
),
instance_type=ec2.InstanceType.of(
ec2.InstanceClass.T3, ec2.InstanceSize.MICRO
),
vpc=vpc,
vpc_subnets=ec2.SubnetSelection(
subnet_type=ec2.SubnetType.PRIVATE_ISOLATED
),
security_groups=[rds_sg],
allocated_storage=10,
max_allocated_storage=10,
database_name="iptv_manager",
removal_policy=RemovalPolicy.DESTROY,
deletion_protection=False,
publicly_accessible=False, # Avoid public IPv4 charges
)
# Add RDS permissions to instance role
role.add_managed_policy(
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonRDSFullAccess")
)
# Store DB connection info in SSM Parameter Store
db_host_param = ssm.StringParameter(
self,
"DBHostParam",
parameter_name="/iptv-manager/DB_HOST",
string_value=db.db_instance_endpoint_address,
)
db_name_param = ssm.StringParameter(
self,
"DBNameParam",
parameter_name="/iptv-manager/DB_NAME",
string_value="iptv_manager",
)
db_user_param = ssm.StringParameter(
self,
"DBUserParam",
parameter_name="/iptv-manager/DB_USER",
string_value=db.secret.secret_value_from_json("username").to_string(),
)
db_pass_param = ssm.StringParameter(
self,
"DBPassParam",
parameter_name="/iptv-manager/DB_PASSWORD",
string_value=db.secret.secret_value_from_json("password").to_string(),
)
# Add SSM read permissions to instance role
role.add_managed_policy(
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonSSMReadOnlyAccess")
)
# EC2 Instance (created after all dependencies are ready)
instance = ec2.Instance(
self,
"IptvManagerInstance",
vpc=vpc,
vpc_subnets=ec2.SubnetSelection(subnet_type=ec2.SubnetType.PUBLIC),
instance_type=ec2.InstanceType.of(
ec2.InstanceClass.T2, ec2.InstanceSize.MICRO
),
machine_image=ec2.AmazonLinuxImage(
generation=ec2.AmazonLinuxGeneration.AMAZON_LINUX_2023
),
security_group=security_group,
key_pair=key_pair,
role=role,
# Option: 1: Enable auto-assign public IP (free tier compatible)
associate_public_ip_address=True,
)
# Ensure instance depends on SSM parameters being created
instance.node.add_dependency(db)
instance.node.add_dependency(db_host_param)
instance.node.add_dependency(db_name_param)
instance.node.add_dependency(db_user_param)
instance.node.add_dependency(db_pass_param)
# Option: 2: Create Elastic IP (not free tier compatible)
# eip = ec2.CfnEIP(
# self, "IptvManagerEIP",
# domain="vpc",
# instance_id=instance.instance_id
# )
# Update instance with userdata # Update instance with userdata
instance.add_user_data(userdata.render()) instance.add_user_data(userdata.render())
# Outputs # Outputs
CfnOutput(self, "InstancePublicIP", value=eip.attr_public_ip) CfnOutput(self, "DBEndpoint", value=db.db_instance_endpoint_address)
# Option: 1: Use EC2 instance public IP (free tier compatible)
CfnOutput(self, "InstancePublicIP", value=instance.instance_public_ip)
# Option: 2: Use EIP (not free tier compatible)
# CfnOutput(self, "InstancePublicIP", value=eip.attr_public_ip)
CfnOutput(self, "UserPoolId", value=user_pool.user_pool_id) CfnOutput(self, "UserPoolId", value=user_pool.user_pool_id)
CfnOutput(self, "UserPoolClientId", value=client.user_pool_client_id) CfnOutput(self, "UserPoolClientId", value=client.user_pool_client_id)
CfnOutput(self, "CognitoDomainUrl", CfnOutput(
value=f"https://{domain.domain_name}.auth.{self.region}.amazoncognito.com" self,
"CognitoDomainUrl",
value=f"https://{domain.domain_name}.auth.{self.region}.amazoncognito.com",
) )

View File

@@ -1,29 +1,79 @@
#!/bin/sh #!/bin/sh
yum update -y # Update system and install required packages
yum install -y python3-pip git dnf update -y
amazon-linux-extras install nginx1 dnf install -y python3-pip git cronie nginx certbot python3-certbot-nginx postgresql15.x86_64 awscli
pip3 install --upgrade pip # Start and enable crond service
pip3 install certbot certbot-nginx systemctl start crond
systemctl enable crond
cd /home/ec2-user cd /home/ec2-user
git clone https://git.fiorinis.com/Home/iptv-updater-aws.git git clone ${REPO_URL}
cd iptv-updater-aws cd iptv-manager-service
pip3 install -r requirements.txt # Install Python packages with --ignore-installed to prevent conflicts with RPM packages
pip3 install --ignore-installed -r requirements.txt
# Retrieve DB credentials from SSM Parameter Store with retries
echo "Attempting to retrieve DB credentials from SSM..."
for i in {1..30}; do
DB_HOST=$(aws ssm get-parameter --name "/iptv-manager/DB_HOST" --query "Parameter.Value" --output text 2>/dev/null)
DB_NAME=$(aws ssm get-parameter --name "/iptv-manager/DB_NAME" --query "Parameter.Value" --output text 2>/dev/null)
DB_USER=$(aws ssm get-parameter --name "/iptv-manager/DB_USER" --query "Parameter.Value" --output text 2>/dev/null)
DB_PASSWORD=$(aws ssm get-parameter --name "/iptv-manager/DB_PASSWORD" --query "Parameter.Value" --output text 2>/dev/null)
if [ -n "$DB_HOST" ] && [ -n "$DB_NAME" ] && [ -n "$DB_USER" ] && [ -n "$DB_PASSWORD" ]; then
echo "Successfully retrieved all DB credentials"
break
fi
echo "Waiting for SSM parameters to be available... (attempt $i/30)"
sleep 5
done
if [ -z "$DB_HOST" ] || [ -z "$DB_NAME" ] || [ -z "$DB_USER" ] || [ -z "$DB_PASSWORD" ]; then
echo "ERROR: Failed to retrieve all required DB credentials after 30 attempts"
exit 1
fi
export DB_HOST
export DB_NAME
export DB_USER
export DB_PASSWORD
# Set PGPASSWORD for psql to use
export PGPASSWORD=$DB_PASSWORD
# Wait for PostgreSQL to be ready
echo "Waiting for PostgreSQL to start..."
until psql -h $DB_HOST -U $DB_USER -d postgres -c '\q'; do
sleep 1
done
echo "PostgreSQL is ready."
# Create database if it does not exist
DB_EXISTS=$(psql -h $DB_HOST -U $DB_USER -d postgres -tc "SELECT 1 FROM pg_database WHERE datname = '$DB_NAME';")
if [ -z "$DB_EXISTS" ]; then
echo "Creating database $DB_NAME..."
psql -h $DB_HOST -U $DB_USER -d postgres -c "CREATE DATABASE $DB_NAME;"
echo "Database $DB_NAME created."
fi
# Run database migrations
alembic upgrade head
# Create systemd service file # Create systemd service file
cat << 'EOF' > /etc/systemd/system/iptv-updater.service cat << 'EOF' > /etc/systemd/system/iptv-manager.service
[Unit] [Unit]
Description=IPTV Updater Service Description=IPTV Manager Service
After=network.target After=network.target
[Service] [Service]
Type=simple Type=simple
User=ec2-user User=ec2-user
WorkingDirectory=/home/ec2-user/iptv-updater-aws WorkingDirectory=/home/ec2-user/iptv-manager-service
ExecStart=/usr/local/bin/uvicorn app.main:app --host 127.0.0.1 --port 8000 ExecStart=/usr/local/bin/uvicorn app.main:app --host 127.0.0.1 --port 8000
EnvironmentFile=/etc/environment EnvironmentFile=/etc/environment
Restart=always Restart=always
@@ -32,17 +82,42 @@ Restart=always
WantedBy=multi-user.target WantedBy=multi-user.target
EOF EOF
# Ensure root has a crontab before installing acme.sh
crontab -u root -l >/dev/null 2>&1 || (echo "" | crontab -u root -)
# Install and configure acme.sh
curl https://get.acme.sh | sh -s email="${LETSENCRYPT_EMAIL}"
# Configure acme.sh to use DNS API for FreeDNS
. "/.acme.sh/acme.sh.env"
"/.acme.sh"/acme.sh --issue --dns dns_freedns -d ${DOMAIN_NAME} -d *.${DOMAIN_NAME}
sudo mkdir -p /etc/nginx/ssl
"/.acme.sh"/acme.sh --install-cert -d ${DOMAIN_NAME} -d *.${DOMAIN_NAME} \
--key-file /etc/nginx/ssl/${DOMAIN_NAME}.pem \
--fullchain-file /etc/nginx/ssl/cert.pem \
--reloadcmd "service nginx force-reload"
# Create nginx config # Create nginx config
cat << 'EOF' > /etc/nginx/conf.d/iptvUpdater.conf cat << EOF > /etc/nginx/conf.d/iptvManager.conf
server { server {
listen 80; listen 80;
server_name $HOSTNAME; server_name ${DOMAIN_NAME} *.${DOMAIN_NAME};
return 301 https://\$host\$request_uri;
}
server {
listen 443 ssl;
server_name ${DOMAIN_NAME} *.${DOMAIN_NAME};
ssl_certificate /etc/nginx/ssl/cert.pem;
ssl_certificate_key /etc/nginx/ssl/${DOMAIN_NAME}.pem;
location / { location / {
proxy_pass http://127.0.0.1:8000; proxy_pass http://127.0.0.1:8000;
proxy_set_header Host $host; proxy_set_header Host \$host;
proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Real-IP \$remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-For \$proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme; proxy_set_header X-Forwarded-Proto \$scheme;
} }
} }
EOF EOF
@@ -50,5 +125,5 @@ EOF
# Start nginx service # Start nginx service
systemctl enable nginx systemctl enable nginx
systemctl start nginx systemctl start nginx
systemctl enable iptv-updater systemctl enable iptv-manager
systemctl start iptv-updater systemctl start iptv-manager

View File

@@ -1,5 +0,0 @@
#!/bin/bash
# Install dependencies and deploy infrastructure
npm install -g aws-cdk
python3 -m pip install -r requirements.txt

31
pyproject.toml Normal file
View File

@@ -0,0 +1,31 @@
[tool.ruff]
line-length = 88
exclude = [
"alembic/versions/*.py", # Auto-generated Alembic migration files
]
[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"F", # pyflakes
"I", # isort
"UP", # pyupgrade
"W", # pycodestyle warnings
]
ignore = []
[tool.ruff.lint.per-file-ignores]
"tests/**/*.py" = [
"F811", # redefinition of unused name
"F401", # unused import
]
[tool.ruff.lint.isort]
known-first-party = ["app"]
[tool.ruff.format]
docstring-code-format = true
[tool.pytest.ini_options]
addopts = "--cov=app --cov-report=term-missing --cov-fail-under=70"
testpaths = ["tests"]

28
pytest.ini Normal file
View File

@@ -0,0 +1,28 @@
[pytest]
testpaths = tests
python_files = test_*.py
python_functions = test_*
asyncio_mode = auto
filterwarnings =
ignore::DeprecationWarning:botocore.auth
ignore:The 'app' shortcut is now deprecated:DeprecationWarning:httpx._client
# Coverage configuration
addopts =
--cov=app
--cov-report=term-missing
# Test environment variables
env =
MOCK_AUTH=true
DB_USER=test_user
DB_PASSWORD=test_password
DB_HOST=localhost
DB_NAME=iptv_manager_test
# Test markers
markers =
slow: mark tests as slow running
integration: integration tests
unit: unit tests
db: tests requiring database

View File

@@ -9,3 +9,14 @@ passlib[bcrypt]==1.7.4
boto3==1.28.0 boto3==1.28.0
starlette>=0.27.0 starlette>=0.27.0
pyjwt==2.7.0 pyjwt==2.7.0
sqlalchemy==2.0.23
psycopg2-binary==2.9.9
alembic==1.16.1
pytest==8.1.1
pytest-asyncio==0.23.6
pytest-mock==3.12.0
pytest-cov==4.1.0
pytest-env==1.1.1
httpx==0.27.0
pre-commit
apscheduler==3.10.4

36
scripts/create_cognito_user.sh Executable file
View File

@@ -0,0 +1,36 @@
#!/bin/bash
set -e
if [ "$#" -lt 3 ]; then
echo "Usage: $0 USER_POOL_ID USERNAME PASSWORD [--admin]"
exit 1
fi
USER_POOL_ID=$1
USERNAME=$2
PASSWORD=$3
ADMIN_FLAG=${4:-""}
# Create user with temporary password
CREATE_CMD="aws cognito-idp admin-create-user --no-cli-pager \
--user-pool-id \"$USER_POOL_ID\" \
--username \"$USERNAME\" \
--temporary-password \"TempPass123!\" \
--output json > /dev/null 2>&1"
if [ "$ADMIN_FLAG" == "--admin" ]; then
CREATE_CMD+=" --user-attributes Name=zoneinfo,Value=admin"
fi
eval "$CREATE_CMD"
# Set permanent password
aws cognito-idp admin-set-user-password --no-cli-pager \
--user-pool-id "$USER_POOL_ID" \
--username "$USERNAME" \
--password "$PASSWORD" \
--permanent \
--output json > /dev/null 2>&1
echo "User $USERNAME created successfully"

18
scripts/delete_cognito_user.sh Executable file
View File

@@ -0,0 +1,18 @@
#!/bin/bash
set -e
if [ "$#" -ne 2 ]; then
echo "Usage: $0 USER_POOL_ID USERNAME"
exit 1
fi
USER_POOL_ID=$1
USERNAME=$2
aws cognito-idp admin-delete-user --no-cli-pager \
--user-pool-id "$USER_POOL_ID" \
--username "$USERNAME" \
--output json > /dev/null 2>&1
echo "User $USERNAME deleted successfully"

43
scripts/deploy.sh Executable file
View File

@@ -0,0 +1,43 @@
#!/bin/bash
# Load environment variables from .env file if it exists
if [ -f ${PWD}/.env ]; then
# Use set -a to automatically export all variables
set -a
source ${PWD}/.env
set +a
fi
# Check if required environment variables are set
if [ -z "$FREEDNS_User" ] ||
[ -z "$FREEDNS_Password" ] ||
[ -z "$DOMAIN_NAME" ] ||
[ -z "$SSH_PUBLIC_KEY" ] ||
[ -z "$REPO_URL" ] ||
[ -z "$LETSENCRYPT_EMAIL" ]; then
echo "Error: FREEDNS_User, FREEDNS_Password, DOMAIN_NAME, SSH_PUBLIC_KEY, REPO_URL, and LETSENCRYPT_EMAIL must be set as environment variables or in a .env file."
exit 1
fi
# Deploy infrastructure
cdk deploy --app="python3 ${PWD}/app.py"
# Update application on running instances
INSTANCE_IDS=$(aws ec2 describe-instances \
--region us-east-2 \
--filters "Name=tag:Name,Values=IptvManagerStack/IptvManagerInstance" \
"Name=instance-state-name,Values=running" \
--query "Reservations[].Instances[].InstanceId" \
--output text)
for INSTANCE_ID in $INSTANCE_IDS; do
echo "Updating application on instance: $INSTANCE_ID"
aws ssm send-command \
--instance-ids "$INSTANCE_ID" \
--document-name "AWS-RunShellScript" \
--parameters '{"commands":["cd /home/ec2-user/iptv-manager-service && git pull && pip3 install -r requirements.txt && alembic upgrade head && sudo systemctl restart iptv-manager"]}' \
--no-cli-pager \
--no-paginate
done
echo "Deployment and instance update complete"

23
scripts/destroy.sh Executable file
View File

@@ -0,0 +1,23 @@
#!/bin/bash
# Load environment variables from .env file if it exists
if [ -f ${PWD}/.env ]; then
# Use set -a to automatically export all variables
set -a
source ${PWD}/.env
set +a
fi
# Check if required environment variables are set
if [ -z "$FREEDNS_User" ] ||
[ -z "$FREEDNS_Password" ] ||
[ -z "$DOMAIN_NAME" ] ||
[ -z "$SSH_PUBLIC_KEY" ] ||
[ -z "$REPO_URL" ] ||
[ -z "$LETSENCRYPT_EMAIL" ]; then
echo "Error: FREEDNS_User, FREEDNS_Password, DOMAIN_NAME, SSH_PUBLIC_KEY, REPO_URL, and LETSENCRYPT_EMAIL must be set as environment variables or in a .env file."
exit 1
fi
# Destroy infrastructure
cdk destroy --app="python3 ${PWD}/app.py" --force

13
scripts/install.sh Executable file
View File

@@ -0,0 +1,13 @@
#!/bin/bash
# Install dependencies and deploy infrastructure
npm install -g aws-cdk
python3 -m pip install -r requirements.txt
# Install and configure pre-commit hooks
pre-commit install
pre-commit install-hooks
pre-commit autoupdate
# Verify pytest setup
python3 -m pytest

28
scripts/start_local_dev.sh Executable file
View File

@@ -0,0 +1,28 @@
#!/bin/bash
set -e
# Start PostgreSQL
docker-compose -f docker/docker-compose-db.yml up -d
# Set environment variables
export MOCK_AUTH=true
export DB_HOST=localhost
export DB_USER=postgres
export DB_PASSWORD=postgres
export DB_NAME=iptv_manager
echo "Ensuring database $DB_NAME exists using conditional DDL..."
PGPASSWORD=$DB_PASSWORD docker exec -i postgres psql -U $DB_USER <<< "SELECT 'CREATE DATABASE $DB_NAME' WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = '$DB_NAME')\gexec"
echo "Database $DB_NAME check complete."
# Run database migrations
alembic upgrade head
# Start FastAPI
nohup uvicorn app.main:app --host 127.0.0.1 --port 8000 > app.log 2>&1 &
echo $! > iptv-manager.pid
echo "Services started:"
echo "- PostgreSQL running on localhost:5432"
echo "- FastAPI running on http://127.0.0.1:8000"
echo "- Mock auth enabled (use token: testuser)"

19
scripts/stop_local_dev.sh Executable file
View File

@@ -0,0 +1,19 @@
#!/bin/bash
# Stop FastAPI
if [ -f iptv-manager.pid ]; then
kill $(cat iptv-manager.pid)
rm iptv-manager.pid
echo "Stopped FastAPI"
fi
# Clean up mock auth and database environment variables
unset MOCK_AUTH
unset DB_USER
unset DB_PASSWORD
unset DB_HOST
unset DB_NAME
# Stop PostgreSQL
docker-compose -f docker/docker-compose-db.yml down
echo "Stopped PostgreSQL"

0
tests/__init__.py Normal file
View File

0
tests/auth/__init__.py Normal file
View File

169
tests/auth/test_cognito.py Normal file
View File

@@ -0,0 +1,169 @@
from unittest.mock import MagicMock, patch
import pytest
from fastapi import HTTPException, status
# Test constants
TEST_CLIENT_ID = "test_client_id"
TEST_CLIENT_SECRET = "test_client_secret"
# Patch constants before importing the module
with (
patch("app.utils.constants.COGNITO_CLIENT_ID", TEST_CLIENT_ID),
patch("app.utils.constants.COGNITO_CLIENT_SECRET", TEST_CLIENT_SECRET),
):
from app.auth.cognito import get_user_from_token, initiate_auth
from app.models.auth import CognitoUser
from app.utils.constants import USER_ROLE_ATTRIBUTE
@pytest.fixture(autouse=True)
def mock_cognito_client():
with patch("app.auth.cognito.cognito_client") as mock_client:
# Setup mock client and exceptions
mock_client.exceptions = MagicMock()
mock_client.exceptions.NotAuthorizedException = type(
"NotAuthorizedException", (Exception,), {}
)
mock_client.exceptions.UserNotFoundException = type(
"UserNotFoundException", (Exception,), {}
)
yield mock_client
def test_initiate_auth_success(mock_cognito_client):
# Mock successful authentication response
mock_cognito_client.initiate_auth.return_value = {
"AuthenticationResult": {
"AccessToken": "mock_access_token",
"IdToken": "mock_id_token",
"RefreshToken": "mock_refresh_token",
}
}
result = initiate_auth("test_user", "test_pass")
assert result == {
"AccessToken": "mock_access_token",
"IdToken": "mock_id_token",
"RefreshToken": "mock_refresh_token",
}
def test_initiate_auth_with_secret_hash(mock_cognito_client):
with patch(
"app.auth.cognito.calculate_secret_hash", return_value="mocked_secret_hash"
) as mock_hash:
mock_cognito_client.initiate_auth.return_value = {
"AuthenticationResult": {"AccessToken": "token"}
}
initiate_auth("test_user", "test_pass")
# Verify calculate_secret_hash was called
mock_hash.assert_called_once_with(
"test_user", TEST_CLIENT_ID, TEST_CLIENT_SECRET
)
# Verify SECRET_HASH was included in auth params
call_args = mock_cognito_client.initiate_auth.call_args[1]
assert "SECRET_HASH" in call_args["AuthParameters"]
assert call_args["AuthParameters"]["SECRET_HASH"] == "mocked_secret_hash"
def test_initiate_auth_not_authorized(mock_cognito_client):
mock_cognito_client.initiate_auth.side_effect = (
mock_cognito_client.exceptions.NotAuthorizedException()
)
with pytest.raises(HTTPException) as exc_info:
initiate_auth("invalid_user", "wrong_pass")
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert exc_info.value.detail == "Invalid username or password"
def test_initiate_auth_user_not_found(mock_cognito_client):
mock_cognito_client.initiate_auth.side_effect = (
mock_cognito_client.exceptions.UserNotFoundException()
)
with pytest.raises(HTTPException) as exc_info:
initiate_auth("nonexistent_user", "any_pass")
assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND
assert exc_info.value.detail == "User not found"
def test_initiate_auth_generic_error(mock_cognito_client):
mock_cognito_client.initiate_auth.side_effect = Exception("Some error")
with pytest.raises(HTTPException) as exc_info:
initiate_auth("test_user", "test_pass")
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert "An error occurred during authentication" in exc_info.value.detail
def test_get_user_from_token_success(mock_cognito_client):
mock_response = {
"Username": "test_user",
"UserAttributes": [
{"Name": "sub", "Value": "123"},
{"Name": USER_ROLE_ATTRIBUTE, "Value": "admin,user"},
],
}
mock_cognito_client.get_user.return_value = mock_response
result = get_user_from_token("valid_token")
assert isinstance(result, CognitoUser)
assert result.username == "test_user"
assert set(result.roles) == {"admin", "user"}
def test_get_user_from_token_no_roles(mock_cognito_client):
mock_response = {
"Username": "test_user",
"UserAttributes": [{"Name": "sub", "Value": "123"}],
}
mock_cognito_client.get_user.return_value = mock_response
result = get_user_from_token("valid_token")
assert isinstance(result, CognitoUser)
assert result.username == "test_user"
assert result.roles == []
def test_get_user_from_token_invalid_token(mock_cognito_client):
mock_cognito_client.get_user.side_effect = (
mock_cognito_client.exceptions.NotAuthorizedException()
)
with pytest.raises(HTTPException) as exc_info:
get_user_from_token("invalid_token")
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert exc_info.value.detail == "Invalid or expired token."
def test_get_user_from_token_user_not_found(mock_cognito_client):
mock_cognito_client.get_user.side_effect = (
mock_cognito_client.exceptions.UserNotFoundException()
)
with pytest.raises(HTTPException) as exc_info:
get_user_from_token("token_for_nonexistent_user")
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert exc_info.value.detail == "User not found or invalid token."
def test_get_user_from_token_generic_error(mock_cognito_client):
mock_cognito_client.get_user.side_effect = Exception("Some error")
with pytest.raises(HTTPException) as exc_info:
get_user_from_token("test_token")
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert "Token verification failed" in exc_info.value.detail

View File

@@ -0,0 +1,205 @@
import importlib
import os
import pytest
from fastapi import Depends, HTTPException, Request
from fastapi.security import OAuth2PasswordBearer
from app.auth.dependencies import get_current_user, oauth2_scheme, require_roles
from app.models.auth import CognitoUser
# Mock user for testing
TEST_USER = CognitoUser(
username="testuser",
email="test@example.com",
roles=["admin", "user"],
groups=["test_group"],
)
# Mock the underlying get_user_from_token function
def mock_get_user_from_token(token: str) -> CognitoUser:
if token == "valid_token":
return TEST_USER
raise HTTPException(status_code=401, detail="Invalid token")
# Mock endpoint for testing the require_roles decorator
@require_roles("admin")
def mock_protected_endpoint(user: CognitoUser = Depends(get_current_user)):
return {"message": "Success", "user": user.username}
# Patch the get_user_from_token function for testing
@pytest.fixture(autouse=True)
def mock_auth(monkeypatch):
monkeypatch.setattr(
"app.auth.dependencies.get_user_from_token", mock_get_user_from_token
)
# Test get_current_user dependency
def test_get_current_user_success():
user = get_current_user("valid_token")
assert user == TEST_USER
assert user.username == "testuser"
assert user.roles == ["admin", "user"]
def test_get_current_user_invalid_token():
with pytest.raises(HTTPException) as exc:
get_current_user("invalid_token")
assert exc.value.status_code == 401
# Test require_roles decorator
@pytest.mark.asyncio
async def test_require_roles_success():
# Create test user with required role
user = CognitoUser(
username="testuser", email="test@example.com", roles=["admin"], groups=[]
)
result = await mock_protected_endpoint(user=user)
assert result == {"message": "Success", "user": "testuser"}
@pytest.mark.asyncio
async def test_require_roles_missing_role():
# Create test user without required role
user = CognitoUser(
username="testuser", email="test@example.com", roles=["user"], groups=[]
)
with pytest.raises(HTTPException) as exc:
await mock_protected_endpoint(user=user)
assert exc.value.status_code == 403
assert (
exc.value.detail
== "You do not have the required roles to access this endpoint."
)
@pytest.mark.asyncio
async def test_require_roles_no_roles():
# Create test user with no roles
user = CognitoUser(
username="testuser", email="test@example.com", roles=[], groups=[]
)
with pytest.raises(HTTPException) as exc:
await mock_protected_endpoint(user=user)
assert exc.value.status_code == 403
@pytest.mark.asyncio
async def test_require_roles_multiple_roles():
# Test requiring multiple roles
@require_roles("admin", "super_user")
def mock_multi_role_endpoint(user: CognitoUser = Depends(get_current_user)):
return {"message": "Success"}
# User with all required roles
user_with_roles = CognitoUser(
username="testuser",
email="test@example.com",
roles=["admin", "super_user", "user"],
groups=[],
)
result = await mock_multi_role_endpoint(user=user_with_roles)
assert result == {"message": "Success"}
# User missing one required role
user_missing_role = CognitoUser(
username="testuser",
email="test@example.com",
roles=["admin", "user"],
groups=[],
)
with pytest.raises(HTTPException) as exc:
await mock_multi_role_endpoint(user=user_missing_role)
assert exc.value.status_code == 403
@pytest.mark.asyncio
async def test_oauth2_scheme_configuration():
# Verify that we have a properly configured OAuth2PasswordBearer instance
assert isinstance(oauth2_scheme, OAuth2PasswordBearer)
# Create a mock request with no Authorization header
mock_request = Request(
scope={
"type": "http",
"headers": [],
"method": "GET",
"scheme": "http",
"path": "/",
"query_string": b"",
"client": ("127.0.0.1", 8000),
}
)
# Test that the scheme raises 401 when no token is provided
with pytest.raises(HTTPException) as exc:
await oauth2_scheme(mock_request)
assert exc.value.status_code == 401
assert exc.value.detail == "Not authenticated"
def test_mock_auth_import(monkeypatch):
# Save original env var value
original_value = os.environ.get("MOCK_AUTH")
try:
# Set MOCK_AUTH to true
monkeypatch.setenv("MOCK_AUTH", "true")
# Reload the dependencies module to trigger the import condition
import app.auth.dependencies
importlib.reload(app.auth.dependencies)
# Verify that mock_get_user_from_token was imported
from app.auth.dependencies import get_user_from_token
assert get_user_from_token.__module__ == "app.auth.mock_auth"
finally:
# Restore original env var
if original_value is None:
monkeypatch.delenv("MOCK_AUTH", raising=False)
else:
monkeypatch.setenv("MOCK_AUTH", original_value)
# Reload again to restore original state
importlib.reload(app.auth.dependencies)
def test_cognito_auth_import(monkeypatch):
"""Test that cognito auth is imported when MOCK_AUTH=false (covers line 14)"""
# Save original env var value
original_value = os.environ.get("MOCK_AUTH")
try:
# Set MOCK_AUTH to false
monkeypatch.setenv("MOCK_AUTH", "false")
# Reload the dependencies module to trigger the import condition
import app.auth.dependencies
importlib.reload(app.auth.dependencies)
# Verify that get_user_from_token was imported from app.auth.cognito
from app.auth.dependencies import get_user_from_token
assert get_user_from_token.__module__ == "app.auth.cognito"
finally:
# Restore original env var
if original_value is None:
monkeypatch.delenv("MOCK_AUTH", raising=False)
else:
monkeypatch.setenv("MOCK_AUTH", original_value)
# Reload again to restore original state
importlib.reload(app.auth.dependencies)

View File

@@ -0,0 +1,41 @@
import pytest
from fastapi import HTTPException
from app.auth.mock_auth import mock_get_user_from_token, mock_initiate_auth
from app.models.auth import CognitoUser
def test_mock_get_user_from_token_success():
"""Test successful token validation returns expected user"""
user = mock_get_user_from_token("testuser")
assert isinstance(user, CognitoUser)
assert user.username == "testuser"
assert user.roles == ["admin"]
def test_mock_get_user_from_token_invalid():
"""Test invalid token raises expected exception"""
with pytest.raises(HTTPException) as exc_info:
mock_get_user_from_token("invalid_token")
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Invalid mock token - use 'testuser'"
def test_mock_initiate_auth():
"""Test mock authentication returns expected token response"""
result = mock_initiate_auth("any_user", "any_password")
assert isinstance(result, dict)
assert result["AccessToken"] == "testuser"
assert result["ExpiresIn"] == 3600
assert result["TokenType"] == "Bearer"
def test_mock_initiate_auth_different_credentials():
"""Test mock authentication works with any credentials"""
result1 = mock_initiate_auth("user1", "pass1")
result2 = mock_initiate_auth("user2", "pass2")
# Should return same mock token regardless of credentials
assert result1 == result2

0
tests/iptv/__init__.py Normal file
View File

0
tests/models/__init__.py Normal file
View File

144
tests/models/test_db.py Normal file
View File

@@ -0,0 +1,144 @@
import os
import uuid
from unittest.mock import patch
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from app.models.db import UUID_COLUMN_TYPE, Base, SQLiteUUID
# --- Test SQLiteUUID Type ---
def test_sqliteuuid_process_bind_param_none():
"""Test SQLiteUUID.process_bind_param with None returns None"""
uuid_type = SQLiteUUID()
assert uuid_type.process_bind_param(None, None) is None
def test_sqliteuuid_process_bind_param_valid_uuid():
"""Test SQLiteUUID.process_bind_param with valid UUID returns string"""
uuid_type = SQLiteUUID()
test_uuid = uuid.uuid4()
assert uuid_type.process_bind_param(test_uuid, None) == str(test_uuid)
def test_sqliteuuid_process_bind_param_valid_string():
"""Test SQLiteUUID.process_bind_param with valid UUID string returns string"""
uuid_type = SQLiteUUID()
test_uuid_str = "550e8400-e29b-41d4-a716-446655440000"
assert uuid_type.process_bind_param(test_uuid_str, None) == test_uuid_str
def test_sqliteuuid_process_bind_param_invalid_string():
"""Test SQLiteUUID.process_bind_param raises ValueError for invalid UUID"""
uuid_type = SQLiteUUID()
with pytest.raises(ValueError, match="Invalid UUID string format"):
uuid_type.process_bind_param("invalid-uuid", None)
def test_sqliteuuid_process_result_value_none():
"""Test SQLiteUUID.process_result_value with None returns None"""
uuid_type = SQLiteUUID()
assert uuid_type.process_result_value(None, None) is None
def test_sqliteuuid_process_result_value_valid_string():
"""Test SQLiteUUID.process_result_value converts string to UUID"""
uuid_type = SQLiteUUID()
test_uuid = uuid.uuid4()
result = uuid_type.process_result_value(str(test_uuid), None)
assert isinstance(result, uuid.UUID)
assert result == test_uuid
def test_sqliteuuid_process_result_value_uuid_object():
"""Test SQLiteUUID.process_result_value: UUID object returns itself."""
uuid_type = SQLiteUUID()
test_uuid = uuid.uuid4()
result = uuid_type.process_result_value(test_uuid, None)
assert isinstance(result, uuid.UUID)
assert result is test_uuid # Ensure it's the same object, not a new one
def test_sqliteuuid_compare_values_none():
"""Test SQLiteUUID.compare_values handles None values"""
uuid_type = SQLiteUUID()
assert uuid_type.compare_values(None, None) is True
assert uuid_type.compare_values(None, uuid.uuid4()) is False
assert uuid_type.compare_values(uuid.uuid4(), None) is False
def test_sqliteuuid_compare_values_uuid():
"""Test SQLiteUUID.compare_values compares UUIDs as strings"""
uuid_type = SQLiteUUID()
test_uuid = uuid.uuid4()
assert uuid_type.compare_values(test_uuid, test_uuid) is True
assert uuid_type.compare_values(test_uuid, uuid.uuid4()) is False
def test_sqlite_uuid_comparison():
"""Test SQLiteUUID comparison functionality (moved from db_mocks.py)"""
uuid_type = SQLiteUUID()
# Test equal UUIDs
uuid1 = uuid.uuid4()
uuid2 = uuid.UUID(str(uuid1))
assert uuid_type.compare_values(uuid1, uuid2) is True
# Test UUID vs string
assert uuid_type.compare_values(uuid1, str(uuid1)) is True
# Test None comparisons
assert uuid_type.compare_values(None, None) is True
assert uuid_type.compare_values(uuid1, None) is False
assert uuid_type.compare_values(None, uuid1) is False
# Test different UUIDs
uuid3 = uuid.uuid4()
assert uuid_type.compare_values(uuid1, uuid3) is False
def test_sqlite_uuid_binding():
"""Test SQLiteUUID binding parameter handling (moved from db_mocks.py)"""
uuid_type = SQLiteUUID()
# Test UUID object binding
uuid_obj = uuid.uuid4()
assert uuid_type.process_bind_param(uuid_obj, None) == str(uuid_obj)
# Test valid UUID string binding
uuid_str = str(uuid.uuid4())
assert uuid_type.process_bind_param(uuid_str, None) == uuid_str
# Test None handling
assert uuid_type.process_bind_param(None, None) is None
# Test invalid UUID string
with pytest.raises(ValueError):
uuid_type.process_bind_param("invalid-uuid", None)
# --- Test UUID Column Type Configuration ---
def test_uuid_column_type_default():
"""Test UUID_COLUMN_TYPE uses SQLiteUUID in test environment"""
assert isinstance(UUID_COLUMN_TYPE, SQLiteUUID)
@patch.dict(os.environ, {"MOCK_AUTH": "false"})
def test_uuid_column_type_postgres():
"""Test UUID_COLUMN_TYPE uses Postgres UUID when MOCK_AUTH=false"""
# Need to re-import to get the patched environment
from importlib import reload
from app import models
reload(models.db)
from sqlalchemy.dialects.postgresql import UUID as PostgresUUID
from app.models.db import UUID_COLUMN_TYPE
assert isinstance(UUID_COLUMN_TYPE, PostgresUUID)

View File

43
tests/routers/mocks.py Normal file
View File

@@ -0,0 +1,43 @@
from unittest.mock import Mock
from fastapi import Request
from app.iptv.scheduler import StreamScheduler
class MockScheduler:
"""Base mock APScheduler instance"""
running = True
start = Mock()
shutdown = Mock()
add_job = Mock()
remove_job = Mock()
get_job = Mock(return_value=None)
def __init__(self, running=True):
self.running = running
def create_trigger_mock(triggered_ref: dict) -> callable:
"""Create a mock trigger function that updates a reference when called"""
def trigger_mock():
triggered_ref["value"] = True
return trigger_mock
async def mock_get_scheduler(
request: Request, scheduler_class=MockScheduler, running=True, **kwargs
) -> StreamScheduler:
"""Mock dependency for get_scheduler with customization options"""
scheduler = StreamScheduler()
mock_scheduler = scheduler_class(running=running)
# Apply any additional attributes/methods
for key, value in kwargs.items():
setattr(mock_scheduler, key, value)
scheduler.scheduler = mock_scheduler
return scheduler

101
tests/routers/test_auth.py Normal file
View File

@@ -0,0 +1,101 @@
from unittest.mock import patch
import pytest
from fastapi import HTTPException, status
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
@pytest.fixture
def mock_successful_auth():
return {
"AccessToken": "mock_access_token",
"IdToken": "mock_id_token",
"RefreshToken": "mock_refresh_token",
}
@pytest.fixture
def mock_successful_auth_no_refresh():
return {"AccessToken": "mock_access_token", "IdToken": "mock_id_token"}
def test_signin_success(mock_successful_auth):
"""Test successful signin with all tokens"""
with patch("app.routers.auth.initiate_auth", return_value=mock_successful_auth):
response = client.post(
"/auth/signin", json={"username": "testuser", "password": "testpass"}
)
assert response.status_code == 200
data = response.json()
assert data["access_token"] == "mock_access_token"
assert data["id_token"] == "mock_id_token"
assert data["refresh_token"] == "mock_refresh_token"
assert data["token_type"] == "Bearer"
def test_signin_success_no_refresh(mock_successful_auth_no_refresh):
"""Test successful signin without refresh token"""
with patch(
"app.routers.auth.initiate_auth", return_value=mock_successful_auth_no_refresh
):
response = client.post(
"/auth/signin", json={"username": "testuser", "password": "testpass"}
)
assert response.status_code == 200
data = response.json()
assert data["access_token"] == "mock_access_token"
assert data["id_token"] == "mock_id_token"
assert data["refresh_token"] is None
assert data["token_type"] == "Bearer"
def test_signin_invalid_input():
"""Test signin with invalid input format"""
# Missing password
response = client.post("/auth/signin", json={"username": "testuser"})
assert response.status_code == 422
# Missing username
response = client.post("/auth/signin", json={"password": "testpass"})
assert response.status_code == 422
# Empty payload
response = client.post("/auth/signin", json={})
assert response.status_code == 422
def test_signin_auth_failure():
"""Test signin with authentication failure"""
with patch("app.routers.auth.initiate_auth") as mock_auth:
mock_auth.side_effect = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid username or password",
)
response = client.post(
"/auth/signin", json={"username": "testuser", "password": "wrongpass"}
)
assert response.status_code == 401
data = response.json()
assert data["detail"] == "Invalid username or password"
def test_signin_user_not_found():
"""Test signin with non-existent user"""
with patch("app.routers.auth.initiate_auth") as mock_auth:
mock_auth.side_effect = HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
response = client.post(
"/auth/signin", json={"username": "nonexistent", "password": "testpass"}
)
assert response.status_code == 404
data = response.json()
assert data["detail"] == "User not found"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,461 @@
import uuid
from datetime import datetime, timezone
from fastapi import status
from sqlalchemy.orm import Session
from app.auth.dependencies import get_current_user
from app.utils.database import get_db
# Import mocks and fixtures
from tests.utils.auth_test_fixtures import (
admin_user_client,
db_session,
non_admin_user_client,
)
from tests.utils.db_mocks import (
MockChannelDB,
MockGroup,
create_mock_priorities_and_group,
)
# --- Test Cases For Group Creation ---
def test_create_group_success(db_session: Session, admin_user_client):
group_data = {"name": "Test Group", "sort_order": 1}
response = admin_user_client.post("/groups/", json=group_data)
assert response.status_code == status.HTTP_201_CREATED
data = response.json()
assert data["name"] == "Test Group"
assert data["sort_order"] == 1
assert "id" in data
assert "created_at" in data
assert "updated_at" in data
# Verify in DB
db_group = (
db_session.query(MockGroup).filter(MockGroup.name == "Test Group").first()
)
assert db_group is not None
assert db_group.sort_order == 1
def test_create_group_duplicate(db_session: Session, admin_user_client):
# Create initial group
initial_group = MockGroup(
id=uuid.uuid4(),
name="Duplicate Group",
sort_order=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(initial_group)
db_session.commit()
# Attempt to create duplicate
response = admin_user_client.post(
"/groups/", json={"name": "Duplicate Group", "sort_order": 2}
)
assert response.status_code == status.HTTP_409_CONFLICT
assert "already exists" in response.json()["detail"]
def test_create_group_forbidden_for_non_admin(
db_session: Session, non_admin_user_client
):
response = non_admin_user_client.post(
"/groups/", json={"name": "Forbidden Group", "sort_order": 1}
)
assert response.status_code == status.HTTP_403_FORBIDDEN
assert "required roles" in response.json()["detail"]
# --- Test Cases For Get Group ---
def test_get_group_success(db_session: Session, admin_user_client):
# Create a group first
test_group = MockGroup(
id=uuid.uuid4(),
name="Get Me Group",
sort_order=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(test_group)
db_session.commit()
response = admin_user_client.get(f"/groups/{test_group.id}")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["id"] == str(test_group.id)
assert data["name"] == "Get Me Group"
assert data["sort_order"] == 1
def test_get_group_not_found(db_session: Session, admin_user_client):
random_uuid = uuid.uuid4()
response = admin_user_client.get(f"/groups/{random_uuid}")
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "Group not found" in response.json()["detail"]
# --- Test Cases For Update Group ---
def test_update_group_success(db_session: Session, admin_user_client):
# Create initial group
group_id = uuid.uuid4()
test_group = MockGroup(
id=group_id,
name="Update Me",
sort_order=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(test_group)
db_session.commit()
update_data = {"name": "Updated Name", "sort_order": 2}
response = admin_user_client.put(f"/groups/{group_id}", json=update_data)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["name"] == "Updated Name"
assert data["sort_order"] == 2
# Verify in DB
db_group = db_session.query(MockGroup).filter(MockGroup.id == group_id).first()
assert db_group.name == "Updated Name"
assert db_group.sort_order == 2
def test_update_group_conflict(db_session: Session, admin_user_client):
# Create two groups
group1 = MockGroup(
id=uuid.uuid4(),
name="Group One",
sort_order=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
group2 = MockGroup(
id=uuid.uuid4(),
name="Group Two",
sort_order=2,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add_all([group1, group2])
db_session.commit()
# Try to rename group2 to conflict with group1
response = admin_user_client.put(f"/groups/{group2.id}", json={"name": "Group One"})
assert response.status_code == status.HTTP_409_CONFLICT
assert "already exists" in response.json()["detail"]
def test_update_group_not_found(db_session: Session, admin_user_client):
random_uuid = uuid.uuid4()
response = admin_user_client.put(
f"/groups/{random_uuid}", json={"name": "Non-existent"}
)
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "Group not found" in response.json()["detail"]
def test_update_group_forbidden_for_non_admin(
db_session: Session, non_admin_user_client, admin_user_client
):
# Create group with admin
group_id = uuid.uuid4()
test_group = MockGroup(
id=group_id,
name="Admin Created",
sort_order=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(test_group)
db_session.commit()
# Attempt update with non-admin
response = non_admin_user_client.put(
f"/groups/{group_id}", json={"name": "Non-Admin Update"}
)
assert response.status_code == status.HTTP_403_FORBIDDEN
assert "required roles" in response.json()["detail"]
# --- Test Cases For Delete Group ---
def test_delete_all_groups_success(db_session, admin_user_client):
"""Test reset groups endpoint"""
# Create test data
group1_id = create_mock_priorities_and_group(db_session, [], "Group A")
group2_id = create_mock_priorities_and_group(db_session, [], "Group B")
# Add channel to group2
channel_data = [
{
"group-title": "Group A",
"tvg_id": "channel1.tv",
"name": "Channel One",
"url": ["http://test.com", "http://example.com"],
}
]
admin_user_client.post("/channels/bulk-upload", json=channel_data)
# Reset groups
response = admin_user_client.delete("/groups")
assert response.status_code == status.HTTP_200_OK
assert response.json()["deleted"] == 1 # Only group2 should be deleted
assert response.json()["skipped"] == 1 # group1 has channels
# Verify group2 deleted, group1 remains
assert (
db_session.query(MockGroup).filter(MockGroup.id == group1_id).first()
is not None
)
assert db_session.query(MockGroup).filter(MockGroup.id == group2_id).first() is None
def test_delete_all_groups_forbidden_for_non_admin(db_session, non_admin_user_client):
"""Test reset groups requires admin role"""
response = non_admin_user_client.delete("/groups")
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_delete_group_success(db_session: Session, admin_user_client):
# Create group
group_id = uuid.uuid4()
test_group = MockGroup(
id=group_id,
name="Delete Me",
sort_order=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(test_group)
db_session.commit()
# Verify exists before delete
assert (
db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is not None
)
response = admin_user_client.delete(f"/groups/{group_id}")
assert response.status_code == status.HTTP_204_NO_CONTENT
# Verify deleted
assert db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is None
def test_delete_group_with_channels_fails(db_session: Session, admin_user_client):
# Create group with channel
group_id = uuid.uuid4()
test_group = MockGroup(
id=group_id,
name="Group With Channels",
sort_order=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(test_group)
# Create channel in this group
test_channel = MockChannelDB(
id=uuid.uuid4(),
tvg_id="channel1.tv",
name="Channel 1",
group_id=group_id,
tvg_name="Channel1",
tvg_logo="logo.png",
)
db_session.add(test_channel)
db_session.commit()
response = admin_user_client.delete(f"/groups/{group_id}")
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "existing channels" in response.json()["detail"]
# Verify group still exists
assert (
db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is not None
)
def test_delete_group_not_found(db_session: Session, admin_user_client):
random_uuid = uuid.uuid4()
response = admin_user_client.delete(f"/groups/{random_uuid}")
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "Group not found" in response.json()["detail"]
def test_delete_group_forbidden_for_non_admin(
db_session: Session, non_admin_user_client, admin_user_client
):
# Create group with admin
group_id = uuid.uuid4()
test_group = MockGroup(
id=group_id,
name="Admin Created",
sort_order=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(test_group)
db_session.commit()
# Attempt delete with non-admin
response = non_admin_user_client.delete(f"/groups/{group_id}")
assert response.status_code == status.HTTP_403_FORBIDDEN
assert "required roles" in response.json()["detail"]
# Verify group still exists
assert (
db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is not None
)
# --- Test Cases For List Groups ---
def test_list_groups_empty(db_session: Session, admin_user_client):
response = admin_user_client.get("/groups/")
assert response.status_code == status.HTTP_200_OK
assert response.json() == []
def test_list_groups_with_data(db_session: Session, admin_user_client):
# Create some groups
groups = [
MockGroup(
id=uuid.uuid4(),
name=f"Group {i}",
sort_order=i,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
for i in range(3)
]
db_session.add_all(groups)
db_session.commit()
response = admin_user_client.get("/groups/")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert len(data) == 3
assert data[0]["sort_order"] == 0 # Should be sorted by sort_order
assert data[1]["sort_order"] == 1
assert data[2]["sort_order"] == 2
# --- Test Cases For Sort Order Updates ---
def test_update_group_sort_order_success(db_session: Session, admin_user_client):
# Create group
group_id = uuid.uuid4()
test_group = MockGroup(
id=group_id,
name="Sort Me",
sort_order=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(test_group)
db_session.commit()
response = admin_user_client.put(f"/groups/{group_id}/sort", json={"sort_order": 5})
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["sort_order"] == 5
# Verify in DB
db_group = db_session.query(MockGroup).filter(MockGroup.id == group_id).first()
assert db_group.sort_order == 5
def test_update_group_sort_order_not_found(db_session: Session, admin_user_client):
"""Test that updating sort order for non-existent group returns 404"""
random_uuid = uuid.uuid4()
response = admin_user_client.put(
f"/groups/{random_uuid}/sort", json={"sort_order": 5}
)
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "Group not found" in response.json()["detail"]
def test_bulk_update_sort_orders_success(db_session: Session, admin_user_client):
# Create groups
groups = [
MockGroup(
id=uuid.uuid4(),
name=f"Group {i}",
sort_order=i,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
for i in range(3)
]
print(groups)
db_session.add_all(groups)
db_session.commit()
# Bulk update sort orders (reverse order)
bulk_data = {
"groups": [
{"group_id": str(groups[0].id), "sort_order": 2},
{"group_id": str(groups[1].id), "sort_order": 1},
{"group_id": str(groups[2].id), "sort_order": 0},
]
}
response = admin_user_client.post("/groups/reorder", json=bulk_data)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert len(data) == 3
# Create a dictionary for easy lookup of returned group data by ID
returned_groups_map = {item["id"]: item for item in data}
# Verify each group has its expected new sort_order
assert returned_groups_map[str(groups[0].id)]["sort_order"] == 2
assert returned_groups_map[str(groups[1].id)]["sort_order"] == 1
assert returned_groups_map[str(groups[2].id)]["sort_order"] == 0
# Verify in DB
db_groups = db_session.query(MockGroup).order_by(MockGroup.sort_order).all()
assert db_groups[0].sort_order == 2
assert db_groups[1].sort_order == 1
assert db_groups[2].sort_order == 0
def test_bulk_update_sort_orders_invalid_group(db_session: Session, admin_user_client):
# Create one group
group_id = uuid.uuid4()
test_group = MockGroup(
id=group_id,
name="Valid Group",
sort_order=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(test_group)
db_session.commit()
# Try to update with invalid group
bulk_data = {
"groups": [
{"group_id": str(group_id), "sort_order": 2},
{"group_id": str(uuid.uuid4()), "sort_order": 1}, # Invalid group
]
}
response = admin_user_client.post("/groups/reorder", json=bulk_data)
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "not found" in response.json()["detail"]
# Verify original sort order unchanged
db_group = db_session.query(MockGroup).filter(MockGroup.id == group_id).first()
assert db_group.sort_order == 1

View File

@@ -0,0 +1,261 @@
import uuid
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from fastapi import status
from sqlalchemy.orm import Session
from app.auth.dependencies import get_current_user
# Import the router we're testing
from app.routers.playlist import (
ProcessStatus,
ValidationProcessResponse,
ValidationResultResponse,
router,
validation_processes,
)
from app.utils.database import get_db
# Import mocks and fixtures
from tests.utils.auth_test_fixtures import (
admin_user_client,
db_session,
non_admin_user_client,
)
from tests.utils.db_mocks import MockChannelDB
# --- Test Fixtures ---
@pytest.fixture
def mock_stream_manager():
with patch("app.routers.playlist.StreamManager") as mock:
yield mock
# --- Test Cases For Stream Validation ---
def test_start_stream_validation_success(
db_session: Session, admin_user_client, mock_stream_manager
):
"""Test starting a stream validation process"""
mock_instance = mock_stream_manager.return_value
mock_instance.validate_and_select_stream.return_value = "http://valid.stream.url"
response = admin_user_client.post(
"/playlist/validate-streams", json={"channel_id": "test-channel"}
)
assert response.status_code == status.HTTP_202_ACCEPTED
data = response.json()
assert "process_id" in data
assert data["status"] == ProcessStatus.PENDING
assert data["message"] == "Validation process started"
# Verify process was added to tracking
process_id = data["process_id"]
assert process_id in validation_processes
# In test environment, background tasks run synchronously so status may be COMPLETED
assert validation_processes[process_id]["status"] in [
ProcessStatus.PENDING,
ProcessStatus.COMPLETED,
]
assert validation_processes[process_id]["channel_id"] == "test-channel"
def test_get_validation_status_pending(db_session: Session, admin_user_client):
"""Test checking status of pending validation"""
process_id = str(uuid.uuid4())
validation_processes[process_id] = {
"status": ProcessStatus.PENDING,
"channel_id": "test-channel",
}
response = admin_user_client.get(f"/playlist/validate-streams/{process_id}")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["process_id"] == process_id
assert data["status"] == ProcessStatus.PENDING
assert data["working_streams"] is None
assert data["error"] is None
def test_get_validation_status_completed(db_session: Session, admin_user_client):
"""Test checking status of completed validation"""
process_id = str(uuid.uuid4())
validation_processes[process_id] = {
"status": ProcessStatus.COMPLETED,
"channel_id": "test-channel",
"result": {
"working_streams": [
{"channel_id": "test-channel", "stream_url": "http://valid.stream.url"}
]
},
}
response = admin_user_client.get(f"/playlist/validate-streams/{process_id}")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["process_id"] == process_id
assert data["status"] == ProcessStatus.COMPLETED
assert len(data["working_streams"]) == 1
assert data["working_streams"][0]["channel_id"] == "test-channel"
assert data["working_streams"][0]["stream_url"] == "http://valid.stream.url"
assert data["error"] is None
def test_get_validation_status_completed_with_error(
db_session: Session, admin_user_client
):
"""Test checking status of completed validation with error"""
process_id = str(uuid.uuid4())
validation_processes[process_id] = {
"status": ProcessStatus.COMPLETED,
"channel_id": "test-channel",
"error": "No working streams found for channel test-channel",
}
response = admin_user_client.get(f"/playlist/validate-streams/{process_id}")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["process_id"] == process_id
assert data["status"] == ProcessStatus.COMPLETED
assert data["working_streams"] is None
assert data["error"] == "No working streams found for channel test-channel"
def test_get_validation_status_failed(db_session: Session, admin_user_client):
"""Test checking status of failed validation"""
process_id = str(uuid.uuid4())
validation_processes[process_id] = {
"status": ProcessStatus.FAILED,
"channel_id": "test-channel",
"error": "Validation error occurred",
}
response = admin_user_client.get(f"/playlist/validate-streams/{process_id}")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["process_id"] == process_id
assert data["status"] == ProcessStatus.FAILED
assert data["working_streams"] is None
assert data["error"] == "Validation error occurred"
def test_get_validation_status_not_found(db_session: Session, admin_user_client):
"""Test checking status of non-existent process"""
random_uuid = str(uuid.uuid4())
response = admin_user_client.get(f"/playlist/validate-streams/{random_uuid}")
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "Process not found" in response.json()["detail"]
def test_run_stream_validation_success(mock_stream_manager, db_session):
"""Test the background validation task success case"""
process_id = str(uuid.uuid4())
validation_processes[process_id] = {
"status": ProcessStatus.PENDING,
"channel_id": "test-channel",
}
mock_instance = mock_stream_manager.return_value
mock_instance.validate_and_select_stream.return_value = "http://valid.stream.url"
from app.routers.playlist import run_stream_validation
run_stream_validation(process_id, "test-channel", db_session)
assert validation_processes[process_id]["status"] == ProcessStatus.COMPLETED
assert len(validation_processes[process_id]["result"]["working_streams"]) == 1
assert (
validation_processes[process_id]["result"]["working_streams"][0].channel_id
== "test-channel"
)
assert (
validation_processes[process_id]["result"]["working_streams"][0].stream_url
== "http://valid.stream.url"
)
def test_run_stream_validation_failure(mock_stream_manager, db_session):
"""Test the background validation task failure case"""
process_id = str(uuid.uuid4())
validation_processes[process_id] = {
"status": ProcessStatus.PENDING,
"channel_id": "test-channel",
}
mock_instance = mock_stream_manager.return_value
mock_instance.validate_and_select_stream.return_value = None
from app.routers.playlist import run_stream_validation
run_stream_validation(process_id, "test-channel", db_session)
assert validation_processes[process_id]["status"] == ProcessStatus.COMPLETED
assert "error" in validation_processes[process_id]
assert "No working streams found" in validation_processes[process_id]["error"]
def test_run_stream_validation_exception(mock_stream_manager, db_session):
"""Test the background validation task exception case"""
process_id = str(uuid.uuid4())
validation_processes[process_id] = {
"status": ProcessStatus.PENDING,
"channel_id": "test-channel",
}
mock_instance = mock_stream_manager.return_value
mock_instance.validate_and_select_stream.side_effect = Exception("Test error")
from app.routers.playlist import run_stream_validation
run_stream_validation(process_id, "test-channel", db_session)
assert validation_processes[process_id]["status"] == ProcessStatus.FAILED
assert "error" in validation_processes[process_id]
assert "Test error" in validation_processes[process_id]["error"]
def test_start_stream_validation_no_channel_id(
db_session: Session, admin_user_client, mock_stream_manager
):
"""Test starting validation without channel_id"""
response = admin_user_client.post("/playlist/validate-streams", json={})
assert response.status_code == status.HTTP_202_ACCEPTED
data = response.json()
assert "process_id" in data
assert data["status"] == ProcessStatus.PENDING
# Verify process was added to tracking
process_id = data["process_id"]
assert process_id in validation_processes
assert validation_processes[process_id]["status"] in [
ProcessStatus.PENDING,
ProcessStatus.COMPLETED,
]
assert validation_processes[process_id]["channel_id"] is None
assert "not yet implemented" in validation_processes[process_id].get("error", "")
def test_run_stream_validation_no_channel_id(mock_stream_manager, db_session):
"""Test background validation without channel_id"""
process_id = str(uuid.uuid4())
validation_processes[process_id] = {"status": ProcessStatus.PENDING}
from app.routers.playlist import run_stream_validation
run_stream_validation(process_id, None, db_session)
assert validation_processes[process_id]["status"] == ProcessStatus.COMPLETED
assert "error" in validation_processes[process_id]
assert "not yet implemented" in validation_processes[process_id]["error"]

View File

@@ -0,0 +1,241 @@
import uuid
from datetime import datetime, timezone
from fastapi import status
from sqlalchemy.orm import Session
from app.routers.priorities import router as priorities_router
# Import fixtures and mocks
from tests.utils.auth_test_fixtures import (
admin_user_client,
db_session,
non_admin_user_client,
)
from tests.utils.db_mocks import (
MockChannelDB,
MockChannelURL,
MockGroup,
MockPriority,
create_mock_priorities_and_group,
)
# --- Test Cases For Priority Creation ---
def test_create_priority_success(db_session: Session, admin_user_client):
priority_data = {"id": 100, "description": "Test Priority"}
response = admin_user_client.post("/priorities/", json=priority_data)
assert response.status_code == status.HTTP_201_CREATED
data = response.json()
assert data["id"] == 100
assert data["description"] == "Test Priority"
# Verify in DB
db_priority = db_session.get(MockPriority, 100)
assert db_priority is not None
assert db_priority.description == "Test Priority"
def test_create_priority_duplicate(db_session: Session, admin_user_client):
# Create initial priority
priority_data = {"id": 100, "description": "Original Priority"}
response1 = admin_user_client.post("/priorities/", json=priority_data)
assert response1.status_code == status.HTTP_201_CREATED
# Attempt to create with same ID
response2 = admin_user_client.post("/priorities/", json=priority_data)
assert response2.status_code == status.HTTP_409_CONFLICT
assert "already exists" in response2.json()["detail"]
def test_create_priority_forbidden_for_non_admin(
db_session: Session, non_admin_user_client
):
priority_data = {"id": 100, "description": "Test Priority"}
response = non_admin_user_client.post("/priorities/", json=priority_data)
assert response.status_code == status.HTTP_403_FORBIDDEN
assert "required roles" in response.json()["detail"]
# --- Test Cases For List Priorities ---
def test_list_priorities_empty(db_session: Session, admin_user_client):
response = admin_user_client.get("/priorities/")
assert response.status_code == status.HTTP_200_OK
assert response.json() == []
def test_list_priorities_with_data(db_session: Session, admin_user_client):
# Create some test priorities
priorities = [
MockPriority(id=100, description="High"),
MockPriority(id=200, description="Medium"),
MockPriority(id=300, description="Low"),
]
for priority in priorities:
db_session.add(priority)
db_session.commit()
response = admin_user_client.get("/priorities/")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert len(data) == 3
assert data[0]["id"] == 100
assert data[0]["description"] == "High"
assert data[1]["id"] == 200
assert data[1]["description"] == "Medium"
assert data[2]["id"] == 300
assert data[2]["description"] == "Low"
def test_list_priorities_forbidden_for_non_admin(
db_session: Session, non_admin_user_client
):
response = non_admin_user_client.get("/priorities/")
assert response.status_code == status.HTTP_403_FORBIDDEN
assert "required roles" in response.json()["detail"]
# --- Test Cases For Get Priority ---
def test_get_priority_success(db_session: Session, admin_user_client):
# Create a test priority
priority = MockPriority(id=100, description="Test Priority")
db_session.add(priority)
db_session.commit()
response = admin_user_client.get("/priorities/100")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["id"] == 100
assert data["description"] == "Test Priority"
def test_get_priority_not_found(db_session: Session, admin_user_client):
response = admin_user_client.get("/priorities/999")
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "Priority not found" in response.json()["detail"]
def test_get_priority_forbidden_for_non_admin(
db_session: Session, non_admin_user_client
):
response = non_admin_user_client.get("/priorities/100")
assert response.status_code == status.HTTP_403_FORBIDDEN
assert "required roles" in response.json()["detail"]
# --- Test Cases For Delete Priority ---
def test_delete_all_priorities_success(db_session, admin_user_client):
"""Test reset priorities endpoint"""
# Create test data
priorities = [(100, "High"), (200, "Medium"), (300, "Low")]
for id, desc in priorities:
db_session.add(MockPriority(id=id, description=desc))
db_session.commit()
# Create channel using priority 100
create_mock_priorities_and_group(db_session, [], "Test Group")
channel_data = [
{
"group-title": "Test Group",
"tvg_id": "test.tv",
"name": "Test Channel",
"urls": ["http://test.com"],
}
]
admin_user_client.post("/channels/bulk-upload", json=channel_data)
# Delete all priorities
response = admin_user_client.delete("/priorities")
assert response.status_code == status.HTTP_200_OK
assert response.json()["deleted"] == 2 # Medium and Low priorities
assert response.json()["skipped"] == 1 # High priority is in use
# Verify only priority 100 remains
priorities = db_session.query(MockPriority).all()
assert len(priorities) == 1
assert priorities[0].id == 100
def test_reset_priorities_forbidden_for_non_admin(db_session, non_admin_user_client):
"""Test reset priorities requires admin role"""
response = non_admin_user_client.delete("/priorities")
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_delete_priority_success(db_session: Session, admin_user_client):
# Create a test priority
priority = MockPriority(id=100, description="To Delete")
db_session.add(priority)
db_session.commit()
response = admin_user_client.delete("/priorities/100")
assert response.status_code == status.HTTP_204_NO_CONTENT
# Verify priority is gone from DB
db_priority = db_session.get(MockPriority, 100)
assert db_priority is None
def test_delete_priority_not_found(db_session: Session, admin_user_client):
response = admin_user_client.delete("/priorities/999")
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "Priority not found" in response.json()["detail"]
def test_delete_priority_in_use(db_session: Session, admin_user_client):
# Create a priority and a channel URL using it
priority = MockPriority(id=100, description="In Use")
group_id = uuid.uuid4()
test_group = MockGroup(
id=group_id,
name="Group With Channels",
sort_order=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add_all([priority, test_group])
db_session.commit()
# Create a channel first
channel = MockChannelDB(
name="Test Channel",
tvg_id="test.tv",
tvg_name="Test",
tvg_logo="test.png",
group_id=group_id,
)
db_session.add(channel)
db_session.commit()
# Create URL associated with the channel and priority
channel_url = MockChannelURL(
url="http://test.com",
priority_id=100,
in_use=True,
channel_id=channel.id, # Add the channel_id
)
db_session.add(channel_url)
db_session.commit()
response = admin_user_client.delete("/priorities/100")
assert response.status_code == status.HTTP_409_CONFLICT
assert "in use by channel URLs" in response.json()["detail"]
# Verify priority still exists
db_priority = db_session.get(MockPriority, 100)
assert db_priority is not None
def test_delete_priority_forbidden_for_non_admin(
db_session: Session, non_admin_user_client
):
response = non_admin_user_client.delete("/priorities/100")
assert response.status_code == status.HTTP_403_FORBIDDEN
assert "required roles" in response.json()["detail"]

View File

@@ -0,0 +1,287 @@
from datetime import datetime, timezone
from unittest.mock import Mock
from fastapi import HTTPException, Request, status
from app.iptv.scheduler import StreamScheduler
from app.routers.scheduler import get_scheduler
from app.routers.scheduler import router as scheduler_router
from app.utils.database import get_db
from tests.routers.mocks import MockScheduler, create_trigger_mock, mock_get_scheduler
from tests.utils.auth_test_fixtures import (
admin_user_client,
db_session,
non_admin_user_client,
)
from tests.utils.db_mocks import mock_get_db
# Scheduler Health Check Tests
def test_scheduler_health_success(admin_user_client, monkeypatch):
"""
Test case for successful scheduler health check when accessed by an admin user.
It mocks the scheduler to be running and have a next scheduled job.
"""
# Define the expected next run time for the scheduler job.
next_run = datetime.now(timezone.utc)
# Create a mock job object that simulates an APScheduler job.
mock_job = Mock()
mock_job.next_run_time = next_run
# Mock the `get_job` method to return our mock_job for a specific ID.
def mock_get_job(job_id):
if job_id == "daily_stream_validation":
return mock_job
return None
# Create a custom mock for `get_scheduler` dependency.
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
return await mock_get_scheduler(
request,
running=True,
get_job=Mock(side_effect=mock_get_job), # Use the custom mock_get_job
)
# Include the scheduler router in the test application.
admin_user_client.app.include_router(scheduler_router)
# Override dependencies for the test.
admin_user_client.app.dependency_overrides[get_scheduler] = (
custom_mock_get_scheduler
)
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
# Make the request to the scheduler health endpoint.
response = admin_user_client.get("/scheduler/health")
# Assert the response status code and content.
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["status"] == "running"
assert data["next_run"] == str(next_run)
def test_scheduler_health_stopped(admin_user_client, monkeypatch):
"""
Test case for scheduler health check when the scheduler is in a stopped state.
Ensures the API returns the correct status and no next run time.
"""
# Create a custom mock for `get_scheduler` dependency,
# simulating a stopped scheduler.
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
return await mock_get_scheduler(
request,
running=False,
)
# Include the scheduler router in the test application.
admin_user_client.app.include_router(scheduler_router)
# Override dependencies for the test.
admin_user_client.app.dependency_overrides[get_scheduler] = (
custom_mock_get_scheduler
)
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
# Make the request to the scheduler health endpoint.
response = admin_user_client.get("/scheduler/health")
# Assert the response status code and content.
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["status"] == "stopped"
assert data["next_run"] is None
def test_scheduler_health_forbidden_for_non_admin(non_admin_user_client, monkeypatch):
"""
Test case to ensure that non-admin users are forbidden from accessing
the scheduler health endpoint.
"""
# Create a custom mock for `get_scheduler` dependency.
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
return await mock_get_scheduler(
request,
running=False,
)
# Include the scheduler router in the test application.
non_admin_user_client.app.include_router(scheduler_router)
# Override dependencies for the test.
non_admin_user_client.app.dependency_overrides[get_scheduler] = (
custom_mock_get_scheduler
)
non_admin_user_client.app.dependency_overrides[get_db] = mock_get_db
# Make the request to the scheduler health endpoint.
response = non_admin_user_client.get("/scheduler/health")
# Assert the response status code and error detail.
assert response.status_code == status.HTTP_403_FORBIDDEN
assert "required roles" in response.json()["detail"]
def test_scheduler_health_check_exception(admin_user_client, monkeypatch):
"""
Test case for handling exceptions during the scheduler health check.
Ensures the API returns a 500 Internal Server Error when an exception occurs.
"""
# Create a custom mock for `get_scheduler` dependency that raises an exception.
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
return await mock_get_scheduler(
request, running=True, get_job=Mock(side_effect=Exception("Test exception"))
)
# Include the scheduler router in the test application.
admin_user_client.app.include_router(scheduler_router)
# Override dependencies for the test.
admin_user_client.app.dependency_overrides[get_scheduler] = (
custom_mock_get_scheduler
)
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
# Make the request to the scheduler health endpoint.
response = admin_user_client.get("/scheduler/health")
# Assert the response status code and error detail.
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert "Failed to check scheduler health" in response.json()["detail"]
# Scheduler Trigger Tests
def test_trigger_validation_success(admin_user_client, monkeypatch):
"""
Test case for successful manual triggering
of stream validation by an admin user.
It verifies that the trigger method is called and
the API returns a 202 Accepted status.
"""
# Use a mutable reference to check if the trigger method was called.
triggered_ref = {"value": False}
# Initialize a custom mock scheduler.
custom_scheduler = MockScheduler(running=True)
custom_scheduler.get_job = Mock(return_value=None)
# Create a custom mock for `get_scheduler` dependency,
# overriding `trigger_manual_validation`.
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
scheduler = await mock_get_scheduler(
request,
running=True,
)
# Replace the actual trigger method with our mock to track calls.
scheduler.trigger_manual_validation = create_trigger_mock(
triggered_ref=triggered_ref
)
return scheduler
# Include the scheduler router in the test application.
admin_user_client.app.include_router(scheduler_router)
# Override dependencies for the test.
admin_user_client.app.dependency_overrides[get_scheduler] = (
custom_mock_get_scheduler
)
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
# Make the request to trigger stream validation.
response = admin_user_client.post("/scheduler/trigger")
# Assert the response status code, message, and that the trigger was called.
assert response.status_code == status.HTTP_202_ACCEPTED
assert response.json()["message"] == "Stream validation triggered"
assert triggered_ref["value"] is True
def test_trigger_validation_forbidden_for_non_admin(non_admin_user_client, monkeypatch):
"""
Test case to ensure that non-admin users are
forbidden from manually triggering stream validation.
"""
# Create a custom mock for `get_scheduler` dependency.
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
return await mock_get_scheduler(
request,
running=True,
)
# Include the scheduler router in the test application.
non_admin_user_client.app.include_router(scheduler_router)
# Override dependencies for the test.
non_admin_user_client.app.dependency_overrides[get_scheduler] = (
custom_mock_get_scheduler
)
non_admin_user_client.app.dependency_overrides[get_db] = mock_get_db
# Make the request to trigger stream validation.
response = non_admin_user_client.post("/scheduler/trigger")
# Assert the response status code and error detail.
assert response.status_code == status.HTTP_403_FORBIDDEN
assert "required roles" in response.json()["detail"]
def test_scheduler_initialized_in_app_state(admin_user_client):
"""
Test case for when the scheduler is initialized in the app state but its internal
scheduler attribute is not set, which should still allow health check.
"""
scheduler = StreamScheduler()
# Set the scheduler instance in the test client's app state.
admin_user_client.app.state.scheduler = scheduler
# Include the scheduler router in the test application.
admin_user_client.app.include_router(scheduler_router)
# Override only get_db, allowing the real get_scheduler to be tested.
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
# Make the request to the scheduler health endpoint.
response = admin_user_client.get("/scheduler/health")
# Assert the response status code.
assert response.status_code == status.HTTP_200_OK
def test_scheduler_not_initialized_in_app_state(admin_user_client):
"""
Test case for when the scheduler is not properly initialized in the app state.
This simulates a scenario where the internal scheduler attribute is missing,
leading to a 500 Internal Server Error on health check.
"""
scheduler = StreamScheduler()
del (
scheduler.scheduler
) # Simulate uninitialized scheduler by deleting the attribute
# Set the scheduler instance in the test client's app state.
admin_user_client.app.state.scheduler = scheduler
# Include the scheduler router in the test application.
admin_user_client.app.include_router(scheduler_router)
# Override only get_db, allowing the real get_scheduler to be tested.
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
# Make the request to the scheduler health endpoint.
response = admin_user_client.get("/scheduler/health")
# Assert the response status code and error detail.
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert "Scheduler not initialized" in response.json()["detail"]

81
tests/test_main.py Normal file
View File

@@ -0,0 +1,81 @@
from unittest.mock import patch
import pytest
from fastapi.testclient import TestClient
from app.main import app, lifespan
@pytest.fixture
def client():
"""Test client for FastAPI app"""
return TestClient(app)
def test_root_endpoint(client):
"""Test root endpoint returns expected message"""
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"message": "IPTV Manager API"}
def test_openapi_schema_generation(client):
"""Test OpenAPI schema is properly generated"""
# First call - generate schema
response = client.get("/openapi.json")
assert response.status_code == 200
schema = response.json()
assert schema["openapi"] == "3.1.0"
assert "securitySchemes" in schema["components"]
assert "Bearer" in schema["components"]["securitySchemes"]
# Test empty components initialization
with patch("app.main.get_openapi", return_value={"info": {}}):
# Clear cached schema
app.openapi_schema = None
# Get schema with empty response
response = client.get("/openapi.json")
assert response.status_code == 200
schema = response.json()
assert "components" in schema
assert "schemas" in schema["components"]
def test_openapi_schema_caching(mocker):
"""Test OpenAPI schema caching behavior"""
# Clear any existing schema
app.openapi_schema = None
# Mock get_openapi to return test schema
mock_schema = {"test": "schema"}
mocker.patch("app.main.get_openapi", return_value=mock_schema)
# First call - should call get_openapi
schema = app.openapi()
assert schema == mock_schema
assert app.openapi_schema == mock_schema
# Second call - should return cached schema
with patch("app.main.get_openapi") as mock_get_openapi:
schema = app.openapi()
assert schema == mock_schema
mock_get_openapi.assert_not_called()
@pytest.mark.asyncio
async def test_lifespan_init_db(mocker):
"""Test lifespan manager initializes database"""
mock_init_db = mocker.patch("app.main.init_db")
async with lifespan(app):
pass # Just enter/exit context
mock_init_db.assert_called_once()
def test_router_inclusion():
"""Test all routers are properly included"""
route_paths = {route.path for route in app.routes}
assert "/" in route_paths
assert any(path.startswith("/auth") for path in route_paths)
assert any(path.startswith("/channels") for path in route_paths)
assert any(path.startswith("/playlist") for path in route_paths)
assert any(path.startswith("/priorities") for path in route_paths)

0
tests/utils/__init__.py Normal file
View File

View File

@@ -0,0 +1,82 @@
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
from app.auth.dependencies import get_current_user
from app.models.auth import CognitoUser
from app.routers.channels import router as channels_router
from app.routers.groups import router as groups_router
from app.routers.playlist import router as playlist_router
from app.routers.priorities import router as priorities_router
from app.utils.database import get_db
from tests.utils.db_mocks import (
MockBase,
MockChannelDB,
MockChannelURL,
MockPriority,
engine_mock,
mock_get_db,
)
from tests.utils.db_mocks import session_mock as TestingSessionLocal
def mock_get_current_user_admin():
return CognitoUser(
username="testadmin",
email="testadmin@example.com",
roles=["admin"],
user_status="CONFIRMED",
enabled=True,
)
def mock_get_current_user_non_admin():
return CognitoUser(
username="testuser",
email="testuser@example.com",
roles=["user"], # Or any role other than admin
user_status="CONFIRMED",
enabled=True,
)
@pytest.fixture(scope="function")
def db_session():
# Create tables for each test function
MockBase.metadata.create_all(bind=engine_mock)
db = TestingSessionLocal()
try:
yield db
finally:
db.close()
# Drop tables after each test function
MockBase.metadata.drop_all(bind=engine_mock)
@pytest.fixture(scope="function")
def admin_user_client(db_session: Session):
"""Yields a TestClient configured with an admin user."""
test_app = FastAPI()
test_app.include_router(channels_router)
test_app.include_router(priorities_router)
test_app.include_router(playlist_router)
test_app.include_router(groups_router)
test_app.dependency_overrides[get_db] = mock_get_db
test_app.dependency_overrides[get_current_user] = mock_get_current_user_admin
with TestClient(test_app) as test_client:
yield test_client
@pytest.fixture(scope="function")
def non_admin_user_client(db_session: Session):
"""Yields a TestClient configured with a non-admin user."""
test_app = FastAPI()
test_app.include_router(channels_router)
test_app.include_router(priorities_router)
test_app.include_router(playlist_router)
test_app.include_router(groups_router)
test_app.dependency_overrides[get_db] = mock_get_db
test_app.dependency_overrides[get_current_user] = mock_get_current_user_non_admin
with TestClient(test_app) as test_client:
yield test_client

149
tests/utils/db_mocks.py Normal file
View File

@@ -0,0 +1,149 @@
import uuid
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from sqlalchemy import (
Boolean,
Column,
DateTime,
ForeignKey,
Integer,
String,
UniqueConstraint,
create_engine,
)
from sqlalchemy.orm import declarative_base, relationship, sessionmaker
from sqlalchemy.pool import StaticPool
# Import the actual UUID_COLUMN_TYPE and SQLiteUUID from app.models.db
from app.models.db import UUID_COLUMN_TYPE, SQLiteUUID
# Create a mock-specific Base class for testing
MockBase = declarative_base()
# Model classes for testing - prefix with Mock to avoid pytest collection
class MockPriority(MockBase):
__tablename__ = "priorities"
id = Column(Integer, primary_key=True)
description = Column(String, nullable=False)
class MockGroup(MockBase):
__tablename__ = "groups"
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
name = Column(String, nullable=False, unique=True)
sort_order = Column(Integer, nullable=False, default=0)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(
DateTime,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
channels = relationship("MockChannelDB", back_populates="group")
class MockChannelDB(MockBase):
__tablename__ = "channels"
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
tvg_id = Column(String, nullable=False)
name = Column(String, nullable=False)
group_id = Column(UUID_COLUMN_TYPE, ForeignKey("groups.id"), nullable=False)
tvg_name = Column(String)
__table_args__ = (UniqueConstraint("group_id", "name", name="uix_group_id_name"),)
tvg_logo = Column(String)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(
DateTime,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
group = relationship("MockGroup", back_populates="channels")
urls = relationship(
"MockChannelURL", back_populates="channel", cascade="all, delete-orphan"
)
class MockChannelURL(MockBase):
__tablename__ = "channels_urls"
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
channel_id = Column(
UUID_COLUMN_TYPE, ForeignKey("channels.id", ondelete="CASCADE"), nullable=False
)
url = Column(String, nullable=False)
in_use = Column(Boolean, default=False, nullable=False)
priority_id = Column(Integer, ForeignKey("priorities.id"), nullable=False)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(
DateTime,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
channel = relationship("MockChannelDB", back_populates="urls")
def create_mock_priorities_and_group(db_session, priorities, group_name):
"""Create mock priorities and group for testing purposes.
Args:
db_session: SQLAlchemy session object
priorities: List of (id, description) tuples for priorities to create
group_name: Name for the new mock group
Returns:
UUID: The ID of the created group
"""
# Create priorities
priority_objects = [
MockPriority(id=priority_id, description=description)
for priority_id, description in priorities
]
# Create group
group = MockGroup(name=group_name)
db_session.add_all(priority_objects + [group])
db_session.commit()
return group.id
# Create test engine
engine_mock = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
# Create test session
session_mock = sessionmaker(autocommit=False, autoflush=False, bind=engine_mock)
# Mock the actual database functions
def mock_get_db():
db = session_mock()
try:
yield db
finally:
db.close()
@pytest.fixture(autouse=True)
def mock_env(monkeypatch):
"""Fixture for mocking environment variables"""
monkeypatch.setenv("MOCK_AUTH", "true")
monkeypatch.setenv("DB_USER", "testuser")
monkeypatch.setenv("DB_PASSWORD", "testpass")
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_NAME", "testdb")
monkeypatch.setenv("AWS_REGION", "us-east-1")
@pytest.fixture
def mock_ssm():
"""Fixture for mocking boto3 SSM client"""
with patch("boto3.client") as mock_client:
mock_ssm = MagicMock()
mock_client.return_value = mock_ssm
mock_ssm.get_parameter.return_value = {"Parameter": {"Value": "mocked_value"}}
yield mock_ssm

View File

@@ -0,0 +1,309 @@
import os
from unittest.mock import MagicMock, Mock, mock_open, patch
import pytest
import requests
from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout
from app.utils.check_streams import StreamValidator, main
@pytest.fixture
def validator():
"""Create a StreamValidator instance for testing"""
return StreamValidator(timeout=1)
def test_validator_init():
"""Test StreamValidator initialization with default and custom values"""
# Test with default user agent
validator = StreamValidator()
assert validator.timeout == 10
assert "Mozilla" in validator.session.headers["User-Agent"]
# Test with custom values
custom_agent = "CustomAgent/1.0"
validator = StreamValidator(timeout=5, user_agent=custom_agent)
assert validator.timeout == 5
assert validator.session.headers["User-Agent"] == custom_agent
def test_is_valid_content_type(validator):
"""Test content type validation"""
valid_types = [
"video/mp4",
"video/mp2t",
"application/vnd.apple.mpegurl",
"application/dash+xml",
"video/webm",
"application/octet-stream",
"application/x-mpegURL",
"video/mp4; charset=utf-8", # Test with additional parameters
]
invalid_types = [
"text/html",
"application/json",
"image/jpeg",
"",
]
for content_type in valid_types:
assert validator._is_valid_content_type(content_type)
for content_type in invalid_types:
assert not validator._is_valid_content_type(content_type)
# Test None case explicitly
assert not validator._is_valid_content_type(None)
@pytest.mark.parametrize(
"status_code,content_type,should_succeed",
[
(200, "video/mp4", True),
(206, "video/mp4", True), # Partial content
(404, "video/mp4", False),
(500, "video/mp4", False),
(200, "text/html", False),
(200, "application/json", False),
],
)
def test_validate_stream_response_handling(status_code, content_type, should_succeed):
"""Test stream validation with different response scenarios"""
mock_response = MagicMock()
mock_response.status_code = status_code
mock_response.headers = {"Content-Type": content_type}
mock_response.iter_content.return_value = iter([b"some content"])
mock_session = MagicMock()
mock_session.get.return_value.__enter__.return_value = mock_response
with patch("requests.Session", return_value=mock_session):
validator = StreamValidator()
valid, message = validator.validate_stream("http://example.com/stream")
assert valid == should_succeed
mock_session.get.assert_called_once()
def test_validate_stream_connection_error():
"""Test stream validation with connection error"""
mock_session = MagicMock()
mock_session.get.side_effect = ConnectionError("Connection failed")
with patch("requests.Session", return_value=mock_session):
validator = StreamValidator()
valid, message = validator.validate_stream("http://example.com/stream")
assert not valid
assert "Connection Error" in message
def test_validate_stream_timeout():
"""Test stream validation with timeout"""
mock_session = MagicMock()
mock_session.get.side_effect = Timeout("Request timed out")
with patch("requests.Session", return_value=mock_session):
validator = StreamValidator()
valid, message = validator.validate_stream("http://example.com/stream")
assert not valid
assert "timeout" in message.lower()
def test_validate_stream_http_error():
"""Test stream validation with HTTP error"""
mock_session = MagicMock()
mock_session.get.side_effect = HTTPError("HTTP Error occurred")
with patch("requests.Session", return_value=mock_session):
validator = StreamValidator()
valid, message = validator.validate_stream("http://example.com/stream")
assert not valid
assert "HTTP Error" in message
def test_validate_stream_request_exception():
"""Test stream validation with general request exception"""
mock_session = MagicMock()
mock_session.get.side_effect = RequestException("Request failed")
with patch("requests.Session", return_value=mock_session):
validator = StreamValidator()
valid, message = validator.validate_stream("http://example.com/stream")
assert not valid
assert "Request Exception" in message
def test_validate_stream_content_read_error():
"""Test stream validation when content reading fails"""
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "video/mp4"}
mock_response.iter_content.side_effect = ConnectionError("Read failed")
mock_session = MagicMock()
mock_session.get.return_value.__enter__.return_value = mock_response
with patch("requests.Session", return_value=mock_session):
validator = StreamValidator()
valid, message = validator.validate_stream("http://example.com/stream")
assert not valid
assert "Connection failed during content read" in message
def test_validate_stream_general_exception():
"""Test validate_stream with an unexpected exception"""
mock_session = MagicMock()
mock_session.get.side_effect = Exception("Unexpected error")
with patch("requests.Session", return_value=mock_session):
validator = StreamValidator()
valid, message = validator.validate_stream("http://example.com/stream")
assert not valid
assert "Validation error" in message
def test_parse_playlist(validator, tmp_path):
"""Test playlist file parsing"""
playlist_content = """
#EXTM3U
#EXTINF:-1,Channel 1
http://example.com/stream1
#EXTINF:-1,Channel 2
http://example.com/stream2
http://example.com/stream3
"""
playlist_file = tmp_path / "test_playlist.m3u"
playlist_file.write_text(playlist_content)
urls = validator.parse_playlist(str(playlist_file))
assert len(urls) == 3
assert urls == [
"http://example.com/stream1",
"http://example.com/stream2",
"http://example.com/stream3",
]
def test_parse_playlist_error(validator):
"""Test playlist parsing with non-existent file"""
with pytest.raises(Exception):
validator.parse_playlist("nonexistent_file.m3u")
@patch("app.utils.check_streams.logging")
@patch("app.utils.check_streams.StreamValidator")
def test_main_with_urls(mock_validator_class, mock_logging, tmp_path, capsys):
"""Test main function with direct URLs"""
# Setup mock validator
mock_validator = Mock()
mock_validator_class.return_value = mock_validator
mock_validator.validate_stream.return_value = (True, "Stream is valid")
# Setup test arguments
test_args = ["script", "http://example.com/stream1", "http://example.com/stream2"]
with patch("sys.argv", test_args):
main()
# Verify validator was called correctly
assert mock_validator.validate_stream.call_count == 2
mock_validator.validate_stream.assert_any_call("http://example.com/stream1")
mock_validator.validate_stream.assert_any_call("http://example.com/stream2")
@patch("app.utils.check_streams.logging")
@patch("app.utils.check_streams.StreamValidator")
def test_main_with_playlist(mock_validator_class, mock_logging, tmp_path):
"""Test main function with a playlist file"""
# Create test playlist
playlist_content = "http://example.com/stream1\nhttp://example.com/stream2"
playlist_file = tmp_path / "test.m3u"
playlist_file.write_text(playlist_content)
# Setup mock validator
mock_validator = Mock()
mock_validator_class.return_value = mock_validator
mock_validator.parse_playlist.return_value = [
"http://example.com/stream1",
"http://example.com/stream2",
]
mock_validator.validate_stream.return_value = (True, "Stream is valid")
# Setup test arguments
test_args = ["script", str(playlist_file)]
with patch("sys.argv", test_args):
main()
# Verify validator was called correctly
mock_validator.parse_playlist.assert_called_once_with(str(playlist_file))
assert mock_validator.validate_stream.call_count == 2
@patch("app.utils.check_streams.logging")
@patch("app.utils.check_streams.StreamValidator")
def test_main_with_dead_streams(mock_validator_class, mock_logging, tmp_path):
"""Test main function handling dead streams"""
# Setup mock validator
mock_validator = Mock()
mock_validator_class.return_value = mock_validator
mock_validator.validate_stream.return_value = (False, "Stream is dead")
# Setup test arguments
test_args = ["script", "http://example.com/dead1", "http://example.com/dead2"]
# Mock file operations
mock_file = mock_open()
with patch("sys.argv", test_args), patch("builtins.open", mock_file):
main()
# Verify dead streams were written to file
mock_file().write.assert_called_once_with(
"http://example.com/dead1\nhttp://example.com/dead2"
)
@patch("app.utils.check_streams.logging")
@patch("app.utils.check_streams.StreamValidator")
@patch("os.path.isfile")
def test_main_with_playlist_error(
mock_isfile, mock_validator_class, mock_logging, tmp_path
):
"""Test main function handling playlist parsing errors"""
# Setup mock validator
mock_validator = Mock()
mock_validator_class.return_value = mock_validator
# Configure mock validator behavior
error_msg = "Failed to parse playlist"
mock_validator.parse_playlist.side_effect = [
Exception(error_msg), # First call fails
["http://example.com/stream1"], # Second call succeeds
]
mock_validator.validate_stream.return_value = (True, "Stream is valid")
# Configure isfile mock to return True for our test files
mock_isfile.side_effect = lambda x: x in ["/invalid.m3u", "/valid.m3u"]
# Setup test arguments
test_args = ["script", "/invalid.m3u", "/valid.m3u"]
with patch("sys.argv", test_args):
main()
# Verify error was logged correctly
mock_logging.error.assert_called_with(
"Failed to process file /invalid.m3u: Failed to parse playlist"
)
# Verify processing continued with valid playlist
mock_validator.parse_playlist.assert_called_with("/valid.m3u")
assert (
mock_validator.validate_stream.call_count == 1
) # Called for the URL from valid playlist

View File

@@ -0,0 +1,69 @@
import os
import pytest
from sqlalchemy.orm import Session
from app.utils.database import get_db, get_db_credentials
from tests.utils.db_mocks import mock_env, mock_ssm, session_mock
def test_get_db_credentials_env(mock_env):
"""Test getting DB credentials from environment variables"""
conn_str = get_db_credentials()
assert conn_str == "postgresql://testuser:testpass@localhost/testdb"
def test_get_db_credentials_ssm(mock_ssm):
"""Test getting DB credentials from SSM"""
os.environ.pop("MOCK_AUTH", None)
conn_str = get_db_credentials()
expected_conn = "postgresql://mocked_value:mocked_value@mocked_value/mocked_value"
assert expected_conn in conn_str
mock_ssm.get_parameter.assert_called()
def test_get_db_credentials_ssm_exception(mock_ssm):
"""Test SSM credential fetching failure raises RuntimeError"""
os.environ.pop("MOCK_AUTH", None)
mock_ssm.get_parameter.side_effect = Exception("SSM timeout")
with pytest.raises(RuntimeError) as excinfo:
get_db_credentials()
assert "Failed to fetch DB credentials from SSM: SSM timeout" in str(excinfo.value)
def test_session_creation():
"""Test database session creation"""
session = session_mock()
assert isinstance(session, Session)
session.close()
def test_get_db_generator():
"""Test get_db dependency generator"""
db_gen = get_db()
db = next(db_gen)
assert isinstance(db, Session)
try:
next(db_gen) # Should raise StopIteration
except StopIteration:
pass
def test_init_db(mocker, mock_env):
"""Test database initialization creates tables"""
mock_create_all = mocker.patch("app.models.Base.metadata.create_all")
# Mock get_db_credentials to return SQLite test connection
mocker.patch(
"app.utils.database.get_db_credentials",
return_value="sqlite:///:memory:",
)
from app.utils.database import engine, init_db
init_db()
# Verify create_all was called with the engine
mock_create_all.assert_called_once_with(bind=engine)