From cebbb9c1a8e8468cb6255c14bfd0557b5c6bd7aa Mon Sep 17 00:00:00 2001 From: Stefano Date: Tue, 27 May 2025 17:37:05 -0500 Subject: [PATCH] Added pytest configuration and first 4 unit tests --- .vscode/settings.json | 11 +++ app/main.py | 10 +++ app/models/db.py | 2 +- app/models/schemas.py | 11 +-- app/utils/database.py | 12 +-- pytest.ini | 19 +++++ requirements.txt | 5 +- tests/__init__.py | 0 tests/auth/__init__.py | 0 tests/iptv/__init__.py | 0 tests/models/__init__.py | 0 tests/routers/__init__.py | 0 tests/utils/__init__.py | 0 tests/utils/test_database.py | 150 +++++++++++++++++++++++++++++++++++ 14 files changed, 205 insertions(+), 15 deletions(-) create mode 100644 pytest.ini create mode 100644 tests/__init__.py create mode 100644 tests/auth/__init__.py create mode 100644 tests/iptv/__init__.py create mode 100644 tests/models/__init__.py create mode 100644 tests/routers/__init__.py create mode 100644 tests/utils/__init__.py create mode 100644 tests/utils/test_database.py diff --git a/.vscode/settings.json b/.vscode/settings.json index f97cfee..1712ada 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,17 +1,23 @@ { "cSpell.words": [ + "addopts", "adminpassword", "altinstall", + "asyncio", "autoflush", + "autouse", "awscliv", "boto", + "botocore", "BURSTABLE", "cabletv", "certbot", "certifi", + "delenv", "devel", "dotenv", "fastapi", + "filterwarnings", "fiorinis", "freedns", "fullchain", @@ -19,8 +25,10 @@ "iptv", "LETSENCRYPT", "nohup", + "ondelete", "onupdate", "passlib", + "poolclass", "psycopg", "pycache", "pyjwt", @@ -34,6 +42,9 @@ "sqlalchemy", "starlette", "stefano", + "testdb", + "testpass", + "testpaths", "uvicorn", "venv" ] diff --git a/app/main.py b/app/main.py index a78a346..a4aa9a8 100644 --- a/app/main.py +++ b/app/main.py @@ -1,9 +1,19 @@ +from fastapi.concurrency import asynccontextmanager from app.routers import channels, auth, playlist, priorities from fastapi import FastAPI from fastapi.openapi.utils import get_openapi +from app.utils.database import init_db + +@asynccontextmanager +async def lifespan(app: FastAPI): + # Initialize database tables on startup + init_db() + yield + app = FastAPI( + lifespan=lifespan, title="IPTV Updater API", description="API for IPTV Updater service", version="1.0.0", diff --git a/app/models/db.py b/app/models/db.py index 1d69efd..46369c6 100644 --- a/app/models/db.py +++ b/app/models/db.py @@ -2,7 +2,7 @@ from datetime import datetime, timezone import uuid from sqlalchemy import Column, String, JSON, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import declarative_base from sqlalchemy.orm import relationship Base = declarative_base() diff --git a/app/models/schemas.py b/app/models/schemas.py index 4393d3e..534ed68 100644 --- a/app/models/schemas.py +++ b/app/models/schemas.py @@ -1,15 +1,14 @@ from datetime import datetime from typing import List, Optional from uuid import UUID -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict class PriorityBase(BaseModel): """Base Pydantic model for priorities""" id: int description: str - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class PriorityCreate(PriorityBase): """Pydantic model for creating priorities""" @@ -32,8 +31,7 @@ class ChannelURLBase(ChannelURLCreate): updated_at: datetime priority_id: int - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class ChannelURLResponse(ChannelURLBase): """Pydantic model for channel URL responses""" @@ -74,5 +72,4 @@ class ChannelResponse(BaseModel): created_at: datetime updated_at: datetime - class Config: - from_attributes = True \ No newline at end of file + model_config = ConfigDict(from_attributes=True) \ No newline at end of file diff --git a/app/utils/database.py b/app/utils/database.py index c453dfd..273dd36 100644 --- a/app/utils/database.py +++ b/app/utils/database.py @@ -1,12 +1,12 @@ import os import boto3 +from app.models import Base from .constants import AWS_REGION from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from functools import lru_cache -@lru_cache(maxsize=1) def get_db_credentials(): """Fetch and cache DB credentials from environment or SSM Parameter Store""" if os.getenv("MOCK_AUTH", "").lower() == "true": @@ -25,14 +25,14 @@ def get_db_credentials(): except Exception as e: raise RuntimeError(f"Failed to fetch DB credentials from SSM: {str(e)}") +# Initialize engine and session maker engine = create_engine(get_db_credentials()) - -# Create all tables -from app.models import Base -Base.metadata.create_all(bind=engine) - SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +def init_db(): + """Initialize database by creating all tables""" + Base.metadata.create_all(bind=engine) + def get_db(): """Dependency for getting database session""" db = SessionLocal() diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..8c64b37 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,19 @@ +[pytest] +testpaths = tests +python_files = test_*.py +python_functions = test_* +asyncio_mode = auto +filterwarnings = + ignore::DeprecationWarning:botocore.auth + +# Coverage configuration +addopts = + --cov=app + --cov-report=term-missing + +# Test markers +markers = + slow: mark tests as slow running + integration: integration tests + unit: unit tests + db: tests requiring database \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 60815b9..f702cc0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,7 @@ starlette>=0.27.0 pyjwt==2.7.0 sqlalchemy==2.0.23 psycopg2-binary==2.9.9 -alembic==1.16.1 \ No newline at end of file +alembic==1.16.1 +pytest==8.1.1 +pytest-asyncio==0.23.6 +pytest-mock==3.12.0 \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/auth/__init__.py b/tests/auth/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/iptv/__init__.py b/tests/iptv/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/routers/__init__.py b/tests/routers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/utils/test_database.py b/tests/utils/test_database.py new file mode 100644 index 0000000..dccdf21 --- /dev/null +++ b/tests/utils/test_database.py @@ -0,0 +1,150 @@ +import os +import pytest +import uuid +from datetime import datetime, timezone +from unittest.mock import patch, MagicMock +from sqlalchemy.pool import StaticPool +from sqlalchemy.orm import Session, sessionmaker, declarative_base +from sqlalchemy import create_engine, TypeDecorator, TEXT, Column, String, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer +from app.utils.database import get_db_credentials, get_db + +# Create a mock-specific Base class for testing +MockBase = declarative_base() + +class SQLiteUUID(TypeDecorator): + """Enables UUID support for SQLite.""" + impl = TEXT + cache_ok = True + + def process_bind_param(self, value, dialect): + if value is None: + return value + return str(value) + + def process_result_value(self, value, dialect): + if value is None: + return value + return uuid.UUID(value) + +# Model classes for testing - prefix with Mock to avoid pytest collection +class MockPriority(MockBase): + __tablename__ = "priorities" + id = Column(Integer, primary_key=True) + description = Column(String, nullable=False) + +class MockChannelDB(MockBase): + __tablename__ = "channels" + id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4) + tvg_id = Column(String, nullable=False) + name = Column(String, nullable=False) + group_title = Column(String, nullable=False) + tvg_name = Column(String) + __table_args__ = ( + UniqueConstraint('group_title', 'name', name='uix_group_title_name'), + ) + tvg_logo = Column(String) + created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) + updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) + +class MockChannelURL(MockBase): + __tablename__ = "channels_urls" + id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4) + channel_id = Column(SQLiteUUID(), ForeignKey('channels.id', ondelete='CASCADE'), nullable=False) + url = Column(String, nullable=False) + in_use = Column(Boolean, default=False, nullable=False) + priority_id = Column(Integer, ForeignKey('priorities.id'), nullable=False) + created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) + updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) + +# Create test engine +TEST_ENGINE = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool +) + +# Create test session +TEST_SESSION = sessionmaker(autocommit=False, autoflush=False, bind=TEST_ENGINE) + +# Mock the actual database functions +def mock_get_db(): + db = TEST_SESSION() + try: + yield db + finally: + db.close() + +@pytest.fixture(scope="session", autouse=True) +def create_test_tables(): + """Create test database tables for all tests""" + MockBase.metadata.create_all(bind=TEST_ENGINE) + +@pytest.fixture(autouse=True) +def mock_db_engine(monkeypatch): + """Fixture to mock database engine and session for each test""" + # First mock get_db_credentials to prevent real connection attempts + def mock_credentials(): + return "sqlite:///:memory:" + monkeypatch.setattr('app.utils.database.get_db_credentials', mock_credentials) + + # Then patch the actual database functions + monkeypatch.setattr('app.utils.database.engine', TEST_ENGINE) + monkeypatch.setattr('app.utils.database.SessionLocal', TEST_SESSION) + monkeypatch.setattr('app.utils.database.get_db', mock_get_db) + +@pytest.fixture(autouse=True) +def mock_env(monkeypatch): + """Fixture for mocking environment variables""" + # Clear any existing env vars first + monkeypatch.delenv("MOCK_AUTH", raising=False) + monkeypatch.delenv("DB_USER", raising=False) + monkeypatch.delenv("DB_PASSWORD", raising=False) + monkeypatch.delenv("DB_HOST", raising=False) + monkeypatch.delenv("DB_NAME", raising=False) + + # Set mock values + monkeypatch.setenv("MOCK_AUTH", "true") + monkeypatch.setenv("DB_USER", "testuser") + monkeypatch.setenv("DB_PASSWORD", "testpass") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_NAME", "testdb") + monkeypatch.setenv("AWS_REGION", "us-east-1") # Mock AWS region + +@pytest.fixture +def mock_ssm(): + """Fixture for mocking boto3 SSM client""" + with patch('boto3.client') as mock_client: + mock_ssm = MagicMock() + mock_client.return_value = mock_ssm + mock_ssm.get_parameter.return_value = { + 'Parameter': {'Value': 'mocked_value'} + } + yield mock_ssm + +def test_get_db_credentials_env(mock_env): + """Test getting DB credentials from environment variables""" + conn_str = get_db_credentials() + assert conn_str == "postgresql://testuser:testpass@localhost/testdb" + +def test_get_db_credentials_ssm(mock_ssm): + """Test getting DB credentials from SSM""" + os.environ.pop("MOCK_AUTH", None) + conn_str = get_db_credentials() + assert "postgresql://mocked_value:mocked_value@mocked_value/mocked_value" in conn_str + mock_ssm.get_parameter.assert_called() + +def test_session_creation(): + """Test database session creation""" + session = TEST_SESSION() + assert isinstance(session, Session) + session.close() + +def test_get_db_generator(): + """Test get_db dependency generator""" + db_gen = get_db() + db = next(db_gen) + assert isinstance(db, Session) + try: + next(db_gen) # Should raise StopIteration + except StopIteration: + pass \ No newline at end of file