Compare commits

..

17 Commits

Author SHA1 Message Date
a42d4c30a6 Started (incomplete) implementation of stream verification scheduler and endpoints
All checks were successful
AWS Deploy on Push / build (push) Successful in 5m18s
2025-06-17 17:12:39 -05:00
abb467749b Implemented bulk upload by passing a json structure. Added delete all channels, groups and priorities
All checks were successful
AWS Deploy on Push / build (push) Successful in 2m17s
2025-06-12 18:49:20 -05:00
b8ac25e301 Introduced groups and added all related endpoints
All checks were successful
AWS Deploy on Push / build (push) Successful in 7m39s
2025-06-10 23:02:46 -05:00
729eabf27f Updated documentation
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m8s
2025-05-29 18:46:40 -05:00
34c446bcfa Make sure DB credentials are available when running userdata (fix-2)
All checks were successful
AWS Deploy on Push / build (push) Successful in 10m26s
2025-05-29 17:52:53 -05:00
d4cc74ea8c Make sure DB credentials are available when running userdata (fix-1)
Some checks failed
AWS Deploy on Push / build (push) Failing after 39s
2025-05-29 17:48:23 -05:00
21b73b6843 Make sure DB credentials are available when running userdata
Some checks failed
AWS Deploy on Push / build (push) Failing after 41s
2025-05-29 17:43:08 -05:00
e743daf9f7 Moved creation of the instance after database creation
All checks were successful
AWS Deploy on Push / build (push) Successful in 6m51s
2025-05-29 17:16:08 -05:00
b0d98551b8 Fixed install of postgres client on Amazon Linux 2023
All checks were successful
AWS Deploy on Push / build (push) Successful in 7m55s
2025-05-29 16:37:42 -05:00
eaab1ef998 Changed project name to be IPTV Manager Service
All checks were successful
AWS Deploy on Push / build (push) Successful in 8m29s
2025-05-29 16:09:52 -05:00
e25f8c1ecd Run unit test upon committing new code
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m3s
2025-05-28 23:41:12 -05:00
95bf0f9701 Created unit tests for check_streams
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m5s
2025-05-28 23:31:04 -05:00
f7a1c20066 Created unit tests for playlist.py
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m8s
2025-05-28 22:58:29 -05:00
bf6f156fec Created unit tests for priorities.py router
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m6s
2025-05-28 22:31:31 -05:00
7e25ec6755 Test refactoring
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m9s
2025-05-28 22:22:20 -05:00
6d506122d9 Add pre-commit commands to install script
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m7s
2025-05-28 22:05:13 -05:00
02913c7385 Linted and formatted all files 2025-05-28 21:52:39 -05:00
59 changed files with 5394 additions and 1220 deletions

View File

@@ -1,10 +1,15 @@
# Environment variables
# Scheduler configuration
STREAM_VALIDATION_SCHEDULE=0 3 * * * # Daily at 3 AM (cron syntax)
STREAM_VALIDATION_BATCH_SIZE=10 # Number of channels per batch (0=all)
# For use with Docker Compose to run application locally # For use with Docker Compose to run application locally
MOCK_AUTH=true/false MOCK_AUTH=true/false
DB_USER=MyDBUser DB_USER=MyDBUser
DB_PASSWORD=MyDBPassword DB_PASSWORD=MyDBPassword
DB_HOST=MyDBHost DB_HOST=MyDBHost
DB_NAME=iptv_updater DB_NAME=iptv_manager
FREEDNS_User=MyFreeDNSUsername FREEDNS_User=MyFreeDNSUsername
FREEDNS_Password=MyFreeDNSPassword FREEDNS_Password=MyFreeDNSPassword

View File

@@ -58,7 +58,7 @@ jobs:
run: | run: |
INSTANCE_IDS=$(aws ec2 describe-instances \ INSTANCE_IDS=$(aws ec2 describe-instances \
--region us-east-2 \ --region us-east-2 \
--filters "Name=tag:Name,Values=IptvUpdaterStack/IptvUpdaterInstance" \ --filters "Name=tag:Name,Values=IptvManagerStack/IptvManagerInstance" \
"Name=instance-state-name,Values=running" \ "Name=instance-state-name,Values=running" \
--query "Reservations[].Instances[].InstanceId" \ --query "Reservations[].Instances[].InstanceId" \
--output text) --output text)
@@ -69,11 +69,11 @@ jobs:
--instance-ids "$INSTANCE_ID" \ --instance-ids "$INSTANCE_ID" \
--document-name "AWS-RunShellScript" \ --document-name "AWS-RunShellScript" \
--parameters 'commands=[ --parameters 'commands=[
"cd /home/ec2-user/iptv-updater-aws", "cd /home/ec2-user/iptv-manager-service",
"git pull", "git pull",
"pip3 install -r requirements.txt", "pip3 install -r requirements.txt",
"alembic upgrade head", "alembic upgrade head",
"sudo systemctl restart iptv-updater" "sudo systemctl restart iptv-manager"
]' ]'
done done

View File

@@ -1,7 +1,16 @@
repos: repos:
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.2 rev: v0.11.12
hooks: hooks:
- id: ruff - id: ruff
args: [--fix, --exit-non-zero-on-fix] args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format - id: ruff-format
- repo: local
hooks:
- id: pytest-check
name: pytest-check
entry: pytest
language: system
pass_filenames: false
always_run: true

22
.vscode/settings.json vendored
View File

@@ -1,4 +1,6 @@
{ {
"python.terminal.activateEnvironment": true,
"python.terminal.activateEnvInCurrentTerminal": true,
"editor.formatOnSave": true, "editor.formatOnSave": true,
"editor.defaultFormatter": "charliermarsh.ruff", "editor.defaultFormatter": "charliermarsh.ruff",
"ruff.importStrategy": "fromEnvironment", "ruff.importStrategy": "fromEnvironment",
@@ -7,14 +9,18 @@
"addopts", "addopts",
"adminpassword", "adminpassword",
"altinstall", "altinstall",
"apscheduler",
"asyncio", "asyncio",
"autoflush", "autoflush",
"autoupdate",
"autouse", "autouse",
"awscli",
"awscliv", "awscliv",
"boto", "boto",
"botocore", "botocore",
"BURSTABLE", "BURSTABLE",
"cabletv", "cabletv",
"capsys",
"CDUF", "CDUF",
"cduflogo", "cduflogo",
"cdulogo", "cdulogo",
@@ -29,31 +35,44 @@
"cluflogo", "cluflogo",
"clulogo", "clulogo",
"cpulogo", "cpulogo",
"crond",
"cronie",
"cuflgo", "cuflgo",
"CUNF", "CUNF",
"cunflogo", "cunflogo",
"cuulogo", "cuulogo",
"datname",
"deadstreams",
"delenv", "delenv",
"delogo", "delogo",
"devel", "devel",
"dflogo", "dflogo",
"dmlogo", "dmlogo",
"dotenv", "dotenv",
"EXTINF",
"EXTM",
"fastapi", "fastapi",
"filterwarnings", "filterwarnings",
"fiorinis", "fiorinis",
"freedns", "freedns",
"fullchain", "fullchain",
"gitea", "gitea",
"httpx",
"iptv", "iptv",
"isort", "isort",
"KHTML",
"lclogo", "lclogo",
"LETSENCRYPT", "LETSENCRYPT",
"levelname",
"mpegurl",
"nohup", "nohup",
"nopriority",
"ondelete", "ondelete",
"onupdate", "onupdate",
"passlib", "passlib",
"PGPASSWORD",
"poolclass", "poolclass",
"psql",
"psycopg", "psycopg",
"pycache", "pycache",
"pycodestyle", "pycodestyle",
@@ -68,14 +87,17 @@
"ruru", "ruru",
"sessionmaker", "sessionmaker",
"sqlalchemy", "sqlalchemy",
"sqliteuuid",
"starlette", "starlette",
"stefano", "stefano",
"testadmin", "testadmin",
"testdb", "testdb",
"testpass", "testpass",
"testpaths", "testpaths",
"testuser",
"uflogo", "uflogo",
"umlogo", "umlogo",
"usefixtures",
"uvicorn", "uvicorn",
"venv", "venv",
"wrongpass" "wrongpass"

226
README.md
View File

@@ -1,165 +1,149 @@
# IPTV Updater AWS # IPTV Manager Service
An automated IPTV playlist and EPG updater service deployed on AWS infrastructure using CDK. A FastAPI-based service for managing IPTV playlists and channel priorities. The application provides secure endpoints for user authentication, channel management, and playlist generation.
## Overview ## ✨ Features
This project provides a service for automatically updating IPTV playlists and Electronic Program Guide (EPG) data. It runs on AWS infrastructure with: - **JWT Authentication**: Secure login using AWS Cognito
- **Channel Management**: CRUD operations for IPTV channels
- **Playlist Generation**: Create M3U playlists with channel priorities
- **Stream Monitoring**: Background checks for channel availability
- **Priority Management**: Set channel priorities for playlist ordering
- EC2 instance for hosting the application ## 🛠️ Technology Stack
- RDS PostgreSQL database for data storage
- Amazon Cognito for user authentication
- HTTPS support via Let's Encrypt
- Domain management via FreeDNS
## Prerequisites - **Backend**: Python 3.11, FastAPI
- **Database**: PostgreSQL (SQLAlchemy ORM)
- **Authentication**: AWS Cognito
- **Infrastructure**: AWS CDK (API Gateway, Lambda, RDS)
- **Testing**: Pytest with 85%+ coverage
- **CI/CD**: Pre-commit hooks, Alembic migrations
- AWS CLI installed and configured ## 🚀 Getting Started
- Python 3.12 or later
- Node.js v22.15 or later for AWS CDK
- Docker and Docker Compose for local development
## Local Development ### Prerequisites
1. Clone the repository: - Python 3.11+
- Docker
- AWS CLI (for deployment)
### Installation
```bash ```bash
git clone <repo-url> # Clone repository
cd iptv-updater-aws git clone https://github.com/your-repo/iptv-manager-service.git
cd iptv-manager-service
# Setup environment
python -m venv venv
source venv/bin/activate
pip install -r requirements.txt
cp .env.example .env # Update with your values
# Run installation script
./scripts/install.sh
``` ```
2. Copy the example environment file: ### Running Locally
```bash
cp .env.example .env
```
3. Add your configuration to `.env`:
```
FREEDNS_User=your_freedns_username
FREEDNS_Password=your_freedns_password
DOMAIN_NAME=your.domain.name
SSH_PUBLIC_KEY=your_ssh_public_key
REPO_URL=repository_url
LETSENCRYPT_EMAIL=your_email
```
4. Start the local development environment:
```bash ```bash
# Start development environment
./scripts/start_local_dev.sh ./scripts/start_local_dev.sh
```
5. Stop the local environment: # Stop development environment
```bash
./scripts/stop_local_dev.sh ./scripts/stop_local_dev.sh
``` ```
## Deployment ## ☁️ AWS Deployment
### Initial Deployment The infrastructure is defined in CDK. Use the provided scripts:
1. Ensure your AWS credentials are configured:
```bash
aws configure
```
2. Install dependencies:
```bash
pip install -r requirements.txt
```
3. Deploy the infrastructure:
```bash ```bash
# Deploy AWS infrastructure
./scripts/deploy.sh ./scripts/deploy.sh
# Destroy AWS infrastructure
./scripts/destroy.sh
# Create Cognito test user
./scripts/create_cognito_user.sh
# Delete Cognito user
./scripts/delete_cognito_user.sh
``` ```
The deployment script will: Key AWS components:
- Create/update the CloudFormation stack using CDK - API Gateway
- Configure the EC2 instance with required software - Lambda functions
- Set up HTTPS using Let's Encrypt - RDS PostgreSQL
- Configure the domain using FreeDNS - Cognito User Pool
### Continuous Deployment ## 🤖 Continuous Integration/Deployment
The project includes a Gitea workflow (`.gitea/workflows/aws_deploy_on_push.yml`) that automatically: This project includes a Gitea Actions workflow (`.gitea/workflows/deploy.yml`) for automated deployment to AWS. The workflow is fully compatible with GitHub Actions and can be easily adapted by:
- Deploys infrastructure changes 1. Placing the workflow file in the `.github/workflows/` directory
- Updates the application on EC2 instances 2. Setting up the required secrets in your CI/CD environment:
- Restarts the service - `AWS_ACCESS_KEY_ID`
- `AWS_SECRET_ACCESS_KEY`
- `AWS_DEFAULT_REGION`
## Infrastructure The workflow automatically deploys the infrastructure and application when changes are pushed to the main branch.
The AWS infrastructure is defined in `infrastructure/stack.py` and includes: ## 📚 API Documentation
- VPC with public subnets Access interactive docs at:
- EC2 t2.micro instance (Free Tier eligible)
- RDS PostgreSQL database (db.t3.micro)
- Security groups for EC2 and RDS
- Elastic IP for the EC2 instance
- Cognito User Pool for authentication
- IAM roles and policies for EC2 instance access
## User Management - Swagger UI: `http://localhost:8000/docs`
- ReDoc: `http://localhost:8000/redoc`
### Creating Users ### Key Endpoints
To create a new user in Cognito: | Endpoint | Method | Description |
| ------------- | ------ | --------------------- |
| `/auth/login` | POST | User authentication |
| `/channels` | GET | List all channels |
| `/playlist` | GET | Generate M3U playlist |
| `/priorities` | POST | Set channel priority |
## 🧪 Testing
Run the full test suite:
```bash ```bash
./scripts/create_cognito_user.sh <user_pool_id> <username> <password> --admin <= optional for defining an admin user pytest
``` ```
### Deleting Users Test coverage includes:
To delete a user from Cognito: - Authentication workflows
- Channel CRUD operations
- Playlist generation logic
- Stream monitoring
- Database operations
```bash ## 📂 Project Structure
./scripts/delete_cognito_user.sh <user_pool_id> <username>
```txt
iptv-manager-service/
├── app/ # Core application
│ ├── auth/ # Cognito authentication
│ ├── iptv/ # Playlist logic
│ ├── models/ # Database models
│ ├── routers/ # API endpoints
│ ├── utils/ # Helper functions
│ └── main.py # App entry point
├── infrastructure/ # AWS CDK stack
├── docker/ # Docker configs
├── scripts/ # Deployment scripts
├── tests/ # Comprehensive tests
├── alembic/ # Database migrations
├── .gitea/ # Gitea CI/CD workflows
│ └── workflows/
└── ... # Config files
``` ```
## Architecture ## 📝 License
The application is structured as follows: This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
```bash
app/
├── auth/ # Authentication modules
├── iptv/ # IPTV and EPG processing
├── models/ # Database models
└── utils/ # Utility functions
infrastructure/ # AWS CDK infrastructure code
docker/ # Docker configuration for local development
scripts/ # Utility scripts for deployment and management
```
## Environment Variables
The following environment variables are required:
| Variable | Description |
|----------|-------------|
| FREEDNS_User | FreeDNS username |
| FREEDNS_Password | FreeDNS password |
| DOMAIN_NAME | Your domain name |
| SSH_PUBLIC_KEY | SSH public key for EC2 access |
| REPO_URL | Repository URL |
| LETSENCRYPT_EMAIL | Email for Let's Encrypt certificates |
## Security Notes
- The EC2 instance has appropriate IAM permissions for:
- EC2 instance discovery
- SSM command execution
- RDS access
- Cognito user management
- All database credentials are stored in AWS Secrets Manager
- HTTPS is enforced using Let's Encrypt certificates
- Access is restricted through Security Groups

View File

@@ -1,12 +1,10 @@
import os
from logging.config import fileConfig from logging.config import fileConfig
from sqlalchemy import engine_from_config from sqlalchemy import engine_from_config, pool
from sqlalchemy import pool
from alembic import context from alembic import context
from app.utils.database import get_db_credentials
from app.models.db import Base from app.models.db import Base
from app.utils.database import get_db_credentials
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.
@@ -17,12 +15,13 @@ config = context.config
if config.config_file_name is not None: if config.config_file_name is not None:
fileConfig(config.config_file_name) fileConfig(config.config_file_name)
# Setup target metadata for autogenerate support # add your model's MetaData object here
# for 'autogenerate' support
target_metadata = Base.metadata target_metadata = Base.metadata
# Override sqlalchemy.url with dynamic credentials # Override sqlalchemy.url with dynamic credentials
if not context.is_offline_mode(): if not context.is_offline_mode():
config.set_main_option('sqlalchemy.url', get_db_credentials()) config.set_main_option("sqlalchemy.url", get_db_credentials())
# other values from the config, defined by the needs of env.py, # other values from the config, defined by the needs of env.py,
# can be acquired: # can be acquired:
@@ -68,9 +67,7 @@ def run_migrations_online() -> None:
) )
with connectable.connect() as connection: with connectable.connect() as connection:
context.configure( context.configure(connection=connection, target_metadata=target_metadata)
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction(): with context.begin_transaction():
context.run_migrations() context.run_migrations()

View File

@@ -1,59 +0,0 @@
"""Add priority and in_use fields
Revision ID: 036879e47172
Revises:
Create Date: 2025-05-26 19:21:32.285656
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '036879e47172'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
# 1. Create priorities table if not exists
if not op.get_bind().engine.dialect.has_table(op.get_bind(), 'priorities'):
op.create_table('priorities',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('description', sa.String(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
# 2. Insert default priorities (skip if already exists)
op.execute("""
INSERT INTO priorities (id, description)
VALUES (100, 'High'), (200, 'Medium'), (300, 'Low')
ON CONFLICT (id) DO NOTHING
""")
# Add new columns with temporary nullable=True
op.add_column('channels_urls', sa.Column('in_use', sa.Boolean(), nullable=True))
op.add_column('channels_urls', sa.Column('priority_id', sa.Integer(), nullable=True))
# Set default values
op.execute("UPDATE channels_urls SET in_use = false, priority_id = 100")
# Convert to NOT NULL
op.alter_column('channels_urls', 'in_use', nullable=False)
op.alter_column('channels_urls', 'priority_id', nullable=False)
op.create_foreign_key(None, 'channels_urls', 'priorities', ['priority_id'], ['id'])
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint('channels_urls_priority_id_fkey', 'channels_urls', type_='foreignkey')
op.drop_column('channels_urls', 'priority_id')
op.drop_column('channels_urls', 'in_use')
op.drop_table('priorities')
# ### end Alembic commands ###

View File

@@ -0,0 +1,110 @@
"""add groups table and migrate group_title data
Revision ID: 0a455608256f
Revises: 95b61a92455a
Create Date: 2025-06-10 09:22:11.820035
"""
import uuid
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '0a455608256f'
down_revision: Union[str, None] = '95b61a92455a'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('groups',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('sort_order', sa.Integer(), nullable=False, server_default='0'),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('name')
)
# Create temporary table for group mapping
group_mapping = op.create_table(
'group_mapping',
sa.Column('group_title', sa.String(), nullable=False),
sa.Column('group_id', sa.UUID(), nullable=False)
)
# Get existing group titles and create groups
conn = op.get_bind()
distinct_groups = conn.execute(
sa.text("SELECT DISTINCT group_title FROM channels")
).fetchall()
for group in distinct_groups:
group_title = group[0]
group_id = str(uuid.uuid4())
conn.execute(
sa.text(
"INSERT INTO groups (id, name, sort_order) "
"VALUES (:id, :name, 0)"
).bindparams(id=group_id, name=group_title)
)
conn.execute(
group_mapping.insert().values(
group_title=group_title,
group_id=group_id
)
)
# Add group_id column (nullable first)
op.add_column('channels', sa.Column('group_id', sa.UUID(), nullable=True))
# Update channels with group_ids
conn.execute(
sa.text(
"UPDATE channels c SET group_id = gm.group_id "
"FROM group_mapping gm WHERE c.group_title = gm.group_title"
)
)
# Now make group_id non-nullable and add constraints
op.alter_column('channels', 'group_id', nullable=False)
op.drop_constraint(op.f('uix_group_title_name'), 'channels', type_='unique')
op.create_unique_constraint('uix_group_id_name', 'channels', ['group_id', 'name'])
op.create_foreign_key('fk_channels_group_id', 'channels', 'groups', ['group_id'], ['id'])
# Clean up and drop group_title
op.drop_table('group_mapping')
op.drop_column('channels', 'group_title')
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('channels', sa.Column('group_title', sa.VARCHAR(), autoincrement=False, nullable=True))
# Restore group_title values from groups table
conn = op.get_bind()
conn.execute(
sa.text(
"UPDATE channels c SET group_title = g.name "
"FROM groups g WHERE c.group_id = g.id"
)
)
# Now make group_title non-nullable
op.alter_column('channels', 'group_title', nullable=False)
# Drop constraints and columns
op.drop_constraint('fk_channels_group_id', 'channels', type_='foreignkey')
op.drop_constraint('uix_group_id_name', 'channels', type_='unique')
op.create_unique_constraint(op.f('uix_group_title_name'), 'channels', ['group_title', 'name'])
op.drop_column('channels', 'group_id')
op.drop_table('groups')
# ### end Alembic commands ###

View File

@@ -0,0 +1,79 @@
"""create initial tables
Revision ID: 95b61a92455a
Revises:
Create Date: 2025-05-29 14:42:16.239587
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '95b61a92455a'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('channels',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('tvg_id', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('group_title', sa.String(), nullable=False),
sa.Column('tvg_name', sa.String(), nullable=True),
sa.Column('tvg_logo', sa.String(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('group_title', 'name', name='uix_group_title_name')
)
op.create_table('priorities',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('description', sa.String(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_table('channels_urls',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('channel_id', sa.UUID(), nullable=False),
sa.Column('url', sa.String(), nullable=False),
sa.Column('in_use', sa.Boolean(), nullable=False),
sa.Column('priority_id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['channel_id'], ['channels.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['priority_id'], ['priorities.id'], ),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
# Seed initial priorities
op.bulk_insert(
sa.Table(
'priorities',
sa.MetaData(),
sa.Column('id', sa.Integer),
sa.Column('description', sa.String),
),
[
{'id': 100, 'description': 'High'},
{'id': 200, 'description': 'Medium'},
{'id': 300, 'description': 'Low'},
]
)
def downgrade() -> None:
"""Downgrade schema."""
# Remove seeded priorities
op.execute("DELETE FROM priorities WHERE id IN (100, 200, 300);")
# Drop tables
op.drop_table('channels_urls')
op.drop_table('priorities')
op.drop_table('channels')

18
app.py
View File

@@ -1,7 +1,9 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import os import os
import aws_cdk as cdk import aws_cdk as cdk
from infrastructure.stack import IptvUpdaterStack
from infrastructure.stack import IptvManagerStack
app = cdk.App() app = cdk.App()
@@ -19,21 +21,25 @@ required_vars = {
"DOMAIN_NAME": domain_name, "DOMAIN_NAME": domain_name,
"SSH_PUBLIC_KEY": ssh_public_key, "SSH_PUBLIC_KEY": ssh_public_key,
"REPO_URL": repo_url, "REPO_URL": repo_url,
"LETSENCRYPT_EMAIL": letsencrypt_email "LETSENCRYPT_EMAIL": letsencrypt_email,
} }
# Check for missing required variables # Check for missing required variables
missing_vars = [k for k, v in required_vars.items() if not v] missing_vars = [k for k, v in required_vars.items() if not v]
if missing_vars: if missing_vars:
raise ValueError(f"Missing required environment variables: {', '.join(missing_vars)}") raise ValueError(
f"Missing required environment variables: {', '.join(missing_vars)}"
)
IptvUpdaterStack(app, "IptvUpdaterStack", IptvManagerStack(
app,
"IptvManagerStack",
freedns_user=freedns_user, freedns_user=freedns_user,
freedns_password=freedns_password, freedns_password=freedns_password,
domain_name=domain_name, domain_name=domain_name,
ssh_public_key=ssh_public_key, ssh_public_key=ssh_public_key,
repo_url=repo_url, repo_url=repo_url,
letsencrypt_email=letsencrypt_email letsencrypt_email=letsencrypt_email,
) )
app.synth() app.synth()

View File

@@ -1,9 +1,14 @@
import boto3 import boto3
from fastapi import HTTPException, status from fastapi import HTTPException, status
from app.models.auth import CognitoUser from app.models.auth import CognitoUser
from app.utils.auth import calculate_secret_hash from app.utils.auth import calculate_secret_hash
from app.utils.constants import (AWS_REGION, COGNITO_CLIENT_ID, from app.utils.constants import (
COGNITO_CLIENT_SECRET, USER_ROLE_ATTRIBUTE) AWS_REGION,
COGNITO_CLIENT_ID,
COGNITO_CLIENT_SECRET,
USER_ROLE_ATTRIBUTE,
)
cognito_client = boto3.client("cognito-idp", region_name=AWS_REGION) cognito_client = boto3.client("cognito-idp", region_name=AWS_REGION)
@@ -12,43 +17,41 @@ def initiate_auth(username: str, password: str) -> dict:
""" """
Initiate AUTH flow with Cognito using USER_PASSWORD_AUTH. Initiate AUTH flow with Cognito using USER_PASSWORD_AUTH.
""" """
auth_params = { auth_params = {"USERNAME": username, "PASSWORD": password}
"USERNAME": username,
"PASSWORD": password
}
# If a client secret is required, add SECRET_HASH # If a client secret is required, add SECRET_HASH
if COGNITO_CLIENT_SECRET: if COGNITO_CLIENT_SECRET:
auth_params["SECRET_HASH"] = calculate_secret_hash( auth_params["SECRET_HASH"] = calculate_secret_hash(
username, COGNITO_CLIENT_ID, COGNITO_CLIENT_SECRET) username, COGNITO_CLIENT_ID, COGNITO_CLIENT_SECRET
)
try: try:
response = cognito_client.initiate_auth( response = cognito_client.initiate_auth(
AuthFlow="USER_PASSWORD_AUTH", AuthFlow="USER_PASSWORD_AUTH",
AuthParameters=auth_params, AuthParameters=auth_params,
ClientId=COGNITO_CLIENT_ID ClientId=COGNITO_CLIENT_ID,
) )
return response["AuthenticationResult"] return response["AuthenticationResult"]
except cognito_client.exceptions.NotAuthorizedException: except cognito_client.exceptions.NotAuthorizedException:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid username or password" detail="Invalid username or password",
) )
except cognito_client.exceptions.UserNotFoundException: except cognito_client.exceptions.UserNotFoundException:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
detail="User not found"
) )
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"An error occurred during authentication: {str(e)}" detail=f"An error occurred during authentication: {str(e)}",
) )
def get_user_from_token(access_token: str) -> CognitoUser: def get_user_from_token(access_token: str) -> CognitoUser:
""" """
Verify the token by calling GetUser in Cognito and retrieve user attributes including roles. Verify the token by calling GetUser in Cognito and
retrieve user attributes including roles.
""" """
try: try:
user_response = cognito_client.get_user(AccessToken=access_token) user_response = cognito_client.get_user(AccessToken=access_token)
@@ -59,23 +62,21 @@ def get_user_from_token(access_token: str) -> CognitoUser:
for attr in attributes: for attr in attributes:
if attr["Name"] == USER_ROLE_ATTRIBUTE: if attr["Name"] == USER_ROLE_ATTRIBUTE:
# Assume roles are stored as a comma-separated string # Assume roles are stored as a comma-separated string
user_roles = [r.strip() user_roles = [r.strip() for r in attr["Value"].split(",") if r.strip()]
for r in attr["Value"].split(",") if r.strip()]
break break
return CognitoUser(username=username, roles=user_roles) return CognitoUser(username=username, roles=user_roles)
except cognito_client.exceptions.NotAuthorizedException: except cognito_client.exceptions.NotAuthorizedException:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token."
detail="Invalid or expired token."
) )
except cognito_client.exceptions.UserNotFoundException: except cognito_client.exceptions.UserNotFoundException:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or invalid token." detail="User not found or invalid token.",
) )
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Token verification failed: {str(e)}" detail=f"Token verification failed: {str(e)}",
) )

