From a42d4c30a691ee23922fbcf18094dc9b256f833d Mon Sep 17 00:00:00 2001 From: Stefano Date: Tue, 17 Jun 2025 17:12:39 -0500 Subject: [PATCH] Started (incomplete) implementation of stream verification scheduler and endpoints --- .env.example | 5 + .vscode/settings.json | 1 + app/auth/dependencies.py | 4 +- app/iptv/scheduler.py | 110 ++++++++++++ app/iptv/stream_manager.py | 151 +++++++++++++++++ app/main.py | 13 +- app/routers/playlist.py | 153 ++++++++++++++++- app/routers/scheduler.py | 57 +++++++ app/utils/database.py | 6 + requirements.txt | 3 +- tests/auth/test_dependencies.py | 4 +- tests/routers/mocks.py | 43 +++++ tests/routers/test_playlist.py | 262 ++++++++++++++++++++++++++--- tests/routers/test_scheduler.py | 287 ++++++++++++++++++++++++++++++++ 14 files changed, 1066 insertions(+), 33 deletions(-) create mode 100644 app/iptv/scheduler.py create mode 100644 app/iptv/stream_manager.py create mode 100644 app/routers/scheduler.py create mode 100644 tests/routers/mocks.py create mode 100644 tests/routers/test_scheduler.py diff --git a/.env.example b/.env.example index b012966..0545b26 100644 --- a/.env.example +++ b/.env.example @@ -1,4 +1,9 @@ +# Environment variables +# Scheduler configuration +STREAM_VALIDATION_SCHEDULE=0 3 * * * # Daily at 3 AM (cron syntax) +STREAM_VALIDATION_BATCH_SIZE=10 # Number of channels per batch (0=all) + # For use with Docker Compose to run application locally MOCK_AUTH=true/false DB_USER=MyDBUser diff --git a/.vscode/settings.json b/.vscode/settings.json index 69daef7..fec9e8c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -9,6 +9,7 @@ "addopts", "adminpassword", "altinstall", + "apscheduler", "asyncio", "autoflush", "autoupdate", diff --git a/app/auth/dependencies.py b/app/auth/dependencies.py index debd019..d1ec3c5 100644 --- a/app/auth/dependencies.py +++ b/app/auth/dependencies.py @@ -32,7 +32,9 @@ def require_roles(*required_roles: str) -> Callable: def decorator(endpoint: Callable) -> Callable: @wraps(endpoint) - def wrapper(*args, user: CognitoUser = Depends(get_current_user), **kwargs): + async def wrapper( + *args, user: CognitoUser = Depends(get_current_user), **kwargs + ): user_roles = set(user.roles or []) needed_roles = set(required_roles) if not needed_roles.issubset(user_roles): diff --git a/app/iptv/scheduler.py b/app/iptv/scheduler.py new file mode 100644 index 0000000..da902cb --- /dev/null +++ b/app/iptv/scheduler.py @@ -0,0 +1,110 @@ +import logging +import os +from typing import Optional + +from apscheduler.schedulers.background import BackgroundScheduler +from apscheduler.triggers.cron import CronTrigger +from fastapi import FastAPI +from sqlalchemy.orm import Session + +from app.iptv.stream_manager import StreamManager +from app.models.db import ChannelDB +from app.utils.database import get_db_session + +logger = logging.getLogger(__name__) + + +class StreamScheduler: + """Scheduler service for periodic stream validation tasks.""" + + def __init__(self, app: Optional[FastAPI] = None): + """ + Initialize the scheduler with optional FastAPI app integration. + + Args: + app: Optional FastAPI app instance for lifecycle integration + """ + self.scheduler = BackgroundScheduler() + self.app = app + self.batch_size = int(os.getenv("STREAM_VALIDATION_BATCH_SIZE", "10")) + self.schedule_time = os.getenv( + "STREAM_VALIDATION_SCHEDULE", "0 3 * * *" + ) # Default 3 AM daily + logger.info(f"Scheduler initialized with app: {app is not None}") + + def validate_streams_batch(self, db_session: Optional[Session] = None) -> None: + """ + Validate streams and update their status. + When batch_size=0, validates all channels. + + Args: + db_session: Optional SQLAlchemy session + """ + db = db_session if db_session else get_db_session() + try: + manager = StreamManager(db) + + # Get channels to validate + query = db.query(ChannelDB) + if self.batch_size > 0: + query = query.limit(self.batch_size) + channels = query.all() + + for channel in channels: + try: + logger.info(f"Validating streams for channel {channel.id}") + manager.validate_and_select_stream(str(channel.id)) + except Exception as e: + logger.error(f"Error validating channel {channel.id}: {str(e)}") + continue + + logger.info(f"Completed stream validation of {len(channels)} channels") + finally: + if db_session is None: + db.close() + + def start(self) -> None: + """Start the scheduler and add jobs.""" + if not self.scheduler.running: + # Add the scheduled job + self.scheduler.add_job( + self.validate_streams_batch, + trigger=CronTrigger.from_crontab(self.schedule_time), + id="daily_stream_validation", + ) + + # Start the scheduler + self.scheduler.start() + logger.info( + f"Stream scheduler started with daily validation job. " + f"Running: {self.scheduler.running}" + ) + + # Register shutdown handler if FastAPI app is provided + if self.app: + logger.info( + f"Registering scheduler with FastAPI " + f"app: {hasattr(self.app, 'state')}" + ) + + @self.app.on_event("shutdown") + def shutdown_scheduler(): + self.shutdown() + + def shutdown(self) -> None: + """Shutdown the scheduler gracefully.""" + if self.scheduler.running: + self.scheduler.shutdown() + logger.info("Stream scheduler stopped") + + def trigger_manual_validation(self) -> None: + """Trigger manual validation of streams.""" + logger.info("Manually triggering stream validation") + self.validate_streams_batch() + + +def init_scheduler(app: FastAPI) -> StreamScheduler: + """Initialize and start the scheduler with FastAPI integration.""" + scheduler = StreamScheduler(app) + scheduler.start() + return scheduler diff --git a/app/iptv/stream_manager.py b/app/iptv/stream_manager.py new file mode 100644 index 0000000..bcdc548 --- /dev/null +++ b/app/iptv/stream_manager.py @@ -0,0 +1,151 @@ +import logging +import random +from typing import Optional + +from sqlalchemy.orm import Session + +from app.models.db import ChannelURL +from app.utils.check_streams import StreamValidator +from app.utils.database import get_db_session + +logger = logging.getLogger(__name__) + + +class StreamManager: + """Service for managing and validating channel streams.""" + + def __init__(self, db_session: Optional[Session] = None): + """ + Initialize StreamManager with optional database session. + + Args: + db_session: Optional SQLAlchemy session. If None, will create a new one. + """ + self.db = db_session if db_session else get_db_session() + self.validator = StreamValidator() + + def get_streams_for_channel(self, channel_id: str) -> list[ChannelURL]: + """ + Get all streams for a channel ordered by priority (lowest first), + with same-priority streams randomized. + + Args: + channel_id: UUID of the channel to get streams for + + Returns: + List of ChannelURL objects ordered by priority + """ + try: + # Get all streams for channel ordered by priority + streams = ( + self.db.query(ChannelURL) + .filter(ChannelURL.channel_id == channel_id) + .order_by(ChannelURL.priority_id) + .all() + ) + + # Group streams by priority and randomize same-priority streams + grouped = {} + for stream in streams: + if stream.priority_id not in grouped: + grouped[stream.priority_id] = [] + grouped[stream.priority_id].append(stream) + + # Randomize same-priority streams and flatten + randomized_streams = [] + for priority in sorted(grouped.keys()): + random.shuffle(grouped[priority]) + randomized_streams.extend(grouped[priority]) + + return randomized_streams + + except Exception as e: + logger.error(f"Error getting streams for channel {channel_id}: {str(e)}") + raise + + def validate_and_select_stream(self, channel_id: str) -> Optional[str]: + """ + Find and validate a working stream for the given channel. + + Args: + channel_id: UUID of the channel to find a stream for + + Returns: + URL of the first working stream found, or None if none found + """ + try: + streams = self.get_streams_for_channel(channel_id) + if not streams: + logger.warning(f"No streams found for channel {channel_id}") + return None + + working_stream = None + + for stream in streams: + logger.info(f"Validating stream {stream.url} for channel {channel_id}") + is_valid, _ = self.validator.validate_stream(stream.url) + + if is_valid: + working_stream = stream + break + + if working_stream: + self._update_stream_status(working_stream, streams) + return working_stream.url + else: + logger.warning(f"No valid streams found for channel {channel_id}") + return None + + except Exception as e: + logger.error(f"Error validating streams for channel {channel_id}: {str(e)}") + raise + + def _update_stream_status( + self, working_stream: ChannelURL, all_streams: list[ChannelURL] + ) -> None: + """ + Update in_use status for streams (True for working stream, False for others). + + Args: + working_stream: The stream that was validated as working + all_streams: All streams for the channel + """ + try: + for stream in all_streams: + stream.in_use = stream.id == working_stream.id + + self.db.commit() + logger.info( + f"Updated stream status - set in_use=True for {working_stream.url}" + ) + + except Exception as e: + self.db.rollback() + logger.error(f"Error updating stream status: {str(e)}") + raise + + def __del__(self): + """Close database session when StreamManager is destroyed.""" + if hasattr(self, "db"): + self.db.close() + + +def get_working_stream( + channel_id: str, db_session: Optional[Session] = None +) -> Optional[str]: + """ + Convenience function to get a working stream for a channel. + + Args: + channel_id: UUID of the channel to get a stream for + db_session: Optional SQLAlchemy session + + Returns: + URL of the first working stream found, or None if none found + """ + manager = StreamManager(db_session) + try: + return manager.validate_and_select_stream(channel_id) + finally: + if db_session is None: # Only close if we created the session + manager.__del__() diff --git a/app/main.py b/app/main.py index b03957e..6564bf4 100644 --- a/app/main.py +++ b/app/main.py @@ -2,7 +2,8 @@ from fastapi import FastAPI from fastapi.concurrency import asynccontextmanager from fastapi.openapi.utils import get_openapi -from app.routers import auth, channels, groups, playlist, priorities +from app.iptv.scheduler import StreamScheduler +from app.routers import auth, channels, groups, playlist, priorities, scheduler from app.utils.database import init_db @@ -10,8 +11,17 @@ from app.utils.database import init_db async def lifespan(app: FastAPI): # Initialize database tables on startup init_db() + + # Initialize and start the stream scheduler + scheduler = StreamScheduler(app) + app.state.scheduler = scheduler # Store scheduler in app state + scheduler.start() + yield + # Shutdown scheduler on app shutdown + scheduler.shutdown() + app = FastAPI( lifespan=lifespan, @@ -69,3 +79,4 @@ app.include_router(channels.router) app.include_router(playlist.router) app.include_router(priorities.router) app.include_router(groups.router) +app.include_router(scheduler.router) diff --git a/app/routers/playlist.py b/app/routers/playlist.py index c38e80a..29312c6 100644 --- a/app/routers/playlist.py +++ b/app/routers/playlist.py @@ -1,15 +1,156 @@ -from fastapi import APIRouter, Depends +import logging +from enum import Enum +from typing import Optional +from uuid import uuid4 + +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status +from pydantic import BaseModel +from sqlalchemy.orm import Session from app.auth.dependencies import get_current_user +from app.iptv.stream_manager import StreamManager from app.models.auth import CognitoUser +from app.utils.database import get_db_session router = APIRouter(prefix="/playlist", tags=["playlist"]) +logger = logging.getLogger(__name__) + +# In-memory store for validation processes +validation_processes: dict[str, dict] = {} -@router.get("/protected", summary="Protected endpoint for authenticated users") -async def protected_route(user: CognitoUser = Depends(get_current_user)): +class ProcessStatus(str, Enum): + PENDING = "pending" + IN_PROGRESS = "in_progress" + COMPLETED = "completed" + FAILED = "failed" + + +class StreamValidationRequest(BaseModel): + """Request model for stream validation endpoint""" + + channel_id: Optional[str] = None + + +class ValidatedStream(BaseModel): + """Model for a validated working stream""" + + channel_id: str + stream_url: str + + +class ValidationProcessResponse(BaseModel): + """Response model for validation process initiation""" + + process_id: str + status: ProcessStatus + message: str + + +class ValidationResultResponse(BaseModel): + """Response model for validation results""" + + process_id: str + status: ProcessStatus + working_streams: Optional[list[ValidatedStream]] = None + error: Optional[str] = None + + +def run_stream_validation(process_id: str, channel_id: Optional[str], db: Session): + """Background task to validate streams""" + try: + validation_processes[process_id]["status"] = ProcessStatus.IN_PROGRESS + manager = StreamManager(db) + + if channel_id: + stream_url = manager.validate_and_select_stream(channel_id) + if stream_url: + validation_processes[process_id]["result"] = { + "working_streams": [ + ValidatedStream(channel_id=channel_id, stream_url=stream_url) + ] + } + else: + validation_processes[process_id]["error"] = ( + f"No working streams found for channel {channel_id}" + ) + else: + # TODO: Implement validation for all channels + validation_processes[process_id]["error"] = ( + "Validation of all channels not yet implemented" + ) + + validation_processes[process_id]["status"] = ProcessStatus.COMPLETED + + except Exception as e: + logger.error(f"Error validating streams: {str(e)}") + validation_processes[process_id]["status"] = ProcessStatus.FAILED + validation_processes[process_id]["error"] = str(e) + + +@router.post( + "/validate-streams", + summary="Start stream validation process", + response_model=ValidationProcessResponse, + status_code=status.HTTP_202_ACCEPTED, + responses={202: {"description": "Validation process started successfully"}}, +) +async def start_stream_validation( + request: StreamValidationRequest, + background_tasks: BackgroundTasks, + user: CognitoUser = Depends(get_current_user), + db: Session = Depends(get_db_session), +): """ - Protected endpoint that requires authentication for all users. - If the user is authenticated, returns success message. + Start asynchronous validation of streams. + + - Returns immediately with a process ID + - Use GET /validate-streams/{process_id} to check status """ - return {"message": f"Hello {user.username}, you have access to support resources!"} + process_id = str(uuid4()) + validation_processes[process_id] = { + "status": ProcessStatus.PENDING, + "channel_id": request.channel_id, + } + + background_tasks.add_task(run_stream_validation, process_id, request.channel_id, db) + + return { + "process_id": process_id, + "status": ProcessStatus.PENDING, + "message": "Validation process started", + } + + +@router.get( + "/validate-streams/{process_id}", + summary="Check validation process status", + response_model=ValidationResultResponse, + responses={ + 200: {"description": "Process status and results"}, + 404: {"description": "Process not found"}, + }, +) +async def get_validation_status( + process_id: str, user: CognitoUser = Depends(get_current_user) +): + """ + Check status of a stream validation process. + + Returns current status and results if completed. + """ + if process_id not in validation_processes: + raise HTTPException(status_code=404, detail="Process not found") + + process = validation_processes[process_id] + response = {"process_id": process_id, "status": process["status"]} + + if process["status"] == ProcessStatus.COMPLETED: + if "error" in process: + response["error"] = process["error"] + else: + response["working_streams"] = process["result"]["working_streams"] + elif process["status"] == ProcessStatus.FAILED: + response["error"] = process["error"] + + return response diff --git a/app/routers/scheduler.py b/app/routers/scheduler.py new file mode 100644 index 0000000..1ec8485 --- /dev/null +++ b/app/routers/scheduler.py @@ -0,0 +1,57 @@ +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import JSONResponse +from sqlalchemy.orm import Session + +from app.auth.dependencies import get_current_user, require_roles +from app.iptv.scheduler import StreamScheduler +from app.models.auth import CognitoUser +from app.utils.database import get_db + +router = APIRouter( + prefix="/scheduler", + tags=["scheduler"], + responses={404: {"description": "Not found"}}, +) + + +async def get_scheduler(request: Request) -> StreamScheduler: + """Get the scheduler instance from the app state.""" + if not hasattr(request.app.state.scheduler, "scheduler"): + raise HTTPException(status_code=500, detail="Scheduler not initialized") + return request.app.state.scheduler + + +@router.get("/health") +@require_roles("admin") +def scheduler_health( + scheduler: StreamScheduler = Depends(get_scheduler), + user: CognitoUser = Depends(get_current_user), + db: Session = Depends(get_db), +): + """Check scheduler health status (admin only).""" + try: + job = scheduler.scheduler.get_job("daily_stream_validation") + next_run = str(job.next_run_time) if job and job.next_run_time else None + + return { + "status": "running" if scheduler.scheduler.running else "stopped", + "next_run": next_run, + } + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to check scheduler health: {str(e)}" + ) + + +@router.post("/trigger") +@require_roles("admin") +def trigger_validation( + scheduler: StreamScheduler = Depends(get_scheduler), + user: CognitoUser = Depends(get_current_user), + db: Session = Depends(get_db), +): + """Manually trigger stream validation (admin only).""" + scheduler.trigger_manual_validation() + return JSONResponse( + status_code=202, content={"message": "Stream validation triggered"} + ) diff --git a/app/utils/database.py b/app/utils/database.py index c5dea92..67975b0 100644 --- a/app/utils/database.py +++ b/app/utils/database.py @@ -1,6 +1,7 @@ import os import boto3 +from requests import Session from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker @@ -53,3 +54,8 @@ def get_db(): yield db finally: db.close() + + +def get_db_session() -> Session: + """Get a direct database session (non-generator version)""" + return SessionLocal() diff --git a/requirements.txt b/requirements.txt index 390d76c..c8b2922 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,4 +18,5 @@ pytest-mock==3.12.0 pytest-cov==4.1.0 pytest-env==1.1.1 httpx==0.27.0 -pre-commit \ No newline at end of file +pre-commit +apscheduler==3.10.4 \ No newline at end of file diff --git a/tests/auth/test_dependencies.py b/tests/auth/test_dependencies.py index 17f033e..2a214c8 100644 --- a/tests/auth/test_dependencies.py +++ b/tests/auth/test_dependencies.py @@ -26,7 +26,7 @@ def mock_get_user_from_token(token: str) -> CognitoUser: # Mock endpoint for testing the require_roles decorator @require_roles("admin") -async def mock_protected_endpoint(user: CognitoUser = Depends(get_current_user)): +def mock_protected_endpoint(user: CognitoUser = Depends(get_current_user)): return {"message": "Success", "user": user.username} @@ -96,7 +96,7 @@ async def test_require_roles_no_roles(): async def test_require_roles_multiple_roles(): # Test requiring multiple roles @require_roles("admin", "super_user") - async def mock_multi_role_endpoint(user: CognitoUser = Depends(get_current_user)): + def mock_multi_role_endpoint(user: CognitoUser = Depends(get_current_user)): return {"message": "Success"} # User with all required roles diff --git a/tests/routers/mocks.py b/tests/routers/mocks.py new file mode 100644 index 0000000..bc239dc --- /dev/null +++ b/tests/routers/mocks.py @@ -0,0 +1,43 @@ +from unittest.mock import Mock + +from fastapi import Request + +from app.iptv.scheduler import StreamScheduler + + +class MockScheduler: + """Base mock APScheduler instance""" + + running = True + start = Mock() + shutdown = Mock() + add_job = Mock() + remove_job = Mock() + get_job = Mock(return_value=None) + + def __init__(self, running=True): + self.running = running + + +def create_trigger_mock(triggered_ref: dict) -> callable: + """Create a mock trigger function that updates a reference when called""" + + def trigger_mock(): + triggered_ref["value"] = True + + return trigger_mock + + +async def mock_get_scheduler( + request: Request, scheduler_class=MockScheduler, running=True, **kwargs +) -> StreamScheduler: + """Mock dependency for get_scheduler with customization options""" + scheduler = StreamScheduler() + mock_scheduler = scheduler_class(running=running) + + # Apply any additional attributes/methods + for key, value in kwargs.items(): + setattr(mock_scheduler, key, value) + + scheduler.scheduler = mock_scheduler + return scheduler diff --git a/tests/routers/test_playlist.py b/tests/routers/test_playlist.py index ac05124..33877fa 100644 --- a/tests/routers/test_playlist.py +++ b/tests/routers/test_playlist.py @@ -1,43 +1,261 @@ +import uuid +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + import pytest from fastapi import status +from sqlalchemy.orm import Session +from app.auth.dependencies import get_current_user + +# Import the router we're testing +from app.routers.playlist import ( + ProcessStatus, + ValidationProcessResponse, + ValidationResultResponse, + router, + validation_processes, +) +from app.utils.database import get_db + +# Import mocks and fixtures from tests.utils.auth_test_fixtures import ( admin_user_client, db_session, non_admin_user_client, ) +from tests.utils.db_mocks import MockChannelDB + +# --- Test Fixtures --- -def test_protected_route_admin_access(db_session, admin_user_client): - """Test that admin users can access the protected route""" - response = admin_user_client.get("/playlist/protected") +@pytest.fixture +def mock_stream_manager(): + with patch("app.routers.playlist.StreamManager") as mock: + yield mock + + +# --- Test Cases For Stream Validation --- + + +def test_start_stream_validation_success( + db_session: Session, admin_user_client, mock_stream_manager +): + """Test starting a stream validation process""" + mock_instance = mock_stream_manager.return_value + mock_instance.validate_and_select_stream.return_value = "http://valid.stream.url" + + response = admin_user_client.post( + "/playlist/validate-streams", json={"channel_id": "test-channel"} + ) + + assert response.status_code == status.HTTP_202_ACCEPTED + data = response.json() + assert "process_id" in data + assert data["status"] == ProcessStatus.PENDING + assert data["message"] == "Validation process started" + + # Verify process was added to tracking + process_id = data["process_id"] + assert process_id in validation_processes + # In test environment, background tasks run synchronously so status may be COMPLETED + assert validation_processes[process_id]["status"] in [ + ProcessStatus.PENDING, + ProcessStatus.COMPLETED, + ] + assert validation_processes[process_id]["channel_id"] == "test-channel" + + +def test_get_validation_status_pending(db_session: Session, admin_user_client): + """Test checking status of pending validation""" + process_id = str(uuid.uuid4()) + validation_processes[process_id] = { + "status": ProcessStatus.PENDING, + "channel_id": "test-channel", + } + + response = admin_user_client.get(f"/playlist/validate-streams/{process_id}") + assert response.status_code == status.HTTP_200_OK data = response.json() - assert "access to support resources" in data["message"] - assert "testadmin" in data["message"] + assert data["process_id"] == process_id + assert data["status"] == ProcessStatus.PENDING + assert data["working_streams"] is None + assert data["error"] is None -def test_protected_route_non_admin_access(db_session, non_admin_user_client): - """Test that non-admin users can access the protected route - (just requires authentication)""" - response = non_admin_user_client.get("/playlist/protected") +def test_get_validation_status_completed(db_session: Session, admin_user_client): + """Test checking status of completed validation""" + process_id = str(uuid.uuid4()) + validation_processes[process_id] = { + "status": ProcessStatus.COMPLETED, + "channel_id": "test-channel", + "result": { + "working_streams": [ + {"channel_id": "test-channel", "stream_url": "http://valid.stream.url"} + ] + }, + } + + response = admin_user_client.get(f"/playlist/validate-streams/{process_id}") + assert response.status_code == status.HTTP_200_OK data = response.json() - assert "access to support resources" in data["message"] - assert "testuser" in data["message"] + assert data["process_id"] == process_id + assert data["status"] == ProcessStatus.COMPLETED + assert len(data["working_streams"]) == 1 + assert data["working_streams"][0]["channel_id"] == "test-channel" + assert data["working_streams"][0]["stream_url"] == "http://valid.stream.url" + assert data["error"] is None -def test_protected_route_no_auth(): - """Test that unauthenticated users cannot access the protected route""" - from fastapi import FastAPI - from fastapi.testclient import TestClient +def test_get_validation_status_completed_with_error( + db_session: Session, admin_user_client +): + """Test checking status of completed validation with error""" + process_id = str(uuid.uuid4()) + validation_processes[process_id] = { + "status": ProcessStatus.COMPLETED, + "channel_id": "test-channel", + "error": "No working streams found for channel test-channel", + } - from app.routers.playlist import router as playlist_router + response = admin_user_client.get(f"/playlist/validate-streams/{process_id}") - app = FastAPI() - app.include_router(playlist_router) - client = TestClient(app) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["process_id"] == process_id + assert data["status"] == ProcessStatus.COMPLETED + assert data["working_streams"] is None + assert data["error"] == "No working streams found for channel test-channel" - response = client.get("/playlist/protected") - assert response.status_code == status.HTTP_401_UNAUTHORIZED - assert "Not authenticated" in response.json()["detail"] + +def test_get_validation_status_failed(db_session: Session, admin_user_client): + """Test checking status of failed validation""" + process_id = str(uuid.uuid4()) + validation_processes[process_id] = { + "status": ProcessStatus.FAILED, + "channel_id": "test-channel", + "error": "Validation error occurred", + } + + response = admin_user_client.get(f"/playlist/validate-streams/{process_id}") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["process_id"] == process_id + assert data["status"] == ProcessStatus.FAILED + assert data["working_streams"] is None + assert data["error"] == "Validation error occurred" + + +def test_get_validation_status_not_found(db_session: Session, admin_user_client): + """Test checking status of non-existent process""" + random_uuid = str(uuid.uuid4()) + response = admin_user_client.get(f"/playlist/validate-streams/{random_uuid}") + + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "Process not found" in response.json()["detail"] + + +def test_run_stream_validation_success(mock_stream_manager, db_session): + """Test the background validation task success case""" + process_id = str(uuid.uuid4()) + validation_processes[process_id] = { + "status": ProcessStatus.PENDING, + "channel_id": "test-channel", + } + + mock_instance = mock_stream_manager.return_value + mock_instance.validate_and_select_stream.return_value = "http://valid.stream.url" + + from app.routers.playlist import run_stream_validation + + run_stream_validation(process_id, "test-channel", db_session) + + assert validation_processes[process_id]["status"] == ProcessStatus.COMPLETED + assert len(validation_processes[process_id]["result"]["working_streams"]) == 1 + assert ( + validation_processes[process_id]["result"]["working_streams"][0].channel_id + == "test-channel" + ) + assert ( + validation_processes[process_id]["result"]["working_streams"][0].stream_url + == "http://valid.stream.url" + ) + + +def test_run_stream_validation_failure(mock_stream_manager, db_session): + """Test the background validation task failure case""" + process_id = str(uuid.uuid4()) + validation_processes[process_id] = { + "status": ProcessStatus.PENDING, + "channel_id": "test-channel", + } + + mock_instance = mock_stream_manager.return_value + mock_instance.validate_and_select_stream.return_value = None + + from app.routers.playlist import run_stream_validation + + run_stream_validation(process_id, "test-channel", db_session) + + assert validation_processes[process_id]["status"] == ProcessStatus.COMPLETED + assert "error" in validation_processes[process_id] + assert "No working streams found" in validation_processes[process_id]["error"] + + +def test_run_stream_validation_exception(mock_stream_manager, db_session): + """Test the background validation task exception case""" + process_id = str(uuid.uuid4()) + validation_processes[process_id] = { + "status": ProcessStatus.PENDING, + "channel_id": "test-channel", + } + + mock_instance = mock_stream_manager.return_value + mock_instance.validate_and_select_stream.side_effect = Exception("Test error") + + from app.routers.playlist import run_stream_validation + + run_stream_validation(process_id, "test-channel", db_session) + + assert validation_processes[process_id]["status"] == ProcessStatus.FAILED + assert "error" in validation_processes[process_id] + assert "Test error" in validation_processes[process_id]["error"] + + +def test_start_stream_validation_no_channel_id( + db_session: Session, admin_user_client, mock_stream_manager +): + """Test starting validation without channel_id""" + response = admin_user_client.post("/playlist/validate-streams", json={}) + + assert response.status_code == status.HTTP_202_ACCEPTED + data = response.json() + assert "process_id" in data + assert data["status"] == ProcessStatus.PENDING + + # Verify process was added to tracking + process_id = data["process_id"] + assert process_id in validation_processes + assert validation_processes[process_id]["status"] in [ + ProcessStatus.PENDING, + ProcessStatus.COMPLETED, + ] + assert validation_processes[process_id]["channel_id"] is None + assert "not yet implemented" in validation_processes[process_id].get("error", "") + + +def test_run_stream_validation_no_channel_id(mock_stream_manager, db_session): + """Test background validation without channel_id""" + process_id = str(uuid.uuid4()) + validation_processes[process_id] = {"status": ProcessStatus.PENDING} + + from app.routers.playlist import run_stream_validation + + run_stream_validation(process_id, None, db_session) + + assert validation_processes[process_id]["status"] == ProcessStatus.COMPLETED + assert "error" in validation_processes[process_id] + assert "not yet implemented" in validation_processes[process_id]["error"] diff --git a/tests/routers/test_scheduler.py b/tests/routers/test_scheduler.py new file mode 100644 index 0000000..dd76974 --- /dev/null +++ b/tests/routers/test_scheduler.py @@ -0,0 +1,287 @@ +from datetime import datetime, timezone +from unittest.mock import Mock + +from fastapi import HTTPException, Request, status + +from app.iptv.scheduler import StreamScheduler +from app.routers.scheduler import get_scheduler +from app.routers.scheduler import router as scheduler_router +from app.utils.database import get_db +from tests.routers.mocks import MockScheduler, create_trigger_mock, mock_get_scheduler +from tests.utils.auth_test_fixtures import ( + admin_user_client, + db_session, + non_admin_user_client, +) +from tests.utils.db_mocks import mock_get_db + +# Scheduler Health Check Tests + + +def test_scheduler_health_success(admin_user_client, monkeypatch): + """ + Test case for successful scheduler health check when accessed by an admin user. + It mocks the scheduler to be running and have a next scheduled job. + """ + + # Define the expected next run time for the scheduler job. + next_run = datetime.now(timezone.utc) + + # Create a mock job object that simulates an APScheduler job. + mock_job = Mock() + mock_job.next_run_time = next_run + + # Mock the `get_job` method to return our mock_job for a specific ID. + def mock_get_job(job_id): + if job_id == "daily_stream_validation": + return mock_job + return None + + # Create a custom mock for `get_scheduler` dependency. + async def custom_mock_get_scheduler(request: Request) -> StreamScheduler: + return await mock_get_scheduler( + request, + running=True, + get_job=Mock(side_effect=mock_get_job), # Use the custom mock_get_job + ) + + # Include the scheduler router in the test application. + admin_user_client.app.include_router(scheduler_router) + + # Override dependencies for the test. + admin_user_client.app.dependency_overrides[get_scheduler] = ( + custom_mock_get_scheduler + ) + admin_user_client.app.dependency_overrides[get_db] = mock_get_db + + # Make the request to the scheduler health endpoint. + response = admin_user_client.get("/scheduler/health") + + # Assert the response status code and content. + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["status"] == "running" + assert data["next_run"] == str(next_run) + + +def test_scheduler_health_stopped(admin_user_client, monkeypatch): + """ + Test case for scheduler health check when the scheduler is in a stopped state. + Ensures the API returns the correct status and no next run time. + """ + + # Create a custom mock for `get_scheduler` dependency, + # simulating a stopped scheduler. + async def custom_mock_get_scheduler(request: Request) -> StreamScheduler: + return await mock_get_scheduler( + request, + running=False, + ) + + # Include the scheduler router in the test application. + admin_user_client.app.include_router(scheduler_router) + + # Override dependencies for the test. + admin_user_client.app.dependency_overrides[get_scheduler] = ( + custom_mock_get_scheduler + ) + admin_user_client.app.dependency_overrides[get_db] = mock_get_db + + # Make the request to the scheduler health endpoint. + response = admin_user_client.get("/scheduler/health") + + # Assert the response status code and content. + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["status"] == "stopped" + assert data["next_run"] is None + + +def test_scheduler_health_forbidden_for_non_admin(non_admin_user_client, monkeypatch): + """ + Test case to ensure that non-admin users are forbidden from accessing + the scheduler health endpoint. + """ + + # Create a custom mock for `get_scheduler` dependency. + async def custom_mock_get_scheduler(request: Request) -> StreamScheduler: + return await mock_get_scheduler( + request, + running=False, + ) + + # Include the scheduler router in the test application. + non_admin_user_client.app.include_router(scheduler_router) + + # Override dependencies for the test. + non_admin_user_client.app.dependency_overrides[get_scheduler] = ( + custom_mock_get_scheduler + ) + non_admin_user_client.app.dependency_overrides[get_db] = mock_get_db + + # Make the request to the scheduler health endpoint. + response = non_admin_user_client.get("/scheduler/health") + + # Assert the response status code and error detail. + assert response.status_code == status.HTTP_403_FORBIDDEN + assert "required roles" in response.json()["detail"] + + +def test_scheduler_health_check_exception(admin_user_client, monkeypatch): + """ + Test case for handling exceptions during the scheduler health check. + Ensures the API returns a 500 Internal Server Error when an exception occurs. + """ + + # Create a custom mock for `get_scheduler` dependency that raises an exception. + async def custom_mock_get_scheduler(request: Request) -> StreamScheduler: + return await mock_get_scheduler( + request, running=True, get_job=Mock(side_effect=Exception("Test exception")) + ) + + # Include the scheduler router in the test application. + admin_user_client.app.include_router(scheduler_router) + + # Override dependencies for the test. + admin_user_client.app.dependency_overrides[get_scheduler] = ( + custom_mock_get_scheduler + ) + admin_user_client.app.dependency_overrides[get_db] = mock_get_db + + # Make the request to the scheduler health endpoint. + response = admin_user_client.get("/scheduler/health") + + # Assert the response status code and error detail. + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert "Failed to check scheduler health" in response.json()["detail"] + + +# Scheduler Trigger Tests + + +def test_trigger_validation_success(admin_user_client, monkeypatch): + """ + Test case for successful manual triggering + of stream validation by an admin user. + It verifies that the trigger method is called and + the API returns a 202 Accepted status. + """ + # Use a mutable reference to check if the trigger method was called. + triggered_ref = {"value": False} + + # Initialize a custom mock scheduler. + custom_scheduler = MockScheduler(running=True) + custom_scheduler.get_job = Mock(return_value=None) + + # Create a custom mock for `get_scheduler` dependency, + # overriding `trigger_manual_validation`. + async def custom_mock_get_scheduler(request: Request) -> StreamScheduler: + scheduler = await mock_get_scheduler( + request, + running=True, + ) + + # Replace the actual trigger method with our mock to track calls. + scheduler.trigger_manual_validation = create_trigger_mock( + triggered_ref=triggered_ref + ) + + return scheduler + + # Include the scheduler router in the test application. + admin_user_client.app.include_router(scheduler_router) + + # Override dependencies for the test. + admin_user_client.app.dependency_overrides[get_scheduler] = ( + custom_mock_get_scheduler + ) + admin_user_client.app.dependency_overrides[get_db] = mock_get_db + + # Make the request to trigger stream validation. + response = admin_user_client.post("/scheduler/trigger") + + # Assert the response status code, message, and that the trigger was called. + assert response.status_code == status.HTTP_202_ACCEPTED + assert response.json()["message"] == "Stream validation triggered" + assert triggered_ref["value"] is True + + +def test_trigger_validation_forbidden_for_non_admin(non_admin_user_client, monkeypatch): + """ + Test case to ensure that non-admin users are + forbidden from manually triggering stream validation. + """ + + # Create a custom mock for `get_scheduler` dependency. + async def custom_mock_get_scheduler(request: Request) -> StreamScheduler: + return await mock_get_scheduler( + request, + running=True, + ) + + # Include the scheduler router in the test application. + non_admin_user_client.app.include_router(scheduler_router) + + # Override dependencies for the test. + non_admin_user_client.app.dependency_overrides[get_scheduler] = ( + custom_mock_get_scheduler + ) + non_admin_user_client.app.dependency_overrides[get_db] = mock_get_db + + # Make the request to trigger stream validation. + response = non_admin_user_client.post("/scheduler/trigger") + + # Assert the response status code and error detail. + assert response.status_code == status.HTTP_403_FORBIDDEN + assert "required roles" in response.json()["detail"] + + +def test_scheduler_initialized_in_app_state(admin_user_client): + """ + Test case for when the scheduler is initialized in the app state but its internal + scheduler attribute is not set, which should still allow health check. + """ + scheduler = StreamScheduler() + + # Set the scheduler instance in the test client's app state. + admin_user_client.app.state.scheduler = scheduler + + # Include the scheduler router in the test application. + admin_user_client.app.include_router(scheduler_router) + + # Override only get_db, allowing the real get_scheduler to be tested. + admin_user_client.app.dependency_overrides[get_db] = mock_get_db + + # Make the request to the scheduler health endpoint. + response = admin_user_client.get("/scheduler/health") + + # Assert the response status code. + assert response.status_code == status.HTTP_200_OK + + +def test_scheduler_not_initialized_in_app_state(admin_user_client): + """ + Test case for when the scheduler is not properly initialized in the app state. + This simulates a scenario where the internal scheduler attribute is missing, + leading to a 500 Internal Server Error on health check. + """ + scheduler = StreamScheduler() + del ( + scheduler.scheduler + ) # Simulate uninitialized scheduler by deleting the attribute + + # Set the scheduler instance in the test client's app state. + admin_user_client.app.state.scheduler = scheduler + + # Include the scheduler router in the test application. + admin_user_client.app.include_router(scheduler_router) + + # Override only get_db, allowing the real get_scheduler to be tested. + admin_user_client.app.dependency_overrides[get_db] = mock_get_db + + # Make the request to the scheduler health endpoint. + response = admin_user_client.get("/scheduler/health") + + # Assert the response status code and error detail. + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert "Scheduler not initialized" in response.json()["detail"]