import os import pytest import uuid from datetime import datetime, timezone from unittest.mock import patch, MagicMock from sqlalchemy.pool import StaticPool from sqlalchemy.orm import Session, sessionmaker, declarative_base from sqlalchemy import create_engine, TypeDecorator, TEXT, Column, String, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer from app.utils.database import get_db_credentials, get_db # Create a mock-specific Base class for testing MockBase = declarative_base() class SQLiteUUID(TypeDecorator): """Enables UUID support for SQLite.""" impl = TEXT cache_ok = True def process_bind_param(self, value, dialect): if value is None: return value return str(value) def process_result_value(self, value, dialect): if value is None: return value return uuid.UUID(value) # Model classes for testing - prefix with Mock to avoid pytest collection class MockPriority(MockBase): __tablename__ = "priorities" id = Column(Integer, primary_key=True) description = Column(String, nullable=False) class MockChannelDB(MockBase): __tablename__ = "channels" id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4) tvg_id = Column(String, nullable=False) name = Column(String, nullable=False) group_title = Column(String, nullable=False) tvg_name = Column(String) __table_args__ = ( UniqueConstraint('group_title', 'name', name='uix_group_title_name'), ) tvg_logo = Column(String) created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) class MockChannelURL(MockBase): __tablename__ = "channels_urls" id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4) channel_id = Column(SQLiteUUID(), ForeignKey('channels.id', ondelete='CASCADE'), nullable=False) url = Column(String, nullable=False) in_use = Column(Boolean, default=False, nullable=False) priority_id = Column(Integer, ForeignKey('priorities.id'), nullable=False) created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) # Create test engine engine_mock = create_engine( "sqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool ) # Create test session session_mock = sessionmaker(autocommit=False, autoflush=False, bind=engine_mock) # Mock the actual database functions def mock_get_db(): db = session_mock() try: yield db finally: db.close() @pytest.fixture(autouse=True) def mock_env(monkeypatch): """Fixture for mocking environment variables""" monkeypatch.setenv("MOCK_AUTH", "true") monkeypatch.setenv("DB_USER", "testuser") monkeypatch.setenv("DB_PASSWORD", "testpass") monkeypatch.setenv("DB_HOST", "localhost") monkeypatch.setenv("DB_NAME", "testdb") monkeypatch.setenv("AWS_REGION", "us-east-1") @pytest.fixture def mock_ssm(): """Fixture for mocking boto3 SSM client""" with patch('boto3.client') as mock_client: mock_ssm = MagicMock() mock_client.return_value = mock_ssm mock_ssm.get_parameter.return_value = { 'Parameter': {'Value': 'mocked_value'} } yield mock_ssm def test_get_db_credentials_env(mock_env): """Test getting DB credentials from environment variables""" conn_str = get_db_credentials() assert conn_str == "postgresql://testuser:testpass@localhost/testdb" def test_get_db_credentials_ssm(mock_ssm): """Test getting DB credentials from SSM""" os.environ.pop("MOCK_AUTH", None) conn_str = get_db_credentials() assert "postgresql://mocked_value:mocked_value@mocked_value/mocked_value" in conn_str mock_ssm.get_parameter.assert_called() def test_session_creation(): """Test database session creation""" session = session_mock() assert isinstance(session, Session) session.close() def test_get_db_generator(): """Test get_db dependency generator""" db_gen = get_db() db = next(db_gen) assert isinstance(db, Session) try: next(db_gen) # Should raise StopIteration except StopIteration: pass