View File

@@ -1,6 +1,6 @@
import os
from functools import wraps from functools import wraps
from typing import Callable from typing import Callable
import os
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
@@ -13,10 +13,8 @@ if os.getenv("MOCK_AUTH", "").lower() == "true":
else: else:
from app.auth.cognito import get_user_from_token from app.auth.cognito import get_user_from_token
oauth2_scheme = OAuth2PasswordBearer( oauth2_scheme = OAuth2PasswordBearer(tokenUrl="signin", scheme_name="Bearer")
tokenUrl="signin",
scheme_name="Bearer"
)
def get_current_user(token: str = Depends(oauth2_scheme)) -> CognitoUser: def get_current_user(token: str = Depends(oauth2_scheme)) -> CognitoUser:
""" """
@@ -34,13 +32,17 @@ def require_roles(*required_roles: str) -> Callable:
def decorator(endpoint: Callable) -> Callable: def decorator(endpoint: Callable) -> Callable:
@wraps(endpoint) @wraps(endpoint)
def wrapper(*args, user: CognitoUser = Depends(get_current_user), **kwargs): async def wrapper(
*args, user: CognitoUser = Depends(get_current_user), **kwargs
):
user_roles = set(user.roles or []) user_roles = set(user.roles or [])
needed_roles = set(required_roles) needed_roles = set(required_roles)
if not needed_roles.issubset(user_roles): if not needed_roles.issubset(user_roles):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="You do not have the required roles to access this endpoint.", detail=(
"You do not have the required roles to access this endpoint."
),
) )
return endpoint(*args, user=user, **kwargs) return endpoint(*args, user=user, **kwargs)

View File

@@ -1,12 +1,9 @@
from fastapi import HTTPException, status from fastapi import HTTPException, status
from app.models.auth import CognitoUser from app.models.auth import CognitoUser
MOCK_USERS = { MOCK_USERS = {"testuser": {"username": "testuser", "roles": ["admin"]}}
"testuser": {
"username": "testuser",
"roles": ["admin"]
}
}
def mock_get_user_from_token(token: str) -> CognitoUser: def mock_get_user_from_token(token: str) -> CognitoUser:
""" """
@@ -17,16 +14,13 @@ def mock_get_user_from_token(token: str) -> CognitoUser:
return CognitoUser(**MOCK_USERS["testuser"]) return CognitoUser(**MOCK_USERS["testuser"])
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid mock token - use 'testuser'" detail="Invalid mock token - use 'testuser'",
) )
def mock_initiate_auth(username: str, password: str) -> dict: def mock_initiate_auth(username: str, password: str) -> dict:
""" """
Mock version of initiate_auth for local testing Mock version of initiate_auth for local testing
Accepts any username/password and returns a mock token Accepts any username/password and returns a mock token
""" """
return { return {"AccessToken": "testuser", "ExpiresIn": 3600, "TokenType": "Bearer"}
"AccessToken": "testuser",
"ExpiresIn": 3600,
"TokenType": "Bearer"
}

View File

@@ -1,39 +1,59 @@
import os import argparse
import re
import gzip import gzip
import json import json
import os
import re
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import requests import requests
import argparse from utils.constants import (
from utils.constants import IPTV_SERVER_ADMIN_PASSWORD, IPTV_SERVER_ADMIN_USER, IPTV_SERVER_URL IPTV_SERVER_ADMIN_PASSWORD,
IPTV_SERVER_ADMIN_USER,
IPTV_SERVER_URL,
)
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser(description='EPG Grabber') parser = argparse.ArgumentParser(description="EPG Grabber")
parser.add_argument('--playlist', parser.add_argument(
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'playlist.m3u8'), "--playlist",
help='Path to playlist file') default=os.path.join(
parser.add_argument('--output', os.path.dirname(os.path.dirname(__file__)), "playlist.m3u8"
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'epg.xml'), ),
help='Path to output EPG XML file') help="Path to playlist file",
parser.add_argument('--epg-sources', )
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'epg_sources.json'), parser.add_argument(
help='Path to EPG sources JSON configuration file') "--output",
parser.add_argument('--save-as-gz', default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "epg.xml"),
action='store_true', help="Path to output EPG XML file",
default=True, )
help='Save an additional gzipped version of the EPG file') parser.add_argument(
"--epg-sources",
default=os.path.join(
os.path.dirname(os.path.dirname(__file__)), "epg_sources.json"
),
help="Path to EPG sources JSON configuration file",
)
parser.add_argument(
"--save-as-gz",
action="store_true",
default=True,
help="Save an additional gzipped version of the EPG file",
)
return parser.parse_args() return parser.parse_args()
def load_epg_sources(config_path): def load_epg_sources(config_path):
"""Load EPG sources from JSON configuration file""" """Load EPG sources from JSON configuration file"""
try: try:
with open(config_path, 'r', encoding='utf-8') as f: with open(config_path, encoding="utf-8") as f:
config = json.load(f) config = json.load(f)
return config.get('epg_sources', []) return config.get("epg_sources", [])
except (FileNotFoundError, json.JSONDecodeError) as e: except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Error loading EPG sources: {e}") print(f"Error loading EPG sources: {e}")
return [] return []
def get_tvg_ids(playlist_path): def get_tvg_ids(playlist_path):
""" """
Extracts unique tvg-id values from an M3U playlist file. Extracts unique tvg-id values from an M3U playlist file.
@@ -51,26 +71,27 @@ def get_tvg_ids(playlist_path):
# and ends with a double quote. # and ends with a double quote.
tvg_id_pattern = re.compile(r'tvg-id="([^"]*)"') tvg_id_pattern = re.compile(r'tvg-id="([^"]*)"')
with open(playlist_path, 'r', encoding='utf-8') as file: with open(playlist_path, encoding="utf-8") as file:
for line in file: for line in file:
if line.startswith('#EXTINF'): if line.startswith("#EXTINF"):
# Search for the tvg-id pattern in the line # Search for the tvg-id pattern in the line
match = tvg_id_pattern.search(line) match = tvg_id_pattern.search(line)
if match: if match:
# Extract the captured group (the value inside the quotes) # Extract the captured group (the value inside the quotes)
tvg_id = match.group(1) tvg_id = match.group(1)
if tvg_id: # Ensure the extracted id is not empty if tvg_id: # Ensure the extracted id is not empty
unique_tvg_ids.add(tvg_id) unique_tvg_ids.add(tvg_id)
return list(unique_tvg_ids) return list(unique_tvg_ids)
def fetch_and_extract_xml(url): def fetch_and_extract_xml(url):
response = requests.get(url) response = requests.get(url)
if response.status_code != 200: if response.status_code != 200:
print(f"Failed to fetch {url}") print(f"Failed to fetch {url}")
return None return None
if url.endswith('.gz'): if url.endswith(".gz"):
try: try:
decompressed_data = gzip.decompress(response.content) decompressed_data = gzip.decompress(response.content)
return ET.fromstring(decompressed_data) return ET.fromstring(decompressed_data)
@@ -84,44 +105,48 @@ def fetch_and_extract_xml(url):
print(f"Failed to parse XML from {url}: {e}") print(f"Failed to parse XML from {url}: {e}")
return None return None
def filter_and_build_epg(urls, tvg_ids, output_file, save_as_gz=True): def filter_and_build_epg(urls, tvg_ids, output_file, save_as_gz=True):
root = ET.Element('tv') root = ET.Element("tv")
for url in urls: for url in urls:
epg_data = fetch_and_extract_xml(url) epg_data = fetch_and_extract_xml(url)
if epg_data is None: if epg_data is None:
continue continue
for channel in epg_data.findall('channel'): for channel in epg_data.findall("channel"):
tvg_id = channel.get('id') tvg_id = channel.get("id")
if tvg_id in tvg_ids: if tvg_id in tvg_ids:
root.append(channel) root.append(channel)
for programme in epg_data.findall('programme'): for programme in epg_data.findall("programme"):
tvg_id = programme.get('channel') tvg_id = programme.get("channel")
if tvg_id in tvg_ids: if tvg_id in tvg_ids:
root.append(programme) root.append(programme)
tree = ET.ElementTree(root) tree = ET.ElementTree(root)
tree.write(output_file, encoding='utf-8', xml_declaration=True) tree.write(output_file, encoding="utf-8", xml_declaration=True)
print(f"New EPG saved to {output_file}") print(f"New EPG saved to {output_file}")
if save_as_gz: if save_as_gz:
output_file_gz = output_file + '.gz' output_file_gz = output_file + ".gz"
with gzip.open(output_file_gz, 'wb') as f: with gzip.open(output_file_gz, "wb") as f:
tree.write(f, encoding='utf-8', xml_declaration=True) tree.write(f, encoding="utf-8", xml_declaration=True)
print(f"New EPG saved to {output_file_gz}") print(f"New EPG saved to {output_file_gz}")
def upload_epg(file_path): def upload_epg(file_path):
"""Uploads gzipped EPG file to IPTV server using HTTP Basic Auth""" """Uploads gzipped EPG file to IPTV server using HTTP Basic Auth"""
try: try:
with open(file_path, 'rb') as f: with open(file_path, "rb") as f:
response = requests.post( response = requests.post(
IPTV_SERVER_URL + '/admin/epg', IPTV_SERVER_URL + "/admin/epg",
auth=requests.auth.HTTPBasicAuth(IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD), auth=requests.auth.HTTPBasicAuth(
files={'file': (os.path.basename(file_path), f)} IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD
),
files={"file": (os.path.basename(file_path), f)},
) )
if response.status_code == 200: if response.status_code == 200:
print("EPG successfully uploaded to server") print("EPG successfully uploaded to server")
else: else:
@@ -129,6 +154,7 @@ def upload_epg(file_path):
except Exception as e: except Exception as e:
print(f"Upload error: {str(e)}") print(f"Upload error: {str(e)}")
if __name__ == "__main__": if __name__ == "__main__":
args = parse_arguments() args = parse_arguments()
playlist_file = args.playlist playlist_file = args.playlist
@@ -144,4 +170,4 @@ if __name__ == "__main__":
filter_and_build_epg(urls, tvg_ids, output_file, args.save_as_gz) filter_and_build_epg(urls, tvg_ids, output_file, args.save_as_gz)
if args.save_as_gz: if args.save_as_gz:
upload_epg(output_file + '.gz') upload_epg(output_file + ".gz")

View File

@@ -1,26 +1,45 @@
import os
import argparse import argparse
import json import json
import logging import logging
import requests import os
from pathlib import Path
from datetime import datetime from datetime import datetime
import requests
from utils.check_streams import StreamValidator from utils.check_streams import StreamValidator
from utils.constants import EPG_URL, IPTV_SERVER_ADMIN_PASSWORD, IPTV_SERVER_ADMIN_USER, IPTV_SERVER_URL from utils.constants import (
EPG_URL,
IPTV_SERVER_ADMIN_PASSWORD,
IPTV_SERVER_ADMIN_USER,
IPTV_SERVER_URL,
)
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser(description='IPTV playlist generator') parser = argparse.ArgumentParser(description="IPTV playlist generator")
parser.add_argument('--output', parser.add_argument(
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'playlist.m3u8'), "--output",
help='Path to output playlist file') default=os.path.join(
parser.add_argument('--channels', os.path.dirname(os.path.dirname(__file__)), "playlist.m3u8"
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'channels.json'), ),
help='Path to channels definition JSON file') help="Path to output playlist file",
parser.add_argument('--dead-channels-log', )
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'dead_channels.log'), parser.add_argument(
help='Path to log file to store a list of dead channels') "--channels",
default=os.path.join(
os.path.dirname(os.path.dirname(__file__)), "channels.json"
),
help="Path to channels definition JSON file",
)
parser.add_argument(
"--dead-channels-log",
default=os.path.join(
os.path.dirname(os.path.dirname(__file__)), "dead_channels.log"
),
help="Path to log file to store a list of dead channels",
)
return parser.parse_args() return parser.parse_args()
def find_working_stream(validator, urls): def find_working_stream(validator, urls):
"""Test all URLs and return the first working one""" """Test all URLs and return the first working one"""
for url in urls: for url in urls:
@@ -29,48 +48,55 @@ def find_working_stream(validator, urls):
return url return url
return None return None
def create_playlist(channels_file, output_file): def create_playlist(channels_file, output_file):
# Read channels from JSON file # Read channels from JSON file
with open(channels_file, 'r', encoding='utf-8') as f: with open(channels_file, encoding="utf-8") as f:
channels = json.load(f) channels = json.load(f)
# Initialize validator # Initialize validator
validator = StreamValidator(timeout=45) validator = StreamValidator(timeout=45)
# Prepare M3U8 header # Prepare M3U8 header
m3u8_content = f'#EXTM3U url-tvg="{EPG_URL}"\n' m3u8_content = f'#EXTM3U url-tvg="{EPG_URL}"\n'
for channel in channels: for channel in channels:
if 'urls' in channel: # Check if channel has URLs if "urls" in channel: # Check if channel has URLs
# Find first working stream # Find first working stream
working_url = find_working_stream(validator, channel['urls']) working_url = find_working_stream(validator, channel["urls"])
if working_url: if working_url:
# Add channel to playlist # Add channel to playlist
m3u8_content += f'#EXTINF:-1 tvg-id="{channel.get("tvg-id", "")}" ' m3u8_content += f'#EXTINF:-1 tvg-id="{channel.get("tvg-id", "")}" '
m3u8_content += f'tvg-name="{channel.get("tvg-name", "")}" ' m3u8_content += f'tvg-name="{channel.get("tvg-name", "")}" '
m3u8_content += f'tvg-logo="{channel.get("tvg-logo", "")}" ' m3u8_content += f'tvg-logo="{channel.get("tvg-logo", "")}" '
m3u8_content += f'group-title="{channel.get("group-title", "")}", ' m3u8_content += f'group-title="{channel.get("group-title", "")}", '
m3u8_content += f'{channel.get("name", "")}\n' m3u8_content += f"{channel.get('name', '')}\n"
m3u8_content += f'{working_url}\n' m3u8_content += f"{working_url}\n"
else: else:
# Log dead channel # Log dead channel
logging.info(f'Dead channel: {channel.get("name", "Unknown")} - No working streams found') logging.info(
f"Dead channel: {channel.get('name', 'Unknown')} - "
"No working streams found"
)
# Write playlist file # Write playlist file
with open(output_file, 'w', encoding='utf-8') as f: with open(output_file, "w", encoding="utf-8") as f:
f.write(m3u8_content) f.write(m3u8_content)
def upload_playlist(file_path): def upload_playlist(file_path):
"""Uploads playlist file to IPTV server using HTTP Basic Auth""" """Uploads playlist file to IPTV server using HTTP Basic Auth"""
try: try:
with open(file_path, 'rb') as f: with open(file_path, "rb") as f:
response = requests.post( response = requests.post(
IPTV_SERVER_URL + '/admin/playlist', IPTV_SERVER_URL + "/admin/playlist",
auth=requests.auth.HTTPBasicAuth(IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD), auth=requests.auth.HTTPBasicAuth(
files={'file': (os.path.basename(file_path), f)} IPTV_SERVER_ADMIN_USER, IPTV_SERVER_ADMIN_PASSWORD
),
files={"file": (os.path.basename(file_path), f)},
) )
if response.status_code == 200: if response.status_code == 200:
print("Playlist successfully uploaded to server") print("Playlist successfully uploaded to server")
else: else:
@@ -78,6 +104,7 @@ def upload_playlist(file_path):
except Exception as e: except Exception as e:
print(f"Upload error: {str(e)}") print(f"Upload error: {str(e)}")
def main(): def main():
args = parse_arguments() args = parse_arguments()
channels_file = args.channels channels_file = args.channels
@@ -85,24 +112,25 @@ def main():
dead_channels_log_file = args.dead_channels_log dead_channels_log_file = args.dead_channels_log
# Clear previous log file # Clear previous log file
with open(dead_channels_log_file, 'w') as f: with open(dead_channels_log_file, "w") as f:
f.write(f'Log created on {datetime.now()}\n') f.write(f"Log created on {datetime.now()}\n")
# Configure logging # Configure logging
logging.basicConfig( logging.basicConfig(
filename=dead_channels_log_file, filename=dead_channels_log_file,
level=logging.INFO, level=logging.INFO,
format='%(asctime)s - %(message)s', format="%(asctime)s - %(message)s",
datefmt='%Y-%m-%d %H:%M:%S' datefmt="%Y-%m-%d %H:%M:%S",
) )
# Create playlist # Create playlist
create_playlist(channels_file, output_file) create_playlist(channels_file, output_file)
#upload playlist to server # upload playlist to server
upload_playlist(output_file) upload_playlist(output_file)
print("Playlist creation completed!") print("Playlist creation completed!")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

110
app/iptv/scheduler.py Normal file
View File

@@ -0,0 +1,110 @@
import logging
import os
from typing import Optional
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from fastapi import FastAPI
from sqlalchemy.orm import Session
from app.iptv.stream_manager import StreamManager
from app.models.db import ChannelDB
from app.utils.database import get_db_session
logger = logging.getLogger(__name__)
class StreamScheduler:
"""Scheduler service for periodic stream validation tasks."""
def __init__(self, app: Optional[FastAPI] = None):
"""
Initialize the scheduler with optional FastAPI app integration.
Args:
app: Optional FastAPI app instance for lifecycle integration
"""
self.scheduler = BackgroundScheduler()
self.app = app
self.batch_size = int(os.getenv("STREAM_VALIDATION_BATCH_SIZE", "10"))
self.schedule_time = os.getenv(
"STREAM_VALIDATION_SCHEDULE", "0 3 * * *"
) # Default 3 AM daily
logger.info(f"Scheduler initialized with app: {app is not None}")
def validate_streams_batch(self, db_session: Optional[Session] = None) -> None:
"""
Validate streams and update their status.
When batch_size=0, validates all channels.
Args:
db_session: Optional SQLAlchemy session
"""
db = db_session if db_session else get_db_session()
try:
manager = StreamManager(db)
# Get channels to validate
query = db.query(ChannelDB)
if self.batch_size > 0:
query = query.limit(self.batch_size)
channels = query.all()
for channel in channels:
try:
logger.info(f"Validating streams for channel {channel.id}")
manager.validate_and_select_stream(str(channel.id))
except Exception as e:
logger.error(f"Error validating channel {channel.id}: {str(e)}")
continue
logger.info(f"Completed stream validation of {len(channels)} channels")
finally:
if db_session is None:
db.close()
def start(self) -> None:
"""Start the scheduler and add jobs."""
if not self.scheduler.running:
# Add the scheduled job
self.scheduler.add_job(
self.validate_streams_batch,
trigger=CronTrigger.from_crontab(self.schedule_time),
id="daily_stream_validation",
)
# Start the scheduler
self.scheduler.start()
logger.info(
f"Stream scheduler started with daily validation job. "
f"Running: {self.scheduler.running}"
)
# Register shutdown handler if FastAPI app is provided
if self.app:
logger.info(
f"Registering scheduler with FastAPI "
f"app: {hasattr(self.app, 'state')}"
)
@self.app.on_event("shutdown")
def shutdown_scheduler():
self.shutdown()
def shutdown(self) -> None:
"""Shutdown the scheduler gracefully."""
if self.scheduler.running:
self.scheduler.shutdown()
logger.info("Stream scheduler stopped")
def trigger_manual_validation(self) -> None:
"""Trigger manual validation of streams."""
logger.info("Manually triggering stream validation")
self.validate_streams_batch()
def init_scheduler(app: FastAPI) -> StreamScheduler:
"""Initialize and start the scheduler with FastAPI integration."""
scheduler = StreamScheduler(app)
scheduler.start()
return scheduler

151
app/iptv/stream_manager.py Normal file
View File

@@ -0,0 +1,151 @@
import logging
import random
from typing import Optional
from sqlalchemy.orm import Session
from app.models.db import ChannelURL
from app.utils.check_streams import StreamValidator
from app.utils.database import get_db_session
logger = logging.getLogger(__name__)
class StreamManager:
"""Service for managing and validating channel streams."""
def __init__(self, db_session: Optional[Session] = None):
"""
Initialize StreamManager with optional database session.
Args:
db_session: Optional SQLAlchemy session. If None, will create a new one.
"""
self.db = db_session if db_session else get_db_session()
self.validator = StreamValidator()
def get_streams_for_channel(self, channel_id: str) -> list[ChannelURL]:
"""
Get all streams for a channel ordered by priority (lowest first),
with same-priority streams randomized.
Args:
channel_id: UUID of the channel to get streams for
Returns:
List of ChannelURL objects ordered by priority
"""
try:
# Get all streams for channel ordered by priority
streams = (
self.db.query(ChannelURL)
.filter(ChannelURL.channel_id == channel_id)
.order_by(ChannelURL.priority_id)
.all()
)
# Group streams by priority and randomize same-priority streams
grouped = {}
for stream in streams:
if stream.priority_id not in grouped:
grouped[stream.priority_id] = []
grouped[stream.priority_id].append(stream)
# Randomize same-priority streams and flatten
randomized_streams = []
for priority in sorted(grouped.keys()):
random.shuffle(grouped[priority])
randomized_streams.extend(grouped[priority])
return randomized_streams
except Exception as e:
logger.error(f"Error getting streams for channel {channel_id}: {str(e)}")
raise
def validate_and_select_stream(self, channel_id: str) -> Optional[str]:
"""
Find and validate a working stream for the given channel.
Args:
channel_id: UUID of the channel to find a stream for
Returns:
URL of the first working stream found, or None if none found
"""
try:
streams = self.get_streams_for_channel(channel_id)
if not streams:
logger.warning(f"No streams found for channel {channel_id}")
return None
working_stream = None
for stream in streams:
logger.info(f"Validating stream {stream.url} for channel {channel_id}")
is_valid, _ = self.validator.validate_stream(stream.url)
if is_valid:
working_stream = stream
break
if working_stream:
self._update_stream_status(working_stream, streams)
return working_stream.url
else:
logger.warning(f"No valid streams found for channel {channel_id}")
return None
except Exception as e:
logger.error(f"Error validating streams for channel {channel_id}: {str(e)}")
raise
def _update_stream_status(
self, working_stream: ChannelURL, all_streams: list[ChannelURL]
) -> None:
"""
Update in_use status for streams (True for working stream, False for others).
Args:
working_stream: The stream that was validated as working
all_streams: All streams for the channel
"""
try:
for stream in all_streams:
stream.in_use = stream.id == working_stream.id
self.db.commit()
logger.info(
f"Updated stream status - set in_use=True for {working_stream.url}"
)
except Exception as e:
self.db.rollback()
logger.error(f"Error updating stream status: {str(e)}")
raise
def __del__(self):
"""Close database session when StreamManager is destroyed."""
if hasattr(self, "db"):
self.db.close()
def get_working_stream(
channel_id: str, db_session: Optional[Session] = None
) -> Optional[str]:
"""
Convenience function to get a working stream for a channel.
Args:
channel_id: UUID of the channel to get a stream for
db_session: Optional SQLAlchemy session
Returns:
URL of the first working stream found, or None if none found
"""
manager = StreamManager(db_session)
try:
return manager.validate_and_select_stream(channel_id)
finally:
if db_session is None: # Only close if we created the session
manager.__del__()

View File

@@ -1,24 +1,36 @@
from fastapi.concurrency import asynccontextmanager
from app.routers import channels, auth, playlist, priorities
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.concurrency import asynccontextmanager
from fastapi.openapi.utils import get_openapi from fastapi.openapi.utils import get_openapi
from app.iptv.scheduler import StreamScheduler
from app.routers import auth, channels, groups, playlist, priorities, scheduler
from app.utils.database import init_db from app.utils.database import init_db
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
# Initialize database tables on startup # Initialize database tables on startup
init_db() init_db()
# Initialize and start the stream scheduler
scheduler = StreamScheduler(app)
app.state.scheduler = scheduler # Store scheduler in app state
scheduler.start()
yield yield
# Shutdown scheduler on app shutdown
scheduler.shutdown()
app = FastAPI( app = FastAPI(
lifespan=lifespan, lifespan=lifespan,
title="IPTV Updater API", title="IPTV Manager API",
description="API for IPTV Updater service", description="API for IPTV Manager service",
version="1.0.0", version="1.0.0",
) )
def custom_openapi(): def custom_openapi():
if app.openapi_schema: if app.openapi_schema:
return app.openapi_schema return app.openapi_schema
@@ -40,11 +52,7 @@ def custom_openapi():
# Add security scheme component # Add security scheme component
openapi_schema["components"]["securitySchemes"] = { openapi_schema["components"]["securitySchemes"] = {
"Bearer": { "Bearer": {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}
"type": "http",
"scheme": "bearer",
"bearerFormat": "JWT"
}
} }
# Add global security requirement # Add global security requirement
@@ -56,14 +64,19 @@ def custom_openapi():
app.openapi_schema = openapi_schema app.openapi_schema = openapi_schema
return app.openapi_schema return app.openapi_schema
app.openapi = custom_openapi app.openapi = custom_openapi
@app.get("/") @app.get("/")
async def root(): async def root():
return {"message": "IPTV Updater API"} return {"message": "IPTV Manager API"}
# Include routers # Include routers
app.include_router(auth.router) app.include_router(auth.router)
app.include_router(channels.router) app.include_router(channels.router)
app.include_router(playlist.router) app.include_router(playlist.router)
app.include_router(priorities.router) app.include_router(priorities.router)
app.include_router(groups.router)
app.include_router(scheduler.router)

View File

@@ -1,4 +1,27 @@
from .db import Base, ChannelDB, ChannelURL from .db import Base, ChannelDB, ChannelURL, Group, Priority
from .schemas import ChannelCreate, ChannelUpdate, ChannelResponse, ChannelURLCreate, ChannelURLResponse from .schemas import (
ChannelCreate,
ChannelResponse,
ChannelUpdate,
ChannelURLCreate,
ChannelURLResponse,
GroupCreate,
GroupResponse,
GroupUpdate,
)
__all__ = ["Base", "ChannelDB", "ChannelCreate", "ChannelUpdate", "ChannelResponse", "ChannelURL", "ChannelURLCreate", "ChannelURLResponse"] __all__ = [
"Base",
"ChannelDB",
"ChannelCreate",
"ChannelUpdate",
"ChannelResponse",
"ChannelURL",
"ChannelURLCreate",
"ChannelURLResponse",
"Group",
"Priority",
"GroupCreate",
"GroupResponse",
"GroupUpdate",
]

View File

@@ -1,20 +1,26 @@
from typing import List, Optional from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class SigninRequest(BaseModel): class SigninRequest(BaseModel):
"""Request model for the signin endpoint.""" """Request model for the signin endpoint."""
username: str = Field(..., description="The user's username") username: str = Field(..., description="The user's username")
password: str = Field(..., description="The user's password") password: str = Field(..., description="The user's password")
class TokenResponse(BaseModel): class TokenResponse(BaseModel):
"""Response model for successful authentication.""" """Response model for successful authentication."""
access_token: str = Field(..., description="Access JWT token from Cognito") access_token: str = Field(..., description="Access JWT token from Cognito")
id_token: str = Field(..., description="ID JWT token from Cognito") id_token: str = Field(..., description="ID JWT token from Cognito")
refresh_token: Optional[str] = Field( refresh_token: Optional[str] = Field(None, description="Refresh token from Cognito")
None, description="Refresh token from Cognito")
token_type: str = Field(..., description="Type of the token returned") token_type: str = Field(..., description="Type of the token returned")
class CognitoUser(BaseModel): class CognitoUser(BaseModel):
"""Model representing the user returned from token verification.""" """Model representing the user returned from token verification."""
username: str username: str
roles: List[str] roles: list[str]

View File

