Linted and formatted all files
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
# Test constants
|
||||
@@ -7,12 +8,15 @@ TEST_CLIENT_ID = "test_client_id"
|
||||
TEST_CLIENT_SECRET = "test_client_secret"
|
||||
|
||||
# Patch constants before importing the module
|
||||
with patch("app.utils.constants.COGNITO_CLIENT_ID", TEST_CLIENT_ID), \
|
||||
patch("app.utils.constants.COGNITO_CLIENT_SECRET", TEST_CLIENT_SECRET):
|
||||
from app.auth.cognito import initiate_auth, get_user_from_token
|
||||
with (
|
||||
patch("app.utils.constants.COGNITO_CLIENT_ID", TEST_CLIENT_ID),
|
||||
patch("app.utils.constants.COGNITO_CLIENT_SECRET", TEST_CLIENT_SECRET),
|
||||
):
|
||||
from app.auth.cognito import get_user_from_token, initiate_auth
|
||||
from app.models.auth import CognitoUser
|
||||
from app.utils.constants import USER_ROLE_ATTRIBUTE
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_cognito_client():
|
||||
with patch("app.auth.cognito.cognito_client") as mock_client:
|
||||
@@ -26,13 +30,14 @@ def mock_cognito_client():
|
||||
)
|
||||
yield mock_client
|
||||
|
||||
|
||||
def test_initiate_auth_success(mock_cognito_client):
|
||||
# Mock successful authentication response
|
||||
mock_cognito_client.initiate_auth.return_value = {
|
||||
"AuthenticationResult": {
|
||||
"AccessToken": "mock_access_token",
|
||||
"IdToken": "mock_id_token",
|
||||
"RefreshToken": "mock_refresh_token"
|
||||
"RefreshToken": "mock_refresh_token",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,104 +45,125 @@ def test_initiate_auth_success(mock_cognito_client):
|
||||
assert result == {
|
||||
"AccessToken": "mock_access_token",
|
||||
"IdToken": "mock_id_token",
|
||||
"RefreshToken": "mock_refresh_token"
|
||||
"RefreshToken": "mock_refresh_token",
|
||||
}
|
||||
|
||||
|
||||
def test_initiate_auth_with_secret_hash(mock_cognito_client):
|
||||
with patch("app.auth.cognito.calculate_secret_hash", return_value="mocked_secret_hash") as mock_hash:
|
||||
with patch(
|
||||
"app.auth.cognito.calculate_secret_hash", return_value="mocked_secret_hash"
|
||||
) as mock_hash:
|
||||
mock_cognito_client.initiate_auth.return_value = {
|
||||
"AuthenticationResult": {"AccessToken": "token"}
|
||||
}
|
||||
|
||||
result = initiate_auth("test_user", "test_pass")
|
||||
|
||||
|
||||
initiate_auth("test_user", "test_pass")
|
||||
|
||||
# Verify calculate_secret_hash was called
|
||||
mock_hash.assert_called_once_with("test_user", TEST_CLIENT_ID, TEST_CLIENT_SECRET)
|
||||
|
||||
mock_hash.assert_called_once_with(
|
||||
"test_user", TEST_CLIENT_ID, TEST_CLIENT_SECRET
|
||||
)
|
||||
|
||||
# Verify SECRET_HASH was included in auth params
|
||||
call_args = mock_cognito_client.initiate_auth.call_args[1]
|
||||
assert "SECRET_HASH" in call_args["AuthParameters"]
|
||||
assert call_args["AuthParameters"]["SECRET_HASH"] == "mocked_secret_hash"
|
||||
|
||||
|
||||
def test_initiate_auth_not_authorized(mock_cognito_client):
|
||||
mock_cognito_client.initiate_auth.side_effect = mock_cognito_client.exceptions.NotAuthorizedException()
|
||||
|
||||
mock_cognito_client.initiate_auth.side_effect = (
|
||||
mock_cognito_client.exceptions.NotAuthorizedException()
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
initiate_auth("invalid_user", "wrong_pass")
|
||||
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid username or password"
|
||||
|
||||
|
||||
def test_initiate_auth_user_not_found(mock_cognito_client):
|
||||
mock_cognito_client.initiate_auth.side_effect = mock_cognito_client.exceptions.UserNotFoundException()
|
||||
|
||||
mock_cognito_client.initiate_auth.side_effect = (
|
||||
mock_cognito_client.exceptions.UserNotFoundException()
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
initiate_auth("nonexistent_user", "any_pass")
|
||||
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert exc_info.value.detail == "User not found"
|
||||
|
||||
|
||||
def test_initiate_auth_generic_error(mock_cognito_client):
|
||||
mock_cognito_client.initiate_auth.side_effect = Exception("Some error")
|
||||
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
initiate_auth("test_user", "test_pass")
|
||||
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert "An error occurred during authentication" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_get_user_from_token_success(mock_cognito_client):
|
||||
mock_response = {
|
||||
"Username": "test_user",
|
||||
"UserAttributes": [
|
||||
{"Name": "sub", "Value": "123"},
|
||||
{"Name": USER_ROLE_ATTRIBUTE, "Value": "admin,user"}
|
||||
]
|
||||
{"Name": USER_ROLE_ATTRIBUTE, "Value": "admin,user"},
|
||||
],
|
||||
}
|
||||
mock_cognito_client.get_user.return_value = mock_response
|
||||
|
||||
|
||||
result = get_user_from_token("valid_token")
|
||||
|
||||
|
||||
assert isinstance(result, CognitoUser)
|
||||
assert result.username == "test_user"
|
||||
assert set(result.roles) == {"admin", "user"}
|
||||
|
||||
|
||||
def test_get_user_from_token_no_roles(mock_cognito_client):
|
||||
mock_response = {
|
||||
"Username": "test_user",
|
||||
"UserAttributes": [{"Name": "sub", "Value": "123"}]
|
||||
"UserAttributes": [{"Name": "sub", "Value": "123"}],
|
||||
}
|
||||
mock_cognito_client.get_user.return_value = mock_response
|
||||
|
||||
|
||||
result = get_user_from_token("valid_token")
|
||||
|
||||
|
||||
assert isinstance(result, CognitoUser)
|
||||
assert result.username == "test_user"
|
||||
assert result.roles == []
|
||||
|
||||
|
||||
def test_get_user_from_token_invalid_token(mock_cognito_client):
|
||||
mock_cognito_client.get_user.side_effect = mock_cognito_client.exceptions.NotAuthorizedException()
|
||||
|
||||
mock_cognito_client.get_user.side_effect = (
|
||||
mock_cognito_client.exceptions.NotAuthorizedException()
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_user_from_token("invalid_token")
|
||||
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid or expired token."
|
||||
|
||||
|
||||
def test_get_user_from_token_user_not_found(mock_cognito_client):
|
||||
mock_cognito_client.get_user.side_effect = mock_cognito_client.exceptions.UserNotFoundException()
|
||||
|
||||
mock_cognito_client.get_user.side_effect = (
|
||||
mock_cognito_client.exceptions.UserNotFoundException()
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_user_from_token("token_for_nonexistent_user")
|
||||
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "User not found or invalid token."
|
||||
|
||||
|
||||
def test_get_user_from_token_generic_error(mock_cognito_client):
|
||||
mock_cognito_client.get_user.side_effect = Exception("Some error")
|
||||
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_user_from_token("test_token")
|
||||
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert "Token verification failed" in exc_info.value.detail
|
||||
assert "Token verification failed" in exc_info.value.detail
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import os
|
||||
import pytest
|
||||
import importlib
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi import HTTPException, Depends, Request
|
||||
from app.auth.dependencies import get_current_user, require_roles, oauth2_scheme
|
||||
|
||||
from app.auth.dependencies import get_current_user, oauth2_scheme, require_roles
|
||||
from app.models.auth import CognitoUser
|
||||
|
||||
# Mock user for testing
|
||||
@@ -11,24 +13,30 @@ TEST_USER = CognitoUser(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
roles=["admin", "user"],
|
||||
groups=["test_group"]
|
||||
groups=["test_group"],
|
||||
)
|
||||
|
||||
|
||||
# Mock the underlying get_user_from_token function
|
||||
def mock_get_user_from_token(token: str) -> CognitoUser:
|
||||
if token == "valid_token":
|
||||
return TEST_USER
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
|
||||
# Mock endpoint for testing the require_roles decorator
|
||||
@require_roles("admin")
|
||||
async def mock_protected_endpoint(user: CognitoUser = Depends(get_current_user)):
|
||||
return {"message": "Success", "user": user.username}
|
||||
|
||||
|
||||
# Patch the get_user_from_token function for testing
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth(monkeypatch):
|
||||
monkeypatch.setattr("app.auth.dependencies.get_user_from_token", mock_get_user_from_token)
|
||||
monkeypatch.setattr(
|
||||
"app.auth.dependencies.get_user_from_token", mock_get_user_from_token
|
||||
)
|
||||
|
||||
|
||||
# Test get_current_user dependency
|
||||
def test_get_current_user_success():
|
||||
@@ -37,54 +45,53 @@ def test_get_current_user_success():
|
||||
assert user.username == "testuser"
|
||||
assert user.roles == ["admin", "user"]
|
||||
|
||||
|
||||
def test_get_current_user_invalid_token():
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
get_current_user("invalid_token")
|
||||
assert exc.value.status_code == 401
|
||||
|
||||
|
||||
# Test require_roles decorator
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_roles_success():
|
||||
# Create test user with required role
|
||||
user = CognitoUser(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
roles=["admin"],
|
||||
groups=[]
|
||||
username="testuser", email="test@example.com", roles=["admin"], groups=[]
|
||||
)
|
||||
|
||||
|
||||
result = await mock_protected_endpoint(user=user)
|
||||
assert result == {"message": "Success", "user": "testuser"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_roles_missing_role():
|
||||
# Create test user without required role
|
||||
user = CognitoUser(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
roles=["user"],
|
||||
groups=[]
|
||||
username="testuser", email="test@example.com", roles=["user"], groups=[]
|
||||
)
|
||||
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await mock_protected_endpoint(user=user)
|
||||
assert exc.value.status_code == 403
|
||||
assert exc.value.detail == "You do not have the required roles to access this endpoint."
|
||||
assert (
|
||||
exc.value.detail
|
||||
== "You do not have the required roles to access this endpoint."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_roles_no_roles():
|
||||
# Create test user with no roles
|
||||
user = CognitoUser(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
roles=[],
|
||||
groups=[]
|
||||
username="testuser", email="test@example.com", roles=[], groups=[]
|
||||
)
|
||||
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await mock_protected_endpoint(user=user)
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_roles_multiple_roles():
|
||||
# Test requiring multiple roles
|
||||
@@ -97,7 +104,7 @@ async def test_require_roles_multiple_roles():
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
roles=["admin", "super_user", "user"],
|
||||
groups=[]
|
||||
groups=[],
|
||||
)
|
||||
result = await mock_multi_role_endpoint(user=user_with_roles)
|
||||
assert result == {"message": "Success"}
|
||||
@@ -107,56 +114,62 @@ async def test_require_roles_multiple_roles():
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
roles=["admin", "user"],
|
||||
groups=[]
|
||||
groups=[],
|
||||
)
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await mock_multi_role_endpoint(user=user_missing_role)
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth2_scheme_configuration():
|
||||
# Verify that we have a properly configured OAuth2PasswordBearer instance
|
||||
assert isinstance(oauth2_scheme, OAuth2PasswordBearer)
|
||||
|
||||
|
||||
# Create a mock request with no Authorization header
|
||||
mock_request = Request(scope={
|
||||
'type': 'http',
|
||||
'headers': [],
|
||||
'method': 'GET',
|
||||
'scheme': 'http',
|
||||
'path': '/',
|
||||
'query_string': b'',
|
||||
'client': ('127.0.0.1', 8000)
|
||||
})
|
||||
|
||||
mock_request = Request(
|
||||
scope={
|
||||
"type": "http",
|
||||
"headers": [],
|
||||
"method": "GET",
|
||||
"scheme": "http",
|
||||
"path": "/",
|
||||
"query_string": b"",
|
||||
"client": ("127.0.0.1", 8000),
|
||||
}
|
||||
)
|
||||
|
||||
# Test that the scheme raises 401 when no token is provided
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await oauth2_scheme(mock_request)
|
||||
assert exc.value.status_code == 401
|
||||
assert exc.value.detail == "Not authenticated"
|
||||
|
||||
|
||||
def test_mock_auth_import(monkeypatch):
|
||||
# Save original env var value
|
||||
original_value = os.environ.get("MOCK_AUTH")
|
||||
|
||||
|
||||
try:
|
||||
# Set MOCK_AUTH to true
|
||||
monkeypatch.setenv("MOCK_AUTH", "true")
|
||||
|
||||
|
||||
# Reload the dependencies module to trigger the import condition
|
||||
import app.auth.dependencies
|
||||
|
||||
importlib.reload(app.auth.dependencies)
|
||||
|
||||
|
||||
# Verify that mock_get_user_from_token was imported
|
||||
from app.auth.dependencies import get_user_from_token
|
||||
assert get_user_from_token.__module__ == 'app.auth.mock_auth'
|
||||
|
||||
|
||||
assert get_user_from_token.__module__ == "app.auth.mock_auth"
|
||||
|
||||
finally:
|
||||
# Restore original env var
|
||||
if original_value is None:
|
||||
monkeypatch.delenv("MOCK_AUTH", raising=False)
|
||||
else:
|
||||
monkeypatch.setenv("MOCK_AUTH", original_value)
|
||||
|
||||
|
||||
# Reload again to restore original state
|
||||
importlib.reload(app.auth.dependencies)
|
||||
importlib.reload(app.auth.dependencies)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.auth.mock_auth import mock_get_user_from_token, mock_initiate_auth
|
||||
from app.models.auth import CognitoUser
|
||||
|
||||
|
||||
def test_mock_get_user_from_token_success():
|
||||
"""Test successful token validation returns expected user"""
|
||||
user = mock_get_user_from_token("testuser")
|
||||
@@ -10,27 +12,30 @@ def test_mock_get_user_from_token_success():
|
||||
assert user.username == "testuser"
|
||||
assert user.roles == ["admin"]
|
||||
|
||||
|
||||
def test_mock_get_user_from_token_invalid():
|
||||
"""Test invalid token raises expected exception"""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
mock_get_user_from_token("invalid_token")
|
||||
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail == "Invalid mock token - use 'testuser'"
|
||||
|
||||
|
||||
def test_mock_initiate_auth():
|
||||
"""Test mock authentication returns expected token response"""
|
||||
result = mock_initiate_auth("any_user", "any_password")
|
||||
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["AccessToken"] == "testuser"
|
||||
assert result["ExpiresIn"] == 3600
|
||||
assert result["TokenType"] == "Bearer"
|
||||
|
||||
|
||||
def test_mock_initiate_auth_different_credentials():
|
||||
"""Test mock authentication works with any credentials"""
|
||||
result1 = mock_initiate_auth("user1", "pass1")
|
||||
result2 = mock_initiate_auth("user2", "pass2")
|
||||
|
||||
|
||||
# Should return same mock token regardless of credentials
|
||||
assert result1 == result2
|
||||
assert result1 == result2
|
||||
|
||||
@@ -1,34 +1,35 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import HTTPException, status
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.main import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_successful_auth():
|
||||
return {
|
||||
"AccessToken": "mock_access_token",
|
||||
"IdToken": "mock_id_token",
|
||||
"RefreshToken": "mock_refresh_token"
|
||||
"RefreshToken": "mock_refresh_token",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_successful_auth_no_refresh():
|
||||
return {
|
||||
"AccessToken": "mock_access_token",
|
||||
"IdToken": "mock_id_token"
|
||||
}
|
||||
return {"AccessToken": "mock_access_token", "IdToken": "mock_id_token"}
|
||||
|
||||
|
||||
def test_signin_success(mock_successful_auth):
|
||||
"""Test successful signin with all tokens"""
|
||||
with patch('app.routers.auth.initiate_auth', return_value=mock_successful_auth):
|
||||
with patch("app.routers.auth.initiate_auth", return_value=mock_successful_auth):
|
||||
response = client.post(
|
||||
"/auth/signin",
|
||||
json={"username": "testuser", "password": "testpass"}
|
||||
"/auth/signin", json={"username": "testuser", "password": "testpass"}
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["access_token"] == "mock_access_token"
|
||||
@@ -36,14 +37,16 @@ def test_signin_success(mock_successful_auth):
|
||||
assert data["refresh_token"] == "mock_refresh_token"
|
||||
assert data["token_type"] == "Bearer"
|
||||
|
||||
|
||||
def test_signin_success_no_refresh(mock_successful_auth_no_refresh):
|
||||
"""Test successful signin without refresh token"""
|
||||
with patch('app.routers.auth.initiate_auth', return_value=mock_successful_auth_no_refresh):
|
||||
with patch(
|
||||
"app.routers.auth.initiate_auth", return_value=mock_successful_auth_no_refresh
|
||||
):
|
||||
response = client.post(
|
||||
"/auth/signin",
|
||||
json={"username": "testuser", "password": "testpass"}
|
||||
"/auth/signin", json={"username": "testuser", "password": "testpass"}
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["access_token"] == "mock_access_token"
|
||||
@@ -51,57 +54,48 @@ def test_signin_success_no_refresh(mock_successful_auth_no_refresh):
|
||||
assert data["refresh_token"] is None
|
||||
assert data["token_type"] == "Bearer"
|
||||
|
||||
|
||||
def test_signin_invalid_input():
|
||||
"""Test signin with invalid input format"""
|
||||
# Missing password
|
||||
response = client.post(
|
||||
"/auth/signin",
|
||||
json={"username": "testuser"}
|
||||
)
|
||||
response = client.post("/auth/signin", json={"username": "testuser"})
|
||||
assert response.status_code == 422
|
||||
|
||||
# Missing username
|
||||
response = client.post(
|
||||
"/auth/signin",
|
||||
json={"password": "testpass"}
|
||||
)
|
||||
response = client.post("/auth/signin", json={"password": "testpass"})
|
||||
assert response.status_code == 422
|
||||
|
||||
# Empty payload
|
||||
response = client.post(
|
||||
"/auth/signin",
|
||||
json={}
|
||||
)
|
||||
response = client.post("/auth/signin", json={})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_signin_auth_failure():
|
||||
"""Test signin with authentication failure"""
|
||||
with patch('app.routers.auth.initiate_auth') as mock_auth:
|
||||
with patch("app.routers.auth.initiate_auth") as mock_auth:
|
||||
mock_auth.side_effect = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password"
|
||||
detail="Invalid username or password",
|
||||
)
|
||||
response = client.post(
|
||||
"/auth/signin",
|
||||
json={"username": "testuser", "password": "wrongpass"}
|
||||
"/auth/signin", json={"username": "testuser", "password": "wrongpass"}
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 401
|
||||
data = response.json()
|
||||
assert data["detail"] == "Invalid username or password"
|
||||
|
||||
|
||||
def test_signin_user_not_found():
|
||||
"""Test signin with non-existent user"""
|
||||
with patch('app.routers.auth.initiate_auth') as mock_auth:
|
||||
with patch("app.routers.auth.initiate_auth") as mock_auth:
|
||||
mock_auth.side_effect = HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
response = client.post(
|
||||
"/auth/signin",
|
||||
json={"username": "nonexistent", "password": "testpass"}
|
||||
"/auth/signin", json={"username": "nonexistent", "password": "testpass"}
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert data["detail"] == "User not found"
|
||||
assert data["detail"] == "User not found"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,19 +1,24 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.main import app, lifespan
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Test client for FastAPI app"""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_root_endpoint(client):
|
||||
"""Test root endpoint returns expected message"""
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "IPTV Updater API"}
|
||||
|
||||
|
||||
def test_openapi_schema_generation(client):
|
||||
"""Test OpenAPI schema is properly generated"""
|
||||
# First call - generate schema
|
||||
@@ -23,7 +28,7 @@ def test_openapi_schema_generation(client):
|
||||
assert schema["openapi"] == "3.1.0"
|
||||
assert "securitySchemes" in schema["components"]
|
||||
assert "Bearer" in schema["components"]["securitySchemes"]
|
||||
|
||||
|
||||
# Test empty components initialization
|
||||
with patch("app.main.get_openapi", return_value={"info": {}}):
|
||||
# Clear cached schema
|
||||
@@ -35,26 +40,28 @@ def test_openapi_schema_generation(client):
|
||||
assert "components" in schema
|
||||
assert "schemas" in schema["components"]
|
||||
|
||||
|
||||
def test_openapi_schema_caching(mocker):
|
||||
"""Test OpenAPI schema caching behavior"""
|
||||
# Clear any existing schema
|
||||
app.openapi_schema = None
|
||||
|
||||
|
||||
# Mock get_openapi to return test schema
|
||||
mock_schema = {"test": "schema"}
|
||||
mocker.patch("app.main.get_openapi", return_value=mock_schema)
|
||||
|
||||
|
||||
# First call - should call get_openapi
|
||||
schema = app.openapi()
|
||||
assert schema == mock_schema
|
||||
assert app.openapi_schema == mock_schema
|
||||
|
||||
|
||||
# Second call - should return cached schema
|
||||
with patch("app.main.get_openapi") as mock_get_openapi:
|
||||
schema = app.openapi()
|
||||
assert schema == mock_schema
|
||||
mock_get_openapi.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan_init_db(mocker):
|
||||
"""Test lifespan manager initializes database"""
|
||||
@@ -63,6 +70,7 @@ async def test_lifespan_init_db(mocker):
|
||||
pass # Just enter/exit context
|
||||
mock_init_db.assert_called_once()
|
||||
|
||||
|
||||
def test_router_inclusion():
|
||||
"""Test all routers are properly included"""
|
||||
route_paths = {route.path for route in app.routes}
|
||||
@@ -70,4 +78,4 @@ def test_router_inclusion():
|
||||
assert any(path.startswith("/auth") for path in route_paths)
|
||||
assert any(path.startswith("/channels") for path in route_paths)
|
||||
assert any(path.startswith("/playlist") for path in route_paths)
|
||||
assert any(path.startswith("/priorities") for path in route_paths)
|
||||
assert any(path.startswith("/priorities") for path in route_paths)
|
||||
|
||||
@@ -1,17 +1,30 @@
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import patch, MagicMock
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlalchemy.orm import Session, sessionmaker, declarative_base
|
||||
from sqlalchemy import create_engine, TypeDecorator, TEXT, Column, String, DateTime, UniqueConstraint, ForeignKey, Boolean, Integer
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import (
|
||||
TEXT,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
TypeDecorator,
|
||||
UniqueConstraint,
|
||||
create_engine,
|
||||
)
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
# Create a mock-specific Base class for testing
|
||||
MockBase = declarative_base()
|
||||
|
||||
|
||||
class SQLiteUUID(TypeDecorator):
|
||||
"""Enables UUID support for SQLite."""
|
||||
|
||||
impl = TEXT
|
||||
cache_ok = True
|
||||
|
||||
@@ -25,12 +38,14 @@ class SQLiteUUID(TypeDecorator):
|
||||
return value
|
||||
return uuid.UUID(value)
|
||||
|
||||
|
||||
# Model classes for testing - prefix with Mock to avoid pytest collection
|
||||
class MockPriority(MockBase):
|
||||
__tablename__ = "priorities"
|
||||
id = Column(Integer, primary_key=True)
|
||||
description = Column(String, nullable=False)
|
||||
|
||||
|
||||
class MockChannelDB(MockBase):
|
||||
__tablename__ = "channels"
|
||||
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
|
||||
@@ -39,32 +54,45 @@ class MockChannelDB(MockBase):
|
||||
group_title = Column(String, nullable=False)
|
||||
tvg_name = Column(String)
|
||||
__table_args__ = (
|
||||
UniqueConstraint('group_title', 'name', name='uix_group_title_name'),
|
||||
UniqueConstraint("group_title", "name", name="uix_group_title_name"),
|
||||
)
|
||||
tvg_logo = Column(String)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
class MockChannelURL(MockBase):
|
||||
__tablename__ = "channels_urls"
|
||||
id = Column(SQLiteUUID(), primary_key=True, default=uuid.uuid4)
|
||||
channel_id = Column(SQLiteUUID(), ForeignKey('channels.id', ondelete='CASCADE'), nullable=False)
|
||||
channel_id = Column(
|
||||
SQLiteUUID(), ForeignKey("channels.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
url = Column(String, nullable=False)
|
||||
in_use = Column(Boolean, default=False, nullable=False)
|
||||
priority_id = Column(Integer, ForeignKey('priorities.id'), nullable=False)
|
||||
priority_id = Column(Integer, ForeignKey("priorities.id"), nullable=False)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
# Create test engine
|
||||
engine_mock = create_engine(
|
||||
"sqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
|
||||
# Create test session
|
||||
session_mock = sessionmaker(autocommit=False, autoflush=False, bind=engine_mock)
|
||||
|
||||
|
||||
# Mock the actual database functions
|
||||
def mock_get_db():
|
||||
db = session_mock()
|
||||
@@ -73,6 +101,7 @@ def mock_get_db():
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_env(monkeypatch):
|
||||
"""Fixture for mocking environment variables"""
|
||||
@@ -82,14 +111,13 @@ def mock_env(monkeypatch):
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
monkeypatch.setenv("DB_NAME", "testdb")
|
||||
monkeypatch.setenv("AWS_REGION", "us-east-1")
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ssm():
|
||||
"""Fixture for mocking boto3 SSM client"""
|
||||
with patch('boto3.client') as mock_client:
|
||||
with patch("boto3.client") as mock_client:
|
||||
mock_ssm = MagicMock()
|
||||
mock_client.return_value = mock_ssm
|
||||
mock_ssm.get_parameter.return_value = {
|
||||
'Parameter': {'Value': 'mocked_value'}
|
||||
}
|
||||
yield mock_ssm
|
||||
mock_ssm.get_parameter.return_value = {"Parameter": {"Value": "mocked_value"}}
|
||||
yield mock_ssm
|
||||
|
||||
@@ -1,43 +1,45 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from sqlalchemy.orm import Session
|
||||
from app.utils.database import get_db_credentials, get_db
|
||||
from tests.utils.db_mocks import (
|
||||
session_mock,
|
||||
mock_get_db,
|
||||
mock_env,
|
||||
mock_ssm
|
||||
)
|
||||
|
||||
from app.utils.database import get_db, get_db_credentials
|
||||
from tests.utils.db_mocks import mock_env, mock_ssm, session_mock
|
||||
|
||||
|
||||
def test_get_db_credentials_env(mock_env):
|
||||
"""Test getting DB credentials from environment variables"""
|
||||
conn_str = get_db_credentials()
|
||||
assert conn_str == "postgresql://testuser:testpass@localhost/testdb"
|
||||
|
||||
|
||||
def test_get_db_credentials_ssm(mock_ssm):
|
||||
"""Test getting DB credentials from SSM"""
|
||||
os.environ.pop("MOCK_AUTH", None)
|
||||
conn_str = get_db_credentials()
|
||||
assert "postgresql://mocked_value:mocked_value@mocked_value/mocked_value" in conn_str
|
||||
expected_conn = "postgresql://mocked_value:mocked_value@mocked_value/mocked_value"
|
||||
assert expected_conn in conn_str
|
||||
mock_ssm.get_parameter.assert_called()
|
||||
|
||||
|
||||
def test_get_db_credentials_ssm_exception(mock_ssm):
|
||||
"""Test SSM credential fetching failure raises RuntimeError"""
|
||||
os.environ.pop("MOCK_AUTH", None)
|
||||
mock_ssm.get_parameter.side_effect = Exception("SSM timeout")
|
||||
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
get_db_credentials()
|
||||
|
||||
|
||||
assert "Failed to fetch DB credentials from SSM: SSM timeout" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_session_creation():
|
||||
"""Test database session creation"""
|
||||
session = session_mock()
|
||||
assert isinstance(session, Session)
|
||||
session.close()
|
||||
|
||||
|
||||
def test_get_db_generator():
|
||||
"""Test get_db dependency generator"""
|
||||
db_gen = get_db()
|
||||
@@ -48,18 +50,20 @@ def test_get_db_generator():
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
|
||||
def test_init_db(mocker, mock_env):
|
||||
"""Test database initialization creates tables"""
|
||||
mock_create_all = mocker.patch('app.models.Base.metadata.create_all')
|
||||
|
||||
mock_create_all = mocker.patch("app.models.Base.metadata.create_all")
|
||||
|
||||
# Mock get_db_credentials to return SQLite test connection
|
||||
mocker.patch(
|
||||
'app.utils.database.get_db_credentials',
|
||||
return_value="sqlite:///:memory:"
|
||||
"app.utils.database.get_db_credentials",
|
||||
return_value="sqlite:///:memory:",
|
||||
)
|
||||
|
||||
from app.utils.database import init_db, engine
|
||||
|
||||
from app.utils.database import engine, init_db
|
||||
|
||||
init_db()
|
||||
|
||||
|
||||
# Verify create_all was called with the engine
|
||||
mock_create_all.assert_called_once_with(bind=engine)
|
||||
mock_create_all.assert_called_once_with(bind=engine)
|
||||
|
||||
Reference in New Issue
Block a user