"""Notification-related Celery tasks and helpers.""" from __future__ import annotations import asyncio from dataclasses import dataclass from datetime import UTC, datetime from typing import Any from uuid import UUID, uuid4 import asyncpg from celery import shared_task from celery.utils.log import get_task_logger from app.config import settings from app.repositories.incident import IncidentRepository from app.repositories.notification import NotificationRepository logger = get_task_logger(__name__) @dataclass class NotificationDispatch: """Represents a pending notification attempt for a target.""" attempt_id: UUID incident_id: UUID target: dict[str, Any] def _serialize_target(target: dict[str, Any]) -> dict[str, Any]: serialized: dict[str, Any] = {} for key, value in target.items(): if isinstance(value, UUID): serialized[key] = str(value) else: serialized[key] = value return serialized async def prepare_notification_dispatches( conn: asyncpg.Connection, *, incident_id: UUID, org_id: UUID, ) -> list[NotificationDispatch]: """Create notification attempts for all enabled targets in the org.""" notification_repo = NotificationRepository(conn) targets = await notification_repo.get_targets_by_org(org_id, enabled_only=True) dispatches: list[NotificationDispatch] = [] for target in targets: attempt = await notification_repo.create_attempt(uuid4(), incident_id, target["id"]) dispatches.append( NotificationDispatch( attempt_id=attempt["id"], incident_id=attempt["incident_id"], target=_serialize_target(target), ) ) return dispatches async def _prepare_dispatches_with_new_connection( incident_id: UUID, org_id: UUID, ) -> list[NotificationDispatch]: conn = await asyncpg.connect(settings.database_url) try: return await prepare_notification_dispatches(conn, incident_id=incident_id, org_id=org_id) finally: await conn.close() async def _mark_attempt_success(attempt_id: UUID) -> None: conn = await asyncpg.connect(settings.database_url) try: repo = NotificationRepository(conn) await repo.update_attempt_success(attempt_id, datetime.now(UTC)) finally: await conn.close() async def _mark_attempt_failure(attempt_id: UUID, error: str) -> None: conn = await asyncpg.connect(settings.database_url) try: repo = NotificationRepository(conn) await repo.update_attempt_failure(attempt_id, error) finally: await conn.close() async def _should_escalate(incident_id: UUID) -> bool: conn = await asyncpg.connect(settings.database_url) try: repo = IncidentRepository(conn) incident = await repo.get_by_id(incident_id) if incident is None: return False return incident["status"] == "triggered" finally: await conn.close() def _simulate_delivery(channel: str, target: dict[str, Any], incident_id: str) -> None: target_name = target.get("name") or target.get("id") logger.info("Simulated %s delivery for incident %s to %s", channel, incident_id, target_name) @shared_task(name="worker.tasks.notifications.incident_triggered", bind=True) def incident_triggered( self, *, incident_id: str, org_id: str, triggered_by: str | None = None, ) -> None: """Fan-out notifications to all active targets for the incident's org.""" incident_uuid = UUID(incident_id) org_uuid = UUID(org_id) try: dispatches = asyncio.run(_prepare_dispatches_with_new_connection(incident_uuid, org_uuid)) except Exception as exc: # pragma: no cover - logged for observability logger.exception("Failed to prepare notification dispatches: %s", exc) raise if not dispatches: logger.info("No notification targets for org %s", org_id) return for dispatch in dispatches: target_type = dispatch.target.get("target_type") kwargs = { "attempt_id": str(dispatch.attempt_id), "incident_id": incident_id, "target": dispatch.target, } if target_type == "webhook": send_webhook.apply_async(kwargs=kwargs, queue=settings.task_queue_default_queue) elif target_type == "email": send_email.apply_async(kwargs=kwargs, queue=settings.task_queue_default_queue) elif target_type == "slack": send_slack.apply_async(kwargs=kwargs, queue=settings.task_queue_default_queue) else: logger.warning("Unsupported notification target type: %s", target_type) @shared_task( name="worker.tasks.notifications.send_webhook", bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={"max_retries": 3}, ) def send_webhook(self, *, attempt_id: str, target: dict[str, Any], incident_id: str) -> None: """Simulate webhook delivery and mark the attempt status.""" try: _simulate_delivery("webhook", target, incident_id) asyncio.run(_mark_attempt_success(UUID(attempt_id))) except Exception as exc: # pragma: no cover - logged for observability logger.exception("Webhook delivery failed: %s", exc) asyncio.run(_mark_attempt_failure(UUID(attempt_id), str(exc))) raise @shared_task(name="worker.tasks.notifications.send_email", bind=True) def send_email(self, *, attempt_id: str, target: dict[str, Any], incident_id: str) -> None: """Simulate email delivery for the notification attempt.""" try: _simulate_delivery("email", target, incident_id) asyncio.run(_mark_attempt_success(UUID(attempt_id))) except Exception as exc: # pragma: no cover logger.exception("Email delivery failed: %s", exc) asyncio.run(_mark_attempt_failure(UUID(attempt_id), str(exc))) raise @shared_task(name="worker.tasks.notifications.send_slack", bind=True) def send_slack(self, *, attempt_id: str, target: dict[str, Any], incident_id: str) -> None: """Simulate Slack delivery for the notification attempt.""" try: _simulate_delivery("slack", target, incident_id) asyncio.run(_mark_attempt_success(UUID(attempt_id))) except Exception as exc: # pragma: no cover logger.exception("Slack delivery failed: %s", exc) asyncio.run(_mark_attempt_failure(UUID(attempt_id), str(exc))) raise @shared_task(name="worker.tasks.notifications.escalate_if_unacked", bind=True) def escalate_if_unacked(self, *, incident_id: str, org_id: str) -> None: """Re-dispatch notifications if the incident remains unacknowledged.""" incident_uuid = UUID(incident_id) should_escalate = asyncio.run(_should_escalate(incident_uuid)) if not should_escalate: logger.info("Incident %s no longer needs escalation", incident_id) return logger.info("Incident %s still triggered; re-fanning notifications", incident_id) incident_triggered.apply_async( # type: ignore[attr-defined] kwargs={ "incident_id": incident_id, "org_id": org_id, "triggered_by": None, }, queue=settings.task_queue_critical_queue, ) __all__ = [ "NotificationDispatch", "incident_triggered", "escalate_if_unacked", "prepare_notification_dispatches", "send_email", "send_slack", "send_webhook", ]