"""Task queue abstractions for scheduling background work.""" from __future__ import annotations import asyncio from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any from uuid import UUID from app.config import settings try: from worker.celery_app import celery_app except Exception: # pragma: no cover - celery app may not import during docs builds celery_app = None # type: ignore[assignment] class TaskQueue(ABC): """Interface for enqueueing background work.""" async def startup(self) -> None: # pragma: no cover - default no-op """Hook for queue initialization.""" async def shutdown(self) -> None: # pragma: no cover - default no-op """Hook for queue teardown.""" async def ping(self) -> bool: """Check if the queue backend is reachable.""" return True def reset(self) -> None: # pragma: no cover - optional for in-memory impls """Reset any in-memory state (used in tests).""" @abstractmethod def incident_triggered( self, *, incident_id: UUID, org_id: UUID, triggered_by: UUID | None, ) -> None: """Fan out an incident triggered notification.""" @abstractmethod def schedule_escalation_check( self, *, incident_id: UUID, org_id: UUID, delay_seconds: int, ) -> None: """Schedule a delayed escalation check.""" class CeleryTaskQueue(TaskQueue): """Celery-backed task queue that can use Redis or SQS brokers.""" def __init__(self, default_queue: str, critical_queue: str) -> None: if celery_app is None: # pragma: no cover - guarded by try/except raise RuntimeError("Celery application is unavailable") self._celery = celery_app self._default_queue = default_queue self._critical_queue = critical_queue def incident_triggered( self, *, incident_id: UUID, org_id: UUID, triggered_by: UUID | None, ) -> None: self._celery.send_task( "worker.tasks.notifications.incident_triggered", kwargs={ "incident_id": str(incident_id), "org_id": str(org_id), "triggered_by": str(triggered_by) if triggered_by else None, }, queue=self._default_queue, ) def schedule_escalation_check( self, *, incident_id: UUID, org_id: UUID, delay_seconds: int, ) -> None: self._celery.send_task( "worker.tasks.notifications.escalate_if_unacked", kwargs={ "incident_id": str(incident_id), "org_id": str(org_id), }, countdown=max(delay_seconds, 0), queue=self._critical_queue, ) async def ping(self) -> bool: loop = asyncio.get_running_loop() return await loop.run_in_executor(None, self._ping_sync) def _ping_sync(self) -> bool: connection = self._celery.connection() try: connection.connect() return True except Exception: return False finally: try: connection.release() except Exception: # pragma: no cover - release best effort pass @dataclass class InMemoryTaskQueue(TaskQueue): """Test-friendly queue that records dispatched tasks in memory.""" dispatched: list[tuple[str, dict[str, Any]]] | None = None def __post_init__(self) -> None: if self.dispatched is None: self.dispatched = [] def incident_triggered( self, *, incident_id: UUID, org_id: UUID, triggered_by: UUID | None, ) -> None: self.dispatched.append( ( "incident_triggered", { "incident_id": incident_id, "org_id": org_id, "triggered_by": triggered_by, }, ) ) def schedule_escalation_check( self, *, incident_id: UUID, org_id: UUID, delay_seconds: int, ) -> None: self.dispatched.append( ( "escalate_if_unacked", { "incident_id": incident_id, "org_id": org_id, "delay_seconds": delay_seconds, }, ) ) def reset(self) -> None: if self.dispatched is not None: self.dispatched.clear() def _build_task_queue() -> TaskQueue: if settings.task_queue_driver == "inmemory": return InMemoryTaskQueue() return CeleryTaskQueue( default_queue=settings.task_queue_default_queue, critical_queue=settings.task_queue_critical_queue, ) task_queue = _build_task_queue() __all__ = [ "CeleryTaskQueue", "InMemoryTaskQueue", "TaskQueue", "task_queue", ]