import os import pytest import uuid from datetime import datetime, timezone from unittest.mock import patch, MagicMock from sqlalchemy.pool import StaticPool from sqlalchemy.orm import Session, sessionmaker, declarative_base from sqlalchemy import create_engine, TypeDecorator, TEXT, Column, String, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer from app.utils.database import get_db_credentials, get_db # Create a mock-specific Base class for testing MockBase = declarative_base() class SQLiteUUID(TypeDecorator): """Enables UUID support for SQLite.""" impl = TEXT cache_ok = True def process_bind_param(self, value, dialect): if value is None: return value return str(value) def process_result_value(self, value, dialect): if value is None: return value return uuid.UUID(value) # Model classes for testing - prefix with Mock to avoid pytest collection class MockPriority(MockBase): __tablename__ = "priorities" id = Column(Integer, primary_key=True) description = Column(String, nullable=False) class MockChannelDB(MockBase): __tablename__ = "channels" id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4) tvg_id = Column(String, nullable=False) name = Column(String, nullable=False) group_title = Column(String, nullable=False) tvg_name = Column(String) __table_args__ = ( UniqueConstraint('group_title', 'name', name='uix_group_title_name'), ) tvg_logo = Column(String) created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) class MockChannelURL(MockBase): __tablename__ = "channels_urls" id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4) channel_id = Column(SQLiteUUID(), ForeignKey('channels.id', ondelete='CASCADE'), nullable=False) url = Column(String, nullable=False) in_use = Column(Boolean, default=False, nullable=False) priority_id = Column(Integer, ForeignKey('priorities.id'), nullable=False) created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) # Create test engine TEST_ENGINE = create_engine( "sqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool ) # Create test session TEST_SESSION = sessionmaker(autocommit=False, autoflush=False, bind=TEST_ENGINE) # Mock the actual database functions def mock_get_db(): db = TEST_SESSION() try: yield db finally: db.close() @pytest.fixture(scope="session", autouse=True) def create_test_tables(): """Create test database tables for all tests""" MockBase.metadata.create_all(bind=TEST_ENGINE) @pytest.fixture(autouse=True) def mock_db_engine(monkeypatch): """Fixture to mock database engine and session for each test""" # First mock get_db_credentials to prevent real connection attempts def mock_credentials(): return "sqlite:///:memory:" monkeypatch.setattr('app.utils.database.get_db_credentials', mock_credentials) # Then patch the actual database functions monkeypatch.setattr('app.utils.database.engine', TEST_ENGINE) monkeypatch.setattr('app.utils.database.SessionLocal', TEST_SESSION) monkeypatch.setattr('app.utils.database.get_db', mock_get_db) @pytest.fixture(autouse=True) def mock_env(monkeypatch): """Fixture for mocking environment variables""" # Clear any existing env vars first monkeypatch.delenv("MOCK_AUTH", raising=False) monkeypatch.delenv("DB_USER", raising=False) monkeypatch.delenv("DB_PASSWORD", raising=False) monkeypatch.delenv("DB_HOST", raising=False) monkeypatch.delenv("DB_NAME", raising=False) # Set mock values monkeypatch.setenv("MOCK_AUTH", "true") monkeypatch.setenv("DB_USER", "testuser") monkeypatch.setenv("DB_PASSWORD", "testpass") monkeypatch.setenv("DB_HOST", "localhost") monkeypatch.setenv("DB_NAME", "testdb") monkeypatch.setenv("AWS_REGION", "us-east-1") # Mock AWS region @pytest.fixture def mock_ssm(): """Fixture for mocking boto3 SSM client""" with patch('boto3.client') as mock_client: mock_ssm = MagicMock() mock_client.return_value = mock_ssm mock_ssm.get_parameter.return_value = { 'Parameter': {'Value': 'mocked_value'} } yield mock_ssm def test_get_db_credentials_env(mock_env): """Test getting DB credentials from environment variables""" conn_str = get_db_credentials() assert conn_str == "postgresql://testuser:testpass@localhost/testdb" def test_get_db_credentials_ssm(mock_ssm): """Test getting DB credentials from SSM""" os.environ.pop("MOCK_AUTH", None) conn_str = get_db_credentials() assert "postgresql://mocked_value:mocked_value@mocked_value/mocked_value" in conn_str mock_ssm.get_parameter.assert_called() def test_session_creation(): """Test database session creation""" session = TEST_SESSION() assert isinstance(session, Session) session.close() def test_get_db_generator(): """Test get_db dependency generator""" db_gen = get_db() db = next(db_gen) assert isinstance(db, Session) try: next(db_gen) # Should raise StopIteration except StopIteration: pass