65 lines
2.1 KiB
Python
65 lines
2.1 KiB
Python
import os
|
|
import pytest
|
|
from unittest.mock import patch
|
|
from sqlalchemy.orm import Session
|
|
from app.utils.database import get_db_credentials, get_db
|
|
from tests.utils.db_mocks import (
|
|
session_mock,
|
|
mock_get_db,
|
|
mock_env,
|
|
mock_ssm
|
|
)
|
|
|
|
def test_get_db_credentials_env(mock_env):
|
|
"""Test getting DB credentials from environment variables"""
|
|
conn_str = get_db_credentials()
|
|
assert conn_str == "postgresql://testuser:testpass@localhost/testdb"
|
|
|
|
def test_get_db_credentials_ssm(mock_ssm):
|
|
"""Test getting DB credentials from SSM"""
|
|
os.environ.pop("MOCK_AUTH", None)
|
|
conn_str = get_db_credentials()
|
|
assert "postgresql://mocked_value:mocked_value@mocked_value/mocked_value" in conn_str
|
|
mock_ssm.get_parameter.assert_called()
|
|
|
|
def test_get_db_credentials_ssm_exception(mock_ssm):
|
|
"""Test SSM credential fetching failure raises RuntimeError"""
|
|
os.environ.pop("MOCK_AUTH", None)
|
|
mock_ssm.get_parameter.side_effect = Exception("SSM timeout")
|
|
|
|
with pytest.raises(RuntimeError) as excinfo:
|
|
get_db_credentials()
|
|
|
|
assert "Failed to fetch DB credentials from SSM: SSM timeout" in str(excinfo.value)
|
|
|
|
def test_session_creation():
|
|
"""Test database session creation"""
|
|
session = session_mock()
|
|
assert isinstance(session, Session)
|
|
session.close()
|
|
|
|
def test_get_db_generator():
|
|
"""Test get_db dependency generator"""
|
|
db_gen = get_db()
|
|
db = next(db_gen)
|
|
assert isinstance(db, Session)
|
|
try:
|
|
next(db_gen) # Should raise StopIteration
|
|
except StopIteration:
|
|
pass
|
|
|
|
def test_init_db(mocker, mock_env):
|
|
"""Test database initialization creates tables"""
|
|
mock_create_all = mocker.patch('app.models.Base.metadata.create_all')
|
|
|
|
# Mock get_db_credentials to return SQLite test connection
|
|
mocker.patch(
|
|
'app.utils.database.get_db_credentials',
|
|
return_value="sqlite:///:memory:"
|
|
)
|
|
|
|
from app.utils.database import init_db, engine
|
|
init_db()
|
|
|
|
# Verify create_all was called with the engine
|
|
mock_create_all.assert_called_once_with(bind=engine) |