146 lines
4.5 KiB
Python
146 lines
4.5 KiB
Python
import uuid
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
from sqlalchemy import (
|
|
Boolean,
|
|
Column,
|
|
DateTime,
|
|
ForeignKey,
|
|
Integer,
|
|
String,
|
|
UniqueConstraint,
|
|
create_engine,
|
|
)
|
|
from sqlalchemy.orm import declarative_base, relationship, sessionmaker
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
# Import the actual UUID_COLUMN_TYPE and SQLiteUUID from app.models.db
|
|
from app.models.db import UUID_COLUMN_TYPE, SQLiteUUID
|
|
|
|
# Create a mock-specific Base class for testing
|
|
MockBase = declarative_base()
|
|
|
|
|
|
# Model classes for testing - prefix with Mock to avoid pytest collection
|
|
class MockPriority(MockBase):
|
|
__tablename__ = "priorities"
|
|
id = Column(Integer, primary_key=True)
|
|
description = Column(String, nullable=False)
|
|
|
|
|
|
class MockGroup(MockBase):
|
|
__tablename__ = "groups"
|
|
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
|
name = Column(String, nullable=False, unique=True)
|
|
sort_order = Column(Integer, nullable=False, default=0)
|
|
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
|
updated_at = Column(
|
|
DateTime,
|
|
default=lambda: datetime.now(timezone.utc),
|
|
onupdate=lambda: datetime.now(timezone.utc),
|
|
)
|
|
channels = relationship("MockChannelDB", back_populates="group")
|
|
|
|
|
|
class MockChannelDB(MockBase):
|
|
__tablename__ = "channels"
|
|
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
|
tvg_id = Column(String, nullable=False)
|
|
name = Column(String, nullable=False)
|
|
group_id = Column(UUID_COLUMN_TYPE, ForeignKey("groups.id"), nullable=False)
|
|
tvg_name = Column(String)
|
|
__table_args__ = (UniqueConstraint("group_id", "name", name="uix_group_id_name"),)
|
|
tvg_logo = Column(String)
|
|
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
|
updated_at = Column(
|
|
DateTime,
|
|
default=lambda: datetime.now(timezone.utc),
|
|
onupdate=lambda: datetime.now(timezone.utc),
|
|
)
|
|
group = relationship("MockGroup", back_populates="channels")
|
|
|
|
|
|
class MockChannelURL(MockBase):
|
|
__tablename__ = "channels_urls"
|
|
id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4)
|
|
channel_id = Column(
|
|
UUID_COLUMN_TYPE, ForeignKey("channels.id", ondelete="CASCADE"), nullable=False
|
|
)
|
|
url = Column(String, nullable=False)
|
|
in_use = Column(Boolean, default=False, nullable=False)
|
|
priority_id = Column(Integer, ForeignKey("priorities.id"), nullable=False)
|
|
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
|
updated_at = Column(
|
|
DateTime,
|
|
default=lambda: datetime.now(timezone.utc),
|
|
onupdate=lambda: datetime.now(timezone.utc),
|
|
)
|
|
|
|
|
|
def create_mock_priorities_and_group(db_session, priorities, group_name):
|
|
"""Create mock priorities and group for testing purposes.
|
|
|
|
Args:
|
|
db_session: SQLAlchemy session object
|
|
priorities: List of (id, description) tuples for priorities to create
|
|
group_name: Name for the new mock group
|
|
|
|
Returns:
|
|
UUID: The ID of the created group
|
|
"""
|
|
# Create priorities
|
|
priority_objects = [
|
|
MockPriority(id=priority_id, description=description)
|
|
for priority_id, description in priorities
|
|
]
|
|
|
|
# Create group
|
|
group = MockGroup(name=group_name)
|
|
db_session.add_all(priority_objects + [group])
|
|
db_session.commit()
|
|
|
|
return group.id
|
|
|
|
|
|
# Create test engine
|
|
engine_mock = create_engine(
|
|
"sqlite:///:memory:",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
|
|
# Create test session
|
|
session_mock = sessionmaker(autocommit=False, autoflush=False, bind=engine_mock)
|
|
|
|
|
|
# Mock the actual database functions
|
|
def mock_get_db():
|
|
db = session_mock()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def mock_env(monkeypatch):
|
|
"""Fixture for mocking environment variables"""
|
|
monkeypatch.setenv("MOCK_AUTH", "true")
|
|
monkeypatch.setenv("DB_USER", "testuser")
|
|
monkeypatch.setenv("DB_PASSWORD", "testpass")
|
|
monkeypatch.setenv("DB_HOST", "localhost")
|
|
monkeypatch.setenv("DB_NAME", "testdb")
|
|
monkeypatch.setenv("AWS_REGION", "us-east-1")
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_ssm():
|
|
"""Fixture for mocking boto3 SSM client"""
|
|
with patch("boto3.client") as mock_client:
|
|
mock_ssm = MagicMock()
|
|
mock_client.return_value = mock_ssm
|
|
mock_ssm.get_parameter.return_value = {"Parameter": {"Value": "mocked_value"}}
|
|
yield mock_ssm
|