diff --git a/.vscode/settings.json b/.vscode/settings.json index e7f6f06..78efe5e 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -85,6 +85,7 @@ "ruru", "sessionmaker", "sqlalchemy", + "sqliteuuid", "starlette", "stefano", "testadmin", diff --git a/alembic/versions/0a455608256f_add_groups_table_and_migrate_group_.py b/alembic/versions/0a455608256f_add_groups_table_and_migrate_group_.py new file mode 100644 index 0000000..6c2f83e --- /dev/null +++ b/alembic/versions/0a455608256f_add_groups_table_and_migrate_group_.py @@ -0,0 +1,110 @@ +"""add groups table and migrate group_title data + +Revision ID: 0a455608256f +Revises: 95b61a92455a +Create Date: 2025-06-10 09:22:11.820035 + +""" +import uuid +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '0a455608256f' +down_revision: Union[str, None] = '95b61a92455a' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('groups', + sa.Column('id', sa.UUID(), nullable=False), + sa.Column('name', sa.String(), nullable=False), + sa.Column('sort_order', sa.Integer(), nullable=False, server_default='0'), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('name') + ) + + # Create temporary table for group mapping + group_mapping = op.create_table( + 'group_mapping', + sa.Column('group_title', sa.String(), nullable=False), + sa.Column('group_id', sa.UUID(), nullable=False) + ) + + # Get existing group titles and create groups + conn = op.get_bind() + distinct_groups = conn.execute( + sa.text("SELECT DISTINCT group_title FROM channels") + ).fetchall() + + for group in distinct_groups: + group_title = group[0] + group_id = str(uuid.uuid4()) + conn.execute( + sa.text( + "INSERT INTO groups (id, name, sort_order) " + "VALUES (:id, :name, 0)" + ).bindparams(id=group_id, name=group_title) + ) + conn.execute( + group_mapping.insert().values( + group_title=group_title, + group_id=group_id + ) + ) + + # Add group_id column (nullable first) + op.add_column('channels', sa.Column('group_id', sa.UUID(), nullable=True)) + + # Update channels with group_ids + conn.execute( + sa.text( + "UPDATE channels c SET group_id = gm.group_id " + "FROM group_mapping gm WHERE c.group_title = gm.group_title" + ) + ) + + # Now make group_id non-nullable and add constraints + op.alter_column('channels', 'group_id', nullable=False) + op.drop_constraint(op.f('uix_group_title_name'), 'channels', type_='unique') + op.create_unique_constraint('uix_group_id_name', 'channels', ['group_id', 'name']) + op.create_foreign_key('fk_channels_group_id', 'channels', 'groups', ['group_id'], ['id']) + + # Clean up and drop group_title + op.drop_table('group_mapping') + op.drop_column('channels', 'group_title') + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('channels', sa.Column('group_title', sa.VARCHAR(), autoincrement=False, nullable=True)) + + # Restore group_title values from groups table + conn = op.get_bind() + conn.execute( + sa.text( + "UPDATE channels c SET group_title = g.name " + "FROM groups g WHERE c.group_id = g.id" + ) + ) + + # Now make group_title non-nullable + op.alter_column('channels', 'group_title', nullable=False) + + # Drop constraints and columns + op.drop_constraint('fk_channels_group_id', 'channels', type_='foreignkey') + op.drop_constraint('uix_group_id_name', 'channels', type_='unique') + op.create_unique_constraint(op.f('uix_group_title_name'), 'channels', ['group_title', 'name']) + op.drop_column('channels', 'group_id') + op.drop_table('groups') + # ### end Alembic commands ### diff --git a/app/main.py b/app/main.py index 1e79974..b03957e 100644 --- a/app/main.py +++ b/app/main.py @@ -2,7 +2,7 @@ from fastapi import FastAPI from fastapi.concurrency import asynccontextmanager from fastapi.openapi.utils import get_openapi -from app.routers import auth, channels, playlist, priorities +from app.routers import auth, channels, groups, playlist, priorities from app.utils.database import init_db @@ -68,3 +68,4 @@ app.include_router(auth.router) app.include_router(channels.router) app.include_router(playlist.router) app.include_router(priorities.router) +app.include_router(groups.router) diff --git a/app/models/__init__.py b/app/models/__init__.py index 29b82a9..d3a14da 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -1,10 +1,13 @@ -from .db import Base, ChannelDB, ChannelURL +from .db import Base, ChannelDB, ChannelURL, Group from .schemas import ( ChannelCreate, ChannelResponse, ChannelUpdate, ChannelURLCreate, ChannelURLResponse, + GroupCreate, + GroupResponse, + GroupUpdate, ) __all__ = [ @@ -16,4 +19,8 @@ __all__ = [ "ChannelURL", "ChannelURLCreate", "ChannelURLResponse", + "Group", + "GroupCreate", + "GroupResponse", + "GroupUpdate", ] diff --git a/app/models/db.py b/app/models/db.py index adf02ce..fb570c5 100644 --- a/app/models/db.py +++ b/app/models/db.py @@ -1,18 +1,58 @@ +import os import uuid from datetime import datetime, timezone from sqlalchemy import ( + TEXT, Boolean, Column, DateTime, ForeignKey, Integer, String, + TypeDecorator, UniqueConstraint, ) from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import declarative_base, relationship + +# Custom UUID type for SQLite compatibility +class SQLiteUUID(TypeDecorator): + """Enables UUID support for SQLite with proper comparison handling.""" + + impl = TEXT + cache_ok = True + + def process_bind_param(self, value, dialect): + if value is None: + return value + if isinstance(value, uuid.UUID): + return str(value) + try: + # Validate string format by attempting to create UUID + uuid.UUID(value) + return value + except (ValueError, AttributeError): + raise ValueError(f"Invalid UUID string format: {value}") + + def process_result_value(self, value, dialect): + if value is None: + return value + return uuid.UUID(value) + + def compare_values(self, x, y): + if x is None or y is None: + return x == y + return str(x) == str(y) + + +# Determine which UUID type to use based on environment +if os.getenv("MOCK_AUTH", "").lower() == "true": + UUID_COLUMN_TYPE = SQLiteUUID() +else: + UUID_COLUMN_TYPE = UUID(as_uuid=True) + Base = declarative_base() @@ -25,20 +65,37 @@ class Priority(Base): description = Column(String, nullable=False) +class Group(Base): + """SQLAlchemy model for channel groups""" + + __tablename__ = "groups" + + id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4) + name = Column(String, nullable=False, unique=True) + sort_order = Column(Integer, nullable=False, default=0) + created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) + updated_at = Column( + DateTime, + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + ) + + # Relationship with Channel + channels = relationship("ChannelDB", back_populates="group") + + class ChannelDB(Base): """SQLAlchemy model for IPTV channels""" __tablename__ = "channels" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4) tvg_id = Column(String, nullable=False) name = Column(String, nullable=False) - group_title = Column(String, nullable=False) + group_id = Column(UUID_COLUMN_TYPE, ForeignKey("groups.id"), nullable=False) tvg_name = Column(String) - __table_args__ = ( - UniqueConstraint("group_title", "name", name="uix_group_title_name"), - ) + __table_args__ = (UniqueConstraint("group_id", "name", name="uix_group_id_name"),) tvg_logo = Column(String) created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) updated_at = Column( @@ -47,10 +104,11 @@ class ChannelDB(Base): onupdate=lambda: datetime.now(timezone.utc), ) - # Relationship with ChannelURL + # Relationships urls = relationship( "ChannelURL", back_populates="channel", cascade="all, delete-orphan" ) + group = relationship("Group", back_populates="channels") class ChannelURL(Base): @@ -58,9 +116,9 @@ class ChannelURL(Base): __tablename__ = "channels_urls" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4) channel_id = Column( - UUID(as_uuid=True), + UUID_COLUMN_TYPE, ForeignKey("channels.id", ondelete="CASCADE"), nullable=False, ) diff --git a/app/models/schemas.py b/app/models/schemas.py index 22e122c..3b73c84 100644 --- a/app/models/schemas.py +++ b/app/models/schemas.py @@ -53,12 +53,54 @@ class ChannelURLResponse(ChannelURLBase): pass +# New Group Schemas +class GroupCreate(BaseModel): + """Pydantic model for creating groups""" + + name: str + sort_order: int = Field(default=0, ge=0) + + +class GroupUpdate(BaseModel): + """Pydantic model for updating groups""" + + name: Optional[str] = None + sort_order: Optional[int] = Field(None, ge=0) + + +class GroupResponse(BaseModel): + """Pydantic model for group responses""" + + id: UUID + name: str + sort_order: int + created_at: datetime + updated_at: datetime + + model_config = ConfigDict(from_attributes=True) + + +class GroupSortUpdate(BaseModel): + """Pydantic model for updating a single group's sort order""" + + sort_order: int = Field(ge=0) + + +class GroupBulkSort(BaseModel): + """Pydantic model for bulk updating group sort orders""" + + groups: list[dict] = Field( + description="List of dicts with group_id and new sort_order", + json_schema_extra={"example": [{"group_id": "uuid", "sort_order": 1}]}, + ) + + class ChannelCreate(BaseModel): """Pydantic model for creating channels""" urls: list[ChannelURLCreate] # List of URL objects with priority name: str - group_title: str + group_id: UUID tvg_id: str tvg_logo: str tvg_name: str @@ -76,7 +118,7 @@ class ChannelUpdate(BaseModel): """Pydantic model for updating channels (all fields optional)""" name: Optional[str] = Field(None, min_length=1) - group_title: Optional[str] = Field(None, min_length=1) + group_id: Optional[UUID] = None tvg_id: Optional[str] = Field(None, min_length=1) tvg_logo: Optional[str] = None tvg_name: Optional[str] = Field(None, min_length=1) @@ -87,7 +129,7 @@ class ChannelResponse(BaseModel): id: UUID name: str - group_title: str + group_id: UUID tvg_id: str tvg_logo: str tvg_name: str diff --git a/app/routers/channels.py b/app/routers/channels.py index 4c2118a..67c3189 100644 --- a/app/routers/channels.py +++ b/app/routers/channels.py @@ -13,6 +13,7 @@ from app.models import ( ChannelURL, ChannelURLCreate, ChannelURLResponse, + Group, ) from app.models.auth import CognitoUser from app.models.schemas import ChannelURLUpdate @@ -29,12 +30,20 @@ def create_channel( user: CognitoUser = Depends(get_current_user), ): """Create a new channel""" - # Check for duplicate channel (same group_title + name) + # Check if group exists + group = db.query(Group).filter(Group.id == channel.group_id).first() + if not group: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Group not found", + ) + + # Check for duplicate channel (same group_id + name) existing_channel = ( db.query(ChannelDB) .filter( and_( - ChannelDB.group_title == channel.group_title, + ChannelDB.group_id == channel.group_id, ChannelDB.name == channel.name, ) ) @@ -44,7 +53,7 @@ def create_channel( if existing_channel: raise HTTPException( status_code=status.HTTP_409_CONFLICT, - detail="Channel with same group_title and name already exists", + detail="Channel with same group_id and name already exists", ) # Create channel without URLs first @@ -96,20 +105,27 @@ def update_channel( status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found" ) - # Only check for duplicates if name or group_title are being updated - if channel.name is not None or channel.group_title is not None: + # Only check for duplicates if name or group_id are being updated + if channel.name is not None or channel.group_id is not None: name = channel.name if channel.name is not None else db_channel.name - group_title = ( - channel.group_title - if channel.group_title is not None - else db_channel.group_title + group_id = ( + channel.group_id if channel.group_id is not None else db_channel.group_id ) + # Check if new group exists + if channel.group_id is not None: + group = db.query(Group).filter(Group.id == channel.group_id).first() + if not group: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Group not found", + ) + existing_channel = ( db.query(ChannelDB) .filter( and_( - ChannelDB.group_title == group_title, + ChannelDB.group_id == group_id, ChannelDB.name == name, ChannelDB.id != channel_id, ) @@ -120,7 +136,7 @@ def update_channel( if existing_channel: raise HTTPException( status_code=status.HTTP_409_CONFLICT, - detail="Channel with same group_title and name already exists", + detail="Channel with same group_id and name already exists", ) # Update only provided fields @@ -163,9 +179,69 @@ def list_channels( return db.query(ChannelDB).offset(skip).limit(limit).all() +# New endpoint to get channels by group +@router.get("/groups/{group_id}/channels", response_model=list[ChannelResponse]) +def get_channels_by_group( + group_id: UUID, + db: Session = Depends(get_db), +): + """Get all channels for a specific group""" + group = db.query(Group).filter(Group.id == group_id).first() + if not group: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Group not found" + ) + return db.query(ChannelDB).filter(ChannelDB.group_id == group_id).all() + + +# New endpoint to update a channel's group +@router.put("/{channel_id}/group", response_model=ChannelResponse) +@require_roles("admin") +def update_channel_group( + channel_id: UUID, + group_id: UUID, + db: Session = Depends(get_db), + user: CognitoUser = Depends(get_current_user), +): + """Update a channel's group""" + channel = db.query(ChannelDB).filter(ChannelDB.id == channel_id).first() + if not channel: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Channel not found" + ) + + group = db.query(Group).filter(Group.id == group_id).first() + if not group: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Group not found" + ) + + # Check for duplicate channel name in new group + existing_channel = ( + db.query(ChannelDB) + .filter( + and_( + ChannelDB.group_id == group_id, + ChannelDB.name == channel.name, + ChannelDB.id != channel_id, + ) + ) + .first() + ) + + if existing_channel: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Channel with same name already exists in target group", + ) + + channel.group_id = group_id + db.commit() + db.refresh(channel) + return channel + + # URL Management Endpoints - - @router.post( "/{channel_id}/urls", response_model=ChannelURLResponse, diff --git a/app/routers/groups.py b/app/routers/groups.py new file mode 100644 index 0000000..14f6f0a --- /dev/null +++ b/app/routers/groups.py @@ -0,0 +1,169 @@ +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session + +from app.auth.dependencies import get_current_user, require_roles +from app.models import Group +from app.models.auth import CognitoUser +from app.models.schemas import ( + GroupBulkSort, + GroupCreate, + GroupResponse, + GroupSortUpdate, + GroupUpdate, +) +from app.utils.database import get_db + +router = APIRouter(prefix="/groups", tags=["groups"]) + + +@router.post("/", response_model=GroupResponse, status_code=status.HTTP_201_CREATED) +@require_roles("admin") +def create_group( + group: GroupCreate, + db: Session = Depends(get_db), + user: CognitoUser = Depends(get_current_user), +): + """Create a new channel group""" + # Check for duplicate group name + existing_group = db.query(Group).filter(Group.name == group.name).first() + if existing_group: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Group with this name already exists", + ) + + db_group = Group(**group.model_dump()) + db.add(db_group) + db.commit() + db.refresh(db_group) + return db_group + + +@router.get("/{group_id}", response_model=GroupResponse) +def get_group(group_id: UUID, db: Session = Depends(get_db)): + """Get a group by id""" + group = db.query(Group).filter(Group.id == group_id).first() + if not group: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Group not found" + ) + return group + + +@router.put("/{group_id}", response_model=GroupResponse) +@require_roles("admin") +def update_group( + group_id: UUID, + group: GroupUpdate, + db: Session = Depends(get_db), + user: CognitoUser = Depends(get_current_user), +): + """Update a group's name or sort order""" + db_group = db.query(Group).filter(Group.id == group_id).first() + if not db_group: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Group not found" + ) + + # Check for duplicate name if name is being updated + if group.name is not None and group.name != db_group.name: + existing_group = db.query(Group).filter(Group.name == group.name).first() + if existing_group: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Group with this name already exists", + ) + + # Update only provided fields + update_data = group.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(db_group, key, value) + + db.commit() + db.refresh(db_group) + return db_group + + +@router.delete("/{group_id}", status_code=status.HTTP_204_NO_CONTENT) +@require_roles("admin") +def delete_group( + group_id: UUID, + db: Session = Depends(get_db), + user: CognitoUser = Depends(get_current_user), +): + """Delete a group (only if it has no channels)""" + group = db.query(Group).filter(Group.id == group_id).first() + if not group: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Group not found" + ) + + # Check if group has any channels + if group.channels: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Cannot delete group with existing channels", + ) + + db.delete(group) + db.commit() + return None + + +@router.get("/", response_model=list[GroupResponse]) +def list_groups(db: Session = Depends(get_db)): + """List all groups sorted by sort_order""" + return db.query(Group).order_by(Group.sort_order).all() + + +@router.put("/{group_id}/sort", response_model=GroupResponse) +@require_roles("admin") +def update_group_sort_order( + group_id: UUID, + sort_update: GroupSortUpdate, + db: Session = Depends(get_db), + user: CognitoUser = Depends(get_current_user), +): + """Update a single group's sort order""" + db_group = db.query(Group).filter(Group.id == group_id).first() + if not db_group: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Group not found" + ) + + db_group.sort_order = sort_update.sort_order + db.commit() + db.refresh(db_group) + return db_group + + +@router.post("/reorder", response_model=list[GroupResponse]) +@require_roles("admin") +def bulk_update_sort_orders( + bulk_sort: GroupBulkSort, + db: Session = Depends(get_db), + user: CognitoUser = Depends(get_current_user), +): + """Bulk update group sort orders""" + groups_to_update = [] + + for group_data in bulk_sort.groups: + group_id = group_data["group_id"] + sort_order = group_data["sort_order"] + + group = db.query(Group).filter(Group.id == str(group_id)).first() + if not group: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Group with id {group_id} not found", + ) + + group.sort_order = sort_order + groups_to_update.append(group) + + db.commit() + + # Return all groups in their new order + return db.query(Group).order_by(Group.sort_order).all() diff --git a/tests/auth/test_dependencies.py b/tests/auth/test_dependencies.py index 5d73400..17f033e 100644 --- a/tests/auth/test_dependencies.py +++ b/tests/auth/test_dependencies.py @@ -173,3 +173,33 @@ def test_mock_auth_import(monkeypatch): # Reload again to restore original state importlib.reload(app.auth.dependencies) + + +def test_cognito_auth_import(monkeypatch): + """Test that cognito auth is imported when MOCK_AUTH=false (covers line 14)""" + # Save original env var value + original_value = os.environ.get("MOCK_AUTH") + + try: + # Set MOCK_AUTH to false + monkeypatch.setenv("MOCK_AUTH", "false") + + # Reload the dependencies module to trigger the import condition + import app.auth.dependencies + + importlib.reload(app.auth.dependencies) + + # Verify that get_user_from_token was imported from app.auth.cognito + from app.auth.dependencies import get_user_from_token + + assert get_user_from_token.__module__ == "app.auth.cognito" + + finally: + # Restore original env var + if original_value is None: + monkeypatch.delenv("MOCK_AUTH", raising=False) + else: + monkeypatch.setenv("MOCK_AUTH", original_value) + + # Reload again to restore original state + importlib.reload(app.auth.dependencies) diff --git a/tests/models/test_db.py b/tests/models/test_db.py new file mode 100644 index 0000000..15e2cac --- /dev/null +++ b/tests/models/test_db.py @@ -0,0 +1,135 @@ +import os +import uuid +from unittest.mock import patch + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from app.models.db import UUID_COLUMN_TYPE, Base, SQLiteUUID + +# --- Test SQLiteUUID Type --- + + +def test_sqliteuuid_process_bind_param_none(): + """Test SQLiteUUID.process_bind_param with None returns None""" + uuid_type = SQLiteUUID() + assert uuid_type.process_bind_param(None, None) is None + + +def test_sqliteuuid_process_bind_param_valid_uuid(): + """Test SQLiteUUID.process_bind_param with valid UUID returns string""" + uuid_type = SQLiteUUID() + test_uuid = uuid.uuid4() + assert uuid_type.process_bind_param(test_uuid, None) == str(test_uuid) + + +def test_sqliteuuid_process_bind_param_valid_string(): + """Test SQLiteUUID.process_bind_param with valid UUID string returns string""" + uuid_type = SQLiteUUID() + test_uuid_str = "550e8400-e29b-41d4-a716-446655440000" + assert uuid_type.process_bind_param(test_uuid_str, None) == test_uuid_str + + +def test_sqliteuuid_process_bind_param_invalid_string(): + """Test SQLiteUUID.process_bind_param raises ValueError for invalid UUID""" + uuid_type = SQLiteUUID() + with pytest.raises(ValueError, match="Invalid UUID string format"): + uuid_type.process_bind_param("invalid-uuid", None) + + +def test_sqliteuuid_process_result_value_none(): + """Test SQLiteUUID.process_result_value with None returns None""" + uuid_type = SQLiteUUID() + assert uuid_type.process_result_value(None, None) is None + + +def test_sqliteuuid_process_result_value_valid_string(): + """Test SQLiteUUID.process_result_value converts string to UUID""" + uuid_type = SQLiteUUID() + test_uuid = uuid.uuid4() + result = uuid_type.process_result_value(str(test_uuid), None) + assert isinstance(result, uuid.UUID) + assert result == test_uuid + + +def test_sqliteuuid_compare_values_none(): + """Test SQLiteUUID.compare_values handles None values""" + uuid_type = SQLiteUUID() + assert uuid_type.compare_values(None, None) is True + assert uuid_type.compare_values(None, uuid.uuid4()) is False + assert uuid_type.compare_values(uuid.uuid4(), None) is False + + +def test_sqliteuuid_compare_values_uuid(): + """Test SQLiteUUID.compare_values compares UUIDs as strings""" + uuid_type = SQLiteUUID() + test_uuid = uuid.uuid4() + assert uuid_type.compare_values(test_uuid, test_uuid) is True + assert uuid_type.compare_values(test_uuid, uuid.uuid4()) is False + + +def test_sqlite_uuid_comparison(): + """Test SQLiteUUID comparison functionality (moved from db_mocks.py)""" + uuid_type = SQLiteUUID() + + # Test equal UUIDs + uuid1 = uuid.uuid4() + uuid2 = uuid.UUID(str(uuid1)) + assert uuid_type.compare_values(uuid1, uuid2) is True + + # Test UUID vs string + assert uuid_type.compare_values(uuid1, str(uuid1)) is True + + # Test None comparisons + assert uuid_type.compare_values(None, None) is True + assert uuid_type.compare_values(uuid1, None) is False + assert uuid_type.compare_values(None, uuid1) is False + + # Test different UUIDs + uuid3 = uuid.uuid4() + assert uuid_type.compare_values(uuid1, uuid3) is False + + +def test_sqlite_uuid_binding(): + """Test SQLiteUUID binding parameter handling (moved from db_mocks.py)""" + uuid_type = SQLiteUUID() + + # Test UUID object binding + uuid_obj = uuid.uuid4() + assert uuid_type.process_bind_param(uuid_obj, None) == str(uuid_obj) + + # Test valid UUID string binding + uuid_str = str(uuid.uuid4()) + assert uuid_type.process_bind_param(uuid_str, None) == uuid_str + + # Test None handling + assert uuid_type.process_bind_param(None, None) is None + + # Test invalid UUID string + with pytest.raises(ValueError): + uuid_type.process_bind_param("invalid-uuid", None) + + +# --- Test UUID Column Type Configuration --- + + +def test_uuid_column_type_default(): + """Test UUID_COLUMN_TYPE uses SQLiteUUID in test environment""" + assert isinstance(UUID_COLUMN_TYPE, SQLiteUUID) + + +@patch.dict(os.environ, {"MOCK_AUTH": "false"}) +def test_uuid_column_type_postgres(): + """Test UUID_COLUMN_TYPE uses Postgres UUID when MOCK_AUTH=false""" + # Need to re-import to get the patched environment + from importlib import reload + + from app import models + + reload(models.db) + from sqlalchemy.dialects.postgresql import UUID as PostgresUUID + + from app.models.db import UUID_COLUMN_TYPE + + assert isinstance(UUID_COLUMN_TYPE, PostgresUUID) diff --git a/tests/routers/test_channels.py b/tests/routers/test_channels.py index db3864b..349ad20 100644 --- a/tests/routers/test_channels.py +++ b/tests/routers/test_channels.py @@ -1,13 +1,14 @@ import uuid +from datetime import datetime, timezone import pytest -from fastapi import FastAPI, status +from fastapi import status from fastapi.testclient import TestClient from sqlalchemy import String from sqlalchemy.orm import Session from app.auth.dependencies import get_current_user -from app.routers.channels import router as channels_router +from app.main import app from app.utils.database import get_db # Import mocks and fixtures @@ -22,18 +23,16 @@ from tests.utils.db_mocks import ( MockBase, MockChannelDB, MockChannelURL, + MockGroup, MockPriority, + create_mock_priorities_and_group, engine_mock, mock_get_db, ) from tests.utils.db_mocks import session_mock as TestingSessionLocal -# Create a FastAPI instance for testing -app = FastAPI() - -# Override dependencies +# Override dependencies for testing app.dependency_overrides[get_db] = mock_get_db -app.include_router(channels_router) client = TestClient(app) @@ -42,26 +41,24 @@ client = TestClient(app) def test_create_channel_success(db_session: Session, admin_user_client: TestClient): - # Setup a priority - priority1 = MockPriority(id=100, description="High") - db_session.add(priority1) - db_session.commit() + # Create mock priority and group + group_id = create_mock_priorities_and_group( + db_session, [(100, "High")], "Test Group" + ) channel_data = { "tvg_id": "channel1.tv", "name": "Test Channel 1", - "group_title": "Test Group", + "group_id": str(group_id), "tvg_name": "TestChannel1", "tvg_logo": "logo.png", "urls": [{"url": "http://stream1.com/test", "priority_id": 100}], } - response = admin_user_client.post( - "/channels/", json=channel_data - ) # No headers needed now + response = admin_user_client.post("/channels/", json=channel_data) assert response.status_code == status.HTTP_201_CREATED data = response.json() assert data["name"] == "Test Channel 1" - assert data["group_title"] == "Test Group" + assert data["group_id"] == str(group_id) assert data["tvg_id"] == "channel1.tv" assert len(data["urls"]) == 1 assert data["urls"][0]["url"] == "http://stream1.com/test" @@ -74,12 +71,12 @@ def test_create_channel_success(db_session: Session, admin_user_client: TestClie .first() ) assert db_channel is not None - assert db_channel.group_title == "Test Group" + assert db_channel.group_id == group_id - # Query URLs using exact string comparison + # Query URLs db_urls = ( db_session.query(MockChannelURL) - .filter(MockChannelURL.channel_id.cast(String()) == db_channel.id) + .filter(MockChannelURL.channel_id == db_channel.id) .all() ) @@ -88,16 +85,16 @@ def test_create_channel_success(db_session: Session, admin_user_client: TestClie def test_create_channel_duplicate(db_session: Session, admin_user_client: TestClient): - # Setup a priority - priority1 = MockPriority(id=100, description="High") - db_session.add(priority1) - db_session.commit() + # Create mock priority and group + group_id = create_mock_priorities_and_group( + db_session, [(100, "High")], "Duplicate Group" + ) # Create initial channel initial_channel_data = { "tvg_id": "channel_dup.tv", "name": "Duplicate Channel", - "group_title": "Duplicate Group", + "group_id": str(group_id), "tvg_name": "DuplicateChannelName", "tvg_logo": "duplicate_logo.png", "urls": [{"url": "http://stream_dup.com/test", "priority_id": 100}], @@ -114,15 +111,15 @@ def test_create_channel_duplicate(db_session: Session, admin_user_client: TestCl def test_create_channel_forbidden_for_non_admin( db_session: Session, non_admin_user_client: TestClient ): - # Setup a priority - priority1 = MockPriority(id=100, description="High") - db_session.add(priority1) - db_session.commit() + # Create mock priority and group + group_id = create_mock_priorities_and_group( + db_session, [(100, "High")], "Forbidden Group" + ) channel_data = { "tvg_id": "channel_forbidden.tv", "name": "Forbidden Channel", - "group_title": "Forbidden Group", + "group_id": str(group_id), "tvg_name": "ForbiddenChannelName", "tvg_logo": "forbidden_logo.png", "urls": [{"url": "http://stream_forbidden.com/test", "priority_id": 100}], @@ -132,20 +129,39 @@ def test_create_channel_forbidden_for_non_admin( assert "required roles" in response.json()["detail"] +def test_create_channel_group_not_found( + db_session: Session, admin_user_client: TestClient +): + """Test creating channel with non-existent group returns 404""" + # No group created in DB + channel_data = { + "tvg_id": "no_group.tv", + "name": "No Group Channel", + "group_id": str(uuid.uuid4()), # Random non-existent group ID + "tvg_name": "NoGroupChannel", + "tvg_logo": "no_group_logo.png", + "urls": [{"url": "http://no_group.com/stream", "priority_id": 100}], + } + response = admin_user_client.post("/channels/", json=channel_data) + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "Group not found" in response.json()["detail"] + + # --- Test Cases For Get Channel --- def test_get_channel_success(db_session: Session, admin_user_client: TestClient): - # Setup a priority - priority1 = MockPriority(id=100, description="High") - db_session.add(priority1) - db_session.commit() + # Create priority and group using utility function + group_id = create_mock_priorities_and_group( + db_session, [(100, "High")], "Get Group" + ) # Create a channel first + channel_data_create = { "tvg_id": "get_me.tv", "name": "Get Me Channel", - "group_title": "Get Group", + "group_id": str(group_id), "tvg_name": "GetMeChannelName", "tvg_logo": "get_me_logo.png", "urls": [{"url": "http://get_me.com/stream", "priority_id": 100}], @@ -162,7 +178,7 @@ def test_get_channel_success(db_session: Session, admin_user_client: TestClient) data = get_response.json() assert data["id"] == created_channel_id assert data["name"] == "Get Me Channel" - assert data["group_title"] == "Get Group" + assert data["group_id"] == str(group_id) assert len(data["urls"]) == 1 app.dependency_overrides.pop(get_current_user, None) @@ -180,15 +196,15 @@ def test_get_channel_not_found(db_session: Session, admin_user_client: TestClien def test_update_channel_success(db_session: Session, admin_user_client: TestClient): - # Setup priority and create initial channel - priority1 = MockPriority(id=100, description="High") - db_session.add(priority1) - db_session.commit() + # Create priority and group using utility function + group_id = create_mock_priorities_and_group( + db_session, [(100, "High")], "Update Group" + ) initial_channel_data = { "tvg_id": "update_me.tv", "name": "Update Me Channel", - "group_title": "Update Group", + "group_id": str(group_id), "tvg_name": "UpdateMeChannelName", "tvg_logo": "update_me_logo.png", "urls": [{"url": "http://update_me.com/stream", "priority_id": 100}], @@ -205,13 +221,13 @@ def test_update_channel_success(db_session: Session, admin_user_client: TestClie data = response.json() assert data["id"] == created_channel_id assert data["name"] == "Updated Channel Name" - assert data["group_title"] == "Update Group" + assert data["group_id"] == str(group_id) assert data["tvg_logo"] == "new_logo.png" # Verify in DB db_channel = ( db_session.query(MockChannelDB) - .filter(MockChannelDB.id.cast(String()) == uuid.UUID(created_channel_id)) + .filter(MockChannelDB.id == created_channel_id) .first() ) assert db_channel is not None @@ -220,16 +236,15 @@ def test_update_channel_success(db_session: Session, admin_user_client: TestClie def test_update_channel_conflict(db_session: Session, admin_user_client: TestClient): - # Setup priority - priority1 = MockPriority(id=100, description="High") - db_session.add(priority1) - db_session.commit() + # Create priorities and groups using utility function + group1_id = create_mock_priorities_and_group(db_session, [(100, "High")], "Group A") + group2_id = create_mock_priorities_and_group(db_session, [], "Group B") # Create channel 1 channel1_data = { "tvg_id": "c1.tv", "name": "Channel One", - "group_title": "Group A", + "group_id": str(group1_id), "tvg_name": "C1Name", "tvg_logo": "c1logo.png", "urls": [{"url": "http://c1.com", "priority_id": 100}], @@ -240,7 +255,7 @@ def test_update_channel_conflict(db_session: Session, admin_user_client: TestCli channel2_data = { "tvg_id": "c2.tv", "name": "Channel Two", - "group_title": "Group B", + "group_id": str(group2_id), "tvg_name": "C2Name", "tvg_logo": "c2logo.png", "urls": [{"url": "http://c2.com", "priority_id": 100}], @@ -249,7 +264,7 @@ def test_update_channel_conflict(db_session: Session, admin_user_client: TestCli channel2_id = response_c2.json()["id"] # Attempt to update channel 2 to conflict with channel 1 - update_conflict_data = {"name": "Channel One", "group_title": "Group A"} + update_conflict_data = {"name": "Channel One", "group_id": str(group1_id)} response = admin_user_client.put( f"/channels/{channel2_id}", json=update_conflict_data ) @@ -265,19 +280,49 @@ def test_update_channel_not_found(db_session: Session, admin_user_client: TestCl assert "Channel not found" in response.json()["detail"] +def test_update_channel_group_not_found( + db_session: Session, admin_user_client: TestClient +): + """Test updating channel with non-existent group returns 404""" + # Create priority and group using utility function + group_id = create_mock_priorities_and_group( + db_session, [(100, "High")], "Original Group" + ) + + initial_channel_data = { + "tvg_id": "original.tv", + "name": "Original Channel", + "group_id": str(group_id), + "tvg_name": "OriginalName", + "tvg_logo": "original_logo.png", + "urls": [{"url": "http://original.com", "priority_id": 100}], + } + create_response = admin_user_client.post("/channels/", json=initial_channel_data) + created_channel_id = create_response.json()["id"] + + # Attempt to update with non-existent group + update_data = {"group_id": str(uuid.uuid4())} + response = admin_user_client.put( + f"/channels/{created_channel_id}", json=update_data + ) + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "Group not found" in response.json()["detail"] + + def test_update_channel_forbidden_for_non_admin( db_session: Session, non_admin_user_client: TestClient, admin_user_client: TestClient, ): - # Setup priority and create initial channel with admin - priority1 = MockPriority(id=100, description="High") - db_session.add(priority1) - db_session.commit() + # Create priority and group using utility function + group_id = create_mock_priorities_and_group( + db_session, [(100, "High")], "Forbidden Update Group" + ) + initial_channel_data = { "tvg_id": "update_forbidden.tv", "name": "Update Forbidden", - "group_title": "Forbidden Update Group", + "group_id": str(group_id), "tvg_name": "UFName", "tvg_logo": "uflogo.png", "urls": [{"url": "http://update_forbidden.com", "priority_id": 100}], @@ -297,15 +342,15 @@ def test_update_channel_forbidden_for_non_admin( def test_delete_channel_success(db_session: Session, admin_user_client: TestClient): - # Setup priority and create initial channel - priority1 = MockPriority(id=100, description="High") - db_session.add(priority1) - db_session.commit() + # Create priority and group + group_id = create_mock_priorities_and_group( + db_session, [(100, "High")], "Delete Group" + ) initial_channel_data = { "tvg_id": "delete_me.tv", "name": "Delete Me Channel", - "group_title": "Delete Group", + "group_id": str(group_id), # Use the ID of the created group "tvg_name": "DMName", "tvg_logo": "dmlogo.png", "urls": [{"url": "http://delete_me.com/stream", "priority_id": 100}], @@ -317,7 +362,7 @@ def test_delete_channel_success(db_session: Session, admin_user_client: TestClie # Verify it exists before delete db_channel_before_delete = ( db_session.query(MockChannelDB) - .filter(MockChannelDB.id.cast(String()) == uuid.UUID(created_channel_id)) + .filter(MockChannelDB.id == uuid.UUID(created_channel_id)) .first() ) assert db_channel_before_delete is not None @@ -328,7 +373,7 @@ def test_delete_channel_success(db_session: Session, admin_user_client: TestClie # Verify it's gone from DB db_channel_after_delete = ( db_session.query(MockChannelDB) - .filter(MockChannelDB.id.cast(String()) == uuid.UUID(created_channel_id)) + .filter(MockChannelDB.id == uuid.UUID(created_channel_id)) .first() ) assert db_channel_after_delete is None @@ -336,9 +381,7 @@ def test_delete_channel_success(db_session: Session, admin_user_client: TestClie # Also verify associated URLs are deleted (due to CASCADE in mock model) db_urls_after_delete = ( db_session.query(MockChannelURL) - .filter( - MockChannelURL.channel_id.cast(String()) == uuid.UUID(created_channel_id) - ) + .filter(MockChannelURL.channel_id == uuid.UUID(created_channel_id)) .all() ) assert len(db_urls_after_delete) == 0 @@ -356,14 +399,15 @@ def test_delete_channel_forbidden_for_non_admin( non_admin_user_client: TestClient, admin_user_client: TestClient, ): - # Setup priority and create initial channel with admin - priority1 = MockPriority(id=100, description="High") - db_session.add(priority1) - db_session.commit() + # Create priority and group + group_id = create_mock_priorities_and_group( + db_session, [(100, "High")], "Forbidden Delete Group" + ) + initial_channel_data = { "tvg_id": "delete_forbidden.tv", "name": "Delete Forbidden", - "group_title": "Forbidden Delete Group", + "group_id": str(group_id), # Use the ID of the created group "tvg_name": "DFName", "tvg_logo": "dflogo.png", "urls": [{"url": "http://delete_forbidden.com", "priority_id": 100}], @@ -378,7 +422,7 @@ def test_delete_channel_forbidden_for_non_admin( # Ensure channel was not deleted db_channel_not_deleted = ( db_session.query(MockChannelDB) - .filter(MockChannelDB.id.cast(String()) == uuid.UUID(created_channel_id)) + .filter(MockChannelDB.id == uuid.UUID(created_channel_id)) .first() ) assert db_channel_not_deleted is not None @@ -393,20 +437,230 @@ def test_list_channels_empty(db_session: Session, admin_user_client: TestClient) assert response.json() == [] +def test_get_channels_by_group_success( + db_session: Session, admin_user_client: TestClient +): + """Test getting channels for an existing group""" + # Create priority and groups + group1_id = create_mock_priorities_and_group(db_session, [(100, "High")], "Group 1") + + group2_id = create_mock_priorities_and_group(db_session, [], "Group 2") + + # Create 2 channels in group1 and 1 in group2 + channels_group1 = [ + { + "tvg_id": f"g1c{i}.tv", + "name": f"Group1 Channel {i}", + "group_id": str(group1_id), + "tvg_name": f"G1C{i}", + "tvg_logo": f"g1c{i}_logo.png", + "urls": [{"url": f"http://g1c{i}.com", "priority_id": 100}], + } + for i in range(2) + ] + channel_group2 = { + "tvg_id": "g2c1.tv", + "name": "Group2 Channel 1", + "group_id": str(group2_id), + "tvg_name": "G2C1", + "tvg_logo": "g2c1_logo.png", + "urls": [{"url": "http://g2c1.com", "priority_id": 100}], + } + + # Create all channels + for channel_data in channels_group1 + [channel_group2]: + admin_user_client.post("/channels/", json=channel_data) + + # Get channels for group1 + response = admin_user_client.get(f"/channels/groups/{group1_id}/channels") + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data) == 2 + assert all(channel["group_id"] == str(group1_id) for channel in data) + + +def test_get_channels_by_group_not_found( + db_session: Session, admin_user_client: TestClient +): + """Test getting channels for non-existent group returns 404""" + random_uuid = uuid.uuid4() + response = admin_user_client.get(f"/channels/groups/{random_uuid}/channels") + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "Group not found" in response.json()["detail"] + + +def test_update_channel_group_success( + db_session: Session, admin_user_client: TestClient +): + """Test successfully updating a channel's group""" + # Create priority and group + group1_id = create_mock_priorities_and_group( + db_session, [(100, "High")], "Original Group" + ) + + channel_data = { + "tvg_id": "original.tv", + "name": "Original Channel", + "group_id": str(group1_id), + "tvg_name": "OriginalName", + "tvg_logo": "original_logo.png", + "urls": [{"url": "http://original.com", "priority_id": 100}], + } + create_response = admin_user_client.post("/channels/", json=channel_data) + channel_id = create_response.json()["id"] + + # Create target group + group2_id = create_mock_priorities_and_group(db_session, [], "Target Group") + + # Update channel's group + response = admin_user_client.put( + f"/channels/{channel_id}/group?group_id={group2_id}" + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["id"] == channel_id + assert data["group_id"] == str(group2_id) + + # Verify in DB + db_channel = ( + db_session.query(MockChannelDB) + .filter(MockChannelDB.id == uuid.UUID(channel_id)) + .first() + ) + assert db_channel.group_id == group2_id + + +def test_update_channel_group_channel_not_found( + db_session: Session, admin_user_client: TestClient +): + """Test updating non-existent channel's group returns 404""" + # Create priority and group + group_id = create_mock_priorities_and_group( + db_session, [(100, "High")], "Test Group" + ) + + # Attempt to update non-existent channel + random_uuid = uuid.uuid4() + response = admin_user_client.put( + f"/channels/{random_uuid}/group?group_id={group_id}" + ) + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "Channel not found" in response.json()["detail"] + + +def test_update_channel_group_group_not_found( + db_session: Session, admin_user_client: TestClient +): + """Test updating channel to non-existent group returns 404""" + # Create priority and group + group_id = create_mock_priorities_and_group( + db_session, [(100, "High")], "Original Group" + ) + + # Create a channel in the original group + channel_data = { + "tvg_id": "original.tv", + "name": "Original Channel", + "group_id": str(group_id), + "tvg_name": "OriginalName", + "tvg_logo": "original_logo.png", + "urls": [{"url": "http://original.com", "priority_id": 100}], + } + create_response = admin_user_client.post("/channels/", json=channel_data) + channel_id = create_response.json()["id"] + + # Attempt to update with non-existent group + response = admin_user_client.put( + f"/channels/{channel_id}/group?group_id={uuid.uuid4()}" + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "Group not found" in response.json()["detail"] + + +def test_update_channel_group_duplicate_name( + db_session: Session, admin_user_client: TestClient +): + """Test updating channel to group with duplicate name returns 409""" + # Create priority and groups + group1_id = create_mock_priorities_and_group(db_session, [(100, "High")], "Group 1") + group2_id = create_mock_priorities_and_group(db_session, [], "Group 2") + + # Create channel in each group with same name + channel_name = "Duplicate Channel" + channel1_data = { + "tvg_id": "c1.tv", + "name": channel_name, + "group_id": str(group1_id), + "tvg_name": "C1", + "tvg_logo": "c1.png", + "urls": [{"url": "http://c1.com", "priority_id": 100}], + } + channel2_data = { + "tvg_id": "c2.tv", + "name": channel_name, + "group_id": str(group2_id), + "tvg_name": "C2", + "tvg_logo": "c2.png", + "urls": [{"url": "http://c2.com", "priority_id": 100}], + } + admin_user_client.post("/channels/", json=channel1_data) + create2_response = admin_user_client.post("/channels/", json=channel2_data) + channel2_id = create2_response.json()["id"] + + # Attempt to move channel2 to group1 (would create duplicate name) + response = admin_user_client.put( + f"/channels/{channel2_id}/group?group_id={group1_id}" + ) + assert response.status_code == status.HTTP_409_CONFLICT + assert "already exists" in response.json()["detail"] + + +def test_update_channel_group_forbidden_for_non_admin( + db_session: Session, + non_admin_user_client: TestClient, + admin_user_client: TestClient, +): + """Test updating channel's group as non-admin returns 403""" + # Create priority and groups + group1_id = create_mock_priorities_and_group( + db_session, [(100, "High")], "Original Group" + ) + group2_id = create_mock_priorities_and_group(db_session, [], "Target Group") + + channel_data = { + "tvg_id": "protected.tv", + "name": "Protected Channel", + "group_id": str(group1_id), + "tvg_name": "Protected", + "tvg_logo": "protected.png", + "urls": [{"url": "http://protected.com", "priority_id": 100}], + } + create_response = admin_user_client.post("/channels/", json=channel_data) + channel_id = create_response.json()["id"] + + # Attempt to update group as non-admin + response = non_admin_user_client.put( + f"/channels/{channel_id}/group?group_id={group2_id}" + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + assert "required roles" in response.json()["detail"] + + def test_list_channels_with_data_and_pagination( db_session: Session, admin_user_client: TestClient ): - # Setup priority - priority1 = MockPriority(id=100, description="High") - db_session.add(priority1) - db_session.commit() + # Create priority and group + group_id = create_mock_priorities_and_group( + db_session, [(100, "High")], "List Group" + ) # Create some channels for i in range(5): channel_data = { "tvg_id": f"list_c{i}.tv", "name": f"List Channel {i}", - "group_title": "List Group", + "group_id": str(group_id), "tvg_name": f"LCName{i}", "tvg_logo": f"lclogo{i}.png", "urls": [{"url": f"http://list_c{i}.com", "priority_id": 100}], @@ -454,16 +708,15 @@ def test_list_channels_forbidden_for_non_admin( def test_add_channel_url_success(db_session: Session, admin_user_client: TestClient): - # Setup priority and create a channel - priority1 = MockPriority(id=100, description="High") - priority2 = MockPriority(id=200, description="Medium") - db_session.add_all([priority1, priority2]) - db_session.commit() + # Setup priorities and create a group + group_id = create_mock_priorities_and_group( + db_session, [(100, "High"), (200, "Medium")], "URL Group" + ) channel_data = { "tvg_id": "channel_for_url.tv", "name": "Channel For URL", - "group_title": "URL Group", + "group_id": str(group_id), "tvg_name": "CFUName", "tvg_logo": "cfulogo.png", "urls": [{"url": "http://initial.com/stream", "priority_id": 100}], @@ -485,7 +738,7 @@ def test_add_channel_url_success(db_session: Session, admin_user_client: TestCli # Verify in DB db_url = ( db_session.query(MockChannelURL) - .filter(MockChannelURL.id.cast(String()) == uuid.UUID(data["id"])) + .filter(MockChannelURL.id == uuid.UUID(data["id"])) .first() ) assert db_url is not None @@ -503,9 +756,7 @@ def test_add_channel_url_success(db_session: Session, admin_user_client: TestCli # for the count of URLs for the channel url_count = ( db_session.query(MockChannelURL) - .filter( - MockChannelURL.channel_id.cast(String()) == uuid.UUID(created_channel_id) - ) + .filter(MockChannelURL.channel_id == uuid.UUID(created_channel_id)) .count() ) assert url_count == 2 @@ -521,9 +772,7 @@ def test_add_channel_url_channel_not_found( db_session: Session, admin_user_client: TestClient ): # Setup priority - priority1 = MockPriority(id=100, description="High") - db_session.add(priority1) - db_session.commit() + create_mock_priorities_and_group(db_session, [(100, "High")], "My Group") random_channel_uuid = uuid.uuid4() url_data = {"url": "http://stream_no_channel.com", "priority_id": 100} @@ -540,13 +789,14 @@ def test_add_channel_url_forbidden_for_non_admin( admin_user_client: TestClient, ): # Setup priority and create a channel with admin - priority1 = MockPriority(id=100, description="High") - db_session.add(priority1) - db_session.commit() + group_id = create_mock_priorities_and_group( + db_session, [(100, "High")], "URL Forbidden Group" + ) + channel_data = { "tvg_id": "url_forbidden.tv", "name": "URL Forbidden", - "group_title": "URL Forbidden Group", + "group_id": str(group_id), # Use the ID of the created group "tvg_name": "UFName2", "tvg_logo": "uflogo2.png", "urls": [{"url": "http://url_forbidden.com", "priority_id": 100}], @@ -567,18 +817,14 @@ def test_add_channel_url_forbidden_for_non_admin( def test_update_channel_url_success(db_session: Session, admin_user_client: TestClient): # Setup priorities and create a channel with a URL - priority1 = MockPriority(id=100, description="High") - priority2 = MockPriority(id=200, description="Medium") - priority3 = MockPriority( - id=300, description="Low" - ) # New priority for update, Use valid priority ID - db_session.add_all([priority1, priority2, priority3]) - db_session.commit() + group_id = create_mock_priorities_and_group( + db_session, [(100, "High"), (200, "Medium"), (300, "Low")], "URL Update Group" + ) channel_data = { "tvg_id": "ch_update_url.tv", "name": "Channel Update URL", - "group_title": "URL Update Group", + "group_id": str(group_id), # Use the ID of the created group "tvg_name": "CUUName", "tvg_logo": "cuulogo.png", "urls": [{"url": "http://original_url.com/stream", "priority_id": 100}], @@ -606,7 +852,7 @@ def test_update_channel_url_success(db_session: Session, admin_user_client: Test # Verify in DB db_url = ( db_session.query(MockChannelURL) - .filter(MockChannelURL.id.cast(String()) == uuid.UUID(initial_url_id)) + .filter(MockChannelURL.id == uuid.UUID(initial_url_id)) .first() ) assert db_url is not None @@ -619,14 +865,14 @@ def test_update_channel_url_partial_success( db_session: Session, admin_user_client: TestClient ): # Setup priorities and create a channel with a URL - priority1 = MockPriority(id=100, description="High") - db_session.add_all([priority1]) - db_session.commit() + group_id = create_mock_priorities_and_group( + db_session, [(100, "High")], "URL Partial Update Group" + ) channel_data = { "tvg_id": "ch_partial_update_url.tv", "name": "Channel Partial Update URL", - "group_title": "URL Partial Update Group", + "group_id": str(group_id), "tvg_name": "CPUName", "tvg_logo": "cpulogo.png", "urls": [{"url": "http://partial_original.com/stream", "priority_id": 100}], @@ -650,7 +896,7 @@ def test_update_channel_url_partial_success( # Verify in DB db_url = ( db_session.query(MockChannelURL) - .filter(MockChannelURL.id.cast(String()) == uuid.UUID(initial_url_id)) + .filter(MockChannelURL.id == uuid.UUID(initial_url_id)) .first() ) assert db_url is not None @@ -662,14 +908,15 @@ def test_update_channel_url_partial_success( def test_update_channel_url_url_not_found( db_session: Session, admin_user_client: TestClient ): - # Setup priority and create a channel - priority1 = MockPriority(id=100, description="High") - db_session.add(priority1) - db_session.commit() + # Setup priority and create a group + group_id = create_mock_priorities_and_group( + db_session, [(100, "High")], "URL Not Found Group" + ) + channel_data = { "tvg_id": "ch_url_not_found.tv", "name": "Channel URL Not Found", - "group_title": "URL Not Found Group", + "group_id": str(group_id), "tvg_name": "CUNFName", "tvg_logo": "cunflogo.png", "urls": [], @@ -691,15 +938,17 @@ def test_update_channel_url_channel_id_mismatch_is_url_not_found( ): # This tests if a URL ID exists but is not associated # with the given channel_id in the path - priority1 = MockPriority(id=100, description="High") - db_session.add(priority1) - db_session.commit() + + # Setup priority and create a group + group1_id = create_mock_priorities_and_group( + db_session, [(100, "High")], "Ch1 Group" + ) # Create channel 1 with a URL ch1_data = { "tvg_id": "ch1_url_mismatch.tv", "name": "CH1 URL Mismatch", - "group_title": "G1", + "group_id": str(group1_id), "tvg_name": "C1UMName", "tvg_logo": "c1umlogo.png", "urls": [{"url": "http://ch1.url", "priority_id": 100}], @@ -707,11 +956,16 @@ def test_update_channel_url_channel_id_mismatch_is_url_not_found( ch1_resp = admin_user_client.post("/channels/", json=ch1_data) url_id_from_ch1 = ch1_resp.json()["urls"][0]["id"] + # Create another group + group2_id = create_mock_priorities_and_group( + db_session, [(200, "Medium")], "Ch2 Group" + ) + # Create channel 2 ch2_data = { "tvg_id": "ch2_url_mismatch.tv", "name": "CH2 URL Mismatch", - "group_title": "G2", + "group_id": str(group2_id), "tvg_name": "C2UMName", "tvg_logo": "c2umlogo.png", "urls": [], @@ -733,14 +987,15 @@ def test_update_channel_url_forbidden_for_non_admin( non_admin_user_client: TestClient, admin_user_client: TestClient, ): - # Setup priority and create channel with URL using admin - priority1 = MockPriority(id=100, description="High") - db_session.add(priority1) - db_session.commit() + # Setup priority, group and create channel with URL using admin + group_id = create_mock_priorities_and_group( + db_session, [(100, "High")], "URL Update Forbidden Group" + ) + channel_data = { "tvg_id": "ch_update_url_forbidden.tv", "name": "Channel Update URL Forbidden", - "group_title": "URL Update Forbidden Group", + "group_id": str(group_id), "tvg_name": "CUFName", "tvg_logo": "cuflgo.png", "urls": [{"url": "http://original_forbidden.com/stream", "priority_id": 100}], @@ -761,15 +1016,15 @@ def test_update_channel_url_forbidden_for_non_admin( def test_delete_channel_url_success(db_session: Session, admin_user_client: TestClient): - # Setup priority and create a channel with a URL - priority1 = MockPriority(id=100, description="High") - db_session.add(priority1) - db_session.commit() + # Setup priority, group and create a channel with a URL + group_id = create_mock_priorities_and_group( + db_session, [(100, "High")], "URL Delete Group" + ) channel_data = { "tvg_id": "ch_delete_url.tv", "name": "Channel Delete URL", - "group_title": "URL Delete Group", + "group_id": str(group_id), "tvg_name": "CDUName", "tvg_logo": "cdulogo.png", "urls": [{"url": "http://delete_this_url.com/stream", "priority_id": 100}], @@ -781,7 +1036,7 @@ def test_delete_channel_url_success(db_session: Session, admin_user_client: Test # Verify URL exists before delete db_url_before = ( db_session.query(MockChannelURL) - .filter(MockChannelURL.id.cast(String()) == uuid.UUID(url_to_delete_id)) + .filter(MockChannelURL.id == url_to_delete_id) .first() ) assert db_url_before is not None @@ -794,7 +1049,7 @@ def test_delete_channel_url_success(db_session: Session, admin_user_client: Test # Verify URL is gone from DB db_url_after = ( db_session.query(MockChannelURL) - .filter(MockChannelURL.id.cast(String()) == uuid.UUID(url_to_delete_id)) + .filter(MockChannelURL.id == url_to_delete_id) .first() ) assert db_url_after is None @@ -808,14 +1063,15 @@ def test_delete_channel_url_success(db_session: Session, admin_user_client: Test def test_delete_channel_url_url_not_found( db_session: Session, admin_user_client: TestClient ): - # Setup priority and create a channel - priority1 = MockPriority(id=100, description="High") - db_session.add(priority1) - db_session.commit() + # Setup priority, group and create a channel with a URL + group_id = create_mock_priorities_and_group( + db_session, [(100, "High")], "URL Del Not Found Group" + ) + channel_data = { "tvg_id": "ch_del_url_not_found.tv", "name": "Channel Del URL Not Found", - "group_title": "URL Del Not Found Group", + "group_id": str(group_id), "tvg_name": "CDUNFName", "tvg_logo": "cdunflogo.png", "urls": [], @@ -834,9 +1090,8 @@ def test_delete_channel_url_url_not_found( def test_delete_channel_url_channel_id_mismatch_is_url_not_found( db_session: Session, admin_user_client: TestClient ): - priority1 = MockPriority(id=100, description="High") - db_session.add(priority1) - db_session.commit() + # Setup priority, group 1 and create a channel 1 with a URL + group1_id = create_mock_priorities_and_group(db_session, [(100, "High")], "G1Del") # Create channel 1 with a URL ch1_data = { @@ -844,20 +1099,23 @@ def test_delete_channel_url_channel_id_mismatch_is_url_not_found( "tvg_id": "ch1_del_url_mismatch.tv", "tvg_name": "CH1 Del URL Mismatch", "tvg_logo": "ch1delogo.png", - "group_title": "G1Del", + "group_id": str(group1_id), "urls": [{"url": "http://ch1del.url", "priority_id": 100}], } ch1_resp = admin_user_client.post("/channels/", json=ch1_data) print(ch1_resp.json()) url_id_from_ch1 = ch1_resp.json()["urls"][0]["id"] + # Setup group 2 and create a channel 2 + group2_id = create_mock_priorities_and_group(db_session, [], "G2Del") + # Create channel 2 ch2_data = { "tvg_id": "ch2_del_url_mismatch.tv", "name": "CH2 Del URL Mismatch", "tvg_name": "CH2 Del URL Mismatch", "tvg_logo": "ch2delogo.png", - "group_title": "G2Del", + "group_id": str(group2_id), "urls": [], } ch2_resp = admin_user_client.post("/channels/", json=ch2_data) @@ -871,7 +1129,7 @@ def test_delete_channel_url_channel_id_mismatch_is_url_not_found( # Ensure the original URL on CH1 was not deleted db_url_ch1 = ( db_session.query(MockChannelURL) - .filter(MockChannelURL.id.cast(String()) == uuid.UUID(url_id_from_ch1)) + .filter(MockChannelURL.id == url_id_from_ch1) .first() ) assert db_url_ch1 is not None @@ -882,14 +1140,15 @@ def test_delete_channel_url_forbidden_for_non_admin( non_admin_user_client: TestClient, admin_user_client: TestClient, ): - # Setup priority and create channel with URL using admin - priority1 = MockPriority(id=100, description="High") - db_session.add(priority1) - db_session.commit() + # Setup priority, group and create a channel with a URL + group_id = create_mock_priorities_and_group( + db_session, [(100, "High")], "URL Del Forbidden Group" + ) + channel_data = { "tvg_id": "ch_del_url_forbidden.tv", "name": "Channel Del URL Forbidden", - "group_title": "URL Del Forbidden Group", + "group_id": str(group_id), "tvg_name": "CDUFName", "tvg_logo": "cduflogo.png", "urls": [ @@ -909,7 +1168,7 @@ def test_delete_channel_url_forbidden_for_non_admin( # Ensure URL was not deleted db_url_not_deleted = ( db_session.query(MockChannelURL) - .filter(MockChannelURL.id.cast(String()) == uuid.UUID(initial_url_id)) + .filter(MockChannelURL.id == initial_url_id) .first() ) assert db_url_not_deleted is not None @@ -919,16 +1178,15 @@ def test_delete_channel_url_forbidden_for_non_admin( def test_list_channel_urls_success(db_session: Session, admin_user_client: TestClient): - # Setup priorities and create a channel with multiple URLs - priority1 = MockPriority(id=100, description="High") - priority2 = MockPriority(id=200, description="Medium") - db_session.add_all([priority1, priority2]) - db_session.commit() + # Setup priorities, group and create a channel with a URL + group_id = create_mock_priorities_and_group( + db_session, [(100, "High"), (200, "Medium")], "URL List Group" + ) channel_data = { "tvg_id": "ch_list_urls.tv", "name": "Channel List URLs", - "group_title": "URL List Group", + "group_id": str(group_id), "tvg_name": "CLUName", "tvg_logo": "clulogo.png", "urls": [ @@ -961,10 +1219,14 @@ def test_list_channel_urls_success(db_session: Session, admin_user_client: TestC def test_list_channel_urls_empty(db_session: Session, admin_user_client: TestClient): # Create a channel with no URLs initially # No need to set up MockPriority if no URLs with priority_id are being created. + + # Setup group + group_id = create_mock_priorities_and_group(db_session, [], "URL List Empty Group") + channel_data = { "tvg_id": "ch_list_empty_urls.tv", "name": "Channel List Empty URLs", - "group_title": "URL List Empty Group", + "group_id": str(group_id), "tvg_name": "CLEUName", "tvg_logo": "cleulogo.png", "urls": [], @@ -991,14 +1253,15 @@ def test_list_channel_urls_forbidden_for_non_admin( non_admin_user_client: TestClient, admin_user_client: TestClient, ): - # Setup priority and create channel with admin - priority1 = MockPriority(id=100, description="High") - db_session.add(priority1) - db_session.commit() + # Setup priority, group and create a channel admin + group_id = create_mock_priorities_and_group( + db_session, [(100, "High")], "URL Del Forbidden Group" + ) + channel_data = { "tvg_id": "ch_list_url_forbidden.tv", "name": "Channel List URL Forbidden", - "group_title": "URL List Forbidden Group", + "group_id": str(group_id), "tvg_name": "CLUFName", "tvg_logo": "cluflogo.png", "urls": [{"url": "http://list_url_forbidden.com", "priority_id": 100}], diff --git a/tests/routers/test_groups.py b/tests/routers/test_groups.py new file mode 100644 index 0000000..16f785b --- /dev/null +++ b/tests/routers/test_groups.py @@ -0,0 +1,422 @@ +import uuid +from datetime import datetime, timezone + +import pytest +from fastapi import status +from sqlalchemy.orm import Session + +from app.auth.dependencies import get_current_user +from app.routers.groups import router as groups_router +from app.utils.database import get_db + +# Import mocks and fixtures +from tests.utils.auth_test_fixtures import ( + admin_user_client, + db_session, + non_admin_user_client, +) +from tests.utils.db_mocks import MockChannelDB, MockGroup, SQLiteUUID + +# --- Test Cases For Group Creation --- + + +def test_create_group_success(db_session: Session, admin_user_client): + group_data = {"name": "Test Group", "sort_order": 1} + response = admin_user_client.post("/groups/", json=group_data) + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["name"] == "Test Group" + assert data["sort_order"] == 1 + assert "id" in data + assert "created_at" in data + assert "updated_at" in data + + # Verify in DB + db_group = ( + db_session.query(MockGroup).filter(MockGroup.name == "Test Group").first() + ) + assert db_group is not None + assert db_group.sort_order == 1 + + +def test_create_group_duplicate(db_session: Session, admin_user_client): + # Create initial group + initial_group = MockGroup( + id=uuid.uuid4(), + name="Duplicate Group", + sort_order=1, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + db_session.add(initial_group) + db_session.commit() + + # Attempt to create duplicate + response = admin_user_client.post( + "/groups/", json={"name": "Duplicate Group", "sort_order": 2} + ) + assert response.status_code == status.HTTP_409_CONFLICT + assert "already exists" in response.json()["detail"] + + +def test_create_group_forbidden_for_non_admin( + db_session: Session, non_admin_user_client +): + response = non_admin_user_client.post( + "/groups/", json={"name": "Forbidden Group", "sort_order": 1} + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + assert "required roles" in response.json()["detail"] + + +# --- Test Cases For Get Group --- + + +def test_get_group_success(db_session: Session, admin_user_client): + # Create a group first + test_group = MockGroup( + id=uuid.uuid4(), + name="Get Me Group", + sort_order=1, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + db_session.add(test_group) + db_session.commit() + + response = admin_user_client.get(f"/groups/{test_group.id}") + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["id"] == str(test_group.id) + assert data["name"] == "Get Me Group" + assert data["sort_order"] == 1 + + +def test_get_group_not_found(db_session: Session, admin_user_client): + random_uuid = uuid.uuid4() + response = admin_user_client.get(f"/groups/{random_uuid}") + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "Group not found" in response.json()["detail"] + + +# --- Test Cases For Update Group --- + + +def test_update_group_success(db_session: Session, admin_user_client): + # Create initial group + group_id = uuid.uuid4() + test_group = MockGroup( + id=group_id, + name="Update Me", + sort_order=1, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + db_session.add(test_group) + db_session.commit() + + update_data = {"name": "Updated Name", "sort_order": 2} + response = admin_user_client.put(f"/groups/{group_id}", json=update_data) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["name"] == "Updated Name" + assert data["sort_order"] == 2 + + # Verify in DB + db_group = db_session.query(MockGroup).filter(MockGroup.id == group_id).first() + assert db_group.name == "Updated Name" + assert db_group.sort_order == 2 + + +def test_update_group_conflict(db_session: Session, admin_user_client): + # Create two groups + group1 = MockGroup( + id=uuid.uuid4(), + name="Group One", + sort_order=1, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + group2 = MockGroup( + id=uuid.uuid4(), + name="Group Two", + sort_order=2, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + db_session.add_all([group1, group2]) + db_session.commit() + + # Try to rename group2 to conflict with group1 + response = admin_user_client.put(f"/groups/{group2.id}", json={"name": "Group One"}) + assert response.status_code == status.HTTP_409_CONFLICT + assert "already exists" in response.json()["detail"] + + +def test_update_group_not_found(db_session: Session, admin_user_client): + random_uuid = uuid.uuid4() + response = admin_user_client.put( + f"/groups/{random_uuid}", json={"name": "Non-existent"} + ) + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "Group not found" in response.json()["detail"] + + +def test_update_group_forbidden_for_non_admin( + db_session: Session, non_admin_user_client, admin_user_client +): + # Create group with admin + group_id = uuid.uuid4() + test_group = MockGroup( + id=group_id, + name="Admin Created", + sort_order=1, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + db_session.add(test_group) + db_session.commit() + + # Attempt update with non-admin + response = non_admin_user_client.put( + f"/groups/{group_id}", json={"name": "Non-Admin Update"} + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + assert "required roles" in response.json()["detail"] + + +# --- Test Cases For Delete Group --- + + +def test_delete_group_success(db_session: Session, admin_user_client): + # Create group + group_id = uuid.uuid4() + test_group = MockGroup( + id=group_id, + name="Delete Me", + sort_order=1, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + db_session.add(test_group) + db_session.commit() + + # Verify exists before delete + assert ( + db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is not None + ) + + response = admin_user_client.delete(f"/groups/{group_id}") + assert response.status_code == status.HTTP_204_NO_CONTENT + + # Verify deleted + assert db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is None + + +def test_delete_group_with_channels_fails(db_session: Session, admin_user_client): + # Create group with channel + group_id = uuid.uuid4() + test_group = MockGroup( + id=group_id, + name="Group With Channels", + sort_order=1, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + db_session.add(test_group) + + # Create channel in this group + test_channel = MockChannelDB( + id=uuid.uuid4(), + tvg_id="channel1.tv", + name="Channel 1", + group_id=group_id, + tvg_name="Channel1", + tvg_logo="logo.png", + ) + db_session.add(test_channel) + db_session.commit() + + response = admin_user_client.delete(f"/groups/{group_id}") + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "existing channels" in response.json()["detail"] + + # Verify group still exists + assert ( + db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is not None + ) + + +def test_delete_group_not_found(db_session: Session, admin_user_client): + random_uuid = uuid.uuid4() + response = admin_user_client.delete(f"/groups/{random_uuid}") + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "Group not found" in response.json()["detail"] + + +def test_delete_group_forbidden_for_non_admin( + db_session: Session, non_admin_user_client, admin_user_client +): + # Create group with admin + group_id = uuid.uuid4() + test_group = MockGroup( + id=group_id, + name="Admin Created", + sort_order=1, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + db_session.add(test_group) + db_session.commit() + + # Attempt delete with non-admin + response = non_admin_user_client.delete(f"/groups/{group_id}") + assert response.status_code == status.HTTP_403_FORBIDDEN + assert "required roles" in response.json()["detail"] + + # Verify group still exists + assert ( + db_session.query(MockGroup).filter(MockGroup.id == group_id).first() is not None + ) + + +# --- Test Cases For List Groups --- + + +def test_list_groups_empty(db_session: Session, admin_user_client): + response = admin_user_client.get("/groups/") + assert response.status_code == status.HTTP_200_OK + assert response.json() == [] + + +def test_list_groups_with_data(db_session: Session, admin_user_client): + # Create some groups + groups = [ + MockGroup( + id=uuid.uuid4(), + name=f"Group {i}", + sort_order=i, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + for i in range(3) + ] + db_session.add_all(groups) + db_session.commit() + + response = admin_user_client.get("/groups/") + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data) == 3 + assert data[0]["sort_order"] == 0 # Should be sorted by sort_order + assert data[1]["sort_order"] == 1 + assert data[2]["sort_order"] == 2 + + +# --- Test Cases For Sort Order Updates --- + + +def test_update_group_sort_order_success(db_session: Session, admin_user_client): + # Create group + group_id = uuid.uuid4() + test_group = MockGroup( + id=group_id, + name="Sort Me", + sort_order=1, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + db_session.add(test_group) + db_session.commit() + + response = admin_user_client.put(f"/groups/{group_id}/sort", json={"sort_order": 5}) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["sort_order"] == 5 + + # Verify in DB + db_group = db_session.query(MockGroup).filter(MockGroup.id == group_id).first() + assert db_group.sort_order == 5 + + +def test_update_group_sort_order_not_found(db_session: Session, admin_user_client): + """Test that updating sort order for non-existent group returns 404""" + random_uuid = uuid.uuid4() + response = admin_user_client.put( + f"/groups/{random_uuid}/sort", json={"sort_order": 5} + ) + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "Group not found" in response.json()["detail"] + + +def test_bulk_update_sort_orders_success(db_session: Session, admin_user_client): + # Create groups + groups = [ + MockGroup( + id=uuid.uuid4(), + name=f"Group {i}", + sort_order=i, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + for i in range(3) + ] + print(groups) + db_session.add_all(groups) + db_session.commit() + + # Bulk update sort orders (reverse order) + bulk_data = { + "groups": [ + {"group_id": str(groups[0].id), "sort_order": 2}, + {"group_id": str(groups[1].id), "sort_order": 1}, + {"group_id": str(groups[2].id), "sort_order": 0}, + ] + } + response = admin_user_client.post("/groups/reorder", json=bulk_data) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data) == 3 + + # Create a dictionary for easy lookup of returned group data by ID + returned_groups_map = {item["id"]: item for item in data} + + # Verify each group has its expected new sort_order + assert returned_groups_map[str(groups[0].id)]["sort_order"] == 2 + assert returned_groups_map[str(groups[1].id)]["sort_order"] == 1 + assert returned_groups_map[str(groups[2].id)]["sort_order"] == 0 + + # Verify in DB + db_groups = db_session.query(MockGroup).order_by(MockGroup.sort_order).all() + assert db_groups[0].sort_order == 2 + assert db_groups[1].sort_order == 1 + assert db_groups[2].sort_order == 0 + + +def test_bulk_update_sort_orders_invalid_group(db_session: Session, admin_user_client): + # Create one group + group_id = uuid.uuid4() + test_group = MockGroup( + id=group_id, + name="Valid Group", + sort_order=1, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + db_session.add(test_group) + db_session.commit() + + # Try to update with invalid group + bulk_data = { + "groups": [ + {"group_id": str(group_id), "sort_order": 2}, + {"group_id": str(uuid.uuid4()), "sort_order": 1}, # Invalid group + ] + } + response = admin_user_client.post("/groups/reorder", json=bulk_data) + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "not found" in response.json()["detail"] + + # Verify original sort order unchanged + db_group = db_session.query(MockGroup).filter(MockGroup.id == group_id).first() + assert db_group.sort_order == 1 diff --git a/tests/routers/test_priorities.py b/tests/routers/test_priorities.py index 1673bb8..1390520 100644 --- a/tests/routers/test_priorities.py +++ b/tests/routers/test_priorities.py @@ -1,3 +1,6 @@ +import uuid +from datetime import datetime, timezone + import pytest from fastapi import status from sqlalchemy.orm import Session @@ -11,7 +14,7 @@ from tests.utils.auth_test_fixtures import ( db_session, non_admin_user_client, ) -from tests.utils.db_mocks import MockChannelDB, MockChannelURL, MockPriority +from tests.utils.db_mocks import MockChannelDB, MockChannelURL, MockGroup, MockPriority # --- Test Cases For Priority Creation --- @@ -147,7 +150,15 @@ def test_delete_priority_not_found(db_session: Session, admin_user_client): def test_delete_priority_in_use(db_session: Session, admin_user_client): # Create a priority and a channel URL using it priority = MockPriority(id=100, description="In Use") - db_session.add(priority) + group_id = uuid.uuid4() + test_group = MockGroup( + id=group_id, + name="Group With Channels", + sort_order=1, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + db_session.add_all([priority, test_group]) db_session.commit() # Create a channel first @@ -156,7 +167,7 @@ def test_delete_priority_in_use(db_session: Session, admin_user_client): tvg_id="test.tv", tvg_name="Test", tvg_logo="test.png", - group_title="Test Group", + group_id=group_id, ) db_session.add(channel) db_session.commit() diff --git a/tests/utils/auth_test_fixtures.py b/tests/utils/auth_test_fixtures.py index 0f95630..d212480 100644 --- a/tests/utils/auth_test_fixtures.py +++ b/tests/utils/auth_test_fixtures.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import Session from app.auth.dependencies import get_current_user from app.models.auth import CognitoUser from app.routers.channels import router as channels_router +from app.routers.groups import router as groups_router from app.routers.playlist import router as playlist_router from app.routers.priorities import router as priorities_router from app.utils.database import get_db @@ -60,6 +61,7 @@ def admin_user_client(db_session: Session): test_app.include_router(channels_router) test_app.include_router(priorities_router) test_app.include_router(playlist_router) + test_app.include_router(groups_router) test_app.dependency_overrides[get_db] = mock_get_db test_app.dependency_overrides[get_current_user] = mock_get_current_user_admin with TestClient(test_app) as test_client: @@ -73,6 +75,7 @@ def non_admin_user_client(db_session: Session): test_app.include_router(channels_router) test_app.include_router(priorities_router) test_app.include_router(playlist_router) + test_app.include_router(groups_router) test_app.dependency_overrides[get_db] = mock_get_db test_app.dependency_overrides[get_current_user] = mock_get_current_user_non_admin with TestClient(test_app) as test_client: diff --git a/tests/utils/db_mocks.py b/tests/utils/db_mocks.py index 4427282..ab22477 100644 --- a/tests/utils/db_mocks.py +++ b/tests/utils/db_mocks.py @@ -4,41 +4,25 @@ from unittest.mock import MagicMock, patch import pytest from sqlalchemy import ( - TEXT, Boolean, Column, DateTime, ForeignKey, Integer, String, - TypeDecorator, UniqueConstraint, create_engine, ) -from sqlalchemy.orm import declarative_base, sessionmaker +from sqlalchemy.orm import declarative_base, relationship, sessionmaker from sqlalchemy.pool import StaticPool +# Import the actual UUID_COLUMN_TYPE and SQLiteUUID from app.models.db +from app.models.db import UUID_COLUMN_TYPE, SQLiteUUID + # Create a mock-specific Base class for testing MockBase = declarative_base() -class SQLiteUUID(TypeDecorator): - """Enables UUID support for SQLite.""" - - impl = TEXT - cache_ok = True - - def process_bind_param(self, value, dialect): - if value is None: - return value - return str(value) - - def process_result_value(self, value, dialect): - if value is None: - return value - return uuid.UUID(value) - - # Model classes for testing - prefix with Mock to avoid pytest collection class MockPriority(MockBase): __tablename__ = "priorities" @@ -46,16 +30,28 @@ class MockPriority(MockBase): description = Column(String, nullable=False) +class MockGroup(MockBase): + __tablename__ = "groups" + id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4) + name = Column(String, nullable=False, unique=True) + sort_order = Column(Integer, nullable=False, default=0) + created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) + updated_at = Column( + DateTime, + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + ) + channels = relationship("MockChannelDB", back_populates="group") + + class MockChannelDB(MockBase): __tablename__ = "channels" - id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4) + id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4) tvg_id = Column(String, nullable=False) name = Column(String, nullable=False) - group_title = Column(String, nullable=False) + group_id = Column(UUID_COLUMN_TYPE, ForeignKey("groups.id"), nullable=False) tvg_name = Column(String) - __table_args__ = ( - UniqueConstraint("group_title", "name", name="uix_group_title_name"), - ) + __table_args__ = (UniqueConstraint("group_id", "name", name="uix_group_id_name"),) tvg_logo = Column(String) created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) updated_at = Column( @@ -63,13 +59,14 @@ class MockChannelDB(MockBase): default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc), ) + group = relationship("MockGroup", back_populates="channels") class MockChannelURL(MockBase): __tablename__ = "channels_urls" - id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4) + id = Column(UUID_COLUMN_TYPE, primary_key=True, default=uuid.uuid4) channel_id = Column( - SQLiteUUID(), ForeignKey("channels.id", ondelete="CASCADE"), nullable=False + UUID_COLUMN_TYPE, ForeignKey("channels.id", ondelete="CASCADE"), nullable=False ) url = Column(String, nullable=False) in_use = Column(Boolean, default=False, nullable=False) @@ -82,6 +79,31 @@ class MockChannelURL(MockBase): ) +def create_mock_priorities_and_group(db_session, priorities, group_name): + """Create mock priorities and group for testing purposes. + + Args: + db_session: SQLAlchemy session object + priorities: List of (id, description) tuples for priorities to create + group_name: Name for the new mock group + + Returns: + UUID: The ID of the created group + """ + # Create priorities + priority_objects = [ + MockPriority(id=priority_id, description=description) + for priority_id, description in priorities + ] + + # Create group + group = MockGroup(name=group_name) + db_session.add_all(priority_objects + [group]) + db_session.commit() + + return group.id + + # Create test engine engine_mock = create_engine( "sqlite:///:memory:",