226 lines
7.3 KiB
Python
226 lines
7.3 KiB
Python
|
|
"""Notification-related Celery tasks and helpers."""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from datetime import UTC, datetime
|
||
|
|
from typing import Any
|
||
|
|
from uuid import UUID, uuid4
|
||
|
|
|
||
|
|
import asyncpg
|
||
|
|
from celery import shared_task
|
||
|
|
from celery.utils.log import get_task_logger
|
||
|
|
|
||
|
|
from app.config import settings
|
||
|
|
from app.repositories.incident import IncidentRepository
|
||
|
|
from app.repositories.notification import NotificationRepository
|
||
|
|
|
||
|
|
logger = get_task_logger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class NotificationDispatch:
|
||
|
|
"""Represents a pending notification attempt for a target."""
|
||
|
|
|
||
|
|
attempt_id: UUID
|
||
|
|
incident_id: UUID
|
||
|
|
target: dict[str, Any]
|
||
|
|
|
||
|
|
|
||
|
|
def _serialize_target(target: dict[str, Any]) -> dict[str, Any]:
|
||
|
|
serialized: dict[str, Any] = {}
|
||
|
|
for key, value in target.items():
|
||
|
|
if isinstance(value, UUID):
|
||
|
|
serialized[key] = str(value)
|
||
|
|
else:
|
||
|
|
serialized[key] = value
|
||
|
|
return serialized
|
||
|
|
|
||
|
|
|
||
|
|
async def prepare_notification_dispatches(
|
||
|
|
conn: asyncpg.Connection,
|
||
|
|
*,
|
||
|
|
incident_id: UUID,
|
||
|
|
org_id: UUID,
|
||
|
|
) -> list[NotificationDispatch]:
|
||
|
|
"""Create notification attempts for all enabled targets in the org."""
|
||
|
|
|
||
|
|
notification_repo = NotificationRepository(conn)
|
||
|
|
targets = await notification_repo.get_targets_by_org(org_id, enabled_only=True)
|
||
|
|
dispatches: list[NotificationDispatch] = []
|
||
|
|
|
||
|
|
for target in targets:
|
||
|
|
attempt = await notification_repo.create_attempt(uuid4(), incident_id, target["id"])
|
||
|
|
dispatches.append(
|
||
|
|
NotificationDispatch(
|
||
|
|
attempt_id=attempt["id"],
|
||
|
|
incident_id=attempt["incident_id"],
|
||
|
|
target=_serialize_target(target),
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
return dispatches
|
||
|
|
|
||
|
|
|
||
|
|
async def _prepare_dispatches_with_new_connection(
|
||
|
|
incident_id: UUID,
|
||
|
|
org_id: UUID,
|
||
|
|
) -> list[NotificationDispatch]:
|
||
|
|
conn = await asyncpg.connect(settings.database_url)
|
||
|
|
try:
|
||
|
|
return await prepare_notification_dispatches(conn, incident_id=incident_id, org_id=org_id)
|
||
|
|
finally:
|
||
|
|
await conn.close()
|
||
|
|
|
||
|
|
|
||
|
|
async def _mark_attempt_success(attempt_id: UUID) -> None:
|
||
|
|
conn = await asyncpg.connect(settings.database_url)
|
||
|
|
try:
|
||
|
|
repo = NotificationRepository(conn)
|
||
|
|
await repo.update_attempt_success(attempt_id, datetime.now(UTC))
|
||
|
|
finally:
|
||
|
|
await conn.close()
|
||
|
|
|
||
|
|
|
||
|
|
async def _mark_attempt_failure(attempt_id: UUID, error: str) -> None:
|
||
|
|
conn = await asyncpg.connect(settings.database_url)
|
||
|
|
try:
|
||
|
|
repo = NotificationRepository(conn)
|
||
|
|
await repo.update_attempt_failure(attempt_id, error)
|
||
|
|
finally:
|
||
|
|
await conn.close()
|
||
|
|
|
||
|
|
|
||
|
|
async def _should_escalate(incident_id: UUID) -> bool:
|
||
|
|
conn = await asyncpg.connect(settings.database_url)
|
||
|
|
try:
|
||
|
|
repo = IncidentRepository(conn)
|
||
|
|
incident = await repo.get_by_id(incident_id)
|
||
|
|
if incident is None:
|
||
|
|
return False
|
||
|
|
return incident["status"] == "triggered"
|
||
|
|
finally:
|
||
|
|
await conn.close()
|
||
|
|
|
||
|
|
|
||
|
|
def _simulate_delivery(channel: str, target: dict[str, Any], incident_id: str) -> None:
|
||
|
|
target_name = target.get("name") or target.get("id")
|
||
|
|
logger.info("Simulated %s delivery for incident %s to %s", channel, incident_id, target_name)
|
||
|
|
|
||
|
|
|
||
|
|
@shared_task(name="worker.tasks.notifications.incident_triggered", bind=True)
|
||
|
|
def incident_triggered(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
incident_id: str,
|
||
|
|
org_id: str,
|
||
|
|
triggered_by: str | None = None,
|
||
|
|
) -> None:
|
||
|
|
"""Fan-out notifications to all active targets for the incident's org."""
|
||
|
|
|
||
|
|
incident_uuid = UUID(incident_id)
|
||
|
|
org_uuid = UUID(org_id)
|
||
|
|
try:
|
||
|
|
dispatches = asyncio.run(_prepare_dispatches_with_new_connection(incident_uuid, org_uuid))
|
||
|
|
except Exception as exc: # pragma: no cover - logged for observability
|
||
|
|
logger.exception("Failed to prepare notification dispatches: %s", exc)
|
||
|
|
raise
|
||
|
|
|
||
|
|
if not dispatches:
|
||
|
|
logger.info("No notification targets for org %s", org_id)
|
||
|
|
return
|
||
|
|
|
||
|
|
for dispatch in dispatches:
|
||
|
|
target_type = dispatch.target.get("target_type")
|
||
|
|
kwargs = {
|
||
|
|
"attempt_id": str(dispatch.attempt_id),
|
||
|
|
"incident_id": incident_id,
|
||
|
|
"target": dispatch.target,
|
||
|
|
}
|
||
|
|
if target_type == "webhook":
|
||
|
|
send_webhook.apply_async(kwargs=kwargs, queue=settings.task_queue_default_queue)
|
||
|
|
elif target_type == "email":
|
||
|
|
send_email.apply_async(kwargs=kwargs, queue=settings.task_queue_default_queue)
|
||
|
|
elif target_type == "slack":
|
||
|
|
send_slack.apply_async(kwargs=kwargs, queue=settings.task_queue_default_queue)
|
||
|
|
else:
|
||
|
|
logger.warning("Unsupported notification target type: %s", target_type)
|
||
|
|
|
||
|
|
|
||
|
|
@shared_task(
|
||
|
|
name="worker.tasks.notifications.send_webhook",
|
||
|
|
bind=True,
|
||
|
|
autoretry_for=(Exception,),
|
||
|
|
retry_backoff=True,
|
||
|
|
retry_kwargs={"max_retries": 3},
|
||
|
|
)
|
||
|
|
def send_webhook(self, *, attempt_id: str, target: dict[str, Any], incident_id: str) -> None:
|
||
|
|
"""Simulate webhook delivery and mark the attempt status."""
|
||
|
|
|
||
|
|
try:
|
||
|
|
_simulate_delivery("webhook", target, incident_id)
|
||
|
|
asyncio.run(_mark_attempt_success(UUID(attempt_id)))
|
||
|
|
except Exception as exc: # pragma: no cover - logged for observability
|
||
|
|
logger.exception("Webhook delivery failed: %s", exc)
|
||
|
|
asyncio.run(_mark_attempt_failure(UUID(attempt_id), str(exc)))
|
||
|
|
raise
|
||
|
|
|
||
|
|
|
||
|
|
@shared_task(name="worker.tasks.notifications.send_email", bind=True)
|
||
|
|
def send_email(self, *, attempt_id: str, target: dict[str, Any], incident_id: str) -> None:
|
||
|
|
"""Simulate email delivery for the notification attempt."""
|
||
|
|
|
||
|
|
try:
|
||
|
|
_simulate_delivery("email", target, incident_id)
|
||
|
|
asyncio.run(_mark_attempt_success(UUID(attempt_id)))
|
||
|
|
except Exception as exc: # pragma: no cover
|
||
|
|
logger.exception("Email delivery failed: %s", exc)
|
||
|
|
asyncio.run(_mark_attempt_failure(UUID(attempt_id), str(exc)))
|
||
|
|
raise
|
||
|
|
|
||
|
|
|
||
|
|
@shared_task(name="worker.tasks.notifications.send_slack", bind=True)
|
||
|
|
def send_slack(self, *, attempt_id: str, target: dict[str, Any], incident_id: str) -> None:
|
||
|
|
"""Simulate Slack delivery for the notification attempt."""
|
||
|
|
|
||
|
|
try:
|
||
|
|
_simulate_delivery("slack", target, incident_id)
|
||
|
|
asyncio.run(_mark_attempt_success(UUID(attempt_id)))
|
||
|
|
except Exception as exc: # pragma: no cover
|
||
|
|
logger.exception("Slack delivery failed: %s", exc)
|
||
|
|
asyncio.run(_mark_attempt_failure(UUID(attempt_id), str(exc)))
|
||
|
|
raise
|
||
|
|
|
||
|
|
|
||
|
|
@shared_task(name="worker.tasks.notifications.escalate_if_unacked", bind=True)
|
||
|
|
def escalate_if_unacked(self, *, incident_id: str, org_id: str) -> None:
|
||
|
|
"""Re-dispatch notifications if the incident remains unacknowledged."""
|
||
|
|
|
||
|
|
incident_uuid = UUID(incident_id)
|
||
|
|
should_escalate = asyncio.run(_should_escalate(incident_uuid))
|
||
|
|
if not should_escalate:
|
||
|
|
logger.info("Incident %s no longer needs escalation", incident_id)
|
||
|
|
return
|
||
|
|
|
||
|
|
logger.info("Incident %s still triggered; re-fanning notifications", incident_id)
|
||
|
|
incident_triggered.apply_async( # type: ignore[attr-defined]
|
||
|
|
kwargs={
|
||
|
|
"incident_id": incident_id,
|
||
|
|
"org_id": org_id,
|
||
|
|
"triggered_by": None,
|
||
|
|
},
|
||
|
|
queue=settings.task_queue_critical_queue,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
__all__ = [
|
||
|
|
"NotificationDispatch",
|
||
|
|
"incident_triggered",
|
||
|
|
"escalate_if_unacked",
|
||
|
|
"prepare_notification_dispatches",
|
||
|
|
"send_email",
|
||
|
|
"send_slack",
|
||
|
|
"send_webhook",
|
||
|
|
]
|