Added pytest configuration and first 4 unit tests
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m4s

This commit is contained in:
2025-05-27 17:37:05 -05:00
parent 4b1a7e9bea
commit cebbb9c1a8
14 changed files with 205 additions and 15 deletions

11
.vscode/settings.json vendored
View File

@@ -1,17 +1,23 @@
{ {
"cSpell.words": [ "cSpell.words": [
"addopts",
"adminpassword", "adminpassword",
"altinstall", "altinstall",
"asyncio",
"autoflush", "autoflush",
"autouse",
"awscliv", "awscliv",
"boto", "boto",
"botocore",
"BURSTABLE", "BURSTABLE",
"cabletv", "cabletv",
"certbot", "certbot",
"certifi", "certifi",
"delenv",
"devel", "devel",
"dotenv", "dotenv",
"fastapi", "fastapi",
"filterwarnings",
"fiorinis", "fiorinis",
"freedns", "freedns",
"fullchain", "fullchain",
@@ -19,8 +25,10 @@
"iptv", "iptv",
"LETSENCRYPT", "LETSENCRYPT",
"nohup", "nohup",
"ondelete",
"onupdate", "onupdate",
"passlib", "passlib",
"poolclass",
"psycopg", "psycopg",
"pycache", "pycache",
"pyjwt", "pyjwt",
@@ -34,6 +42,9 @@
"sqlalchemy", "sqlalchemy",
"starlette", "starlette",
"stefano", "stefano",
"testdb",
"testpass",
"testpaths",
"uvicorn", "uvicorn",
"venv" "venv"
] ]

View File

@@ -1,9 +1,19 @@
from fastapi.concurrency import asynccontextmanager
from app.routers import channels, auth, playlist, priorities from app.routers import channels, auth, playlist, priorities
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi from fastapi.openapi.utils import get_openapi
from app.utils.database import init_db
@asynccontextmanager
async def lifespan(app: FastAPI):
# Initialize database tables on startup
init_db()
yield
app = FastAPI( app = FastAPI(
lifespan=lifespan,
title="IPTV Updater API", title="IPTV Updater API",
description="API for IPTV Updater service", description="API for IPTV Updater service",
version="1.0.0", version="1.0.0",

View File

@@ -2,7 +2,7 @@ from datetime import datetime, timezone
import uuid import uuid
from sqlalchemy import Column, String, JSON, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer from sqlalchemy import Column, String, JSON, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
Base = declarative_base() Base = declarative_base()

View File

@@ -1,15 +1,14 @@
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import List, Optional
from uuid import UUID from uuid import UUID
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, ConfigDict
class PriorityBase(BaseModel): class PriorityBase(BaseModel):
"""Base Pydantic model for priorities""" """Base Pydantic model for priorities"""
id: int id: int
description: str description: str
class Config: model_config = ConfigDict(from_attributes=True)
from_attributes = True
class PriorityCreate(PriorityBase): class PriorityCreate(PriorityBase):
"""Pydantic model for creating priorities""" """Pydantic model for creating priorities"""
@@ -32,8 +31,7 @@ class ChannelURLBase(ChannelURLCreate):
updated_at: datetime updated_at: datetime
priority_id: int priority_id: int
class Config: model_config = ConfigDict(from_attributes=True)
from_attributes = True
class ChannelURLResponse(ChannelURLBase): class ChannelURLResponse(ChannelURLBase):
"""Pydantic model for channel URL responses""" """Pydantic model for channel URL responses"""
@@ -74,5 +72,4 @@ class ChannelResponse(BaseModel):
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
class Config: model_config = ConfigDict(from_attributes=True)
from_attributes = True

View File

@@ -1,12 +1,12 @@
import os import os
import boto3 import boto3
from app.models import Base
from .constants import AWS_REGION from .constants import AWS_REGION
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from functools import lru_cache from functools import lru_cache
@lru_cache(maxsize=1)
def get_db_credentials(): def get_db_credentials():
"""Fetch and cache DB credentials from environment or SSM Parameter Store""" """Fetch and cache DB credentials from environment or SSM Parameter Store"""
if os.getenv("MOCK_AUTH", "").lower() == "true": if os.getenv("MOCK_AUTH", "").lower() == "true":
@@ -25,14 +25,14 @@ def get_db_credentials():
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to fetch DB credentials from SSM: {str(e)}") raise RuntimeError(f"Failed to fetch DB credentials from SSM: {str(e)}")
# Initialize engine and session maker
engine = create_engine(get_db_credentials()) engine = create_engine(get_db_credentials())
# Create all tables
from app.models import Base
Base.metadata.create_all(bind=engine)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def init_db():
"""Initialize database by creating all tables"""
Base.metadata.create_all(bind=engine)
def get_db(): def get_db():
"""Dependency for getting database session""" """Dependency for getting database session"""
db = SessionLocal() db = SessionLocal()

19
pytest.ini Normal file
View File

@@ -0,0 +1,19 @@
[pytest]
testpaths = tests
python_files = test_*.py
python_functions = test_*
asyncio_mode = auto
filterwarnings =
ignore::DeprecationWarning:botocore.auth
# Coverage configuration
addopts =
--cov=app
--cov-report=term-missing
# Test markers
markers =
slow: mark tests as slow running
integration: integration tests
unit: unit tests
db: tests requiring database

View File

@@ -12,3 +12,6 @@ pyjwt==2.7.0
sqlalchemy==2.0.23 sqlalchemy==2.0.23
psycopg2-binary==2.9.9 psycopg2-binary==2.9.9
alembic==1.16.1 alembic==1.16.1
pytest==8.1.1
pytest-asyncio==0.23.6
pytest-mock==3.12.0

0
tests/__init__.py Normal file
View File

0
tests/auth/__init__.py Normal file
View File

0
tests/iptv/__init__.py Normal file
View File

0
tests/models/__init__.py Normal file
View File

View File

0
tests/utils/__init__.py Normal file
View File

View File

@@ -0,0 +1,150 @@
import os
import pytest
import uuid
from datetime import datetime, timezone
from unittest.mock import patch, MagicMock
from sqlalchemy.pool import StaticPool
from sqlalchemy.orm import Session, sessionmaker, declarative_base
from sqlalchemy import create_engine, TypeDecorator, TEXT, Column, String, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer
from app.utils.database import get_db_credentials, get_db
# Create a mock-specific Base class for testing
MockBase = declarative_base()
class SQLiteUUID(TypeDecorator):
"""Enables UUID support for SQLite."""
impl = TEXT
cache_ok = True
def process_bind_param(self, value, dialect):
if value is None:
return value
return str(value)
def process_result_value(self, value, dialect):
if value is None:
return value
return uuid.UUID(value)
# Model classes for testing - prefix with Mock to avoid pytest collection
class MockPriority(MockBase):
__tablename__ = "priorities"
id = Column(Integer, primary_key=True)
description = Column(String, nullable=False)
class MockChannelDB(MockBase):
__tablename__ = "channels"
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
tvg_id = Column(String, nullable=False)
name = Column(String, nullable=False)
group_title = Column(String, nullable=False)
tvg_name = Column(String)
__table_args__ = (
UniqueConstraint('group_title', 'name', name='uix_group_title_name'),
)
tvg_logo = Column(String)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
class MockChannelURL(MockBase):
__tablename__ = "channels_urls"
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
channel_id = Column(SQLiteUUID(), ForeignKey('channels.id', ondelete='CASCADE'), nullable=False)
url = Column(String, nullable=False)
in_use = Column(Boolean, default=False, nullable=False)
priority_id = Column(Integer, ForeignKey('priorities.id'), nullable=False)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
# Create test engine
TEST_ENGINE = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool
)
# Create test session
TEST_SESSION = sessionmaker(autocommit=False, autoflush=False, bind=TEST_ENGINE)
# Mock the actual database functions
def mock_get_db():
db = TEST_SESSION()
try:
yield db
finally:
db.close()
@pytest.fixture(scope="session", autouse=True)
def create_test_tables():
"""Create test database tables for all tests"""
MockBase.metadata.create_all(bind=TEST_ENGINE)
@pytest.fixture(autouse=True)
def mock_db_engine(monkeypatch):
"""Fixture to mock database engine and session for each test"""
# First mock get_db_credentials to prevent real connection attempts
def mock_credentials():
return "sqlite:///:memory:"
monkeypatch.setattr('app.utils.database.get_db_credentials', mock_credentials)
# Then patch the actual database functions
monkeypatch.setattr('app.utils.database.engine', TEST_ENGINE)
monkeypatch.setattr('app.utils.database.SessionLocal', TEST_SESSION)
monkeypatch.setattr('app.utils.database.get_db', mock_get_db)
@pytest.fixture(autouse=True)
def mock_env(monkeypatch):
"""Fixture for mocking environment variables"""
# Clear any existing env vars first
monkeypatch.delenv("MOCK_AUTH", raising=False)
monkeypatch.delenv("DB_USER", raising=False)
monkeypatch.delenv("DB_PASSWORD", raising=False)
monkeypatch.delenv("DB_HOST", raising=False)
monkeypatch.delenv("DB_NAME", raising=False)
# Set mock values
monkeypatch.setenv("MOCK_AUTH", "true")
monkeypatch.setenv("DB_USER", "testuser")
monkeypatch.setenv("DB_PASSWORD", "testpass")
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_NAME", "testdb")
monkeypatch.setenv("AWS_REGION", "us-east-1") # Mock AWS region
@pytest.fixture
def mock_ssm():
"""Fixture for mocking boto3 SSM client"""
with patch('boto3.client') as mock_client:
mock_ssm = MagicMock()
mock_client.return_value = mock_ssm
mock_ssm.get_parameter.return_value = {
'Parameter': {'Value': 'mocked_value'}
}
yield mock_ssm
def test_get_db_credentials_env(mock_env):
"""Test getting DB credentials from environment variables"""
conn_str = get_db_credentials()
assert conn_str == "postgresql://testuser:testpass@localhost/testdb"
def test_get_db_credentials_ssm(mock_ssm):
"""Test getting DB credentials from SSM"""
os.environ.pop("MOCK_AUTH", None)
conn_str = get_db_credentials()
assert "postgresql://mocked_value:mocked_value@mocked_value/mocked_value" in conn_str
mock_ssm.get_parameter.assert_called()
def test_session_creation():
"""Test database session creation"""
session = TEST_SESSION()
assert isinstance(session, Session)
session.close()
def test_get_db_generator():
"""Test get_db dependency generator"""
db_gen = get_db()
db = next(db_gen)
assert isinstance(db, Session)
try:
next(db_gen) # Should raise StopIteration
except StopIteration:
pass