2025-12-07 12:00:00 +00:00
|
|
|
"""Shared pytest fixtures for all tests."""
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import os
|
2025-12-29 09:55:30 +00:00
|
|
|
from contextlib import asynccontextmanager
|
2026-01-07 20:51:13 -05:00
|
|
|
from typing import AsyncGenerator, Callable, Generator
|
2025-12-29 09:55:30 +00:00
|
|
|
from uuid import UUID, uuid4
|
2025-12-07 12:00:00 +00:00
|
|
|
|
|
|
|
|
import asyncpg
|
2025-12-29 09:55:30 +00:00
|
|
|
import httpx
|
2025-12-07 12:00:00 +00:00
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
# Set test environment variables before importing app modules
|
|
|
|
|
os.environ.setdefault("DATABASE_URL", "postgresql://incidentops:incidentops@localhost:5432/incidentops_test")
|
|
|
|
|
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-testing-only")
|
|
|
|
|
os.environ.setdefault("REDIS_URL", "redis://localhost:6379/1")
|
2026-01-07 20:51:13 -05:00
|
|
|
os.environ.setdefault("TASK_QUEUE_DRIVER", "inmemory")
|
|
|
|
|
os.environ.setdefault("TASK_QUEUE_BROKER_URL", "redis://localhost:6379/2")
|
2025-12-07 12:00:00 +00:00
|
|
|
|
2025-12-29 09:55:30 +00:00
|
|
|
from app.main import app
|
2026-01-07 20:51:13 -05:00
|
|
|
from app.taskqueue import task_queue
|
2025-12-29 09:55:30 +00:00
|
|
|
|
2025-12-07 12:00:00 +00:00
|
|
|
|
|
|
|
|
# Module-level setup: create database and run migrations once
|
|
|
|
|
_db_initialized = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _init_test_db() -> None:
|
|
|
|
|
"""Initialize test database and run migrations (once per session)."""
|
|
|
|
|
global _db_initialized
|
|
|
|
|
if _db_initialized:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
admin_dsn = os.environ["DATABASE_URL"].rsplit("/", 1)[0] + "/postgres"
|
|
|
|
|
test_db_name = "incidentops_test"
|
|
|
|
|
|
|
|
|
|
admin_conn = await asyncpg.connect(admin_dsn)
|
|
|
|
|
try:
|
|
|
|
|
# Terminate existing connections to the test database
|
|
|
|
|
await admin_conn.execute(f"""
|
|
|
|
|
SELECT pg_terminate_backend(pg_stat_activity.pid)
|
|
|
|
|
FROM pg_stat_activity
|
|
|
|
|
WHERE pg_stat_activity.datname = '{test_db_name}'
|
|
|
|
|
AND pid <> pg_backend_pid()
|
|
|
|
|
""")
|
|
|
|
|
# Drop and recreate test database
|
|
|
|
|
await admin_conn.execute(f"DROP DATABASE IF EXISTS {test_db_name}")
|
|
|
|
|
await admin_conn.execute(f"CREATE DATABASE {test_db_name}")
|
|
|
|
|
finally:
|
|
|
|
|
await admin_conn.close()
|
|
|
|
|
|
|
|
|
|
# Connect to test database and run migrations
|
|
|
|
|
test_dsn = os.environ["DATABASE_URL"]
|
|
|
|
|
conn = await asyncpg.connect(test_dsn)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# Run migrations
|
|
|
|
|
migrations_dir = os.path.join(os.path.dirname(__file__), "..", "migrations")
|
|
|
|
|
migration_files = sorted(
|
|
|
|
|
f for f in os.listdir(migrations_dir) if f.endswith(".sql")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for migration_file in migration_files:
|
|
|
|
|
migration_path = os.path.join(migrations_dir, migration_file)
|
|
|
|
|
with open(migration_path) as f:
|
|
|
|
|
sql = f.read()
|
|
|
|
|
await conn.execute(sql)
|
|
|
|
|
finally:
|
|
|
|
|
await conn.close()
|
|
|
|
|
|
|
|
|
|
_db_initialized = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
2025-12-29 09:55:30 +00:00
|
|
|
async def db_conn() -> AsyncGenerator[asyncpg.Connection, None]:
|
2025-12-07 12:00:00 +00:00
|
|
|
"""Get a database connection with transaction rollback for test isolation."""
|
|
|
|
|
await _init_test_db()
|
|
|
|
|
|
|
|
|
|
test_dsn = os.environ["DATABASE_URL"]
|
|
|
|
|
conn = await asyncpg.connect(test_dsn)
|
|
|
|
|
|
|
|
|
|
# Start a transaction that will be rolled back
|
|
|
|
|
tr = conn.transaction()
|
|
|
|
|
await tr.start()
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
yield conn
|
|
|
|
|
finally:
|
|
|
|
|
await tr.rollback()
|
|
|
|
|
await conn.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
2025-12-29 09:55:30 +00:00
|
|
|
def make_user_id() -> Callable[[], UUID]:
|
2025-12-07 12:00:00 +00:00
|
|
|
"""Factory for generating user IDs."""
|
|
|
|
|
return lambda: uuid4()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
2025-12-29 09:55:30 +00:00
|
|
|
def make_org_id() -> Callable[[], UUID]:
|
2025-12-07 12:00:00 +00:00
|
|
|
"""Factory for generating org IDs."""
|
|
|
|
|
return lambda: uuid4()
|
2025-12-29 09:55:30 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
TABLES_TO_TRUNCATE = [
|
|
|
|
|
"incident_events",
|
|
|
|
|
"notification_attempts",
|
|
|
|
|
"incidents",
|
|
|
|
|
"notification_targets",
|
|
|
|
|
"services",
|
|
|
|
|
"refresh_tokens",
|
|
|
|
|
"org_members",
|
|
|
|
|
"orgs",
|
|
|
|
|
"users",
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _truncate_all_tables() -> None:
|
|
|
|
|
test_dsn = os.environ["DATABASE_URL"]
|
|
|
|
|
conn = await asyncpg.connect(test_dsn)
|
|
|
|
|
try:
|
|
|
|
|
tables = ", ".join(TABLES_TO_TRUNCATE)
|
|
|
|
|
await conn.execute(f"TRUNCATE TABLE {tables} CASCADE")
|
|
|
|
|
finally:
|
|
|
|
|
await conn.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
async def clean_database() -> AsyncGenerator[None, None]:
|
|
|
|
|
"""Ensure the database is initialized and truncated before/after tests."""
|
|
|
|
|
|
|
|
|
|
await _init_test_db()
|
|
|
|
|
await _truncate_all_tables()
|
|
|
|
|
yield
|
|
|
|
|
await _truncate_all_tables()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager
|
|
|
|
|
async def _lifespan_manager() -> AsyncGenerator[None, None]:
|
|
|
|
|
lifespan = app.router.lifespan_context
|
|
|
|
|
if lifespan is None:
|
|
|
|
|
yield
|
|
|
|
|
else:
|
|
|
|
|
async with lifespan(app):
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
async def api_client(clean_database: None) -> AsyncGenerator[httpx.AsyncClient, None]:
|
|
|
|
|
"""HTTPX async client bound to the FastAPI app with lifespan support."""
|
|
|
|
|
|
|
|
|
|
async with _lifespan_manager():
|
|
|
|
|
transport = httpx.ASGITransport(app=app)
|
|
|
|
|
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
|
|
|
|
|
yield client
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
async def db_admin(clean_database: None) -> AsyncGenerator[asyncpg.Connection, None]:
|
|
|
|
|
"""Plain connection for arranging/inspecting API test data (no rollback)."""
|
|
|
|
|
|
|
|
|
|
test_dsn = os.environ["DATABASE_URL"]
|
|
|
|
|
conn = await asyncpg.connect(test_dsn)
|
|
|
|
|
try:
|
|
|
|
|
yield conn
|
|
|
|
|
finally:
|
|
|
|
|
await conn.close()
|
2026-01-07 20:51:13 -05:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
|
def reset_task_queue() -> Generator[None, None, None]:
|
|
|
|
|
"""Ensure in-memory task queue state is cleared between tests."""
|
|
|
|
|
|
|
|
|
|
if hasattr(task_queue, "reset"):
|
|
|
|
|
task_queue.reset()
|
|
|
|
|
yield
|
|
|
|
|
if hasattr(task_queue, "reset"):
|
|
|
|
|
task_queue.reset()
|