Compare commits
74 Commits
35745c43bd
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| a42d4c30a6 | |||
| abb467749b | |||
| b8ac25e301 | |||
| 729eabf27f | |||
| 34c446bcfa | |||
| d4cc74ea8c | |||
| 21b73b6843 | |||
| e743daf9f7 | |||
| b0d98551b8 | |||
| eaab1ef998 | |||
| e25f8c1ecd | |||
| 95bf0f9701 | |||
| f7a1c20066 | |||
| bf6f156fec | |||
| 7e25ec6755 | |||
| 6d506122d9 | |||
| 02913c7385 | |||
| e46f13930d | |||
| 903f190ee2 | |||
| 32af6bbdb5 | |||
| 7ee7a0e644 | |||
| 9474a3ca44 | |||
| 1ab8599dde | |||
| fb5215b92a | |||
| cebbb9c1a8 | |||
| 4b1a7e9bea | |||
| 21cc99eff6 | |||
| 76dc8908de | |||
| 07dab76e3b | |||
| c21b34f5fe | |||
| b1942354d9 | |||
| 3937269bb9 | |||
| 8c7ed421c9 | |||
| c96ee307db | |||
| f11d533fac | |||
| 99d26a8f53 | |||
| 260fcb311b | |||
| 1e82418cad | |||
| 9c690fe6a6 | |||
| 9e8df169fc | |||
| 5ee6cb4be4 | |||
| c1e3a6ef26 | |||
| cb793ef5e1 | |||
| be719a6e34 | |||
| 5767124031 | |||
| c6f7e9cb2b | |||
| eeb0f1c844 | |||
| 4cb3811d17 | |||
| 489281f3eb | |||
| b947ac67f0 | |||
| dd2446a01a | |||
| 639adba7eb | |||
| 5698e7f26b | |||
| df3fc2f37c | |||
| 594ce0c67a | |||
| 37be1f3f91 | |||
| 732667cf64 | |||
| 5bc7a72a92 | |||
| a5dfc1b493 | |||
| 0b69ffd67c | |||
| 127d81adac | |||
| 658f7998ef | |||
| c4f19999dc | |||
| c221a8cded | |||
| 8d1997fa5a | |||
| 795a25961f | |||
| d55c383bc4 | |||
| 5c17e4b1e9 | |||
| 30ccf86c86 | |||
| ae040fc49e | |||
| 47befceb17 | |||
| 7f282049ac | |||
| 38e5a94701 | |||
| 7b7ff78030 |
19
.env.example
Normal file
19
.env.example
Normal file
@@ -0,0 +1,19 @@
|
||||
|
||||
# Environment variables
|
||||
# Scheduler configuration
|
||||
STREAM_VALIDATION_SCHEDULE=0 3 * * * # Daily at 3 AM (cron syntax)
|
||||
STREAM_VALIDATION_BATCH_SIZE=10 # Number of channels per batch (0=all)
|
||||
|
||||
# For use with Docker Compose to run application locally
|
||||
MOCK_AUTH=true/false
|
||||
DB_USER=MyDBUser
|
||||
DB_PASSWORD=MyDBPassword
|
||||
DB_HOST=MyDBHost
|
||||
DB_NAME=iptv_manager
|
||||
|
||||
FREEDNS_User=MyFreeDNSUsername
|
||||
FREEDNS_Password=MyFreeDNSPassword
|
||||
DOMAIN_NAME=mydomain.com
|
||||
SSH_PUBLIC_KEY="ssh-rsa AAAAB3NzaC1yc2EMYPUBLICKEY7+"
|
||||
REPO_URL="https://git.example.com/user/repo.git"
|
||||
LETSENCRYPT_EMAIL="admin@example.com"
|
||||
@@ -39,6 +39,13 @@ jobs:
|
||||
|
||||
- name: Deploy to AWS
|
||||
run: cdk deploy --app="python3 ${PWD}/app.py" --require-approval=never
|
||||
env:
|
||||
FREEDNS_User: ${{ secrets.FREEDNS_USER }}
|
||||
FREEDNS_Password: ${{ secrets.FREEDNS_PASSWORD }}
|
||||
DOMAIN_NAME: ${{ secrets.DOMAIN_NAME }}
|
||||
SSH_PUBLIC_KEY: ${{ secrets.SSH_PUBLIC_KEY }}
|
||||
REPO_URL: ${{ secrets.REPO_URL }}
|
||||
LETSENCRYPT_EMAIL: ${{ secrets.LETSENCRYPT_EMAIL }}
|
||||
|
||||
- name: Install AWS CLI
|
||||
run: |
|
||||
@@ -50,20 +57,23 @@ jobs:
|
||||
- name: Update application on instance
|
||||
run: |
|
||||
INSTANCE_IDS=$(aws ec2 describe-instances \
|
||||
--filters "Name=tag:Name,Values=IptvUpdater/IptvUpdaterInstance" \
|
||||
--region us-east-2 \
|
||||
--filters "Name=tag:Name,Values=IptvManagerStack/IptvManagerInstance" \
|
||||
"Name=instance-state-name,Values=running" \
|
||||
--query "Reservations[].Instances[].InstanceId" \
|
||||
--output text)
|
||||
|
||||
for INSTANCE_ID in $INSTANCE_IDS; do
|
||||
aws ssm send-command \
|
||||
--region us-east-2 \
|
||||
--instance-ids "$INSTANCE_ID" \
|
||||
--document-name "AWS-RunShellScript" \
|
||||
--parameters 'commands=[
|
||||
"cd /home/ec2-user/iptv-updater-aws",
|
||||
"cd /home/ec2-user/iptv-manager-service",
|
||||
"git pull",
|
||||
"pip3 install -r requirements.txt",
|
||||
"sudo systemctl restart iptv-updater"
|
||||
"alembic upgrade head",
|
||||
"sudo systemctl restart iptv-manager"
|
||||
]'
|
||||
done
|
||||
|
||||
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -4,10 +4,16 @@ __pycache__
|
||||
.pytest_cache
|
||||
.env
|
||||
.venv
|
||||
*.pid
|
||||
*.log
|
||||
*.egg-info
|
||||
.coverage
|
||||
.roomodes
|
||||
cdk.out/
|
||||
node_modules/
|
||||
data/
|
||||
.roo/
|
||||
.ruru/
|
||||
|
||||
# CDK asset staging directory
|
||||
.cdk.staging
|
||||
|
||||
16
.pre-commit-config.yaml
Normal file
16
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.12
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix]
|
||||
- id: ruff-format
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: pytest-check
|
||||
name: pytest-check
|
||||
entry: pytest
|
||||
language: system
|
||||
pass_filenames: false
|
||||
always_run: true
|
||||
85
.vscode/settings.json
vendored
85
.vscode/settings.json
vendored
@@ -1,22 +1,105 @@
|
||||
{
|
||||
"python.terminal.activateEnvironment": true,
|
||||
"python.terminal.activateEnvInCurrentTerminal": true,
|
||||
"editor.formatOnSave": true,
|
||||
"editor.defaultFormatter": "charliermarsh.ruff",
|
||||
"ruff.importStrategy": "fromEnvironment",
|
||||
"ruff.path": ["${workspaceFolder}"],
|
||||
"cSpell.words": [
|
||||
"addopts",
|
||||
"adminpassword",
|
||||
"altinstall",
|
||||
"apscheduler",
|
||||
"asyncio",
|
||||
"autoflush",
|
||||
"autoupdate",
|
||||
"autouse",
|
||||
"awscli",
|
||||
"awscliv",
|
||||
"boto",
|
||||
"botocore",
|
||||
"BURSTABLE",
|
||||
"cabletv",
|
||||
"capsys",
|
||||
"CDUF",
|
||||
"cduflogo",
|
||||
"cdulogo",
|
||||
"CDUNF",
|
||||
"cdunflogo",
|
||||
"certbot",
|
||||
"certifi",
|
||||
"cfulogo",
|
||||
"CLEU",
|
||||
"cleulogo",
|
||||
"CLUF",
|
||||
"cluflogo",
|
||||
"clulogo",
|
||||
"cpulogo",
|
||||
"crond",
|
||||
"cronie",
|
||||
"cuflgo",
|
||||
"CUNF",
|
||||
"cunflogo",
|
||||
"cuulogo",
|
||||
"datname",
|
||||
"deadstreams",
|
||||
"delenv",
|
||||
"delogo",
|
||||
"devel",
|
||||
"dflogo",
|
||||
"dmlogo",
|
||||
"dotenv",
|
||||
"EXTINF",
|
||||
"EXTM",
|
||||
"fastapi",
|
||||
"filterwarnings",
|
||||
"fiorinis",
|
||||
"freedns",
|
||||
"fullchain",
|
||||
"gitea",
|
||||
"httpx",
|
||||
"iptv",
|
||||
"isort",
|
||||
"KHTML",
|
||||
"lclogo",
|
||||
"LETSENCRYPT",
|
||||
"levelname",
|
||||
"mpegurl",
|
||||
"nohup",
|
||||
"nopriority",
|
||||
"ondelete",
|
||||
"onupdate",
|
||||
"passlib",
|
||||
"PGPASSWORD",
|
||||
"poolclass",
|
||||
"psql",
|
||||
"psycopg",
|
||||
"pycache",
|
||||
"pycodestyle",
|
||||
"pyflakes",
|
||||
"pyjwt",
|
||||
"pytest",
|
||||
"PYTHONDONTWRITEBYTECODE",
|
||||
"PYTHONUNBUFFERED",
|
||||
"pyupgrade",
|
||||
"reloadcmd",
|
||||
"roomodes",
|
||||
"ruru",
|
||||
"sessionmaker",
|
||||
"sqlalchemy",
|
||||
"sqliteuuid",
|
||||
"starlette",
|
||||
"stefano",
|
||||
"testadmin",
|
||||
"testdb",
|
||||
"testpass",
|
||||
"testpaths",
|
||||
"testuser",
|
||||
"uflogo",
|
||||
"umlogo",
|
||||
"usefixtures",
|
||||
"uvicorn",
|
||||
"venv"
|
||||
"venv",
|
||||
"wrongpass"
|
||||
]
|
||||
}
|
||||
150
README.md
150
README.md
@@ -1 +1,149 @@
|
||||
# To do
|
||||
# IPTV Manager Service
|
||||
|
||||
A FastAPI-based service for managing IPTV playlists and channel priorities. The application provides secure endpoints for user authentication, channel management, and playlist generation.
|
||||
|
||||
## ✨ Features
|
||||
|
||||
- **JWT Authentication**: Secure login using AWS Cognito
|
||||
- **Channel Management**: CRUD operations for IPTV channels
|
||||
- **Playlist Generation**: Create M3U playlists with channel priorities
|
||||
- **Stream Monitoring**: Background checks for channel availability
|
||||
- **Priority Management**: Set channel priorities for playlist ordering
|
||||
|
||||
## 🛠️ Technology Stack
|
||||
|
||||
- **Backend**: Python 3.11, FastAPI
|
||||
- **Database**: PostgreSQL (SQLAlchemy ORM)
|
||||
- **Authentication**: AWS Cognito
|
||||
- **Infrastructure**: AWS CDK (API Gateway, Lambda, RDS)
|
||||
- **Testing**: Pytest with 85%+ coverage
|
||||
- **CI/CD**: Pre-commit hooks, Alembic migrations
|
||||
|
||||
## 🚀 Getting Started
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.11+
|
||||
- Docker
|
||||
- AWS CLI (for deployment)
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Clone repository
|
||||
git clone https://github.com/your-repo/iptv-manager-service.git
|
||||
cd iptv-manager-service
|
||||
|
||||
# Setup environment
|
||||
python -m venv venv
|
||||
source venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
cp .env.example .env # Update with your values
|
||||
|
||||
# Run installation script
|
||||
./scripts/install.sh
|
||||
```
|
||||
|
||||
### Running Locally
|
||||
|
||||
```bash
|
||||
# Start development environment
|
||||
./scripts/start_local_dev.sh
|
||||
|
||||
# Stop development environment
|
||||
./scripts/stop_local_dev.sh
|
||||
```
|
||||
|
||||
## ☁️ AWS Deployment
|
||||
|
||||
The infrastructure is defined in CDK. Use the provided scripts:
|
||||
|
||||
```bash
|
||||
# Deploy AWS infrastructure
|
||||
./scripts/deploy.sh
|
||||
|
||||
# Destroy AWS infrastructure
|
||||
./scripts/destroy.sh
|
||||
|
||||
# Create Cognito test user
|
||||
./scripts/create_cognito_user.sh
|
||||
|
||||
# Delete Cognito user
|
||||
./scripts/delete_cognito_user.sh
|
||||
```
|
||||
|
||||
Key AWS components:
|
||||
|
||||
- API Gateway
|
||||
- Lambda functions
|
||||
- RDS PostgreSQL
|
||||
- Cognito User Pool
|
||||
|
||||
## 🤖 Continuous Integration/Deployment
|
||||
|
||||
This project includes a Gitea Actions workflow (`.gitea/workflows/deploy.yml`) for automated deployment to AWS. The workflow is fully compatible with GitHub Actions and can be easily adapted by:
|
||||
|
||||
1. Placing the workflow file in the `.github/workflows/` directory
|
||||
2. Setting up the required secrets in your CI/CD environment:
|
||||
- `AWS_ACCESS_KEY_ID`
|
||||
- `AWS_SECRET_ACCESS_KEY`
|
||||
- `AWS_DEFAULT_REGION`
|
||||
|
||||
The workflow automatically deploys the infrastructure and application when changes are pushed to the main branch.
|
||||
|
||||
## 📚 API Documentation
|
||||
|
||||
Access interactive docs at:
|
||||
|
||||
- Swagger UI: `http://localhost:8000/docs`
|
||||
- ReDoc: `http://localhost:8000/redoc`
|
||||
|
||||
### Key Endpoints
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
| ------------- | ------ | --------------------- |
|
||||
| `/auth/login` | POST | User authentication |
|
||||
| `/channels` | GET | List all channels |
|
||||
| `/playlist` | GET | Generate M3U playlist |
|
||||
| `/priorities` | POST | Set channel priority |
|
||||
|
||||
## 🧪 Testing
|
||||
|
||||
Run the full test suite:
|
||||
|
||||
```bash
|
||||
pytest
|
||||
```
|
||||
|
||||
Test coverage includes:
|
||||
|
||||
- Authentication workflows
|
||||
- Channel CRUD operations
|
||||
- Playlist generation logic
|
||||
- Stream monitoring
|
||||
- Database operations
|
||||
|
||||
## 📂 Project Structure
|
||||
|
||||
```txt
|
||||
iptv-manager-service/
|
||||
├── app/ # Core application
|
||||
│ ├── auth/ # Cognito authentication
|
||||
│ ├── iptv/ # Playlist logic
|
||||
│ ├── models/ # Database models
|
||||
│ ├── routers/ # API endpoints
|
||||
│ ├── utils/ # Helper functions
|
||||
│ └── main.py # App entry point
|
||||
├── infrastructure/ # AWS CDK stack
|
||||
├── docker/ # Docker configs
|
||||
├── scripts/ # Deployment scripts
|
||||
├── tests/ # Comprehensive tests
|
||||
├── alembic/ # Database migrations
|
||||
├── .gitea/ # Gitea CI/CD workflows
|
||||
│ └── workflows/
|
||||
└── ... # Config files
|
||||
```
|
||||
|
||||
## 📝 License
|
||||
|
||||
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
||||
|
||||
141
alembic.ini
Normal file
141
alembic.ini
Normal file
@@ -0,0 +1,141 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts.
|
||||
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||
# format, relative to the token %(here)s which refers to the location of this
|
||||
# ini file
|
||||
script_location = %(here)s/alembic
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory. for multiple paths, the path separator
|
||||
# is defined by "path_separator" below.
|
||||
prepend_sys_path = .
|
||||
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
|
||||
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to <script_location>/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "path_separator"
|
||||
# below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
|
||||
|
||||
# path_separator; This indicates what character is used to split lists of file
|
||||
# paths, including version_locations and prepend_sys_path within configparser
|
||||
# files such as alembic.ini.
|
||||
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
|
||||
# to provide os-dependent path splitting.
|
||||
#
|
||||
# Note that in order to support legacy alembic.ini files, this default does NOT
|
||||
# take place if path_separator is not present in alembic.ini. If this
|
||||
# option is omitted entirely, fallback logic is as follows:
|
||||
#
|
||||
# 1. Parsing of the version_locations option falls back to using the legacy
|
||||
# "version_path_separator" key, which if absent then falls back to the legacy
|
||||
# behavior of splitting on spaces and/or commas.
|
||||
# 2. Parsing of the prepend_sys_path option falls back to the legacy
|
||||
# behavior of splitting on spaces, commas, or colons.
|
||||
#
|
||||
# Valid values for path_separator are:
|
||||
#
|
||||
# path_separator = :
|
||||
# path_separator = ;
|
||||
# path_separator = space
|
||||
# path_separator = newline
|
||||
#
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
path_separator = os
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# database URL. This is consumed by the user-maintained env.py script only.
|
||||
# other means of configuring database URLs may be customized within the env.py
|
||||
# file.
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = %(here)s/.venv/bin/ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration. This is also consumed by the user-maintained
|
||||
# env.py script only.
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
1
alembic/README
Normal file
1
alembic/README
Normal file
@@ -0,0 +1 @@
|
||||
Generic single-database configuration.
|
||||
79
alembic/env.py
Normal file
79
alembic/env.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
from alembic import context
|
||||
from app.models.db import Base
|
||||
from app.utils.database import get_db_credentials
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# Override sqlalchemy.url with dynamic credentials
|
||||
if not context.is_offline_mode():
|
||||
config.set_main_option("sqlalchemy.url", get_db_credentials())
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
28
alembic/script.py.mako
Normal file
28
alembic/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -0,0 +1,110 @@
|
||||
"""add groups table and migrate group_title data
|
||||
|
||||
Revision ID: 0a455608256f
|
||||
Revises: 95b61a92455a
|
||||
Create Date: 2025-06-10 09:22:11.820035
|
||||
|
||||
"""
|
||||
import uuid
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '0a455608256f'
|
||||
down_revision: Union[str, None] = '95b61a92455a'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('groups',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=False),
|
||||
sa.Column('sort_order', sa.Integer(), nullable=False, server_default='0'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('name')
|
||||
)
|
||||
|
||||
# Create temporary table for group mapping
|
||||
group_mapping = op.create_table(
|
||||
'group_mapping',
|
||||
sa.Column('group_title', sa.String(), nullable=False),
|
||||
sa.Column('group_id', sa.UUID(), nullable=False)
|
||||
)
|
||||
|
||||
# Get existing group titles and create groups
|
||||
conn = op.get_bind()
|
||||
distinct_groups = conn.execute(
|
||||
sa.text("SELECT DISTINCT group_title FROM channels")
|
||||
).fetchall()
|
||||
|
||||
for group in distinct_groups:
|
||||
group_title = group[0]
|
||||
group_id = str(uuid.uuid4())
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"INSERT INTO groups (id, name, sort_order) "
|
||||
"VALUES (:id, :name, 0)"
|
||||
).bindparams(id=group_id, name=group_title)
|
||||
)
|
||||
conn.execute(
|
||||
group_mapping.insert().values(
|
||||
group_title=group_title,
|
||||
group_id=group_id
|
||||
)
|
||||
)
|
||||
|
||||
# Add group_id column (nullable first)
|
||||
op.add_column('channels', sa.Column('group_id', sa.UUID(), nullable=True))
|
||||
|
||||
# Update channels with group_ids
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"UPDATE channels c SET group_id = gm.group_id "
|
||||
"FROM group_mapping gm WHERE c.group_title = gm.group_title"
|
||||
)
|
||||
)
|
||||
|
||||
# Now make group_id non-nullable and add constraints
|
||||
op.alter_column('channels', 'group_id', nullable=False)
|
||||
op.drop_constraint(op.f('uix_group_title_name'), 'channels', type_='unique')
|
||||
op.create_unique_constraint('uix_group_id_name', 'channels', ['group_id', 'name'])
|
||||
op.create_foreign_key('fk_channels_group_id', 'channels', 'groups', ['group_id'], ['id'])
|
||||
|
||||
# Clean up and drop group_title
|
||||
op.drop_table('group_mapping')
|
||||
op.drop_column('channels', 'group_title')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('channels', sa.Column('group_title', sa.VARCHAR(), autoincrement=False, nullable=True))
|
||||
|
||||
# Restore group_title values from groups table
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"UPDATE channels c SET group_title = g.name "
|
||||
"FROM groups g WHERE c.group_id = g.id"
|
||||
)
|
||||
)
|
||||
|
||||
# Now make group_title non-nullable
|
||||
op.alter_column('channels', 'group_title', nullable=False)
|
||||
|
||||
# Drop constraints and columns
|
||||
op.drop_constraint('fk_channels_group_id', 'channels', type_='foreignkey')
|
||||
op.drop_constraint('uix_group_id_name', 'channels', type_='unique')
|
||||
op.create_unique_constraint(op.f('uix_group_title_name'), 'channels', ['group_title', 'name'])
|
||||
op.drop_column('channels', 'group_id')
|
||||
op.drop_table('groups')
|
||||
# ### end Alembic commands ###
|
||||
79
alembic/versions/95b61a92455a_create_initial_tables.py
Normal file
79
alembic/versions/95b61a92455a_create_initial_tables.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""create initial tables
|
||||
|
||||
Revision ID: 95b61a92455a
|
||||
Revises:
|
||||
Create Date: 2025-05-29 14:42:16.239587
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '95b61a92455a'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('channels',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('tvg_id', sa.String(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=False),
|
||||
sa.Column('group_title', sa.String(), nullable=False),
|
||||
sa.Column('tvg_name', sa.String(), nullable=True),
|
||||
sa.Column('tvg_logo', sa.String(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('group_title', 'name', name='uix_group_title_name')
|
||||
)
|
||||
op.create_table('priorities',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('description', sa.String(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('channels_urls',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('channel_id', sa.UUID(), nullable=False),
|
||||
sa.Column('url', sa.String(), nullable=False),
|
||||
sa.Column('in_use', sa.Boolean(), nullable=False),
|
||||
sa.Column('priority_id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['channel_id'], ['channels.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['priority_id'], ['priorities.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
# Seed initial priorities
|
||||
op.bulk_insert(
|
||||
sa.Table(
|
||||
'priorities',
|
||||
sa.MetaData(),
|
||||
sa.Column('id', sa.Integer),
|
||||
sa.Column('description', sa.String),
|
||||
),
|
||||
[
|
||||
{'id': 100, 'description': 'High'},
|
||||
{'id': 200, 'description': 'Medium'},
|
||||
{'id': 300, 'description': 'Low'},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# Remove seeded priorities
|
||||
op.execute("DELETE FROM priorities WHERE id IN (100, 200, 300);")
|
||||
|
||||
# Drop tables
|
||||
op.drop_table('channels_urls')
|
||||
op.drop_table('priorities')
|
||||
op.drop_table('channels')
|
||||
42
app.py
42
app.py
@@ -1,7 +1,45 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
|
||||
import aws_cdk as cdk
|
||||
from infrastructure.stack import IptvUpdaterStack
|
||||
|
||||
from infrastructure.stack import IptvManagerStack
|
||||
|
||||
app = cdk.App()
|
||||
IptvUpdaterStack(app, "IptvUpdater")
|
||||
|
||||
# Read environment variables for FreeDNS credentials
|
||||
freedns_user = os.environ.get("FREEDNS_User")
|
||||
freedns_password = os.environ.get("FREEDNS_Password")
|
||||
domain_name = os.environ.get("DOMAIN_NAME")
|
||||
ssh_public_key = os.environ.get("SSH_PUBLIC_KEY")
|
||||
repo_url = os.environ.get("REPO_URL")
|
||||
letsencrypt_email = os.environ.get("LETSENCRYPT_EMAIL")
|
||||
|
||||
required_vars = {
|
||||
"FREEDNS_User": freedns_user,
|
||||
"FREEDNS_Password": freedns_password,
|
||||
"DOMAIN_NAME": domain_name,
|
||||
"SSH_PUBLIC_KEY": ssh_public_key,
|
||||
"REPO_URL": repo_url,
|
||||
"LETSENCRYPT_EMAIL": letsencrypt_email,
|
||||
}
|
||||
|
||||
# Check for missing required variables
|
||||
missing_vars = [k for k, v in required_vars.items() if not v]
|
||||
if missing_vars:
|
||||
raise ValueError(
|
||||
f"Missing required environment variables: {', '.join(missing_vars)}"
|
||||
)
|
||||
|
||||
IptvManagerStack(
|
||||
app,
|
||||
"IptvManagerStack",
|
||||
freedns_user=freedns_user,
|
||||
freedns_password=freedns_password,
|
||||
domain_name=domain_name,
|
||||
ssh_public_key=ssh_public_key,
|
||||
repo_url=repo_url,
|
||||
letsencrypt_email=letsencrypt_email,
|
||||
)
|
||||
|
||||
app.synth()
|
||||
82
app/auth/cognito.py
Normal file
82
app/auth/cognito.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import boto3
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.models.auth import CognitoUser
|
||||
from app.utils.auth import calculate_secret_hash
|
||||
from app.utils.constants import (
|
||||
AWS_REGION,
|
||||
COGNITO_CLIENT_ID,
|
||||
COGNITO_CLIENT_SECRET,
|
||||
USER_ROLE_ATTRIBUTE,
|
||||
)
|
||||
|
||||
cognito_client = boto3.client("cognito-idp", region_name=AWS_REGION)
|
||||
|
||||
|
||||
def initiate_auth(username: str, password: str) -> dict:
|
||||
"""
|
||||
Initiate AUTH flow with Cognito using USER_PASSWORD_AUTH.
|
||||
"""
|
||||
auth_params = {"USERNAME": username, "PASSWORD": password}
|
||||
|
||||
# If a client secret is required, add SECRET_HASH
|
||||
if COGNITO_CLIENT_SECRET:
|
||||
auth_params["SECRET_HASH"] = calculate_secret_hash(
|
||||
username, COGNITO_CLIENT_ID, COGNITO_CLIENT_SECRET
|
||||
)
|
||||
|
||||
try:
|
||||
response = cognito_client.initiate_auth(
|
||||
AuthFlow="USER_PASSWORD_AUTH",
|
||||
AuthParameters=auth_params,
|
||||
ClientId=COGNITO_CLIENT_ID,
|
||||
)
|
||||
return response["AuthenticationResult"]
|
||||
except cognito_client.exceptions.NotAuthorizedException:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password",
|
||||
)
|
||||
except cognito_client.exceptions.UserNotFoundException:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"An error occurred during authentication: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
def get_user_from_token(access_token: str) -> CognitoUser:
|
||||
"""
|
||||
Verify the token by calling GetUser in Cognito and
|
||||
retrieve user attributes including roles.
|
||||
"""
|
||||
try:
|
||||
user_response = cognito_client.get_user(AccessToken=access_token)
|
||||
username = user_response.get("Username", "")
|
||||
attributes = user_response.get("UserAttributes", [])
|
||||
user_roles = []
|
||||
|
||||
for attr in attributes:
|
||||
if attr["Name"] == USER_ROLE_ATTRIBUTE:
|
||||
# Assume roles are stored as a comma-separated string
|
||||
user_roles = [r.strip() for r in attr["Value"].split(",") if r.strip()]
|
||||
break
|
||||
|
||||
return CognitoUser(username=username, roles=user_roles)
|
||||
except cognito_client.exceptions.NotAuthorizedException:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token."
|
||||
)
|
||||
except cognito_client.exceptions.UserNotFoundException:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found or invalid token.",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Token verification failed: {str(e)}",
|
||||
)
|
||||
51
app/auth/dependencies.py
Normal file
51
app/auth/dependencies.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import os
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
from app.models.auth import CognitoUser
|
||||
|
||||
# Use mock auth for local testing if MOCK_AUTH is set
|
||||
if os.getenv("MOCK_AUTH", "").lower() == "true":
|
||||
from app.auth.mock_auth import mock_get_user_from_token as get_user_from_token
|
||||
else:
|
||||
from app.auth.cognito import get_user_from_token
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="signin", scheme_name="Bearer")
|
||||
|
||||
|
||||
def get_current_user(token: str = Depends(oauth2_scheme)) -> CognitoUser:
|
||||
"""
|
||||
Dependency to get the current user from the given token.
|
||||
This will verify the token with Cognito and return the user's information.
|
||||
"""
|
||||
return get_user_from_token(token)
|
||||
|
||||
|
||||
def require_roles(*required_roles: str) -> Callable:
|
||||
"""
|
||||
Decorator for role-based access control.
|
||||
Use on endpoints to enforce that the user possesses all required roles.
|
||||
"""
|
||||
|
||||
def decorator(endpoint: Callable) -> Callable:
|
||||
@wraps(endpoint)
|
||||
async def wrapper(
|
||||
*args, user: CognitoUser = Depends(get_current_user), **kwargs
|
||||
):
|
||||
user_roles = set(user.roles or [])
|
||||
needed_roles = set(required_roles)
|
||||
if not needed_roles.issubset(user_roles):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=(
|
||||
"You do not have the required roles to access this endpoint."
|
||||
),
|
||||
)
|
||||
return endpoint(*args, user=user, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
26
app/auth/mock_auth.py
Normal file
26
app/auth/mock_auth.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.models.auth import CognitoUser
|
||||
|
||||
MOCK_USERS = {"testuser": {"username": "testuser", "roles": ["admin"]}}
|
||||
|
||||
|
||||
def mock_get_user_from_token(token: str) -> CognitoUser:
|
||||
"""
|
||||
Mock version of get_user_from_token for local testing
|
||||
Accepts 'testuser' as a valid token and returns admin user
|
||||
"""
|
||||
if token == "testuser":
|
||||
return CognitoUser(**MOCK_USERS["testuser"])
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid mock token - use 'testuser'",
|
||||
)
|
||||
|
||||
|
||||
def mock_initiate_auth(username: str, password: str) -> dict:
|
||||
"""
|
||||
Mock version of initiate_auth for local testing
|
||||
Accepts any username/password and returns a mock token
|
||||
"""
|
||||
return {"AccessToken": "testuser", "ExpiresIn": 3600, "TokenType": "Bearer"}
|
||||
@@ -1,52 +0,0 @@
|
||||
import os
|
||||
import boto3
|
||||
import requests
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2AuthorizationCodeBearer
|
||||
from fastapi.responses import RedirectResponse
|
||||
from typing import Optional
|
||||
|
||||
REGION = "us-east-2"
|
||||
USER_POOL_ID = os.getenv("COGNITO_USER_POOL_ID")
|
||||
CLIENT_ID = os.getenv("COGNITO_CLIENT_ID")
|
||||
DOMAIN = f"https://iptv-updater.auth.{REGION}.amazoncognito.com"
|
||||
REDIRECT_URI = "http://localhost:8000/auth/callback"
|
||||
|
||||
oauth2_scheme = OAuth2AuthorizationCodeBearer(
|
||||
authorizationUrl=f"{DOMAIN}/oauth2/authorize",
|
||||
tokenUrl=f"{DOMAIN}/oauth2/token"
|
||||
)
|
||||
|
||||
def exchange_code_for_token(code: str):
|
||||
token_url = f"{DOMAIN}/oauth2/token"
|
||||
data = {
|
||||
'grant_type': 'authorization_code',
|
||||
'client_id': CLIENT_ID,
|
||||
'code': code,
|
||||
'redirect_uri': REDIRECT_URI
|
||||
}
|
||||
|
||||
response = requests.post(token_url, data=data)
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
raise HTTPException(status_code=400, detail="Failed to exchange code for token")
|
||||
|
||||
async def get_current_user(token: str = Depends(oauth2_scheme)):
|
||||
if not token:
|
||||
return RedirectResponse(
|
||||
f"{DOMAIN}/login?client_id={CLIENT_ID}"
|
||||
f"&response_type=code"
|
||||
f"&scope=openid"
|
||||
f"&redirect_uri={REDIRECT_URI}"
|
||||
)
|
||||
|
||||
try:
|
||||
cognito = boto3.client('cognito-idp', region_name=REGION)
|
||||
response = cognito.get_user(AccessToken=token)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
@@ -1,39 +1,59 @@
|
||||
import os
|
||||
import re
|
||||
import argparse
|
||||
import gzip
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
import requests
|
||||
import argparse
|
||||
from utils.config import IPTV_SERVER_ADMIN_PASSWORD, IPTV_SERVER_ADMIN_USER, IPTV_SERVER_URL
|
||||
from utils.constants import (
|
||||
IPTV_SERVER_ADMIN_PASSWORD,
|
||||
IPTV_SERVER_ADMIN_USER,
|
||||
IPTV_SERVER_URL,
|
||||
)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description='EPG Grabber')
|
||||
parser.add_argument('--playlist',
|
||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'playlist.m3u8'),
|
||||
help='Path to playlist file')
|
||||
parser.add_argument('--output',
|
||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'epg.xml'),
|
||||
help='Path to output EPG XML file')
|
||||
parser.add_argument('--epg-sources',
|
||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'epg_sources.json'),
|
||||
help='Path to EPG sources JSON configuration file')
|
||||
parser.add_argument('--save-as-gz',
|
||||
action='store_true',
|
||||
parser = argparse.ArgumentParser(description="EPG Grabber")
|
||||
parser.add_argument(
|
||||
"--playlist",
|
||||
default=os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "playlist.m3u8"
|
||||
),
|
||||
help="Path to playlist file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "epg.xml"),
|
||||
help="Path to output EPG XML file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epg-sources",
|
||||
default=os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "epg_sources.json"
|
||||
),
|
||||
help="Path to EPG sources JSON configuration file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-as-gz",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help='Save an additional gzipped version of the EPG file')
|
||||
help="Save an additional gzipped version of the EPG file",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_epg_sources(config_path):
|
||||
"""Load EPG sources from JSON configuration file"""
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
return config.get('epg_sources', [])
|
||||
return config.get("epg_sources", [])
|
||||
except (FileNotFoundError, json.JSONDecodeError) as e:
|
||||
print(f"Error loading EPG sources: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_tvg_ids(playlist_path):
|
||||
"""
|
||||
Extracts unique tvg-id values from an M3U playlist file.
|
||||
@@ -51,9 +71,9 @@ def get_tvg_ids(playlist_path):
|
||||
# and ends with a double quote.
|
||||
tvg_id_pattern = re.compile(r'tvg-id="([^"]*)"')
|
||||
|
||||
with open(playlist_path, 'r', encoding='utf-8') as file:
|
||||
with open(playlist_path, encoding="utf-8") as file:
|
||||
for line in file:
|
||||
if line.startswith('#EXTINF'):
|
||||
if line.startswith("#EXTINF"):
|
||||
# Search for the tvg-id pattern in the line
|
||||
match = tvg_id_pattern.search(line)
|
||||
if match:
|
||||
@@ -64,13 +84,14 @@ def get_tvg_ids(playlist_path):
|
||||
|
||||
return list(unique_tvg_ids)
|
||||
|
||||
|
||||
def fetch_and_extract_xml(url):
|
||||
response = requests.get(url)
|
||||
if response.status_code != 200:
|
||||
print(f"Failed to fetch {url}")
|
||||
return None
|
||||
|
||||
if url.endswith('.gz'):
|
||||
if url.endswith(".gz"):
|
||||
try:
|
||||
decompressed_data = gzip.decompress(response.content)
|
||||
return ET.fromstring(decompressed_data)
|
||||
@@ -84,42 +105,46 @@ def fetch_and_extract_xml(url):
|
||||
print(f"Failed to parse XML from {url}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def filter_and_build_epg(urls, tvg_ids, output_file, save_as_gz=True):
|
||||
root = ET.Element('tv')
|
||||
root = ET.Element("tv")
|
||||
|
||||
for url in urls:
|
||||
epg_data = fetch_and_extract_xml(url)
|
||||
if epg_data is None:
|
||||
continue
|
||||
|
||||
for channel in epg_data.findall('channel'):
|
||||
tvg_id = channel.get('id')
|
||||
for channel in epg_data.findall("channel"):
|
||||
tvg_id = channel.get("id")
|
||||
if tvg_id in tvg_ids:
|
||||
root.append(channel)
|
||||
|
||||
for programme in epg_data.findall('programme'):
|
||||
tvg_id = programme.get('channel')
|
||||
for programme in epg_data.findall("programme"):
|
||||
tvg_id = programme.get("channel")
|
||||
if tvg_id in tvg_ids:
|
||||
root.append(programme)
|
||||
|
||||
tree = ET.ElementTree(root)
|
||||
tree.write(output_file, encoding='utf-8', xml_declaration=True)
|
||||
tree.write(output_file, encoding="utf-8", xml_declaration=True)
|
||||
print(f"New EPG saved to {output_file}")
|
||||
|
||||
if save_as_gz:
|
||||
output_file_gz = output_file + '.gz'
|
||||
with gzip.open(output_file_gz, 'wb') as f:
|
||||
tree.write(f, encoding='utf-8', xml_declaration=True)
|
||||
output_file_gz = output_file + ".gz"
|
||||
with gzip.open(output_file_gz, "wb") as f:
|
||||
tree.write(f, encoding="utf-8", xml_declaration=True)
|
||||
print(f"New EPG saved to {output_file_gz}")
|
||||
|
||||
|
||||
def upload_epg(file_path):
|
||||
"""Uploads gzipped EPG file to IPTV server using HTTP Basic Auth"""
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
with open(file_path, "rb") as f:
|
||||
response = requests.post(
|
||||
IPTV_SERVER_URL + '/admin/epg',
|
||||
auth=requests.auth.HTTPBasicAuth(IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD),
|
||||
files={'file': (os.path.basename(file_path), f)}
|
||||
IPTV_SERVER_URL + "/admin/epg",
|
||||
auth=requests.auth.HTTPBasicAuth(
|
||||
IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD
|
||||
),
|
||||
files={"file": (os.path.basename(file_path), f)},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
@@ -129,6 +154,7 @@ def upload_epg(file_path):
|
||||
except Exception as e:
|
||||
print(f"Upload error: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments()
|
||||
playlist_file = args.playlist
|
||||
@@ -144,4 +170,4 @@ if __name__ == "__main__":
|
||||
filter_and_build_epg(urls, tvg_ids, output_file, args.save_as_gz)
|
||||
|
||||
if args.save_as_gz:
|
||||
upload_epg(output_file + '.gz')
|
||||
upload_epg(output_file + ".gz")
|
||||
@@ -1,26 +1,45 @@
|
||||
import os
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import requests
|
||||
from pathlib import Path
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
from utils.check_streams import StreamValidator
|
||||
from utils.config import EPG_URL, IPTV_SERVER_ADMIN_PASSWORD, IPTV_SERVER_ADMIN_USER, IPTV_SERVER_URL
|
||||
from utils.constants import (
|
||||
EPG_URL,
|
||||
IPTV_SERVER_ADMIN_PASSWORD,
|
||||
IPTV_SERVER_ADMIN_USER,
|
||||
IPTV_SERVER_URL,
|
||||
)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description='IPTV playlist generator')
|
||||
parser.add_argument('--output',
|
||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'playlist.m3u8'),
|
||||
help='Path to output playlist file')
|
||||
parser.add_argument('--channels',
|
||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'channels.json'),
|
||||
help='Path to channels definition JSON file')
|
||||
parser.add_argument('--dead-channels-log',
|
||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'dead_channels.log'),
|
||||
help='Path to log file to store a list of dead channels')
|
||||
parser = argparse.ArgumentParser(description="IPTV playlist generator")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default=os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "playlist.m3u8"
|
||||
),
|
||||
help="Path to output playlist file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--channels",
|
||||
default=os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "channels.json"
|
||||
),
|
||||
help="Path to channels definition JSON file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dead-channels-log",
|
||||
default=os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "dead_channels.log"
|
||||
),
|
||||
help="Path to log file to store a list of dead channels",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def find_working_stream(validator, urls):
|
||||
"""Test all URLs and return the first working one"""
|
||||
for url in urls:
|
||||
@@ -29,9 +48,10 @@ def find_working_stream(validator, urls):
|
||||
return url
|
||||
return None
|
||||
|
||||
|
||||
def create_playlist(channels_file, output_file):
|
||||
# Read channels from JSON file
|
||||
with open(channels_file, 'r', encoding='utf-8') as f:
|
||||
with open(channels_file, encoding="utf-8") as f:
|
||||
channels = json.load(f)
|
||||
|
||||
# Initialize validator
|
||||
@@ -41,9 +61,9 @@ def create_playlist(channels_file, output_file):
|
||||
m3u8_content = f'#EXTM3U url-tvg="{EPG_URL}"\n'
|
||||
|
||||
for channel in channels:
|
||||
if 'urls' in channel: # Check if channel has URLs
|
||||
if "urls" in channel: # Check if channel has URLs
|
||||
# Find first working stream
|
||||
working_url = find_working_stream(validator, channel['urls'])
|
||||
working_url = find_working_stream(validator, channel["urls"])
|
||||
|
||||
if working_url:
|
||||
# Add channel to playlist
|
||||
@@ -51,24 +71,30 @@ def create_playlist(channels_file, output_file):
|
||||
m3u8_content += f'tvg-name="{channel.get("tvg-name", "")}" '
|
||||
m3u8_content += f'tvg-logo="{channel.get("tvg-logo", "")}" '
|
||||
m3u8_content += f'group-title="{channel.get("group-title", "")}", '
|
||||
m3u8_content += f'{channel.get("name", "")}\n'
|
||||
m3u8_content += f'{working_url}\n'
|
||||
m3u8_content += f"{channel.get('name', '')}\n"
|
||||
m3u8_content += f"{working_url}\n"
|
||||
else:
|
||||
# Log dead channel
|
||||
logging.info(f'Dead channel: {channel.get("name", "Unknown")} - No working streams found')
|
||||
logging.info(
|
||||
f"Dead channel: {channel.get('name', 'Unknown')} - "
|
||||
"No working streams found"
|
||||
)
|
||||
|
||||
# Write playlist file
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
f.write(m3u8_content)
|
||||
|
||||
|
||||
def upload_playlist(file_path):
|
||||
"""Uploads playlist file to IPTV server using HTTP Basic Auth"""
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
with open(file_path, "rb") as f:
|
||||
response = requests.post(
|
||||
IPTV_SERVER_URL + '/admin/playlist',
|
||||
auth=requests.auth.HTTPBasicAuth(IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD),
|
||||
files={'file': (os.path.basename(file_path), f)}
|
||||
IPTV_SERVER_URL + "/admin/playlist",
|
||||
auth=requests.auth.HTTPBasicAuth(
|
||||
IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD
|
||||
),
|
||||
files={"file": (os.path.basename(file_path), f)},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
@@ -78,6 +104,7 @@ def upload_playlist(file_path):
|
||||
except Exception as e:
|
||||
print(f"Upload error: {str(e)}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
channels_file = args.channels
|
||||
@@ -85,15 +112,15 @@ def main():
|
||||
dead_channels_log_file = args.dead_channels_log
|
||||
|
||||
# Clear previous log file
|
||||
with open(dead_channels_log_file, 'w') as f:
|
||||
f.write(f'Log created on {datetime.now()}\n')
|
||||
with open(dead_channels_log_file, "w") as f:
|
||||
f.write(f"Log created on {datetime.now()}\n")
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
filename=dead_channels_log_file,
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
format="%(asctime)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
# Create playlist
|
||||
@@ -104,5 +131,6 @@ def main():
|
||||
|
||||
print("Playlist creation completed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
110
app/iptv/scheduler.py
Normal file
110
app/iptv/scheduler.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from fastapi import FastAPI
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.iptv.stream_manager import StreamManager
|
||||
from app.models.db import ChannelDB
|
||||
from app.utils.database import get_db_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamScheduler:
|
||||
"""Scheduler service for periodic stream validation tasks."""
|
||||
|
||||
def __init__(self, app: Optional[FastAPI] = None):
|
||||
"""
|
||||
Initialize the scheduler with optional FastAPI app integration.
|
||||
|
||||
Args:
|
||||
app: Optional FastAPI app instance for lifecycle integration
|
||||
"""
|
||||
self.scheduler = BackgroundScheduler()
|
||||
self.app = app
|
||||
self.batch_size = int(os.getenv("STREAM_VALIDATION_BATCH_SIZE", "10"))
|
||||
self.schedule_time = os.getenv(
|
||||
"STREAM_VALIDATION_SCHEDULE", "0 3 * * *"
|
||||
) # Default 3 AM daily
|
||||
logger.info(f"Scheduler initialized with app: {app is not None}")
|
||||
|
||||
def validate_streams_batch(self, db_session: Optional[Session] = None) -> None:
|
||||
"""
|
||||
Validate streams and update their status.
|
||||
When batch_size=0, validates all channels.
|
||||
|
||||
Args:
|
||||
db_session: Optional SQLAlchemy session
|
||||
"""
|
||||
db = db_session if db_session else get_db_session()
|
||||
try:
|
||||
manager = StreamManager(db)
|
||||
|
||||
# Get channels to validate
|
||||
query = db.query(ChannelDB)
|
||||
if self.batch_size > 0:
|
||||
query = query.limit(self.batch_size)
|
||||
channels = query.all()
|
||||
|
||||
for channel in channels:
|
||||
try:
|
||||
logger.info(f"Validating streams for channel {channel.id}")
|
||||
manager.validate_and_select_stream(str(channel.id))
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating channel {channel.id}: {str(e)}")
|
||||
continue
|
||||
|
||||
logger.info(f"Completed stream validation of {len(channels)} channels")
|
||||
finally:
|
||||
if db_session is None:
|
||||
db.close()
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the scheduler and add jobs."""
|
||||
if not self.scheduler.running:
|
||||
# Add the scheduled job
|
||||
self.scheduler.add_job(
|
||||
self.validate_streams_batch,
|
||||
trigger=CronTrigger.from_crontab(self.schedule_time),
|
||||
id="daily_stream_validation",
|
||||
)
|
||||
|
||||
# Start the scheduler
|
||||
self.scheduler.start()
|
||||
logger.info(
|
||||
f"Stream scheduler started with daily validation job. "
|
||||
f"Running: {self.scheduler.running}"
|
||||
)
|
||||
|
||||
# Register shutdown handler if FastAPI app is provided
|
||||
if self.app:
|
||||
logger.info(
|
||||
f"Registering scheduler with FastAPI "
|
||||
f"app: {hasattr(self.app, 'state')}"
|
||||
)
|
||||
|
||||
@self.app.on_event("shutdown")
|
||||
def shutdown_scheduler():
|
||||
self.shutdown()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the scheduler gracefully."""
|
||||
if self.scheduler.running:
|
||||
self.scheduler.shutdown()
|
||||
logger.info("Stream scheduler stopped")
|
||||
|
||||
def trigger_manual_validation(self) -> None:
|
||||
"""Trigger manual validation of streams."""
|
||||
logger.info("Manually triggering stream validation")
|
||||
self.validate_streams_batch()
|
||||
|
||||
|
||||
def init_scheduler(app: FastAPI) -> StreamScheduler:
|
||||
"""Initialize and start the scheduler with FastAPI integration."""
|
||||
scheduler = StreamScheduler(app)
|
||||
scheduler.start()
|
||||
return scheduler
|
||||
151
app/iptv/stream_manager.py
Normal file
151
app/iptv/stream_manager.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import logging
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.db import ChannelURL
|
||||
from app.utils.check_streams import StreamValidator
|
||||
from app.utils.database import get_db_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamManager:
|
||||
"""Service for managing and validating channel streams."""
|
||||
|
||||
def __init__(self, db_session: Optional[Session] = None):
|
||||
"""
|
||||
Initialize StreamManager with optional database session.
|
||||
|
||||
Args:
|
||||
db_session: Optional SQLAlchemy session. If None, will create a new one.
|
||||
"""
|
||||
self.db = db_session if db_session else get_db_session()
|
||||
self.validator = StreamValidator()
|
||||
|
||||
def get_streams_for_channel(self, channel_id: str) -> list[ChannelURL]:
|
||||
"""
|
||||
Get all streams for a channel ordered by priority (lowest first),
|
||||
with same-priority streams randomized.
|
||||
|
||||
Args:
|
||||
channel_id: UUID of the channel to get streams for
|
||||
|
||||
Returns:
|
||||
List of ChannelURL objects ordered by priority
|
||||
"""
|
||||
try:
|
||||
# Get all streams for channel ordered by priority
|
||||
streams = (
|
||||
self.db.query(ChannelURL)
|
||||
.filter(ChannelURL.channel_id == channel_id)
|
||||
.order_by(ChannelURL.priority_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Group streams by priority and randomize same-priority streams
|
||||
grouped = {}
|
||||
for stream in streams:
|
||||
if stream.priority_id not in grouped:
|
||||
grouped[stream.priority_id] = []
|
||||
grouped[stream.priority_id].append(stream)
|
||||
|
||||
# Randomize same-priority streams and flatten
|
||||
randomized_streams = []
|
||||
for priority in sorted(grouped.keys()):
|
||||
random.shuffle(grouped[priority])
|
||||
randomized_streams.extend(grouped[priority])
|
||||
|
||||
return randomized_streams
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting streams for channel {channel_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def validate_and_select_stream(self, channel_id: str) -> Optional[str]:
|
||||
"""
|
||||
Find and validate a working stream for the given channel.
|
||||
|
||||
Args:
|
||||
channel_id: UUID of the channel to find a stream for
|
||||
|
||||
Returns:
|
||||
URL of the first working stream found, or None if none found
|
||||
"""
|
||||
try:
|
||||
streams = self.get_streams_for_channel(channel_id)
|
||||
if not streams:
|
||||
logger.warning(f"No streams found for channel {channel_id}")
|
||||
return None
|
||||
|
||||
working_stream = None
|
||||
|
||||
for stream in streams:
|
||||
logger.info(f"Validating stream {stream.url} for channel {channel_id}")
|
||||
is_valid, _ = self.validator.validate_stream(stream.url)
|
||||
|
||||
if is_valid:
|
||||
working_stream = stream
|
||||
break
|
||||
|
||||
if working_stream:
|
||||
self._update_stream_status(working_stream, streams)
|
||||
return working_stream.url
|
||||
else:
|
||||
logger.warning(f"No valid streams found for channel {channel_id}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating streams for channel {channel_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def _update_stream_status(
|
||||
self, working_stream: ChannelURL, all_streams: list[ChannelURL]
|
||||
) -> None:
|
||||
"""
|
||||
Update in_use status for streams (True for working stream, False for others).
|
||||
|
||||
Args:
|
||||
working_stream: The stream that was validated as working
|
||||
all_streams: All streams for the channel
|
||||
"""
|
||||
try:
|
||||
for stream in all_streams:
|
||||
stream.in_use = stream.id == working_stream.id
|
||||
|
||||
self.db.commit()
|
||||
logger.info(
|
||||
f"Updated stream status - set in_use=True for {working_stream.url}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Error updating stream status: {str(e)}")
|
||||
raise
|
||||
|
||||
def __del__(self):
|
||||
"""Close database session when StreamManager is destroyed."""
|
||||
if hasattr(self, "db"):
|
||||
self.db.close()
|
||||
|
||||
|
||||
def get_working_stream(
|
||||
channel_id: str, db_session: Optional[Session] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Convenience function to get a working stream for a channel.
|
||||
|
||||
Args:
|
||||
channel_id: UUID of the channel to get a stream for
|
||||
db_session: Optional SQLAlchemy session
|
||||
|
||||
Returns:
|
||||
URL of the first working stream found, or None if none found
|
||||
"""
|
||||
manager = StreamManager(db_session)
|
||||
try:
|
||||
return manager.validate_and_select_stream(channel_id)
|
||||
finally:
|
||||
if db_session is None: # Only close if we created the session
|
||||
manager.__del__()
|
||||
113
app/main.py
113
app/main.py
@@ -1,43 +1,82 @@
|
||||
from fastapi import FastAPI, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from app.cabletv.utils.auth import exchange_code_for_token, get_current_user, DOMAIN, CLIENT_ID
|
||||
from fastapi import FastAPI
|
||||
from fastapi.concurrency import asynccontextmanager
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from app.iptv.scheduler import StreamScheduler
|
||||
from app.routers import auth, channels, groups, playlist, priorities, scheduler
|
||||
from app.utils.database import init_db
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Initialize database tables on startup
|
||||
init_db()
|
||||
|
||||
# Initialize and start the stream scheduler
|
||||
scheduler = StreamScheduler(app)
|
||||
app.state.scheduler = scheduler # Store scheduler in app state
|
||||
scheduler.start()
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown scheduler on app shutdown
|
||||
scheduler.shutdown()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
title="IPTV Manager API",
|
||||
description="API for IPTV Manager service",
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
# Ensure components object exists
|
||||
if "components" not in openapi_schema:
|
||||
openapi_schema["components"] = {}
|
||||
|
||||
# Add schemas if they don't exist
|
||||
if "schemas" not in openapi_schema["components"]:
|
||||
openapi_schema["components"]["schemas"] = {}
|
||||
|
||||
# Add security scheme component
|
||||
openapi_schema["components"]["securitySchemes"] = {
|
||||
"Bearer": {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}
|
||||
}
|
||||
|
||||
# Add global security requirement
|
||||
openapi_schema["security"] = [{"Bearer": []}]
|
||||
|
||||
# Set OpenAPI version explicitly
|
||||
openapi_schema["openapi"] = "3.1.0"
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
|
||||
app.openapi = custom_openapi
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "IPTV Updater API"}
|
||||
return {"message": "IPTV Manager API"}
|
||||
|
||||
@app.get("/protected")
|
||||
async def protected_route(user = Depends(get_current_user)):
|
||||
if isinstance(user, RedirectResponse):
|
||||
return user
|
||||
return {"message": "Protected content", "user": user['Username']}
|
||||
|
||||
@app.get("/auth/callback")
|
||||
async def auth_callback(code: str):
|
||||
try:
|
||||
# Exchange the authorization code for tokens
|
||||
tokens = exchange_code_for_token(code)
|
||||
|
||||
# Create a response with the access token
|
||||
response = JSONResponse(content={
|
||||
"message": "Authentication successful",
|
||||
"access_token": tokens["access_token"]
|
||||
})
|
||||
|
||||
# Set the access token as a cookie
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=tokens["access_token"],
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="lax"
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Authentication failed: {str(e)}"
|
||||
)
|
||||
# Include routers
|
||||
app.include_router(auth.router)
|
||||
app.include_router(channels.router)
|
||||
app.include_router(playlist.router)
|
||||
app.include_router(priorities.router)
|
||||
app.include_router(groups.router)
|
||||
app.include_router(scheduler.router)
|
||||
|
||||
27
app/models/__init__.py
Normal file
27
app/models/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from .db import Base, ChannelDB, ChannelURL, Group, Priority
|
||||
from .schemas import (
|
||||
ChannelCreate,
|
||||
ChannelResponse,
|
||||
ChannelUpdate,
|
||||
ChannelURLCreate,
|
||||
ChannelURLResponse,
|
||||
GroupCreate,
|
||||
GroupResponse,
|
||||
GroupUpdate,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"ChannelDB",
|
||||
"ChannelCreate",
|
||||
"ChannelUpdate",
|
||||
"ChannelResponse",
|
||||
"ChannelURL",
|
||||
"ChannelURLCreate",
|
||||
"ChannelURLResponse",
|
||||
"Group",
|
||||
"Priority",
|
||||
"GroupCreate",
|
||||
"GroupResponse",
|
||||
"GroupUpdate",
|
||||
]
|
||||
26
app/models/auth.py
Normal file
26
app/models/auth.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SigninRequest(BaseModel):
|
||||
"""Request model for the signin endpoint."""
|
||||
|
||||
username: str = Field(..., description="The user's username")
|
||||
password: str = Field(..., description="The user's password")
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""Response model for successful authentication."""
|
||||
|
||||
access_token: str = Field(..., description="Access JWT token from Cognito")
|
||||
id_token: str = Field(..., description="ID JWT token from Cognito")
|
||||
refresh_token: Optional[str] = Field(None, description="Refresh token from Cognito")
|
||||
token_type: str = Field(..., description="Type of the token returned")
|
||||
|
||||
|
||||
class CognitoUser(BaseModel):
|
||||
"""Model representing the user returned from token verification."""
|
||||
|
||||
username: str
|
||||
roles: list[str]
|
||||
139
app/models/db.py
Normal file
139
app/models/db.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import (
|
||||
TEXT,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
TypeDecorator,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import declarative_base, relationship
|
||||
|
||||
|
||||
# Custom UUID type for SQLite compatibility
|
||||
class SQLiteUUID(TypeDecorator):
|
||||
"""Enables UUID support for SQLite with proper comparison handling."""
|
||||
|
||||
impl = TEXT
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, uuid.UUID):
|
||||
return str(value)
|
||||
try:
|
||||
# Validate string format by attempting to create UUID
|
||||
uuid.UUID(value)
|
||||
return value
|
||||
except (ValueError, AttributeError):
|
||||
raise ValueError(f"Invalid UUID string format: {value}")
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, uuid.UUID):
|
||||
return value
|
||||
return uuid.UUID(value)
|
||||
|
||||
def compare_values(self, x, y):
|
||||
if x is None or y is None:
|
||||
return x == y
|
||||
return str(x) == str(y)
|
||||
|
||||
|
||||
# Determine which UUID type to use based on environment
|
||||
if os.getenv("MOCK_AUTH", "").lower() == "true":
|
||||
UUID_COLUMN_TYPE = SQLiteUUID()
|
||||
else:
|
||||
UUID_COLUMN_TYPE = UUID(as_uuid=True)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class Priority(Base):
|
||||
"""SQLAlchemy model for channel URL priorities"""
|
||||
|
||||
__tablename__ = "priorities"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
description = Column(String, nullable=False)
|
||||
|
||||
|
||||
class Group(Base):
|
||||
"""SQLAlchemy model for channel groups"""
|
||||
|
||||
__tablename__ = "groups"
|
||||
|
||||
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String, nullable=False, unique=True)
|
||||
sort_order = Column(Integer, nullable=False, default=0)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Relationship with Channel
|
||||
channels = relationship("ChannelDB", back_populates="group")
|
||||
|
||||
|
||||
class ChannelDB(Base):
|
||||
"""SQLAlchemy model for IPTV channels"""
|
||||
|
||||
__tablename__ = "channels"
|
||||
|
||||
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
||||
tvg_id = Column(String, nullable=False)
|
||||
name = Column(String, nullable=False)
|
||||
group_id = Column(UUID_COLUMN_TYPE, ForeignKey("groups.id"), nullable=False)
|
||||
tvg_name = Column(String)
|
||||
|
||||
__table_args__ = (UniqueConstraint("group_id", "name", name="uix_group_id_name"),)
|
||||
tvg_logo = Column(String)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Relationships
|
||||
urls = relationship(
|
||||
"ChannelURL", back_populates="channel", cascade="all, delete-orphan"
|
||||
)
|
||||
group = relationship("Group", back_populates="channels")
|
||||
|
||||
|
||||
class ChannelURL(Base):
|
||||
"""SQLAlchemy model for channel URLs"""
|
||||
|
||||
__tablename__ = "channels_urls"
|
||||
|
||||
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
||||
channel_id = Column(
|
||||
UUID_COLUMN_TYPE,
|
||||
ForeignKey("channels.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
url = Column(String, nullable=False)
|
||||
in_use = Column(Boolean, default=False, nullable=False)
|
||||
priority_id = Column(Integer, ForeignKey("priorities.id"), nullable=False)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Relationships
|
||||
channel = relationship("ChannelDB", back_populates="urls")
|
||||
priority = relationship("Priority")
|
||||
140
app/models/schemas.py
Normal file
140
app/models/schemas.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class PriorityBase(BaseModel):
|
||||
"""Base Pydantic model for priorities"""
|
||||
|
||||
id: int
|
||||
description: str
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class PriorityCreate(PriorityBase):
|
||||
"""Pydantic model for creating priorities"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PriorityResponse(PriorityBase):
|
||||
"""Pydantic model for priority responses"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ChannelURLCreate(BaseModel):
|
||||
"""Pydantic model for creating channel URLs"""
|
||||
|
||||
url: str
|
||||
priority_id: int = Field(
|
||||
default=100, ge=100, le=300
|
||||
) # Default to High, validate range
|
||||
|
||||
|
||||
class ChannelURLBase(ChannelURLCreate):
|
||||
"""Base Pydantic model for channel URL responses"""
|
||||
|
||||
id: UUID
|
||||
in_use: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
priority_id: int
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ChannelURLResponse(ChannelURLBase):
|
||||
"""Pydantic model for channel URL responses"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# New Group Schemas
|
||||
class GroupCreate(BaseModel):
|
||||
"""Pydantic model for creating groups"""
|
||||
|
||||
name: str
|
||||
sort_order: int = Field(default=0, ge=0)
|
||||
|
||||
|
||||
class GroupUpdate(BaseModel):
|
||||
"""Pydantic model for updating groups"""
|
||||
|
||||
name: Optional[str] = None
|
||||
sort_order: Optional[int] = Field(None, ge=0)
|
||||
|
||||
|
||||
class GroupResponse(BaseModel):
|
||||
"""Pydantic model for group responses"""
|
||||
|
||||
id: UUID
|
||||
name: str
|
||||
sort_order: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class GroupSortUpdate(BaseModel):
|
||||
"""Pydantic model for updating a single group's sort order"""
|
||||
|
||||
sort_order: int = Field(ge=0)
|
||||
|
||||
|
||||
class GroupBulkSort(BaseModel):
|
||||
"""Pydantic model for bulk updating group sort orders"""
|
||||
|
||||
groups: list[dict] = Field(
|
||||
description="List of dicts with group_id and new sort_order",
|
||||
json_schema_extra={"example": [{"group_id": "uuid", "sort_order": 1}]},
|
||||
)
|
||||
|
||||
|
||||
class ChannelCreate(BaseModel):
|
||||
"""Pydantic model for creating channels"""
|
||||
|
||||
urls: list[ChannelURLCreate] # List of URL objects with priority
|
||||
name: str
|
||||
group_id: UUID
|
||||
tvg_id: str
|
||||
tvg_logo: str
|
||||
tvg_name: str
|
||||
|
||||
|
||||
class ChannelURLUpdate(BaseModel):
|
||||
"""Pydantic model for updating channel URLs"""
|
||||
|
||||
url: Optional[str] = None
|
||||
in_use: Optional[bool] = None
|
||||
priority_id: Optional[int] = Field(default=None, ge=100, le=300)
|
||||
|
||||
|
||||
class ChannelUpdate(BaseModel):
|
||||
"""Pydantic model for updating channels (all fields optional)"""
|
||||
|
||||
name: Optional[str] = Field(None, min_length=1)
|
||||
group_id: Optional[UUID] = None
|
||||
tvg_id: Optional[str] = Field(None, min_length=1)
|
||||
tvg_logo: Optional[str] = None
|
||||
tvg_name: Optional[str] = Field(None, min_length=1)
|
||||
|
||||
|
||||
class ChannelResponse(BaseModel):
|
||||
"""Pydantic model for channel responses"""
|
||||
|
||||
id: UUID
|
||||
name: str
|
||||
group_id: UUID
|
||||
tvg_id: str
|
||||
tvg_logo: str
|
||||
tvg_name: str
|
||||
urls: list[ChannelURLResponse] # List of URL objects without channel_id
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
0
app/routers/__init__.py
Normal file
0
app/routers/__init__.py
Normal file
22
app/routers/auth.py
Normal file
22
app/routers/auth.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.auth.cognito import initiate_auth
|
||||
from app.models.auth import SigninRequest, TokenResponse
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["authentication"])
|
||||
|
||||
|
||||
@router.post("/signin", response_model=TokenResponse, summary="Signin Endpoint")
|
||||
def signin(credentials: SigninRequest):
|
||||
"""
|
||||
Sign-in endpoint to authenticate the user with AWS Cognito
|
||||
using username and password.
|
||||
On success, returns JWT tokens (access_token, id_token, refresh_token).
|
||||
"""
|
||||
auth_result = initiate_auth(credentials.username, credentials.password)
|
||||
return TokenResponse(
|
||||
access_token=auth_result["AccessToken"],
|
||||
id_token=auth_result["IdToken"],
|
||||
refresh_token=auth_result.get("RefreshToken"),
|
||||
token_type="Bearer",
|
||||
)
|
||||
508
app/routers/channels.py
Normal file
508
app/routers/channels.py
Normal file
@@ -0,0 +1,508 @@
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.auth.dependencies import get_current_user, require_roles
|
||||
from app.models import (
|
||||
ChannelCreate,
|
||||
ChannelDB,
|
||||
ChannelResponse,
|
||||
ChannelUpdate,
|
||||
ChannelURL,
|
||||
ChannelURLCreate,
|
||||
ChannelURLResponse,
|
||||
Group,
|
||||
Priority, # Added Priority import
|
||||
)
|
||||
from app.models.auth import CognitoUser
|
||||
from app.models.schemas import ChannelURLUpdate
|
||||
from app.utils.database import get_db
|
||||
|
||||
router = APIRouter(prefix="/channels", tags=["channels"])
|
||||
|
||||
|
||||
@router.post("/", response_model=ChannelResponse, status_code=status.HTTP_201_CREATED)
|
||||
@require_roles("admin")
|
||||
def create_channel(
|
||||
channel: ChannelCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Create a new channel"""
|
||||
# Check if group exists
|
||||
group = db.query(Group).filter(Group.id == channel.group_id).first()
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Group not found",
|
||||
)
|
||||
|
||||
# Check for duplicate channel (same group_id + name)
|
||||
existing_channel = (
|
||||
db.query(ChannelDB)
|
||||
.filter(
|
||||
and_(
|
||||
ChannelDB.group_id == channel.group_id,
|
||||
ChannelDB.name == channel.name,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Channel with same group_id and name already exists",
|
||||
)
|
||||
|
||||
# Create channel without URLs first
|
||||
channel_data = channel.model_dump(exclude={"urls"})
|
||||
urls = channel.urls
|
||||
db_channel = ChannelDB(**channel_data)
|
||||
db.add(db_channel)
|
||||
db.commit()
|
||||
db.refresh(db_channel)
|
||||
|
||||
# Add URLs with priority
|
||||
for url in urls:
|
||||
db_url = ChannelURL(
|
||||
channel_id=db_channel.id,
|
||||
url=url.url,
|
||||
priority_id=url.priority_id,
|
||||
in_use=False,
|
||||
)
|
||||
db.add(db_url)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_channel)
|
||||
return db_channel
|
||||
|
||||
|
||||
@router.get("/{channel_id}", response_model=ChannelResponse)
|
||||
def get_channel(channel_id: UUID, db: Session = Depends(get_db)):
|
||||
"""Get a channel by id"""
|
||||
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||
)
|
||||
return channel
|
||||
|
||||
|
||||
@router.put("/{channel_id}", response_model=ChannelResponse)
|
||||
@require_roles("admin")
|
||||
def update_channel(
|
||||
channel_id: UUID,
|
||||
channel: ChannelUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Update a channel"""
|
||||
db_channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
|
||||
if not db_channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||
)
|
||||
|
||||
# Only check for duplicates if name or group_id are being updated
|
||||
if channel.name is not None or channel.group_id is not None:
|
||||
name = channel.name if channel.name is not None else db_channel.name
|
||||
group_id = (
|
||||
channel.group_id if channel.group_id is not None else db_channel.group_id
|
||||
)
|
||||
|
||||
# Check if new group exists
|
||||
if channel.group_id is not None:
|
||||
group = db.query(Group).filter(Group.id == channel.group_id).first()
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Group not found",
|
||||
)
|
||||
|
||||
existing_channel = (
|
||||
db.query(ChannelDB)
|
||||
.filter(
|
||||
and_(
|
||||
ChannelDB.group_id == group_id,
|
||||
ChannelDB.name == name,
|
||||
ChannelDB.id != channel_id,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Channel with same group_id and name already exists",
|
||||
)
|
||||
|
||||
# Update only provided fields
|
||||
update_data = channel.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_channel, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_channel)
|
||||
return db_channel
|
||||
|
||||
|
||||
@router.delete("/", status_code=status.HTTP_200_OK)
|
||||
@require_roles("admin")
|
||||
def delete_channels(
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Delete all channels"""
|
||||
count = 0
|
||||
try:
|
||||
count = db.query(ChannelDB).count()
|
||||
|
||||
# First delete all channels
|
||||
db.query(ChannelDB).delete()
|
||||
|
||||
# Then delete any URLs that are now orphaned (no channel references)
|
||||
db.query(ChannelURL).filter(
|
||||
~ChannelURL.channel_id.in_(db.query(ChannelDB.id))
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
# Then delete any groups that are now empty
|
||||
db.query(Group).filter(~Group.id.in_(db.query(ChannelDB.group_id))).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
print(f"Error deleting channels: {e}")
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to delete channels",
|
||||
)
|
||||
return {"deleted": count}
|
||||
|
||||
|
||||
@router.delete("/{channel_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
@require_roles("admin")
|
||||
def delete_channel(
|
||||
channel_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Delete a channel"""
|
||||
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||
)
|
||||
db.delete(channel)
|
||||
db.commit()
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/", response_model=list[ChannelResponse])
|
||||
@require_roles("admin")
|
||||
def list_channels(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""List all channels with pagination"""
|
||||
return db.query(ChannelDB).offset(skip).limit(limit).all()
|
||||
|
||||
|
||||
# New endpoint to get channels by group
|
||||
@router.get("/groups/{group_id}/channels", response_model=list[ChannelResponse])
|
||||
def get_channels_by_group(
|
||||
group_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get all channels for a specific group"""
|
||||
group = db.query(Group).filter(Group.id == group_id).first()
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
|
||||
)
|
||||
return db.query(ChannelDB).filter(ChannelDB.group_id == group_id).all()
|
||||
|
||||
|
||||
# New endpoint to update a channel's group
|
||||
@router.put("/{channel_id}/group", response_model=ChannelResponse)
|
||||
@require_roles("admin")
|
||||
def update_channel_group(
|
||||
channel_id: UUID,
|
||||
group_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Update a channel's group"""
|
||||
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||
)
|
||||
|
||||
group = db.query(Group).filter(Group.id == group_id).first()
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
|
||||
)
|
||||
|
||||
# Check for duplicate channel name in new group
|
||||
existing_channel = (
|
||||
db.query(ChannelDB)
|
||||
.filter(
|
||||
and_(
|
||||
ChannelDB.group_id == group_id,
|
||||
ChannelDB.name == channel.name,
|
||||
ChannelDB.id != channel_id,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Channel with same name already exists in target group",
|
||||
)
|
||||
|
||||
channel.group_id = group_id
|
||||
db.commit()
|
||||
db.refresh(channel)
|
||||
return channel
|
||||
|
||||
|
||||
# Bulk Upload and Reset Endpoints
|
||||
@router.post("/bulk-upload", status_code=status.HTTP_200_OK)
|
||||
@require_roles("admin")
|
||||
def bulk_upload_channels(
|
||||
channels: list[dict],
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Bulk upload channels from JSON array"""
|
||||
processed = 0
|
||||
|
||||
# Fetch all priorities from the database, ordered by id
|
||||
priorities = db.query(Priority).order_by(Priority.id).all()
|
||||
priority_map = {i: p.id for i, p in enumerate(priorities)}
|
||||
|
||||
# Get the highest priority_id (which corresponds to the lowest priority level)
|
||||
max_priority_id = None
|
||||
if priorities:
|
||||
max_priority_id = db.query(Priority.id).order_by(Priority.id.desc()).first()[0]
|
||||
|
||||
for channel_data in channels:
|
||||
try:
|
||||
# Get or create group
|
||||
group_name = channel_data.get("group-title")
|
||||
if not group_name:
|
||||
continue
|
||||
|
||||
group = db.query(Group).filter(Group.name == group_name).first()
|
||||
if not group:
|
||||
group = Group(name=group_name)
|
||||
db.add(group)
|
||||
db.flush() # Use flush to make the group available in the session
|
||||
db.refresh(group)
|
||||
|
||||
# Prepare channel data
|
||||
urls = channel_data.get("urls", [])
|
||||
if not isinstance(urls, list):
|
||||
urls = [urls]
|
||||
|
||||
# Assign priorities dynamically based on fetched priorities
|
||||
url_objects = []
|
||||
for i, url in enumerate(urls): # Process all URLs
|
||||
priority_id = priority_map.get(i)
|
||||
if priority_id is None:
|
||||
# If index is out of bounds,
|
||||
# assign the highest priority_id (lowest priority)
|
||||
if max_priority_id is not None:
|
||||
priority_id = max_priority_id
|
||||
else:
|
||||
print(
|
||||
f"Warning: No priorities defined in database. "
|
||||
f"Skipping URL {url}"
|
||||
)
|
||||
continue
|
||||
url_objects.append({"url": url, "priority_id": priority_id})
|
||||
|
||||
# Create channel object with required fields
|
||||
channel_obj = ChannelDB(
|
||||
tvg_id=channel_data.get("tvg-id", ""),
|
||||
name=channel_data.get("name", ""),
|
||||
group_id=group.id,
|
||||
tvg_name=channel_data.get("tvg-name", ""),
|
||||
tvg_logo=channel_data.get("tvg-logo", ""),
|
||||
)
|
||||
|
||||
# Upsert channel
|
||||
existing_channel = (
|
||||
db.query(ChannelDB)
|
||||
.filter(
|
||||
and_(
|
||||
ChannelDB.group_id == group.id,
|
||||
ChannelDB.name == channel_obj.name,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_channel:
|
||||
# Update existing
|
||||
existing_channel.tvg_id = channel_obj.tvg_id
|
||||
existing_channel.tvg_name = channel_obj.tvg_name
|
||||
existing_channel.tvg_logo = channel_obj.tvg_logo
|
||||
|
||||
# Clear and recreate URLs
|
||||
db.query(ChannelURL).filter(
|
||||
ChannelURL.channel_id == existing_channel.id
|
||||
).delete()
|
||||
|
||||
for url in url_objects:
|
||||
db_url = ChannelURL(
|
||||
channel_id=existing_channel.id,
|
||||
url=url["url"],
|
||||
priority_id=url["priority_id"],
|
||||
in_use=False,
|
||||
)
|
||||
db.add(db_url)
|
||||
else:
|
||||
# Create new
|
||||
db.add(channel_obj)
|
||||
db.flush() # Flush to get the new channel's ID
|
||||
db.refresh(channel_obj)
|
||||
|
||||
# Add URLs for new channel
|
||||
for url in url_objects:
|
||||
db_url = ChannelURL(
|
||||
channel_id=channel_obj.id,
|
||||
url=url["url"],
|
||||
priority_id=url["priority_id"],
|
||||
in_use=False,
|
||||
)
|
||||
db.add(db_url)
|
||||
|
||||
db.commit() # Commit all changes for this channel atomically
|
||||
processed += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing channel: {channel_data.get('name', 'Unknown')}")
|
||||
print(f"Exception details: {e}")
|
||||
db.rollback() # Rollback the entire transaction for the failed channel
|
||||
continue
|
||||
|
||||
return {"processed": processed}
|
||||
|
||||
|
||||
# URL Management Endpoints
|
||||
@router.post(
|
||||
"/{channel_id}/urls",
|
||||
response_model=ChannelURLResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
@require_roles("admin")
|
||||
def add_channel_url(
|
||||
channel_id: UUID,
|
||||
url: ChannelURLCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Add a new URL to a channel"""
|
||||
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||
)
|
||||
|
||||
db_url = ChannelURL(
|
||||
channel_id=channel_id,
|
||||
url=url.url,
|
||||
priority_id=url.priority_id,
|
||||
in_use=False, # Default to not in use
|
||||
)
|
||||
db.add(db_url)
|
||||
db.commit()
|
||||
db.refresh(db_url)
|
||||
return db_url
|
||||
|
||||
|
||||
@router.put("/{channel_id}/urls/{url_id}", response_model=ChannelURLResponse)
|
||||
@require_roles("admin")
|
||||
def update_channel_url(
|
||||
channel_id: UUID,
|
||||
url_id: UUID,
|
||||
url_update: ChannelURLUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Update a channel URL (url, in_use, or priority_id)"""
|
||||
db_url = (
|
||||
db.query(ChannelURL)
|
||||
.filter(and_(ChannelURL.id == url_id, ChannelURL.channel_id == channel_id))
|
||||
.first()
|
||||
)
|
||||
|
||||
if not db_url:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="URL not found"
|
||||
)
|
||||
|
||||
if url_update.url is not None:
|
||||
db_url.url = url_update.url
|
||||
if url_update.in_use is not None:
|
||||
db_url.in_use = url_update.in_use
|
||||
if url_update.priority_id is not None:
|
||||
db_url.priority_id = url_update.priority_id
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_url)
|
||||
return db_url
|
||||
|
||||
|
||||
@router.delete("/{channel_id}/urls/{url_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
@require_roles("admin")
|
||||
def delete_channel_url(
|
||||
channel_id: UUID,
|
||||
url_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Delete a URL from a channel"""
|
||||
url = (
|
||||
db.query(ChannelURL)
|
||||
.filter(and_(ChannelURL.id == url_id, ChannelURL.channel_id == channel_id))
|
||||
.first()
|
||||
)
|
||||
|
||||
if not url:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="URL not found"
|
||||
)
|
||||
|
||||
db.delete(url)
|
||||
db.commit()
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/{channel_id}/urls", response_model=list[ChannelURLResponse])
|
||||
@require_roles("admin")
|
||||
def list_channel_urls(
|
||||
channel_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""List all URLs for a channel"""
|
||||
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||
)
|
||||
|
||||
return db.query(ChannelURL).filter(ChannelURL.channel_id == channel_id).all()
|
||||
191
app/routers/groups.py
Normal file
191
app/routers/groups.py
Normal file
@@ -0,0 +1,191 @@
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.auth.dependencies import get_current_user, require_roles
|
||||
from app.models import Group
|
||||
from app.models.auth import CognitoUser
|
||||
from app.models.schemas import (
|
||||
GroupBulkSort,
|
||||
GroupCreate,
|
||||
GroupResponse,
|
||||
GroupSortUpdate,
|
||||
GroupUpdate,
|
||||
)
|
||||
from app.utils.database import get_db
|
||||
|
||||
router = APIRouter(prefix="/groups", tags=["groups"])
|
||||
|
||||
|
||||
@router.post("/", response_model=GroupResponse, status_code=status.HTTP_201_CREATED)
|
||||
@require_roles("admin")
|
||||
def create_group(
|
||||
group: GroupCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Create a new channel group"""
|
||||
# Check for duplicate group name
|
||||
existing_group = db.query(Group).filter(Group.name == group.name).first()
|
||||
if existing_group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Group with this name already exists",
|
||||
)
|
||||
|
||||
db_group = Group(**group.model_dump())
|
||||
db.add(db_group)
|
||||
db.commit()
|
||||
db.refresh(db_group)
|
||||
return db_group
|
||||
|
||||
|
||||
@router.get("/{group_id}", response_model=GroupResponse)
|
||||
def get_group(group_id: UUID, db: Session = Depends(get_db)):
|
||||
"""Get a group by id"""
|
||||
group = db.query(Group).filter(Group.id == group_id).first()
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
|
||||
)
|
||||
return group
|
||||
|
||||
|
||||
@router.put("/{group_id}", response_model=GroupResponse)
|
||||
@require_roles("admin")
|
||||
def update_group(
|
||||
group_id: UUID,
|
||||
group: GroupUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Update a group's name or sort order"""
|
||||
db_group = db.query(Group).filter(Group.id == group_id).first()
|
||||
if not db_group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
|
||||
)
|
||||
|
||||
# Check for duplicate name if name is being updated
|
||||
if group.name is not None and group.name != db_group.name:
|
||||
existing_group = db.query(Group).filter(Group.name == group.name).first()
|
||||
if existing_group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Group with this name already exists",
|
||||
)
|
||||
|
||||
# Update only provided fields
|
||||
update_data = group.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_group, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_group)
|
||||
return db_group
|
||||
|
||||
|
||||
@router.delete("/", status_code=status.HTTP_200_OK)
|
||||
@require_roles("admin")
|
||||
def delete_groups(
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Delete all groups that have no channels (skip groups with channels)"""
|
||||
groups = db.query(Group).all()
|
||||
deleted = 0
|
||||
skipped = 0
|
||||
|
||||
for group in groups:
|
||||
if not group.channels:
|
||||
db.delete(group)
|
||||
deleted += 1
|
||||
else:
|
||||
skipped += 1
|
||||
|
||||
db.commit()
|
||||
return {"deleted": deleted, "skipped": skipped}
|
||||
|
||||
|
||||
@router.delete("/{group_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
@require_roles("admin")
|
||||
def delete_group(
|
||||
group_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Delete a group (only if it has no channels)"""
|
||||
group = db.query(Group).filter(Group.id == group_id).first()
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
|
||||
)
|
||||
|
||||
# Check if group has any channels
|
||||
if group.channels:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot delete group with existing channels",
|
||||
)
|
||||
|
||||
db.delete(group)
|
||||
db.commit()
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/", response_model=list[GroupResponse])
|
||||
def list_groups(db: Session = Depends(get_db)):
|
||||
"""List all groups sorted by sort_order"""
|
||||
return db.query(Group).order_by(Group.sort_order).all()
|
||||
|
||||
|
||||
@router.put("/{group_id}/sort", response_model=GroupResponse)
|
||||
@require_roles("admin")
|
||||
def update_group_sort_order(
|
||||
group_id: UUID,
|
||||
sort_update: GroupSortUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Update a single group's sort order"""
|
||||
db_group = db.query(Group).filter(Group.id == group_id).first()
|
||||
if not db_group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
|
||||
)
|
||||
|
||||
db_group.sort_order = sort_update.sort_order
|
||||
db.commit()
|
||||
db.refresh(db_group)
|
||||
return db_group
|
||||
|
||||
|
||||
@router.post("/reorder", response_model=list[GroupResponse])
|
||||
@require_roles("admin")
|
||||
def bulk_update_sort_orders(
|
||||
bulk_sort: GroupBulkSort,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Bulk update group sort orders"""
|
||||
groups_to_update = []
|
||||
|
||||
for group_data in bulk_sort.groups:
|
||||
group_id = group_data["group_id"]
|
||||
sort_order = group_data["sort_order"]
|
||||
|
||||
group = db.query(Group).filter(Group.id == str(group_id)).first()
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Group with id {group_id} not found",
|
||||
)
|
||||
|
||||
group.sort_order = sort_order
|
||||
groups_to_update.append(group)
|
||||
|
||||
db.commit()
|
||||
|
||||
# Return all groups in their new order
|
||||
return db.query(Group).order_by(Group.sort_order).all()
|
||||
156
app/routers/playlist.py
Normal file
156
app/routers/playlist.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.auth.dependencies import get_current_user
|
||||
from app.iptv.stream_manager import StreamManager
|
||||
from app.models.auth import CognitoUser
|
||||
from app.utils.database import get_db_session
|
||||
|
||||
router = APIRouter(prefix="/playlist", tags=["playlist"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# In-memory store for validation processes
|
||||
validation_processes: dict[str, dict] = {}
|
||||
|
||||
|
||||
class ProcessStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class StreamValidationRequest(BaseModel):
|
||||
"""Request model for stream validation endpoint"""
|
||||
|
||||
channel_id: Optional[str] = None
|
||||
|
||||
|
||||
class ValidatedStream(BaseModel):
|
||||
"""Model for a validated working stream"""
|
||||
|
||||
channel_id: str
|
||||
stream_url: str
|
||||
|
||||
|
||||
class ValidationProcessResponse(BaseModel):
|
||||
"""Response model for validation process initiation"""
|
||||
|
||||
process_id: str
|
||||
status: ProcessStatus
|
||||
message: str
|
||||
|
||||
|
||||
class ValidationResultResponse(BaseModel):
|
||||
"""Response model for validation results"""
|
||||
|
||||
process_id: str
|
||||
status: ProcessStatus
|
||||
working_streams: Optional[list[ValidatedStream]] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
def run_stream_validation(process_id: str, channel_id: Optional[str], db: Session):
|
||||
"""Background task to validate streams"""
|
||||
try:
|
||||
validation_processes[process_id]["status"] = ProcessStatus.IN_PROGRESS
|
||||
manager = StreamManager(db)
|
||||
|
||||
if channel_id:
|
||||
stream_url = manager.validate_and_select_stream(channel_id)
|
||||
if stream_url:
|
||||
validation_processes[process_id]["result"] = {
|
||||
"working_streams": [
|
||||
ValidatedStream(channel_id=channel_id, stream_url=stream_url)
|
||||
]
|
||||
}
|
||||
else:
|
||||
validation_processes[process_id]["error"] = (
|
||||
f"No working streams found for channel {channel_id}"
|
||||
)
|
||||
else:
|
||||
# TODO: Implement validation for all channels
|
||||
validation_processes[process_id]["error"] = (
|
||||
"Validation of all channels not yet implemented"
|
||||
)
|
||||
|
||||
validation_processes[process_id]["status"] = ProcessStatus.COMPLETED
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating streams: {str(e)}")
|
||||
validation_processes[process_id]["status"] = ProcessStatus.FAILED
|
||||
validation_processes[process_id]["error"] = str(e)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/validate-streams",
|
||||
summary="Start stream validation process",
|
||||
response_model=ValidationProcessResponse,
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
responses={202: {"description": "Validation process started successfully"}},
|
||||
)
|
||||
async def start_stream_validation(
|
||||
request: StreamValidationRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
db: Session = Depends(get_db_session),
|
||||
):
|
||||
"""
|
||||
Start asynchronous validation of streams.
|
||||
|
||||
- Returns immediately with a process ID
|
||||
- Use GET /validate-streams/{process_id} to check status
|
||||
"""
|
||||
process_id = str(uuid4())
|
||||
validation_processes[process_id] = {
|
||||
"status": ProcessStatus.PENDING,
|
||||
"channel_id": request.channel_id,
|
||||
}
|
||||
|
||||
background_tasks.add_task(run_stream_validation, process_id, request.channel_id, db)
|
||||
|
||||
return {
|
||||
"process_id": process_id,
|
||||
"status": ProcessStatus.PENDING,
|
||||
"message": "Validation process started",
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/validate-streams/{process_id}",
|
||||
summary="Check validation process status",
|
||||
response_model=ValidationResultResponse,
|
||||
responses={
|
||||
200: {"description": "Process status and results"},
|
||||
404: {"description": "Process not found"},
|
||||
},
|
||||
)
|
||||
async def get_validation_status(
|
||||
process_id: str, user: CognitoUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Check status of a stream validation process.
|
||||
|
||||
Returns current status and results if completed.
|
||||
"""
|
||||
if process_id not in validation_processes:
|
||||
raise HTTPException(status_code=404, detail="Process not found")
|
||||
|
||||
process = validation_processes[process_id]
|
||||
response = {"process_id": process_id, "status": process["status"]}
|
||||
|
||||
if process["status"] == ProcessStatus.COMPLETED:
|
||||
if "error" in process:
|
||||
response["error"] = process["error"]
|
||||
else:
|
||||
response["working_streams"] = process["result"]["working_streams"]
|
||||
elif process["status"] == ProcessStatus.FAILED:
|
||||
response["error"] = process["error"]
|
||||
|
||||
return response
|
||||
120
app/routers/priorities.py
Normal file
120
app/routers/priorities.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.auth.dependencies import get_current_user, require_roles
|
||||
from app.models.auth import CognitoUser
|
||||
from app.models.db import Priority
|
||||
from app.models.schemas import PriorityCreate, PriorityResponse
|
||||
from app.utils.database import get_db
|
||||
|
||||
router = APIRouter(prefix="/priorities", tags=["priorities"])
|
||||
|
||||
|
||||
@router.post("/", response_model=PriorityResponse, status_code=status.HTTP_201_CREATED)
|
||||
@require_roles("admin")
|
||||
def create_priority(
|
||||
priority: PriorityCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Create a new priority"""
|
||||
# Check if priority with this ID already exists
|
||||
existing = db.get(Priority, priority.id)
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Priority with ID {priority.id} already exists",
|
||||
)
|
||||
|
||||
db_priority = Priority(**priority.model_dump())
|
||||
db.add(db_priority)
|
||||
db.commit()
|
||||
db.refresh(db_priority)
|
||||
return db_priority
|
||||
|
||||
|
||||
@router.get("/", response_model=list[PriorityResponse])
|
||||
@require_roles("admin")
|
||||
def list_priorities(
|
||||
db: Session = Depends(get_db), user: CognitoUser = Depends(get_current_user)
|
||||
):
|
||||
"""List all priorities"""
|
||||
return db.query(Priority).all()
|
||||
|
||||
|
||||
@router.get("/{priority_id}", response_model=PriorityResponse)
|
||||
@require_roles("admin")
|
||||
def get_priority(
|
||||
priority_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Get a priority by id"""
|
||||
priority = db.get(Priority, priority_id)
|
||||
if not priority:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Priority not found"
|
||||
)
|
||||
return priority
|
||||
|
||||
|
||||
@router.delete("/", status_code=status.HTTP_200_OK)
|
||||
@require_roles("admin")
|
||||
def delete_priorities(
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Delete all priorities not in use by channel URLs"""
|
||||
from app.models.db import ChannelURL
|
||||
|
||||
priorities = db.query(Priority).all()
|
||||
deleted = 0
|
||||
skipped = 0
|
||||
|
||||
for priority in priorities:
|
||||
in_use = db.scalar(
|
||||
select(ChannelURL).where(ChannelURL.priority_id == priority.id).limit(1)
|
||||
)
|
||||
|
||||
if not in_use:
|
||||
db.delete(priority)
|
||||
deleted += 1
|
||||
else:
|
||||
skipped += 1
|
||||
|
||||
db.commit()
|
||||
return {"deleted": deleted, "skipped": skipped}
|
||||
|
||||
|
||||
@router.delete("/{priority_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
@require_roles("admin")
|
||||
def delete_priority(
|
||||
priority_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Delete a priority (if not in use)"""
|
||||
from app.models.db import ChannelURL
|
||||
|
||||
# Check if priority exists
|
||||
priority = db.get(Priority, priority_id)
|
||||
if not priority:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Priority not found"
|
||||
)
|
||||
|
||||
# Check if priority is in use
|
||||
in_use = db.scalar(
|
||||
select(ChannelURL).where(ChannelURL.priority_id == priority_id).limit(1)
|
||||
)
|
||||
|
||||
if in_use:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Cannot delete priority that is in use by channel URLs",
|
||||
)
|
||||
|
||||
db.execute(delete(Priority).where(Priority.id == priority_id))
|
||||
db.commit()
|
||||
return None
|
||||
57
app/routers/scheduler.py
Normal file
57
app/routers/scheduler.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.auth.dependencies import get_current_user, require_roles
|
||||
from app.iptv.scheduler import StreamScheduler
|
||||
from app.models.auth import CognitoUser
|
||||
from app.utils.database import get_db
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/scheduler",
|
||||
tags=["scheduler"],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
|
||||
async def get_scheduler(request: Request) -> StreamScheduler:
|
||||
"""Get the scheduler instance from the app state."""
|
||||
if not hasattr(request.app.state.scheduler, "scheduler"):
|
||||
raise HTTPException(status_code=500, detail="Scheduler not initialized")
|
||||
return request.app.state.scheduler
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
@require_roles("admin")
|
||||
def scheduler_health(
|
||||
scheduler: StreamScheduler = Depends(get_scheduler),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Check scheduler health status (admin only)."""
|
||||
try:
|
||||
job = scheduler.scheduler.get_job("daily_stream_validation")
|
||||
next_run = str(job.next_run_time) if job and job.next_run_time else None
|
||||
|
||||
return {
|
||||
"status": "running" if scheduler.scheduler.running else "stopped",
|
||||
"next_run": next_run,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to check scheduler health: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/trigger")
|
||||
@require_roles("admin")
|
||||
def trigger_validation(
|
||||
scheduler: StreamScheduler = Depends(get_scheduler),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Manually trigger stream validation (admin only)."""
|
||||
scheduler.trigger_manual_validation()
|
||||
return JSONResponse(
|
||||
status_code=202, content={"message": "Stream validation triggered"}
|
||||
)
|
||||
0
app/utils/__init__.py
Normal file
0
app/utils/__init__.py
Normal file
14
app/utils/auth.py
Normal file
14
app/utils/auth.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
|
||||
|
||||
def calculate_secret_hash(username: str, client_id: str, client_secret: str) -> str:
|
||||
"""
|
||||
Calculate the Cognito SECRET_HASH using HMAC SHA256 for secret-enabled clients.
|
||||
"""
|
||||
msg = username + client_id
|
||||
dig = hmac.new(
|
||||
client_secret.encode("utf-8"), msg.encode("utf-8"), hashlib.sha256
|
||||
).digest()
|
||||
return base64.b64encode(dig).decode()
|
||||
@@ -1,32 +1,41 @@
|
||||
import os
|
||||
import argparse
|
||||
import requests
|
||||
import logging
|
||||
from requests.exceptions import RequestException, Timeout, ConnectionError, HTTPError
|
||||
import os
|
||||
|
||||
import requests
|
||||
from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout
|
||||
|
||||
|
||||
class StreamValidator:
|
||||
def __init__(self, timeout=10, user_agent=None):
|
||||
self.timeout = timeout
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({
|
||||
'User-Agent': user_agent or 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36'
|
||||
})
|
||||
self.session.headers.update(
|
||||
{
|
||||
"User-Agent": user_agent
|
||||
or (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/120.0.0.0 Safari/537.36"
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
def validate_stream(self, url):
|
||||
"""Validate a media stream URL with multiple fallback checks"""
|
||||
try:
|
||||
headers = {'Range': 'bytes=0-1024'}
|
||||
headers = {"Range": "bytes=0-1024"}
|
||||
with self.session.get(
|
||||
url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
stream=True,
|
||||
allow_redirects=True
|
||||
allow_redirects=True,
|
||||
) as response:
|
||||
if response.status_code not in [200, 206]:
|
||||
return False, f"Invalid status code: {response.status_code}"
|
||||
|
||||
content_type = response.headers.get('Content-Type', '')
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
if not self._is_valid_content_type(content_type):
|
||||
return False, f"Invalid content type: {content_type}"
|
||||
|
||||
@@ -49,47 +58,50 @@ class StreamValidator:
|
||||
|
||||
def _is_valid_content_type(self, content_type):
|
||||
valid_types = [
|
||||
'video/mp2t', 'application/vnd.apple.mpegurl',
|
||||
'application/dash+xml', 'video/mp4',
|
||||
'video/webm', 'application/octet-stream',
|
||||
'application/x-mpegURL'
|
||||
"video/mp2t",
|
||||
"application/vnd.apple.mpegurl",
|
||||
"application/dash+xml",
|
||||
"video/mp4",
|
||||
"video/webm",
|
||||
"application/octet-stream",
|
||||
"application/x-mpegURL",
|
||||
]
|
||||
if content_type is None:
|
||||
return False
|
||||
return any(ct in content_type for ct in valid_types)
|
||||
|
||||
def parse_playlist(self, file_path):
|
||||
"""Extract stream URLs from M3U playlist file"""
|
||||
urls = []
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
with open(file_path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
if line and not line.startswith("#"):
|
||||
urls.append(line)
|
||||
except Exception as e:
|
||||
logging.error(f"Error reading playlist file: {str(e)}")
|
||||
raise
|
||||
return urls
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Validate streaming URLs from command line arguments or playlist files',
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
description=(
|
||||
"Validate streaming URLs from command line arguments or playlist files"
|
||||
),
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
'sources',
|
||||
nargs='+',
|
||||
help='List of URLs or file paths containing stream URLs'
|
||||
"sources", nargs="+", help="List of URLs or file paths containing stream URLs"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--timeout',
|
||||
type=int,
|
||||
default=20,
|
||||
help='Timeout in seconds for stream checks'
|
||||
"--timeout", type=int, default=20, help="Timeout in seconds for stream checks"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output',
|
||||
default='deadstreams.txt',
|
||||
help='Output file name for inactive streams'
|
||||
"--output",
|
||||
default="deadstreams.txt",
|
||||
help="Output file name for inactive streams",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -97,8 +109,8 @@ def main():
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[logging.FileHandler('stream_check.log'), logging.StreamHandler()]
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
handlers=[logging.FileHandler("stream_check.log"), logging.StreamHandler()],
|
||||
)
|
||||
|
||||
validator = StreamValidator(timeout=args.timeout)
|
||||
@@ -127,9 +139,10 @@ def main():
|
||||
|
||||
# Save dead streams to file
|
||||
if dead_streams:
|
||||
with open(args.output, 'w') as f:
|
||||
f.write('\n'.join(dead_streams))
|
||||
with open(args.output, "w") as f:
|
||||
f.write("\n".join(dead_streams))
|
||||
logging.info(f"Found {len(dead_streams)} dead streams. Saved to {args.output}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,12 +1,21 @@
|
||||
# Utility functions and constants
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from a .env file if it exists
|
||||
load_dotenv()
|
||||
|
||||
# AWS related constants
|
||||
AWS_REGION = os.environ.get("AWS_REGION", "us-east-2")
|
||||
COGNITO_USER_POOL_ID = os.getenv("COGNITO_USER_POOL_ID")
|
||||
COGNITO_CLIENT_ID = os.getenv("COGNITO_CLIENT_ID")
|
||||
COGNITO_CLIENT_SECRET = os.environ.get("COGNITO_CLIENT_SECRET", None)
|
||||
USER_ROLE_ATTRIBUTE = "zoneinfo"
|
||||
|
||||
IPTV_SERVER_URL = os.getenv("IPTV_SERVER_URL", "https://iptv.fiorinis.com")
|
||||
|
||||
# Super iptv-server admin credentials for basic auth
|
||||
# iptv-server super admin credentials for basic auth
|
||||
# Reads from environment variables IPTV_SERVER_ADMIN_USER and IPTV_SERVER_ADMIN_PASSWORD
|
||||
IPTV_SERVER_ADMIN_USER = os.getenv("IPTV_SERVER_ADMIN_USER", "admin")
|
||||
IPTV_SERVER_ADMIN_PASSWORD = os.getenv("IPTV_SERVER_ADMIN_PASSWORD", "adminpassword")
|
||||
61
app/utils/database.py
Normal file
61
app/utils/database.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
|
||||
import boto3
|
||||
from requests import Session
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from app.models import Base
|
||||
|
||||
from .constants import AWS_REGION
|
||||
|
||||
|
||||
def get_db_credentials():
|
||||
"""Fetch and cache DB credentials from environment or SSM Parameter Store"""
|
||||
if os.getenv("MOCK_AUTH", "").lower() == "true":
|
||||
return (
|
||||
f"postgresql://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}"
|
||||
f"@{os.getenv('DB_HOST')}/{os.getenv('DB_NAME')}"
|
||||
)
|
||||
|
||||
ssm = boto3.client("ssm", region_name=AWS_REGION)
|
||||
try:
|
||||
host = ssm.get_parameter(Name="/iptv-manager/DB_HOST", WithDecryption=True)[
|
||||
"Parameter"
|
||||
]["Value"]
|
||||
user = ssm.get_parameter(Name="/iptv-manager/DB_USER", WithDecryption=True)[
|
||||
"Parameter"
|
||||
]["Value"]
|
||||
password = ssm.get_parameter(
|
||||
Name="/iptv-manager/DB_PASSWORD", WithDecryption=True
|
||||
)["Parameter"]["Value"]
|
||||
dbname = ssm.get_parameter(Name="/iptv-manager/DB_NAME", WithDecryption=True)[
|
||||
"Parameter"
|
||||
]["Value"]
|
||||
return f"postgresql://{user}:{password}@{host}/{dbname}"
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to fetch DB credentials from SSM: {str(e)}")
|
||||
|
||||
|
||||
# Initialize engine and session maker
|
||||
engine = create_engine(get_db_credentials())
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
def init_db():
|
||||
"""Initialize database by creating all tables"""
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
def get_db():
|
||||
"""Dependency for getting database session"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def get_db_session() -> Session:
|
||||
"""Get a direct database session (non-generator version)"""
|
||||
return SessionLocal()
|
||||
23
deploy.sh
23
deploy.sh
@@ -1,23 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Deploy infrastructure
|
||||
cdk deploy
|
||||
|
||||
# Update application on running instances
|
||||
INSTANCE_IDS=$(aws ec2 describe-instances \
|
||||
--filters "Name=tag:Name,Values=IptvUpdater/IptvUpdaterInstance" \
|
||||
"Name=instance-state-name,Values=running" \
|
||||
--query "Reservations[].Instances[].InstanceId" \
|
||||
--output text)
|
||||
|
||||
for INSTANCE_ID in $INSTANCE_IDS; do
|
||||
echo "Updating application on instance: $INSTANCE_ID"
|
||||
aws ssm send-command \
|
||||
--instance-ids "$INSTANCE_ID" \
|
||||
--document-name "AWS-RunShellScript" \
|
||||
--parameters '{"commands":["cd /home/ec2-user/iptv-updater-aws && git pull && pip3 install -r requirements.txt && sudo systemctl restart iptv-updater"]}' \
|
||||
--no-cli-pager \
|
||||
--no-paginate
|
||||
done
|
||||
|
||||
echo "Deployment and instance update complete"
|
||||
@@ -1,4 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Destroy infrastructure
|
||||
cdk destroy
|
||||
28
docker/Dockerfile
Normal file
28
docker/Dockerfile
Normal file
@@ -0,0 +1,28 @@
|
||||
# Use official Python image
|
||||
FROM python:3.9-slim
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONDONTWRITEBYTECODE 1
|
||||
ENV PYTHONUNBUFFERED 1
|
||||
|
||||
# Set work directory
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
gcc \
|
||||
python3-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Python dependencies
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
# Expose the port the app runs on
|
||||
EXPOSE 8000
|
||||
|
||||
# Command to run the application
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
17
docker/docker-compose-db.yml
Normal file
17
docker/docker-compose-db.yml
Normal file
@@ -0,0 +1,17 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:13
|
||||
container_name: postgres
|
||||
environment:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: iptv_manager
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
32
docker/docker-compose-local.yml
Normal file
32
docker/docker-compose-local.yml
Normal file
@@ -0,0 +1,32 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:13
|
||||
environment:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: iptv_manager
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
|
||||
app:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/Dockerfile
|
||||
environment:
|
||||
DB_USER: postgres
|
||||
DB_PASSWORD: postgres
|
||||
DB_HOST: postgres
|
||||
DB_NAME: iptv_manager
|
||||
MOCK_AUTH: "true"
|
||||
ports:
|
||||
- "8000:8000"
|
||||
depends_on:
|
||||
- postgres
|
||||
command: uvicorn app.main:app --host 0.0.0.0 --port 8000
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
@@ -1,67 +1,84 @@
|
||||
import os
|
||||
from aws_cdk import (
|
||||
Stack,
|
||||
aws_ec2 as ec2,
|
||||
aws_iam as iam,
|
||||
aws_cognito as cognito,
|
||||
CfnOutput
|
||||
)
|
||||
|
||||
from aws_cdk import CfnOutput, Duration, RemovalPolicy, Stack
|
||||
from aws_cdk import aws_cognito as cognito
|
||||
from aws_cdk import aws_ec2 as ec2
|
||||
from aws_cdk import aws_iam as iam
|
||||
from aws_cdk import aws_rds as rds
|
||||
from aws_cdk import aws_ssm as ssm
|
||||
from constructs import Construct
|
||||
|
||||
class IptvUpdaterStack(Stack):
|
||||
def __init__(self, scope: Construct, construct_id: str, **kwargs) -> None:
|
||||
|
||||
class IptvManagerStack(Stack):
|
||||
def __init__(
|
||||
self,
|
||||
scope: Construct,
|
||||
construct_id: str,
|
||||
freedns_user: str,
|
||||
freedns_password: str,
|
||||
domain_name: str,
|
||||
ssh_public_key: str,
|
||||
repo_url: str,
|
||||
letsencrypt_email: str,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(scope, construct_id, **kwargs)
|
||||
|
||||
# Create VPC
|
||||
vpc = ec2.Vpc(self, "IptvUpdaterVPC",
|
||||
max_azs=1, # Use only one AZ for free tier
|
||||
vpc = ec2.Vpc(
|
||||
self,
|
||||
"IptvManagerVPC",
|
||||
max_azs=2, # Need at least 2 AZs for RDS subnet group
|
||||
nat_gateways=0, # No NAT Gateway to stay in free tier
|
||||
subnet_configuration=[
|
||||
ec2.SubnetConfiguration(
|
||||
name="public",
|
||||
subnet_type=ec2.SubnetType.PUBLIC,
|
||||
cidr_mask=24
|
||||
)
|
||||
]
|
||||
name="public", subnet_type=ec2.SubnetType.PUBLIC, cidr_mask=24
|
||||
),
|
||||
ec2.SubnetConfiguration(
|
||||
name="private",
|
||||
subnet_type=ec2.SubnetType.PRIVATE_ISOLATED,
|
||||
cidr_mask=24,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Security Group
|
||||
security_group = ec2.SecurityGroup(
|
||||
self, "IptvUpdaterSG",
|
||||
vpc=vpc,
|
||||
allow_all_outbound=True
|
||||
self, "IptvManagerSG", vpc=vpc, allow_all_outbound=True
|
||||
)
|
||||
|
||||
security_group.add_ingress_rule(
|
||||
ec2.Peer.any_ipv4(),
|
||||
ec2.Port.tcp(443),
|
||||
"Allow HTTPS traffic"
|
||||
ec2.Peer.any_ipv4(), ec2.Port.tcp(443), "Allow HTTPS traffic"
|
||||
)
|
||||
|
||||
security_group.add_ingress_rule(
|
||||
ec2.Peer.any_ipv4(),
|
||||
ec2.Port.tcp(80),
|
||||
"Allow HTTP traffic"
|
||||
ec2.Peer.any_ipv4(), ec2.Port.tcp(80), "Allow HTTP traffic"
|
||||
)
|
||||
|
||||
security_group.add_ingress_rule(
|
||||
ec2.Peer.any_ipv4(),
|
||||
ec2.Port.tcp(22),
|
||||
"Allow SSH traffic"
|
||||
ec2.Peer.any_ipv4(), ec2.Port.tcp(22), "Allow SSH traffic"
|
||||
)
|
||||
|
||||
# Key pair for IPTV Updater instance
|
||||
# Allow PostgreSQL port for tunneling restricted to developer IP
|
||||
security_group.add_ingress_rule(
|
||||
ec2.Peer.ipv4("47.189.88.48/32"), # Developer IP
|
||||
ec2.Port.tcp(5432),
|
||||
"Allow PostgreSQL traffic for tunneling",
|
||||
)
|
||||
|
||||
# Key pair for IPTV Manager instance
|
||||
key_pair = ec2.KeyPair(
|
||||
self,
|
||||
"IptvUpdaterKeyPair",
|
||||
key_pair_name="iptv-updater-key",
|
||||
public_key_material="ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQDZcD20mfOQ/al6VMXWWUcUILYIHyIy5gY9RmC6koicaMYTJ078yMUnYw9DONEIfvxZwceTrtdUWv7vOXjknh1bberF78Pi3vwwFKazdgjbyZnkM9r2N40+PqfOFYf03B53VIA8jdae6gaD3VvYdlwV5dqqW3N+JE/+7sXHLeENnqxeS5jNLoRKDIxHgV4MBnewWNdp77ZEHZFrG3Fj/0lnqq2EfDh+5E/Vhkc/mzkuXu2l5Z1szw2rcjJz4IYUjgnClorBVlmWwwiICrHbyVCYivaaVvpVmZoBy7WW1P8KFAiot4G6C0Klyn4sy2AwCQ7u65TyDtbCR89VuuwW0zLgERAfWMdhn/5HdIudcTScXEsUsADTqxb2x0IXVbs3uhxHCUcc7BVg0S7dpAbmpK+80NlDCH38LFcyrYASRD6/Le2skVePIFt0Tw1OnwPH/QqPG0Y9vLZJl8779pg+kpk+o1MaRnczPq9Sk9zr3dR4Sv82CObnjTeY5LCTvs06JBCtcey/vLk7RQYs43uXZgg746aWIcM0lgHPivh2JpUOAd/Kj1v4aEXTRgHyPPLQ4KKwnALcd4s6+ytXwf4gxumRmPf/7P6HYJQvY8pfVda8jEvu08rvSVMXb09Jq4tzS/JTX2flfEdF0qIyTNHnK/lum27fgP0yq6Pq24IWYXha9w== stefano@MSI"
|
||||
"IptvManagerKeyPair",
|
||||
key_pair_name="iptv-manager-key",
|
||||
public_key_material=ssh_public_key,
|
||||
)
|
||||
|
||||
# Create IAM role for EC2
|
||||
role = iam.Role(
|
||||
self, "IptvUpdaterRole",
|
||||
assumed_by=iam.ServicePrincipal("ec2.amazonaws.com")
|
||||
self,
|
||||
"IptvManagerRole",
|
||||
assumed_by=iam.ServicePrincipal("ec2.amazonaws.com"),
|
||||
)
|
||||
|
||||
# Add SSM managed policy
|
||||
@@ -71,70 +88,65 @@ class IptvUpdaterStack(Stack):
|
||||
)
|
||||
)
|
||||
|
||||
# Add EC2 describe permissions
|
||||
role.add_to_policy(
|
||||
iam.PolicyStatement(actions=["ec2:DescribeInstances"], resources=["*"])
|
||||
)
|
||||
|
||||
# Add SSM SendCommand permissions
|
||||
role.add_to_policy(
|
||||
iam.PolicyStatement(
|
||||
actions=["ssm:SendCommand"],
|
||||
resources=[
|
||||
# Allow on all EC2 instances
|
||||
f"arn:aws:ec2:{self.region}:{self.account}:instance/*",
|
||||
# Required for the RunShellScript document
|
||||
f"arn:aws:ssm:{self.region}:{self.account}:document/AWS-RunShellScript",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
# Add Cognito permissions to instance role
|
||||
role.add_managed_policy(
|
||||
iam.ManagedPolicy.from_aws_managed_policy_name(
|
||||
"AmazonCognitoReadOnly"
|
||||
)
|
||||
)
|
||||
|
||||
# EC2 Instance
|
||||
instance = ec2.Instance(
|
||||
self, "IptvUpdaterInstance",
|
||||
vpc=vpc,
|
||||
instance_type=ec2.InstanceType.of(
|
||||
ec2.InstanceClass.T2,
|
||||
ec2.InstanceSize.MICRO
|
||||
),
|
||||
machine_image=ec2.AmazonLinuxImage(
|
||||
generation=ec2.AmazonLinuxGeneration.AMAZON_LINUX_2
|
||||
),
|
||||
security_group=security_group,
|
||||
key_pair=key_pair,
|
||||
role=role
|
||||
)
|
||||
|
||||
# Create Elastic IP
|
||||
eip = ec2.CfnEIP(
|
||||
self, "IptvUpdaterEIP",
|
||||
domain="vpc",
|
||||
instance_id=instance.instance_id
|
||||
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonCognitoReadOnly")
|
||||
)
|
||||
|
||||
# Add Cognito User Pool
|
||||
user_pool = cognito.UserPool(
|
||||
self, "IptvUpdaterUserPool",
|
||||
user_pool_name="iptv-updater-users",
|
||||
self,
|
||||
"IptvManagerUserPool",
|
||||
user_pool_name="iptv-manager-users",
|
||||
self_sign_up_enabled=False, # Only admins can create users
|
||||
password_policy=cognito.PasswordPolicy(
|
||||
min_length=8,
|
||||
require_lowercase=True,
|
||||
require_digits=True,
|
||||
require_symbols=True,
|
||||
require_uppercase=True
|
||||
require_uppercase=True,
|
||||
),
|
||||
account_recovery=cognito.AccountRecovery.EMAIL_ONLY
|
||||
account_recovery=cognito.AccountRecovery.EMAIL_ONLY,
|
||||
removal_policy=RemovalPolicy.DESTROY,
|
||||
)
|
||||
|
||||
# Add App Client with the correct callback URL
|
||||
client = user_pool.add_client("IptvUpdaterClient",
|
||||
client = user_pool.add_client(
|
||||
"IptvManagerClient",
|
||||
access_token_validity=Duration.minutes(60),
|
||||
id_token_validity=Duration.minutes(60),
|
||||
refresh_token_validity=Duration.days(1),
|
||||
auth_flows=cognito.AuthFlow(user_password=True),
|
||||
o_auth=cognito.OAuthSettings(
|
||||
flows=cognito.OAuthFlows(
|
||||
authorization_code_grant=True
|
||||
flows=cognito.OAuthFlows(implicit_code_grant=True)
|
||||
),
|
||||
scopes=[cognito.OAuthScope.OPENID],
|
||||
callback_urls=[
|
||||
"http://localhost:8000/auth/callback", # For local testing
|
||||
"https://*.amazonaws.com/auth/callback" # Will match EC2 public DNS
|
||||
]
|
||||
)
|
||||
prevent_user_existence_errors=True,
|
||||
generate_secret=True,
|
||||
enable_token_revocation=True,
|
||||
)
|
||||
|
||||
# Add domain for hosted UI
|
||||
domain = user_pool.add_domain("IptvUpdaterDomain",
|
||||
cognito_domain=cognito.CognitoDomainOptions(
|
||||
domain_prefix="iptv-updater"
|
||||
)
|
||||
domain = user_pool.add_domain(
|
||||
"IptvManagerDomain",
|
||||
cognito_domain=cognito.CognitoDomainOptions(domain_prefix="iptv-manager"),
|
||||
)
|
||||
|
||||
# Read the userdata script with proper path resolution
|
||||
@@ -144,20 +156,152 @@ class IptvUpdaterStack(Stack):
|
||||
|
||||
# Creates a userdata object for Linux hosts
|
||||
userdata = ec2.UserData.for_linux()
|
||||
|
||||
# Add environment variables for acme.sh from parameters
|
||||
userdata.add_commands(
|
||||
f'export FREEDNS_User="{freedns_user}"',
|
||||
f'export FREEDNS_Password="{freedns_password}"',
|
||||
f'export DOMAIN_NAME="{domain_name}"',
|
||||
f'export REPO_URL="{repo_url}"',
|
||||
f'export LETSENCRYPT_EMAIL="{letsencrypt_email}"',
|
||||
)
|
||||
|
||||
# Adds one or more commands to the userdata object.
|
||||
userdata.add_commands(
|
||||
f'echo "COGNITO_USER_POOL_ID={user_pool.user_pool_id}" >> /etc/environment',
|
||||
f'echo "COGNITO_CLIENT_ID={client.user_pool_client_id}" >> /etc/environment'
|
||||
(
|
||||
f'echo "COGNITO_USER_POOL_ID='
|
||||
f'{user_pool.user_pool_id}" >> /etc/environment'
|
||||
),
|
||||
(
|
||||
f'echo "COGNITO_CLIENT_ID='
|
||||
f'{client.user_pool_client_id}" >> /etc/environment'
|
||||
),
|
||||
(
|
||||
f'echo "COGNITO_CLIENT_SECRET='
|
||||
f'{client.user_pool_client_secret.to_string()}" >> /etc/environment'
|
||||
),
|
||||
f'echo "DOMAIN_NAME={domain_name}" >> /etc/environment',
|
||||
)
|
||||
userdata.add_commands(str(userdata_file, 'utf-8'))
|
||||
userdata.add_commands(str(userdata_file, "utf-8"))
|
||||
|
||||
# Create RDS Security Group
|
||||
rds_sg = ec2.SecurityGroup(
|
||||
self,
|
||||
"RdsSecurityGroup",
|
||||
vpc=vpc,
|
||||
description="Security group for RDS PostgreSQL",
|
||||
)
|
||||
rds_sg.add_ingress_rule(
|
||||
security_group,
|
||||
ec2.Port.tcp(5432),
|
||||
"Allow PostgreSQL access from EC2 instance",
|
||||
)
|
||||
|
||||
# Create RDS PostgreSQL instance (free tier compatible - db.t3.micro)
|
||||
db = rds.DatabaseInstance(
|
||||
self,
|
||||
"IptvManagerDB",
|
||||
engine=rds.DatabaseInstanceEngine.postgres(
|
||||
version=rds.PostgresEngineVersion.VER_13
|
||||
),
|
||||
instance_type=ec2.InstanceType.of(
|
||||
ec2.InstanceClass.T3, ec2.InstanceSize.MICRO
|
||||
),
|
||||
vpc=vpc,
|
||||
vpc_subnets=ec2.SubnetSelection(
|
||||
subnet_type=ec2.SubnetType.PRIVATE_ISOLATED
|
||||
),
|
||||
security_groups=[rds_sg],
|
||||
allocated_storage=10,
|
||||
max_allocated_storage=10,
|
||||
database_name="iptv_manager",
|
||||
removal_policy=RemovalPolicy.DESTROY,
|
||||
deletion_protection=False,
|
||||
publicly_accessible=False, # Avoid public IPv4 charges
|
||||
)
|
||||
|
||||
# Add RDS permissions to instance role
|
||||
role.add_managed_policy(
|
||||
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonRDSFullAccess")
|
||||
)
|
||||
|
||||
# Store DB connection info in SSM Parameter Store
|
||||
db_host_param = ssm.StringParameter(
|
||||
self,
|
||||
"DBHostParam",
|
||||
parameter_name="/iptv-manager/DB_HOST",
|
||||
string_value=db.db_instance_endpoint_address,
|
||||
)
|
||||
db_name_param = ssm.StringParameter(
|
||||
self,
|
||||
"DBNameParam",
|
||||
parameter_name="/iptv-manager/DB_NAME",
|
||||
string_value="iptv_manager",
|
||||
)
|
||||
db_user_param = ssm.StringParameter(
|
||||
self,
|
||||
"DBUserParam",
|
||||
parameter_name="/iptv-manager/DB_USER",
|
||||
string_value=db.secret.secret_value_from_json("username").to_string(),
|
||||
)
|
||||
db_pass_param = ssm.StringParameter(
|
||||
self,
|
||||
"DBPassParam",
|
||||
parameter_name="/iptv-manager/DB_PASSWORD",
|
||||
string_value=db.secret.secret_value_from_json("password").to_string(),
|
||||
)
|
||||
|
||||
# Add SSM read permissions to instance role
|
||||
role.add_managed_policy(
|
||||
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonSSMReadOnlyAccess")
|
||||
)
|
||||
|
||||
# EC2 Instance (created after all dependencies are ready)
|
||||
instance = ec2.Instance(
|
||||
self,
|
||||
"IptvManagerInstance",
|
||||
vpc=vpc,
|
||||
vpc_subnets=ec2.SubnetSelection(subnet_type=ec2.SubnetType.PUBLIC),
|
||||
instance_type=ec2.InstanceType.of(
|
||||
ec2.InstanceClass.T2, ec2.InstanceSize.MICRO
|
||||
),
|
||||
machine_image=ec2.AmazonLinuxImage(
|
||||
generation=ec2.AmazonLinuxGeneration.AMAZON_LINUX_2023
|
||||
),
|
||||
security_group=security_group,
|
||||
key_pair=key_pair,
|
||||
role=role,
|
||||
# Option: 1: Enable auto-assign public IP (free tier compatible)
|
||||
associate_public_ip_address=True,
|
||||
)
|
||||
|
||||
# Ensure instance depends on SSM parameters being created
|
||||
instance.node.add_dependency(db)
|
||||
instance.node.add_dependency(db_host_param)
|
||||
instance.node.add_dependency(db_name_param)
|
||||
instance.node.add_dependency(db_user_param)
|
||||
instance.node.add_dependency(db_pass_param)
|
||||
|
||||
# Option: 2: Create Elastic IP (not free tier compatible)
|
||||
# eip = ec2.CfnEIP(
|
||||
# self, "IptvManagerEIP",
|
||||
# domain="vpc",
|
||||
# instance_id=instance.instance_id
|
||||
# )
|
||||
|
||||
# Update instance with userdata
|
||||
instance.add_user_data(userdata.render())
|
||||
|
||||
# Outputs
|
||||
CfnOutput(self, "InstancePublicIP", value=eip.attr_public_ip)
|
||||
CfnOutput(self, "DBEndpoint", value=db.db_instance_endpoint_address)
|
||||
# Option: 1: Use EC2 instance public IP (free tier compatible)
|
||||
CfnOutput(self, "InstancePublicIP", value=instance.instance_public_ip)
|
||||
# Option: 2: Use EIP (not free tier compatible)
|
||||
# CfnOutput(self, "InstancePublicIP", value=eip.attr_public_ip)
|
||||
CfnOutput(self, "UserPoolId", value=user_pool.user_pool_id)
|
||||
CfnOutput(self, "UserPoolClientId", value=client.user_pool_client_id)
|
||||
CfnOutput(self, "CognitoDomainUrl",
|
||||
value=f"https://{domain.domain_name}.auth.{self.region}.amazoncognito.com"
|
||||
CfnOutput(
|
||||
self,
|
||||
"CognitoDomainUrl",
|
||||
value=f"https://{domain.domain_name}.auth.{self.region}.amazoncognito.com",
|
||||
)
|
||||
@@ -1,29 +1,79 @@
|
||||
#!/bin/sh
|
||||
|
||||
yum update -y
|
||||
yum install -y python3-pip git
|
||||
amazon-linux-extras install nginx1
|
||||
# Update system and install required packages
|
||||
dnf update -y
|
||||
dnf install -y python3-pip git cronie nginx certbot python3-certbot-nginx postgresql15.x86_64 awscli
|
||||
|
||||
pip3 install --upgrade pip
|
||||
pip3 install certbot certbot-nginx
|
||||
# Start and enable crond service
|
||||
systemctl start crond
|
||||
systemctl enable crond
|
||||
|
||||
cd /home/ec2-user
|
||||
|
||||
git clone https://git.fiorinis.com/Home/iptv-updater-aws.git
|
||||
cd iptv-updater-aws
|
||||
git clone ${REPO_URL}
|
||||
cd iptv-manager-service
|
||||
|
||||
pip3 install -r requirements.txt
|
||||
# Install Python packages with --ignore-installed to prevent conflicts with RPM packages
|
||||
pip3 install --ignore-installed -r requirements.txt
|
||||
|
||||
# Retrieve DB credentials from SSM Parameter Store with retries
|
||||
echo "Attempting to retrieve DB credentials from SSM..."
|
||||
for i in {1..30}; do
|
||||
DB_HOST=$(aws ssm get-parameter --name "/iptv-manager/DB_HOST" --query "Parameter.Value" --output text 2>/dev/null)
|
||||
DB_NAME=$(aws ssm get-parameter --name "/iptv-manager/DB_NAME" --query "Parameter.Value" --output text 2>/dev/null)
|
||||
DB_USER=$(aws ssm get-parameter --name "/iptv-manager/DB_USER" --query "Parameter.Value" --output text 2>/dev/null)
|
||||
DB_PASSWORD=$(aws ssm get-parameter --name "/iptv-manager/DB_PASSWORD" --query "Parameter.Value" --output text 2>/dev/null)
|
||||
|
||||
if [ -n "$DB_HOST" ] && [ -n "$DB_NAME" ] && [ -n "$DB_USER" ] && [ -n "$DB_PASSWORD" ]; then
|
||||
echo "Successfully retrieved all DB credentials"
|
||||
break
|
||||
fi
|
||||
|
||||
echo "Waiting for SSM parameters to be available... (attempt $i/30)"
|
||||
sleep 5
|
||||
done
|
||||
|
||||
if [ -z "$DB_HOST" ] || [ -z "$DB_NAME" ] || [ -z "$DB_USER" ] || [ -z "$DB_PASSWORD" ]; then
|
||||
echo "ERROR: Failed to retrieve all required DB credentials after 30 attempts"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DB_HOST
|
||||
export DB_NAME
|
||||
export DB_USER
|
||||
export DB_PASSWORD
|
||||
|
||||
# Set PGPASSWORD for psql to use
|
||||
export PGPASSWORD=$DB_PASSWORD
|
||||
|
||||
# Wait for PostgreSQL to be ready
|
||||
echo "Waiting for PostgreSQL to start..."
|
||||
until psql -h $DB_HOST -U $DB_USER -d postgres -c '\q'; do
|
||||
sleep 1
|
||||
done
|
||||
echo "PostgreSQL is ready."
|
||||
|
||||
# Create database if it does not exist
|
||||
DB_EXISTS=$(psql -h $DB_HOST -U $DB_USER -d postgres -tc "SELECT 1 FROM pg_database WHERE datname = '$DB_NAME';")
|
||||
if [ -z "$DB_EXISTS" ]; then
|
||||
echo "Creating database $DB_NAME..."
|
||||
psql -h $DB_HOST -U $DB_USER -d postgres -c "CREATE DATABASE $DB_NAME;"
|
||||
echo "Database $DB_NAME created."
|
||||
fi
|
||||
|
||||
# Run database migrations
|
||||
alembic upgrade head
|
||||
|
||||
# Create systemd service file
|
||||
cat << 'EOF' > /etc/systemd/system/iptv-updater.service
|
||||
cat << 'EOF' > /etc/systemd/system/iptv-manager.service
|
||||
[Unit]
|
||||
Description=IPTV Updater Service
|
||||
Description=IPTV Manager Service
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User=ec2-user
|
||||
WorkingDirectory=/home/ec2-user/iptv-updater-aws
|
||||
WorkingDirectory=/home/ec2-user/iptv-manager-service
|
||||
ExecStart=/usr/local/bin/uvicorn app.main:app --host 127.0.0.1 --port 8000
|
||||
EnvironmentFile=/etc/environment
|
||||
Restart=always
|
||||
@@ -32,17 +82,42 @@ Restart=always
|
||||
WantedBy=multi-user.target
|
||||
EOF
|
||||
|
||||
# Ensure root has a crontab before installing acme.sh
|
||||
crontab -u root -l >/dev/null 2>&1 || (echo "" | crontab -u root -)
|
||||
|
||||
# Install and configure acme.sh
|
||||
curl https://get.acme.sh | sh -s email="${LETSENCRYPT_EMAIL}"
|
||||
|
||||
# Configure acme.sh to use DNS API for FreeDNS
|
||||
. "/.acme.sh/acme.sh.env"
|
||||
"/.acme.sh"/acme.sh --issue --dns dns_freedns -d ${DOMAIN_NAME} -d *.${DOMAIN_NAME}
|
||||
sudo mkdir -p /etc/nginx/ssl
|
||||
"/.acme.sh"/acme.sh --install-cert -d ${DOMAIN_NAME} -d *.${DOMAIN_NAME} \
|
||||
--key-file /etc/nginx/ssl/${DOMAIN_NAME}.pem \
|
||||
--fullchain-file /etc/nginx/ssl/cert.pem \
|
||||
--reloadcmd "service nginx force-reload"
|
||||
|
||||
# Create nginx config
|
||||
cat << 'EOF' > /etc/nginx/conf.d/iptvUpdater.conf
|
||||
cat << EOF > /etc/nginx/conf.d/iptvManager.conf
|
||||
server {
|
||||
listen 80;
|
||||
server_name $HOSTNAME;
|
||||
server_name ${DOMAIN_NAME} *.${DOMAIN_NAME};
|
||||
return 301 https://\$host\$request_uri;
|
||||
}
|
||||
|
||||
server {
|
||||
listen 443 ssl;
|
||||
server_name ${DOMAIN_NAME} *.${DOMAIN_NAME};
|
||||
|
||||
ssl_certificate /etc/nginx/ssl/cert.pem;
|
||||
ssl_certificate_key /etc/nginx/ssl/${DOMAIN_NAME}.pem;
|
||||
|
||||
location / {
|
||||
proxy_pass http://127.0.0.1:8000;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header Host \$host;
|
||||
proxy_set_header X-Real-IP \$remote_addr;
|
||||
proxy_set_header X-Forwarded-For \$proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto \$scheme;
|
||||
}
|
||||
}
|
||||
EOF
|
||||
@@ -50,5 +125,5 @@ EOF
|
||||
# Start nginx service
|
||||
systemctl enable nginx
|
||||
systemctl start nginx
|
||||
systemctl enable iptv-updater
|
||||
systemctl start iptv-updater
|
||||
systemctl enable iptv-manager
|
||||
systemctl start iptv-manager
|
||||
@@ -1,5 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Install dependencies and deploy infrastructure
|
||||
npm install -g aws-cdk
|
||||
python3 -m pip install -r requirements.txt
|
||||
31
pyproject.toml
Normal file
31
pyproject.toml
Normal file
@@ -0,0 +1,31 @@
|
||||
[tool.ruff]
|
||||
line-length = 88
|
||||
exclude = [
|
||||
"alembic/versions/*.py", # Auto-generated Alembic migration files
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle errors
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"UP", # pyupgrade
|
||||
"W", # pycodestyle warnings
|
||||
]
|
||||
ignore = []
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/**/*.py" = [
|
||||
"F811", # redefinition of unused name
|
||||
"F401", # unused import
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["app"]
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--cov=app --cov-report=term-missing --cov-fail-under=70"
|
||||
testpaths = ["tests"]
|
||||
28
pytest.ini
Normal file
28
pytest.ini
Normal file
@@ -0,0 +1,28 @@
|
||||
[pytest]
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_functions = test_*
|
||||
asyncio_mode = auto
|
||||
filterwarnings =
|
||||
ignore::DeprecationWarning:botocore.auth
|
||||
ignore:The 'app' shortcut is now deprecated:DeprecationWarning:httpx._client
|
||||
|
||||
# Coverage configuration
|
||||
addopts =
|
||||
--cov=app
|
||||
--cov-report=term-missing
|
||||
|
||||
# Test environment variables
|
||||
env =
|
||||
MOCK_AUTH=true
|
||||
DB_USER=test_user
|
||||
DB_PASSWORD=test_password
|
||||
DB_HOST=localhost
|
||||
DB_NAME=iptv_manager_test
|
||||
|
||||
# Test markers
|
||||
markers =
|
||||
slow: mark tests as slow running
|
||||
integration: integration tests
|
||||
unit: unit tests
|
||||
db: tests requiring database
|
||||
@@ -8,3 +8,15 @@ requests==2.31.0
|
||||
passlib[bcrypt]==1.7.4
|
||||
boto3==1.28.0
|
||||
starlette>=0.27.0
|
||||
pyjwt==2.7.0
|
||||
sqlalchemy==2.0.23
|
||||
psycopg2-binary==2.9.9
|
||||
alembic==1.16.1
|
||||
pytest==8.1.1
|
||||
pytest-asyncio==0.23.6
|
||||
pytest-mock==3.12.0
|
||||
pytest-cov==4.1.0
|
||||
pytest-env==1.1.1
|
||||
httpx==0.27.0
|
||||
pre-commit
|
||||
apscheduler==3.10.4
|
||||
36
scripts/create_cognito_user.sh
Executable file
36
scripts/create_cognito_user.sh
Executable file
@@ -0,0 +1,36 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
if [ "$#" -lt 3 ]; then
|
||||
echo "Usage: $0 USER_POOL_ID USERNAME PASSWORD [--admin]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
USER_POOL_ID=$1
|
||||
USERNAME=$2
|
||||
PASSWORD=$3
|
||||
ADMIN_FLAG=${4:-""}
|
||||
|
||||
# Create user with temporary password
|
||||
CREATE_CMD="aws cognito-idp admin-create-user --no-cli-pager \
|
||||
--user-pool-id \"$USER_POOL_ID\" \
|
||||
--username \"$USERNAME\" \
|
||||
--temporary-password \"TempPass123!\" \
|
||||
--output json > /dev/null 2>&1"
|
||||
|
||||
if [ "$ADMIN_FLAG" == "--admin" ]; then
|
||||
CREATE_CMD+=" --user-attributes Name=zoneinfo,Value=admin"
|
||||
fi
|
||||
|
||||
eval "$CREATE_CMD"
|
||||
|
||||
# Set permanent password
|
||||
aws cognito-idp admin-set-user-password --no-cli-pager \
|
||||
--user-pool-id "$USER_POOL_ID" \
|
||||
--username "$USERNAME" \
|
||||
--password "$PASSWORD" \
|
||||
--permanent \
|
||||
--output json > /dev/null 2>&1
|
||||
|
||||
echo "User $USERNAME created successfully"
|
||||
18
scripts/delete_cognito_user.sh
Executable file
18
scripts/delete_cognito_user.sh
Executable file
@@ -0,0 +1,18 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
if [ "$#" -ne 2 ]; then
|
||||
echo "Usage: $0 USER_POOL_ID USERNAME"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
USER_POOL_ID=$1
|
||||
USERNAME=$2
|
||||
|
||||
aws cognito-idp admin-delete-user --no-cli-pager \
|
||||
--user-pool-id "$USER_POOL_ID" \
|
||||
--username "$USERNAME" \
|
||||
--output json > /dev/null 2>&1
|
||||
|
||||
echo "User $USERNAME deleted successfully"
|
||||
43
scripts/deploy.sh
Executable file
43
scripts/deploy.sh
Executable file
@@ -0,0 +1,43 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Load environment variables from .env file if it exists
|
||||
if [ -f ${PWD}/.env ]; then
|
||||
# Use set -a to automatically export all variables
|
||||
set -a
|
||||
source ${PWD}/.env
|
||||
set +a
|
||||
fi
|
||||
|
||||
# Check if required environment variables are set
|
||||
if [ -z "$FREEDNS_User" ] ||
|
||||
[ -z "$FREEDNS_Password" ] ||
|
||||
[ -z "$DOMAIN_NAME" ] ||
|
||||
[ -z "$SSH_PUBLIC_KEY" ] ||
|
||||
[ -z "$REPO_URL" ] ||
|
||||
[ -z "$LETSENCRYPT_EMAIL" ]; then
|
||||
echo "Error: FREEDNS_User, FREEDNS_Password, DOMAIN_NAME, SSH_PUBLIC_KEY, REPO_URL, and LETSENCRYPT_EMAIL must be set as environment variables or in a .env file."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Deploy infrastructure
|
||||
cdk deploy --app="python3 ${PWD}/app.py"
|
||||
|
||||
# Update application on running instances
|
||||
INSTANCE_IDS=$(aws ec2 describe-instances \
|
||||
--region us-east-2 \
|
||||
--filters "Name=tag:Name,Values=IptvManagerStack/IptvManagerInstance" \
|
||||
"Name=instance-state-name,Values=running" \
|
||||
--query "Reservations[].Instances[].InstanceId" \
|
||||
--output text)
|
||||
|
||||
for INSTANCE_ID in $INSTANCE_IDS; do
|
||||
echo "Updating application on instance: $INSTANCE_ID"
|
||||
aws ssm send-command \
|
||||
--instance-ids "$INSTANCE_ID" \
|
||||
--document-name "AWS-RunShellScript" \
|
||||
--parameters '{"commands":["cd /home/ec2-user/iptv-manager-service && git pull && pip3 install -r requirements.txt && alembic upgrade head && sudo systemctl restart iptv-manager"]}' \
|
||||
--no-cli-pager \
|
||||
--no-paginate
|
||||
done
|
||||
|
||||
echo "Deployment and instance update complete"
|
||||
23
scripts/destroy.sh
Executable file
23
scripts/destroy.sh
Executable file
@@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Load environment variables from .env file if it exists
|
||||
if [ -f ${PWD}/.env ]; then
|
||||
# Use set -a to automatically export all variables
|
||||
set -a
|
||||
source ${PWD}/.env
|
||||
set +a
|
||||
fi
|
||||
|
||||
# Check if required environment variables are set
|
||||
if [ -z "$FREEDNS_User" ] ||
|
||||
[ -z "$FREEDNS_Password" ] ||
|
||||
[ -z "$DOMAIN_NAME" ] ||
|
||||
[ -z "$SSH_PUBLIC_KEY" ] ||
|
||||
[ -z "$REPO_URL" ] ||
|
||||
[ -z "$LETSENCRYPT_EMAIL" ]; then
|
||||
echo "Error: FREEDNS_User, FREEDNS_Password, DOMAIN_NAME, SSH_PUBLIC_KEY, REPO_URL, and LETSENCRYPT_EMAIL must be set as environment variables or in a .env file."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Destroy infrastructure
|
||||
cdk destroy --app="python3 ${PWD}/app.py" --force
|
||||
13
scripts/install.sh
Executable file
13
scripts/install.sh
Executable file
@@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Install dependencies and deploy infrastructure
|
||||
npm install -g aws-cdk
|
||||
python3 -m pip install -r requirements.txt
|
||||
|
||||
# Install and configure pre-commit hooks
|
||||
pre-commit install
|
||||
pre-commit install-hooks
|
||||
pre-commit autoupdate
|
||||
|
||||
# Verify pytest setup
|
||||
python3 -m pytest
|
||||
28
scripts/start_local_dev.sh
Executable file
28
scripts/start_local_dev.sh
Executable file
@@ -0,0 +1,28 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Start PostgreSQL
|
||||
docker-compose -f docker/docker-compose-db.yml up -d
|
||||
|
||||
# Set environment variables
|
||||
export MOCK_AUTH=true
|
||||
export DB_HOST=localhost
|
||||
export DB_USER=postgres
|
||||
export DB_PASSWORD=postgres
|
||||
export DB_NAME=iptv_manager
|
||||
|
||||
echo "Ensuring database $DB_NAME exists using conditional DDL..."
|
||||
PGPASSWORD=$DB_PASSWORD docker exec -i postgres psql -U $DB_USER <<< "SELECT 'CREATE DATABASE $DB_NAME' WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = '$DB_NAME')\gexec"
|
||||
echo "Database $DB_NAME check complete."
|
||||
|
||||
# Run database migrations
|
||||
alembic upgrade head
|
||||
|
||||
# Start FastAPI
|
||||
nohup uvicorn app.main:app --host 127.0.0.1 --port 8000 > app.log 2>&1 &
|
||||
echo $! > iptv-manager.pid
|
||||
|
||||
echo "Services started:"
|
||||
echo "- PostgreSQL running on localhost:5432"
|
||||
echo "- FastAPI running on http://127.0.0.1:8000"
|
||||
echo "- Mock auth enabled (use token: testuser)"
|
||||
19
scripts/stop_local_dev.sh
Executable file
19
scripts/stop_local_dev.sh
Executable file
@@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Stop FastAPI
|
||||
if [ -f iptv-manager.pid ]; then
|
||||
kill $(cat iptv-manager.pid)
|
||||
rm iptv-manager.pid
|
||||
echo "Stopped FastAPI"
|
||||
fi
|
||||
|
||||
# Clean up mock auth and database environment variables
|
||||
unset MOCK_AUTH
|
||||
unset DB_USER
|
||||
unset DB_PASSWORD
|
||||
unset DB_HOST
|
||||
unset DB_NAME
|
||||
|
||||
# Stop PostgreSQL
|
||||
docker-compose -f docker/docker-compose-db.yml down
|
||||
echo "Stopped PostgreSQL"
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
0
tests/auth/__init__.py
Normal file
0
tests/auth/__init__.py
Normal file
169
tests/auth/test_cognito.py
Normal file
169
tests/auth/test_cognito.py
Normal file
@@ -0,0 +1,169 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
# Test constants
|
||||
TEST_CLIENT_ID = "test_client_id"
|
||||
TEST_CLIENT_SECRET = "test_client_secret"
|
||||
|
||||
# Patch constants before importing the module
|
||||
with (
|
||||
patch("app.utils.constants.COGNITO_CLIENT_ID", TEST_CLIENT_ID),
|
||||
patch("app.utils.constants.COGNITO_CLIENT_SECRET", TEST_CLIENT_SECRET),
|
||||
):
|
||||
from app.auth.cognito import get_user_from_token, initiate_auth
|
||||
from app.models.auth import CognitoUser
|
||||
from app.utils.constants import USER_ROLE_ATTRIBUTE
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_cognito_client():
|
||||
with patch("app.auth.cognito.cognito_client") as mock_client:
|
||||
# Setup mock client and exceptions
|
||||
mock_client.exceptions = MagicMock()
|
||||
mock_client.exceptions.NotAuthorizedException = type(
|
||||
"NotAuthorizedException", (Exception,), {}
|
||||
)
|
||||
mock_client.exceptions.UserNotFoundException = type(
|
||||
"UserNotFoundException", (Exception,), {}
|
||||
)
|
||||
yield mock_client
|
||||
|
||||
|
||||
def test_initiate_auth_success(mock_cognito_client):
|
||||
# Mock successful authentication response
|
||||
mock_cognito_client.initiate_auth.return_value = {
|
||||
"AuthenticationResult": {
|
||||
"AccessToken": "mock_access_token",
|
||||
"IdToken": "mock_id_token",
|
||||
"RefreshToken": "mock_refresh_token",
|
||||
}
|
||||
}
|
||||
|
||||
result = initiate_auth("test_user", "test_pass")
|
||||
assert result == {
|
||||
"AccessToken": "mock_access_token",
|
||||
"IdToken": "mock_id_token",
|
||||
"RefreshToken": "mock_refresh_token",
|
||||
}
|
||||
|
||||
|
||||
def test_initiate_auth_with_secret_hash(mock_cognito_client):
|
||||
with patch(
|
||||
"app.auth.cognito.calculate_secret_hash", return_value="mocked_secret_hash"
|
||||
) as mock_hash:
|
||||
mock_cognito_client.initiate_auth.return_value = {
|
||||
"AuthenticationResult": {"AccessToken": "token"}
|
||||
}
|
||||
|
||||
initiate_auth("test_user", "test_pass")
|
||||
|
||||
# Verify calculate_secret_hash was called
|
||||
mock_hash.assert_called_once_with(
|
||||
"test_user", TEST_CLIENT_ID, TEST_CLIENT_SECRET
|
||||
)
|
||||
|
||||
# Verify SECRET_HASH was included in auth params
|
||||
call_args = mock_cognito_client.initiate_auth.call_args[1]
|
||||
assert "SECRET_HASH" in call_args["AuthParameters"]
|
||||
assert call_args["AuthParameters"]["SECRET_HASH"] == "mocked_secret_hash"
|
||||
|
||||
|
||||
def test_initiate_auth_not_authorized(mock_cognito_client):
|
||||
mock_cognito_client.initiate_auth.side_effect = (
|
||||
mock_cognito_client.exceptions.NotAuthorizedException()
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
initiate_auth("invalid_user", "wrong_pass")
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid username or password"
|
||||
|
||||
|
||||
def test_initiate_auth_user_not_found(mock_cognito_client):
|
||||
mock_cognito_client.initiate_auth.side_effect = (
|
||||
mock_cognito_client.exceptions.UserNotFoundException()
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
initiate_auth("nonexistent_user", "any_pass")
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert exc_info.value.detail == "User not found"
|
||||
|
||||
|
||||
def test_initiate_auth_generic_error(mock_cognito_client):
|
||||
mock_cognito_client.initiate_auth.side_effect = Exception("Some error")
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
initiate_auth("test_user", "test_pass")
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert "An error occurred during authentication" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_get_user_from_token_success(mock_cognito_client):
|
||||
mock_response = {
|
||||
"Username": "test_user",
|
||||
"UserAttributes": [
|
||||
{"Name": "sub", "Value": "123"},
|
||||
{"Name": USER_ROLE_ATTRIBUTE, "Value": "admin,user"},
|
||||
],
|
||||
}
|
||||
mock_cognito_client.get_user.return_value = mock_response
|
||||
|
||||
result = get_user_from_token("valid_token")
|
||||
|
||||
assert isinstance(result, CognitoUser)
|
||||
assert result.username == "test_user"
|
||||
assert set(result.roles) == {"admin", "user"}
|
||||
|
||||
|
||||
def test_get_user_from_token_no_roles(mock_cognito_client):
|
||||
mock_response = {
|
||||
"Username": "test_user",
|
||||
"UserAttributes": [{"Name": "sub", "Value": "123"}],
|
||||
}
|
||||
mock_cognito_client.get_user.return_value = mock_response
|
||||
|
||||
result = get_user_from_token("valid_token")
|
||||
|
||||
assert isinstance(result, CognitoUser)
|
||||
assert result.username == "test_user"
|
||||
assert result.roles == []
|
||||
|
||||
|
||||
def test_get_user_from_token_invalid_token(mock_cognito_client):
|
||||
mock_cognito_client.get_user.side_effect = (
|
||||
mock_cognito_client.exceptions.NotAuthorizedException()
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_user_from_token("invalid_token")
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid or expired token."
|
||||
|
||||
|
||||
def test_get_user_from_token_user_not_found(mock_cognito_client):
|
||||
mock_cognito_client.get_user.side_effect = (
|
||||
mock_cognito_client.exceptions.UserNotFoundException()
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_user_from_token("token_for_nonexistent_user")
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "User not found or invalid token."
|
||||
|
||||
|
||||
def test_get_user_from_token_generic_error(mock_cognito_client):
|
||||
mock_cognito_client.get_user.side_effect = Exception("Some error")
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_user_from_token("test_token")
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert "Token verification failed" in exc_info.value.detail
|
||||
205
tests/auth/test_dependencies.py
Normal file
205
tests/auth/test_dependencies.py
Normal file
@@ -0,0 +1,205 @@
|
||||
import importlib
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
from app.auth.dependencies import get_current_user, oauth2_scheme, require_roles
|
||||
from app.models.auth import CognitoUser
|
||||
|
||||
# Mock user for testing
|
||||
TEST_USER = CognitoUser(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
roles=["admin", "user"],
|
||||
groups=["test_group"],
|
||||
)
|
||||
|
||||
|
||||
# Mock the underlying get_user_from_token function
|
||||
def mock_get_user_from_token(token: str) -> CognitoUser:
|
||||
if token == "valid_token":
|
||||
return TEST_USER
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
|
||||
# Mock endpoint for testing the require_roles decorator
|
||||
@require_roles("admin")
|
||||
def mock_protected_endpoint(user: CognitoUser = Depends(get_current_user)):
|
||||
return {"message": "Success", "user": user.username}
|
||||
|
||||
|
||||
# Patch the get_user_from_token function for testing
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"app.auth.dependencies.get_user_from_token", mock_get_user_from_token
|
||||
)
|
||||
|
||||
|
||||
# Test get_current_user dependency
|
||||
def test_get_current_user_success():
|
||||
user = get_current_user("valid_token")
|
||||
assert user == TEST_USER
|
||||
assert user.username == "testuser"
|
||||
assert user.roles == ["admin", "user"]
|
||||
|
||||
|
||||
def test_get_current_user_invalid_token():
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
get_current_user("invalid_token")
|
||||
assert exc.value.status_code == 401
|
||||
|
||||
|
||||
# Test require_roles decorator
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_roles_success():
|
||||
# Create test user with required role
|
||||
user = CognitoUser(
|
||||
username="testuser", email="test@example.com", roles=["admin"], groups=[]
|
||||
)
|
||||
|
||||
result = await mock_protected_endpoint(user=user)
|
||||
assert result == {"message": "Success", "user": "testuser"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_roles_missing_role():
|
||||
# Create test user without required role
|
||||
user = CognitoUser(
|
||||
username="testuser", email="test@example.com", roles=["user"], groups=[]
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await mock_protected_endpoint(user=user)
|
||||
assert exc.value.status_code == 403
|
||||
assert (
|
||||
exc.value.detail
|
||||
== "You do not have the required roles to access this endpoint."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_roles_no_roles():
|
||||
# Create test user with no roles
|
||||
user = CognitoUser(
|
||||
username="testuser", email="test@example.com", roles=[], groups=[]
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await mock_protected_endpoint(user=user)
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_roles_multiple_roles():
|
||||
# Test requiring multiple roles
|
||||
@require_roles("admin", "super_user")
|
||||
def mock_multi_role_endpoint(user: CognitoUser = Depends(get_current_user)):
|
||||
return {"message": "Success"}
|
||||
|
||||
# User with all required roles
|
||||
user_with_roles = CognitoUser(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
roles=["admin", "super_user", "user"],
|
||||
groups=[],
|
||||
)
|
||||
result = await mock_multi_role_endpoint(user=user_with_roles)
|
||||
assert result == {"message": "Success"}
|
||||
|
||||
# User missing one required role
|
||||
user_missing_role = CognitoUser(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
roles=["admin", "user"],
|
||||
groups=[],
|
||||
)
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await mock_multi_role_endpoint(user=user_missing_role)
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth2_scheme_configuration():
|
||||
# Verify that we have a properly configured OAuth2PasswordBearer instance
|
||||
assert isinstance(oauth2_scheme, OAuth2PasswordBearer)
|
||||
|
||||
# Create a mock request with no Authorization header
|
||||
mock_request = Request(
|
||||
scope={
|
||||
"type": "http",
|
||||
"headers": [],
|
||||
"method": "GET",
|
||||
"scheme": "http",
|
||||
"path": "/",
|
||||
"query_string": b"",
|
||||
"client": ("127.0.0.1", 8000),
|
||||
}
|
||||
)
|
||||
|
||||
# Test that the scheme raises 401 when no token is provided
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await oauth2_scheme(mock_request)
|
||||
assert exc.value.status_code == 401
|
||||
assert exc.value.detail == "Not authenticated"
|
||||
|
||||
|
||||
def test_mock_auth_import(monkeypatch):
|
||||
# Save original env var value
|
||||
original_value = os.environ.get("MOCK_AUTH")
|
||||
|
||||
try:
|
||||
# Set MOCK_AUTH to true
|
||||
monkeypatch.setenv("MOCK_AUTH", "true")
|
||||
|
||||
# Reload the dependencies module to trigger the import condition
|
||||
import app.auth.dependencies
|
||||
|
||||
importlib.reload(app.auth.dependencies)
|
||||
|
||||
# Verify that mock_get_user_from_token was imported
|
||||
from app.auth.dependencies import get_user_from_token
|
||||
|
||||
assert get_user_from_token.__module__ == "app.auth.mock_auth"
|
||||
|
||||
finally:
|
||||
# Restore original env var
|
||||
if original_value is None:
|
||||
monkeypatch.delenv("MOCK_AUTH", raising=False)
|
||||
else:
|
||||
monkeypatch.setenv("MOCK_AUTH", original_value)
|
||||
|
||||
# Reload again to restore original state
|
||||
importlib.reload(app.auth.dependencies)
|
||||
|
||||
|
||||
def test_cognito_auth_import(monkeypatch):
|
||||
"""Test that cognito auth is imported when MOCK_AUTH=false (covers line 14)"""
|
||||
# Save original env var value
|
||||
original_value = os.environ.get("MOCK_AUTH")
|
||||
|
||||
try:
|
||||
# Set MOCK_AUTH to false
|
||||
monkeypatch.setenv("MOCK_AUTH", "false")
|
||||
|
||||
# Reload the dependencies module to trigger the import condition
|
||||
import app.auth.dependencies
|
||||
|
||||
importlib.reload(app.auth.dependencies)
|
||||
|
||||
# Verify that get_user_from_token was imported from app.auth.cognito
|
||||
from app.auth.dependencies import get_user_from_token
|
||||
|
||||
assert get_user_from_token.__module__ == "app.auth.cognito"
|
||||
|
||||
finally:
|
||||
# Restore original env var
|
||||
if original_value is None:
|
||||
monkeypatch.delenv("MOCK_AUTH", raising=False)
|
||||
else:
|
||||
monkeypatch.setenv("MOCK_AUTH", original_value)
|
||||
|
||||
# Reload again to restore original state
|
||||
importlib.reload(app.auth.dependencies)
|
||||
41
tests/auth/test_mock_auth.py
Normal file
41
tests/auth/test_mock_auth.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.auth.mock_auth import mock_get_user_from_token, mock_initiate_auth
|
||||
from app.models.auth import CognitoUser
|
||||
|
||||
|
||||
def test_mock_get_user_from_token_success():
|
||||
"""Test successful token validation returns expected user"""
|
||||
user = mock_get_user_from_token("testuser")
|
||||
assert isinstance(user, CognitoUser)
|
||||
assert user.username == "testuser"
|
||||
assert user.roles == ["admin"]
|
||||
|
||||
|
||||
def test_mock_get_user_from_token_invalid():
|
||||
"""Test invalid token raises expected exception"""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
mock_get_user_from_token("invalid_token")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail == "Invalid mock token - use 'testuser'"
|
||||
|
||||
|
||||
def test_mock_initiate_auth():
|
||||
"""Test mock authentication returns expected token response"""
|
||||
result = mock_initiate_auth("any_user", "any_password")
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["AccessToken"] == "testuser"
|
||||
assert result["ExpiresIn"] == 3600
|
||||
assert result["TokenType"] == "Bearer"
|
||||
|
||||
|
||||
def test_mock_initiate_auth_different_credentials():
|
||||
"""Test mock authentication works with any credentials"""
|
||||
result1 = mock_initiate_auth("user1", "pass1")
|
||||
result2 = mock_initiate_auth("user2", "pass2")
|
||||
|
||||
# Should return same mock token regardless of credentials
|
||||
assert result1 == result2
|
||||
0
tests/iptv/__init__.py
Normal file
0
tests/iptv/__init__.py
Normal file
0
tests/models/__init__.py
Normal file
0
tests/models/__init__.py
Normal file
144
tests/models/test_db.py
Normal file
144
tests/models/test_db.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import os
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from app.models.db import UUID_COLUMN_TYPE, Base, SQLiteUUID
|
||||
|
||||
# --- Test SQLiteUUID Type ---
|
||||
|
||||
|
||||
def test_sqliteuuid_process_bind_param_none():
|
||||
"""Test SQLiteUUID.process_bind_param with None returns None"""
|
||||
uuid_type = SQLiteUUID()
|
||||
assert uuid_type.process_bind_param(None, None) is None
|
||||
|
||||
|
||||
def test_sqliteuuid_process_bind_param_valid_uuid():
|
||||
"""Test SQLiteUUID.process_bind_param with valid UUID returns string"""
|
||||
uuid_type = SQLiteUUID()
|
||||
test_uuid = uuid.uuid4()
|
||||
assert uuid_type.process_bind_param(test_uuid, None) == str(test_uuid)
|
||||
|
||||
|
||||
def test_sqliteuuid_process_bind_param_valid_string():
|
||||
"""Test SQLiteUUID.process_bind_param with valid UUID string returns string"""
|
||||
uuid_type = SQLiteUUID()
|
||||
test_uuid_str = "550e8400-e29b-41d4-a716-446655440000"
|
||||
assert uuid_type.process_bind_param(test_uuid_str, None) == test_uuid_str
|
||||
|
||||
|
||||
def test_sqliteuuid_process_bind_param_invalid_string():
|
||||
"""Test SQLiteUUID.process_bind_param raises ValueError for invalid UUID"""
|
||||
uuid_type = SQLiteUUID()
|
||||
with pytest.raises(ValueError, match="Invalid UUID string format"):
|
||||
uuid_type.process_bind_param("invalid-uuid", None)
|
||||
|
||||
|
||||
def test_sqliteuuid_process_result_value_none():
|
||||
"""Test SQLiteUUID.process_result_value with None returns None"""
|
||||
uuid_type = SQLiteUUID()
|
||||
assert uuid_type.process_result_value(None, None) is None
|
||||
|
||||
|
||||
def test_sqliteuuid_process_result_value_valid_string():
|
||||
"""Test SQLiteUUID.process_result_value converts string to UUID"""
|
||||
uuid_type = SQLiteUUID()
|
||||
test_uuid = uuid.uuid4()
|
||||
result = uuid_type.process_result_value(str(test_uuid), None)
|
||||
assert isinstance(result, uuid.UUID)
|
||||
assert result == test_uuid
|
||||
|
||||
|
||||
def test_sqliteuuid_process_result_value_uuid_object():
|
||||
"""Test SQLiteUUID.process_result_value: UUID object returns itself."""
|
||||
uuid_type = SQLiteUUID()
|
||||
test_uuid = uuid.uuid4()
|
||||
result = uuid_type.process_result_value(test_uuid, None)
|
||||
assert isinstance(result, uuid.UUID)
|
||||
assert result is test_uuid # Ensure it's the same object, not a new one
|
||||
|
||||
|
||||
def test_sqliteuuid_compare_values_none():
|
||||
"""Test SQLiteUUID.compare_values handles None values"""
|
||||
uuid_type = SQLiteUUID()
|
||||
assert uuid_type.compare_values(None, None) is True
|
||||
assert uuid_type.compare_values(None, uuid.uuid4()) is False
|
||||
assert uuid_type.compare_values(uuid.uuid4(), None) is False
|
||||
|
||||
|
||||
def test_sqliteuuid_compare_values_uuid():
|
||||
"""Test SQLiteUUID.compare_values compares UUIDs as strings"""
|
||||
uuid_type = SQLiteUUID()
|
||||
test_uuid = uuid.uuid4()
|
||||
assert uuid_type.compare_values(test_uuid, test_uuid) is True
|
||||
assert uuid_type.compare_values(test_uuid, uuid.uuid4()) is False
|
||||
|
||||
|
||||
def test_sqlite_uuid_comparison():
|
||||
"""Test SQLiteUUID comparison functionality (moved from db_mocks.py)"""
|
||||
uuid_type = SQLiteUUID()
|
||||
|
||||
# Test equal UUIDs
|
||||
uuid1 = uuid.uuid4()
|
||||
uuid2 = uuid.UUID(str(uuid1))
|
||||
assert uuid_type.compare_values(uuid1, uuid2) is True
|
||||
|
||||
# Test UUID vs string
|
||||
assert uuid_type.compare_values(uuid1, str(uuid1)) is True
|
||||
|
||||
# Test None comparisons
|
||||
assert uuid_type.compare_values(None, None) is True
|
||||
assert uuid_type.compare_values(uuid1, None) is False
|
||||
assert uuid_type.compare_values(None, uuid1) is False
|
||||
|
||||
# Test different UUIDs
|
||||
uuid3 = uuid.uuid4()
|
||||
assert uuid_type.compare_values(uuid1, uuid3) is False
|
||||
|
||||
|
||||
def test_sqlite_uuid_binding():
|
||||
"""Test SQLiteUUID binding parameter handling (moved from db_mocks.py)"""
|
||||
uuid_type = SQLiteUUID()
|
||||
|
||||
# Test UUID object binding
|
||||
uuid_obj = uuid.uuid4()
|
||||
assert uuid_type.process_bind_param(uuid_obj, None) == str(uuid_obj)
|
||||
|
||||
# Test valid UUID string binding
|
||||
uuid_str = str(uuid.uuid4())
|
||||
assert uuid_type.process_bind_param(uuid_str, None) == uuid_str
|
||||
|
||||
# Test None handling
|
||||
assert uuid_type.process_bind_param(None, None) is None
|
||||
|
||||
# Test invalid UUID string
|
||||
with pytest.raises(ValueError):
|
||||
uuid_type.process_bind_param("invalid-uuid", None)
|
||||
|
||||
|
||||
# --- Test UUID Column Type Configuration ---
|
||||
|
||||
|
||||
def test_uuid_column_type_default():
|
||||
"""Test UUID_COLUMN_TYPE uses SQLiteUUID in test environment"""
|
||||
assert isinstance(UUID_COLUMN_TYPE, SQLiteUUID)
|
||||
|
||||
|
||||
@patch.dict(os.environ, {"MOCK_AUTH": "false"})
|
||||
def test_uuid_column_type_postgres():
|
||||
"""Test UUID_COLUMN_TYPE uses Postgres UUID when MOCK_AUTH=false"""
|
||||
# Need to re-import to get the patched environment
|
||||
from importlib import reload
|
||||
|
||||
from app import models
|
||||
|
||||
reload(models.db)
|
||||
from sqlalchemy.dialects.postgresql import UUID as PostgresUUID
|
||||
|
||||
from app.models.db import UUID_COLUMN_TYPE
|
||||
|
||||
assert isinstance(UUID_COLUMN_TYPE, PostgresUUID)
|
||||
0
tests/routers/__init__.py
Normal file
0
tests/routers/__init__.py
Normal file
43
tests/routers/mocks.py
Normal file
43
tests/routers/mocks.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from app.iptv.scheduler import StreamScheduler
|
||||
|
||||
|
||||
class MockScheduler:
|
||||
"""Base mock APScheduler instance"""
|
||||
|
||||
running = True
|
||||
start = Mock()
|
||||
shutdown = Mock()
|
||||
add_job = Mock()
|
||||
remove_job = Mock()
|
||||
get_job = Mock(return_value=None)
|
||||
|
||||
def __init__(self, running=True):
|
||||
self.running = running
|
||||
|
||||
|
||||
def create_trigger_mock(triggered_ref: dict) -> callable:
|
||||
"""Create a mock trigger function that updates a reference when called"""
|
||||
|
||||
def trigger_mock():
|
||||
triggered_ref["value"] = True
|
||||
|
||||
return trigger_mock
|
||||
|
||||
|
||||
async def mock_get_scheduler(
|
||||
request: Request, scheduler_class=MockScheduler, running=True, **kwargs
|
||||
) -> StreamScheduler:
|
||||
"""Mock dependency for get_scheduler with customization options"""
|
||||
scheduler = StreamScheduler()
|
||||
mock_scheduler = scheduler_class(running=running)
|
||||
|
||||
# Apply any additional attributes/methods
|
||||
for key, value in kwargs.items():
|
||||
setattr(mock_scheduler, key, value)
|
||||
|
||||
scheduler.scheduler = mock_scheduler
|
||||
return scheduler
|
||||
101
tests/routers/test_auth.py
Normal file
101
tests/routers/test_auth.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException, status
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.main import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_successful_auth():
|
||||
return {
|
||||
"AccessToken": "mock_access_token",
|
||||
"IdToken": "mock_id_token",
|
||||
"RefreshToken": "mock_refresh_token",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_successful_auth_no_refresh():
|
||||
return {"AccessToken": "mock_access_token", "IdToken": "mock_id_token"}
|
||||
|
||||
|
||||
def test_signin_success(mock_successful_auth):
|
||||
"""Test successful signin with all tokens"""
|
||||
with patch("app.routers.auth.initiate_auth", return_value=mock_successful_auth):
|
||||
response = client.post(
|
||||
"/auth/signin", json={"username": "testuser", "password": "testpass"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["access_token"] == "mock_access_token"
|
||||
assert data["id_token"] == "mock_id_token"
|
||||
assert data["refresh_token"] == "mock_refresh_token"
|
||||
assert data["token_type"] == "Bearer"
|
||||
|
||||
|
||||
def test_signin_success_no_refresh(mock_successful_auth_no_refresh):
|
||||
"""Test successful signin without refresh token"""
|
||||
with patch(
|
||||
"app.routers.auth.initiate_auth", return_value=mock_successful_auth_no_refresh
|
||||
):
|
||||
response = client.post(
|
||||
"/auth/signin", json={"username": "testuser", "password": "testpass"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["access_token"] == "mock_access_token"
|
||||
assert data["id_token"] == "mock_id_token"
|
||||
assert data["refresh_token"] is None
|
||||
assert data["token_type"] == "Bearer"
|
||||
|
||||
|
||||
def test_signin_invalid_input():
|
||||
"""Test signin with invalid input format"""
|
||||
# Missing password
|
||||
response = client.post("/auth/signin", json={"username": "testuser"})
|
||||
assert response.status_code == 422
|
||||
|
||||
# Missing username
|
||||
response = client.post("/auth/signin", json={"password": "testpass"})
|
||||
assert response.status_code == 422
|
||||
|
||||
# Empty payload
|
||||
response = client.post("/auth/signin", json={})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_signin_auth_failure():
|
||||
"""Test signin with authentication failure"""
|
||||
with patch("app.routers.auth.initiate_auth") as mock_auth:
|
||||
mock_auth.side_effect = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password",
|
||||
)
|
||||
response = client.post(
|
||||
"/auth/signin", json={"username": "testuser", "password": "wrongpass"}
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
data = response.json()
|
||||
assert data["detail"] == "Invalid username or password"
|
||||
|
||||
|
||||
def test_signin_user_not_found():
|
||||
"""Test signin with non-existent user"""
|
||||
with patch("app.routers.auth.initiate_auth") as mock_auth:
|
||||
mock_auth.side_effect = HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
response = client.post(
|
||||
"/auth/signin", json={"username": "nonexistent", "password": "testpass"}
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert data["detail"] == "User not found"
|
||||
1588
tests/routers/test_channels.py
Normal file
1588
tests/routers/test_channels.py
Normal file
File diff suppressed because it is too large
Load Diff
461
tests/routers/test_groups.py
Normal file
461
tests/routers/test_groups.py
Normal file
@@ -0,0 +1,461 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.auth.dependencies import get_current_user
|
||||
from app.utils.database import get_db
|
||||
|
||||
# Import mocks and fixtures
|
||||
from tests.utils.auth_test_fixtures import (
|
||||
admin_user_client,
|
||||
db_session,
|
||||
non_admin_user_client,
|
||||
)
|
||||
from tests.utils.db_mocks import (
|
||||
MockChannelDB,
|
||||
MockGroup,
|
||||
create_mock_priorities_and_group,
|
||||
)
|
||||
|
||||
# --- Test Cases For Group Creation ---
|
||||
|
||||
|
||||
def test_create_group_success(db_session: Session, admin_user_client):
|
||||
group_data = {"name": "Test Group", "sort_order": 1}
|
||||
response = admin_user_client.post("/groups/", json=group_data)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
data = response.json()
|
||||
assert data["name"] == "Test Group"
|
||||
assert data["sort_order"] == 1
|
||||
assert "id" in data
|
||||
assert "created_at" in data
|
||||
assert "updated_at" in data
|
||||
|
||||
# Verify in DB
|
||||
db_group = (
|
||||
db_session.query(MockGroup).filter(MockGroup.name == "Test Group").first()
|
||||
)
|
||||
assert db_group is not None
|
||||
assert db_group.sort_order == 1
|
||||
|
||||
|
||||
def test_create_group_duplicate(db_session: Session, admin_user_client):
|
||||
# Create initial group
|
||||
initial_group = MockGroup(
|
||||
id=uuid.uuid4(),
|
||||
name="Duplicate Group",
|
||||
sort_order=1,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(initial_group)
|
||||
db_session.commit()
|
||||
|
||||
# Attempt to create duplicate
|
||||
response = admin_user_client.post(
|
||||
"/groups/", json={"name": "Duplicate Group", "sort_order": 2}
|
||||
)
|
||||
assert response.status_code == status.HTTP_409_CONFLICT
|
||||
assert "already exists" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_create_group_forbidden_for_non_admin(
|
||||
db_session: Session, non_admin_user_client
|
||||
):
|
||||
response = non_admin_user_client.post(
|
||||
"/groups/", json={"name": "Forbidden Group", "sort_order": 1}
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert "required roles" in response.json()["detail"]
|
||||
|
||||
|
||||
# --- Test Cases For Get Group ---
|
||||
|
||||
|
||||
def test_get_group_success(db_session: Session, admin_user_client):
|
||||
# Create a group first
|
||||
test_group = MockGroup(
|
||||
id=uuid.uuid4(),
|
||||
name="Get Me Group",
|
||||
sort_order=1,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(test_group)
|
||||
db_session.commit()
|
||||
|
||||
response = admin_user_client.get(f"/groups/{test_group.id}")
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["id"] == str(test_group.id)
|
||||
assert data["name"] == "Get Me Group"
|
||||
assert data["sort_order"] == 1
|
||||
|
||||
|
||||
def test_get_group_not_found(db_session: Session, admin_user_client):
|
||||
random_uuid = uuid.uuid4()
|
||||
response = admin_user_client.get(f"/groups/{random_uuid}")
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "Group not found" in response.json()["detail"]
|
||||
|
||||
|
||||
# --- Test Cases For Update Group ---
|
||||
|
||||
|
||||
def test_update_group_success(db_session: Session, admin_user_client):
|
||||
# Create initial group
|
||||
group_id = uuid.uuid4()
|
||||
test_group = MockGroup(
|
||||
id=group_id,
|
||||
name="Update Me",
|
||||
sort_order=1,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(test_group)
|
||||
db_session.commit()
|
||||
|
||||
update_data = {"name": "Updated Name", "sort_order": 2}
|
||||
response = admin_user_client.put(f"/groups/{group_id}", json=update_data)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["name"] == "Updated Name"
|
||||
assert data["sort_order"] == 2
|
||||
|
||||
# Verify in DB
|
||||
db_group = db_session.query(MockGroup).filter(MockGroup.id == group_id).first()
|
||||
assert db_group.name == "Updated Name"
|
||||
assert db_group.sort_order == 2
|
||||
|
||||
|
||||
def test_update_group_conflict(db_session: Session, admin_user_client):
|
||||
# Create two groups
|
||||
group1 = MockGroup(
|
||||
id=uuid.uuid4(),
|
||||
name="Group One",
|
||||
sort_order=1,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
group2 = MockGroup(
|
||||
id=uuid.uuid4(),
|
||||
name="Group Two",
|
||||
sort_order=2,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add_all([group1, group2])
|
||||
db_session.commit()
|
||||
|
||||
# Try to rename group2 to conflict with group1
|
||||
response = admin_user_client.put(f"/groups/{group2.id}", json={"name": "Group One"})
|
||||
assert response.status_code == status.HTTP_409_CONFLICT
|
||||
assert "already exists" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_update_group_not_found(db_session: Session, admin_user_client):
|
||||
random_uuid = uuid.uuid4()
|
||||
response = admin_user_client.put(
|
||||
f"/groups/{random_uuid}", json={"name": "Non-existent"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "Group not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_update_group_forbidden_for_non_admin(
|
||||
db_session: Session, non_admin_user_client, admin_user_client
|
||||
):
|
||||
# Create group with admin
|
||||
group_id = uuid.uuid4()
|
||||
test_group = MockGroup(
|
||||
id=group_id,
|
||||
name="Admin Created",
|
||||
sort_order=1,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(test_group)
|
||||
db_session.commit()
|
||||
|
||||
# Attempt update with non-admin
|
||||
response = non_admin_user_client.put(
|
||||
f"/groups/{group_id}", json={"name": "Non-Admin Update"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert "required roles" in response.json()["detail"]
|
||||
|
||||
|
||||
# --- Test Cases For Delete Group ---
|
||||
|
||||
|
||||
def test_delete_all_groups_success(db_session, admin_user_client):
|
||||
"""Test reset groups endpoint"""
|
||||
# Create test data
|
||||
group1_id = create_mock_priorities_and_group(db_session, [], "Group A")
|
||||
group2_id = create_mock_priorities_and_group(db_session, [], "Group B")
|
||||
|
||||
# Add channel to group2
|
||||
channel_data = [
|
||||
{
|
||||
"group-title": "Group A",
|
||||
"tvg_id": "channel1.tv",
|
||||
"name": "Channel One",
|
||||
"url": ["http://test.com", "http://example.com"],
|
||||
}
|
||||
]
|
||||
admin_user_client.post("/channels/bulk-upload", json=channel_data)
|
||||
|
||||
# Reset groups
|
||||
response = admin_user_client.delete("/groups")
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json()["deleted"] == 1 # Only group2 should be deleted
|
||||
assert response.json()["skipped"] == 1 # group1 has channels
|
||||
|
||||
# Verify group2 deleted, group1 remains
|
||||
assert (
|
||||
db_session.query(MockGroup).filter(MockGroup.id == group1_id).first()
|
||||
is not None
|
||||
)
|
||||
assert db_session.query(MockGroup).filter(MockGroup.id == group2_id).first() is None
|
||||
|
||||
|
||||
def test_delete_all_groups_forbidden_for_non_admin(db_session, non_admin_user_client):
|
||||
"""Test reset groups requires admin role"""
|
||||
response = non_admin_user_client.delete("/groups")
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
def test_delete_group_success(db_session: Session, admin_user_client):
|
||||
# Create group
|
||||
group_id = uuid.uuid4()
|
||||
test_group = MockGroup(
|
||||
id=group_id,
|
||||
name="Delete Me",
|
||||
sort_order=1,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(test_group)
|
||||
db_session.commit()
|
||||
|
||||
# Verify exists before delete
|
||||
assert (
|
||||
db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is not None
|
||||
)
|
||||
|
||||
response = admin_user_client.delete(f"/groups/{group_id}")
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
|
||||
# Verify deleted
|
||||
assert db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is None
|
||||
|
||||
|
||||
def test_delete_group_with_channels_fails(db_session: Session, admin_user_client):
|
||||
# Create group with channel
|
||||
group_id = uuid.uuid4()
|
||||
test_group = MockGroup(
|
||||
id=group_id,
|
||||
name="Group With Channels",
|
||||
sort_order=1,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(test_group)
|
||||
|
||||
# Create channel in this group
|
||||
test_channel = MockChannelDB(
|
||||
id=uuid.uuid4(),
|
||||
tvg_id="channel1.tv",
|
||||
name="Channel 1",
|
||||
group_id=group_id,
|
||||
tvg_name="Channel1",
|
||||
tvg_logo="logo.png",
|
||||
)
|
||||
db_session.add(test_channel)
|
||||
db_session.commit()
|
||||
|
||||
response = admin_user_client.delete(f"/groups/{group_id}")
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert "existing channels" in response.json()["detail"]
|
||||
|
||||
# Verify group still exists
|
||||
assert (
|
||||
db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is not None
|
||||
)
|
||||
|
||||
|
||||
def test_delete_group_not_found(db_session: Session, admin_user_client):
|
||||
random_uuid = uuid.uuid4()
|
||||
response = admin_user_client.delete(f"/groups/{random_uuid}")
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "Group not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_delete_group_forbidden_for_non_admin(
|
||||
db_session: Session, non_admin_user_client, admin_user_client
|
||||
):
|
||||
# Create group with admin
|
||||
group_id = uuid.uuid4()
|
||||
test_group = MockGroup(
|
||||
id=group_id,
|
||||
name="Admin Created",
|
||||
sort_order=1,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(test_group)
|
||||
db_session.commit()
|
||||
|
||||
# Attempt delete with non-admin
|
||||
response = non_admin_user_client.delete(f"/groups/{group_id}")
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert "required roles" in response.json()["detail"]
|
||||
|
||||
# Verify group still exists
|
||||
assert (
|
||||
db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is not None
|
||||
)
|
||||
|
||||
|
||||
# --- Test Cases For List Groups ---
|
||||
|
||||
|
||||
def test_list_groups_empty(db_session: Session, admin_user_client):
|
||||
response = admin_user_client.get("/groups/")
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json() == []
|
||||
|
||||
|
||||
def test_list_groups_with_data(db_session: Session, admin_user_client):
|
||||
# Create some groups
|
||||
groups = [
|
||||
MockGroup(
|
||||
id=uuid.uuid4(),
|
||||
name=f"Group {i}",
|
||||
sort_order=i,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
db_session.add_all(groups)
|
||||
db_session.commit()
|
||||
|
||||
response = admin_user_client.get("/groups/")
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert len(data) == 3
|
||||
assert data[0]["sort_order"] == 0 # Should be sorted by sort_order
|
||||
assert data[1]["sort_order"] == 1
|
||||
assert data[2]["sort_order"] == 2
|
||||
|
||||
|
||||
# --- Test Cases For Sort Order Updates ---
|
||||
|
||||
|
||||
def test_update_group_sort_order_success(db_session: Session, admin_user_client):
|
||||
# Create group
|
||||
group_id = uuid.uuid4()
|
||||
test_group = MockGroup(
|
||||
id=group_id,
|
||||
name="Sort Me",
|
||||
sort_order=1,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(test_group)
|
||||
db_session.commit()
|
||||
|
||||
response = admin_user_client.put(f"/groups/{group_id}/sort", json={"sort_order": 5})
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["sort_order"] == 5
|
||||
|
||||
# Verify in DB
|
||||
db_group = db_session.query(MockGroup).filter(MockGroup.id == group_id).first()
|
||||
assert db_group.sort_order == 5
|
||||
|
||||
|
||||
def test_update_group_sort_order_not_found(db_session: Session, admin_user_client):
|
||||
"""Test that updating sort order for non-existent group returns 404"""
|
||||
random_uuid = uuid.uuid4()
|
||||
response = admin_user_client.put(
|
||||
f"/groups/{random_uuid}/sort", json={"sort_order": 5}
|
||||
)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "Group not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_bulk_update_sort_orders_success(db_session: Session, admin_user_client):
|
||||
# Create groups
|
||||
groups = [
|
||||
MockGroup(
|
||||
id=uuid.uuid4(),
|
||||
name=f"Group {i}",
|
||||
sort_order=i,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
print(groups)
|
||||
db_session.add_all(groups)
|
||||
db_session.commit()
|
||||
|
||||
# Bulk update sort orders (reverse order)
|
||||
bulk_data = {
|
||||
"groups": [
|
||||
{"group_id": str(groups[0].id), "sort_order": 2},
|
||||
{"group_id": str(groups[1].id), "sort_order": 1},
|
||||
{"group_id": str(groups[2].id), "sort_order": 0},
|
||||
]
|
||||
}
|
||||
response = admin_user_client.post("/groups/reorder", json=bulk_data)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert len(data) == 3
|
||||
|
||||
# Create a dictionary for easy lookup of returned group data by ID
|
||||
returned_groups_map = {item["id"]: item for item in data}
|
||||
|
||||
# Verify each group has its expected new sort_order
|
||||
assert returned_groups_map[str(groups[0].id)]["sort_order"] == 2
|
||||
assert returned_groups_map[str(groups[1].id)]["sort_order"] == 1
|
||||
assert returned_groups_map[str(groups[2].id)]["sort_order"] == 0
|
||||
|
||||
# Verify in DB
|
||||
db_groups = db_session.query(MockGroup).order_by(MockGroup.sort_order).all()
|
||||
assert db_groups[0].sort_order == 2
|
||||
assert db_groups[1].sort_order == 1
|
||||
assert db_groups[2].sort_order == 0
|
||||
|
||||
|
||||
def test_bulk_update_sort_orders_invalid_group(db_session: Session, admin_user_client):
|
||||
# Create one group
|
||||
group_id = uuid.uuid4()
|
||||
test_group = MockGroup(
|
||||
id=group_id,
|
||||
name="Valid Group",
|
||||
sort_order=1,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(test_group)
|
||||
db_session.commit()
|
||||
|
||||
# Try to update with invalid group
|
||||
bulk_data = {
|
||||
"groups": [
|
||||
{"group_id": str(group_id), "sort_order": 2},
|
||||
{"group_id": str(uuid.uuid4()), "sort_order": 1}, # Invalid group
|
||||
]
|
||||
}
|
||||
response = admin_user_client.post("/groups/reorder", json=bulk_data)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "not found" in response.json()["detail"]
|
||||
|
||||
# Verify original sort order unchanged
|
||||
db_group = db_session.query(MockGroup).filter(MockGroup.id == group_id).first()
|
||||
assert db_group.sort_order == 1
|
||||
261
tests/routers/test_playlist.py
Normal file
261
tests/routers/test_playlist.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.auth.dependencies import get_current_user
|
||||
|
||||
# Import the router we're testing
|
||||
from app.routers.playlist import (
|
||||
ProcessStatus,
|
||||
ValidationProcessResponse,
|
||||
ValidationResultResponse,
|
||||
router,
|
||||
validation_processes,
|
||||
)
|
||||
from app.utils.database import get_db
|
||||
|
||||
# Import mocks and fixtures
|
||||
from tests.utils.auth_test_fixtures import (
|
||||
admin_user_client,
|
||||
db_session,
|
||||
non_admin_user_client,
|
||||
)
|
||||
from tests.utils.db_mocks import MockChannelDB
|
||||
|
||||
# --- Test Fixtures ---
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stream_manager():
|
||||
with patch("app.routers.playlist.StreamManager") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
# --- Test Cases For Stream Validation ---
|
||||
|
||||
|
||||
def test_start_stream_validation_success(
|
||||
db_session: Session, admin_user_client, mock_stream_manager
|
||||
):
|
||||
"""Test starting a stream validation process"""
|
||||
mock_instance = mock_stream_manager.return_value
|
||||
mock_instance.validate_and_select_stream.return_value = "http://valid.stream.url"
|
||||
|
||||
response = admin_user_client.post(
|
||||
"/playlist/validate-streams", json={"channel_id": "test-channel"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||
data = response.json()
|
||||
assert "process_id" in data
|
||||
assert data["status"] == ProcessStatus.PENDING
|
||||
assert data["message"] == "Validation process started"
|
||||
|
||||
# Verify process was added to tracking
|
||||
process_id = data["process_id"]
|
||||
assert process_id in validation_processes
|
||||
# In test environment, background tasks run synchronously so status may be COMPLETED
|
||||
assert validation_processes[process_id]["status"] in [
|
||||
ProcessStatus.PENDING,
|
||||
ProcessStatus.COMPLETED,
|
||||
]
|
||||
assert validation_processes[process_id]["channel_id"] == "test-channel"
|
||||
|
||||
|
||||
def test_get_validation_status_pending(db_session: Session, admin_user_client):
|
||||
"""Test checking status of pending validation"""
|
||||
process_id = str(uuid.uuid4())
|
||||
validation_processes[process_id] = {
|
||||
"status": ProcessStatus.PENDING,
|
||||
"channel_id": "test-channel",
|
||||
}
|
||||
|
||||
response = admin_user_client.get(f"/playlist/validate-streams/{process_id}")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["process_id"] == process_id
|
||||
assert data["status"] == ProcessStatus.PENDING
|
||||
assert data["working_streams"] is None
|
||||
assert data["error"] is None
|
||||
|
||||
|
||||
def test_get_validation_status_completed(db_session: Session, admin_user_client):
|
||||
"""Test checking status of completed validation"""
|
||||
process_id = str(uuid.uuid4())
|
||||
validation_processes[process_id] = {
|
||||
"status": ProcessStatus.COMPLETED,
|
||||
"channel_id": "test-channel",
|
||||
"result": {
|
||||
"working_streams": [
|
||||
{"channel_id": "test-channel", "stream_url": "http://valid.stream.url"}
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
response = admin_user_client.get(f"/playlist/validate-streams/{process_id}")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["process_id"] == process_id
|
||||
assert data["status"] == ProcessStatus.COMPLETED
|
||||
assert len(data["working_streams"]) == 1
|
||||
assert data["working_streams"][0]["channel_id"] == "test-channel"
|
||||
assert data["working_streams"][0]["stream_url"] == "http://valid.stream.url"
|
||||
assert data["error"] is None
|
||||
|
||||
|
||||
def test_get_validation_status_completed_with_error(
|
||||
db_session: Session, admin_user_client
|
||||
):
|
||||
"""Test checking status of completed validation with error"""
|
||||
process_id = str(uuid.uuid4())
|
||||
validation_processes[process_id] = {
|
||||
"status": ProcessStatus.COMPLETED,
|
||||
"channel_id": "test-channel",
|
||||
"error": "No working streams found for channel test-channel",
|
||||
}
|
||||
|
||||
response = admin_user_client.get(f"/playlist/validate-streams/{process_id}")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["process_id"] == process_id
|
||||
assert data["status"] == ProcessStatus.COMPLETED
|
||||
assert data["working_streams"] is None
|
||||
assert data["error"] == "No working streams found for channel test-channel"
|
||||
|
||||
|
||||
def test_get_validation_status_failed(db_session: Session, admin_user_client):
|
||||
"""Test checking status of failed validation"""
|
||||
process_id = str(uuid.uuid4())
|
||||
validation_processes[process_id] = {
|
||||
"status": ProcessStatus.FAILED,
|
||||
"channel_id": "test-channel",
|
||||
"error": "Validation error occurred",
|
||||
}
|
||||
|
||||
response = admin_user_client.get(f"/playlist/validate-streams/{process_id}")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["process_id"] == process_id
|
||||
assert data["status"] == ProcessStatus.FAILED
|
||||
assert data["working_streams"] is None
|
||||
assert data["error"] == "Validation error occurred"
|
||||
|
||||
|
||||
def test_get_validation_status_not_found(db_session: Session, admin_user_client):
|
||||
"""Test checking status of non-existent process"""
|
||||
random_uuid = str(uuid.uuid4())
|
||||
response = admin_user_client.get(f"/playlist/validate-streams/{random_uuid}")
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "Process not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_run_stream_validation_success(mock_stream_manager, db_session):
|
||||
"""Test the background validation task success case"""
|
||||
process_id = str(uuid.uuid4())
|
||||
validation_processes[process_id] = {
|
||||
"status": ProcessStatus.PENDING,
|
||||
"channel_id": "test-channel",
|
||||
}
|
||||
|
||||
mock_instance = mock_stream_manager.return_value
|
||||
mock_instance.validate_and_select_stream.return_value = "http://valid.stream.url"
|
||||
|
||||
from app.routers.playlist import run_stream_validation
|
||||
|
||||
run_stream_validation(process_id, "test-channel", db_session)
|
||||
|
||||
assert validation_processes[process_id]["status"] == ProcessStatus.COMPLETED
|
||||
assert len(validation_processes[process_id]["result"]["working_streams"]) == 1
|
||||
assert (
|
||||
validation_processes[process_id]["result"]["working_streams"][0].channel_id
|
||||
== "test-channel"
|
||||
)
|
||||
assert (
|
||||
validation_processes[process_id]["result"]["working_streams"][0].stream_url
|
||||
== "http://valid.stream.url"
|
||||
)
|
||||
|
||||
|
||||
def test_run_stream_validation_failure(mock_stream_manager, db_session):
|
||||
"""Test the background validation task failure case"""
|
||||
process_id = str(uuid.uuid4())
|
||||
validation_processes[process_id] = {
|
||||
"status": ProcessStatus.PENDING,
|
||||
"channel_id": "test-channel",
|
||||
}
|
||||
|
||||
mock_instance = mock_stream_manager.return_value
|
||||
mock_instance.validate_and_select_stream.return_value = None
|
||||
|
||||
from app.routers.playlist import run_stream_validation
|
||||
|
||||
run_stream_validation(process_id, "test-channel", db_session)
|
||||
|
||||
assert validation_processes[process_id]["status"] == ProcessStatus.COMPLETED
|
||||
assert "error" in validation_processes[process_id]
|
||||
assert "No working streams found" in validation_processes[process_id]["error"]
|
||||
|
||||
|
||||
def test_run_stream_validation_exception(mock_stream_manager, db_session):
|
||||
"""Test the background validation task exception case"""
|
||||
process_id = str(uuid.uuid4())
|
||||
validation_processes[process_id] = {
|
||||
"status": ProcessStatus.PENDING,
|
||||
"channel_id": "test-channel",
|
||||
}
|
||||
|
||||
mock_instance = mock_stream_manager.return_value
|
||||
mock_instance.validate_and_select_stream.side_effect = Exception("Test error")
|
||||
|
||||
from app.routers.playlist import run_stream_validation
|
||||
|
||||
run_stream_validation(process_id, "test-channel", db_session)
|
||||
|
||||
assert validation_processes[process_id]["status"] == ProcessStatus.FAILED
|
||||
assert "error" in validation_processes[process_id]
|
||||
assert "Test error" in validation_processes[process_id]["error"]
|
||||
|
||||
|
||||
def test_start_stream_validation_no_channel_id(
|
||||
db_session: Session, admin_user_client, mock_stream_manager
|
||||
):
|
||||
"""Test starting validation without channel_id"""
|
||||
response = admin_user_client.post("/playlist/validate-streams", json={})
|
||||
|
||||
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||
data = response.json()
|
||||
assert "process_id" in data
|
||||
assert data["status"] == ProcessStatus.PENDING
|
||||
|
||||
# Verify process was added to tracking
|
||||
process_id = data["process_id"]
|
||||
assert process_id in validation_processes
|
||||
assert validation_processes[process_id]["status"] in [
|
||||
ProcessStatus.PENDING,
|
||||
ProcessStatus.COMPLETED,
|
||||
]
|
||||
assert validation_processes[process_id]["channel_id"] is None
|
||||
assert "not yet implemented" in validation_processes[process_id].get("error", "")
|
||||
|
||||
|
||||
def test_run_stream_validation_no_channel_id(mock_stream_manager, db_session):
|
||||
"""Test background validation without channel_id"""
|
||||
process_id = str(uuid.uuid4())
|
||||
validation_processes[process_id] = {"status": ProcessStatus.PENDING}
|
||||
|
||||
from app.routers.playlist import run_stream_validation
|
||||
|
||||
run_stream_validation(process_id, None, db_session)
|
||||
|
||||
assert validation_processes[process_id]["status"] == ProcessStatus.COMPLETED
|
||||
assert "error" in validation_processes[process_id]
|
||||
assert "not yet implemented" in validation_processes[process_id]["error"]
|
||||
241
tests/routers/test_priorities.py
Normal file
241
tests/routers/test_priorities.py
Normal file
@@ -0,0 +1,241 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.routers.priorities import router as priorities_router
|
||||
|
||||
# Import fixtures and mocks
|
||||
from tests.utils.auth_test_fixtures import (
|
||||
admin_user_client,
|
||||
db_session,
|
||||
non_admin_user_client,
|
||||
)
|
||||
from tests.utils.db_mocks import (
|
||||
MockChannelDB,
|
||||
MockChannelURL,
|
||||
MockGroup,
|
||||
MockPriority,
|
||||
create_mock_priorities_and_group,
|
||||
)
|
||||
|
||||
# --- Test Cases For Priority Creation ---
|
||||
|
||||
|
||||
def test_create_priority_success(db_session: Session, admin_user_client):
|
||||
priority_data = {"id": 100, "description": "Test Priority"}
|
||||
response = admin_user_client.post("/priorities/", json=priority_data)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
data = response.json()
|
||||
assert data["id"] == 100
|
||||
assert data["description"] == "Test Priority"
|
||||
|
||||
# Verify in DB
|
||||
db_priority = db_session.get(MockPriority, 100)
|
||||
assert db_priority is not None
|
||||
assert db_priority.description == "Test Priority"
|
||||
|
||||
|
||||
def test_create_priority_duplicate(db_session: Session, admin_user_client):
|
||||
# Create initial priority
|
||||
priority_data = {"id": 100, "description": "Original Priority"}
|
||||
response1 = admin_user_client.post("/priorities/", json=priority_data)
|
||||
assert response1.status_code == status.HTTP_201_CREATED
|
||||
|
||||
# Attempt to create with same ID
|
||||
response2 = admin_user_client.post("/priorities/", json=priority_data)
|
||||
assert response2.status_code == status.HTTP_409_CONFLICT
|
||||
assert "already exists" in response2.json()["detail"]
|
||||
|
||||
|
||||
def test_create_priority_forbidden_for_non_admin(
|
||||
db_session: Session, non_admin_user_client
|
||||
):
|
||||
priority_data = {"id": 100, "description": "Test Priority"}
|
||||
response = non_admin_user_client.post("/priorities/", json=priority_data)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert "required roles" in response.json()["detail"]
|
||||
|
||||
|
||||
# --- Test Cases For List Priorities ---
|
||||
|
||||
|
||||
def test_list_priorities_empty(db_session: Session, admin_user_client):
|
||||
response = admin_user_client.get("/priorities/")
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json() == []
|
||||
|
||||
|
||||
def test_list_priorities_with_data(db_session: Session, admin_user_client):
|
||||
# Create some test priorities
|
||||
priorities = [
|
||||
MockPriority(id=100, description="High"),
|
||||
MockPriority(id=200, description="Medium"),
|
||||
MockPriority(id=300, description="Low"),
|
||||
]
|
||||
for priority in priorities:
|
||||
db_session.add(priority)
|
||||
db_session.commit()
|
||||
|
||||
response = admin_user_client.get("/priorities/")
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert len(data) == 3
|
||||
assert data[0]["id"] == 100
|
||||
assert data[0]["description"] == "High"
|
||||
assert data[1]["id"] == 200
|
||||
assert data[1]["description"] == "Medium"
|
||||
assert data[2]["id"] == 300
|
||||
assert data[2]["description"] == "Low"
|
||||
|
||||
|
||||
def test_list_priorities_forbidden_for_non_admin(
|
||||
db_session: Session, non_admin_user_client
|
||||
):
|
||||
response = non_admin_user_client.get("/priorities/")
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert "required roles" in response.json()["detail"]
|
||||
|
||||
|
||||
# --- Test Cases For Get Priority ---
|
||||
|
||||
|
||||
def test_get_priority_success(db_session: Session, admin_user_client):
|
||||
# Create a test priority
|
||||
priority = MockPriority(id=100, description="Test Priority")
|
||||
db_session.add(priority)
|
||||
db_session.commit()
|
||||
|
||||
response = admin_user_client.get("/priorities/100")
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["id"] == 100
|
||||
assert data["description"] == "Test Priority"
|
||||
|
||||
|
||||
def test_get_priority_not_found(db_session: Session, admin_user_client):
|
||||
response = admin_user_client.get("/priorities/999")
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "Priority not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_get_priority_forbidden_for_non_admin(
|
||||
db_session: Session, non_admin_user_client
|
||||
):
|
||||
response = non_admin_user_client.get("/priorities/100")
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert "required roles" in response.json()["detail"]
|
||||
|
||||
|
||||
# --- Test Cases For Delete Priority ---
|
||||
|
||||
|
||||
def test_delete_all_priorities_success(db_session, admin_user_client):
|
||||
"""Test reset priorities endpoint"""
|
||||
# Create test data
|
||||
priorities = [(100, "High"), (200, "Medium"), (300, "Low")]
|
||||
for id, desc in priorities:
|
||||
db_session.add(MockPriority(id=id, description=desc))
|
||||
db_session.commit()
|
||||
|
||||
# Create channel using priority 100
|
||||
create_mock_priorities_and_group(db_session, [], "Test Group")
|
||||
channel_data = [
|
||||
{
|
||||
"group-title": "Test Group",
|
||||
"tvg_id": "test.tv",
|
||||
"name": "Test Channel",
|
||||
"urls": ["http://test.com"],
|
||||
}
|
||||
]
|
||||
admin_user_client.post("/channels/bulk-upload", json=channel_data)
|
||||
|
||||
# Delete all priorities
|
||||
response = admin_user_client.delete("/priorities")
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json()["deleted"] == 2 # Medium and Low priorities
|
||||
assert response.json()["skipped"] == 1 # High priority is in use
|
||||
|
||||
# Verify only priority 100 remains
|
||||
priorities = db_session.query(MockPriority).all()
|
||||
assert len(priorities) == 1
|
||||
assert priorities[0].id == 100
|
||||
|
||||
|
||||
def test_reset_priorities_forbidden_for_non_admin(db_session, non_admin_user_client):
|
||||
"""Test reset priorities requires admin role"""
|
||||
response = non_admin_user_client.delete("/priorities")
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
def test_delete_priority_success(db_session: Session, admin_user_client):
|
||||
# Create a test priority
|
||||
priority = MockPriority(id=100, description="To Delete")
|
||||
db_session.add(priority)
|
||||
db_session.commit()
|
||||
|
||||
response = admin_user_client.delete("/priorities/100")
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
|
||||
# Verify priority is gone from DB
|
||||
db_priority = db_session.get(MockPriority, 100)
|
||||
assert db_priority is None
|
||||
|
||||
|
||||
def test_delete_priority_not_found(db_session: Session, admin_user_client):
|
||||
response = admin_user_client.delete("/priorities/999")
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "Priority not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_delete_priority_in_use(db_session: Session, admin_user_client):
|
||||
# Create a priority and a channel URL using it
|
||||
priority = MockPriority(id=100, description="In Use")
|
||||
group_id = uuid.uuid4()
|
||||
test_group = MockGroup(
|
||||
id=group_id,
|
||||
name="Group With Channels",
|
||||
sort_order=1,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add_all([priority, test_group])
|
||||
db_session.commit()
|
||||
|
||||
# Create a channel first
|
||||
channel = MockChannelDB(
|
||||
name="Test Channel",
|
||||
tvg_id="test.tv",
|
||||
tvg_name="Test",
|
||||
tvg_logo="test.png",
|
||||
group_id=group_id,
|
||||
)
|
||||
db_session.add(channel)
|
||||
db_session.commit()
|
||||
|
||||
# Create URL associated with the channel and priority
|
||||
channel_url = MockChannelURL(
|
||||
url="http://test.com",
|
||||
priority_id=100,
|
||||
in_use=True,
|
||||
channel_id=channel.id, # Add the channel_id
|
||||
)
|
||||
db_session.add(channel_url)
|
||||
db_session.commit()
|
||||
|
||||
response = admin_user_client.delete("/priorities/100")
|
||||
assert response.status_code == status.HTTP_409_CONFLICT
|
||||
assert "in use by channel URLs" in response.json()["detail"]
|
||||
|
||||
# Verify priority still exists
|
||||
db_priority = db_session.get(MockPriority, 100)
|
||||
assert db_priority is not None
|
||||
|
||||
|
||||
def test_delete_priority_forbidden_for_non_admin(
|
||||
db_session: Session, non_admin_user_client
|
||||
):
|
||||
response = non_admin_user_client.delete("/priorities/100")
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert "required roles" in response.json()["detail"]
|
||||
287
tests/routers/test_scheduler.py
Normal file
287
tests/routers/test_scheduler.py
Normal file
@@ -0,0 +1,287 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import Mock
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
from app.iptv.scheduler import StreamScheduler
|
||||
from app.routers.scheduler import get_scheduler
|
||||
from app.routers.scheduler import router as scheduler_router
|
||||
from app.utils.database import get_db
|
||||
from tests.routers.mocks import MockScheduler, create_trigger_mock, mock_get_scheduler
|
||||
from tests.utils.auth_test_fixtures import (
|
||||
admin_user_client,
|
||||
db_session,
|
||||
non_admin_user_client,
|
||||
)
|
||||
from tests.utils.db_mocks import mock_get_db
|
||||
|
||||
# Scheduler Health Check Tests
|
||||
|
||||
|
||||
def test_scheduler_health_success(admin_user_client, monkeypatch):
|
||||
"""
|
||||
Test case for successful scheduler health check when accessed by an admin user.
|
||||
It mocks the scheduler to be running and have a next scheduled job.
|
||||
"""
|
||||
|
||||
# Define the expected next run time for the scheduler job.
|
||||
next_run = datetime.now(timezone.utc)
|
||||
|
||||
# Create a mock job object that simulates an APScheduler job.
|
||||
mock_job = Mock()
|
||||
mock_job.next_run_time = next_run
|
||||
|
||||
# Mock the `get_job` method to return our mock_job for a specific ID.
|
||||
def mock_get_job(job_id):
|
||||
if job_id == "daily_stream_validation":
|
||||
return mock_job
|
||||
return None
|
||||
|
||||
# Create a custom mock for `get_scheduler` dependency.
|
||||
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
|
||||
return await mock_get_scheduler(
|
||||
request,
|
||||
running=True,
|
||||
get_job=Mock(side_effect=mock_get_job), # Use the custom mock_get_job
|
||||
)
|
||||
|
||||
# Include the scheduler router in the test application.
|
||||
admin_user_client.app.include_router(scheduler_router)
|
||||
|
||||
# Override dependencies for the test.
|
||||
admin_user_client.app.dependency_overrides[get_scheduler] = (
|
||||
custom_mock_get_scheduler
|
||||
)
|
||||
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||
|
||||
# Make the request to the scheduler health endpoint.
|
||||
response = admin_user_client.get("/scheduler/health")
|
||||
|
||||
# Assert the response status code and content.
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["status"] == "running"
|
||||
assert data["next_run"] == str(next_run)
|
||||
|
||||
|
||||
def test_scheduler_health_stopped(admin_user_client, monkeypatch):
|
||||
"""
|
||||
Test case for scheduler health check when the scheduler is in a stopped state.
|
||||
Ensures the API returns the correct status and no next run time.
|
||||
"""
|
||||
|
||||
# Create a custom mock for `get_scheduler` dependency,
|
||||
# simulating a stopped scheduler.
|
||||
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
|
||||
return await mock_get_scheduler(
|
||||
request,
|
||||
running=False,
|
||||
)
|
||||
|
||||
# Include the scheduler router in the test application.
|
||||
admin_user_client.app.include_router(scheduler_router)
|
||||
|
||||
# Override dependencies for the test.
|
||||
admin_user_client.app.dependency_overrides[get_scheduler] = (
|
||||
custom_mock_get_scheduler
|
||||
)
|
||||
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||
|
||||
# Make the request to the scheduler health endpoint.
|
||||
response = admin_user_client.get("/scheduler/health")
|
||||
|
||||
# Assert the response status code and content.
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["status"] == "stopped"
|
||||
assert data["next_run"] is None
|
||||
|
||||
|
||||
def test_scheduler_health_forbidden_for_non_admin(non_admin_user_client, monkeypatch):
|
||||
"""
|
||||
Test case to ensure that non-admin users are forbidden from accessing
|
||||
the scheduler health endpoint.
|
||||
"""
|
||||
|
||||
# Create a custom mock for `get_scheduler` dependency.
|
||||
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
|
||||
return await mock_get_scheduler(
|
||||
request,
|
||||
running=False,
|
||||
)
|
||||
|
||||
# Include the scheduler router in the test application.
|
||||
non_admin_user_client.app.include_router(scheduler_router)
|
||||
|
||||
# Override dependencies for the test.
|
||||
non_admin_user_client.app.dependency_overrides[get_scheduler] = (
|
||||
custom_mock_get_scheduler
|
||||
)
|
||||
non_admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||
|
||||
# Make the request to the scheduler health endpoint.
|
||||
response = non_admin_user_client.get("/scheduler/health")
|
||||
|
||||
# Assert the response status code and error detail.
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert "required roles" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_scheduler_health_check_exception(admin_user_client, monkeypatch):
|
||||
"""
|
||||
Test case for handling exceptions during the scheduler health check.
|
||||
Ensures the API returns a 500 Internal Server Error when an exception occurs.
|
||||
"""
|
||||
|
||||
# Create a custom mock for `get_scheduler` dependency that raises an exception.
|
||||
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
|
||||
return await mock_get_scheduler(
|
||||
request, running=True, get_job=Mock(side_effect=Exception("Test exception"))
|
||||
)
|
||||
|
||||
# Include the scheduler router in the test application.
|
||||
admin_user_client.app.include_router(scheduler_router)
|
||||
|
||||
# Override dependencies for the test.
|
||||
admin_user_client.app.dependency_overrides[get_scheduler] = (
|
||||
custom_mock_get_scheduler
|
||||
)
|
||||
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||
|
||||
# Make the request to the scheduler health endpoint.
|
||||
response = admin_user_client.get("/scheduler/health")
|
||||
|
||||
# Assert the response status code and error detail.
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert "Failed to check scheduler health" in response.json()["detail"]
|
||||
|
||||
|
||||
# Scheduler Trigger Tests
|
||||
|
||||
|
||||
def test_trigger_validation_success(admin_user_client, monkeypatch):
|
||||
"""
|
||||
Test case for successful manual triggering
|
||||
of stream validation by an admin user.
|
||||
It verifies that the trigger method is called and
|
||||
the API returns a 202 Accepted status.
|
||||
"""
|
||||
# Use a mutable reference to check if the trigger method was called.
|
||||
triggered_ref = {"value": False}
|
||||
|
||||
# Initialize a custom mock scheduler.
|
||||
custom_scheduler = MockScheduler(running=True)
|
||||
custom_scheduler.get_job = Mock(return_value=None)
|
||||
|
||||
# Create a custom mock for `get_scheduler` dependency,
|
||||
# overriding `trigger_manual_validation`.
|
||||
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
|
||||
scheduler = await mock_get_scheduler(
|
||||
request,
|
||||
running=True,
|
||||
)
|
||||
|
||||
# Replace the actual trigger method with our mock to track calls.
|
||||
scheduler.trigger_manual_validation = create_trigger_mock(
|
||||
triggered_ref=triggered_ref
|
||||
)
|
||||
|
||||
return scheduler
|
||||
|
||||
# Include the scheduler router in the test application.
|
||||
admin_user_client.app.include_router(scheduler_router)
|
||||
|
||||
# Override dependencies for the test.
|
||||
admin_user_client.app.dependency_overrides[get_scheduler] = (
|
||||
custom_mock_get_scheduler
|
||||
)
|
||||
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||
|
||||
# Make the request to trigger stream validation.
|
||||
response = admin_user_client.post("/scheduler/trigger")
|
||||
|
||||
# Assert the response status code, message, and that the trigger was called.
|
||||
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||
assert response.json()["message"] == "Stream validation triggered"
|
||||
assert triggered_ref["value"] is True
|
||||
|
||||
|
||||
def test_trigger_validation_forbidden_for_non_admin(non_admin_user_client, monkeypatch):
|
||||
"""
|
||||
Test case to ensure that non-admin users are
|
||||
forbidden from manually triggering stream validation.
|
||||
"""
|
||||
|
||||
# Create a custom mock for `get_scheduler` dependency.
|
||||
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
|
||||
return await mock_get_scheduler(
|
||||
request,
|
||||
running=True,
|
||||
)
|
||||
|
||||
# Include the scheduler router in the test application.
|
||||
non_admin_user_client.app.include_router(scheduler_router)
|
||||
|
||||
# Override dependencies for the test.
|
||||
non_admin_user_client.app.dependency_overrides[get_scheduler] = (
|
||||
custom_mock_get_scheduler
|
||||
)
|
||||
non_admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||
|
||||
# Make the request to trigger stream validation.
|
||||
response = non_admin_user_client.post("/scheduler/trigger")
|
||||
|
||||
# Assert the response status code and error detail.
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert "required roles" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_scheduler_initialized_in_app_state(admin_user_client):
|
||||
"""
|
||||
Test case for when the scheduler is initialized in the app state but its internal
|
||||
scheduler attribute is not set, which should still allow health check.
|
||||
"""
|
||||
scheduler = StreamScheduler()
|
||||
|
||||
# Set the scheduler instance in the test client's app state.
|
||||
admin_user_client.app.state.scheduler = scheduler
|
||||
|
||||
# Include the scheduler router in the test application.
|
||||
admin_user_client.app.include_router(scheduler_router)
|
||||
|
||||
# Override only get_db, allowing the real get_scheduler to be tested.
|
||||
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||
|
||||
# Make the request to the scheduler health endpoint.
|
||||
response = admin_user_client.get("/scheduler/health")
|
||||
|
||||
# Assert the response status code.
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
|
||||
def test_scheduler_not_initialized_in_app_state(admin_user_client):
|
||||
"""
|
||||
Test case for when the scheduler is not properly initialized in the app state.
|
||||
This simulates a scenario where the internal scheduler attribute is missing,
|
||||
leading to a 500 Internal Server Error on health check.
|
||||
"""
|
||||
scheduler = StreamScheduler()
|
||||
del (
|
||||
scheduler.scheduler
|
||||
) # Simulate uninitialized scheduler by deleting the attribute
|
||||
|
||||
# Set the scheduler instance in the test client's app state.
|
||||
admin_user_client.app.state.scheduler = scheduler
|
||||
|
||||
# Include the scheduler router in the test application.
|
||||
admin_user_client.app.include_router(scheduler_router)
|
||||
|
||||
# Override only get_db, allowing the real get_scheduler to be tested.
|
||||
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||
|
||||
# Make the request to the scheduler health endpoint.
|
||||
response = admin_user_client.get("/scheduler/health")
|
||||
|
||||
# Assert the response status code and error detail.
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert "Scheduler not initialized" in response.json()["detail"]
|
||||
81
tests/test_main.py
Normal file
81
tests/test_main.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.main import app, lifespan
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Test client for FastAPI app"""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_root_endpoint(client):
|
||||
"""Test root endpoint returns expected message"""
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "IPTV Manager API"}
|
||||
|
||||
|
||||
def test_openapi_schema_generation(client):
|
||||
"""Test OpenAPI schema is properly generated"""
|
||||
# First call - generate schema
|
||||
response = client.get("/openapi.json")
|
||||
assert response.status_code == 200
|
||||
schema = response.json()
|
||||
assert schema["openapi"] == "3.1.0"
|
||||
assert "securitySchemes" in schema["components"]
|
||||
assert "Bearer" in schema["components"]["securitySchemes"]
|
||||
|
||||
# Test empty components initialization
|
||||
with patch("app.main.get_openapi", return_value={"info": {}}):
|
||||
# Clear cached schema
|
||||
app.openapi_schema = None
|
||||
# Get schema with empty response
|
||||
response = client.get("/openapi.json")
|
||||
assert response.status_code == 200
|
||||
schema = response.json()
|
||||
assert "components" in schema
|
||||
assert "schemas" in schema["components"]
|
||||
|
||||
|
||||
def test_openapi_schema_caching(mocker):
|
||||
"""Test OpenAPI schema caching behavior"""
|
||||
# Clear any existing schema
|
||||
app.openapi_schema = None
|
||||
|
||||
# Mock get_openapi to return test schema
|
||||
mock_schema = {"test": "schema"}
|
||||
mocker.patch("app.main.get_openapi", return_value=mock_schema)
|
||||
|
||||
# First call - should call get_openapi
|
||||
schema = app.openapi()
|
||||
assert schema == mock_schema
|
||||
assert app.openapi_schema == mock_schema
|
||||
|
||||
# Second call - should return cached schema
|
||||
with patch("app.main.get_openapi") as mock_get_openapi:
|
||||
schema = app.openapi()
|
||||
assert schema == mock_schema
|
||||
mock_get_openapi.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan_init_db(mocker):
|
||||
"""Test lifespan manager initializes database"""
|
||||
mock_init_db = mocker.patch("app.main.init_db")
|
||||
async with lifespan(app):
|
||||
pass # Just enter/exit context
|
||||
mock_init_db.assert_called_once()
|
||||
|
||||
|
||||
def test_router_inclusion():
|
||||
"""Test all routers are properly included"""
|
||||
route_paths = {route.path for route in app.routes}
|
||||
assert "/" in route_paths
|
||||
assert any(path.startswith("/auth") for path in route_paths)
|
||||
assert any(path.startswith("/channels") for path in route_paths)
|
||||
assert any(path.startswith("/playlist") for path in route_paths)
|
||||
assert any(path.startswith("/priorities") for path in route_paths)
|
||||
0
tests/utils/__init__.py
Normal file
0
tests/utils/__init__.py
Normal file
82
tests/utils/auth_test_fixtures.py
Normal file
82
tests/utils/auth_test_fixtures.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.auth.dependencies import get_current_user
|
||||
from app.models.auth import CognitoUser
|
||||
from app.routers.channels import router as channels_router
|
||||
from app.routers.groups import router as groups_router
|
||||
from app.routers.playlist import router as playlist_router
|
||||
from app.routers.priorities import router as priorities_router
|
||||
from app.utils.database import get_db
|
||||
from tests.utils.db_mocks import (
|
||||
MockBase,
|
||||
MockChannelDB,
|
||||
MockChannelURL,
|
||||
MockPriority,
|
||||
engine_mock,
|
||||
mock_get_db,
|
||||
)
|
||||
from tests.utils.db_mocks import session_mock as TestingSessionLocal
|
||||
|
||||
|
||||
def mock_get_current_user_admin():
|
||||
return CognitoUser(
|
||||
username="testadmin",
|
||||
email="testadmin@example.com",
|
||||
roles=["admin"],
|
||||
user_status="CONFIRMED",
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
|
||||
def mock_get_current_user_non_admin():
|
||||
return CognitoUser(
|
||||
username="testuser",
|
||||
email="testuser@example.com",
|
||||
roles=["user"], # Or any role other than admin
|
||||
user_status="CONFIRMED",
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def db_session():
|
||||
# Create tables for each test function
|
||||
MockBase.metadata.create_all(bind=engine_mock)
|
||||
db = TestingSessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
# Drop tables after each test function
|
||||
MockBase.metadata.drop_all(bind=engine_mock)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def admin_user_client(db_session: Session):
|
||||
"""Yields a TestClient configured with an admin user."""
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(channels_router)
|
||||
test_app.include_router(priorities_router)
|
||||
test_app.include_router(playlist_router)
|
||||
test_app.include_router(groups_router)
|
||||
test_app.dependency_overrides[get_db] = mock_get_db
|
||||
test_app.dependency_overrides[get_current_user] = mock_get_current_user_admin
|
||||
with TestClient(test_app) as test_client:
|
||||
yield test_client
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def non_admin_user_client(db_session: Session):
|
||||
"""Yields a TestClient configured with a non-admin user."""
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(channels_router)
|
||||
test_app.include_router(priorities_router)
|
||||
test_app.include_router(playlist_router)
|
||||
test_app.include_router(groups_router)
|
||||
test_app.dependency_overrides[get_db] = mock_get_db
|
||||
test_app.dependency_overrides[get_current_user] = mock_get_current_user_non_admin
|
||||
with TestClient(test_app) as test_client:
|
||||
yield test_client
|
||||
149
tests/utils/db_mocks.py
Normal file
149
tests/utils/db_mocks.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
UniqueConstraint,
|
||||
create_engine,
|
||||
)
|
||||
from sqlalchemy.orm import declarative_base, relationship, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
# Import the actual UUID_COLUMN_TYPE and SQLiteUUID from app.models.db
|
||||
from app.models.db import UUID_COLUMN_TYPE, SQLiteUUID
|
||||
|
||||
# Create a mock-specific Base class for testing
|
||||
MockBase = declarative_base()
|
||||
|
||||
|
||||
# Model classes for testing - prefix with Mock to avoid pytest collection
|
||||
class MockPriority(MockBase):
|
||||
__tablename__ = "priorities"
|
||||
id = Column(Integer, primary_key=True)
|
||||
description = Column(String, nullable=False)
|
||||
|
||||
|
||||
class MockGroup(MockBase):
|
||||
__tablename__ = "groups"
|
||||
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String, nullable=False, unique=True)
|
||||
sort_order = Column(Integer, nullable=False, default=0)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
channels = relationship("MockChannelDB", back_populates="group")
|
||||
|
||||
|
||||
class MockChannelDB(MockBase):
|
||||
__tablename__ = "channels"
|
||||
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
||||
tvg_id = Column(String, nullable=False)
|
||||
name = Column(String, nullable=False)
|
||||
group_id = Column(UUID_COLUMN_TYPE, ForeignKey("groups.id"), nullable=False)
|
||||
tvg_name = Column(String)
|
||||
__table_args__ = (UniqueConstraint("group_id", "name", name="uix_group_id_name"),)
|
||||
tvg_logo = Column(String)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
group = relationship("MockGroup", back_populates="channels")
|
||||
urls = relationship(
|
||||
"MockChannelURL", back_populates="channel", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class MockChannelURL(MockBase):
|
||||
__tablename__ = "channels_urls"
|
||||
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
||||
channel_id = Column(
|
||||
UUID_COLUMN_TYPE, ForeignKey("channels.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
url = Column(String, nullable=False)
|
||||
in_use = Column(Boolean, default=False, nullable=False)
|
||||
priority_id = Column(Integer, ForeignKey("priorities.id"), nullable=False)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
channel = relationship("MockChannelDB", back_populates="urls")
|
||||
|
||||
|
||||
def create_mock_priorities_and_group(db_session, priorities, group_name):
|
||||
"""Create mock priorities and group for testing purposes.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session object
|
||||
priorities: List of (id, description) tuples for priorities to create
|
||||
group_name: Name for the new mock group
|
||||
|
||||
Returns:
|
||||
UUID: The ID of the created group
|
||||
"""
|
||||
# Create priorities
|
||||
priority_objects = [
|
||||
MockPriority(id=priority_id, description=description)
|
||||
for priority_id, description in priorities
|
||||
]
|
||||
|
||||
# Create group
|
||||
group = MockGroup(name=group_name)
|
||||
db_session.add_all(priority_objects + [group])
|
||||
db_session.commit()
|
||||
|
||||
return group.id
|
||||
|
||||
|
||||
# Create test engine
|
||||
engine_mock = create_engine(
|
||||
"sqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
|
||||
# Create test session
|
||||
session_mock = sessionmaker(autocommit=False, autoflush=False, bind=engine_mock)
|
||||
|
||||
|
||||
# Mock the actual database functions
|
||||
def mock_get_db():
|
||||
db = session_mock()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_env(monkeypatch):
|
||||
"""Fixture for mocking environment variables"""
|
||||
monkeypatch.setenv("MOCK_AUTH", "true")
|
||||
monkeypatch.setenv("DB_USER", "testuser")
|
||||
monkeypatch.setenv("DB_PASSWORD", "testpass")
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
monkeypatch.setenv("DB_NAME", "testdb")
|
||||
monkeypatch.setenv("AWS_REGION", "us-east-1")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ssm():
|
||||
"""Fixture for mocking boto3 SSM client"""
|
||||
with patch("boto3.client") as mock_client:
|
||||
mock_ssm = MagicMock()
|
||||
mock_client.return_value = mock_ssm
|
||||
mock_ssm.get_parameter.return_value = {"Parameter": {"Value": "mocked_value"}}
|
||||
yield mock_ssm
|
||||
309
tests/utils/test_check_streams.py
Normal file
309
tests/utils/test_check_streams.py
Normal file
@@ -0,0 +1,309 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock, Mock, mock_open, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout
|
||||
|
||||
from app.utils.check_streams import StreamValidator, main
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def validator():
|
||||
"""Create a StreamValidator instance for testing"""
|
||||
return StreamValidator(timeout=1)
|
||||
|
||||
|
||||
def test_validator_init():
|
||||
"""Test StreamValidator initialization with default and custom values"""
|
||||
# Test with default user agent
|
||||
validator = StreamValidator()
|
||||
assert validator.timeout == 10
|
||||
assert "Mozilla" in validator.session.headers["User-Agent"]
|
||||
|
||||
# Test with custom values
|
||||
custom_agent = "CustomAgent/1.0"
|
||||
validator = StreamValidator(timeout=5, user_agent=custom_agent)
|
||||
assert validator.timeout == 5
|
||||
assert validator.session.headers["User-Agent"] == custom_agent
|
||||
|
||||
|
||||
def test_is_valid_content_type(validator):
|
||||
"""Test content type validation"""
|
||||
valid_types = [
|
||||
"video/mp4",
|
||||
"video/mp2t",
|
||||
"application/vnd.apple.mpegurl",
|
||||
"application/dash+xml",
|
||||
"video/webm",
|
||||
"application/octet-stream",
|
||||
"application/x-mpegURL",
|
||||
"video/mp4; charset=utf-8", # Test with additional parameters
|
||||
]
|
||||
|
||||
invalid_types = [
|
||||
"text/html",
|
||||
"application/json",
|
||||
"image/jpeg",
|
||||
"",
|
||||
]
|
||||
|
||||
for content_type in valid_types:
|
||||
assert validator._is_valid_content_type(content_type)
|
||||
|
||||
for content_type in invalid_types:
|
||||
assert not validator._is_valid_content_type(content_type)
|
||||
|
||||
# Test None case explicitly
|
||||
assert not validator._is_valid_content_type(None)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code,content_type,should_succeed",
|
||||
[
|
||||
(200, "video/mp4", True),
|
||||
(206, "video/mp4", True), # Partial content
|
||||
(404, "video/mp4", False),
|
||||
(500, "video/mp4", False),
|
||||
(200, "text/html", False),
|
||||
(200, "application/json", False),
|
||||
],
|
||||
)
|
||||
def test_validate_stream_response_handling(status_code, content_type, should_succeed):
|
||||
"""Test stream validation with different response scenarios"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.headers = {"Content-Type": content_type}
|
||||
mock_response.iter_content.return_value = iter([b"some content"])
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value.__enter__.return_value = mock_response
|
||||
|
||||
with patch("requests.Session", return_value=mock_session):
|
||||
validator = StreamValidator()
|
||||
valid, message = validator.validate_stream("http://example.com/stream")
|
||||
|
||||
assert valid == should_succeed
|
||||
mock_session.get.assert_called_once()
|
||||
|
||||
|
||||
def test_validate_stream_connection_error():
|
||||
"""Test stream validation with connection error"""
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.side_effect = ConnectionError("Connection failed")
|
||||
|
||||
with patch("requests.Session", return_value=mock_session):
|
||||
validator = StreamValidator()
|
||||
valid, message = validator.validate_stream("http://example.com/stream")
|
||||
|
||||
assert not valid
|
||||
assert "Connection Error" in message
|
||||
|
||||
|
||||
def test_validate_stream_timeout():
|
||||
"""Test stream validation with timeout"""
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.side_effect = Timeout("Request timed out")
|
||||
|
||||
with patch("requests.Session", return_value=mock_session):
|
||||
validator = StreamValidator()
|
||||
valid, message = validator.validate_stream("http://example.com/stream")
|
||||
|
||||
assert not valid
|
||||
assert "timeout" in message.lower()
|
||||
|
||||
|
||||
def test_validate_stream_http_error():
|
||||
"""Test stream validation with HTTP error"""
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.side_effect = HTTPError("HTTP Error occurred")
|
||||
|
||||
with patch("requests.Session", return_value=mock_session):
|
||||
validator = StreamValidator()
|
||||
valid, message = validator.validate_stream("http://example.com/stream")
|
||||
|
||||
assert not valid
|
||||
assert "HTTP Error" in message
|
||||
|
||||
|
||||
def test_validate_stream_request_exception():
|
||||
"""Test stream validation with general request exception"""
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.side_effect = RequestException("Request failed")
|
||||
|
||||
with patch("requests.Session", return_value=mock_session):
|
||||
validator = StreamValidator()
|
||||
valid, message = validator.validate_stream("http://example.com/stream")
|
||||
|
||||
assert not valid
|
||||
assert "Request Exception" in message
|
||||
|
||||
|
||||
def test_validate_stream_content_read_error():
|
||||
"""Test stream validation when content reading fails"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Type": "video/mp4"}
|
||||
mock_response.iter_content.side_effect = ConnectionError("Read failed")
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value.__enter__.return_value = mock_response
|
||||
|
||||
with patch("requests.Session", return_value=mock_session):
|
||||
validator = StreamValidator()
|
||||
valid, message = validator.validate_stream("http://example.com/stream")
|
||||
|
||||
assert not valid
|
||||
assert "Connection failed during content read" in message
|
||||
|
||||
|
||||
def test_validate_stream_general_exception():
|
||||
"""Test validate_stream with an unexpected exception"""
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.side_effect = Exception("Unexpected error")
|
||||
|
||||
with patch("requests.Session", return_value=mock_session):
|
||||
validator = StreamValidator()
|
||||
valid, message = validator.validate_stream("http://example.com/stream")
|
||||
|
||||
assert not valid
|
||||
assert "Validation error" in message
|
||||
|
||||
|
||||
def test_parse_playlist(validator, tmp_path):
|
||||
"""Test playlist file parsing"""
|
||||
playlist_content = """
|
||||
#EXTM3U
|
||||
#EXTINF:-1,Channel 1
|
||||
http://example.com/stream1
|
||||
#EXTINF:-1,Channel 2
|
||||
http://example.com/stream2
|
||||
|
||||
http://example.com/stream3
|
||||
"""
|
||||
playlist_file = tmp_path / "test_playlist.m3u"
|
||||
playlist_file.write_text(playlist_content)
|
||||
|
||||
urls = validator.parse_playlist(str(playlist_file))
|
||||
assert len(urls) == 3
|
||||
assert urls == [
|
||||
"http://example.com/stream1",
|
||||
"http://example.com/stream2",
|
||||
"http://example.com/stream3",
|
||||
]
|
||||
|
||||
|
||||
def test_parse_playlist_error(validator):
|
||||
"""Test playlist parsing with non-existent file"""
|
||||
with pytest.raises(Exception):
|
||||
validator.parse_playlist("nonexistent_file.m3u")
|
||||
|
||||
|
||||
@patch("app.utils.check_streams.logging")
|
||||
@patch("app.utils.check_streams.StreamValidator")
|
||||
def test_main_with_urls(mock_validator_class, mock_logging, tmp_path, capsys):
|
||||
"""Test main function with direct URLs"""
|
||||
# Setup mock validator
|
||||
mock_validator = Mock()
|
||||
mock_validator_class.return_value = mock_validator
|
||||
mock_validator.validate_stream.return_value = (True, "Stream is valid")
|
||||
|
||||
# Setup test arguments
|
||||
test_args = ["script", "http://example.com/stream1", "http://example.com/stream2"]
|
||||
with patch("sys.argv", test_args):
|
||||
main()
|
||||
|
||||
# Verify validator was called correctly
|
||||
assert mock_validator.validate_stream.call_count == 2
|
||||
mock_validator.validate_stream.assert_any_call("http://example.com/stream1")
|
||||
mock_validator.validate_stream.assert_any_call("http://example.com/stream2")
|
||||
|
||||
|
||||
@patch("app.utils.check_streams.logging")
|
||||
@patch("app.utils.check_streams.StreamValidator")
|
||||
def test_main_with_playlist(mock_validator_class, mock_logging, tmp_path):
|
||||
"""Test main function with a playlist file"""
|
||||
# Create test playlist
|
||||
playlist_content = "http://example.com/stream1\nhttp://example.com/stream2"
|
||||
playlist_file = tmp_path / "test.m3u"
|
||||
playlist_file.write_text(playlist_content)
|
||||
|
||||
# Setup mock validator
|
||||
mock_validator = Mock()
|
||||
mock_validator_class.return_value = mock_validator
|
||||
mock_validator.parse_playlist.return_value = [
|
||||
"http://example.com/stream1",
|
||||
"http://example.com/stream2",
|
||||
]
|
||||
mock_validator.validate_stream.return_value = (True, "Stream is valid")
|
||||
|
||||
# Setup test arguments
|
||||
test_args = ["script", str(playlist_file)]
|
||||
with patch("sys.argv", test_args):
|
||||
main()
|
||||
|
||||
# Verify validator was called correctly
|
||||
mock_validator.parse_playlist.assert_called_once_with(str(playlist_file))
|
||||
assert mock_validator.validate_stream.call_count == 2
|
||||
|
||||
|
||||
@patch("app.utils.check_streams.logging")
|
||||
@patch("app.utils.check_streams.StreamValidator")
|
||||
def test_main_with_dead_streams(mock_validator_class, mock_logging, tmp_path):
|
||||
"""Test main function handling dead streams"""
|
||||
# Setup mock validator
|
||||
mock_validator = Mock()
|
||||
mock_validator_class.return_value = mock_validator
|
||||
mock_validator.validate_stream.return_value = (False, "Stream is dead")
|
||||
|
||||
# Setup test arguments
|
||||
test_args = ["script", "http://example.com/dead1", "http://example.com/dead2"]
|
||||
|
||||
# Mock file operations
|
||||
mock_file = mock_open()
|
||||
with patch("sys.argv", test_args), patch("builtins.open", mock_file):
|
||||
main()
|
||||
|
||||
# Verify dead streams were written to file
|
||||
mock_file().write.assert_called_once_with(
|
||||
"http://example.com/dead1\nhttp://example.com/dead2"
|
||||
)
|
||||
|
||||
|
||||
@patch("app.utils.check_streams.logging")
|
||||
@patch("app.utils.check_streams.StreamValidator")
|
||||
@patch("os.path.isfile")
|
||||
def test_main_with_playlist_error(
|
||||
mock_isfile, mock_validator_class, mock_logging, tmp_path
|
||||
):
|
||||
"""Test main function handling playlist parsing errors"""
|
||||
# Setup mock validator
|
||||
mock_validator = Mock()
|
||||
mock_validator_class.return_value = mock_validator
|
||||
|
||||
# Configure mock validator behavior
|
||||
error_msg = "Failed to parse playlist"
|
||||
mock_validator.parse_playlist.side_effect = [
|
||||
Exception(error_msg), # First call fails
|
||||
["http://example.com/stream1"], # Second call succeeds
|
||||
]
|
||||
mock_validator.validate_stream.return_value = (True, "Stream is valid")
|
||||
|
||||
# Configure isfile mock to return True for our test files
|
||||
mock_isfile.side_effect = lambda x: x in ["/invalid.m3u", "/valid.m3u"]
|
||||
|
||||
# Setup test arguments
|
||||
test_args = ["script", "/invalid.m3u", "/valid.m3u"]
|
||||
with patch("sys.argv", test_args):
|
||||
main()
|
||||
|
||||
# Verify error was logged correctly
|
||||
mock_logging.error.assert_called_with(
|
||||
"Failed to process file /invalid.m3u: Failed to parse playlist"
|
||||
)
|
||||
|
||||
# Verify processing continued with valid playlist
|
||||
mock_validator.parse_playlist.assert_called_with("/valid.m3u")
|
||||
assert (
|
||||
mock_validator.validate_stream.call_count == 1
|
||||
) # Called for the URL from valid playlist
|
||||
69
tests/utils/test_database.py
Normal file
69
tests/utils/test_database.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.utils.database import get_db, get_db_credentials
|
||||
from tests.utils.db_mocks import mock_env, mock_ssm, session_mock
|
||||
|
||||
|
||||
def test_get_db_credentials_env(mock_env):
|
||||
"""Test getting DB credentials from environment variables"""
|
||||
conn_str = get_db_credentials()
|
||||
assert conn_str == "postgresql://testuser:testpass@localhost/testdb"
|
||||
|
||||
|
||||
def test_get_db_credentials_ssm(mock_ssm):
|
||||
"""Test getting DB credentials from SSM"""
|
||||
os.environ.pop("MOCK_AUTH", None)
|
||||
conn_str = get_db_credentials()
|
||||
expected_conn = "postgresql://mocked_value:mocked_value@mocked_value/mocked_value"
|
||||
assert expected_conn in conn_str
|
||||
mock_ssm.get_parameter.assert_called()
|
||||
|
||||
|
||||
def test_get_db_credentials_ssm_exception(mock_ssm):
|
||||
"""Test SSM credential fetching failure raises RuntimeError"""
|
||||
os.environ.pop("MOCK_AUTH", None)
|
||||
mock_ssm.get_parameter.side_effect = Exception("SSM timeout")
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
get_db_credentials()
|
||||
|
||||
assert "Failed to fetch DB credentials from SSM: SSM timeout" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_session_creation():
|
||||
"""Test database session creation"""
|
||||
session = session_mock()
|
||||
assert isinstance(session, Session)
|
||||
session.close()
|
||||
|
||||
|
||||
def test_get_db_generator():
|
||||
"""Test get_db dependency generator"""
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
assert isinstance(db, Session)
|
||||
try:
|
||||
next(db_gen) # Should raise StopIteration
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
|
||||
def test_init_db(mocker, mock_env):
|
||||
"""Test database initialization creates tables"""
|
||||
mock_create_all = mocker.patch("app.models.Base.metadata.create_all")
|
||||
|
||||
# Mock get_db_credentials to return SQLite test connection
|
||||
mocker.patch(
|
||||
"app.utils.database.get_db_credentials",
|
||||
return_value="sqlite:///:memory:",
|
||||
)
|
||||
|
||||
from app.utils.database import engine, init_db
|
||||
|
||||
init_db()
|
||||
|
||||
# Verify create_all was called with the engine
|
||||
mock_create_all.assert_called_once_with(bind=engine)
|
||||
Reference in New Issue
Block a user