Added more unit tests for routers
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m10s
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m10s
This commit is contained in:
95
tests/utils/db_mocks.py
Normal file
95
tests/utils/db_mocks.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import patch, MagicMock
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlalchemy.orm import Session, sessionmaker, declarative_base
|
||||
from sqlalchemy import create_engine, TypeDecorator, TEXT, Column, String, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer
|
||||
import pytest
|
||||
|
||||
# Create a mock-specific Base class for testing
|
||||
MockBase = declarative_base()
|
||||
|
||||
class SQLiteUUID(TypeDecorator):
|
||||
"""Enables UUID support for SQLite."""
|
||||
impl = TEXT
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
return uuid.UUID(value)
|
||||
|
||||
# Model classes for testing - prefix with Mock to avoid pytest collection
|
||||
class MockPriority(MockBase):
|
||||
__tablename__ = "priorities"
|
||||
id = Column(Integer, primary_key=True)
|
||||
description = Column(String, nullable=False)
|
||||
|
||||
class MockChannelDB(MockBase):
|
||||
__tablename__ = "channels"
|
||||
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
|
||||
tvg_id = Column(String, nullable=False)
|
||||
name = Column(String, nullable=False)
|
||||
group_title = Column(String, nullable=False)
|
||||
tvg_name = Column(String)
|
||||
__table_args__ = (
|
||||
UniqueConstraint('group_title', 'name', name='uix_group_title_name'),
|
||||
)
|
||||
tvg_logo = Column(String)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
class MockChannelURL(MockBase):
|
||||
__tablename__ = "channels_urls"
|
||||
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
|
||||
channel_id = Column(SQLiteUUID(), ForeignKey('channels.id', ondelete='CASCADE'), nullable=False)
|
||||
url = Column(String, nullable=False)
|
||||
in_use = Column(Boolean, default=False, nullable=False)
|
||||
priority_id = Column(Integer, ForeignKey('priorities.id'), nullable=False)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
# Create test engine
|
||||
engine_mock = create_engine(
|
||||
"sqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool
|
||||
)
|
||||
|
||||
# Create test session
|
||||
session_mock = sessionmaker(autocommit=False, autoflush=False, bind=engine_mock)
|
||||
|
||||
# Mock the actual database functions
|
||||
def mock_get_db():
|
||||
db = session_mock()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_env(monkeypatch):
|
||||
"""Fixture for mocking environment variables"""
|
||||
monkeypatch.setenv("MOCK_AUTH", "true")
|
||||
monkeypatch.setenv("DB_USER", "testuser")
|
||||
monkeypatch.setenv("DB_PASSWORD", "testpass")
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
monkeypatch.setenv("DB_NAME", "testdb")
|
||||
monkeypatch.setenv("AWS_REGION", "us-east-1")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ssm():
|
||||
"""Fixture for mocking boto3 SSM client"""
|
||||
with patch('boto3.client') as mock_client:
|
||||
mock_ssm = MagicMock()
|
||||
mock_client.return_value = mock_ssm
|
||||
mock_ssm.get_parameter.return_value = {
|
||||
'Parameter': {'Value': 'mocked_value'}
|
||||
}
|
||||
yield mock_ssm
|
||||
@@ -1,100 +1,15 @@
|
||||
import os
|
||||
import pytest
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import patch, MagicMock
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlalchemy.orm import Session, sessionmaker, declarative_base
|
||||
from sqlalchemy import create_engine, TypeDecorator, TEXT, Column, String, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer
|
||||
from unittest.mock import patch
|
||||
from sqlalchemy.orm import Session
|
||||
from app.utils.database import get_db_credentials, get_db
|
||||
|
||||
# Create a mock-specific Base class for testing
|
||||
MockBase = declarative_base()
|
||||
|
||||
class SQLiteUUID(TypeDecorator):
|
||||
"""Enables UUID support for SQLite."""
|
||||
impl = TEXT
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
return uuid.UUID(value)
|
||||
|
||||
# Model classes for testing - prefix with Mock to avoid pytest collection
|
||||
class MockPriority(MockBase):
|
||||
__tablename__ = "priorities"
|
||||
id = Column(Integer, primary_key=True)
|
||||
description = Column(String, nullable=False)
|
||||
|
||||
class MockChannelDB(MockBase):
|
||||
__tablename__ = "channels"
|
||||
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
|
||||
tvg_id = Column(String, nullable=False)
|
||||
name = Column(String, nullable=False)
|
||||
group_title = Column(String, nullable=False)
|
||||
tvg_name = Column(String)
|
||||
__table_args__ = (
|
||||
UniqueConstraint('group_title', 'name', name='uix_group_title_name'),
|
||||
)
|
||||
tvg_logo = Column(String)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
class MockChannelURL(MockBase):
|
||||
__tablename__ = "channels_urls"
|
||||
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
|
||||
channel_id = Column(SQLiteUUID(), ForeignKey('channels.id', ondelete='CASCADE'), nullable=False)
|
||||
url = Column(String, nullable=False)
|
||||
in_use = Column(Boolean, default=False, nullable=False)
|
||||
priority_id = Column(Integer, ForeignKey('priorities.id'), nullable=False)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
# Create test engine
|
||||
engine_mock = create_engine(
|
||||
"sqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool
|
||||
from tests.utils.db_mocks import (
|
||||
session_mock,
|
||||
mock_get_db,
|
||||
mock_env,
|
||||
mock_ssm
|
||||
)
|
||||
|
||||
# Create test session
|
||||
session_mock = sessionmaker(autocommit=False, autoflush=False, bind=engine_mock)
|
||||
|
||||
# Mock the actual database functions
|
||||
def mock_get_db():
|
||||
db = session_mock()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_env(monkeypatch):
|
||||
"""Fixture for mocking environment variables"""
|
||||
monkeypatch.setenv("MOCK_AUTH", "true")
|
||||
monkeypatch.setenv("DB_USER", "testuser")
|
||||
monkeypatch.setenv("DB_PASSWORD", "testpass")
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
monkeypatch.setenv("DB_NAME", "testdb")
|
||||
monkeypatch.setenv("AWS_REGION", "us-east-1")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ssm():
|
||||
"""Fixture for mocking boto3 SSM client"""
|
||||
with patch('boto3.client') as mock_client:
|
||||
mock_ssm = MagicMock()
|
||||
mock_client.return_value = mock_ssm
|
||||
mock_ssm.get_parameter.return_value = {
|
||||
'Parameter': {'Value': 'mocked_value'}
|
||||
}
|
||||
yield mock_ssm
|
||||
|
||||
def test_get_db_credentials_env(mock_env):
|
||||
"""Test getting DB credentials from environment variables"""
|
||||
conn_str = get_db_credentials()
|
||||
|
||||
Reference in New Issue
Block a user