Linted and formatted all files

This commit is contained in:
2025-05-28 21:52:39 -05:00
parent e46f13930d
commit 02913c7385
31 changed files with 1264 additions and 766 deletions

View File

@@ -1,17 +1,30 @@
import os
import uuid
from datetime import datetime, timezone
from unittest.mock import patch, MagicMock
from sqlalchemy.pool import StaticPool
from sqlalchemy.orm import Session, sessionmaker, declarative_base
from sqlalchemy import create_engine, TypeDecorator, TEXT, Column, String, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer
from unittest.mock import MagicMock, patch
import pytest
from sqlalchemy import (
TEXT,
Boolean,
Column,
DateTime,
ForeignKey,
Integer,
String,
TypeDecorator,
UniqueConstraint,
create_engine,
)
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.pool import StaticPool
# Create a mock-specific Base class for testing
MockBase = declarative_base()
class SQLiteUUID(TypeDecorator):
"""Enables UUID support for SQLite."""
impl = TEXT
cache_ok = True
@@ -25,12 +38,14 @@ class SQLiteUUID(TypeDecorator):
return value
return uuid.UUID(value)
# Model classes for testing - prefix with Mock to avoid pytest collection
class MockPriority(MockBase):
__tablename__ = "priorities"
id = Column(Integer, primary_key=True)
description = Column(String, nullable=False)
class MockChannelDB(MockBase):
__tablename__ = "channels"
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
@@ -39,32 +54,45 @@ class MockChannelDB(MockBase):
group_title = Column(String, nullable=False)
tvg_name = Column(String)
__table_args__ = (
UniqueConstraint('group_title', 'name', name='uix_group_title_name'),
UniqueConstraint("group_title", "name", name="uix_group_title_name"),
)
tvg_logo = Column(String)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
updated_at = Column(
DateTime,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
class MockChannelURL(MockBase):
__tablename__ = "channels_urls"
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
channel_id = Column(SQLiteUUID(), ForeignKey('channels.id', ondelete='CASCADE'), nullable=False)
channel_id = Column(
SQLiteUUID(), ForeignKey("channels.id", ondelete="CASCADE"), nullable=False
)
url = Column(String, nullable=False)
in_use = Column(Boolean, default=False, nullable=False)
priority_id = Column(Integer, ForeignKey('priorities.id'), nullable=False)
priority_id = Column(Integer, ForeignKey("priorities.id"), nullable=False)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
updated_at = Column(
DateTime,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
# Create test engine
engine_mock = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool
poolclass=StaticPool,
)
# Create test session
session_mock = sessionmaker(autocommit=False, autoflush=False, bind=engine_mock)
# Mock the actual database functions
def mock_get_db():
db = session_mock()
@@ -73,6 +101,7 @@ def mock_get_db():
finally:
db.close()
@pytest.fixture(autouse=True)
def mock_env(monkeypatch):
"""Fixture for mocking environment variables"""
@@ -82,14 +111,13 @@ def mock_env(monkeypatch):
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_NAME", "testdb")
monkeypatch.setenv("AWS_REGION", "us-east-1")
@pytest.fixture
def mock_ssm():
"""Fixture for mocking boto3 SSM client"""
with patch('boto3.client') as mock_client:
with patch("boto3.client") as mock_client:
mock_ssm = MagicMock()
mock_client.return_value = mock_ssm
mock_ssm.get_parameter.return_value = {
'Parameter': {'Value': 'mocked_value'}
}
yield mock_ssm
mock_ssm.get_parameter.return_value = {"Parameter": {"Value": "mocked_value"}}
yield mock_ssm

View File

@@ -1,43 +1,45 @@
import os
import pytest
from unittest.mock import patch
from sqlalchemy.orm import Session
from app.utils.database import get_db_credentials, get_db
from tests.utils.db_mocks import (
session_mock,
mock_get_db,
mock_env,
mock_ssm
)
from app.utils.database import get_db, get_db_credentials
from tests.utils.db_mocks import mock_env, mock_ssm, session_mock
def test_get_db_credentials_env(mock_env):
"""Test getting DB credentials from environment variables"""
conn_str = get_db_credentials()
assert conn_str == "postgresql://testuser:testpass@localhost/testdb"
def test_get_db_credentials_ssm(mock_ssm):
"""Test getting DB credentials from SSM"""
os.environ.pop("MOCK_AUTH", None)
conn_str = get_db_credentials()
assert "postgresql://mocked_value:mocked_value@mocked_value/mocked_value" in conn_str
expected_conn = "postgresql://mocked_value:mocked_value@mocked_value/mocked_value"
assert expected_conn in conn_str
mock_ssm.get_parameter.assert_called()
def test_get_db_credentials_ssm_exception(mock_ssm):
"""Test SSM credential fetching failure raises RuntimeError"""
os.environ.pop("MOCK_AUTH", None)
mock_ssm.get_parameter.side_effect = Exception("SSM timeout")
with pytest.raises(RuntimeError) as excinfo:
get_db_credentials()
assert "Failed to fetch DB credentials from SSM: SSM timeout" in str(excinfo.value)
def test_session_creation():
"""Test database session creation"""
session = session_mock()
assert isinstance(session, Session)
session.close()
def test_get_db_generator():
"""Test get_db dependency generator"""
db_gen = get_db()
@@ -48,18 +50,20 @@ def test_get_db_generator():
except StopIteration:
pass
def test_init_db(mocker, mock_env):
"""Test database initialization creates tables"""
mock_create_all = mocker.patch('app.models.Base.metadata.create_all')
mock_create_all = mocker.patch("app.models.Base.metadata.create_all")
# Mock get_db_credentials to return SQLite test connection
mocker.patch(
'app.utils.database.get_db_credentials',
return_value="sqlite:///:memory:"
"app.utils.database.get_db_credentials",
return_value="sqlite:///:memory:",
)
from app.utils.database import init_db, engine
from app.utils.database import engine, init_db
init_db()
# Verify create_all was called with the engine
mock_create_all.assert_called_once_with(bind=engine)
mock_create_all.assert_called_once_with(bind=engine)