Files
iptv-manager-service/tests/utils/test_database.py
Stefano fb5215b92a
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m8s
Cleanup test database unit tests
2025-05-27 17:57:28 -05:00

124 lines
4.3 KiB
Python

import os
import pytest
import uuid
from datetime import datetime, timezone
from unittest.mock import patch, MagicMock
from sqlalchemy.pool import StaticPool
from sqlalchemy.orm import Session, sessionmaker, declarative_base
from sqlalchemy import create_engine, TypeDecorator, TEXT, Column, String, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer
from app.utils.database import get_db_credentials, get_db
# Create a mock-specific Base class for testing
MockBase = declarative_base()
class SQLiteUUID(TypeDecorator):
"""Enables UUID support for SQLite."""
impl = TEXT
cache_ok = True
def process_bind_param(self, value, dialect):
if value is None:
return value
return str(value)
def process_result_value(self, value, dialect):
if value is None:
return value
return uuid.UUID(value)
# Model classes for testing - prefix with Mock to avoid pytest collection
class MockPriority(MockBase):
__tablename__ = "priorities"
id = Column(Integer, primary_key=True)
description = Column(String, nullable=False)
class MockChannelDB(MockBase):
__tablename__ = "channels"
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
tvg_id = Column(String, nullable=False)
name = Column(String, nullable=False)
group_title = Column(String, nullable=False)
tvg_name = Column(String)
__table_args__ = (
UniqueConstraint('group_title', 'name', name='uix_group_title_name'),
)
tvg_logo = Column(String)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
class MockChannelURL(MockBase):
__tablename__ = "channels_urls"
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
channel_id = Column(SQLiteUUID(), ForeignKey('channels.id', ondelete='CASCADE'), nullable=False)
url = Column(String, nullable=False)
in_use = Column(Boolean, default=False, nullable=False)
priority_id = Column(Integer, ForeignKey('priorities.id'), nullable=False)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
# Create test engine
engine_mock = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool
)
# Create test session
session_mock = sessionmaker(autocommit=False, autoflush=False, bind=engine_mock)
# Mock the actual database functions
def mock_get_db():
db = session_mock()
try:
yield db
finally:
db.close()
@pytest.fixture(autouse=True)
def mock_env(monkeypatch):
"""Fixture for mocking environment variables"""
monkeypatch.setenv("MOCK_AUTH", "true")
monkeypatch.setenv("DB_USER", "testuser")
monkeypatch.setenv("DB_PASSWORD", "testpass")
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_NAME", "testdb")
monkeypatch.setenv("AWS_REGION", "us-east-1")
@pytest.fixture
def mock_ssm():
"""Fixture for mocking boto3 SSM client"""
with patch('boto3.client') as mock_client:
mock_ssm = MagicMock()
mock_client.return_value = mock_ssm
mock_ssm.get_parameter.return_value = {
'Parameter': {'Value': 'mocked_value'}
}
yield mock_ssm
def test_get_db_credentials_env(mock_env):
"""Test getting DB credentials from environment variables"""
conn_str = get_db_credentials()
assert conn_str == "postgresql://testuser:testpass@localhost/testdb"
def test_get_db_credentials_ssm(mock_ssm):
"""Test getting DB credentials from SSM"""
os.environ.pop("MOCK_AUTH", None)
conn_str = get_db_credentials()
assert "postgresql://mocked_value:mocked_value@mocked_value/mocked_value" in conn_str
mock_ssm.get_parameter.assert_called()
def test_session_creation():
"""Test database session creation"""
session = session_mock()
assert isinstance(session, Session)
session.close()
def test_get_db_generator():
"""Test get_db dependency generator"""
db_gen = get_db()
db = next(db_gen)
assert isinstance(db, Session)
try:
next(db_gen) # Should raise StopIteration
except StopIteration:
pass