@@ -1,51 +1,139 @@
from datetime import datetime, timezone import os
import uuid import uuid
from sqlalchemy import Column, String, JSON, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer from datetime import datetime, timezone
from sqlalchemy import (
TEXT,
Boolean,
Column,
DateTime,
ForeignKey,
Integer,
String,
TypeDecorator,
UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import declarative_base from sqlalchemy.orm import declarative_base, relationship
from sqlalchemy.orm import relationship
# Custom UUID type for SQLite compatibility
class SQLiteUUID(TypeDecorator):
"""Enables UUID support for SQLite with proper comparison handling."""
impl = TEXT
cache_ok = True
def process_bind_param(self, value, dialect):
if value is None:
return value
if isinstance(value, uuid.UUID):
return str(value)
try:
# Validate string format by attempting to create UUID
uuid.UUID(value)
return value
except (ValueError, AttributeError):
raise ValueError(f"Invalid UUID string format: {value}")
def process_result_value(self, value, dialect):
if value is None:
return value
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(value)
def compare_values(self, x, y):
if x is None or y is None:
return x == y
return str(x) == str(y)
# Determine which UUID type to use based on environment
if os.getenv("MOCK_AUTH", "").lower() == "true":
UUID_COLUMN_TYPE = SQLiteUUID()
else:
UUID_COLUMN_TYPE = UUID(as_uuid=True)
Base = declarative_base() Base = declarative_base()
class Priority(Base): class Priority(Base):
"""SQLAlchemy model for channel URL priorities""" """SQLAlchemy model for channel URL priorities"""
__tablename__ = "priorities" __tablename__ = "priorities"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
description = Column(String, nullable=False) description = Column(String, nullable=False)
class Group(Base):
"""SQLAlchemy model for channel groups"""
__tablename__ = "groups"
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
name = Column(String, nullable=False, unique=True)
sort_order = Column(Integer, nullable=False, default=0)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(
DateTime,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
# Relationship with Channel
channels = relationship("ChannelDB", back_populates="group")
class ChannelDB(Base): class ChannelDB(Base):
"""SQLAlchemy model for IPTV channels""" """SQLAlchemy model for IPTV channels"""
__tablename__ = "channels" __tablename__ = "channels"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
tvg_id = Column(String, nullable=False) tvg_id = Column(String, nullable=False)
name = Column(String, nullable=False) name = Column(String, nullable=False)
group_title = Column(String, nullable=False) group_id = Column(UUID_COLUMN_TYPE, ForeignKey("groups.id"), nullable=False)
tvg_name = Column(String) tvg_name = Column(String)
__table_args__ = ( __table_args__ = (UniqueConstraint("group_id", "name", name="uix_group_id_name"),)
UniqueConstraint('group_title', 'name', name='uix_group_title_name'),
)
tvg_logo = Column(String) tvg_logo = Column(String)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) updated_at = Column(
DateTime,
# Relationship with ChannelURL default=lambda: datetime.now(timezone.utc),
urls = relationship("ChannelURL", back_populates="channel", cascade="all, delete-orphan") onupdate=lambda: datetime.now(timezone.utc),
)
# Relationships
urls = relationship(
"ChannelURL", back_populates="channel", cascade="all, delete-orphan"
)
group = relationship("Group", back_populates="channels")
class ChannelURL(Base): class ChannelURL(Base):
"""SQLAlchemy model for channel URLs""" """SQLAlchemy model for channel URLs"""
__tablename__ = "channels_urls" __tablename__ = "channels_urls"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
channel_id = Column(UUID(as_uuid=True), ForeignKey('channels.id', ondelete='CASCADE'), nullable=False) channel_id = Column(
UUID_COLUMN_TYPE,
ForeignKey("channels.id", ondelete="CASCADE"),
nullable=False,
)
url = Column(String, nullable=False) url = Column(String, nullable=False)
in_use = Column(Boolean, default=False, nullable=False) in_use = Column(Boolean, default=False, nullable=False)
priority_id = Column(Integer, ForeignKey('priorities.id'), nullable=False) priority_id = Column(Integer, ForeignKey("priorities.id"), nullable=False)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) updated_at = Column(
DateTime,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
# Relationships # Relationships
channel = relationship("ChannelDB", back_populates="urls") channel = relationship("ChannelDB", back_populates="urls")
priority = relationship("Priority") priority = relationship("Priority")

View File

@@ -1,30 +1,43 @@
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import Optional
from uuid import UUID from uuid import UUID
from pydantic import BaseModel, Field, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
class PriorityBase(BaseModel): class PriorityBase(BaseModel):
"""Base Pydantic model for priorities""" """Base Pydantic model for priorities"""
id: int id: int
description: str description: str
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
class PriorityCreate(PriorityBase): class PriorityCreate(PriorityBase):
"""Pydantic model for creating priorities""" """Pydantic model for creating priorities"""
pass pass
class PriorityResponse(PriorityBase): class PriorityResponse(PriorityBase):
"""Pydantic model for priority responses""" """Pydantic model for priority responses"""
pass pass
class ChannelURLCreate(BaseModel): class ChannelURLCreate(BaseModel):
"""Pydantic model for creating channel URLs""" """Pydantic model for creating channel URLs"""
url: str url: str
priority_id: int = Field(default=100, ge=100, le=300) # Default to High, validate range priority_id: int = Field(
default=100, ge=100, le=300
) # Default to High, validate range
class ChannelURLBase(ChannelURLCreate): class ChannelURLBase(ChannelURLCreate):
"""Base Pydantic model for channel URL responses""" """Base Pydantic model for channel URL responses"""
id: UUID id: UUID
in_use: bool in_use: bool
created_at: datetime created_at: datetime
@@ -33,43 +46,95 @@ class ChannelURLBase(ChannelURLCreate):
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
class ChannelURLResponse(ChannelURLBase): class ChannelURLResponse(ChannelURLBase):
"""Pydantic model for channel URL responses""" """Pydantic model for channel URL responses"""
pass pass
# New Group Schemas
class GroupCreate(BaseModel):
"""Pydantic model for creating groups"""
name: str
sort_order: int = Field(default=0, ge=0)
class GroupUpdate(BaseModel):
"""Pydantic model for updating groups"""
name: Optional[str] = None
sort_order: Optional[int] = Field(None, ge=0)
class GroupResponse(BaseModel):
"""Pydantic model for group responses"""
id: UUID
name: str
sort_order: int
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
class GroupSortUpdate(BaseModel):
"""Pydantic model for updating a single group's sort order"""
sort_order: int = Field(ge=0)
class GroupBulkSort(BaseModel):
"""Pydantic model for bulk updating group sort orders"""
groups: list[dict] = Field(
description="List of dicts with group_id and new sort_order",
json_schema_extra={"example": [{"group_id": "uuid", "sort_order": 1}]},
)
class ChannelCreate(BaseModel): class ChannelCreate(BaseModel):
"""Pydantic model for creating channels""" """Pydantic model for creating channels"""
urls: List[ChannelURLCreate] # List of URL objects with priority
urls: list[ChannelURLCreate] # List of URL objects with priority
name: str name: str
group_title: str group_id: UUID
tvg_id: str tvg_id: str
tvg_logo: str tvg_logo: str
tvg_name: str tvg_name: str
class ChannelURLUpdate(BaseModel): class ChannelURLUpdate(BaseModel):
"""Pydantic model for updating channel URLs""" """Pydantic model for updating channel URLs"""
url: Optional[str] = None url: Optional[str] = None
in_use: Optional[bool] = None in_use: Optional[bool] = None
priority_id: Optional[int] = Field(default=None, ge=100, le=300) priority_id: Optional[int] = Field(default=None, ge=100, le=300)
class ChannelUpdate(BaseModel): class ChannelUpdate(BaseModel):
"""Pydantic model for updating channels (all fields optional)""" """Pydantic model for updating channels (all fields optional)"""
name: Optional[str] = Field(None, min_length=1) name: Optional[str] = Field(None, min_length=1)
group_title: Optional[str] = Field(None, min_length=1) group_id: Optional[UUID] = None
tvg_id: Optional[str] = Field(None, min_length=1) tvg_id: Optional[str] = Field(None, min_length=1)
tvg_logo: Optional[str] = None tvg_logo: Optional[str] = None
tvg_name: Optional[str] = Field(None, min_length=1) tvg_name: Optional[str] = Field(None, min_length=1)
class ChannelResponse(BaseModel): class ChannelResponse(BaseModel):
"""Pydantic model for channel responses""" """Pydantic model for channel responses"""
id: UUID id: UUID
name: str name: str
group_title: str group_id: UUID
tvg_id: str tvg_id: str
tvg_logo: str tvg_logo: str
tvg_name: str tvg_name: str
urls: List[ChannelURLResponse] # List of URL objects without channel_id urls: list[ChannelURLResponse] # List of URL objects without channel_id
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)

View File

@@ -1,16 +1,16 @@
from fastapi import APIRouter from fastapi import APIRouter
from app.auth.cognito import initiate_auth from app.auth.cognito import initiate_auth
from app.models.auth import SigninRequest, TokenResponse from app.models.auth import SigninRequest, TokenResponse
router = APIRouter( router = APIRouter(prefix="/auth", tags=["authentication"])
prefix="/auth",
tags=["authentication"]
)
@router.post("/signin", response_model=TokenResponse, summary="Signin Endpoint") @router.post("/signin", response_model=TokenResponse, summary="Signin Endpoint")
def signin(credentials: SigninRequest): def signin(credentials: SigninRequest):
""" """
Sign-in endpoint to authenticate the user with AWS Cognito using username and password. Sign-in endpoint to authenticate the user with AWS Cognito
using username and password.
On success, returns JWT tokens (access_token, id_token, refresh_token). On success, returns JWT tokens (access_token, id_token, refresh_token).
""" """
auth_result = initiate_auth(credentials.username, credentials.password) auth_result = initiate_auth(credentials.username, credentials.password)
@@ -19,4 +19,4 @@ def signin(credentials: SigninRequest):
id_token=auth_result["IdToken"], id_token=auth_result["IdToken"],
refresh_token=auth_result.get("RefreshToken"), refresh_token=auth_result.get("RefreshToken"),
token_type="Bearer", token_type="Bearer",
) )

View File

