Files
iptv-manager-service/tests/utils/test_database.py
Stefano cebbb9c1a8
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m4s
Added pytest configuration and first 4 unit tests
2025-05-27 17:37:05 -05:00

150 lines
5.4 KiB
Python

import os
import pytest
import uuid
from datetime import datetime, timezone
from unittest.mock import patch, MagicMock
from sqlalchemy.pool import StaticPool
from sqlalchemy.orm import Session, sessionmaker, declarative_base
from sqlalchemy import create_engine, TypeDecorator, TEXT, Column, String, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer
from app.utils.database import get_db_credentials, get_db
# Create a mock-specific Base class for testing
MockBase = declarative_base()
class SQLiteUUID(TypeDecorator):
"""Enables UUID support for SQLite."""
impl = TEXT
cache_ok = True
def process_bind_param(self, value, dialect):
if value is None:
return value
return str(value)
def process_result_value(self, value, dialect):
if value is None:
return value
return uuid.UUID(value)
# Model classes for testing - prefix with Mock to avoid pytest collection
class MockPriority(MockBase):
__tablename__ = "priorities"
id = Column(Integer, primary_key=True)
description = Column(String, nullable=False)
class MockChannelDB(MockBase):
__tablename__ = "channels"
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
tvg_id = Column(String, nullable=False)
name = Column(String, nullable=False)
group_title = Column(String, nullable=False)
tvg_name = Column(String)
__table_args__ = (
UniqueConstraint('group_title', 'name', name='uix_group_title_name'),
)
tvg_logo = Column(String)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
class MockChannelURL(MockBase):
__tablename__ = "channels_urls"
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
channel_id = Column(SQLiteUUID(), ForeignKey('channels.id', ondelete='CASCADE'), nullable=False)
url = Column(String, nullable=False)
in_use = Column(Boolean, default=False, nullable=False)
priority_id = Column(Integer, ForeignKey('priorities.id'), nullable=False)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
# Create test engine
TEST_ENGINE = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool
)
# Create test session
TEST_SESSION = sessionmaker(autocommit=False, autoflush=False, bind=TEST_ENGINE)
# Mock the actual database functions
def mock_get_db():
db = TEST_SESSION()
try:
yield db
finally:
db.close()
@pytest.fixture(scope="session", autouse=True)
def create_test_tables():
"""Create test database tables for all tests"""
MockBase.metadata.create_all(bind=TEST_ENGINE)
@pytest.fixture(autouse=True)
def mock_db_engine(monkeypatch):
"""Fixture to mock database engine and session for each test"""
# First mock get_db_credentials to prevent real connection attempts
def mock_credentials():
return "sqlite:///:memory:"
monkeypatch.setattr('app.utils.database.get_db_credentials', mock_credentials)
# Then patch the actual database functions
monkeypatch.setattr('app.utils.database.engine', TEST_ENGINE)
monkeypatch.setattr('app.utils.database.SessionLocal', TEST_SESSION)
monkeypatch.setattr('app.utils.database.get_db', mock_get_db)
@pytest.fixture(autouse=True)
def mock_env(monkeypatch):
"""Fixture for mocking environment variables"""
# Clear any existing env vars first
monkeypatch.delenv("MOCK_AUTH", raising=False)
monkeypatch.delenv("DB_USER", raising=False)
monkeypatch.delenv("DB_PASSWORD", raising=False)
monkeypatch.delenv("DB_HOST", raising=False)
monkeypatch.delenv("DB_NAME", raising=False)
# Set mock values
monkeypatch.setenv("MOCK_AUTH", "true")
monkeypatch.setenv("DB_USER", "testuser")
monkeypatch.setenv("DB_PASSWORD", "testpass")
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_NAME", "testdb")
monkeypatch.setenv("AWS_REGION", "us-east-1") # Mock AWS region
@pytest.fixture
def mock_ssm():
"""Fixture for mocking boto3 SSM client"""
with patch('boto3.client') as mock_client:
mock_ssm = MagicMock()
mock_client.return_value = mock_ssm
mock_ssm.get_parameter.return_value = {
'Parameter': {'Value': 'mocked_value'}
}
yield mock_ssm
def test_get_db_credentials_env(mock_env):
"""Test getting DB credentials from environment variables"""
conn_str = get_db_credentials()
assert conn_str == "postgresql://testuser:testpass@localhost/testdb"
def test_get_db_credentials_ssm(mock_ssm):
"""Test getting DB credentials from SSM"""
os.environ.pop("MOCK_AUTH", None)
conn_str = get_db_credentials()
assert "postgresql://mocked_value:mocked_value@mocked_value/mocked_value" in conn_str
mock_ssm.get_parameter.assert_called()
def test_session_creation():
"""Test database session creation"""
session = TEST_SESSION()
assert isinstance(session, Session)
session.close()
def test_get_db_generator():
"""Test get_db dependency generator"""
db_gen = get_db()
db = next(db_gen)
assert isinstance(db, Session)
try:
next(db_gen) # Should raise StopIteration
except StopIteration:
pass