Compare commits
15 Commits
6d506122d9
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| a42d4c30a6 | |||
| abb467749b | |||
| b8ac25e301 | |||
| 729eabf27f | |||
| 34c446bcfa | |||
| d4cc74ea8c | |||
| 21b73b6843 | |||
| e743daf9f7 | |||
| b0d98551b8 | |||
| eaab1ef998 | |||
| e25f8c1ecd | |||
| 95bf0f9701 | |||
| f7a1c20066 | |||
| bf6f156fec | |||
| 7e25ec6755 |
@@ -1,10 +1,15 @@
|
||||
|
||||
# Environment variables
|
||||
# Scheduler configuration
|
||||
STREAM_VALIDATION_SCHEDULE=0 3 * * * # Daily at 3 AM (cron syntax)
|
||||
STREAM_VALIDATION_BATCH_SIZE=10 # Number of channels per batch (0=all)
|
||||
|
||||
# For use with Docker Compose to run application locally
|
||||
MOCK_AUTH=true/false
|
||||
DB_USER=MyDBUser
|
||||
DB_PASSWORD=MyDBPassword
|
||||
DB_HOST=MyDBHost
|
||||
DB_NAME=iptv_updater
|
||||
DB_NAME=iptv_manager
|
||||
|
||||
FREEDNS_User=MyFreeDNSUsername
|
||||
FREEDNS_Password=MyFreeDNSPassword
|
||||
|
||||
@@ -58,7 +58,7 @@ jobs:
|
||||
run: |
|
||||
INSTANCE_IDS=$(aws ec2 describe-instances \
|
||||
--region us-east-2 \
|
||||
--filters "Name=tag:Name,Values=IptvUpdaterStack/IptvUpdaterInstance" \
|
||||
--filters "Name=tag:Name,Values=IptvManagerStack/IptvManagerInstance" \
|
||||
"Name=instance-state-name,Values=running" \
|
||||
--query "Reservations[].Instances[].InstanceId" \
|
||||
--output text)
|
||||
@@ -69,11 +69,11 @@ jobs:
|
||||
--instance-ids "$INSTANCE_ID" \
|
||||
--document-name "AWS-RunShellScript" \
|
||||
--parameters 'commands=[
|
||||
"cd /home/ec2-user/iptv-updater-aws",
|
||||
"cd /home/ec2-user/iptv-manager-service",
|
||||
"git pull",
|
||||
"pip3 install -r requirements.txt",
|
||||
"alembic upgrade head",
|
||||
"sudo systemctl restart iptv-updater"
|
||||
"sudo systemctl restart iptv-manager"
|
||||
]'
|
||||
done
|
||||
|
||||
|
||||
@@ -1,7 +1,16 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.2
|
||||
rev: v0.11.12
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix]
|
||||
- id: ruff-format
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: pytest-check
|
||||
name: pytest-check
|
||||
entry: pytest
|
||||
language: system
|
||||
pass_filenames: false
|
||||
always_run: true
|
||||
19
.vscode/settings.json
vendored
19
.vscode/settings.json
vendored
@@ -1,4 +1,6 @@
|
||||
{
|
||||
"python.terminal.activateEnvironment": true,
|
||||
"python.terminal.activateEnvInCurrentTerminal": true,
|
||||
"editor.formatOnSave": true,
|
||||
"editor.defaultFormatter": "charliermarsh.ruff",
|
||||
"ruff.importStrategy": "fromEnvironment",
|
||||
@@ -7,14 +9,18 @@
|
||||
"addopts",
|
||||
"adminpassword",
|
||||
"altinstall",
|
||||
"apscheduler",
|
||||
"asyncio",
|
||||
"autoflush",
|
||||
"autoupdate",
|
||||
"autouse",
|
||||
"awscli",
|
||||
"awscliv",
|
||||
"boto",
|
||||
"botocore",
|
||||
"BURSTABLE",
|
||||
"cabletv",
|
||||
"capsys",
|
||||
"CDUF",
|
||||
"cduflogo",
|
||||
"cdulogo",
|
||||
@@ -29,10 +35,14 @@
|
||||
"cluflogo",
|
||||
"clulogo",
|
||||
"cpulogo",
|
||||
"crond",
|
||||
"cronie",
|
||||
"cuflgo",
|
||||
"CUNF",
|
||||
"cunflogo",
|
||||
"cuulogo",
|
||||
"datname",
|
||||
"deadstreams",
|
||||
"delenv",
|
||||
"delogo",
|
||||
"devel",
|
||||
@@ -40,22 +50,29 @@
|
||||
"dmlogo",
|
||||
"dotenv",
|
||||
"EXTINF",
|
||||
"EXTM",
|
||||
"fastapi",
|
||||
"filterwarnings",
|
||||
"fiorinis",
|
||||
"freedns",
|
||||
"fullchain",
|
||||
"gitea",
|
||||
"httpx",
|
||||
"iptv",
|
||||
"isort",
|
||||
"KHTML",
|
||||
"lclogo",
|
||||
"LETSENCRYPT",
|
||||
"levelname",
|
||||
"mpegurl",
|
||||
"nohup",
|
||||
"nopriority",
|
||||
"ondelete",
|
||||
"onupdate",
|
||||
"passlib",
|
||||
"PGPASSWORD",
|
||||
"poolclass",
|
||||
"psql",
|
||||
"psycopg",
|
||||
"pycache",
|
||||
"pycodestyle",
|
||||
@@ -70,12 +87,14 @@
|
||||
"ruru",
|
||||
"sessionmaker",
|
||||
"sqlalchemy",
|
||||
"sqliteuuid",
|
||||
"starlette",
|
||||
"stefano",
|
||||
"testadmin",
|
||||
"testdb",
|
||||
"testpass",
|
||||
"testpaths",
|
||||
"testuser",
|
||||
"uflogo",
|
||||
"umlogo",
|
||||
"usefixtures",
|
||||
|
||||
226
README.md
226
README.md
@@ -1,165 +1,149 @@
|
||||
# IPTV Updater AWS
|
||||
# IPTV Manager Service
|
||||
|
||||
An automated IPTV playlist and EPG updater service deployed on AWS infrastructure using CDK.
|
||||
A FastAPI-based service for managing IPTV playlists and channel priorities. The application provides secure endpoints for user authentication, channel management, and playlist generation.
|
||||
|
||||
## Overview
|
||||
## ✨ Features
|
||||
|
||||
This project provides a service for automatically updating IPTV playlists and Electronic Program Guide (EPG) data. It runs on AWS infrastructure with:
|
||||
- **JWT Authentication**: Secure login using AWS Cognito
|
||||
- **Channel Management**: CRUD operations for IPTV channels
|
||||
- **Playlist Generation**: Create M3U playlists with channel priorities
|
||||
- **Stream Monitoring**: Background checks for channel availability
|
||||
- **Priority Management**: Set channel priorities for playlist ordering
|
||||
|
||||
- EC2 instance for hosting the application
|
||||
- RDS PostgreSQL database for data storage
|
||||
- Amazon Cognito for user authentication
|
||||
- HTTPS support via Let's Encrypt
|
||||
- Domain management via FreeDNS
|
||||
## 🛠️ Technology Stack
|
||||
|
||||
## Prerequisites
|
||||
- **Backend**: Python 3.11, FastAPI
|
||||
- **Database**: PostgreSQL (SQLAlchemy ORM)
|
||||
- **Authentication**: AWS Cognito
|
||||
- **Infrastructure**: AWS CDK (API Gateway, Lambda, RDS)
|
||||
- **Testing**: Pytest with 85%+ coverage
|
||||
- **CI/CD**: Pre-commit hooks, Alembic migrations
|
||||
|
||||
- AWS CLI installed and configured
|
||||
- Python 3.12 or later
|
||||
- Node.js v22.15 or later for AWS CDK
|
||||
- Docker and Docker Compose for local development
|
||||
## 🚀 Getting Started
|
||||
|
||||
## Local Development
|
||||
### Prerequisites
|
||||
|
||||
1. Clone the repository:
|
||||
- Python 3.11+
|
||||
- Docker
|
||||
- AWS CLI (for deployment)
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
git clone <repo-url>
|
||||
cd iptv-updater-aws
|
||||
# Clone repository
|
||||
git clone https://github.com/your-repo/iptv-manager-service.git
|
||||
cd iptv-manager-service
|
||||
|
||||
# Setup environment
|
||||
python -m venv venv
|
||||
source venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
cp .env.example .env # Update with your values
|
||||
|
||||
# Run installation script
|
||||
./scripts/install.sh
|
||||
```
|
||||
|
||||
2. Copy the example environment file:
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
3. Add your configuration to `.env`:
|
||||
|
||||
```
|
||||
FREEDNS_User=your_freedns_username
|
||||
FREEDNS_Password=your_freedns_password
|
||||
DOMAIN_NAME=your.domain.name
|
||||
SSH_PUBLIC_KEY=your_ssh_public_key
|
||||
REPO_URL=repository_url
|
||||
LETSENCRYPT_EMAIL=your_email
|
||||
```
|
||||
|
||||
4. Start the local development environment:
|
||||
### Running Locally
|
||||
|
||||
```bash
|
||||
# Start development environment
|
||||
./scripts/start_local_dev.sh
|
||||
```
|
||||
|
||||
5. Stop the local environment:
|
||||
|
||||
```bash
|
||||
# Stop development environment
|
||||
./scripts/stop_local_dev.sh
|
||||
```
|
||||
|
||||
## Deployment
|
||||
## ☁️ AWS Deployment
|
||||
|
||||
### Initial Deployment
|
||||
|
||||
1. Ensure your AWS credentials are configured:
|
||||
|
||||
```bash
|
||||
aws configure
|
||||
```
|
||||
|
||||
2. Install dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
3. Deploy the infrastructure:
|
||||
The infrastructure is defined in CDK. Use the provided scripts:
|
||||
|
||||
```bash
|
||||
# Deploy AWS infrastructure
|
||||
./scripts/deploy.sh
|
||||
|
||||
# Destroy AWS infrastructure
|
||||
./scripts/destroy.sh
|
||||
|
||||
# Create Cognito test user
|
||||
./scripts/create_cognito_user.sh
|
||||
|
||||
# Delete Cognito user
|
||||
./scripts/delete_cognito_user.sh
|
||||
```
|
||||
|
||||
The deployment script will:
|
||||
Key AWS components:
|
||||
|
||||
- Create/update the CloudFormation stack using CDK
|
||||
- Configure the EC2 instance with required software
|
||||
- Set up HTTPS using Let's Encrypt
|
||||
- Configure the domain using FreeDNS
|
||||
- API Gateway
|
||||
- Lambda functions
|
||||
- RDS PostgreSQL
|
||||
- Cognito User Pool
|
||||
|
||||
### Continuous Deployment
|
||||
## 🤖 Continuous Integration/Deployment
|
||||
|
||||
The project includes a Gitea workflow (`.gitea/workflows/aws_deploy_on_push.yml`) that automatically:
|
||||
This project includes a Gitea Actions workflow (`.gitea/workflows/deploy.yml`) for automated deployment to AWS. The workflow is fully compatible with GitHub Actions and can be easily adapted by:
|
||||
|
||||
- Deploys infrastructure changes
|
||||
- Updates the application on EC2 instances
|
||||
- Restarts the service
|
||||
1. Placing the workflow file in the `.github/workflows/` directory
|
||||
2. Setting up the required secrets in your CI/CD environment:
|
||||
- `AWS_ACCESS_KEY_ID`
|
||||
- `AWS_SECRET_ACCESS_KEY`
|
||||
- `AWS_DEFAULT_REGION`
|
||||
|
||||
## Infrastructure
|
||||
The workflow automatically deploys the infrastructure and application when changes are pushed to the main branch.
|
||||
|
||||
The AWS infrastructure is defined in `infrastructure/stack.py` and includes:
|
||||
## 📚 API Documentation
|
||||
|
||||
- VPC with public subnets
|
||||
- EC2 t2.micro instance (Free Tier eligible)
|
||||
- RDS PostgreSQL database (db.t3.micro)
|
||||
- Security groups for EC2 and RDS
|
||||
- Elastic IP for the EC2 instance
|
||||
- Cognito User Pool for authentication
|
||||
- IAM roles and policies for EC2 instance access
|
||||
Access interactive docs at:
|
||||
|
||||
## User Management
|
||||
- Swagger UI: `http://localhost:8000/docs`
|
||||
- ReDoc: `http://localhost:8000/redoc`
|
||||
|
||||
### Creating Users
|
||||
### Key Endpoints
|
||||
|
||||
To create a new user in Cognito:
|
||||
| Endpoint | Method | Description |
|
||||
| ------------- | ------ | --------------------- |
|
||||
| `/auth/login` | POST | User authentication |
|
||||
| `/channels` | GET | List all channels |
|
||||
| `/playlist` | GET | Generate M3U playlist |
|
||||
| `/priorities` | POST | Set channel priority |
|
||||
|
||||
## 🧪 Testing
|
||||
|
||||
Run the full test suite:
|
||||
|
||||
```bash
|
||||
./scripts/create_cognito_user.sh <user_pool_id> <username> <password> --admin <= optional for defining an admin user
|
||||
pytest
|
||||
```
|
||||
|
||||
### Deleting Users
|
||||
Test coverage includes:
|
||||
|
||||
To delete a user from Cognito:
|
||||
- Authentication workflows
|
||||
- Channel CRUD operations
|
||||
- Playlist generation logic
|
||||
- Stream monitoring
|
||||
- Database operations
|
||||
|
||||
```bash
|
||||
./scripts/delete_cognito_user.sh <user_pool_id> <username>
|
||||
## 📂 Project Structure
|
||||
|
||||
```txt
|
||||
iptv-manager-service/
|
||||
├── app/ # Core application
|
||||
│ ├── auth/ # Cognito authentication
|
||||
│ ├── iptv/ # Playlist logic
|
||||
│ ├── models/ # Database models
|
||||
│ ├── routers/ # API endpoints
|
||||
│ ├── utils/ # Helper functions
|
||||
│ └── main.py # App entry point
|
||||
├── infrastructure/ # AWS CDK stack
|
||||
├── docker/ # Docker configs
|
||||
├── scripts/ # Deployment scripts
|
||||
├── tests/ # Comprehensive tests
|
||||
├── alembic/ # Database migrations
|
||||
├── .gitea/ # Gitea CI/CD workflows
|
||||
│ └── workflows/
|
||||
└── ... # Config files
|
||||
```
|
||||
|
||||
## Architecture
|
||||
## 📝 License
|
||||
|
||||
The application is structured as follows:
|
||||
|
||||
```bash
|
||||
app/
|
||||
├── auth/ # Authentication modules
|
||||
├── iptv/ # IPTV and EPG processing
|
||||
├── models/ # Database models
|
||||
└── utils/ # Utility functions
|
||||
|
||||
infrastructure/ # AWS CDK infrastructure code
|
||||
docker/ # Docker configuration for local development
|
||||
scripts/ # Utility scripts for deployment and management
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
||||
The following environment variables are required:
|
||||
|
||||
| Variable | Description |
|
||||
|----------|-------------|
|
||||
| FREEDNS_User | FreeDNS username |
|
||||
| FREEDNS_Password | FreeDNS password |
|
||||
| DOMAIN_NAME | Your domain name |
|
||||
| SSH_PUBLIC_KEY | SSH public key for EC2 access |
|
||||
| REPO_URL | Repository URL |
|
||||
| LETSENCRYPT_EMAIL | Email for Let's Encrypt certificates |
|
||||
|
||||
## Security Notes
|
||||
|
||||
- The EC2 instance has appropriate IAM permissions for:
|
||||
- EC2 instance discovery
|
||||
- SSM command execution
|
||||
- RDS access
|
||||
- Cognito user management
|
||||
- All database credentials are stored in AWS Secrets Manager
|
||||
- HTTPS is enforced using Let's Encrypt certificates
|
||||
- Access is restricted through Security Groups
|
||||
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
||||
|
||||
@@ -15,7 +15,8 @@ config = context.config
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# Setup target metadata for autogenerate support
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# Override sqlalchemy.url with dynamic credentials
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
"""Add priority and in_use fields
|
||||
|
||||
Revision ID: 036879e47172
|
||||
Revises:
|
||||
Create Date: 2025-05-26 19:21:32.285656
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '036879e47172'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# 1. Create priorities table if not exists
|
||||
if not op.get_bind().engine.dialect.has_table(op.get_bind(), 'priorities'):
|
||||
op.create_table('priorities',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('description', sa.String(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
|
||||
# 2. Insert default priorities (skip if already exists)
|
||||
op.execute("""
|
||||
INSERT INTO priorities (id, description)
|
||||
VALUES (100, 'High'), (200, 'Medium'), (300, 'Low')
|
||||
ON CONFLICT (id) DO NOTHING
|
||||
""")
|
||||
# Add new columns with temporary nullable=True
|
||||
op.add_column('channels_urls', sa.Column('in_use', sa.Boolean(), nullable=True))
|
||||
op.add_column('channels_urls', sa.Column('priority_id', sa.Integer(), nullable=True))
|
||||
|
||||
# Set default values
|
||||
op.execute("UPDATE channels_urls SET in_use = false, priority_id = 100")
|
||||
|
||||
# Convert to NOT NULL
|
||||
op.alter_column('channels_urls', 'in_use', nullable=False)
|
||||
op.alter_column('channels_urls', 'priority_id', nullable=False)
|
||||
op.create_foreign_key(None, 'channels_urls', 'priorities', ['priority_id'], ['id'])
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint('channels_urls_priority_id_fkey', 'channels_urls', type_='foreignkey')
|
||||
op.drop_column('channels_urls', 'priority_id')
|
||||
op.drop_column('channels_urls', 'in_use')
|
||||
op.drop_table('priorities')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,110 @@
|
||||
"""add groups table and migrate group_title data
|
||||
|
||||
Revision ID: 0a455608256f
|
||||
Revises: 95b61a92455a
|
||||
Create Date: 2025-06-10 09:22:11.820035
|
||||
|
||||
"""
|
||||
import uuid
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '0a455608256f'
|
||||
down_revision: Union[str, None] = '95b61a92455a'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('groups',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=False),
|
||||
sa.Column('sort_order', sa.Integer(), nullable=False, server_default='0'),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('name')
|
||||
)
|
||||
|
||||
# Create temporary table for group mapping
|
||||
group_mapping = op.create_table(
|
||||
'group_mapping',
|
||||
sa.Column('group_title', sa.String(), nullable=False),
|
||||
sa.Column('group_id', sa.UUID(), nullable=False)
|
||||
)
|
||||
|
||||
# Get existing group titles and create groups
|
||||
conn = op.get_bind()
|
||||
distinct_groups = conn.execute(
|
||||
sa.text("SELECT DISTINCT group_title FROM channels")
|
||||
).fetchall()
|
||||
|
||||
for group in distinct_groups:
|
||||
group_title = group[0]
|
||||
group_id = str(uuid.uuid4())
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"INSERT INTO groups (id, name, sort_order) "
|
||||
"VALUES (:id, :name, 0)"
|
||||
).bindparams(id=group_id, name=group_title)
|
||||
)
|
||||
conn.execute(
|
||||
group_mapping.insert().values(
|
||||
group_title=group_title,
|
||||
group_id=group_id
|
||||
)
|
||||
)
|
||||
|
||||
# Add group_id column (nullable first)
|
||||
op.add_column('channels', sa.Column('group_id', sa.UUID(), nullable=True))
|
||||
|
||||
# Update channels with group_ids
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"UPDATE channels c SET group_id = gm.group_id "
|
||||
"FROM group_mapping gm WHERE c.group_title = gm.group_title"
|
||||
)
|
||||
)
|
||||
|
||||
# Now make group_id non-nullable and add constraints
|
||||
op.alter_column('channels', 'group_id', nullable=False)
|
||||
op.drop_constraint(op.f('uix_group_title_name'), 'channels', type_='unique')
|
||||
op.create_unique_constraint('uix_group_id_name', 'channels', ['group_id', 'name'])
|
||||
op.create_foreign_key('fk_channels_group_id', 'channels', 'groups', ['group_id'], ['id'])
|
||||
|
||||
# Clean up and drop group_title
|
||||
op.drop_table('group_mapping')
|
||||
op.drop_column('channels', 'group_title')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('channels', sa.Column('group_title', sa.VARCHAR(), autoincrement=False, nullable=True))
|
||||
|
||||
# Restore group_title values from groups table
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"UPDATE channels c SET group_title = g.name "
|
||||
"FROM groups g WHERE c.group_id = g.id"
|
||||
)
|
||||
)
|
||||
|
||||
# Now make group_title non-nullable
|
||||
op.alter_column('channels', 'group_title', nullable=False)
|
||||
|
||||
# Drop constraints and columns
|
||||
op.drop_constraint('fk_channels_group_id', 'channels', type_='foreignkey')
|
||||
op.drop_constraint('uix_group_id_name', 'channels', type_='unique')
|
||||
op.create_unique_constraint(op.f('uix_group_title_name'), 'channels', ['group_title', 'name'])
|
||||
op.drop_column('channels', 'group_id')
|
||||
op.drop_table('groups')
|
||||
# ### end Alembic commands ###
|
||||
79
alembic/versions/95b61a92455a_create_initial_tables.py
Normal file
79
alembic/versions/95b61a92455a_create_initial_tables.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""create initial tables
|
||||
|
||||
Revision ID: 95b61a92455a
|
||||
Revises:
|
||||
Create Date: 2025-05-29 14:42:16.239587
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '95b61a92455a'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('channels',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('tvg_id', sa.String(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=False),
|
||||
sa.Column('group_title', sa.String(), nullable=False),
|
||||
sa.Column('tvg_name', sa.String(), nullable=True),
|
||||
sa.Column('tvg_logo', sa.String(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('group_title', 'name', name='uix_group_title_name')
|
||||
)
|
||||
op.create_table('priorities',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('description', sa.String(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('channels_urls',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('channel_id', sa.UUID(), nullable=False),
|
||||
sa.Column('url', sa.String(), nullable=False),
|
||||
sa.Column('in_use', sa.Boolean(), nullable=False),
|
||||
sa.Column('priority_id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['channel_id'], ['channels.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['priority_id'], ['priorities.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
# Seed initial priorities
|
||||
op.bulk_insert(
|
||||
sa.Table(
|
||||
'priorities',
|
||||
sa.MetaData(),
|
||||
sa.Column('id', sa.Integer),
|
||||
sa.Column('description', sa.String),
|
||||
),
|
||||
[
|
||||
{'id': 100, 'description': 'High'},
|
||||
{'id': 200, 'description': 'Medium'},
|
||||
{'id': 300, 'description': 'Low'},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# Remove seeded priorities
|
||||
op.execute("DELETE FROM priorities WHERE id IN (100, 200, 300);")
|
||||
|
||||
# Drop tables
|
||||
op.drop_table('channels_urls')
|
||||
op.drop_table('priorities')
|
||||
op.drop_table('channels')
|
||||
6
app.py
6
app.py
@@ -3,7 +3,7 @@ import os
|
||||
|
||||
import aws_cdk as cdk
|
||||
|
||||
from infrastructure.stack import IptvUpdaterStack
|
||||
from infrastructure.stack import IptvManagerStack
|
||||
|
||||
app = cdk.App()
|
||||
|
||||
@@ -31,9 +31,9 @@ if missing_vars:
|
||||
f"Missing required environment variables: {', '.join(missing_vars)}"
|
||||
)
|
||||
|
||||
IptvUpdaterStack(
|
||||
IptvManagerStack(
|
||||
app,
|
||||
"IptvUpdaterStack",
|
||||
"IptvManagerStack",
|
||||
freedns_user=freedns_user,
|
||||
freedns_password=freedns_password,
|
||||
domain_name=domain_name,
|
||||
|
||||
@@ -32,7 +32,9 @@ def require_roles(*required_roles: str) -> Callable:
|
||||
|
||||
def decorator(endpoint: Callable) -> Callable:
|
||||
@wraps(endpoint)
|
||||
def wrapper(*args, user: CognitoUser = Depends(get_current_user), **kwargs):
|
||||
async def wrapper(
|
||||
*args, user: CognitoUser = Depends(get_current_user), **kwargs
|
||||
):
|
||||
user_roles = set(user.roles or [])
|
||||
needed_roles = set(required_roles)
|
||||
if not needed_roles.issubset(user_roles):
|
||||
|
||||
110
app/iptv/scheduler.py
Normal file
110
app/iptv/scheduler.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from fastapi import FastAPI
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.iptv.stream_manager import StreamManager
|
||||
from app.models.db import ChannelDB
|
||||
from app.utils.database import get_db_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamScheduler:
|
||||
"""Scheduler service for periodic stream validation tasks."""
|
||||
|
||||
def __init__(self, app: Optional[FastAPI] = None):
|
||||
"""
|
||||
Initialize the scheduler with optional FastAPI app integration.
|
||||
|
||||
Args:
|
||||
app: Optional FastAPI app instance for lifecycle integration
|
||||
"""
|
||||
self.scheduler = BackgroundScheduler()
|
||||
self.app = app
|
||||
self.batch_size = int(os.getenv("STREAM_VALIDATION_BATCH_SIZE", "10"))
|
||||
self.schedule_time = os.getenv(
|
||||
"STREAM_VALIDATION_SCHEDULE", "0 3 * * *"
|
||||
) # Default 3 AM daily
|
||||
logger.info(f"Scheduler initialized with app: {app is not None}")
|
||||
|
||||
def validate_streams_batch(self, db_session: Optional[Session] = None) -> None:
|
||||
"""
|
||||
Validate streams and update their status.
|
||||
When batch_size=0, validates all channels.
|
||||
|
||||
Args:
|
||||
db_session: Optional SQLAlchemy session
|
||||
"""
|
||||
db = db_session if db_session else get_db_session()
|
||||
try:
|
||||
manager = StreamManager(db)
|
||||
|
||||
# Get channels to validate
|
||||
query = db.query(ChannelDB)
|
||||
if self.batch_size > 0:
|
||||
query = query.limit(self.batch_size)
|
||||
channels = query.all()
|
||||
|
||||
for channel in channels:
|
||||
try:
|
||||
logger.info(f"Validating streams for channel {channel.id}")
|
||||
manager.validate_and_select_stream(str(channel.id))
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating channel {channel.id}: {str(e)}")
|
||||
continue
|
||||
|
||||
logger.info(f"Completed stream validation of {len(channels)} channels")
|
||||
finally:
|
||||
if db_session is None:
|
||||
db.close()
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the scheduler and add jobs."""
|
||||
if not self.scheduler.running:
|
||||
# Add the scheduled job
|
||||
self.scheduler.add_job(
|
||||
self.validate_streams_batch,
|
||||
trigger=CronTrigger.from_crontab(self.schedule_time),
|
||||
id="daily_stream_validation",
|
||||
)
|
||||
|
||||
# Start the scheduler
|
||||
self.scheduler.start()
|
||||
logger.info(
|
||||
f"Stream scheduler started with daily validation job. "
|
||||
f"Running: {self.scheduler.running}"
|
||||
)
|
||||
|
||||
# Register shutdown handler if FastAPI app is provided
|
||||
if self.app:
|
||||
logger.info(
|
||||
f"Registering scheduler with FastAPI "
|
||||
f"app: {hasattr(self.app, 'state')}"
|
||||
)
|
||||
|
||||
@self.app.on_event("shutdown")
|
||||
def shutdown_scheduler():
|
||||
self.shutdown()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the scheduler gracefully."""
|
||||
if self.scheduler.running:
|
||||
self.scheduler.shutdown()
|
||||
logger.info("Stream scheduler stopped")
|
||||
|
||||
def trigger_manual_validation(self) -> None:
|
||||
"""Trigger manual validation of streams."""
|
||||
logger.info("Manually triggering stream validation")
|
||||
self.validate_streams_batch()
|
||||
|
||||
|
||||
def init_scheduler(app: FastAPI) -> StreamScheduler:
|
||||
"""Initialize and start the scheduler with FastAPI integration."""
|
||||
scheduler = StreamScheduler(app)
|
||||
scheduler.start()
|
||||
return scheduler
|
||||
151
app/iptv/stream_manager.py
Normal file
151
app/iptv/stream_manager.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import logging
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.db import ChannelURL
|
||||
from app.utils.check_streams import StreamValidator
|
||||
from app.utils.database import get_db_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamManager:
|
||||
"""Service for managing and validating channel streams."""
|
||||
|
||||
def __init__(self, db_session: Optional[Session] = None):
|
||||
"""
|
||||
Initialize StreamManager with optional database session.
|
||||
|
||||
Args:
|
||||
db_session: Optional SQLAlchemy session. If None, will create a new one.
|
||||
"""
|
||||
self.db = db_session if db_session else get_db_session()
|
||||
self.validator = StreamValidator()
|
||||
|
||||
def get_streams_for_channel(self, channel_id: str) -> list[ChannelURL]:
|
||||
"""
|
||||
Get all streams for a channel ordered by priority (lowest first),
|
||||
with same-priority streams randomized.
|
||||
|
||||
Args:
|
||||
channel_id: UUID of the channel to get streams for
|
||||
|
||||
Returns:
|
||||
List of ChannelURL objects ordered by priority
|
||||
"""
|
||||
try:
|
||||
# Get all streams for channel ordered by priority
|
||||
streams = (
|
||||
self.db.query(ChannelURL)
|
||||
.filter(ChannelURL.channel_id == channel_id)
|
||||
.order_by(ChannelURL.priority_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Group streams by priority and randomize same-priority streams
|
||||
grouped = {}
|
||||
for stream in streams:
|
||||
if stream.priority_id not in grouped:
|
||||
grouped[stream.priority_id] = []
|
||||
grouped[stream.priority_id].append(stream)
|
||||
|
||||
# Randomize same-priority streams and flatten
|
||||
randomized_streams = []
|
||||
for priority in sorted(grouped.keys()):
|
||||
random.shuffle(grouped[priority])
|
||||
randomized_streams.extend(grouped[priority])
|
||||
|
||||
return randomized_streams
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting streams for channel {channel_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def validate_and_select_stream(self, channel_id: str) -> Optional[str]:
|
||||
"""
|
||||
Find and validate a working stream for the given channel.
|
||||
|
||||
Args:
|
||||
channel_id: UUID of the channel to find a stream for
|
||||
|
||||
Returns:
|
||||
URL of the first working stream found, or None if none found
|
||||
"""
|
||||
try:
|
||||
streams = self.get_streams_for_channel(channel_id)
|
||||
if not streams:
|
||||
logger.warning(f"No streams found for channel {channel_id}")
|
||||
return None
|
||||
|
||||
working_stream = None
|
||||
|
||||
for stream in streams:
|
||||
logger.info(f"Validating stream {stream.url} for channel {channel_id}")
|
||||
is_valid, _ = self.validator.validate_stream(stream.url)
|
||||
|
||||
if is_valid:
|
||||
working_stream = stream
|
||||
break
|
||||
|
||||
if working_stream:
|
||||
self._update_stream_status(working_stream, streams)
|
||||
return working_stream.url
|
||||
else:
|
||||
logger.warning(f"No valid streams found for channel {channel_id}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating streams for channel {channel_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def _update_stream_status(
|
||||
self, working_stream: ChannelURL, all_streams: list[ChannelURL]
|
||||
) -> None:
|
||||
"""
|
||||
Update in_use status for streams (True for working stream, False for others).
|
||||
|
||||
Args:
|
||||
working_stream: The stream that was validated as working
|
||||
all_streams: All streams for the channel
|
||||
"""
|
||||
try:
|
||||
for stream in all_streams:
|
||||
stream.in_use = stream.id == working_stream.id
|
||||
|
||||
self.db.commit()
|
||||
logger.info(
|
||||
f"Updated stream status - set in_use=True for {working_stream.url}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Error updating stream status: {str(e)}")
|
||||
raise
|
||||
|
||||
def __del__(self):
|
||||
"""Close database session when StreamManager is destroyed."""
|
||||
if hasattr(self, "db"):
|
||||
self.db.close()
|
||||
|
||||
|
||||
def get_working_stream(
|
||||
channel_id: str, db_session: Optional[Session] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Convenience function to get a working stream for a channel.
|
||||
|
||||
Args:
|
||||
channel_id: UUID of the channel to get a stream for
|
||||
db_session: Optional SQLAlchemy session
|
||||
|
||||
Returns:
|
||||
URL of the first working stream found, or None if none found
|
||||
"""
|
||||
manager = StreamManager(db_session)
|
||||
try:
|
||||
return manager.validate_and_select_stream(channel_id)
|
||||
finally:
|
||||
if db_session is None: # Only close if we created the session
|
||||
manager.__del__()
|
||||
20
app/main.py
20
app/main.py
@@ -2,7 +2,8 @@ from fastapi import FastAPI
|
||||
from fastapi.concurrency import asynccontextmanager
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from app.routers import auth, channels, playlist, priorities
|
||||
from app.iptv.scheduler import StreamScheduler
|
||||
from app.routers import auth, channels, groups, playlist, priorities, scheduler
|
||||
from app.utils.database import init_db
|
||||
|
||||
|
||||
@@ -10,13 +11,22 @@ from app.utils.database import init_db
|
||||
async def lifespan(app: FastAPI):
|
||||
# Initialize database tables on startup
|
||||
init_db()
|
||||
|
||||
# Initialize and start the stream scheduler
|
||||
scheduler = StreamScheduler(app)
|
||||
app.state.scheduler = scheduler # Store scheduler in app state
|
||||
scheduler.start()
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown scheduler on app shutdown
|
||||
scheduler.shutdown()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
title="IPTV Updater API",
|
||||
description="API for IPTV Updater service",
|
||||
title="IPTV Manager API",
|
||||
description="API for IPTV Manager service",
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
@@ -60,7 +70,7 @@ app.openapi = custom_openapi
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "IPTV Updater API"}
|
||||
return {"message": "IPTV Manager API"}
|
||||
|
||||
|
||||
# Include routers
|
||||
@@ -68,3 +78,5 @@ app.include_router(auth.router)
|
||||
app.include_router(channels.router)
|
||||
app.include_router(playlist.router)
|
||||
app.include_router(priorities.router)
|
||||
app.include_router(groups.router)
|
||||
app.include_router(scheduler.router)
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from .db import Base, ChannelDB, ChannelURL
|
||||
from .db import Base, ChannelDB, ChannelURL, Group, Priority
|
||||
from .schemas import (
|
||||
ChannelCreate,
|
||||
ChannelResponse,
|
||||
ChannelUpdate,
|
||||
ChannelURLCreate,
|
||||
ChannelURLResponse,
|
||||
GroupCreate,
|
||||
GroupResponse,
|
||||
GroupUpdate,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -16,4 +19,9 @@ __all__ = [
|
||||
"ChannelURL",
|
||||
"ChannelURLCreate",
|
||||
"ChannelURLResponse",
|
||||
"Group",
|
||||
"Priority",
|
||||
"GroupCreate",
|
||||
"GroupResponse",
|
||||
"GroupUpdate",
|
||||
]
|
||||
|
||||
@@ -1,18 +1,60 @@
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import (
|
||||
TEXT,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
TypeDecorator,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import declarative_base, relationship
|
||||
|
||||
|
||||
# Custom UUID type for SQLite compatibility
|
||||
class SQLiteUUID(TypeDecorator):
|
||||
"""Enables UUID support for SQLite with proper comparison handling."""
|
||||
|
||||
impl = TEXT
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, uuid.UUID):
|
||||
return str(value)
|
||||
try:
|
||||
# Validate string format by attempting to create UUID
|
||||
uuid.UUID(value)
|
||||
return value
|
||||
except (ValueError, AttributeError):
|
||||
raise ValueError(f"Invalid UUID string format: {value}")
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, uuid.UUID):
|
||||
return value
|
||||
return uuid.UUID(value)
|
||||
|
||||
def compare_values(self, x, y):
|
||||
if x is None or y is None:
|
||||
return x == y
|
||||
return str(x) == str(y)
|
||||
|
||||
|
||||
# Determine which UUID type to use based on environment
|
||||
if os.getenv("MOCK_AUTH", "").lower() == "true":
|
||||
UUID_COLUMN_TYPE = SQLiteUUID()
|
||||
else:
|
||||
UUID_COLUMN_TYPE = UUID(as_uuid=True)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
@@ -25,20 +67,37 @@ class Priority(Base):
|
||||
description = Column(String, nullable=False)
|
||||
|
||||
|
||||
class Group(Base):
|
||||
"""SQLAlchemy model for channel groups"""
|
||||
|
||||
__tablename__ = "groups"
|
||||
|
||||
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String, nullable=False, unique=True)
|
||||
sort_order = Column(Integer, nullable=False, default=0)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Relationship with Channel
|
||||
channels = relationship("ChannelDB", back_populates="group")
|
||||
|
||||
|
||||
class ChannelDB(Base):
|
||||
"""SQLAlchemy model for IPTV channels"""
|
||||
|
||||
__tablename__ = "channels"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
||||
tvg_id = Column(String, nullable=False)
|
||||
name = Column(String, nullable=False)
|
||||
group_title = Column(String, nullable=False)
|
||||
group_id = Column(UUID_COLUMN_TYPE, ForeignKey("groups.id"), nullable=False)
|
||||
tvg_name = Column(String)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("group_title", "name", name="uix_group_title_name"),
|
||||
)
|
||||
__table_args__ = (UniqueConstraint("group_id", "name", name="uix_group_id_name"),)
|
||||
tvg_logo = Column(String)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(
|
||||
@@ -47,10 +106,11 @@ class ChannelDB(Base):
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Relationship with ChannelURL
|
||||
# Relationships
|
||||
urls = relationship(
|
||||
"ChannelURL", back_populates="channel", cascade="all, delete-orphan"
|
||||
)
|
||||
group = relationship("Group", back_populates="channels")
|
||||
|
||||
|
||||
class ChannelURL(Base):
|
||||
@@ -58,9 +118,9 @@ class ChannelURL(Base):
|
||||
|
||||
__tablename__ = "channels_urls"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
||||
channel_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
UUID_COLUMN_TYPE,
|
||||
ForeignKey("channels.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
@@ -53,12 +53,54 @@ class ChannelURLResponse(ChannelURLBase):
|
||||
pass
|
||||
|
||||
|
||||
# New Group Schemas
|
||||
class GroupCreate(BaseModel):
|
||||
"""Pydantic model for creating groups"""
|
||||
|
||||
name: str
|
||||
sort_order: int = Field(default=0, ge=0)
|
||||
|
||||
|
||||
class GroupUpdate(BaseModel):
|
||||
"""Pydantic model for updating groups"""
|
||||
|
||||
name: Optional[str] = None
|
||||
sort_order: Optional[int] = Field(None, ge=0)
|
||||
|
||||
|
||||
class GroupResponse(BaseModel):
|
||||
"""Pydantic model for group responses"""
|
||||
|
||||
id: UUID
|
||||
name: str
|
||||
sort_order: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class GroupSortUpdate(BaseModel):
|
||||
"""Pydantic model for updating a single group's sort order"""
|
||||
|
||||
sort_order: int = Field(ge=0)
|
||||
|
||||
|
||||
class GroupBulkSort(BaseModel):
|
||||
"""Pydantic model for bulk updating group sort orders"""
|
||||
|
||||
groups: list[dict] = Field(
|
||||
description="List of dicts with group_id and new sort_order",
|
||||
json_schema_extra={"example": [{"group_id": "uuid", "sort_order": 1}]},
|
||||
)
|
||||
|
||||
|
||||
class ChannelCreate(BaseModel):
|
||||
"""Pydantic model for creating channels"""
|
||||
|
||||
urls: list[ChannelURLCreate] # List of URL objects with priority
|
||||
name: str
|
||||
group_title: str
|
||||
group_id: UUID
|
||||
tvg_id: str
|
||||
tvg_logo: str
|
||||
tvg_name: str
|
||||
@@ -76,7 +118,7 @@ class ChannelUpdate(BaseModel):
|
||||
"""Pydantic model for updating channels (all fields optional)"""
|
||||
|
||||
name: Optional[str] = Field(None, min_length=1)
|
||||
group_title: Optional[str] = Field(None, min_length=1)
|
||||
group_id: Optional[UUID] = None
|
||||
tvg_id: Optional[str] = Field(None, min_length=1)
|
||||
tvg_logo: Optional[str] = None
|
||||
tvg_name: Optional[str] = Field(None, min_length=1)
|
||||
@@ -87,7 +129,7 @@ class ChannelResponse(BaseModel):
|
||||
|
||||
id: UUID
|
||||
name: str
|
||||
group_title: str
|
||||
group_id: UUID
|
||||
tvg_id: str
|
||||
tvg_logo: str
|
||||
tvg_name: str
|
||||
|
||||
@@ -13,6 +13,8 @@ from app.models import (
|
||||
ChannelURL,
|
||||
ChannelURLCreate,
|
||||
ChannelURLResponse,
|
||||
Group,
|
||||
Priority, # Added Priority import
|
||||
)
|
||||
from app.models.auth import CognitoUser
|
||||
from app.models.schemas import ChannelURLUpdate
|
||||
@@ -29,12 +31,20 @@ def create_channel(
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Create a new channel"""
|
||||
# Check for duplicate channel (same group_title + name)
|
||||
# Check if group exists
|
||||
group = db.query(Group).filter(Group.id == channel.group_id).first()
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Group not found",
|
||||
)
|
||||
|
||||
# Check for duplicate channel (same group_id + name)
|
||||
existing_channel = (
|
||||
db.query(ChannelDB)
|
||||
.filter(
|
||||
and_(
|
||||
ChannelDB.group_title == channel.group_title,
|
||||
ChannelDB.group_id == channel.group_id,
|
||||
ChannelDB.name == channel.name,
|
||||
)
|
||||
)
|
||||
@@ -44,7 +54,7 @@ def create_channel(
|
||||
if existing_channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Channel with same group_title and name already exists",
|
||||
detail="Channel with same group_id and name already exists",
|
||||
)
|
||||
|
||||
# Create channel without URLs first
|
||||
@@ -96,20 +106,27 @@ def update_channel(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||
)
|
||||
|
||||
# Only check for duplicates if name or group_title are being updated
|
||||
if channel.name is not None or channel.group_title is not None:
|
||||
# Only check for duplicates if name or group_id are being updated
|
||||
if channel.name is not None or channel.group_id is not None:
|
||||
name = channel.name if channel.name is not None else db_channel.name
|
||||
group_title = (
|
||||
channel.group_title
|
||||
if channel.group_title is not None
|
||||
else db_channel.group_title
|
||||
group_id = (
|
||||
channel.group_id if channel.group_id is not None else db_channel.group_id
|
||||
)
|
||||
|
||||
# Check if new group exists
|
||||
if channel.group_id is not None:
|
||||
group = db.query(Group).filter(Group.id == channel.group_id).first()
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Group not found",
|
||||
)
|
||||
|
||||
existing_channel = (
|
||||
db.query(ChannelDB)
|
||||
.filter(
|
||||
and_(
|
||||
ChannelDB.group_title == group_title,
|
||||
ChannelDB.group_id == group_id,
|
||||
ChannelDB.name == name,
|
||||
ChannelDB.id != channel_id,
|
||||
)
|
||||
@@ -120,7 +137,7 @@ def update_channel(
|
||||
if existing_channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Channel with same group_title and name already exists",
|
||||
detail="Channel with same group_id and name already exists",
|
||||
)
|
||||
|
||||
# Update only provided fields
|
||||
@@ -133,6 +150,41 @@ def update_channel(
|
||||
return db_channel
|
||||
|
||||
|
||||
@router.delete("/", status_code=status.HTTP_200_OK)
|
||||
@require_roles("admin")
|
||||
def delete_channels(
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Delete all channels"""
|
||||
count = 0
|
||||
try:
|
||||
count = db.query(ChannelDB).count()
|
||||
|
||||
# First delete all channels
|
||||
db.query(ChannelDB).delete()
|
||||
|
||||
# Then delete any URLs that are now orphaned (no channel references)
|
||||
db.query(ChannelURL).filter(
|
||||
~ChannelURL.channel_id.in_(db.query(ChannelDB.id))
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
# Then delete any groups that are now empty
|
||||
db.query(Group).filter(~Group.id.in_(db.query(ChannelDB.group_id))).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
print(f"Error deleting channels: {e}")
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to delete channels",
|
||||
)
|
||||
return {"deleted": count}
|
||||
|
||||
|
||||
@router.delete("/{channel_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
@require_roles("admin")
|
||||
def delete_channel(
|
||||
@@ -163,9 +215,193 @@ def list_channels(
|
||||
return db.query(ChannelDB).offset(skip).limit(limit).all()
|
||||
|
||||
|
||||
# New endpoint to get channels by group
|
||||
@router.get("/groups/{group_id}/channels", response_model=list[ChannelResponse])
|
||||
def get_channels_by_group(
|
||||
group_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get all channels for a specific group"""
|
||||
group = db.query(Group).filter(Group.id == group_id).first()
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
|
||||
)
|
||||
return db.query(ChannelDB).filter(ChannelDB.group_id == group_id).all()
|
||||
|
||||
|
||||
# New endpoint to update a channel's group
|
||||
@router.put("/{channel_id}/group", response_model=ChannelResponse)
|
||||
@require_roles("admin")
|
||||
def update_channel_group(
|
||||
channel_id: UUID,
|
||||
group_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Update a channel's group"""
|
||||
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
|
||||
)
|
||||
|
||||
group = db.query(Group).filter(Group.id == group_id).first()
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
|
||||
)
|
||||
|
||||
# Check for duplicate channel name in new group
|
||||
existing_channel = (
|
||||
db.query(ChannelDB)
|
||||
.filter(
|
||||
and_(
|
||||
ChannelDB.group_id == group_id,
|
||||
ChannelDB.name == channel.name,
|
||||
ChannelDB.id != channel_id,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Channel with same name already exists in target group",
|
||||
)
|
||||
|
||||
channel.group_id = group_id
|
||||
db.commit()
|
||||
db.refresh(channel)
|
||||
return channel
|
||||
|
||||
|
||||
# Bulk Upload and Reset Endpoints
|
||||
@router.post("/bulk-upload", status_code=status.HTTP_200_OK)
|
||||
@require_roles("admin")
|
||||
def bulk_upload_channels(
|
||||
channels: list[dict],
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Bulk upload channels from JSON array"""
|
||||
processed = 0
|
||||
|
||||
# Fetch all priorities from the database, ordered by id
|
||||
priorities = db.query(Priority).order_by(Priority.id).all()
|
||||
priority_map = {i: p.id for i, p in enumerate(priorities)}
|
||||
|
||||
# Get the highest priority_id (which corresponds to the lowest priority level)
|
||||
max_priority_id = None
|
||||
if priorities:
|
||||
max_priority_id = db.query(Priority.id).order_by(Priority.id.desc()).first()[0]
|
||||
|
||||
for channel_data in channels:
|
||||
try:
|
||||
# Get or create group
|
||||
group_name = channel_data.get("group-title")
|
||||
if not group_name:
|
||||
continue
|
||||
|
||||
group = db.query(Group).filter(Group.name == group_name).first()
|
||||
if not group:
|
||||
group = Group(name=group_name)
|
||||
db.add(group)
|
||||
db.flush() # Use flush to make the group available in the session
|
||||
db.refresh(group)
|
||||
|
||||
# Prepare channel data
|
||||
urls = channel_data.get("urls", [])
|
||||
if not isinstance(urls, list):
|
||||
urls = [urls]
|
||||
|
||||
# Assign priorities dynamically based on fetched priorities
|
||||
url_objects = []
|
||||
for i, url in enumerate(urls): # Process all URLs
|
||||
priority_id = priority_map.get(i)
|
||||
if priority_id is None:
|
||||
# If index is out of bounds,
|
||||
# assign the highest priority_id (lowest priority)
|
||||
if max_priority_id is not None:
|
||||
priority_id = max_priority_id
|
||||
else:
|
||||
print(
|
||||
f"Warning: No priorities defined in database. "
|
||||
f"Skipping URL {url}"
|
||||
)
|
||||
continue
|
||||
url_objects.append({"url": url, "priority_id": priority_id})
|
||||
|
||||
# Create channel object with required fields
|
||||
channel_obj = ChannelDB(
|
||||
tvg_id=channel_data.get("tvg-id", ""),
|
||||
name=channel_data.get("name", ""),
|
||||
group_id=group.id,
|
||||
tvg_name=channel_data.get("tvg-name", ""),
|
||||
tvg_logo=channel_data.get("tvg-logo", ""),
|
||||
)
|
||||
|
||||
# Upsert channel
|
||||
existing_channel = (
|
||||
db.query(ChannelDB)
|
||||
.filter(
|
||||
and_(
|
||||
ChannelDB.group_id == group.id,
|
||||
ChannelDB.name == channel_obj.name,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_channel:
|
||||
# Update existing
|
||||
existing_channel.tvg_id = channel_obj.tvg_id
|
||||
existing_channel.tvg_name = channel_obj.tvg_name
|
||||
existing_channel.tvg_logo = channel_obj.tvg_logo
|
||||
|
||||
# Clear and recreate URLs
|
||||
db.query(ChannelURL).filter(
|
||||
ChannelURL.channel_id == existing_channel.id
|
||||
).delete()
|
||||
|
||||
for url in url_objects:
|
||||
db_url = ChannelURL(
|
||||
channel_id=existing_channel.id,
|
||||
url=url["url"],
|
||||
priority_id=url["priority_id"],
|
||||
in_use=False,
|
||||
)
|
||||
db.add(db_url)
|
||||
else:
|
||||
# Create new
|
||||
db.add(channel_obj)
|
||||
db.flush() # Flush to get the new channel's ID
|
||||
db.refresh(channel_obj)
|
||||
|
||||
# Add URLs for new channel
|
||||
for url in url_objects:
|
||||
db_url = ChannelURL(
|
||||
channel_id=channel_obj.id,
|
||||
url=url["url"],
|
||||
priority_id=url["priority_id"],
|
||||
in_use=False,
|
||||
)
|
||||
db.add(db_url)
|
||||
|
||||
db.commit() # Commit all changes for this channel atomically
|
||||
processed += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing channel: {channel_data.get('name', 'Unknown')}")
|
||||
print(f"Exception details: {e}")
|
||||
db.rollback() # Rollback the entire transaction for the failed channel
|
||||
continue
|
||||
|
||||
return {"processed": processed}
|
||||
|
||||
|
||||
# URL Management Endpoints
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{channel_id}/urls",
|
||||
response_model=ChannelURLResponse,
|
||||
|
||||
191
app/routers/groups.py
Normal file
191
app/routers/groups.py
Normal file
@@ -0,0 +1,191 @@
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.auth.dependencies import get_current_user, require_roles
|
||||
from app.models import Group
|
||||
from app.models.auth import CognitoUser
|
||||
from app.models.schemas import (
|
||||
GroupBulkSort,
|
||||
GroupCreate,
|
||||
GroupResponse,
|
||||
GroupSortUpdate,
|
||||
GroupUpdate,
|
||||
)
|
||||
from app.utils.database import get_db
|
||||
|
||||
router = APIRouter(prefix="/groups", tags=["groups"])
|
||||
|
||||
|
||||
@router.post("/", response_model=GroupResponse, status_code=status.HTTP_201_CREATED)
|
||||
@require_roles("admin")
|
||||
def create_group(
|
||||
group: GroupCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Create a new channel group"""
|
||||
# Check for duplicate group name
|
||||
existing_group = db.query(Group).filter(Group.name == group.name).first()
|
||||
if existing_group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Group with this name already exists",
|
||||
)
|
||||
|
||||
db_group = Group(**group.model_dump())
|
||||
db.add(db_group)
|
||||
db.commit()
|
||||
db.refresh(db_group)
|
||||
return db_group
|
||||
|
||||
|
||||
@router.get("/{group_id}", response_model=GroupResponse)
|
||||
def get_group(group_id: UUID, db: Session = Depends(get_db)):
|
||||
"""Get a group by id"""
|
||||
group = db.query(Group).filter(Group.id == group_id).first()
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
|
||||
)
|
||||
return group
|
||||
|
||||
|
||||
@router.put("/{group_id}", response_model=GroupResponse)
|
||||
@require_roles("admin")
|
||||
def update_group(
|
||||
group_id: UUID,
|
||||
group: GroupUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Update a group's name or sort order"""
|
||||
db_group = db.query(Group).filter(Group.id == group_id).first()
|
||||
if not db_group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
|
||||
)
|
||||
|
||||
# Check for duplicate name if name is being updated
|
||||
if group.name is not None and group.name != db_group.name:
|
||||
existing_group = db.query(Group).filter(Group.name == group.name).first()
|
||||
if existing_group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Group with this name already exists",
|
||||
)
|
||||
|
||||
# Update only provided fields
|
||||
update_data = group.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_group, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_group)
|
||||
return db_group
|
||||
|
||||
|
||||
@router.delete("/", status_code=status.HTTP_200_OK)
|
||||
@require_roles("admin")
|
||||
def delete_groups(
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Delete all groups that have no channels (skip groups with channels)"""
|
||||
groups = db.query(Group).all()
|
||||
deleted = 0
|
||||
skipped = 0
|
||||
|
||||
for group in groups:
|
||||
if not group.channels:
|
||||
db.delete(group)
|
||||
deleted += 1
|
||||
else:
|
||||
skipped += 1
|
||||
|
||||
db.commit()
|
||||
return {"deleted": deleted, "skipped": skipped}
|
||||
|
||||
|
||||
@router.delete("/{group_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
@require_roles("admin")
|
||||
def delete_group(
|
||||
group_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Delete a group (only if it has no channels)"""
|
||||
group = db.query(Group).filter(Group.id == group_id).first()
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
|
||||
)
|
||||
|
||||
# Check if group has any channels
|
||||
if group.channels:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot delete group with existing channels",
|
||||
)
|
||||
|
||||
db.delete(group)
|
||||
db.commit()
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/", response_model=list[GroupResponse])
|
||||
def list_groups(db: Session = Depends(get_db)):
|
||||
"""List all groups sorted by sort_order"""
|
||||
return db.query(Group).order_by(Group.sort_order).all()
|
||||
|
||||
|
||||
@router.put("/{group_id}/sort", response_model=GroupResponse)
|
||||
@require_roles("admin")
|
||||
def update_group_sort_order(
|
||||
group_id: UUID,
|
||||
sort_update: GroupSortUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Update a single group's sort order"""
|
||||
db_group = db.query(Group).filter(Group.id == group_id).first()
|
||||
if not db_group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
|
||||
)
|
||||
|
||||
db_group.sort_order = sort_update.sort_order
|
||||
db.commit()
|
||||
db.refresh(db_group)
|
||||
return db_group
|
||||
|
||||
|
||||
@router.post("/reorder", response_model=list[GroupResponse])
|
||||
@require_roles("admin")
|
||||
def bulk_update_sort_orders(
|
||||
bulk_sort: GroupBulkSort,
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Bulk update group sort orders"""
|
||||
groups_to_update = []
|
||||
|
||||
for group_data in bulk_sort.groups:
|
||||
group_id = group_data["group_id"]
|
||||
sort_order = group_data["sort_order"]
|
||||
|
||||
group = db.query(Group).filter(Group.id == str(group_id)).first()
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Group with id {group_id} not found",
|
||||
)
|
||||
|
||||
group.sort_order = sort_order
|
||||
groups_to_update.append(group)
|
||||
|
||||
db.commit()
|
||||
|
||||
# Return all groups in their new order
|
||||
return db.query(Group).order_by(Group.sort_order).all()
|
||||
@@ -1,15 +1,156 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.auth.dependencies import get_current_user
|
||||
from app.iptv.stream_manager import StreamManager
|
||||
from app.models.auth import CognitoUser
|
||||
from app.utils.database import get_db_session
|
||||
|
||||
router = APIRouter(prefix="/playlist", tags=["playlist"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# In-memory store for validation processes
|
||||
validation_processes: dict[str, dict] = {}
|
||||
|
||||
|
||||
@router.get("/protected", summary="Protected endpoint for authenticated users")
|
||||
async def protected_route(user: CognitoUser = Depends(get_current_user)):
|
||||
class ProcessStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class StreamValidationRequest(BaseModel):
|
||||
"""Request model for stream validation endpoint"""
|
||||
|
||||
channel_id: Optional[str] = None
|
||||
|
||||
|
||||
class ValidatedStream(BaseModel):
|
||||
"""Model for a validated working stream"""
|
||||
|
||||
channel_id: str
|
||||
stream_url: str
|
||||
|
||||
|
||||
class ValidationProcessResponse(BaseModel):
|
||||
"""Response model for validation process initiation"""
|
||||
|
||||
process_id: str
|
||||
status: ProcessStatus
|
||||
message: str
|
||||
|
||||
|
||||
class ValidationResultResponse(BaseModel):
|
||||
"""Response model for validation results"""
|
||||
|
||||
process_id: str
|
||||
status: ProcessStatus
|
||||
working_streams: Optional[list[ValidatedStream]] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
def run_stream_validation(process_id: str, channel_id: Optional[str], db: Session):
|
||||
"""Background task to validate streams"""
|
||||
try:
|
||||
validation_processes[process_id]["status"] = ProcessStatus.IN_PROGRESS
|
||||
manager = StreamManager(db)
|
||||
|
||||
if channel_id:
|
||||
stream_url = manager.validate_and_select_stream(channel_id)
|
||||
if stream_url:
|
||||
validation_processes[process_id]["result"] = {
|
||||
"working_streams": [
|
||||
ValidatedStream(channel_id=channel_id, stream_url=stream_url)
|
||||
]
|
||||
}
|
||||
else:
|
||||
validation_processes[process_id]["error"] = (
|
||||
f"No working streams found for channel {channel_id}"
|
||||
)
|
||||
else:
|
||||
# TODO: Implement validation for all channels
|
||||
validation_processes[process_id]["error"] = (
|
||||
"Validation of all channels not yet implemented"
|
||||
)
|
||||
|
||||
validation_processes[process_id]["status"] = ProcessStatus.COMPLETED
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating streams: {str(e)}")
|
||||
validation_processes[process_id]["status"] = ProcessStatus.FAILED
|
||||
validation_processes[process_id]["error"] = str(e)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/validate-streams",
|
||||
summary="Start stream validation process",
|
||||
response_model=ValidationProcessResponse,
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
responses={202: {"description": "Validation process started successfully"}},
|
||||
)
|
||||
async def start_stream_validation(
|
||||
request: StreamValidationRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
db: Session = Depends(get_db_session),
|
||||
):
|
||||
"""
|
||||
Protected endpoint that requires authentication for all users.
|
||||
If the user is authenticated, returns success message.
|
||||
Start asynchronous validation of streams.
|
||||
|
||||
- Returns immediately with a process ID
|
||||
- Use GET /validate-streams/{process_id} to check status
|
||||
"""
|
||||
return {"message": f"Hello {user.username}, you have access to support resources!"}
|
||||
process_id = str(uuid4())
|
||||
validation_processes[process_id] = {
|
||||
"status": ProcessStatus.PENDING,
|
||||
"channel_id": request.channel_id,
|
||||
}
|
||||
|
||||
background_tasks.add_task(run_stream_validation, process_id, request.channel_id, db)
|
||||
|
||||
return {
|
||||
"process_id": process_id,
|
||||
"status": ProcessStatus.PENDING,
|
||||
"message": "Validation process started",
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/validate-streams/{process_id}",
|
||||
summary="Check validation process status",
|
||||
response_model=ValidationResultResponse,
|
||||
responses={
|
||||
200: {"description": "Process status and results"},
|
||||
404: {"description": "Process not found"},
|
||||
},
|
||||
)
|
||||
async def get_validation_status(
|
||||
process_id: str, user: CognitoUser = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Check status of a stream validation process.
|
||||
|
||||
Returns current status and results if completed.
|
||||
"""
|
||||
if process_id not in validation_processes:
|
||||
raise HTTPException(status_code=404, detail="Process not found")
|
||||
|
||||
process = validation_processes[process_id]
|
||||
response = {"process_id": process_id, "status": process["status"]}
|
||||
|
||||
if process["status"] == ProcessStatus.COMPLETED:
|
||||
if "error" in process:
|
||||
response["error"] = process["error"]
|
||||
else:
|
||||
response["working_streams"] = process["result"]["working_streams"]
|
||||
elif process["status"] == ProcessStatus.FAILED:
|
||||
response["error"] = process["error"]
|
||||
|
||||
return response
|
||||
|
||||
@@ -59,6 +59,34 @@ def get_priority(
|
||||
return priority
|
||||
|
||||
|
||||
@router.delete("/", status_code=status.HTTP_200_OK)
|
||||
@require_roles("admin")
|
||||
def delete_priorities(
|
||||
db: Session = Depends(get_db),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
):
|
||||
"""Delete all priorities not in use by channel URLs"""
|
||||
from app.models.db import ChannelURL
|
||||
|
||||
priorities = db.query(Priority).all()
|
||||
deleted = 0
|
||||
skipped = 0
|
||||
|
||||
for priority in priorities:
|
||||
in_use = db.scalar(
|
||||
select(ChannelURL).where(ChannelURL.priority_id == priority.id).limit(1)
|
||||
)
|
||||
|
||||
if not in_use:
|
||||
db.delete(priority)
|
||||
deleted += 1
|
||||
else:
|
||||
skipped += 1
|
||||
|
||||
db.commit()
|
||||
return {"deleted": deleted, "skipped": skipped}
|
||||
|
||||
|
||||
@router.delete("/{priority_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
@require_roles("admin")
|
||||
def delete_priority(
|
||||
|
||||
57
app/routers/scheduler.py
Normal file
57
app/routers/scheduler.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.auth.dependencies import get_current_user, require_roles
|
||||
from app.iptv.scheduler import StreamScheduler
|
||||
from app.models.auth import CognitoUser
|
||||
from app.utils.database import get_db
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/scheduler",
|
||||
tags=["scheduler"],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
|
||||
async def get_scheduler(request: Request) -> StreamScheduler:
|
||||
"""Get the scheduler instance from the app state."""
|
||||
if not hasattr(request.app.state.scheduler, "scheduler"):
|
||||
raise HTTPException(status_code=500, detail="Scheduler not initialized")
|
||||
return request.app.state.scheduler
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
@require_roles("admin")
|
||||
def scheduler_health(
|
||||
scheduler: StreamScheduler = Depends(get_scheduler),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Check scheduler health status (admin only)."""
|
||||
try:
|
||||
job = scheduler.scheduler.get_job("daily_stream_validation")
|
||||
next_run = str(job.next_run_time) if job and job.next_run_time else None
|
||||
|
||||
return {
|
||||
"status": "running" if scheduler.scheduler.running else "stopped",
|
||||
"next_run": next_run,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to check scheduler health: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/trigger")
|
||||
@require_roles("admin")
|
||||
def trigger_validation(
|
||||
scheduler: StreamScheduler = Depends(get_scheduler),
|
||||
user: CognitoUser = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Manually trigger stream validation (admin only)."""
|
||||
scheduler.trigger_manual_validation()
|
||||
return JSONResponse(
|
||||
status_code=202, content={"message": "Stream validation triggered"}
|
||||
)
|
||||
@@ -66,6 +66,8 @@ class StreamValidator:
|
||||
"application/octet-stream",
|
||||
"application/x-mpegURL",
|
||||
]
|
||||
if content_type is None:
|
||||
return False
|
||||
return any(ct in content_type for ct in valid_types)
|
||||
|
||||
def parse_playlist(self, file_path):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
|
||||
import boto3
|
||||
from requests import Session
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
@@ -19,16 +20,16 @@ def get_db_credentials():
|
||||
|
||||
ssm = boto3.client("ssm", region_name=AWS_REGION)
|
||||
try:
|
||||
host = ssm.get_parameter(Name="/iptv-updater/DB_HOST", WithDecryption=True)[
|
||||
host = ssm.get_parameter(Name="/iptv-manager/DB_HOST", WithDecryption=True)[
|
||||
"Parameter"
|
||||
]["Value"]
|
||||
user = ssm.get_parameter(Name="/iptv-updater/DB_USER", WithDecryption=True)[
|
||||
user = ssm.get_parameter(Name="/iptv-manager/DB_USER", WithDecryption=True)[
|
||||
"Parameter"
|
||||
]["Value"]
|
||||
password = ssm.get_parameter(
|
||||
Name="/iptv-updater/DB_PASSWORD", WithDecryption=True
|
||||
Name="/iptv-manager/DB_PASSWORD", WithDecryption=True
|
||||
)["Parameter"]["Value"]
|
||||
dbname = ssm.get_parameter(Name="/iptv-updater/DB_NAME", WithDecryption=True)[
|
||||
dbname = ssm.get_parameter(Name="/iptv-manager/DB_NAME", WithDecryption=True)[
|
||||
"Parameter"
|
||||
]["Value"]
|
||||
return f"postgresql://{user}:{password}@{host}/{dbname}"
|
||||
@@ -53,3 +54,8 @@ def get_db():
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def get_db_session() -> Session:
|
||||
"""Get a direct database session (non-generator version)"""
|
||||
return SessionLocal()
|
||||
|
||||
@@ -3,10 +3,11 @@ version: '3.8'
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:13
|
||||
container_name: postgres
|
||||
environment:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: iptv_updater
|
||||
POSTGRES_DB: iptv_manager
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
|
||||
@@ -6,7 +6,7 @@ services:
|
||||
environment:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: iptv_updater
|
||||
POSTGRES_DB: iptv_manager
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
@@ -20,7 +20,7 @@ services:
|
||||
DB_USER: postgres
|
||||
DB_PASSWORD: postgres
|
||||
DB_HOST: postgres
|
||||
DB_NAME: iptv_updater
|
||||
DB_NAME: iptv_manager
|
||||
MOCK_AUTH: "true"
|
||||
ports:
|
||||
- "8000:8000"
|
||||
|
||||
@@ -9,7 +9,7 @@ from aws_cdk import aws_ssm as ssm
|
||||
from constructs import Construct
|
||||
|
||||
|
||||
class IptvUpdaterStack(Stack):
|
||||
class IptvManagerStack(Stack):
|
||||
def __init__(
|
||||
self,
|
||||
scope: Construct,
|
||||
@@ -27,7 +27,7 @@ class IptvUpdaterStack(Stack):
|
||||
# Create VPC
|
||||
vpc = ec2.Vpc(
|
||||
self,
|
||||
"IptvUpdaterVPC",
|
||||
"IptvManagerVPC",
|
||||
max_azs=2, # Need at least 2 AZs for RDS subnet group
|
||||
nat_gateways=0, # No NAT Gateway to stay in free tier
|
||||
subnet_configuration=[
|
||||
@@ -44,7 +44,7 @@ class IptvUpdaterStack(Stack):
|
||||
|
||||
# Security Group
|
||||
security_group = ec2.SecurityGroup(
|
||||
self, "IptvUpdaterSG", vpc=vpc, allow_all_outbound=True
|
||||
self, "IptvManagerSG", vpc=vpc, allow_all_outbound=True
|
||||
)
|
||||
|
||||
security_group.add_ingress_rule(
|
||||
@@ -66,18 +66,18 @@ class IptvUpdaterStack(Stack):
|
||||
"Allow PostgreSQL traffic for tunneling",
|
||||
)
|
||||
|
||||
# Key pair for IPTV Updater instance
|
||||
# Key pair for IPTV Manager instance
|
||||
key_pair = ec2.KeyPair(
|
||||
self,
|
||||
"IptvUpdaterKeyPair",
|
||||
key_pair_name="iptv-updater-key",
|
||||
"IptvManagerKeyPair",
|
||||
key_pair_name="iptv-manager-key",
|
||||
public_key_material=ssh_public_key,
|
||||
)
|
||||
|
||||
# Create IAM role for EC2
|
||||
role = iam.Role(
|
||||
self,
|
||||
"IptvUpdaterRole",
|
||||
"IptvManagerRole",
|
||||
assumed_by=iam.ServicePrincipal("ec2.amazonaws.com"),
|
||||
)
|
||||
|
||||
@@ -111,37 +111,11 @@ class IptvUpdaterStack(Stack):
|
||||
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonCognitoReadOnly")
|
||||
)
|
||||
|
||||
# EC2 Instance
|
||||
instance = ec2.Instance(
|
||||
self,
|
||||
"IptvUpdaterInstance",
|
||||
vpc=vpc,
|
||||
vpc_subnets=ec2.SubnetSelection(subnet_type=ec2.SubnetType.PUBLIC),
|
||||
instance_type=ec2.InstanceType.of(
|
||||
ec2.InstanceClass.T2, ec2.InstanceSize.MICRO
|
||||
),
|
||||
machine_image=ec2.AmazonLinuxImage(
|
||||
generation=ec2.AmazonLinuxGeneration.AMAZON_LINUX_2023
|
||||
),
|
||||
security_group=security_group,
|
||||
key_pair=key_pair,
|
||||
role=role,
|
||||
# Option: 1: Enable auto-assign public IP (free tier compatible)
|
||||
associate_public_ip_address=True,
|
||||
)
|
||||
|
||||
# Option: 2: Create Elastic IP (not free tier compatible)
|
||||
# eip = ec2.CfnEIP(
|
||||
# self, "IptvUpdaterEIP",
|
||||
# domain="vpc",
|
||||
# instance_id=instance.instance_id
|
||||
# )
|
||||
|
||||
# Add Cognito User Pool
|
||||
user_pool = cognito.UserPool(
|
||||
self,
|
||||
"IptvUpdaterUserPool",
|
||||
user_pool_name="iptv-updater-users",
|
||||
"IptvManagerUserPool",
|
||||
user_pool_name="iptv-manager-users",
|
||||
self_sign_up_enabled=False, # Only admins can create users
|
||||
password_policy=cognito.PasswordPolicy(
|
||||
min_length=8,
|
||||
@@ -156,7 +130,7 @@ class IptvUpdaterStack(Stack):
|
||||
|
||||
# Add App Client with the correct callback URL
|
||||
client = user_pool.add_client(
|
||||
"IptvUpdaterClient",
|
||||
"IptvManagerClient",
|
||||
access_token_validity=Duration.minutes(60),
|
||||
id_token_validity=Duration.minutes(60),
|
||||
refresh_token_validity=Duration.days(1),
|
||||
@@ -171,8 +145,8 @@ class IptvUpdaterStack(Stack):
|
||||
|
||||
# Add domain for hosted UI
|
||||
domain = user_pool.add_domain(
|
||||
"IptvUpdaterDomain",
|
||||
cognito_domain=cognito.CognitoDomainOptions(domain_prefix="iptv-updater"),
|
||||
"IptvManagerDomain",
|
||||
cognito_domain=cognito.CognitoDomainOptions(domain_prefix="iptv-manager"),
|
||||
)
|
||||
|
||||
# Read the userdata script with proper path resolution
|
||||
@@ -226,7 +200,7 @@ class IptvUpdaterStack(Stack):
|
||||
# Create RDS PostgreSQL instance (free tier compatible - db.t3.micro)
|
||||
db = rds.DatabaseInstance(
|
||||
self,
|
||||
"IptvUpdaterDB",
|
||||
"IptvManagerDB",
|
||||
engine=rds.DatabaseInstanceEngine.postgres(
|
||||
version=rds.PostgresEngineVersion.VER_13
|
||||
),
|
||||
@@ -240,7 +214,7 @@ class IptvUpdaterStack(Stack):
|
||||
security_groups=[rds_sg],
|
||||
allocated_storage=10,
|
||||
max_allocated_storage=10,
|
||||
database_name="iptv_updater",
|
||||
database_name="iptv_manager",
|
||||
removal_policy=RemovalPolicy.DESTROY,
|
||||
deletion_protection=False,
|
||||
publicly_accessible=False, # Avoid public IPv4 charges
|
||||
@@ -252,28 +226,28 @@ class IptvUpdaterStack(Stack):
|
||||
)
|
||||
|
||||
# Store DB connection info in SSM Parameter Store
|
||||
ssm.StringParameter(
|
||||
db_host_param = ssm.StringParameter(
|
||||
self,
|
||||
"DBHostParam",
|
||||
parameter_name="/iptv-updater/DB_HOST",
|
||||
parameter_name="/iptv-manager/DB_HOST",
|
||||
string_value=db.db_instance_endpoint_address,
|
||||
)
|
||||
ssm.StringParameter(
|
||||
db_name_param = ssm.StringParameter(
|
||||
self,
|
||||
"DBNameParam",
|
||||
parameter_name="/iptv-updater/DB_NAME",
|
||||
string_value="iptv_updater",
|
||||
parameter_name="/iptv-manager/DB_NAME",
|
||||
string_value="iptv_manager",
|
||||
)
|
||||
ssm.StringParameter(
|
||||
db_user_param = ssm.StringParameter(
|
||||
self,
|
||||
"DBUserParam",
|
||||
parameter_name="/iptv-updater/DB_USER",
|
||||
parameter_name="/iptv-manager/DB_USER",
|
||||
string_value=db.secret.secret_value_from_json("username").to_string(),
|
||||
)
|
||||
ssm.StringParameter(
|
||||
db_pass_param = ssm.StringParameter(
|
||||
self,
|
||||
"DBPassParam",
|
||||
parameter_name="/iptv-updater/DB_PASSWORD",
|
||||
parameter_name="/iptv-manager/DB_PASSWORD",
|
||||
string_value=db.secret.secret_value_from_json("password").to_string(),
|
||||
)
|
||||
|
||||
@@ -282,6 +256,39 @@ class IptvUpdaterStack(Stack):
|
||||
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonSSMReadOnlyAccess")
|
||||
)
|
||||
|
||||
# EC2 Instance (created after all dependencies are ready)
|
||||
instance = ec2.Instance(
|
||||
self,
|
||||
"IptvManagerInstance",
|
||||
vpc=vpc,
|
||||
vpc_subnets=ec2.SubnetSelection(subnet_type=ec2.SubnetType.PUBLIC),
|
||||
instance_type=ec2.InstanceType.of(
|
||||
ec2.InstanceClass.T2, ec2.InstanceSize.MICRO
|
||||
),
|
||||
machine_image=ec2.AmazonLinuxImage(
|
||||
generation=ec2.AmazonLinuxGeneration.AMAZON_LINUX_2023
|
||||
),
|
||||
security_group=security_group,
|
||||
key_pair=key_pair,
|
||||
role=role,
|
||||
# Option: 1: Enable auto-assign public IP (free tier compatible)
|
||||
associate_public_ip_address=True,
|
||||
)
|
||||
|
||||
# Ensure instance depends on SSM parameters being created
|
||||
instance.node.add_dependency(db)
|
||||
instance.node.add_dependency(db_host_param)
|
||||
instance.node.add_dependency(db_name_param)
|
||||
instance.node.add_dependency(db_user_param)
|
||||
instance.node.add_dependency(db_pass_param)
|
||||
|
||||
# Option: 2: Create Elastic IP (not free tier compatible)
|
||||
# eip = ec2.CfnEIP(
|
||||
# self, "IptvManagerEIP",
|
||||
# domain="vpc",
|
||||
# instance_id=instance.instance_id
|
||||
# )
|
||||
|
||||
# Update instance with userdata
|
||||
instance.add_user_data(userdata.render())
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
# Update system and install required packages
|
||||
dnf update -y
|
||||
dnf install -y python3-pip git cronie nginx certbot python3-certbot-nginx
|
||||
dnf install -y python3-pip git cronie nginx certbot python3-certbot-nginx postgresql15.x86_64 awscli
|
||||
|
||||
# Start and enable crond service
|
||||
systemctl start crond
|
||||
@@ -11,27 +11,69 @@ systemctl enable crond
|
||||
cd /home/ec2-user
|
||||
|
||||
git clone ${REPO_URL}
|
||||
cd iptv-updater-aws
|
||||
cd iptv-manager-service
|
||||
|
||||
# Install Python packages with --ignore-installed to prevent conflicts with RPM packages
|
||||
pip3 install --ignore-installed -r requirements.txt
|
||||
|
||||
# Retrieve DB credentials from SSM Parameter Store with retries
|
||||
echo "Attempting to retrieve DB credentials from SSM..."
|
||||
for i in {1..30}; do
|
||||
DB_HOST=$(aws ssm get-parameter --name "/iptv-manager/DB_HOST" --query "Parameter.Value" --output text 2>/dev/null)
|
||||
DB_NAME=$(aws ssm get-parameter --name "/iptv-manager/DB_NAME" --query "Parameter.Value" --output text 2>/dev/null)
|
||||
DB_USER=$(aws ssm get-parameter --name "/iptv-manager/DB_USER" --query "Parameter.Value" --output text 2>/dev/null)
|
||||
DB_PASSWORD=$(aws ssm get-parameter --name "/iptv-manager/DB_PASSWORD" --query "Parameter.Value" --output text 2>/dev/null)
|
||||
|
||||
if [ -n "$DB_HOST" ] && [ -n "$DB_NAME" ] && [ -n "$DB_USER" ] && [ -n "$DB_PASSWORD" ]; then
|
||||
echo "Successfully retrieved all DB credentials"
|
||||
break
|
||||
fi
|
||||
|
||||
echo "Waiting for SSM parameters to be available... (attempt $i/30)"
|
||||
sleep 5
|
||||
done
|
||||
|
||||
if [ -z "$DB_HOST" ] || [ -z "$DB_NAME" ] || [ -z "$DB_USER" ] || [ -z "$DB_PASSWORD" ]; then
|
||||
echo "ERROR: Failed to retrieve all required DB credentials after 30 attempts"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DB_HOST
|
||||
export DB_NAME
|
||||
export DB_USER
|
||||
export DB_PASSWORD
|
||||
|
||||
# Set PGPASSWORD for psql to use
|
||||
export PGPASSWORD=$DB_PASSWORD
|
||||
|
||||
# Wait for PostgreSQL to be ready
|
||||
echo "Waiting for PostgreSQL to start..."
|
||||
until psql -h $DB_HOST -U $DB_USER -d postgres -c '\q'; do
|
||||
sleep 1
|
||||
done
|
||||
echo "PostgreSQL is ready."
|
||||
|
||||
# Create database if it does not exist
|
||||
DB_EXISTS=$(psql -h $DB_HOST -U $DB_USER -d postgres -tc "SELECT 1 FROM pg_database WHERE datname = '$DB_NAME';")
|
||||
if [ -z "$DB_EXISTS" ]; then
|
||||
echo "Creating database $DB_NAME..."
|
||||
psql -h $DB_HOST -U $DB_USER -d postgres -c "CREATE DATABASE $DB_NAME;"
|
||||
echo "Database $DB_NAME created."
|
||||
fi
|
||||
|
||||
# Run database migrations
|
||||
alembic upgrade head
|
||||
|
||||
# Seed initial priorities
|
||||
python3 -c "from app.utils.database import SessionLocal; from app.models.db import Priority; db = SessionLocal(); db.add_all([Priority(id=100, description='High'), Priority(id=200, description='Medium'), Priority(id=300, description='Low')]); db.commit()"
|
||||
|
||||
# Create systemd service file
|
||||
cat << 'EOF' > /etc/systemd/system/iptv-updater.service
|
||||
cat << 'EOF' > /etc/systemd/system/iptv-manager.service
|
||||
[Unit]
|
||||
Description=IPTV Updater Service
|
||||
Description=IPTV Manager Service
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User=ec2-user
|
||||
WorkingDirectory=/home/ec2-user/iptv-updater-aws
|
||||
WorkingDirectory=/home/ec2-user/iptv-manager-service
|
||||
ExecStart=/usr/local/bin/uvicorn app.main:app --host 127.0.0.1 --port 8000
|
||||
EnvironmentFile=/etc/environment
|
||||
Restart=always
|
||||
@@ -56,7 +98,7 @@ sudo mkdir -p /etc/nginx/ssl
|
||||
--reloadcmd "service nginx force-reload"
|
||||
|
||||
# Create nginx config
|
||||
cat << EOF > /etc/nginx/conf.d/iptvUpdater.conf
|
||||
cat << EOF > /etc/nginx/conf.d/iptvManager.conf
|
||||
server {
|
||||
listen 80;
|
||||
server_name ${DOMAIN_NAME} *.${DOMAIN_NAME};
|
||||
@@ -83,5 +125,5 @@ EOF
|
||||
# Start nginx service
|
||||
systemctl enable nginx
|
||||
systemctl start nginx
|
||||
systemctl enable iptv-updater
|
||||
systemctl start iptv-updater
|
||||
systemctl enable iptv-manager
|
||||
systemctl start iptv-manager
|
||||
@@ -25,3 +25,7 @@ known-first-party = ["app"]
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--cov=app --cov-report=term-missing --cov-fail-under=70"
|
||||
testpaths = ["tests"]
|
||||
@@ -5,12 +5,21 @@ python_functions = test_*
|
||||
asyncio_mode = auto
|
||||
filterwarnings =
|
||||
ignore::DeprecationWarning:botocore.auth
|
||||
ignore:The 'app' shortcut is now deprecated:DeprecationWarning:httpx._client
|
||||
|
||||
# Coverage configuration
|
||||
addopts =
|
||||
--cov=app
|
||||
--cov-report=term-missing
|
||||
|
||||
# Test environment variables
|
||||
env =
|
||||
MOCK_AUTH=true
|
||||
DB_USER=test_user
|
||||
DB_PASSWORD=test_password
|
||||
DB_HOST=localhost
|
||||
DB_NAME=iptv_manager_test
|
||||
|
||||
# Test markers
|
||||
markers =
|
||||
slow: mark tests as slow running
|
||||
|
||||
@@ -15,3 +15,8 @@ alembic==1.16.1
|
||||
pytest==8.1.1
|
||||
pytest-asyncio==0.23.6
|
||||
pytest-mock==3.12.0
|
||||
pytest-cov==4.1.0
|
||||
pytest-env==1.1.1
|
||||
httpx==0.27.0
|
||||
pre-commit
|
||||
apscheduler==3.10.4
|
||||
@@ -25,7 +25,7 @@ cdk deploy --app="python3 ${PWD}/app.py"
|
||||
# Update application on running instances
|
||||
INSTANCE_IDS=$(aws ec2 describe-instances \
|
||||
--region us-east-2 \
|
||||
--filters "Name=tag:Name,Values=IptvUpdaterStack/IptvUpdaterInstance" \
|
||||
--filters "Name=tag:Name,Values=IptvManagerStack/IptvManagerInstance" \
|
||||
"Name=instance-state-name,Values=running" \
|
||||
--query "Reservations[].Instances[].InstanceId" \
|
||||
--output text)
|
||||
@@ -35,7 +35,7 @@ for INSTANCE_ID in $INSTANCE_IDS; do
|
||||
aws ssm send-command \
|
||||
--instance-ids "$INSTANCE_ID" \
|
||||
--document-name "AWS-RunShellScript" \
|
||||
--parameters '{"commands":["cd /home/ec2-user/iptv-updater-aws && git pull && pip3 install -r requirements.txt && alembic upgrade head && sudo systemctl restart iptv-updater"]}' \
|
||||
--parameters '{"commands":["cd /home/ec2-user/iptv-manager-service && git pull && pip3 install -r requirements.txt && alembic upgrade head && sudo systemctl restart iptv-manager"]}' \
|
||||
--no-cli-pager \
|
||||
--no-paginate
|
||||
done
|
||||
|
||||
@@ -7,9 +7,7 @@ python3 -m pip install -r requirements.txt
|
||||
# Install and configure pre-commit hooks
|
||||
pre-commit install
|
||||
pre-commit install-hooks
|
||||
pre-commit autoupdate
|
||||
|
||||
# Initialize and run database migrations
|
||||
alembic upgrade head
|
||||
|
||||
# Seed initial data
|
||||
python3 -c "from app.utils.database import SessionLocal; from app.models.db import Priority; db = SessionLocal(); db.add_all([Priority(id=100, description='High'), Priority(id=200, description='Medium'), Priority(id=300, description='Low')]); db.commit()"
|
||||
# Verify pytest setup
|
||||
python3 -m pytest
|
||||
@@ -1,21 +1,26 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Start PostgreSQL
|
||||
docker-compose -f docker/docker-compose-db.yml up -d
|
||||
|
||||
# Set mock auth and database environment variables
|
||||
# Set environment variables
|
||||
export MOCK_AUTH=true
|
||||
export DB_HOST=localhost
|
||||
export DB_USER=postgres
|
||||
export DB_PASSWORD=postgres
|
||||
export DB_HOST=localhost
|
||||
export DB_NAME=iptv_updater
|
||||
export DB_NAME=iptv_manager
|
||||
|
||||
echo "Ensuring database $DB_NAME exists using conditional DDL..."
|
||||
PGPASSWORD=$DB_PASSWORD docker exec -i postgres psql -U $DB_USER <<< "SELECT 'CREATE DATABASE $DB_NAME' WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = '$DB_NAME')\gexec"
|
||||
echo "Database $DB_NAME check complete."
|
||||
|
||||
# Run database migrations
|
||||
alembic upgrade head
|
||||
|
||||
# Start FastAPI
|
||||
nohup uvicorn app.main:app --host 127.0.0.1 --port 8000 > app.log 2>&1 &
|
||||
echo $! > iptv-updater.pid
|
||||
echo $! > iptv-manager.pid
|
||||
|
||||
echo "Services started:"
|
||||
echo "- PostgreSQL running on localhost:5432"
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Stop FastAPI
|
||||
if [ -f iptv-updater.pid ]; then
|
||||
kill $(cat iptv-updater.pid)
|
||||
rm iptv-updater.pid
|
||||
if [ -f iptv-manager.pid ]; then
|
||||
kill $(cat iptv-manager.pid)
|
||||
rm iptv-manager.pid
|
||||
echo "Stopped FastAPI"
|
||||
fi
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ def mock_get_user_from_token(token: str) -> CognitoUser:
|
||||
|
||||
# Mock endpoint for testing the require_roles decorator
|
||||
@require_roles("admin")
|
||||
async def mock_protected_endpoint(user: CognitoUser = Depends(get_current_user)):
|
||||
def mock_protected_endpoint(user: CognitoUser = Depends(get_current_user)):
|
||||
return {"message": "Success", "user": user.username}
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ async def test_require_roles_no_roles():
|
||||
async def test_require_roles_multiple_roles():
|
||||
# Test requiring multiple roles
|
||||
@require_roles("admin", "super_user")
|
||||
async def mock_multi_role_endpoint(user: CognitoUser = Depends(get_current_user)):
|
||||
def mock_multi_role_endpoint(user: CognitoUser = Depends(get_current_user)):
|
||||
return {"message": "Success"}
|
||||
|
||||
# User with all required roles
|
||||
@@ -173,3 +173,33 @@ def test_mock_auth_import(monkeypatch):
|
||||
|
||||
# Reload again to restore original state
|
||||
importlib.reload(app.auth.dependencies)
|
||||
|
||||
|
||||
def test_cognito_auth_import(monkeypatch):
|
||||
"""Test that cognito auth is imported when MOCK_AUTH=false (covers line 14)"""
|
||||
# Save original env var value
|
||||
original_value = os.environ.get("MOCK_AUTH")
|
||||
|
||||
try:
|
||||
# Set MOCK_AUTH to false
|
||||
monkeypatch.setenv("MOCK_AUTH", "false")
|
||||
|
||||
# Reload the dependencies module to trigger the import condition
|
||||
import app.auth.dependencies
|
||||
|
||||
importlib.reload(app.auth.dependencies)
|
||||
|
||||
# Verify that get_user_from_token was imported from app.auth.cognito
|
||||
from app.auth.dependencies import get_user_from_token
|
||||
|
||||
assert get_user_from_token.__module__ == "app.auth.cognito"
|
||||
|
||||
finally:
|
||||
# Restore original env var
|
||||
if original_value is None:
|
||||
monkeypatch.delenv("MOCK_AUTH", raising=False)
|
||||
else:
|
||||
monkeypatch.setenv("MOCK_AUTH", original_value)
|
||||
|
||||
# Reload again to restore original state
|
||||
importlib.reload(app.auth.dependencies)
|
||||
|
||||
144
tests/models/test_db.py
Normal file
144
tests/models/test_db.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import os
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from app.models.db import UUID_COLUMN_TYPE, Base, SQLiteUUID
|
||||
|
||||
# --- Test SQLiteUUID Type ---
|
||||
|
||||
|
||||
def test_sqliteuuid_process_bind_param_none():
|
||||
"""Test SQLiteUUID.process_bind_param with None returns None"""
|
||||
uuid_type = SQLiteUUID()
|
||||
assert uuid_type.process_bind_param(None, None) is None
|
||||
|
||||
|
||||
def test_sqliteuuid_process_bind_param_valid_uuid():
|
||||
"""Test SQLiteUUID.process_bind_param with valid UUID returns string"""
|
||||
uuid_type = SQLiteUUID()
|
||||
test_uuid = uuid.uuid4()
|
||||
assert uuid_type.process_bind_param(test_uuid, None) == str(test_uuid)
|
||||
|
||||
|
||||
def test_sqliteuuid_process_bind_param_valid_string():
|
||||
"""Test SQLiteUUID.process_bind_param with valid UUID string returns string"""
|
||||
uuid_type = SQLiteUUID()
|
||||
test_uuid_str = "550e8400-e29b-41d4-a716-446655440000"
|
||||
assert uuid_type.process_bind_param(test_uuid_str, None) == test_uuid_str
|
||||
|
||||
|
||||
def test_sqliteuuid_process_bind_param_invalid_string():
|
||||
"""Test SQLiteUUID.process_bind_param raises ValueError for invalid UUID"""
|
||||
uuid_type = SQLiteUUID()
|
||||
with pytest.raises(ValueError, match="Invalid UUID string format"):
|
||||
uuid_type.process_bind_param("invalid-uuid", None)
|
||||
|
||||
|
||||
def test_sqliteuuid_process_result_value_none():
|
||||
"""Test SQLiteUUID.process_result_value with None returns None"""
|
||||
uuid_type = SQLiteUUID()
|
||||
assert uuid_type.process_result_value(None, None) is None
|
||||
|
||||
|
||||
def test_sqliteuuid_process_result_value_valid_string():
|
||||
"""Test SQLiteUUID.process_result_value converts string to UUID"""
|
||||
uuid_type = SQLiteUUID()
|
||||
test_uuid = uuid.uuid4()
|
||||
result = uuid_type.process_result_value(str(test_uuid), None)
|
||||
assert isinstance(result, uuid.UUID)
|
||||
assert result == test_uuid
|
||||
|
||||
|
||||
def test_sqliteuuid_process_result_value_uuid_object():
|
||||
"""Test SQLiteUUID.process_result_value: UUID object returns itself."""
|
||||
uuid_type = SQLiteUUID()
|
||||
test_uuid = uuid.uuid4()
|
||||
result = uuid_type.process_result_value(test_uuid, None)
|
||||
assert isinstance(result, uuid.UUID)
|
||||
assert result is test_uuid # Ensure it's the same object, not a new one
|
||||
|
||||
|
||||
def test_sqliteuuid_compare_values_none():
|
||||
"""Test SQLiteUUID.compare_values handles None values"""
|
||||
uuid_type = SQLiteUUID()
|
||||
assert uuid_type.compare_values(None, None) is True
|
||||
assert uuid_type.compare_values(None, uuid.uuid4()) is False
|
||||
assert uuid_type.compare_values(uuid.uuid4(), None) is False
|
||||
|
||||
|
||||
def test_sqliteuuid_compare_values_uuid():
|
||||
"""Test SQLiteUUID.compare_values compares UUIDs as strings"""
|
||||
uuid_type = SQLiteUUID()
|
||||
test_uuid = uuid.uuid4()
|
||||
assert uuid_type.compare_values(test_uuid, test_uuid) is True
|
||||
assert uuid_type.compare_values(test_uuid, uuid.uuid4()) is False
|
||||
|
||||
|
||||
def test_sqlite_uuid_comparison():
|
||||
"""Test SQLiteUUID comparison functionality (moved from db_mocks.py)"""
|
||||
uuid_type = SQLiteUUID()
|
||||
|
||||
# Test equal UUIDs
|
||||
uuid1 = uuid.uuid4()
|
||||
uuid2 = uuid.UUID(str(uuid1))
|
||||
assert uuid_type.compare_values(uuid1, uuid2) is True
|
||||
|
||||
# Test UUID vs string
|
||||
assert uuid_type.compare_values(uuid1, str(uuid1)) is True
|
||||
|
||||
# Test None comparisons
|
||||
assert uuid_type.compare_values(None, None) is True
|
||||
assert uuid_type.compare_values(uuid1, None) is False
|
||||
assert uuid_type.compare_values(None, uuid1) is False
|
||||
|
||||
# Test different UUIDs
|
||||
uuid3 = uuid.uuid4()
|
||||
assert uuid_type.compare_values(uuid1, uuid3) is False
|
||||
|
||||
|
||||
def test_sqlite_uuid_binding():
|
||||
"""Test SQLiteUUID binding parameter handling (moved from db_mocks.py)"""
|
||||
uuid_type = SQLiteUUID()
|
||||
|
||||
# Test UUID object binding
|
||||
uuid_obj = uuid.uuid4()
|
||||
assert uuid_type.process_bind_param(uuid_obj, None) == str(uuid_obj)
|
||||
|
||||
# Test valid UUID string binding
|
||||
uuid_str = str(uuid.uuid4())
|
||||
assert uuid_type.process_bind_param(uuid_str, None) == uuid_str
|
||||
|
||||
# Test None handling
|
||||
assert uuid_type.process_bind_param(None, None) is None
|
||||
|
||||
# Test invalid UUID string
|
||||
with pytest.raises(ValueError):
|
||||
uuid_type.process_bind_param("invalid-uuid", None)
|
||||
|
||||
|
||||
# --- Test UUID Column Type Configuration ---
|
||||
|
||||
|
||||
def test_uuid_column_type_default():
|
||||
"""Test UUID_COLUMN_TYPE uses SQLiteUUID in test environment"""
|
||||
assert isinstance(UUID_COLUMN_TYPE, SQLiteUUID)
|
||||
|
||||
|
||||
@patch.dict(os.environ, {"MOCK_AUTH": "false"})
|
||||
def test_uuid_column_type_postgres():
|
||||
"""Test UUID_COLUMN_TYPE uses Postgres UUID when MOCK_AUTH=false"""
|
||||
# Need to re-import to get the patched environment
|
||||
from importlib import reload
|
||||
|
||||
from app import models
|
||||
|
||||
reload(models.db)
|
||||
from sqlalchemy.dialects.postgresql import UUID as PostgresUUID
|
||||
|
||||
from app.models.db import UUID_COLUMN_TYPE
|
||||
|
||||
assert isinstance(UUID_COLUMN_TYPE, PostgresUUID)
|
||||
43
tests/routers/mocks.py
Normal file
43
tests/routers/mocks.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from app.iptv.scheduler import StreamScheduler
|
||||
|
||||
|
||||
class MockScheduler:
|
||||
"""Base mock APScheduler instance"""
|
||||
|
||||
running = True
|
||||
start = Mock()
|
||||
shutdown = Mock()
|
||||
add_job = Mock()
|
||||
remove_job = Mock()
|
||||
get_job = Mock(return_value=None)
|
||||
|
||||
def __init__(self, running=True):
|
||||
self.running = running
|
||||
|
||||
|
||||
def create_trigger_mock(triggered_ref: dict) -> callable:
|
||||
"""Create a mock trigger function that updates a reference when called"""
|
||||
|
||||
def trigger_mock():
|
||||
triggered_ref["value"] = True
|
||||
|
||||
return trigger_mock
|
||||
|
||||
|
||||
async def mock_get_scheduler(
|
||||
request: Request, scheduler_class=MockScheduler, running=True, **kwargs
|
||||
) -> StreamScheduler:
|
||||
"""Mock dependency for get_scheduler with customization options"""
|
||||
scheduler = StreamScheduler()
|
||||
mock_scheduler = scheduler_class(running=running)
|
||||
|
||||
# Apply any additional attributes/methods
|
||||
for key, value in kwargs.items():
|
||||
setattr(mock_scheduler, key, value)
|
||||
|
||||
scheduler.scheduler = mock_scheduler
|
||||
return scheduler
|
||||
File diff suppressed because it is too large
Load Diff
461
tests/routers/test_groups.py
Normal file
461
tests/routers/test_groups.py
Normal file
@@ -0,0 +1,461 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.auth.dependencies import get_current_user
|
||||
from app.utils.database import get_db
|
||||
|
||||
# Import mocks and fixtures
|
||||
from tests.utils.auth_test_fixtures import (
|
||||
admin_user_client,
|
||||
db_session,
|
||||
non_admin_user_client,
|
||||
)
|
||||
from tests.utils.db_mocks import (
|
||||
MockChannelDB,
|
||||
MockGroup,
|
||||
create_mock_priorities_and_group,
|
||||
)
|
||||
|
||||
# --- Test Cases For Group Creation ---
|
||||
|
||||
|
||||
def test_create_group_success(db_session: Session, admin_user_client):
|
||||
group_data = {"name": "Test Group", "sort_order": 1}
|
||||
response = admin_user_client.post("/groups/", json=group_data)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
data = response.json()
|
||||
assert data["name"] == "Test Group"
|
||||
assert data["sort_order"] == 1
|
||||
assert "id" in data
|
||||
assert "created_at" in data
|
||||
assert "updated_at" in data
|
||||
|
||||
# Verify in DB
|
||||
db_group = (
|
||||
db_session.query(MockGroup).filter(MockGroup.name == "Test Group").first()
|
||||
)
|
||||
assert db_group is not None
|
||||
assert db_group.sort_order == 1
|
||||
|
||||
|
||||
def test_create_group_duplicate(db_session: Session, admin_user_client):
|
||||
# Create initial group
|
||||
initial_group = MockGroup(
|
||||
id=uuid.uuid4(),
|
||||
name="Duplicate Group",
|
||||
sort_order=1,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(initial_group)
|
||||
db_session.commit()
|
||||
|
||||
# Attempt to create duplicate
|
||||
response = admin_user_client.post(
|
||||
"/groups/", json={"name": "Duplicate Group", "sort_order": 2}
|
||||
)
|
||||
assert response.status_code == status.HTTP_409_CONFLICT
|
||||
assert "already exists" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_create_group_forbidden_for_non_admin(
|
||||
db_session: Session, non_admin_user_client
|
||||
):
|
||||
response = non_admin_user_client.post(
|
||||
"/groups/", json={"name": "Forbidden Group", "sort_order": 1}
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert "required roles" in response.json()["detail"]
|
||||
|
||||
|
||||
# --- Test Cases For Get Group ---
|
||||
|
||||
|
||||
def test_get_group_success(db_session: Session, admin_user_client):
|
||||
# Create a group first
|
||||
test_group = MockGroup(
|
||||
id=uuid.uuid4(),
|
||||
name="Get Me Group",
|
||||
sort_order=1,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(test_group)
|
||||
db_session.commit()
|
||||
|
||||
response = admin_user_client.get(f"/groups/{test_group.id}")
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["id"] == str(test_group.id)
|
||||
assert data["name"] == "Get Me Group"
|
||||
assert data["sort_order"] == 1
|
||||
|
||||
|
||||
def test_get_group_not_found(db_session: Session, admin_user_client):
|
||||
random_uuid = uuid.uuid4()
|
||||
response = admin_user_client.get(f"/groups/{random_uuid}")
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "Group not found" in response.json()["detail"]
|
||||
|
||||
|
||||
# --- Test Cases For Update Group ---
|
||||
|
||||
|
||||
def test_update_group_success(db_session: Session, admin_user_client):
|
||||
# Create initial group
|
||||
group_id = uuid.uuid4()
|
||||
test_group = MockGroup(
|
||||
id=group_id,
|
||||
name="Update Me",
|
||||
sort_order=1,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(test_group)
|
||||
db_session.commit()
|
||||
|
||||
update_data = {"name": "Updated Name", "sort_order": 2}
|
||||
response = admin_user_client.put(f"/groups/{group_id}", json=update_data)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["name"] == "Updated Name"
|
||||
assert data["sort_order"] == 2
|
||||
|
||||
# Verify in DB
|
||||
db_group = db_session.query(MockGroup).filter(MockGroup.id == group_id).first()
|
||||
assert db_group.name == "Updated Name"
|
||||
assert db_group.sort_order == 2
|
||||
|
||||
|
||||
def test_update_group_conflict(db_session: Session, admin_user_client):
|
||||
# Create two groups
|
||||
group1 = MockGroup(
|
||||
id=uuid.uuid4(),
|
||||
name="Group One",
|
||||
sort_order=1,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
group2 = MockGroup(
|
||||
id=uuid.uuid4(),
|
||||
name="Group Two",
|
||||
sort_order=2,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add_all([group1, group2])
|
||||
db_session.commit()
|
||||
|
||||
# Try to rename group2 to conflict with group1
|
||||
response = admin_user_client.put(f"/groups/{group2.id}", json={"name": "Group One"})
|
||||
assert response.status_code == status.HTTP_409_CONFLICT
|
||||
assert "already exists" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_update_group_not_found(db_session: Session, admin_user_client):
|
||||
random_uuid = uuid.uuid4()
|
||||
response = admin_user_client.put(
|
||||
f"/groups/{random_uuid}", json={"name": "Non-existent"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "Group not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_update_group_forbidden_for_non_admin(
|
||||
db_session: Session, non_admin_user_client, admin_user_client
|
||||
):
|
||||
# Create group with admin
|
||||
group_id = uuid.uuid4()
|
||||
test_group = MockGroup(
|
||||
id=group_id,
|
||||
name="Admin Created",
|
||||
sort_order=1,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(test_group)
|
||||
db_session.commit()
|
||||
|
||||
# Attempt update with non-admin
|
||||
response = non_admin_user_client.put(
|
||||
f"/groups/{group_id}", json={"name": "Non-Admin Update"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert "required roles" in response.json()["detail"]
|
||||
|
||||
|
||||
# --- Test Cases For Delete Group ---
|
||||
|
||||
|
||||
def test_delete_all_groups_success(db_session, admin_user_client):
|
||||
"""Test reset groups endpoint"""
|
||||
# Create test data
|
||||
group1_id = create_mock_priorities_and_group(db_session, [], "Group A")
|
||||
group2_id = create_mock_priorities_and_group(db_session, [], "Group B")
|
||||
|
||||
# Add channel to group2
|
||||
channel_data = [
|
||||
{
|
||||
"group-title": "Group A",
|
||||
"tvg_id": "channel1.tv",
|
||||
"name": "Channel One",
|
||||
"url": ["http://test.com", "http://example.com"],
|
||||
}
|
||||
]
|
||||
admin_user_client.post("/channels/bulk-upload", json=channel_data)
|
||||
|
||||
# Reset groups
|
||||
response = admin_user_client.delete("/groups")
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json()["deleted"] == 1 # Only group2 should be deleted
|
||||
assert response.json()["skipped"] == 1 # group1 has channels
|
||||
|
||||
# Verify group2 deleted, group1 remains
|
||||
assert (
|
||||
db_session.query(MockGroup).filter(MockGroup.id == group1_id).first()
|
||||
is not None
|
||||
)
|
||||
assert db_session.query(MockGroup).filter(MockGroup.id == group2_id).first() is None
|
||||
|
||||
|
||||
def test_delete_all_groups_forbidden_for_non_admin(db_session, non_admin_user_client):
|
||||
"""Test reset groups requires admin role"""
|
||||
response = non_admin_user_client.delete("/groups")
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
def test_delete_group_success(db_session: Session, admin_user_client):
|
||||
# Create group
|
||||
group_id = uuid.uuid4()
|
||||
test_group = MockGroup(
|
||||
id=group_id,
|
||||
name="Delete Me",
|
||||
sort_order=1,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(test_group)
|
||||
db_session.commit()
|
||||
|
||||
# Verify exists before delete
|
||||
assert (
|
||||
db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is not None
|
||||
)
|
||||
|
||||
response = admin_user_client.delete(f"/groups/{group_id}")
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
|
||||
# Verify deleted
|
||||
assert db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is None
|
||||
|
||||
|
||||
def test_delete_group_with_channels_fails(db_session: Session, admin_user_client):
|
||||
# Create group with channel
|
||||
group_id = uuid.uuid4()
|
||||
test_group = MockGroup(
|
||||
id=group_id,
|
||||
name="Group With Channels",
|
||||
sort_order=1,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(test_group)
|
||||
|
||||
# Create channel in this group
|
||||
test_channel = MockChannelDB(
|
||||
id=uuid.uuid4(),
|
||||
tvg_id="channel1.tv",
|
||||
name="Channel 1",
|
||||
group_id=group_id,
|
||||
tvg_name="Channel1",
|
||||
tvg_logo="logo.png",
|
||||
)
|
||||
db_session.add(test_channel)
|
||||
db_session.commit()
|
||||
|
||||
response = admin_user_client.delete(f"/groups/{group_id}")
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert "existing channels" in response.json()["detail"]
|
||||
|
||||
# Verify group still exists
|
||||
assert (
|
||||
db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is not None
|
||||
)
|
||||
|
||||
|
||||
def test_delete_group_not_found(db_session: Session, admin_user_client):
|
||||
random_uuid = uuid.uuid4()
|
||||
response = admin_user_client.delete(f"/groups/{random_uuid}")
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "Group not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_delete_group_forbidden_for_non_admin(
|
||||
db_session: Session, non_admin_user_client, admin_user_client
|
||||
):
|
||||
# Create group with admin
|
||||
group_id = uuid.uuid4()
|
||||
test_group = MockGroup(
|
||||
id=group_id,
|
||||
name="Admin Created",
|
||||
sort_order=1,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(test_group)
|
||||
db_session.commit()
|
||||
|
||||
# Attempt delete with non-admin
|
||||
response = non_admin_user_client.delete(f"/groups/{group_id}")
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert "required roles" in response.json()["detail"]
|
||||
|
||||
# Verify group still exists
|
||||
assert (
|
||||
db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is not None
|
||||
)
|
||||
|
||||
|
||||
# --- Test Cases For List Groups ---
|
||||
|
||||
|
||||
def test_list_groups_empty(db_session: Session, admin_user_client):
|
||||
response = admin_user_client.get("/groups/")
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json() == []
|
||||
|
||||
|
||||
def test_list_groups_with_data(db_session: Session, admin_user_client):
|
||||
# Create some groups
|
||||
groups = [
|
||||
MockGroup(
|
||||
id=uuid.uuid4(),
|
||||
name=f"Group {i}",
|
||||
sort_order=i,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
db_session.add_all(groups)
|
||||
db_session.commit()
|
||||
|
||||
response = admin_user_client.get("/groups/")
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert len(data) == 3
|
||||
assert data[0]["sort_order"] == 0 # Should be sorted by sort_order
|
||||
assert data[1]["sort_order"] == 1
|
||||
assert data[2]["sort_order"] == 2
|
||||
|
||||
|
||||
# --- Test Cases For Sort Order Updates ---
|
||||
|
||||
|
||||
def test_update_group_sort_order_success(db_session: Session, admin_user_client):
|
||||
# Create group
|
||||
group_id = uuid.uuid4()
|
||||
test_group = MockGroup(
|
||||
id=group_id,
|
||||
name="Sort Me",
|
||||
sort_order=1,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(test_group)
|
||||
db_session.commit()
|
||||
|
||||
response = admin_user_client.put(f"/groups/{group_id}/sort", json={"sort_order": 5})
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["sort_order"] == 5
|
||||
|
||||
# Verify in DB
|
||||
db_group = db_session.query(MockGroup).filter(MockGroup.id == group_id).first()
|
||||
assert db_group.sort_order == 5
|
||||
|
||||
|
||||
def test_update_group_sort_order_not_found(db_session: Session, admin_user_client):
|
||||
"""Test that updating sort order for non-existent group returns 404"""
|
||||
random_uuid = uuid.uuid4()
|
||||
response = admin_user_client.put(
|
||||
f"/groups/{random_uuid}/sort", json={"sort_order": 5}
|
||||
)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "Group not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_bulk_update_sort_orders_success(db_session: Session, admin_user_client):
|
||||
# Create groups
|
||||
groups = [
|
||||
MockGroup(
|
||||
id=uuid.uuid4(),
|
||||
name=f"Group {i}",
|
||||
sort_order=i,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
print(groups)
|
||||
db_session.add_all(groups)
|
||||
db_session.commit()
|
||||
|
||||
# Bulk update sort orders (reverse order)
|
||||
bulk_data = {
|
||||
"groups": [
|
||||
{"group_id": str(groups[0].id), "sort_order": 2},
|
||||
{"group_id": str(groups[1].id), "sort_order": 1},
|
||||
{"group_id": str(groups[2].id), "sort_order": 0},
|
||||
]
|
||||
}
|
||||
response = admin_user_client.post("/groups/reorder", json=bulk_data)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert len(data) == 3
|
||||
|
||||
# Create a dictionary for easy lookup of returned group data by ID
|
||||
returned_groups_map = {item["id"]: item for item in data}
|
||||
|
||||
# Verify each group has its expected new sort_order
|
||||
assert returned_groups_map[str(groups[0].id)]["sort_order"] == 2
|
||||
assert returned_groups_map[str(groups[1].id)]["sort_order"] == 1
|
||||
assert returned_groups_map[str(groups[2].id)]["sort_order"] == 0
|
||||
|
||||
# Verify in DB
|
||||
db_groups = db_session.query(MockGroup).order_by(MockGroup.sort_order).all()
|
||||
assert db_groups[0].sort_order == 2
|
||||
assert db_groups[1].sort_order == 1
|
||||
assert db_groups[2].sort_order == 0
|
||||
|
||||
|
||||
def test_bulk_update_sort_orders_invalid_group(db_session: Session, admin_user_client):
|
||||
# Create one group
|
||||
group_id = uuid.uuid4()
|
||||
test_group = MockGroup(
|
||||
id=group_id,
|
||||
name="Valid Group",
|
||||
sort_order=1,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add(test_group)
|
||||
db_session.commit()
|
||||
|
||||
# Try to update with invalid group
|
||||
bulk_data = {
|
||||
"groups": [
|
||||
{"group_id": str(group_id), "sort_order": 2},
|
||||
{"group_id": str(uuid.uuid4()), "sort_order": 1}, # Invalid group
|
||||
]
|
||||
}
|
||||
response = admin_user_client.post("/groups/reorder", json=bulk_data)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "not found" in response.json()["detail"]
|
||||
|
||||
# Verify original sort order unchanged
|
||||
db_group = db_session.query(MockGroup).filter(MockGroup.id == group_id).first()
|
||||
assert db_group.sort_order == 1
|
||||
261
tests/routers/test_playlist.py
Normal file
261
tests/routers/test_playlist.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.auth.dependencies import get_current_user
|
||||
|
||||
# Import the router we're testing
|
||||
from app.routers.playlist import (
|
||||
ProcessStatus,
|
||||
ValidationProcessResponse,
|
||||
ValidationResultResponse,
|
||||
router,
|
||||
validation_processes,
|
||||
)
|
||||
from app.utils.database import get_db
|
||||
|
||||
# Import mocks and fixtures
|
||||
from tests.utils.auth_test_fixtures import (
|
||||
admin_user_client,
|
||||
db_session,
|
||||
non_admin_user_client,
|
||||
)
|
||||
from tests.utils.db_mocks import MockChannelDB
|
||||
|
||||
# --- Test Fixtures ---
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stream_manager():
|
||||
with patch("app.routers.playlist.StreamManager") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
# --- Test Cases For Stream Validation ---
|
||||
|
||||
|
||||
def test_start_stream_validation_success(
|
||||
db_session: Session, admin_user_client, mock_stream_manager
|
||||
):
|
||||
"""Test starting a stream validation process"""
|
||||
mock_instance = mock_stream_manager.return_value
|
||||
mock_instance.validate_and_select_stream.return_value = "http://valid.stream.url"
|
||||
|
||||
response = admin_user_client.post(
|
||||
"/playlist/validate-streams", json={"channel_id": "test-channel"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||
data = response.json()
|
||||
assert "process_id" in data
|
||||
assert data["status"] == ProcessStatus.PENDING
|
||||
assert data["message"] == "Validation process started"
|
||||
|
||||
# Verify process was added to tracking
|
||||
process_id = data["process_id"]
|
||||
assert process_id in validation_processes
|
||||
# In test environment, background tasks run synchronously so status may be COMPLETED
|
||||
assert validation_processes[process_id]["status"] in [
|
||||
ProcessStatus.PENDING,
|
||||
ProcessStatus.COMPLETED,
|
||||
]
|
||||
assert validation_processes[process_id]["channel_id"] == "test-channel"
|
||||
|
||||
|
||||
def test_get_validation_status_pending(db_session: Session, admin_user_client):
|
||||
"""Test checking status of pending validation"""
|
||||
process_id = str(uuid.uuid4())
|
||||
validation_processes[process_id] = {
|
||||
"status": ProcessStatus.PENDING,
|
||||
"channel_id": "test-channel",
|
||||
}
|
||||
|
||||
response = admin_user_client.get(f"/playlist/validate-streams/{process_id}")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["process_id"] == process_id
|
||||
assert data["status"] == ProcessStatus.PENDING
|
||||
assert data["working_streams"] is None
|
||||
assert data["error"] is None
|
||||
|
||||
|
||||
def test_get_validation_status_completed(db_session: Session, admin_user_client):
|
||||
"""Test checking status of completed validation"""
|
||||
process_id = str(uuid.uuid4())
|
||||
validation_processes[process_id] = {
|
||||
"status": ProcessStatus.COMPLETED,
|
||||
"channel_id": "test-channel",
|
||||
"result": {
|
||||
"working_streams": [
|
||||
{"channel_id": "test-channel", "stream_url": "http://valid.stream.url"}
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
response = admin_user_client.get(f"/playlist/validate-streams/{process_id}")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["process_id"] == process_id
|
||||
assert data["status"] == ProcessStatus.COMPLETED
|
||||
assert len(data["working_streams"]) == 1
|
||||
assert data["working_streams"][0]["channel_id"] == "test-channel"
|
||||
assert data["working_streams"][0]["stream_url"] == "http://valid.stream.url"
|
||||
assert data["error"] is None
|
||||
|
||||
|
||||
def test_get_validation_status_completed_with_error(
|
||||
db_session: Session, admin_user_client
|
||||
):
|
||||
"""Test checking status of completed validation with error"""
|
||||
process_id = str(uuid.uuid4())
|
||||
validation_processes[process_id] = {
|
||||
"status": ProcessStatus.COMPLETED,
|
||||
"channel_id": "test-channel",
|
||||
"error": "No working streams found for channel test-channel",
|
||||
}
|
||||
|
||||
response = admin_user_client.get(f"/playlist/validate-streams/{process_id}")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["process_id"] == process_id
|
||||
assert data["status"] == ProcessStatus.COMPLETED
|
||||
assert data["working_streams"] is None
|
||||
assert data["error"] == "No working streams found for channel test-channel"
|
||||
|
||||
|
||||
def test_get_validation_status_failed(db_session: Session, admin_user_client):
|
||||
"""Test checking status of failed validation"""
|
||||
process_id = str(uuid.uuid4())
|
||||
validation_processes[process_id] = {
|
||||
"status": ProcessStatus.FAILED,
|
||||
"channel_id": "test-channel",
|
||||
"error": "Validation error occurred",
|
||||
}
|
||||
|
||||
response = admin_user_client.get(f"/playlist/validate-streams/{process_id}")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["process_id"] == process_id
|
||||
assert data["status"] == ProcessStatus.FAILED
|
||||
assert data["working_streams"] is None
|
||||
assert data["error"] == "Validation error occurred"
|
||||
|
||||
|
||||
def test_get_validation_status_not_found(db_session: Session, admin_user_client):
|
||||
"""Test checking status of non-existent process"""
|
||||
random_uuid = str(uuid.uuid4())
|
||||
response = admin_user_client.get(f"/playlist/validate-streams/{random_uuid}")
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "Process not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_run_stream_validation_success(mock_stream_manager, db_session):
|
||||
"""Test the background validation task success case"""
|
||||
process_id = str(uuid.uuid4())
|
||||
validation_processes[process_id] = {
|
||||
"status": ProcessStatus.PENDING,
|
||||
"channel_id": "test-channel",
|
||||
}
|
||||
|
||||
mock_instance = mock_stream_manager.return_value
|
||||
mock_instance.validate_and_select_stream.return_value = "http://valid.stream.url"
|
||||
|
||||
from app.routers.playlist import run_stream_validation
|
||||
|
||||
run_stream_validation(process_id, "test-channel", db_session)
|
||||
|
||||
assert validation_processes[process_id]["status"] == ProcessStatus.COMPLETED
|
||||
assert len(validation_processes[process_id]["result"]["working_streams"]) == 1
|
||||
assert (
|
||||
validation_processes[process_id]["result"]["working_streams"][0].channel_id
|
||||
== "test-channel"
|
||||
)
|
||||
assert (
|
||||
validation_processes[process_id]["result"]["working_streams"][0].stream_url
|
||||
== "http://valid.stream.url"
|
||||
)
|
||||
|
||||
|
||||
def test_run_stream_validation_failure(mock_stream_manager, db_session):
|
||||
"""Test the background validation task failure case"""
|
||||
process_id = str(uuid.uuid4())
|
||||
validation_processes[process_id] = {
|
||||
"status": ProcessStatus.PENDING,
|
||||
"channel_id": "test-channel",
|
||||
}
|
||||
|
||||
mock_instance = mock_stream_manager.return_value
|
||||
mock_instance.validate_and_select_stream.return_value = None
|
||||
|
||||
from app.routers.playlist import run_stream_validation
|
||||
|
||||
run_stream_validation(process_id, "test-channel", db_session)
|
||||
|
||||
assert validation_processes[process_id]["status"] == ProcessStatus.COMPLETED
|
||||
assert "error" in validation_processes[process_id]
|
||||
assert "No working streams found" in validation_processes[process_id]["error"]
|
||||
|
||||
|
||||
def test_run_stream_validation_exception(mock_stream_manager, db_session):
|
||||
"""Test the background validation task exception case"""
|
||||
process_id = str(uuid.uuid4())
|
||||
validation_processes[process_id] = {
|
||||
"status": ProcessStatus.PENDING,
|
||||
"channel_id": "test-channel",
|
||||
}
|
||||
|
||||
mock_instance = mock_stream_manager.return_value
|
||||
mock_instance.validate_and_select_stream.side_effect = Exception("Test error")
|
||||
|
||||
from app.routers.playlist import run_stream_validation
|
||||
|
||||
run_stream_validation(process_id, "test-channel", db_session)
|
||||
|
||||
assert validation_processes[process_id]["status"] == ProcessStatus.FAILED
|
||||
assert "error" in validation_processes[process_id]
|
||||
assert "Test error" in validation_processes[process_id]["error"]
|
||||
|
||||
|
||||
def test_start_stream_validation_no_channel_id(
|
||||
db_session: Session, admin_user_client, mock_stream_manager
|
||||
):
|
||||
"""Test starting validation without channel_id"""
|
||||
response = admin_user_client.post("/playlist/validate-streams", json={})
|
||||
|
||||
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||
data = response.json()
|
||||
assert "process_id" in data
|
||||
assert data["status"] == ProcessStatus.PENDING
|
||||
|
||||
# Verify process was added to tracking
|
||||
process_id = data["process_id"]
|
||||
assert process_id in validation_processes
|
||||
assert validation_processes[process_id]["status"] in [
|
||||
ProcessStatus.PENDING,
|
||||
ProcessStatus.COMPLETED,
|
||||
]
|
||||
assert validation_processes[process_id]["channel_id"] is None
|
||||
assert "not yet implemented" in validation_processes[process_id].get("error", "")
|
||||
|
||||
|
||||
def test_run_stream_validation_no_channel_id(mock_stream_manager, db_session):
|
||||
"""Test background validation without channel_id"""
|
||||
process_id = str(uuid.uuid4())
|
||||
validation_processes[process_id] = {"status": ProcessStatus.PENDING}
|
||||
|
||||
from app.routers.playlist import run_stream_validation
|
||||
|
||||
run_stream_validation(process_id, None, db_session)
|
||||
|
||||
assert validation_processes[process_id]["status"] == ProcessStatus.COMPLETED
|
||||
assert "error" in validation_processes[process_id]
|
||||
assert "not yet implemented" in validation_processes[process_id]["error"]
|
||||
241
tests/routers/test_priorities.py
Normal file
241
tests/routers/test_priorities.py
Normal file
@@ -0,0 +1,241 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.routers.priorities import router as priorities_router
|
||||
|
||||
# Import fixtures and mocks
|
||||
from tests.utils.auth_test_fixtures import (
|
||||
admin_user_client,
|
||||
db_session,
|
||||
non_admin_user_client,
|
||||
)
|
||||
from tests.utils.db_mocks import (
|
||||
MockChannelDB,
|
||||
MockChannelURL,
|
||||
MockGroup,
|
||||
MockPriority,
|
||||
create_mock_priorities_and_group,
|
||||
)
|
||||
|
||||
# --- Test Cases For Priority Creation ---
|
||||
|
||||
|
||||
def test_create_priority_success(db_session: Session, admin_user_client):
|
||||
priority_data = {"id": 100, "description": "Test Priority"}
|
||||
response = admin_user_client.post("/priorities/", json=priority_data)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
data = response.json()
|
||||
assert data["id"] == 100
|
||||
assert data["description"] == "Test Priority"
|
||||
|
||||
# Verify in DB
|
||||
db_priority = db_session.get(MockPriority, 100)
|
||||
assert db_priority is not None
|
||||
assert db_priority.description == "Test Priority"
|
||||
|
||||
|
||||
def test_create_priority_duplicate(db_session: Session, admin_user_client):
|
||||
# Create initial priority
|
||||
priority_data = {"id": 100, "description": "Original Priority"}
|
||||
response1 = admin_user_client.post("/priorities/", json=priority_data)
|
||||
assert response1.status_code == status.HTTP_201_CREATED
|
||||
|
||||
# Attempt to create with same ID
|
||||
response2 = admin_user_client.post("/priorities/", json=priority_data)
|
||||
assert response2.status_code == status.HTTP_409_CONFLICT
|
||||
assert "already exists" in response2.json()["detail"]
|
||||
|
||||
|
||||
def test_create_priority_forbidden_for_non_admin(
|
||||
db_session: Session, non_admin_user_client
|
||||
):
|
||||
priority_data = {"id": 100, "description": "Test Priority"}
|
||||
response = non_admin_user_client.post("/priorities/", json=priority_data)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert "required roles" in response.json()["detail"]
|
||||
|
||||
|
||||
# --- Test Cases For List Priorities ---
|
||||
|
||||
|
||||
def test_list_priorities_empty(db_session: Session, admin_user_client):
|
||||
response = admin_user_client.get("/priorities/")
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json() == []
|
||||
|
||||
|
||||
def test_list_priorities_with_data(db_session: Session, admin_user_client):
|
||||
# Create some test priorities
|
||||
priorities = [
|
||||
MockPriority(id=100, description="High"),
|
||||
MockPriority(id=200, description="Medium"),
|
||||
MockPriority(id=300, description="Low"),
|
||||
]
|
||||
for priority in priorities:
|
||||
db_session.add(priority)
|
||||
db_session.commit()
|
||||
|
||||
response = admin_user_client.get("/priorities/")
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert len(data) == 3
|
||||
assert data[0]["id"] == 100
|
||||
assert data[0]["description"] == "High"
|
||||
assert data[1]["id"] == 200
|
||||
assert data[1]["description"] == "Medium"
|
||||
assert data[2]["id"] == 300
|
||||
assert data[2]["description"] == "Low"
|
||||
|
||||
|
||||
def test_list_priorities_forbidden_for_non_admin(
|
||||
db_session: Session, non_admin_user_client
|
||||
):
|
||||
response = non_admin_user_client.get("/priorities/")
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert "required roles" in response.json()["detail"]
|
||||
|
||||
|
||||
# --- Test Cases For Get Priority ---
|
||||
|
||||
|
||||
def test_get_priority_success(db_session: Session, admin_user_client):
|
||||
# Create a test priority
|
||||
priority = MockPriority(id=100, description="Test Priority")
|
||||
db_session.add(priority)
|
||||
db_session.commit()
|
||||
|
||||
response = admin_user_client.get("/priorities/100")
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["id"] == 100
|
||||
assert data["description"] == "Test Priority"
|
||||
|
||||
|
||||
def test_get_priority_not_found(db_session: Session, admin_user_client):
|
||||
response = admin_user_client.get("/priorities/999")
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "Priority not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_get_priority_forbidden_for_non_admin(
|
||||
db_session: Session, non_admin_user_client
|
||||
):
|
||||
response = non_admin_user_client.get("/priorities/100")
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert "required roles" in response.json()["detail"]
|
||||
|
||||
|
||||
# --- Test Cases For Delete Priority ---
|
||||
|
||||
|
||||
def test_delete_all_priorities_success(db_session, admin_user_client):
|
||||
"""Test reset priorities endpoint"""
|
||||
# Create test data
|
||||
priorities = [(100, "High"), (200, "Medium"), (300, "Low")]
|
||||
for id, desc in priorities:
|
||||
db_session.add(MockPriority(id=id, description=desc))
|
||||
db_session.commit()
|
||||
|
||||
# Create channel using priority 100
|
||||
create_mock_priorities_and_group(db_session, [], "Test Group")
|
||||
channel_data = [
|
||||
{
|
||||
"group-title": "Test Group",
|
||||
"tvg_id": "test.tv",
|
||||
"name": "Test Channel",
|
||||
"urls": ["http://test.com"],
|
||||
}
|
||||
]
|
||||
admin_user_client.post("/channels/bulk-upload", json=channel_data)
|
||||
|
||||
# Delete all priorities
|
||||
response = admin_user_client.delete("/priorities")
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json()["deleted"] == 2 # Medium and Low priorities
|
||||
assert response.json()["skipped"] == 1 # High priority is in use
|
||||
|
||||
# Verify only priority 100 remains
|
||||
priorities = db_session.query(MockPriority).all()
|
||||
assert len(priorities) == 1
|
||||
assert priorities[0].id == 100
|
||||
|
||||
|
||||
def test_reset_priorities_forbidden_for_non_admin(db_session, non_admin_user_client):
|
||||
"""Test reset priorities requires admin role"""
|
||||
response = non_admin_user_client.delete("/priorities")
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
def test_delete_priority_success(db_session: Session, admin_user_client):
|
||||
# Create a test priority
|
||||
priority = MockPriority(id=100, description="To Delete")
|
||||
db_session.add(priority)
|
||||
db_session.commit()
|
||||
|
||||
response = admin_user_client.delete("/priorities/100")
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
|
||||
# Verify priority is gone from DB
|
||||
db_priority = db_session.get(MockPriority, 100)
|
||||
assert db_priority is None
|
||||
|
||||
|
||||
def test_delete_priority_not_found(db_session: Session, admin_user_client):
|
||||
response = admin_user_client.delete("/priorities/999")
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "Priority not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_delete_priority_in_use(db_session: Session, admin_user_client):
|
||||
# Create a priority and a channel URL using it
|
||||
priority = MockPriority(id=100, description="In Use")
|
||||
group_id = uuid.uuid4()
|
||||
test_group = MockGroup(
|
||||
id=group_id,
|
||||
name="Group With Channels",
|
||||
sort_order=1,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db_session.add_all([priority, test_group])
|
||||
db_session.commit()
|
||||
|
||||
# Create a channel first
|
||||
channel = MockChannelDB(
|
||||
name="Test Channel",
|
||||
tvg_id="test.tv",
|
||||
tvg_name="Test",
|
||||
tvg_logo="test.png",
|
||||
group_id=group_id,
|
||||
)
|
||||
db_session.add(channel)
|
||||
db_session.commit()
|
||||
|
||||
# Create URL associated with the channel and priority
|
||||
channel_url = MockChannelURL(
|
||||
url="http://test.com",
|
||||
priority_id=100,
|
||||
in_use=True,
|
||||
channel_id=channel.id, # Add the channel_id
|
||||
)
|
||||
db_session.add(channel_url)
|
||||
db_session.commit()
|
||||
|
||||
response = admin_user_client.delete("/priorities/100")
|
||||
assert response.status_code == status.HTTP_409_CONFLICT
|
||||
assert "in use by channel URLs" in response.json()["detail"]
|
||||
|
||||
# Verify priority still exists
|
||||
db_priority = db_session.get(MockPriority, 100)
|
||||
assert db_priority is not None
|
||||
|
||||
|
||||
def test_delete_priority_forbidden_for_non_admin(
|
||||
db_session: Session, non_admin_user_client
|
||||
):
|
||||
response = non_admin_user_client.delete("/priorities/100")
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert "required roles" in response.json()["detail"]
|
||||
287
tests/routers/test_scheduler.py
Normal file
287
tests/routers/test_scheduler.py
Normal file
@@ -0,0 +1,287 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import Mock
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
from app.iptv.scheduler import StreamScheduler
|
||||
from app.routers.scheduler import get_scheduler
|
||||
from app.routers.scheduler import router as scheduler_router
|
||||
from app.utils.database import get_db
|
||||
from tests.routers.mocks import MockScheduler, create_trigger_mock, mock_get_scheduler
|
||||
from tests.utils.auth_test_fixtures import (
|
||||
admin_user_client,
|
||||
db_session,
|
||||
non_admin_user_client,
|
||||
)
|
||||
from tests.utils.db_mocks import mock_get_db
|
||||
|
||||
# Scheduler Health Check Tests
|
||||
|
||||
|
||||
def test_scheduler_health_success(admin_user_client, monkeypatch):
|
||||
"""
|
||||
Test case for successful scheduler health check when accessed by an admin user.
|
||||
It mocks the scheduler to be running and have a next scheduled job.
|
||||
"""
|
||||
|
||||
# Define the expected next run time for the scheduler job.
|
||||
next_run = datetime.now(timezone.utc)
|
||||
|
||||
# Create a mock job object that simulates an APScheduler job.
|
||||
mock_job = Mock()
|
||||
mock_job.next_run_time = next_run
|
||||
|
||||
# Mock the `get_job` method to return our mock_job for a specific ID.
|
||||
def mock_get_job(job_id):
|
||||
if job_id == "daily_stream_validation":
|
||||
return mock_job
|
||||
return None
|
||||
|
||||
# Create a custom mock for `get_scheduler` dependency.
|
||||
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
|
||||
return await mock_get_scheduler(
|
||||
request,
|
||||
running=True,
|
||||
get_job=Mock(side_effect=mock_get_job), # Use the custom mock_get_job
|
||||
)
|
||||
|
||||
# Include the scheduler router in the test application.
|
||||
admin_user_client.app.include_router(scheduler_router)
|
||||
|
||||
# Override dependencies for the test.
|
||||
admin_user_client.app.dependency_overrides[get_scheduler] = (
|
||||
custom_mock_get_scheduler
|
||||
)
|
||||
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||
|
||||
# Make the request to the scheduler health endpoint.
|
||||
response = admin_user_client.get("/scheduler/health")
|
||||
|
||||
# Assert the response status code and content.
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["status"] == "running"
|
||||
assert data["next_run"] == str(next_run)
|
||||
|
||||
|
||||
def test_scheduler_health_stopped(admin_user_client, monkeypatch):
|
||||
"""
|
||||
Test case for scheduler health check when the scheduler is in a stopped state.
|
||||
Ensures the API returns the correct status and no next run time.
|
||||
"""
|
||||
|
||||
# Create a custom mock for `get_scheduler` dependency,
|
||||
# simulating a stopped scheduler.
|
||||
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
|
||||
return await mock_get_scheduler(
|
||||
request,
|
||||
running=False,
|
||||
)
|
||||
|
||||
# Include the scheduler router in the test application.
|
||||
admin_user_client.app.include_router(scheduler_router)
|
||||
|
||||
# Override dependencies for the test.
|
||||
admin_user_client.app.dependency_overrides[get_scheduler] = (
|
||||
custom_mock_get_scheduler
|
||||
)
|
||||
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||
|
||||
# Make the request to the scheduler health endpoint.
|
||||
response = admin_user_client.get("/scheduler/health")
|
||||
|
||||
# Assert the response status code and content.
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["status"] == "stopped"
|
||||
assert data["next_run"] is None
|
||||
|
||||
|
||||
def test_scheduler_health_forbidden_for_non_admin(non_admin_user_client, monkeypatch):
|
||||
"""
|
||||
Test case to ensure that non-admin users are forbidden from accessing
|
||||
the scheduler health endpoint.
|
||||
"""
|
||||
|
||||
# Create a custom mock for `get_scheduler` dependency.
|
||||
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
|
||||
return await mock_get_scheduler(
|
||||
request,
|
||||
running=False,
|
||||
)
|
||||
|
||||
# Include the scheduler router in the test application.
|
||||
non_admin_user_client.app.include_router(scheduler_router)
|
||||
|
||||
# Override dependencies for the test.
|
||||
non_admin_user_client.app.dependency_overrides[get_scheduler] = (
|
||||
custom_mock_get_scheduler
|
||||
)
|
||||
non_admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||
|
||||
# Make the request to the scheduler health endpoint.
|
||||
response = non_admin_user_client.get("/scheduler/health")
|
||||
|
||||
# Assert the response status code and error detail.
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert "required roles" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_scheduler_health_check_exception(admin_user_client, monkeypatch):
|
||||
"""
|
||||
Test case for handling exceptions during the scheduler health check.
|
||||
Ensures the API returns a 500 Internal Server Error when an exception occurs.
|
||||
"""
|
||||
|
||||
# Create a custom mock for `get_scheduler` dependency that raises an exception.
|
||||
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
|
||||
return await mock_get_scheduler(
|
||||
request, running=True, get_job=Mock(side_effect=Exception("Test exception"))
|
||||
)
|
||||
|
||||
# Include the scheduler router in the test application.
|
||||
admin_user_client.app.include_router(scheduler_router)
|
||||
|
||||
# Override dependencies for the test.
|
||||
admin_user_client.app.dependency_overrides[get_scheduler] = (
|
||||
custom_mock_get_scheduler
|
||||
)
|
||||
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||
|
||||
# Make the request to the scheduler health endpoint.
|
||||
response = admin_user_client.get("/scheduler/health")
|
||||
|
||||
# Assert the response status code and error detail.
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert "Failed to check scheduler health" in response.json()["detail"]
|
||||
|
||||
|
||||
# Scheduler Trigger Tests
|
||||
|
||||
|
||||
def test_trigger_validation_success(admin_user_client, monkeypatch):
|
||||
"""
|
||||
Test case for successful manual triggering
|
||||
of stream validation by an admin user.
|
||||
It verifies that the trigger method is called and
|
||||
the API returns a 202 Accepted status.
|
||||
"""
|
||||
# Use a mutable reference to check if the trigger method was called.
|
||||
triggered_ref = {"value": False}
|
||||
|
||||
# Initialize a custom mock scheduler.
|
||||
custom_scheduler = MockScheduler(running=True)
|
||||
custom_scheduler.get_job = Mock(return_value=None)
|
||||
|
||||
# Create a custom mock for `get_scheduler` dependency,
|
||||
# overriding `trigger_manual_validation`.
|
||||
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
|
||||
scheduler = await mock_get_scheduler(
|
||||
request,
|
||||
running=True,
|
||||
)
|
||||
|
||||
# Replace the actual trigger method with our mock to track calls.
|
||||
scheduler.trigger_manual_validation = create_trigger_mock(
|
||||
triggered_ref=triggered_ref
|
||||
)
|
||||
|
||||
return scheduler
|
||||
|
||||
# Include the scheduler router in the test application.
|
||||
admin_user_client.app.include_router(scheduler_router)
|
||||
|
||||
# Override dependencies for the test.
|
||||
admin_user_client.app.dependency_overrides[get_scheduler] = (
|
||||
custom_mock_get_scheduler
|
||||
)
|
||||
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||
|
||||
# Make the request to trigger stream validation.
|
||||
response = admin_user_client.post("/scheduler/trigger")
|
||||
|
||||
# Assert the response status code, message, and that the trigger was called.
|
||||
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||
assert response.json()["message"] == "Stream validation triggered"
|
||||
assert triggered_ref["value"] is True
|
||||
|
||||
|
||||
def test_trigger_validation_forbidden_for_non_admin(non_admin_user_client, monkeypatch):
|
||||
"""
|
||||
Test case to ensure that non-admin users are
|
||||
forbidden from manually triggering stream validation.
|
||||
"""
|
||||
|
||||
# Create a custom mock for `get_scheduler` dependency.
|
||||
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
|
||||
return await mock_get_scheduler(
|
||||
request,
|
||||
running=True,
|
||||
)
|
||||
|
||||
# Include the scheduler router in the test application.
|
||||
non_admin_user_client.app.include_router(scheduler_router)
|
||||
|
||||
# Override dependencies for the test.
|
||||
non_admin_user_client.app.dependency_overrides[get_scheduler] = (
|
||||
custom_mock_get_scheduler
|
||||
)
|
||||
non_admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||
|
||||
# Make the request to trigger stream validation.
|
||||
response = non_admin_user_client.post("/scheduler/trigger")
|
||||
|
||||
# Assert the response status code and error detail.
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert "required roles" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_scheduler_initialized_in_app_state(admin_user_client):
|
||||
"""
|
||||
Test case for when the scheduler is initialized in the app state but its internal
|
||||
scheduler attribute is not set, which should still allow health check.
|
||||
"""
|
||||
scheduler = StreamScheduler()
|
||||
|
||||
# Set the scheduler instance in the test client's app state.
|
||||
admin_user_client.app.state.scheduler = scheduler
|
||||
|
||||
# Include the scheduler router in the test application.
|
||||
admin_user_client.app.include_router(scheduler_router)
|
||||
|
||||
# Override only get_db, allowing the real get_scheduler to be tested.
|
||||
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||
|
||||
# Make the request to the scheduler health endpoint.
|
||||
response = admin_user_client.get("/scheduler/health")
|
||||
|
||||
# Assert the response status code.
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
|
||||
def test_scheduler_not_initialized_in_app_state(admin_user_client):
|
||||
"""
|
||||
Test case for when the scheduler is not properly initialized in the app state.
|
||||
This simulates a scenario where the internal scheduler attribute is missing,
|
||||
leading to a 500 Internal Server Error on health check.
|
||||
"""
|
||||
scheduler = StreamScheduler()
|
||||
del (
|
||||
scheduler.scheduler
|
||||
) # Simulate uninitialized scheduler by deleting the attribute
|
||||
|
||||
# Set the scheduler instance in the test client's app state.
|
||||
admin_user_client.app.state.scheduler = scheduler
|
||||
|
||||
# Include the scheduler router in the test application.
|
||||
admin_user_client.app.include_router(scheduler_router)
|
||||
|
||||
# Override only get_db, allowing the real get_scheduler to be tested.
|
||||
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
|
||||
|
||||
# Make the request to the scheduler health endpoint.
|
||||
response = admin_user_client.get("/scheduler/health")
|
||||
|
||||
# Assert the response status code and error detail.
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert "Scheduler not initialized" in response.json()["detail"]
|
||||
@@ -16,7 +16,7 @@ def test_root_endpoint(client):
|
||||
"""Test root endpoint returns expected message"""
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "IPTV Updater API"}
|
||||
assert response.json() == {"message": "IPTV Manager API"}
|
||||
|
||||
|
||||
def test_openapi_schema_generation(client):
|
||||
|
||||
82
tests/utils/auth_test_fixtures.py
Normal file
82
tests/utils/auth_test_fixtures.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.auth.dependencies import get_current_user
|
||||
from app.models.auth import CognitoUser
|
||||
from app.routers.channels import router as channels_router
|
||||
from app.routers.groups import router as groups_router
|
||||
from app.routers.playlist import router as playlist_router
|
||||
from app.routers.priorities import router as priorities_router
|
||||
from app.utils.database import get_db
|
||||
from tests.utils.db_mocks import (
|
||||
MockBase,
|
||||
MockChannelDB,
|
||||
MockChannelURL,
|
||||
MockPriority,
|
||||
engine_mock,
|
||||
mock_get_db,
|
||||
)
|
||||
from tests.utils.db_mocks import session_mock as TestingSessionLocal
|
||||
|
||||
|
||||
def mock_get_current_user_admin():
|
||||
return CognitoUser(
|
||||
username="testadmin",
|
||||
email="testadmin@example.com",
|
||||
roles=["admin"],
|
||||
user_status="CONFIRMED",
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
|
||||
def mock_get_current_user_non_admin():
|
||||
return CognitoUser(
|
||||
username="testuser",
|
||||
email="testuser@example.com",
|
||||
roles=["user"], # Or any role other than admin
|
||||
user_status="CONFIRMED",
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def db_session():
|
||||
# Create tables for each test function
|
||||
MockBase.metadata.create_all(bind=engine_mock)
|
||||
db = TestingSessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
# Drop tables after each test function
|
||||
MockBase.metadata.drop_all(bind=engine_mock)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def admin_user_client(db_session: Session):
|
||||
"""Yields a TestClient configured with an admin user."""
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(channels_router)
|
||||
test_app.include_router(priorities_router)
|
||||
test_app.include_router(playlist_router)
|
||||
test_app.include_router(groups_router)
|
||||
test_app.dependency_overrides[get_db] = mock_get_db
|
||||
test_app.dependency_overrides[get_current_user] = mock_get_current_user_admin
|
||||
with TestClient(test_app) as test_client:
|
||||
yield test_client
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def non_admin_user_client(db_session: Session):
|
||||
"""Yields a TestClient configured with a non-admin user."""
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(channels_router)
|
||||
test_app.include_router(priorities_router)
|
||||
test_app.include_router(playlist_router)
|
||||
test_app.include_router(groups_router)
|
||||
test_app.dependency_overrides[get_db] = mock_get_db
|
||||
test_app.dependency_overrides[get_current_user] = mock_get_current_user_non_admin
|
||||
with TestClient(test_app) as test_client:
|
||||
yield test_client
|
||||
@@ -4,41 +4,25 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import (
|
||||
TEXT,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
TypeDecorator,
|
||||
UniqueConstraint,
|
||||
create_engine,
|
||||
)
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
from sqlalchemy.orm import declarative_base, relationship, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
# Import the actual UUID_COLUMN_TYPE and SQLiteUUID from app.models.db
|
||||
from app.models.db import UUID_COLUMN_TYPE, SQLiteUUID
|
||||
|
||||
# Create a mock-specific Base class for testing
|
||||
MockBase = declarative_base()
|
||||
|
||||
|
||||
class SQLiteUUID(TypeDecorator):
|
||||
"""Enables UUID support for SQLite."""
|
||||
|
||||
impl = TEXT
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
return uuid.UUID(value)
|
||||
|
||||
|
||||
# Model classes for testing - prefix with Mock to avoid pytest collection
|
||||
class MockPriority(MockBase):
|
||||
__tablename__ = "priorities"
|
||||
@@ -46,16 +30,28 @@ class MockPriority(MockBase):
|
||||
description = Column(String, nullable=False)
|
||||
|
||||
|
||||
class MockGroup(MockBase):
|
||||
__tablename__ = "groups"
|
||||
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String, nullable=False, unique=True)
|
||||
sort_order = Column(Integer, nullable=False, default=0)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
channels = relationship("MockChannelDB", back_populates="group")
|
||||
|
||||
|
||||
class MockChannelDB(MockBase):
|
||||
__tablename__ = "channels"
|
||||
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
||||
tvg_id = Column(String, nullable=False)
|
||||
name = Column(String, nullable=False)
|
||||
group_title = Column(String, nullable=False)
|
||||
group_id = Column(UUID_COLUMN_TYPE, ForeignKey("groups.id"), nullable=False)
|
||||
tvg_name = Column(String)
|
||||
__table_args__ = (
|
||||
UniqueConstraint("group_title", "name", name="uix_group_title_name"),
|
||||
)
|
||||
__table_args__ = (UniqueConstraint("group_id", "name", name="uix_group_id_name"),)
|
||||
tvg_logo = Column(String)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(
|
||||
@@ -63,13 +59,17 @@ class MockChannelDB(MockBase):
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
group = relationship("MockGroup", back_populates="channels")
|
||||
urls = relationship(
|
||||
"MockChannelURL", back_populates="channel", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class MockChannelURL(MockBase):
|
||||
__tablename__ = "channels_urls"
|
||||
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
||||
channel_id = Column(
|
||||
SQLiteUUID(), ForeignKey("channels.id", ondelete="CASCADE"), nullable=False
|
||||
UUID_COLUMN_TYPE, ForeignKey("channels.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
url = Column(String, nullable=False)
|
||||
in_use = Column(Boolean, default=False, nullable=False)
|
||||
@@ -80,6 +80,32 @@ class MockChannelURL(MockBase):
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
channel = relationship("MockChannelDB", back_populates="urls")
|
||||
|
||||
|
||||
def create_mock_priorities_and_group(db_session, priorities, group_name):
|
||||
"""Create mock priorities and group for testing purposes.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session object
|
||||
priorities: List of (id, description) tuples for priorities to create
|
||||
group_name: Name for the new mock group
|
||||
|
||||
Returns:
|
||||
UUID: The ID of the created group
|
||||
"""
|
||||
# Create priorities
|
||||
priority_objects = [
|
||||
MockPriority(id=priority_id, description=description)
|
||||
for priority_id, description in priorities
|
||||
]
|
||||
|
||||
# Create group
|
||||
group = MockGroup(name=group_name)
|
||||
db_session.add_all(priority_objects + [group])
|
||||
db_session.commit()
|
||||
|
||||
return group.id
|
||||
|
||||
|
||||
# Create test engine
|
||||
|
||||
309
tests/utils/test_check_streams.py
Normal file
309
tests/utils/test_check_streams.py
Normal file
@@ -0,0 +1,309 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock, Mock, mock_open, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout
|
||||
|
||||
from app.utils.check_streams import StreamValidator, main
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def validator():
|
||||
"""Create a StreamValidator instance for testing"""
|
||||
return StreamValidator(timeout=1)
|
||||
|
||||
|
||||
def test_validator_init():
|
||||
"""Test StreamValidator initialization with default and custom values"""
|
||||
# Test with default user agent
|
||||
validator = StreamValidator()
|
||||
assert validator.timeout == 10
|
||||
assert "Mozilla" in validator.session.headers["User-Agent"]
|
||||
|
||||
# Test with custom values
|
||||
custom_agent = "CustomAgent/1.0"
|
||||
validator = StreamValidator(timeout=5, user_agent=custom_agent)
|
||||
assert validator.timeout == 5
|
||||
assert validator.session.headers["User-Agent"] == custom_agent
|
||||
|
||||
|
||||
def test_is_valid_content_type(validator):
|
||||
"""Test content type validation"""
|
||||
valid_types = [
|
||||
"video/mp4",
|
||||
"video/mp2t",
|
||||
"application/vnd.apple.mpegurl",
|
||||
"application/dash+xml",
|
||||
"video/webm",
|
||||
"application/octet-stream",
|
||||
"application/x-mpegURL",
|
||||
"video/mp4; charset=utf-8", # Test with additional parameters
|
||||
]
|
||||
|
||||
invalid_types = [
|
||||
"text/html",
|
||||
"application/json",
|
||||
"image/jpeg",
|
||||
"",
|
||||
]
|
||||
|
||||
for content_type in valid_types:
|
||||
assert validator._is_valid_content_type(content_type)
|
||||
|
||||
for content_type in invalid_types:
|
||||
assert not validator._is_valid_content_type(content_type)
|
||||
|
||||
# Test None case explicitly
|
||||
assert not validator._is_valid_content_type(None)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code,content_type,should_succeed",
|
||||
[
|
||||
(200, "video/mp4", True),
|
||||
(206, "video/mp4", True), # Partial content
|
||||
(404, "video/mp4", False),
|
||||
(500, "video/mp4", False),
|
||||
(200, "text/html", False),
|
||||
(200, "application/json", False),
|
||||
],
|
||||
)
|
||||
def test_validate_stream_response_handling(status_code, content_type, should_succeed):
|
||||
"""Test stream validation with different response scenarios"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.headers = {"Content-Type": content_type}
|
||||
mock_response.iter_content.return_value = iter([b"some content"])
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value.__enter__.return_value = mock_response
|
||||
|
||||
with patch("requests.Session", return_value=mock_session):
|
||||
validator = StreamValidator()
|
||||
valid, message = validator.validate_stream("http://example.com/stream")
|
||||
|
||||
assert valid == should_succeed
|
||||
mock_session.get.assert_called_once()
|
||||
|
||||
|
||||
def test_validate_stream_connection_error():
|
||||
"""Test stream validation with connection error"""
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.side_effect = ConnectionError("Connection failed")
|
||||
|
||||
with patch("requests.Session", return_value=mock_session):
|
||||
validator = StreamValidator()
|
||||
valid, message = validator.validate_stream("http://example.com/stream")
|
||||
|
||||
assert not valid
|
||||
assert "Connection Error" in message
|
||||
|
||||
|
||||
def test_validate_stream_timeout():
|
||||
"""Test stream validation with timeout"""
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.side_effect = Timeout("Request timed out")
|
||||
|
||||
with patch("requests.Session", return_value=mock_session):
|
||||
validator = StreamValidator()
|
||||
valid, message = validator.validate_stream("http://example.com/stream")
|
||||
|
||||
assert not valid
|
||||
assert "timeout" in message.lower()
|
||||
|
||||
|
||||
def test_validate_stream_http_error():
|
||||
"""Test stream validation with HTTP error"""
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.side_effect = HTTPError("HTTP Error occurred")
|
||||
|
||||
with patch("requests.Session", return_value=mock_session):
|
||||
validator = StreamValidator()
|
||||
valid, message = validator.validate_stream("http://example.com/stream")
|
||||
|
||||
assert not valid
|
||||
assert "HTTP Error" in message
|
||||
|
||||
|
||||
def test_validate_stream_request_exception():
|
||||
"""Test stream validation with general request exception"""
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.side_effect = RequestException("Request failed")
|
||||
|
||||
with patch("requests.Session", return_value=mock_session):
|
||||
validator = StreamValidator()
|
||||
valid, message = validator.validate_stream("http://example.com/stream")
|
||||
|
||||
assert not valid
|
||||
assert "Request Exception" in message
|
||||
|
||||
|
||||
def test_validate_stream_content_read_error():
|
||||
"""Test stream validation when content reading fails"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Type": "video/mp4"}
|
||||
mock_response.iter_content.side_effect = ConnectionError("Read failed")
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value.__enter__.return_value = mock_response
|
||||
|
||||
with patch("requests.Session", return_value=mock_session):
|
||||
validator = StreamValidator()
|
||||
valid, message = validator.validate_stream("http://example.com/stream")
|
||||
|
||||
assert not valid
|
||||
assert "Connection failed during content read" in message
|
||||
|
||||
|
||||
def test_validate_stream_general_exception():
|
||||
"""Test validate_stream with an unexpected exception"""
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.side_effect = Exception("Unexpected error")
|
||||
|
||||
with patch("requests.Session", return_value=mock_session):
|
||||
validator = StreamValidator()
|
||||
valid, message = validator.validate_stream("http://example.com/stream")
|
||||
|
||||
assert not valid
|
||||
assert "Validation error" in message
|
||||
|
||||
|
||||
def test_parse_playlist(validator, tmp_path):
|
||||
"""Test playlist file parsing"""
|
||||
playlist_content = """
|
||||
#EXTM3U
|
||||
#EXTINF:-1,Channel 1
|
||||
http://example.com/stream1
|
||||
#EXTINF:-1,Channel 2
|
||||
http://example.com/stream2
|
||||
|
||||
http://example.com/stream3
|
||||
"""
|
||||
playlist_file = tmp_path / "test_playlist.m3u"
|
||||
playlist_file.write_text(playlist_content)
|
||||
|
||||
urls = validator.parse_playlist(str(playlist_file))
|
||||
assert len(urls) == 3
|
||||
assert urls == [
|
||||
"http://example.com/stream1",
|
||||
"http://example.com/stream2",
|
||||
"http://example.com/stream3",
|
||||
]
|
||||
|
||||
|
||||
def test_parse_playlist_error(validator):
|
||||
"""Test playlist parsing with non-existent file"""
|
||||
with pytest.raises(Exception):
|
||||
validator.parse_playlist("nonexistent_file.m3u")
|
||||
|
||||
|
||||
@patch("app.utils.check_streams.logging")
|
||||
@patch("app.utils.check_streams.StreamValidator")
|
||||
def test_main_with_urls(mock_validator_class, mock_logging, tmp_path, capsys):
|
||||
"""Test main function with direct URLs"""
|
||||
# Setup mock validator
|
||||
mock_validator = Mock()
|
||||
mock_validator_class.return_value = mock_validator
|
||||
mock_validator.validate_stream.return_value = (True, "Stream is valid")
|
||||
|
||||
# Setup test arguments
|
||||
test_args = ["script", "http://example.com/stream1", "http://example.com/stream2"]
|
||||
with patch("sys.argv", test_args):
|
||||
main()
|
||||
|
||||
# Verify validator was called correctly
|
||||
assert mock_validator.validate_stream.call_count == 2
|
||||
mock_validator.validate_stream.assert_any_call("http://example.com/stream1")
|
||||
mock_validator.validate_stream.assert_any_call("http://example.com/stream2")
|
||||
|
||||
|
||||
@patch("app.utils.check_streams.logging")
|
||||
@patch("app.utils.check_streams.StreamValidator")
|
||||
def test_main_with_playlist(mock_validator_class, mock_logging, tmp_path):
|
||||
"""Test main function with a playlist file"""
|
||||
# Create test playlist
|
||||
playlist_content = "http://example.com/stream1\nhttp://example.com/stream2"
|
||||
playlist_file = tmp_path / "test.m3u"
|
||||
playlist_file.write_text(playlist_content)
|
||||
|
||||
# Setup mock validator
|
||||
mock_validator = Mock()
|
||||
mock_validator_class.return_value = mock_validator
|
||||
mock_validator.parse_playlist.return_value = [
|
||||
"http://example.com/stream1",
|
||||
"http://example.com/stream2",
|
||||
]
|
||||
mock_validator.validate_stream.return_value = (True, "Stream is valid")
|
||||
|
||||
# Setup test arguments
|
||||
test_args = ["script", str(playlist_file)]
|
||||
with patch("sys.argv", test_args):
|
||||
main()
|
||||
|
||||
# Verify validator was called correctly
|
||||
mock_validator.parse_playlist.assert_called_once_with(str(playlist_file))
|
||||
assert mock_validator.validate_stream.call_count == 2
|
||||
|
||||
|
||||
@patch("app.utils.check_streams.logging")
|
||||
@patch("app.utils.check_streams.StreamValidator")
|
||||
def test_main_with_dead_streams(mock_validator_class, mock_logging, tmp_path):
|
||||
"""Test main function handling dead streams"""
|
||||
# Setup mock validator
|
||||
mock_validator = Mock()
|
||||
mock_validator_class.return_value = mock_validator
|
||||
mock_validator.validate_stream.return_value = (False, "Stream is dead")
|
||||
|
||||
# Setup test arguments
|
||||
test_args = ["script", "http://example.com/dead1", "http://example.com/dead2"]
|
||||
|
||||
# Mock file operations
|
||||
mock_file = mock_open()
|
||||
with patch("sys.argv", test_args), patch("builtins.open", mock_file):
|
||||
main()
|
||||
|
||||
# Verify dead streams were written to file
|
||||
mock_file().write.assert_called_once_with(
|
||||
"http://example.com/dead1\nhttp://example.com/dead2"
|
||||
)
|
||||
|
||||
|
||||
@patch("app.utils.check_streams.logging")
|
||||
@patch("app.utils.check_streams.StreamValidator")
|
||||
@patch("os.path.isfile")
|
||||
def test_main_with_playlist_error(
|
||||
mock_isfile, mock_validator_class, mock_logging, tmp_path
|
||||
):
|
||||
"""Test main function handling playlist parsing errors"""
|
||||
# Setup mock validator
|
||||
mock_validator = Mock()
|
||||
mock_validator_class.return_value = mock_validator
|
||||
|
||||
# Configure mock validator behavior
|
||||
error_msg = "Failed to parse playlist"
|
||||
mock_validator.parse_playlist.side_effect = [
|
||||
Exception(error_msg), # First call fails
|
||||
["http://example.com/stream1"], # Second call succeeds
|
||||
]
|
||||
mock_validator.validate_stream.return_value = (True, "Stream is valid")
|
||||
|
||||
# Configure isfile mock to return True for our test files
|
||||
mock_isfile.side_effect = lambda x: x in ["/invalid.m3u", "/valid.m3u"]
|
||||
|
||||
# Setup test arguments
|
||||
test_args = ["script", "/invalid.m3u", "/valid.m3u"]
|
||||
with patch("sys.argv", test_args):
|
||||
main()
|
||||
|
||||
# Verify error was logged correctly
|
||||
mock_logging.error.assert_called_with(
|
||||
"Failed to process file /invalid.m3u: Failed to parse playlist"
|
||||
)
|
||||
|
||||
# Verify processing continued with valid playlist
|
||||
mock_validator.parse_playlist.assert_called_with("/valid.m3u")
|
||||
assert (
|
||||
mock_validator.validate_stream.call_count == 1
|
||||
) # Called for the URL from valid playlist
|
||||
Reference in New Issue
Block a user