@@ -1,52 +1,64 @@
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from typing import List
from uuid import UUID from uuid import UUID
from sqlalchemy import and_
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import and_
from sqlalchemy.orm import Session
from app.auth.dependencies import get_current_user, require_roles
from app.models import ( from app.models import (
ChannelDB,
ChannelURL,
ChannelCreate, ChannelCreate,
ChannelUpdate, ChannelDB,
ChannelResponse, ChannelResponse,
ChannelUpdate,
ChannelURL,
ChannelURLCreate, ChannelURLCreate,
ChannelURLResponse, ChannelURLResponse,
Group,
Priority, # Added Priority import
) )
from app.models.auth import CognitoUser
from app.models.schemas import ChannelURLUpdate from app.models.schemas import ChannelURLUpdate
from app.utils.database import get_db from app.utils.database import get_db
from app.auth.dependencies import get_current_user, require_roles
from app.models.auth import CognitoUser
router = APIRouter( router = APIRouter(prefix="/channels", tags=["channels"])
prefix="/channels",
tags=["channels"]
)
@router.post("/", response_model=ChannelResponse, status_code=status.HTTP_201_CREATED) @router.post("/", response_model=ChannelResponse, status_code=status.HTTP_201_CREATED)
@require_roles("admin") @require_roles("admin")
def create_channel( def create_channel(
channel: ChannelCreate, channel: ChannelCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user) user: CognitoUser = Depends(get_current_user),
): ):
"""Create a new channel""" """Create a new channel"""
# Check for duplicate channel (same group_title + name) # Check if group exists
existing_channel = db.query(ChannelDB).filter( group = db.query(Group).filter(Group.id == channel.group_id).first()
and_( if not group:
ChannelDB.group_title == channel.group_title, raise HTTPException(
ChannelDB.name == channel.name status_code=status.HTTP_404_NOT_FOUND,
detail="Group not found",
) )
).first()
# Check for duplicate channel (same group_id + name)
existing_channel = (
db.query(ChannelDB)
.filter(
and_(
ChannelDB.group_id == channel.group_id,
ChannelDB.name == channel.name,
)
)
.first()
)
if existing_channel: if existing_channel:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_409_CONFLICT, status_code=status.HTTP_409_CONFLICT,
detail="Channel with same group_title and name already exists" detail="Channel with same group_id and name already exists",
) )
# Create channel without URLs first # Create channel without URLs first
channel_data = channel.model_dump(exclude={'urls'}) channel_data = channel.model_dump(exclude={"urls"})
urls = channel.urls urls = channel.urls
db_channel = ChannelDB(**channel_data) db_channel = ChannelDB(**channel_data)
db.add(db_channel) db.add(db_channel)
@@ -59,130 +71,368 @@ def create_channel(
channel_id=db_channel.id, channel_id=db_channel.id,
url=url.url, url=url.url,
priority_id=url.priority_id, priority_id=url.priority_id,
in_use=False in_use=False,
) )
db.add(db_url) db.add(db_url)
db.commit() db.commit()
db.refresh(db_channel) db.refresh(db_channel)
return db_channel return db_channel
@router.get("/{channel_id}", response_model=ChannelResponse) @router.get("/{channel_id}", response_model=ChannelResponse)
def get_channel( def get_channel(channel_id: UUID, db: Session = Depends(get_db)):
channel_id: UUID,
db: Session = Depends(get_db)
):
"""Get a channel by id""" """Get a channel by id"""
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first() channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not channel: if not channel:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
detail="Channel not found"
) )
return channel return channel
@router.put("/{channel_id}", response_model=ChannelResponse) @router.put("/{channel_id}", response_model=ChannelResponse)
@require_roles("admin") @require_roles("admin")
def update_channel( def update_channel(
channel_id: UUID, channel_id: UUID,
channel: ChannelUpdate, channel: ChannelUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user) user: CognitoUser = Depends(get_current_user),
): ):
"""Update a channel""" """Update a channel"""
db_channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first() db_channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not db_channel: if not db_channel:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
detail="Channel not found"
) )
# Only check for duplicates if name or group_title are being updated # Only check for duplicates if name or group_id are being updated
if channel.name is not None or channel.group_title is not None: if channel.name is not None or channel.group_id is not None:
name = channel.name if channel.name is not None else db_channel.name name = channel.name if channel.name is not None else db_channel.name
group_title = channel.group_title if channel.group_title is not None else db_channel.group_title group_id = (
channel.group_id if channel.group_id is not None else db_channel.group_id
existing_channel = db.query(ChannelDB).filter( )
and_(
ChannelDB.group_title == group_title, # Check if new group exists
ChannelDB.name == name, if channel.group_id is not None:
ChannelDB.id != channel_id group = db.query(Group).filter(Group.id == channel.group_id).first()
if not group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Group not found",
)
existing_channel = (
db.query(ChannelDB)
.filter(
and_(
ChannelDB.group_id == group_id,
ChannelDB.name == name,
ChannelDB.id != channel_id,
)
) )
).first() .first()
)
if existing_channel: if existing_channel:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_409_CONFLICT, status_code=status.HTTP_409_CONFLICT,
detail="Channel with same group_title and name already exists" detail="Channel with same group_id and name already exists",
) )
# Update only provided fields # Update only provided fields
update_data = channel.model_dump(exclude_unset=True) update_data = channel.model_dump(exclude_unset=True)
for key, value in update_data.items(): for key, value in update_data.items():
setattr(db_channel, key, value) setattr(db_channel, key, value)
db.commit() db.commit()
db.refresh(db_channel) db.refresh(db_channel)
return db_channel return db_channel
@router.delete("/", status_code=status.HTTP_200_OK)
@require_roles("admin")
def delete_channels(
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Delete all channels"""
count = 0
try:
count = db.query(ChannelDB).count()
# First delete all channels
db.query(ChannelDB).delete()
# Then delete any URLs that are now orphaned (no channel references)
db.query(ChannelURL).filter(
~ChannelURL.channel_id.in_(db.query(ChannelDB.id))
).delete(synchronize_session=False)
# Then delete any groups that are now empty
db.query(Group).filter(~Group.id.in_(db.query(ChannelDB.group_id))).delete(
synchronize_session=False
)
db.commit()
except Exception as e:
print(f"Error deleting channels: {e}")
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to delete channels",
)
return {"deleted": count}
@router.delete("/{channel_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{channel_id}", status_code=status.HTTP_204_NO_CONTENT)
@require_roles("admin") @require_roles("admin")
def delete_channel( def delete_channel(
channel_id: UUID, channel_id: UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user) user: CognitoUser = Depends(get_current_user),
): ):
"""Delete a channel""" """Delete a channel"""
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first() channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not channel: if not channel:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
detail="Channel not found"
) )
db.delete(channel) db.delete(channel)
db.commit() db.commit()
return None return None
@router.get("/", response_model=List[ChannelResponse])
@router.get("/", response_model=list[ChannelResponse])
@require_roles("admin") @require_roles("admin")
def list_channels( def list_channels(
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user) user: CognitoUser = Depends(get_current_user),
): ):
"""List all channels with pagination""" """List all channels with pagination"""
return db.query(ChannelDB).offset(skip).limit(limit).all() return db.query(ChannelDB).offset(skip).limit(limit).all()
# URL Management Endpoints
@router.post("/{channel_id}/urls", response_model=ChannelURLResponse, status_code=status.HTTP_201_CREATED) # New endpoint to get channels by group
@router.get("/groups/{group_id}/channels", response_model=list[ChannelResponse])
def get_channels_by_group(
group_id: UUID,
db: Session = Depends(get_db),
):
"""Get all channels for a specific group"""
group = db.query(Group).filter(Group.id == group_id).first()
if not group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
)
return db.query(ChannelDB).filter(ChannelDB.group_id == group_id).all()
# New endpoint to update a channel's group
@router.put("/{channel_id}/group", response_model=ChannelResponse)
@require_roles("admin")
def update_channel_group(
channel_id: UUID,
group_id: UUID,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Update a channel's group"""
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
)
group = db.query(Group).filter(Group.id == group_id).first()
if not group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
)
# Check for duplicate channel name in new group
existing_channel = (
db.query(ChannelDB)
.filter(
and_(
ChannelDB.group_id == group_id,
ChannelDB.name == channel.name,
ChannelDB.id != channel_id,
)
)
.first()
)
if existing_channel:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Channel with same name already exists in target group",
)
channel.group_id = group_id
db.commit()
db.refresh(channel)
return channel
# Bulk Upload and Reset Endpoints
@router.post("/bulk-upload", status_code=status.HTTP_200_OK)
@require_roles("admin")
def bulk_upload_channels(
channels: list[dict],
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Bulk upload channels from JSON array"""
processed = 0
# Fetch all priorities from the database, ordered by id
priorities = db.query(Priority).order_by(Priority.id).all()
priority_map = {i: p.id for i, p in enumerate(priorities)}
# Get the highest priority_id (which corresponds to the lowest priority level)
max_priority_id = None
if priorities:
max_priority_id = db.query(Priority.id).order_by(Priority.id.desc()).first()[0]
for channel_data in channels:
try:
# Get or create group
group_name = channel_data.get("group-title")
if not group_name:
continue
group = db.query(Group).filter(Group.name == group_name).first()
if not group:
group = Group(name=group_name)
db.add(group)
db.flush() # Use flush to make the group available in the session
db.refresh(group)
# Prepare channel data
urls = channel_data.get("urls", [])
if not isinstance(urls, list):
urls = [urls]
# Assign priorities dynamically based on fetched priorities
url_objects = []
for i, url in enumerate(urls): # Process all URLs
priority_id = priority_map.get(i)
if priority_id is None:
# If index is out of bounds,
# assign the highest priority_id (lowest priority)
if max_priority_id is not None:
priority_id = max_priority_id
else:
print(
f"Warning: No priorities defined in database. "
f"Skipping URL {url}"
)
continue
url_objects.append({"url": url, "priority_id": priority_id})
# Create channel object with required fields
channel_obj = ChannelDB(
tvg_id=channel_data.get("tvg-id", ""),
name=channel_data.get("name", ""),
group_id=group.id,
tvg_name=channel_data.get("tvg-name", ""),
tvg_logo=channel_data.get("tvg-logo", ""),
)
# Upsert channel
existing_channel = (
db.query(ChannelDB)
.filter(
and_(
ChannelDB.group_id == group.id,
ChannelDB.name == channel_obj.name,
)
)
.first()
)
if existing_channel:
# Update existing
existing_channel.tvg_id = channel_obj.tvg_id
existing_channel.tvg_name = channel_obj.tvg_name
existing_channel.tvg_logo = channel_obj.tvg_logo
# Clear and recreate URLs
db.query(ChannelURL).filter(
ChannelURL.channel_id == existing_channel.id
).delete()
for url in url_objects:
db_url = ChannelURL(
channel_id=existing_channel.id,
url=url["url"],
priority_id=url["priority_id"],
in_use=False,
)
db.add(db_url)
else:
# Create new
db.add(channel_obj)
db.flush() # Flush to get the new channel's ID
db.refresh(channel_obj)
# Add URLs for new channel
for url in url_objects:
db_url = ChannelURL(
channel_id=channel_obj.id,
url=url["url"],
priority_id=url["priority_id"],
in_use=False,
)
db.add(db_url)
db.commit() # Commit all changes for this channel atomically
processed += 1
except Exception as e:
print(f"Error processing channel: {channel_data.get('name', 'Unknown')}")
print(f"Exception details: {e}")
db.rollback() # Rollback the entire transaction for the failed channel
continue
return {"processed": processed}
# URL Management Endpoints
@router.post(
"/{channel_id}/urls",
response_model=ChannelURLResponse,
status_code=status.HTTP_201_CREATED,
)
@require_roles("admin") @require_roles("admin")
def add_channel_url( def add_channel_url(
channel_id: UUID, channel_id: UUID,
url: ChannelURLCreate, url: ChannelURLCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user) user: CognitoUser = Depends(get_current_user),
): ):
"""Add a new URL to a channel""" """Add a new URL to a channel"""
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first() channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not channel: if not channel:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
detail="Channel not found"
) )
db_url = ChannelURL( db_url = ChannelURL(
channel_id=channel_id, channel_id=channel_id,
url=url.url, url=url.url,
priority_id=url.priority_id, priority_id=url.priority_id,
in_use=False # Default to not in use in_use=False, # Default to not in use
) )
db.add(db_url) db.add(db_url)
db.commit() db.commit()
db.refresh(db_url) db.refresh(db_url)
return db_url return db_url
@router.put("/{channel_id}/urls/{url_id}", response_model=ChannelURLResponse) @router.put("/{channel_id}/urls/{url_id}", response_model=ChannelURLResponse)
@require_roles("admin") @require_roles("admin")
def update_channel_url( def update_channel_url(
@@ -190,72 +440,69 @@ def update_channel_url(
url_id: UUID, url_id: UUID,
url_update: ChannelURLUpdate, url_update: ChannelURLUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user) user: CognitoUser = Depends(get_current_user),
): ):
"""Update a channel URL (url, in_use, or priority_id)""" """Update a channel URL (url, in_use, or priority_id)"""
db_url = db.query(ChannelURL).filter( db_url = (
and_( db.query(ChannelURL)
ChannelURL.id == url_id, .filter(and_(ChannelURL.id == url_id, ChannelURL.channel_id == channel_id))
ChannelURL.channel_id == channel_id .first()
) )
).first()
if not db_url: if not db_url:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="URL not found"
detail="URL not found"
) )
if url_update.url is not None: if url_update.url is not None:
db_url.url = url_update.url db_url.url = url_update.url
if url_update.in_use is not None: if url_update.in_use is not None:
db_url.in_use = url_update.in_use db_url.in_use = url_update.in_use
if url_update.priority_id is not None: if url_update.priority_id is not None:
db_url.priority_id = url_update.priority_id db_url.priority_id = url_update.priority_id
db.commit() db.commit()
db.refresh(db_url) db.refresh(db_url)
return db_url return db_url
@router.delete("/{channel_id}/urls/{url_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{channel_id}/urls/{url_id}", status_code=status.HTTP_204_NO_CONTENT)
@require_roles("admin") @require_roles("admin")
def delete_channel_url( def delete_channel_url(
channel_id: UUID, channel_id: UUID,
url_id: UUID, url_id: UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user) user: CognitoUser = Depends(get_current_user),
): ):
"""Delete a URL from a channel""" """Delete a URL from a channel"""
url = db.query(ChannelURL).filter( url = (
and_( db.query(ChannelURL)
ChannelURL.id == url_id, .filter(and_(ChannelURL.id == url_id, ChannelURL.channel_id == channel_id))
ChannelURL.channel_id == channel_id .first()
) )
).first()
if not url: if not url:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="URL not found"
detail="URL not found"
) )
db.delete(url) db.delete(url)
db.commit() db.commit()
return None return None
@router.get("/{channel_id}/urls", response_model=List[ChannelURLResponse])
@router.get("/{channel_id}/urls", response_model=list[ChannelURLResponse])
@require_roles("admin") @require_roles("admin")
def list_channel_urls( def list_channel_urls(
channel_id: UUID, channel_id: UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user) user: CognitoUser = Depends(get_current_user),
): ):
"""List all URLs for a channel""" """List all URLs for a channel"""
channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first() channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first()
if not channel: if not channel:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found"
detail="Channel not found"
) )
return db.query(ChannelURL).filter(ChannelURL.channel_id == channel_id).all() return db.query(ChannelURL).filter(ChannelURL.channel_id == channel_id).all()

191
app/routers/groups.py Normal file
View File

@@ -0,0 +1,191 @@
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from app.auth.dependencies import get_current_user, require_roles
from app.models import Group
from app.models.auth import CognitoUser
from app.models.schemas import (
GroupBulkSort,
GroupCreate,
GroupResponse,
GroupSortUpdate,
GroupUpdate,
)
from app.utils.database import get_db
router = APIRouter(prefix="/groups", tags=["groups"])
@router.post("/", response_model=GroupResponse, status_code=status.HTTP_201_CREATED)
@require_roles("admin")
def create_group(
group: GroupCreate,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Create a new channel group"""
# Check for duplicate group name
existing_group = db.query(Group).filter(Group.name == group.name).first()
if existing_group:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Group with this name already exists",
)
db_group = Group(**group.model_dump())
db.add(db_group)
db.commit()
db.refresh(db_group)
return db_group
@router.get("/{group_id}", response_model=GroupResponse)
def get_group(group_id: UUID, db: Session = Depends(get_db)):
"""Get a group by id"""
group = db.query(Group).filter(Group.id == group_id).first()
if not group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
)
return group
@router.put("/{group_id}", response_model=GroupResponse)
@require_roles("admin")
def update_group(
group_id: UUID,
group: GroupUpdate,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Update a group's name or sort order"""
db_group = db.query(Group).filter(Group.id == group_id).first()
if not db_group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
)
# Check for duplicate name if name is being updated
if group.name is not None and group.name != db_group.name:
existing_group = db.query(Group).filter(Group.name == group.name).first()
if existing_group:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Group with this name already exists",
)
# Update only provided fields
update_data = group.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_group, key, value)
db.commit()
db.refresh(db_group)
return db_group
@router.delete("/", status_code=status.HTTP_200_OK)
@require_roles("admin")
def delete_groups(
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Delete all groups that have no channels (skip groups with channels)"""
groups = db.query(Group).all()
deleted = 0
skipped = 0
for group in groups:
if not group.channels:
db.delete(group)
deleted += 1
else:
skipped += 1
db.commit()
return {"deleted": deleted, "skipped": skipped}
@router.delete("/{group_id}", status_code=status.HTTP_204_NO_CONTENT)
@require_roles("admin")
def delete_group(
group_id: UUID,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Delete a group (only if it has no channels)"""
group = db.query(Group).filter(Group.id == group_id).first()
if not group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
)
# Check if group has any channels
if group.channels:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot delete group with existing channels",
)
db.delete(group)
db.commit()
return None
@router.get("/", response_model=list[GroupResponse])
def list_groups(db: Session = Depends(get_db)):
"""List all groups sorted by sort_order"""
return db.query(Group).order_by(Group.sort_order).all()
@router.put("/{group_id}/sort", response_model=GroupResponse)
@require_roles("admin")
def update_group_sort_order(
group_id: UUID,
sort_update: GroupSortUpdate,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Update a single group's sort order"""
db_group = db.query(Group).filter(Group.id == group_id).first()
if not db_group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Group not found"
)
db_group.sort_order = sort_update.sort_order
db.commit()
db.refresh(db_group)
return db_group
@router.post("/reorder", response_model=list[GroupResponse])
@require_roles("admin")
def bulk_update_sort_orders(
bulk_sort: GroupBulkSort,
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Bulk update group sort orders"""
groups_to_update = []
for group_data in bulk_sort.groups:
group_id = group_data["group_id"]
sort_order = group_data["sort_order"]
group = db.query(Group).filter(Group.id == str(group_id)).first()
if not group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Group with id {group_id} not found",
)
group.sort_order = sort_order
groups_to_update.append(group)
db.commit()
# Return all groups in their new order
return db.query(Group).order_by(Group.sort_order).all()

View File

@@ -1,17 +1,156 @@
from fastapi import APIRouter, Depends import logging
from enum import Enum
from typing import Optional
from uuid import uuid4
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.auth.dependencies import get_current_user from app.auth.dependencies import get_current_user
from app.iptv.stream_manager import StreamManager
from app.models.auth import CognitoUser from app.models.auth import CognitoUser
from app.utils.database import get_db_session
router = APIRouter( router = APIRouter(prefix="/playlist", tags=["playlist"])
prefix="/playlist", logger = logging.getLogger(__name__)
tags=["playlist"]
# In-memory store for validation processes
validation_processes: dict[str, dict] = {}
class ProcessStatus(str, Enum):
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
FAILED = "failed"
class StreamValidationRequest(BaseModel):
"""Request model for stream validation endpoint"""
channel_id: Optional[str] = None
class ValidatedStream(BaseModel):
"""Model for a validated working stream"""
channel_id: str
stream_url: str
class ValidationProcessResponse(BaseModel):
"""Response model for validation process initiation"""
process_id: str
status: ProcessStatus
message: str
class ValidationResultResponse(BaseModel):
"""Response model for validation results"""
process_id: str
status: ProcessStatus
working_streams: Optional[list[ValidatedStream]] = None
error: Optional[str] = None
def run_stream_validation(process_id: str, channel_id: Optional[str], db: Session):
"""Background task to validate streams"""
try:
validation_processes[process_id]["status"] = ProcessStatus.IN_PROGRESS
manager = StreamManager(db)
if channel_id:
stream_url = manager.validate_and_select_stream(channel_id)
if stream_url:
validation_processes[process_id]["result"] = {
"working_streams": [
ValidatedStream(channel_id=channel_id, stream_url=stream_url)
]
}
else:
validation_processes[process_id]["error"] = (
f"No working streams found for channel {channel_id}"
)
else:
# TODO: Implement validation for all channels
validation_processes[process_id]["error"] = (
"Validation of all channels not yet implemented"
)
validation_processes[process_id]["status"] = ProcessStatus.COMPLETED
except Exception as e:
logger.error(f"Error validating streams: {str(e)}")
validation_processes[process_id]["status"] = ProcessStatus.FAILED
validation_processes[process_id]["error"] = str(e)
@router.post(
"/validate-streams",
summary="Start stream validation process",
response_model=ValidationProcessResponse,
status_code=status.HTTP_202_ACCEPTED,
responses={202: {"description": "Validation process started successfully"}},
) )
async def start_stream_validation(
request: StreamValidationRequest,
background_tasks: BackgroundTasks,
user: CognitoUser = Depends(get_current_user),
db: Session = Depends(get_db_session),
):
"""
Start asynchronous validation of streams.
@router.get("/protected", - Returns immediately with a process ID
summary="Protected endpoint for authenticated users") - Use GET /validate-streams/{process_id} to check status
async def protected_route(user: CognitoUser = Depends(get_current_user)):
""" """
Protected endpoint that requires authentication for all users. process_id = str(uuid4())
If the user is authenticated, returns success message. validation_processes[process_id] = {
"status": ProcessStatus.PENDING,
"channel_id": request.channel_id,
}
background_tasks.add_task(run_stream_validation, process_id, request.channel_id, db)
return {
"process_id": process_id,
"status": ProcessStatus.PENDING,
"message": "Validation process started",
}
@router.get(
"/validate-streams/{process_id}",
summary="Check validation process status",
response_model=ValidationResultResponse,
responses={
200: {"description": "Process status and results"},
404: {"description": "Process not found"},
},
)
async def get_validation_status(
process_id: str, user: CognitoUser = Depends(get_current_user)
):
""" """
return {"message": f"Hello {user.username}, you have access to support resources!"} Check status of a stream validation process.
Returns current status and results if completed.
"""
if process_id not in validation_processes:
raise HTTPException(status_code=404, detail="Process not found")
process = validation_processes[process_id]
response = {"process_id": process_id, "status": process["status"]}
if process["status"] == ProcessStatus.COMPLETED:
if "error" in process:
response["error"] = process["error"]
else:
response["working_streams"] = process["result"]["working_streams"]
elif process["status"] == ProcessStatus.FAILED:
response["error"] = process["error"]
return response

View File

@@ -1,25 +1,22 @@
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import delete, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy import select, delete
from typing import List
from app.auth.dependencies import get_current_user, require_roles
from app.models.auth import CognitoUser
from app.models.db import Priority from app.models.db import Priority
from app.models.schemas import PriorityCreate, PriorityResponse from app.models.schemas import PriorityCreate, PriorityResponse
from app.utils.database import get_db from app.utils.database import get_db
from app.auth.dependencies import get_current_user, require_roles
from app.models.auth import CognitoUser
router = APIRouter( router = APIRouter(prefix="/priorities", tags=["priorities"])
prefix="/priorities",
tags=["priorities"]
)
@router.post("/", response_model=PriorityResponse, status_code=status.HTTP_201_CREATED) @router.post("/", response_model=PriorityResponse, status_code=status.HTTP_201_CREATED)
@require_roles("admin") @require_roles("admin")
def create_priority( def create_priority(
priority: PriorityCreate, priority: PriorityCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user) user: CognitoUser = Depends(get_current_user),
): ):
"""Create a new priority""" """Create a new priority"""
# Check if priority with this ID already exists # Check if priority with this ID already exists
@@ -27,71 +24,97 @@ def create_priority(
if existing: if existing:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_409_CONFLICT, status_code=status.HTTP_409_CONFLICT,
detail=f"Priority with ID {priority.id} already exists" detail=f"Priority with ID {priority.id} already exists",
) )
db_priority = Priority(**priority.model_dump()) db_priority = Priority(**priority.model_dump())
db.add(db_priority) db.add(db_priority)
db.commit() db.commit()
db.refresh(db_priority) db.refresh(db_priority)
return db_priority return db_priority
@router.get("/", response_model=List[PriorityResponse])
@router.get("/", response_model=list[PriorityResponse])
@require_roles("admin") @require_roles("admin")
def list_priorities( def list_priorities(
db: Session = Depends(get_db), db: Session = Depends(get_db), user: CognitoUser = Depends(get_current_user)
user: CognitoUser = Depends(get_current_user)
): ):
"""List all priorities""" """List all priorities"""
return db.query(Priority).all() return db.query(Priority).all()
@router.get("/{priority_id}", response_model=PriorityResponse) @router.get("/{priority_id}", response_model=PriorityResponse)
@require_roles("admin") @require_roles("admin")
def get_priority( def get_priority(
priority_id: int, priority_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user) user: CognitoUser = Depends(get_current_user),
): ):
"""Get a priority by id""" """Get a priority by id"""
priority = db.get(Priority, priority_id) priority = db.get(Priority, priority_id)
if not priority: if not priority:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="Priority not found"
detail="Priority not found"
) )
return priority return priority
@router.delete("/", status_code=status.HTTP_200_OK)
@require_roles("admin")
def delete_priorities(
db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user),
):
"""Delete all priorities not in use by channel URLs"""
from app.models.db import ChannelURL
priorities = db.query(Priority).all()
deleted = 0
skipped = 0
for priority in priorities:
in_use = db.scalar(
select(ChannelURL).where(ChannelURL.priority_id == priority.id).limit(1)
)
if not in_use:
db.delete(priority)
deleted += 1
else:
skipped += 1
db.commit()
return {"deleted": deleted, "skipped": skipped}
@router.delete("/{priority_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{priority_id}", status_code=status.HTTP_204_NO_CONTENT)
@require_roles("admin") @require_roles("admin")
def delete_priority( def delete_priority(
priority_id: int, priority_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: CognitoUser = Depends(get_current_user) user: CognitoUser = Depends(get_current_user),
): ):
"""Delete a priority (if not in use)""" """Delete a priority (if not in use)"""
from app.models.db import ChannelURL from app.models.db import ChannelURL
# Check if priority exists # Check if priority exists
priority = db.get(Priority, priority_id) priority = db.get(Priority, priority_id)
if not priority: if not priority:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="Priority not found"
detail="Priority not found"
) )
# Check if priority is in use # Check if priority is in use
in_use = db.scalar( in_use = db.scalar(
select(ChannelURL) select(ChannelURL).where(ChannelURL.priority_id == priority_id).limit(1)
.where(ChannelURL.priority_id == priority_id)
.limit(1)
) )
if in_use: if in_use:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_409_CONFLICT, status_code=status.HTTP_409_CONFLICT,
detail="Cannot delete priority that is in use by channel URLs" detail="Cannot delete priority that is in use by channel URLs",
) )
db.execute(delete(Priority).where(Priority.id == priority_id)) db.execute(delete(Priority).where(Priority.id == priority_id))
db.commit() db.commit()
return None return None

57
app/routers/scheduler.py Normal file
View File

@@ -0,0 +1,57 @@
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from app.auth.dependencies import get_current_user, require_roles
from app.iptv.scheduler import StreamScheduler
from app.models.auth import CognitoUser
from app.utils.database import get_db
router = APIRouter(
prefix="/scheduler",
tags=["scheduler"],
responses={404: {"description": "Not found"}},
)
async def get_scheduler(request: Request) -> StreamScheduler:
"""Get the scheduler instance from the app state."""
if not hasattr(request.app.state.scheduler, "scheduler"):
raise HTTPException(status_code=500, detail="Scheduler not initialized")
return request.app.state.scheduler
@router.get("/health")
@require_roles("admin")
def scheduler_health(
scheduler: StreamScheduler = Depends(get_scheduler),
user: CognitoUser = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""Check scheduler health status (admin only)."""
try:
job = scheduler.scheduler.get_job("daily_stream_validation")
next_run = str(job.next_run_time) if job and job.next_run_time else None
return {
"status": "running" if scheduler.scheduler.running else "stopped",
"next_run": next_run,
}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to check scheduler health: {str(e)}"
)
@router.post("/trigger")
@require_roles("admin")
def trigger_validation(
scheduler: StreamScheduler = Depends(get_scheduler),
user: CognitoUser = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""Manually trigger stream validation (admin only)."""
scheduler.trigger_manual_validation()
return JSONResponse(
status_code=202, content={"message": "Stream validation triggered"}
)

View File

@@ -2,11 +2,13 @@ import base64
import hashlib import hashlib
import hmac import hmac
def calculate_secret_hash(username: str, client_id: str, client_secret: str) -> str: def calculate_secret_hash(username: str, client_id: str, client_secret: str) -> str:
""" """
Calculate the Cognito SECRET_HASH using HMAC SHA256 for secret-enabled clients. Calculate the Cognito SECRET_HASH using HMAC SHA256 for secret-enabled clients.
""" """
msg = username + client_id msg = username + client_id
dig = hmac.new(client_secret.encode('utf-8'), dig = hmac.new(
msg.encode('utf-8'), hashlib.sha256).digest() client_secret.encode("utf-8"), msg.encode("utf-8"), hashlib.sha256
return base64.b64encode(dig).decode() ).digest()
return base64.b64encode(dig).decode()

View File

@@ -1,41 +1,50 @@
import os
import argparse import argparse
import requests
import logging import logging
from requests.exceptions import RequestException, Timeout, ConnectionError, HTTPError import os
import requests
from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout
class StreamValidator: class StreamValidator:
def __init__(self, timeout=10, user_agent=None): def __init__(self, timeout=10, user_agent=None):
self.timeout = timeout self.timeout = timeout
self.session = requests.Session() self.session = requests.Session()
self.session.headers.update({ self.session.headers.update(
'User-Agent': user_agent or 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36' {
}) "User-Agent": user_agent
or (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/120.0.0.0 Safari/537.36"
)
}
)
def validate_stream(self, url): def validate_stream(self, url):
"""Validate a media stream URL with multiple fallback checks""" """Validate a media stream URL with multiple fallback checks"""
try: try:
headers = {'Range': 'bytes=0-1024'} headers = {"Range": "bytes=0-1024"}
with self.session.get( with self.session.get(
url, url,
headers=headers, headers=headers,
timeout=self.timeout, timeout=self.timeout,
stream=True, stream=True,
allow_redirects=True allow_redirects=True,
) as response: ) as response:
if response.status_code not in [200, 206]: if response.status_code not in [200, 206]:
return False, f"Invalid status code: {response.status_code}" return False, f"Invalid status code: {response.status_code}"
content_type = response.headers.get('Content-Type', '') content_type = response.headers.get("Content-Type", "")
if not self._is_valid_content_type(content_type): if not self._is_valid_content_type(content_type):
return False, f"Invalid content type: {content_type}" return False, f"Invalid content type: {content_type}"
try: try:
next(response.iter_content(chunk_size=1024)) next(response.iter_content(chunk_size=1024))
return True, "Stream is valid" return True, "Stream is valid"
except (ConnectionError, Timeout): except (ConnectionError, Timeout):
return False, "Connection failed during content read" return False, "Connection failed during content read"
except HTTPError as e: except HTTPError as e:
return False, f"HTTP Error: {str(e)}" return False, f"HTTP Error: {str(e)}"
except ConnectionError as e: except ConnectionError as e:
@@ -49,56 +58,59 @@ class StreamValidator:
def _is_valid_content_type(self, content_type): def _is_valid_content_type(self, content_type):
valid_types = [ valid_types = [
'video/mp2t', 'application/vnd.apple.mpegurl', "video/mp2t",
'application/dash+xml', 'video/mp4', "application/vnd.apple.mpegurl",
'video/webm', 'application/octet-stream', "application/dash+xml",
'application/x-mpegURL' "video/mp4",
"video/webm",
"application/octet-stream",
"application/x-mpegURL",
] ]
if content_type is None:
return False
return any(ct in content_type for ct in valid_types) return any(ct in content_type for ct in valid_types)
def parse_playlist(self, file_path): def parse_playlist(self, file_path):
"""Extract stream URLs from M3U playlist file""" """Extract stream URLs from M3U playlist file"""
urls = [] urls = []
try: try:
with open(file_path, 'r') as f: with open(file_path) as f:
for line in f: for line in f:
line = line.strip() line = line.strip()
if line and not line.startswith('#'): if line and not line.startswith("#"):
urls.append(line) urls.append(line)
except Exception as e: except Exception as e:
logging.error(f"Error reading playlist file: {str(e)}") logging.error(f"Error reading playlist file: {str(e)}")
raise raise
return urls return urls
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Validate streaming URLs from command line arguments or playlist files', description=(
formatter_class=argparse.ArgumentDefaultsHelpFormatter "Validate streaming URLs from command line arguments or playlist files"
),
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
) )
parser.add_argument( parser.add_argument(
'sources', "sources", nargs="+", help="List of URLs or file paths containing stream URLs"
nargs='+',
help='List of URLs or file paths containing stream URLs'
) )
parser.add_argument( parser.add_argument(
'--timeout', "--timeout", type=int, default=20, help="Timeout in seconds for stream checks"
type=int,
default=20,
help='Timeout in seconds for stream checks'
) )
parser.add_argument( parser.add_argument(
'--output', "--output",
default='deadstreams.txt', default="deadstreams.txt",
help='Output file name for inactive streams' help="Output file name for inactive streams",
) )
args = parser.parse_args() args = parser.parse_args()
# Configure logging # Configure logging
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s', format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[logging.FileHandler('stream_check.log'), logging.StreamHandler()] handlers=[logging.FileHandler("stream_check.log"), logging.StreamHandler()],
) )
validator = StreamValidator(timeout=args.timeout) validator = StreamValidator(timeout=args.timeout)
@@ -127,9 +139,10 @@ def main():
# Save dead streams to file # Save dead streams to file
if dead_streams: if dead_streams:
with open(args.output, 'w') as f: with open(args.output, "w") as f:
f.write('\n'.join(dead_streams)) f.write("\n".join(dead_streams))
logging.info(f"Found {len(dead_streams)} dead streams. Saved to {args.output}.") logging.info(f"Found {len(dead_streams)} dead streams. Saved to {args.output}.")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -21,4 +21,4 @@ IPTV_SERVER_ADMIN_USER = os.getenv("IPTV_SERVER_ADMIN_USER", "admin")
IPTV_SERVER_ADMIN_PASSWORD = os.getenv("IPTV_SERVER_ADMIN_PASSWORD", "adminpassword") IPTV_SERVER_ADMIN_PASSWORD = os.getenv("IPTV_SERVER_ADMIN_PASSWORD", "adminpassword")
# URL for the EPG XML file to place in the playlist's header # URL for the EPG XML file to place in the playlist's header
EPG_URL = os.getenv("EPG_URL", "https://example.com/epg.xml.gz") EPG_URL = os.getenv("EPG_URL", "https://example.com/epg.xml.gz")

View File

@@ -1,11 +1,14 @@
import os import os
import boto3 import boto3
from app.models import Base from requests import Session
from .constants import AWS_REGION
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from functools import lru_cache
from app.models import Base
from .constants import AWS_REGION
def get_db_credentials(): def get_db_credentials():
"""Fetch and cache DB credentials from environment or SSM Parameter Store""" """Fetch and cache DB credentials from environment or SSM Parameter Store"""
@@ -14,29 +17,45 @@ def get_db_credentials():
f"postgresql://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}" f"postgresql://{os.getenv('DB_USER')}:{os.getenv('DB_PASSWORD')}"
f"@{os.getenv('DB_HOST')}/{os.getenv('DB_NAME')}" f"@{os.getenv('DB_HOST')}/{os.getenv('DB_NAME')}"
) )
ssm = boto3.client('ssm', region_name=AWS_REGION) ssm = boto3.client("ssm", region_name=AWS_REGION)
try: try:
host = ssm.get_parameter(Name='/iptv-updater/DB_HOST', WithDecryption=True)['Parameter']['Value'] host = ssm.get_parameter(Name="/iptv-manager/DB_HOST", WithDecryption=True)[
user = ssm.get_parameter(Name='/iptv-updater/DB_USER', WithDecryption=True)['Parameter']['Value'] "Parameter"
password = ssm.get_parameter(Name='/iptv-updater/DB_PASSWORD', WithDecryption=True)['Parameter']['Value'] ]["Value"]
dbname = ssm.get_parameter(Name='/iptv-updater/DB_NAME', WithDecryption=True)['Parameter']['Value'] user = ssm.get_parameter(Name="/iptv-manager/DB_USER", WithDecryption=True)[
"Parameter"
]["Value"]
password = ssm.get_parameter(
Name="/iptv-manager/DB_PASSWORD", WithDecryption=True
)["Parameter"]["Value"]
dbname = ssm.get_parameter(Name="/iptv-manager/DB_NAME", WithDecryption=True)[
"Parameter"
]["Value"]
return f"postgresql://{user}:{password}@{host}/{dbname}" return f"postgresql://{user}:{password}@{host}/{dbname}"
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to fetch DB credentials from SSM: {str(e)}") raise RuntimeError(f"Failed to fetch DB credentials from SSM: {str(e)}")
# Initialize engine and session maker # Initialize engine and session maker
engine = create_engine(get_db_credentials()) engine = create_engine(get_db_credentials())
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def init_db(): def init_db():
"""Initialize database by creating all tables""" """Initialize database by creating all tables"""
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
def get_db(): def get_db():
"""Dependency for getting database session""" """Dependency for getting database session"""
db = SessionLocal() db = SessionLocal()
try: try:
yield db yield db
finally: finally:
db.close() db.close()
def get_db_session() -> Session:
"""Get a direct database session (non-generator version)"""
return SessionLocal()

View File

@@ -3,10 +3,11 @@ version: '3.8'
services: services:
postgres: postgres:
image: postgres:13 image: postgres:13
container_name: postgres
environment: environment:
POSTGRES_USER: postgres POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres POSTGRES_PASSWORD: postgres
POSTGRES_DB: iptv_updater POSTGRES_DB: iptv_manager
ports: ports:
- "5432:5432" - "5432:5432"
volumes: volumes:

View File

@@ -6,7 +6,7 @@ services:
environment: environment:
POSTGRES_USER: postgres POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres POSTGRES_PASSWORD: postgres
POSTGRES_DB: iptv_updater POSTGRES_DB: iptv_manager
ports: ports:
- "5432:5432" - "5432:5432"
volumes: volumes:
@@ -20,7 +20,7 @@ services:
DB_USER: postgres DB_USER: postgres
DB_PASSWORD: postgres DB_PASSWORD: postgres
DB_HOST: postgres DB_HOST: postgres
DB_NAME: iptv_updater DB_NAME: iptv_manager
MOCK_AUTH: "true" MOCK_AUTH: "true"
ports: ports:
- "8000:8000" - "8000:8000"

View File

@@ -1,94 +1,84 @@
import os import os
from aws_cdk import (
Duration, from aws_cdk import CfnOutput, Duration, RemovalPolicy, Stack
RemovalPolicy, from aws_cdk import aws_cognito as cognito
Stack, from aws_cdk import aws_ec2 as ec2
aws_ec2 as ec2, from aws_cdk import aws_iam as iam
aws_iam as iam, from aws_cdk import aws_rds as rds
aws_cognito as cognito, from aws_cdk import aws_ssm as ssm
aws_rds as rds,
aws_ssm as ssm,
CfnOutput
)
from constructs import Construct from constructs import Construct
class IptvUpdaterStack(Stack):
class IptvManagerStack(Stack):
def __init__( def __init__(
self, self,
scope: Construct, scope: Construct,
construct_id: str, construct_id: str,
freedns_user: str, freedns_user: str,
freedns_password: str, freedns_password: str,
domain_name: str, domain_name: str,
ssh_public_key: str, ssh_public_key: str,
repo_url: str, repo_url: str,
letsencrypt_email: str, letsencrypt_email: str,
**kwargs **kwargs,
) -> None: ) -> None:
super().__init__(scope, construct_id, **kwargs) super().__init__(scope, construct_id, **kwargs)
# Create VPC # Create VPC
vpc = ec2.Vpc(self, "IptvUpdaterVPC", vpc = ec2.Vpc(
self,
"IptvManagerVPC",
max_azs=2, # Need at least 2 AZs for RDS subnet group max_azs=2, # Need at least 2 AZs for RDS subnet group
nat_gateways=0, # No NAT Gateway to stay in free tier nat_gateways=0, # No NAT Gateway to stay in free tier
subnet_configuration=[ subnet_configuration=[
ec2.SubnetConfiguration( ec2.SubnetConfiguration(
name="public", name="public", subnet_type=ec2.SubnetType.PUBLIC, cidr_mask=24
subnet_type=ec2.SubnetType.PUBLIC,
cidr_mask=24
), ),
ec2.SubnetConfiguration( ec2.SubnetConfiguration(
name="private", name="private",
subnet_type=ec2.SubnetType.PRIVATE_ISOLATED, subnet_type=ec2.SubnetType.PRIVATE_ISOLATED,
cidr_mask=24 cidr_mask=24,
) ),
] ],
) )
# Security Group # Security Group
security_group = ec2.SecurityGroup( security_group = ec2.SecurityGroup(
self, "IptvUpdaterSG", self, "IptvManagerSG", vpc=vpc, allow_all_outbound=True
vpc=vpc,
allow_all_outbound=True
) )
security_group.add_ingress_rule( security_group.add_ingress_rule(
ec2.Peer.any_ipv4(), ec2.Peer.any_ipv4(), ec2.Port.tcp(443), "Allow HTTPS traffic"
ec2.Port.tcp(443),
"Allow HTTPS traffic"
)
security_group.add_ingress_rule(
ec2.Peer.any_ipv4(),
ec2.Port.tcp(80),
"Allow HTTP traffic"
) )
security_group.add_ingress_rule( security_group.add_ingress_rule(
ec2.Peer.any_ipv4(), ec2.Peer.any_ipv4(), ec2.Port.tcp(80), "Allow HTTP traffic"
ec2.Port.tcp(22), )
"Allow SSH traffic"
security_group.add_ingress_rule(
ec2.Peer.any_ipv4(), ec2.Port.tcp(22), "Allow SSH traffic"
) )
# Allow PostgreSQL port for tunneling restricted to developer IP # Allow PostgreSQL port for tunneling restricted to developer IP
security_group.add_ingress_rule( security_group.add_ingress_rule(
ec2.Peer.ipv4("47.189.88.48/32"), # Developer IP ec2.Peer.ipv4("47.189.88.48/32"), # Developer IP
ec2.Port.tcp(5432), ec2.Port.tcp(5432),
"Allow PostgreSQL traffic for tunneling" "Allow PostgreSQL traffic for tunneling",
) )
# Key pair for IPTV Updater instance # Key pair for IPTV Manager instance
key_pair = ec2.KeyPair( key_pair = ec2.KeyPair(
self, self,
"IptvUpdaterKeyPair", "IptvManagerKeyPair",
key_pair_name="iptv-updater-key", key_pair_name="iptv-manager-key",
public_key_material=ssh_public_key public_key_material=ssh_public_key,
) )
# Create IAM role for EC2 # Create IAM role for EC2
role = iam.Role( role = iam.Role(
self, "IptvUpdaterRole", self,
assumed_by=iam.ServicePrincipal("ec2.amazonaws.com") "IptvManagerRole",
assumed_by=iam.ServicePrincipal("ec2.amazonaws.com"),
) )
# Add SSM managed policy # Add SSM managed policy
@@ -99,96 +89,66 @@ class IptvUpdaterStack(Stack):
) )
# Add EC2 describe permissions # Add EC2 describe permissions
role.add_to_policy(iam.PolicyStatement( role.add_to_policy(
actions=["ec2:DescribeInstances"], iam.PolicyStatement(actions=["ec2:DescribeInstances"], resources=["*"])
resources=["*"] )
))
# Add SSM SendCommand permissions # Add SSM SendCommand permissions
role.add_to_policy(iam.PolicyStatement( role.add_to_policy(
actions=["ssm:SendCommand"], iam.PolicyStatement(
resources=[ actions=["ssm:SendCommand"],
f"arn:aws:ec2:{self.region}:{self.account}:instance/*", # Allow on all EC2 instances resources=[
f"arn:aws:ssm:{self.region}:{self.account}:document/AWS-RunShellScript" # Required for the RunShellScript document # Allow on all EC2 instances
] f"arn:aws:ec2:{self.region}:{self.account}:instance/*",
)) # Required for the RunShellScript document
f"arn:aws:ssm:{self.region}:{self.account}:document/AWS-RunShellScript",
# Add Cognito permissions to instance role ],
role.add_managed_policy(
iam.ManagedPolicy.from_aws_managed_policy_name(
"AmazonCognitoReadOnly"
) )
) )
# EC2 Instance # Add Cognito permissions to instance role
instance = ec2.Instance( role.add_managed_policy(
self, "IptvUpdaterInstance", iam.ManagedPolicy.from_aws_managed_policy_name("AmazonCognitoReadOnly")
vpc=vpc,
vpc_subnets=ec2.SubnetSelection(
subnet_type=ec2.SubnetType.PUBLIC
),
instance_type=ec2.InstanceType.of(
ec2.InstanceClass.T2,
ec2.InstanceSize.MICRO
),
machine_image=ec2.AmazonLinuxImage(
generation=ec2.AmazonLinuxGeneration.AMAZON_LINUX_2023
),
security_group=security_group,
key_pair=key_pair,
role=role,
# Option: 1: Enable auto-assign public IP (free tier compatible)
associate_public_ip_address=True
) )
# Option: 2: Create Elastic IP (not free tier compatible)
# eip = ec2.CfnEIP(
# self, "IptvUpdaterEIP",
# domain="vpc",
# instance_id=instance.instance_id
# )
# Add Cognito User Pool # Add Cognito User Pool
user_pool = cognito.UserPool( user_pool = cognito.UserPool(
self, "IptvUpdaterUserPool", self,
user_pool_name="iptv-updater-users", "IptvManagerUserPool",
user_pool_name="iptv-manager-users",
self_sign_up_enabled=False, # Only admins can create users self_sign_up_enabled=False, # Only admins can create users
password_policy=cognito.PasswordPolicy( password_policy=cognito.PasswordPolicy(
min_length=8, min_length=8,
require_lowercase=True, require_lowercase=True,
require_digits=True, require_digits=True,
require_symbols=True, require_symbols=True,
require_uppercase=True require_uppercase=True,
), ),
account_recovery=cognito.AccountRecovery.EMAIL_ONLY, account_recovery=cognito.AccountRecovery.EMAIL_ONLY,
removal_policy=RemovalPolicy.DESTROY removal_policy=RemovalPolicy.DESTROY,
) )
# Add App Client with the correct callback URL # Add App Client with the correct callback URL
client = user_pool.add_client("IptvUpdaterClient", client = user_pool.add_client(
"IptvManagerClient",
access_token_validity=Duration.minutes(60), access_token_validity=Duration.minutes(60),
id_token_validity=Duration.minutes(60), id_token_validity=Duration.minutes(60),
refresh_token_validity=Duration.days(1), refresh_token_validity=Duration.days(1),
auth_flows=cognito.AuthFlow( auth_flows=cognito.AuthFlow(user_password=True),
user_password=True
),
o_auth=cognito.OAuthSettings( o_auth=cognito.OAuthSettings(
flows=cognito.OAuthFlows( flows=cognito.OAuthFlows(implicit_code_grant=True)
implicit_code_grant=True
)
), ),
prevent_user_existence_errors=True, prevent_user_existence_errors=True,
generate_secret=True, generate_secret=True,
enable_token_revocation=True enable_token_revocation=True,
) )
# Add domain for hosted UI # Add domain for hosted UI
domain = user_pool.add_domain("IptvUpdaterDomain", domain = user_pool.add_domain(
cognito_domain=cognito.CognitoDomainOptions( "IptvManagerDomain",
domain_prefix="iptv-updater" cognito_domain=cognito.CognitoDomainOptions(domain_prefix="iptv-manager"),
)
) )
# Read the userdata script with proper path resolution # Read the userdata script with proper path resolution
script_dir = os.path.dirname(os.path.abspath(__file__)) script_dir = os.path.dirname(os.path.abspath(__file__))
userdata_path = os.path.join(script_dir, "userdata.sh") userdata_path = os.path.join(script_dir, "userdata.sh")
@@ -196,46 +156,56 @@ class IptvUpdaterStack(Stack):
# Creates a userdata object for Linux hosts # Creates a userdata object for Linux hosts
userdata = ec2.UserData.for_linux() userdata = ec2.UserData.for_linux()
# Add environment variables for acme.sh from parameters # Add environment variables for acme.sh from parameters
userdata.add_commands( userdata.add_commands(
f'export FREEDNS_User="{freedns_user}"', f'export FREEDNS_User="{freedns_user}"',
f'export FREEDNS_Password="{freedns_password}"', f'export FREEDNS_Password="{freedns_password}"',
f'export DOMAIN_NAME="{domain_name}"', f'export DOMAIN_NAME="{domain_name}"',
f'export REPO_URL="{repo_url}"', f'export REPO_URL="{repo_url}"',
f'export LETSENCRYPT_EMAIL="{letsencrypt_email}"' f'export LETSENCRYPT_EMAIL="{letsencrypt_email}"',
) )
# Adds one or more commands to the userdata object. # Adds one or more commands to the userdata object.
userdata.add_commands( userdata.add_commands(
f'echo "COGNITO_USER_POOL_ID={user_pool.user_pool_id}" >> /etc/environment', (
f'echo "COGNITO_CLIENT_ID={client.user_pool_client_id}" >> /etc/environment', f'echo "COGNITO_USER_POOL_ID='
f'echo "COGNITO_CLIENT_SECRET={client.user_pool_client_secret.to_string()}" >> /etc/environment', f'{user_pool.user_pool_id}" >> /etc/environment'
f'echo "DOMAIN_NAME={domain_name}" >> /etc/environment' ),
(
f'echo "COGNITO_CLIENT_ID='
f'{client.user_pool_client_id}" >> /etc/environment'
),
(
f'echo "COGNITO_CLIENT_SECRET='
f'{client.user_pool_client_secret.to_string()}" >> /etc/environment'
),
f'echo "DOMAIN_NAME={domain_name}" >> /etc/environment',
) )
userdata.add_commands(str(userdata_file, 'utf-8')) userdata.add_commands(str(userdata_file, "utf-8"))
# Create RDS Security Group # Create RDS Security Group
rds_sg = ec2.SecurityGroup( rds_sg = ec2.SecurityGroup(
self, "RdsSecurityGroup", self,
"RdsSecurityGroup",
vpc=vpc, vpc=vpc,
description="Security group for RDS PostgreSQL" description="Security group for RDS PostgreSQL",
) )
rds_sg.add_ingress_rule( rds_sg.add_ingress_rule(
security_group, security_group,
ec2.Port.tcp(5432), ec2.Port.tcp(5432),
"Allow PostgreSQL access from EC2 instance" "Allow PostgreSQL access from EC2 instance",
) )
# Create RDS PostgreSQL instance (free tier compatible - db.t3.micro) # Create RDS PostgreSQL instance (free tier compatible - db.t3.micro)
db = rds.DatabaseInstance( db = rds.DatabaseInstance(
self, "IptvUpdaterDB", self,
"IptvManagerDB",
engine=rds.DatabaseInstanceEngine.postgres( engine=rds.DatabaseInstanceEngine.postgres(
version=rds.PostgresEngineVersion.VER_13 version=rds.PostgresEngineVersion.VER_13
), ),
instance_type=ec2.InstanceType.of( instance_type=ec2.InstanceType.of(
ec2.InstanceClass.T3, ec2.InstanceClass.T3, ec2.InstanceSize.MICRO
ec2.InstanceSize.MICRO
), ),
vpc=vpc, vpc=vpc,
vpc_subnets=ec2.SubnetSelection( vpc_subnets=ec2.SubnetSelection(
@@ -244,44 +214,81 @@ class IptvUpdaterStack(Stack):
security_groups=[rds_sg], security_groups=[rds_sg],
allocated_storage=10, allocated_storage=10,
max_allocated_storage=10, max_allocated_storage=10,
database_name="iptv_updater", database_name="iptv_manager",
removal_policy=RemovalPolicy.DESTROY, removal_policy=RemovalPolicy.DESTROY,
deletion_protection=False, deletion_protection=False,
publicly_accessible=False # Avoid public IPv4 charges publicly_accessible=False, # Avoid public IPv4 charges
) )
# Add RDS permissions to instance role # Add RDS permissions to instance role
role.add_managed_policy( role.add_managed_policy(
iam.ManagedPolicy.from_aws_managed_policy_name( iam.ManagedPolicy.from_aws_managed_policy_name("AmazonRDSFullAccess")
"AmazonRDSFullAccess"
)
) )
# Store DB connection info in SSM Parameter Store # Store DB connection info in SSM Parameter Store
ssm.StringParameter(self, "DBHostParam", db_host_param = ssm.StringParameter(
parameter_name="/iptv-updater/DB_HOST", self,
string_value=db.db_instance_endpoint_address "DBHostParam",
parameter_name="/iptv-manager/DB_HOST",
string_value=db.db_instance_endpoint_address,
) )
ssm.StringParameter(self, "DBNameParam", db_name_param = ssm.StringParameter(
parameter_name="/iptv-updater/DB_NAME", self,
string_value="iptv_updater" "DBNameParam",
parameter_name="/iptv-manager/DB_NAME",
string_value="iptv_manager",
) )
ssm.StringParameter(self, "DBUserParam", db_user_param = ssm.StringParameter(
parameter_name="/iptv-updater/DB_USER", self,
string_value=db.secret.secret_value_from_json("username").to_string() "DBUserParam",
parameter_name="/iptv-manager/DB_USER",
string_value=db.secret.secret_value_from_json("username").to_string(),
) )
ssm.StringParameter(self, "DBPassParam", db_pass_param = ssm.StringParameter(
parameter_name="/iptv-updater/DB_PASSWORD", self,
string_value=db.secret.secret_value_from_json("password").to_string() "DBPassParam",
parameter_name="/iptv-manager/DB_PASSWORD",
string_value=db.secret.secret_value_from_json("password").to_string(),
) )
# Add SSM read permissions to instance role # Add SSM read permissions to instance role
role.add_managed_policy( role.add_managed_policy(
iam.ManagedPolicy.from_aws_managed_policy_name( iam.ManagedPolicy.from_aws_managed_policy_name("AmazonSSMReadOnlyAccess")
"AmazonSSMReadOnlyAccess"
)
) )
# EC2 Instance (created after all dependencies are ready)
instance = ec2.Instance(
self,
"IptvManagerInstance",
vpc=vpc,
vpc_subnets=ec2.SubnetSelection(subnet_type=ec2.SubnetType.PUBLIC),
instance_type=ec2.InstanceType.of(
ec2.InstanceClass.T2, ec2.InstanceSize.MICRO
),
machine_image=ec2.AmazonLinuxImage(
generation=ec2.AmazonLinuxGeneration.AMAZON_LINUX_2023
),
security_group=security_group,
key_pair=key_pair,
role=role,
# Option: 1: Enable auto-assign public IP (free tier compatible)
associate_public_ip_address=True,
)
# Ensure instance depends on SSM parameters being created
instance.node.add_dependency(db)
instance.node.add_dependency(db_host_param)
instance.node.add_dependency(db_name_param)
instance.node.add_dependency(db_user_param)
instance.node.add_dependency(db_pass_param)
# Option: 2: Create Elastic IP (not free tier compatible)
# eip = ec2.CfnEIP(
# self, "IptvManagerEIP",
# domain="vpc",
# instance_id=instance.instance_id
# )
# Update instance with userdata # Update instance with userdata
instance.add_user_data(userdata.render()) instance.add_user_data(userdata.render())
@@ -293,6 +300,8 @@ class IptvUpdaterStack(Stack):
# CfnOutput(self, "InstancePublicIP", value=eip.attr_public_ip) # CfnOutput(self, "InstancePublicIP", value=eip.attr_public_ip)
CfnOutput(self, "UserPoolId", value=user_pool.user_pool_id) CfnOutput(self, "UserPoolId", value=user_pool.user_pool_id)
CfnOutput(self, "UserPoolClientId", value=client.user_pool_client_id) CfnOutput(self, "UserPoolClientId", value=client.user_pool_client_id)
CfnOutput(self, "CognitoDomainUrl", CfnOutput(
value=f"https://{domain.domain_name}.auth.{self.region}.amazoncognito.com" self,
) "CognitoDomainUrl",
value=f"https://{domain.domain_name}.auth.{self.region}.amazoncognito.com",
)

View File

@@ -2,7 +2,7 @@
# Update system and install required packages # Update system and install required packages
dnf update -y dnf update -y
dnf install -y python3-pip git cronie nginx certbot python3-certbot-nginx dnf install -y python3-pip git cronie nginx certbot python3-certbot-nginx postgresql15.x86_64 awscli
# Start and enable crond service # Start and enable crond service
systemctl start crond systemctl start crond
@@ -11,27 +11,69 @@ systemctl enable crond
cd /home/ec2-user cd /home/ec2-user
git clone ${REPO_URL} git clone ${REPO_URL}
cd iptv-updater-aws cd iptv-manager-service
# Install Python packages with --ignore-installed to prevent conflicts with RPM packages # Install Python packages with --ignore-installed to prevent conflicts with RPM packages
pip3 install --ignore-installed -r requirements.txt pip3 install --ignore-installed -r requirements.txt
# Retrieve DB credentials from SSM Parameter Store with retries
echo "Attempting to retrieve DB credentials from SSM..."
for i in {1..30}; do
DB_HOST=$(aws ssm get-parameter --name "/iptv-manager/DB_HOST" --query "Parameter.Value" --output text 2>/dev/null)
DB_NAME=$(aws ssm get-parameter --name "/iptv-manager/DB_NAME" --query "Parameter.Value" --output text 2>/dev/null)
DB_USER=$(aws ssm get-parameter --name "/iptv-manager/DB_USER" --query "Parameter.Value" --output text 2>/dev/null)
DB_PASSWORD=$(aws ssm get-parameter --name "/iptv-manager/DB_PASSWORD" --query "Parameter.Value" --output text 2>/dev/null)
if [ -n "$DB_HOST" ] && [ -n "$DB_NAME" ] && [ -n "$DB_USER" ] && [ -n "$DB_PASSWORD" ]; then
echo "Successfully retrieved all DB credentials"
break
fi
echo "Waiting for SSM parameters to be available... (attempt $i/30)"
sleep 5
done
if [ -z "$DB_HOST" ] || [ -z "$DB_NAME" ] || [ -z "$DB_USER" ] || [ -z "$DB_PASSWORD" ]; then
echo "ERROR: Failed to retrieve all required DB credentials after 30 attempts"
exit 1
fi
export DB_HOST
export DB_NAME
export DB_USER
export DB_PASSWORD
# Set PGPASSWORD for psql to use
export PGPASSWORD=$DB_PASSWORD
# Wait for PostgreSQL to be ready
echo "Waiting for PostgreSQL to start..."
until psql -h $DB_HOST -U $DB_USER -d postgres -c '\q'; do
sleep 1
done
echo "PostgreSQL is ready."
# Create database if it does not exist
DB_EXISTS=$(psql -h $DB_HOST -U $DB_USER -d postgres -tc "SELECT 1 FROM pg_database WHERE datname = '$DB_NAME';")
if [ -z "$DB_EXISTS" ]; then
echo "Creating database $DB_NAME..."
psql -h $DB_HOST -U $DB_USER -d postgres -c "CREATE DATABASE $DB_NAME;"
echo "Database $DB_NAME created."
fi
# Run database migrations # Run database migrations
alembic upgrade head alembic upgrade head
# Seed initial priorities
python3 -c "from app.utils.database import SessionLocal; from app.models.db import Priority; db = SessionLocal(); db.add_all([Priority(id=100, description='High'), Priority(id=200, description='Medium'), Priority(id=300, description='Low')]); db.commit()"
# Create systemd service file # Create systemd service file
cat << 'EOF' > /etc/systemd/system/iptv-updater.service cat << 'EOF' > /etc/systemd/system/iptv-manager.service
[Unit] [Unit]
Description=IPTV Updater Service Description=IPTV Manager Service
After=network.target After=network.target
[Service] [Service]
Type=simple Type=simple
User=ec2-user User=ec2-user
WorkingDirectory=/home/ec2-user/iptv-updater-aws WorkingDirectory=/home/ec2-user/iptv-manager-service
ExecStart=/usr/local/bin/uvicorn app.main:app --host 127.0.0.1 --port 8000 ExecStart=/usr/local/bin/uvicorn app.main:app --host 127.0.0.1 --port 8000
EnvironmentFile=/etc/environment EnvironmentFile=/etc/environment
Restart=always Restart=always
@@ -56,7 +98,7 @@ sudo mkdir -p /etc/nginx/ssl
--reloadcmd "service nginx force-reload" --reloadcmd "service nginx force-reload"
# Create nginx config # Create nginx config
cat << EOF > /etc/nginx/conf.d/iptvUpdater.conf cat << EOF > /etc/nginx/conf.d/iptvManager.conf
server { server {
listen 80; listen 80;
server_name ${DOMAIN_NAME} *.${DOMAIN_NAME}; server_name ${DOMAIN_NAME} *.${DOMAIN_NAME};
@@ -83,5 +125,5 @@ EOF
# Start nginx service # Start nginx service
systemctl enable nginx systemctl enable nginx
systemctl start nginx systemctl start nginx
systemctl enable iptv-updater systemctl enable iptv-manager
systemctl start iptv-updater systemctl start iptv-manager

View File

@@ -1,5 +1,10 @@
[tool.ruff] [tool.ruff]
line-length = 88 line-length = 88
exclude = [
"alembic/versions/*.py", # Auto-generated Alembic migration files
]
[tool.ruff.lint]
select = [ select = [
"E", # pycodestyle errors "E", # pycodestyle errors
"F", # pyflakes "F", # pyflakes
@@ -9,8 +14,18 @@ select = [
] ]
ignore = [] ignore = []
[tool.ruff.isort] [tool.ruff.lint.per-file-ignores]
"tests/**/*.py" = [
"F811", # redefinition of unused name
"F401", # unused import
]
[tool.ruff.lint.isort]
known-first-party = ["app"] known-first-party = ["app"]
[tool.ruff.format] [tool.ruff.format]
docstring-code-format = true docstring-code-format = true
[tool.pytest.ini_options]
addopts = "--cov=app --cov-report=term-missing --cov-fail-under=70"
testpaths = ["tests"]

View File

@@ -5,12 +5,21 @@ python_functions = test_*
asyncio_mode = auto asyncio_mode = auto
filterwarnings = filterwarnings =
ignore::DeprecationWarning:botocore.auth ignore::DeprecationWarning:botocore.auth
ignore:The 'app' shortcut is now deprecated:DeprecationWarning:httpx._client
# Coverage configuration # Coverage configuration
addopts = addopts =
--cov=app --cov=app
--cov-report=term-missing --cov-report=term-missing
# Test environment variables
env =
MOCK_AUTH=true
DB_USER=test_user
DB_PASSWORD=test_password
DB_HOST=localhost
DB_NAME=iptv_manager_test
# Test markers # Test markers
markers = markers =
slow: mark tests as slow running slow: mark tests as slow running

View File

@@ -14,4 +14,9 @@ psycopg2-binary==2.9.9
alembic==1.16.1 alembic==1.16.1
pytest==8.1.1 pytest==8.1.1
pytest-asyncio==0.23.6 pytest-asyncio==0.23.6
pytest-mock==3.12.0 pytest-mock==3.12.0
pytest-cov==4.1.0
pytest-env==1.1.1
httpx==0.27.0
pre-commit
apscheduler==3.10.4

View File

@@ -25,7 +25,7 @@ cdk deploy --app="python3 ${PWD}/app.py"
# Update application on running instances # Update application on running instances
INSTANCE_IDS=$(aws ec2 describe-instances \ INSTANCE_IDS=$(aws ec2 describe-instances \
--region us-east-2 \ --region us-east-2 \
--filters "Name=tag:Name,Values=IptvUpdaterStack/IptvUpdaterInstance" \ --filters "Name=tag:Name,Values=IptvManagerStack/IptvManagerInstance" \
"Name=instance-state-name,Values=running" \ "Name=instance-state-name,Values=running" \
--query "Reservations[].Instances[].InstanceId" \ --query "Reservations[].Instances[].InstanceId" \
--output text) --output text)
@@ -35,7 +35,7 @@ for INSTANCE_ID in $INSTANCE_IDS; do
aws ssm send-command \ aws ssm send-command \
--instance-ids "$INSTANCE_ID" \ --instance-ids "$INSTANCE_ID" \
--document-name "AWS-RunShellScript" \ --document-name "AWS-RunShellScript" \
--parameters '{"commands":["cd /home/ec2-user/iptv-updater-aws && git pull && pip3 install -r requirements.txt && alembic upgrade head && sudo systemctl restart iptv-updater"]}' \ --parameters '{"commands":["cd /home/ec2-user/iptv-manager-service && git pull && pip3 install -r requirements.txt && alembic upgrade head && sudo systemctl restart iptv-manager"]}' \
--no-cli-pager \ --no-cli-pager \
--no-paginate --no-paginate
done done

View File

@@ -4,8 +4,10 @@
npm install -g aws-cdk npm install -g aws-cdk
python3 -m pip install -r requirements.txt python3 -m pip install -r requirements.txt
# Initialize and run database migrations # Install and configure pre-commit hooks
alembic upgrade head pre-commit install
pre-commit install-hooks
pre-commit autoupdate
# Seed initial data # Verify pytest setup
python3 -c "from app.utils.database import SessionLocal; from app.models.db import Priority; db = SessionLocal(); db.add_all([Priority(id=100, description='High'), Priority(id=200, description='Medium'), Priority(id=300, description='Low')]); db.commit()" python3 -m pytest

View File

@@ -1,21 +1,26 @@
#!/bin/bash #!/bin/bash
set -e
# Start PostgreSQL # Start PostgreSQL
docker-compose -f docker/docker-compose-db.yml up -d docker-compose -f docker/docker-compose-db.yml up -d
# Set mock auth and database environment variables # Set environment variables
export MOCK_AUTH=true export MOCK_AUTH=true
export DB_HOST=localhost
export DB_USER=postgres export DB_USER=postgres
export DB_PASSWORD=postgres export DB_PASSWORD=postgres
export DB_HOST=localhost export DB_NAME=iptv_manager
export DB_NAME=iptv_updater
echo "Ensuring database $DB_NAME exists using conditional DDL..."
PGPASSWORD=$DB_PASSWORD docker exec -i postgres psql -U $DB_USER <<< "SELECT 'CREATE DATABASE $DB_NAME' WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = '$DB_NAME')\gexec"
echo "Database $DB_NAME check complete."
# Run database migrations # Run database migrations
alembic upgrade head alembic upgrade head
# Start FastAPI # Start FastAPI
nohup uvicorn app.main:app --host 127.0.0.1 --port 8000 > app.log 2>&1 & nohup uvicorn app.main:app --host 127.0.0.1 --port 8000 > app.log 2>&1 &
echo $! > iptv-updater.pid echo $! > iptv-manager.pid
echo "Services started:" echo "Services started:"
echo "- PostgreSQL running on localhost:5432" echo "- PostgreSQL running on localhost:5432"

View File

@@ -1,9 +1,9 @@
#!/bin/bash #!/bin/bash
# Stop FastAPI # Stop FastAPI
if [ -f iptv-updater.pid ]; then if [ -f iptv-manager.pid ]; then
kill $(cat iptv-updater.pid) kill $(cat iptv-manager.pid)
rm iptv-updater.pid rm iptv-manager.pid
echo "Stopped FastAPI" echo "Stopped FastAPI"
fi fi

View File

@@ -1,5 +1,6 @@
from unittest.mock import MagicMock, patch
import pytest import pytest
from unittest.mock import patch, MagicMock
from fastapi import HTTPException, status from fastapi import HTTPException, status
# Test constants # Test constants
@@ -7,12 +8,15 @@ TEST_CLIENT_ID = "test_client_id"
TEST_CLIENT_SECRET = "test_client_secret" TEST_CLIENT_SECRET = "test_client_secret"
# Patch constants before importing the module # Patch constants before importing the module
with patch("app.utils.constants.COGNITO_CLIENT_ID", TEST_CLIENT_ID), \ with (
patch("app.utils.constants.COGNITO_CLIENT_SECRET", TEST_CLIENT_SECRET): patch("app.utils.constants.COGNITO_CLIENT_ID", TEST_CLIENT_ID),
from app.auth.cognito import initiate_auth, get_user_from_token patch("app.utils.constants.COGNITO_CLIENT_SECRET", TEST_CLIENT_SECRET),
):
from app.auth.cognito import get_user_from_token, initiate_auth
from app.models.auth import CognitoUser from app.models.auth import CognitoUser
from app.utils.constants import USER_ROLE_ATTRIBUTE from app.utils.constants import USER_ROLE_ATTRIBUTE
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_cognito_client(): def mock_cognito_client():
with patch("app.auth.cognito.cognito_client") as mock_client: with patch("app.auth.cognito.cognito_client") as mock_client:
@@ -26,13 +30,14 @@ def mock_cognito_client():
) )
yield mock_client yield mock_client
def test_initiate_auth_success(mock_cognito_client): def test_initiate_auth_success(mock_cognito_client):
# Mock successful authentication response # Mock successful authentication response
mock_cognito_client.initiate_auth.return_value = { mock_cognito_client.initiate_auth.return_value = {
"AuthenticationResult": { "AuthenticationResult": {
"AccessToken": "mock_access_token", "AccessToken": "mock_access_token",
"IdToken": "mock_id_token", "IdToken": "mock_id_token",
"RefreshToken": "mock_refresh_token" "RefreshToken": "mock_refresh_token",
} }
} }
@@ -40,104 +45,125 @@ def test_initiate_auth_success(mock_cognito_client):
assert result == { assert result == {
"AccessToken": "mock_access_token", "AccessToken": "mock_access_token",
"IdToken": "mock_id_token", "IdToken": "mock_id_token",
"RefreshToken": "mock_refresh_token" "RefreshToken": "mock_refresh_token",
} }
def test_initiate_auth_with_secret_hash(mock_cognito_client): def test_initiate_auth_with_secret_hash(mock_cognito_client):
with patch("app.auth.cognito.calculate_secret_hash", return_value="mocked_secret_hash") as mock_hash: with patch(
"app.auth.cognito.calculate_secret_hash", return_value="mocked_secret_hash"
) as mock_hash:
mock_cognito_client.initiate_auth.return_value = { mock_cognito_client.initiate_auth.return_value = {
"AuthenticationResult": {"AccessToken": "token"} "AuthenticationResult": {"AccessToken": "token"}
} }
result = initiate_auth("test_user", "test_pass") initiate_auth("test_user", "test_pass")
# Verify calculate_secret_hash was called # Verify calculate_secret_hash was called
mock_hash.assert_called_once_with("test_user", TEST_CLIENT_ID, TEST_CLIENT_SECRET) mock_hash.assert_called_once_with(
"test_user", TEST_CLIENT_ID, TEST_CLIENT_SECRET
)
# Verify SECRET_HASH was included in auth params # Verify SECRET_HASH was included in auth params
call_args = mock_cognito_client.initiate_auth.call_args[1] call_args = mock_cognito_client.initiate_auth.call_args[1]
assert "SECRET_HASH" in call_args["AuthParameters"] assert "SECRET_HASH" in call_args["AuthParameters"]
assert call_args["AuthParameters"]["SECRET_HASH"] == "mocked_secret_hash" assert call_args["AuthParameters"]["SECRET_HASH"] == "mocked_secret_hash"
def test_initiate_auth_not_authorized(mock_cognito_client): def test_initiate_auth_not_authorized(mock_cognito_client):
mock_cognito_client.initiate_auth.side_effect = mock_cognito_client.exceptions.NotAuthorizedException() mock_cognito_client.initiate_auth.side_effect = (
mock_cognito_client.exceptions.NotAuthorizedException()
)
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
initiate_auth("invalid_user", "wrong_pass") initiate_auth("invalid_user", "wrong_pass")
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert exc_info.value.detail == "Invalid username or password" assert exc_info.value.detail == "Invalid username or password"
def test_initiate_auth_user_not_found(mock_cognito_client): def test_initiate_auth_user_not_found(mock_cognito_client):
mock_cognito_client.initiate_auth.side_effect = mock_cognito_client.exceptions.UserNotFoundException() mock_cognito_client.initiate_auth.side_effect = (
mock_cognito_client.exceptions.UserNotFoundException()
)
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
initiate_auth("nonexistent_user", "any_pass") initiate_auth("nonexistent_user", "any_pass")
assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND
assert exc_info.value.detail == "User not found" assert exc_info.value.detail == "User not found"
def test_initiate_auth_generic_error(mock_cognito_client): def test_initiate_auth_generic_error(mock_cognito_client):
mock_cognito_client.initiate_auth.side_effect = Exception("Some error") mock_cognito_client.initiate_auth.side_effect = Exception("Some error")
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
initiate_auth("test_user", "test_pass") initiate_auth("test_user", "test_pass")
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert "An error occurred during authentication" in exc_info.value.detail assert "An error occurred during authentication" in exc_info.value.detail
def test_get_user_from_token_success(mock_cognito_client): def test_get_user_from_token_success(mock_cognito_client):
mock_response = { mock_response = {
"Username": "test_user", "Username": "test_user",
"UserAttributes": [ "UserAttributes": [
{"Name": "sub", "Value": "123"}, {"Name": "sub", "Value": "123"},
{"Name": USER_ROLE_ATTRIBUTE, "Value": "admin,user"} {"Name": USER_ROLE_ATTRIBUTE, "Value": "admin,user"},
] ],
} }
mock_cognito_client.get_user.return_value = mock_response mock_cognito_client.get_user.return_value = mock_response
result = get_user_from_token("valid_token") result = get_user_from_token("valid_token")
assert isinstance(result, CognitoUser) assert isinstance(result, CognitoUser)
assert result.username == "test_user" assert result.username == "test_user"
assert set(result.roles) == {"admin", "user"} assert set(result.roles) == {"admin", "user"}
def test_get_user_from_token_no_roles(mock_cognito_client): def test_get_user_from_token_no_roles(mock_cognito_client):
mock_response = { mock_response = {
"Username": "test_user", "Username": "test_user",
"UserAttributes": [{"Name": "sub", "Value": "123"}] "UserAttributes": [{"Name": "sub", "Value": "123"}],
} }
mock_cognito_client.get_user.return_value = mock_response mock_cognito_client.get_user.return_value = mock_response
result = get_user_from_token("valid_token") result = get_user_from_token("valid_token")
assert isinstance(result, CognitoUser) assert isinstance(result, CognitoUser)
assert result.username == "test_user" assert result.username == "test_user"
assert result.roles == [] assert result.roles == []
def test_get_user_from_token_invalid_token(mock_cognito_client): def test_get_user_from_token_invalid_token(mock_cognito_client):
mock_cognito_client.get_user.side_effect = mock_cognito_client.exceptions.NotAuthorizedException() mock_cognito_client.get_user.side_effect = (
mock_cognito_client.exceptions.NotAuthorizedException()
)
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
get_user_from_token("invalid_token") get_user_from_token("invalid_token")
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert exc_info.value.detail == "Invalid or expired token." assert exc_info.value.detail == "Invalid or expired token."
def test_get_user_from_token_user_not_found(mock_cognito_client): def test_get_user_from_token_user_not_found(mock_cognito_client):
mock_cognito_client.get_user.side_effect = mock_cognito_client.exceptions.UserNotFoundException() mock_cognito_client.get_user.side_effect = (
mock_cognito_client.exceptions.UserNotFoundException()
)
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
get_user_from_token("token_for_nonexistent_user") get_user_from_token("token_for_nonexistent_user")
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert exc_info.value.detail == "User not found or invalid token." assert exc_info.value.detail == "User not found or invalid token."
def test_get_user_from_token_generic_error(mock_cognito_client): def test_get_user_from_token_generic_error(mock_cognito_client):
mock_cognito_client.get_user.side_effect = Exception("Some error") mock_cognito_client.get_user.side_effect = Exception("Some error")
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
get_user_from_token("test_token") get_user_from_token("test_token")
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert "Token verification failed" in exc_info.value.detail assert "Token verification failed" in exc_info.value.detail

View File

@@ -1,9 +1,11 @@
import os
import pytest
import importlib import importlib
import os
import pytest
from fastapi import Depends, HTTPException, Request
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from fastapi import HTTPException, Depends, Request
from app.auth.dependencies import get_current_user, require_roles, oauth2_scheme from app.auth.dependencies import get_current_user, oauth2_scheme, require_roles
from app.models.auth import CognitoUser from app.models.auth import CognitoUser
# Mock user for testing # Mock user for testing
@@ -11,24 +13,30 @@ TEST_USER = CognitoUser(
username="testuser", username="testuser",
email="test@example.com", email="test@example.com",
roles=["admin", "user"], roles=["admin", "user"],
groups=["test_group"] groups=["test_group"],
) )
# Mock the underlying get_user_from_token function # Mock the underlying get_user_from_token function
def mock_get_user_from_token(token: str) -> CognitoUser: def mock_get_user_from_token(token: str) -> CognitoUser:
if token == "valid_token": if token == "valid_token":
return TEST_USER return TEST_USER
raise HTTPException(status_code=401, detail="Invalid token") raise HTTPException(status_code=401, detail="Invalid token")
# Mock endpoint for testing the require_roles decorator # Mock endpoint for testing the require_roles decorator
@require_roles("admin") @require_roles("admin")
async def mock_protected_endpoint(user: CognitoUser = Depends(get_current_user)): def mock_protected_endpoint(user: CognitoUser = Depends(get_current_user)):
return {"message": "Success", "user": user.username} return {"message": "Success", "user": user.username}
# Patch the get_user_from_token function for testing # Patch the get_user_from_token function for testing
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_auth(monkeypatch): def mock_auth(monkeypatch):
monkeypatch.setattr("app.auth.dependencies.get_user_from_token", mock_get_user_from_token) monkeypatch.setattr(
"app.auth.dependencies.get_user_from_token", mock_get_user_from_token
)
# Test get_current_user dependency # Test get_current_user dependency
def test_get_current_user_success(): def test_get_current_user_success():
@@ -37,59 +45,58 @@ def test_get_current_user_success():
assert user.username == "testuser" assert user.username == "testuser"
assert user.roles == ["admin", "user"] assert user.roles == ["admin", "user"]
def test_get_current_user_invalid_token(): def test_get_current_user_invalid_token():
with pytest.raises(HTTPException) as exc: with pytest.raises(HTTPException) as exc:
get_current_user("invalid_token") get_current_user("invalid_token")
assert exc.value.status_code == 401 assert exc.value.status_code == 401
# Test require_roles decorator # Test require_roles decorator
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_require_roles_success(): async def test_require_roles_success():
# Create test user with required role # Create test user with required role
user = CognitoUser( user = CognitoUser(
username="testuser", username="testuser", email="test@example.com", roles=["admin"], groups=[]
email="test@example.com",
roles=["admin"],
groups=[]
) )
result = await mock_protected_endpoint(user=user) result = await mock_protected_endpoint(user=user)
assert result == {"message": "Success", "user": "testuser"} assert result == {"message": "Success", "user": "testuser"}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_require_roles_missing_role(): async def test_require_roles_missing_role():
# Create test user without required role # Create test user without required role
user = CognitoUser( user = CognitoUser(
username="testuser", username="testuser", email="test@example.com", roles=["user"], groups=[]
email="test@example.com",
roles=["user"],
groups=[]
) )
with pytest.raises(HTTPException) as exc: with pytest.raises(HTTPException) as exc:
await mock_protected_endpoint(user=user) await mock_protected_endpoint(user=user)
assert exc.value.status_code == 403 assert exc.value.status_code == 403
assert exc.value.detail == "You do not have the required roles to access this endpoint." assert (
exc.value.detail
== "You do not have the required roles to access this endpoint."
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_require_roles_no_roles(): async def test_require_roles_no_roles():
# Create test user with no roles # Create test user with no roles
user = CognitoUser( user = CognitoUser(
username="testuser", username="testuser", email="test@example.com", roles=[], groups=[]
email="test@example.com",
roles=[],
groups=[]
) )
with pytest.raises(HTTPException) as exc: with pytest.raises(HTTPException) as exc:
await mock_protected_endpoint(user=user) await mock_protected_endpoint(user=user)
assert exc.value.status_code == 403 assert exc.value.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_require_roles_multiple_roles(): async def test_require_roles_multiple_roles():
# Test requiring multiple roles # Test requiring multiple roles
@require_roles("admin", "super_user") @require_roles("admin", "super_user")
async def mock_multi_role_endpoint(user: CognitoUser = Depends(get_current_user)): def mock_multi_role_endpoint(user: CognitoUser = Depends(get_current_user)):
return {"message": "Success"} return {"message": "Success"}
# User with all required roles # User with all required roles
@@ -97,7 +104,7 @@ async def test_require_roles_multiple_roles():
username="testuser", username="testuser",
email="test@example.com", email="test@example.com",
roles=["admin", "super_user", "user"], roles=["admin", "super_user", "user"],
groups=[] groups=[],
) )
result = await mock_multi_role_endpoint(user=user_with_roles) result = await mock_multi_role_endpoint(user=user_with_roles)
assert result == {"message": "Success"} assert result == {"message": "Success"}
@@ -107,56 +114,92 @@ async def test_require_roles_multiple_roles():
username="testuser", username="testuser",
email="test@example.com", email="test@example.com",
roles=["admin", "user"], roles=["admin", "user"],
groups=[] groups=[],
) )
with pytest.raises(HTTPException) as exc: with pytest.raises(HTTPException) as exc:
await mock_multi_role_endpoint(user=user_missing_role) await mock_multi_role_endpoint(user=user_missing_role)
assert exc.value.status_code == 403 assert exc.value.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_oauth2_scheme_configuration(): async def test_oauth2_scheme_configuration():
# Verify that we have a properly configured OAuth2PasswordBearer instance # Verify that we have a properly configured OAuth2PasswordBearer instance
assert isinstance(oauth2_scheme, OAuth2PasswordBearer) assert isinstance(oauth2_scheme, OAuth2PasswordBearer)
# Create a mock request with no Authorization header # Create a mock request with no Authorization header
mock_request = Request(scope={ mock_request = Request(
'type': 'http', scope={
'headers': [], "type": "http",
'method': 'GET', "headers": [],
'scheme': 'http', "method": "GET",
'path': '/', "scheme": "http",
'query_string': b'', "path": "/",
'client': ('127.0.0.1', 8000) "query_string": b"",
}) "client": ("127.0.0.1", 8000),
}
)
# Test that the scheme raises 401 when no token is provided # Test that the scheme raises 401 when no token is provided
with pytest.raises(HTTPException) as exc: with pytest.raises(HTTPException) as exc:
await oauth2_scheme(mock_request) await oauth2_scheme(mock_request)
assert exc.value.status_code == 401 assert exc.value.status_code == 401
assert exc.value.detail == "Not authenticated" assert exc.value.detail == "Not authenticated"
def test_mock_auth_import(monkeypatch): def test_mock_auth_import(monkeypatch):
# Save original env var value # Save original env var value
original_value = os.environ.get("MOCK_AUTH") original_value = os.environ.get("MOCK_AUTH")
try: try:
# Set MOCK_AUTH to true # Set MOCK_AUTH to true
monkeypatch.setenv("MOCK_AUTH", "true") monkeypatch.setenv("MOCK_AUTH", "true")
# Reload the dependencies module to trigger the import condition # Reload the dependencies module to trigger the import condition
import app.auth.dependencies import app.auth.dependencies
importlib.reload(app.auth.dependencies) importlib.reload(app.auth.dependencies)
# Verify that mock_get_user_from_token was imported # Verify that mock_get_user_from_token was imported
from app.auth.dependencies import get_user_from_token from app.auth.dependencies import get_user_from_token
assert get_user_from_token.__module__ == 'app.auth.mock_auth'
assert get_user_from_token.__module__ == "app.auth.mock_auth"
finally: finally:
# Restore original env var # Restore original env var
if original_value is None: if original_value is None:
monkeypatch.delenv("MOCK_AUTH", raising=False) monkeypatch.delenv("MOCK_AUTH", raising=False)
else: else:
monkeypatch.setenv("MOCK_AUTH", original_value) monkeypatch.setenv("MOCK_AUTH", original_value)
# Reload again to restore original state # Reload again to restore original state
importlib.reload(app.auth.dependencies) importlib.reload(app.auth.dependencies)
def test_cognito_auth_import(monkeypatch):
"""Test that cognito auth is imported when MOCK_AUTH=false (covers line 14)"""
# Save original env var value
original_value = os.environ.get("MOCK_AUTH")
try:
# Set MOCK_AUTH to false
monkeypatch.setenv("MOCK_AUTH", "false")
# Reload the dependencies module to trigger the import condition
import app.auth.dependencies
importlib.reload(app.auth.dependencies)
# Verify that get_user_from_token was imported from app.auth.cognito
from app.auth.dependencies import get_user_from_token
assert get_user_from_token.__module__ == "app.auth.cognito"
finally:
# Restore original env var
if original_value is None:
monkeypatch.delenv("MOCK_AUTH", raising=False)
else:
monkeypatch.setenv("MOCK_AUTH", original_value)
# Reload again to restore original state
importlib.reload(app.auth.dependencies)

View File

@@ -1,8 +1,10 @@
import pytest import pytest
from fastapi import HTTPException from fastapi import HTTPException
from app.auth.mock_auth import mock_get_user_from_token, mock_initiate_auth from app.auth.mock_auth import mock_get_user_from_token, mock_initiate_auth
from app.models.auth import CognitoUser from app.models.auth import CognitoUser
def test_mock_get_user_from_token_success(): def test_mock_get_user_from_token_success():
"""Test successful token validation returns expected user""" """Test successful token validation returns expected user"""
user = mock_get_user_from_token("testuser") user = mock_get_user_from_token("testuser")
@@ -10,27 +12,30 @@ def test_mock_get_user_from_token_success():
assert user.username == "testuser" assert user.username == "testuser"
assert user.roles == ["admin"] assert user.roles == ["admin"]
def test_mock_get_user_from_token_invalid(): def test_mock_get_user_from_token_invalid():
"""Test invalid token raises expected exception""" """Test invalid token raises expected exception"""
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
mock_get_user_from_token("invalid_token") mock_get_user_from_token("invalid_token")
assert exc_info.value.status_code == 401 assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Invalid mock token - use 'testuser'" assert exc_info.value.detail == "Invalid mock token - use 'testuser'"
def test_mock_initiate_auth(): def test_mock_initiate_auth():
"""Test mock authentication returns expected token response""" """Test mock authentication returns expected token response"""
result = mock_initiate_auth("any_user", "any_password") result = mock_initiate_auth("any_user", "any_password")
assert isinstance(result, dict) assert isinstance(result, dict)
assert result["AccessToken"] == "testuser" assert result["AccessToken"] == "testuser"
assert result["ExpiresIn"] == 3600 assert result["ExpiresIn"] == 3600
assert result["TokenType"] == "Bearer" assert result["TokenType"] == "Bearer"
def test_mock_initiate_auth_different_credentials(): def test_mock_initiate_auth_different_credentials():
"""Test mock authentication works with any credentials""" """Test mock authentication works with any credentials"""
result1 = mock_initiate_auth("user1", "pass1") result1 = mock_initiate_auth("user1", "pass1")
result2 = mock_initiate_auth("user2", "pass2") result2 = mock_initiate_auth("user2", "pass2")
# Should return same mock token regardless of credentials # Should return same mock token regardless of credentials
assert result1 == result2 assert result1 == result2

144
tests/models/test_db.py Normal file
View File

@@ -0,0 +1,144 @@
import os
import uuid
from unittest.mock import patch
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from app.models.db import UUID_COLUMN_TYPE, Base, SQLiteUUID
# --- Test SQLiteUUID Type ---
def test_sqliteuuid_process_bind_param_none():
"""Test SQLiteUUID.process_bind_param with None returns None"""
uuid_type = SQLiteUUID()
assert uuid_type.process_bind_param(None, None) is None
def test_sqliteuuid_process_bind_param_valid_uuid():
"""Test SQLiteUUID.process_bind_param with valid UUID returns string"""
uuid_type = SQLiteUUID()
test_uuid = uuid.uuid4()
assert uuid_type.process_bind_param(test_uuid, None) == str(test_uuid)
def test_sqliteuuid_process_bind_param_valid_string():
"""Test SQLiteUUID.process_bind_param with valid UUID string returns string"""
uuid_type = SQLiteUUID()
test_uuid_str = "550e8400-e29b-41d4-a716-446655440000"
assert uuid_type.process_bind_param(test_uuid_str, None) == test_uuid_str
def test_sqliteuuid_process_bind_param_invalid_string():
"""Test SQLiteUUID.process_bind_param raises ValueError for invalid UUID"""
uuid_type = SQLiteUUID()
with pytest.raises(ValueError, match="Invalid UUID string format"):
uuid_type.process_bind_param("invalid-uuid", None)
def test_sqliteuuid_process_result_value_none():
"""Test SQLiteUUID.process_result_value with None returns None"""
uuid_type = SQLiteUUID()
assert uuid_type.process_result_value(None, None) is None
def test_sqliteuuid_process_result_value_valid_string():
"""Test SQLiteUUID.process_result_value converts string to UUID"""
uuid_type = SQLiteUUID()
test_uuid = uuid.uuid4()
result = uuid_type.process_result_value(str(test_uuid), None)
assert isinstance(result, uuid.UUID)
assert result == test_uuid
def test_sqliteuuid_process_result_value_uuid_object():
"""Test SQLiteUUID.process_result_value: UUID object returns itself."""
uuid_type = SQLiteUUID()
test_uuid = uuid.uuid4()
result = uuid_type.process_result_value(test_uuid, None)
assert isinstance(result, uuid.UUID)
assert result is test_uuid # Ensure it's the same object, not a new one
def test_sqliteuuid_compare_values_none():
"""Test SQLiteUUID.compare_values handles None values"""
uuid_type = SQLiteUUID()
assert uuid_type.compare_values(None, None) is True
assert uuid_type.compare_values(None, uuid.uuid4()) is False
assert uuid_type.compare_values(uuid.uuid4(), None) is False
def test_sqliteuuid_compare_values_uuid():
"""Test SQLiteUUID.compare_values compares UUIDs as strings"""
uuid_type = SQLiteUUID()
test_uuid = uuid.uuid4()
assert uuid_type.compare_values(test_uuid, test_uuid) is True
assert uuid_type.compare_values(test_uuid, uuid.uuid4()) is False
def test_sqlite_uuid_comparison():
"""Test SQLiteUUID comparison functionality (moved from db_mocks.py)"""
uuid_type = SQLiteUUID()
# Test equal UUIDs
uuid1 = uuid.uuid4()
uuid2 = uuid.UUID(str(uuid1))
assert uuid_type.compare_values(uuid1, uuid2) is True
# Test UUID vs string
assert uuid_type.compare_values(uuid1, str(uuid1)) is True
# Test None comparisons
assert uuid_type.compare_values(None, None) is True
assert uuid_type.compare_values(uuid1, None) is False
assert uuid_type.compare_values(None, uuid1) is False
# Test different UUIDs
uuid3 = uuid.uuid4()
assert uuid_type.compare_values(uuid1, uuid3) is False
def test_sqlite_uuid_binding():
"""Test SQLiteUUID binding parameter handling (moved from db_mocks.py)"""
uuid_type = SQLiteUUID()
# Test UUID object binding
uuid_obj = uuid.uuid4()
assert uuid_type.process_bind_param(uuid_obj, None) == str(uuid_obj)
# Test valid UUID string binding
uuid_str = str(uuid.uuid4())
assert uuid_type.process_bind_param(uuid_str, None) == uuid_str
# Test None handling
assert uuid_type.process_bind_param(None, None) is None
# Test invalid UUID string
with pytest.raises(ValueError):
uuid_type.process_bind_param("invalid-uuid", None)
# --- Test UUID Column Type Configuration ---
def test_uuid_column_type_default():
"""Test UUID_COLUMN_TYPE uses SQLiteUUID in test environment"""
assert isinstance(UUID_COLUMN_TYPE, SQLiteUUID)
@patch.dict(os.environ, {"MOCK_AUTH": "false"})
def test_uuid_column_type_postgres():
"""Test UUID_COLUMN_TYPE uses Postgres UUID when MOCK_AUTH=false"""
# Need to re-import to get the patched environment
from importlib import reload
from app import models
reload(models.db)
from sqlalchemy.dialects.postgresql import UUID as PostgresUUID
from app.models.db import UUID_COLUMN_TYPE
assert isinstance(UUID_COLUMN_TYPE, PostgresUUID)

43
tests/routers/mocks.py Normal file
View File

@@ -0,0 +1,43 @@
from unittest.mock import Mock
from fastapi import Request
from app.iptv.scheduler import StreamScheduler
class MockScheduler:
"""Base mock APScheduler instance"""
running = True
start = Mock()
shutdown = Mock()
add_job = Mock()
remove_job = Mock()
get_job = Mock(return_value=None)
def __init__(self, running=True):
self.running = running
def create_trigger_mock(triggered_ref: dict) -> callable:
"""Create a mock trigger function that updates a reference when called"""
def trigger_mock():
triggered_ref["value"] = True
return trigger_mock
async def mock_get_scheduler(
request: Request, scheduler_class=MockScheduler, running=True, **kwargs
) -> StreamScheduler:
"""Mock dependency for get_scheduler with customization options"""
scheduler = StreamScheduler()
mock_scheduler = scheduler_class(running=running)
# Apply any additional attributes/methods
for key, value in kwargs.items():
setattr(mock_scheduler, key, value)
scheduler.scheduler = mock_scheduler
return scheduler

View File

@@ -1,34 +1,35 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from fastapi.testclient import TestClient
from fastapi import HTTPException, status from fastapi import HTTPException, status
from fastapi.testclient import TestClient
from app.main import app from app.main import app
client = TestClient(app) client = TestClient(app)
@pytest.fixture @pytest.fixture
def mock_successful_auth(): def mock_successful_auth():
return { return {
"AccessToken": "mock_access_token", "AccessToken": "mock_access_token",
"IdToken": "mock_id_token", "IdToken": "mock_id_token",
"RefreshToken": "mock_refresh_token" "RefreshToken": "mock_refresh_token",
} }
@pytest.fixture @pytest.fixture
def mock_successful_auth_no_refresh(): def mock_successful_auth_no_refresh():
return { return {"AccessToken": "mock_access_token", "IdToken": "mock_id_token"}
"AccessToken": "mock_access_token",
"IdToken": "mock_id_token"
}
def test_signin_success(mock_successful_auth): def test_signin_success(mock_successful_auth):
"""Test successful signin with all tokens""" """Test successful signin with all tokens"""
with patch('app.routers.auth.initiate_auth', return_value=mock_successful_auth): with patch("app.routers.auth.initiate_auth", return_value=mock_successful_auth):
response = client.post( response = client.post(
"/auth/signin", "/auth/signin", json={"username": "testuser", "password": "testpass"}
json={"username": "testuser", "password": "testpass"}
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["access_token"] == "mock_access_token" assert data["access_token"] == "mock_access_token"
@@ -36,14 +37,16 @@ def test_signin_success(mock_successful_auth):
assert data["refresh_token"] == "mock_refresh_token" assert data["refresh_token"] == "mock_refresh_token"
assert data["token_type"] == "Bearer" assert data["token_type"] == "Bearer"
def test_signin_success_no_refresh(mock_successful_auth_no_refresh): def test_signin_success_no_refresh(mock_successful_auth_no_refresh):
"""Test successful signin without refresh token""" """Test successful signin without refresh token"""
with patch('app.routers.auth.initiate_auth', return_value=mock_successful_auth_no_refresh): with patch(
"app.routers.auth.initiate_auth", return_value=mock_successful_auth_no_refresh
):
response = client.post( response = client.post(
"/auth/signin", "/auth/signin", json={"username": "testuser", "password": "testpass"}
json={"username": "testuser", "password": "testpass"}
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["access_token"] == "mock_access_token" assert data["access_token"] == "mock_access_token"
@@ -51,57 +54,48 @@ def test_signin_success_no_refresh(mock_successful_auth_no_refresh):
assert data["refresh_token"] is None assert data["refresh_token"] is None
assert data["token_type"] == "Bearer" assert data["token_type"] == "Bearer"
def test_signin_invalid_input(): def test_signin_invalid_input():
"""Test signin with invalid input format""" """Test signin with invalid input format"""
# Missing password # Missing password
response = client.post( response = client.post("/auth/signin", json={"username": "testuser"})
"/auth/signin",
json={"username": "testuser"}
)
assert response.status_code == 422 assert response.status_code == 422
# Missing username # Missing username
response = client.post( response = client.post("/auth/signin", json={"password": "testpass"})
"/auth/signin",
json={"password": "testpass"}
)
assert response.status_code == 422 assert response.status_code == 422
# Empty payload # Empty payload
response = client.post( response = client.post("/auth/signin", json={})
"/auth/signin",
json={}
)
assert response.status_code == 422 assert response.status_code == 422
def test_signin_auth_failure(): def test_signin_auth_failure():
"""Test signin with authentication failure""" """Test signin with authentication failure"""
with patch('app.routers.auth.initiate_auth') as mock_auth: with patch("app.routers.auth.initiate_auth") as mock_auth:
mock_auth.side_effect = HTTPException( mock_auth.side_effect = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid username or password" detail="Invalid username or password",
) )
response = client.post( response = client.post(
"/auth/signin", "/auth/signin", json={"username": "testuser", "password": "wrongpass"}
json={"username": "testuser", "password": "wrongpass"}
) )
assert response.status_code == 401 assert response.status_code == 401
data = response.json() data = response.json()
assert data["detail"] == "Invalid username or password" assert data["detail"] == "Invalid username or password"
def test_signin_user_not_found(): def test_signin_user_not_found():
"""Test signin with non-existent user""" """Test signin with non-existent user"""
with patch('app.routers.auth.initiate_auth') as mock_auth: with patch("app.routers.auth.initiate_auth") as mock_auth:
mock_auth.side_effect = HTTPException( mock_auth.side_effect = HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
detail="User not found"
) )
response = client.post( response = client.post(
"/auth/signin", "/auth/signin", json={"username": "nonexistent", "password": "testpass"}
json={"username": "nonexistent", "password": "testpass"}
) )
assert response.status_code == 404 assert response.status_code == 404
data = response.json() data = response.json()
assert data["detail"] == "User not found" assert data["detail"] == "User not found"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,461 @@
import uuid
from datetime import datetime, timezone
from fastapi import status
from sqlalchemy.orm import Session
from app.auth.dependencies import get_current_user
from app.utils.database import get_db
# Import mocks and fixtures
from tests.utils.auth_test_fixtures import (
admin_user_client,
db_session,
non_admin_user_client,
)
from tests.utils.db_mocks import (
MockChannelDB,
MockGroup,
create_mock_priorities_and_group,
)
# --- Test Cases For Group Creation ---
def test_create_group_success(db_session: Session, admin_user_client):
group_data = {"name": "Test Group", "sort_order": 1}
response = admin_user_client.post("/groups/", json=group_data)
assert response.status_code == status.HTTP_201_CREATED
data = response.json()
assert data["name"] == "Test Group"
assert data["sort_order"] == 1
assert "id" in data
assert "created_at" in data
assert "updated_at" in data
# Verify in DB
db_group = (
db_session.query(MockGroup).filter(MockGroup.name == "Test Group").first()
)
assert db_group is not None
assert db_group.sort_order == 1
def test_create_group_duplicate(db_session: Session, admin_user_client):
# Create initial group
initial_group = MockGroup(
id=uuid.uuid4(),
name="Duplicate Group",
sort_order=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(initial_group)
db_session.commit()
# Attempt to create duplicate
response = admin_user_client.post(
"/groups/", json={"name": "Duplicate Group", "sort_order": 2}
)
assert response.status_code == status.HTTP_409_CONFLICT
assert "already exists" in response.json()["detail"]
def test_create_group_forbidden_for_non_admin(
db_session: Session, non_admin_user_client
):
response = non_admin_user_client.post(
"/groups/", json={"name": "Forbidden Group", "sort_order": 1}
)
assert response.status_code == status.HTTP_403_FORBIDDEN
assert "required roles" in response.json()["detail"]
# --- Test Cases For Get Group ---
def test_get_group_success(db_session: Session, admin_user_client):
# Create a group first
test_group = MockGroup(
id=uuid.uuid4(),
name="Get Me Group",
sort_order=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(test_group)
db_session.commit()
response = admin_user_client.get(f"/groups/{test_group.id}")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["id"] == str(test_group.id)
assert data["name"] == "Get Me Group"
assert data["sort_order"] == 1
def test_get_group_not_found(db_session: Session, admin_user_client):
random_uuid = uuid.uuid4()
response = admin_user_client.get(f"/groups/{random_uuid}")
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "Group not found" in response.json()["detail"]
# --- Test Cases For Update Group ---
def test_update_group_success(db_session: Session, admin_user_client):
# Create initial group
group_id = uuid.uuid4()
test_group = MockGroup(
id=group_id,
name="Update Me",
sort_order=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(test_group)
db_session.commit()
update_data = {"name": "Updated Name", "sort_order": 2}
response = admin_user_client.put(f"/groups/{group_id}", json=update_data)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["name"] == "Updated Name"
assert data["sort_order"] == 2
# Verify in DB
db_group = db_session.query(MockGroup).filter(MockGroup.id == group_id).first()
assert db_group.name == "Updated Name"
assert db_group.sort_order == 2
def test_update_group_conflict(db_session: Session, admin_user_client):
# Create two groups
group1 = MockGroup(
id=uuid.uuid4(),
name="Group One",
sort_order=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
group2 = MockGroup(
id=uuid.uuid4(),
name="Group Two",
sort_order=2,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add_all([group1, group2])
db_session.commit()
# Try to rename group2 to conflict with group1
response = admin_user_client.put(f"/groups/{group2.id}", json={"name": "Group One"})
assert response.status_code == status.HTTP_409_CONFLICT
assert "already exists" in response.json()["detail"]
def test_update_group_not_found(db_session: Session, admin_user_client):
random_uuid = uuid.uuid4()
response = admin_user_client.put(
f"/groups/{random_uuid}", json={"name": "Non-existent"}
)
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "Group not found" in response.json()["detail"]
def test_update_group_forbidden_for_non_admin(
db_session: Session, non_admin_user_client, admin_user_client
):
# Create group with admin
group_id = uuid.uuid4()
test_group = MockGroup(
id=group_id,
name="Admin Created",
sort_order=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(test_group)
db_session.commit()
# Attempt update with non-admin
response = non_admin_user_client.put(
f"/groups/{group_id}", json={"name": "Non-Admin Update"}
)
assert response.status_code == status.HTTP_403_FORBIDDEN
assert "required roles" in response.json()["detail"]
# --- Test Cases For Delete Group ---
def test_delete_all_groups_success(db_session, admin_user_client):
"""Test reset groups endpoint"""
# Create test data
group1_id = create_mock_priorities_and_group(db_session, [], "Group A")
group2_id = create_mock_priorities_and_group(db_session, [], "Group B")
# Add channel to group2
channel_data = [
{
"group-title": "Group A",
"tvg_id": "channel1.tv",
"name": "Channel One",
"url": ["http://test.com", "http://example.com"],
}
]
admin_user_client.post("/channels/bulk-upload", json=channel_data)
# Reset groups
response = admin_user_client.delete("/groups")
assert response.status_code == status.HTTP_200_OK
assert response.json()["deleted"] == 1 # Only group2 should be deleted
assert response.json()["skipped"] == 1 # group1 has channels
# Verify group2 deleted, group1 remains
assert (
db_session.query(MockGroup).filter(MockGroup.id == group1_id).first()
is not None
)
assert db_session.query(MockGroup).filter(MockGroup.id == group2_id).first() is None
def test_delete_all_groups_forbidden_for_non_admin(db_session, non_admin_user_client):
"""Test reset groups requires admin role"""
response = non_admin_user_client.delete("/groups")
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_delete_group_success(db_session: Session, admin_user_client):
# Create group
group_id = uuid.uuid4()
test_group = MockGroup(
id=group_id,
name="Delete Me",
sort_order=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(test_group)
db_session.commit()
# Verify exists before delete
assert (
db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is not None
)
response = admin_user_client.delete(f"/groups/{group_id}")
assert response.status_code == status.HTTP_204_NO_CONTENT
# Verify deleted
assert db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is None
def test_delete_group_with_channels_fails(db_session: Session, admin_user_client):
# Create group with channel
group_id = uuid.uuid4()
test_group = MockGroup(
id=group_id,
name="Group With Channels",
sort_order=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(test_group)
# Create channel in this group
test_channel = MockChannelDB(
id=uuid.uuid4(),
tvg_id="channel1.tv",
name="Channel 1",
group_id=group_id,
tvg_name="Channel1",
tvg_logo="logo.png",
)
db_session.add(test_channel)
db_session.commit()
response = admin_user_client.delete(f"/groups/{group_id}")
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "existing channels" in response.json()["detail"]
# Verify group still exists
assert (
db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is not None
)
def test_delete_group_not_found(db_session: Session, admin_user_client):
random_uuid = uuid.uuid4()
response = admin_user_client.delete(f"/groups/{random_uuid}")
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "Group not found" in response.json()["detail"]
def test_delete_group_forbidden_for_non_admin(
db_session: Session, non_admin_user_client, admin_user_client
):
# Create group with admin
group_id = uuid.uuid4()
test_group = MockGroup(
id=group_id,
name="Admin Created",
sort_order=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(test_group)
db_session.commit()
# Attempt delete with non-admin
response = non_admin_user_client.delete(f"/groups/{group_id}")
assert response.status_code == status.HTTP_403_FORBIDDEN
assert "required roles" in response.json()["detail"]
# Verify group still exists
assert (
db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is not None
)
# --- Test Cases For List Groups ---
def test_list_groups_empty(db_session: Session, admin_user_client):
response = admin_user_client.get("/groups/")
assert response.status_code == status.HTTP_200_OK
assert response.json() == []
def test_list_groups_with_data(db_session: Session, admin_user_client):
# Create some groups
groups = [
MockGroup(
id=uuid.uuid4(),
name=f"Group {i}",
sort_order=i,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
for i in range(3)
]
db_session.add_all(groups)
db_session.commit()
response = admin_user_client.get("/groups/")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert len(data) == 3
assert data[0]["sort_order"] == 0 # Should be sorted by sort_order
assert data[1]["sort_order"] == 1
assert data[2]["sort_order"] == 2
# --- Test Cases For Sort Order Updates ---
def test_update_group_sort_order_success(db_session: Session, admin_user_client):
# Create group
group_id = uuid.uuid4()
test_group = MockGroup(
id=group_id,
name="Sort Me",
sort_order=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(test_group)
db_session.commit()
response = admin_user_client.put(f"/groups/{group_id}/sort", json={"sort_order": 5})
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["sort_order"] == 5
# Verify in DB
db_group = db_session.query(MockGroup).filter(MockGroup.id == group_id).first()
assert db_group.sort_order == 5
def test_update_group_sort_order_not_found(db_session: Session, admin_user_client):
"""Test that updating sort order for non-existent group returns 404"""
random_uuid = uuid.uuid4()
response = admin_user_client.put(
f"/groups/{random_uuid}/sort", json={"sort_order": 5}
)
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "Group not found" in response.json()["detail"]
def test_bulk_update_sort_orders_success(db_session: Session, admin_user_client):
# Create groups
groups = [
MockGroup(
id=uuid.uuid4(),
name=f"Group {i}",
sort_order=i,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
for i in range(3)
]
print(groups)
db_session.add_all(groups)
db_session.commit()
# Bulk update sort orders (reverse order)
bulk_data = {
"groups": [
{"group_id": str(groups[0].id), "sort_order": 2},
{"group_id": str(groups[1].id), "sort_order": 1},
{"group_id": str(groups[2].id), "sort_order": 0},
]
}
response = admin_user_client.post("/groups/reorder", json=bulk_data)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert len(data) == 3
# Create a dictionary for easy lookup of returned group data by ID
returned_groups_map = {item["id"]: item for item in data}
# Verify each group has its expected new sort_order
assert returned_groups_map[str(groups[0].id)]["sort_order"] == 2
assert returned_groups_map[str(groups[1].id)]["sort_order"] == 1
assert returned_groups_map[str(groups[2].id)]["sort_order"] == 0
# Verify in DB
db_groups = db_session.query(MockGroup).order_by(MockGroup.sort_order).all()
assert db_groups[0].sort_order == 2
assert db_groups[1].sort_order == 1
assert db_groups[2].sort_order == 0
def test_bulk_update_sort_orders_invalid_group(db_session: Session, admin_user_client):
# Create one group
group_id = uuid.uuid4()
test_group = MockGroup(
id=group_id,
name="Valid Group",
sort_order=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add(test_group)
db_session.commit()
# Try to update with invalid group
bulk_data = {
"groups": [
{"group_id": str(group_id), "sort_order": 2},
{"group_id": str(uuid.uuid4()), "sort_order": 1}, # Invalid group
]
}
response = admin_user_client.post("/groups/reorder", json=bulk_data)
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "not found" in response.json()["detail"]
# Verify original sort order unchanged
db_group = db_session.query(MockGroup).filter(MockGroup.id == group_id).first()
assert db_group.sort_order == 1

View File

@@ -0,0 +1,261 @@
import uuid
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from fastapi import status
from sqlalchemy.orm import Session
from app.auth.dependencies import get_current_user
# Import the router we're testing
from app.routers.playlist import (
ProcessStatus,
ValidationProcessResponse,
ValidationResultResponse,
router,
validation_processes,
)
from app.utils.database import get_db
# Import mocks and fixtures
from tests.utils.auth_test_fixtures import (
admin_user_client,
db_session,
non_admin_user_client,
)
from tests.utils.db_mocks import MockChannelDB
# --- Test Fixtures ---
@pytest.fixture
def mock_stream_manager():
with patch("app.routers.playlist.StreamManager") as mock:
yield mock
# --- Test Cases For Stream Validation ---
def test_start_stream_validation_success(
db_session: Session, admin_user_client, mock_stream_manager
):
"""Test starting a stream validation process"""
mock_instance = mock_stream_manager.return_value
mock_instance.validate_and_select_stream.return_value = "http://valid.stream.url"
response = admin_user_client.post(
"/playlist/validate-streams", json={"channel_id": "test-channel"}
)
assert response.status_code == status.HTTP_202_ACCEPTED
data = response.json()
assert "process_id" in data
assert data["status"] == ProcessStatus.PENDING
assert data["message"] == "Validation process started"
# Verify process was added to tracking
process_id = data["process_id"]
assert process_id in validation_processes
# In test environment, background tasks run synchronously so status may be COMPLETED
assert validation_processes[process_id]["status"] in [
ProcessStatus.PENDING,
ProcessStatus.COMPLETED,
]
assert validation_processes[process_id]["channel_id"] == "test-channel"
def test_get_validation_status_pending(db_session: Session, admin_user_client):
"""Test checking status of pending validation"""
process_id = str(uuid.uuid4())
validation_processes[process_id] = {
"status": ProcessStatus.PENDING,
"channel_id": "test-channel",
}
response = admin_user_client.get(f"/playlist/validate-streams/{process_id}")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["process_id"] == process_id
assert data["status"] == ProcessStatus.PENDING
assert data["working_streams"] is None
assert data["error"] is None
def test_get_validation_status_completed(db_session: Session, admin_user_client):
"""Test checking status of completed validation"""
process_id = str(uuid.uuid4())
validation_processes[process_id] = {
"status": ProcessStatus.COMPLETED,
"channel_id": "test-channel",
"result": {
"working_streams": [
{"channel_id": "test-channel", "stream_url": "http://valid.stream.url"}
]
},
}
response = admin_user_client.get(f"/playlist/validate-streams/{process_id}")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["process_id"] == process_id
assert data["status"] == ProcessStatus.COMPLETED
assert len(data["working_streams"]) == 1
assert data["working_streams"][0]["channel_id"] == "test-channel"
assert data["working_streams"][0]["stream_url"] == "http://valid.stream.url"
assert data["error"] is None
def test_get_validation_status_completed_with_error(
db_session: Session, admin_user_client
):
"""Test checking status of completed validation with error"""
process_id = str(uuid.uuid4())
validation_processes[process_id] = {
"status": ProcessStatus.COMPLETED,
"channel_id": "test-channel",
"error": "No working streams found for channel test-channel",
}
response = admin_user_client.get(f"/playlist/validate-streams/{process_id}")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["process_id"] == process_id
assert data["status"] == ProcessStatus.COMPLETED
assert data["working_streams"] is None
assert data["error"] == "No working streams found for channel test-channel"
def test_get_validation_status_failed(db_session: Session, admin_user_client):
"""Test checking status of failed validation"""
process_id = str(uuid.uuid4())
validation_processes[process_id] = {
"status": ProcessStatus.FAILED,
"channel_id": "test-channel",
"error": "Validation error occurred",
}
response = admin_user_client.get(f"/playlist/validate-streams/{process_id}")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["process_id"] == process_id
assert data["status"] == ProcessStatus.FAILED
assert data["working_streams"] is None
assert data["error"] == "Validation error occurred"
def test_get_validation_status_not_found(db_session: Session, admin_user_client):
"""Test checking status of non-existent process"""
random_uuid = str(uuid.uuid4())
response = admin_user_client.get(f"/playlist/validate-streams/{random_uuid}")
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "Process not found" in response.json()["detail"]
def test_run_stream_validation_success(mock_stream_manager, db_session):
"""Test the background validation task success case"""
process_id = str(uuid.uuid4())
validation_processes[process_id] = {
"status": ProcessStatus.PENDING,
"channel_id": "test-channel",
}
mock_instance = mock_stream_manager.return_value
mock_instance.validate_and_select_stream.return_value = "http://valid.stream.url"
from app.routers.playlist import run_stream_validation
run_stream_validation(process_id, "test-channel", db_session)
assert validation_processes[process_id]["status"] == ProcessStatus.COMPLETED
assert len(validation_processes[process_id]["result"]["working_streams"]) == 1
assert (
validation_processes[process_id]["result"]["working_streams"][0].channel_id
== "test-channel"
)
assert (
validation_processes[process_id]["result"]["working_streams"][0].stream_url
== "http://valid.stream.url"
)
def test_run_stream_validation_failure(mock_stream_manager, db_session):
"""Test the background validation task failure case"""
process_id = str(uuid.uuid4())
validation_processes[process_id] = {
"status": ProcessStatus.PENDING,
"channel_id": "test-channel",
}
mock_instance = mock_stream_manager.return_value
mock_instance.validate_and_select_stream.return_value = None
from app.routers.playlist import run_stream_validation
run_stream_validation(process_id, "test-channel", db_session)
assert validation_processes[process_id]["status"] == ProcessStatus.COMPLETED
assert "error" in validation_processes[process_id]
assert "No working streams found" in validation_processes[process_id]["error"]
def test_run_stream_validation_exception(mock_stream_manager, db_session):
"""Test the background validation task exception case"""
process_id = str(uuid.uuid4())
validation_processes[process_id] = {
"status": ProcessStatus.PENDING,
"channel_id": "test-channel",
}
mock_instance = mock_stream_manager.return_value
mock_instance.validate_and_select_stream.side_effect = Exception("Test error")
from app.routers.playlist import run_stream_validation
run_stream_validation(process_id, "test-channel", db_session)
assert validation_processes[process_id]["status"] == ProcessStatus.FAILED
assert "error" in validation_processes[process_id]
assert "Test error" in validation_processes[process_id]["error"]
def test_start_stream_validation_no_channel_id(
db_session: Session, admin_user_client, mock_stream_manager
):
"""Test starting validation without channel_id"""
response = admin_user_client.post("/playlist/validate-streams", json={})
assert response.status_code == status.HTTP_202_ACCEPTED
data = response.json()
assert "process_id" in data
assert data["status"] == ProcessStatus.PENDING
# Verify process was added to tracking
process_id = data["process_id"]
assert process_id in validation_processes
assert validation_processes[process_id]["status"] in [
ProcessStatus.PENDING,
ProcessStatus.COMPLETED,
]
assert validation_processes[process_id]["channel_id"] is None
assert "not yet implemented" in validation_processes[process_id].get("error", "")
def test_run_stream_validation_no_channel_id(mock_stream_manager, db_session):
"""Test background validation without channel_id"""
process_id = str(uuid.uuid4())
validation_processes[process_id] = {"status": ProcessStatus.PENDING}
from app.routers.playlist import run_stream_validation
run_stream_validation(process_id, None, db_session)
assert validation_processes[process_id]["status"] == ProcessStatus.COMPLETED
assert "error" in validation_processes[process_id]
assert "not yet implemented" in validation_processes[process_id]["error"]

View File

@@ -0,0 +1,241 @@
import uuid
from datetime import datetime, timezone
from fastapi import status
from sqlalchemy.orm import Session
from app.routers.priorities import router as priorities_router
# Import fixtures and mocks
from tests.utils.auth_test_fixtures import (
admin_user_client,
db_session,
non_admin_user_client,
)
from tests.utils.db_mocks import (
MockChannelDB,
MockChannelURL,
MockGroup,
MockPriority,
create_mock_priorities_and_group,
)
# --- Test Cases For Priority Creation ---
def test_create_priority_success(db_session: Session, admin_user_client):
priority_data = {"id": 100, "description": "Test Priority"}
response = admin_user_client.post("/priorities/", json=priority_data)
assert response.status_code == status.HTTP_201_CREATED
data = response.json()
assert data["id"] == 100
assert data["description"] == "Test Priority"
# Verify in DB
db_priority = db_session.get(MockPriority, 100)
assert db_priority is not None
assert db_priority.description == "Test Priority"
def test_create_priority_duplicate(db_session: Session, admin_user_client):
# Create initial priority
priority_data = {"id": 100, "description": "Original Priority"}
response1 = admin_user_client.post("/priorities/", json=priority_data)
assert response1.status_code == status.HTTP_201_CREATED
# Attempt to create with same ID
response2 = admin_user_client.post("/priorities/", json=priority_data)
assert response2.status_code == status.HTTP_409_CONFLICT
assert "already exists" in response2.json()["detail"]
def test_create_priority_forbidden_for_non_admin(
db_session: Session, non_admin_user_client
):
priority_data = {"id": 100, "description": "Test Priority"}
response = non_admin_user_client.post("/priorities/", json=priority_data)
assert response.status_code == status.HTTP_403_FORBIDDEN
assert "required roles" in response.json()["detail"]
# --- Test Cases For List Priorities ---
def test_list_priorities_empty(db_session: Session, admin_user_client):
response = admin_user_client.get("/priorities/")
assert response.status_code == status.HTTP_200_OK
assert response.json() == []
def test_list_priorities_with_data(db_session: Session, admin_user_client):
# Create some test priorities
priorities = [
MockPriority(id=100, description="High"),
MockPriority(id=200, description="Medium"),
MockPriority(id=300, description="Low"),
]
for priority in priorities:
db_session.add(priority)
db_session.commit()
response = admin_user_client.get("/priorities/")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert len(data) == 3
assert data[0]["id"] == 100
assert data[0]["description"] == "High"
assert data[1]["id"] == 200
assert data[1]["description"] == "Medium"
assert data[2]["id"] == 300
assert data[2]["description"] == "Low"
def test_list_priorities_forbidden_for_non_admin(
db_session: Session, non_admin_user_client
):
response = non_admin_user_client.get("/priorities/")
assert response.status_code == status.HTTP_403_FORBIDDEN
assert "required roles" in response.json()["detail"]
# --- Test Cases For Get Priority ---
def test_get_priority_success(db_session: Session, admin_user_client):
# Create a test priority
priority = MockPriority(id=100, description="Test Priority")
db_session.add(priority)
db_session.commit()
response = admin_user_client.get("/priorities/100")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["id"] == 100
assert data["description"] == "Test Priority"
def test_get_priority_not_found(db_session: Session, admin_user_client):
response = admin_user_client.get("/priorities/999")
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "Priority not found" in response.json()["detail"]
def test_get_priority_forbidden_for_non_admin(
db_session: Session, non_admin_user_client
):
response = non_admin_user_client.get("/priorities/100")
assert response.status_code == status.HTTP_403_FORBIDDEN
assert "required roles" in response.json()["detail"]
# --- Test Cases For Delete Priority ---
def test_delete_all_priorities_success(db_session, admin_user_client):
"""Test reset priorities endpoint"""
# Create test data
priorities = [(100, "High"), (200, "Medium"), (300, "Low")]
for id, desc in priorities:
db_session.add(MockPriority(id=id, description=desc))
db_session.commit()
# Create channel using priority 100
create_mock_priorities_and_group(db_session, [], "Test Group")
channel_data = [
{
"group-title": "Test Group",
"tvg_id": "test.tv",
"name": "Test Channel",
"urls": ["http://test.com"],
}
]
admin_user_client.post("/channels/bulk-upload", json=channel_data)
# Delete all priorities
response = admin_user_client.delete("/priorities")
assert response.status_code == status.HTTP_200_OK
assert response.json()["deleted"] == 2 # Medium and Low priorities
assert response.json()["skipped"] == 1 # High priority is in use
# Verify only priority 100 remains
priorities = db_session.query(MockPriority).all()
assert len(priorities) == 1
assert priorities[0].id == 100
def test_reset_priorities_forbidden_for_non_admin(db_session, non_admin_user_client):
"""Test reset priorities requires admin role"""
response = non_admin_user_client.delete("/priorities")
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_delete_priority_success(db_session: Session, admin_user_client):
# Create a test priority
priority = MockPriority(id=100, description="To Delete")
db_session.add(priority)
db_session.commit()
response = admin_user_client.delete("/priorities/100")
assert response.status_code == status.HTTP_204_NO_CONTENT
# Verify priority is gone from DB
db_priority = db_session.get(MockPriority, 100)
assert db_priority is None
def test_delete_priority_not_found(db_session: Session, admin_user_client):
response = admin_user_client.delete("/priorities/999")
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "Priority not found" in response.json()["detail"]
def test_delete_priority_in_use(db_session: Session, admin_user_client):
# Create a priority and a channel URL using it
priority = MockPriority(id=100, description="In Use")
group_id = uuid.uuid4()
test_group = MockGroup(
id=group_id,
name="Group With Channels",
sort_order=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db_session.add_all([priority, test_group])
db_session.commit()
# Create a channel first
channel = MockChannelDB(
name="Test Channel",
tvg_id="test.tv",
tvg_name="Test",
tvg_logo="test.png",
group_id=group_id,
)
db_session.add(channel)
db_session.commit()
# Create URL associated with the channel and priority
channel_url = MockChannelURL(
url="http://test.com",
priority_id=100,
in_use=True,
channel_id=channel.id, # Add the channel_id
)
db_session.add(channel_url)
db_session.commit()
response = admin_user_client.delete("/priorities/100")
assert response.status_code == status.HTTP_409_CONFLICT
assert "in use by channel URLs" in response.json()["detail"]
# Verify priority still exists
db_priority = db_session.get(MockPriority, 100)
assert db_priority is not None
def test_delete_priority_forbidden_for_non_admin(
db_session: Session, non_admin_user_client
):
response = non_admin_user_client.delete("/priorities/100")
assert response.status_code == status.HTTP_403_FORBIDDEN
assert "required roles" in response.json()["detail"]

View File

@@ -0,0 +1,287 @@
from datetime import datetime, timezone
from unittest.mock import Mock
from fastapi import HTTPException, Request, status
from app.iptv.scheduler import StreamScheduler
from app.routers.scheduler import get_scheduler
from app.routers.scheduler import router as scheduler_router
from app.utils.database import get_db
from tests.routers.mocks import MockScheduler, create_trigger_mock, mock_get_scheduler
from tests.utils.auth_test_fixtures import (
admin_user_client,
db_session,
non_admin_user_client,
)
from tests.utils.db_mocks import mock_get_db
# Scheduler Health Check Tests
def test_scheduler_health_success(admin_user_client, monkeypatch):
"""
Test case for successful scheduler health check when accessed by an admin user.
It mocks the scheduler to be running and have a next scheduled job.
"""
# Define the expected next run time for the scheduler job.
next_run = datetime.now(timezone.utc)
# Create a mock job object that simulates an APScheduler job.
mock_job = Mock()
mock_job.next_run_time = next_run
# Mock the `get_job` method to return our mock_job for a specific ID.
def mock_get_job(job_id):
if job_id == "daily_stream_validation":
return mock_job
return None
# Create a custom mock for `get_scheduler` dependency.
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
return await mock_get_scheduler(
request,
running=True,
get_job=Mock(side_effect=mock_get_job), # Use the custom mock_get_job
)
# Include the scheduler router in the test application.
admin_user_client.app.include_router(scheduler_router)
# Override dependencies for the test.
admin_user_client.app.dependency_overrides[get_scheduler] = (
custom_mock_get_scheduler
)
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
# Make the request to the scheduler health endpoint.
response = admin_user_client.get("/scheduler/health")
# Assert the response status code and content.
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["status"] == "running"
assert data["next_run"] == str(next_run)
def test_scheduler_health_stopped(admin_user_client, monkeypatch):
"""
Test case for scheduler health check when the scheduler is in a stopped state.
Ensures the API returns the correct status and no next run time.
"""
# Create a custom mock for `get_scheduler` dependency,
# simulating a stopped scheduler.
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
return await mock_get_scheduler(
request,
running=False,
)
# Include the scheduler router in the test application.
admin_user_client.app.include_router(scheduler_router)
# Override dependencies for the test.
admin_user_client.app.dependency_overrides[get_scheduler] = (
custom_mock_get_scheduler
)
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
# Make the request to the scheduler health endpoint.
response = admin_user_client.get("/scheduler/health")
# Assert the response status code and content.
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["status"] == "stopped"
assert data["next_run"] is None
def test_scheduler_health_forbidden_for_non_admin(non_admin_user_client, monkeypatch):
"""
Test case to ensure that non-admin users are forbidden from accessing
the scheduler health endpoint.
"""
# Create a custom mock for `get_scheduler` dependency.
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
return await mock_get_scheduler(
request,
running=False,
)
# Include the scheduler router in the test application.
non_admin_user_client.app.include_router(scheduler_router)
# Override dependencies for the test.
non_admin_user_client.app.dependency_overrides[get_scheduler] = (
custom_mock_get_scheduler
)
non_admin_user_client.app.dependency_overrides[get_db] = mock_get_db
# Make the request to the scheduler health endpoint.
response = non_admin_user_client.get("/scheduler/health")
# Assert the response status code and error detail.
assert response.status_code == status.HTTP_403_FORBIDDEN
assert "required roles" in response.json()["detail"]
def test_scheduler_health_check_exception(admin_user_client, monkeypatch):
"""
Test case for handling exceptions during the scheduler health check.
Ensures the API returns a 500 Internal Server Error when an exception occurs.
"""
# Create a custom mock for `get_scheduler` dependency that raises an exception.
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
return await mock_get_scheduler(
request, running=True, get_job=Mock(side_effect=Exception("Test exception"))
)
# Include the scheduler router in the test application.
admin_user_client.app.include_router(scheduler_router)
# Override dependencies for the test.
admin_user_client.app.dependency_overrides[get_scheduler] = (
custom_mock_get_scheduler
)
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
# Make the request to the scheduler health endpoint.
response = admin_user_client.get("/scheduler/health")
# Assert the response status code and error detail.
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert "Failed to check scheduler health" in response.json()["detail"]
# Scheduler Trigger Tests
def test_trigger_validation_success(admin_user_client, monkeypatch):
"""
Test case for successful manual triggering
of stream validation by an admin user.
It verifies that the trigger method is called and
the API returns a 202 Accepted status.
"""
# Use a mutable reference to check if the trigger method was called.
triggered_ref = {"value": False}
# Initialize a custom mock scheduler.
custom_scheduler = MockScheduler(running=True)
custom_scheduler.get_job = Mock(return_value=None)
# Create a custom mock for `get_scheduler` dependency,
# overriding `trigger_manual_validation`.
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
scheduler = await mock_get_scheduler(
request,
running=True,
)
# Replace the actual trigger method with our mock to track calls.
scheduler.trigger_manual_validation = create_trigger_mock(
triggered_ref=triggered_ref
)
return scheduler
# Include the scheduler router in the test application.
admin_user_client.app.include_router(scheduler_router)
# Override dependencies for the test.
admin_user_client.app.dependency_overrides[get_scheduler] = (
custom_mock_get_scheduler
)
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
# Make the request to trigger stream validation.
response = admin_user_client.post("/scheduler/trigger")
# Assert the response status code, message, and that the trigger was called.
assert response.status_code == status.HTTP_202_ACCEPTED
assert response.json()["message"] == "Stream validation triggered"
assert triggered_ref["value"] is True
def test_trigger_validation_forbidden_for_non_admin(non_admin_user_client, monkeypatch):
"""
Test case to ensure that non-admin users are
forbidden from manually triggering stream validation.
"""
# Create a custom mock for `get_scheduler` dependency.
async def custom_mock_get_scheduler(request: Request) -> StreamScheduler:
return await mock_get_scheduler(
request,
running=True,
)
# Include the scheduler router in the test application.
non_admin_user_client.app.include_router(scheduler_router)
# Override dependencies for the test.
non_admin_user_client.app.dependency_overrides[get_scheduler] = (
custom_mock_get_scheduler
)
non_admin_user_client.app.dependency_overrides[get_db] = mock_get_db
# Make the request to trigger stream validation.
response = non_admin_user_client.post("/scheduler/trigger")
# Assert the response status code and error detail.
assert response.status_code == status.HTTP_403_FORBIDDEN
assert "required roles" in response.json()["detail"]
def test_scheduler_initialized_in_app_state(admin_user_client):
"""
Test case for when the scheduler is initialized in the app state but its internal
scheduler attribute is not set, which should still allow health check.
"""
scheduler = StreamScheduler()
# Set the scheduler instance in the test client's app state.
admin_user_client.app.state.scheduler = scheduler
# Include the scheduler router in the test application.
admin_user_client.app.include_router(scheduler_router)
# Override only get_db, allowing the real get_scheduler to be tested.
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
# Make the request to the scheduler health endpoint.
response = admin_user_client.get("/scheduler/health")
# Assert the response status code.
assert response.status_code == status.HTTP_200_OK
def test_scheduler_not_initialized_in_app_state(admin_user_client):
"""
Test case for when the scheduler is not properly initialized in the app state.
This simulates a scenario where the internal scheduler attribute is missing,
leading to a 500 Internal Server Error on health check.
"""
scheduler = StreamScheduler()
del (
scheduler.scheduler
) # Simulate uninitialized scheduler by deleting the attribute
# Set the scheduler instance in the test client's app state.
admin_user_client.app.state.scheduler = scheduler
# Include the scheduler router in the test application.
admin_user_client.app.include_router(scheduler_router)
# Override only get_db, allowing the real get_scheduler to be tested.
admin_user_client.app.dependency_overrides[get_db] = mock_get_db
# Make the request to the scheduler health endpoint.
response = admin_user_client.get("/scheduler/health")
# Assert the response status code and error detail.
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert "Scheduler not initialized" in response.json()["detail"]

View File

@@ -1,18 +1,23 @@
from unittest.mock import patch
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from app.main import app, lifespan from app.main import app, lifespan
from unittest.mock import patch, MagicMock
@pytest.fixture @pytest.fixture
def client(): def client():
"""Test client for FastAPI app""" """Test client for FastAPI app"""
return TestClient(app) return TestClient(app)
def test_root_endpoint(client): def test_root_endpoint(client):
"""Test root endpoint returns expected message""" """Test root endpoint returns expected message"""
response = client.get("/") response = client.get("/")
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"message": "IPTV Updater API"} assert response.json() == {"message": "IPTV Manager API"}
def test_openapi_schema_generation(client): def test_openapi_schema_generation(client):
"""Test OpenAPI schema is properly generated""" """Test OpenAPI schema is properly generated"""
@@ -23,7 +28,7 @@ def test_openapi_schema_generation(client):
assert schema["openapi"] == "3.1.0" assert schema["openapi"] == "3.1.0"
assert "securitySchemes" in schema["components"] assert "securitySchemes" in schema["components"]
assert "Bearer" in schema["components"]["securitySchemes"] assert "Bearer" in schema["components"]["securitySchemes"]
# Test empty components initialization # Test empty components initialization
with patch("app.main.get_openapi", return_value={"info": {}}): with patch("app.main.get_openapi", return_value={"info": {}}):
# Clear cached schema # Clear cached schema
@@ -35,26 +40,28 @@ def test_openapi_schema_generation(client):
assert "components" in schema assert "components" in schema
assert "schemas" in schema["components"] assert "schemas" in schema["components"]
def test_openapi_schema_caching(mocker): def test_openapi_schema_caching(mocker):
"""Test OpenAPI schema caching behavior""" """Test OpenAPI schema caching behavior"""
# Clear any existing schema # Clear any existing schema
app.openapi_schema = None app.openapi_schema = None
# Mock get_openapi to return test schema # Mock get_openapi to return test schema
mock_schema = {"test": "schema"} mock_schema = {"test": "schema"}
mocker.patch("app.main.get_openapi", return_value=mock_schema) mocker.patch("app.main.get_openapi", return_value=mock_schema)
# First call - should call get_openapi # First call - should call get_openapi
schema = app.openapi() schema = app.openapi()
assert schema == mock_schema assert schema == mock_schema
assert app.openapi_schema == mock_schema assert app.openapi_schema == mock_schema
# Second call - should return cached schema # Second call - should return cached schema
with patch("app.main.get_openapi") as mock_get_openapi: with patch("app.main.get_openapi") as mock_get_openapi:
schema = app.openapi() schema = app.openapi()
assert schema == mock_schema assert schema == mock_schema
mock_get_openapi.assert_not_called() mock_get_openapi.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_lifespan_init_db(mocker): async def test_lifespan_init_db(mocker):
"""Test lifespan manager initializes database""" """Test lifespan manager initializes database"""
@@ -63,6 +70,7 @@ async def test_lifespan_init_db(mocker):
pass # Just enter/exit context pass # Just enter/exit context
mock_init_db.assert_called_once() mock_init_db.assert_called_once()
def test_router_inclusion(): def test_router_inclusion():
"""Test all routers are properly included""" """Test all routers are properly included"""
route_paths = {route.path for route in app.routes} route_paths = {route.path for route in app.routes}
@@ -70,4 +78,4 @@ def test_router_inclusion():
assert any(path.startswith("/auth") for path in route_paths) assert any(path.startswith("/auth") for path in route_paths)
assert any(path.startswith("/channels") for path in route_paths) assert any(path.startswith("/channels") for path in route_paths)
assert any(path.startswith("/playlist") for path in route_paths) assert any(path.startswith("/playlist") for path in route_paths)
assert any(path.startswith("/priorities") for path in route_paths) assert any(path.startswith("/priorities") for path in route_paths)

View File

@@ -0,0 +1,82 @@
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
from app.auth.dependencies import get_current_user
from app.models.auth import CognitoUser
from app.routers.channels import router as channels_router
from app.routers.groups import router as groups_router
from app.routers.playlist import router as playlist_router
from app.routers.priorities import router as priorities_router
from app.utils.database import get_db
from tests.utils.db_mocks import (
MockBase,
MockChannelDB,
MockChannelURL,
MockPriority,
engine_mock,
mock_get_db,
)
from tests.utils.db_mocks import session_mock as TestingSessionLocal
def mock_get_current_user_admin():
return CognitoUser(
username="testadmin",
email="testadmin@example.com",
roles=["admin"],
user_status="CONFIRMED",
enabled=True,
)
def mock_get_current_user_non_admin():
return CognitoUser(
username="testuser",
email="testuser@example.com",
roles=["user"], # Or any role other than admin
user_status="CONFIRMED",
enabled=True,
)
@pytest.fixture(scope="function")
def db_session():
# Create tables for each test function
MockBase.metadata.create_all(bind=engine_mock)
db = TestingSessionLocal()
try:
yield db
finally:
db.close()
# Drop tables after each test function
MockBase.metadata.drop_all(bind=engine_mock)
@pytest.fixture(scope="function")
def admin_user_client(db_session: Session):
"""Yields a TestClient configured with an admin user."""
test_app = FastAPI()
test_app.include_router(channels_router)
test_app.include_router(priorities_router)
test_app.include_router(playlist_router)
test_app.include_router(groups_router)
test_app.dependency_overrides[get_db] = mock_get_db
test_app.dependency_overrides[get_current_user] = mock_get_current_user_admin
with TestClient(test_app) as test_client:
yield test_client
@pytest.fixture(scope="function")
def non_admin_user_client(db_session: Session):
"""Yields a TestClient configured with a non-admin user."""
test_app = FastAPI()
test_app.include_router(channels_router)
test_app.include_router(priorities_router)
test_app.include_router(playlist_router)
test_app.include_router(groups_router)
test_app.dependency_overrides[get_db] = mock_get_db
test_app.dependency_overrides[get_current_user] = mock_get_current_user_non_admin
with TestClient(test_app) as test_client:
yield test_client

View File

@@ -1,29 +1,27 @@
import os
import uuid import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from unittest.mock import patch, MagicMock from unittest.mock import MagicMock, patch
from sqlalchemy.pool import StaticPool
from sqlalchemy.orm import Session, sessionmaker, declarative_base
from sqlalchemy import create_engine, TypeDecorator, TEXT, Column, String, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer
import pytest import pytest
from sqlalchemy import (
Boolean,
Column,
DateTime,
ForeignKey,
Integer,
String,
UniqueConstraint,
create_engine,
)
from sqlalchemy.orm import declarative_base, relationship, sessionmaker
from sqlalchemy.pool import StaticPool
# Import the actual UUID_COLUMN_TYPE and SQLiteUUID from app.models.db
from app.models.db import UUID_COLUMN_TYPE, SQLiteUUID
# Create a mock-specific Base class for testing # Create a mock-specific Base class for testing
MockBase = declarative_base() MockBase = declarative_base()
class SQLiteUUID(TypeDecorator):
"""Enables UUID support for SQLite."""
impl = TEXT
cache_ok = True
def process_bind_param(self, value, dialect):
if value is None:
return value
return str(value)
def process_result_value(self, value, dialect):
if value is None:
return value
return uuid.UUID(value)
# Model classes for testing - prefix with Mock to avoid pytest collection # Model classes for testing - prefix with Mock to avoid pytest collection
class MockPriority(MockBase): class MockPriority(MockBase):
@@ -31,40 +29,96 @@ class MockPriority(MockBase):
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
description = Column(String, nullable=False) description = Column(String, nullable=False)
class MockGroup(MockBase):
__tablename__ = "groups"
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
name = Column(String, nullable=False, unique=True)
sort_order = Column(Integer, nullable=False, default=0)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(
DateTime,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
channels = relationship("MockChannelDB", back_populates="group")
class MockChannelDB(MockBase): class MockChannelDB(MockBase):
__tablename__ = "channels" __tablename__ = "channels"
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4) id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
tvg_id = Column(String, nullable=False) tvg_id = Column(String, nullable=False)
name = Column(String, nullable=False) name = Column(String, nullable=False)
group_title = Column(String, nullable=False) group_id = Column(UUID_COLUMN_TYPE, ForeignKey("groups.id"), nullable=False)
tvg_name = Column(String) tvg_name = Column(String)
__table_args__ = ( __table_args__ = (UniqueConstraint("group_id", "name", name="uix_group_id_name"),)
UniqueConstraint('group_title', 'name', name='uix_group_title_name'),
)
tvg_logo = Column(String) tvg_logo = Column(String)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) updated_at = Column(
DateTime,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
group = relationship("MockGroup", back_populates="channels")
urls = relationship(
"MockChannelURL", back_populates="channel", cascade="all, delete-orphan"
)
class MockChannelURL(MockBase): class MockChannelURL(MockBase):
__tablename__ = "channels_urls" __tablename__ = "channels_urls"
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4) id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
channel_id = Column(SQLiteUUID(), ForeignKey('channels.id', ondelete='CASCADE'), nullable=False) channel_id = Column(
UUID_COLUMN_TYPE, ForeignKey("channels.id", ondelete="CASCADE"), nullable=False
)
url = Column(String, nullable=False) url = Column(String, nullable=False)
in_use = Column(Boolean, default=False, nullable=False) in_use = Column(Boolean, default=False, nullable=False)
priority_id = Column(Integer, ForeignKey('priorities.id'), nullable=False) priority_id = Column(Integer, ForeignKey("priorities.id"), nullable=False)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) updated_at = Column(
DateTime,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
channel = relationship("MockChannelDB", back_populates="urls")
def create_mock_priorities_and_group(db_session, priorities, group_name):
"""Create mock priorities and group for testing purposes.
Args:
db_session: SQLAlchemy session object
priorities: List of (id, description) tuples for priorities to create
group_name: Name for the new mock group
Returns:
UUID: The ID of the created group
"""
# Create priorities
priority_objects = [
MockPriority(id=priority_id, description=description)
for priority_id, description in priorities
]
# Create group
group = MockGroup(name=group_name)
db_session.add_all(priority_objects + [group])
db_session.commit()
return group.id
# Create test engine # Create test engine
engine_mock = create_engine( engine_mock = create_engine(
"sqlite:///:memory:", "sqlite:///:memory:",
connect_args={"check_same_thread": False}, connect_args={"check_same_thread": False},
poolclass=StaticPool poolclass=StaticPool,
) )
# Create test session # Create test session
session_mock = sessionmaker(autocommit=False, autoflush=False, bind=engine_mock) session_mock = sessionmaker(autocommit=False, autoflush=False, bind=engine_mock)
# Mock the actual database functions # Mock the actual database functions
def mock_get_db(): def mock_get_db():
db = session_mock() db = session_mock()
@@ -73,6 +127,7 @@ def mock_get_db():
finally: finally:
db.close() db.close()
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_env(monkeypatch): def mock_env(monkeypatch):
"""Fixture for mocking environment variables""" """Fixture for mocking environment variables"""
@@ -82,14 +137,13 @@ def mock_env(monkeypatch):
monkeypatch.setenv("DB_HOST", "localhost") monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_NAME", "testdb") monkeypatch.setenv("DB_NAME", "testdb")
monkeypatch.setenv("AWS_REGION", "us-east-1") monkeypatch.setenv("AWS_REGION", "us-east-1")
@pytest.fixture @pytest.fixture
def mock_ssm(): def mock_ssm():
"""Fixture for mocking boto3 SSM client""" """Fixture for mocking boto3 SSM client"""
with patch('boto3.client') as mock_client: with patch("boto3.client") as mock_client:
mock_ssm = MagicMock() mock_ssm = MagicMock()
mock_client.return_value = mock_ssm mock_client.return_value = mock_ssm
mock_ssm.get_parameter.return_value = { mock_ssm.get_parameter.return_value = {"Parameter": {"Value": "mocked_value"}}
'Parameter': {'Value': 'mocked_value'} yield mock_ssm
}
yield mock_ssm

View File

@@ -0,0 +1,309 @@
import os
from unittest.mock import MagicMock, Mock, mock_open, patch
import pytest
import requests
from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout
from app.utils.check_streams import StreamValidator, main
@pytest.fixture
def validator():
"""Create a StreamValidator instance for testing"""
return StreamValidator(timeout=1)
def test_validator_init():
"""Test StreamValidator initialization with default and custom values"""
# Test with default user agent
validator = StreamValidator()
assert validator.timeout == 10
assert "Mozilla" in validator.session.headers["User-Agent"]
# Test with custom values
custom_agent = "CustomAgent/1.0"
validator = StreamValidator(timeout=5, user_agent=custom_agent)
assert validator.timeout == 5
assert validator.session.headers["User-Agent"] == custom_agent
def test_is_valid_content_type(validator):
"""Test content type validation"""
valid_types = [
"video/mp4",
"video/mp2t",
"application/vnd.apple.mpegurl",
"application/dash+xml",
"video/webm",
"application/octet-stream",
"application/x-mpegURL",
"video/mp4; charset=utf-8", # Test with additional parameters
]
invalid_types = [
"text/html",
"application/json",
"image/jpeg",
"",
]
for content_type in valid_types:
assert validator._is_valid_content_type(content_type)
for content_type in invalid_types:
assert not validator._is_valid_content_type(content_type)
# Test None case explicitly
assert not validator._is_valid_content_type(None)
@pytest.mark.parametrize(
"status_code,content_type,should_succeed",
[
(200, "video/mp4", True),
(206, "video/mp4", True), # Partial content
(404, "video/mp4", False),
(500, "video/mp4", False),
(200, "text/html", False),
(200, "application/json", False),
],
)
def test_validate_stream_response_handling(status_code, content_type, should_succeed):
"""Test stream validation with different response scenarios"""
mock_response = MagicMock()
mock_response.status_code = status_code
mock_response.headers = {"Content-Type": content_type}
mock_response.iter_content.return_value = iter([b"some content"])
mock_session = MagicMock()
mock_session.get.return_value.__enter__.return_value = mock_response
with patch("requests.Session", return_value=mock_session):
validator = StreamValidator()
valid, message = validator.validate_stream("http://example.com/stream")
assert valid == should_succeed
mock_session.get.assert_called_once()
def test_validate_stream_connection_error():
"""Test stream validation with connection error"""
mock_session = MagicMock()
mock_session.get.side_effect = ConnectionError("Connection failed")
with patch("requests.Session", return_value=mock_session):
validator = StreamValidator()
valid, message = validator.validate_stream("http://example.com/stream")
assert not valid
assert "Connection Error" in message
def test_validate_stream_timeout():
"""Test stream validation with timeout"""
mock_session = MagicMock()
mock_session.get.side_effect = Timeout("Request timed out")
with patch("requests.Session", return_value=mock_session):
validator = StreamValidator()
valid, message = validator.validate_stream("http://example.com/stream")
assert not valid
assert "timeout" in message.lower()
def test_validate_stream_http_error():
"""Test stream validation with HTTP error"""
mock_session = MagicMock()
mock_session.get.side_effect = HTTPError("HTTP Error occurred")
with patch("requests.Session", return_value=mock_session):
validator = StreamValidator()
valid, message = validator.validate_stream("http://example.com/stream")
assert not valid
assert "HTTP Error" in message
def test_validate_stream_request_exception():
"""Test stream validation with general request exception"""
mock_session = MagicMock()
mock_session.get.side_effect = RequestException("Request failed")
with patch("requests.Session", return_value=mock_session):
validator = StreamValidator()
valid, message = validator.validate_stream("http://example.com/stream")
assert not valid
assert "Request Exception" in message
def test_validate_stream_content_read_error():
"""Test stream validation when content reading fails"""
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "video/mp4"}
mock_response.iter_content.side_effect = ConnectionError("Read failed")
mock_session = MagicMock()
mock_session.get.return_value.__enter__.return_value = mock_response
with patch("requests.Session", return_value=mock_session):
validator = StreamValidator()
valid, message = validator.validate_stream("http://example.com/stream")
assert not valid
assert "Connection failed during content read" in message
def test_validate_stream_general_exception():
"""Test validate_stream with an unexpected exception"""
mock_session = MagicMock()
mock_session.get.side_effect = Exception("Unexpected error")
with patch("requests.Session", return_value=mock_session):
validator = StreamValidator()
valid, message = validator.validate_stream("http://example.com/stream")
assert not valid
assert "Validation error" in message
def test_parse_playlist(validator, tmp_path):
"""Test playlist file parsing"""
playlist_content = """
#EXTM3U
#EXTINF:-1,Channel 1
http://example.com/stream1
#EXTINF:-1,Channel 2
http://example.com/stream2
http://example.com/stream3
"""
playlist_file = tmp_path / "test_playlist.m3u"
playlist_file.write_text(playlist_content)
urls = validator.parse_playlist(str(playlist_file))
assert len(urls) == 3
assert urls == [
"http://example.com/stream1",
"http://example.com/stream2",
"http://example.com/stream3",
]
def test_parse_playlist_error(validator):
"""Test playlist parsing with non-existent file"""
with pytest.raises(Exception):
validator.parse_playlist("nonexistent_file.m3u")
@patch("app.utils.check_streams.logging")
@patch("app.utils.check_streams.StreamValidator")
def test_main_with_urls(mock_validator_class, mock_logging, tmp_path, capsys):
"""Test main function with direct URLs"""
# Setup mock validator
mock_validator = Mock()
mock_validator_class.return_value = mock_validator
mock_validator.validate_stream.return_value = (True, "Stream is valid")
# Setup test arguments
test_args = ["script", "http://example.com/stream1", "http://example.com/stream2"]
with patch("sys.argv", test_args):
main()
# Verify validator was called correctly
assert mock_validator.validate_stream.call_count == 2
mock_validator.validate_stream.assert_any_call("http://example.com/stream1")
mock_validator.validate_stream.assert_any_call("http://example.com/stream2")
@patch("app.utils.check_streams.logging")
@patch("app.utils.check_streams.StreamValidator")
def test_main_with_playlist(mock_validator_class, mock_logging, tmp_path):
"""Test main function with a playlist file"""
# Create test playlist
playlist_content = "http://example.com/stream1\nhttp://example.com/stream2"
playlist_file = tmp_path / "test.m3u"
playlist_file.write_text(playlist_content)
# Setup mock validator
mock_validator = Mock()
mock_validator_class.return_value = mock_validator
mock_validator.parse_playlist.return_value = [
"http://example.com/stream1",
"http://example.com/stream2",
]
mock_validator.validate_stream.return_value = (True, "Stream is valid")
# Setup test arguments
test_args = ["script", str(playlist_file)]
with patch("sys.argv", test_args):
main()
# Verify validator was called correctly
mock_validator.parse_playlist.assert_called_once_with(str(playlist_file))
assert mock_validator.validate_stream.call_count == 2
@patch("app.utils.check_streams.logging")
@patch("app.utils.check_streams.StreamValidator")
def test_main_with_dead_streams(mock_validator_class, mock_logging, tmp_path):
"""Test main function handling dead streams"""
# Setup mock validator
mock_validator = Mock()
mock_validator_class.return_value = mock_validator
mock_validator.validate_stream.return_value = (False, "Stream is dead")
# Setup test arguments
test_args = ["script", "http://example.com/dead1", "http://example.com/dead2"]
# Mock file operations
mock_file = mock_open()
with patch("sys.argv", test_args), patch("builtins.open", mock_file):
main()
# Verify dead streams were written to file
mock_file().write.assert_called_once_with(
"http://example.com/dead1\nhttp://example.com/dead2"
)
@patch("app.utils.check_streams.logging")
@patch("app.utils.check_streams.StreamValidator")
@patch("os.path.isfile")
def test_main_with_playlist_error(
mock_isfile, mock_validator_class, mock_logging, tmp_path
):
"""Test main function handling playlist parsing errors"""
# Setup mock validator
mock_validator = Mock()
mock_validator_class.return_value = mock_validator
# Configure mock validator behavior
error_msg = "Failed to parse playlist"
mock_validator.parse_playlist.side_effect = [
Exception(error_msg), # First call fails
["http://example.com/stream1"], # Second call succeeds
]
mock_validator.validate_stream.return_value = (True, "Stream is valid")
# Configure isfile mock to return True for our test files
mock_isfile.side_effect = lambda x: x in ["/invalid.m3u", "/valid.m3u"]
# Setup test arguments
test_args = ["script", "/invalid.m3u", "/valid.m3u"]
with patch("sys.argv", test_args):
main()
# Verify error was logged correctly
mock_logging.error.assert_called_with(
"Failed to process file /invalid.m3u: Failed to parse playlist"
)
# Verify processing continued with valid playlist
mock_validator.parse_playlist.assert_called_with("/valid.m3u")
assert (
mock_validator.validate_stream.call_count == 1
) # Called for the URL from valid playlist

View File

@@ -1,43 +1,45 @@
import os import os
import pytest import pytest
from unittest.mock import patch
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.utils.database import get_db_credentials, get_db
from tests.utils.db_mocks import ( from app.utils.database import get_db, get_db_credentials
session_mock, from tests.utils.db_mocks import mock_env, mock_ssm, session_mock
mock_get_db,
mock_env,
mock_ssm
)
def test_get_db_credentials_env(mock_env): def test_get_db_credentials_env(mock_env):
"""Test getting DB credentials from environment variables""" """Test getting DB credentials from environment variables"""
conn_str = get_db_credentials() conn_str = get_db_credentials()
assert conn_str == "postgresql://testuser:testpass@localhost/testdb" assert conn_str == "postgresql://testuser:testpass@localhost/testdb"
def test_get_db_credentials_ssm(mock_ssm): def test_get_db_credentials_ssm(mock_ssm):
"""Test getting DB credentials from SSM""" """Test getting DB credentials from SSM"""
os.environ.pop("MOCK_AUTH", None) os.environ.pop("MOCK_AUTH", None)
conn_str = get_db_credentials() conn_str = get_db_credentials()
assert "postgresql://mocked_value:mocked_value@mocked_value/mocked_value" in conn_str expected_conn = "postgresql://mocked_value:mocked_value@mocked_value/mocked_value"
assert expected_conn in conn_str
mock_ssm.get_parameter.assert_called() mock_ssm.get_parameter.assert_called()
def test_get_db_credentials_ssm_exception(mock_ssm): def test_get_db_credentials_ssm_exception(mock_ssm):
"""Test SSM credential fetching failure raises RuntimeError""" """Test SSM credential fetching failure raises RuntimeError"""
os.environ.pop("MOCK_AUTH", None) os.environ.pop("MOCK_AUTH", None)
mock_ssm.get_parameter.side_effect = Exception("SSM timeout") mock_ssm.get_parameter.side_effect = Exception("SSM timeout")
with pytest.raises(RuntimeError) as excinfo: with pytest.raises(RuntimeError) as excinfo:
get_db_credentials() get_db_credentials()
assert "Failed to fetch DB credentials from SSM: SSM timeout" in str(excinfo.value) assert "Failed to fetch DB credentials from SSM: SSM timeout" in str(excinfo.value)
def test_session_creation(): def test_session_creation():
"""Test database session creation""" """Test database session creation"""
session = session_mock() session = session_mock()
assert isinstance(session, Session) assert isinstance(session, Session)
session.close() session.close()
def test_get_db_generator(): def test_get_db_generator():
"""Test get_db dependency generator""" """Test get_db dependency generator"""
db_gen = get_db() db_gen = get_db()
@@ -48,18 +50,20 @@ def test_get_db_generator():
except StopIteration: except StopIteration:
pass pass
def test_init_db(mocker, mock_env): def test_init_db(mocker, mock_env):
"""Test database initialization creates tables""" """Test database initialization creates tables"""
mock_create_all = mocker.patch('app.models.Base.metadata.create_all') mock_create_all = mocker.patch("app.models.Base.metadata.create_all")
# Mock get_db_credentials to return SQLite test connection # Mock get_db_credentials to return SQLite test connection
mocker.patch( mocker.patch(
'app.utils.database.get_db_credentials', "app.utils.database.get_db_credentials",
return_value="sqlite:///:memory:" return_value="sqlite:///:memory:",
) )
from app.utils.database import init_db, engine from app.utils.database import engine, init_db
init_db() init_db()
# Verify create_all was called with the engine # Verify create_all was called with the engine
mock_create_all.assert_called_once_with(bind=engine) mock_create_all.assert_called_once_with(bind=engine)