Files
incidentops/worker/tasks/notifications.py

226 lines
7.3 KiB
Python
Raw Normal View History

"""Notification-related Celery tasks and helpers."""
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any
from uuid import UUID, uuid4
import asyncpg
from celery import shared_task
from celery.utils.log import get_task_logger
from app.config import settings
from app.repositories.incident import IncidentRepository
from app.repositories.notification import NotificationRepository
logger = get_task_logger(__name__)
@dataclass
class NotificationDispatch:
"""Represents a pending notification attempt for a target."""
attempt_id: UUID
incident_id: UUID
target: dict[str, Any]
def _serialize_target(target: dict[str, Any]) -> dict[str, Any]:
serialized: dict[str, Any] = {}
for key, value in target.items():
if isinstance(value, UUID):
serialized[key] = str(value)
else:
serialized[key] = value
return serialized
async def prepare_notification_dispatches(
conn: asyncpg.Connection,
*,
incident_id: UUID,
org_id: UUID,
) -> list[NotificationDispatch]:
"""Create notification attempts for all enabled targets in the org."""
notification_repo = NotificationRepository(conn)
targets = await notification_repo.get_targets_by_org(org_id, enabled_only=True)
dispatches: list[NotificationDispatch] = []
for target in targets:
attempt = await notification_repo.create_attempt(uuid4(), incident_id, target["id"])
dispatches.append(
NotificationDispatch(
attempt_id=attempt["id"],
incident_id=attempt["incident_id"],
target=_serialize_target(target),
)
)
return dispatches
async def _prepare_dispatches_with_new_connection(
incident_id: UUID,
org_id: UUID,
) -> list[NotificationDispatch]:
conn = await asyncpg.connect(settings.database_url)
try:
return await prepare_notification_dispatches(conn, incident_id=incident_id, org_id=org_id)
finally:
await conn.close()
async def _mark_attempt_success(attempt_id: UUID) -> None:
conn = await asyncpg.connect(settings.database_url)
try:
repo = NotificationRepository(conn)
await repo.update_attempt_success(attempt_id, datetime.now(UTC))
finally:
await conn.close()
async def _mark_attempt_failure(attempt_id: UUID, error: str) -> None:
conn = await asyncpg.connect(settings.database_url)
try:
repo = NotificationRepository(conn)
await repo.update_attempt_failure(attempt_id, error)
finally:
await conn.close()
async def _should_escalate(incident_id: UUID) -> bool:
conn = await asyncpg.connect(settings.database_url)
try:
repo = IncidentRepository(conn)
incident = await repo.get_by_id(incident_id)
if incident is None:
return False
return incident["status"] == "triggered"
finally:
await conn.close()
def _simulate_delivery(channel: str, target: dict[str, Any], incident_id: str) -> None:
target_name = target.get("name") or target.get("id")
logger.info("Simulated %s delivery for incident %s to %s", channel, incident_id, target_name)
@shared_task(name="worker.tasks.notifications.incident_triggered", bind=True)
def incident_triggered(
self,
*,
incident_id: str,
org_id: str,
triggered_by: str | None = None,
) -> None:
"""Fan-out notifications to all active targets for the incident's org."""
incident_uuid = UUID(incident_id)
org_uuid = UUID(org_id)
try:
dispatches = asyncio.run(_prepare_dispatches_with_new_connection(incident_uuid, org_uuid))
except Exception as exc: # pragma: no cover - logged for observability
logger.exception("Failed to prepare notification dispatches: %s", exc)
raise
if not dispatches:
logger.info("No notification targets for org %s", org_id)
return
for dispatch in dispatches:
target_type = dispatch.target.get("target_type")
kwargs = {
"attempt_id": str(dispatch.attempt_id),
"incident_id": incident_id,
"target": dispatch.target,
}
if target_type == "webhook":
send_webhook.apply_async(kwargs=kwargs, queue=settings.task_queue_default_queue)
elif target_type == "email":
send_email.apply_async(kwargs=kwargs, queue=settings.task_queue_default_queue)
elif target_type == "slack":
send_slack.apply_async(kwargs=kwargs, queue=settings.task_queue_default_queue)
else:
logger.warning("Unsupported notification target type: %s", target_type)
@shared_task(
name="worker.tasks.notifications.send_webhook",
bind=True,
autoretry_for=(Exception,),
retry_backoff=True,
retry_kwargs={"max_retries": 3},
)
def send_webhook(self, *, attempt_id: str, target: dict[str, Any], incident_id: str) -> None:
"""Simulate webhook delivery and mark the attempt status."""
try:
_simulate_delivery("webhook", target, incident_id)
asyncio.run(_mark_attempt_success(UUID(attempt_id)))
except Exception as exc: # pragma: no cover - logged for observability
logger.exception("Webhook delivery failed: %s", exc)
asyncio.run(_mark_attempt_failure(UUID(attempt_id), str(exc)))
raise
@shared_task(name="worker.tasks.notifications.send_email", bind=True)
def send_email(self, *, attempt_id: str, target: dict[str, Any], incident_id: str) -> None:
"""Simulate email delivery for the notification attempt."""
try:
_simulate_delivery("email", target, incident_id)
asyncio.run(_mark_attempt_success(UUID(attempt_id)))
except Exception as exc: # pragma: no cover
logger.exception("Email delivery failed: %s", exc)
asyncio.run(_mark_attempt_failure(UUID(attempt_id), str(exc)))
raise
@shared_task(name="worker.tasks.notifications.send_slack", bind=True)
def send_slack(self, *, attempt_id: str, target: dict[str, Any], incident_id: str) -> None:
"""Simulate Slack delivery for the notification attempt."""
try:
_simulate_delivery("slack", target, incident_id)
asyncio.run(_mark_attempt_success(UUID(attempt_id)))
except Exception as exc: # pragma: no cover
logger.exception("Slack delivery failed: %s", exc)
asyncio.run(_mark_attempt_failure(UUID(attempt_id), str(exc)))
raise
@shared_task(name="worker.tasks.notifications.escalate_if_unacked", bind=True)
def escalate_if_unacked(self, *, incident_id: str, org_id: str) -> None:
"""Re-dispatch notifications if the incident remains unacknowledged."""
incident_uuid = UUID(incident_id)
should_escalate = asyncio.run(_should_escalate(incident_uuid))
if not should_escalate:
logger.info("Incident %s no longer needs escalation", incident_id)
return
logger.info("Incident %s still triggered; re-fanning notifications", incident_id)
incident_triggered.apply_async( # type: ignore[attr-defined]
kwargs={
"incident_id": incident_id,
"org_id": org_id,
"triggered_by": None,
},
queue=settings.task_queue_critical_queue,
)
__all__ = [
"NotificationDispatch",
"incident_triggered",
"escalate_if_unacked",
"prepare_notification_dispatches",
"send_email",
"send_slack",
"send_webhook",
]