Added pytest configuration and first 4 unit tests
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m4s
All checks were successful
AWS Deploy on Push / build (push) Successful in 1m4s
This commit is contained in:
11
.vscode/settings.json
vendored
11
.vscode/settings.json
vendored
@@ -1,17 +1,23 @@
|
|||||||
{
|
{
|
||||||
"cSpell.words": [
|
"cSpell.words": [
|
||||||
|
"addopts",
|
||||||
"adminpassword",
|
"adminpassword",
|
||||||
"altinstall",
|
"altinstall",
|
||||||
|
"asyncio",
|
||||||
"autoflush",
|
"autoflush",
|
||||||
|
"autouse",
|
||||||
"awscliv",
|
"awscliv",
|
||||||
"boto",
|
"boto",
|
||||||
|
"botocore",
|
||||||
"BURSTABLE",
|
"BURSTABLE",
|
||||||
"cabletv",
|
"cabletv",
|
||||||
"certbot",
|
"certbot",
|
||||||
"certifi",
|
"certifi",
|
||||||
|
"delenv",
|
||||||
"devel",
|
"devel",
|
||||||
"dotenv",
|
"dotenv",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
|
"filterwarnings",
|
||||||
"fiorinis",
|
"fiorinis",
|
||||||
"freedns",
|
"freedns",
|
||||||
"fullchain",
|
"fullchain",
|
||||||
@@ -19,8 +25,10 @@
|
|||||||
"iptv",
|
"iptv",
|
||||||
"LETSENCRYPT",
|
"LETSENCRYPT",
|
||||||
"nohup",
|
"nohup",
|
||||||
|
"ondelete",
|
||||||
"onupdate",
|
"onupdate",
|
||||||
"passlib",
|
"passlib",
|
||||||
|
"poolclass",
|
||||||
"psycopg",
|
"psycopg",
|
||||||
"pycache",
|
"pycache",
|
||||||
"pyjwt",
|
"pyjwt",
|
||||||
@@ -34,6 +42,9 @@
|
|||||||
"sqlalchemy",
|
"sqlalchemy",
|
||||||
"starlette",
|
"starlette",
|
||||||
"stefano",
|
"stefano",
|
||||||
|
"testdb",
|
||||||
|
"testpass",
|
||||||
|
"testpaths",
|
||||||
"uvicorn",
|
"uvicorn",
|
||||||
"venv"
|
"venv"
|
||||||
]
|
]
|
||||||
|
|||||||
10
app/main.py
10
app/main.py
@@ -1,9 +1,19 @@
|
|||||||
|
|
||||||
|
from fastapi.concurrency import asynccontextmanager
|
||||||
from app.routers import channels, auth, playlist, priorities
|
from app.routers import channels, auth, playlist, priorities
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.openapi.utils import get_openapi
|
from fastapi.openapi.utils import get_openapi
|
||||||
|
|
||||||
|
from app.utils.database import init_db
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
# Initialize database tables on startup
|
||||||
|
init_db()
|
||||||
|
yield
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
|
lifespan=lifespan,
|
||||||
title="IPTV Updater API",
|
title="IPTV Updater API",
|
||||||
description="API for IPTV Updater service",
|
description="API for IPTV Updater service",
|
||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from datetime import datetime, timezone
|
|||||||
import uuid
|
import uuid
|
||||||
from sqlalchemy import Column, String, JSON, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer
|
from sqlalchemy import Column, String, JSON, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.orm import declarative_base
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|||||||
@@ -1,15 +1,14 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, ConfigDict
|
||||||
|
|
||||||
class PriorityBase(BaseModel):
|
class PriorityBase(BaseModel):
|
||||||
"""Base Pydantic model for priorities"""
|
"""Base Pydantic model for priorities"""
|
||||||
id: int
|
id: int
|
||||||
description: str
|
description: str
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(from_attributes=True)
|
||||||
from_attributes = True
|
|
||||||
|
|
||||||
class PriorityCreate(PriorityBase):
|
class PriorityCreate(PriorityBase):
|
||||||
"""Pydantic model for creating priorities"""
|
"""Pydantic model for creating priorities"""
|
||||||
@@ -32,8 +31,7 @@ class ChannelURLBase(ChannelURLCreate):
|
|||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
priority_id: int
|
priority_id: int
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(from_attributes=True)
|
||||||
from_attributes = True
|
|
||||||
|
|
||||||
class ChannelURLResponse(ChannelURLBase):
|
class ChannelURLResponse(ChannelURLBase):
|
||||||
"""Pydantic model for channel URL responses"""
|
"""Pydantic model for channel URL responses"""
|
||||||
@@ -74,5 +72,4 @@ class ChannelResponse(BaseModel):
|
|||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(from_attributes=True)
|
||||||
from_attributes = True
|
|
||||||
@@ -1,12 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
import boto3
|
import boto3
|
||||||
|
from app.models import Base
|
||||||
from .constants import AWS_REGION
|
from .constants import AWS_REGION
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
|
||||||
def get_db_credentials():
|
def get_db_credentials():
|
||||||
"""Fetch and cache DB credentials from environment or SSM Parameter Store"""
|
"""Fetch and cache DB credentials from environment or SSM Parameter Store"""
|
||||||
if os.getenv("MOCK_AUTH", "").lower() == "true":
|
if os.getenv("MOCK_AUTH", "").lower() == "true":
|
||||||
@@ -25,14 +25,14 @@ def get_db_credentials():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to fetch DB credentials from SSM: {str(e)}")
|
raise RuntimeError(f"Failed to fetch DB credentials from SSM: {str(e)}")
|
||||||
|
|
||||||
|
# Initialize engine and session maker
|
||||||
engine = create_engine(get_db_credentials())
|
engine = create_engine(get_db_credentials())
|
||||||
|
|
||||||
# Create all tables
|
|
||||||
from app.models import Base
|
|
||||||
Base.metadata.create_all(bind=engine)
|
|
||||||
|
|
||||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
||||||
|
def init_db():
|
||||||
|
"""Initialize database by creating all tables"""
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
|
|
||||||
def get_db():
|
def get_db():
|
||||||
"""Dependency for getting database session"""
|
"""Dependency for getting database session"""
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
|
|||||||
19
pytest.ini
Normal file
19
pytest.ini
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
[pytest]
|
||||||
|
testpaths = tests
|
||||||
|
python_files = test_*.py
|
||||||
|
python_functions = test_*
|
||||||
|
asyncio_mode = auto
|
||||||
|
filterwarnings =
|
||||||
|
ignore::DeprecationWarning:botocore.auth
|
||||||
|
|
||||||
|
# Coverage configuration
|
||||||
|
addopts =
|
||||||
|
--cov=app
|
||||||
|
--cov-report=term-missing
|
||||||
|
|
||||||
|
# Test markers
|
||||||
|
markers =
|
||||||
|
slow: mark tests as slow running
|
||||||
|
integration: integration tests
|
||||||
|
unit: unit tests
|
||||||
|
db: tests requiring database
|
||||||
@@ -12,3 +12,6 @@ pyjwt==2.7.0
|
|||||||
sqlalchemy==2.0.23
|
sqlalchemy==2.0.23
|
||||||
psycopg2-binary==2.9.9
|
psycopg2-binary==2.9.9
|
||||||
alembic==1.16.1
|
alembic==1.16.1
|
||||||
|
pytest==8.1.1
|
||||||
|
pytest-asyncio==0.23.6
|
||||||
|
pytest-mock==3.12.0
|
||||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
0
tests/auth/__init__.py
Normal file
0
tests/auth/__init__.py
Normal file
0
tests/iptv/__init__.py
Normal file
0
tests/iptv/__init__.py
Normal file
0
tests/models/__init__.py
Normal file
0
tests/models/__init__.py
Normal file
0
tests/routers/__init__.py
Normal file
0
tests/routers/__init__.py
Normal file
0
tests/utils/__init__.py
Normal file
0
tests/utils/__init__.py
Normal file
150
tests/utils/test_database.py
Normal file
150
tests/utils/test_database.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
from sqlalchemy.orm import Session, sessionmaker, declarative_base
|
||||||
|
from sqlalchemy import create_engine, TypeDecorator, TEXT, Column, String, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer
|
||||||
|
from app.utils.database import get_db_credentials, get_db
|
||||||
|
|
||||||
|
# Create a mock-specific Base class for testing
|
||||||
|
MockBase = declarative_base()
|
||||||
|
|
||||||
|
class SQLiteUUID(TypeDecorator):
|
||||||
|
"""Enables UUID support for SQLite."""
|
||||||
|
impl = TEXT
|
||||||
|
cache_ok = True
|
||||||
|
|
||||||
|
def process_bind_param(self, value, dialect):
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
def process_result_value(self, value, dialect):
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
return uuid.UUID(value)
|
||||||
|
|
||||||
|
# Model classes for testing - prefix with Mock to avoid pytest collection
|
||||||
|
class MockPriority(MockBase):
|
||||||
|
__tablename__ = "priorities"
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
description = Column(String, nullable=False)
|
||||||
|
|
||||||
|
class MockChannelDB(MockBase):
|
||||||
|
__tablename__ = "channels"
|
||||||
|
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
|
||||||
|
tvg_id = Column(String, nullable=False)
|
||||||
|
name = Column(String, nullable=False)
|
||||||
|
group_title = Column(String, nullable=False)
|
||||||
|
tvg_name = Column(String)
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint('group_title', 'name', name='uix_group_title_name'),
|
||||||
|
)
|
||||||
|
tvg_logo = Column(String)
|
||||||
|
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||||
|
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||||
|
|
||||||
|
class MockChannelURL(MockBase):
|
||||||
|
__tablename__ = "channels_urls"
|
||||||
|
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
|
||||||
|
channel_id = Column(SQLiteUUID(), ForeignKey('channels.id', ondelete='CASCADE'), nullable=False)
|
||||||
|
url = Column(String, nullable=False)
|
||||||
|
in_use = Column(Boolean, default=False, nullable=False)
|
||||||
|
priority_id = Column(Integer, ForeignKey('priorities.id'), nullable=False)
|
||||||
|
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||||
|
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||||
|
|
||||||
|
# Create test engine
|
||||||
|
TEST_ENGINE = create_engine(
|
||||||
|
"sqlite:///:memory:",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create test session
|
||||||
|
TEST_SESSION = sessionmaker(autocommit=False, autoflush=False, bind=TEST_ENGINE)
|
||||||
|
|
||||||
|
# Mock the actual database functions
|
||||||
|
def mock_get_db():
|
||||||
|
db = TEST_SESSION()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def create_test_tables():
|
||||||
|
"""Create test database tables for all tests"""
|
||||||
|
MockBase.metadata.create_all(bind=TEST_ENGINE)
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_db_engine(monkeypatch):
|
||||||
|
"""Fixture to mock database engine and session for each test"""
|
||||||
|
# First mock get_db_credentials to prevent real connection attempts
|
||||||
|
def mock_credentials():
|
||||||
|
return "sqlite:///:memory:"
|
||||||
|
monkeypatch.setattr('app.utils.database.get_db_credentials', mock_credentials)
|
||||||
|
|
||||||
|
# Then patch the actual database functions
|
||||||
|
monkeypatch.setattr('app.utils.database.engine', TEST_ENGINE)
|
||||||
|
monkeypatch.setattr('app.utils.database.SessionLocal', TEST_SESSION)
|
||||||
|
monkeypatch.setattr('app.utils.database.get_db', mock_get_db)
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_env(monkeypatch):
|
||||||
|
"""Fixture for mocking environment variables"""
|
||||||
|
# Clear any existing env vars first
|
||||||
|
monkeypatch.delenv("MOCK_AUTH", raising=False)
|
||||||
|
monkeypatch.delenv("DB_USER", raising=False)
|
||||||
|
monkeypatch.delenv("DB_PASSWORD", raising=False)
|
||||||
|
monkeypatch.delenv("DB_HOST", raising=False)
|
||||||
|
monkeypatch.delenv("DB_NAME", raising=False)
|
||||||
|
|
||||||
|
# Set mock values
|
||||||
|
monkeypatch.setenv("MOCK_AUTH", "true")
|
||||||
|
monkeypatch.setenv("DB_USER", "testuser")
|
||||||
|
monkeypatch.setenv("DB_PASSWORD", "testpass")
|
||||||
|
monkeypatch.setenv("DB_HOST", "localhost")
|
||||||
|
monkeypatch.setenv("DB_NAME", "testdb")
|
||||||
|
monkeypatch.setenv("AWS_REGION", "us-east-1") # Mock AWS region
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_ssm():
|
||||||
|
"""Fixture for mocking boto3 SSM client"""
|
||||||
|
with patch('boto3.client') as mock_client:
|
||||||
|
mock_ssm = MagicMock()
|
||||||
|
mock_client.return_value = mock_ssm
|
||||||
|
mock_ssm.get_parameter.return_value = {
|
||||||
|
'Parameter': {'Value': 'mocked_value'}
|
||||||
|
}
|
||||||
|
yield mock_ssm
|
||||||
|
|
||||||
|
def test_get_db_credentials_env(mock_env):
|
||||||
|
"""Test getting DB credentials from environment variables"""
|
||||||
|
conn_str = get_db_credentials()
|
||||||
|
assert conn_str == "postgresql://testuser:testpass@localhost/testdb"
|
||||||
|
|
||||||
|
def test_get_db_credentials_ssm(mock_ssm):
|
||||||
|
"""Test getting DB credentials from SSM"""
|
||||||
|
os.environ.pop("MOCK_AUTH", None)
|
||||||
|
conn_str = get_db_credentials()
|
||||||
|
assert "postgresql://mocked_value:mocked_value@mocked_value/mocked_value" in conn_str
|
||||||
|
mock_ssm.get_parameter.assert_called()
|
||||||
|
|
||||||
|
def test_session_creation():
|
||||||
|
"""Test database session creation"""
|
||||||
|
session = TEST_SESSION()
|
||||||
|
assert isinstance(session, Session)
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
def test_get_db_generator():
|
||||||
|
"""Test get_db dependency generator"""
|
||||||
|
db_gen = get_db()
|
||||||
|
db = next(db_gen)
|
||||||
|
assert isinstance(db, Session)
|
||||||
|
try:
|
||||||
|
next(db_gen) # Should raise StopIteration
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
Reference in New Issue
Block a user