Compare commits
74 Commits
35745c43bd
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| a42d4c30a6 | |||
| abb467749b | |||
| b8ac25e301 | |||
| 729eabf27f | |||
| 34c446bcfa | |||
| d4cc74ea8c | |||
| 21b73b6843 | |||
| e743daf9f7 | |||
| b0d98551b8 | |||
| eaab1ef998 | |||
| e25f8c1ecd | |||
| 95bf0f9701 | |||
| f7a1c20066 | |||
| bf6f156fec | |||
| 7e25ec6755 | |||
| 6d506122d9 | |||
| 02913c7385 | |||
| e46f13930d | |||
| 903f190ee2 | |||
| 32af6bbdb5 | |||
| 7ee7a0e644 | |||
| 9474a3ca44 | |||
| 1ab8599dde | |||
| fb5215b92a | |||
| cebbb9c1a8 | |||
| 4b1a7e9bea | |||
| 21cc99eff6 | |||
| 76dc8908de | |||
| 07dab76e3b | |||
| c21b34f5fe | |||
| b1942354d9 | |||
| 3937269bb9 | |||
| 8c7ed421c9 | |||
| c96ee307db | |||
| f11d533fac | |||
| 99d26a8f53 | |||
| 260fcb311b | |||
| 1e82418cad | |||
| 9c690fe6a6 | |||
| 9e8df169fc | |||
| 5ee6cb4be4 | |||
| c1e3a6ef26 | |||
| cb793ef5e1 | |||
| be719a6e34 | |||
| 5767124031 | |||
| c6f7e9cb2b | |||
| eeb0f1c844 | |||
| 4cb3811d17 | |||
| 489281f3eb | |||
| b947ac67f0 | |||
| dd2446a01a | |||
| 639adba7eb | |||
| 5698e7f26b | |||
| df3fc2f37c | |||
| 594ce0c67a | |||
| 37be1f3f91 | |||
| 732667cf64 | |||
| 5bc7a72a92 | |||
| a5dfc1b493 | |||
| 0b69ffd67c | |||
| 127d81adac | |||
| 658f7998ef | |||
| c4f19999dc | |||
| c221a8cded | |||
| 8d1997fa5a | |||
| 795a25961f | |||
| d55c383bc4 | |||
| 5c17e4b1e9 | |||
| 30ccf86c86 | |||
| ae040fc49e | |||
| 47befceb17 | |||
| 7f282049ac | |||
| 38e5a94701 | |||
| 7b7ff78030 |
19
.env.example
Normal file
19
.env.example
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
|
||||||
|
# Environment variables
|
||||||
|
# Scheduler configuration
|
||||||
|
STREAM_VALIDATION_SCHEDULE=0 3 * * * # Daily at 3 AM (cron syntax)
|
||||||
|
STREAM_VALIDATION_BATCH_SIZE=10 # Number of channels per batch (0=all)
|
||||||
|
|
||||||
|
# For use with Docker Compose to run application locally
|
||||||
|
MOCK_AUTH=true/false
|
||||||
|
DB_USER=MyDBUser
|
||||||
|
DB_PASSWORD=MyDBPassword
|
||||||
|
DB_HOST=MyDBHost
|
||||||
|
DB_NAME=iptv_manager
|
||||||
|
|
||||||
|
FREEDNS_User=MyFreeDNSUsername
|
||||||
|
FREEDNS_Password=MyFreeDNSPassword
|
||||||
|
DOMAIN_NAME=mydomain.com
|
||||||
|
SSH_PUBLIC_KEY="ssh-rsa AAAAB3NzaC1yc2EMYPUBLICKEY7+"
|
||||||
|
REPO_URL="https://git.example.com/user/repo.git"
|
||||||
|
LETSENCRYPT_EMAIL="admin@example.com"
|
||||||
@@ -39,6 +39,13 @@ jobs:
|
|||||||
|
|
||||||
- name: Deploy to AWS
|
- name: Deploy to AWS
|
||||||
run: cdk deploy --app="python3 ${PWD}/app.py" --require-approval=never
|
run: cdk deploy --app="python3 ${PWD}/app.py" --require-approval=never
|
||||||
|
env:
|
||||||
|
FREEDNS_User: ${{ secrets.FREEDNS_USER }}
|
||||||
|
FREEDNS_Password: ${{ secrets.FREEDNS_PASSWORD }}
|
||||||
|
DOMAIN_NAME: ${{ secrets.DOMAIN_NAME }}
|
||||||
|
SSH_PUBLIC_KEY: ${{ secrets.SSH_PUBLIC_KEY }}
|
||||||
|
REPO_URL: ${{ secrets.REPO_URL }}
|
||||||
|
LETSENCRYPT_EMAIL: ${{ secrets.LETSENCRYPT_EMAIL }}
|
||||||
|
|
||||||
- name: Install AWS CLI
|
- name: Install AWS CLI
|
||||||
run: |
|
run: |
|
||||||
@@ -50,20 +57,23 @@ jobs:
|
|||||||
- name: Update application on instance
|
- name: Update application on instance
|
||||||
run: |
|
run: |
|
||||||
INSTANCE_IDS=$(aws ec2 describe-instances \
|
INSTANCE_IDS=$(aws ec2 describe-instances \
|
||||||
--filters "Name=tag:Name,Values=IptvUpdater/IptvUpdaterInstance" \
|
--region us-east-2 \
|
||||||
|
--filters "Name=tag:Name,Values=IptvManagerStack/IptvManagerInstance" \
|
||||||
"Name=instance-state-name,Values=running" \
|
"Name=instance-state-name,Values=running" \
|
||||||
--query "Reservations[].Instances[].InstanceId" \
|
--query "Reservations[].Instances[].InstanceId" \
|
||||||
--output text)
|
--output text)
|
||||||
|
|
||||||
for INSTANCE_ID in $INSTANCE_IDS; do
|
for INSTANCE_ID in $INSTANCE_IDS; do
|
||||||
aws ssm send-command \
|
aws ssm send-command \
|
||||||
|
--region us-east-2 \
|
||||||
--instance-ids "$INSTANCE_ID" \
|
--instance-ids "$INSTANCE_ID" \
|
||||||
--document-name "AWS-RunShellScript" \
|
--document-name "AWS-RunShellScript" \
|
||||||
--parameters 'commands=[
|
--parameters 'commands=[
|
||||||
"cd /home/ec2-user/iptv-updater-aws",
|
"cd /home/ec2-user/iptv-manager-service",
|
||||||
"git pull",
|
"git pull",
|
||||||
"pip3 install -r requirements.txt",
|
"pip3 install -r requirements.txt",
|
||||||
"sudo systemctl restart iptv-updater"
|
"alembic upgrade head",
|
||||||
|
"sudo systemctl restart iptv-manager"
|
||||||
]'
|
]'
|
||||||
done
|
done
|
||||||
|
|
||||||
|
|||||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -4,10 +4,16 @@ __pycache__
|
|||||||
.pytest_cache
|
.pytest_cache
|
||||||
.env
|
.env
|
||||||
.venv
|
.venv
|
||||||
|
*.pid
|
||||||
|
*.log
|
||||||
*.egg-info
|
*.egg-info
|
||||||
.coverage
|
.coverage
|
||||||
|
.roomodes
|
||||||
cdk.out/
|
cdk.out/
|
||||||
node_modules/
|
node_modules/
|
||||||
|
data/
|
||||||
|
.roo/
|
||||||
|
.ruru/
|
||||||
|
|
||||||
# CDK asset staging directory
|
# CDK asset staging directory
|
||||||
.cdk.staging
|
.cdk.staging
|
||||||
|
|||||||
16
.pre-commit-config.yaml
Normal file
16
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.11.12
|
||||||
|
hooks:
|
||||||
|
- id: ruff
|
||||||
|
args: [--fix, --exit-non-zero-on-fix]
|
||||||
|
- id: ruff-format
|
||||||
|
|
||||||
|
- repo: local
|
||||||
|
hooks:
|
||||||
|
- id: pytest-check
|
||||||
|
name: pytest-check
|
||||||
|
entry: pytest
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
always_run: true
|
||||||
85
.vscode/settings.json
vendored
85
.vscode/settings.json
vendored
@@ -1,22 +1,105 @@
|
|||||||
{
|
{
|
||||||
|
"python.terminal.activateEnvironment": true,
|
||||||
|
"python.terminal.activateEnvInCurrentTerminal": true,
|
||||||
|
"editor.formatOnSave": true,
|
||||||
|
"editor.defaultFormatter": "charliermarsh.ruff",
|
||||||
|
"ruff.importStrategy": "fromEnvironment",
|
||||||
|
"ruff.path": ["${workspaceFolder}"],
|
||||||
"cSpell.words": [
|
"cSpell.words": [
|
||||||
|
"addopts",
|
||||||
|
"adminpassword",
|
||||||
"altinstall",
|
"altinstall",
|
||||||
|
"apscheduler",
|
||||||
|
"asyncio",
|
||||||
|
"autoflush",
|
||||||
|
"autoupdate",
|
||||||
|
"autouse",
|
||||||
|
"awscli",
|
||||||
"awscliv",
|
"awscliv",
|
||||||
"boto",
|
"boto",
|
||||||
|
"botocore",
|
||||||
|
"BURSTABLE",
|
||||||
"cabletv",
|
"cabletv",
|
||||||
|
"capsys",
|
||||||
|
"CDUF",
|
||||||
|
"cduflogo",
|
||||||
|
"cdulogo",
|
||||||
|
"CDUNF",
|
||||||
|
"cdunflogo",
|
||||||
"certbot",
|
"certbot",
|
||||||
"certifi",
|
"certifi",
|
||||||
|
"cfulogo",
|
||||||
|
"CLEU",
|
||||||
|
"cleulogo",
|
||||||
|
"CLUF",
|
||||||
|
"cluflogo",
|
||||||
|
"clulogo",
|
||||||
|
"cpulogo",
|
||||||
|
"crond",
|
||||||
|
"cronie",
|
||||||
|
"cuflgo",
|
||||||
|
"CUNF",
|
||||||
|
"cunflogo",
|
||||||
|
"cuulogo",
|
||||||
|
"datname",
|
||||||
|
"deadstreams",
|
||||||
|
"delenv",
|
||||||
|
"delogo",
|
||||||
"devel",
|
"devel",
|
||||||
|
"dflogo",
|
||||||
|
"dmlogo",
|
||||||
"dotenv",
|
"dotenv",
|
||||||
|
"EXTINF",
|
||||||
|
"EXTM",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
|
"filterwarnings",
|
||||||
"fiorinis",
|
"fiorinis",
|
||||||
|
"freedns",
|
||||||
|
"fullchain",
|
||||||
"gitea",
|
"gitea",
|
||||||
|
"httpx",
|
||||||
"iptv",
|
"iptv",
|
||||||
|
"isort",
|
||||||
|
"KHTML",
|
||||||
|
"lclogo",
|
||||||
|
"LETSENCRYPT",
|
||||||
|
"levelname",
|
||||||
|
"mpegurl",
|
||||||
"nohup",
|
"nohup",
|
||||||
|
"nopriority",
|
||||||
|
"ondelete",
|
||||||
|
"onupdate",
|
||||||
"passlib",
|
"passlib",
|
||||||
|
"PGPASSWORD",
|
||||||
|
"poolclass",
|
||||||
|
"psql",
|
||||||
|
"psycopg",
|
||||||
|
"pycache",
|
||||||
|
"pycodestyle",
|
||||||
|
"pyflakes",
|
||||||
|
"pyjwt",
|
||||||
|
"pytest",
|
||||||
|
"PYTHONDONTWRITEBYTECODE",
|
||||||
|
"PYTHONUNBUFFERED",
|
||||||
|
"pyupgrade",
|
||||||
|
"reloadcmd",
|
||||||
|
"roomodes",
|
||||||
|
"ruru",
|
||||||
|
"sessionmaker",
|
||||||
|
"sqlalchemy",
|
||||||
|
"sqliteuuid",
|
||||||
"starlette",
|
"starlette",
|
||||||
"stefano",
|
"stefano",
|
||||||
|
"testadmin",
|
||||||
|
"testdb",
|
||||||
|
"testpass",
|
||||||
|
"testpaths",
|
||||||
|
"testuser",
|
||||||
|
"uflogo",
|
||||||
|
"umlogo",
|
||||||
|
"usefixtures",
|
||||||
"uvicorn",
|
"uvicorn",
|
||||||
"venv"
|
"venv",
|
||||||
|
"wrongpass"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
150
README.md
150
README.md
@@ -1 +1,149 @@
|
|||||||
# To do
|
# IPTV Manager Service
|
||||||
|
|
||||||
|
A FastAPI-based service for managing IPTV playlists and channel priorities. The application provides secure endpoints for user authentication, channel management, and playlist generation.
|
||||||
|
|
||||||
|
## ✨ Features
|
||||||
|
|
||||||
|
- **JWT Authentication**: Secure login using AWS Cognito
|
||||||
|
- **Channel Management**: CRUD operations for IPTV channels
|
||||||
|
- **Playlist Generation**: Create M3U playlists with channel priorities
|
||||||
|
- **Stream Monitoring**: Background checks for channel availability
|
||||||
|
- **Priority Management**: Set channel priorities for playlist ordering
|
||||||
|
|
||||||
|
## 🛠️ Technology Stack
|
||||||
|
|
||||||
|
- **Backend**: Python 3.11, FastAPI
|
||||||
|
- **Database**: PostgreSQL (SQLAlchemy ORM)
|
||||||
|
- **Authentication**: AWS Cognito
|
||||||
|
- **Infrastructure**: AWS CDK (API Gateway, Lambda, RDS)
|
||||||
|
- **Testing**: Pytest with 85%+ coverage
|
||||||
|
- **CI/CD**: Pre-commit hooks, Alembic migrations
|
||||||
|
|
||||||
|
## 🚀 Getting Started
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- Python 3.11+
|
||||||
|
- Docker
|
||||||
|
- AWS CLI (for deployment)
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Clone repository
|
||||||
|
git clone https://github.com/your-repo/iptv-manager-service.git
|
||||||
|
cd iptv-manager-service
|
||||||
|
|
||||||
|
# Setup environment
|
||||||
|
python -m venv venv
|
||||||
|
source venv/bin/activate
|
||||||
|
pip install -r requirements.txt
|
||||||
|
cp .env.example .env # Update with your values
|
||||||
|
|
||||||
|
# Run installation script
|
||||||
|
./scripts/install.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
### Running Locally
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start development environment
|
||||||
|
./scripts/start_local_dev.sh
|
||||||
|
|
||||||
|
# Stop development environment
|
||||||
|
./scripts/stop_local_dev.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
## ☁️ AWS Deployment
|
||||||
|
|
||||||
|
The infrastructure is defined in CDK. Use the provided scripts:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Deploy AWS infrastructure
|
||||||
|
./scripts/deploy.sh
|
||||||
|
|
||||||
|
# Destroy AWS infrastructure
|
||||||
|
./scripts/destroy.sh
|
||||||
|
|
||||||
|
# Create Cognito test user
|
||||||
|
./scripts/create_cognito_user.sh
|
||||||
|
|
||||||
|
# Delete Cognito user
|
||||||
|
./scripts/delete_cognito_user.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
Key AWS components:
|
||||||
|
|
||||||
|
- API Gateway
|
||||||
|
- Lambda functions
|
||||||
|
- RDS PostgreSQL
|
||||||
|
- Cognito User Pool
|
||||||
|
|
||||||
|
## 🤖 Continuous Integration/Deployment
|
||||||
|
|
||||||
|
This project includes a Gitea Actions workflow (`.gitea/workflows/deploy.yml`) for automated deployment to AWS. The workflow is fully compatible with GitHub Actions and can be easily adapted by:
|
||||||
|
|
||||||
|
1. Placing the workflow file in the `.github/workflows/` directory
|
||||||
|
2. Setting up the required secrets in your CI/CD environment:
|
||||||
|
- `AWS_ACCESS_KEY_ID`
|
||||||
|
- `AWS_SECRET_ACCESS_KEY`
|
||||||
|
- `AWS_DEFAULT_REGION`
|
||||||
|
|
||||||
|
The workflow automatically deploys the infrastructure and application when changes are pushed to the main branch.
|
||||||
|
|
||||||
|
## 📚 API Documentation
|
||||||
|
|
||||||
|
Access interactive docs at:
|
||||||
|
|
||||||
|
- Swagger UI: `http://localhost:8000/docs`
|
||||||
|
- ReDoc: `http://localhost:8000/redoc`
|
||||||
|
|
||||||
|
### Key Endpoints
|
||||||
|
|
||||||
|
| Endpoint | Method | Description |
|
||||||
|
| ------------- | ------ | --------------------- |
|
||||||
|
| `/auth/login` | POST | User authentication |
|
||||||
|
| `/channels` | GET | List all channels |
|
||||||
|
| `/playlist` | GET | Generate M3U playlist |
|
||||||
|
| `/priorities` | POST | Set channel priority |
|
||||||
|
|
||||||
|
## 🧪 Testing
|
||||||
|
|
||||||
|
Run the full test suite:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pytest
|
||||||
|
```
|
||||||
|
|
||||||
|
Test coverage includes:
|
||||||
|
|
||||||
|
- Authentication workflows
|
||||||
|
- Channel CRUD operations
|
||||||
|
- Playlist generation logic
|
||||||
|
- Stream monitoring
|
||||||
|
- Database operations
|
||||||
|
|
||||||
|
## 📂 Project Structure
|
||||||
|
|
||||||
|
```txt
|
||||||
|
iptv-manager-service/
|
||||||
|
├── app/ # Core application
|
||||||
|
│ ├── auth/ # Cognito authentication
|
||||||
|
│ ├── iptv/ # Playlist logic
|
||||||
|
│ ├── models/ # Database models
|
||||||
|
│ ├── routers/ # API endpoints
|
||||||
|
│ ├── utils/ # Helper functions
|
||||||
|
│ └── main.py # App entry point
|
||||||
|
├── infrastructure/ # AWS CDK stack
|
||||||
|
├── docker/ # Docker configs
|
||||||
|
├── scripts/ # Deployment scripts
|
||||||
|
├── tests/ # Comprehensive tests
|
||||||
|
├── alembic/ # Database migrations
|
||||||
|
├── .gitea/ # Gitea CI/CD workflows
|
||||||
|
│ └── workflows/
|
||||||
|
└── ... # Config files
|
||||||
|
```
|
||||||
|
|
||||||
|
## 📝 License
|
||||||
|
|
||||||
|
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
||||||
|
|||||||
141
alembic.ini
Normal file
141
alembic.ini
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
# A generic, single database configuration.
|
||||||
|
|
||||||
|
[alembic]
|
||||||
|
# path to migration scripts.
|
||||||
|
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||||
|
# format, relative to the token %(here)s which refers to the location of this
|
||||||
|
# ini file
|
||||||
|
script_location = %(here)s/alembic
|
||||||
|
|
||||||
|
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||||
|
# Uncomment the line below if you want the files to be prepended with date and time
|
||||||
|
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||||
|
# for all available tokens
|
||||||
|
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||||
|
|
||||||
|
# sys.path path, will be prepended to sys.path if present.
|
||||||
|
# defaults to the current working directory. for multiple paths, the path separator
|
||||||
|
# is defined by "path_separator" below.
|
||||||
|
prepend_sys_path = .
|
||||||
|
|
||||||
|
|
||||||
|
# timezone to use when rendering the date within the migration file
|
||||||
|
# as well as the filename.
|
||||||
|
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
|
||||||
|
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
||||||
|
# string value is passed to ZoneInfo()
|
||||||
|
# leave blank for localtime
|
||||||
|
# timezone =
|
||||||
|
|
||||||
|
# max length of characters to apply to the "slug" field
|
||||||
|
# truncate_slug_length = 40
|
||||||
|
|
||||||
|
# set to 'true' to run the environment during
|
||||||
|
# the 'revision' command, regardless of autogenerate
|
||||||
|
# revision_environment = false
|
||||||
|
|
||||||
|
# set to 'true' to allow .pyc and .pyo files without
|
||||||
|
# a source .py file to be detected as revisions in the
|
||||||
|
# versions/ directory
|
||||||
|
# sourceless = false
|
||||||
|
|
||||||
|
# version location specification; This defaults
|
||||||
|
# to <script_location>/versions. When using multiple version
|
||||||
|
# directories, initial revisions must be specified with --version-path.
|
||||||
|
# The path separator used here should be the separator specified by "path_separator"
|
||||||
|
# below.
|
||||||
|
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
|
||||||
|
|
||||||
|
# path_separator; This indicates what character is used to split lists of file
|
||||||
|
# paths, including version_locations and prepend_sys_path within configparser
|
||||||
|
# files such as alembic.ini.
|
||||||
|
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
|
||||||
|
# to provide os-dependent path splitting.
|
||||||
|
#
|
||||||
|
# Note that in order to support legacy alembic.ini files, this default does NOT
|
||||||
|
# take place if path_separator is not present in alembic.ini. If this
|
||||||
|
# option is omitted entirely, fallback logic is as follows:
|
||||||
|
#
|
||||||
|
# 1. Parsing of the version_locations option falls back to using the legacy
|
||||||
|
# "version_path_separator" key, which if absent then falls back to the legacy
|
||||||
|
# behavior of splitting on spaces and/or commas.
|
||||||
|
# 2. Parsing of the prepend_sys_path option falls back to the legacy
|
||||||
|
# behavior of splitting on spaces, commas, or colons.
|
||||||
|
#
|
||||||
|
# Valid values for path_separator are:
|
||||||
|
#
|
||||||
|
# path_separator = :
|
||||||
|
# path_separator = ;
|
||||||
|
# path_separator = space
|
||||||
|
# path_separator = newline
|
||||||
|
#
|
||||||
|
# Use os.pathsep. Default configuration used for new projects.
|
||||||
|
path_separator = os
|
||||||
|
|
||||||
|
# set to 'true' to search source files recursively
|
||||||
|
# in each "version_locations" directory
|
||||||
|
# new in Alembic version 1.10
|
||||||
|
# recursive_version_locations = false
|
||||||
|
|
||||||
|
# the output encoding used when revision files
|
||||||
|
# are written from script.py.mako
|
||||||
|
# output_encoding = utf-8
|
||||||
|
|
||||||
|
# database URL. This is consumed by the user-maintained env.py script only.
|
||||||
|
# other means of configuring database URLs may be customized within the env.py
|
||||||
|
# file.
|
||||||
|
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||||
|
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
# post_write_hooks defines scripts or Python functions that are run
|
||||||
|
# on newly generated revision scripts. See the documentation for further
|
||||||
|
# detail and examples
|
||||||
|
|
||||||
|
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||||
|
# hooks = black
|
||||||
|
# black.type = console_scripts
|
||||||
|
# black.entrypoint = black
|
||||||
|
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
|
||||||
|
# hooks = ruff
|
||||||
|
# ruff.type = exec
|
||||||
|
# ruff.executable = %(here)s/.venv/bin/ruff
|
||||||
|
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# Logging configuration. This is also consumed by the user-maintained
|
||||||
|
# env.py script only.
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARNING
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARNING
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
1
alembic/README
Normal file
1
alembic/README
Normal file
@@ -0,0 +1 @@
|
|||||||
|
Generic single-database configuration.
|
||||||
79
alembic/env.py
Normal file
79
alembic/env.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from sqlalchemy import engine_from_config, pool
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
from app.models.db import Base
|
||||||
|
from app.utils.database import get_db_credentials
|
||||||
|
|
||||||
|
# this is the Alembic Config object, which provides
|
||||||
|
# access to the values within the .ini file in use.
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
# Interpret the config file for Python logging.
|
||||||
|
# This line sets up loggers basically.
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
# add your model's MetaData object here
|
||||||
|
# for 'autogenerate' support
|
||||||
|
target_metadata = Base.metadata
|
||||||
|
|
||||||
|
# Override sqlalchemy.url with dynamic credentials
|
||||||
|
if not context.is_offline_mode():
|
||||||
|
config.set_main_option("sqlalchemy.url", get_db_credentials())
|
||||||
|
|
||||||
|
# other values from the config, defined by the needs of env.py,
|
||||||
|
# can be acquired:
|
||||||
|
# my_important_option = config.get_main_option("my_important_option")
|
||||||
|
# ... etc.
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
"""Run migrations in 'offline' mode.
|
||||||
|
|
||||||
|
This configures the context with just a URL
|
||||||
|
and not an Engine, though an Engine is acceptable
|
||||||
|
here as well. By skipping the Engine creation
|
||||||
|
we don't even need a DBAPI to be available.
|
||||||
|
|
||||||
|
Calls to context.execute() here emit the given string to the
|
||||||
|
script output.
|
||||||
|
|
||||||
|
"""
|
||||||
|
url = config.get_main_option("sqlalchemy.url")
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
"""Run migrations in 'online' mode.
|
||||||
|
|
||||||
|
In this scenario we need to create an Engine
|
||||||
|
and associate a connection with the context.
|
||||||
|
|
||||||
|
"""
|
||||||
|
connectable = engine_from_config(
|
||||||
|
config.get_section(config.config_ini_section, {}),
|
||||||
|
prefix="sqlalchemy.",
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
)
|
||||||
|
|
||||||
|
with connectable.connect() as connection:
|
||||||
|
context.configure(connection=connection, target_metadata=target_metadata)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
28
alembic/script.py.mako
Normal file
28
alembic/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||||
|
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
@@ -0,0 +1,110 @@
|
|||||||
|
"""add groups table and migrate group_title data
|
||||||
|
|
||||||
|
Revision ID: 0a455608256f
|
||||||
|
Revises: 95b61a92455a
|
||||||
|
Create Date: 2025-06-10 09:22:11.820035
|
||||||
|
|
||||||
|
"""
|
||||||
|
import uuid
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '0a455608256f'
|
||||||
|
down_revision: Union[str, None] = '95b61a92455a'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('groups',
|
||||||
|
sa.Column('id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(), nullable=False),
|
||||||
|
sa.Column('sort_order', sa.Integer(), nullable=False, server_default='0'),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.UniqueConstraint('name')
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create temporary table for group mapping
|
||||||
|
group_mapping = op.create_table(
|
||||||
|
'group_mapping',
|
||||||
|
sa.Column('group_title', sa.String(), nullable=False),
|
||||||
|
sa.Column('group_id', sa.UUID(), nullable=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get existing group titles and create groups
|
||||||
|
conn = op.get_bind()
|
||||||
|
distinct_groups = conn.execute(
|
||||||
|
sa.text("SELECT DISTINCT group_title FROM channels")
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
for group in distinct_groups:
|
||||||
|
group_title = group[0]
|
||||||
|
group_id = str(uuid.uuid4())
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"INSERT INTO groups (id, name, sort_order) "
|
||||||
|
"VALUES (:id, :name, 0)"
|
||||||
|
).bindparams(id=group_id, name=group_title)
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
group_mapping.insert().values(
|
||||||
|
group_title=group_title,
|
||||||
|
group_id=group_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add group_id column (nullable first)
|
||||||
|
op.add_column('channels', sa.Column('group_id', sa.UUID(), nullable=True))
|
||||||
|
|
||||||
|
# Update channels with group_ids
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"UPDATE channels c SET group_id = gm.group_id "
|
||||||
|
"FROM group_mapping gm WHERE c.group_title = gm.group_title"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now make group_id non-nullable and add constraints
|
||||||
|
op.alter_column('channels', 'group_id', nullable=False)
|
||||||
|
op.drop_constraint(op.f('uix_group_title_name'), 'channels', type_='unique')
|
||||||
|
op.create_unique_constraint('uix_group_id_name', 'channels', ['group_id', 'name'])
|
||||||
|
op.create_foreign_key('fk_channels_group_id', 'channels', 'groups', ['group_id'], ['id'])
|
||||||
|
|
||||||
|
# Clean up and drop group_title
|
||||||
|
op.drop_table('group_mapping')
|
||||||
|
op.drop_column('channels', 'group_title')
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column('channels', sa.Column('group_title', sa.VARCHAR(), autoincrement=False, nullable=True))
|
||||||
|
|
||||||
|
# Restore group_title values from groups table
|
||||||
|
conn = op.get_bind()
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"UPDATE channels c SET group_title = g.name "
|
||||||
|
"FROM groups g WHERE c.group_id = g.id"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now make group_title non-nullable
|
||||||
|
op.alter_column('channels', 'group_title', nullable=False)
|
||||||
|
|
||||||
|
# Drop constraints and columns
|
||||||
|
op.drop_constraint('fk_channels_group_id', 'channels', type_='foreignkey')
|
||||||
|
op.drop_constraint('uix_group_id_name', 'channels', type_='unique')
|
||||||
|
op.create_unique_constraint(op.f('uix_group_title_name'), 'channels', ['group_title', 'name'])
|
||||||
|
op.drop_column('channels', 'group_id')
|
||||||
|
op.drop_table('groups')
|
||||||
|
# ### end Alembic commands ###
|
||||||
79
alembic/versions/95b61a92455a_create_initial_tables.py
Normal file
79
alembic/versions/95b61a92455a_create_initial_tables.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""create initial tables
|
||||||
|
|
||||||
|
Revision ID: 95b61a92455a
|
||||||
|
Revises:
|
||||||
|
Create Date: 2025-05-29 14:42:16.239587
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '95b61a92455a'
|
||||||
|
down_revision: Union[str, None] = None
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('channels',
|
||||||
|
sa.Column('id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('tvg_id', sa.String(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(), nullable=False),
|
||||||
|
sa.Column('group_title', sa.String(), nullable=False),
|
||||||
|
sa.Column('tvg_name', sa.String(), nullable=True),
|
||||||
|
sa.Column('tvg_logo', sa.String(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.UniqueConstraint('group_title', 'name', name='uix_group_title_name')
|
||||||
|
)
|
||||||
|
op.create_table('priorities',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('description', sa.String(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_table('channels_urls',
|
||||||
|
sa.Column('id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('channel_id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('url', sa.String(), nullable=False),
|
||||||
|
sa.Column('in_use', sa.Boolean(), nullable=False),
|
||||||
|
sa.Column('priority_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['channel_id'], ['channels.id'], ondelete='CASCADE'),
|
||||||
|
sa.ForeignKeyConstraint(['priority_id'], ['priorities.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
# Seed initial priorities
|
||||||
|
op.bulk_insert(
|
||||||
|
sa.Table(
|
||||||
|
'priorities',
|
||||||
|
sa.MetaData(),
|
||||||
|
sa.Column('id', sa.Integer),
|
||||||
|
sa.Column('description', sa.String),
|
||||||
|
),
|
||||||
|
[
|
||||||
|
{'id': 100, 'description': 'High'},
|
||||||
|
{'id': 200, 'description': 'Medium'},
|
||||||
|
{'id': 300, 'description': 'Low'},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# Remove seeded priorities
|
||||||
|
op.execute("DELETE FROM priorities WHERE id IN (100, 200, 300);")
|
||||||
|
|
||||||
|
# Drop tables
|
||||||
|
op.drop_table('channels_urls')
|
||||||
|
op.drop_table('priorities')
|
||||||
|
op.drop_table('channels')
|
||||||
42
app.py
42
app.py
@@ -1,7 +1,45 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
import os
|
||||||
|
|
||||||
import aws_cdk as cdk
|
import aws_cdk as cdk
|
||||||
from infrastructure.stack import IptvUpdaterStack
|
|
||||||
|
from infrastructure.stack import IptvManagerStack
|
||||||
|
|
||||||
app = cdk.App()
|
app = cdk.App()
|
||||||
IptvUpdaterStack(app, "IptvUpdater")
|
|
||||||
|
# Read environment variables for FreeDNS credentials
|
||||||
|
freedns_user = os.environ.get("FREEDNS_User")
|
||||||
|
freedns_password = os.environ.get("FREEDNS_Password")
|
||||||
|
domain_name = os.environ.get("DOMAIN_NAME")
|
||||||
|
ssh_public_key = os.environ.get("SSH_PUBLIC_KEY")
|
||||||
|
repo_url = os.environ.get("REPO_URL")
|
||||||
|
letsencrypt_email = os.environ.get("LETSENCRYPT_EMAIL")
|
||||||
|
|
||||||
|
required_vars = {
|
||||||
|
"FREEDNS_User": freedns_user,
|
||||||
|
"FREEDNS_Password": freedns_password,
|
||||||
|
"DOMAIN_NAME": domain_name,
|
||||||
|
"SSH_PUBLIC_KEY": ssh_public_key,
|
||||||
|
"REPO_URL": repo_url,
|
||||||
|
"LETSENCRYPT_EMAIL": letsencrypt_email,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check for missing required variables
|
||||||
|
missing_vars = [k for k, v in required_vars.items() if not v]
|
||||||
|
if missing_vars:
|
||||||
|
raise ValueError(
|
||||||
|
f"Missing required environment variables: {', '.join(missing_vars)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
IptvManagerStack(
|
||||||
|
app,
|
||||||
|
"IptvManagerStack",
|
||||||
|
freedns_user=freedns_user,
|
||||||
|
freedns_password=freedns_password,
|
||||||
|
domain_name=domain_name,
|
||||||
|
ssh_public_key=ssh_public_key,
|
||||||
|
repo_url=repo_url,
|
||||||
|
letsencrypt_email=letsencrypt_email,
|
||||||
|
)
|
||||||
|
|
||||||
app.synth()
|
app.synth()
|
||||||
82
app/auth/cognito.py
Normal file
82
app/auth/cognito.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
import boto3
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
|
||||||
|
from app.models.auth import CognitoUser
|
||||||
|
from app.utils.auth import calculate_secret_hash
|
||||||
|
from app.utils.constants import (
|
||||||
|
AWS_REGION,
|
||||||
|
COGNITO_CLIENT_ID,
|
||||||
|
COGNITO_CLIENT_SECRET,
|
||||||
|
USER_ROLE_ATTRIBUTE,
|
||||||
|
)
|
||||||
|
|
||||||
|
cognito_client = boto3.client("cognito-idp", region_name=AWS_REGION)
|
||||||
|
|
||||||
|
|
||||||
|
def initiate_auth(username: str, password: str) -> dict:
|
||||||
|
"""
|
||||||
|
Initiate AUTH flow with Cognito using USER_PASSWORD_AUTH.
|
||||||
|
"""
|
||||||
|
auth_params = {"USERNAME": username, "PASSWORD": password}
|
||||||
|
|
||||||
|
# If a client secret is required, add SECRET_HASH
|
||||||
|
if COGNITO_CLIENT_SECRET:
|
||||||
|
auth_params["SECRET_HASH"] = calculate_secret_hash(
|
||||||
|
username, COGNITO_CLIENT_ID, COGNITO_CLIENT_SECRET
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = cognito_client.initiate_auth(
|
||||||
|
AuthFlow="USER_PASSWORD_AUTH",
|
||||||
|
AuthParameters=auth_params,
|
||||||
|
ClientId=COGNITO_CLIENT_ID,
|
||||||
|
)
|
||||||
|
return response["AuthenticationResult"]
|
||||||
|
except cognito_client.exceptions.NotAuthorizedException:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid username or password",
|
||||||
|
)
|
||||||
|
except cognito_client.exceptions.UserNotFoundException:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"An error occurred during authentication: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_from_token(access_token: str) -> CognitoUser:
|
||||||
|
"""
|
||||||
|
Verify the token by calling GetUser in Cognito and
|
||||||
|
retrieve user attributes including roles.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
user_response = cognito_client.get_user(AccessToken=access_token)
|
||||||
|
username = user_response.get("Username", "")
|
||||||
|
attributes = user_response.get("UserAttributes", [])
|
||||||
|
user_roles = []
|
||||||
|
|
||||||
|
for attr in attributes:
|
||||||
|
if attr["Name"] == USER_ROLE_ATTRIBUTE:
|
||||||
|
# Assume roles are stored as a comma-separated string
|
||||||
|
user_roles = [r.strip() for r in attr["Value"].split(",") if r.strip()]
|
||||||
|
break
|
||||||
|
|
||||||
|
return CognitoUser(username=username, roles=user_roles)
|
||||||
|
except cognito_client.exceptions.NotAuthorizedException:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token."
|
||||||
|
)
|
||||||
|
except cognito_client.exceptions.UserNotFoundException:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="User not found or invalid token.",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Token verification failed: {str(e)}",
|
||||||
|
)
|
||||||
51
app/auth/dependencies.py
Normal file
51
app/auth/dependencies.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import os
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
|
||||||
|
from app.models.auth import CognitoUser
|
||||||
|
|
||||||
|
# Use mock auth for local testing if MOCK_AUTH is set
|
||||||
|
if os.getenv("MOCK_AUTH", "").lower() == "true":
|
||||||
|
from app.auth.mock_auth import mock_get_user_from_token as get_user_from_token
|
||||||
|
else:
|
||||||
|
from app.auth.cognito import get_user_from_token
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="signin", scheme_name="Bearer")
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user(token: str = Depends(oauth2_scheme)) -> CognitoUser:
|
||||||
|
"""
|
||||||
|
Dependency to get the current user from the given token.
|
||||||
|
This will verify the token with Cognito and return the user's information.
|
||||||
|
"""
|
||||||
|
return get_user_from_token(token)
|
||||||
|
|
||||||
|
|
||||||
|
def require_roles(*required_roles: str) -> Callable:
|
||||||
|
"""
|
||||||
|
Decorator for role-based access control.
|
||||||
|
Use on endpoints to enforce that the user possesses all required roles.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(endpoint: Callable) -> Callable:
|
||||||
|
@wraps(endpoint)
|
||||||
|
async def wrapper(
|
||||||
|
*args, user: CognitoUser = Depends(get_current_user), **kwargs
|
||||||
|
):
|
||||||
|
user_roles = set(user.roles or [])
|
||||||
|
needed_roles = set(required_roles)
|
||||||
|
if not needed_roles.issubset(user_roles):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=(
|
||||||
|
"You do not have the required roles to access this endpoint."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return endpoint(*args, user=user, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
26
app/auth/mock_auth.py
Normal file
26
app/auth/mock_auth.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from fastapi import HTTPException, status
|
||||||
|
|
||||||
|
from app.models.auth import CognitoUser
|
||||||
|
|
||||||
|
MOCK_USERS = {"testuser": {"username": "testuser", "roles": ["admin"]}}
|
||||||
|
|
||||||
|
|
||||||
|
def mock_get_user_from_token(token: str) -> CognitoUser:
|
||||||
|
"""
|
||||||
|
Mock version of get_user_from_token for local testing
|
||||||
|
Accepts 'testuser' as a valid token and returns admin user
|
||||||
|
"""
|
||||||
|
if token == "testuser":
|
||||||
|
return CognitoUser(**MOCK_USERS["testuser"])
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid mock token - use 'testuser'",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def mock_initiate_auth(username: str, password: str) -> dict:
|
||||||
|
"""
|
||||||
|
Mock version of initiate_auth for local testing
|
||||||
|
Accepts any username/password and returns a mock token
|
||||||
|
"""
|
||||||
|
return {"AccessToken": "testuser", "ExpiresIn": 3600, "TokenType": "Bearer"}
|
||||||
@@ -1,52 +0,0 @@
|
|||||||
import os
|
|
||||||
import boto3
|
|
||||||
import requests
|
|
||||||
from fastapi import Depends, HTTPException, status
|
|
||||||
from fastapi.security import OAuth2AuthorizationCodeBearer
|
|
||||||
from fastapi.responses import RedirectResponse
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
REGION = "us-east-2"
|
|
||||||
USER_POOL_ID = os.getenv("COGNITO_USER_POOL_ID")
|
|
||||||
CLIENT_ID = os.getenv("COGNITO_CLIENT_ID")
|
|
||||||
DOMAIN = f"https://iptv-updater.auth.{REGION}.amazoncognito.com"
|
|
||||||
REDIRECT_URI = "http://localhost:8000/auth/callback"
|
|
||||||
|
|
||||||
oauth2_scheme = OAuth2AuthorizationCodeBearer(
|
|
||||||
authorizationUrl=f"{DOMAIN}/oauth2/authorize",
|
|
||||||
tokenUrl=f"{DOMAIN}/oauth2/token"
|
|
||||||
)
|
|
||||||
|
|
||||||
def exchange_code_for_token(code: str):
|
|
||||||
token_url = f"{DOMAIN}/oauth2/token"
|
|
||||||
data = {
|
|
||||||
'grant_type': 'authorization_code',
|
|
||||||
'client_id': CLIENT_ID,
|
|
||||||
'code': code,
|
|
||||||
'redirect_uri': REDIRECT_URI
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(token_url, data=data)
|
|
||||||
if response.status_code == 200:
|
|
||||||
return response.json()
|
|
||||||
raise HTTPException(status_code=400, detail="Failed to exchange code for token")
|
|
||||||
|
|
||||||
async def get_current_user(token: str = Depends(oauth2_scheme)):
|
|
||||||
if not token:
|
|
||||||
return RedirectResponse(
|
|
||||||
f"{DOMAIN}/login?client_id={CLIENT_ID}"
|
|
||||||
f"&response_type=code"
|
|
||||||
f"&scope=openid"
|
|
||||||
f"&redirect_uri={REDIRECT_URI}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
cognito = boto3.client('cognito-idp', region_name=REGION)
|
|
||||||
response = cognito.get_user(AccessToken=token)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Invalid authentication credentials",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
@@ -1,39 +1,59 @@
|
|||||||
import os
|
import argparse
|
||||||
import re
|
|
||||||
import gzip
|
import gzip
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import argparse
|
from utils.constants import (
|
||||||
from utils.config import IPTV_SERVER_ADMIN_PASSWORD, IPTV_SERVER_ADMIN_USER, IPTV_SERVER_URL
|
IPTV_SERVER_ADMIN_PASSWORD,
|
||||||
|
IPTV_SERVER_ADMIN_USER,
|
||||||
|
IPTV_SERVER_URL,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_arguments():
|
def parse_arguments():
|
||||||
parser = argparse.ArgumentParser(description='EPG Grabber')
|
parser = argparse.ArgumentParser(description="EPG Grabber")
|
||||||
parser.add_argument('--playlist',
|
parser.add_argument(
|
||||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'playlist.m3u8'),
|
"--playlist",
|
||||||
help='Path to playlist file')
|
default=os.path.join(
|
||||||
parser.add_argument('--output',
|
os.path.dirname(os.path.dirname(__file__)), "playlist.m3u8"
|
||||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'epg.xml'),
|
),
|
||||||
help='Path to output EPG XML file')
|
help="Path to playlist file",
|
||||||
parser.add_argument('--epg-sources',
|
)
|
||||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'epg_sources.json'),
|
parser.add_argument(
|
||||||
help='Path to EPG sources JSON configuration file')
|
"--output",
|
||||||
parser.add_argument('--save-as-gz',
|
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "epg.xml"),
|
||||||
action='store_true',
|
help="Path to output EPG XML file",
|
||||||
default=True,
|
)
|
||||||
help='Save an additional gzipped version of the EPG file')
|
parser.add_argument(
|
||||||
|
"--epg-sources",
|
||||||
|
default=os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(__file__)), "epg_sources.json"
|
||||||
|
),
|
||||||
|
help="Path to EPG sources JSON configuration file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-as-gz",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Save an additional gzipped version of the EPG file",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def load_epg_sources(config_path):
|
def load_epg_sources(config_path):
|
||||||
"""Load EPG sources from JSON configuration file"""
|
"""Load EPG sources from JSON configuration file"""
|
||||||
try:
|
try:
|
||||||
with open(config_path, 'r', encoding='utf-8') as f:
|
with open(config_path, encoding="utf-8") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
return config.get('epg_sources', [])
|
return config.get("epg_sources", [])
|
||||||
except (FileNotFoundError, json.JSONDecodeError) as e:
|
except (FileNotFoundError, json.JSONDecodeError) as e:
|
||||||
print(f"Error loading EPG sources: {e}")
|
print(f"Error loading EPG sources: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def get_tvg_ids(playlist_path):
|
def get_tvg_ids(playlist_path):
|
||||||
"""
|
"""
|
||||||
Extracts unique tvg-id values from an M3U playlist file.
|
Extracts unique tvg-id values from an M3U playlist file.
|
||||||
@@ -51,26 +71,27 @@ def get_tvg_ids(playlist_path):
|
|||||||
# and ends with a double quote.
|
# and ends with a double quote.
|
||||||
tvg_id_pattern = re.compile(r'tvg-id="([^"]*)"')
|
tvg_id_pattern = re.compile(r'tvg-id="([^"]*)"')
|
||||||
|
|
||||||
with open(playlist_path, 'r', encoding='utf-8') as file:
|
with open(playlist_path, encoding="utf-8") as file:
|
||||||
for line in file:
|
for line in file:
|
||||||
if line.startswith('#EXTINF'):
|
if line.startswith("#EXTINF"):
|
||||||
# Search for the tvg-id pattern in the line
|
# Search for the tvg-id pattern in the line
|
||||||
match = tvg_id_pattern.search(line)
|
match = tvg_id_pattern.search(line)
|
||||||
if match:
|
if match:
|
||||||
# Extract the captured group (the value inside the quotes)
|
# Extract the captured group (the value inside the quotes)
|
||||||
tvg_id = match.group(1)
|
tvg_id = match.group(1)
|
||||||
if tvg_id: # Ensure the extracted id is not empty
|
if tvg_id: # Ensure the extracted id is not empty
|
||||||
unique_tvg_ids.add(tvg_id)
|
unique_tvg_ids.add(tvg_id)
|
||||||
|
|
||||||
return list(unique_tvg_ids)
|
return list(unique_tvg_ids)
|
||||||
|
|
||||||
|
|
||||||
def fetch_and_extract_xml(url):
|
def fetch_and_extract_xml(url):
|
||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
print(f"Failed to fetch {url}")
|
print(f"Failed to fetch {url}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if url.endswith('.gz'):
|
if url.endswith(".gz"):
|
||||||
try:
|
try:
|
||||||
decompressed_data = gzip.decompress(response.content)
|
decompressed_data = gzip.decompress(response.content)
|
||||||
return ET.fromstring(decompressed_data)
|
return ET.fromstring(decompressed_data)
|
||||||
@@ -84,42 +105,46 @@ def fetch_and_extract_xml(url):
|
|||||||
print(f"Failed to parse XML from {url}: {e}")
|
print(f"Failed to parse XML from {url}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def filter_and_build_epg(urls, tvg_ids, output_file, save_as_gz=True):
|
def filter_and_build_epg(urls, tvg_ids, output_file, save_as_gz=True):
|
||||||
root = ET.Element('tv')
|
root = ET.Element("tv")
|
||||||
|
|
||||||
for url in urls:
|
for url in urls:
|
||||||
epg_data = fetch_and_extract_xml(url)
|
epg_data = fetch_and_extract_xml(url)
|
||||||
if epg_data is None:
|
if epg_data is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for channel in epg_data.findall('channel'):
|
for channel in epg_data.findall("channel"):
|
||||||
tvg_id = channel.get('id')
|
tvg_id = channel.get("id")
|
||||||
if tvg_id in tvg_ids:
|
if tvg_id in tvg_ids:
|
||||||
root.append(channel)
|
root.append(channel)
|
||||||
|
|
||||||
for programme in epg_data.findall('programme'):
|
for programme in epg_data.findall("programme"):
|
||||||
tvg_id = programme.get('channel')
|
tvg_id = programme.get("channel")
|
||||||
if tvg_id in tvg_ids:
|
if tvg_id in tvg_ids:
|
||||||
root.append(programme)
|
root.append(programme)
|
||||||
|
|
||||||
tree = ET.ElementTree(root)
|
tree = ET.ElementTree(root)
|
||||||
tree.write(output_file, encoding='utf-8', xml_declaration=True)
|
tree.write(output_file, encoding="utf-8", xml_declaration=True)
|
||||||
print(f"New EPG saved to {output_file}")
|
print(f"New EPG saved to {output_file}")
|
||||||
|
|
||||||
if save_as_gz:
|
if save_as_gz:
|
||||||
output_file_gz = output_file + '.gz'
|
output_file_gz = output_file + ".gz"
|
||||||
with gzip.open(output_file_gz, 'wb') as f:
|
with gzip.open(output_file_gz, "wb") as f:
|
||||||
tree.write(f, encoding='utf-8', xml_declaration=True)
|
tree.write(f, encoding="utf-8", xml_declaration=True)
|
||||||
print(f"New EPG saved to {output_file_gz}")
|
print(f"New EPG saved to {output_file_gz}")
|
||||||
|
|
||||||
|
|
||||||
def upload_epg(file_path):
|
def upload_epg(file_path):
|
||||||
"""Uploads gzipped EPG file to IPTV server using HTTP Basic Auth"""
|
"""Uploads gzipped EPG file to IPTV server using HTTP Basic Auth"""
|
||||||
try:
|
try:
|
||||||
with open(file_path, 'rb') as f:
|
with open(file_path, "rb") as f:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
IPTV_SERVER_URL + '/admin/epg',
|
IPTV_SERVER_URL + "/admin/epg",
|
||||||
auth=requests.auth.HTTPBasicAuth(IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD),
|
auth=requests.auth.HTTPBasicAuth(
|
||||||
files={'file': (os.path.basename(file_path), f)}
|
IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD
|
||||||
|
),
|
||||||
|
files={"file": (os.path.basename(file_path), f)},
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
@@ -129,6 +154,7 @@ def upload_epg(file_path):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Upload error: {str(e)}")
|
print(f"Upload error: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
playlist_file = args.playlist
|
playlist_file = args.playlist
|
||||||
@@ -144,4 +170,4 @@ if __name__ == "__main__":
|
|||||||
filter_and_build_epg(urls, tvg_ids, output_file, args.save_as_gz)
|
filter_and_build_epg(urls, tvg_ids, output_file, args.save_as_gz)
|
||||||
|
|
||||||
if args.save_as_gz:
|
if args.save_as_gz:
|
||||||
upload_epg(output_file + '.gz')
|
upload_epg(output_file + ".gz")
|
||||||
@@ -1,26 +1,45 @@
|
|||||||
import os
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import requests
|
import os
|
||||||
from pathlib import Path
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
import requests
|
||||||
from utils.check_streams import StreamValidator
|
from utils.check_streams import StreamValidator
|
||||||
from utils.config import EPG_URL, IPTV_SERVER_ADMIN_PASSWORD, IPTV_SERVER_ADMIN_USER, IPTV_SERVER_URL
|
from utils.constants import (
|
||||||
|
EPG_URL,
|
||||||
|
IPTV_SERVER_ADMIN_PASSWORD,
|
||||||
|
IPTV_SERVER_ADMIN_USER,
|
||||||
|
IPTV_SERVER_URL,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_arguments():
|
def parse_arguments():
|
||||||
parser = argparse.ArgumentParser(description='IPTV playlist generator')
|
parser = argparse.ArgumentParser(description="IPTV playlist generator")
|
||||||
parser.add_argument('--output',
|
parser.add_argument(
|
||||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'playlist.m3u8'),
|
"--output",
|
||||||
help='Path to output playlist file')
|
default=os.path.join(
|
||||||
parser.add_argument('--channels',
|
os.path.dirname(os.path.dirname(__file__)), "playlist.m3u8"
|
||||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'channels.json'),
|
),
|
||||||
help='Path to channels definition JSON file')
|
help="Path to output playlist file",
|
||||||
parser.add_argument('--dead-channels-log',
|
)
|
||||||
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'dead_channels.log'),
|
parser.add_argument(
|
||||||
help='Path to log file to store a list of dead channels')
|
"--channels",
|
||||||
|
default=os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(__file__)), "channels.json"
|
||||||
|
),
|
||||||
|
help="Path to channels definition JSON file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dead-channels-log",
|
||||||
|
default=os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(__file__)), "dead_channels.log"
|
||||||
|
),
|
||||||
|
help="Path to log file to store a list of dead channels",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def find_working_stream(validator, urls):
|
def find_working_stream(validator, urls):
|
||||||
"""Test all URLs and return the first working one"""
|
"""Test all URLs and return the first working one"""
|
||||||
for url in urls:
|
for url in urls:
|
||||||
@@ -29,9 +48,10 @@ def find_working_stream(validator, urls):
|
|||||||
return url
|
return url
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def create_playlist(channels_file, output_file):
|
def create_playlist(channels_file, output_file):
|
||||||
# Read channels from JSON file
|
# Read channels from JSON file
|
||||||
with open(channels_file, 'r', encoding='utf-8') as f:
|
with open(channels_file, encoding="utf-8") as f:
|
||||||
channels = json.load(f)
|
channels = json.load(f)
|
||||||
|
|
||||||
# Initialize validator
|
# Initialize validator
|
||||||
@@ -41,9 +61,9 @@ def create_playlist(channels_file, output_file):
|
|||||||
m3u8_content = f'#EXTM3U url-tvg="{EPG_URL}"\n'
|
m3u8_content = f'#EXTM3U url-tvg="{EPG_URL}"\n'
|
||||||
|
|
||||||
for channel in channels:
|
for channel in channels:
|
||||||
if 'urls' in channel: # Check if channel has URLs
|
if "urls" in channel: # Check if channel has URLs
|
||||||
# Find first working stream
|
# Find first working stream
|
||||||
working_url = find_working_stream(validator, channel['urls'])
|
working_url = find_working_stream(validator, channel["urls"])
|
||||||
|
|
||||||
if working_url:
|
if working_url:
|
||||||
# Add channel to playlist
|
# Add channel to playlist
|
||||||
@@ -51,24 +71,30 @@ def create_playlist(channels_file, output_file):
|
|||||||
m3u8_content += f'tvg-name="{channel.get("tvg-name", "")}" '
|
m3u8_content += f'tvg-name="{channel.get("tvg-name", "")}" '
|
||||||
m3u8_content += f'tvg-logo="{channel.get("tvg-logo", "")}" '
|
m3u8_content += f'tvg-logo="{channel.get("tvg-logo", "")}" '
|
||||||
m3u8_content += f'group-title="{channel.get("group-title", "")}", '
|
m3u8_content += f'group-title="{channel.get("group-title", "")}", '
|
||||||
m3u8_content += f'{channel.get("name", "")}\n'
|
m3u8_content += f"{channel.get('name', '')}\n"
|
||||||
m3u8_content += f'{working_url}\n'
|
m3u8_content += f"{working_url}\n"
|
||||||
else:
|
else:
|
||||||
# Log dead channel
|
# Log dead channel
|
||||||
logging.info(f'Dead channel: {channel.get("name", "Unknown")} - No working streams found')
|
logging.info(
|
||||||
|
f"Dead channel: {channel.get('name', 'Unknown')} - "
|
||||||
|
"No working streams found"
|
||||||
|
)
|
||||||
|
|
||||||
# Write playlist file
|
# Write playlist file
|
||||||
with open(output_file, 'w', encoding='utf-8') as f:
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
f.write(m3u8_content)
|
f.write(m3u8_content)
|
||||||
|
|
||||||
|
|
||||||
def upload_playlist(file_path):
|
def upload_playlist(file_path):
|
||||||
"""Uploads playlist file to IPTV server using HTTP Basic Auth"""
|
"""Uploads playlist file to IPTV server using HTTP Basic Auth"""
|
||||||
try:
|
try:
|
||||||
with open(file_path, 'rb') as f:
|
with open(file_path, "rb") as f:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
IPTV_SERVER_URL + '/admin/playlist',
|
IPTV_SERVER_URL + "/admin/playlist",
|
||||||
auth=requests.auth.HTTPBasicAuth(IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD),
|
auth=requests.auth.HTTPBasicAuth(
|
||||||
files={'file': (os.path.basename(file_path), f)}
|
IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD
|
||||||
|
),
|
||||||
|
files={"file": (os.path.basename(file_path), f)},
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
@@ -78,6 +104,7 @@ def upload_playlist(file_path):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Upload error: {str(e)}")
|
print(f"Upload error: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
channels_file = args.channels
|
channels_file = args.channels
|
||||||
@@ -85,24 +112,25 @@ def main():
|
|||||||
dead_channels_log_file = args.dead_channels_log
|
dead_channels_log_file = args.dead_channels_log
|
||||||
|
|
||||||
# Clear previous log file
|
# Clear previous log file
|
||||||
with open(dead_channels_log_file, 'w') as f:
|
with open(dead_channels_log_file, "w") as f:
|
||||||
f.write(f'Log created on {datetime.now()}\n')
|
f.write(f"Log created on {datetime.now()}\n")
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
filename=dead_channels_log_file,
|
filename=dead_channels_log_file,
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format='%(asctime)s - %(message)s',
|
format="%(asctime)s - %(message)s",
|
||||||
datefmt='%Y-%m-%d %H:%M:%S'
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create playlist
|
# Create playlist
|
||||||
create_playlist(channels_file, output_file)
|
create_playlist(channels_file, output_file)
|
||||||
|
|
||||||
#upload playlist to server
|
# upload playlist to server
|
||||||
upload_playlist(output_file)
|
upload_playlist(output_file)
|
||||||
|
|
||||||
print("Playlist creation completed!")
|
print("Playlist creation completed!")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
110
app/iptv/scheduler.py
Normal file
110
app/iptv/scheduler.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from apscheduler.schedulers.background import BackgroundScheduler
|
||||||
|
from apscheduler.triggers.cron import CronTrigger
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.iptv.stream_manager import StreamManager
|
||||||
|
from app.models.db import ChannelDB
|
||||||
|
from app.utils.database import get_db_session
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamScheduler:
|
||||||
|
"""Scheduler service for periodic stream validation tasks."""
|
||||||
|
|
||||||
|
def __init__(self, app: Optional[FastAPI] = None):
|
||||||
|
"""
|
||||||
|
Initialize the scheduler with optional FastAPI app integration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: Optional FastAPI app instance for lifecycle integration
|
||||||
|
"""
|
||||||
|
self.scheduler = BackgroundScheduler()
|
||||||
|
self.app = app
|
||||||
|
self.batch_size = int(os.getenv("STREAM_VALIDATION_BATCH_SIZE", "10"))
|
||||||
|
self.schedule_time = os.getenv(
|
||||||
|
"STREAM_VALIDATION_SCHEDULE", "0 3 * * *"
|
||||||
|
) # Default 3 AM daily
|
||||||
|
logger.info(f"Scheduler initialized with app: {app is not None}")
|
||||||
|
|
||||||
|
def validate_streams_batch(self, db_session: Optional[Session] = None) -> None:
|
||||||
|
"""
|
||||||
|
Validate streams and update their status.
|
||||||
|
When batch_size=0, validates all channels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Optional SQLAlchemy session
|
||||||
|
"""
|
||||||
|
db = db_session if db_session else get_db_session()
|
||||||
|
try:
|
||||||
|
manager = StreamManager(db)
|
||||||
|
|
||||||
|
# Get channels to validate
|
||||||
|
query = db.query(ChannelDB)
|
||||||
|
if self.batch_size > 0:
|
||||||
|
query = query.limit(self.batch_size)
|
||||||
|
channels = query.all()
|
||||||
|
|
||||||
|
for channel in channels:
|
||||||
|
try:
|
||||||
|
logger.info(f"Validating streams for channel {channel.id}")
|
||||||
|
manager.validate_and_select_stream(str(channel.id))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error validating channel {channel.id}: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(f"Completed stream validation of {len(channels)} channels")
|
||||||
|
finally:
|
||||||
|
if db_session is None:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
"""Start the scheduler and add jobs."""
|
||||||
|
if not self.scheduler.running:
|
||||||
|
# Add the scheduled job
|
||||||
|
self.scheduler.add_job(
|
||||||
|
self.validate_streams_batch,
|
||||||
|
trigger=CronTrigger.from_crontab(self.schedule_time),
|
||||||
|
id="daily_stream_validation",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start the scheduler
|
||||||
|
self.scheduler.start()
|
||||||
|
logger.info(
|
||||||
|
f"Stream scheduler started with daily validation job. "
|
||||||
|
f"Running: {self.scheduler.running}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register shutdown handler if FastAPI app is provided
|
||||||
|
if self.app:
|
||||||
|
logger.info(
|
||||||
|
f"Registering scheduler with FastAPI "
|
||||||
|
f"app: {hasattr(self.app, 'state')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@self.app.on_event("shutdown")
|
||||||
|
def shutdown_scheduler():
|
||||||
|
self.shutdown()
|
||||||
|
|
||||||
|
def shutdown(self) -> None:
|
||||||
|
"""Shutdown the scheduler gracefully."""
|
||||||
|
if self.scheduler.running:
|
||||||
|
self.scheduler.shutdown()
|
||||||
|
logger.info("Stream scheduler stopped")
|
||||||
|
|
||||||
|
def trigger_manual_validation(self) -> None:
|
||||||
|
"""Trigger manual validation of streams."""
|
||||||
|
logger.info("Manually triggering stream validation")
|
||||||
|
self.validate_streams_batch()
|
||||||
|
|
||||||
|
|
||||||
|
def init_scheduler(app: FastAPI) -> StreamScheduler:
|
||||||
|
"""Initialize and start the scheduler with FastAPI integration."""
|
||||||
|
scheduler = StreamScheduler(app)
|
||||||
|
scheduler.start()
|
||||||
|
return scheduler
|
||||||
151
app/iptv/stream_manager.py
Normal file
151
app/iptv/stream_manager.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
import logging
|
||||||
|
import random
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.models.db import ChannelURL
|
||||||
|
from app.utils.check_streams import StreamValidator
|
||||||
|
from app.utils.database import get_db_session
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamManager:
|
||||||
|
"""Service for managing and validating channel streams."""
|
||||||
|
|
||||||
|
def __init__(self, db_session: Optional[Session] = None):
|
||||||
|
"""
|
||||||
|
Initialize StreamManager with optional database session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Optional SQLAlchemy session. If None, will create a new one.
|
||||||
|
"""
|
||||||
|
self.db = db_session if db_session else get_db_session()
|
||||||
|
self.validator = StreamValidator()
|
||||||
|
|
||||||
|
def get_streams_for_channel(self, channel_id: str) -> list[ChannelURL]:
|
||||||
|
"""
|
||||||
|
Get all streams for a channel ordered by priority (lowest first),
|
||||||
|
with same-priority streams randomized.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel_id: UUID of the channel to get streams for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ChannelURL objects ordered by priority
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get all streams for channel ordered by priority
|
||||||
|
streams = (
|
||||||
|
self.db.query(ChannelURL)
|
||||||
|
.filter(ChannelURL.channel_id == channel_id)
|
||||||
|
.order_by(ChannelURL.priority_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Group streams by priority and randomize same-priority streams
|
||||||
|
grouped = {}
|
||||||
|
for stream in streams:
|
||||||
|
if stream.priority_id not in grouped:
|
||||||
|
grouped[stream.priority_id] = []
|
||||||
|
grouped[stream.priority_id].append(stream)
|
||||||
|
|
||||||
|
# Randomize same-priority streams and flatten
|
||||||
|
randomized_streams = []
|
||||||
|
for priority in sorted(grouped.keys()):
|
||||||
|
random.shuffle(grouped[priority])
|
||||||
|
randomized_streams.extend(grouped[priority])
|
||||||
|
|
||||||
|
return randomized_streams
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting streams for channel {channel_id}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def validate_and_select_stream(self, channel_id: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Find and validate a working stream for the given channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel_id: UUID of the channel to find a stream for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
URL of the first working stream found, or None if none found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
streams = self.get_streams_for_channel(channel_id)
|
||||||
|
if not streams:
|
||||||
|
logger.warning(f"No streams found for channel {channel_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
working_stream = None
|
||||||
|
|
||||||
|
for stream in streams:
|
||||||
|
logger.info(f"Validating stream {stream.url} for channel {channel_id}")
|
||||||
|
is_valid, _ = self.validator.validate_stream(stream.url)
|
||||||
|
|
||||||
|
if is_valid:
|
||||||
|
working_stream = stream
|
||||||
|
break
|
||||||
|
|
||||||
|
if working_stream:
|
||||||
|
self._update_stream_status(working_stream, streams)
|
||||||
|
return working_stream.url
|
||||||
|
else:
|
||||||
|
logger.warning(f"No valid streams found for channel {channel_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error validating streams for channel {channel_id}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _update_stream_status(
|
||||||
|
self, working_stream: ChannelURL, all_streams: list[ChannelURL]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Update in_use status for streams (True for working stream, False for others).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
working_stream: The stream that was validated as working
|
||||||
|
all_streams: All streams for the channel
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
for stream in all_streams:
|
||||||
|
stream.in_use = stream.id == working_stream.id
|
||||||
|
|
||||||
|
self.db.commit()
|
||||||
|
logger.info(
|
||||||
|
f"Updated stream status - set in_use=True for {working_stream.url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.db.rollback()
|
||||||
|
logger.error(f"Error updating stream status: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""Close database session when StreamManager is destroyed."""
|
||||||
|
if hasattr(self, "db"):
|
||||||
|
self.db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def get_working_stream(
|
||||||
|
channel_id: str, db_session: Optional[Session] = None
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Convenience function to get a working stream for a channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel_id: UUID of the channel to get a stream for
|
||||||
|
db_session: Optional SQLAlchemy session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
URL of the first working stream found, or None if none found
|
||||||
|
"""
|
||||||
|
manager = StreamManager(db_session)
|
||||||
|
try:
|
||||||
|
return manager.validate_and_select_stream(channel_id)
|
||||||
|
finally:
|
||||||
|
if db_session is None: # Only close if we created the session
|
||||||
|
manager.__del__()
|
||||||
113
app/main.py
113
app/main.py
@@ -1,43 +1,82 @@
|
|||||||
from fastapi import FastAPI, Depends, HTTPException
|
from fastapi import FastAPI
|
||||||
from fastapi.responses import JSONResponse, RedirectResponse
|
from fastapi.concurrency import asynccontextmanager
|
||||||
from app.cabletv.utils.auth import exchange_code_for_token, get_current_user, DOMAIN, CLIENT_ID
|
from fastapi.openapi.utils import get_openapi
|
||||||
|
|
||||||
|
from app.iptv.scheduler import StreamScheduler
|
||||||
|
from app.routers import auth, channels, groups, playlist, priorities, scheduler
|
||||||
|
from app.utils.database import init_db
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
# Initialize database tables on startup
|
||||||
|
init_db()
|
||||||
|
|
||||||
|
# Initialize and start the stream scheduler
|
||||||
|
scheduler = StreamScheduler(app)
|
||||||
|
app.state.scheduler = scheduler # Store scheduler in app state
|
||||||
|
scheduler.start()
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Shutdown scheduler on app shutdown
|
||||||
|
scheduler.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
lifespan=lifespan,
|
||||||
|
title="IPTV Manager API",
|
||||||
|
description="API for IPTV Manager service",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def custom_openapi():
|
||||||
|
if app.openapi_schema:
|
||||||
|
return app.openapi_schema
|
||||||
|
|
||||||
|
openapi_schema = get_openapi(
|
||||||
|
title=app.title,
|
||||||
|
version=app.version,
|
||||||
|
description=app.description,
|
||||||
|
routes=app.routes,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure components object exists
|
||||||
|
if "components" not in openapi_schema:
|
||||||
|
openapi_schema["components"] = {}
|
||||||
|
|
||||||
|
# Add schemas if they don't exist
|
||||||
|
if "schemas" not in openapi_schema["components"]:
|
||||||
|
openapi_schema["components"]["schemas"] = {}
|
||||||
|
|
||||||
|
# Add security scheme component
|
||||||
|
openapi_schema["components"]["securitySchemes"] = {
|
||||||
|
"Bearer": {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add global security requirement
|
||||||
|
openapi_schema["security"] = [{"Bearer": []}]
|
||||||
|
|
||||||
|
# Set OpenAPI version explicitly
|
||||||
|
openapi_schema["openapi"] = "3.1.0"
|
||||||
|
|
||||||
|
app.openapi_schema = openapi_schema
|
||||||
|
return app.openapi_schema
|
||||||
|
|
||||||
|
|
||||||
|
app.openapi = custom_openapi
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def root():
|
async def root():
|
||||||
return {"message": "IPTV Updater API"}
|
return {"message": "IPTV Manager API"}
|
||||||
|
|
||||||
@app.get("/protected")
|
|
||||||
async def protected_route(user = Depends(get_current_user)):
|
|
||||||
if isinstance(user, RedirectResponse):
|
|
||||||
return user
|
|
||||||
return {"message": "Protected content", "user": user['Username']}
|
|
||||||
|
|
||||||
@app.get("/auth/callback")
|
# Include routers
|
||||||
async def auth_callback(code: str):
|
app.include_router(auth.router)
|
||||||
try:
|
app.include_router(channels.router)
|
||||||
# Exchange the authorization code for tokens
|
app.include_router(playlist.router)
|
||||||
tokens = exchange_code_for_token(code)
|
app.include_router(priorities.router)
|
||||||
|
app.include_router(groups.router)
|
||||||
# Create a response with the access token
|
app.include_router(scheduler.router)
|
||||||
response = JSONResponse(content={
|
|
||||||
"message": "Authentication successful",
|
|
||||||
"access_token": tokens["access_token"]
|
|
||||||
})
|
|
||||||
|
|
||||||
# Set the access token as a cookie
|
|
||||||
response.set_cookie(
|
|
||||||
key="access_token",
|
|
||||||
value=tokens["access_token"],
|
|
||||||
httponly=True,
|
|
||||||
secure=True,
|
|
||||||
samesite="lax"
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"Authentication failed: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|||||||
27
app/models/__init__.py
Normal file
27
app/models/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from .db import Base, ChannelDB, ChannelURL, Group, Priority
|
||||||
|
from .schemas import (
|
||||||
|
ChannelCreate,
|
||||||
|
ChannelResponse,
|
||||||
|
ChannelUpdate,
|
||||||
|
ChannelURLCreate,
|
||||||
|
ChannelURLResponse,
|
||||||
|
GroupCreate,
|
||||||
|
GroupResponse,
|
||||||
|
GroupUpdate,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Base",
|
||||||
|
"ChannelDB",
|
||||||
|
"ChannelCreate",
|
||||||
|
"ChannelUpdate",
|
||||||
|
"ChannelResponse",
|
||||||
|
"ChannelURL",
|
||||||
|
"ChannelURLCreate",
|
||||||
|
"ChannelURLResponse",
|
||||||
|
"Group",
|
||||||
|
"Priority",
|
||||||
|
"GroupCreate",
|
||||||
|
"GroupResponse",
|
||||||
|
"GroupUpdate",
|
||||||
|
]
|
||||||
26
app/models/auth.py
Normal file
26
app/models/auth.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class SigninRequest(BaseModel):
|
||||||
|
"""Request model for the signin endpoint."""
|
||||||
|
|
||||||
|
username: str = Field(..., description="The user's username")
|
||||||
|
password: str = Field(..., description="The user's password")
|
||||||
|
|
||||||
|
|
||||||
|
class TokenResponse(BaseModel):
|
||||||
|
"""Response model for successful authentication."""
|
||||||
|
|
||||||
|
access_token: str = Field(..., description="Access JWT token from Cognito")
|
||||||
|
id_token: str = Field(..., description="ID JWT token from Cognito")
|
||||||
|
refresh_token: Optional[str] = Field(None, description="Refresh token from Cognito")
|
||||||
|
token_type: str = Field(..., description="Type of the token returned")
|
||||||
|
|
||||||
|
|
||||||
|
class CognitoUser(BaseModel):
|
||||||
|
"""Model representing the user returned from token verification."""
|
||||||
|
|
||||||
|
username: str
|
||||||
|
roles: list[str]
|
||||||
139
app/models/db.py
Normal file
139
app/models/db.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
TEXT,
|
||||||
|
Boolean,
|
||||||
|
Column,
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
Integer,
|
||||||
|
String,
|
||||||
|
TypeDecorator,
|
||||||
|
UniqueConstraint,
|
||||||
|
)
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
from sqlalchemy.orm import declarative_base, relationship
|
||||||
|
|
||||||
|
|
||||||
|
# Custom UUID type for SQLite compatibility
|
||||||
|
class SQLiteUUID(TypeDecorator):
|
||||||
|
"""Enables UUID support for SQLite with proper comparison handling."""
|
||||||
|
|
||||||
|
impl = TEXT
|
||||||
|
cache_ok = True
|
||||||
|
|
||||||
|
def process_bind_param(self, value, dialect):
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
if isinstance(value, uuid.UUID):
|
||||||
|
return str(value)
|
||||||
|
try:
|
||||||
|
# Validate string format by attempting to create UUID
|
||||||
|
uuid.UUID(value)
|
||||||
|
return value
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
raise ValueError(f"Invalid UUID string format: {value}")
|
||||||
|
|
||||||
|
def process_result_value(self, value, dialect):
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
if isinstance(value, uuid.UUID):
|
||||||
|
return value
|
||||||
|
return uuid.UUID(value)
|
||||||
|
|
||||||
|
def compare_values(self, x, y):
|
||||||
|
if x is None or y is None:
|
||||||
|
return x == y
|
||||||
|
return str(x) == str(y)
|
||||||
|
|
||||||
|
|
||||||
|
# Determine which UUID type to use based on environment
|
||||||
|
if os.getenv("MOCK_AUTH", "").lower() == "true":
|
||||||
|
UUID_COLUMN_TYPE = SQLiteUUID()
|
||||||
|
else:
|
||||||
|
UUID_COLUMN_TYPE = UUID(as_uuid=True)
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
class Priority(Base):
|
||||||
|
"""SQLAlchemy model for channel URL priorities"""
|
||||||
|
|
||||||
|
__tablename__ = "priorities"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
description = Column(String, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
class Group(Base):
|
||||||
|
"""SQLAlchemy model for channel groups"""
|
||||||
|
|
||||||
|
__tablename__ = "groups"
|
||||||
|
|
||||||
|
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
||||||
|
name = Column(String, nullable=False, unique=True)
|
||||||
|
sort_order = Column(Integer, nullable=False, default=0)
|
||||||
|
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||||
|
updated_at = Column(
|
||||||
|
DateTime,
|
||||||
|
default=lambda: datetime.now(timezone.utc),
|
||||||
|
onupdate=lambda: datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Relationship with Channel
|
||||||
|
channels = relationship("ChannelDB", back_populates="group")
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelDB(Base):
|
||||||
|
"""SQLAlchemy model for IPTV channels"""
|
||||||
|
|
||||||
|
__tablename__ = "channels"
|
||||||
|
|
||||||
|
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
||||||
|
tvg_id = Column(String, nullable=False)
|
||||||
|
name = Column(String, nullable=False)
|
||||||
|
group_id = Column(UUID_COLUMN_TYPE, ForeignKey("groups.id"), nullable=False)
|
||||||
|
tvg_name = Column(String)
|
||||||
|
|
||||||
|
__table_args__ = (UniqueConstraint("group_id", "name", name="uix_group_id_name"),)
|
||||||
|
tvg_logo = Column(String)
|
||||||
|
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||||
|
updated_at = Column(
|
||||||
|
DateTime,
|
||||||
|
default=lambda: datetime.now(timezone.utc),
|
||||||
|
onupdate=lambda: datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
urls = relationship(
|
||||||
|
"ChannelURL", back_populates="channel", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
group = relationship("Group", back_populates="channels")
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelURL(Base):
|
||||||
|
"""SQLAlchemy model for channel URLs"""
|
||||||
|
|
||||||
|
__tablename__ = "channels_urls"
|
||||||
|
|
||||||
|
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
||||||
|
channel_id = Column(
|
||||||
|
UUID_COLUMN_TYPE,
|
||||||
|
ForeignKey("channels.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
url = Column(String, nullable=False)
|
||||||
|
in_use = Column(Boolean, default=False, nullable=False)
|
||||||
|
priority_id = Column(Integer, ForeignKey("priorities.id"), nullable=False)
|
||||||
|
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||||
|
updated_at = Column(
|
||||||
|
DateTime,
|
||||||
|
default=lambda: datetime.now(timezone.utc),
|
||||||
|
onupdate=lambda: datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
channel = relationship("ChannelDB", back_populates="urls")
|
||||||
|
priority = relationship("Priority")
|
||||||
140
app/models/schemas.py
Normal file
140
app/models/schemas.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
|
class PriorityBase(BaseModel):
|
||||||
|
"""Base Pydantic model for priorities"""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
description: str
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class PriorityCreate(PriorityBase):
|
||||||
|
"""Pydantic model for creating priorities"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PriorityResponse(PriorityBase):
|
||||||
|
"""Pydantic model for priority responses"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelURLCreate(BaseModel):
|
||||||
|
"""Pydantic model for creating channel URLs"""
|
||||||
|
|
||||||
|
url: str
|
||||||
|
priority_id: int = Field(
|
||||||
|
default=100, ge=100, le=300
|
||||||
|
) # Default to High, validate range
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelURLBase(ChannelURLCreate):
|
||||||
|
"""Base Pydantic model for channel URL responses"""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
in_use: bool
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
priority_id: int
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelURLResponse(ChannelURLBase):
|
||||||
|
"""Pydantic model for channel URL responses"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# New Group Schemas
|
||||||
|
class GroupCreate(BaseModel):
|
||||||
|
"""Pydantic model for creating groups"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
sort_order: int = Field(default=0, ge=0)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupUpdate(BaseModel):
|
||||||
|
"""Pydantic model for updating groups"""
|
||||||
|
|
||||||
|
name: Optional[str] = None
|
||||||
|
sort_order: Optional[int] = Field(None, ge=0)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupResponse(BaseModel):
|
||||||
|
"""Pydantic model for group responses"""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
name: str
|
||||||
|
sort_order: int
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupSortUpdate(BaseModel):
|
||||||
|
"""Pydantic model for updating a single group's sort order"""
|
||||||
|
|
||||||
|
sort_order: int = Field(ge=0)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupBulkSort(BaseModel):
|
||||||
|
"""Pydantic model for bulk updating group sort orders"""
|
||||||
|
|
||||||
|
groups: list[dict] = Field(
|
||||||
|
description="List of dicts with group_id and new sort_order",
|
||||||
|
json_schema_extra={"example": [{"group_id": "uuid", "sort_order": 1}]},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelCreate(BaseModel):
|
||||||
|
"""Pydantic model for creating channels"""
|
||||||
|
|
||||||
|
urls: list[ChannelURLCreate] # List of URL objects with priority
|
||||||
|
name: str
|
||||||
|
group_id: UUID
|
||||||
|
tvg_id: str
|
||||||
|
tvg_logo: str
|
||||||
|
tvg_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelURLUpdate(BaseModel):
|
||||||
|
"""Pydantic model for updating channel URLs"""
|
||||||
|
|
||||||
|
url: Optional[str] = None
|
||||||
|
in_use: Optional[bool] = None
|
||||||
|
priority_id: Optional[int] = Field(default=None, ge=100, le=300)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelUpdate(BaseModel):
|
||||||
|
"""Pydantic model for updating channels (all fields optional)"""
|
||||||
|
|
||||||
|
name: Optional[str] = Field(None, min_length=1)
|
||||||
|
group_id: Optional[UUID] = None
|
||||||
|
tvg_id: Optional[str] = Field(None, min_length=1)
|
||||||
|
tvg_logo: Optional[str] = None
|
||||||
|
tvg_name: Optional[str] = Field(None, min_length=1)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelResponse(BaseModel):
|
||||||
|
"""Pydantic model for channel responses"""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
name: str
|
||||||
|
group_id: UUID
|
||||||
|
tvg_id: str
|
||||||
|
tvg_logo: str
|
||||||
|
tvg_name: str
|
||||||
|
urls: list[ChannelURLResponse] # List of URL objects without channel_id
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
0
app/routers/__init__.py
Normal file
0
app/routers/__init__.py
Normal file
22
app/routers/auth.py
Normal file
22
app/routers/auth.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from app.auth.cognito import initiate_auth
|
||||||
|
from app.models.auth import SigninRequest, TokenResponse
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/auth", tags=["authentication"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/signin", response_model=TokenResponse, summary="Signin Endpoint")
|
||||||
|
def signin(credentials: SigninRequest):
|
||||||
|
"""
|
||||||
|
Sign-in endpoint to authenticate the user with AWS Cognito
|
||||||
|
using username and password.
|
||||||
|
On success, returns JWT tokens (access_token, id_token, refresh_token).
|
||||||
|
"""
|
||||||
|
auth_result = initiate_auth(credentials.username, credentials.password)
|
||||||
|
return TokenResponse(
|
||||||
|
access_token=auth_result["AccessToken"],
|
||||||
|
id_token=auth_result["IdToken"],
|
||||||
|
refresh_token=auth_result.get("RefreshToken"),
|
||||||
|
token_type="Bearer",
|
||||||
|
)
|
||||||
508
app/routers/channels.py
Normal file
508
app/routers/channels.py
Normal file
@@ -0,0 +1,508 @@
|
|||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy import and_
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.auth.dependencies import get_current_user, require_roles
|
||||||
|
from app.models import (
|
||||||
|
ChannelCreate,
|
||||||
|
ChannelDB,
|
||||||
|
ChannelResponse,
|
||||||
|
ChannelUpdate,
|
||||||
|
ChannelURL,
|
||||||
|
ChannelURLCreate,
|
||||||
|
ChannelURLResponse,
|
||||||
|
Group,
|
||||||
|
Priority, # Added Priority import
|
||||||
|
)
|
||||||
|
from app.models.auth import CognitoUser
|
||||||
|
from app.models.schemas import ChannelURLUpdate
|
||||||
|
from app.utils.database import get_db
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/channels", tags=["channels"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/", response_model=ChannelResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
@require_roles("admin")
|
||||||
|
def create_channel(
|
||||||
|
channel: ChannelCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Create a new channel"""
|
||||||
|
# Check if group exists
|
||||||
|
group = db.query(Group).filter(Group.id == channel.group_id).first()
|
||||||
|
if not group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Group not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for duplicate channel (same group_id + name)
|
||||||
|
existing_channel = (
|
||||||
|
db.query(ChannelDB)
|
||||||
|
.filter(
|
||||||
|
and_(
|
||||||
|
ChannelDB.group_id == channel.group_id,
|
||||||
|
ChannelDB.name == channel.name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if existing_channel:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail="Channel with same group_id and name already exists",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create channel without URLs first
|
||||||
|
channel_data = channel.model_dump(exclude={"urls"})
|
||||||
|
urls = channel.urls
|
||||||
|
db_channel = ChannelDB(**channel_data)
|
||||||
|
db.add(db_channel)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_channel)
|
||||||
|
|
||||||
|
# Add URLs with priority
|
||||||
|
for url in urls:
|
||||||
|
db_url = ChannelURL(
|
||||||
|
channel_id=db_channel.id,
|
||||||
|
url=url.url,
|
||||||
|
priority_id=url.priority_id,
|
||||||
|
in_use=False,
|
||||||
|
)
|
||||||
|
db.add(db_url)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_channel)
|
||||||
|
return db_channel
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{channel_id}", response_model=ChannelResponse)
|
||||||
|
def get_channel(channel_id: UUID, db: Session = Depends(get_db)):
|
||||||
|
"""Get a channel by id"""
|
||||||
|
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
|
||||||
|
if not channel:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||||
|
)
|
||||||
|
return channel
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{channel_id}", response_model=ChannelResponse)
|
||||||
|
@require_roles("admin")
|
||||||
|
def update_channel(
|
||||||
|
channel_id: UUID,
|
||||||
|
channel: ChannelUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Update a channel"""
|
||||||
|
db_channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
|
||||||
|
if not db_channel:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only check for duplicates if name or group_id are being updated
|
||||||
|
if channel.name is not None or channel.group_id is not None:
|
||||||
|
name = channel.name if channel.name is not None else db_channel.name
|
||||||
|
group_id = (
|
||||||
|
channel.group_id if channel.group_id is not None else db_channel.group_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if new group exists
|
||||||
|
if channel.group_id is not None:
|
||||||
|
group = db.query(Group).filter(Group.id == channel.group_id).first()
|
||||||
|
if not group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Group not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
existing_channel = (
|
||||||
|
db.query(ChannelDB)
|
||||||
|
.filter(
|
||||||
|
and_(
|
||||||
|
ChannelDB.group_id == group_id,
|
||||||
|
ChannelDB.name == name,
|
||||||
|
ChannelDB.id != channel_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if existing_channel:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail="Channel with same group_id and name already exists",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update only provided fields
|
||||||
|
update_data = channel.model_dump(exclude_unset=True)
|
||||||
|
for key, value in update_data.items():
|
||||||
|
setattr(db_channel, key, value)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_channel)
|
||||||
|
return db_channel
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/", status_code=status.HTTP_200_OK)
|
||||||
|
@require_roles("admin")
|
||||||
|
def delete_channels(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Delete all channels"""
|
||||||
|
count = 0
|
||||||
|
try:
|
||||||
|
count = db.query(ChannelDB).count()
|
||||||
|
|
||||||
|
# First delete all channels
|
||||||
|
db.query(ChannelDB).delete()
|
||||||
|
|
||||||
|
# Then delete any URLs that are now orphaned (no channel references)
|
||||||
|
db.query(ChannelURL).filter(
|
||||||
|
~ChannelURL.channel_id.in_(db.query(ChannelDB.id))
|
||||||
|
).delete(synchronize_session=False)
|
||||||
|
|
||||||
|
# Then delete any groups that are now empty
|
||||||
|
db.query(Group).filter(~Group.id.in_(db.query(ChannelDB.group_id))).delete(
|
||||||
|
synchronize_session=False
|
||||||
|
)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error deleting channels: {e}")
|
||||||
|
db.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to delete channels",
|
||||||
|
)
|
||||||
|
return {"deleted": count}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{channel_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
@require_roles("admin")
|
||||||
|
def delete_channel(
|
||||||
|
channel_id: UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Delete a channel"""
|
||||||
|
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
|
||||||
|
if not channel:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||||
|
)
|
||||||
|
db.delete(channel)
|
||||||
|
db.commit()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/", response_model=list[ChannelResponse])
|
||||||
|
@require_roles("admin")
|
||||||
|
def list_channels(
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""List all channels with pagination"""
|
||||||
|
return db.query(ChannelDB).offset(skip).limit(limit).all()
|
||||||
|
|
||||||
|
|
||||||
|
# New endpoint to get channels by group
|
||||||
|
@router.get("/groups/{group_id}/channels", response_model=list[ChannelResponse])
|
||||||
|
def get_channels_by_group(
|
||||||
|
group_id: UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Get all channels for a specific group"""
|
||||||
|
group = db.query(Group).filter(Group.id == group_id).first()
|
||||||
|
if not group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
|
||||||
|
)
|
||||||
|
return db.query(ChannelDB).filter(ChannelDB.group_id == group_id).all()
|
||||||
|
|
||||||
|
|
||||||
|
# New endpoint to update a channel's group
|
||||||
|
@router.put("/{channel_id}/group", response_model=ChannelResponse)
|
||||||
|
@require_roles("admin")
|
||||||
|
def update_channel_group(
|
||||||
|
channel_id: UUID,
|
||||||
|
group_id: UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Update a channel's group"""
|
||||||
|
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
|
||||||
|
if not channel:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
group = db.query(Group).filter(Group.id == group_id).first()
|
||||||
|
if not group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for duplicate channel name in new group
|
||||||
|
existing_channel = (
|
||||||
|
db.query(ChannelDB)
|
||||||
|
.filter(
|
||||||
|
and_(
|
||||||
|
ChannelDB.group_id == group_id,
|
||||||
|
ChannelDB.name == channel.name,
|
||||||
|
ChannelDB.id != channel_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if existing_channel:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail="Channel with same name already exists in target group",
|
||||||
|
)
|
||||||
|
|
||||||
|
channel.group_id = group_id
|
||||||
|
db.commit()
|
||||||
|
db.refresh(channel)
|
||||||
|
return channel
|
||||||
|
|
||||||
|
|
||||||
|
# Bulk Upload and Reset Endpoints
|
||||||
|
@router.post("/bulk-upload", status_code=status.HTTP_200_OK)
|
||||||
|
@require_roles("admin")
|
||||||
|
def bulk_upload_channels(
|
||||||
|
channels: list[dict],
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Bulk upload channels from JSON array"""
|
||||||
|
processed = 0
|
||||||
|
|
||||||
|
# Fetch all priorities from the database, ordered by id
|
||||||
|
priorities = db.query(Priority).order_by(Priority.id).all()
|
||||||
|
priority_map = {i: p.id for i, p in enumerate(priorities)}
|
||||||
|
|
||||||
|
# Get the highest priority_id (which corresponds to the lowest priority level)
|
||||||
|
max_priority_id = None
|
||||||
|
if priorities:
|
||||||
|
max_priority_id = db.query(Priority.id).order_by(Priority.id.desc()).first()[0]
|
||||||
|
|
||||||
|
for channel_data in channels:
|
||||||
|
try:
|
||||||
|
# Get or create group
|
||||||
|
group_name = channel_data.get("group-title")
|
||||||
|
if not group_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
group = db.query(Group).filter(Group.name == group_name).first()
|
||||||
|
if not group:
|
||||||
|
group = Group(name=group_name)
|
||||||
|
db.add(group)
|
||||||
|
db.flush() # Use flush to make the group available in the session
|
||||||
|
db.refresh(group)
|
||||||
|
|
||||||
|
# Prepare channel data
|
||||||
|
urls = channel_data.get("urls", [])
|
||||||
|
if not isinstance(urls, list):
|
||||||
|
urls = [urls]
|
||||||
|
|
||||||
|
# Assign priorities dynamically based on fetched priorities
|
||||||
|
url_objects = []
|
||||||
|
for i, url in enumerate(urls): # Process all URLs
|
||||||
|
priority_id = priority_map.get(i)
|
||||||
|
if priority_id is None:
|
||||||
|
# If index is out of bounds,
|
||||||
|
# assign the highest priority_id (lowest priority)
|
||||||
|
if max_priority_id is not None:
|
||||||
|
priority_id = max_priority_id
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Warning: No priorities defined in database. "
|
||||||
|
f"Skipping URL {url}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
url_objects.append({"url": url, "priority_id": priority_id})
|
||||||
|
|
||||||
|
# Create channel object with required fields
|
||||||
|
channel_obj = ChannelDB(
|
||||||
|
tvg_id=channel_data.get("tvg-id", ""),
|
||||||
|
name=channel_data.get("name", ""),
|
||||||
|
group_id=group.id,
|
||||||
|
tvg_name=channel_data.get("tvg-name", ""),
|
||||||
|
tvg_logo=channel_data.get("tvg-logo", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Upsert channel
|
||||||
|
existing_channel = (
|
||||||
|
db.query(ChannelDB)
|
||||||
|
.filter(
|
||||||
|
and_(
|
||||||
|
ChannelDB.group_id == group.id,
|
||||||
|
ChannelDB.name == channel_obj.name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if existing_channel:
|
||||||
|
# Update existing
|
||||||
|
existing_channel.tvg_id = channel_obj.tvg_id
|
||||||
|
existing_channel.tvg_name = channel_obj.tvg_name
|
||||||
|
existing_channel.tvg_logo = channel_obj.tvg_logo
|
||||||
|
|
||||||
|
# Clear and recreate URLs
|
||||||
|
db.query(ChannelURL).filter(
|
||||||
|
ChannelURL.channel_id == existing_channel.id
|
||||||
|
).delete()
|
||||||
|
|
||||||
|
for url in url_objects:
|
||||||
|
db_url = ChannelURL(
|
||||||
|
channel_id=existing_channel.id,
|
||||||
|
url=url["url"],
|
||||||
|
priority_id=url["priority_id"],
|
||||||
|
in_use=False,
|
||||||
|
)
|
||||||
|
db.add(db_url)
|
||||||
|
else:
|
||||||
|
# Create new
|
||||||
|
db.add(channel_obj)
|
||||||
|
db.flush() # Flush to get the new channel's ID
|
||||||
|
db.refresh(channel_obj)
|
||||||
|
|
||||||
|
# Add URLs for new channel
|
||||||
|
for url in url_objects:
|
||||||
|
db_url = ChannelURL(
|
||||||
|
channel_id=channel_obj.id,
|
||||||
|
url=url["url"],
|
||||||
|
priority_id=url["priority_id"],
|
||||||
|
in_use=False,
|
||||||
|
)
|
||||||
|
db.add(db_url)
|
||||||
|
|
||||||
|
db.commit() # Commit all changes for this channel atomically
|
||||||
|
processed += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing channel: {channel_data.get('name', 'Unknown')}")
|
||||||
|
print(f"Exception details: {e}")
|
||||||
|
db.rollback() # Rollback the entire transaction for the failed channel
|
||||||
|
continue
|
||||||
|
|
||||||
|
return {"processed": processed}
|
||||||
|
|
||||||
|
|
||||||
|
# URL Management Endpoints
|
||||||
|
@router.post(
|
||||||
|
"/{channel_id}/urls",
|
||||||
|
response_model=ChannelURLResponse,
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
)
|
||||||
|
@require_roles("admin")
|
||||||
|
def add_channel_url(
|
||||||
|
channel_id: UUID,
|
||||||
|
url: ChannelURLCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Add a new URL to a channel"""
|
||||||
|
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
|
||||||
|
if not channel:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
db_url = ChannelURL(
|
||||||
|
channel_id=channel_id,
|
||||||
|
url=url.url,
|
||||||
|
priority_id=url.priority_id,
|
||||||
|
in_use=False, # Default to not in use
|
||||||
|
)
|
||||||
|
db.add(db_url)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_url)
|
||||||
|
return db_url
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{channel_id}/urls/{url_id}", response_model=ChannelURLResponse)
|
||||||
|
@require_roles("admin")
|
||||||
|
def update_channel_url(
|
||||||
|
channel_id: UUID,
|
||||||
|
url_id: UUID,
|
||||||
|
url_update: ChannelURLUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Update a channel URL (url, in_use, or priority_id)"""
|
||||||
|
db_url = (
|
||||||
|
db.query(ChannelURL)
|
||||||
|
.filter(and_(ChannelURL.id == url_id, ChannelURL.channel_id == channel_id))
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not db_url:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="URL not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
if url_update.url is not None:
|
||||||
|
db_url.url = url_update.url
|
||||||
|
if url_update.in_use is not None:
|
||||||
|
db_url.in_use = url_update.in_use
|
||||||
|
if url_update.priority_id is not None:
|
||||||
|
db_url.priority_id = url_update.priority_id
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_url)
|
||||||
|
return db_url
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{channel_id}/urls/{url_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
@require_roles("admin")
|
||||||
|
def delete_channel_url(
|
||||||
|
channel_id: UUID,
|
||||||
|
url_id: UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Delete a URL from a channel"""
|
||||||
|
url = (
|
||||||
|
db.query(ChannelURL)
|
||||||
|
.filter(and_(ChannelURL.id == url_id, ChannelURL.channel_id == channel_id))
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not url:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="URL not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
db.delete(url)
|
||||||
|
db.commit()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{channel_id}/urls", response_model=list[ChannelURLResponse])
|
||||||
|
@require_roles("admin")
|
||||||
|
def list_channel_urls(
|
||||||
|
channel_id: UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""List all URLs for a channel"""
|
||||||
|
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
|
||||||
|
if not channel:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
return db.query(ChannelURL).filter(ChannelURL.channel_id == channel_id).all()
|
||||||
191
app/routers/groups.py
Normal file
191
app/routers/groups.py
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.auth.dependencies import get_current_user, require_roles
|
||||||
|
from app.models import Group
|
||||||
|
from app.models.auth import CognitoUser
|
||||||
|
from app.models.schemas import (
|
||||||
|
GroupBulkSort,
|
||||||
|
GroupCreate,
|
||||||
|
GroupResponse,
|
||||||
|
GroupSortUpdate,
|
||||||
|
GroupUpdate,
|
||||||
|
)
|
||||||
|
from app.utils.database import get_db
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/groups", tags=["groups"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/", response_model=GroupResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
@require_roles("admin")
|
||||||
|
def create_group(
|
||||||
|
group: GroupCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Create a new channel group"""
|
||||||
|
# Check for duplicate group name
|
||||||
|
existing_group = db.query(Group).filter(Group.name == group.name).first()
|
||||||
|
if existing_group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail="Group with this name already exists",
|
||||||
|
)
|
||||||
|
|
||||||
|
db_group = Group(**group.model_dump())
|
||||||
|
db.add(db_group)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_group)
|
||||||
|
return db_group
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{group_id}", response_model=GroupResponse)
|
||||||
|
def get_group(group_id: UUID, db: Session = Depends(get_db)):
|
||||||
|
"""Get a group by id"""
|
||||||
|
group = db.query(Group).filter(Group.id == group_id).first()
|
||||||
|
if not group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
|
||||||
|
)
|
||||||
|
return group
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{group_id}", response_model=GroupResponse)
|
||||||
|
@require_roles("admin")
|
||||||
|
def update_group(
|
||||||
|
group_id: UUID,
|
||||||
|
group: GroupUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Update a group's name or sort order"""
|
||||||
|
db_group = db.query(Group).filter(Group.id == group_id).first()
|
||||||
|
if not db_group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for duplicate name if name is being updated
|
||||||
|
if group.name is not None and group.name != db_group.name:
|
||||||
|
existing_group = db.query(Group).filter(Group.name == group.name).first()
|
||||||
|
if existing_group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail="Group with this name already exists",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update only provided fields
|
||||||
|
update_data = group.model_dump(exclude_unset=True)
|
||||||
|
for key, value in update_data.items():
|
||||||
|
setattr(db_group, key, value)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_group)
|
||||||
|
return db_group
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/", status_code=status.HTTP_200_OK)
|
||||||
|
@require_roles("admin")
|
||||||
|
def delete_groups(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Delete all groups that have no channels (skip groups with channels)"""
|
||||||
|
groups = db.query(Group).all()
|
||||||
|
deleted = 0
|
||||||
|
skipped = 0
|
||||||
|
|
||||||
|
for group in groups:
|
||||||
|
if not group.channels:
|
||||||
|
db.delete(group)
|
||||||
|
deleted += 1
|
||||||
|
else:
|
||||||
|
skipped += 1
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
return {"deleted": deleted, "skipped": skipped}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{group_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
@require_roles("admin")
|
||||||
|
def delete_group(
|
||||||
|
group_id: UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Delete a group (only if it has no channels)"""
|
||||||
|
group = db.query(Group).filter(Group.id == group_id).first()
|
||||||
|
if not group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if group has any channels
|
||||||
|
if group.channels:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Cannot delete group with existing channels",
|
||||||
|
)
|
||||||
|
|
||||||
|
db.delete(group)
|
||||||
|
db.commit()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/", response_model=list[GroupResponse])
|
||||||
|
def list_groups(db: Session = Depends(get_db)):
|
||||||
|
"""List all groups sorted by sort_order"""
|
||||||
|
return db.query(Group).order_by(Group.sort_order).all()
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{group_id}/sort", response_model=GroupResponse)
|
||||||
|
@require_roles("admin")
|
||||||
|
def update_group_sort_order(
|
||||||
|
group_id: UUID,
|
||||||
|
sort_update: GroupSortUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Update a single group's sort order"""
|
||||||
|
db_group = db.query(Group).filter(Group.id == group_id).first()
|
||||||
|
if not db_group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
db_group.sort_order = sort_update.sort_order
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_group)
|
||||||
|
return db_group
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/reorder", response_model=list[GroupResponse])
|
||||||
|
@require_roles("admin")
|
||||||
|
def bulk_update_sort_orders(
|
||||||
|
bulk_sort: GroupBulkSort,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Bulk update group sort orders"""
|
||||||
|
groups_to_update = []
|
||||||
|
|
||||||
|
for group_data in bulk_sort.groups:
|
||||||
|
group_id = group_data["group_id"]
|
||||||
|
sort_order = group_data["sort_order"]
|
||||||
|
|
||||||
|
group = db.query(Group).filter(Group.id == str(group_id)).first()
|
||||||
|
if not group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Group with id {group_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.sort_order = sort_order
|
||||||
|
groups_to_update.append(group)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
# Return all groups in their new order
|
||||||
|
return db.query(Group).order_by(Group.sort_order).all()
|
||||||
156
app/routers/playlist.py
Normal file
156
app/routers/playlist.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.auth.dependencies import get_current_user
|
||||||
|
from app.iptv.stream_manager import StreamManager
|
||||||
|
from app.models.auth import CognitoUser
|
||||||
|
from app.utils.database import get_db_session
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/playlist", tags=["playlist"])
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# In-memory store for validation processes
|
||||||
|
validation_processes: dict[str, dict] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessStatus(str, Enum):
|
||||||
|
PENDING = "pending"
|
||||||
|
IN_PROGRESS = "in_progress"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
class StreamValidationRequest(BaseModel):
|
||||||
|
"""Request model for stream validation endpoint"""
|
||||||
|
|
||||||
|
channel_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ValidatedStream(BaseModel):
|
||||||
|
"""Model for a validated working stream"""
|
||||||
|
|
||||||
|
channel_id: str
|
||||||
|
stream_url: str
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationProcessResponse(BaseModel):
|
||||||
|
"""Response model for validation process initiation"""
|
||||||
|
|
||||||
|
process_id: str
|
||||||
|
status: ProcessStatus
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationResultResponse(BaseModel):
|
||||||
|
"""Response model for validation results"""
|
||||||
|
|
||||||
|
process_id: str
|
||||||
|
status: ProcessStatus
|
||||||
|
working_streams: Optional[list[ValidatedStream]] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
def run_stream_validation(process_id: str, channel_id: Optional[str], db: Session):
|
||||||
|
"""Background task to validate streams"""
|
||||||
|
try:
|
||||||
|
validation_processes[process_id]["status"] = ProcessStatus.IN_PROGRESS
|
||||||
|
manager = StreamManager(db)
|
||||||
|
|
||||||
|
if channel_id:
|
||||||
|
stream_url = manager.validate_and_select_stream(channel_id)
|
||||||
|
if stream_url:
|
||||||
|
validation_processes[process_id]["result"] = {
|
||||||
|
"working_streams": [
|
||||||
|
ValidatedStream(channel_id=channel_id, stream_url=stream_url)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
validation_processes[process_id]["error"] = (
|
||||||
|
f"No working streams found for channel {channel_id}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# TODO: Implement validation for all channels
|
||||||
|
validation_processes[process_id]["error"] = (
|
||||||
|
"Validation of all channels not yet implemented"
|
||||||
|
)
|
||||||
|
|
||||||
|
validation_processes[process_id]["status"] = ProcessStatus.COMPLETED
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error validating streams: {str(e)}")
|
||||||
|
validation_processes[process_id]["status"] = ProcessStatus.FAILED
|
||||||
|
validation_processes[process_id]["error"] = str(e)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/validate-streams",
|
||||||
|
summary="Start stream validation process",
|
||||||
|
response_model=ValidationProcessResponse,
|
||||||
|
status_code=status.HTTP_202_ACCEPTED,
|
||||||
|
responses={202: {"description": "Validation process started successfully"}},
|
||||||
|
)
|
||||||
|
async def start_stream_validation(
|
||||||
|
request: StreamValidationRequest,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db_session),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Start asynchronous validation of streams.
|
||||||
|
|
||||||
|
- Returns immediately with a process ID
|
||||||
|
- Use GET /validate-streams/{process_id} to check status
|
||||||
|
"""
|
||||||
|
process_id = str(uuid4())
|
||||||
|
validation_processes[process_id] = {
|
||||||
|
"status": ProcessStatus.PENDING,
|
||||||
|
"channel_id": request.channel_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
background_tasks.add_task(run_stream_validation, process_id, request.channel_id, db)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"process_id": process_id,
|
||||||
|
"status": ProcessStatus.PENDING,
|
||||||
|
"message": "Validation process started",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/validate-streams/{process_id}",
|
||||||
|
summary="Check validation process status",
|
||||||
|
response_model=ValidationResultResponse,
|
||||||
|
responses={
|
||||||
|
200: {"description": "Process status and results"},
|
||||||
|
404: {"description": "Process not found"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def get_validation_status(
|
||||||
|
process_id: str, user: CognitoUser = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Check status of a stream validation process.
|
||||||
|
|
||||||
|
Returns current status and results if completed.
|
||||||
|
"""
|
||||||
|
if process_id not in validation_processes:
|
||||||
|
raise HTTPException(status_code=404, detail="Process not found")
|
||||||
|
|
||||||
|
process = validation_processes[process_id]
|
||||||
|
response = {"process_id": process_id, "status": process["status"]}
|
||||||
|
|
||||||
|
if process["status"] == ProcessStatus.COMPLETED:
|
||||||
|
if "error" in process:
|
||||||
|
response["error"] = process["error"]
|
||||||
|
else:
|
||||||
|
response["working_streams"] = process["result"]["working_streams"]
|
||||||
|
elif process["status"] == ProcessStatus.FAILED:
|
||||||
|
response["error"] = process["error"]
|
||||||
|
|
||||||
|
return response
|
||||||
120
app/routers/priorities.py
Normal file
120
app/routers/priorities.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy import delete, select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.auth.dependencies import get_current_user, require_roles
|
||||||
|
from app.models.auth import CognitoUser
|
||||||
|
from app.models.db import Priority
|
||||||
|
from app.models.schemas import PriorityCreate, PriorityResponse
|
||||||
|
from app.utils.database import get_db
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/priorities", tags=["priorities"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/", response_model=PriorityResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
@require_roles("admin")
|
||||||
|
def create_priority(
|
||||||
|
priority: PriorityCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Create a new priority"""
|
||||||
|
# Check if priority with this ID already exists
|
||||||
|
existing = db.get(Priority, priority.id)
|
||||||
|
if existing:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=f"Priority with ID {priority.id} already exists",
|
||||||
|
)
|
||||||
|
|
||||||
|
db_priority = Priority(**priority.model_dump())
|
||||||
|
db.add(db_priority)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_priority)
|
||||||
|
return db_priority
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/", response_model=list[PriorityResponse])
|
||||||
|
@require_roles("admin")
|
||||||
|
def list_priorities(
|
||||||
|
db: Session = Depends(get_db), user: CognitoUser = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""List all priorities"""
|
||||||
|
return db.query(Priority).all()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{priority_id}", response_model=PriorityResponse)
|
||||||
|
@require_roles("admin")
|
||||||
|
def get_priority(
|
||||||
|
priority_id: int,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Get a priority by id"""
|
||||||
|
priority = db.get(Priority, priority_id)
|
||||||
|
if not priority:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Priority not found"
|
||||||
|
)
|
||||||
|
return priority
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/", status_code=status.HTTP_200_OK)
|
||||||
|
@require_roles("admin")
|
||||||
|
def delete_priorities(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Delete all priorities not in use by channel URLs"""
|
||||||
|
from app.models.db import ChannelURL
|
||||||
|
|
||||||
|
priorities = db.query(Priority).all()
|
||||||
|
deleted = 0
|
||||||
|
skipped = 0
|
||||||
|
|
||||||
|
for priority in priorities:
|
||||||
|
in_use = db.scalar(
|
||||||
|
select(ChannelURL).where(ChannelURL.priority_id == priority.id).limit(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not in_use:
|
||||||
|
db.delete(priority)
|
||||||
|
deleted += 1
|
||||||
|
else:
|
||||||
|
skipped += 1
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
return {"deleted": deleted, "skipped": skipped}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{priority_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
@require_roles("admin")
|
||||||
|
def delete_priority(
|
||||||
|
priority_id: int,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Delete a priority (if not in use)"""
|
||||||
|
from app.models.db import ChannelURL
|
||||||
|
|
||||||
|
# Check if priority exists
|
||||||
|
priority = db.get(Priority, priority_id)
|
||||||
|
if not priority:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Priority not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if priority is in use
|
||||||
|
in_use = db.scalar(
|
||||||
|
select(ChannelURL).where(ChannelURL.priority_id == priority_id).limit(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if in_use:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail="Cannot delete priority that is in use by channel URLs",
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute(delete(Priority).where(Priority.id == priority_id))
|
||||||
|
db.commit()
|
||||||
|
return None
|
||||||
57
app/routers/scheduler.py
Normal file
57
app/routers/scheduler.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.auth.dependencies import get_current_user, require_roles
|
||||||
|
from app.iptv.scheduler import StreamScheduler
|
||||||
|
from app.models.auth import CognitoUser
|
||||||
|
from app.utils.database import get_db
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/scheduler",
|
||||||
|
tags=["scheduler"],
|
||||||
|
responses={404: {"description": "Not found"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_scheduler(request: Request) -> StreamScheduler:
|
||||||
|
"""Get the scheduler instance from the app state."""
|
||||||
|
if not hasattr(request.app.state.scheduler, "scheduler"):
|
||||||
|
raise HTTPException(status_code=500, detail="Scheduler not initialized")
|
||||||
|
return request.app.state.scheduler
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/health")
|
||||||
|
@require_roles("admin")
|
||||||
|
def scheduler_health(
|
||||||
|
scheduler: StreamScheduler = Depends(get_scheduler),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Check scheduler health status (admin only)."""
|
||||||
|
try:
|
||||||
|
job = scheduler.scheduler.get_job("daily_stream_validation")
|
||||||
|
next_run = str(job.next_run_time) if job and job.next_run_time else None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "running" if scheduler.scheduler.running else "stopped",
|
||||||
|
"next_run": next_run,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to check scheduler health: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/trigger")
|
||||||
|
@require_roles("admin")
|
||||||
|
def trigger_validation(
|
||||||
|
scheduler: StreamScheduler = Depends(get_scheduler),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Manually trigger stream validation (admin only)."""
|
||||||
|
scheduler.trigger_manual_validation()
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=202, content={"message": "Stream validation triggered"}
|
||||||
|
)
|
||||||
0
app/utils/__init__.py
Normal file
0
app/utils/__init__.py
Normal file
14
app/utils/auth.py
Normal file
14
app/utils/auth.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_secret_hash(username: str, client_id: str, client_secret: str) -> str:
|
||||||
|
"""
|
||||||
|
Calculate the Cognito SECRET_HASH using HMAC SHA256 for secret-enabled clients.
|
||||||
|
"""
|
||||||
|
msg = username + client_id
|
||||||
|
dig = hmac.new(
|
||||||
|
client_secret.encode("utf-8"), msg.encode("utf-8"), hashlib.sha256
|
||||||
|
).digest()
|
||||||
|
return base64.b64encode(dig).decode()
|
||||||
@@ -1,32 +1,41 @@
|
|||||||
import os
|
|
||||||
import argparse
|
import argparse
|
||||||
import requests
|
|
||||||
import logging
|
import logging
|
||||||
from requests.exceptions import RequestException, Timeout, ConnectionError, HTTPError
|
import os
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout
|
||||||
|
|
||||||
|
|
||||||
class StreamValidator:
|
class StreamValidator:
|
||||||
def __init__(self, timeout=10, user_agent=None):
|
def __init__(self, timeout=10, user_agent=None):
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.session = requests.Session()
|
self.session = requests.Session()
|
||||||
self.session.headers.update({
|
self.session.headers.update(
|
||||||
'User-Agent': user_agent or 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36'
|
{
|
||||||
})
|
"User-Agent": user_agent
|
||||||
|
or (
|
||||||
|
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||||
|
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||||
|
"Chrome/120.0.0.0 Safari/537.36"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def validate_stream(self, url):
|
def validate_stream(self, url):
|
||||||
"""Validate a media stream URL with multiple fallback checks"""
|
"""Validate a media stream URL with multiple fallback checks"""
|
||||||
try:
|
try:
|
||||||
headers = {'Range': 'bytes=0-1024'}
|
headers = {"Range": "bytes=0-1024"}
|
||||||
with self.session.get(
|
with self.session.get(
|
||||||
url,
|
url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
stream=True,
|
stream=True,
|
||||||
allow_redirects=True
|
allow_redirects=True,
|
||||||
) as response:
|
) as response:
|
||||||
if response.status_code not in [200, 206]:
|
if response.status_code not in [200, 206]:
|
||||||
return False, f"Invalid status code: {response.status_code}"
|
return False, f"Invalid status code: {response.status_code}"
|
||||||
|
|
||||||
content_type = response.headers.get('Content-Type', '')
|
content_type = response.headers.get("Content-Type", "")
|
||||||
if not self._is_valid_content_type(content_type):
|
if not self._is_valid_content_type(content_type):
|
||||||
return False, f"Invalid content type: {content_type}"
|
return False, f"Invalid content type: {content_type}"
|
||||||
|
|
||||||
@@ -49,47 +58,50 @@ class StreamValidator:
|
|||||||
|
|
||||||
def _is_valid_content_type(self, content_type):
|
def _is_valid_content_type(self, content_type):
|
||||||
valid_types = [
|
valid_types = [
|
||||||
'video/mp2t', 'application/vnd.apple.mpegurl',
|
"video/mp2t",
|
||||||
'application/dash+xml', 'video/mp4',
|
"application/vnd.apple.mpegurl",
|
||||||
'video/webm', 'application/octet-stream',
|
"application/dash+xml",
|
||||||
'application/x-mpegURL'
|
"video/mp4",
|
||||||
|
"video/webm",
|
||||||
|
"application/octet-stream",
|
||||||
|
"application/x-mpegURL",
|
||||||
]
|
]
|
||||||
|
if content_type is None:
|
||||||
|
return False
|
||||||
return any(ct in content_type for ct in valid_types)
|
return any(ct in content_type for ct in valid_types)
|
||||||
|
|
||||||
def parse_playlist(self, file_path):
|
def parse_playlist(self, file_path):
|
||||||
"""Extract stream URLs from M3U playlist file"""
|
"""Extract stream URLs from M3U playlist file"""
|
||||||
urls = []
|
urls = []
|
||||||
try:
|
try:
|
||||||
with open(file_path, 'r') as f:
|
with open(file_path) as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if line and not line.startswith('#'):
|
if line and not line.startswith("#"):
|
||||||
urls.append(line)
|
urls.append(line)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error reading playlist file: {str(e)}")
|
logging.error(f"Error reading playlist file: {str(e)}")
|
||||||
raise
|
raise
|
||||||
return urls
|
return urls
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description='Validate streaming URLs from command line arguments or playlist files',
|
description=(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
"Validate streaming URLs from command line arguments or playlist files"
|
||||||
|
),
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'sources',
|
"sources", nargs="+", help="List of URLs or file paths containing stream URLs"
|
||||||
nargs='+',
|
|
||||||
help='List of URLs or file paths containing stream URLs'
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--timeout',
|
"--timeout", type=int, default=20, help="Timeout in seconds for stream checks"
|
||||||
type=int,
|
|
||||||
default=20,
|
|
||||||
help='Timeout in seconds for stream checks'
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--output',
|
"--output",
|
||||||
default='deadstreams.txt',
|
default="deadstreams.txt",
|
||||||
help='Output file name for inactive streams'
|
help="Output file name for inactive streams",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -97,8 +109,8 @@ def main():
|
|||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||||
handlers=[logging.FileHandler('stream_check.log'), logging.StreamHandler()]
|
handlers=[logging.FileHandler("stream_check.log"), logging.StreamHandler()],
|
||||||
)
|
)
|
||||||
|
|
||||||
validator = StreamValidator(timeout=args.timeout)
|
validator = StreamValidator(timeout=args.timeout)
|
||||||
@@ -127,9 +139,10 @@ def main():
|
|||||||
|
|
||||||
# Save dead streams to file
|
# Save dead streams to file
|
||||||
if dead_streams:
|
if dead_streams:
|
||||||
with open(args.output, 'w') as f:
|
with open(args.output, "w") as f:
|
||||||
f.write('\n'.join(dead_streams))
|
f.write("\n".join(dead_streams))
|
||||||
logging.info(f"Found {len(dead_streams)} dead streams. Saved to {args.output}.")
|
logging.info(f"Found {len(dead_streams)} dead streams. Saved to {args.output}.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
@@ -1,12 +1,21 @@
|
|||||||
|
# Utility functions and constants
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
# Load environment variables from a .env file if it exists
|
# Load environment variables from a .env file if it exists
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
# AWS related constants
|
||||||
|
AWS_REGION = os.environ.get("AWS_REGION", "us-east-2")
|
||||||
|
COGNITO_USER_POOL_ID = os.getenv("COGNITO_USER_POOL_ID")
|
||||||
|
COGNITO_CLIENT_ID = os.getenv("COGNITO_CLIENT_ID")
|
||||||
|
COGNITO_CLIENT_SECRET = os.environ.get("COGNITO_CLIENT_SECRET", None)
|
||||||
|
USER_ROLE_ATTRIBUTE = "zoneinfo"
|
||||||
|
|
||||||
IPTV_SERVER_URL = os.getenv("IPTV_SERVER_URL", "https://iptv.fiorinis.com")
|
IPTV_SERVER_URL = os.getenv("IPTV_SERVER_URL", "https://iptv.fiorinis.com")
|
||||||
|
|
||||||
# Super iptv-server admin credentials for basic auth
|
# iptv-server super admin credentials for basic auth
|
||||||
# Reads from environment variables IPTV_SERVER_ADMIN_USER and IPTV_SERVER_ADMIN_PASSWORD
|
# Reads from environment variables IPTV_SERVER_ADMIN_USER and IPTV_SERVER_ADMIN_PASSWORD
|
||||||
IPTV_SERVER_ADMIN_USER = os.getenv("IPTV_SERVER_ADMIN_USER", "admin")
|
IPTV_SERVER_ADMIN_USER = os.getenv("IPTV_SERVER_ADMIN_USER", "admin")
|
||||||
IPTV_SERVER_ADMIN_PASSWORD = os.getenv("IPTV_SERVER_ADMIN_PASSWORD", "adminpassword")
|
IPTV_SERVER_ADMIN_PASSWORD = os.getenv("IPTV_SERVER_ADMIN_PASSWORD", "adminpassword")
|
||||||
61
app/utils/database.py
Normal file
61
app/utils/database.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
from requests import Session
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from app.models import Base
|
||||||
|
|
||||||
|
from .constants import AWS_REGION
|
||||||
|
|
||||||
|
|
||||||
|
def get_db_credentials():
|
||||||
|
"""Fetch and cache DB credentials from environment or SSM Parameter Store"""
|
||||||
|
if os.getenv("MOCK_AUTH", "").lower() == "true":
|
||||||
|
return (
|
||||||
|
f"postgresql://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}"
|
||||||
|
f"@{os.getenv('DB_HOST')}/{os.getenv('DB_NAME')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
ssm = boto3.client("ssm", region_name=AWS_REGION)
|
||||||
|
try:
|
||||||
|
host = ssm.get_parameter(Name="/iptv-manager/DB_HOST", WithDecryption=True)[
|
||||||
|
"Parameter"
|
||||||
|
]["Value"]
|
||||||
|
user = ssm.get_parameter(Name="/iptv-manager/DB_USER", WithDecryption=True)[
|
||||||
|
"Parameter"
|
||||||
|
]["Value"]
|
||||||
|
password = ssm.get_parameter(
|
||||||
|
Name="/iptv-manager/DB_PASSWORD", WithDecryption=True
|
||||||
|
)["Parameter"]["Value"]
|
||||||
|
dbname = ssm.get_parameter(Name="/iptv-manager/DB_NAME", WithDecryption=True)[
|
||||||
|
"Parameter"
|
||||||
|
]["Value"]
|
||||||
|
return f"postgresql://{user}:{password}@{host}/{dbname}"
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to fetch DB credentials from SSM: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize engine and session maker
|
||||||
|
engine = create_engine(get_db_credentials())
|
||||||
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
||||||
|
|
||||||
|
def init_db():
|
||||||
|
"""Initialize database by creating all tables"""
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
|
|
||||||
|
|
||||||
|
def get_db():
|
||||||
|
"""Dependency for getting database session"""
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def get_db_session() -> Session:
|
||||||
|
"""Get a direct database session (non-generator version)"""
|
||||||
|
return SessionLocal()
|
||||||
23
deploy.sh
23
deploy.sh
@@ -1,23 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Deploy infrastructure
|
|
||||||
cdk deploy
|
|
||||||
|
|
||||||
# Update application on running instances
|
|
||||||
INSTANCE_IDS=$(aws ec2 describe-instances \
|
|
||||||
--filters "Name=tag:Name,Values=IptvUpdater/IptvUpdaterInstance" \
|
|
||||||
"Name=instance-state-name,Values=running" \
|
|
||||||
--query "Reservations[].Instances[].InstanceId" \
|
|
||||||
--output text)
|
|
||||||
|
|
||||||
for INSTANCE_ID in $INSTANCE_IDS; do
|
|
||||||
echo "Updating application on instance: $INSTANCE_ID"
|
|
||||||
aws ssm send-command \
|
|
||||||
--instance-ids "$INSTANCE_ID" \
|
|
||||||
--document-name "AWS-RunShellScript" \
|
|
||||||
--parameters '{"commands":["cd /home/ec2-user/iptv-updater-aws && git pull && pip3 install -r requirements.txt && sudo systemctl restart iptv-updater"]}' \
|
|
||||||
--no-cli-pager \
|
|
||||||
--no-paginate
|
|
||||||
done
|
|
||||||
|
|
||||||
echo "Deployment and instance update complete"
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Destroy infrastructure
|
|
||||||
cdk destroy
|
|
||||||
28
docker/Dockerfile
Normal file
28
docker/Dockerfile
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# Use official Python image
|
||||||
|
FROM python:3.9-slim
|
||||||
|
|
||||||
|
# Set environment variables
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE 1
|
||||||
|
ENV PYTHONUNBUFFERED 1
|
||||||
|
|
||||||
|
# Set work directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install system dependencies
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
gcc \
|
||||||
|
python3-dev \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install Python dependencies
|
||||||
|
COPY requirements.txt .
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
# Copy application code
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
# Expose the port the app runs on
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
# Command to run the application
|
||||||
|
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
17
docker/docker-compose-db.yml
Normal file
17
docker/docker-compose-db.yml
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
version: '3.8'
|
||||||
|
|
||||||
|
services:
|
||||||
|
postgres:
|
||||||
|
image: postgres:13
|
||||||
|
container_name: postgres
|
||||||
|
environment:
|
||||||
|
POSTGRES_USER: postgres
|
||||||
|
POSTGRES_PASSWORD: postgres
|
||||||
|
POSTGRES_DB: iptv_manager
|
||||||
|
ports:
|
||||||
|
- "5432:5432"
|
||||||
|
volumes:
|
||||||
|
- postgres_data:/var/lib/postgresql/data
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
postgres_data:
|
||||||
32
docker/docker-compose-local.yml
Normal file
32
docker/docker-compose-local.yml
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
version: '3.8'
|
||||||
|
|
||||||
|
services:
|
||||||
|
postgres:
|
||||||
|
image: postgres:13
|
||||||
|
environment:
|
||||||
|
POSTGRES_USER: postgres
|
||||||
|
POSTGRES_PASSWORD: postgres
|
||||||
|
POSTGRES_DB: iptv_manager
|
||||||
|
ports:
|
||||||
|
- "5432:5432"
|
||||||
|
volumes:
|
||||||
|
- postgres_data:/var/lib/postgresql/data
|
||||||
|
|
||||||
|
app:
|
||||||
|
build:
|
||||||
|
context: ..
|
||||||
|
dockerfile: docker/Dockerfile
|
||||||
|
environment:
|
||||||
|
DB_USER: postgres
|
||||||
|
DB_PASSWORD: postgres
|
||||||
|
DB_HOST: postgres
|
||||||
|
DB_NAME: iptv_manager
|
||||||
|
MOCK_AUTH: "true"
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
depends_on:
|
||||||
|
- postgres
|
||||||
|
command: uvicorn app.main:app --host 0.0.0.0 --port 8000
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
postgres_data:
|
||||||
@@ -1,67 +1,84 @@
|
|||||||
import os
|
import os
|
||||||
from aws_cdk import (
|
|
||||||
Stack,
|
from aws_cdk import CfnOutput, Duration, RemovalPolicy, Stack
|
||||||
aws_ec2 as ec2,
|
from aws_cdk import aws_cognito as cognito
|
||||||
aws_iam as iam,
|
from aws_cdk import aws_ec2 as ec2
|
||||||
aws_cognito as cognito,
|
from aws_cdk import aws_iam as iam
|
||||||
CfnOutput
|
from aws_cdk import aws_rds as rds
|
||||||
)
|
from aws_cdk import aws_ssm as ssm
|
||||||
from constructs import Construct
|
from constructs import Construct
|
||||||
|
|
||||||
class IptvUpdaterStack(Stack):
|
|
||||||
def __init__(self, scope: Construct, construct_id: str, **kwargs) -> None:
|
class IptvManagerStack(Stack):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
scope: Construct,
|
||||||
|
construct_id: str,
|
||||||
|
freedns_user: str,
|
||||||
|
freedns_password: str,
|
||||||
|
domain_name: str,
|
||||||
|
ssh_public_key: str,
|
||||||
|
repo_url: str,
|
||||||
|
letsencrypt_email: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
super().__init__(scope, construct_id, **kwargs)
|
super().__init__(scope, construct_id, **kwargs)
|
||||||
|
|
||||||
# Create VPC
|
# Create VPC
|
||||||
vpc = ec2.Vpc(self, "IptvUpdaterVPC",
|
vpc = ec2.Vpc(
|
||||||
max_azs=1, # Use only one AZ for free tier
|
self,
|
||||||
|
"IptvManagerVPC",
|
||||||
|
max_azs=2, # Need at least 2 AZs for RDS subnet group
|
||||||
nat_gateways=0, # No NAT Gateway to stay in free tier
|
nat_gateways=0, # No NAT Gateway to stay in free tier
|
||||||
subnet_configuration=[
|
subnet_configuration=[
|
||||||
ec2.SubnetConfiguration(
|
ec2.SubnetConfiguration(
|
||||||
name="public",
|
name="public", subnet_type=ec2.SubnetType.PUBLIC, cidr_mask=24
|
||||||
subnet_type=ec2.SubnetType.PUBLIC,
|
),
|
||||||
cidr_mask=24
|
ec2.SubnetConfiguration(
|
||||||
)
|
name="private",
|
||||||
]
|
subnet_type=ec2.SubnetType.PRIVATE_ISOLATED,
|
||||||
|
cidr_mask=24,
|
||||||
|
),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Security Group
|
# Security Group
|
||||||
security_group = ec2.SecurityGroup(
|
security_group = ec2.SecurityGroup(
|
||||||
self, "IptvUpdaterSG",
|
self, "IptvManagerSG", vpc=vpc, allow_all_outbound=True
|
||||||
vpc=vpc,
|
|
||||||
allow_all_outbound=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
security_group.add_ingress_rule(
|
security_group.add_ingress_rule(
|
||||||
ec2.Peer.any_ipv4(),
|
ec2.Peer.any_ipv4(), ec2.Port.tcp(443), "Allow HTTPS traffic"
|
||||||
ec2.Port.tcp(443),
|
|
||||||
"Allow HTTPS traffic"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
security_group.add_ingress_rule(
|
security_group.add_ingress_rule(
|
||||||
ec2.Peer.any_ipv4(),
|
ec2.Peer.any_ipv4(), ec2.Port.tcp(80), "Allow HTTP traffic"
|
||||||
ec2.Port.tcp(80),
|
|
||||||
"Allow HTTP traffic"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
security_group.add_ingress_rule(
|
security_group.add_ingress_rule(
|
||||||
ec2.Peer.any_ipv4(),
|
ec2.Peer.any_ipv4(), ec2.Port.tcp(22), "Allow SSH traffic"
|
||||||
ec2.Port.tcp(22),
|
|
||||||
"Allow SSH traffic"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Key pair for IPTV Updater instance
|
# Allow PostgreSQL port for tunneling restricted to developer IP
|
||||||
|
security_group.add_ingress_rule(
|
||||||
|
ec2.Peer.ipv4("47.189.88.48/32"), # Developer IP
|
||||||
|
ec2.Port.tcp(5432),
|
||||||
|
"Allow PostgreSQL traffic for tunneling",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Key pair for IPTV Manager instance
|
||||||
key_pair = ec2.KeyPair(
|
key_pair = ec2.KeyPair(
|
||||||
self,
|
self,
|
||||||
"IptvUpdaterKeyPair",
|
"IptvManagerKeyPair",
|
||||||
key_pair_name="iptv-updater-key",
|
key_pair_name="iptv-manager-key",
|
||||||
public_key_material="ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQDZcD20mfOQ/al6VMXWWUcUILYIHyIy5gY9RmC6koicaMYTJ078yMUnYw9DONEIfvxZwceTrtdUWv7vOXjknh1bberF78Pi3vwwFKazdgjbyZnkM9r2N40+PqfOFYf03B53VIA8jdae6gaD3VvYdlwV5dqqW3N+JE/+7sXHLeENnqxeS5jNLoRKDIxHgV4MBnewWNdp77ZEHZFrG3Fj/0lnqq2EfDh+5E/Vhkc/mzkuXu2l5Z1szw2rcjJz4IYUjgnClorBVlmWwwiICrHbyVCYivaaVvpVmZoBy7WW1P8KFAiot4G6C0Klyn4sy2AwCQ7u65TyDtbCR89VuuwW0zLgERAfWMdhn/5HdIudcTScXEsUsADTqxb2x0IXVbs3uhxHCUcc7BVg0S7dpAbmpK+80NlDCH38LFcyrYASRD6/Le2skVePIFt0Tw1OnwPH/QqPG0Y9vLZJl8779pg+kpk+o1MaRnczPq9Sk9zr3dR4Sv82CObnjTeY5LCTvs06JBCtcey/vLk7RQYs43uXZgg746aWIcM0lgHPivh2JpUOAd/Kj1v4aEXTRgHyPPLQ4KKwnALcd4s6+ytXwf4gxumRmPf/7P6HYJQvY8pfVda8jEvu08rvSVMXb09Jq4tzS/JTX2flfEdF0qIyTNHnK/lum27fgP0yq6Pq24IWYXha9w== stefano@MSI"
|
public_key_material=ssh_public_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create IAM role for EC2
|
# Create IAM role for EC2
|
||||||
role = iam.Role(
|
role = iam.Role(
|
||||||
self, "IptvUpdaterRole",
|
self,
|
||||||
assumed_by=iam.ServicePrincipal("ec2.amazonaws.com")
|
"IptvManagerRole",
|
||||||
|
assumed_by=iam.ServicePrincipal("ec2.amazonaws.com"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add SSM managed policy
|
# Add SSM managed policy
|
||||||
@@ -71,70 +88,65 @@ class IptvUpdaterStack(Stack):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add Cognito permissions to instance role
|
# Add EC2 describe permissions
|
||||||
role.add_managed_policy(
|
role.add_to_policy(
|
||||||
iam.ManagedPolicy.from_aws_managed_policy_name(
|
iam.PolicyStatement(actions=["ec2:DescribeInstances"], resources=["*"])
|
||||||
"AmazonCognitoReadOnly"
|
)
|
||||||
|
|
||||||
|
# Add SSM SendCommand permissions
|
||||||
|
role.add_to_policy(
|
||||||
|
iam.PolicyStatement(
|
||||||
|
actions=["ssm:SendCommand"],
|
||||||
|
resources=[
|
||||||
|
# Allow on all EC2 instances
|
||||||
|
f"arn:aws:ec2:{self.region}:{self.account}:instance/*",
|
||||||
|
# Required for the RunShellScript document
|
||||||
|
f"arn:aws:ssm:{self.region}:{self.account}:document/AWS-RunShellScript",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# EC2 Instance
|
# Add Cognito permissions to instance role
|
||||||
instance = ec2.Instance(
|
role.add_managed_policy(
|
||||||
self, "IptvUpdaterInstance",
|
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonCognitoReadOnly")
|
||||||
vpc=vpc,
|
|
||||||
instance_type=ec2.InstanceType.of(
|
|
||||||
ec2.InstanceClass.T2,
|
|
||||||
ec2.InstanceSize.MICRO
|
|
||||||
),
|
|
||||||
machine_image=ec2.AmazonLinuxImage(
|
|
||||||
generation=ec2.AmazonLinuxGeneration.AMAZON_LINUX_2
|
|
||||||
),
|
|
||||||
security_group=security_group,
|
|
||||||
key_pair=key_pair,
|
|
||||||
role=role
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create Elastic IP
|
|
||||||
eip = ec2.CfnEIP(
|
|
||||||
self, "IptvUpdaterEIP",
|
|
||||||
domain="vpc",
|
|
||||||
instance_id=instance.instance_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add Cognito User Pool
|
# Add Cognito User Pool
|
||||||
user_pool = cognito.UserPool(
|
user_pool = cognito.UserPool(
|
||||||
self, "IptvUpdaterUserPool",
|
self,
|
||||||
user_pool_name="iptv-updater-users",
|
"IptvManagerUserPool",
|
||||||
|
user_pool_name="iptv-manager-users",
|
||||||
self_sign_up_enabled=False, # Only admins can create users
|
self_sign_up_enabled=False, # Only admins can create users
|
||||||
password_policy=cognito.PasswordPolicy(
|
password_policy=cognito.PasswordPolicy(
|
||||||
min_length=8,
|
min_length=8,
|
||||||
require_lowercase=True,
|
require_lowercase=True,
|
||||||
require_digits=True,
|
require_digits=True,
|
||||||
require_symbols=True,
|
require_symbols=True,
|
||||||
require_uppercase=True
|
require_uppercase=True,
|
||||||
),
|
),
|
||||||
account_recovery=cognito.AccountRecovery.EMAIL_ONLY
|
account_recovery=cognito.AccountRecovery.EMAIL_ONLY,
|
||||||
|
removal_policy=RemovalPolicy.DESTROY,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add App Client with the correct callback URL
|
# Add App Client with the correct callback URL
|
||||||
client = user_pool.add_client("IptvUpdaterClient",
|
client = user_pool.add_client(
|
||||||
|
"IptvManagerClient",
|
||||||
|
access_token_validity=Duration.minutes(60),
|
||||||
|
id_token_validity=Duration.minutes(60),
|
||||||
|
refresh_token_validity=Duration.days(1),
|
||||||
|
auth_flows=cognito.AuthFlow(user_password=True),
|
||||||
o_auth=cognito.OAuthSettings(
|
o_auth=cognito.OAuthSettings(
|
||||||
flows=cognito.OAuthFlows(
|
flows=cognito.OAuthFlows(implicit_code_grant=True)
|
||||||
authorization_code_grant=True
|
),
|
||||||
),
|
prevent_user_existence_errors=True,
|
||||||
scopes=[cognito.OAuthScope.OPENID],
|
generate_secret=True,
|
||||||
callback_urls=[
|
enable_token_revocation=True,
|
||||||
"http://localhost:8000/auth/callback", # For local testing
|
|
||||||
"https://*.amazonaws.com/auth/callback" # Will match EC2 public DNS
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add domain for hosted UI
|
# Add domain for hosted UI
|
||||||
domain = user_pool.add_domain("IptvUpdaterDomain",
|
domain = user_pool.add_domain(
|
||||||
cognito_domain=cognito.CognitoDomainOptions(
|
"IptvManagerDomain",
|
||||||
domain_prefix="iptv-updater"
|
cognito_domain=cognito.CognitoDomainOptions(domain_prefix="iptv-manager"),
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Read the userdata script with proper path resolution
|
# Read the userdata script with proper path resolution
|
||||||
@@ -144,20 +156,152 @@ class IptvUpdaterStack(Stack):
|
|||||||
|
|
||||||
# Creates a userdata object for Linux hosts
|
# Creates a userdata object for Linux hosts
|
||||||
userdata = ec2.UserData.for_linux()
|
userdata = ec2.UserData.for_linux()
|
||||||
|
|
||||||
|
# Add environment variables for acme.sh from parameters
|
||||||
|
userdata.add_commands(
|
||||||
|
f'export FREEDNS_User="{freedns_user}"',
|
||||||
|
f'export FREEDNS_Password="{freedns_password}"',
|
||||||
|
f'export DOMAIN_NAME="{domain_name}"',
|
||||||
|
f'export REPO_URL="{repo_url}"',
|
||||||
|
f'export LETSENCRYPT_EMAIL="{letsencrypt_email}"',
|
||||||
|
)
|
||||||
|
|
||||||
# Adds one or more commands to the userdata object.
|
# Adds one or more commands to the userdata object.
|
||||||
userdata.add_commands(
|
userdata.add_commands(
|
||||||
f'echo "COGNITO_USER_POOL_ID={user_pool.user_pool_id}" >> /etc/environment',
|
(
|
||||||
f'echo "COGNITO_CLIENT_ID={client.user_pool_client_id}" >> /etc/environment'
|
f'echo "COGNITO_USER_POOL_ID='
|
||||||
|
f'{user_pool.user_pool_id}" >> /etc/environment'
|
||||||
|
),
|
||||||
|
(
|
||||||
|
f'echo "COGNITO_CLIENT_ID='
|
||||||
|
f'{client.user_pool_client_id}" >> /etc/environment'
|
||||||
|
),
|
||||||
|
(
|
||||||
|
f'echo "COGNITO_CLIENT_SECRET='
|
||||||
|
f'{client.user_pool_client_secret.to_string()}" >> /etc/environment'
|
||||||
|
),
|
||||||
|
f'echo "DOMAIN_NAME={domain_name}" >> /etc/environment',
|
||||||
)
|
)
|
||||||
userdata.add_commands(str(userdata_file, 'utf-8'))
|
userdata.add_commands(str(userdata_file, "utf-8"))
|
||||||
|
|
||||||
|
# Create RDS Security Group
|
||||||
|
rds_sg = ec2.SecurityGroup(
|
||||||
|
self,
|
||||||
|
"RdsSecurityGroup",
|
||||||
|
vpc=vpc,
|
||||||
|
description="Security group for RDS PostgreSQL",
|
||||||
|
)
|
||||||
|
rds_sg.add_ingress_rule(
|
||||||
|
security_group,
|
||||||
|
ec2.Port.tcp(5432),
|
||||||
|
"Allow PostgreSQL access from EC2 instance",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create RDS PostgreSQL instance (free tier compatible - db.t3.micro)
|
||||||
|
db = rds.DatabaseInstance(
|
||||||
|
self,
|
||||||
|
"IptvManagerDB",
|
||||||
|
engine=rds.DatabaseInstanceEngine.postgres(
|
||||||
|
version=rds.PostgresEngineVersion.VER_13
|
||||||
|
),
|
||||||
|
instance_type=ec2.InstanceType.of(
|
||||||
|
ec2.InstanceClass.T3, ec2.InstanceSize.MICRO
|
||||||
|
),
|
||||||
|
vpc=vpc,
|
||||||
|
vpc_subnets=ec2.SubnetSelection(
|
||||||
|
subnet_type=ec2.SubnetType.PRIVATE_ISOLATED
|
||||||
|
),
|
||||||
|
security_groups=[rds_sg],
|
||||||
|
allocated_storage=10,
|
||||||
|
max_allocated_storage=10,
|
||||||
|
database_name="iptv_manager",
|
||||||
|
removal_policy=RemovalPolicy.DESTROY,
|
||||||
|
deletion_protection=False,
|
||||||
|
publicly_accessible=False, # Avoid public IPv4 charges
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add RDS permissions to instance role
|
||||||
|
role.add_managed_policy(
|
||||||
|
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonRDSFullAccess")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store DB connection info in SSM Parameter Store
|
||||||
|
db_host_param = ssm.StringParameter(
|
||||||
|
self,
|
||||||
|
"DBHostParam",
|
||||||
|
parameter_name="/iptv-manager/DB_HOST",
|
||||||
|
string_value=db.db_instance_endpoint_address,
|
||||||
|
)
|
||||||
|
db_name_param = ssm.StringParameter(
|
||||||
|
self,
|
||||||
|
"DBNameParam",
|
||||||
|
parameter_name="/iptv-manager/DB_NAME",
|
||||||
|
string_value="iptv_manager",
|
||||||
|
)
|
||||||
|
db_user_param = ssm.StringParameter(
|
||||||
|
self,
|
||||||
|
"DBUserParam",
|
||||||
|
parameter_name="/iptv-manager/DB_USER",
|
||||||
|
string_value=db.secret.secret_value_from_json("username").to_string(),
|
||||||
|
)
|
||||||
|
db_pass_param = ssm.StringParameter(
|
||||||
|
self,
|
||||||
|
"DBPassParam",
|
||||||
|
parameter_name="/iptv-manager/DB_PASSWORD",
|
||||||
|
string_value=db.secret.secret_value_from_json("password").to_string(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add SSM read permissions to instance role
|
||||||
|
role.add_managed_policy(
|
||||||
|
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonSSMReadOnlyAccess")
|
||||||
|
)
|
||||||
|
|
||||||
|
# EC2 Instance (created after all dependencies are ready)
|
||||||
|
instance = ec2.Instance(
|
||||||
|
self,
|
||||||
|
"IptvManagerInstance",
|
||||||
|
vpc=vpc,
|
||||||
|
vpc_subnets=ec2.SubnetSelection(subnet_type=ec2.SubnetType.PUBLIC),
|
||||||
|
instance_type=ec2.InstanceType.of(
|
||||||
|
ec2.InstanceClass.T2, ec2.InstanceSize.MICRO
|
||||||
|
),
|
||||||
|
machine_image=ec2.AmazonLinuxImage(
|
||||||
|
generation=ec2.AmazonLinuxGeneration.AMAZON_LINUX_2023
|
||||||
|
),
|
||||||
|
security_group=security_group,
|
||||||
|
key_pair=key_pair,
|
||||||
|
role=role,
|
||||||
|
# Option: 1: Enable auto-assign public IP (free tier compatible)
|
||||||
|
associate_public_ip_address=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure instance depends on SSM parameters being created
|
||||||
|
instance.node.add_dependency(db)
|
||||||
|
instance.node.add_dependency(db_host_param)
|
||||||
|
instance.node.add_dependency(db_name_param)
|
||||||
|
instance.node.add_dependency(db_user_param)
|
||||||
|
instance.node.add_dependency(db_pass_param)
|
||||||
|
|
||||||
|
# Option: 2: Create Elastic IP (not free tier compatible)
|
||||||
|
# eip = ec2.CfnEIP(
|
||||||
|
# self, "IptvManagerEIP",
|
||||||
|
# domain="vpc",
|
||||||
|
# instance_id=instance.instance_id
|
||||||
|
# )
|
||||||
|
|
||||||
# Update instance with userdata
|
# Update instance with userdata
|
||||||
instance.add_user_data(userdata.render())
|
instance.add_user_data(userdata.render())
|
||||||
|
|
||||||
# Outputs
|
# Outputs
|
||||||
CfnOutput(self, "InstancePublicIP", value=eip.attr_public_ip)
|
CfnOutput(self, "DBEndpoint", value=db.db_instance_endpoint_address)
|
||||||
|
# Option: 1: Use EC2 instance public IP (free tier compatible)
|
||||||
|
CfnOutput(self, "InstancePublicIP", value=instance.instance_public_ip)
|
||||||
|
# Option: 2: Use EIP (not free tier compatible)
|
||||||
|
# CfnOutput(self, "InstancePublicIP", value=eip.attr_public_ip)
|
||||||
CfnOutput(self, "UserPoolId", value=user_pool.user_pool_id)
|
CfnOutput(self, "UserPoolId", value=user_pool.user_pool_id)
|
||||||
CfnOutput(self, "UserPoolClientId", value=client.user_pool_client_id)
|
CfnOutput(self, "UserPoolClientId", value=client.user_pool_client_id)
|
||||||
CfnOutput(self, "CognitoDomainUrl",
|
CfnOutput(
|
||||||
value=f"https://{domain.domain_name}.auth.{self.region}.amazoncognito.com"
|
self,
|
||||||
|
"CognitoDomainUrl",
|
||||||
|
value=f"https://{domain.domain_name}.auth.{self.region}.amazoncognito.com",
|
||||||
)
|
)
|
||||||
@@ -1,29 +1,79 @@
|
|||||||
#!/bin/sh
|
#!/bin/sh
|
||||||
|
|
||||||
yum update -y
|
# Update system and install required packages
|
||||||
yum install -y python3-pip git
|
dnf update -y
|
||||||
amazon-linux-extras install nginx1
|
dnf install -y python3-pip git cronie nginx certbot python3-certbot-nginx postgresql15.x86_64 awscli
|
||||||
|
|
||||||
pip3 install --upgrade pip
|
# Start and enable crond service
|
||||||
pip3 install certbot certbot-nginx
|
systemctl start crond
|
||||||
|
systemctl enable crond
|
||||||
|
|
||||||
cd /home/ec2-user
|
cd /home/ec2-user
|
||||||
|
|
||||||
git clone https://git.fiorinis.com/Home/iptv-updater-aws.git
|
git clone ${REPO_URL}
|
||||||
cd iptv-updater-aws
|
cd iptv-manager-service
|
||||||
|
|
||||||
pip3 install -r requirements.txt
|
# Install Python packages with --ignore-installed to prevent conflicts with RPM packages
|
||||||
|
pip3 install --ignore-installed -r requirements.txt
|
||||||
|
|
||||||
|
# Retrieve DB credentials from SSM Parameter Store with retries
|
||||||
|
echo "Attempting to retrieve DB credentials from SSM..."
|
||||||
|
for i in {1..30}; do
|
||||||
|
DB_HOST=$(aws ssm get-parameter --name "/iptv-manager/DB_HOST" --query "Parameter.Value" --output text 2>/dev/null)
|
||||||
|
DB_NAME=$(aws ssm get-parameter --name "/iptv-manager/DB_NAME" --query "Parameter.Value" --output text 2>/dev/null)
|
||||||
|
DB_USER=$(aws ssm get-parameter --name "/iptv-manager/DB_USER" --query "Parameter.Value" --output text 2>/dev/null)
|
||||||
|
DB_PASSWORD=$(aws ssm get-parameter --name "/iptv-manager/DB_PASSWORD" --query "Parameter.Value" --output text 2>/dev/null)
|
||||||
|
|
||||||
|
if [ -n "$DB_HOST" ] && [ -n "$DB_NAME" ] && [ -n "$DB_USER" ] && [ -n "$DB_PASSWORD" ]; then
|
||||||
|
echo "Successfully retrieved all DB credentials"
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Waiting for SSM parameters to be available... (attempt $i/30)"
|
||||||
|
sleep 5
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ -z "$DB_HOST" ] || [ -z "$DB_NAME" ] || [ -z "$DB_USER" ] || [ -z "$DB_PASSWORD" ]; then
|
||||||
|
echo "ERROR: Failed to retrieve all required DB credentials after 30 attempts"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
export DB_HOST
|
||||||
|
export DB_NAME
|
||||||
|
export DB_USER
|
||||||
|
export DB_PASSWORD
|
||||||
|
|
||||||
|
# Set PGPASSWORD for psql to use
|
||||||
|
export PGPASSWORD=$DB_PASSWORD
|
||||||
|
|
||||||
|
# Wait for PostgreSQL to be ready
|
||||||
|
echo "Waiting for PostgreSQL to start..."
|
||||||
|
until psql -h $DB_HOST -U $DB_USER -d postgres -c '\q'; do
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
echo "PostgreSQL is ready."
|
||||||
|
|
||||||
|
# Create database if it does not exist
|
||||||
|
DB_EXISTS=$(psql -h $DB_HOST -U $DB_USER -d postgres -tc "SELECT 1 FROM pg_database WHERE datname = '$DB_NAME';")
|
||||||
|
if [ -z "$DB_EXISTS" ]; then
|
||||||
|
echo "Creating database $DB_NAME..."
|
||||||
|
psql -h $DB_HOST -U $DB_USER -d postgres -c "CREATE DATABASE $DB_NAME;"
|
||||||
|
echo "Database $DB_NAME created."
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Run database migrations
|
||||||
|
alembic upgrade head
|
||||||
|
|
||||||
# Create systemd service file
|
# Create systemd service file
|
||||||
cat << 'EOF' > /etc/systemd/system/iptv-updater.service
|
cat << 'EOF' > /etc/systemd/system/iptv-manager.service
|
||||||
[Unit]
|
[Unit]
|
||||||
Description=IPTV Updater Service
|
Description=IPTV Manager Service
|
||||||
After=network.target
|
After=network.target
|
||||||
|
|
||||||
[Service]
|
[Service]
|
||||||
Type=simple
|
Type=simple
|
||||||
User=ec2-user
|
User=ec2-user
|
||||||
WorkingDirectory=/home/ec2-user/iptv-updater-aws
|
WorkingDirectory=/home/ec2-user/iptv-manager-service
|
||||||
ExecStart=/usr/local/bin/uvicorn app.main:app --host 127.0.0.1 --port 8000
|
ExecStart=/usr/local/bin/uvicorn app.main:app --host 127.0.0.1 --port 8000
|
||||||
EnvironmentFile=/etc/environment
|
EnvironmentFile=/etc/environment
|
||||||
Restart=always
|
Restart=always
|
||||||
@@ -32,17 +82,42 @@ Restart=always
|
|||||||
WantedBy=multi-user.target
|
WantedBy=multi-user.target
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
|
# Ensure root has a crontab before installing acme.sh
|
||||||
|
crontab -u root -l >/dev/null 2>&1 || (echo "" | crontab -u root -)
|
||||||
|
|
||||||
|
# Install and configure acme.sh
|
||||||
|
curl https://get.acme.sh | sh -s email="${LETSENCRYPT_EMAIL}"
|
||||||
|
|
||||||
|
# Configure acme.sh to use DNS API for FreeDNS
|
||||||
|
. "/.acme.sh/acme.sh.env"
|
||||||
|
"/.acme.sh"/acme.sh --issue --dns dns_freedns -d ${DOMAIN_NAME} -d *.${DOMAIN_NAME}
|
||||||
|
sudo mkdir -p /etc/nginx/ssl
|
||||||
|
"/.acme.sh"/acme.sh --install-cert -d ${DOMAIN_NAME} -d *.${DOMAIN_NAME} \
|
||||||
|
--key-file /etc/nginx/ssl/${DOMAIN_NAME}.pem \
|
||||||
|
--fullchain-file /etc/nginx/ssl/cert.pem \
|
||||||
|
--reloadcmd "service nginx force-reload"
|
||||||
|
|
||||||
# Create nginx config
|
# Create nginx config
|
||||||
cat << 'EOF' > /etc/nginx/conf.d/iptvUpdater.conf
|
cat << EOF > /etc/nginx/conf.d/iptvManager.conf
|
||||||
server {
|
server {
|
||||||
listen 80;
|
listen 80;
|
||||||
server_name $HOSTNAME;
|
server_name ${DOMAIN_NAME} *.${DOMAIN_NAME};
|
||||||
|
return 301 https://\$host\$request_uri;
|
||||||
|
}
|
||||||
|
|
||||||
|
server {
|
||||||
|
listen 443 ssl;
|
||||||
|
server_name ${DOMAIN_NAME} *.${DOMAIN_NAME};
|
||||||
|
|
||||||
|
ssl_certificate /etc/nginx/ssl/cert.pem;
|
||||||
|
ssl_certificate_key /etc/nginx/ssl/${DOMAIN_NAME}.pem;
|
||||||
|
|
||||||
location / {
|
location / {
|
||||||
proxy_pass http://127.0.0.1:8000;
|
proxy_pass http://127.0.0.1:8000;
|
||||||
proxy_set_header Host $host;
|
proxy_set_header Host \$host;
|
||||||
proxy_set_header X-Real-IP $remote_addr;
|
proxy_set_header X-Real-IP \$remote_addr;
|
||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For \$proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto \$scheme;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EOF
|
EOF
|
||||||
@@ -50,5 +125,5 @@ EOF
|
|||||||
# Start nginx service
|
# Start nginx service
|
||||||
systemctl enable nginx
|
systemctl enable nginx
|
||||||
systemctl start nginx
|
systemctl start nginx
|
||||||
systemctl enable iptv-updater
|
systemctl enable iptv-manager
|
||||||
systemctl start iptv-updater
|
systemctl start iptv-manager
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Install dependencies and deploy infrastructure
|
|
||||||
npm install -g aws-cdk
|
|
||||||
python3 -m pip install -r requirements.txt
|
|
||||||
31
pyproject.toml
Normal file
31
pyproject.toml
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
[tool.ruff]
|
||||||
|
line-length = 88
|
||||||
|
exclude = [
|
||||||
|
"alembic/versions/*.py", # Auto-generated Alembic migration files
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = [
|
||||||
|
"E", # pycodestyle errors
|
||||||
|
"F", # pyflakes
|
||||||
|
"I", # isort
|
||||||
|
"UP", # pyupgrade
|
||||||
|
"W", # pycodestyle warnings
|
||||||
|
]
|
||||||
|
ignore = []
|
||||||
|
|
||||||
|
[tool.ruff.lint.per-file-ignores]
|
||||||
|
"tests/**/*.py" = [
|
||||||
|
"F811", # redefinition of unused name
|
||||||
|
"F401", # unused import
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff.lint.isort]
|
||||||
|
known-first-party = ["app"]
|
||||||
|
|
||||||
|
[tool.ruff.format]
|
||||||
|
docstring-code-format = true
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
addopts = "--cov=app --cov-report=term-missing --cov-fail-under=70"
|
||||||
|
testpaths = ["tests"]
|
||||||
28
pytest.ini
Normal file
28
pytest.ini
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
[pytest]
|
||||||
|
testpaths = tests
|
||||||
|
python_files = test_*.py
|
||||||
|
python_functions = test_*
|
||||||
|
asyncio_mode = auto
|
||||||
|
filterwarnings =
|
||||||
|
ignore::DeprecationWarning:botocore.auth
|
||||||
|
ignore:The 'app' shortcut is now deprecated:DeprecationWarning:httpx._client
|
||||||
|
|
||||||
|
# Coverage configuration
|
||||||
|
addopts =
|
||||||
|
--cov=app
|
||||||
|
--cov-report=term-missing
|
||||||
|
|
||||||
|
# Test environment variables
|
||||||
|
env =
|
||||||
|
MOCK_AUTH=true
|
||||||
|
DB_USER=test_user
|
||||||
|
DB_PASSWORD=test_password
|
||||||
|
DB_HOST=localhost
|
||||||
|
DB_NAME=iptv_manager_test
|
||||||
|
|
||||||
|
# Test markers
|
||||||
|
markers =
|
||||||
|
slow: mark tests as slow running
|
||||||
|
integration: integration tests
|
||||||
|
unit: unit tests
|
||||||
|
db: tests requiring database
|
||||||
@@ -8,3 +8,15 @@ requests==2.31.0
|
|||||||
passlib[bcrypt]==1.7.4
|
passlib[bcrypt]==1.7.4
|
||||||
boto3==1.28.0
|
boto3==1.28.0
|
||||||
starlette>=0.27.0
|
starlette>=0.27.0
|
||||||
|
pyjwt==2.7.0
|
||||||
|
sqlalchemy==2.0.23
|
||||||
|
psycopg2-binary==2.9.9
|
||||||
|
alembic==1.16.1
|
||||||
|
pytest==8.1.1
|
||||||
|
pytest-asyncio==0.23.6
|
||||||
|
pytest-mock==3.12.0
|
||||||
|
pytest-cov==4.1.0
|
||||||
|
pytest-env==1.1.1
|
||||||
|
httpx==0.27.0
|
||||||
|
pre-commit
|
||||||
|
apscheduler==3.10.4
|
||||||
36
scripts/create_cognito_user.sh
Executable file
36
scripts/create_cognito_user.sh
Executable file
@@ -0,0 +1,36 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
if [ "$#" -lt 3 ]; then
|
||||||
|
echo "Usage: $0 USER_POOL_ID USERNAME PASSWORD [--admin]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
USER_POOL_ID=$1
|
||||||
|
USERNAME=$2
|
||||||
|
PASSWORD=$3
|
||||||
|
ADMIN_FLAG=${4:-""}
|
||||||
|
|
||||||
|
# Create user with temporary password
|
||||||
|
CREATE_CMD="aws cognito-idp admin-create-user --no-cli-pager \
|
||||||
|
--user-pool-id \"$USER_POOL_ID\" \
|
||||||
|
--username \"$USERNAME\" \
|
||||||
|
--temporary-password \"TempPass123!\" \
|
||||||
|
--output json > /dev/null 2>&1"
|
||||||
|
|
||||||
|
if [ "$ADMIN_FLAG" == "--admin" ]; then
|
||||||
|
CREATE_CMD+=" --user-attributes Name=zoneinfo,Value=admin"
|
||||||
|
fi
|
||||||
|
|
||||||
|
eval "$CREATE_CMD"
|
||||||
|
|
||||||
|
# Set permanent password
|
||||||
|
aws cognito-idp admin-set-user-password --no-cli-pager \
|
||||||
|
--user-pool-id "$USER_POOL_ID" \
|
||||||
|
--username "$USERNAME" \
|
||||||
|
--password "$PASSWORD" \
|
||||||
|
--permanent \
|
||||||
|
--output json > /dev/null 2>&1
|
||||||
|
|
||||||
|
echo "User $USERNAME created successfully"
|
||||||
18
scripts/delete_cognito_user.sh
Executable file
18
scripts/delete_cognito_user.sh
Executable file
@@ -0,0 +1,18 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
if [ "$#" -ne 2 ]; then
|
||||||
|
echo "Usage: $0 USER_POOL_ID USERNAME"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
USER_POOL_ID=$1
|
||||||
|
USERNAME=$2
|
||||||
|
|
||||||
|
aws cognito-idp admin-delete-user --no-cli-pager \
|
||||||
|
--user-pool-id "$USER_POOL_ID" \
|
||||||
|
--username "$USERNAME" \
|
||||||
|
--output json > /dev/null 2>&1
|
||||||
|
|
||||||
|
echo "User $USERNAME deleted successfully"
|
||||||
43
scripts/deploy.sh
Executable file
43
scripts/deploy.sh
Executable file
@@ -0,0 +1,43 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Load environment variables from .env file if it exists
|
||||||
|
if [ -f ${PWD}/.env ]; then
|
||||||
|
# Use set -a to automatically export all variables
|
||||||
|
set -a
|
||||||
|
source ${PWD}/.env
|
||||||
|
set +a
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if required environment variables are set
|
||||||
|
if [ -z "$FREEDNS_User" ] ||
|
||||||
|
[ -z "$FREEDNS_Password" ] ||
|
||||||
|
[ -z "$DOMAIN_NAME" ] ||
|
||||||
|
[ -z "$SSH_PUBLIC_KEY" ] ||
|
||||||
|
[ -z "$REPO_URL" ] ||
|
||||||
|
[ -z "$LETSENCRYPT_EMAIL" ]; then
|
||||||
|
echo "Error: FREEDNS_User, FREEDNS_Password, DOMAIN_NAME, SSH_PUBLIC_KEY, REPO_URL, and LETSENCRYPT_EMAIL must be set as environment variables or in a .env file."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Deploy infrastructure
|
||||||
|
cdk deploy --app="python3 ${PWD}/app.py"
|
||||||
|
|
||||||
|
# Update application on running instances
|
||||||
|
INSTANCE_IDS=$(aws ec2 describe-instances \
|
||||||
|
--region us-east-2 \
|
||||||
|
--filters "Name=tag:Name,Values=IptvManagerStack/IptvManagerInstance" \
|
||||||
|
"Name=instance-state-name,Values=running" \
|
||||||
|
--query "Reservations[].Instances[].InstanceId" \
|
||||||
|
--output text)
|
||||||
|
|
||||||
|
for INSTANCE_ID in $INSTANCE_IDS; do
|
||||||
|
echo "Updating application on instance: $INSTANCE_ID"
|
||||||
|
aws ssm send-command \
|
||||||
|
--instance-ids "$INSTANCE_ID" \
|
||||||
|
--document-name "AWS-RunShellScript" \
|
||||||
|
--parameters '{"commands":["cd /home/ec2-user/iptv-manager-service && git pull && pip3 install -r requirements.txt && alembic upgrade head && sudo systemctl restart iptv-manager"]}' \
|
||||||
|
--no-cli-pager \
|
||||||
|
--no-paginate
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "Deployment and instance update complete"
|
||||||
23
scripts/destroy.sh
Executable file
23
scripts/destroy.sh
Executable file
@@ -0,0 +1,23 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Load environment variables from .env file if it exists
|
||||||
|
if [ -f ${PWD}/.env ]; then
|
||||||
|
# Use set -a to automatically export all variables
|
||||||
|
set -a
|
||||||
|
source ${PWD}/.env
|
||||||
|
set +a
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if required environment variables are set
|
||||||
|
if [ -z "$FREEDNS_User" ] ||
|
||||||
|
[ -z "$FREEDNS_Password" ] ||
|
||||||
|
[ -z "$DOMAIN_NAME" ] ||
|
||||||
|
[ -z "$SSH_PUBLIC_KEY" ] ||
|
||||||
|
[ -z "$REPO_URL" ] ||
|
||||||
|
[ -z "$LETSENCRYPT_EMAIL" ]; then
|
||||||
|
echo "Error: FREEDNS_User, FREEDNS_Password, DOMAIN_NAME, SSH_PUBLIC_KEY, REPO_URL, and LETSENCRYPT_EMAIL must be set as environment variables or in a .env file."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Destroy infrastructure
|
||||||
|
cdk destroy --app="python3 ${PWD}/app.py" --force
|
||||||
13
scripts/install.sh
Executable file
13
scripts/install.sh
Executable file
@@ -0,0 +1,13 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Install dependencies and deploy infrastructure
|
||||||
|
npm install -g aws-cdk
|
||||||
|
python3 -m pip install -r requirements.txt
|
||||||
|
|
||||||
|
# Install and configure pre-commit hooks
|
||||||
|
pre-commit install
|
||||||
|
pre-commit install-hooks
|
||||||
|
pre-commit autoupdate
|
||||||
|
|
||||||
|
# Verify pytest setup
|
||||||
|
python3 -m pytest
|
||||||
28
scripts/start_local_dev.sh
Executable file
28
scripts/start_local_dev.sh
Executable file
@@ -0,0 +1,28 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Start PostgreSQL
|
||||||
|
docker-compose -f docker/docker-compose-db.yml up -d
|
||||||
|
|
||||||
|
# Set environment variables
|
||||||
|
export MOCK_AUTH=true
|
||||||
|
export DB_HOST=localhost
|
||||||
|
export DB_USER=postgres
|
||||||
|
export DB_PASSWORD=postgres
|
||||||
|
export DB_NAME=iptv_manager
|
||||||
|
|
||||||
|
echo "Ensuring database $DB_NAME exists using conditional DDL..."
|
||||||
|
PGPASSWORD=$DB_PASSWORD docker exec -i postgres psql -U $DB_USER <<< "SELECT 'CREATE DATABASE $DB_NAME' WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = '$DB_NAME')\gexec"
|
||||||
|
echo "Database $DB_NAME check complete."
|
||||||
|
|
||||||
|
# Run database migrations
|
||||||
|
alembic upgrade head
|
||||||
|
|
||||||
|
# Start FastAPI
|
||||||
|
nohup uvicorn app.main:app --host 127.0.0.1 --port 8000 > app.log 2>&1 &
|
||||||
|
echo $! > iptv-manager.pid
|
||||||
|
|
||||||
|
echo "Services started:"
|
||||||
|
echo "- PostgreSQL running on localhost:5432"
|
||||||
|
echo "- FastAPI running on http://127.0.0.1:8000"
|
||||||
|
echo "- Mock auth enabled (use token: testuser)"
|
||||||
19
scripts/stop_local_dev.sh
Executable file
19
scripts/stop_local_dev.sh
Executable file
@@ -0,0 +1,19 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Stop FastAPI
|
||||||
|
if [ -f iptv-manager.pid ]; then
|
||||||
|
kill $(cat iptv-manager.pid)
|
||||||
|
rm iptv-manager.pid
|
||||||
|
echo "Stopped FastAPI"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Clean up mock auth and database environment variables
|
||||||
|
unset MOCK_AUTH
|
||||||
|
unset DB_USER
|
||||||
|
unset DB_PASSWORD
|
||||||
|
unset DB_HOST
|
||||||
|
unset DB_NAME
|
||||||
|
|
||||||
|
# Stop PostgreSQL
|
||||||
|
docker-compose -f docker/docker-compose-db.yml down
|
||||||
|
echo "Stopped PostgreSQL"
|
||||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
0
tests/auth/__init__.py
Normal file
0
tests/auth/__init__.py
Normal file
169
tests/auth/test_cognito.py
Normal file
169
tests/auth/test_cognito.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
|
||||||
|
# Test constants
|
||||||
|
TEST_CLIENT_ID = "test_client_id"
|
||||||
|
TEST_CLIENT_SECRET = "test_client_secret"
|
||||||
|
|
||||||
|
# Patch constants before importing the module
|
||||||
|
with (
|
||||||
|
patch("app.utils.constants.COGNITO_CLIENT_ID", TEST_CLIENT_ID),
|
||||||
|
patch("app.utils.constants.COGNITO_CLIENT_SECRET", TEST_CLIENT_SECRET),
|
||||||
|
):
|
||||||
|
from app.auth.cognito import get_user_from_token, initiate_auth
|
||||||
|
from app.models.auth import CognitoUser
|
||||||
|
from app.utils.constants import USER_ROLE_ATTRIBUTE
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_cognito_client():
|
||||||
|
with patch("app.auth.cognito.cognito_client") as mock_client:
|
||||||
|
# Setup mock client and exceptions
|
||||||
|
mock_client.exceptions = MagicMock()
|
||||||
|
mock_client.exceptions.NotAuthorizedException = type(
|
||||||
|
"NotAuthorizedException", (Exception,), {}
|
||||||
|
)
|
||||||
|
mock_client.exceptions.UserNotFoundException = type(
|
||||||
|
"UserNotFoundException", (Exception,), {}
|
||||||
|
)
|
||||||
|
yield mock_client
|
||||||
|
|
||||||
|
|
||||||
|
def test_initiate_auth_success(mock_cognito_client):
|
||||||
|
# Mock successful authentication response
|
||||||
|
mock_cognito_client.initiate_auth.return_value = {
|
||||||
|
"AuthenticationResult": {
|
||||||
|
"AccessToken": "mock_access_token",
|
||||||
|
"IdToken": "mock_id_token",
|
||||||
|
"RefreshToken": "mock_refresh_token",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = initiate_auth("test_user", "test_pass")
|
||||||
|
assert result == {
|
||||||
|
"AccessToken": "mock_access_token",
|
||||||
|
"IdToken": "mock_id_token",
|
||||||
|
"RefreshToken": "mock_refresh_token",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_initiate_auth_with_secret_hash(mock_cognito_client):
|
||||||
|
with patch(
|
||||||
|
"app.auth.cognito.calculate_secret_hash", return_value="mocked_secret_hash"
|
||||||
|
) as mock_hash:
|
||||||
|
mock_cognito_client.initiate_auth.return_value = {
|
||||||
|
"AuthenticationResult": {"AccessToken": "token"}
|
||||||
|
}
|
||||||
|
|
||||||
|
initiate_auth("test_user", "test_pass")
|
||||||
|
|
||||||
|
# Verify calculate_secret_hash was called
|
||||||
|
mock_hash.assert_called_once_with(
|
||||||
|
"test_user", TEST_CLIENT_ID, TEST_CLIENT_SECRET
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify SECRET_HASH was included in auth params
|
||||||
|
call_args = mock_cognito_client.initiate_auth.call_args[1]
|
||||||
|
assert "SECRET_HASH" in call_args["AuthParameters"]
|
||||||
|
assert call_args["AuthParameters"]["SECRET_HASH"] == "mocked_secret_hash"
|
||||||
|
|
||||||
|
|
||||||
|
def test_initiate_auth_not_authorized(mock_cognito_client):
|
||||||
|
mock_cognito_client.initiate_auth.side_effect = (
|
||||||
|
mock_cognito_client.exceptions.NotAuthorizedException()
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
initiate_auth("invalid_user", "wrong_pass")
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
assert exc_info.value.detail == "Invalid username or password"
|
||||||
|
|
||||||
|
|
||||||
|
def test_initiate_auth_user_not_found(mock_cognito_client):
|
||||||
|
mock_cognito_client.initiate_auth.side_effect = (
|
||||||
|
mock_cognito_client.exceptions.UserNotFoundException()
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
initiate_auth("nonexistent_user", "any_pass")
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert exc_info.value.detail == "User not found"
|
||||||
|
|
||||||
|
|
||||||
|
def test_initiate_auth_generic_error(mock_cognito_client):
|
||||||
|
mock_cognito_client.initiate_auth.side_effect = Exception("Some error")
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
initiate_auth("test_user", "test_pass")
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||||
|
assert "An error occurred during authentication" in exc_info.value.detail
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_user_from_token_success(mock_cognito_client):
|
||||||
|
mock_response = {
|
||||||
|
"Username": "test_user",
|
||||||
|
"UserAttributes": [
|
||||||
|
{"Name": "sub", "Value": "123"},
|
||||||
|
{"Name": USER_ROLE_ATTRIBUTE, "Value": "admin,user"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
mock_cognito_client.get_user.return_value = mock_response
|
||||||
|
|
||||||
|
result = get_user_from_token("valid_token")
|
||||||
|
|
||||||
|
assert isinstance(result, CognitoUser)
|
||||||
|
assert result.username == "test_user"
|
||||||
|
assert set(result.roles) == {"admin", "user"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_user_from_token_no_roles(mock_cognito_client):
|
||||||
|
mock_response = {
|
||||||
|
"Username": "test_user",
|
||||||
|
"UserAttributes": [{"Name": "sub", "Value": "123"}],
|
||||||
|
}
|
||||||
|
mock_cognito_client.get_user.return_value = mock_response
|
||||||
|
|
||||||
|
result = get_user_from_token("valid_token")
|
||||||
|
|
||||||
|
assert isinstance(result, CognitoUser)
|
||||||
|
assert result.username == "test_user"
|
||||||
|
assert result.roles == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_user_from_token_invalid_token(mock_cognito_client):
|
||||||
|
mock_cognito_client.get_user.side_effect = (
|
||||||
|
mock_cognito_client.exceptions.NotAuthorizedException()
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
get_user_from_token("invalid_token")
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
assert exc_info.value.detail == "Invalid or expired token."
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_user_from_token_user_not_found(mock_cognito_client):
|
||||||
|
mock_cognito_client.get_user.side_effect = (
|
||||||
|
mock_cognito_client.exceptions.UserNotFoundException()
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
get_user_from_token("token_for_nonexistent_user")
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
assert exc_info.value.detail == "User not found or invalid token."
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_user_from_token_generic_error(mock_cognito_client):
|
||||||
|
mock_cognito_client.get_user.side_effect = Exception("Some error")
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
get_user_from_token("test_token")
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||||
|
assert "Token verification failed" in exc_info.value.detail
|
||||||
205
tests/auth/test_dependencies.py
Normal file
205
tests/auth/test_dependencies.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import Depends, HTTPException, Request
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
|
||||||
|
from app.auth.dependencies import get_current_user, oauth2_scheme, require_roles
|
||||||
|
from app.models.auth import CognitoUser
|
||||||
|
|
||||||
|
# Mock user for testing
|
||||||
|
TEST_USER = CognitoUser(
|
||||||
|
username="testuser",
|
||||||
|
email="test@example.com",
|
||||||
|
roles=["admin", "user"],
|
||||||
|
groups=["test_group"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Mock the underlying get_user_from_token function
|
||||||
|
def mock_get_user_from_token(token: str) -> CognitoUser:
|
||||||
|
if token == "valid_token":
|
||||||
|
return TEST_USER
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid token")
|
||||||
|
|
||||||
|
|
||||||
|
# Mock endpoint for testing the require_roles decorator
|
||||||
|
@require_roles("admin")
|
||||||
|
def mock_protected_endpoint(user: CognitoUser = Depends(get_current_user)):
|
||||||
|
return {"message": "Success", "user": user.username}
|
||||||
|
|
||||||
|
|
||||||
|
# Patch the get_user_from_token function for testing
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_auth(monkeypatch):
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.auth.dependencies.get_user_from_token", mock_get_user_from_token
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test get_current_user dependency
|
||||||
|
def test_get_current_user_success():
|
||||||
|
user = get_current_user("valid_token")
|
||||||
|
assert user == TEST_USER
|
||||||
|
assert user.username == "testuser"
|
||||||
|
assert user.roles == ["admin", "user"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_current_user_invalid_token():
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
get_current_user("invalid_token")
|
||||||
|
assert exc.value.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
# Test require_roles decorator
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_require_roles_success():
|
||||||
|
# Create test user with required role
|
||||||
|
user = CognitoUser(
|
||||||
|
username="testuser", email="test@example.com", roles=["admin"], groups=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await mock_protected_endpoint(user=user)
|
||||||
|
assert result == {"message": "Success", "user": "testuser"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_require_roles_missing_role():
|
||||||
|
# Create test user without required role
|
||||||
|
user = CognitoUser(
|
||||||
|
username="testuser", email="test@example.com", roles=["user"], groups=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
await mock_protected_endpoint(user=user)
|
||||||
|
assert exc.value.status_code == 403
|
||||||
|
assert (
|
||||||
|
exc.value.detail
|
||||||
|
== "You do not have the required roles to access this endpoint."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_require_roles_no_roles():
|
||||||
|
# Create test user with no roles
|
||||||
|
user = CognitoUser(
|
||||||
|
username="testuser", email="test@example.com", roles=[], groups=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
await mock_protected_endpoint(user=user)
|
||||||
|
assert exc.value.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_require_roles_multiple_roles():
|
||||||
|
# Test requiring multiple roles
|
||||||
|
@require_roles("admin", "super_user")
|
||||||
|
def mock_multi_role_endpoint(user: CognitoUser = Depends(get_current_user)):
|
||||||
|
return {"message": "Success"}
|
||||||
|
|
||||||
|
# User with all required roles
|
||||||
|
user_with_roles = CognitoUser(
|
||||||
|
username="testuser",
|
||||||
|
email="test@example.com",
|
||||||
|
roles=["admin", "super_user", "user"],
|
||||||
|
groups=[],
|
||||||
|
)
|
||||||
|
result = await mock_multi_role_endpoint(user=user_with_roles)
|
||||||
|
assert result == {"message": "Success"}
|
||||||
|
|
||||||
|
# User missing one required role
|
||||||
|
user_missing_role = CognitoUser(
|
||||||
|
username="testuser",
|
||||||
|
email="test@example.com",
|
||||||
|
roles=["admin", "user"],
|
||||||
|
groups=[],
|
||||||
|
)
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
await mock_multi_role_endpoint(user=user_missing_role)
|
||||||
|
assert exc.value.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oauth2_scheme_configuration():
|
||||||
|
# Verify that we have a properly configured OAuth2PasswordBearer instance
|
||||||
|
assert isinstance(oauth2_scheme, OAuth2PasswordBearer)
|
||||||
|
|
||||||
|
# Create a mock request with no Authorization header
|
||||||
|
mock_request = Request(
|
||||||
|
scope={
|
||||||
|
"type": "http",
|
||||||
|
"headers": [],
|
||||||
|
"method": "GET",
|
||||||
|
"scheme": "http",
|
||||||
|
"path": "/",
|
||||||
|
"query_string": b"",
|
||||||
|
"client": ("127.0.0.1", 8000),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test that the scheme raises 401 when no token is provided
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
await oauth2_scheme(mock_request)
|
||||||
|
assert exc.value.status_code == 401
|
||||||
|
assert exc.value.detail == "Not authenticated"
|
||||||
|
|
||||||
|
|
||||||
|
def test_mock_auth_import(monkeypatch):
|
||||||
|
# Save original env var value
|
||||||
|
original_value = os.environ.get("MOCK_AUTH")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set MOCK_AUTH to true
|
||||||
|
monkeypatch.setenv("MOCK_AUTH", "true")
|
||||||
|
|
||||||
|
# Reload the dependencies module to trigger the import condition
|
||||||
|
import app.auth.dependencies
|
||||||
|
|
||||||
|
importlib.reload(app.auth.dependencies)
|
||||||
|
|
||||||
|
# Verify that mock_get_user_from_token was imported
|
||||||
|
from app.auth.dependencies import get_user_from_token
|
||||||
|
|
||||||
|
assert get_user_from_token.__module__ == "app.auth.mock_auth"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore original env var
|
||||||
|
if original_value is None:
|
||||||
|
monkeypatch.delenv("MOCK_AUTH", raising=False)
|
||||||
|
else:
|
||||||
|
monkeypatch.setenv("MOCK_AUTH", original_value)
|
||||||
|
|
||||||
|
# Reload again to restore original state
|
||||||
|
importlib.reload(app.auth.dependencies)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cognito_auth_import(monkeypatch):
|
||||||
|
"""Test that cognito auth is imported when MOCK_AUTH=false (covers line 14)"""
|
||||||
|
# Save original env var value
|
||||||
|
original_value = os.environ.get("MOCK_AUTH")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set MOCK_AUTH to false
|
||||||
|
monkeypatch.setenv("MOCK_AUTH", "false")
|
||||||
|
|
||||||
|
# Reload the dependencies module to trigger the import condition
|
||||||
|
import app.auth.dependencies
|
||||||
|
|
||||||
|
importlib.reload(app.auth.dependencies)
|
||||||
|
|
||||||
|
# Verify that get_user_from_token was imported from app.auth.cognito
|
||||||
|
from app.auth.dependencies import get_user_from_token
|
||||||
|
|
||||||
|
assert get_user_from_token.__module__ == "app.auth.cognito"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore original env var
|
||||||
|
if original_value is None:
|
||||||
|
monkeypatch.delenv("MOCK_AUTH", raising=False)
|
||||||
|
else:
|
||||||
|
monkeypatch.setenv("MOCK_AUTH", original_value)
|
||||||
|
|
||||||
|
# Reload again to restore original state
|
||||||
|
importlib.reload(app.auth.dependencies)
|
||||||
41
tests/auth/test_mock_auth.py
Normal file
41
tests/auth/test_mock_auth.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
import pytest
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from app.auth.mock_auth import mock_get_user_from_token, mock_initiate_auth
|
||||||
|
from app.models.auth import CognitoUser
|
||||||
|
|
||||||
|
|
||||||
|
def test_mock_get_user_from_token_success():
|
||||||
|
"""Test successful token validation returns expected user"""
|
||||||
|
user = mock_get_user_from_token("testuser")
|
||||||
|
assert isinstance(user, CognitoUser)
|
||||||
|
assert user.username == "testuser"
|
||||||
|
assert user.roles == ["admin"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_mock_get_user_from_token_invalid():
|
||||||
|
"""Test invalid token raises expected exception"""
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
mock_get_user_from_token("invalid_token")
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 401
|
||||||
|
assert exc_info.value.detail == "Invalid mock token - use 'testuser'"
|
||||||
|
|
||||||
|
|
||||||
|
def test_mock_initiate_auth():
|
||||||
|
"""Test mock authentication returns expected token response"""
|
||||||
|
result = mock_initiate_auth("any_user", "any_password")
|
||||||
|
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert result["AccessToken"] == "testuser"
|
||||||
|
assert result["ExpiresIn"] == 3600
|
||||||
|
assert result["TokenType"] == "Bearer"
|
||||||
|
|
||||||
|
|
||||||
|
def test_mock_initiate_auth_different_credentials():
|
||||||
|
"""Test mock authentication works with any credentials"""
|
||||||
|
result1 = mock_initiate_auth("user1", "pass1")
|
||||||
|
result2 = mock_initiate_auth("user2", "pass2")
|
||||||
|
|
||||||
|
# Should return same mock token regardless of credentials
|
||||||
|
assert result1 == result2
|
||||||
0
tests/iptv/__init__.py
Normal file
0
tests/iptv/__init__.py
Normal file
0
tests/models/__init__.py
Normal file
0
tests/models/__init__.py
Normal file
144
tests/models/test_db.py
Normal file
144
tests/models/test_db.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from app.models.db import UUID_COLUMN_TYPE, Base, SQLiteUUID
|
||||||
|
|
||||||
|
# --- Test SQLiteUUID Type ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqliteuuid_process_bind_param_none():
|
||||||
|
"""Test SQLiteUUID.process_bind_param with None returns None"""
|
||||||
|
uuid_type = SQLiteUUID()
|
||||||
|
assert uuid_type.process_bind_param(None, None) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqliteuuid_process_bind_param_valid_uuid():
|
||||||
|
"""Test SQLiteUUID.process_bind_param with valid UUID returns string"""
|
||||||
|
uuid_type = SQLiteUUID()
|
||||||
|
test_uuid = uuid.uuid4()
|
||||||
|
assert uuid_type.process_bind_param(test_uuid, None) == str(test_uuid)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqliteuuid_process_bind_param_valid_string():
|
||||||
|
"""Test SQLiteUUID.process_bind_param with valid UUID string returns string"""
|
||||||
|
uuid_type = SQLiteUUID()
|
||||||
|
test_uuid_str = "550e8400-e29b-41d4-a716-446655440000"
|
||||||
|
assert uuid_type.process_bind_param(test_uuid_str, None) == test_uuid_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqliteuuid_process_bind_param_invalid_string():
|
||||||
|
"""Test SQLiteUUID.process_bind_param raises ValueError for invalid UUID"""
|
||||||
|
uuid_type = SQLiteUUID()
|
||||||
|
with pytest.raises(ValueError, match="Invalid UUID string format"):
|
||||||
|
uuid_type.process_bind_param("invalid-uuid", None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqliteuuid_process_result_value_none():
|
||||||
|
"""Test SQLiteUUID.process_result_value with None returns None"""
|
||||||
|
uuid_type = SQLiteUUID()
|
||||||
|
assert uuid_type.process_result_value(None, None) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqliteuuid_process_result_value_valid_string():
|
||||||
|
"""Test SQLiteUUID.process_result_value converts string to UUID"""
|
||||||
|
uuid_type = SQLiteUUID()
|
||||||
|
test_uuid = uuid.uuid4()
|
||||||
|
result = uuid_type.process_result_value(str(test_uuid), None)
|
||||||
|
assert isinstance(result, uuid.UUID)
|
||||||
|
assert result == test_uuid
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqliteuuid_process_result_value_uuid_object():
|
||||||
|
"""Test SQLiteUUID.process_result_value: UUID object returns itself."""
|
||||||
|
uuid_type = SQLiteUUID()
|
||||||
|
test_uuid = uuid.uuid4()
|
||||||
|
result = uuid_type.process_result_value(test_uuid, None)
|
||||||
|
assert isinstance(result, uuid.UUID)
|
||||||
|
assert result is test_uuid # Ensure it's the same object, not a new one
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqliteuuid_compare_values_none():
|
||||||
|
"""Test SQLiteUUID.compare_values handles None values"""
|
||||||
|
uuid_type = SQLiteUUID()
|
||||||
|
assert uuid_type.compare_values(None, None) is True
|
||||||
|
assert uuid_type.compare_values(None, uuid.uuid4()) is False
|
||||||
|
assert uuid_type.compare_values(uuid.uuid4(), None) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqliteuuid_compare_values_uuid():
|
||||||
|
"""Test SQLiteUUID.compare_values compares UUIDs as strings"""
|
||||||
|
uuid_type = SQLiteUUID()
|
||||||
|
test_uuid = uuid.uuid4()
|
||||||
|
assert uuid_type.compare_values(test_uuid, test_uuid) is True
|
||||||
|
assert uuid_type.compare_values(test_uuid, uuid.uuid4()) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqlite_uuid_comparison():
|
||||||
|
"""Test SQLiteUUID comparison functionality (moved from db_mocks.py)"""
|
||||||
|
uuid_type = SQLiteUUID()
|
||||||
|
|
||||||
|
# Test equal UUIDs
|
||||||
|
uuid1 = uuid.uuid4()
|
||||||
|
uuid2 = uuid.UUID(str(uuid1))
|
||||||
|
assert uuid_type.compare_values(uuid1, uuid2) is True
|
||||||
|
|
||||||
|
# Test UUID vs string
|
||||||
|
assert uuid_type.compare_values(uuid1, str(uuid1)) is True
|
||||||
|
|
||||||
|
# Test None comparisons
|
||||||
|
assert uuid_type.compare_values(None, None) is True
|
||||||
|
assert uuid_type.compare_values(uuid1, None) is False
|
||||||
|
assert uuid_type.compare_values(None, uuid1) is False
|
||||||
|
|
||||||
|
# Test different UUIDs
|
||||||
|
uuid3 = uuid.uuid4()
|
||||||
|
assert uuid_type.compare_values(uuid1, uuid3) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqlite_uuid_binding():
|
||||||
|
"""Test SQLiteUUID binding parameter handling (moved from db_mocks.py)"""
|
||||||
|
uuid_type = SQLiteUUID()
|
||||||
|
|
||||||
|
# Test UUID object binding
|
||||||
|
uuid_obj = uuid.uuid4()
|
||||||
|
assert uuid_type.process_bind_param(uuid_obj, None) == str(uuid_obj)
|
||||||
|
|
||||||
|
# Test valid UUID string binding
|
||||||
|
uuid_str = str(uuid.uuid4())
|
||||||
|
assert uuid_type.process_bind_param(uuid_str, None) == uuid_str
|
||||||
|
|
||||||
|
# Test None handling
|
||||||
|
assert uuid_type.process_bind_param(None, None) is None
|
||||||
|
|
||||||
|
# Test invalid UUID string
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
uuid_type.process_bind_param("invalid-uuid", None)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test UUID Column Type Configuration ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_uuid_column_type_default():
|
||||||
|
"""Test UUID_COLUMN_TYPE uses SQLiteUUID in test environment"""
|
||||||
|
assert isinstance(UUID_COLUMN_TYPE, SQLiteUUID)
|
||||||
|
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"MOCK_AUTH": "false"})
|
||||||
|
def test_uuid_column_type_postgres():
|
||||||
|
"""Test UUID_COLUMN_TYPE uses Postgres UUID when MOCK_AUTH=false"""
|
||||||
|
# Need to re-import to get the patched environment
|
||||||
|
from importlib import reload
|
||||||
|
|
||||||
|
from app import models
|
||||||
|
|
||||||
|
reload(models.db)
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID as PostgresUUID
|
||||||
|
|
||||||
|
from app.models.db import UUID_COLUMN_TYPE
|
||||||
|
|
||||||
|
assert isinstance(UUID_COLUMN_TYPE, PostgresUUID)
|
||||||
0
tests/routers/__init__.py
Normal file
0
tests/routers/__init__.py
Normal file
43
tests/routers/mocks.py
Normal file
43
tests/routers/mocks.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from app.iptv.scheduler import StreamScheduler
|
||||||
|
|
||||||
|
|
||||||
|
class MockScheduler:
|
||||||
|
"""Base mock APScheduler instance"""
|
||||||
|
|
||||||
|
running = True
|
||||||
|
start = Mock()
|
||||||
|
shutdown = Mock()
|
||||||
|
add_job = Mock()
|
||||||
|
remove_job = Mock()
|
||||||
|
get_job = Mock(return_value=None)
|
||||||
|
|
||||||
|
def __init__(self, running=True):
|
||||||
|
self.running = running
|
||||||
|
|
||||||
|
|
||||||
|
def create_trigger_mock(triggered_ref: dict) -> callable:
|
||||||
|
"""Create a mock trigger function that updates a reference when called"""
|
||||||
|
|
||||||
|
def trigger_mock():
|
||||||
|
triggered_ref["value"] = True
|
||||||
|
|
||||||
|
return trigger_mock
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_get_scheduler(
|
||||||
|
request: Request, scheduler_class=MockScheduler, running=True, **kwargs
|
||||||
|
) -> StreamScheduler:
|
||||||
|
"""Mock dependency for get_scheduler with customization options"""
|
||||||
|
scheduler = StreamScheduler()
|
||||||
|
mock_scheduler = scheduler_class(running=running)
|
||||||
|
|
||||||
|
# Apply any additional attributes/methods
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
setattr(mock_scheduler, key, value)
|
||||||
|
|
||||||
|
scheduler.scheduler = mock_scheduler
|
||||||
|
return scheduler
|
||||||
101
tests/routers/test_auth.py
Normal file
101
tests/routers/test_auth.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from app.main import app
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_successful_auth():
|
||||||
|
return {
|
||||||
|
"AccessToken": "mock_access_token",
|
||||||
|
"IdToken": "mock_id_token",
|
||||||
|
"RefreshToken": "mock_refresh_token",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_successful_auth_no_refresh():
|
||||||
|
return {"AccessToken": "mock_access_token", "IdToken": "mock_id_token"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_signin_success(mock_successful_auth):
|
||||||
|
"""Test successful signin with all tokens"""
|
||||||
|
with patch("app.routers.auth.initiate_auth", return_value=mock_successful_auth):
|
||||||
|
response = client.post(
|
||||||
|
"/auth/signin", json={"username": "testuser", "password": "testpass"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["access_token"] == "mock_access_token"
|
||||||
|
assert data["id_token"] == "mock_id_token"
|
||||||
|
assert data["refresh_token"] == "mock_refresh_token"
|
||||||
|
assert data["token_type"] == "Bearer"
|
||||||
|
|
||||||
|
|
||||||
|
def test_signin_success_no_refresh(mock_successful_auth_no_refresh):
|
||||||
|
"""Test successful signin without refresh token"""
|
||||||
|
with patch(
|
||||||
|
"app.routers.auth.initiate_auth", return_value=mock_successful_auth_no_refresh
|
||||||
|
):
|
||||||
|
response = client.post(
|
||||||
|
"/auth/signin", json={"username": "testuser", "password": "testpass"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["access_token"] == "mock_access_token"
|
||||||
|
assert data["id_token"] == "mock_id_token"
|
||||||
|
assert data["refresh_token"] is None
|
||||||
|
assert data["token_type"] == "Bearer"
|
||||||
|
|
||||||
|
|
||||||
|
def test_signin_invalid_input():
|
||||||
|
"""Test signin with invalid input format"""
|
||||||
|
# Missing password
|
||||||
|
response = client.post("/auth/signin", json={"username": "testuser"})
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
# Missing username
|
||||||
|
response = client.post("/auth/signin", json={"password": "testpass"})
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
# Empty payload
|
||||||
|
response = client.post("/auth/signin", json={})
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
def test_signin_auth_failure():
|
||||||
|
"""Test signin with authentication failure"""
|
||||||
|
with patch("app.routers.auth.initiate_auth") as mock_auth:
|
||||||
|
mock_auth.side_effect = HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid username or password",
|
||||||
|
)
|
||||||
|
response = client.post(
|
||||||
|
"/auth/signin", json={"username": "testuser", "password": "wrongpass"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
data = response.json()
|
||||||
|
assert data["detail"] == "Invalid username or password"
|
||||||
|
|
||||||
|
|
||||||
|
def test_signin_user_not_found():
|
||||||
|
"""Test signin with non-existent user"""
|
||||||
|
with patch("app.routers.auth.initiate_auth") as mock_auth:
|
||||||
|
mock_auth.side_effect = HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||||
|
)
|
||||||
|
response = client.post(
|
||||||
|
"/auth/signin", json={"username": "nonexistent", "password": "testpass"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
data = response.json()
|
||||||
|
assert data["detail"] == "User not found"
|
||||||
1588
tests/routers/test_channels.py
Normal file
1588
tests/routers/test_channels.py
Normal file
File diff suppressed because it is too large
Load Diff
461
tests/routers/test_groups.py
Normal file
461
tests/routers/test_groups.py
Normal file
@@ -0,0 +1,461 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from fastapi import status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.auth.dependencies import get_current_user
|
||||||
|
from app.utils.database import get_db
|
||||||
|
|
||||||
|
# Import mocks and fixtures
|
||||||
|
from tests.utils.auth_test_fixtures import (
|
||||||
|
admin_user_client,
|
||||||
|
db_session,
|
||||||
|
non_admin_user_client,
|
||||||
|
)
|
||||||
|
from tests.utils.db_mocks import (
|
||||||
|
MockChannelDB,
|
||||||
|
MockGroup,
|
||||||
|
create_mock_priorities_and_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Test Cases For Group Creation ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_group_success(db_session: Session, admin_user_client):
|
||||||
|
group_data = {"name": "Test Group", "sort_order": 1}
|
||||||
|
response = admin_user_client.post("/groups/", json=group_data)
|
||||||
|
assert response.status_code == status.HTTP_201_CREATED
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == "Test Group"
|
||||||
|
assert data["sort_order"] == 1
|
||||||
|
assert "id" in data
|
||||||
|
assert "created_at" in data
|
||||||
|
assert "updated_at" in data
|
||||||
|
|
||||||
|
# Verify in DB
|
||||||
|
db_group = (
|
||||||
|
db_session.query(MockGroup).filter(MockGroup.name == "Test Group").first()
|
||||||
|
)
|
||||||
|
assert db_group is not None
|
||||||
|
assert db_group.sort_order == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_group_duplicate(db_session: Session, admin_user_client):
|
||||||
|
# Create initial group
|
||||||
|
initial_group = MockGroup(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Duplicate Group",
|
||||||
|
sort_order=1,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db_session.add(initial_group)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Attempt to create duplicate
|
||||||
|
response = admin_user_client.post(
|
||||||
|
"/groups/", json={"name": "Duplicate Group", "sort_order": 2}
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_409_CONFLICT
|
||||||
|
assert "already exists" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_group_forbidden_for_non_admin(
|
||||||
|
db_session: Session, non_admin_user_client
|
||||||
|
):
|
||||||
|
response = non_admin_user_client.post(
|
||||||
|
"/groups/", json={"name": "Forbidden Group", "sort_order": 1}
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
assert "required roles" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test Cases For Get Group ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_group_success(db_session: Session, admin_user_client):
|
||||||
|
# Create a group first
|
||||||
|
test_group = MockGroup(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Get Me Group",
|
||||||
|
sort_order=1,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db_session.add(test_group)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
response = admin_user_client.get(f"/groups/{test_group.id}")
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] == str(test_group.id)
|
||||||
|
assert data["name"] == "Get Me Group"
|
||||||
|
assert data["sort_order"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_group_not_found(db_session: Session, admin_user_client):
|
||||||
|
random_uuid = uuid.uuid4()
|
||||||
|
response = admin_user_client.get(f"/groups/{random_uuid}")
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert "Group not found" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test Cases For Update Group ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_group_success(db_session: Session, admin_user_client):
|
||||||
|
# Create initial group
|
||||||
|
group_id = uuid.uuid4()
|
||||||
|
test_group = MockGroup(
|
||||||
|
id=group_id,
|
||||||
|
name="Update Me",
|
||||||
|
sort_order=1,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db_session.add(test_group)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
update_data = {"name": "Updated Name", "sort_order": 2}
|
||||||
|
response = admin_user_client.put(f"/groups/{group_id}", json=update_data)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == "Updated Name"
|
||||||
|
assert data["sort_order"] == 2
|
||||||
|
|
||||||
|
# Verify in DB
|
||||||
|
db_group = db_session.query(MockGroup).filter(MockGroup.id == group_id).first()
|
||||||
|
assert db_group.name == "Updated Name"
|
||||||
|
assert db_group.sort_order == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_group_conflict(db_session: Session, admin_user_client):
|
||||||
|
# Create two groups
|
||||||
|
group1 = MockGroup(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Group One",
|
||||||
|
sort_order=1,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
group2 = MockGroup(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Group Two",
|
||||||
|
sort_order=2,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db_session.add_all([group1, group2])
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Try to rename group2 to conflict with group1
|
||||||
|
response = admin_user_client.put(f"/groups/{group2.id}", json={"name": "Group One"})
|
||||||
|
assert response.status_code == status.HTTP_409_CONFLICT
|
||||||
|
assert "already exists" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_group_not_found(db_session: Session, admin_user_client):
|
||||||
|
random_uuid = uuid.uuid4()
|
||||||
|
response = admin_user_client.put(
|
||||||
|
f"/groups/{random_uuid}", json={"name": "Non-existent"}
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert "Group not found" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_group_forbidden_for_non_admin(
|
||||||
|
db_session: Session, non_admin_user_client, admin_user_client
|
||||||
|
):
|
||||||
|
# Create group with admin
|
||||||
|
group_id = uuid.uuid4()
|
||||||
|
test_group = MockGroup(
|
||||||
|
id=group_id,
|
||||||
|
name="Admin Created",
|
||||||
|
sort_order=1,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db_session.add(test_group)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Attempt update with non-admin
|
||||||
|
response = non_admin_user_client.put(
|
||||||
|
f"/groups/{group_id}", json={"name": "Non-Admin Update"}
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
assert "required roles" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test Cases For Delete Group ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_all_groups_success(db_session, admin_user_client):
|
||||||
|
"""Test reset groups endpoint"""
|
||||||
|
# Create test data
|
||||||
|
group1_id = create_mock_priorities_and_group(db_session, [], "Group A")
|
||||||
|
group2_id = create_mock_priorities_and_group(db_session, [], "Group B")
|
||||||
|
|
||||||
|
# Add channel to group2
|
||||||
|
channel_data = [
|
||||||
|
{
|
||||||
|
"group-title": "Group A",
|
||||||
|
"tvg_id": "channel1.tv",
|
||||||
|
"name": "Channel One",
|
||||||
|
"url": ["http://test.com", "http://example.com"],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
admin_user_client.post("/channels/bulk-upload", json=channel_data)
|
||||||
|
|
||||||
|
# Reset groups
|
||||||
|
response = admin_user_client.delete("/groups")
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert response.json()["deleted"] == 1 # Only group2 should be deleted
|
||||||
|
assert response.json()["skipped"] == 1 # group1 has channels
|
||||||
|
|
||||||
|
# Verify group2 deleted, group1 remains
|
||||||
|
assert (
|
||||||
|
db_session.query(MockGroup).filter(MockGroup.id == group1_id).first()
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
assert db_session.query(MockGroup).filter(MockGroup.id == group2_id).first() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_all_groups_forbidden_for_non_admin(db_session, non_admin_user_client):
|
||||||
|
"""Test reset groups requires admin role"""
|
||||||
|
response = non_admin_user_client.delete("/groups")
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_group_success(db_session: Session, admin_user_client):
|
||||||
|
# Create group
|
||||||
|
group_id = uuid.uuid4()
|
||||||
|
test_group = MockGroup(
|
||||||
|
id=group_id,
|
||||||
|
name="Delete Me",
|
||||||
|
sort_order=1,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db_session.add(test_group)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Verify exists before delete
|
||||||
|
assert (
|
||||||
|
db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
response = admin_user_client.delete(f"/groups/{group_id}")
|
||||||
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||||
|
|
||||||
|
# Verify deleted
|
||||||
|
assert db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_group_with_channels_fails(db_session: Session, admin_user_client):
|
||||||
|
# Create group with channel
|
||||||
|
group_id = uuid.uuid4()
|
||||||
|
test_group = MockGroup(
|
||||||
|
id=group_id,
|
||||||
|
name="Group With Channels",
|
||||||
|
sort_order=1,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db_session.add(test_group)
|
||||||
|
|
||||||
|
# Create channel in this group
|
||||||
|
test_channel = MockChannelDB(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
tvg_id="channel1.tv",
|
||||||
|
name="Channel 1",
|
||||||
|
group_id=group_id,
|
||||||
|
tvg_name="Channel1",
|
||||||
|
tvg_logo="logo.png",
|
||||||
|
)
|
||||||
|
db_session.add(test_channel)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
response = admin_user_client.delete(f"/groups/{group_id}")
|
||||||
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
assert "existing channels" in response.json()["detail"]
|
||||||
|
|
||||||
|
# Verify group still exists
|
||||||
|
assert (
|
||||||
|
db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_group_not_found(db_session: Session, admin_user_client):
|
||||||
|
random_uuid = uuid.uuid4()
|
||||||
|
response = admin_user_client.delete(f"/groups/{random_uuid}")
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert "Group not found" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_group_forbidden_for_non_admin(
|
||||||
|
db_session: Session, non_admin_user_client, admin_user_client
|
||||||
|
):
|
||||||
|
# Create group with admin
|
||||||
|
group_id = uuid.uuid4()
|
||||||
|
test_group = MockGroup(
|
||||||
|
id=group_id,
|
||||||
|
name="Admin Created",
|
||||||
|
sort_order=1,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db_session.add(test_group)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Attempt delete with non-admin
|
||||||
|
response = non_admin_user_client.delete(f"/groups/{group_id}")
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
assert "required roles" in response.json()["detail"]
|
||||||
|
|
||||||
|
# Verify group still exists
|
||||||
|
assert (
|
||||||
|
db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test Cases For List Groups ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_groups_empty(db_session: Session, admin_user_client):
|
||||||
|
response = admin_user_client.get("/groups/")
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert response.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_groups_with_data(db_session: Session, admin_user_client):
|
||||||
|
# Create some groups
|
||||||
|
groups = [
|
||||||
|
MockGroup(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name=f"Group {i}",
|
||||||
|
sort_order=i,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
db_session.add_all(groups)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
response = admin_user_client.get("/groups/")
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert len(data) == 3
|
||||||
|
assert data[0]["sort_order"] == 0 # Should be sorted by sort_order
|
||||||
|
assert data[1]["sort_order"] == 1
|
||||||
|
assert data[2]["sort_order"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test Cases For Sort Order Updates ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_group_sort_order_success(db_session: Session, admin_user_client):
|
||||||
|
# Create group
|
||||||
|
group_id = uuid.uuid4()
|
||||||
|
test_group = MockGroup(
|
||||||
|
id=group_id,
|
||||||
|
name="Sort Me",
|
||||||
|
sort_order=1,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db_session.add(test_group)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
response = admin_user_client.put(f"/groups/{group_id}/sort", json={"sort_order": 5})
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["sort_order"] == 5
|
||||||
|
|
||||||
|
# Verify in DB
|
||||||
|
db_group = db_session.query(MockGroup).filter(MockGroup.id == group_id).first()
|
||||||
|
assert db_group.sort_order == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_group_sort_order_not_found(db_session: Session, admin_user_client):
|
||||||
|
"""Test that updating sort order for non-existent group returns 404"""
|
||||||
|
random_uuid = uuid.uuid4()
|
||||||
|
response = admin_user_client.put(
|
||||||
|
f"/groups/{random_uuid}/sort", json={"sort_order": 5}
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert "Group not found" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_bulk_update_sort_orders_success(db_session: Session, admin_user_client):
|
||||||
|
# Create groups
|
||||||
|
groups = [
|
||||||
|
MockGroup(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name=f"Group {i}",
|
||||||
|
sort_order=i,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
print(groups)
|
||||||
|
db_session.add_all(groups)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Bulk update sort orders (reverse order)
|
||||||
|
bulk_data = {
|
||||||
|
"groups": [
|
||||||
|
{"group_id": str(groups[0].id), "sort_order": 2},
|
||||||
|
{"group_id": str(groups[1].id), "sort_order": 1},
|
||||||
|
{"group_id": str(groups[2].id), "sort_order": 0},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
response = admin_user_client.post("/groups/reorder", json=bulk_data)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert len(data) == 3
|
||||||
|
|
||||||
|
# Create a dictionary for easy lookup of returned group data by ID
|
||||||
|
returned_groups_map = {item["id"]: item for item in data}
|
||||||
|
|
||||||
|
# Verify each group has its expected new sort_order
|
||||||
|
assert returned_groups_map[str(groups[0].id)]["sort_order"] == 2
|
||||||
|
assert returned_groups_map[str(groups[1].id)]["sort_order"] == 1
|
||||||
|
assert returned_groups_map[str(groups[2].id)]["sort_order"] == 0
|
||||||
|
|
||||||
|
# Verify in DB
|
||||||
|
db_groups = db_session.query(MockGroup).order_by(MockGroup.sort_order).all()
|
||||||
|
assert db_groups[0].sort_order == 2
|
||||||
|
assert db_groups[1].sort_order == 1
|
||||||
|
assert db_groups[2].sort_order == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_bulk_update_sort_orders_invalid_group(db_session: Session, admin_user_client):
|
||||||
|
# Create one group
|
||||||
|
group_id = uuid.uuid4()
|
||||||
|
test_group = MockGroup(
|
||||||
|
id=group_id,
|
||||||
|
name="Valid Group",
|
||||||
|
sort_order=1,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db_session.add(test_group)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Try to update with invalid group
|
||||||
|
bulk_data = {
|
||||||
|
"groups": [
|
||||||
|
{"group_id": str(group_id), "sort_order": 2},
|
||||||
|
{"group_id": str(uuid.uuid4()), "sort_order": 1}, # Invalid group
|
||||||
|
]
|
||||||
|
}
|
||||||
|
response = admin_user_client.post("/groups/reorder", json=bulk_data)
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert "not found" in response.json()["detail"]
|
||||||
|
|
||||||
|
# Verify original sort order unchanged
|
||||||
|
db_group = db_session.query(MockGroup).filter(MockGroup.id == group_id).first()
|
||||||
|
assert db_group.sort_order == 1
|
||||||
261
tests/routers/test_playlist.py
Normal file
261
tests/routers/test_playlist.py
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.auth.dependencies import get_current_user
|
||||||
|
|
||||||
|
# Import the router we're testing
|
||||||
|
from app.routers.playlist import (
|
||||||
|
ProcessStatus,
|
||||||
|
ValidationProcessResponse,
|
||||||
|
ValidationResultResponse,
|
||||||
|
router,
|
||||||
|
validation_processes,
|
||||||
|
)
|
||||||
|
from app.utils.database import get_db
|
||||||
|
|
||||||
|
# Import mocks and fixtures
|
||||||
|
from tests.utils.auth_test_fixtures import (
|
||||||
|
admin_user_client,
|
||||||
|
db_session,
|
||||||
|
non_admin_user_client,
|
||||||
|
)
|
||||||
|
from tests.utils.db_mocks import MockChannelDB
|
||||||
|
|
||||||
|
# --- Test Fixtures ---
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_stream_manager():
|
||||||
|
with patch("app.routers.playlist.StreamManager") as mock:
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test Cases For Stream Validation ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_start_stream_validation_success(
|
||||||
|
db_session: Session, admin_user_client, mock_stream_manager
|
||||||
|
):
|
||||||
|
"""Test starting a stream validation process"""
|
||||||
|
mock_instance = mock_stream_manager.return_value
|
||||||
|
mock_instance.validate_and_select_stream.return_value = "http://valid.stream.url"
|
||||||
|
|
||||||
|
response = admin_user_client.post(
|
||||||
|
"/playlist/validate-streams", json={"channel_id": "test-channel"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||||
|
data = response.json()
|
||||||
|
assert "process_id" in data
|
||||||
|
assert data["status"] == ProcessStatus.PENDING
|
||||||
|
assert data["message"] == "Validation process started"
|
||||||
|
|
||||||
|
# Verify process was added to tracking
|
||||||
|
process_id = data["process_id"]
|
||||||
|
assert process_id in validation_processes
|
||||||
|
# In test environment, background tasks run synchronously so status may be COMPLETED
|
||||||
|
assert validation_processes[process_id]["status"] in [
|
||||||
|
ProcessStatus.PENDING,
|
||||||
|
ProcessStatus.COMPLETED,
|
||||||
|
]
|
||||||
|
assert validation_processes[process_id]["channel_id"] == "test-channel"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_validation_status_pending(db_session: Session, admin_user_client):
|
||||||
|
"""Test checking status of pending validation"""
|
||||||
|
process_id = str(uuid.uuid4())
|
||||||
|
validation_processes[process_id] = {
|
||||||
|
"status": ProcessStatus.PENDING,
|
||||||
|
"channel_id": "test-channel",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = admin_user_client.get(f"/playlist/validate-streams/{process_id}")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["process_id"] == process_id
|
||||||
|
assert data["status"] == ProcessStatus.PENDING
|
||||||
|
assert data["working_streams"] is None
|
||||||
|
assert data["error"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_validation_status_completed(db_session: Session, admin_user_client):
|
||||||
|
"""Test checking status of completed validation"""
|
||||||
|
process_id = str(uuid.uuid4())
|
||||||
|
validation_processes[process_id] = {
|
||||||
|
"status": ProcessStatus.COMPLETED,
|
||||||
|
"channel_id": "test-channel",
|
||||||
|
"result": {
|
||||||
|
"working_streams": [
|
||||||
|
{"channel_id": "test-channel", "stream_url": "http://valid.stream.url"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
response = admin_user_client.get(f"/playlist/validate-streams/{process_id}")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["process_id"] == process_id
|
||||||
|
assert data["status"] == ProcessStatus.COMPLETED
|
||||||
|
assert len(data["working_streams"]) == 1
|
||||||
|
assert data["working_streams"][0]["channel_id"] == "test-channel"
|
||||||
|
assert data["working_streams"][0]["stream_url"] == "http://valid.stream.url"
|
||||||
|
assert data["error"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_validation_status_completed_with_error(
|
||||||
|
db_session: Session, admin_user_client
|
||||||
|
):
|
||||||
|
"""Test checking status of completed validation with error"""
|
||||||
|
process_id = str(uuid.uuid4())
|
||||||
|
validation_processes[process_id] = {
|
||||||
|
"status": ProcessStatus.COMPLETED,
|
||||||
|
"channel_id": "test-channel",
|
||||||
|
"error": "No working streams found for channel test-channel",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = admin_user_client.get(f"/playlist/validate-streams/{process_id}")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["process_id"] == process_id
|
||||||
|
assert data["status"] == ProcessStatus.COMPLETED
|
||||||
|
assert data["working_streams"] is None
|
||||||
|
assert data["error"] == "No working streams found for channel test-channel"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_validation_status_failed(db_session: Session, admin_user_client):
|
||||||
|
"""Test checking status of failed validation"""
|
||||||
|
process_id = str(uuid.uuid4())
|
||||||
|
validation_processes[process_id] = {
|
||||||
|
"status": ProcessStatus.FAILED,
|
||||||
|
"channel_id": "test-channel",
|
||||||
|
"error": "Validation error occurred",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = admin_user_client.get(f"/playlist/validate-streams/{process_id}")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["process_id"] == process_id
|
||||||
|
assert data["status"] == ProcessStatus.FAILED
|
||||||
|
assert data["working_streams"] is None
|
||||||
|
assert data["error"] == "Validation error occurred"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_validation_status_not_found(db_session: Session, admin_user_client):
|
||||||
|
"""Test checking status of non-existent process"""
|
||||||
|
random_uuid = str(uuid.uuid4())
|
||||||
|
response = admin_user_client.get(f"/playlist/validate-streams/{random_uuid}")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert "Process not found" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_stream_validation_success(mock_stream_manager, db_session):
|
||||||
|
"""Test the background validation task success case"""
|
||||||
|
process_id = str(uuid.uuid4())
|
||||||
|
validation_processes[process_id] = {
|
||||||
|
"status": ProcessStatus.PENDING,
|
||||||
|
"channel_id": "test-channel",
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_instance = mock_stream_manager.return_value
|
||||||
|
mock_instance.validate_and_select_stream.return_value = "http://valid.stream.url"
|
||||||
|
|
||||||
|
from app.routers.playlist import run_stream_validation
|
||||||
|
|
||||||
|
run_stream_validation(process_id, "test-channel", db_session)
|
||||||
|
|
||||||
|
assert validation_processes[process_id]["status"] == ProcessStatus.COMPLETED
|
||||||
|
assert len(validation_processes[process_id]["result"]["working_streams"]) == 1
|
||||||
|
assert (
|
||||||
|
validation_processes[process_id]["result"]["working_streams"][0].channel_id
|
||||||
|
== "test-channel"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
validation_processes[process_id]["result"]["working_streams"][0].stream_url
|
||||||
|
== "http://valid.stream.url"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_stream_validation_failure(mock_stream_manager, db_session):
|
||||||
|
"""Test the background validation task failure case"""
|
||||||
|
process_id = str(uuid.uuid4())
|
||||||
|
validation_processes[process_id] = {
|
||||||
|
"status": ProcessStatus.PENDING,
|
||||||
|
"channel_id": "test-channel",
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_instance = mock_stream_manager.return_value
|
||||||
|
mock_instance.validate_and_select_stream.return_value = None
|
||||||
|
|
||||||
|
from app.routers.playlist import run_stream_validation
|
||||||
|
|
||||||
|
run_stream_validation(process_id, "test-channel", db_session)
|
||||||
|
|
||||||
|
assert validation_processes[process_id]["status"] == ProcessStatus.COMPLETED
|
||||||
|
assert "error" in validation_processes[process_id]
|
||||||
|
assert "No working streams found" in validation_processes[process_id]["error"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_stream_validation_exception(mock_stream_manager, db_session):
|
||||||
|
"""Test the background validation task exception case"""
|
||||||
|
process_id = str(uuid.uuid4())
|
||||||
|
validation_processes[process_id] = {
|
||||||
|
"status": ProcessStatus.PENDING,
|
||||||
|
"channel_id": "test-channel",
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_instance = mock_stream_manager.return_value
|
||||||
|
mock_instance.validate_and_select_stream.side_effect = Exception("Test error")
|
||||||
|
|
||||||
|
from app.routers.playlist import run_stream_validation
|
||||||
|
|
||||||
|
run_stream_validation(process_id, "test-channel", db_session)
|
||||||
|
|
||||||
|
assert validation_processes[process_id]["status"] == ProcessStatus.FAILED
|
||||||
|
assert "error" in validation_processes[process_id]
|
||||||
|
assert "Test error" in validation_processes[process_id]["error"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_start_stream_validation_no_channel_id(
|
||||||
|
db_session: Session, admin_user_client, mock_stream_manager
|
||||||
|
):
|
||||||
|
"""Test starting validation without channel_id"""
|
||||||
|
response = admin_user_client.post("/playlist/validate-streams", json={})
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||||
|
data = response.json()
|
||||||
|
assert "process_id" in data
|
||||||
|
assert data["status"] == ProcessStatus.PENDING
|
||||||
|
|
||||||
|
# Verify process was added to tracking
|
||||||
|
process_id = data["process_id"]
|
||||||
|
assert process_id in validation_processes
|
||||||
|
assert validation_processes[process_id]["status"] in [
|
||||||
|
ProcessStatus.PENDING,
|
||||||
|
ProcessStatus.COMPLETED,
|
||||||
|
]
|
||||||
|
assert validation_processes[process_id]["channel_id"] is None
|
||||||
|
assert "not yet implemented" in validation_processes[process_id].get("error", "")
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_stream_validation_no_channel_id(mock_stream_manager, db_session):
|
||||||
|
"""Test background validation without channel_id"""
|
||||||
|
process_id = str(uuid.uuid4())
|
||||||
|
validation_processes[process_id] = {"status": ProcessStatus.PENDING}
|
||||||
|
|
||||||
|
from app.routers.playlist import run_stream_validation
|
||||||
|
|
||||||
|
run_stream_validation(process_id, None, db_session)
|
||||||
|
|
||||||
|
assert validation_processes[process_id]["status"] == ProcessStatus.COMPLETED
|
||||||
|
assert "error" in validation_processes[process_id]
|
||||||
|
assert "not yet implemented" in validation_processes[process_id]["error"]
|
||||||
241
tests/routers/test_priorities.py
Normal file
241
tests/routers/test_priorities.py
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from fastapi import status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.routers.priorities import router as priorities_router
|
||||||
|
|
||||||
|
# Import fixtures and mocks
|
||||||
|
from tests.utils.auth_test_fixtures import (
|
||||||
|
admin_user_client,
|
||||||
|
db_session,
|
||||||
|
non_admin_user_client,
|
||||||
|
)
|
||||||
|
from tests.utils.db_mocks import (
|
||||||
|
MockChannelDB,
|
||||||
|
MockChannelURL,
|
||||||
|
MockGroup,
|
||||||
|
MockPriority,
|
||||||
|
create_mock_priorities_and_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Test Cases For Priority Creation ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_priority_success(db_session: Session, admin_user_client):
|
||||||
|
priority_data = {"id": 100, "description": "Test Priority"}
|
||||||
|
response = admin_user_client.post("/priorities/", json=priority_data)
|
||||||
|
assert response.status_code == status.HTTP_201_CREATED
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] == 100
|
||||||
|
assert data["description"] == "Test Priority"
|
||||||
|
|
||||||
|
# Verify in DB
|
||||||
|
db_priority = db_session.get(MockPriority, 100)
|
||||||
|
assert db_priority is not None
|
||||||
|
assert db_priority.description == "Test Priority"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_priority_duplicate(db_session: Session, admin_user_client):
|
||||||
|
# Create initial priority
|
||||||
|
priority_data = {"id": 100, "description": "Original Priority"}
|
||||||
|
response1 = admin_user_client.post("/priorities/", json=priority_data)
|
||||||
|
assert response1.status_code == status.HTTP_201_CREATED
|
||||||
|
|
||||||
|
# Attempt to create with same ID
|
||||||
|
response2 = admin_user_client.post("/priorities/", json=priority_data)
|
||||||
|
assert response2.status_code == status.HTTP_409_CONFLICT
|
||||||
|
assert "already exists" in response2.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_priority_forbidden_for_non_admin(
|
||||||
|
db_session: Session, non_admin_user_client
|
||||||
|
):
|
||||||
|
priority_data = {"id": 100, "description": "Test Priority"}
|
||||||
|
response = non_admin_user_client.post("/priorities/", json=priority_data)
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
assert "required roles" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test Cases For List Priorities ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_priorities_empty(db_session: Session, admin_user_client):
|
||||||
|
response = admin_user_client.get("/priorities/")
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert response.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_priorities_with_data(db_session: Session, admin_user_client):
|
||||||
|
# Create some test priorities
|
||||||
|
priorities = [
|
||||||
|
MockPriority(id=100, description="High"),
|
||||||
|
MockPriority(id=200, description="Medium"),
|
||||||
|
MockPriority(id=300, description="Low"),
|
||||||
|
]
|
||||||
|
for priority in priorities:
|
||||||
|
db_session.add(priority)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
response = admin_user_client.get("/priorities/")
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert len(data) == 3
|
||||||
|
assert data[0]["id"] == 100
|
||||||
|
assert data[0]["description"] == "High"
|
||||||
|
assert data[1]["id"] == 200
|
||||||
|
assert data[1]["description"] == "Medium"
|
||||||
|
assert data[2]["id"] == 300
|
||||||
|
assert data[2]["description"] == "Low"
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_priorities_forbidden_for_non_admin(
|
||||||
|
db_session: Session, non_admin_user_client
|
||||||
|
):
|
||||||
|
response = non_admin_user_client.get("/priorities/")
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
assert "required roles" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test Cases For Get Priority ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_priority_success(db_session: Session, admin_user_client):
|
||||||
|
# Create a test priority
|
||||||
|
priority = MockPriority(id=100, description="Test Priority")
|
||||||
|
db_session.add(priority)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
response = admin_user_client.get("/priorities/100")
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] == 100
|
||||||
|
assert data["description"] == "Test Priority"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_priority_not_found(db_session: Session, admin_user_client):
|
||||||
|
response = admin_user_client.get("/priorities/999")
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert "Priority not found" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_priority_forbidden_for_non_admin(
|
||||||
|
db_session: Session, non_admin_user_client
|
||||||
|
):
|
||||||
|
response = non_admin_user_client.get("/priorities/100")
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
assert "required roles" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test Cases For Delete Priority ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_all_priorities_success(db_session, admin_user_client):
|
||||||
|
"""Test reset priorities endpoint"""
|
||||||
|
# Create test data
|
||||||
|
priorities = [(100, "High"), (200, "Medium"), (300, "Low")]
|
||||||
|
for id, desc in priorities:
|
||||||
|
db_session.add(MockPriority(id=id, description=desc))
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Create channel using priority 100
|
||||||
|
create_mock_priorities_and_group(db_session, [], "Test Group")
|
||||||
|
channel_data = [
|
||||||
|
{
|
||||||
|
"group-title": "Test Group",
|
||||||
|
"tvg_id": "test.tv",
|
||||||
|
"name": "Test Channel",
|
||||||
|
"urls": ["http://test.com"],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
admin_user_client.post("/channels/bulk-upload", json=channel_data)
|
||||||
|
|
||||||
|
# Delete all priorities
|
||||||
|
response = admin_user_client.delete("/priorities")
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert response.json()["deleted"] == 2 # Medium and Low priorities
|
||||||
|
assert response.json()["skipped"] == 1 # High priority is in use
|
||||||
|
|
||||||
|
# Verify only priority 100 remains
|
||||||
|
priorities = db_session.query(MockPriority).all()
|
||||||
|
assert len(priorities) == 1
|
||||||
|
assert priorities[0].id == 100
|
||||||
|
|
||||||
|
|
||||||
|
def test_reset_priorities_forbidden_for_non_admin(db_session, non_admin_user_client):
|
||||||
|
"""Test reset priorities requires admin role"""
|
||||||
|
response = non_admin_user_client.delete("/priorities")
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_priority_success(db_session: Session, admin_user_client):
|
||||||
|
# Create a test priority
|
||||||
|
priority = MockPriority(id=100, description="To Delete")
|
||||||
|
db_session.add(priority)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
response = admin_user_client.delete("/priorities/100")
|
||||||
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||||
|
|
||||||
|
# Verify priority is gone from DB
|
||||||
|
db_priority = db_session.get(MockPriority, 100)
|
||||||
|
assert db_priority is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_priority_not_found(db_session: Session, admin_user_client):
|
||||||
|
response = admin_user_client.delete("/priorities/999")
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert "Priority not found" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_priority_in_use(db_session: Session, admin_user_client):
|
||||||
|
# Create a priority and a channel URL using it
|
||||||
|
priority = MockPriority(id=100, description="In Use")
|
||||||
|
group_id = uuid.uuid4()
|
||||||
|
test_group = MockGroup(
|
||||||
|
id=group_id,
|
||||||
|
name="Group With Channels",
|
||||||
|
sort_order=1,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db_session.add_all([priority, test_group])
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Create a channel first
|
||||||
|
channel = MockChannelDB(
|
||||||
|
name="Test Channel",
|
||||||
|
tvg_id="test.tv",
|
||||||
|
tvg_name="Test",
|
||||||
|
tvg_logo="test.png",
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
db_session.add(channel)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Create URL associated with the channel and priority
|
||||||
|
channel_url = MockChannelURL(
|
||||||
|
url="http://test.com",
|
||||||
|
priority_id=100,
|
||||||
|
in_use=True,
|
||||||
|
channel_id=channel.id, # Add the channel_id
|
||||||
|
)
|
||||||
|
db_session.add(channel_url)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
response = admin_user_client.delete("/priorities/100")
|
||||||
|
assert response.status_code == status.HTTP_409_CONFLICT
|
||||||
|
assert "in use by channel URLs" in response.json()["detail"]
|
||||||
|
|
||||||
|
# Verify priority still exists
|
||||||
|
db_priority = db_session.get(MockPriority, 100)
|
||||||
|
assert db_priority is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_priority_forbidden_for_non_admin(
|
||||||
|
db_session: Session, non_admin_user_client
|
||||||
|
):
|
||||||
|
response = non_admin_user_client.delete("/priorities/100")
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
assert "required roles" in response.json()["detail"]
|
||||||
287
tests/routers/test_scheduler.py
Normal file
287
tests/routers/test_scheduler.py
Normal file
@@ -0,0 +1,287 @@
|
|||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request, status
|
||||||
|
|
||||||
|
from app.iptv.scheduler import StreamScheduler
|
||||||
|
from app.routers.scheduler import get_scheduler
|
||||||
|
from app.routers.scheduler import router as scheduler_router
|
||||||
|
from app.utils.database import get_db
|
||||||
|
from tests.routers.mocks import MockScheduler, create_trigger_mock, mock_get_scheduler
|
||||||
|
from tests.utils.auth_test_fixtures import (
|
||||||
|
admin_user_client,
|
||||||
|
db_session,
|
||||||
|
non_admin_user_client,
|
||||||
|
)
|
||||||
|
from tests.utils.db_mocks import mock_get_db
|
||||||
|
|
||||||
|
# Scheduler Health Check Tests
|
||||||
|
|
||||||
|
|
||||||
|
def test_scheduler_health_success(admin_user_client, monkeypatch):
|
||||||
|
"""
|
||||||
|
Test case for successful scheduler health check when accessed by an admin user.
|
||||||
|
It mocks the scheduler to be running and have a next scheduled job.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Define the expected next run time for the scheduler job.
|
||||||
|
next_run = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# Create a mock job object that simulates an APScheduler job.
|
||||||
|
mock_job = Mock()
|
||||||
|
mock_job.next_run_time = next_run
|
||||||
|
|
||||||
|
# Mock the `get_job` method to return our mock_job for a specific ID.
|
||||||
|
def mock_get_job(job_id):
|
||||||
|
if job_id == "daily_stream_validation":
|
||||||
|
return mock_job
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create a custom mock for `get_scheduler` dependency.
|
||||||
|
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
|
||||||
|
return await mock_get_scheduler(
|
||||||
|
request,
|
||||||
|
running=True,
|
||||||
|
get_job=Mock(side_effect=mock_get_job), # Use the custom mock_get_job
|
||||||
|
)
|
||||||
|
|
||||||
|
# Include the scheduler router in the test application.
|
||||||
|
admin_user_client.app.include_router(scheduler_router)
|
||||||
|
|
||||||
|
# Override dependencies for the test.
|
||||||
|
admin_user_client.app.dependency_overrides[get_scheduler] = (
|
||||||
|
custom_mock_get_scheduler
|
||||||
|
)
|
||||||
|
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||||
|
|
||||||
|
# Make the request to the scheduler health endpoint.
|
||||||
|
response = admin_user_client.get("/scheduler/health")
|
||||||
|
|
||||||
|
# Assert the response status code and content.
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "running"
|
||||||
|
assert data["next_run"] == str(next_run)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scheduler_health_stopped(admin_user_client, monkeypatch):
|
||||||
|
"""
|
||||||
|
Test case for scheduler health check when the scheduler is in a stopped state.
|
||||||
|
Ensures the API returns the correct status and no next run time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create a custom mock for `get_scheduler` dependency,
|
||||||
|
# simulating a stopped scheduler.
|
||||||
|
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
|
||||||
|
return await mock_get_scheduler(
|
||||||
|
request,
|
||||||
|
running=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Include the scheduler router in the test application.
|
||||||
|
admin_user_client.app.include_router(scheduler_router)
|
||||||
|
|
||||||
|
# Override dependencies for the test.
|
||||||
|
admin_user_client.app.dependency_overrides[get_scheduler] = (
|
||||||
|
custom_mock_get_scheduler
|
||||||
|
)
|
||||||
|
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||||
|
|
||||||
|
# Make the request to the scheduler health endpoint.
|
||||||
|
response = admin_user_client.get("/scheduler/health")
|
||||||
|
|
||||||
|
# Assert the response status code and content.
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "stopped"
|
||||||
|
assert data["next_run"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_scheduler_health_forbidden_for_non_admin(non_admin_user_client, monkeypatch):
|
||||||
|
"""
|
||||||
|
Test case to ensure that non-admin users are forbidden from accessing
|
||||||
|
the scheduler health endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create a custom mock for `get_scheduler` dependency.
|
||||||
|
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
|
||||||
|
return await mock_get_scheduler(
|
||||||
|
request,
|
||||||
|
running=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Include the scheduler router in the test application.
|
||||||
|
non_admin_user_client.app.include_router(scheduler_router)
|
||||||
|
|
||||||
|
# Override dependencies for the test.
|
||||||
|
non_admin_user_client.app.dependency_overrides[get_scheduler] = (
|
||||||
|
custom_mock_get_scheduler
|
||||||
|
)
|
||||||
|
non_admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||||
|
|
||||||
|
# Make the request to the scheduler health endpoint.
|
||||||
|
response = non_admin_user_client.get("/scheduler/health")
|
||||||
|
|
||||||
|
# Assert the response status code and error detail.
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
assert "required roles" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_scheduler_health_check_exception(admin_user_client, monkeypatch):
|
||||||
|
"""
|
||||||
|
Test case for handling exceptions during the scheduler health check.
|
||||||
|
Ensures the API returns a 500 Internal Server Error when an exception occurs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create a custom mock for `get_scheduler` dependency that raises an exception.
|
||||||
|
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
|
||||||
|
return await mock_get_scheduler(
|
||||||
|
request, running=True, get_job=Mock(side_effect=Exception("Test exception"))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Include the scheduler router in the test application.
|
||||||
|
admin_user_client.app.include_router(scheduler_router)
|
||||||
|
|
||||||
|
# Override dependencies for the test.
|
||||||
|
admin_user_client.app.dependency_overrides[get_scheduler] = (
|
||||||
|
custom_mock_get_scheduler
|
||||||
|
)
|
||||||
|
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||||
|
|
||||||
|
# Make the request to the scheduler health endpoint.
|
||||||
|
response = admin_user_client.get("/scheduler/health")
|
||||||
|
|
||||||
|
# Assert the response status code and error detail.
|
||||||
|
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||||
|
assert "Failed to check scheduler health" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
# Scheduler Trigger Tests
|
||||||
|
|
||||||
|
|
||||||
|
def test_trigger_validation_success(admin_user_client, monkeypatch):
|
||||||
|
"""
|
||||||
|
Test case for successful manual triggering
|
||||||
|
of stream validation by an admin user.
|
||||||
|
It verifies that the trigger method is called and
|
||||||
|
the API returns a 202 Accepted status.
|
||||||
|
"""
|
||||||
|
# Use a mutable reference to check if the trigger method was called.
|
||||||
|
triggered_ref = {"value": False}
|
||||||
|
|
||||||
|
# Initialize a custom mock scheduler.
|
||||||
|
custom_scheduler = MockScheduler(running=True)
|
||||||
|
custom_scheduler.get_job = Mock(return_value=None)
|
||||||
|
|
||||||
|
# Create a custom mock for `get_scheduler` dependency,
|
||||||
|
# overriding `trigger_manual_validation`.
|
||||||
|
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
|
||||||
|
scheduler = await mock_get_scheduler(
|
||||||
|
request,
|
||||||
|
running=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Replace the actual trigger method with our mock to track calls.
|
||||||
|
scheduler.trigger_manual_validation = create_trigger_mock(
|
||||||
|
triggered_ref=triggered_ref
|
||||||
|
)
|
||||||
|
|
||||||
|
return scheduler
|
||||||
|
|
||||||
|
# Include the scheduler router in the test application.
|
||||||
|
admin_user_client.app.include_router(scheduler_router)
|
||||||
|
|
||||||
|
# Override dependencies for the test.
|
||||||
|
admin_user_client.app.dependency_overrides[get_scheduler] = (
|
||||||
|
custom_mock_get_scheduler
|
||||||
|
)
|
||||||
|
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||||
|
|
||||||
|
# Make the request to trigger stream validation.
|
||||||
|
response = admin_user_client.post("/scheduler/trigger")
|
||||||
|
|
||||||
|
# Assert the response status code, message, and that the trigger was called.
|
||||||
|
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||||
|
assert response.json()["message"] == "Stream validation triggered"
|
||||||
|
assert triggered_ref["value"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_trigger_validation_forbidden_for_non_admin(non_admin_user_client, monkeypatch):
|
||||||
|
"""
|
||||||
|
Test case to ensure that non-admin users are
|
||||||
|
forbidden from manually triggering stream validation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create a custom mock for `get_scheduler` dependency.
|
||||||
|
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
|
||||||
|
return await mock_get_scheduler(
|
||||||
|
request,
|
||||||
|
running=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Include the scheduler router in the test application.
|
||||||
|
non_admin_user_client.app.include_router(scheduler_router)
|
||||||
|
|
||||||
|
# Override dependencies for the test.
|
||||||
|
non_admin_user_client.app.dependency_overrides[get_scheduler] = (
|
||||||
|
custom_mock_get_scheduler
|
||||||
|
)
|
||||||
|
non_admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||||
|
|
||||||
|
# Make the request to trigger stream validation.
|
||||||
|
response = non_admin_user_client.post("/scheduler/trigger")
|
||||||
|
|
||||||
|
# Assert the response status code and error detail.
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
assert "required roles" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_scheduler_initialized_in_app_state(admin_user_client):
|
||||||
|
"""
|
||||||
|
Test case for when the scheduler is initialized in the app state but its internal
|
||||||
|
scheduler attribute is not set, which should still allow health check.
|
||||||
|
"""
|
||||||
|
scheduler = StreamScheduler()
|
||||||
|
|
||||||
|
# Set the scheduler instance in the test client's app state.
|
||||||
|
admin_user_client.app.state.scheduler = scheduler
|
||||||
|
|
||||||
|
# Include the scheduler router in the test application.
|
||||||
|
admin_user_client.app.include_router(scheduler_router)
|
||||||
|
|
||||||
|
# Override only get_db, allowing the real get_scheduler to be tested.
|
||||||
|
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||||
|
|
||||||
|
# Make the request to the scheduler health endpoint.
|
||||||
|
response = admin_user_client.get("/scheduler/health")
|
||||||
|
|
||||||
|
# Assert the response status code.
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
|
||||||
|
def test_scheduler_not_initialized_in_app_state(admin_user_client):
|
||||||
|
"""
|
||||||
|
Test case for when the scheduler is not properly initialized in the app state.
|
||||||
|
This simulates a scenario where the internal scheduler attribute is missing,
|
||||||
|
leading to a 500 Internal Server Error on health check.
|
||||||
|
"""
|
||||||
|
scheduler = StreamScheduler()
|
||||||
|
del (
|
||||||
|
scheduler.scheduler
|
||||||
|
) # Simulate uninitialized scheduler by deleting the attribute
|
||||||
|
|
||||||
|
# Set the scheduler instance in the test client's app state.
|
||||||
|
admin_user_client.app.state.scheduler = scheduler
|
||||||
|
|
||||||
|
# Include the scheduler router in the test application.
|
||||||
|
admin_user_client.app.include_router(scheduler_router)
|
||||||
|
|
||||||
|
# Override only get_db, allowing the real get_scheduler to be tested.
|
||||||
|
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||||
|
|
||||||
|
# Make the request to the scheduler health endpoint.
|
||||||
|
response = admin_user_client.get("/scheduler/health")
|
||||||
|
|
||||||
|
# Assert the response status code and error detail.
|
||||||
|
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||||
|
assert "Scheduler not initialized" in response.json()["detail"]
|
||||||
81
tests/test_main.py
Normal file
81
tests/test_main.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from app.main import app, lifespan
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
"""Test client for FastAPI app"""
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_root_endpoint(client):
|
||||||
|
"""Test root endpoint returns expected message"""
|
||||||
|
response = client.get("/")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"message": "IPTV Manager API"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_openapi_schema_generation(client):
|
||||||
|
"""Test OpenAPI schema is properly generated"""
|
||||||
|
# First call - generate schema
|
||||||
|
response = client.get("/openapi.json")
|
||||||
|
assert response.status_code == 200
|
||||||
|
schema = response.json()
|
||||||
|
assert schema["openapi"] == "3.1.0"
|
||||||
|
assert "securitySchemes" in schema["components"]
|
||||||
|
assert "Bearer" in schema["components"]["securitySchemes"]
|
||||||
|
|
||||||
|
# Test empty components initialization
|
||||||
|
with patch("app.main.get_openapi", return_value={"info": {}}):
|
||||||
|
# Clear cached schema
|
||||||
|
app.openapi_schema = None
|
||||||
|
# Get schema with empty response
|
||||||
|
response = client.get("/openapi.json")
|
||||||
|
assert response.status_code == 200
|
||||||
|
schema = response.json()
|
||||||
|
assert "components" in schema
|
||||||
|
assert "schemas" in schema["components"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_openapi_schema_caching(mocker):
|
||||||
|
"""Test OpenAPI schema caching behavior"""
|
||||||
|
# Clear any existing schema
|
||||||
|
app.openapi_schema = None
|
||||||
|
|
||||||
|
# Mock get_openapi to return test schema
|
||||||
|
mock_schema = {"test": "schema"}
|
||||||
|
mocker.patch("app.main.get_openapi", return_value=mock_schema)
|
||||||
|
|
||||||
|
# First call - should call get_openapi
|
||||||
|
schema = app.openapi()
|
||||||
|
assert schema == mock_schema
|
||||||
|
assert app.openapi_schema == mock_schema
|
||||||
|
|
||||||
|
# Second call - should return cached schema
|
||||||
|
with patch("app.main.get_openapi") as mock_get_openapi:
|
||||||
|
schema = app.openapi()
|
||||||
|
assert schema == mock_schema
|
||||||
|
mock_get_openapi.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lifespan_init_db(mocker):
|
||||||
|
"""Test lifespan manager initializes database"""
|
||||||
|
mock_init_db = mocker.patch("app.main.init_db")
|
||||||
|
async with lifespan(app):
|
||||||
|
pass # Just enter/exit context
|
||||||
|
mock_init_db.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_inclusion():
|
||||||
|
"""Test all routers are properly included"""
|
||||||
|
route_paths = {route.path for route in app.routes}
|
||||||
|
assert "/" in route_paths
|
||||||
|
assert any(path.startswith("/auth") for path in route_paths)
|
||||||
|
assert any(path.startswith("/channels") for path in route_paths)
|
||||||
|
assert any(path.startswith("/playlist") for path in route_paths)
|
||||||
|
assert any(path.startswith("/priorities") for path in route_paths)
|
||||||
0
tests/utils/__init__.py
Normal file
0
tests/utils/__init__.py
Normal file
82
tests/utils/auth_test_fixtures.py
Normal file
82
tests/utils/auth_test_fixtures.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.auth.dependencies import get_current_user
|
||||||
|
from app.models.auth import CognitoUser
|
||||||
|
from app.routers.channels import router as channels_router
|
||||||
|
from app.routers.groups import router as groups_router
|
||||||
|
from app.routers.playlist import router as playlist_router
|
||||||
|
from app.routers.priorities import router as priorities_router
|
||||||
|
from app.utils.database import get_db
|
||||||
|
from tests.utils.db_mocks import (
|
||||||
|
MockBase,
|
||||||
|
MockChannelDB,
|
||||||
|
MockChannelURL,
|
||||||
|
MockPriority,
|
||||||
|
engine_mock,
|
||||||
|
mock_get_db,
|
||||||
|
)
|
||||||
|
from tests.utils.db_mocks import session_mock as TestingSessionLocal
|
||||||
|
|
||||||
|
|
||||||
|
def mock_get_current_user_admin():
|
||||||
|
return CognitoUser(
|
||||||
|
username="testadmin",
|
||||||
|
email="testadmin@example.com",
|
||||||
|
roles=["admin"],
|
||||||
|
user_status="CONFIRMED",
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def mock_get_current_user_non_admin():
|
||||||
|
return CognitoUser(
|
||||||
|
username="testuser",
|
||||||
|
email="testuser@example.com",
|
||||||
|
roles=["user"], # Or any role other than admin
|
||||||
|
user_status="CONFIRMED",
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def db_session():
|
||||||
|
# Create tables for each test function
|
||||||
|
MockBase.metadata.create_all(bind=engine_mock)
|
||||||
|
db = TestingSessionLocal()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
# Drop tables after each test function
|
||||||
|
MockBase.metadata.drop_all(bind=engine_mock)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def admin_user_client(db_session: Session):
|
||||||
|
"""Yields a TestClient configured with an admin user."""
|
||||||
|
test_app = FastAPI()
|
||||||
|
test_app.include_router(channels_router)
|
||||||
|
test_app.include_router(priorities_router)
|
||||||
|
test_app.include_router(playlist_router)
|
||||||
|
test_app.include_router(groups_router)
|
||||||
|
test_app.dependency_overrides[get_db] = mock_get_db
|
||||||
|
test_app.dependency_overrides[get_current_user] = mock_get_current_user_admin
|
||||||
|
with TestClient(test_app) as test_client:
|
||||||
|
yield test_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def non_admin_user_client(db_session: Session):
|
||||||
|
"""Yields a TestClient configured with a non-admin user."""
|
||||||
|
test_app = FastAPI()
|
||||||
|
test_app.include_router(channels_router)
|
||||||
|
test_app.include_router(priorities_router)
|
||||||
|
test_app.include_router(playlist_router)
|
||||||
|
test_app.include_router(groups_router)
|
||||||
|
test_app.dependency_overrides[get_db] = mock_get_db
|
||||||
|
test_app.dependency_overrides[get_current_user] = mock_get_current_user_non_admin
|
||||||
|
with TestClient(test_app) as test_client:
|
||||||
|
yield test_client
|
||||||
149
tests/utils/db_mocks.py
Normal file
149
tests/utils/db_mocks.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import (
|
||||||
|
Boolean,
|
||||||
|
Column,
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
Integer,
|
||||||
|
String,
|
||||||
|
UniqueConstraint,
|
||||||
|
create_engine,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import declarative_base, relationship, sessionmaker
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
# Import the actual UUID_COLUMN_TYPE and SQLiteUUID from app.models.db
|
||||||
|
from app.models.db import UUID_COLUMN_TYPE, SQLiteUUID
|
||||||
|
|
||||||
|
# Create a mock-specific Base class for testing
|
||||||
|
MockBase = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
# Model classes for testing - prefix with Mock to avoid pytest collection
|
||||||
|
class MockPriority(MockBase):
|
||||||
|
__tablename__ = "priorities"
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
description = Column(String, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
class MockGroup(MockBase):
|
||||||
|
__tablename__ = "groups"
|
||||||
|
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
||||||
|
name = Column(String, nullable=False, unique=True)
|
||||||
|
sort_order = Column(Integer, nullable=False, default=0)
|
||||||
|
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||||
|
updated_at = Column(
|
||||||
|
DateTime,
|
||||||
|
default=lambda: datetime.now(timezone.utc),
|
||||||
|
onupdate=lambda: datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
channels = relationship("MockChannelDB", back_populates="group")
|
||||||
|
|
||||||
|
|
||||||
|
class MockChannelDB(MockBase):
|
||||||
|
__tablename__ = "channels"
|
||||||
|
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
||||||
|
tvg_id = Column(String, nullable=False)
|
||||||
|
name = Column(String, nullable=False)
|
||||||
|
group_id = Column(UUID_COLUMN_TYPE, ForeignKey("groups.id"), nullable=False)
|
||||||
|
tvg_name = Column(String)
|
||||||
|
__table_args__ = (UniqueConstraint("group_id", "name", name="uix_group_id_name"),)
|
||||||
|
tvg_logo = Column(String)
|
||||||
|
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||||
|
updated_at = Column(
|
||||||
|
DateTime,
|
||||||
|
default=lambda: datetime.now(timezone.utc),
|
||||||
|
onupdate=lambda: datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
group = relationship("MockGroup", back_populates="channels")
|
||||||
|
urls = relationship(
|
||||||
|
"MockChannelURL", back_populates="channel", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockChannelURL(MockBase):
|
||||||
|
__tablename__ = "channels_urls"
|
||||||
|
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
||||||
|
channel_id = Column(
|
||||||
|
UUID_COLUMN_TYPE, ForeignKey("channels.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
|
url = Column(String, nullable=False)
|
||||||
|
in_use = Column(Boolean, default=False, nullable=False)
|
||||||
|
priority_id = Column(Integer, ForeignKey("priorities.id"), nullable=False)
|
||||||
|
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||||
|
updated_at = Column(
|
||||||
|
DateTime,
|
||||||
|
default=lambda: datetime.now(timezone.utc),
|
||||||
|
onupdate=lambda: datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
channel = relationship("MockChannelDB", back_populates="urls")
|
||||||
|
|
||||||
|
|
||||||
|
def create_mock_priorities_and_group(db_session, priorities, group_name):
|
||||||
|
"""Create mock priorities and group for testing purposes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: SQLAlchemy session object
|
||||||
|
priorities: List of (id, description) tuples for priorities to create
|
||||||
|
group_name: Name for the new mock group
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UUID: The ID of the created group
|
||||||
|
"""
|
||||||
|
# Create priorities
|
||||||
|
priority_objects = [
|
||||||
|
MockPriority(id=priority_id, description=description)
|
||||||
|
for priority_id, description in priorities
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create group
|
||||||
|
group = MockGroup(name=group_name)
|
||||||
|
db_session.add_all(priority_objects + [group])
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
return group.id
|
||||||
|
|
||||||
|
|
||||||
|
# Create test engine
|
||||||
|
engine_mock = create_engine(
|
||||||
|
"sqlite:///:memory:",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create test session
|
||||||
|
session_mock = sessionmaker(autocommit=False, autoflush=False, bind=engine_mock)
|
||||||
|
|
||||||
|
|
||||||
|
# Mock the actual database functions
|
||||||
|
def mock_get_db():
|
||||||
|
db = session_mock()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_env(monkeypatch):
|
||||||
|
"""Fixture for mocking environment variables"""
|
||||||
|
monkeypatch.setenv("MOCK_AUTH", "true")
|
||||||
|
monkeypatch.setenv("DB_USER", "testuser")
|
||||||
|
monkeypatch.setenv("DB_PASSWORD", "testpass")
|
||||||
|
monkeypatch.setenv("DB_HOST", "localhost")
|
||||||
|
monkeypatch.setenv("DB_NAME", "testdb")
|
||||||
|
monkeypatch.setenv("AWS_REGION", "us-east-1")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_ssm():
|
||||||
|
"""Fixture for mocking boto3 SSM client"""
|
||||||
|
with patch("boto3.client") as mock_client:
|
||||||
|
mock_ssm = MagicMock()
|
||||||
|
mock_client.return_value = mock_ssm
|
||||||
|
mock_ssm.get_parameter.return_value = {"Parameter": {"Value": "mocked_value"}}
|
||||||
|
yield mock_ssm
|
||||||
309
tests/utils/test_check_streams.py
Normal file
309
tests/utils/test_check_streams.py
Normal file
@@ -0,0 +1,309 @@
|
|||||||
|
import os
|
||||||
|
from unittest.mock import MagicMock, Mock, mock_open, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout
|
||||||
|
|
||||||
|
from app.utils.check_streams import StreamValidator, main
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def validator():
|
||||||
|
"""Create a StreamValidator instance for testing"""
|
||||||
|
return StreamValidator(timeout=1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validator_init():
|
||||||
|
"""Test StreamValidator initialization with default and custom values"""
|
||||||
|
# Test with default user agent
|
||||||
|
validator = StreamValidator()
|
||||||
|
assert validator.timeout == 10
|
||||||
|
assert "Mozilla" in validator.session.headers["User-Agent"]
|
||||||
|
|
||||||
|
# Test with custom values
|
||||||
|
custom_agent = "CustomAgent/1.0"
|
||||||
|
validator = StreamValidator(timeout=5, user_agent=custom_agent)
|
||||||
|
assert validator.timeout == 5
|
||||||
|
assert validator.session.headers["User-Agent"] == custom_agent
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_valid_content_type(validator):
|
||||||
|
"""Test content type validation"""
|
||||||
|
valid_types = [
|
||||||
|
"video/mp4",
|
||||||
|
"video/mp2t",
|
||||||
|
"application/vnd.apple.mpegurl",
|
||||||
|
"application/dash+xml",
|
||||||
|
"video/webm",
|
||||||
|
"application/octet-stream",
|
||||||
|
"application/x-mpegURL",
|
||||||
|
"video/mp4; charset=utf-8", # Test with additional parameters
|
||||||
|
]
|
||||||
|
|
||||||
|
invalid_types = [
|
||||||
|
"text/html",
|
||||||
|
"application/json",
|
||||||
|
"image/jpeg",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
|
||||||
|
for content_type in valid_types:
|
||||||
|
assert validator._is_valid_content_type(content_type)
|
||||||
|
|
||||||
|
for content_type in invalid_types:
|
||||||
|
assert not validator._is_valid_content_type(content_type)
|
||||||
|
|
||||||
|
# Test None case explicitly
|
||||||
|
assert not validator._is_valid_content_type(None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"status_code,content_type,should_succeed",
|
||||||
|
[
|
||||||
|
(200, "video/mp4", True),
|
||||||
|
(206, "video/mp4", True), # Partial content
|
||||||
|
(404, "video/mp4", False),
|
||||||
|
(500, "video/mp4", False),
|
||||||
|
(200, "text/html", False),
|
||||||
|
(200, "application/json", False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_validate_stream_response_handling(status_code, content_type, should_succeed):
|
||||||
|
"""Test stream validation with different response scenarios"""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = status_code
|
||||||
|
mock_response.headers = {"Content-Type": content_type}
|
||||||
|
mock_response.iter_content.return_value = iter([b"some content"])
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.get.return_value.__enter__.return_value = mock_response
|
||||||
|
|
||||||
|
with patch("requests.Session", return_value=mock_session):
|
||||||
|
validator = StreamValidator()
|
||||||
|
valid, message = validator.validate_stream("http://example.com/stream")
|
||||||
|
|
||||||
|
assert valid == should_succeed
|
||||||
|
mock_session.get.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_stream_connection_error():
|
||||||
|
"""Test stream validation with connection error"""
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.get.side_effect = ConnectionError("Connection failed")
|
||||||
|
|
||||||
|
with patch("requests.Session", return_value=mock_session):
|
||||||
|
validator = StreamValidator()
|
||||||
|
valid, message = validator.validate_stream("http://example.com/stream")
|
||||||
|
|
||||||
|
assert not valid
|
||||||
|
assert "Connection Error" in message
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_stream_timeout():
|
||||||
|
"""Test stream validation with timeout"""
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.get.side_effect = Timeout("Request timed out")
|
||||||
|
|
||||||
|
with patch("requests.Session", return_value=mock_session):
|
||||||
|
validator = StreamValidator()
|
||||||
|
valid, message = validator.validate_stream("http://example.com/stream")
|
||||||
|
|
||||||
|
assert not valid
|
||||||
|
assert "timeout" in message.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_stream_http_error():
|
||||||
|
"""Test stream validation with HTTP error"""
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.get.side_effect = HTTPError("HTTP Error occurred")
|
||||||
|
|
||||||
|
with patch("requests.Session", return_value=mock_session):
|
||||||
|
validator = StreamValidator()
|
||||||
|
valid, message = validator.validate_stream("http://example.com/stream")
|
||||||
|
|
||||||
|
assert not valid
|
||||||
|
assert "HTTP Error" in message
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_stream_request_exception():
|
||||||
|
"""Test stream validation with general request exception"""
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.get.side_effect = RequestException("Request failed")
|
||||||
|
|
||||||
|
with patch("requests.Session", return_value=mock_session):
|
||||||
|
validator = StreamValidator()
|
||||||
|
valid, message = validator.validate_stream("http://example.com/stream")
|
||||||
|
|
||||||
|
assert not valid
|
||||||
|
assert "Request Exception" in message
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_stream_content_read_error():
|
||||||
|
"""Test stream validation when content reading fails"""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.headers = {"Content-Type": "video/mp4"}
|
||||||
|
mock_response.iter_content.side_effect = ConnectionError("Read failed")
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.get.return_value.__enter__.return_value = mock_response
|
||||||
|
|
||||||
|
with patch("requests.Session", return_value=mock_session):
|
||||||
|
validator = StreamValidator()
|
||||||
|
valid, message = validator.validate_stream("http://example.com/stream")
|
||||||
|
|
||||||
|
assert not valid
|
||||||
|
assert "Connection failed during content read" in message
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_stream_general_exception():
|
||||||
|
"""Test validate_stream with an unexpected exception"""
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.get.side_effect = Exception("Unexpected error")
|
||||||
|
|
||||||
|
with patch("requests.Session", return_value=mock_session):
|
||||||
|
validator = StreamValidator()
|
||||||
|
valid, message = validator.validate_stream("http://example.com/stream")
|
||||||
|
|
||||||
|
assert not valid
|
||||||
|
assert "Validation error" in message
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_playlist(validator, tmp_path):
|
||||||
|
"""Test playlist file parsing"""
|
||||||
|
playlist_content = """
|
||||||
|
#EXTM3U
|
||||||
|
#EXTINF:-1,Channel 1
|
||||||
|
http://example.com/stream1
|
||||||
|
#EXTINF:-1,Channel 2
|
||||||
|
http://example.com/stream2
|
||||||
|
|
||||||
|
http://example.com/stream3
|
||||||
|
"""
|
||||||
|
playlist_file = tmp_path / "test_playlist.m3u"
|
||||||
|
playlist_file.write_text(playlist_content)
|
||||||
|
|
||||||
|
urls = validator.parse_playlist(str(playlist_file))
|
||||||
|
assert len(urls) == 3
|
||||||
|
assert urls == [
|
||||||
|
"http://example.com/stream1",
|
||||||
|
"http://example.com/stream2",
|
||||||
|
"http://example.com/stream3",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_playlist_error(validator):
|
||||||
|
"""Test playlist parsing with non-existent file"""
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
validator.parse_playlist("nonexistent_file.m3u")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("app.utils.check_streams.logging")
|
||||||
|
@patch("app.utils.check_streams.StreamValidator")
|
||||||
|
def test_main_with_urls(mock_validator_class, mock_logging, tmp_path, capsys):
|
||||||
|
"""Test main function with direct URLs"""
|
||||||
|
# Setup mock validator
|
||||||
|
mock_validator = Mock()
|
||||||
|
mock_validator_class.return_value = mock_validator
|
||||||
|
mock_validator.validate_stream.return_value = (True, "Stream is valid")
|
||||||
|
|
||||||
|
# Setup test arguments
|
||||||
|
test_args = ["script", "http://example.com/stream1", "http://example.com/stream2"]
|
||||||
|
with patch("sys.argv", test_args):
|
||||||
|
main()
|
||||||
|
|
||||||
|
# Verify validator was called correctly
|
||||||
|
assert mock_validator.validate_stream.call_count == 2
|
||||||
|
mock_validator.validate_stream.assert_any_call("http://example.com/stream1")
|
||||||
|
mock_validator.validate_stream.assert_any_call("http://example.com/stream2")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("app.utils.check_streams.logging")
|
||||||
|
@patch("app.utils.check_streams.StreamValidator")
|
||||||
|
def test_main_with_playlist(mock_validator_class, mock_logging, tmp_path):
|
||||||
|
"""Test main function with a playlist file"""
|
||||||
|
# Create test playlist
|
||||||
|
playlist_content = "http://example.com/stream1\nhttp://example.com/stream2"
|
||||||
|
playlist_file = tmp_path / "test.m3u"
|
||||||
|
playlist_file.write_text(playlist_content)
|
||||||
|
|
||||||
|
# Setup mock validator
|
||||||
|
mock_validator = Mock()
|
||||||
|
mock_validator_class.return_value = mock_validator
|
||||||
|
mock_validator.parse_playlist.return_value = [
|
||||||
|
"http://example.com/stream1",
|
||||||
|
"http://example.com/stream2",
|
||||||
|
]
|
||||||
|
mock_validator.validate_stream.return_value = (True, "Stream is valid")
|
||||||
|
|
||||||
|
# Setup test arguments
|
||||||
|
test_args = ["script", str(playlist_file)]
|
||||||
|
with patch("sys.argv", test_args):
|
||||||
|
main()
|
||||||
|
|
||||||
|
# Verify validator was called correctly
|
||||||
|
mock_validator.parse_playlist.assert_called_once_with(str(playlist_file))
|
||||||
|
assert mock_validator.validate_stream.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@patch("app.utils.check_streams.logging")
|
||||||
|
@patch("app.utils.check_streams.StreamValidator")
|
||||||
|
def test_main_with_dead_streams(mock_validator_class, mock_logging, tmp_path):
|
||||||
|
"""Test main function handling dead streams"""
|
||||||
|
# Setup mock validator
|
||||||
|
mock_validator = Mock()
|
||||||
|
mock_validator_class.return_value = mock_validator
|
||||||
|
mock_validator.validate_stream.return_value = (False, "Stream is dead")
|
||||||
|
|
||||||
|
# Setup test arguments
|
||||||
|
test_args = ["script", "http://example.com/dead1", "http://example.com/dead2"]
|
||||||
|
|
||||||
|
# Mock file operations
|
||||||
|
mock_file = mock_open()
|
||||||
|
with patch("sys.argv", test_args), patch("builtins.open", mock_file):
|
||||||
|
main()
|
||||||
|
|
||||||
|
# Verify dead streams were written to file
|
||||||
|
mock_file().write.assert_called_once_with(
|
||||||
|
"http://example.com/dead1\nhttp://example.com/dead2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("app.utils.check_streams.logging")
|
||||||
|
@patch("app.utils.check_streams.StreamValidator")
|
||||||
|
@patch("os.path.isfile")
|
||||||
|
def test_main_with_playlist_error(
|
||||||
|
mock_isfile, mock_validator_class, mock_logging, tmp_path
|
||||||
|
):
|
||||||
|
"""Test main function handling playlist parsing errors"""
|
||||||
|
# Setup mock validator
|
||||||
|
mock_validator = Mock()
|
||||||
|
mock_validator_class.return_value = mock_validator
|
||||||
|
|
||||||
|
# Configure mock validator behavior
|
||||||
|
error_msg = "Failed to parse playlist"
|
||||||
|
mock_validator.parse_playlist.side_effect = [
|
||||||
|
Exception(error_msg), # First call fails
|
||||||
|
["http://example.com/stream1"], # Second call succeeds
|
||||||
|
]
|
||||||
|
mock_validator.validate_stream.return_value = (True, "Stream is valid")
|
||||||
|
|
||||||
|
# Configure isfile mock to return True for our test files
|
||||||
|
mock_isfile.side_effect = lambda x: x in ["/invalid.m3u", "/valid.m3u"]
|
||||||
|
|
||||||
|
# Setup test arguments
|
||||||
|
test_args = ["script", "/invalid.m3u", "/valid.m3u"]
|
||||||
|
with patch("sys.argv", test_args):
|
||||||
|
main()
|
||||||
|
|
||||||
|
# Verify error was logged correctly
|
||||||
|
mock_logging.error.assert_called_with(
|
||||||
|
"Failed to process file /invalid.m3u: Failed to parse playlist"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify processing continued with valid playlist
|
||||||
|
mock_validator.parse_playlist.assert_called_with("/valid.m3u")
|
||||||
|
assert (
|
||||||
|
mock_validator.validate_stream.call_count == 1
|
||||||
|
) # Called for the URL from valid playlist
|
||||||
69
tests/utils/test_database.py
Normal file
69
tests/utils/test_database.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.utils.database import get_db, get_db_credentials
|
||||||
|
from tests.utils.db_mocks import mock_env, mock_ssm, session_mock
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_db_credentials_env(mock_env):
|
||||||
|
"""Test getting DB credentials from environment variables"""
|
||||||
|
conn_str = get_db_credentials()
|
||||||
|
assert conn_str == "postgresql://testuser:testpass@localhost/testdb"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_db_credentials_ssm(mock_ssm):
|
||||||
|
"""Test getting DB credentials from SSM"""
|
||||||
|
os.environ.pop("MOCK_AUTH", None)
|
||||||
|
conn_str = get_db_credentials()
|
||||||
|
expected_conn = "postgresql://mocked_value:mocked_value@mocked_value/mocked_value"
|
||||||
|
assert expected_conn in conn_str
|
||||||
|
mock_ssm.get_parameter.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_db_credentials_ssm_exception(mock_ssm):
|
||||||
|
"""Test SSM credential fetching failure raises RuntimeError"""
|
||||||
|
os.environ.pop("MOCK_AUTH", None)
|
||||||
|
mock_ssm.get_parameter.side_effect = Exception("SSM timeout")
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as excinfo:
|
||||||
|
get_db_credentials()
|
||||||
|
|
||||||
|
assert "Failed to fetch DB credentials from SSM: SSM timeout" in str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_session_creation():
|
||||||
|
"""Test database session creation"""
|
||||||
|
session = session_mock()
|
||||||
|
assert isinstance(session, Session)
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_db_generator():
|
||||||
|
"""Test get_db dependency generator"""
|
||||||
|
db_gen = get_db()
|
||||||
|
db = next(db_gen)
|
||||||
|
assert isinstance(db, Session)
|
||||||
|
try:
|
||||||
|
next(db_gen) # Should raise StopIteration
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_db(mocker, mock_env):
|
||||||
|
"""Test database initialization creates tables"""
|
||||||
|
mock_create_all = mocker.patch("app.models.Base.metadata.create_all")
|
||||||
|
|
||||||
|
# Mock get_db_credentials to return SQLite test connection
|
||||||
|
mocker.patch(
|
||||||
|
"app.utils.database.get_db_credentials",
|
||||||
|
return_value="sqlite:///:memory:",
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.utils.database import engine, init_db
|
||||||
|
|
||||||
|
init_db()
|
||||||
|
|
||||||
|
# Verify create_all was called with the engine
|
||||||
|
mock_create_all.assert_called_once_with(bind=engine)
|
||||||
Reference in New Issue
Block a user