Compare commits
15 Commits
6d506122d9
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| a42d4c30a6 | |||
| abb467749b | |||
| b8ac25e301 | |||
| 729eabf27f | |||
| 34c446bcfa | |||
| d4cc74ea8c | |||
| 21b73b6843 | |||
| e743daf9f7 | |||
| b0d98551b8 | |||
| eaab1ef998 | |||
| e25f8c1ecd | |||
| 95bf0f9701 | |||
| f7a1c20066 | |||
| bf6f156fec | |||
| 7e25ec6755 |
@@ -1,10 +1,15 @@
|
|||||||
|
|
||||||
|
# Environment variables
|
||||||
|
# Scheduler configuration
|
||||||
|
STREAM_VALIDATION_SCHEDULE=0 3 * * * # Daily at 3 AM (cron syntax)
|
||||||
|
STREAM_VALIDATION_BATCH_SIZE=10 # Number of channels per batch (0=all)
|
||||||
|
|
||||||
# For use with Docker Compose to run application locally
|
# For use with Docker Compose to run application locally
|
||||||
MOCK_AUTH=true/false
|
MOCK_AUTH=true/false
|
||||||
DB_USER=MyDBUser
|
DB_USER=MyDBUser
|
||||||
DB_PASSWORD=MyDBPassword
|
DB_PASSWORD=MyDBPassword
|
||||||
DB_HOST=MyDBHost
|
DB_HOST=MyDBHost
|
||||||
DB_NAME=iptv_updater
|
DB_NAME=iptv_manager
|
||||||
|
|
||||||
FREEDNS_User=MyFreeDNSUsername
|
FREEDNS_User=MyFreeDNSUsername
|
||||||
FREEDNS_Password=MyFreeDNSPassword
|
FREEDNS_Password=MyFreeDNSPassword
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
INSTANCE_IDS=$(aws ec2 describe-instances \
|
INSTANCE_IDS=$(aws ec2 describe-instances \
|
||||||
--region us-east-2 \
|
--region us-east-2 \
|
||||||
--filters "Name=tag:Name,Values=IptvUpdaterStack/IptvUpdaterInstance" \
|
--filters "Name=tag:Name,Values=IptvManagerStack/IptvManagerInstance" \
|
||||||
"Name=instance-state-name,Values=running" \
|
"Name=instance-state-name,Values=running" \
|
||||||
--query "Reservations[].Instances[].InstanceId" \
|
--query "Reservations[].Instances[].InstanceId" \
|
||||||
--output text)
|
--output text)
|
||||||
@@ -69,11 +69,11 @@ jobs:
|
|||||||
--instance-ids "$INSTANCE_ID" \
|
--instance-ids "$INSTANCE_ID" \
|
||||||
--document-name "AWS-RunShellScript" \
|
--document-name "AWS-RunShellScript" \
|
||||||
--parameters 'commands=[
|
--parameters 'commands=[
|
||||||
"cd /home/ec2-user/iptv-updater-aws",
|
"cd /home/ec2-user/iptv-manager-service",
|
||||||
"git pull",
|
"git pull",
|
||||||
"pip3 install -r requirements.txt",
|
"pip3 install -r requirements.txt",
|
||||||
"alembic upgrade head",
|
"alembic upgrade head",
|
||||||
"sudo systemctl restart iptv-updater"
|
"sudo systemctl restart iptv-manager"
|
||||||
]'
|
]'
|
||||||
done
|
done
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,16 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.11.2
|
rev: v0.11.12
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--fix, --exit-non-zero-on-fix]
|
args: [--fix, --exit-non-zero-on-fix]
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
|
|
||||||
|
- repo: local
|
||||||
|
hooks:
|
||||||
|
- id: pytest-check
|
||||||
|
name: pytest-check
|
||||||
|
entry: pytest
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
always_run: true
|
||||||
19
.vscode/settings.json
vendored
19
.vscode/settings.json
vendored
@@ -1,4 +1,6 @@
|
|||||||
{
|
{
|
||||||
|
"python.terminal.activateEnvironment": true,
|
||||||
|
"python.terminal.activateEnvInCurrentTerminal": true,
|
||||||
"editor.formatOnSave": true,
|
"editor.formatOnSave": true,
|
||||||
"editor.defaultFormatter": "charliermarsh.ruff",
|
"editor.defaultFormatter": "charliermarsh.ruff",
|
||||||
"ruff.importStrategy": "fromEnvironment",
|
"ruff.importStrategy": "fromEnvironment",
|
||||||
@@ -7,14 +9,18 @@
|
|||||||
"addopts",
|
"addopts",
|
||||||
"adminpassword",
|
"adminpassword",
|
||||||
"altinstall",
|
"altinstall",
|
||||||
|
"apscheduler",
|
||||||
"asyncio",
|
"asyncio",
|
||||||
"autoflush",
|
"autoflush",
|
||||||
|
"autoupdate",
|
||||||
"autouse",
|
"autouse",
|
||||||
|
"awscli",
|
||||||
"awscliv",
|
"awscliv",
|
||||||
"boto",
|
"boto",
|
||||||
"botocore",
|
"botocore",
|
||||||
"BURSTABLE",
|
"BURSTABLE",
|
||||||
"cabletv",
|
"cabletv",
|
||||||
|
"capsys",
|
||||||
"CDUF",
|
"CDUF",
|
||||||
"cduflogo",
|
"cduflogo",
|
||||||
"cdulogo",
|
"cdulogo",
|
||||||
@@ -29,10 +35,14 @@
|
|||||||
"cluflogo",
|
"cluflogo",
|
||||||
"clulogo",
|
"clulogo",
|
||||||
"cpulogo",
|
"cpulogo",
|
||||||
|
"crond",
|
||||||
|
"cronie",
|
||||||
"cuflgo",
|
"cuflgo",
|
||||||
"CUNF",
|
"CUNF",
|
||||||
"cunflogo",
|
"cunflogo",
|
||||||
"cuulogo",
|
"cuulogo",
|
||||||
|
"datname",
|
||||||
|
"deadstreams",
|
||||||
"delenv",
|
"delenv",
|
||||||
"delogo",
|
"delogo",
|
||||||
"devel",
|
"devel",
|
||||||
@@ -40,22 +50,29 @@
|
|||||||
"dmlogo",
|
"dmlogo",
|
||||||
"dotenv",
|
"dotenv",
|
||||||
"EXTINF",
|
"EXTINF",
|
||||||
|
"EXTM",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"filterwarnings",
|
"filterwarnings",
|
||||||
"fiorinis",
|
"fiorinis",
|
||||||
"freedns",
|
"freedns",
|
||||||
"fullchain",
|
"fullchain",
|
||||||
"gitea",
|
"gitea",
|
||||||
|
"httpx",
|
||||||
"iptv",
|
"iptv",
|
||||||
"isort",
|
"isort",
|
||||||
"KHTML",
|
"KHTML",
|
||||||
"lclogo",
|
"lclogo",
|
||||||
"LETSENCRYPT",
|
"LETSENCRYPT",
|
||||||
|
"levelname",
|
||||||
|
"mpegurl",
|
||||||
"nohup",
|
"nohup",
|
||||||
|
"nopriority",
|
||||||
"ondelete",
|
"ondelete",
|
||||||
"onupdate",
|
"onupdate",
|
||||||
"passlib",
|
"passlib",
|
||||||
|
"PGPASSWORD",
|
||||||
"poolclass",
|
"poolclass",
|
||||||
|
"psql",
|
||||||
"psycopg",
|
"psycopg",
|
||||||
"pycache",
|
"pycache",
|
||||||
"pycodestyle",
|
"pycodestyle",
|
||||||
@@ -70,12 +87,14 @@
|
|||||||
"ruru",
|
"ruru",
|
||||||
"sessionmaker",
|
"sessionmaker",
|
||||||
"sqlalchemy",
|
"sqlalchemy",
|
||||||
|
"sqliteuuid",
|
||||||
"starlette",
|
"starlette",
|
||||||
"stefano",
|
"stefano",
|
||||||
"testadmin",
|
"testadmin",
|
||||||
"testdb",
|
"testdb",
|
||||||
"testpass",
|
"testpass",
|
||||||
"testpaths",
|
"testpaths",
|
||||||
|
"testuser",
|
||||||
"uflogo",
|
"uflogo",
|
||||||
"umlogo",
|
"umlogo",
|
||||||
"usefixtures",
|
"usefixtures",
|
||||||
|
|||||||
226
README.md
226
README.md
@@ -1,165 +1,149 @@
|
|||||||
# IPTV Updater AWS
|
# IPTV Manager Service
|
||||||
|
|
||||||
An automated IPTV playlist and EPG updater service deployed on AWS infrastructure using CDK.
|
A FastAPI-based service for managing IPTV playlists and channel priorities. The application provides secure endpoints for user authentication, channel management, and playlist generation.
|
||||||
|
|
||||||
## Overview
|
## ✨ Features
|
||||||
|
|
||||||
This project provides a service for automatically updating IPTV playlists and Electronic Program Guide (EPG) data. It runs on AWS infrastructure with:
|
- **JWT Authentication**: Secure login using AWS Cognito
|
||||||
|
- **Channel Management**: CRUD operations for IPTV channels
|
||||||
|
- **Playlist Generation**: Create M3U playlists with channel priorities
|
||||||
|
- **Stream Monitoring**: Background checks for channel availability
|
||||||
|
- **Priority Management**: Set channel priorities for playlist ordering
|
||||||
|
|
||||||
- EC2 instance for hosting the application
|
## 🛠️ Technology Stack
|
||||||
- RDS PostgreSQL database for data storage
|
|
||||||
- Amazon Cognito for user authentication
|
|
||||||
- HTTPS support via Let's Encrypt
|
|
||||||
- Domain management via FreeDNS
|
|
||||||
|
|
||||||
## Prerequisites
|
- **Backend**: Python 3.11, FastAPI
|
||||||
|
- **Database**: PostgreSQL (SQLAlchemy ORM)
|
||||||
|
- **Authentication**: AWS Cognito
|
||||||
|
- **Infrastructure**: AWS CDK (API Gateway, Lambda, RDS)
|
||||||
|
- **Testing**: Pytest with 85%+ coverage
|
||||||
|
- **CI/CD**: Pre-commit hooks, Alembic migrations
|
||||||
|
|
||||||
- AWS CLI installed and configured
|
## 🚀 Getting Started
|
||||||
- Python 3.12 or later
|
|
||||||
- Node.js v22.15 or later for AWS CDK
|
|
||||||
- Docker and Docker Compose for local development
|
|
||||||
|
|
||||||
## Local Development
|
### Prerequisites
|
||||||
|
|
||||||
1. Clone the repository:
|
- Python 3.11+
|
||||||
|
- Docker
|
||||||
|
- AWS CLI (for deployment)
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone <repo-url>
|
# Clone repository
|
||||||
cd iptv-updater-aws
|
git clone https://github.com/your-repo/iptv-manager-service.git
|
||||||
|
cd iptv-manager-service
|
||||||
|
|
||||||
|
# Setup environment
|
||||||
|
python -m venv venv
|
||||||
|
source venv/bin/activate
|
||||||
|
pip install -r requirements.txt
|
||||||
|
cp .env.example .env # Update with your values
|
||||||
|
|
||||||
|
# Run installation script
|
||||||
|
./scripts/install.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Copy the example environment file:
|
### Running Locally
|
||||||
|
|
||||||
```bash
|
|
||||||
cp .env.example .env
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Add your configuration to `.env`:
|
|
||||||
|
|
||||||
```
|
|
||||||
FREEDNS_User=your_freedns_username
|
|
||||||
FREEDNS_Password=your_freedns_password
|
|
||||||
DOMAIN_NAME=your.domain.name
|
|
||||||
SSH_PUBLIC_KEY=your_ssh_public_key
|
|
||||||
REPO_URL=repository_url
|
|
||||||
LETSENCRYPT_EMAIL=your_email
|
|
||||||
```
|
|
||||||
|
|
||||||
4. Start the local development environment:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# Start development environment
|
||||||
./scripts/start_local_dev.sh
|
./scripts/start_local_dev.sh
|
||||||
```
|
|
||||||
|
|
||||||
5. Stop the local environment:
|
# Stop development environment
|
||||||
|
|
||||||
```bash
|
|
||||||
./scripts/stop_local_dev.sh
|
./scripts/stop_local_dev.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
## Deployment
|
## ☁️ AWS Deployment
|
||||||
|
|
||||||
### Initial Deployment
|
The infrastructure is defined in CDK. Use the provided scripts:
|
||||||
|
|
||||||
1. Ensure your AWS credentials are configured:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
aws configure
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Install dependencies:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install -r requirements.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Deploy the infrastructure:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# Deploy AWS infrastructure
|
||||||
./scripts/deploy.sh
|
./scripts/deploy.sh
|
||||||
|
|
||||||
|
# Destroy AWS infrastructure
|
||||||
|
./scripts/destroy.sh
|
||||||
|
|
||||||
|
# Create Cognito test user
|
||||||
|
./scripts/create_cognito_user.sh
|
||||||
|
|
||||||
|
# Delete Cognito user
|
||||||
|
./scripts/delete_cognito_user.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
The deployment script will:
|
Key AWS components:
|
||||||
|
|
||||||
- Create/update the CloudFormation stack using CDK
|
- API Gateway
|
||||||
- Configure the EC2 instance with required software
|
- Lambda functions
|
||||||
- Set up HTTPS using Let's Encrypt
|
- RDS PostgreSQL
|
||||||
- Configure the domain using FreeDNS
|
- Cognito User Pool
|
||||||
|
|
||||||
### Continuous Deployment
|
## 🤖 Continuous Integration/Deployment
|
||||||
|
|
||||||
The project includes a Gitea workflow (`.gitea/workflows/aws_deploy_on_push.yml`) that automatically:
|
This project includes a Gitea Actions workflow (`.gitea/workflows/deploy.yml`) for automated deployment to AWS. The workflow is fully compatible with GitHub Actions and can be easily adapted by:
|
||||||
|
|
||||||
- Deploys infrastructure changes
|
1. Placing the workflow file in the `.github/workflows/` directory
|
||||||
- Updates the application on EC2 instances
|
2. Setting up the required secrets in your CI/CD environment:
|
||||||
- Restarts the service
|
- `AWS_ACCESS_KEY_ID`
|
||||||
|
- `AWS_SECRET_ACCESS_KEY`
|
||||||
|
- `AWS_DEFAULT_REGION`
|
||||||
|
|
||||||
## Infrastructure
|
The workflow automatically deploys the infrastructure and application when changes are pushed to the main branch.
|
||||||
|
|
||||||
The AWS infrastructure is defined in `infrastructure/stack.py` and includes:
|
## 📚 API Documentation
|
||||||
|
|
||||||
- VPC with public subnets
|
Access interactive docs at:
|
||||||
- EC2 t2.micro instance (Free Tier eligible)
|
|
||||||
- RDS PostgreSQL database (db.t3.micro)
|
|
||||||
- Security groups for EC2 and RDS
|
|
||||||
- Elastic IP for the EC2 instance
|
|
||||||
- Cognito User Pool for authentication
|
|
||||||
- IAM roles and policies for EC2 instance access
|
|
||||||
|
|
||||||
## User Management
|
- Swagger UI: `http://localhost:8000/docs`
|
||||||
|
- ReDoc: `http://localhost:8000/redoc`
|
||||||
|
|
||||||
### Creating Users
|
### Key Endpoints
|
||||||
|
|
||||||
To create a new user in Cognito:
|
| Endpoint | Method | Description |
|
||||||
|
| ------------- | ------ | --------------------- |
|
||||||
|
| `/auth/login` | POST | User authentication |
|
||||||
|
| `/channels` | GET | List all channels |
|
||||||
|
| `/playlist` | GET | Generate M3U playlist |
|
||||||
|
| `/priorities` | POST | Set channel priority |
|
||||||
|
|
||||||
|
## 🧪 Testing
|
||||||
|
|
||||||
|
Run the full test suite:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./scripts/create_cognito_user.sh <user_pool_id> <username> <password> --admin <= optional for defining an admin user
|
pytest
|
||||||
```
|
```
|
||||||
|
|
||||||
### Deleting Users
|
Test coverage includes:
|
||||||
|
|
||||||
To delete a user from Cognito:
|
- Authentication workflows
|
||||||
|
- Channel CRUD operations
|
||||||
|
- Playlist generation logic
|
||||||
|
- Stream monitoring
|
||||||
|
- Database operations
|
||||||
|
|
||||||
```bash
|
## 📂 Project Structure
|
||||||
./scripts/delete_cognito_user.sh <user_pool_id> <username>
|
|
||||||
|
```txt
|
||||||
|
iptv-manager-service/
|
||||||
|
├── app/ # Core application
|
||||||
|
│ ├── auth/ # Cognito authentication
|
||||||
|
│ ├── iptv/ # Playlist logic
|
||||||
|
│ ├── models/ # Database models
|
||||||
|
│ ├── routers/ # API endpoints
|
||||||
|
│ ├── utils/ # Helper functions
|
||||||
|
│ └── main.py # App entry point
|
||||||
|
├── infrastructure/ # AWS CDK stack
|
||||||
|
├── docker/ # Docker configs
|
||||||
|
├── scripts/ # Deployment scripts
|
||||||
|
├── tests/ # Comprehensive tests
|
||||||
|
├── alembic/ # Database migrations
|
||||||
|
├── .gitea/ # Gitea CI/CD workflows
|
||||||
|
│ └── workflows/
|
||||||
|
└── ... # Config files
|
||||||
```
|
```
|
||||||
|
|
||||||
## Architecture
|
## 📝 License
|
||||||
|
|
||||||
The application is structured as follows:
|
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
||||||
|
|
||||||
```bash
|
|
||||||
app/
|
|
||||||
├── auth/ # Authentication modules
|
|
||||||
├── iptv/ # IPTV and EPG processing
|
|
||||||
├── models/ # Database models
|
|
||||||
└── utils/ # Utility functions
|
|
||||||
|
|
||||||
infrastructure/ # AWS CDK infrastructure code
|
|
||||||
docker/ # Docker configuration for local development
|
|
||||||
scripts/ # Utility scripts for deployment and management
|
|
||||||
```
|
|
||||||
|
|
||||||
## Environment Variables
|
|
||||||
|
|
||||||
The following environment variables are required:
|
|
||||||
|
|
||||||
| Variable | Description |
|
|
||||||
|----------|-------------|
|
|
||||||
| FREEDNS_User | FreeDNS username |
|
|
||||||
| FREEDNS_Password | FreeDNS password |
|
|
||||||
| DOMAIN_NAME | Your domain name |
|
|
||||||
| SSH_PUBLIC_KEY | SSH public key for EC2 access |
|
|
||||||
| REPO_URL | Repository URL |
|
|
||||||
| LETSENCRYPT_EMAIL | Email for Let's Encrypt certificates |
|
|
||||||
|
|
||||||
## Security Notes
|
|
||||||
|
|
||||||
- The EC2 instance has appropriate IAM permissions for:
|
|
||||||
- EC2 instance discovery
|
|
||||||
- SSM command execution
|
|
||||||
- RDS access
|
|
||||||
- Cognito user management
|
|
||||||
- All database credentials are stored in AWS Secrets Manager
|
|
||||||
- HTTPS is enforced using Let's Encrypt certificates
|
|
||||||
- Access is restricted through Security Groups
|
|
||||||
|
|||||||
@@ -15,7 +15,8 @@ config = context.config
|
|||||||
if config.config_file_name is not None:
|
if config.config_file_name is not None:
|
||||||
fileConfig(config.config_file_name)
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
# Setup target metadata for autogenerate support
|
# add your model's MetaData object here
|
||||||
|
# for 'autogenerate' support
|
||||||
target_metadata = Base.metadata
|
target_metadata = Base.metadata
|
||||||
|
|
||||||
# Override sqlalchemy.url with dynamic credentials
|
# Override sqlalchemy.url with dynamic credentials
|
||||||
|
|||||||
@@ -1,59 +0,0 @@
|
|||||||
"""Add priority and in_use fields
|
|
||||||
|
|
||||||
Revision ID: 036879e47172
|
|
||||||
Revises:
|
|
||||||
Create Date: 2025-05-26 19:21:32.285656
|
|
||||||
|
|
||||||
"""
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = '036879e47172'
|
|
||||||
down_revision: Union[str, None] = None
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Upgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
# 1. Create priorities table if not exists
|
|
||||||
if not op.get_bind().engine.dialect.has_table(op.get_bind(), 'priorities'):
|
|
||||||
op.create_table('priorities',
|
|
||||||
sa.Column('id', sa.Integer(), nullable=False),
|
|
||||||
sa.Column('description', sa.String(), nullable=False),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Insert default priorities (skip if already exists)
|
|
||||||
op.execute("""
|
|
||||||
INSERT INTO priorities (id, description)
|
|
||||||
VALUES (100, 'High'), (200, 'Medium'), (300, 'Low')
|
|
||||||
ON CONFLICT (id) DO NOTHING
|
|
||||||
""")
|
|
||||||
# Add new columns with temporary nullable=True
|
|
||||||
op.add_column('channels_urls', sa.Column('in_use', sa.Boolean(), nullable=True))
|
|
||||||
op.add_column('channels_urls', sa.Column('priority_id', sa.Integer(), nullable=True))
|
|
||||||
|
|
||||||
# Set default values
|
|
||||||
op.execute("UPDATE channels_urls SET in_use = false, priority_id = 100")
|
|
||||||
|
|
||||||
# Convert to NOT NULL
|
|
||||||
op.alter_column('channels_urls', 'in_use', nullable=False)
|
|
||||||
op.alter_column('channels_urls', 'priority_id', nullable=False)
|
|
||||||
op.create_foreign_key(None, 'channels_urls', 'priorities', ['priority_id'], ['id'])
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Downgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.drop_constraint('channels_urls_priority_id_fkey', 'channels_urls', type_='foreignkey')
|
|
||||||
op.drop_column('channels_urls', 'priority_id')
|
|
||||||
op.drop_column('channels_urls', 'in_use')
|
|
||||||
op.drop_table('priorities')
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
@@ -0,0 +1,110 @@
|
|||||||
|
"""add groups table and migrate group_title data
|
||||||
|
|
||||||
|
Revision ID: 0a455608256f
|
||||||
|
Revises: 95b61a92455a
|
||||||
|
Create Date: 2025-06-10 09:22:11.820035
|
||||||
|
|
||||||
|
"""
|
||||||
|
import uuid
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '0a455608256f'
|
||||||
|
down_revision: Union[str, None] = '95b61a92455a'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('groups',
|
||||||
|
sa.Column('id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(), nullable=False),
|
||||||
|
sa.Column('sort_order', sa.Integer(), nullable=False, server_default='0'),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.UniqueConstraint('name')
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create temporary table for group mapping
|
||||||
|
group_mapping = op.create_table(
|
||||||
|
'group_mapping',
|
||||||
|
sa.Column('group_title', sa.String(), nullable=False),
|
||||||
|
sa.Column('group_id', sa.UUID(), nullable=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get existing group titles and create groups
|
||||||
|
conn = op.get_bind()
|
||||||
|
distinct_groups = conn.execute(
|
||||||
|
sa.text("SELECT DISTINCT group_title FROM channels")
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
for group in distinct_groups:
|
||||||
|
group_title = group[0]
|
||||||
|
group_id = str(uuid.uuid4())
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"INSERT INTO groups (id, name, sort_order) "
|
||||||
|
"VALUES (:id, :name, 0)"
|
||||||
|
).bindparams(id=group_id, name=group_title)
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
group_mapping.insert().values(
|
||||||
|
group_title=group_title,
|
||||||
|
group_id=group_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add group_id column (nullable first)
|
||||||
|
op.add_column('channels', sa.Column('group_id', sa.UUID(), nullable=True))
|
||||||
|
|
||||||
|
# Update channels with group_ids
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"UPDATE channels c SET group_id = gm.group_id "
|
||||||
|
"FROM group_mapping gm WHERE c.group_title = gm.group_title"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now make group_id non-nullable and add constraints
|
||||||
|
op.alter_column('channels', 'group_id', nullable=False)
|
||||||
|
op.drop_constraint(op.f('uix_group_title_name'), 'channels', type_='unique')
|
||||||
|
op.create_unique_constraint('uix_group_id_name', 'channels', ['group_id', 'name'])
|
||||||
|
op.create_foreign_key('fk_channels_group_id', 'channels', 'groups', ['group_id'], ['id'])
|
||||||
|
|
||||||
|
# Clean up and drop group_title
|
||||||
|
op.drop_table('group_mapping')
|
||||||
|
op.drop_column('channels', 'group_title')
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column('channels', sa.Column('group_title', sa.VARCHAR(), autoincrement=False, nullable=True))
|
||||||
|
|
||||||
|
# Restore group_title values from groups table
|
||||||
|
conn = op.get_bind()
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"UPDATE channels c SET group_title = g.name "
|
||||||
|
"FROM groups g WHERE c.group_id = g.id"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now make group_title non-nullable
|
||||||
|
op.alter_column('channels', 'group_title', nullable=False)
|
||||||
|
|
||||||
|
# Drop constraints and columns
|
||||||
|
op.drop_constraint('fk_channels_group_id', 'channels', type_='foreignkey')
|
||||||
|
op.drop_constraint('uix_group_id_name', 'channels', type_='unique')
|
||||||
|
op.create_unique_constraint(op.f('uix_group_title_name'), 'channels', ['group_title', 'name'])
|
||||||
|
op.drop_column('channels', 'group_id')
|
||||||
|
op.drop_table('groups')
|
||||||
|
# ### end Alembic commands ###
|
||||||
79
alembic/versions/95b61a92455a_create_initial_tables.py
Normal file
79
alembic/versions/95b61a92455a_create_initial_tables.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""create initial tables
|
||||||
|
|
||||||
|
Revision ID: 95b61a92455a
|
||||||
|
Revises:
|
||||||
|
Create Date: 2025-05-29 14:42:16.239587
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '95b61a92455a'
|
||||||
|
down_revision: Union[str, None] = None
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('channels',
|
||||||
|
sa.Column('id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('tvg_id', sa.String(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(), nullable=False),
|
||||||
|
sa.Column('group_title', sa.String(), nullable=False),
|
||||||
|
sa.Column('tvg_name', sa.String(), nullable=True),
|
||||||
|
sa.Column('tvg_logo', sa.String(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.UniqueConstraint('group_title', 'name', name='uix_group_title_name')
|
||||||
|
)
|
||||||
|
op.create_table('priorities',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('description', sa.String(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_table('channels_urls',
|
||||||
|
sa.Column('id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('channel_id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('url', sa.String(), nullable=False),
|
||||||
|
sa.Column('in_use', sa.Boolean(), nullable=False),
|
||||||
|
sa.Column('priority_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['channel_id'], ['channels.id'], ondelete='CASCADE'),
|
||||||
|
sa.ForeignKeyConstraint(['priority_id'], ['priorities.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
# Seed initial priorities
|
||||||
|
op.bulk_insert(
|
||||||
|
sa.Table(
|
||||||
|
'priorities',
|
||||||
|
sa.MetaData(),
|
||||||
|
sa.Column('id', sa.Integer),
|
||||||
|
sa.Column('description', sa.String),
|
||||||
|
),
|
||||||
|
[
|
||||||
|
{'id': 100, 'description': 'High'},
|
||||||
|
{'id': 200, 'description': 'Medium'},
|
||||||
|
{'id': 300, 'description': 'Low'},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# Remove seeded priorities
|
||||||
|
op.execute("DELETE FROM priorities WHERE id IN (100, 200, 300);")
|
||||||
|
|
||||||
|
# Drop tables
|
||||||
|
op.drop_table('channels_urls')
|
||||||
|
op.drop_table('priorities')
|
||||||
|
op.drop_table('channels')
|
||||||
6
app.py
6
app.py
@@ -3,7 +3,7 @@ import os
|
|||||||
|
|
||||||
import aws_cdk as cdk
|
import aws_cdk as cdk
|
||||||
|
|
||||||
from infrastructure.stack import IptvUpdaterStack
|
from infrastructure.stack import IptvManagerStack
|
||||||
|
|
||||||
app = cdk.App()
|
app = cdk.App()
|
||||||
|
|
||||||
@@ -31,9 +31,9 @@ if missing_vars:
|
|||||||
f"Missing required environment variables: {', '.join(missing_vars)}"
|
f"Missing required environment variables: {', '.join(missing_vars)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
IptvUpdaterStack(
|
IptvManagerStack(
|
||||||
app,
|
app,
|
||||||
"IptvUpdaterStack",
|
"IptvManagerStack",
|
||||||
freedns_user=freedns_user,
|
freedns_user=freedns_user,
|
||||||
freedns_password=freedns_password,
|
freedns_password=freedns_password,
|
||||||
domain_name=domain_name,
|
domain_name=domain_name,
|
||||||
|
|||||||
@@ -32,7 +32,9 @@ def require_roles(*required_roles: str) -> Callable:
|
|||||||
|
|
||||||
def decorator(endpoint: Callable) -> Callable:
|
def decorator(endpoint: Callable) -> Callable:
|
||||||
@wraps(endpoint)
|
@wraps(endpoint)
|
||||||
def wrapper(*args, user: CognitoUser = Depends(get_current_user), **kwargs):
|
async def wrapper(
|
||||||
|
*args, user: CognitoUser = Depends(get_current_user), **kwargs
|
||||||
|
):
|
||||||
user_roles = set(user.roles or [])
|
user_roles = set(user.roles or [])
|
||||||
needed_roles = set(required_roles)
|
needed_roles = set(required_roles)
|
||||||
if not needed_roles.issubset(user_roles):
|
if not needed_roles.issubset(user_roles):
|
||||||
|
|||||||
110
app/iptv/scheduler.py
Normal file
110
app/iptv/scheduler.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from apscheduler.schedulers.background import BackgroundScheduler
|
||||||
|
from apscheduler.triggers.cron import CronTrigger
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.iptv.stream_manager import StreamManager
|
||||||
|
from app.models.db import ChannelDB
|
||||||
|
from app.utils.database import get_db_session
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamScheduler:
|
||||||
|
"""Scheduler service for periodic stream validation tasks."""
|
||||||
|
|
||||||
|
def __init__(self, app: Optional[FastAPI] = None):
|
||||||
|
"""
|
||||||
|
Initialize the scheduler with optional FastAPI app integration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: Optional FastAPI app instance for lifecycle integration
|
||||||
|
"""
|
||||||
|
self.scheduler = BackgroundScheduler()
|
||||||
|
self.app = app
|
||||||
|
self.batch_size = int(os.getenv("STREAM_VALIDATION_BATCH_SIZE", "10"))
|
||||||
|
self.schedule_time = os.getenv(
|
||||||
|
"STREAM_VALIDATION_SCHEDULE", "0 3 * * *"
|
||||||
|
) # Default 3 AM daily
|
||||||
|
logger.info(f"Scheduler initialized with app: {app is not None}")
|
||||||
|
|
||||||
|
def validate_streams_batch(self, db_session: Optional[Session] = None) -> None:
|
||||||
|
"""
|
||||||
|
Validate streams and update their status.
|
||||||
|
When batch_size=0, validates all channels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Optional SQLAlchemy session
|
||||||
|
"""
|
||||||
|
db = db_session if db_session else get_db_session()
|
||||||
|
try:
|
||||||
|
manager = StreamManager(db)
|
||||||
|
|
||||||
|
# Get channels to validate
|
||||||
|
query = db.query(ChannelDB)
|
||||||
|
if self.batch_size > 0:
|
||||||
|
query = query.limit(self.batch_size)
|
||||||
|
channels = query.all()
|
||||||
|
|
||||||
|
for channel in channels:
|
||||||
|
try:
|
||||||
|
logger.info(f"Validating streams for channel {channel.id}")
|
||||||
|
manager.validate_and_select_stream(str(channel.id))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error validating channel {channel.id}: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(f"Completed stream validation of {len(channels)} channels")
|
||||||
|
finally:
|
||||||
|
if db_session is None:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
"""Start the scheduler and add jobs."""
|
||||||
|
if not self.scheduler.running:
|
||||||
|
# Add the scheduled job
|
||||||
|
self.scheduler.add_job(
|
||||||
|
self.validate_streams_batch,
|
||||||
|
trigger=CronTrigger.from_crontab(self.schedule_time),
|
||||||
|
id="daily_stream_validation",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start the scheduler
|
||||||
|
self.scheduler.start()
|
||||||
|
logger.info(
|
||||||
|
f"Stream scheduler started with daily validation job. "
|
||||||
|
f"Running: {self.scheduler.running}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register shutdown handler if FastAPI app is provided
|
||||||
|
if self.app:
|
||||||
|
logger.info(
|
||||||
|
f"Registering scheduler with FastAPI "
|
||||||
|
f"app: {hasattr(self.app, 'state')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@self.app.on_event("shutdown")
|
||||||
|
def shutdown_scheduler():
|
||||||
|
self.shutdown()
|
||||||
|
|
||||||
|
def shutdown(self) -> None:
|
||||||
|
"""Shutdown the scheduler gracefully."""
|
||||||
|
if self.scheduler.running:
|
||||||
|
self.scheduler.shutdown()
|
||||||
|
logger.info("Stream scheduler stopped")
|
||||||
|
|
||||||
|
def trigger_manual_validation(self) -> None:
|
||||||
|
"""Trigger manual validation of streams."""
|
||||||
|
logger.info("Manually triggering stream validation")
|
||||||
|
self.validate_streams_batch()
|
||||||
|
|
||||||
|
|
||||||
|
def init_scheduler(app: FastAPI) -> StreamScheduler:
|
||||||
|
"""Initialize and start the scheduler with FastAPI integration."""
|
||||||
|
scheduler = StreamScheduler(app)
|
||||||
|
scheduler.start()
|
||||||
|
return scheduler
|
||||||
151
app/iptv/stream_manager.py
Normal file
151
app/iptv/stream_manager.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
import logging
|
||||||
|
import random
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.models.db import ChannelURL
|
||||||
|
from app.utils.check_streams import StreamValidator
|
||||||
|
from app.utils.database import get_db_session
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamManager:
|
||||||
|
"""Service for managing and validating channel streams."""
|
||||||
|
|
||||||
|
def __init__(self, db_session: Optional[Session] = None):
|
||||||
|
"""
|
||||||
|
Initialize StreamManager with optional database session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Optional SQLAlchemy session. If None, will create a new one.
|
||||||
|
"""
|
||||||
|
self.db = db_session if db_session else get_db_session()
|
||||||
|
self.validator = StreamValidator()
|
||||||
|
|
||||||
|
def get_streams_for_channel(self, channel_id: str) -> list[ChannelURL]:
|
||||||
|
"""
|
||||||
|
Get all streams for a channel ordered by priority (lowest first),
|
||||||
|
with same-priority streams randomized.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel_id: UUID of the channel to get streams for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ChannelURL objects ordered by priority
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get all streams for channel ordered by priority
|
||||||
|
streams = (
|
||||||
|
self.db.query(ChannelURL)
|
||||||
|
.filter(ChannelURL.channel_id == channel_id)
|
||||||
|
.order_by(ChannelURL.priority_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Group streams by priority and randomize same-priority streams
|
||||||
|
grouped = {}
|
||||||
|
for stream in streams:
|
||||||
|
if stream.priority_id not in grouped:
|
||||||
|
grouped[stream.priority_id] = []
|
||||||
|
grouped[stream.priority_id].append(stream)
|
||||||
|
|
||||||
|
# Randomize same-priority streams and flatten
|
||||||
|
randomized_streams = []
|
||||||
|
for priority in sorted(grouped.keys()):
|
||||||
|
random.shuffle(grouped[priority])
|
||||||
|
randomized_streams.extend(grouped[priority])
|
||||||
|
|
||||||
|
return randomized_streams
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting streams for channel {channel_id}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def validate_and_select_stream(self, channel_id: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Find and validate a working stream for the given channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel_id: UUID of the channel to find a stream for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
URL of the first working stream found, or None if none found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
streams = self.get_streams_for_channel(channel_id)
|
||||||
|
if not streams:
|
||||||
|
logger.warning(f"No streams found for channel {channel_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
working_stream = None
|
||||||
|
|
||||||
|
for stream in streams:
|
||||||
|
logger.info(f"Validating stream {stream.url} for channel {channel_id}")
|
||||||
|
is_valid, _ = self.validator.validate_stream(stream.url)
|
||||||
|
|
||||||
|
if is_valid:
|
||||||
|
working_stream = stream
|
||||||
|
break
|
||||||
|
|
||||||
|
if working_stream:
|
||||||
|
self._update_stream_status(working_stream, streams)
|
||||||
|
return working_stream.url
|
||||||
|
else:
|
||||||
|
logger.warning(f"No valid streams found for channel {channel_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error validating streams for channel {channel_id}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _update_stream_status(
|
||||||
|
self, working_stream: ChannelURL, all_streams: list[ChannelURL]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Update in_use status for streams (True for working stream, False for others).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
working_stream: The stream that was validated as working
|
||||||
|
all_streams: All streams for the channel
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
for stream in all_streams:
|
||||||
|
stream.in_use = stream.id == working_stream.id
|
||||||
|
|
||||||
|
self.db.commit()
|
||||||
|
logger.info(
|
||||||
|
f"Updated stream status - set in_use=True for {working_stream.url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.db.rollback()
|
||||||
|
logger.error(f"Error updating stream status: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""Close database session when StreamManager is destroyed."""
|
||||||
|
if hasattr(self, "db"):
|
||||||
|
self.db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def get_working_stream(
|
||||||
|
channel_id: str, db_session: Optional[Session] = None
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Convenience function to get a working stream for a channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel_id: UUID of the channel to get a stream for
|
||||||
|
db_session: Optional SQLAlchemy session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
URL of the first working stream found, or None if none found
|
||||||
|
"""
|
||||||
|
manager = StreamManager(db_session)
|
||||||
|
try:
|
||||||
|
return manager.validate_and_select_stream(channel_id)
|
||||||
|
finally:
|
||||||
|
if db_session is None: # Only close if we created the session
|
||||||
|
manager.__del__()
|
||||||
20
app/main.py
20
app/main.py
@@ -2,7 +2,8 @@ from fastapi import FastAPI
|
|||||||
from fastapi.concurrency import asynccontextmanager
|
from fastapi.concurrency import asynccontextmanager
|
||||||
from fastapi.openapi.utils import get_openapi
|
from fastapi.openapi.utils import get_openapi
|
||||||
|
|
||||||
from app.routers import auth, channels, playlist, priorities
|
from app.iptv.scheduler import StreamScheduler
|
||||||
|
from app.routers import auth, channels, groups, playlist, priorities, scheduler
|
||||||
from app.utils.database import init_db
|
from app.utils.database import init_db
|
||||||
|
|
||||||
|
|
||||||
@@ -10,13 +11,22 @@ from app.utils.database import init_db
|
|||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# Initialize database tables on startup
|
# Initialize database tables on startup
|
||||||
init_db()
|
init_db()
|
||||||
|
|
||||||
|
# Initialize and start the stream scheduler
|
||||||
|
scheduler = StreamScheduler(app)
|
||||||
|
app.state.scheduler = scheduler # Store scheduler in app state
|
||||||
|
scheduler.start()
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
# Shutdown scheduler on app shutdown
|
||||||
|
scheduler.shutdown()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
title="IPTV Updater API",
|
title="IPTV Manager API",
|
||||||
description="API for IPTV Updater service",
|
description="API for IPTV Manager service",
|
||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -60,7 +70,7 @@ app.openapi = custom_openapi
|
|||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def root():
|
async def root():
|
||||||
return {"message": "IPTV Updater API"}
|
return {"message": "IPTV Manager API"}
|
||||||
|
|
||||||
|
|
||||||
# Include routers
|
# Include routers
|
||||||
@@ -68,3 +78,5 @@ app.include_router(auth.router)
|
|||||||
app.include_router(channels.router)
|
app.include_router(channels.router)
|
||||||
app.include_router(playlist.router)
|
app.include_router(playlist.router)
|
||||||
app.include_router(priorities.router)
|
app.include_router(priorities.router)
|
||||||
|
app.include_router(groups.router)
|
||||||
|
app.include_router(scheduler.router)
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
from .db import Base, ChannelDB, ChannelURL
|
from .db import Base, ChannelDB, ChannelURL, Group, Priority
|
||||||
from .schemas import (
|
from .schemas import (
|
||||||
ChannelCreate,
|
ChannelCreate,
|
||||||
ChannelResponse,
|
ChannelResponse,
|
||||||
ChannelUpdate,
|
ChannelUpdate,
|
||||||
ChannelURLCreate,
|
ChannelURLCreate,
|
||||||
ChannelURLResponse,
|
ChannelURLResponse,
|
||||||
|
GroupCreate,
|
||||||
|
GroupResponse,
|
||||||
|
GroupUpdate,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -16,4 +19,9 @@ __all__ = [
|
|||||||
"ChannelURL",
|
"ChannelURL",
|
||||||
"ChannelURLCreate",
|
"ChannelURLCreate",
|
||||||
"ChannelURLResponse",
|
"ChannelURLResponse",
|
||||||
|
"Group",
|
||||||
|
"Priority",
|
||||||
|
"GroupCreate",
|
||||||
|
"GroupResponse",
|
||||||
|
"GroupUpdate",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,18 +1,60 @@
|
|||||||
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
|
TEXT,
|
||||||
Boolean,
|
Boolean,
|
||||||
Column,
|
Column,
|
||||||
DateTime,
|
DateTime,
|
||||||
ForeignKey,
|
ForeignKey,
|
||||||
Integer,
|
Integer,
|
||||||
String,
|
String,
|
||||||
|
TypeDecorator,
|
||||||
UniqueConstraint,
|
UniqueConstraint,
|
||||||
)
|
)
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
from sqlalchemy.orm import declarative_base, relationship
|
from sqlalchemy.orm import declarative_base, relationship
|
||||||
|
|
||||||
|
|
||||||
|
# Custom UUID type for SQLite compatibility
|
||||||
|
class SQLiteUUID(TypeDecorator):
|
||||||
|
"""Enables UUID support for SQLite with proper comparison handling."""
|
||||||
|
|
||||||
|
impl = TEXT
|
||||||
|
cache_ok = True
|
||||||
|
|
||||||
|
def process_bind_param(self, value, dialect):
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
if isinstance(value, uuid.UUID):
|
||||||
|
return str(value)
|
||||||
|
try:
|
||||||
|
# Validate string format by attempting to create UUID
|
||||||
|
uuid.UUID(value)
|
||||||
|
return value
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
raise ValueError(f"Invalid UUID string format: {value}")
|
||||||
|
|
||||||
|
def process_result_value(self, value, dialect):
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
if isinstance(value, uuid.UUID):
|
||||||
|
return value
|
||||||
|
return uuid.UUID(value)
|
||||||
|
|
||||||
|
def compare_values(self, x, y):
|
||||||
|
if x is None or y is None:
|
||||||
|
return x == y
|
||||||
|
return str(x) == str(y)
|
||||||
|
|
||||||
|
|
||||||
|
# Determine which UUID type to use based on environment
|
||||||
|
if os.getenv("MOCK_AUTH", "").lower() == "true":
|
||||||
|
UUID_COLUMN_TYPE = SQLiteUUID()
|
||||||
|
else:
|
||||||
|
UUID_COLUMN_TYPE = UUID(as_uuid=True)
|
||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
@@ -25,20 +67,37 @@ class Priority(Base):
|
|||||||
description = Column(String, nullable=False)
|
description = Column(String, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
class Group(Base):
|
||||||
|
"""SQLAlchemy model for channel groups"""
|
||||||
|
|
||||||
|
__tablename__ = "groups"
|
||||||
|
|
||||||
|
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
||||||
|
name = Column(String, nullable=False, unique=True)
|
||||||
|
sort_order = Column(Integer, nullable=False, default=0)
|
||||||
|
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||||
|
updated_at = Column(
|
||||||
|
DateTime,
|
||||||
|
default=lambda: datetime.now(timezone.utc),
|
||||||
|
onupdate=lambda: datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Relationship with Channel
|
||||||
|
channels = relationship("ChannelDB", back_populates="group")
|
||||||
|
|
||||||
|
|
||||||
class ChannelDB(Base):
|
class ChannelDB(Base):
|
||||||
"""SQLAlchemy model for IPTV channels"""
|
"""SQLAlchemy model for IPTV channels"""
|
||||||
|
|
||||||
__tablename__ = "channels"
|
__tablename__ = "channels"
|
||||||
|
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
||||||
tvg_id = Column(String, nullable=False)
|
tvg_id = Column(String, nullable=False)
|
||||||
name = Column(String, nullable=False)
|
name = Column(String, nullable=False)
|
||||||
group_title = Column(String, nullable=False)
|
group_id = Column(UUID_COLUMN_TYPE, ForeignKey("groups.id"), nullable=False)
|
||||||
tvg_name = Column(String)
|
tvg_name = Column(String)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (UniqueConstraint("group_id", "name", name="uix_group_id_name"),)
|
||||||
UniqueConstraint("group_title", "name", name="uix_group_title_name"),
|
|
||||||
)
|
|
||||||
tvg_logo = Column(String)
|
tvg_logo = Column(String)
|
||||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||||
updated_at = Column(
|
updated_at = Column(
|
||||||
@@ -47,10 +106,11 @@ class ChannelDB(Base):
|
|||||||
onupdate=lambda: datetime.now(timezone.utc),
|
onupdate=lambda: datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Relationship with ChannelURL
|
# Relationships
|
||||||
urls = relationship(
|
urls = relationship(
|
||||||
"ChannelURL", back_populates="channel", cascade="all, delete-orphan"
|
"ChannelURL", back_populates="channel", cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
|
group = relationship("Group", back_populates="channels")
|
||||||
|
|
||||||
|
|
||||||
class ChannelURL(Base):
|
class ChannelURL(Base):
|
||||||
@@ -58,9 +118,9 @@ class ChannelURL(Base):
|
|||||||
|
|
||||||
__tablename__ = "channels_urls"
|
__tablename__ = "channels_urls"
|
||||||
|
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
||||||
channel_id = Column(
|
channel_id = Column(
|
||||||
UUID(as_uuid=True),
|
UUID_COLUMN_TYPE,
|
||||||
ForeignKey("channels.id", ondelete="CASCADE"),
|
ForeignKey("channels.id", ondelete="CASCADE"),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -53,12 +53,54 @@ class ChannelURLResponse(ChannelURLBase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# New Group Schemas
|
||||||
|
class GroupCreate(BaseModel):
|
||||||
|
"""Pydantic model for creating groups"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
sort_order: int = Field(default=0, ge=0)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupUpdate(BaseModel):
|
||||||
|
"""Pydantic model for updating groups"""
|
||||||
|
|
||||||
|
name: Optional[str] = None
|
||||||
|
sort_order: Optional[int] = Field(None, ge=0)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupResponse(BaseModel):
|
||||||
|
"""Pydantic model for group responses"""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
name: str
|
||||||
|
sort_order: int
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupSortUpdate(BaseModel):
|
||||||
|
"""Pydantic model for updating a single group's sort order"""
|
||||||
|
|
||||||
|
sort_order: int = Field(ge=0)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupBulkSort(BaseModel):
|
||||||
|
"""Pydantic model for bulk updating group sort orders"""
|
||||||
|
|
||||||
|
groups: list[dict] = Field(
|
||||||
|
description="List of dicts with group_id and new sort_order",
|
||||||
|
json_schema_extra={"example": [{"group_id": "uuid", "sort_order": 1}]},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChannelCreate(BaseModel):
|
class ChannelCreate(BaseModel):
|
||||||
"""Pydantic model for creating channels"""
|
"""Pydantic model for creating channels"""
|
||||||
|
|
||||||
urls: list[ChannelURLCreate] # List of URL objects with priority
|
urls: list[ChannelURLCreate] # List of URL objects with priority
|
||||||
name: str
|
name: str
|
||||||
group_title: str
|
group_id: UUID
|
||||||
tvg_id: str
|
tvg_id: str
|
||||||
tvg_logo: str
|
tvg_logo: str
|
||||||
tvg_name: str
|
tvg_name: str
|
||||||
@@ -76,7 +118,7 @@ class ChannelUpdate(BaseModel):
|
|||||||
"""Pydantic model for updating channels (all fields optional)"""
|
"""Pydantic model for updating channels (all fields optional)"""
|
||||||
|
|
||||||
name: Optional[str] = Field(None, min_length=1)
|
name: Optional[str] = Field(None, min_length=1)
|
||||||
group_title: Optional[str] = Field(None, min_length=1)
|
group_id: Optional[UUID] = None
|
||||||
tvg_id: Optional[str] = Field(None, min_length=1)
|
tvg_id: Optional[str] = Field(None, min_length=1)
|
||||||
tvg_logo: Optional[str] = None
|
tvg_logo: Optional[str] = None
|
||||||
tvg_name: Optional[str] = Field(None, min_length=1)
|
tvg_name: Optional[str] = Field(None, min_length=1)
|
||||||
@@ -87,7 +129,7 @@ class ChannelResponse(BaseModel):
|
|||||||
|
|
||||||
id: UUID
|
id: UUID
|
||||||
name: str
|
name: str
|
||||||
group_title: str
|
group_id: UUID
|
||||||
tvg_id: str
|
tvg_id: str
|
||||||
tvg_logo: str
|
tvg_logo: str
|
||||||
tvg_name: str
|
tvg_name: str
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ from app.models import (
|
|||||||
ChannelURL,
|
ChannelURL,
|
||||||
ChannelURLCreate,
|
ChannelURLCreate,
|
||||||
ChannelURLResponse,
|
ChannelURLResponse,
|
||||||
|
Group,
|
||||||
|
Priority, # Added Priority import
|
||||||
)
|
)
|
||||||
from app.models.auth import CognitoUser
|
from app.models.auth import CognitoUser
|
||||||
from app.models.schemas import ChannelURLUpdate
|
from app.models.schemas import ChannelURLUpdate
|
||||||
@@ -29,12 +31,20 @@ def create_channel(
|
|||||||
user: CognitoUser = Depends(get_current_user),
|
user: CognitoUser = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Create a new channel"""
|
"""Create a new channel"""
|
||||||
# Check for duplicate channel (same group_title + name)
|
# Check if group exists
|
||||||
|
group = db.query(Group).filter(Group.id == channel.group_id).first()
|
||||||
|
if not group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Group not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for duplicate channel (same group_id + name)
|
||||||
existing_channel = (
|
existing_channel = (
|
||||||
db.query(ChannelDB)
|
db.query(ChannelDB)
|
||||||
.filter(
|
.filter(
|
||||||
and_(
|
and_(
|
||||||
ChannelDB.group_title == channel.group_title,
|
ChannelDB.group_id == channel.group_id,
|
||||||
ChannelDB.name == channel.name,
|
ChannelDB.name == channel.name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -44,7 +54,7 @@ def create_channel(
|
|||||||
if existing_channel:
|
if existing_channel:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_409_CONFLICT,
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
detail="Channel with same group_title and name already exists",
|
detail="Channel with same group_id and name already exists",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create channel without URLs first
|
# Create channel without URLs first
|
||||||
@@ -96,20 +106,27 @@ def update_channel(
|
|||||||
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only check for duplicates if name or group_title are being updated
|
# Only check for duplicates if name or group_id are being updated
|
||||||
if channel.name is not None or channel.group_title is not None:
|
if channel.name is not None or channel.group_id is not None:
|
||||||
name = channel.name if channel.name is not None else db_channel.name
|
name = channel.name if channel.name is not None else db_channel.name
|
||||||
group_title = (
|
group_id = (
|
||||||
channel.group_title
|
channel.group_id if channel.group_id is not None else db_channel.group_id
|
||||||
if channel.group_title is not None
|
|
||||||
else db_channel.group_title
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if new group exists
|
||||||
|
if channel.group_id is not None:
|
||||||
|
group = db.query(Group).filter(Group.id == channel.group_id).first()
|
||||||
|
if not group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Group not found",
|
||||||
|
)
|
||||||
|
|
||||||
existing_channel = (
|
existing_channel = (
|
||||||
db.query(ChannelDB)
|
db.query(ChannelDB)
|
||||||
.filter(
|
.filter(
|
||||||
and_(
|
and_(
|
||||||
ChannelDB.group_title == group_title,
|
ChannelDB.group_id == group_id,
|
||||||
ChannelDB.name == name,
|
ChannelDB.name == name,
|
||||||
ChannelDB.id != channel_id,
|
ChannelDB.id != channel_id,
|
||||||
)
|
)
|
||||||
@@ -120,7 +137,7 @@ def update_channel(
|
|||||||
if existing_channel:
|
if existing_channel:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_409_CONFLICT,
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
detail="Channel with same group_title and name already exists",
|
detail="Channel with same group_id and name already exists",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update only provided fields
|
# Update only provided fields
|
||||||
@@ -133,6 +150,41 @@ def update_channel(
|
|||||||
return db_channel
|
return db_channel
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/", status_code=status.HTTP_200_OK)
|
||||||
|
@require_roles("admin")
|
||||||
|
def delete_channels(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Delete all channels"""
|
||||||
|
count = 0
|
||||||
|
try:
|
||||||
|
count = db.query(ChannelDB).count()
|
||||||
|
|
||||||
|
# First delete all channels
|
||||||
|
db.query(ChannelDB).delete()
|
||||||
|
|
||||||
|
# Then delete any URLs that are now orphaned (no channel references)
|
||||||
|
db.query(ChannelURL).filter(
|
||||||
|
~ChannelURL.channel_id.in_(db.query(ChannelDB.id))
|
||||||
|
).delete(synchronize_session=False)
|
||||||
|
|
||||||
|
# Then delete any groups that are now empty
|
||||||
|
db.query(Group).filter(~Group.id.in_(db.query(ChannelDB.group_id))).delete(
|
||||||
|
synchronize_session=False
|
||||||
|
)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error deleting channels: {e}")
|
||||||
|
db.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to delete channels",
|
||||||
|
)
|
||||||
|
return {"deleted": count}
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{channel_id}", status_code=status.HTTP_204_NO_CONTENT)
|
@router.delete("/{channel_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
@require_roles("admin")
|
@require_roles("admin")
|
||||||
def delete_channel(
|
def delete_channel(
|
||||||
@@ -163,9 +215,193 @@ def list_channels(
|
|||||||
return db.query(ChannelDB).offset(skip).limit(limit).all()
|
return db.query(ChannelDB).offset(skip).limit(limit).all()
|
||||||
|
|
||||||
|
|
||||||
|
# New endpoint to get channels by group
|
||||||
|
@router.get("/groups/{group_id}/channels", response_model=list[ChannelResponse])
|
||||||
|
def get_channels_by_group(
|
||||||
|
group_id: UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Get all channels for a specific group"""
|
||||||
|
group = db.query(Group).filter(Group.id == group_id).first()
|
||||||
|
if not group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
|
||||||
|
)
|
||||||
|
return db.query(ChannelDB).filter(ChannelDB.group_id == group_id).all()
|
||||||
|
|
||||||
|
|
||||||
|
# New endpoint to update a channel's group
|
||||||
|
@router.put("/{channel_id}/group", response_model=ChannelResponse)
|
||||||
|
@require_roles("admin")
|
||||||
|
def update_channel_group(
|
||||||
|
channel_id: UUID,
|
||||||
|
group_id: UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Update a channel's group"""
|
||||||
|
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
|
||||||
|
if not channel:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
group = db.query(Group).filter(Group.id == group_id).first()
|
||||||
|
if not group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for duplicate channel name in new group
|
||||||
|
existing_channel = (
|
||||||
|
db.query(ChannelDB)
|
||||||
|
.filter(
|
||||||
|
and_(
|
||||||
|
ChannelDB.group_id == group_id,
|
||||||
|
ChannelDB.name == channel.name,
|
||||||
|
ChannelDB.id != channel_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if existing_channel:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail="Channel with same name already exists in target group",
|
||||||
|
)
|
||||||
|
|
||||||
|
channel.group_id = group_id
|
||||||
|
db.commit()
|
||||||
|
db.refresh(channel)
|
||||||
|
return channel
|
||||||
|
|
||||||
|
|
||||||
|
# Bulk Upload and Reset Endpoints
|
||||||
|
@router.post("/bulk-upload", status_code=status.HTTP_200_OK)
|
||||||
|
@require_roles("admin")
|
||||||
|
def bulk_upload_channels(
|
||||||
|
channels: list[dict],
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Bulk upload channels from JSON array"""
|
||||||
|
processed = 0
|
||||||
|
|
||||||
|
# Fetch all priorities from the database, ordered by id
|
||||||
|
priorities = db.query(Priority).order_by(Priority.id).all()
|
||||||
|
priority_map = {i: p.id for i, p in enumerate(priorities)}
|
||||||
|
|
||||||
|
# Get the highest priority_id (which corresponds to the lowest priority level)
|
||||||
|
max_priority_id = None
|
||||||
|
if priorities:
|
||||||
|
max_priority_id = db.query(Priority.id).order_by(Priority.id.desc()).first()[0]
|
||||||
|
|
||||||
|
for channel_data in channels:
|
||||||
|
try:
|
||||||
|
# Get or create group
|
||||||
|
group_name = channel_data.get("group-title")
|
||||||
|
if not group_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
group = db.query(Group).filter(Group.name == group_name).first()
|
||||||
|
if not group:
|
||||||
|
group = Group(name=group_name)
|
||||||
|
db.add(group)
|
||||||
|
db.flush() # Use flush to make the group available in the session
|
||||||
|
db.refresh(group)
|
||||||
|
|
||||||
|
# Prepare channel data
|
||||||
|
urls = channel_data.get("urls", [])
|
||||||
|
if not isinstance(urls, list):
|
||||||
|
urls = [urls]
|
||||||
|
|
||||||
|
# Assign priorities dynamically based on fetched priorities
|
||||||
|
url_objects = []
|
||||||
|
for i, url in enumerate(urls): # Process all URLs
|
||||||
|
priority_id = priority_map.get(i)
|
||||||
|
if priority_id is None:
|
||||||
|
# If index is out of bounds,
|
||||||
|
# assign the highest priority_id (lowest priority)
|
||||||
|
if max_priority_id is not None:
|
||||||
|
priority_id = max_priority_id
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Warning: No priorities defined in database. "
|
||||||
|
f"Skipping URL {url}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
url_objects.append({"url": url, "priority_id": priority_id})
|
||||||
|
|
||||||
|
# Create channel object with required fields
|
||||||
|
channel_obj = ChannelDB(
|
||||||
|
tvg_id=channel_data.get("tvg-id", ""),
|
||||||
|
name=channel_data.get("name", ""),
|
||||||
|
group_id=group.id,
|
||||||
|
tvg_name=channel_data.get("tvg-name", ""),
|
||||||
|
tvg_logo=channel_data.get("tvg-logo", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Upsert channel
|
||||||
|
existing_channel = (
|
||||||
|
db.query(ChannelDB)
|
||||||
|
.filter(
|
||||||
|
and_(
|
||||||
|
ChannelDB.group_id == group.id,
|
||||||
|
ChannelDB.name == channel_obj.name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if existing_channel:
|
||||||
|
# Update existing
|
||||||
|
existing_channel.tvg_id = channel_obj.tvg_id
|
||||||
|
existing_channel.tvg_name = channel_obj.tvg_name
|
||||||
|
existing_channel.tvg_logo = channel_obj.tvg_logo
|
||||||
|
|
||||||
|
# Clear and recreate URLs
|
||||||
|
db.query(ChannelURL).filter(
|
||||||
|
ChannelURL.channel_id == existing_channel.id
|
||||||
|
).delete()
|
||||||
|
|
||||||
|
for url in url_objects:
|
||||||
|
db_url = ChannelURL(
|
||||||
|
channel_id=existing_channel.id,
|
||||||
|
url=url["url"],
|
||||||
|
priority_id=url["priority_id"],
|
||||||
|
in_use=False,
|
||||||
|
)
|
||||||
|
db.add(db_url)
|
||||||
|
else:
|
||||||
|
# Create new
|
||||||
|
db.add(channel_obj)
|
||||||
|
db.flush() # Flush to get the new channel's ID
|
||||||
|
db.refresh(channel_obj)
|
||||||
|
|
||||||
|
# Add URLs for new channel
|
||||||
|
for url in url_objects:
|
||||||
|
db_url = ChannelURL(
|
||||||
|
channel_id=channel_obj.id,
|
||||||
|
url=url["url"],
|
||||||
|
priority_id=url["priority_id"],
|
||||||
|
in_use=False,
|
||||||
|
)
|
||||||
|
db.add(db_url)
|
||||||
|
|
||||||
|
db.commit() # Commit all changes for this channel atomically
|
||||||
|
processed += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing channel: {channel_data.get('name', 'Unknown')}")
|
||||||
|
print(f"Exception details: {e}")
|
||||||
|
db.rollback() # Rollback the entire transaction for the failed channel
|
||||||
|
continue
|
||||||
|
|
||||||
|
return {"processed": processed}
|
||||||
|
|
||||||
|
|
||||||
# URL Management Endpoints
|
# URL Management Endpoints
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/{channel_id}/urls",
|
"/{channel_id}/urls",
|
||||||
response_model=ChannelURLResponse,
|
response_model=ChannelURLResponse,
|
||||||
|
|||||||
191
app/routers/groups.py
Normal file
191
app/routers/groups.py
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.auth.dependencies import get_current_user, require_roles
|
||||||
|
from app.models import Group
|
||||||
|
from app.models.auth import CognitoUser
|
||||||
|
from app.models.schemas import (
|
||||||
|
GroupBulkSort,
|
||||||
|
GroupCreate,
|
||||||
|
GroupResponse,
|
||||||
|
GroupSortUpdate,
|
||||||
|
GroupUpdate,
|
||||||
|
)
|
||||||
|
from app.utils.database import get_db
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/groups", tags=["groups"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/", response_model=GroupResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
@require_roles("admin")
|
||||||
|
def create_group(
|
||||||
|
group: GroupCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Create a new channel group"""
|
||||||
|
# Check for duplicate group name
|
||||||
|
existing_group = db.query(Group).filter(Group.name == group.name).first()
|
||||||
|
if existing_group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail="Group with this name already exists",
|
||||||
|
)
|
||||||
|
|
||||||
|
db_group = Group(**group.model_dump())
|
||||||
|
db.add(db_group)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_group)
|
||||||
|
return db_group
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{group_id}", response_model=GroupResponse)
|
||||||
|
def get_group(group_id: UUID, db: Session = Depends(get_db)):
|
||||||
|
"""Get a group by id"""
|
||||||
|
group = db.query(Group).filter(Group.id == group_id).first()
|
||||||
|
if not group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
|
||||||
|
)
|
||||||
|
return group
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{group_id}", response_model=GroupResponse)
|
||||||
|
@require_roles("admin")
|
||||||
|
def update_group(
|
||||||
|
group_id: UUID,
|
||||||
|
group: GroupUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Update a group's name or sort order"""
|
||||||
|
db_group = db.query(Group).filter(Group.id == group_id).first()
|
||||||
|
if not db_group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for duplicate name if name is being updated
|
||||||
|
if group.name is not None and group.name != db_group.name:
|
||||||
|
existing_group = db.query(Group).filter(Group.name == group.name).first()
|
||||||
|
if existing_group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail="Group with this name already exists",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update only provided fields
|
||||||
|
update_data = group.model_dump(exclude_unset=True)
|
||||||
|
for key, value in update_data.items():
|
||||||
|
setattr(db_group, key, value)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_group)
|
||||||
|
return db_group
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/", status_code=status.HTTP_200_OK)
|
||||||
|
@require_roles("admin")
|
||||||
|
def delete_groups(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Delete all groups that have no channels (skip groups with channels)"""
|
||||||
|
groups = db.query(Group).all()
|
||||||
|
deleted = 0
|
||||||
|
skipped = 0
|
||||||
|
|
||||||
|
for group in groups:
|
||||||
|
if not group.channels:
|
||||||
|
db.delete(group)
|
||||||
|
deleted += 1
|
||||||
|
else:
|
||||||
|
skipped += 1
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
return {"deleted": deleted, "skipped": skipped}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{group_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
@require_roles("admin")
|
||||||
|
def delete_group(
|
||||||
|
group_id: UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Delete a group (only if it has no channels)"""
|
||||||
|
group = db.query(Group).filter(Group.id == group_id).first()
|
||||||
|
if not group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if group has any channels
|
||||||
|
if group.channels:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Cannot delete group with existing channels",
|
||||||
|
)
|
||||||
|
|
||||||
|
db.delete(group)
|
||||||
|
db.commit()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/", response_model=list[GroupResponse])
|
||||||
|
def list_groups(db: Session = Depends(get_db)):
|
||||||
|
"""List all groups sorted by sort_order"""
|
||||||
|
return db.query(Group).order_by(Group.sort_order).all()
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{group_id}/sort", response_model=GroupResponse)
|
||||||
|
@require_roles("admin")
|
||||||
|
def update_group_sort_order(
|
||||||
|
group_id: UUID,
|
||||||
|
sort_update: GroupSortUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Update a single group's sort order"""
|
||||||
|
db_group = db.query(Group).filter(Group.id == group_id).first()
|
||||||
|
if not db_group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
db_group.sort_order = sort_update.sort_order
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_group)
|
||||||
|
return db_group
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/reorder", response_model=list[GroupResponse])
|
||||||
|
@require_roles("admin")
|
||||||
|
def bulk_update_sort_orders(
|
||||||
|
bulk_sort: GroupBulkSort,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Bulk update group sort orders"""
|
||||||
|
groups_to_update = []
|
||||||
|
|
||||||
|
for group_data in bulk_sort.groups:
|
||||||
|
group_id = group_data["group_id"]
|
||||||
|
sort_order = group_data["sort_order"]
|
||||||
|
|
||||||
|
group = db.query(Group).filter(Group.id == str(group_id)).first()
|
||||||
|
if not group:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Group with id {group_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.sort_order = sort_order
|
||||||
|
groups_to_update.append(group)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
# Return all groups in their new order
|
||||||
|
return db.query(Group).order_by(Group.sort_order).all()
|
||||||
@@ -1,15 +1,156 @@
|
|||||||
from fastapi import APIRouter, Depends
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.auth.dependencies import get_current_user
|
from app.auth.dependencies import get_current_user
|
||||||
|
from app.iptv.stream_manager import StreamManager
|
||||||
from app.models.auth import CognitoUser
|
from app.models.auth import CognitoUser
|
||||||
|
from app.utils.database import get_db_session
|
||||||
|
|
||||||
router = APIRouter(prefix="/playlist", tags=["playlist"])
|
router = APIRouter(prefix="/playlist", tags=["playlist"])
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# In-memory store for validation processes
|
||||||
|
validation_processes: dict[str, dict] = {}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/protected", summary="Protected endpoint for authenticated users")
|
class ProcessStatus(str, Enum):
|
||||||
async def protected_route(user: CognitoUser = Depends(get_current_user)):
|
PENDING = "pending"
|
||||||
|
IN_PROGRESS = "in_progress"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
class StreamValidationRequest(BaseModel):
|
||||||
|
"""Request model for stream validation endpoint"""
|
||||||
|
|
||||||
|
channel_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ValidatedStream(BaseModel):
|
||||||
|
"""Model for a validated working stream"""
|
||||||
|
|
||||||
|
channel_id: str
|
||||||
|
stream_url: str
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationProcessResponse(BaseModel):
|
||||||
|
"""Response model for validation process initiation"""
|
||||||
|
|
||||||
|
process_id: str
|
||||||
|
status: ProcessStatus
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationResultResponse(BaseModel):
|
||||||
|
"""Response model for validation results"""
|
||||||
|
|
||||||
|
process_id: str
|
||||||
|
status: ProcessStatus
|
||||||
|
working_streams: Optional[list[ValidatedStream]] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
def run_stream_validation(process_id: str, channel_id: Optional[str], db: Session):
|
||||||
|
"""Background task to validate streams"""
|
||||||
|
try:
|
||||||
|
validation_processes[process_id]["status"] = ProcessStatus.IN_PROGRESS
|
||||||
|
manager = StreamManager(db)
|
||||||
|
|
||||||
|
if channel_id:
|
||||||
|
stream_url = manager.validate_and_select_stream(channel_id)
|
||||||
|
if stream_url:
|
||||||
|
validation_processes[process_id]["result"] = {
|
||||||
|
"working_streams": [
|
||||||
|
ValidatedStream(channel_id=channel_id, stream_url=stream_url)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
validation_processes[process_id]["error"] = (
|
||||||
|
f"No working streams found for channel {channel_id}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# TODO: Implement validation for all channels
|
||||||
|
validation_processes[process_id]["error"] = (
|
||||||
|
"Validation of all channels not yet implemented"
|
||||||
|
)
|
||||||
|
|
||||||
|
validation_processes[process_id]["status"] = ProcessStatus.COMPLETED
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error validating streams: {str(e)}")
|
||||||
|
validation_processes[process_id]["status"] = ProcessStatus.FAILED
|
||||||
|
validation_processes[process_id]["error"] = str(e)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/validate-streams",
|
||||||
|
summary="Start stream validation process",
|
||||||
|
response_model=ValidationProcessResponse,
|
||||||
|
status_code=status.HTTP_202_ACCEPTED,
|
||||||
|
responses={202: {"description": "Validation process started successfully"}},
|
||||||
|
)
|
||||||
|
async def start_stream_validation(
|
||||||
|
request: StreamValidationRequest,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db_session),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Protected endpoint that requires authentication for all users.
|
Start asynchronous validation of streams.
|
||||||
If the user is authenticated, returns success message.
|
|
||||||
|
- Returns immediately with a process ID
|
||||||
|
- Use GET /validate-streams/{process_id} to check status
|
||||||
"""
|
"""
|
||||||
return {"message": f"Hello {user.username}, you have access to support resources!"}
|
process_id = str(uuid4())
|
||||||
|
validation_processes[process_id] = {
|
||||||
|
"status": ProcessStatus.PENDING,
|
||||||
|
"channel_id": request.channel_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
background_tasks.add_task(run_stream_validation, process_id, request.channel_id, db)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"process_id": process_id,
|
||||||
|
"status": ProcessStatus.PENDING,
|
||||||
|
"message": "Validation process started",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/validate-streams/{process_id}",
|
||||||
|
summary="Check validation process status",
|
||||||
|
response_model=ValidationResultResponse,
|
||||||
|
responses={
|
||||||
|
200: {"description": "Process status and results"},
|
||||||
|
404: {"description": "Process not found"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def get_validation_status(
|
||||||
|
process_id: str, user: CognitoUser = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Check status of a stream validation process.
|
||||||
|
|
||||||
|
Returns current status and results if completed.
|
||||||
|
"""
|
||||||
|
if process_id not in validation_processes:
|
||||||
|
raise HTTPException(status_code=404, detail="Process not found")
|
||||||
|
|
||||||
|
process = validation_processes[process_id]
|
||||||
|
response = {"process_id": process_id, "status": process["status"]}
|
||||||
|
|
||||||
|
if process["status"] == ProcessStatus.COMPLETED:
|
||||||
|
if "error" in process:
|
||||||
|
response["error"] = process["error"]
|
||||||
|
else:
|
||||||
|
response["working_streams"] = process["result"]["working_streams"]
|
||||||
|
elif process["status"] == ProcessStatus.FAILED:
|
||||||
|
response["error"] = process["error"]
|
||||||
|
|
||||||
|
return response
|
||||||
|
|||||||
@@ -59,6 +59,34 @@ def get_priority(
|
|||||||
return priority
|
return priority
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/", status_code=status.HTTP_200_OK)
|
||||||
|
@require_roles("admin")
|
||||||
|
def delete_priorities(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Delete all priorities not in use by channel URLs"""
|
||||||
|
from app.models.db import ChannelURL
|
||||||
|
|
||||||
|
priorities = db.query(Priority).all()
|
||||||
|
deleted = 0
|
||||||
|
skipped = 0
|
||||||
|
|
||||||
|
for priority in priorities:
|
||||||
|
in_use = db.scalar(
|
||||||
|
select(ChannelURL).where(ChannelURL.priority_id == priority.id).limit(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not in_use:
|
||||||
|
db.delete(priority)
|
||||||
|
deleted += 1
|
||||||
|
else:
|
||||||
|
skipped += 1
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
return {"deleted": deleted, "skipped": skipped}
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{priority_id}", status_code=status.HTTP_204_NO_CONTENT)
|
@router.delete("/{priority_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
@require_roles("admin")
|
@require_roles("admin")
|
||||||
def delete_priority(
|
def delete_priority(
|
||||||
|
|||||||
57
app/routers/scheduler.py
Normal file
57
app/routers/scheduler.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.auth.dependencies import get_current_user, require_roles
|
||||||
|
from app.iptv.scheduler import StreamScheduler
|
||||||
|
from app.models.auth import CognitoUser
|
||||||
|
from app.utils.database import get_db
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/scheduler",
|
||||||
|
tags=["scheduler"],
|
||||||
|
responses={404: {"description": "Not found"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_scheduler(request: Request) -> StreamScheduler:
|
||||||
|
"""Get the scheduler instance from the app state."""
|
||||||
|
if not hasattr(request.app.state.scheduler, "scheduler"):
|
||||||
|
raise HTTPException(status_code=500, detail="Scheduler not initialized")
|
||||||
|
return request.app.state.scheduler
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/health")
|
||||||
|
@require_roles("admin")
|
||||||
|
def scheduler_health(
|
||||||
|
scheduler: StreamScheduler = Depends(get_scheduler),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Check scheduler health status (admin only)."""
|
||||||
|
try:
|
||||||
|
job = scheduler.scheduler.get_job("daily_stream_validation")
|
||||||
|
next_run = str(job.next_run_time) if job and job.next_run_time else None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "running" if scheduler.scheduler.running else "stopped",
|
||||||
|
"next_run": next_run,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to check scheduler health: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/trigger")
|
||||||
|
@require_roles("admin")
|
||||||
|
def trigger_validation(
|
||||||
|
scheduler: StreamScheduler = Depends(get_scheduler),
|
||||||
|
user: CognitoUser = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Manually trigger stream validation (admin only)."""
|
||||||
|
scheduler.trigger_manual_validation()
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=202, content={"message": "Stream validation triggered"}
|
||||||
|
)
|
||||||
@@ -66,6 +66,8 @@ class StreamValidator:
|
|||||||
"application/octet-stream",
|
"application/octet-stream",
|
||||||
"application/x-mpegURL",
|
"application/x-mpegURL",
|
||||||
]
|
]
|
||||||
|
if content_type is None:
|
||||||
|
return False
|
||||||
return any(ct in content_type for ct in valid_types)
|
return any(ct in content_type for ct in valid_types)
|
||||||
|
|
||||||
def parse_playlist(self, file_path):
|
def parse_playlist(self, file_path):
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
|
from requests import Session
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
@@ -19,16 +20,16 @@ def get_db_credentials():
|
|||||||
|
|
||||||
ssm = boto3.client("ssm", region_name=AWS_REGION)
|
ssm = boto3.client("ssm", region_name=AWS_REGION)
|
||||||
try:
|
try:
|
||||||
host = ssm.get_parameter(Name="/iptv-updater/DB_HOST", WithDecryption=True)[
|
host = ssm.get_parameter(Name="/iptv-manager/DB_HOST", WithDecryption=True)[
|
||||||
"Parameter"
|
"Parameter"
|
||||||
]["Value"]
|
]["Value"]
|
||||||
user = ssm.get_parameter(Name="/iptv-updater/DB_USER", WithDecryption=True)[
|
user = ssm.get_parameter(Name="/iptv-manager/DB_USER", WithDecryption=True)[
|
||||||
"Parameter"
|
"Parameter"
|
||||||
]["Value"]
|
]["Value"]
|
||||||
password = ssm.get_parameter(
|
password = ssm.get_parameter(
|
||||||
Name="/iptv-updater/DB_PASSWORD", WithDecryption=True
|
Name="/iptv-manager/DB_PASSWORD", WithDecryption=True
|
||||||
)["Parameter"]["Value"]
|
)["Parameter"]["Value"]
|
||||||
dbname = ssm.get_parameter(Name="/iptv-updater/DB_NAME", WithDecryption=True)[
|
dbname = ssm.get_parameter(Name="/iptv-manager/DB_NAME", WithDecryption=True)[
|
||||||
"Parameter"
|
"Parameter"
|
||||||
]["Value"]
|
]["Value"]
|
||||||
return f"postgresql://{user}:{password}@{host}/{dbname}"
|
return f"postgresql://{user}:{password}@{host}/{dbname}"
|
||||||
@@ -53,3 +54,8 @@ def get_db():
|
|||||||
yield db
|
yield db
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def get_db_session() -> Session:
|
||||||
|
"""Get a direct database session (non-generator version)"""
|
||||||
|
return SessionLocal()
|
||||||
|
|||||||
@@ -3,10 +3,11 @@ version: '3.8'
|
|||||||
services:
|
services:
|
||||||
postgres:
|
postgres:
|
||||||
image: postgres:13
|
image: postgres:13
|
||||||
|
container_name: postgres
|
||||||
environment:
|
environment:
|
||||||
POSTGRES_USER: postgres
|
POSTGRES_USER: postgres
|
||||||
POSTGRES_PASSWORD: postgres
|
POSTGRES_PASSWORD: postgres
|
||||||
POSTGRES_DB: iptv_updater
|
POSTGRES_DB: iptv_manager
|
||||||
ports:
|
ports:
|
||||||
- "5432:5432"
|
- "5432:5432"
|
||||||
volumes:
|
volumes:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
POSTGRES_USER: postgres
|
POSTGRES_USER: postgres
|
||||||
POSTGRES_PASSWORD: postgres
|
POSTGRES_PASSWORD: postgres
|
||||||
POSTGRES_DB: iptv_updater
|
POSTGRES_DB: iptv_manager
|
||||||
ports:
|
ports:
|
||||||
- "5432:5432"
|
- "5432:5432"
|
||||||
volumes:
|
volumes:
|
||||||
@@ -20,7 +20,7 @@ services:
|
|||||||
DB_USER: postgres
|
DB_USER: postgres
|
||||||
DB_PASSWORD: postgres
|
DB_PASSWORD: postgres
|
||||||
DB_HOST: postgres
|
DB_HOST: postgres
|
||||||
DB_NAME: iptv_updater
|
DB_NAME: iptv_manager
|
||||||
MOCK_AUTH: "true"
|
MOCK_AUTH: "true"
|
||||||
ports:
|
ports:
|
||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from aws_cdk import aws_ssm as ssm
|
|||||||
from constructs import Construct
|
from constructs import Construct
|
||||||
|
|
||||||
|
|
||||||
class IptvUpdaterStack(Stack):
|
class IptvManagerStack(Stack):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
scope: Construct,
|
scope: Construct,
|
||||||
@@ -27,7 +27,7 @@ class IptvUpdaterStack(Stack):
|
|||||||
# Create VPC
|
# Create VPC
|
||||||
vpc = ec2.Vpc(
|
vpc = ec2.Vpc(
|
||||||
self,
|
self,
|
||||||
"IptvUpdaterVPC",
|
"IptvManagerVPC",
|
||||||
max_azs=2, # Need at least 2 AZs for RDS subnet group
|
max_azs=2, # Need at least 2 AZs for RDS subnet group
|
||||||
nat_gateways=0, # No NAT Gateway to stay in free tier
|
nat_gateways=0, # No NAT Gateway to stay in free tier
|
||||||
subnet_configuration=[
|
subnet_configuration=[
|
||||||
@@ -44,7 +44,7 @@ class IptvUpdaterStack(Stack):
|
|||||||
|
|
||||||
# Security Group
|
# Security Group
|
||||||
security_group = ec2.SecurityGroup(
|
security_group = ec2.SecurityGroup(
|
||||||
self, "IptvUpdaterSG", vpc=vpc, allow_all_outbound=True
|
self, "IptvManagerSG", vpc=vpc, allow_all_outbound=True
|
||||||
)
|
)
|
||||||
|
|
||||||
security_group.add_ingress_rule(
|
security_group.add_ingress_rule(
|
||||||
@@ -66,18 +66,18 @@ class IptvUpdaterStack(Stack):
|
|||||||
"Allow PostgreSQL traffic for tunneling",
|
"Allow PostgreSQL traffic for tunneling",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Key pair for IPTV Updater instance
|
# Key pair for IPTV Manager instance
|
||||||
key_pair = ec2.KeyPair(
|
key_pair = ec2.KeyPair(
|
||||||
self,
|
self,
|
||||||
"IptvUpdaterKeyPair",
|
"IptvManagerKeyPair",
|
||||||
key_pair_name="iptv-updater-key",
|
key_pair_name="iptv-manager-key",
|
||||||
public_key_material=ssh_public_key,
|
public_key_material=ssh_public_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create IAM role for EC2
|
# Create IAM role for EC2
|
||||||
role = iam.Role(
|
role = iam.Role(
|
||||||
self,
|
self,
|
||||||
"IptvUpdaterRole",
|
"IptvManagerRole",
|
||||||
assumed_by=iam.ServicePrincipal("ec2.amazonaws.com"),
|
assumed_by=iam.ServicePrincipal("ec2.amazonaws.com"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -111,37 +111,11 @@ class IptvUpdaterStack(Stack):
|
|||||||
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonCognitoReadOnly")
|
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonCognitoReadOnly")
|
||||||
)
|
)
|
||||||
|
|
||||||
# EC2 Instance
|
|
||||||
instance = ec2.Instance(
|
|
||||||
self,
|
|
||||||
"IptvUpdaterInstance",
|
|
||||||
vpc=vpc,
|
|
||||||
vpc_subnets=ec2.SubnetSelection(subnet_type=ec2.SubnetType.PUBLIC),
|
|
||||||
instance_type=ec2.InstanceType.of(
|
|
||||||
ec2.InstanceClass.T2, ec2.InstanceSize.MICRO
|
|
||||||
),
|
|
||||||
machine_image=ec2.AmazonLinuxImage(
|
|
||||||
generation=ec2.AmazonLinuxGeneration.AMAZON_LINUX_2023
|
|
||||||
),
|
|
||||||
security_group=security_group,
|
|
||||||
key_pair=key_pair,
|
|
||||||
role=role,
|
|
||||||
# Option: 1: Enable auto-assign public IP (free tier compatible)
|
|
||||||
associate_public_ip_address=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Option: 2: Create Elastic IP (not free tier compatible)
|
|
||||||
# eip = ec2.CfnEIP(
|
|
||||||
# self, "IptvUpdaterEIP",
|
|
||||||
# domain="vpc",
|
|
||||||
# instance_id=instance.instance_id
|
|
||||||
# )
|
|
||||||
|
|
||||||
# Add Cognito User Pool
|
# Add Cognito User Pool
|
||||||
user_pool = cognito.UserPool(
|
user_pool = cognito.UserPool(
|
||||||
self,
|
self,
|
||||||
"IptvUpdaterUserPool",
|
"IptvManagerUserPool",
|
||||||
user_pool_name="iptv-updater-users",
|
user_pool_name="iptv-manager-users",
|
||||||
self_sign_up_enabled=False, # Only admins can create users
|
self_sign_up_enabled=False, # Only admins can create users
|
||||||
password_policy=cognito.PasswordPolicy(
|
password_policy=cognito.PasswordPolicy(
|
||||||
min_length=8,
|
min_length=8,
|
||||||
@@ -156,7 +130,7 @@ class IptvUpdaterStack(Stack):
|
|||||||
|
|
||||||
# Add App Client with the correct callback URL
|
# Add App Client with the correct callback URL
|
||||||
client = user_pool.add_client(
|
client = user_pool.add_client(
|
||||||
"IptvUpdaterClient",
|
"IptvManagerClient",
|
||||||
access_token_validity=Duration.minutes(60),
|
access_token_validity=Duration.minutes(60),
|
||||||
id_token_validity=Duration.minutes(60),
|
id_token_validity=Duration.minutes(60),
|
||||||
refresh_token_validity=Duration.days(1),
|
refresh_token_validity=Duration.days(1),
|
||||||
@@ -171,8 +145,8 @@ class IptvUpdaterStack(Stack):
|
|||||||
|
|
||||||
# Add domain for hosted UI
|
# Add domain for hosted UI
|
||||||
domain = user_pool.add_domain(
|
domain = user_pool.add_domain(
|
||||||
"IptvUpdaterDomain",
|
"IptvManagerDomain",
|
||||||
cognito_domain=cognito.CognitoDomainOptions(domain_prefix="iptv-updater"),
|
cognito_domain=cognito.CognitoDomainOptions(domain_prefix="iptv-manager"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Read the userdata script with proper path resolution
|
# Read the userdata script with proper path resolution
|
||||||
@@ -226,7 +200,7 @@ class IptvUpdaterStack(Stack):
|
|||||||
# Create RDS PostgreSQL instance (free tier compatible - db.t3.micro)
|
# Create RDS PostgreSQL instance (free tier compatible - db.t3.micro)
|
||||||
db = rds.DatabaseInstance(
|
db = rds.DatabaseInstance(
|
||||||
self,
|
self,
|
||||||
"IptvUpdaterDB",
|
"IptvManagerDB",
|
||||||
engine=rds.DatabaseInstanceEngine.postgres(
|
engine=rds.DatabaseInstanceEngine.postgres(
|
||||||
version=rds.PostgresEngineVersion.VER_13
|
version=rds.PostgresEngineVersion.VER_13
|
||||||
),
|
),
|
||||||
@@ -240,7 +214,7 @@ class IptvUpdaterStack(Stack):
|
|||||||
security_groups=[rds_sg],
|
security_groups=[rds_sg],
|
||||||
allocated_storage=10,
|
allocated_storage=10,
|
||||||
max_allocated_storage=10,
|
max_allocated_storage=10,
|
||||||
database_name="iptv_updater",
|
database_name="iptv_manager",
|
||||||
removal_policy=RemovalPolicy.DESTROY,
|
removal_policy=RemovalPolicy.DESTROY,
|
||||||
deletion_protection=False,
|
deletion_protection=False,
|
||||||
publicly_accessible=False, # Avoid public IPv4 charges
|
publicly_accessible=False, # Avoid public IPv4 charges
|
||||||
@@ -252,28 +226,28 @@ class IptvUpdaterStack(Stack):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Store DB connection info in SSM Parameter Store
|
# Store DB connection info in SSM Parameter Store
|
||||||
ssm.StringParameter(
|
db_host_param = ssm.StringParameter(
|
||||||
self,
|
self,
|
||||||
"DBHostParam",
|
"DBHostParam",
|
||||||
parameter_name="/iptv-updater/DB_HOST",
|
parameter_name="/iptv-manager/DB_HOST",
|
||||||
string_value=db.db_instance_endpoint_address,
|
string_value=db.db_instance_endpoint_address,
|
||||||
)
|
)
|
||||||
ssm.StringParameter(
|
db_name_param = ssm.StringParameter(
|
||||||
self,
|
self,
|
||||||
"DBNameParam",
|
"DBNameParam",
|
||||||
parameter_name="/iptv-updater/DB_NAME",
|
parameter_name="/iptv-manager/DB_NAME",
|
||||||
string_value="iptv_updater",
|
string_value="iptv_manager",
|
||||||
)
|
)
|
||||||
ssm.StringParameter(
|
db_user_param = ssm.StringParameter(
|
||||||
self,
|
self,
|
||||||
"DBUserParam",
|
"DBUserParam",
|
||||||
parameter_name="/iptv-updater/DB_USER",
|
parameter_name="/iptv-manager/DB_USER",
|
||||||
string_value=db.secret.secret_value_from_json("username").to_string(),
|
string_value=db.secret.secret_value_from_json("username").to_string(),
|
||||||
)
|
)
|
||||||
ssm.StringParameter(
|
db_pass_param = ssm.StringParameter(
|
||||||
self,
|
self,
|
||||||
"DBPassParam",
|
"DBPassParam",
|
||||||
parameter_name="/iptv-updater/DB_PASSWORD",
|
parameter_name="/iptv-manager/DB_PASSWORD",
|
||||||
string_value=db.secret.secret_value_from_json("password").to_string(),
|
string_value=db.secret.secret_value_from_json("password").to_string(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -282,6 +256,39 @@ class IptvUpdaterStack(Stack):
|
|||||||
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonSSMReadOnlyAccess")
|
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonSSMReadOnlyAccess")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# EC2 Instance (created after all dependencies are ready)
|
||||||
|
instance = ec2.Instance(
|
||||||
|
self,
|
||||||
|
"IptvManagerInstance",
|
||||||
|
vpc=vpc,
|
||||||
|
vpc_subnets=ec2.SubnetSelection(subnet_type=ec2.SubnetType.PUBLIC),
|
||||||
|
instance_type=ec2.InstanceType.of(
|
||||||
|
ec2.InstanceClass.T2, ec2.InstanceSize.MICRO
|
||||||
|
),
|
||||||
|
machine_image=ec2.AmazonLinuxImage(
|
||||||
|
generation=ec2.AmazonLinuxGeneration.AMAZON_LINUX_2023
|
||||||
|
),
|
||||||
|
security_group=security_group,
|
||||||
|
key_pair=key_pair,
|
||||||
|
role=role,
|
||||||
|
# Option: 1: Enable auto-assign public IP (free tier compatible)
|
||||||
|
associate_public_ip_address=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure instance depends on SSM parameters being created
|
||||||
|
instance.node.add_dependency(db)
|
||||||
|
instance.node.add_dependency(db_host_param)
|
||||||
|
instance.node.add_dependency(db_name_param)
|
||||||
|
instance.node.add_dependency(db_user_param)
|
||||||
|
instance.node.add_dependency(db_pass_param)
|
||||||
|
|
||||||
|
# Option: 2: Create Elastic IP (not free tier compatible)
|
||||||
|
# eip = ec2.CfnEIP(
|
||||||
|
# self, "IptvManagerEIP",
|
||||||
|
# domain="vpc",
|
||||||
|
# instance_id=instance.instance_id
|
||||||
|
# )
|
||||||
|
|
||||||
# Update instance with userdata
|
# Update instance with userdata
|
||||||
instance.add_user_data(userdata.render())
|
instance.add_user_data(userdata.render())
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
# Update system and install required packages
|
# Update system and install required packages
|
||||||
dnf update -y
|
dnf update -y
|
||||||
dnf install -y python3-pip git cronie nginx certbot python3-certbot-nginx
|
dnf install -y python3-pip git cronie nginx certbot python3-certbot-nginx postgresql15.x86_64 awscli
|
||||||
|
|
||||||
# Start and enable crond service
|
# Start and enable crond service
|
||||||
systemctl start crond
|
systemctl start crond
|
||||||
@@ -11,27 +11,69 @@ systemctl enable crond
|
|||||||
cd /home/ec2-user
|
cd /home/ec2-user
|
||||||
|
|
||||||
git clone ${REPO_URL}
|
git clone ${REPO_URL}
|
||||||
cd iptv-updater-aws
|
cd iptv-manager-service
|
||||||
|
|
||||||
# Install Python packages with --ignore-installed to prevent conflicts with RPM packages
|
# Install Python packages with --ignore-installed to prevent conflicts with RPM packages
|
||||||
pip3 install --ignore-installed -r requirements.txt
|
pip3 install --ignore-installed -r requirements.txt
|
||||||
|
|
||||||
|
# Retrieve DB credentials from SSM Parameter Store with retries
|
||||||
|
echo "Attempting to retrieve DB credentials from SSM..."
|
||||||
|
for i in {1..30}; do
|
||||||
|
DB_HOST=$(aws ssm get-parameter --name "/iptv-manager/DB_HOST" --query "Parameter.Value" --output text 2>/dev/null)
|
||||||
|
DB_NAME=$(aws ssm get-parameter --name "/iptv-manager/DB_NAME" --query "Parameter.Value" --output text 2>/dev/null)
|
||||||
|
DB_USER=$(aws ssm get-parameter --name "/iptv-manager/DB_USER" --query "Parameter.Value" --output text 2>/dev/null)
|
||||||
|
DB_PASSWORD=$(aws ssm get-parameter --name "/iptv-manager/DB_PASSWORD" --query "Parameter.Value" --output text 2>/dev/null)
|
||||||
|
|
||||||
|
if [ -n "$DB_HOST" ] && [ -n "$DB_NAME" ] && [ -n "$DB_USER" ] && [ -n "$DB_PASSWORD" ]; then
|
||||||
|
echo "Successfully retrieved all DB credentials"
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Waiting for SSM parameters to be available... (attempt $i/30)"
|
||||||
|
sleep 5
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ -z "$DB_HOST" ] || [ -z "$DB_NAME" ] || [ -z "$DB_USER" ] || [ -z "$DB_PASSWORD" ]; then
|
||||||
|
echo "ERROR: Failed to retrieve all required DB credentials after 30 attempts"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
export DB_HOST
|
||||||
|
export DB_NAME
|
||||||
|
export DB_USER
|
||||||
|
export DB_PASSWORD
|
||||||
|
|
||||||
|
# Set PGPASSWORD for psql to use
|
||||||
|
export PGPASSWORD=$DB_PASSWORD
|
||||||
|
|
||||||
|
# Wait for PostgreSQL to be ready
|
||||||
|
echo "Waiting for PostgreSQL to start..."
|
||||||
|
until psql -h $DB_HOST -U $DB_USER -d postgres -c '\q'; do
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
echo "PostgreSQL is ready."
|
||||||
|
|
||||||
|
# Create database if it does not exist
|
||||||
|
DB_EXISTS=$(psql -h $DB_HOST -U $DB_USER -d postgres -tc "SELECT 1 FROM pg_database WHERE datname = '$DB_NAME';")
|
||||||
|
if [ -z "$DB_EXISTS" ]; then
|
||||||
|
echo "Creating database $DB_NAME..."
|
||||||
|
psql -h $DB_HOST -U $DB_USER -d postgres -c "CREATE DATABASE $DB_NAME;"
|
||||||
|
echo "Database $DB_NAME created."
|
||||||
|
fi
|
||||||
|
|
||||||
# Run database migrations
|
# Run database migrations
|
||||||
alembic upgrade head
|
alembic upgrade head
|
||||||
|
|
||||||
# Seed initial priorities
|
|
||||||
python3 -c "from app.utils.database import SessionLocal; from app.models.db import Priority; db = SessionLocal(); db.add_all([Priority(id=100, description='High'), Priority(id=200, description='Medium'), Priority(id=300, description='Low')]); db.commit()"
|
|
||||||
|
|
||||||
# Create systemd service file
|
# Create systemd service file
|
||||||
cat << 'EOF' > /etc/systemd/system/iptv-updater.service
|
cat << 'EOF' > /etc/systemd/system/iptv-manager.service
|
||||||
[Unit]
|
[Unit]
|
||||||
Description=IPTV Updater Service
|
Description=IPTV Manager Service
|
||||||
After=network.target
|
After=network.target
|
||||||
|
|
||||||
[Service]
|
[Service]
|
||||||
Type=simple
|
Type=simple
|
||||||
User=ec2-user
|
User=ec2-user
|
||||||
WorkingDirectory=/home/ec2-user/iptv-updater-aws
|
WorkingDirectory=/home/ec2-user/iptv-manager-service
|
||||||
ExecStart=/usr/local/bin/uvicorn app.main:app --host 127.0.0.1 --port 8000
|
ExecStart=/usr/local/bin/uvicorn app.main:app --host 127.0.0.1 --port 8000
|
||||||
EnvironmentFile=/etc/environment
|
EnvironmentFile=/etc/environment
|
||||||
Restart=always
|
Restart=always
|
||||||
@@ -56,7 +98,7 @@ sudo mkdir -p /etc/nginx/ssl
|
|||||||
--reloadcmd "service nginx force-reload"
|
--reloadcmd "service nginx force-reload"
|
||||||
|
|
||||||
# Create nginx config
|
# Create nginx config
|
||||||
cat << EOF > /etc/nginx/conf.d/iptvUpdater.conf
|
cat << EOF > /etc/nginx/conf.d/iptvManager.conf
|
||||||
server {
|
server {
|
||||||
listen 80;
|
listen 80;
|
||||||
server_name ${DOMAIN_NAME} *.${DOMAIN_NAME};
|
server_name ${DOMAIN_NAME} *.${DOMAIN_NAME};
|
||||||
@@ -83,5 +125,5 @@ EOF
|
|||||||
# Start nginx service
|
# Start nginx service
|
||||||
systemctl enable nginx
|
systemctl enable nginx
|
||||||
systemctl start nginx
|
systemctl start nginx
|
||||||
systemctl enable iptv-updater
|
systemctl enable iptv-manager
|
||||||
systemctl start iptv-updater
|
systemctl start iptv-manager
|
||||||
@@ -24,4 +24,8 @@ ignore = []
|
|||||||
known-first-party = ["app"]
|
known-first-party = ["app"]
|
||||||
|
|
||||||
[tool.ruff.format]
|
[tool.ruff.format]
|
||||||
docstring-code-format = true
|
docstring-code-format = true
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
addopts = "--cov=app --cov-report=term-missing --cov-fail-under=70"
|
||||||
|
testpaths = ["tests"]
|
||||||
@@ -5,12 +5,21 @@ python_functions = test_*
|
|||||||
asyncio_mode = auto
|
asyncio_mode = auto
|
||||||
filterwarnings =
|
filterwarnings =
|
||||||
ignore::DeprecationWarning:botocore.auth
|
ignore::DeprecationWarning:botocore.auth
|
||||||
|
ignore:The 'app' shortcut is now deprecated:DeprecationWarning:httpx._client
|
||||||
|
|
||||||
# Coverage configuration
|
# Coverage configuration
|
||||||
addopts =
|
addopts =
|
||||||
--cov=app
|
--cov=app
|
||||||
--cov-report=term-missing
|
--cov-report=term-missing
|
||||||
|
|
||||||
|
# Test environment variables
|
||||||
|
env =
|
||||||
|
MOCK_AUTH=true
|
||||||
|
DB_USER=test_user
|
||||||
|
DB_PASSWORD=test_password
|
||||||
|
DB_HOST=localhost
|
||||||
|
DB_NAME=iptv_manager_test
|
||||||
|
|
||||||
# Test markers
|
# Test markers
|
||||||
markers =
|
markers =
|
||||||
slow: mark tests as slow running
|
slow: mark tests as slow running
|
||||||
|
|||||||
@@ -14,4 +14,9 @@ psycopg2-binary==2.9.9
|
|||||||
alembic==1.16.1
|
alembic==1.16.1
|
||||||
pytest==8.1.1
|
pytest==8.1.1
|
||||||
pytest-asyncio==0.23.6
|
pytest-asyncio==0.23.6
|
||||||
pytest-mock==3.12.0
|
pytest-mock==3.12.0
|
||||||
|
pytest-cov==4.1.0
|
||||||
|
pytest-env==1.1.1
|
||||||
|
httpx==0.27.0
|
||||||
|
pre-commit
|
||||||
|
apscheduler==3.10.4
|
||||||
@@ -25,7 +25,7 @@ cdk deploy --app="python3 ${PWD}/app.py"
|
|||||||
# Update application on running instances
|
# Update application on running instances
|
||||||
INSTANCE_IDS=$(aws ec2 describe-instances \
|
INSTANCE_IDS=$(aws ec2 describe-instances \
|
||||||
--region us-east-2 \
|
--region us-east-2 \
|
||||||
--filters "Name=tag:Name,Values=IptvUpdaterStack/IptvUpdaterInstance" \
|
--filters "Name=tag:Name,Values=IptvManagerStack/IptvManagerInstance" \
|
||||||
"Name=instance-state-name,Values=running" \
|
"Name=instance-state-name,Values=running" \
|
||||||
--query "Reservations[].Instances[].InstanceId" \
|
--query "Reservations[].Instances[].InstanceId" \
|
||||||
--output text)
|
--output text)
|
||||||
@@ -35,7 +35,7 @@ for INSTANCE_ID in $INSTANCE_IDS; do
|
|||||||
aws ssm send-command \
|
aws ssm send-command \
|
||||||
--instance-ids "$INSTANCE_ID" \
|
--instance-ids "$INSTANCE_ID" \
|
||||||
--document-name "AWS-RunShellScript" \
|
--document-name "AWS-RunShellScript" \
|
||||||
--parameters '{"commands":["cd /home/ec2-user/iptv-updater-aws && git pull && pip3 install -r requirements.txt && alembic upgrade head && sudo systemctl restart iptv-updater"]}' \
|
--parameters '{"commands":["cd /home/ec2-user/iptv-manager-service && git pull && pip3 install -r requirements.txt && alembic upgrade head && sudo systemctl restart iptv-manager"]}' \
|
||||||
--no-cli-pager \
|
--no-cli-pager \
|
||||||
--no-paginate
|
--no-paginate
|
||||||
done
|
done
|
||||||
|
|||||||
@@ -7,9 +7,7 @@ python3 -m pip install -r requirements.txt
|
|||||||
# Install and configure pre-commit hooks
|
# Install and configure pre-commit hooks
|
||||||
pre-commit install
|
pre-commit install
|
||||||
pre-commit install-hooks
|
pre-commit install-hooks
|
||||||
|
pre-commit autoupdate
|
||||||
|
|
||||||
# Initialize and run database migrations
|
# Verify pytest setup
|
||||||
alembic upgrade head
|
python3 -m pytest
|
||||||
|
|
||||||
# Seed initial data
|
|
||||||
python3 -c "from app.utils.database import SessionLocal; from app.models.db import Priority; db = SessionLocal(); db.add_all([Priority(id=100, description='High'), Priority(id=200, description='Medium'), Priority(id=300, description='Low')]); db.commit()"
|
|
||||||
@@ -1,21 +1,26 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
# Start PostgreSQL
|
# Start PostgreSQL
|
||||||
docker-compose -f docker/docker-compose-db.yml up -d
|
docker-compose -f docker/docker-compose-db.yml up -d
|
||||||
|
|
||||||
# Set mock auth and database environment variables
|
# Set environment variables
|
||||||
export MOCK_AUTH=true
|
export MOCK_AUTH=true
|
||||||
|
export DB_HOST=localhost
|
||||||
export DB_USER=postgres
|
export DB_USER=postgres
|
||||||
export DB_PASSWORD=postgres
|
export DB_PASSWORD=postgres
|
||||||
export DB_HOST=localhost
|
export DB_NAME=iptv_manager
|
||||||
export DB_NAME=iptv_updater
|
|
||||||
|
echo "Ensuring database $DB_NAME exists using conditional DDL..."
|
||||||
|
PGPASSWORD=$DB_PASSWORD docker exec -i postgres psql -U $DB_USER <<< "SELECT 'CREATE DATABASE $DB_NAME' WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = '$DB_NAME')\gexec"
|
||||||
|
echo "Database $DB_NAME check complete."
|
||||||
|
|
||||||
# Run database migrations
|
# Run database migrations
|
||||||
alembic upgrade head
|
alembic upgrade head
|
||||||
|
|
||||||
# Start FastAPI
|
# Start FastAPI
|
||||||
nohup uvicorn app.main:app --host 127.0.0.1 --port 8000 > app.log 2>&1 &
|
nohup uvicorn app.main:app --host 127.0.0.1 --port 8000 > app.log 2>&1 &
|
||||||
echo $! > iptv-updater.pid
|
echo $! > iptv-manager.pid
|
||||||
|
|
||||||
echo "Services started:"
|
echo "Services started:"
|
||||||
echo "- PostgreSQL running on localhost:5432"
|
echo "- PostgreSQL running on localhost:5432"
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
# Stop FastAPI
|
# Stop FastAPI
|
||||||
if [ -f iptv-updater.pid ]; then
|
if [ -f iptv-manager.pid ]; then
|
||||||
kill $(cat iptv-updater.pid)
|
kill $(cat iptv-manager.pid)
|
||||||
rm iptv-updater.pid
|
rm iptv-manager.pid
|
||||||
echo "Stopped FastAPI"
|
echo "Stopped FastAPI"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ def mock_get_user_from_token(token: str) -> CognitoUser:
|
|||||||
|
|
||||||
# Mock endpoint for testing the require_roles decorator
|
# Mock endpoint for testing the require_roles decorator
|
||||||
@require_roles("admin")
|
@require_roles("admin")
|
||||||
async def mock_protected_endpoint(user: CognitoUser = Depends(get_current_user)):
|
def mock_protected_endpoint(user: CognitoUser = Depends(get_current_user)):
|
||||||
return {"message": "Success", "user": user.username}
|
return {"message": "Success", "user": user.username}
|
||||||
|
|
||||||
|
|
||||||
@@ -96,7 +96,7 @@ async def test_require_roles_no_roles():
|
|||||||
async def test_require_roles_multiple_roles():
|
async def test_require_roles_multiple_roles():
|
||||||
# Test requiring multiple roles
|
# Test requiring multiple roles
|
||||||
@require_roles("admin", "super_user")
|
@require_roles("admin", "super_user")
|
||||||
async def mock_multi_role_endpoint(user: CognitoUser = Depends(get_current_user)):
|
def mock_multi_role_endpoint(user: CognitoUser = Depends(get_current_user)):
|
||||||
return {"message": "Success"}
|
return {"message": "Success"}
|
||||||
|
|
||||||
# User with all required roles
|
# User with all required roles
|
||||||
@@ -173,3 +173,33 @@ def test_mock_auth_import(monkeypatch):
|
|||||||
|
|
||||||
# Reload again to restore original state
|
# Reload again to restore original state
|
||||||
importlib.reload(app.auth.dependencies)
|
importlib.reload(app.auth.dependencies)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cognito_auth_import(monkeypatch):
|
||||||
|
"""Test that cognito auth is imported when MOCK_AUTH=false (covers line 14)"""
|
||||||
|
# Save original env var value
|
||||||
|
original_value = os.environ.get("MOCK_AUTH")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set MOCK_AUTH to false
|
||||||
|
monkeypatch.setenv("MOCK_AUTH", "false")
|
||||||
|
|
||||||
|
# Reload the dependencies module to trigger the import condition
|
||||||
|
import app.auth.dependencies
|
||||||
|
|
||||||
|
importlib.reload(app.auth.dependencies)
|
||||||
|
|
||||||
|
# Verify that get_user_from_token was imported from app.auth.cognito
|
||||||
|
from app.auth.dependencies import get_user_from_token
|
||||||
|
|
||||||
|
assert get_user_from_token.__module__ == "app.auth.cognito"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore original env var
|
||||||
|
if original_value is None:
|
||||||
|
monkeypatch.delenv("MOCK_AUTH", raising=False)
|
||||||
|
else:
|
||||||
|
monkeypatch.setenv("MOCK_AUTH", original_value)
|
||||||
|
|
||||||
|
# Reload again to restore original state
|
||||||
|
importlib.reload(app.auth.dependencies)
|
||||||
|
|||||||
144
tests/models/test_db.py
Normal file
144
tests/models/test_db.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from app.models.db import UUID_COLUMN_TYPE, Base, SQLiteUUID
|
||||||
|
|
||||||
|
# --- Test SQLiteUUID Type ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqliteuuid_process_bind_param_none():
|
||||||
|
"""Test SQLiteUUID.process_bind_param with None returns None"""
|
||||||
|
uuid_type = SQLiteUUID()
|
||||||
|
assert uuid_type.process_bind_param(None, None) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqliteuuid_process_bind_param_valid_uuid():
|
||||||
|
"""Test SQLiteUUID.process_bind_param with valid UUID returns string"""
|
||||||
|
uuid_type = SQLiteUUID()
|
||||||
|
test_uuid = uuid.uuid4()
|
||||||
|
assert uuid_type.process_bind_param(test_uuid, None) == str(test_uuid)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqliteuuid_process_bind_param_valid_string():
|
||||||
|
"""Test SQLiteUUID.process_bind_param with valid UUID string returns string"""
|
||||||
|
uuid_type = SQLiteUUID()
|
||||||
|
test_uuid_str = "550e8400-e29b-41d4-a716-446655440000"
|
||||||
|
assert uuid_type.process_bind_param(test_uuid_str, None) == test_uuid_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqliteuuid_process_bind_param_invalid_string():
|
||||||
|
"""Test SQLiteUUID.process_bind_param raises ValueError for invalid UUID"""
|
||||||
|
uuid_type = SQLiteUUID()
|
||||||
|
with pytest.raises(ValueError, match="Invalid UUID string format"):
|
||||||
|
uuid_type.process_bind_param("invalid-uuid", None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqliteuuid_process_result_value_none():
|
||||||
|
"""Test SQLiteUUID.process_result_value with None returns None"""
|
||||||
|
uuid_type = SQLiteUUID()
|
||||||
|
assert uuid_type.process_result_value(None, None) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqliteuuid_process_result_value_valid_string():
|
||||||
|
"""Test SQLiteUUID.process_result_value converts string to UUID"""
|
||||||
|
uuid_type = SQLiteUUID()
|
||||||
|
test_uuid = uuid.uuid4()
|
||||||
|
result = uuid_type.process_result_value(str(test_uuid), None)
|
||||||
|
assert isinstance(result, uuid.UUID)
|
||||||
|
assert result == test_uuid
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqliteuuid_process_result_value_uuid_object():
|
||||||
|
"""Test SQLiteUUID.process_result_value: UUID object returns itself."""
|
||||||
|
uuid_type = SQLiteUUID()
|
||||||
|
test_uuid = uuid.uuid4()
|
||||||
|
result = uuid_type.process_result_value(test_uuid, None)
|
||||||
|
assert isinstance(result, uuid.UUID)
|
||||||
|
assert result is test_uuid # Ensure it's the same object, not a new one
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqliteuuid_compare_values_none():
|
||||||
|
"""Test SQLiteUUID.compare_values handles None values"""
|
||||||
|
uuid_type = SQLiteUUID()
|
||||||
|
assert uuid_type.compare_values(None, None) is True
|
||||||
|
assert uuid_type.compare_values(None, uuid.uuid4()) is False
|
||||||
|
assert uuid_type.compare_values(uuid.uuid4(), None) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqliteuuid_compare_values_uuid():
|
||||||
|
"""Test SQLiteUUID.compare_values compares UUIDs as strings"""
|
||||||
|
uuid_type = SQLiteUUID()
|
||||||
|
test_uuid = uuid.uuid4()
|
||||||
|
assert uuid_type.compare_values(test_uuid, test_uuid) is True
|
||||||
|
assert uuid_type.compare_values(test_uuid, uuid.uuid4()) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqlite_uuid_comparison():
|
||||||
|
"""Test SQLiteUUID comparison functionality (moved from db_mocks.py)"""
|
||||||
|
uuid_type = SQLiteUUID()
|
||||||
|
|
||||||
|
# Test equal UUIDs
|
||||||
|
uuid1 = uuid.uuid4()
|
||||||
|
uuid2 = uuid.UUID(str(uuid1))
|
||||||
|
assert uuid_type.compare_values(uuid1, uuid2) is True
|
||||||
|
|
||||||
|
# Test UUID vs string
|
||||||
|
assert uuid_type.compare_values(uuid1, str(uuid1)) is True
|
||||||
|
|
||||||
|
# Test None comparisons
|
||||||
|
assert uuid_type.compare_values(None, None) is True
|
||||||
|
assert uuid_type.compare_values(uuid1, None) is False
|
||||||
|
assert uuid_type.compare_values(None, uuid1) is False
|
||||||
|
|
||||||
|
# Test different UUIDs
|
||||||
|
uuid3 = uuid.uuid4()
|
||||||
|
assert uuid_type.compare_values(uuid1, uuid3) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqlite_uuid_binding():
|
||||||
|
"""Test SQLiteUUID binding parameter handling (moved from db_mocks.py)"""
|
||||||
|
uuid_type = SQLiteUUID()
|
||||||
|
|
||||||
|
# Test UUID object binding
|
||||||
|
uuid_obj = uuid.uuid4()
|
||||||
|
assert uuid_type.process_bind_param(uuid_obj, None) == str(uuid_obj)
|
||||||
|
|
||||||
|
# Test valid UUID string binding
|
||||||
|
uuid_str = str(uuid.uuid4())
|
||||||
|
assert uuid_type.process_bind_param(uuid_str, None) == uuid_str
|
||||||
|
|
||||||
|
# Test None handling
|
||||||
|
assert uuid_type.process_bind_param(None, None) is None
|
||||||
|
|
||||||
|
# Test invalid UUID string
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
uuid_type.process_bind_param("invalid-uuid", None)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test UUID Column Type Configuration ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_uuid_column_type_default():
|
||||||
|
"""Test UUID_COLUMN_TYPE uses SQLiteUUID in test environment"""
|
||||||
|
assert isinstance(UUID_COLUMN_TYPE, SQLiteUUID)
|
||||||
|
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"MOCK_AUTH": "false"})
|
||||||
|
def test_uuid_column_type_postgres():
|
||||||
|
"""Test UUID_COLUMN_TYPE uses Postgres UUID when MOCK_AUTH=false"""
|
||||||
|
# Need to re-import to get the patched environment
|
||||||
|
from importlib import reload
|
||||||
|
|
||||||
|
from app import models
|
||||||
|
|
||||||
|
reload(models.db)
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID as PostgresUUID
|
||||||
|
|
||||||
|
from app.models.db import UUID_COLUMN_TYPE
|
||||||
|
|
||||||
|
assert isinstance(UUID_COLUMN_TYPE, PostgresUUID)
|
||||||
43
tests/routers/mocks.py
Normal file
43
tests/routers/mocks.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from app.iptv.scheduler import StreamScheduler
|
||||||
|
|
||||||
|
|
||||||
|
class MockScheduler:
|
||||||
|
"""Base mock APScheduler instance"""
|
||||||
|
|
||||||
|
running = True
|
||||||
|
start = Mock()
|
||||||
|
shutdown = Mock()
|
||||||
|
add_job = Mock()
|
||||||
|
remove_job = Mock()
|
||||||
|
get_job = Mock(return_value=None)
|
||||||
|
|
||||||
|
def __init__(self, running=True):
|
||||||
|
self.running = running
|
||||||
|
|
||||||
|
|
||||||
|
def create_trigger_mock(triggered_ref: dict) -> callable:
|
||||||
|
"""Create a mock trigger function that updates a reference when called"""
|
||||||
|
|
||||||
|
def trigger_mock():
|
||||||
|
triggered_ref["value"] = True
|
||||||
|
|
||||||
|
return trigger_mock
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_get_scheduler(
|
||||||
|
request: Request, scheduler_class=MockScheduler, running=True, **kwargs
|
||||||
|
) -> StreamScheduler:
|
||||||
|
"""Mock dependency for get_scheduler with customization options"""
|
||||||
|
scheduler = StreamScheduler()
|
||||||
|
mock_scheduler = scheduler_class(running=running)
|
||||||
|
|
||||||
|
# Apply any additional attributes/methods
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
setattr(mock_scheduler, key, value)
|
||||||
|
|
||||||
|
scheduler.scheduler = mock_scheduler
|
||||||
|
return scheduler
|
||||||
File diff suppressed because it is too large
Load Diff
461
tests/routers/test_groups.py
Normal file
461
tests/routers/test_groups.py
Normal file
@@ -0,0 +1,461 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from fastapi import status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.auth.dependencies import get_current_user
|
||||||
|
from app.utils.database import get_db
|
||||||
|
|
||||||
|
# Import mocks and fixtures
|
||||||
|
from tests.utils.auth_test_fixtures import (
|
||||||
|
admin_user_client,
|
||||||
|
db_session,
|
||||||
|
non_admin_user_client,
|
||||||
|
)
|
||||||
|
from tests.utils.db_mocks import (
|
||||||
|
MockChannelDB,
|
||||||
|
MockGroup,
|
||||||
|
create_mock_priorities_and_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Test Cases For Group Creation ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_group_success(db_session: Session, admin_user_client):
|
||||||
|
group_data = {"name": "Test Group", "sort_order": 1}
|
||||||
|
response = admin_user_client.post("/groups/", json=group_data)
|
||||||
|
assert response.status_code == status.HTTP_201_CREATED
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == "Test Group"
|
||||||
|
assert data["sort_order"] == 1
|
||||||
|
assert "id" in data
|
||||||
|
assert "created_at" in data
|
||||||
|
assert "updated_at" in data
|
||||||
|
|
||||||
|
# Verify in DB
|
||||||
|
db_group = (
|
||||||
|
db_session.query(MockGroup).filter(MockGroup.name == "Test Group").first()
|
||||||
|
)
|
||||||
|
assert db_group is not None
|
||||||
|
assert db_group.sort_order == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_group_duplicate(db_session: Session, admin_user_client):
|
||||||
|
# Create initial group
|
||||||
|
initial_group = MockGroup(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Duplicate Group",
|
||||||
|
sort_order=1,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db_session.add(initial_group)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Attempt to create duplicate
|
||||||
|
response = admin_user_client.post(
|
||||||
|
"/groups/", json={"name": "Duplicate Group", "sort_order": 2}
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_409_CONFLICT
|
||||||
|
assert "already exists" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_group_forbidden_for_non_admin(
|
||||||
|
db_session: Session, non_admin_user_client
|
||||||
|
):
|
||||||
|
response = non_admin_user_client.post(
|
||||||
|
"/groups/", json={"name": "Forbidden Group", "sort_order": 1}
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
assert "required roles" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test Cases For Get Group ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_group_success(db_session: Session, admin_user_client):
|
||||||
|
# Create a group first
|
||||||
|
test_group = MockGroup(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Get Me Group",
|
||||||
|
sort_order=1,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db_session.add(test_group)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
response = admin_user_client.get(f"/groups/{test_group.id}")
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] == str(test_group.id)
|
||||||
|
assert data["name"] == "Get Me Group"
|
||||||
|
assert data["sort_order"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_group_not_found(db_session: Session, admin_user_client):
|
||||||
|
random_uuid = uuid.uuid4()
|
||||||
|
response = admin_user_client.get(f"/groups/{random_uuid}")
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert "Group not found" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test Cases For Update Group ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_group_success(db_session: Session, admin_user_client):
|
||||||
|
# Create initial group
|
||||||
|
group_id = uuid.uuid4()
|
||||||
|
test_group = MockGroup(
|
||||||
|
id=group_id,
|
||||||
|
name="Update Me",
|
||||||
|
sort_order=1,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db_session.add(test_group)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
update_data = {"name": "Updated Name", "sort_order": 2}
|
||||||
|
response = admin_user_client.put(f"/groups/{group_id}", json=update_data)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == "Updated Name"
|
||||||
|
assert data["sort_order"] == 2
|
||||||
|
|
||||||
|
# Verify in DB
|
||||||
|
db_group = db_session.query(MockGroup).filter(MockGroup.id == group_id).first()
|
||||||
|
assert db_group.name == "Updated Name"
|
||||||
|
assert db_group.sort_order == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_group_conflict(db_session: Session, admin_user_client):
|
||||||
|
# Create two groups
|
||||||
|
group1 = MockGroup(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Group One",
|
||||||
|
sort_order=1,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
group2 = MockGroup(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Group Two",
|
||||||
|
sort_order=2,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db_session.add_all([group1, group2])
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Try to rename group2 to conflict with group1
|
||||||
|
response = admin_user_client.put(f"/groups/{group2.id}", json={"name": "Group One"})
|
||||||
|
assert response.status_code == status.HTTP_409_CONFLICT
|
||||||
|
assert "already exists" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_group_not_found(db_session: Session, admin_user_client):
|
||||||
|
random_uuid = uuid.uuid4()
|
||||||
|
response = admin_user_client.put(
|
||||||
|
f"/groups/{random_uuid}", json={"name": "Non-existent"}
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert "Group not found" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_group_forbidden_for_non_admin(
|
||||||
|
db_session: Session, non_admin_user_client, admin_user_client
|
||||||
|
):
|
||||||
|
# Create group with admin
|
||||||
|
group_id = uuid.uuid4()
|
||||||
|
test_group = MockGroup(
|
||||||
|
id=group_id,
|
||||||
|
name="Admin Created",
|
||||||
|
sort_order=1,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db_session.add(test_group)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Attempt update with non-admin
|
||||||
|
response = non_admin_user_client.put(
|
||||||
|
f"/groups/{group_id}", json={"name": "Non-Admin Update"}
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
assert "required roles" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test Cases For Delete Group ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_all_groups_success(db_session, admin_user_client):
|
||||||
|
"""Test reset groups endpoint"""
|
||||||
|
# Create test data
|
||||||
|
group1_id = create_mock_priorities_and_group(db_session, [], "Group A")
|
||||||
|
group2_id = create_mock_priorities_and_group(db_session, [], "Group B")
|
||||||
|
|
||||||
|
# Add channel to group2
|
||||||
|
channel_data = [
|
||||||
|
{
|
||||||
|
"group-title": "Group A",
|
||||||
|
"tvg_id": "channel1.tv",
|
||||||
|
"name": "Channel One",
|
||||||
|
"url": ["http://test.com", "http://example.com"],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
admin_user_client.post("/channels/bulk-upload", json=channel_data)
|
||||||
|
|
||||||
|
# Reset groups
|
||||||
|
response = admin_user_client.delete("/groups")
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert response.json()["deleted"] == 1 # Only group2 should be deleted
|
||||||
|
assert response.json()["skipped"] == 1 # group1 has channels
|
||||||
|
|
||||||
|
# Verify group2 deleted, group1 remains
|
||||||
|
assert (
|
||||||
|
db_session.query(MockGroup).filter(MockGroup.id == group1_id).first()
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
assert db_session.query(MockGroup).filter(MockGroup.id == group2_id).first() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_all_groups_forbidden_for_non_admin(db_session, non_admin_user_client):
|
||||||
|
"""Test reset groups requires admin role"""
|
||||||
|
response = non_admin_user_client.delete("/groups")
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_group_success(db_session: Session, admin_user_client):
|
||||||
|
# Create group
|
||||||
|
group_id = uuid.uuid4()
|
||||||
|
test_group = MockGroup(
|
||||||
|
id=group_id,
|
||||||
|
name="Delete Me",
|
||||||
|
sort_order=1,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db_session.add(test_group)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Verify exists before delete
|
||||||
|
assert (
|
||||||
|
db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
response = admin_user_client.delete(f"/groups/{group_id}")
|
||||||
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||||
|
|
||||||
|
# Verify deleted
|
||||||
|
assert db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_group_with_channels_fails(db_session: Session, admin_user_client):
|
||||||
|
# Create group with channel
|
||||||
|
group_id = uuid.uuid4()
|
||||||
|
test_group = MockGroup(
|
||||||
|
id=group_id,
|
||||||
|
name="Group With Channels",
|
||||||
|
sort_order=1,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db_session.add(test_group)
|
||||||
|
|
||||||
|
# Create channel in this group
|
||||||
|
test_channel = MockChannelDB(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
tvg_id="channel1.tv",
|
||||||
|
name="Channel 1",
|
||||||
|
group_id=group_id,
|
||||||
|
tvg_name="Channel1",
|
||||||
|
tvg_logo="logo.png",
|
||||||
|
)
|
||||||
|
db_session.add(test_channel)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
response = admin_user_client.delete(f"/groups/{group_id}")
|
||||||
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
assert "existing channels" in response.json()["detail"]
|
||||||
|
|
||||||
|
# Verify group still exists
|
||||||
|
assert (
|
||||||
|
db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_group_not_found(db_session: Session, admin_user_client):
|
||||||
|
random_uuid = uuid.uuid4()
|
||||||
|
response = admin_user_client.delete(f"/groups/{random_uuid}")
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert "Group not found" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_group_forbidden_for_non_admin(
|
||||||
|
db_session: Session, non_admin_user_client, admin_user_client
|
||||||
|
):
|
||||||
|
# Create group with admin
|
||||||
|
group_id = uuid.uuid4()
|
||||||
|
test_group = MockGroup(
|
||||||
|
id=group_id,
|
||||||
|
name="Admin Created",
|
||||||
|
sort_order=1,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db_session.add(test_group)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Attempt delete with non-admin
|
||||||
|
response = non_admin_user_client.delete(f"/groups/{group_id}")
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
assert "required roles" in response.json()["detail"]
|
||||||
|
|
||||||
|
# Verify group still exists
|
||||||
|
assert (
|
||||||
|
db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test Cases For List Groups ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_groups_empty(db_session: Session, admin_user_client):
|
||||||
|
response = admin_user_client.get("/groups/")
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert response.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_groups_with_data(db_session: Session, admin_user_client):
|
||||||
|
# Create some groups
|
||||||
|
groups = [
|
||||||
|
MockGroup(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name=f"Group {i}",
|
||||||
|
sort_order=i,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
db_session.add_all(groups)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
response = admin_user_client.get("/groups/")
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert len(data) == 3
|
||||||
|
assert data[0]["sort_order"] == 0 # Should be sorted by sort_order
|
||||||
|
assert data[1]["sort_order"] == 1
|
||||||
|
assert data[2]["sort_order"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test Cases For Sort Order Updates ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_group_sort_order_success(db_session: Session, admin_user_client):
|
||||||
|
# Create group
|
||||||
|
group_id = uuid.uuid4()
|
||||||
|
test_group = MockGroup(
|
||||||
|
id=group_id,
|
||||||
|
name="Sort Me",
|
||||||
|
sort_order=1,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db_session.add(test_group)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
response = admin_user_client.put(f"/groups/{group_id}/sort", json={"sort_order": 5})
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["sort_order"] == 5
|
||||||
|
|
||||||
|
# Verify in DB
|
||||||
|
db_group = db_session.query(MockGroup).filter(MockGroup.id == group_id).first()
|
||||||
|
assert db_group.sort_order == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_group_sort_order_not_found(db_session: Session, admin_user_client):
|
||||||
|
"""Test that updating sort order for non-existent group returns 404"""
|
||||||
|
random_uuid = uuid.uuid4()
|
||||||
|
response = admin_user_client.put(
|
||||||
|
f"/groups/{random_uuid}/sort", json={"sort_order": 5}
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert "Group not found" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_bulk_update_sort_orders_success(db_session: Session, admin_user_client):
|
||||||
|
# Create groups
|
||||||
|
groups = [
|
||||||
|
MockGroup(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name=f"Group {i}",
|
||||||
|
sort_order=i,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
print(groups)
|
||||||
|
db_session.add_all(groups)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Bulk update sort orders (reverse order)
|
||||||
|
bulk_data = {
|
||||||
|
"groups": [
|
||||||
|
{"group_id": str(groups[0].id), "sort_order": 2},
|
||||||
|
{"group_id": str(groups[1].id), "sort_order": 1},
|
||||||
|
{"group_id": str(groups[2].id), "sort_order": 0},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
response = admin_user_client.post("/groups/reorder", json=bulk_data)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert len(data) == 3
|
||||||
|
|
||||||
|
# Create a dictionary for easy lookup of returned group data by ID
|
||||||
|
returned_groups_map = {item["id"]: item for item in data}
|
||||||
|
|
||||||
|
# Verify each group has its expected new sort_order
|
||||||
|
assert returned_groups_map[str(groups[0].id)]["sort_order"] == 2
|
||||||
|
assert returned_groups_map[str(groups[1].id)]["sort_order"] == 1
|
||||||
|
assert returned_groups_map[str(groups[2].id)]["sort_order"] == 0
|
||||||
|
|
||||||
|
# Verify in DB
|
||||||
|
db_groups = db_session.query(MockGroup).order_by(MockGroup.sort_order).all()
|
||||||
|
assert db_groups[0].sort_order == 2
|
||||||
|
assert db_groups[1].sort_order == 1
|
||||||
|
assert db_groups[2].sort_order == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_bulk_update_sort_orders_invalid_group(db_session: Session, admin_user_client):
|
||||||
|
# Create one group
|
||||||
|
group_id = uuid.uuid4()
|
||||||
|
test_group = MockGroup(
|
||||||
|
id=group_id,
|
||||||
|
name="Valid Group",
|
||||||
|
sort_order=1,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db_session.add(test_group)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Try to update with invalid group
|
||||||
|
bulk_data = {
|
||||||
|
"groups": [
|
||||||
|
{"group_id": str(group_id), "sort_order": 2},
|
||||||
|
{"group_id": str(uuid.uuid4()), "sort_order": 1}, # Invalid group
|
||||||
|
]
|
||||||
|
}
|
||||||
|
response = admin_user_client.post("/groups/reorder", json=bulk_data)
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert "not found" in response.json()["detail"]
|
||||||
|
|
||||||
|
# Verify original sort order unchanged
|
||||||
|
db_group = db_session.query(MockGroup).filter(MockGroup.id == group_id).first()
|
||||||
|
assert db_group.sort_order == 1
|
||||||
261
tests/routers/test_playlist.py
Normal file
261
tests/routers/test_playlist.py
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.auth.dependencies import get_current_user
|
||||||
|
|
||||||
|
# Import the router we're testing
|
||||||
|
from app.routers.playlist import (
|
||||||
|
ProcessStatus,
|
||||||
|
ValidationProcessResponse,
|
||||||
|
ValidationResultResponse,
|
||||||
|
router,
|
||||||
|
validation_processes,
|
||||||
|
)
|
||||||
|
from app.utils.database import get_db
|
||||||
|
|
||||||
|
# Import mocks and fixtures
|
||||||
|
from tests.utils.auth_test_fixtures import (
|
||||||
|
admin_user_client,
|
||||||
|
db_session,
|
||||||
|
non_admin_user_client,
|
||||||
|
)
|
||||||
|
from tests.utils.db_mocks import MockChannelDB
|
||||||
|
|
||||||
|
# --- Test Fixtures ---
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_stream_manager():
|
||||||
|
with patch("app.routers.playlist.StreamManager") as mock:
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test Cases For Stream Validation ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_start_stream_validation_success(
|
||||||
|
db_session: Session, admin_user_client, mock_stream_manager
|
||||||
|
):
|
||||||
|
"""Test starting a stream validation process"""
|
||||||
|
mock_instance = mock_stream_manager.return_value
|
||||||
|
mock_instance.validate_and_select_stream.return_value = "http://valid.stream.url"
|
||||||
|
|
||||||
|
response = admin_user_client.post(
|
||||||
|
"/playlist/validate-streams", json={"channel_id": "test-channel"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||||
|
data = response.json()
|
||||||
|
assert "process_id" in data
|
||||||
|
assert data["status"] == ProcessStatus.PENDING
|
||||||
|
assert data["message"] == "Validation process started"
|
||||||
|
|
||||||
|
# Verify process was added to tracking
|
||||||
|
process_id = data["process_id"]
|
||||||
|
assert process_id in validation_processes
|
||||||
|
# In test environment, background tasks run synchronously so status may be COMPLETED
|
||||||
|
assert validation_processes[process_id]["status"] in [
|
||||||
|
ProcessStatus.PENDING,
|
||||||
|
ProcessStatus.COMPLETED,
|
||||||
|
]
|
||||||
|
assert validation_processes[process_id]["channel_id"] == "test-channel"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_validation_status_pending(db_session: Session, admin_user_client):
|
||||||
|
"""Test checking status of pending validation"""
|
||||||
|
process_id = str(uuid.uuid4())
|
||||||
|
validation_processes[process_id] = {
|
||||||
|
"status": ProcessStatus.PENDING,
|
||||||
|
"channel_id": "test-channel",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = admin_user_client.get(f"/playlist/validate-streams/{process_id}")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["process_id"] == process_id
|
||||||
|
assert data["status"] == ProcessStatus.PENDING
|
||||||
|
assert data["working_streams"] is None
|
||||||
|
assert data["error"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_validation_status_completed(db_session: Session, admin_user_client):
|
||||||
|
"""Test checking status of completed validation"""
|
||||||
|
process_id = str(uuid.uuid4())
|
||||||
|
validation_processes[process_id] = {
|
||||||
|
"status": ProcessStatus.COMPLETED,
|
||||||
|
"channel_id": "test-channel",
|
||||||
|
"result": {
|
||||||
|
"working_streams": [
|
||||||
|
{"channel_id": "test-channel", "stream_url": "http://valid.stream.url"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
response = admin_user_client.get(f"/playlist/validate-streams/{process_id}")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["process_id"] == process_id
|
||||||
|
assert data["status"] == ProcessStatus.COMPLETED
|
||||||
|
assert len(data["working_streams"]) == 1
|
||||||
|
assert data["working_streams"][0]["channel_id"] == "test-channel"
|
||||||
|
assert data["working_streams"][0]["stream_url"] == "http://valid.stream.url"
|
||||||
|
assert data["error"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_validation_status_completed_with_error(
|
||||||
|
db_session: Session, admin_user_client
|
||||||
|
):
|
||||||
|
"""Test checking status of completed validation with error"""
|
||||||
|
process_id = str(uuid.uuid4())
|
||||||
|
validation_processes[process_id] = {
|
||||||
|
"status": ProcessStatus.COMPLETED,
|
||||||
|
"channel_id": "test-channel",
|
||||||
|
"error": "No working streams found for channel test-channel",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = admin_user_client.get(f"/playlist/validate-streams/{process_id}")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["process_id"] == process_id
|
||||||
|
assert data["status"] == ProcessStatus.COMPLETED
|
||||||
|
assert data["working_streams"] is None
|
||||||
|
assert data["error"] == "No working streams found for channel test-channel"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_validation_status_failed(db_session: Session, admin_user_client):
|
||||||
|
"""Test checking status of failed validation"""
|
||||||
|
process_id = str(uuid.uuid4())
|
||||||
|
validation_processes[process_id] = {
|
||||||
|
"status": ProcessStatus.FAILED,
|
||||||
|
"channel_id": "test-channel",
|
||||||
|
"error": "Validation error occurred",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = admin_user_client.get(f"/playlist/validate-streams/{process_id}")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["process_id"] == process_id
|
||||||
|
assert data["status"] == ProcessStatus.FAILED
|
||||||
|
assert data["working_streams"] is None
|
||||||
|
assert data["error"] == "Validation error occurred"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_validation_status_not_found(db_session: Session, admin_user_client):
|
||||||
|
"""Test checking status of non-existent process"""
|
||||||
|
random_uuid = str(uuid.uuid4())
|
||||||
|
response = admin_user_client.get(f"/playlist/validate-streams/{random_uuid}")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert "Process not found" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_stream_validation_success(mock_stream_manager, db_session):
|
||||||
|
"""Test the background validation task success case"""
|
||||||
|
process_id = str(uuid.uuid4())
|
||||||
|
validation_processes[process_id] = {
|
||||||
|
"status": ProcessStatus.PENDING,
|
||||||
|
"channel_id": "test-channel",
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_instance = mock_stream_manager.return_value
|
||||||
|
mock_instance.validate_and_select_stream.return_value = "http://valid.stream.url"
|
||||||
|
|
||||||
|
from app.routers.playlist import run_stream_validation
|
||||||
|
|
||||||
|
run_stream_validation(process_id, "test-channel", db_session)
|
||||||
|
|
||||||
|
assert validation_processes[process_id]["status"] == ProcessStatus.COMPLETED
|
||||||
|
assert len(validation_processes[process_id]["result"]["working_streams"]) == 1
|
||||||
|
assert (
|
||||||
|
validation_processes[process_id]["result"]["working_streams"][0].channel_id
|
||||||
|
== "test-channel"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
validation_processes[process_id]["result"]["working_streams"][0].stream_url
|
||||||
|
== "http://valid.stream.url"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_stream_validation_failure(mock_stream_manager, db_session):
|
||||||
|
"""Test the background validation task failure case"""
|
||||||
|
process_id = str(uuid.uuid4())
|
||||||
|
validation_processes[process_id] = {
|
||||||
|
"status": ProcessStatus.PENDING,
|
||||||
|
"channel_id": "test-channel",
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_instance = mock_stream_manager.return_value
|
||||||
|
mock_instance.validate_and_select_stream.return_value = None
|
||||||
|
|
||||||
|
from app.routers.playlist import run_stream_validation
|
||||||
|
|
||||||
|
run_stream_validation(process_id, "test-channel", db_session)
|
||||||
|
|
||||||
|
assert validation_processes[process_id]["status"] == ProcessStatus.COMPLETED
|
||||||
|
assert "error" in validation_processes[process_id]
|
||||||
|
assert "No working streams found" in validation_processes[process_id]["error"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_stream_validation_exception(mock_stream_manager, db_session):
|
||||||
|
"""Test the background validation task exception case"""
|
||||||
|
process_id = str(uuid.uuid4())
|
||||||
|
validation_processes[process_id] = {
|
||||||
|
"status": ProcessStatus.PENDING,
|
||||||
|
"channel_id": "test-channel",
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_instance = mock_stream_manager.return_value
|
||||||
|
mock_instance.validate_and_select_stream.side_effect = Exception("Test error")
|
||||||
|
|
||||||
|
from app.routers.playlist import run_stream_validation
|
||||||
|
|
||||||
|
run_stream_validation(process_id, "test-channel", db_session)
|
||||||
|
|
||||||
|
assert validation_processes[process_id]["status"] == ProcessStatus.FAILED
|
||||||
|
assert "error" in validation_processes[process_id]
|
||||||
|
assert "Test error" in validation_processes[process_id]["error"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_start_stream_validation_no_channel_id(
|
||||||
|
db_session: Session, admin_user_client, mock_stream_manager
|
||||||
|
):
|
||||||
|
"""Test starting validation without channel_id"""
|
||||||
|
response = admin_user_client.post("/playlist/validate-streams", json={})
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||||
|
data = response.json()
|
||||||
|
assert "process_id" in data
|
||||||
|
assert data["status"] == ProcessStatus.PENDING
|
||||||
|
|
||||||
|
# Verify process was added to tracking
|
||||||
|
process_id = data["process_id"]
|
||||||
|
assert process_id in validation_processes
|
||||||
|
assert validation_processes[process_id]["status"] in [
|
||||||
|
ProcessStatus.PENDING,
|
||||||
|
ProcessStatus.COMPLETED,
|
||||||
|
]
|
||||||
|
assert validation_processes[process_id]["channel_id"] is None
|
||||||
|
assert "not yet implemented" in validation_processes[process_id].get("error", "")
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_stream_validation_no_channel_id(mock_stream_manager, db_session):
|
||||||
|
"""Test background validation without channel_id"""
|
||||||
|
process_id = str(uuid.uuid4())
|
||||||
|
validation_processes[process_id] = {"status": ProcessStatus.PENDING}
|
||||||
|
|
||||||
|
from app.routers.playlist import run_stream_validation
|
||||||
|
|
||||||
|
run_stream_validation(process_id, None, db_session)
|
||||||
|
|
||||||
|
assert validation_processes[process_id]["status"] == ProcessStatus.COMPLETED
|
||||||
|
assert "error" in validation_processes[process_id]
|
||||||
|
assert "not yet implemented" in validation_processes[process_id]["error"]
|
||||||
241
tests/routers/test_priorities.py
Normal file
241
tests/routers/test_priorities.py
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from fastapi import status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.routers.priorities import router as priorities_router
|
||||||
|
|
||||||
|
# Import fixtures and mocks
|
||||||
|
from tests.utils.auth_test_fixtures import (
|
||||||
|
admin_user_client,
|
||||||
|
db_session,
|
||||||
|
non_admin_user_client,
|
||||||
|
)
|
||||||
|
from tests.utils.db_mocks import (
|
||||||
|
MockChannelDB,
|
||||||
|
MockChannelURL,
|
||||||
|
MockGroup,
|
||||||
|
MockPriority,
|
||||||
|
create_mock_priorities_and_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Test Cases For Priority Creation ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_priority_success(db_session: Session, admin_user_client):
|
||||||
|
priority_data = {"id": 100, "description": "Test Priority"}
|
||||||
|
response = admin_user_client.post("/priorities/", json=priority_data)
|
||||||
|
assert response.status_code == status.HTTP_201_CREATED
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] == 100
|
||||||
|
assert data["description"] == "Test Priority"
|
||||||
|
|
||||||
|
# Verify in DB
|
||||||
|
db_priority = db_session.get(MockPriority, 100)
|
||||||
|
assert db_priority is not None
|
||||||
|
assert db_priority.description == "Test Priority"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_priority_duplicate(db_session: Session, admin_user_client):
|
||||||
|
# Create initial priority
|
||||||
|
priority_data = {"id": 100, "description": "Original Priority"}
|
||||||
|
response1 = admin_user_client.post("/priorities/", json=priority_data)
|
||||||
|
assert response1.status_code == status.HTTP_201_CREATED
|
||||||
|
|
||||||
|
# Attempt to create with same ID
|
||||||
|
response2 = admin_user_client.post("/priorities/", json=priority_data)
|
||||||
|
assert response2.status_code == status.HTTP_409_CONFLICT
|
||||||
|
assert "already exists" in response2.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_priority_forbidden_for_non_admin(
|
||||||
|
db_session: Session, non_admin_user_client
|
||||||
|
):
|
||||||
|
priority_data = {"id": 100, "description": "Test Priority"}
|
||||||
|
response = non_admin_user_client.post("/priorities/", json=priority_data)
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
assert "required roles" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test Cases For List Priorities ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_priorities_empty(db_session: Session, admin_user_client):
|
||||||
|
response = admin_user_client.get("/priorities/")
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert response.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_priorities_with_data(db_session: Session, admin_user_client):
|
||||||
|
# Create some test priorities
|
||||||
|
priorities = [
|
||||||
|
MockPriority(id=100, description="High"),
|
||||||
|
MockPriority(id=200, description="Medium"),
|
||||||
|
MockPriority(id=300, description="Low"),
|
||||||
|
]
|
||||||
|
for priority in priorities:
|
||||||
|
db_session.add(priority)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
response = admin_user_client.get("/priorities/")
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert len(data) == 3
|
||||||
|
assert data[0]["id"] == 100
|
||||||
|
assert data[0]["description"] == "High"
|
||||||
|
assert data[1]["id"] == 200
|
||||||
|
assert data[1]["description"] == "Medium"
|
||||||
|
assert data[2]["id"] == 300
|
||||||
|
assert data[2]["description"] == "Low"
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_priorities_forbidden_for_non_admin(
|
||||||
|
db_session: Session, non_admin_user_client
|
||||||
|
):
|
||||||
|
response = non_admin_user_client.get("/priorities/")
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
assert "required roles" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test Cases For Get Priority ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_priority_success(db_session: Session, admin_user_client):
|
||||||
|
# Create a test priority
|
||||||
|
priority = MockPriority(id=100, description="Test Priority")
|
||||||
|
db_session.add(priority)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
response = admin_user_client.get("/priorities/100")
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] == 100
|
||||||
|
assert data["description"] == "Test Priority"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_priority_not_found(db_session: Session, admin_user_client):
|
||||||
|
response = admin_user_client.get("/priorities/999")
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert "Priority not found" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_priority_forbidden_for_non_admin(
|
||||||
|
db_session: Session, non_admin_user_client
|
||||||
|
):
|
||||||
|
response = non_admin_user_client.get("/priorities/100")
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
assert "required roles" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test Cases For Delete Priority ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_all_priorities_success(db_session, admin_user_client):
|
||||||
|
"""Test reset priorities endpoint"""
|
||||||
|
# Create test data
|
||||||
|
priorities = [(100, "High"), (200, "Medium"), (300, "Low")]
|
||||||
|
for id, desc in priorities:
|
||||||
|
db_session.add(MockPriority(id=id, description=desc))
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Create channel using priority 100
|
||||||
|
create_mock_priorities_and_group(db_session, [], "Test Group")
|
||||||
|
channel_data = [
|
||||||
|
{
|
||||||
|
"group-title": "Test Group",
|
||||||
|
"tvg_id": "test.tv",
|
||||||
|
"name": "Test Channel",
|
||||||
|
"urls": ["http://test.com"],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
admin_user_client.post("/channels/bulk-upload", json=channel_data)
|
||||||
|
|
||||||
|
# Delete all priorities
|
||||||
|
response = admin_user_client.delete("/priorities")
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert response.json()["deleted"] == 2 # Medium and Low priorities
|
||||||
|
assert response.json()["skipped"] == 1 # High priority is in use
|
||||||
|
|
||||||
|
# Verify only priority 100 remains
|
||||||
|
priorities = db_session.query(MockPriority).all()
|
||||||
|
assert len(priorities) == 1
|
||||||
|
assert priorities[0].id == 100
|
||||||
|
|
||||||
|
|
||||||
|
def test_reset_priorities_forbidden_for_non_admin(db_session, non_admin_user_client):
|
||||||
|
"""Test reset priorities requires admin role"""
|
||||||
|
response = non_admin_user_client.delete("/priorities")
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_priority_success(db_session: Session, admin_user_client):
|
||||||
|
# Create a test priority
|
||||||
|
priority = MockPriority(id=100, description="To Delete")
|
||||||
|
db_session.add(priority)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
response = admin_user_client.delete("/priorities/100")
|
||||||
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||||
|
|
||||||
|
# Verify priority is gone from DB
|
||||||
|
db_priority = db_session.get(MockPriority, 100)
|
||||||
|
assert db_priority is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_priority_not_found(db_session: Session, admin_user_client):
|
||||||
|
response = admin_user_client.delete("/priorities/999")
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert "Priority not found" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_priority_in_use(db_session: Session, admin_user_client):
|
||||||
|
# Create a priority and a channel URL using it
|
||||||
|
priority = MockPriority(id=100, description="In Use")
|
||||||
|
group_id = uuid.uuid4()
|
||||||
|
test_group = MockGroup(
|
||||||
|
id=group_id,
|
||||||
|
name="Group With Channels",
|
||||||
|
sort_order=1,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db_session.add_all([priority, test_group])
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Create a channel first
|
||||||
|
channel = MockChannelDB(
|
||||||
|
name="Test Channel",
|
||||||
|
tvg_id="test.tv",
|
||||||
|
tvg_name="Test",
|
||||||
|
tvg_logo="test.png",
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
db_session.add(channel)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Create URL associated with the channel and priority
|
||||||
|
channel_url = MockChannelURL(
|
||||||
|
url="http://test.com",
|
||||||
|
priority_id=100,
|
||||||
|
in_use=True,
|
||||||
|
channel_id=channel.id, # Add the channel_id
|
||||||
|
)
|
||||||
|
db_session.add(channel_url)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
response = admin_user_client.delete("/priorities/100")
|
||||||
|
assert response.status_code == status.HTTP_409_CONFLICT
|
||||||
|
assert "in use by channel URLs" in response.json()["detail"]
|
||||||
|
|
||||||
|
# Verify priority still exists
|
||||||
|
db_priority = db_session.get(MockPriority, 100)
|
||||||
|
assert db_priority is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_priority_forbidden_for_non_admin(
|
||||||
|
db_session: Session, non_admin_user_client
|
||||||
|
):
|
||||||
|
response = non_admin_user_client.delete("/priorities/100")
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
assert "required roles" in response.json()["detail"]
|
||||||
287
tests/routers/test_scheduler.py
Normal file
287
tests/routers/test_scheduler.py
Normal file
@@ -0,0 +1,287 @@
|
|||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request, status
|
||||||
|
|
||||||
|
from app.iptv.scheduler import StreamScheduler
|
||||||
|
from app.routers.scheduler import get_scheduler
|
||||||
|
from app.routers.scheduler import router as scheduler_router
|
||||||
|
from app.utils.database import get_db
|
||||||
|
from tests.routers.mocks import MockScheduler, create_trigger_mock, mock_get_scheduler
|
||||||
|
from tests.utils.auth_test_fixtures import (
|
||||||
|
admin_user_client,
|
||||||
|
db_session,
|
||||||
|
non_admin_user_client,
|
||||||
|
)
|
||||||
|
from tests.utils.db_mocks import mock_get_db
|
||||||
|
|
||||||
|
# Scheduler Health Check Tests
|
||||||
|
|
||||||
|
|
||||||
|
def test_scheduler_health_success(admin_user_client, monkeypatch):
|
||||||
|
"""
|
||||||
|
Test case for successful scheduler health check when accessed by an admin user.
|
||||||
|
It mocks the scheduler to be running and have a next scheduled job.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Define the expected next run time for the scheduler job.
|
||||||
|
next_run = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# Create a mock job object that simulates an APScheduler job.
|
||||||
|
mock_job = Mock()
|
||||||
|
mock_job.next_run_time = next_run
|
||||||
|
|
||||||
|
# Mock the `get_job` method to return our mock_job for a specific ID.
|
||||||
|
def mock_get_job(job_id):
|
||||||
|
if job_id == "daily_stream_validation":
|
||||||
|
return mock_job
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create a custom mock for `get_scheduler` dependency.
|
||||||
|
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
|
||||||
|
return await mock_get_scheduler(
|
||||||
|
request,
|
||||||
|
running=True,
|
||||||
|
get_job=Mock(side_effect=mock_get_job), # Use the custom mock_get_job
|
||||||
|
)
|
||||||
|
|
||||||
|
# Include the scheduler router in the test application.
|
||||||
|
admin_user_client.app.include_router(scheduler_router)
|
||||||
|
|
||||||
|
# Override dependencies for the test.
|
||||||
|
admin_user_client.app.dependency_overrides[get_scheduler] = (
|
||||||
|
custom_mock_get_scheduler
|
||||||
|
)
|
||||||
|
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||||
|
|
||||||
|
# Make the request to the scheduler health endpoint.
|
||||||
|
response = admin_user_client.get("/scheduler/health")
|
||||||
|
|
||||||
|
# Assert the response status code and content.
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "running"
|
||||||
|
assert data["next_run"] == str(next_run)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scheduler_health_stopped(admin_user_client, monkeypatch):
|
||||||
|
"""
|
||||||
|
Test case for scheduler health check when the scheduler is in a stopped state.
|
||||||
|
Ensures the API returns the correct status and no next run time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create a custom mock for `get_scheduler` dependency,
|
||||||
|
# simulating a stopped scheduler.
|
||||||
|
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
|
||||||
|
return await mock_get_scheduler(
|
||||||
|
request,
|
||||||
|
running=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Include the scheduler router in the test application.
|
||||||
|
admin_user_client.app.include_router(scheduler_router)
|
||||||
|
|
||||||
|
# Override dependencies for the test.
|
||||||
|
admin_user_client.app.dependency_overrides[get_scheduler] = (
|
||||||
|
custom_mock_get_scheduler
|
||||||
|
)
|
||||||
|
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||||
|
|
||||||
|
# Make the request to the scheduler health endpoint.
|
||||||
|
response = admin_user_client.get("/scheduler/health")
|
||||||
|
|
||||||
|
# Assert the response status code and content.
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "stopped"
|
||||||
|
assert data["next_run"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_scheduler_health_forbidden_for_non_admin(non_admin_user_client, monkeypatch):
|
||||||
|
"""
|
||||||
|
Test case to ensure that non-admin users are forbidden from accessing
|
||||||
|
the scheduler health endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create a custom mock for `get_scheduler` dependency.
|
||||||
|
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
|
||||||
|
return await mock_get_scheduler(
|
||||||
|
request,
|
||||||
|
running=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Include the scheduler router in the test application.
|
||||||
|
non_admin_user_client.app.include_router(scheduler_router)
|
||||||
|
|
||||||
|
# Override dependencies for the test.
|
||||||
|
non_admin_user_client.app.dependency_overrides[get_scheduler] = (
|
||||||
|
custom_mock_get_scheduler
|
||||||
|
)
|
||||||
|
non_admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||||
|
|
||||||
|
# Make the request to the scheduler health endpoint.
|
||||||
|
response = non_admin_user_client.get("/scheduler/health")
|
||||||
|
|
||||||
|
# Assert the response status code and error detail.
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
assert "required roles" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_scheduler_health_check_exception(admin_user_client, monkeypatch):
|
||||||
|
"""
|
||||||
|
Test case for handling exceptions during the scheduler health check.
|
||||||
|
Ensures the API returns a 500 Internal Server Error when an exception occurs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create a custom mock for `get_scheduler` dependency that raises an exception.
|
||||||
|
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
|
||||||
|
return await mock_get_scheduler(
|
||||||
|
request, running=True, get_job=Mock(side_effect=Exception("Test exception"))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Include the scheduler router in the test application.
|
||||||
|
admin_user_client.app.include_router(scheduler_router)
|
||||||
|
|
||||||
|
# Override dependencies for the test.
|
||||||
|
admin_user_client.app.dependency_overrides[get_scheduler] = (
|
||||||
|
custom_mock_get_scheduler
|
||||||
|
)
|
||||||
|
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||||
|
|
||||||
|
# Make the request to the scheduler health endpoint.
|
||||||
|
response = admin_user_client.get("/scheduler/health")
|
||||||
|
|
||||||
|
# Assert the response status code and error detail.
|
||||||
|
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||||
|
assert "Failed to check scheduler health" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
# Scheduler Trigger Tests
|
||||||
|
|
||||||
|
|
||||||
|
def test_trigger_validation_success(admin_user_client, monkeypatch):
|
||||||
|
"""
|
||||||
|
Test case for successful manual triggering
|
||||||
|
of stream validation by an admin user.
|
||||||
|
It verifies that the trigger method is called and
|
||||||
|
the API returns a 202 Accepted status.
|
||||||
|
"""
|
||||||
|
# Use a mutable reference to check if the trigger method was called.
|
||||||
|
triggered_ref = {"value": False}
|
||||||
|
|
||||||
|
# Initialize a custom mock scheduler.
|
||||||
|
custom_scheduler = MockScheduler(running=True)
|
||||||
|
custom_scheduler.get_job = Mock(return_value=None)
|
||||||
|
|
||||||
|
# Create a custom mock for `get_scheduler` dependency,
|
||||||
|
# overriding `trigger_manual_validation`.
|
||||||
|
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
|
||||||
|
scheduler = await mock_get_scheduler(
|
||||||
|
request,
|
||||||
|
running=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Replace the actual trigger method with our mock to track calls.
|
||||||
|
scheduler.trigger_manual_validation = create_trigger_mock(
|
||||||
|
triggered_ref=triggered_ref
|
||||||
|
)
|
||||||
|
|
||||||
|
return scheduler
|
||||||
|
|
||||||
|
# Include the scheduler router in the test application.
|
||||||
|
admin_user_client.app.include_router(scheduler_router)
|
||||||
|
|
||||||
|
# Override dependencies for the test.
|
||||||
|
admin_user_client.app.dependency_overrides[get_scheduler] = (
|
||||||
|
custom_mock_get_scheduler
|
||||||
|
)
|
||||||
|
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||||
|
|
||||||
|
# Make the request to trigger stream validation.
|
||||||
|
response = admin_user_client.post("/scheduler/trigger")
|
||||||
|
|
||||||
|
# Assert the response status code, message, and that the trigger was called.
|
||||||
|
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||||
|
assert response.json()["message"] == "Stream validation triggered"
|
||||||
|
assert triggered_ref["value"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_trigger_validation_forbidden_for_non_admin(non_admin_user_client, monkeypatch):
|
||||||
|
"""
|
||||||
|
Test case to ensure that non-admin users are
|
||||||
|
forbidden from manually triggering stream validation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create a custom mock for `get_scheduler` dependency.
|
||||||
|
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
|
||||||
|
return await mock_get_scheduler(
|
||||||
|
request,
|
||||||
|
running=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Include the scheduler router in the test application.
|
||||||
|
non_admin_user_client.app.include_router(scheduler_router)
|
||||||
|
|
||||||
|
# Override dependencies for the test.
|
||||||
|
non_admin_user_client.app.dependency_overrides[get_scheduler] = (
|
||||||
|
custom_mock_get_scheduler
|
||||||
|
)
|
||||||
|
non_admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||||
|
|
||||||
|
# Make the request to trigger stream validation.
|
||||||
|
response = non_admin_user_client.post("/scheduler/trigger")
|
||||||
|
|
||||||
|
# Assert the response status code and error detail.
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
assert "required roles" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_scheduler_initialized_in_app_state(admin_user_client):
|
||||||
|
"""
|
||||||
|
Test case for when the scheduler is initialized in the app state but its internal
|
||||||
|
scheduler attribute is not set, which should still allow health check.
|
||||||
|
"""
|
||||||
|
scheduler = StreamScheduler()
|
||||||
|
|
||||||
|
# Set the scheduler instance in the test client's app state.
|
||||||
|
admin_user_client.app.state.scheduler = scheduler
|
||||||
|
|
||||||
|
# Include the scheduler router in the test application.
|
||||||
|
admin_user_client.app.include_router(scheduler_router)
|
||||||
|
|
||||||
|
# Override only get_db, allowing the real get_scheduler to be tested.
|
||||||
|
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||||
|
|
||||||
|
# Make the request to the scheduler health endpoint.
|
||||||
|
response = admin_user_client.get("/scheduler/health")
|
||||||
|
|
||||||
|
# Assert the response status code.
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
|
||||||
|
def test_scheduler_not_initialized_in_app_state(admin_user_client):
|
||||||
|
"""
|
||||||
|
Test case for when the scheduler is not properly initialized in the app state.
|
||||||
|
This simulates a scenario where the internal scheduler attribute is missing,
|
||||||
|
leading to a 500 Internal Server Error on health check.
|
||||||
|
"""
|
||||||
|
scheduler = StreamScheduler()
|
||||||
|
del (
|
||||||
|
scheduler.scheduler
|
||||||
|
) # Simulate uninitialized scheduler by deleting the attribute
|
||||||
|
|
||||||
|
# Set the scheduler instance in the test client's app state.
|
||||||
|
admin_user_client.app.state.scheduler = scheduler
|
||||||
|
|
||||||
|
# Include the scheduler router in the test application.
|
||||||
|
admin_user_client.app.include_router(scheduler_router)
|
||||||
|
|
||||||
|
# Override only get_db, allowing the real get_scheduler to be tested.
|
||||||
|
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||||
|
|
||||||
|
# Make the request to the scheduler health endpoint.
|
||||||
|
response = admin_user_client.get("/scheduler/health")
|
||||||
|
|
||||||
|
# Assert the response status code and error detail.
|
||||||
|
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||||
|
assert "Scheduler not initialized" in response.json()["detail"]
|
||||||
@@ -16,7 +16,7 @@ def test_root_endpoint(client):
|
|||||||
"""Test root endpoint returns expected message"""
|
"""Test root endpoint returns expected message"""
|
||||||
response = client.get("/")
|
response = client.get("/")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == {"message": "IPTV Updater API"}
|
assert response.json() == {"message": "IPTV Manager API"}
|
||||||
|
|
||||||
|
|
||||||
def test_openapi_schema_generation(client):
|
def test_openapi_schema_generation(client):
|
||||||
|
|||||||
82
tests/utils/auth_test_fixtures.py
Normal file
82
tests/utils/auth_test_fixtures.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.auth.dependencies import get_current_user
|
||||||
|
from app.models.auth import CognitoUser
|
||||||
|
from app.routers.channels import router as channels_router
|
||||||
|
from app.routers.groups import router as groups_router
|
||||||
|
from app.routers.playlist import router as playlist_router
|
||||||
|
from app.routers.priorities import router as priorities_router
|
||||||
|
from app.utils.database import get_db
|
||||||
|
from tests.utils.db_mocks import (
|
||||||
|
MockBase,
|
||||||
|
MockChannelDB,
|
||||||
|
MockChannelURL,
|
||||||
|
MockPriority,
|
||||||
|
engine_mock,
|
||||||
|
mock_get_db,
|
||||||
|
)
|
||||||
|
from tests.utils.db_mocks import session_mock as TestingSessionLocal
|
||||||
|
|
||||||
|
|
||||||
|
def mock_get_current_user_admin():
|
||||||
|
return CognitoUser(
|
||||||
|
username="testadmin",
|
||||||
|
email="testadmin@example.com",
|
||||||
|
roles=["admin"],
|
||||||
|
user_status="CONFIRMED",
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def mock_get_current_user_non_admin():
|
||||||
|
return CognitoUser(
|
||||||
|
username="testuser",
|
||||||
|
email="testuser@example.com",
|
||||||
|
roles=["user"], # Or any role other than admin
|
||||||
|
user_status="CONFIRMED",
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def db_session():
|
||||||
|
# Create tables for each test function
|
||||||
|
MockBase.metadata.create_all(bind=engine_mock)
|
||||||
|
db = TestingSessionLocal()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
# Drop tables after each test function
|
||||||
|
MockBase.metadata.drop_all(bind=engine_mock)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def admin_user_client(db_session: Session):
|
||||||
|
"""Yields a TestClient configured with an admin user."""
|
||||||
|
test_app = FastAPI()
|
||||||
|
test_app.include_router(channels_router)
|
||||||
|
test_app.include_router(priorities_router)
|
||||||
|
test_app.include_router(playlist_router)
|
||||||
|
test_app.include_router(groups_router)
|
||||||
|
test_app.dependency_overrides[get_db] = mock_get_db
|
||||||
|
test_app.dependency_overrides[get_current_user] = mock_get_current_user_admin
|
||||||
|
with TestClient(test_app) as test_client:
|
||||||
|
yield test_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def non_admin_user_client(db_session: Session):
|
||||||
|
"""Yields a TestClient configured with a non-admin user."""
|
||||||
|
test_app = FastAPI()
|
||||||
|
test_app.include_router(channels_router)
|
||||||
|
test_app.include_router(priorities_router)
|
||||||
|
test_app.include_router(playlist_router)
|
||||||
|
test_app.include_router(groups_router)
|
||||||
|
test_app.dependency_overrides[get_db] = mock_get_db
|
||||||
|
test_app.dependency_overrides[get_current_user] = mock_get_current_user_non_admin
|
||||||
|
with TestClient(test_app) as test_client:
|
||||||
|
yield test_client
|
||||||
@@ -4,41 +4,25 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
TEXT,
|
|
||||||
Boolean,
|
Boolean,
|
||||||
Column,
|
Column,
|
||||||
DateTime,
|
DateTime,
|
||||||
ForeignKey,
|
ForeignKey,
|
||||||
Integer,
|
Integer,
|
||||||
String,
|
String,
|
||||||
TypeDecorator,
|
|
||||||
UniqueConstraint,
|
UniqueConstraint,
|
||||||
create_engine,
|
create_engine,
|
||||||
)
|
)
|
||||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
from sqlalchemy.orm import declarative_base, relationship, sessionmaker
|
||||||
from sqlalchemy.pool import StaticPool
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
# Import the actual UUID_COLUMN_TYPE and SQLiteUUID from app.models.db
|
||||||
|
from app.models.db import UUID_COLUMN_TYPE, SQLiteUUID
|
||||||
|
|
||||||
# Create a mock-specific Base class for testing
|
# Create a mock-specific Base class for testing
|
||||||
MockBase = declarative_base()
|
MockBase = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
class SQLiteUUID(TypeDecorator):
|
|
||||||
"""Enables UUID support for SQLite."""
|
|
||||||
|
|
||||||
impl = TEXT
|
|
||||||
cache_ok = True
|
|
||||||
|
|
||||||
def process_bind_param(self, value, dialect):
|
|
||||||
if value is None:
|
|
||||||
return value
|
|
||||||
return str(value)
|
|
||||||
|
|
||||||
def process_result_value(self, value, dialect):
|
|
||||||
if value is None:
|
|
||||||
return value
|
|
||||||
return uuid.UUID(value)
|
|
||||||
|
|
||||||
|
|
||||||
# Model classes for testing - prefix with Mock to avoid pytest collection
|
# Model classes for testing - prefix with Mock to avoid pytest collection
|
||||||
class MockPriority(MockBase):
|
class MockPriority(MockBase):
|
||||||
__tablename__ = "priorities"
|
__tablename__ = "priorities"
|
||||||
@@ -46,16 +30,28 @@ class MockPriority(MockBase):
|
|||||||
description = Column(String, nullable=False)
|
description = Column(String, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
class MockGroup(MockBase):
|
||||||
|
__tablename__ = "groups"
|
||||||
|
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
||||||
|
name = Column(String, nullable=False, unique=True)
|
||||||
|
sort_order = Column(Integer, nullable=False, default=0)
|
||||||
|
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||||
|
updated_at = Column(
|
||||||
|
DateTime,
|
||||||
|
default=lambda: datetime.now(timezone.utc),
|
||||||
|
onupdate=lambda: datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
channels = relationship("MockChannelDB", back_populates="group")
|
||||||
|
|
||||||
|
|
||||||
class MockChannelDB(MockBase):
|
class MockChannelDB(MockBase):
|
||||||
__tablename__ = "channels"
|
__tablename__ = "channels"
|
||||||
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
||||||
tvg_id = Column(String, nullable=False)
|
tvg_id = Column(String, nullable=False)
|
||||||
name = Column(String, nullable=False)
|
name = Column(String, nullable=False)
|
||||||
group_title = Column(String, nullable=False)
|
group_id = Column(UUID_COLUMN_TYPE, ForeignKey("groups.id"), nullable=False)
|
||||||
tvg_name = Column(String)
|
tvg_name = Column(String)
|
||||||
__table_args__ = (
|
__table_args__ = (UniqueConstraint("group_id", "name", name="uix_group_id_name"),)
|
||||||
UniqueConstraint("group_title", "name", name="uix_group_title_name"),
|
|
||||||
)
|
|
||||||
tvg_logo = Column(String)
|
tvg_logo = Column(String)
|
||||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||||
updated_at = Column(
|
updated_at = Column(
|
||||||
@@ -63,13 +59,17 @@ class MockChannelDB(MockBase):
|
|||||||
default=lambda: datetime.now(timezone.utc),
|
default=lambda: datetime.now(timezone.utc),
|
||||||
onupdate=lambda: datetime.now(timezone.utc),
|
onupdate=lambda: datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
|
group = relationship("MockGroup", back_populates="channels")
|
||||||
|
urls = relationship(
|
||||||
|
"MockChannelURL", back_populates="channel", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MockChannelURL(MockBase):
|
class MockChannelURL(MockBase):
|
||||||
__tablename__ = "channels_urls"
|
__tablename__ = "channels_urls"
|
||||||
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
||||||
channel_id = Column(
|
channel_id = Column(
|
||||||
SQLiteUUID(), ForeignKey("channels.id", ondelete="CASCADE"), nullable=False
|
UUID_COLUMN_TYPE, ForeignKey("channels.id", ondelete="CASCADE"), nullable=False
|
||||||
)
|
)
|
||||||
url = Column(String, nullable=False)
|
url = Column(String, nullable=False)
|
||||||
in_use = Column(Boolean, default=False, nullable=False)
|
in_use = Column(Boolean, default=False, nullable=False)
|
||||||
@@ -80,6 +80,32 @@ class MockChannelURL(MockBase):
|
|||||||
default=lambda: datetime.now(timezone.utc),
|
default=lambda: datetime.now(timezone.utc),
|
||||||
onupdate=lambda: datetime.now(timezone.utc),
|
onupdate=lambda: datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
|
channel = relationship("MockChannelDB", back_populates="urls")
|
||||||
|
|
||||||
|
|
||||||
|
def create_mock_priorities_and_group(db_session, priorities, group_name):
|
||||||
|
"""Create mock priorities and group for testing purposes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: SQLAlchemy session object
|
||||||
|
priorities: List of (id, description) tuples for priorities to create
|
||||||
|
group_name: Name for the new mock group
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UUID: The ID of the created group
|
||||||
|
"""
|
||||||
|
# Create priorities
|
||||||
|
priority_objects = [
|
||||||
|
MockPriority(id=priority_id, description=description)
|
||||||
|
for priority_id, description in priorities
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create group
|
||||||
|
group = MockGroup(name=group_name)
|
||||||
|
db_session.add_all(priority_objects + [group])
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
return group.id
|
||||||
|
|
||||||
|
|
||||||
# Create test engine
|
# Create test engine
|
||||||
|
|||||||
309
tests/utils/test_check_streams.py
Normal file
309
tests/utils/test_check_streams.py
Normal file
@@ -0,0 +1,309 @@
|
|||||||
|
import os
|
||||||
|
from unittest.mock import MagicMock, Mock, mock_open, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout
|
||||||
|
|
||||||
|
from app.utils.check_streams import StreamValidator, main
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def validator():
|
||||||
|
"""Create a StreamValidator instance for testing"""
|
||||||
|
return StreamValidator(timeout=1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validator_init():
|
||||||
|
"""Test StreamValidator initialization with default and custom values"""
|
||||||
|
# Test with default user agent
|
||||||
|
validator = StreamValidator()
|
||||||
|
assert validator.timeout == 10
|
||||||
|
assert "Mozilla" in validator.session.headers["User-Agent"]
|
||||||
|
|
||||||
|
# Test with custom values
|
||||||
|
custom_agent = "CustomAgent/1.0"
|
||||||
|
validator = StreamValidator(timeout=5, user_agent=custom_agent)
|
||||||
|
assert validator.timeout == 5
|
||||||
|
assert validator.session.headers["User-Agent"] == custom_agent
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_valid_content_type(validator):
|
||||||
|
"""Test content type validation"""
|
||||||
|
valid_types = [
|
||||||
|
"video/mp4",
|
||||||
|
"video/mp2t",
|
||||||
|
"application/vnd.apple.mpegurl",
|
||||||
|
"application/dash+xml",
|
||||||
|
"video/webm",
|
||||||
|
"application/octet-stream",
|
||||||
|
"application/x-mpegURL",
|
||||||
|
"video/mp4; charset=utf-8", # Test with additional parameters
|
||||||
|
]
|
||||||
|
|
||||||
|
invalid_types = [
|
||||||
|
"text/html",
|
||||||
|
"application/json",
|
||||||
|
"image/jpeg",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
|
||||||
|
for content_type in valid_types:
|
||||||
|
assert validator._is_valid_content_type(content_type)
|
||||||
|
|
||||||
|
for content_type in invalid_types:
|
||||||
|
assert not validator._is_valid_content_type(content_type)
|
||||||
|
|
||||||
|
# Test None case explicitly
|
||||||
|
assert not validator._is_valid_content_type(None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"status_code,content_type,should_succeed",
|
||||||
|
[
|
||||||
|
(200, "video/mp4", True),
|
||||||
|
(206, "video/mp4", True), # Partial content
|
||||||
|
(404, "video/mp4", False),
|
||||||
|
(500, "video/mp4", False),
|
||||||
|
(200, "text/html", False),
|
||||||
|
(200, "application/json", False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_validate_stream_response_handling(status_code, content_type, should_succeed):
|
||||||
|
"""Test stream validation with different response scenarios"""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = status_code
|
||||||
|
mock_response.headers = {"Content-Type": content_type}
|
||||||
|
mock_response.iter_content.return_value = iter([b"some content"])
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.get.return_value.__enter__.return_value = mock_response
|
||||||
|
|
||||||
|
with patch("requests.Session", return_value=mock_session):
|
||||||
|
validator = StreamValidator()
|
||||||
|
valid, message = validator.validate_stream("http://example.com/stream")
|
||||||
|
|
||||||
|
assert valid == should_succeed
|
||||||
|
mock_session.get.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_stream_connection_error():
|
||||||
|
"""Test stream validation with connection error"""
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.get.side_effect = ConnectionError("Connection failed")
|
||||||
|
|
||||||
|
with patch("requests.Session", return_value=mock_session):
|
||||||
|
validator = StreamValidator()
|
||||||
|
valid, message = validator.validate_stream("http://example.com/stream")
|
||||||
|
|
||||||
|
assert not valid
|
||||||
|
assert "Connection Error" in message
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_stream_timeout():
|
||||||
|
"""Test stream validation with timeout"""
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.get.side_effect = Timeout("Request timed out")
|
||||||
|
|
||||||
|
with patch("requests.Session", return_value=mock_session):
|
||||||
|
validator = StreamValidator()
|
||||||
|
valid, message = validator.validate_stream("http://example.com/stream")
|
||||||
|
|
||||||
|
assert not valid
|
||||||
|
assert "timeout" in message.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_stream_http_error():
|
||||||
|
"""Test stream validation with HTTP error"""
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.get.side_effect = HTTPError("HTTP Error occurred")
|
||||||
|
|
||||||
|
with patch("requests.Session", return_value=mock_session):
|
||||||
|
validator = StreamValidator()
|
||||||
|
valid, message = validator.validate_stream("http://example.com/stream")
|
||||||
|
|
||||||
|
assert not valid
|
||||||
|
assert "HTTP Error" in message
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_stream_request_exception():
|
||||||
|
"""Test stream validation with general request exception"""
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.get.side_effect = RequestException("Request failed")
|
||||||
|
|
||||||
|
with patch("requests.Session", return_value=mock_session):
|
||||||
|
validator = StreamValidator()
|
||||||
|
valid, message = validator.validate_stream("http://example.com/stream")
|
||||||
|
|
||||||
|
assert not valid
|
||||||
|
assert "Request Exception" in message
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_stream_content_read_error():
|
||||||
|
"""Test stream validation when content reading fails"""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.headers = {"Content-Type": "video/mp4"}
|
||||||
|
mock_response.iter_content.side_effect = ConnectionError("Read failed")
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.get.return_value.__enter__.return_value = mock_response
|
||||||
|
|
||||||
|
with patch("requests.Session", return_value=mock_session):
|
||||||
|
validator = StreamValidator()
|
||||||
|
valid, message = validator.validate_stream("http://example.com/stream")
|
||||||
|
|
||||||
|
assert not valid
|
||||||
|
assert "Connection failed during content read" in message
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_stream_general_exception():
|
||||||
|
"""Test validate_stream with an unexpected exception"""
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.get.side_effect = Exception("Unexpected error")
|
||||||
|
|
||||||
|
with patch("requests.Session", return_value=mock_session):
|
||||||
|
validator = StreamValidator()
|
||||||
|
valid, message = validator.validate_stream("http://example.com/stream")
|
||||||
|
|
||||||
|
assert not valid
|
||||||
|
assert "Validation error" in message
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_playlist(validator, tmp_path):
|
||||||
|
"""Test playlist file parsing"""
|
||||||
|
playlist_content = """
|
||||||
|
#EXTM3U
|
||||||
|
#EXTINF:-1,Channel 1
|
||||||
|
http://example.com/stream1
|
||||||
|
#EXTINF:-1,Channel 2
|
||||||
|
http://example.com/stream2
|
||||||
|
|
||||||
|
http://example.com/stream3
|
||||||
|
"""
|
||||||
|
playlist_file = tmp_path / "test_playlist.m3u"
|
||||||
|
playlist_file.write_text(playlist_content)
|
||||||
|
|
||||||
|
urls = validator.parse_playlist(str(playlist_file))
|
||||||
|
assert len(urls) == 3
|
||||||
|
assert urls == [
|
||||||
|
"http://example.com/stream1",
|
||||||
|
"http://example.com/stream2",
|
||||||
|
"http://example.com/stream3",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_playlist_error(validator):
|
||||||
|
"""Test playlist parsing with non-existent file"""
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
validator.parse_playlist("nonexistent_file.m3u")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("app.utils.check_streams.logging")
|
||||||
|
@patch("app.utils.check_streams.StreamValidator")
|
||||||
|
def test_main_with_urls(mock_validator_class, mock_logging, tmp_path, capsys):
|
||||||
|
"""Test main function with direct URLs"""
|
||||||
|
# Setup mock validator
|
||||||
|
mock_validator = Mock()
|
||||||
|
mock_validator_class.return_value = mock_validator
|
||||||
|
mock_validator.validate_stream.return_value = (True, "Stream is valid")
|
||||||
|
|
||||||
|
# Setup test arguments
|
||||||
|
test_args = ["script", "http://example.com/stream1", "http://example.com/stream2"]
|
||||||
|
with patch("sys.argv", test_args):
|
||||||
|
main()
|
||||||
|
|
||||||
|
# Verify validator was called correctly
|
||||||
|
assert mock_validator.validate_stream.call_count == 2
|
||||||
|
mock_validator.validate_stream.assert_any_call("http://example.com/stream1")
|
||||||
|
mock_validator.validate_stream.assert_any_call("http://example.com/stream2")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("app.utils.check_streams.logging")
|
||||||
|
@patch("app.utils.check_streams.StreamValidator")
|
||||||
|
def test_main_with_playlist(mock_validator_class, mock_logging, tmp_path):
|
||||||
|
"""Test main function with a playlist file"""
|
||||||
|
# Create test playlist
|
||||||
|
playlist_content = "http://example.com/stream1\nhttp://example.com/stream2"
|
||||||
|
playlist_file = tmp_path / "test.m3u"
|
||||||
|
playlist_file.write_text(playlist_content)
|
||||||
|
|
||||||
|
# Setup mock validator
|
||||||
|
mock_validator = Mock()
|
||||||
|
mock_validator_class.return_value = mock_validator
|
||||||
|
mock_validator.parse_playlist.return_value = [
|
||||||
|
"http://example.com/stream1",
|
||||||
|
"http://example.com/stream2",
|
||||||
|
]
|
||||||
|
mock_validator.validate_stream.return_value = (True, "Stream is valid")
|
||||||
|
|
||||||
|
# Setup test arguments
|
||||||
|
test_args = ["script", str(playlist_file)]
|
||||||
|
with patch("sys.argv", test_args):
|
||||||
|
main()
|
||||||
|
|
||||||
|
# Verify validator was called correctly
|
||||||
|
mock_validator.parse_playlist.assert_called_once_with(str(playlist_file))
|
||||||
|
assert mock_validator.validate_stream.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@patch("app.utils.check_streams.logging")
|
||||||
|
@patch("app.utils.check_streams.StreamValidator")
|
||||||
|
def test_main_with_dead_streams(mock_validator_class, mock_logging, tmp_path):
|
||||||
|
"""Test main function handling dead streams"""
|
||||||
|
# Setup mock validator
|
||||||
|
mock_validator = Mock()
|
||||||
|
mock_validator_class.return_value = mock_validator
|
||||||
|
mock_validator.validate_stream.return_value = (False, "Stream is dead")
|
||||||
|
|
||||||
|
# Setup test arguments
|
||||||
|
test_args = ["script", "http://example.com/dead1", "http://example.com/dead2"]
|
||||||
|
|
||||||
|
# Mock file operations
|
||||||
|
mock_file = mock_open()
|
||||||
|
with patch("sys.argv", test_args), patch("builtins.open", mock_file):
|
||||||
|
main()
|
||||||
|
|
||||||
|
# Verify dead streams were written to file
|
||||||
|
mock_file().write.assert_called_once_with(
|
||||||
|
"http://example.com/dead1\nhttp://example.com/dead2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("app.utils.check_streams.logging")
|
||||||
|
@patch("app.utils.check_streams.StreamValidator")
|
||||||
|
@patch("os.path.isfile")
|
||||||
|
def test_main_with_playlist_error(
|
||||||
|
mock_isfile, mock_validator_class, mock_logging, tmp_path
|
||||||
|
):
|
||||||
|
"""Test main function handling playlist parsing errors"""
|
||||||
|
# Setup mock validator
|
||||||
|
mock_validator = Mock()
|
||||||
|
mock_validator_class.return_value = mock_validator
|
||||||
|
|
||||||
|
# Configure mock validator behavior
|
||||||
|
error_msg = "Failed to parse playlist"
|
||||||
|
mock_validator.parse_playlist.side_effect = [
|
||||||
|
Exception(error_msg), # First call fails
|
||||||
|
["http://example.com/stream1"], # Second call succeeds
|
||||||
|
]
|
||||||
|
mock_validator.validate_stream.return_value = (True, "Stream is valid")
|
||||||
|
|
||||||
|
# Configure isfile mock to return True for our test files
|
||||||
|
mock_isfile.side_effect = lambda x: x in ["/invalid.m3u", "/valid.m3u"]
|
||||||
|
|
||||||
|
# Setup test arguments
|
||||||
|
test_args = ["script", "/invalid.m3u", "/valid.m3u"]
|
||||||
|
with patch("sys.argv", test_args):
|
||||||
|
main()
|
||||||
|
|
||||||
|
# Verify error was logged correctly
|
||||||
|
mock_logging.error.assert_called_with(
|
||||||
|
"Failed to process file /invalid.m3u: Failed to parse playlist"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify processing continued with valid playlist
|
||||||
|
mock_validator.parse_playlist.assert_called_with("/valid.m3u")
|
||||||
|
assert (
|
||||||
|
mock_validator.validate_stream.call_count == 1
|
||||||
|
) # Called for the URL from valid playlist
|
||||||
Reference in New Issue
Block a user