Files
incidentops/tests/worker/test_celery_tasks.py

200 lines
5.4 KiB
Python
Raw Permalink Normal View History

"""End-to-end Celery worker tests against the real Redis broker."""
from __future__ import annotations
import asyncio
import inspect
from uuid import UUID, uuid4
import asyncpg
import pytest
import redis
from app.config import settings
from app.repositories.incident import IncidentRepository
from app.taskqueue import CeleryTaskQueue
from celery.contrib.testing.worker import start_worker
from worker.celery_app import celery_app
pytestmark = pytest.mark.asyncio
@pytest.fixture(scope="module", autouse=True)
def ensure_redis_available() -> None:
"""Skip the module if the configured Redis broker is unreachable."""
client = redis.Redis.from_url(settings.resolved_task_queue_broker_url)
try:
client.ping()
except redis.RedisError as exc: # pragma: no cover - diagnostic-only path
pytest.skip(f"Redis broker unavailable: {exc}")
finally:
client.close()
@pytest.fixture(scope="module")
def celery_worker_instance(ensure_redis_available: None):
"""Run a real Celery worker connected to Redis for the duration of the module."""
queues = [settings.task_queue_default_queue, settings.task_queue_critical_queue]
with start_worker(
celery_app,
loglevel="INFO",
pool="solo",
concurrency=1,
queues=queues,
perform_ping_check=False,
):
yield
@pytest.fixture(autouse=True)
def purge_celery_queues():
"""Clear any pending tasks before and after each test for isolation."""
celery_app.control.purge()
yield
celery_app.control.purge()
@pytest.fixture
def celery_queue() -> CeleryTaskQueue:
return CeleryTaskQueue(
default_queue=settings.task_queue_default_queue,
critical_queue=settings.task_queue_critical_queue,
)
async def _seed_incident_with_target(conn: asyncpg.Connection) -> tuple[UUID, UUID]:
org_id = uuid4()
service_id = uuid4()
incident_id = uuid4()
target_id = uuid4()
await conn.execute(
"INSERT INTO orgs (id, name, slug) VALUES ($1, $2, $3)",
org_id,
"Celery Org",
f"celery-{org_id.hex[:6]}",
)
await conn.execute(
"INSERT INTO services (id, org_id, name, slug) VALUES ($1, $2, $3, $4)",
service_id,
org_id,
"API",
f"svc-{service_id.hex[:6]}",
)
repo = IncidentRepository(conn)
await repo.create(
incident_id=incident_id,
org_id=org_id,
service_id=service_id,
title="Latency spike",
description="",
severity="high",
)
await conn.execute(
"""
INSERT INTO notification_targets (id, org_id, name, target_type, webhook_url, enabled)
VALUES ($1, $2, $3, $4, $5, $6)
""",
target_id,
org_id,
"Primary Webhook",
"webhook",
"https://example.com/hook",
True,
)
return org_id, incident_id
async def _wait_until(predicate, timeout: float = 5.0, interval: float = 0.1) -> None:
deadline = asyncio.get_running_loop().time() + timeout
while True:
result = predicate()
if inspect.isawaitable(result):
result = await result
if result:
return
if asyncio.get_running_loop().time() >= deadline:
raise AssertionError("Timed out waiting for Celery worker to finish")
await asyncio.sleep(interval)
async def _attempt_sent(conn: asyncpg.Connection, incident_id: UUID) -> bool:
row = await conn.fetchrow(
"SELECT status FROM notification_attempts WHERE incident_id = $1",
incident_id,
)
return bool(row and row["status"] == "sent")
async def _attempt_count(conn: asyncpg.Connection, incident_id: UUID) -> int:
count = await conn.fetchval(
"SELECT COUNT(*) FROM notification_attempts WHERE incident_id = $1",
incident_id,
)
return int(count or 0)
async def _attempt_count_is(conn: asyncpg.Connection, incident_id: UUID, expected: int) -> bool:
return await _attempt_count(conn, incident_id) == expected
async def test_incident_triggered_task_marks_attempt_sent(
db_admin: asyncpg.Connection,
celery_worker_instance: None,
celery_queue: CeleryTaskQueue,
) -> None:
org_id, incident_id = await _seed_incident_with_target(db_admin)
celery_queue.incident_triggered(
incident_id=incident_id,
org_id=org_id,
triggered_by=uuid4(),
)
await _wait_until(lambda: _attempt_sent(db_admin, incident_id))
async def test_escalate_task_refires_when_incident_still_triggered(
db_admin: asyncpg.Connection,
celery_worker_instance: None,
celery_queue: CeleryTaskQueue,
) -> None:
org_id, incident_id = await _seed_incident_with_target(db_admin)
celery_queue.schedule_escalation_check(
incident_id=incident_id,
org_id=org_id,
delay_seconds=0,
)
await _wait_until(lambda: _attempt_count_is(db_admin, incident_id, 1))
async def test_escalate_task_skips_when_incident_acknowledged(
db_admin: asyncpg.Connection,
celery_worker_instance: None,
celery_queue: CeleryTaskQueue,
) -> None:
org_id, incident_id = await _seed_incident_with_target(db_admin)
await db_admin.execute(
"UPDATE incidents SET status = 'acknowledged' WHERE id = $1",
incident_id,
)
celery_queue.schedule_escalation_check(
incident_id=incident_id,
org_id=org_id,
delay_seconds=0,
)
await asyncio.sleep(1)
assert await _attempt_count(db_admin, incident_id) == 0