189 lines
4.9 KiB
Python
189 lines
4.9 KiB
Python
|
|
"""Task queue abstractions for scheduling background work."""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
from abc import ABC, abstractmethod
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from typing import Any
|
||
|
|
from uuid import UUID
|
||
|
|
|
||
|
|
from app.config import settings
|
||
|
|
|
||
|
|
try:
|
||
|
|
from worker.celery_app import celery_app
|
||
|
|
except Exception: # pragma: no cover - celery app may not import during docs builds
|
||
|
|
celery_app = None # type: ignore[assignment]
|
||
|
|
|
||
|
|
|
||
|
|
class TaskQueue(ABC):
|
||
|
|
"""Interface for enqueueing background work."""
|
||
|
|
|
||
|
|
async def startup(self) -> None: # pragma: no cover - default no-op
|
||
|
|
"""Hook for queue initialization."""
|
||
|
|
|
||
|
|
async def shutdown(self) -> None: # pragma: no cover - default no-op
|
||
|
|
"""Hook for queue teardown."""
|
||
|
|
|
||
|
|
async def ping(self) -> bool:
|
||
|
|
"""Check if the queue backend is reachable."""
|
||
|
|
|
||
|
|
return True
|
||
|
|
|
||
|
|
def reset(self) -> None: # pragma: no cover - optional for in-memory impls
|
||
|
|
"""Reset any in-memory state (used in tests)."""
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def incident_triggered(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
incident_id: UUID,
|
||
|
|
org_id: UUID,
|
||
|
|
triggered_by: UUID | None,
|
||
|
|
) -> None:
|
||
|
|
"""Fan out an incident triggered notification."""
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def schedule_escalation_check(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
incident_id: UUID,
|
||
|
|
org_id: UUID,
|
||
|
|
delay_seconds: int,
|
||
|
|
) -> None:
|
||
|
|
"""Schedule a delayed escalation check."""
|
||
|
|
|
||
|
|
|
||
|
|
class CeleryTaskQueue(TaskQueue):
|
||
|
|
"""Celery-backed task queue that can use Redis or SQS brokers."""
|
||
|
|
|
||
|
|
def __init__(self, default_queue: str, critical_queue: str) -> None:
|
||
|
|
if celery_app is None: # pragma: no cover - guarded by try/except
|
||
|
|
raise RuntimeError("Celery application is unavailable")
|
||
|
|
self._celery = celery_app
|
||
|
|
self._default_queue = default_queue
|
||
|
|
self._critical_queue = critical_queue
|
||
|
|
|
||
|
|
def incident_triggered(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
incident_id: UUID,
|
||
|
|
org_id: UUID,
|
||
|
|
triggered_by: UUID | None,
|
||
|
|
) -> None:
|
||
|
|
self._celery.send_task(
|
||
|
|
"worker.tasks.notifications.incident_triggered",
|
||
|
|
kwargs={
|
||
|
|
"incident_id": str(incident_id),
|
||
|
|
"org_id": str(org_id),
|
||
|
|
"triggered_by": str(triggered_by) if triggered_by else None,
|
||
|
|
},
|
||
|
|
queue=self._default_queue,
|
||
|
|
)
|
||
|
|
|
||
|
|
def schedule_escalation_check(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
incident_id: UUID,
|
||
|
|
org_id: UUID,
|
||
|
|
delay_seconds: int,
|
||
|
|
) -> None:
|
||
|
|
self._celery.send_task(
|
||
|
|
"worker.tasks.notifications.escalate_if_unacked",
|
||
|
|
kwargs={
|
||
|
|
"incident_id": str(incident_id),
|
||
|
|
"org_id": str(org_id),
|
||
|
|
},
|
||
|
|
countdown=max(delay_seconds, 0),
|
||
|
|
queue=self._critical_queue,
|
||
|
|
)
|
||
|
|
|
||
|
|
async def ping(self) -> bool:
|
||
|
|
loop = asyncio.get_running_loop()
|
||
|
|
return await loop.run_in_executor(None, self._ping_sync)
|
||
|
|
|
||
|
|
def _ping_sync(self) -> bool:
|
||
|
|
connection = self._celery.connection()
|
||
|
|
try:
|
||
|
|
connection.connect()
|
||
|
|
return True
|
||
|
|
except Exception:
|
||
|
|
return False
|
||
|
|
finally:
|
||
|
|
try:
|
||
|
|
connection.release()
|
||
|
|
except Exception: # pragma: no cover - release best effort
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class InMemoryTaskQueue(TaskQueue):
|
||
|
|
"""Test-friendly queue that records dispatched tasks in memory."""
|
||
|
|
|
||
|
|
dispatched: list[tuple[str, dict[str, Any]]] | None = None
|
||
|
|
|
||
|
|
def __post_init__(self) -> None:
|
||
|
|
if self.dispatched is None:
|
||
|
|
self.dispatched = []
|
||
|
|
|
||
|
|
def incident_triggered(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
incident_id: UUID,
|
||
|
|
org_id: UUID,
|
||
|
|
triggered_by: UUID | None,
|
||
|
|
) -> None:
|
||
|
|
self.dispatched.append(
|
||
|
|
(
|
||
|
|
"incident_triggered",
|
||
|
|
{
|
||
|
|
"incident_id": incident_id,
|
||
|
|
"org_id": org_id,
|
||
|
|
"triggered_by": triggered_by,
|
||
|
|
},
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
def schedule_escalation_check(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
incident_id: UUID,
|
||
|
|
org_id: UUID,
|
||
|
|
delay_seconds: int,
|
||
|
|
) -> None:
|
||
|
|
self.dispatched.append(
|
||
|
|
(
|
||
|
|
"escalate_if_unacked",
|
||
|
|
{
|
||
|
|
"incident_id": incident_id,
|
||
|
|
"org_id": org_id,
|
||
|
|
"delay_seconds": delay_seconds,
|
||
|
|
},
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
def reset(self) -> None:
|
||
|
|
if self.dispatched is not None:
|
||
|
|
self.dispatched.clear()
|
||
|
|
|
||
|
|
|
||
|
|
def _build_task_queue() -> TaskQueue:
|
||
|
|
if settings.task_queue_driver == "inmemory":
|
||
|
|
return InMemoryTaskQueue()
|
||
|
|
|
||
|
|
return CeleryTaskQueue(
|
||
|
|
default_queue=settings.task_queue_default_queue,
|
||
|
|
critical_queue=settings.task_queue_critical_queue,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
task_queue = _build_task_queue()
|
||
|
|
|
||
|
|
|
||
|
|
__all__ = [
|
||
|
|
"CeleryTaskQueue",
|
||
|
|
"InMemoryTaskQueue",
|
||
|
|
"TaskQueue",
|
||
|
|
"task_queue",
|
||
|
|
]
|