feat(api): Pydantic schemas + Data Repositories
This commit is contained in:
1
tests/repositories/__init__.py
Normal file
1
tests/repositories/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Repository tests."""
|
||||
389
tests/repositories/test_incident.py
Normal file
389
tests/repositories/test_incident.py
Normal file
@@ -0,0 +1,389 @@
|
||||
"""Tests for IncidentRepository."""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
|
||||
from app.repositories.incident import IncidentRepository
|
||||
from app.repositories.org import OrgRepository
|
||||
from app.repositories.service import ServiceRepository
|
||||
from app.repositories.user import UserRepository
|
||||
|
||||
|
||||
class TestIncidentRepository:
|
||||
"""Tests for IncidentRepository conforming to SPECS.md."""
|
||||
|
||||
async def _create_org(self, conn: asyncpg.Connection, slug: str) -> uuid4:
|
||||
"""Helper to create an org."""
|
||||
org_repo = OrgRepository(conn)
|
||||
org_id = uuid4()
|
||||
await org_repo.create(org_id, f"Org {slug}", slug)
|
||||
return org_id
|
||||
|
||||
async def _create_service(self, conn: asyncpg.Connection, org_id: uuid4, slug: str) -> uuid4:
|
||||
"""Helper to create a service."""
|
||||
service_repo = ServiceRepository(conn)
|
||||
service_id = uuid4()
|
||||
await service_repo.create(service_id, org_id, f"Service {slug}", slug)
|
||||
return service_id
|
||||
|
||||
async def _create_user(self, conn: asyncpg.Connection, email: str) -> uuid4:
|
||||
"""Helper to create a user."""
|
||||
user_repo = UserRepository(conn)
|
||||
user_id = uuid4()
|
||||
await user_repo.create(user_id, email, "hash")
|
||||
return user_id
|
||||
|
||||
async def test_create_incident_returns_incident_data(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Creating an incident returns the incident data with triggered status."""
|
||||
org_id = await self._create_org(db_conn, "incident-org")
|
||||
service_id = await self._create_service(db_conn, org_id, "incident-service")
|
||||
repo = IncidentRepository(db_conn)
|
||||
incident_id = uuid4()
|
||||
|
||||
result = await repo.create(
|
||||
incident_id, org_id, service_id,
|
||||
title="Server Down",
|
||||
description="Main API server is not responding",
|
||||
severity="critical"
|
||||
)
|
||||
|
||||
assert result["id"] == incident_id
|
||||
assert result["org_id"] == org_id
|
||||
assert result["service_id"] == service_id
|
||||
assert result["title"] == "Server Down"
|
||||
assert result["description"] == "Main API server is not responding"
|
||||
assert result["status"] == "triggered" # Initial status per SPECS.md
|
||||
assert result["severity"] == "critical"
|
||||
assert result["version"] == 1
|
||||
assert result["created_at"] is not None
|
||||
assert result["updated_at"] is not None
|
||||
|
||||
async def test_create_incident_initial_status_is_triggered(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""New incidents always start with 'triggered' status per SPECS.md state machine."""
|
||||
org_id = await self._create_org(db_conn, "triggered-org")
|
||||
service_id = await self._create_service(db_conn, org_id, "triggered-service")
|
||||
repo = IncidentRepository(db_conn)
|
||||
|
||||
result = await repo.create(uuid4(), org_id, service_id, "Test", None, "low")
|
||||
|
||||
assert result["status"] == "triggered"
|
||||
|
||||
async def test_create_incident_initial_version_is_one(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""New incidents start with version 1 for optimistic locking."""
|
||||
org_id = await self._create_org(db_conn, "version-org")
|
||||
service_id = await self._create_service(db_conn, org_id, "version-service")
|
||||
repo = IncidentRepository(db_conn)
|
||||
|
||||
result = await repo.create(uuid4(), org_id, service_id, "Test", None, "medium")
|
||||
|
||||
assert result["version"] == 1
|
||||
|
||||
async def test_create_incident_severity_must_be_valid(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Severity must be critical, high, medium, or low per SPECS.md."""
|
||||
org_id = await self._create_org(db_conn, "severity-org")
|
||||
service_id = await self._create_service(db_conn, org_id, "severity-service")
|
||||
repo = IncidentRepository(db_conn)
|
||||
|
||||
# Valid severities
|
||||
for severity in ["critical", "high", "medium", "low"]:
|
||||
result = await repo.create(uuid4(), org_id, service_id, f"Test {severity}", None, severity)
|
||||
assert result["severity"] == severity
|
||||
|
||||
# Invalid severity
|
||||
with pytest.raises(asyncpg.CheckViolationError):
|
||||
await repo.create(uuid4(), org_id, service_id, "Invalid", None, "extreme")
|
||||
|
||||
async def test_get_by_id_returns_incident(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_id returns the correct incident."""
|
||||
org_id = await self._create_org(db_conn, "getbyid-org")
|
||||
service_id = await self._create_service(db_conn, org_id, "getbyid-service")
|
||||
repo = IncidentRepository(db_conn)
|
||||
incident_id = uuid4()
|
||||
|
||||
await repo.create(incident_id, org_id, service_id, "My Incident", "Details", "high")
|
||||
result = await repo.get_by_id(incident_id)
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == incident_id
|
||||
assert result["title"] == "My Incident"
|
||||
|
||||
async def test_get_by_id_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_id returns None for non-existent incident."""
|
||||
repo = IncidentRepository(db_conn)
|
||||
|
||||
result = await repo.get_by_id(uuid4())
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_by_org_returns_org_incidents(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_org returns incidents for the organization."""
|
||||
org_id = await self._create_org(db_conn, "list-org")
|
||||
service_id = await self._create_service(db_conn, org_id, "list-service")
|
||||
repo = IncidentRepository(db_conn)
|
||||
|
||||
await repo.create(uuid4(), org_id, service_id, "Incident 1", None, "low")
|
||||
await repo.create(uuid4(), org_id, service_id, "Incident 2", None, "medium")
|
||||
await repo.create(uuid4(), org_id, service_id, "Incident 3", None, "high")
|
||||
|
||||
result = await repo.get_by_org(org_id)
|
||||
|
||||
assert len(result) == 3
|
||||
|
||||
async def test_get_by_org_filters_by_status(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_org can filter by status."""
|
||||
org_id = await self._create_org(db_conn, "filter-org")
|
||||
service_id = await self._create_service(db_conn, org_id, "filter-service")
|
||||
repo = IncidentRepository(db_conn)
|
||||
|
||||
# Create incidents and transition some
|
||||
inc1 = uuid4()
|
||||
inc2 = uuid4()
|
||||
await repo.create(inc1, org_id, service_id, "Triggered", None, "low")
|
||||
await repo.create(inc2, org_id, service_id, "Will be Acked", None, "low")
|
||||
await repo.update_status(inc2, "acknowledged", 1)
|
||||
|
||||
result = await repo.get_by_org(org_id, status="triggered")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["title"] == "Triggered"
|
||||
|
||||
async def test_get_by_org_pagination_with_cursor(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_org supports cursor-based pagination."""
|
||||
org_id = await self._create_org(db_conn, "pagination-org")
|
||||
service_id = await self._create_service(db_conn, org_id, "pagination-service")
|
||||
repo = IncidentRepository(db_conn)
|
||||
|
||||
# Create 5 incidents
|
||||
for i in range(5):
|
||||
await repo.create(uuid4(), org_id, service_id, f"Incident {i}", None, "low")
|
||||
|
||||
# Get first page - should return limit+1 to check for more
|
||||
page1 = await repo.get_by_org(org_id, limit=2)
|
||||
assert len(page1) == 3
|
||||
|
||||
# Verify total is 5 when we get all
|
||||
all_incidents = await repo.get_by_org(org_id, limit=10)
|
||||
assert len(all_incidents) == 5
|
||||
|
||||
async def test_get_by_org_orders_by_created_at_desc(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_org returns incidents ordered by created_at descending."""
|
||||
org_id = await self._create_org(db_conn, "order-org")
|
||||
service_id = await self._create_service(db_conn, org_id, "order-service")
|
||||
repo = IncidentRepository(db_conn)
|
||||
|
||||
await repo.create(uuid4(), org_id, service_id, "First", None, "low")
|
||||
await repo.create(uuid4(), org_id, service_id, "Second", None, "low")
|
||||
await repo.create(uuid4(), org_id, service_id, "Third", None, "low")
|
||||
|
||||
result = await repo.get_by_org(org_id)
|
||||
|
||||
# Verify ordering - newer items should come first (or same time due to fast execution)
|
||||
assert len(result) == 3
|
||||
for i in range(len(result) - 1):
|
||||
assert result[i]["created_at"] >= result[i + 1]["created_at"]
|
||||
|
||||
async def test_get_by_org_tenant_isolation(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_org only returns incidents for the specified org."""
|
||||
org1 = await self._create_org(db_conn, "tenant-org-1")
|
||||
org2 = await self._create_org(db_conn, "tenant-org-2")
|
||||
service1 = await self._create_service(db_conn, org1, "tenant-service-1")
|
||||
service2 = await self._create_service(db_conn, org2, "tenant-service-2")
|
||||
repo = IncidentRepository(db_conn)
|
||||
|
||||
await repo.create(uuid4(), org1, service1, "Org1 Incident", None, "low")
|
||||
await repo.create(uuid4(), org2, service2, "Org2 Incident", None, "low")
|
||||
|
||||
result = await repo.get_by_org(org1)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["title"] == "Org1 Incident"
|
||||
|
||||
|
||||
class TestIncidentStatusTransitions:
|
||||
"""Tests for incident status transitions per SPECS.md state machine."""
|
||||
|
||||
async def _setup_incident(self, conn: asyncpg.Connection) -> tuple[uuid4, IncidentRepository]:
|
||||
"""Helper to create org, service, and incident."""
|
||||
org_repo = OrgRepository(conn)
|
||||
service_repo = ServiceRepository(conn)
|
||||
incident_repo = IncidentRepository(conn)
|
||||
|
||||
org_id = uuid4()
|
||||
service_id = uuid4()
|
||||
incident_id = uuid4()
|
||||
|
||||
await org_repo.create(org_id, "Test Org", f"test-org-{uuid4().hex[:8]}")
|
||||
await service_repo.create(service_id, org_id, "Test Service", f"test-service-{uuid4().hex[:8]}")
|
||||
await incident_repo.create(incident_id, org_id, service_id, "Test Incident", None, "medium")
|
||||
|
||||
return incident_id, incident_repo
|
||||
|
||||
async def test_update_status_increments_version(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""update_status increments version for optimistic locking."""
|
||||
incident_id, repo = await self._setup_incident(db_conn)
|
||||
|
||||
result = await repo.update_status(incident_id, "acknowledged", 1)
|
||||
|
||||
assert result is not None
|
||||
assert result["version"] == 2
|
||||
|
||||
async def test_update_status_fails_on_version_mismatch(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""update_status returns None on version mismatch (optimistic locking)."""
|
||||
incident_id, repo = await self._setup_incident(db_conn)
|
||||
|
||||
# Try with wrong version
|
||||
result = await repo.update_status(incident_id, "acknowledged", 999)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_update_status_updates_updated_at(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""update_status updates the updated_at timestamp."""
|
||||
incident_id, repo = await self._setup_incident(db_conn)
|
||||
|
||||
before = await repo.get_by_id(incident_id)
|
||||
result = await repo.update_status(incident_id, "acknowledged", 1)
|
||||
|
||||
# updated_at should be at least as recent as before (may be same in fast execution)
|
||||
assert result["updated_at"] >= before["updated_at"]
|
||||
# Also verify status was actually updated
|
||||
assert result["status"] == "acknowledged"
|
||||
|
||||
async def test_status_must_be_valid_value(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Status must be triggered, acknowledged, mitigated, or resolved per SPECS.md."""
|
||||
incident_id, repo = await self._setup_incident(db_conn)
|
||||
|
||||
with pytest.raises(asyncpg.CheckViolationError):
|
||||
await repo.update_status(incident_id, "invalid_status", 1)
|
||||
|
||||
async def test_valid_status_transitions(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Test the valid status values per SPECS.md."""
|
||||
incident_id, repo = await self._setup_incident(db_conn)
|
||||
|
||||
# Triggered -> Acknowledged
|
||||
result = await repo.update_status(incident_id, "acknowledged", 1)
|
||||
assert result["status"] == "acknowledged"
|
||||
|
||||
# Acknowledged -> Mitigated
|
||||
result = await repo.update_status(incident_id, "mitigated", 2)
|
||||
assert result["status"] == "mitigated"
|
||||
|
||||
# Mitigated -> Resolved
|
||||
result = await repo.update_status(incident_id, "resolved", 3)
|
||||
assert result["status"] == "resolved"
|
||||
|
||||
|
||||
class TestIncidentEvents:
|
||||
"""Tests for incident events (timeline) per SPECS.md incident_events table."""
|
||||
|
||||
async def _setup_incident(self, conn: asyncpg.Connection) -> tuple[uuid4, uuid4, IncidentRepository]:
|
||||
"""Helper to create org, service, user, and incident."""
|
||||
org_repo = OrgRepository(conn)
|
||||
service_repo = ServiceRepository(conn)
|
||||
user_repo = UserRepository(conn)
|
||||
incident_repo = IncidentRepository(conn)
|
||||
|
||||
org_id = uuid4()
|
||||
service_id = uuid4()
|
||||
user_id = uuid4()
|
||||
incident_id = uuid4()
|
||||
|
||||
await org_repo.create(org_id, "Test Org", f"test-org-{uuid4().hex[:8]}")
|
||||
await service_repo.create(service_id, org_id, "Test Service", f"test-svc-{uuid4().hex[:8]}")
|
||||
await user_repo.create(user_id, f"user-{uuid4().hex[:8]}@example.com", "hash")
|
||||
await incident_repo.create(incident_id, org_id, service_id, "Test Incident", None, "medium")
|
||||
|
||||
return incident_id, user_id, incident_repo
|
||||
|
||||
async def test_add_event_creates_event(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""add_event creates an event in the timeline."""
|
||||
incident_id, user_id, repo = await self._setup_incident(db_conn)
|
||||
event_id = uuid4()
|
||||
|
||||
result = await repo.add_event(
|
||||
event_id, incident_id, "status_changed",
|
||||
actor_user_id=user_id,
|
||||
payload={"from": "triggered", "to": "acknowledged"}
|
||||
)
|
||||
|
||||
assert result["id"] == event_id
|
||||
assert result["incident_id"] == incident_id
|
||||
assert result["event_type"] == "status_changed"
|
||||
assert result["actor_user_id"] == user_id
|
||||
assert result["payload"] == {"from": "triggered", "to": "acknowledged"}
|
||||
assert result["created_at"] is not None
|
||||
|
||||
async def test_add_event_allows_null_actor(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""add_event allows null actor_user_id (system events)."""
|
||||
incident_id, _, repo = await self._setup_incident(db_conn)
|
||||
|
||||
result = await repo.add_event(
|
||||
uuid4(), incident_id, "auto_escalated",
|
||||
actor_user_id=None,
|
||||
payload={"reason": "Unacknowledged after 30 minutes"}
|
||||
)
|
||||
|
||||
assert result["actor_user_id"] is None
|
||||
|
||||
async def test_add_event_allows_null_payload(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""add_event allows null payload."""
|
||||
incident_id, user_id, repo = await self._setup_incident(db_conn)
|
||||
|
||||
result = await repo.add_event(
|
||||
uuid4(), incident_id, "viewed",
|
||||
actor_user_id=user_id,
|
||||
payload=None
|
||||
)
|
||||
|
||||
assert result["payload"] is None
|
||||
|
||||
async def test_get_events_returns_all_incident_events(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_events returns all events for an incident."""
|
||||
incident_id, user_id, repo = await self._setup_incident(db_conn)
|
||||
|
||||
await repo.add_event(uuid4(), incident_id, "created", user_id, {"title": "Test"})
|
||||
await repo.add_event(uuid4(), incident_id, "status_changed", user_id, {"to": "acked"})
|
||||
await repo.add_event(uuid4(), incident_id, "comment_added", user_id, {"text": "Working on it"})
|
||||
|
||||
result = await repo.get_events(incident_id)
|
||||
|
||||
assert len(result) == 3
|
||||
event_types = [e["event_type"] for e in result]
|
||||
assert event_types == ["created", "status_changed", "comment_added"]
|
||||
|
||||
async def test_get_events_orders_by_created_at(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_events returns events in chronological order."""
|
||||
incident_id, user_id, repo = await self._setup_incident(db_conn)
|
||||
|
||||
await repo.add_event(uuid4(), incident_id, "first", user_id, None)
|
||||
await repo.add_event(uuid4(), incident_id, "second", user_id, None)
|
||||
await repo.add_event(uuid4(), incident_id, "third", user_id, None)
|
||||
|
||||
result = await repo.get_events(incident_id)
|
||||
|
||||
assert result[0]["event_type"] == "first"
|
||||
assert result[1]["event_type"] == "second"
|
||||
assert result[2]["event_type"] == "third"
|
||||
|
||||
async def test_get_events_returns_empty_for_no_events(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_events returns empty list for incident with no events."""
|
||||
incident_id, _, repo = await self._setup_incident(db_conn)
|
||||
|
||||
result = await repo.get_events(incident_id)
|
||||
|
||||
assert result == []
|
||||
|
||||
async def test_event_requires_valid_incident_foreign_key(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""incident_events.incident_id must reference existing incident."""
|
||||
incident_id, user_id, repo = await self._setup_incident(db_conn)
|
||||
|
||||
with pytest.raises(asyncpg.ForeignKeyViolationError):
|
||||
await repo.add_event(uuid4(), uuid4(), "test", user_id, None)
|
||||
|
||||
async def test_event_requires_valid_user_foreign_key(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""incident_events.actor_user_id must reference existing user if not null."""
|
||||
incident_id, _, repo = await self._setup_incident(db_conn)
|
||||
|
||||
with pytest.raises(asyncpg.ForeignKeyViolationError):
|
||||
await repo.add_event(uuid4(), incident_id, "test", uuid4(), None)
|
||||
362
tests/repositories/test_notification.py
Normal file
362
tests/repositories/test_notification.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""Tests for NotificationRepository."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from uuid import uuid4
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
|
||||
from app.repositories.incident import IncidentRepository
|
||||
from app.repositories.notification import NotificationRepository
|
||||
from app.repositories.org import OrgRepository
|
||||
from app.repositories.service import ServiceRepository
|
||||
|
||||
|
||||
class TestNotificationTargetRepository:
|
||||
"""Tests for notification targets per SPECS.md notification_targets table."""
|
||||
|
||||
async def _create_org(self, conn: asyncpg.Connection, slug: str) -> uuid4:
|
||||
"""Helper to create an org."""
|
||||
org_repo = OrgRepository(conn)
|
||||
org_id = uuid4()
|
||||
await org_repo.create(org_id, f"Org {slug}", slug)
|
||||
return org_id
|
||||
|
||||
async def test_create_target_returns_target_data(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Creating a notification target returns the target data."""
|
||||
org_id = await self._create_org(db_conn, "target-org")
|
||||
repo = NotificationRepository(db_conn)
|
||||
target_id = uuid4()
|
||||
|
||||
result = await repo.create_target(
|
||||
target_id, org_id, "Slack Alerts",
|
||||
target_type="webhook",
|
||||
webhook_url="https://hooks.slack.com/services/xxx",
|
||||
enabled=True
|
||||
)
|
||||
|
||||
assert result["id"] == target_id
|
||||
assert result["org_id"] == org_id
|
||||
assert result["name"] == "Slack Alerts"
|
||||
assert result["target_type"] == "webhook"
|
||||
assert result["webhook_url"] == "https://hooks.slack.com/services/xxx"
|
||||
assert result["enabled"] is True
|
||||
assert result["created_at"] is not None
|
||||
|
||||
async def test_create_target_type_must_be_valid(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Target type must be webhook, email, or slack per SPECS.md."""
|
||||
org_id = await self._create_org(db_conn, "type-org")
|
||||
repo = NotificationRepository(db_conn)
|
||||
|
||||
# Valid types
|
||||
for target_type in ["webhook", "email", "slack"]:
|
||||
result = await repo.create_target(
|
||||
uuid4(), org_id, f"{target_type} target",
|
||||
target_type=target_type
|
||||
)
|
||||
assert result["target_type"] == target_type
|
||||
|
||||
# Invalid type
|
||||
with pytest.raises(asyncpg.CheckViolationError):
|
||||
await repo.create_target(
|
||||
uuid4(), org_id, "Invalid",
|
||||
target_type="sms"
|
||||
)
|
||||
|
||||
async def test_create_target_webhook_url_optional(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""webhook_url is optional (for email/slack types)."""
|
||||
org_id = await self._create_org(db_conn, "optional-url-org")
|
||||
repo = NotificationRepository(db_conn)
|
||||
|
||||
result = await repo.create_target(
|
||||
uuid4(), org_id, "Email Alerts",
|
||||
target_type="email",
|
||||
webhook_url=None
|
||||
)
|
||||
|
||||
assert result["webhook_url"] is None
|
||||
|
||||
async def test_create_target_enabled_defaults_to_true(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""enabled defaults to True."""
|
||||
org_id = await self._create_org(db_conn, "default-enabled-org")
|
||||
repo = NotificationRepository(db_conn)
|
||||
|
||||
result = await repo.create_target(
|
||||
uuid4(), org_id, "Default Enabled",
|
||||
target_type="webhook"
|
||||
)
|
||||
|
||||
assert result["enabled"] is True
|
||||
|
||||
async def test_get_target_by_id_returns_target(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_target_by_id returns the correct target."""
|
||||
org_id = await self._create_org(db_conn, "getbyid-target-org")
|
||||
repo = NotificationRepository(db_conn)
|
||||
target_id = uuid4()
|
||||
|
||||
await repo.create_target(target_id, org_id, "My Target", "webhook")
|
||||
result = await repo.get_target_by_id(target_id)
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == target_id
|
||||
assert result["name"] == "My Target"
|
||||
|
||||
async def test_get_target_by_id_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_target_by_id returns None for non-existent target."""
|
||||
repo = NotificationRepository(db_conn)
|
||||
|
||||
result = await repo.get_target_by_id(uuid4())
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_targets_by_org_returns_all_targets(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_targets_by_org returns all targets for an organization."""
|
||||
org_id = await self._create_org(db_conn, "multi-target-org")
|
||||
repo = NotificationRepository(db_conn)
|
||||
|
||||
await repo.create_target(uuid4(), org_id, "Target A", "webhook")
|
||||
await repo.create_target(uuid4(), org_id, "Target B", "email")
|
||||
await repo.create_target(uuid4(), org_id, "Target C", "slack")
|
||||
|
||||
result = await repo.get_targets_by_org(org_id)
|
||||
|
||||
assert len(result) == 3
|
||||
names = {t["name"] for t in result}
|
||||
assert names == {"Target A", "Target B", "Target C"}
|
||||
|
||||
async def test_get_targets_by_org_filters_enabled(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_targets_by_org can filter to only enabled targets."""
|
||||
org_id = await self._create_org(db_conn, "enabled-filter-org")
|
||||
repo = NotificationRepository(db_conn)
|
||||
|
||||
await repo.create_target(uuid4(), org_id, "Enabled", "webhook", enabled=True)
|
||||
await repo.create_target(uuid4(), org_id, "Disabled", "webhook", enabled=False)
|
||||
|
||||
result = await repo.get_targets_by_org(org_id, enabled_only=True)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "Enabled"
|
||||
|
||||
async def test_get_targets_by_org_tenant_isolation(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_targets_by_org only returns targets for the specified org."""
|
||||
org1 = await self._create_org(db_conn, "isolated-target-org-1")
|
||||
org2 = await self._create_org(db_conn, "isolated-target-org-2")
|
||||
repo = NotificationRepository(db_conn)
|
||||
|
||||
await repo.create_target(uuid4(), org1, "Org1 Target", "webhook")
|
||||
await repo.create_target(uuid4(), org2, "Org2 Target", "webhook")
|
||||
|
||||
result = await repo.get_targets_by_org(org1)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "Org1 Target"
|
||||
|
||||
async def test_update_target_updates_fields(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""update_target updates the specified fields."""
|
||||
org_id = await self._create_org(db_conn, "update-target-org")
|
||||
repo = NotificationRepository(db_conn)
|
||||
target_id = uuid4()
|
||||
|
||||
await repo.create_target(target_id, org_id, "Original", "webhook", enabled=True)
|
||||
result = await repo.update_target(target_id, name="Updated", enabled=False)
|
||||
|
||||
assert result is not None
|
||||
assert result["name"] == "Updated"
|
||||
assert result["enabled"] is False
|
||||
|
||||
async def test_update_target_partial_update(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""update_target only updates provided fields."""
|
||||
org_id = await self._create_org(db_conn, "partial-update-org")
|
||||
repo = NotificationRepository(db_conn)
|
||||
target_id = uuid4()
|
||||
|
||||
await repo.create_target(
|
||||
target_id, org_id, "Original Name", "webhook",
|
||||
webhook_url="https://original.com", enabled=True
|
||||
)
|
||||
result = await repo.update_target(target_id, name="New Name")
|
||||
|
||||
assert result["name"] == "New Name"
|
||||
assert result["webhook_url"] == "https://original.com"
|
||||
assert result["enabled"] is True
|
||||
|
||||
async def test_delete_target_removes_target(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""delete_target removes the target."""
|
||||
org_id = await self._create_org(db_conn, "delete-target-org")
|
||||
repo = NotificationRepository(db_conn)
|
||||
target_id = uuid4()
|
||||
|
||||
await repo.create_target(target_id, org_id, "To Delete", "webhook")
|
||||
result = await repo.delete_target(target_id)
|
||||
|
||||
assert result is True
|
||||
assert await repo.get_target_by_id(target_id) is None
|
||||
|
||||
async def test_delete_target_returns_false_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""delete_target returns False for non-existent target."""
|
||||
repo = NotificationRepository(db_conn)
|
||||
|
||||
result = await repo.delete_target(uuid4())
|
||||
|
||||
assert result is False
|
||||
|
||||
async def test_target_requires_valid_org_foreign_key(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""notification_targets.org_id must reference existing org."""
|
||||
repo = NotificationRepository(db_conn)
|
||||
|
||||
with pytest.raises(asyncpg.ForeignKeyViolationError):
|
||||
await repo.create_target(uuid4(), uuid4(), "Orphan Target", "webhook")
|
||||
|
||||
|
||||
class TestNotificationAttemptRepository:
|
||||
"""Tests for notification attempts per SPECS.md notification_attempts table."""
|
||||
|
||||
async def _setup_incident_and_target(self, conn: asyncpg.Connection) -> tuple[uuid4, uuid4, NotificationRepository]:
|
||||
"""Helper to create org, service, incident, and notification target."""
|
||||
org_repo = OrgRepository(conn)
|
||||
service_repo = ServiceRepository(conn)
|
||||
incident_repo = IncidentRepository(conn)
|
||||
notification_repo = NotificationRepository(conn)
|
||||
|
||||
org_id = uuid4()
|
||||
service_id = uuid4()
|
||||
incident_id = uuid4()
|
||||
target_id = uuid4()
|
||||
|
||||
await org_repo.create(org_id, "Test Org", f"test-org-{uuid4().hex[:8]}")
|
||||
await service_repo.create(service_id, org_id, "Test Service", f"test-svc-{uuid4().hex[:8]}")
|
||||
await incident_repo.create(incident_id, org_id, service_id, "Test Incident", None, "medium")
|
||||
await notification_repo.create_target(target_id, org_id, "Test Target", "webhook")
|
||||
|
||||
return incident_id, target_id, notification_repo
|
||||
|
||||
async def test_create_attempt_returns_attempt_data(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Creating a notification attempt returns the attempt data."""
|
||||
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
|
||||
attempt_id = uuid4()
|
||||
|
||||
result = await repo.create_attempt(attempt_id, incident_id, target_id)
|
||||
|
||||
assert result["id"] == attempt_id
|
||||
assert result["incident_id"] == incident_id
|
||||
assert result["target_id"] == target_id
|
||||
assert result["status"] == "pending"
|
||||
assert result["error"] is None
|
||||
assert result["sent_at"] is None
|
||||
assert result["created_at"] is not None
|
||||
|
||||
async def test_create_attempt_idempotent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""create_attempt is idempotent per SPECS.md (unique constraint on incident+target)."""
|
||||
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
|
||||
|
||||
# First attempt
|
||||
result1 = await repo.create_attempt(uuid4(), incident_id, target_id)
|
||||
# Second attempt with same incident+target
|
||||
result2 = await repo.create_attempt(uuid4(), incident_id, target_id)
|
||||
|
||||
# Should return the same attempt
|
||||
assert result1["id"] == result2["id"]
|
||||
|
||||
async def test_get_attempt_returns_attempt(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_attempt returns the attempt for incident and target."""
|
||||
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
|
||||
|
||||
await repo.create_attempt(uuid4(), incident_id, target_id)
|
||||
result = await repo.get_attempt(incident_id, target_id)
|
||||
|
||||
assert result is not None
|
||||
assert result["incident_id"] == incident_id
|
||||
assert result["target_id"] == target_id
|
||||
|
||||
async def test_get_attempt_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_attempt returns None for non-existent attempt."""
|
||||
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
|
||||
|
||||
result = await repo.get_attempt(incident_id, target_id)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_update_attempt_success_sets_sent_status(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""update_attempt_success marks attempt as sent."""
|
||||
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
|
||||
attempt = await repo.create_attempt(uuid4(), incident_id, target_id)
|
||||
sent_at = datetime.now(UTC)
|
||||
|
||||
result = await repo.update_attempt_success(attempt["id"], sent_at)
|
||||
|
||||
assert result is not None
|
||||
assert result["status"] == "sent"
|
||||
assert result["sent_at"] is not None
|
||||
assert result["error"] is None
|
||||
|
||||
async def test_update_attempt_failure_sets_failed_status(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""update_attempt_failure marks attempt as failed with error."""
|
||||
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
|
||||
attempt = await repo.create_attempt(uuid4(), incident_id, target_id)
|
||||
|
||||
result = await repo.update_attempt_failure(attempt["id"], "Connection timeout")
|
||||
|
||||
assert result is not None
|
||||
assert result["status"] == "failed"
|
||||
assert result["error"] == "Connection timeout"
|
||||
|
||||
async def test_attempt_status_must_be_valid(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Attempt status must be pending, sent, or failed per SPECS.md."""
|
||||
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
|
||||
|
||||
# Create with default 'pending' status - valid
|
||||
result = await repo.create_attempt(uuid4(), incident_id, target_id)
|
||||
assert result["status"] == "pending"
|
||||
|
||||
# Transition to 'sent' - valid
|
||||
result = await repo.update_attempt_success(result["id"], datetime.now(UTC))
|
||||
assert result["status"] == "sent"
|
||||
|
||||
async def test_get_pending_attempts_returns_pending_with_target_info(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_pending_attempts returns pending attempts with target details."""
|
||||
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
|
||||
|
||||
await repo.create_attempt(uuid4(), incident_id, target_id)
|
||||
result = await repo.get_pending_attempts(incident_id)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["status"] == "pending"
|
||||
assert result[0]["target_id"] == target_id
|
||||
assert "target_type" in result[0]
|
||||
assert "target_name" in result[0]
|
||||
|
||||
async def test_get_pending_attempts_excludes_sent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_pending_attempts excludes sent attempts."""
|
||||
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
|
||||
|
||||
attempt = await repo.create_attempt(uuid4(), incident_id, target_id)
|
||||
await repo.update_attempt_success(attempt["id"], datetime.now(UTC))
|
||||
|
||||
result = await repo.get_pending_attempts(incident_id)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
async def test_get_pending_attempts_excludes_failed(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_pending_attempts excludes failed attempts."""
|
||||
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
|
||||
|
||||
attempt = await repo.create_attempt(uuid4(), incident_id, target_id)
|
||||
await repo.update_attempt_failure(attempt["id"], "Error")
|
||||
|
||||
result = await repo.get_pending_attempts(incident_id)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
async def test_attempt_requires_valid_incident_foreign_key(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""notification_attempts.incident_id must reference existing incident."""
|
||||
_, target_id, repo = await self._setup_incident_and_target(db_conn)
|
||||
|
||||
with pytest.raises(asyncpg.ForeignKeyViolationError):
|
||||
await repo.create_attempt(uuid4(), uuid4(), target_id)
|
||||
|
||||
async def test_attempt_requires_valid_target_foreign_key(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""notification_attempts.target_id must reference existing target."""
|
||||
incident_id, _, repo = await self._setup_incident_and_target(db_conn)
|
||||
|
||||
with pytest.raises(asyncpg.ForeignKeyViolationError):
|
||||
await repo.create_attempt(uuid4(), incident_id, uuid4())
|
||||
250
tests/repositories/test_org.py
Normal file
250
tests/repositories/test_org.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""Tests for OrgRepository."""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
|
||||
from app.repositories.org import OrgRepository
|
||||
from app.repositories.user import UserRepository
|
||||
|
||||
|
||||
class TestOrgRepository:
|
||||
"""Tests for OrgRepository conforming to SPECS.md."""
|
||||
|
||||
async def test_create_org_returns_org_data(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Creating an org returns the org data."""
|
||||
repo = OrgRepository(db_conn)
|
||||
org_id = uuid4()
|
||||
name = "Test Organization"
|
||||
slug = "test-org"
|
||||
|
||||
result = await repo.create(org_id, name, slug)
|
||||
|
||||
assert result["id"] == org_id
|
||||
assert result["name"] == name
|
||||
assert result["slug"] == slug
|
||||
assert result["created_at"] is not None
|
||||
|
||||
async def test_create_org_slug_must_be_unique(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Org slug uniqueness constraint per SPECS.md orgs table."""
|
||||
repo = OrgRepository(db_conn)
|
||||
slug = "unique-slug"
|
||||
|
||||
await repo.create(uuid4(), "Org One", slug)
|
||||
|
||||
with pytest.raises(asyncpg.UniqueViolationError):
|
||||
await repo.create(uuid4(), "Org Two", slug)
|
||||
|
||||
async def test_get_by_id_returns_org(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_id returns the correct organization."""
|
||||
repo = OrgRepository(db_conn)
|
||||
org_id = uuid4()
|
||||
|
||||
await repo.create(org_id, "My Org", "my-org")
|
||||
result = await repo.get_by_id(org_id)
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == org_id
|
||||
assert result["name"] == "My Org"
|
||||
|
||||
async def test_get_by_id_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_id returns None for non-existent org."""
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
result = await repo.get_by_id(uuid4())
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_by_slug_returns_org(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_slug returns the correct organization."""
|
||||
repo = OrgRepository(db_conn)
|
||||
org_id = uuid4()
|
||||
slug = "slug-lookup"
|
||||
|
||||
await repo.create(org_id, "Slug Test", slug)
|
||||
result = await repo.get_by_slug(slug)
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == org_id
|
||||
|
||||
async def test_get_by_slug_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_slug returns None for non-existent slug."""
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
result = await repo.get_by_slug("nonexistent-slug")
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_slug_exists_returns_true_when_exists(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""slug_exists returns True when slug exists."""
|
||||
repo = OrgRepository(db_conn)
|
||||
slug = "existing-slug"
|
||||
|
||||
await repo.create(uuid4(), "Existing Org", slug)
|
||||
result = await repo.slug_exists(slug)
|
||||
|
||||
assert result is True
|
||||
|
||||
async def test_slug_exists_returns_false_when_not_exists(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""slug_exists returns False when slug doesn't exist."""
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
result = await repo.slug_exists("no-such-slug")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestOrgMembership:
|
||||
"""Tests for org membership operations per SPECS.md org_members table."""
|
||||
|
||||
async def _create_user(self, conn: asyncpg.Connection, email: str) -> uuid4:
|
||||
"""Helper to create a user."""
|
||||
user_repo = UserRepository(conn)
|
||||
user_id = uuid4()
|
||||
await user_repo.create(user_id, email, "hash")
|
||||
return user_id
|
||||
|
||||
async def _create_org(self, conn: asyncpg.Connection, slug: str) -> uuid4:
|
||||
"""Helper to create an org."""
|
||||
org_repo = OrgRepository(conn)
|
||||
org_id = uuid4()
|
||||
await org_repo.create(org_id, f"Org {slug}", slug)
|
||||
return org_id
|
||||
|
||||
async def test_add_member_creates_membership(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""add_member creates a membership record."""
|
||||
user_id = await self._create_user(db_conn, "member@example.com")
|
||||
org_id = await self._create_org(db_conn, "member-org")
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
result = await repo.add_member(uuid4(), user_id, org_id, "member")
|
||||
|
||||
assert result["user_id"] == user_id
|
||||
assert result["org_id"] == org_id
|
||||
assert result["role"] == "member"
|
||||
assert result["created_at"] is not None
|
||||
|
||||
async def test_add_member_role_must_be_valid(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Role must be admin, member, or viewer per SPECS.md."""
|
||||
org_id = await self._create_org(db_conn, "role-test-org")
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
# Valid roles should work
|
||||
for role in ["admin", "member", "viewer"]:
|
||||
member_id = uuid4()
|
||||
# Need a new user for each since user+org must be unique
|
||||
new_user_id = await self._create_user(db_conn, f"{role}@example.com")
|
||||
result = await repo.add_member(member_id, new_user_id, org_id, role)
|
||||
assert result["role"] == role
|
||||
|
||||
# Invalid role should fail
|
||||
another_user = await self._create_user(db_conn, "invalid_role@example.com")
|
||||
with pytest.raises(asyncpg.CheckViolationError):
|
||||
await repo.add_member(uuid4(), another_user, org_id, "superuser")
|
||||
|
||||
async def test_add_member_user_org_must_be_unique(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""User can only be member of an org once (unique constraint)."""
|
||||
user_id = await self._create_user(db_conn, "unique_member@example.com")
|
||||
org_id = await self._create_org(db_conn, "unique-member-org")
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
await repo.add_member(uuid4(), user_id, org_id, "member")
|
||||
|
||||
with pytest.raises(asyncpg.UniqueViolationError):
|
||||
await repo.add_member(uuid4(), user_id, org_id, "admin")
|
||||
|
||||
async def test_get_member_returns_membership(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_member returns the membership for user and org."""
|
||||
user_id = await self._create_user(db_conn, "get_member@example.com")
|
||||
org_id = await self._create_org(db_conn, "get-member-org")
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
await repo.add_member(uuid4(), user_id, org_id, "admin")
|
||||
result = await repo.get_member(user_id, org_id)
|
||||
|
||||
assert result is not None
|
||||
assert result["user_id"] == user_id
|
||||
assert result["org_id"] == org_id
|
||||
assert result["role"] == "admin"
|
||||
|
||||
async def test_get_member_returns_none_for_nonmember(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_member returns None if user is not a member."""
|
||||
user_id = await self._create_user(db_conn, "nonmember@example.com")
|
||||
org_id = await self._create_org(db_conn, "nonmember-org")
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
result = await repo.get_member(user_id, org_id)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_members_returns_all_org_members(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_members returns all members with their emails."""
|
||||
org_id = await self._create_org(db_conn, "all-members-org")
|
||||
user1 = await self._create_user(db_conn, "user1@example.com")
|
||||
user2 = await self._create_user(db_conn, "user2@example.com")
|
||||
user3 = await self._create_user(db_conn, "user3@example.com")
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
await repo.add_member(uuid4(), user1, org_id, "admin")
|
||||
await repo.add_member(uuid4(), user2, org_id, "member")
|
||||
await repo.add_member(uuid4(), user3, org_id, "viewer")
|
||||
|
||||
result = await repo.get_members(org_id)
|
||||
|
||||
assert len(result) == 3
|
||||
emails = {m["email"] for m in result}
|
||||
assert emails == {"user1@example.com", "user2@example.com", "user3@example.com"}
|
||||
|
||||
async def test_get_members_returns_empty_list_for_no_members(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_members returns empty list for org with no members."""
|
||||
org_id = await self._create_org(db_conn, "empty-org")
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
result = await repo.get_members(org_id)
|
||||
|
||||
assert result == []
|
||||
|
||||
async def test_get_user_orgs_returns_all_user_memberships(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_user_orgs returns all orgs a user belongs to with their role."""
|
||||
user_id = await self._create_user(db_conn, "multi_org@example.com")
|
||||
org1 = await self._create_org(db_conn, "user-org-1")
|
||||
org2 = await self._create_org(db_conn, "user-org-2")
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
await repo.add_member(uuid4(), user_id, org1, "admin")
|
||||
await repo.add_member(uuid4(), user_id, org2, "member")
|
||||
|
||||
result = await repo.get_user_orgs(user_id)
|
||||
|
||||
assert len(result) == 2
|
||||
slugs = {o["slug"] for o in result}
|
||||
assert slugs == {"user-org-1", "user-org-2"}
|
||||
# Check role is included
|
||||
roles = {o["role"] for o in result}
|
||||
assert roles == {"admin", "member"}
|
||||
|
||||
async def test_get_user_orgs_returns_empty_for_no_memberships(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_user_orgs returns empty list for user with no memberships."""
|
||||
user_id = await self._create_user(db_conn, "no_orgs@example.com")
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
result = await repo.get_user_orgs(user_id)
|
||||
|
||||
assert result == []
|
||||
|
||||
async def test_member_requires_valid_user_foreign_key(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""org_members.user_id must reference existing user."""
|
||||
org_id = await self._create_org(db_conn, "fk-test-org")
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
with pytest.raises(asyncpg.ForeignKeyViolationError):
|
||||
await repo.add_member(uuid4(), uuid4(), org_id, "member")
|
||||
|
||||
async def test_member_requires_valid_org_foreign_key(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""org_members.org_id must reference existing org."""
|
||||
user_id = await self._create_user(db_conn, "fk_user@example.com")
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
with pytest.raises(asyncpg.ForeignKeyViolationError):
|
||||
await repo.add_member(uuid4(), user_id, uuid4(), "member")
|
||||
788
tests/repositories/test_refresh_token.py
Normal file
788
tests/repositories/test_refresh_token.py
Normal file
@@ -0,0 +1,788 @@
|
||||
"""Tests for RefreshTokenRepository with security features."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
|
||||
from app.repositories.org import OrgRepository
|
||||
from app.repositories.refresh_token import RefreshTokenRepository
|
||||
from app.repositories.user import UserRepository
|
||||
|
||||
|
||||
class TestRefreshTokenRepository:
|
||||
"""Tests for basic RefreshTokenRepository operations."""
|
||||
|
||||
async def _create_user(self, conn: asyncpg.Connection, email: str) -> uuid4:
|
||||
"""Helper to create a user."""
|
||||
user_repo = UserRepository(conn)
|
||||
user_id = uuid4()
|
||||
await user_repo.create(user_id, email, "hash")
|
||||
return user_id
|
||||
|
||||
async def _create_org(self, conn: asyncpg.Connection, slug: str) -> uuid4:
|
||||
"""Helper to create an org."""
|
||||
org_repo = OrgRepository(conn)
|
||||
org_id = uuid4()
|
||||
await org_repo.create(org_id, f"Org {slug}", slug)
|
||||
return org_id
|
||||
|
||||
async def test_create_token_returns_token_data(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Creating a refresh token returns the token data including rotated_to."""
|
||||
user_id = await self._create_user(db_conn, "token_create@example.com")
|
||||
org_id = await self._create_org(db_conn, "token-create-org")
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
token_id = uuid4()
|
||||
token_hash = "sha256_hashed_token_value"
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
|
||||
result = await repo.create(token_id, user_id, token_hash, org_id, expires_at)
|
||||
|
||||
assert result["id"] == token_id
|
||||
assert result["user_id"] == user_id
|
||||
assert result["token_hash"] == token_hash
|
||||
assert result["active_org_id"] == org_id
|
||||
assert result["expires_at"] is not None
|
||||
assert result["revoked_at"] is None
|
||||
assert result["rotated_to"] is None # New field
|
||||
assert result["created_at"] is not None
|
||||
|
||||
async def test_token_hash_must_be_unique(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Token hash uniqueness constraint per SPECS.md refresh_tokens table."""
|
||||
user_id = await self._create_user(db_conn, "unique_hash@example.com")
|
||||
org_id = await self._create_org(db_conn, "unique-hash-org")
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
token_hash = "duplicate_hash_value"
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
|
||||
await repo.create(uuid4(), user_id, token_hash, org_id, expires_at)
|
||||
|
||||
with pytest.raises(asyncpg.UniqueViolationError):
|
||||
await repo.create(uuid4(), user_id, token_hash, org_id, expires_at)
|
||||
|
||||
async def test_get_by_hash_returns_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_hash returns the correct token (even if revoked/expired)."""
|
||||
user_id = await self._create_user(db_conn, "get_hash@example.com")
|
||||
org_id = await self._create_org(db_conn, "get-hash-org")
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
token_id = uuid4()
|
||||
token_hash = "lookup_by_hash_value"
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
|
||||
await repo.create(token_id, user_id, token_hash, org_id, expires_at)
|
||||
result = await repo.get_by_hash(token_hash)
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == token_id
|
||||
assert result["token_hash"] == token_hash
|
||||
|
||||
async def test_get_by_hash_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_hash returns None for non-existent hash."""
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
result = await repo.get_by_hash("nonexistent_hash")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetValidByHash:
|
||||
"""Tests for get_valid_by_hash with defense-in-depth validation."""
|
||||
|
||||
async def _setup_token(
|
||||
self, conn: asyncpg.Connection, suffix: str = ""
|
||||
) -> tuple[uuid4, uuid4, uuid4, str, RefreshTokenRepository]:
|
||||
"""Helper to create user, org, and token."""
|
||||
user_repo = UserRepository(conn)
|
||||
org_repo = OrgRepository(conn)
|
||||
token_repo = RefreshTokenRepository(conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
token_id = uuid4()
|
||||
token_hash = f"token_hash_{uuid4().hex[:8]}{suffix}"
|
||||
|
||||
await user_repo.create(user_id, f"user_{uuid4().hex[:8]}@example.com", "hash")
|
||||
await org_repo.create(org_id, "Test Org", f"test-org-{uuid4().hex[:8]}")
|
||||
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
await token_repo.create(token_id, user_id, token_hash, org_id, expires_at)
|
||||
|
||||
return token_id, user_id, org_id, token_hash, token_repo
|
||||
|
||||
async def test_get_valid_returns_valid_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_valid_by_hash returns token if not expired, not revoked, not rotated."""
|
||||
_, _, _, token_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
result = await repo.get_valid_by_hash(token_hash)
|
||||
|
||||
assert result is not None
|
||||
assert result["token_hash"] == token_hash
|
||||
|
||||
async def test_get_valid_returns_none_for_expired(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_valid_by_hash returns None for expired token."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, "expired@example.com", "hash")
|
||||
await org_repo.create(org_id, "Org", "expired-org")
|
||||
|
||||
token_hash = "expired_token_hash"
|
||||
expires_at = datetime.now(UTC) - timedelta(days=1) # Already expired
|
||||
await repo.create(uuid4(), user_id, token_hash, org_id, expires_at)
|
||||
|
||||
result = await repo.get_valid_by_hash(token_hash)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_valid_returns_none_for_revoked(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_valid_by_hash returns None for revoked token."""
|
||||
token_id, _, _, token_hash, repo = await self._setup_token(db_conn, "_revoked")
|
||||
|
||||
await repo.revoke(token_id)
|
||||
result = await repo.get_valid_by_hash(token_hash)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_valid_returns_none_for_rotated(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_valid_by_hash returns None for already-rotated token."""
|
||||
_, user_id, org_id, old_hash, repo = await self._setup_token(db_conn, "_rotated")
|
||||
|
||||
# Rotate the token
|
||||
new_hash = f"new_token_{uuid4().hex[:8]}"
|
||||
new_expires = datetime.now(UTC) + timedelta(days=30)
|
||||
await repo.rotate(old_hash, uuid4(), new_hash, new_expires)
|
||||
|
||||
# Old token should no longer be valid
|
||||
result = await repo.get_valid_by_hash(old_hash)
|
||||
assert result is None
|
||||
|
||||
async def test_get_valid_with_user_id_validation(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_valid_by_hash validates user_id when provided (defense-in-depth)."""
|
||||
_, user_id, _, token_hash, repo = await self._setup_token(db_conn, "_user_check")
|
||||
|
||||
# Correct user_id should work
|
||||
result = await repo.get_valid_by_hash(token_hash, user_id=user_id)
|
||||
assert result is not None
|
||||
|
||||
# Wrong user_id should return None
|
||||
wrong_user_id = uuid4()
|
||||
result = await repo.get_valid_by_hash(token_hash, user_id=wrong_user_id)
|
||||
assert result is None
|
||||
|
||||
async def test_get_valid_with_org_id_validation(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_valid_by_hash validates active_org_id when provided (defense-in-depth)."""
|
||||
_, _, org_id, token_hash, repo = await self._setup_token(db_conn, "_org_check")
|
||||
|
||||
# Correct org_id should work
|
||||
result = await repo.get_valid_by_hash(token_hash, active_org_id=org_id)
|
||||
assert result is not None
|
||||
|
||||
# Wrong org_id should return None
|
||||
wrong_org_id = uuid4()
|
||||
result = await repo.get_valid_by_hash(token_hash, active_org_id=wrong_org_id)
|
||||
assert result is None
|
||||
|
||||
async def test_get_valid_with_both_user_and_org_validation(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_valid_by_hash validates both user_id and active_org_id together."""
|
||||
_, user_id, org_id, token_hash, repo = await self._setup_token(db_conn, "_both")
|
||||
|
||||
# Both correct should work
|
||||
result = await repo.get_valid_by_hash(token_hash, user_id=user_id, active_org_id=org_id)
|
||||
assert result is not None
|
||||
|
||||
# Either wrong should fail
|
||||
result = await repo.get_valid_by_hash(token_hash, user_id=uuid4(), active_org_id=org_id)
|
||||
assert result is None
|
||||
|
||||
result = await repo.get_valid_by_hash(token_hash, user_id=user_id, active_org_id=uuid4())
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestAtomicRotation:
|
||||
"""Tests for atomic token rotation per SPECS.md."""
|
||||
|
||||
async def _setup_token(
|
||||
self, conn: asyncpg.Connection
|
||||
) -> tuple[uuid4, uuid4, uuid4, str, RefreshTokenRepository]:
|
||||
"""Helper to create user, org, and token."""
|
||||
user_repo = UserRepository(conn)
|
||||
org_repo = OrgRepository(conn)
|
||||
token_repo = RefreshTokenRepository(conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
token_id = uuid4()
|
||||
token_hash = f"rotate_token_{uuid4().hex[:8]}"
|
||||
|
||||
await user_repo.create(user_id, f"rotate_{uuid4().hex[:8]}@example.com", "hash")
|
||||
await org_repo.create(org_id, "Rotate Org", f"rotate-org-{uuid4().hex[:8]}")
|
||||
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
await token_repo.create(token_id, user_id, token_hash, org_id, expires_at)
|
||||
|
||||
return token_id, user_id, org_id, token_hash, token_repo
|
||||
|
||||
async def test_rotate_creates_new_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""rotate() creates a new token and returns it."""
|
||||
old_id, user_id, org_id, old_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
new_id = uuid4()
|
||||
new_hash = f"new_rotated_{uuid4().hex[:8]}"
|
||||
new_expires = datetime.now(UTC) + timedelta(days=30)
|
||||
|
||||
result = await repo.rotate(old_hash, new_id, new_hash, new_expires)
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == new_id
|
||||
assert result["token_hash"] == new_hash
|
||||
assert result["user_id"] == user_id
|
||||
assert result["active_org_id"] == org_id
|
||||
|
||||
async def test_rotate_marks_old_token_as_rotated(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""rotate() sets rotated_to on the old token (not revoked_at)."""
|
||||
old_id, _, _, old_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
new_id = uuid4()
|
||||
new_hash = f"new_{uuid4().hex[:8]}"
|
||||
new_expires = datetime.now(UTC) + timedelta(days=30)
|
||||
|
||||
await repo.rotate(old_hash, new_id, new_hash, new_expires)
|
||||
|
||||
# Check old token state
|
||||
old_token = await repo.get_by_hash(old_hash)
|
||||
assert old_token["rotated_to"] == new_id
|
||||
assert old_token["revoked_at"] is None # Not revoked, just rotated
|
||||
|
||||
async def test_rotate_fails_for_invalid_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""rotate() returns None if old token is invalid."""
|
||||
_, _, _, _, repo = await self._setup_token(db_conn)
|
||||
|
||||
result = await repo.rotate(
|
||||
"nonexistent_hash",
|
||||
uuid4(),
|
||||
f"new_{uuid4().hex[:8]}",
|
||||
datetime.now(UTC) + timedelta(days=30),
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_rotate_fails_for_expired_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""rotate() returns None if old token is expired."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, "exp_rotate@example.com", "hash")
|
||||
await org_repo.create(org_id, "Org", "exp-rotate-org")
|
||||
|
||||
old_hash = "expired_for_rotation"
|
||||
await repo.create(
|
||||
uuid4(), user_id, old_hash, org_id,
|
||||
datetime.now(UTC) - timedelta(days=1) # Already expired
|
||||
)
|
||||
|
||||
result = await repo.rotate(
|
||||
old_hash, uuid4(), f"new_{uuid4().hex[:8]}",
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_rotate_fails_for_revoked_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""rotate() returns None if old token is revoked."""
|
||||
old_id, _, _, old_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
await repo.revoke(old_id)
|
||||
|
||||
result = await repo.rotate(
|
||||
old_hash, uuid4(), f"new_{uuid4().hex[:8]}",
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_rotate_fails_for_already_rotated_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""rotate() returns None if old token was already rotated."""
|
||||
_, _, _, old_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
# First rotation should succeed
|
||||
result1 = await repo.rotate(
|
||||
old_hash, uuid4(), f"new1_{uuid4().hex[:8]}",
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
assert result1 is not None
|
||||
|
||||
# Second rotation of same token should fail
|
||||
result2 = await repo.rotate(
|
||||
old_hash, uuid4(), f"new2_{uuid4().hex[:8]}",
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
assert result2 is None
|
||||
|
||||
async def test_rotate_with_org_switch(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""rotate() can change active_org_id (for org-switch flow)."""
|
||||
_, user_id, old_org_id, old_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
# Create a new org for the user to switch to
|
||||
org_repo = OrgRepository(db_conn)
|
||||
new_org_id = uuid4()
|
||||
await org_repo.create(new_org_id, "New Org", f"new-org-{uuid4().hex[:8]}")
|
||||
|
||||
new_hash = f"switched_{uuid4().hex[:8]}"
|
||||
result = await repo.rotate(
|
||||
old_hash, uuid4(), new_hash,
|
||||
datetime.now(UTC) + timedelta(days=30),
|
||||
new_active_org_id=new_org_id # Switch org
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result["active_org_id"] == new_org_id
|
||||
assert result["active_org_id"] != old_org_id
|
||||
|
||||
async def test_rotate_validates_expected_user_id(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""rotate() fails if expected_user_id doesn't match token's user."""
|
||||
_, user_id, _, old_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
# Wrong user should fail
|
||||
result = await repo.rotate(
|
||||
old_hash, uuid4(), f"new_{uuid4().hex[:8]}",
|
||||
datetime.now(UTC) + timedelta(days=30),
|
||||
expected_user_id=uuid4() # Wrong user
|
||||
)
|
||||
assert result is None
|
||||
|
||||
# Correct user should work
|
||||
result = await repo.rotate(
|
||||
old_hash, uuid4(), f"new_{uuid4().hex[:8]}",
|
||||
datetime.now(UTC) + timedelta(days=30),
|
||||
expected_user_id=user_id # Correct user
|
||||
)
|
||||
assert result is not None
|
||||
|
||||
|
||||
class TestTokenReuseDetection:
|
||||
"""Tests for detecting token reuse (stolen token attacks)."""
|
||||
|
||||
async def _setup_rotated_token(
|
||||
self, conn: asyncpg.Connection
|
||||
) -> tuple[uuid4, str, str, RefreshTokenRepository]:
|
||||
"""Create a token and rotate it, returning old and new hashes."""
|
||||
user_repo = UserRepository(conn)
|
||||
org_repo = OrgRepository(conn)
|
||||
token_repo = RefreshTokenRepository(conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, f"reuse_{uuid4().hex[:8]}@example.com", "hash")
|
||||
await org_repo.create(org_id, "Reuse Org", f"reuse-org-{uuid4().hex[:8]}")
|
||||
|
||||
old_hash = f"old_token_{uuid4().hex[:8]}"
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
old_token = await token_repo.create(uuid4(), user_id, old_hash, org_id, expires_at)
|
||||
|
||||
new_hash = f"new_token_{uuid4().hex[:8]}"
|
||||
await token_repo.rotate(old_hash, uuid4(), new_hash, expires_at)
|
||||
|
||||
return old_token["id"], old_hash, new_hash, token_repo
|
||||
|
||||
async def test_check_token_reuse_detects_rotated_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""check_token_reuse returns token if it has been rotated."""
|
||||
old_id, old_hash, _, repo = await self._setup_rotated_token(db_conn)
|
||||
|
||||
result = await repo.check_token_reuse(old_hash)
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == old_id
|
||||
assert result["rotated_to"] is not None
|
||||
|
||||
async def test_check_token_reuse_returns_none_for_active_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""check_token_reuse returns None for token that hasn't been rotated."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, "active@example.com", "hash")
|
||||
await org_repo.create(org_id, "Org", "active-org")
|
||||
|
||||
token_hash = "active_token_hash"
|
||||
await repo.create(
|
||||
uuid4(), user_id, token_hash, org_id,
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
|
||||
result = await repo.check_token_reuse(token_hash)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_check_token_reuse_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""check_token_reuse returns None for non-existent token."""
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
result = await repo.check_token_reuse("nonexistent_hash")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestTokenChainRevocation:
|
||||
"""Tests for revoking entire token chains (breach response)."""
|
||||
|
||||
async def _setup_token_chain(
|
||||
self, conn: asyncpg.Connection, chain_length: int = 3
|
||||
) -> tuple[list[uuid4], list[str], uuid4, RefreshTokenRepository]:
|
||||
"""Create a chain of rotated tokens."""
|
||||
user_repo = UserRepository(conn)
|
||||
org_repo = OrgRepository(conn)
|
||||
token_repo = RefreshTokenRepository(conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, f"chain_{uuid4().hex[:8]}@example.com", "hash")
|
||||
await org_repo.create(org_id, "Chain Org", f"chain-org-{uuid4().hex[:8]}")
|
||||
|
||||
token_ids = []
|
||||
token_hashes = []
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
|
||||
# Create first token
|
||||
first_hash = f"chain_token_0_{uuid4().hex[:8]}"
|
||||
first_token = await token_repo.create(uuid4(), user_id, first_hash, org_id, expires_at)
|
||||
token_ids.append(first_token["id"])
|
||||
token_hashes.append(first_hash)
|
||||
|
||||
# Rotate to create chain
|
||||
current_hash = first_hash
|
||||
for i in range(1, chain_length):
|
||||
new_hash = f"chain_token_{i}_{uuid4().hex[:8]}"
|
||||
new_id = uuid4()
|
||||
await token_repo.rotate(current_hash, new_id, new_hash, expires_at)
|
||||
token_ids.append(new_id)
|
||||
token_hashes.append(new_hash)
|
||||
current_hash = new_hash
|
||||
|
||||
return token_ids, token_hashes, user_id, token_repo
|
||||
|
||||
async def test_revoke_token_chain_revokes_all_in_chain(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke_token_chain revokes the token and all its rotations."""
|
||||
token_ids, token_hashes, _, repo = await self._setup_token_chain(db_conn, chain_length=3)
|
||||
|
||||
# Revoke starting from the first token
|
||||
count = await repo.revoke_token_chain(token_ids[0])
|
||||
|
||||
# Should revoke all 3 tokens in the chain
|
||||
# But note: only the last one wasn't already "consumed" by rotation
|
||||
# Let's check that revoke was called on all that were eligible
|
||||
assert count >= 1 # At least the leaf token
|
||||
|
||||
# Verify the leaf token is revoked
|
||||
leaf_token = await repo.get_by_hash(token_hashes[-1])
|
||||
assert leaf_token["revoked_at"] is not None
|
||||
|
||||
async def test_revoke_token_chain_returns_count(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke_token_chain returns count of actually revoked tokens."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, "single@example.com", "hash")
|
||||
await org_repo.create(org_id, "Single Org", "single-org")
|
||||
|
||||
token_hash = "single_chain_token"
|
||||
token = await repo.create(
|
||||
uuid4(), user_id, token_hash, org_id,
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
|
||||
count = await repo.revoke_token_chain(token["id"])
|
||||
|
||||
assert count == 1
|
||||
|
||||
|
||||
class TestTokenRevocation:
|
||||
"""Tests for token revocation methods."""
|
||||
|
||||
async def _setup_token(self, conn: asyncpg.Connection) -> tuple[uuid4, str, RefreshTokenRepository]:
|
||||
"""Helper to create user, org, and token."""
|
||||
user_repo = UserRepository(conn)
|
||||
org_repo = OrgRepository(conn)
|
||||
token_repo = RefreshTokenRepository(conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
token_id = uuid4()
|
||||
token_hash = f"revoke_token_{uuid4().hex[:8]}"
|
||||
|
||||
await user_repo.create(user_id, f"revoke_{uuid4().hex[:8]}@example.com", "hash")
|
||||
await org_repo.create(org_id, "Revoke Org", f"revoke-org-{uuid4().hex[:8]}")
|
||||
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
await token_repo.create(token_id, user_id, token_hash, org_id, expires_at)
|
||||
|
||||
return token_id, token_hash, token_repo
|
||||
|
||||
async def test_revoke_sets_revoked_at(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke() sets the revoked_at timestamp."""
|
||||
token_id, token_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
result = await repo.revoke(token_id)
|
||||
|
||||
assert result is True
|
||||
token = await repo.get_by_hash(token_hash)
|
||||
assert token["revoked_at"] is not None
|
||||
|
||||
async def test_revoke_returns_true_on_success(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke() returns True when token is revoked."""
|
||||
token_id, _, repo = await self._setup_token(db_conn)
|
||||
|
||||
result = await repo.revoke(token_id)
|
||||
|
||||
assert result is True
|
||||
|
||||
async def test_revoke_returns_false_for_already_revoked(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke() returns False if token already revoked."""
|
||||
token_id, _, repo = await self._setup_token(db_conn)
|
||||
|
||||
await repo.revoke(token_id)
|
||||
result = await repo.revoke(token_id)
|
||||
|
||||
assert result is False
|
||||
|
||||
async def test_revoke_returns_false_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke() returns False for non-existent token."""
|
||||
_, _, repo = await self._setup_token(db_conn)
|
||||
|
||||
result = await repo.revoke(uuid4())
|
||||
|
||||
assert result is False
|
||||
|
||||
async def test_revoke_by_hash_works(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke_by_hash() revokes token by hash value."""
|
||||
_, token_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
result = await repo.revoke_by_hash(token_hash)
|
||||
|
||||
assert result is True
|
||||
token = await repo.get_by_hash(token_hash)
|
||||
assert token["revoked_at"] is not None
|
||||
|
||||
async def test_revoke_by_hash_returns_false_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke_by_hash() returns False for non-existent hash."""
|
||||
_, _, repo = await self._setup_token(db_conn)
|
||||
|
||||
result = await repo.revoke_by_hash("nonexistent_hash")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestRevokeAllForUser:
|
||||
"""Tests for revoking all tokens for a user."""
|
||||
|
||||
async def test_revoke_all_for_user_revokes_all_tokens(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke_all_for_user() revokes all tokens for the user."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
token_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, "multi_token@example.com", "hash")
|
||||
await org_repo.create(org_id, "Multi Token Org", "multi-token-org")
|
||||
|
||||
# Create multiple tokens
|
||||
hashes = []
|
||||
for i in range(3):
|
||||
token_hash = f"token_{i}_{uuid4().hex[:8]}"
|
||||
hashes.append(token_hash)
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
await token_repo.create(uuid4(), user_id, token_hash, org_id, expires_at)
|
||||
|
||||
result = await token_repo.revoke_all_for_user(user_id)
|
||||
|
||||
assert result == 3
|
||||
for token_hash in hashes:
|
||||
token = await token_repo.get_valid_by_hash(token_hash)
|
||||
assert token is None
|
||||
|
||||
async def test_revoke_all_for_user_returns_zero_for_no_tokens(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke_all_for_user() returns 0 if user has no tokens."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
token_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
await user_repo.create(user_id, "no_tokens@example.com", "hash")
|
||||
|
||||
result = await token_repo.revoke_all_for_user(user_id)
|
||||
|
||||
assert result == 0
|
||||
|
||||
async def test_revoke_all_for_user_only_affects_user_tokens(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke_all_for_user() doesn't affect other users' tokens."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
token_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user1 = uuid4()
|
||||
user2 = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user1, "user1@example.com", "hash")
|
||||
await user_repo.create(user2, "user2@example.com", "hash")
|
||||
await org_repo.create(org_id, "Shared Org", "shared-org")
|
||||
|
||||
user1_hash = f"user1_token_{uuid4().hex[:8]}"
|
||||
user2_hash = f"user2_token_{uuid4().hex[:8]}"
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
await token_repo.create(uuid4(), user1, user1_hash, org_id, expires_at)
|
||||
await token_repo.create(uuid4(), user2, user2_hash, org_id, expires_at)
|
||||
|
||||
await token_repo.revoke_all_for_user(user1)
|
||||
|
||||
# User1's token is revoked
|
||||
assert await token_repo.get_valid_by_hash(user1_hash) is None
|
||||
# User2's token is still valid
|
||||
assert await token_repo.get_valid_by_hash(user2_hash) is not None
|
||||
|
||||
|
||||
class TestRevokeAllExcept:
|
||||
"""Tests for revoking all tokens except current session."""
|
||||
|
||||
async def test_revoke_all_except_keeps_specified_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke_all_for_user_except() keeps the specified token active."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
token_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, "except@example.com", "hash")
|
||||
await org_repo.create(org_id, "Except Org", "except-org")
|
||||
|
||||
# Create multiple tokens
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
keep_token_id = uuid4()
|
||||
keep_hash = f"keep_token_{uuid4().hex[:8]}"
|
||||
await token_repo.create(keep_token_id, user_id, keep_hash, org_id, expires_at)
|
||||
|
||||
other_hashes = []
|
||||
for i in range(2):
|
||||
other_hash = f"other_token_{i}_{uuid4().hex[:8]}"
|
||||
other_hashes.append(other_hash)
|
||||
await token_repo.create(uuid4(), user_id, other_hash, org_id, expires_at)
|
||||
|
||||
result = await token_repo.revoke_all_for_user_except(user_id, keep_token_id)
|
||||
|
||||
assert result == 2 # Revoked 2 other tokens
|
||||
|
||||
# Keep token is still valid
|
||||
assert await token_repo.get_valid_by_hash(keep_hash) is not None
|
||||
|
||||
# Other tokens are revoked
|
||||
for other_hash in other_hashes:
|
||||
assert await token_repo.get_valid_by_hash(other_hash) is None
|
||||
|
||||
|
||||
class TestActiveTokensForUser:
|
||||
"""Tests for listing active tokens for a user."""
|
||||
|
||||
async def test_get_active_tokens_returns_only_active(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_active_tokens_for_user() returns only non-revoked, non-expired, non-rotated."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
token_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, "active_list@example.com", "hash")
|
||||
await org_repo.create(org_id, "Active List Org", "active-list-org")
|
||||
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
expired_at = datetime.now(UTC) - timedelta(days=1)
|
||||
|
||||
# Create active token
|
||||
active_hash = f"active_{uuid4().hex[:8]}"
|
||||
await token_repo.create(uuid4(), user_id, active_hash, org_id, expires_at)
|
||||
|
||||
# Create revoked token
|
||||
revoked_id = uuid4()
|
||||
revoked_hash = f"revoked_{uuid4().hex[:8]}"
|
||||
await token_repo.create(revoked_id, user_id, revoked_hash, org_id, expires_at)
|
||||
await token_repo.revoke(revoked_id)
|
||||
|
||||
# Create expired token
|
||||
expired_hash = f"expired_{uuid4().hex[:8]}"
|
||||
await token_repo.create(uuid4(), user_id, expired_hash, org_id, expired_at)
|
||||
|
||||
# Create rotated token
|
||||
rotated_hash = f"rotated_{uuid4().hex[:8]}"
|
||||
await token_repo.create(uuid4(), user_id, rotated_hash, org_id, expires_at)
|
||||
await token_repo.rotate(rotated_hash, uuid4(), f"new_{uuid4().hex[:8]}", expires_at)
|
||||
|
||||
result = await token_repo.get_active_tokens_for_user(user_id)
|
||||
|
||||
# Should only return the active token and the new rotated token
|
||||
assert len(result) == 2
|
||||
hashes = {t["token_hash"] for t in result}
|
||||
assert active_hash in hashes
|
||||
assert revoked_hash not in hashes
|
||||
assert expired_hash not in hashes
|
||||
assert rotated_hash not in hashes
|
||||
|
||||
|
||||
class TestTokenForeignKeys:
|
||||
"""Tests for refresh token foreign key constraints."""
|
||||
|
||||
async def test_token_requires_valid_user_foreign_key(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""refresh_tokens.user_id must reference existing user."""
|
||||
org_repo = OrgRepository(db_conn)
|
||||
token_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
org_id = uuid4()
|
||||
await org_repo.create(org_id, "FK Test Org", "fk-test-org")
|
||||
|
||||
with pytest.raises(asyncpg.ForeignKeyViolationError):
|
||||
await token_repo.create(
|
||||
uuid4(), uuid4(), "orphan_token", org_id,
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
|
||||
async def test_token_requires_valid_org_foreign_key(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""refresh_tokens.active_org_id must reference existing org."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
token_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
await user_repo.create(user_id, "fk_org_test@example.com", "hash")
|
||||
|
||||
with pytest.raises(asyncpg.ForeignKeyViolationError):
|
||||
await token_repo.create(
|
||||
uuid4(), user_id, "orphan_org_token", uuid4(),
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
|
||||
async def test_token_stores_active_org_id(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Token stores active_org_id for org context per SPECS.md."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
token_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, "active_org@example.com", "hash")
|
||||
await org_repo.create(org_id, "Active Org", "active-org")
|
||||
|
||||
token_hash = f"active_org_token_{uuid4().hex[:8]}"
|
||||
await token_repo.create(
|
||||
uuid4(), user_id, token_hash, org_id,
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
|
||||
token = await token_repo.get_by_hash(token_hash)
|
||||
assert token["active_org_id"] == org_id
|
||||
201
tests/repositories/test_service.py
Normal file
201
tests/repositories/test_service.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""Tests for ServiceRepository."""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
|
||||
from app.repositories.org import OrgRepository
|
||||
from app.repositories.service import ServiceRepository
|
||||
|
||||
|
||||
class TestServiceRepository:
|
||||
"""Tests for ServiceRepository conforming to SPECS.md."""
|
||||
|
||||
async def _create_org(self, conn: asyncpg.Connection, slug: str) -> uuid4:
|
||||
"""Helper to create an org."""
|
||||
org_repo = OrgRepository(conn)
|
||||
org_id = uuid4()
|
||||
await org_repo.create(org_id, f"Org {slug}", slug)
|
||||
return org_id
|
||||
|
||||
async def test_create_service_returns_service_data(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Creating a service returns the service data."""
|
||||
org_id = await self._create_org(db_conn, "service-org")
|
||||
repo = ServiceRepository(db_conn)
|
||||
service_id = uuid4()
|
||||
|
||||
result = await repo.create(service_id, org_id, "API Gateway", "api-gateway")
|
||||
|
||||
assert result["id"] == service_id
|
||||
assert result["org_id"] == org_id
|
||||
assert result["name"] == "API Gateway"
|
||||
assert result["slug"] == "api-gateway"
|
||||
assert result["created_at"] is not None
|
||||
|
||||
async def test_create_service_slug_unique_per_org(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Service slug must be unique within an org per SPECS.md."""
|
||||
org_id = await self._create_org(db_conn, "unique-slug-org")
|
||||
repo = ServiceRepository(db_conn)
|
||||
|
||||
await repo.create(uuid4(), org_id, "Service One", "my-service")
|
||||
|
||||
with pytest.raises(asyncpg.UniqueViolationError):
|
||||
await repo.create(uuid4(), org_id, "Service Two", "my-service")
|
||||
|
||||
async def test_same_slug_allowed_in_different_orgs(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Same slug can exist in different orgs."""
|
||||
org1 = await self._create_org(db_conn, "org-one")
|
||||
org2 = await self._create_org(db_conn, "org-two")
|
||||
repo = ServiceRepository(db_conn)
|
||||
slug = "shared-slug"
|
||||
|
||||
# Both should succeed
|
||||
result1 = await repo.create(uuid4(), org1, "Service Org1", slug)
|
||||
result2 = await repo.create(uuid4(), org2, "Service Org2", slug)
|
||||
|
||||
assert result1["slug"] == slug
|
||||
assert result2["slug"] == slug
|
||||
assert result1["org_id"] != result2["org_id"]
|
||||
|
||||
async def test_get_by_id_returns_service(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_id returns the correct service."""
|
||||
org_id = await self._create_org(db_conn, "get-service-org")
|
||||
repo = ServiceRepository(db_conn)
|
||||
service_id = uuid4()
|
||||
|
||||
await repo.create(service_id, org_id, "My Service", "my-service")
|
||||
result = await repo.get_by_id(service_id)
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == service_id
|
||||
assert result["name"] == "My Service"
|
||||
|
||||
async def test_get_by_id_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_id returns None for non-existent service."""
|
||||
repo = ServiceRepository(db_conn)
|
||||
|
||||
result = await repo.get_by_id(uuid4())
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_by_org_returns_all_org_services(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_org returns all services for an organization."""
|
||||
org_id = await self._create_org(db_conn, "multi-service-org")
|
||||
repo = ServiceRepository(db_conn)
|
||||
|
||||
await repo.create(uuid4(), org_id, "Service A", "service-a")
|
||||
await repo.create(uuid4(), org_id, "Service B", "service-b")
|
||||
await repo.create(uuid4(), org_id, "Service C", "service-c")
|
||||
|
||||
result = await repo.get_by_org(org_id)
|
||||
|
||||
assert len(result) == 3
|
||||
names = {s["name"] for s in result}
|
||||
assert names == {"Service A", "Service B", "Service C"}
|
||||
|
||||
async def test_get_by_org_returns_empty_for_no_services(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_org returns empty list for org with no services."""
|
||||
org_id = await self._create_org(db_conn, "empty-service-org")
|
||||
repo = ServiceRepository(db_conn)
|
||||
|
||||
result = await repo.get_by_org(org_id)
|
||||
|
||||
assert result == []
|
||||
|
||||
async def test_get_by_org_only_returns_own_services(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_org doesn't return services from other orgs (tenant isolation)."""
|
||||
org1 = await self._create_org(db_conn, "isolated-org-1")
|
||||
org2 = await self._create_org(db_conn, "isolated-org-2")
|
||||
repo = ServiceRepository(db_conn)
|
||||
|
||||
await repo.create(uuid4(), org1, "Org1 Service", "org1-service")
|
||||
await repo.create(uuid4(), org2, "Org2 Service", "org2-service")
|
||||
|
||||
result = await repo.get_by_org(org1)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "Org1 Service"
|
||||
|
||||
async def test_get_by_slug_returns_service(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_slug returns service by org and slug."""
|
||||
org_id = await self._create_org(db_conn, "slug-lookup-org")
|
||||
repo = ServiceRepository(db_conn)
|
||||
service_id = uuid4()
|
||||
|
||||
await repo.create(service_id, org_id, "Slug Service", "slug-service")
|
||||
result = await repo.get_by_slug(org_id, "slug-service")
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == service_id
|
||||
|
||||
async def test_get_by_slug_returns_none_for_wrong_org(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_slug returns None if slug exists but in different org."""
|
||||
org1 = await self._create_org(db_conn, "slug-org-1")
|
||||
org2 = await self._create_org(db_conn, "slug-org-2")
|
||||
repo = ServiceRepository(db_conn)
|
||||
|
||||
await repo.create(uuid4(), org1, "Service", "the-slug")
|
||||
result = await repo.get_by_slug(org2, "the-slug")
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_by_slug_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_slug returns None for non-existent slug."""
|
||||
org_id = await self._create_org(db_conn, "no-slug-org")
|
||||
repo = ServiceRepository(db_conn)
|
||||
|
||||
result = await repo.get_by_slug(org_id, "nonexistent")
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_slug_exists_returns_true_when_exists(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""slug_exists returns True when slug exists in org."""
|
||||
org_id = await self._create_org(db_conn, "exists-org")
|
||||
repo = ServiceRepository(db_conn)
|
||||
|
||||
await repo.create(uuid4(), org_id, "Exists Service", "exists-slug")
|
||||
result = await repo.slug_exists(org_id, "exists-slug")
|
||||
|
||||
assert result is True
|
||||
|
||||
async def test_slug_exists_returns_false_when_not_exists(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""slug_exists returns False when slug doesn't exist in org."""
|
||||
org_id = await self._create_org(db_conn, "not-exists-org")
|
||||
repo = ServiceRepository(db_conn)
|
||||
|
||||
result = await repo.slug_exists(org_id, "no-such-slug")
|
||||
|
||||
assert result is False
|
||||
|
||||
async def test_slug_exists_returns_false_for_other_org(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""slug_exists returns False for slug in different org."""
|
||||
org1 = await self._create_org(db_conn, "other-org-1")
|
||||
org2 = await self._create_org(db_conn, "other-org-2")
|
||||
repo = ServiceRepository(db_conn)
|
||||
|
||||
await repo.create(uuid4(), org1, "Service", "cross-org-slug")
|
||||
result = await repo.slug_exists(org2, "cross-org-slug")
|
||||
|
||||
assert result is False
|
||||
|
||||
async def test_service_requires_valid_org_foreign_key(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""services.org_id must reference existing org."""
|
||||
repo = ServiceRepository(db_conn)
|
||||
|
||||
with pytest.raises(asyncpg.ForeignKeyViolationError):
|
||||
await repo.create(uuid4(), uuid4(), "Orphan Service", "orphan")
|
||||
|
||||
async def test_get_by_org_orders_by_name(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_org returns services ordered by name."""
|
||||
org_id = await self._create_org(db_conn, "ordered-org")
|
||||
repo = ServiceRepository(db_conn)
|
||||
|
||||
await repo.create(uuid4(), org_id, "Zebra", "zebra")
|
||||
await repo.create(uuid4(), org_id, "Alpha", "alpha")
|
||||
await repo.create(uuid4(), org_id, "Middle", "middle")
|
||||
|
||||
result = await repo.get_by_org(org_id)
|
||||
|
||||
names = [s["name"] for s in result]
|
||||
assert names == ["Alpha", "Middle", "Zebra"]
|
||||
133
tests/repositories/test_user.py
Normal file
133
tests/repositories/test_user.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Tests for UserRepository."""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
|
||||
from app.repositories.user import UserRepository
|
||||
|
||||
|
||||
class TestUserRepository:
|
||||
"""Tests for UserRepository conforming to SPECS.md."""
|
||||
|
||||
async def test_create_user_returns_user_data(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Creating a user returns the user data with id, email, created_at."""
|
||||
repo = UserRepository(db_conn)
|
||||
user_id = uuid4()
|
||||
email = "test@example.com"
|
||||
password_hash = "hashed_password_123"
|
||||
|
||||
result = await repo.create(user_id, email, password_hash)
|
||||
|
||||
assert result["id"] == user_id
|
||||
assert result["email"] == email
|
||||
assert result["created_at"] is not None
|
||||
|
||||
async def test_create_user_stores_password_hash(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Password hash is stored correctly in the database."""
|
||||
repo = UserRepository(db_conn)
|
||||
user_id = uuid4()
|
||||
email = "hash_test@example.com"
|
||||
password_hash = "bcrypt_hashed_value"
|
||||
|
||||
await repo.create(user_id, email, password_hash)
|
||||
user = await repo.get_by_id(user_id)
|
||||
|
||||
assert user["password_hash"] == password_hash
|
||||
|
||||
async def test_create_user_email_must_be_unique(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Email uniqueness constraint per SPECS.md users table."""
|
||||
repo = UserRepository(db_conn)
|
||||
email = "duplicate@example.com"
|
||||
|
||||
await repo.create(uuid4(), email, "hash1")
|
||||
|
||||
with pytest.raises(asyncpg.UniqueViolationError):
|
||||
await repo.create(uuid4(), email, "hash2")
|
||||
|
||||
async def test_get_by_id_returns_user(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_id returns the correct user."""
|
||||
repo = UserRepository(db_conn)
|
||||
user_id = uuid4()
|
||||
email = "getbyid@example.com"
|
||||
|
||||
await repo.create(user_id, email, "hash")
|
||||
result = await repo.get_by_id(user_id)
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == user_id
|
||||
assert result["email"] == email
|
||||
|
||||
async def test_get_by_id_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_id returns None for non-existent user."""
|
||||
repo = UserRepository(db_conn)
|
||||
|
||||
result = await repo.get_by_id(uuid4())
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_by_email_returns_user(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_email returns the correct user."""
|
||||
repo = UserRepository(db_conn)
|
||||
user_id = uuid4()
|
||||
email = "getbyemail@example.com"
|
||||
|
||||
await repo.create(user_id, email, "hash")
|
||||
result = await repo.get_by_email(email)
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == user_id
|
||||
assert result["email"] == email
|
||||
|
||||
async def test_get_by_email_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_email returns None for non-existent email."""
|
||||
repo = UserRepository(db_conn)
|
||||
|
||||
result = await repo.get_by_email("nonexistent@example.com")
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_by_email_is_case_sensitive(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Email lookup is case-sensitive (stored as provided)."""
|
||||
repo = UserRepository(db_conn)
|
||||
email = "CaseSensitive@Example.com"
|
||||
|
||||
await repo.create(uuid4(), email, "hash")
|
||||
|
||||
# Exact match works
|
||||
result = await repo.get_by_email(email)
|
||||
assert result is not None
|
||||
|
||||
# Different case returns None
|
||||
result = await repo.get_by_email(email.lower())
|
||||
assert result is None
|
||||
|
||||
async def test_exists_by_email_returns_true_when_exists(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""exists_by_email returns True when email exists."""
|
||||
repo = UserRepository(db_conn)
|
||||
email = "exists@example.com"
|
||||
|
||||
await repo.create(uuid4(), email, "hash")
|
||||
result = await repo.exists_by_email(email)
|
||||
|
||||
assert result is True
|
||||
|
||||
async def test_exists_by_email_returns_false_when_not_exists(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""exists_by_email returns False when email doesn't exist."""
|
||||
repo = UserRepository(db_conn)
|
||||
|
||||
result = await repo.exists_by_email("notexists@example.com")
|
||||
|
||||
assert result is False
|
||||
|
||||
async def test_user_id_is_uuid_primary_key(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""User ID must be a valid UUID (primary key)."""
|
||||
repo = UserRepository(db_conn)
|
||||
user_id = uuid4()
|
||||
|
||||
await repo.create(user_id, "pk_test@example.com", "hash")
|
||||
|
||||
# Duplicate ID should fail
|
||||
with pytest.raises(asyncpg.UniqueViolationError):
|
||||
await repo.create(user_id, "other@example.com", "hash")
|
||||
Reference in New Issue
Block a user