feat(auth): implement auth stack

This commit is contained in:
2025-12-29 09:55:30 +00:00
parent 3170f10e86
commit ad94833830
13 changed files with 1199 additions and 11 deletions

65
tests/api/helpers.py Normal file
View File

@@ -0,0 +1,65 @@
"""Shared helpers for API integration tests."""
from __future__ import annotations
from typing import Any
from uuid import UUID, uuid4
import asyncpg
from httpx import AsyncClient
API_PREFIX = "/v1"
async def register_user(
client: AsyncClient,
*,
email: str,
password: str,
org_name: str = "Test Org",
) -> dict[str, Any]:
"""Call the register endpoint and return JSON body (raises on failure)."""
response = await client.post(
f"{API_PREFIX}/auth/register",
json={"email": email, "password": password, "org_name": org_name},
)
response.raise_for_status()
return response.json()
async def create_org(
conn: asyncpg.Connection,
*,
name: str,
slug: str | None = None,
) -> UUID:
"""Insert an organization row and return its ID."""
org_id = uuid4()
slug_value = slug or name.lower().replace(" ", "-")
await conn.execute(
"INSERT INTO orgs (id, name, slug) VALUES ($1, $2, $3)",
org_id,
name,
slug_value,
)
return org_id
async def add_membership(
conn: asyncpg.Connection,
*,
user_id: UUID,
org_id: UUID,
role: str,
) -> None:
"""Insert a membership record for the user/org pair."""
await conn.execute(
"INSERT INTO org_members (id, user_id, org_id, role) VALUES ($1, $2, $3, $4)",
uuid4(),
user_id,
org_id,
role,
)

213
tests/api/test_auth.py Normal file
View File

@@ -0,0 +1,213 @@
"""Integration tests for FastAPI auth endpoints."""
from __future__ import annotations
from uuid import UUID
import asyncpg
import pytest
from httpx import AsyncClient
from app.core import security
from tests.api import helpers
pytestmark = pytest.mark.asyncio
API_PREFIX = "/v1/auth"
async def test_register_endpoint_persists_user_and_membership(
api_client: AsyncClient,
db_admin: asyncpg.Connection,
) -> None:
data = await helpers.register_user(
api_client,
email="api-register@example.com",
password="SuperSecret1!",
org_name="API Org",
)
assert "access_token" in data and "refresh_token" in data
token_payload = security.decode_access_token(data["access_token"])
assert token_payload["org_role"] == "admin"
stored_user = await db_admin.fetchrow("SELECT email FROM users WHERE email = $1", "api-register@example.com")
assert stored_user is not None
membership = await db_admin.fetchrow(
"SELECT role FROM org_members WHERE user_id = $1 AND org_id = $2",
UUID(token_payload["sub"]),
UUID(token_payload["org_id"]),
)
assert membership is not None and membership["role"] == "admin"
async def test_login_endpoint_rejects_bad_credentials(
api_client: AsyncClient,
) -> None:
register_payload = {
"email": "api-login@example.com",
"password": "CorrectHorse1!",
"org_name": "Login Org",
}
await helpers.register_user(api_client, **register_payload)
response = await api_client.post(
f"{API_PREFIX}/login",
json={"email": register_payload["email"], "password": "wrong"},
)
assert response.status_code == 401
async def test_refresh_endpoint_rotates_refresh_token(
api_client: AsyncClient,
db_admin: asyncpg.Connection,
) -> None:
register_payload = {
"email": "api-refresh@example.com",
"password": "RefreshPass1!",
"org_name": "Refresh Org",
}
initial = await helpers.register_user(api_client, **register_payload)
response = await api_client.post(
f"{API_PREFIX}/refresh",
json={"refresh_token": initial["refresh_token"]},
)
assert response.status_code == 200
data = response.json()
assert data["refresh_token"] != initial["refresh_token"]
old_hash = security.hash_token(initial["refresh_token"])
old_row = await db_admin.fetchrow(
"SELECT rotated_to FROM refresh_tokens WHERE token_hash = $1",
old_hash,
)
assert old_row is not None and old_row["rotated_to"] is not None
async def test_refresh_endpoint_detects_reuse(
api_client: AsyncClient,
db_admin: asyncpg.Connection,
) -> None:
tokens = await helpers.register_user(
api_client,
email="api-reuse@example.com",
password="ReusePass1!",
org_name="Reuse Org",
)
rotated = await api_client.post(
f"{API_PREFIX}/refresh",
json={"refresh_token": tokens["refresh_token"]},
)
assert rotated.status_code == 200
reuse_response = await api_client.post(
f"{API_PREFIX}/refresh",
json={"refresh_token": tokens["refresh_token"]},
)
assert reuse_response.status_code == 401
old_hash = security.hash_token(tokens["refresh_token"])
old_row = await db_admin.fetchrow(
"SELECT revoked_at FROM refresh_tokens WHERE token_hash = $1",
old_hash,
)
assert old_row is not None and old_row["revoked_at"] is not None
async def test_switch_org_changes_active_org(
api_client: AsyncClient,
db_admin: asyncpg.Connection,
) -> None:
email = "api-switch@example.com"
register_payload = {
"email": email,
"password": "SwitchPass1!",
"org_name": "Primary Org",
}
tokens = await helpers.register_user(api_client, **register_payload)
user_id_row = await db_admin.fetchrow("SELECT id FROM users WHERE email = $1", email)
assert user_id_row is not None
user_id = user_id_row["id"]
target_org_id = await helpers.create_org(db_admin, name="Secondary Org", slug="secondary-org")
await helpers.add_membership(db_admin, user_id=user_id, org_id=target_org_id, role="member")
response = await api_client.post(
f"{API_PREFIX}/switch-org",
json={"org_id": str(target_org_id), "refresh_token": tokens["refresh_token"]},
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert response.status_code == 200
data = response.json()
payload = security.decode_access_token(data["access_token"])
assert payload["org_id"] == str(target_org_id)
assert payload["org_role"] == "member"
new_hash = security.hash_token(data["refresh_token"])
new_row = await db_admin.fetchrow(
"SELECT active_org_id FROM refresh_tokens WHERE token_hash = $1",
new_hash,
)
assert new_row is not None and new_row["active_org_id"] == target_org_id
async def test_switch_org_forbidden_without_membership(
api_client: AsyncClient,
db_admin: asyncpg.Connection,
) -> None:
tokens = await helpers.register_user(
api_client,
email="api-switch-no-access@example.com",
password="SwitchBlock1!",
org_name="Primary",
)
foreign_org = await helpers.create_org(db_admin, name="Foreign Org", slug="foreign-org")
response = await api_client.post(
f"{API_PREFIX}/switch-org",
json={"org_id": str(foreign_org), "refresh_token": tokens["refresh_token"]},
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert response.status_code == 403
# ensure refresh token still valid after failed attempt
retry = await api_client.post(
f"{API_PREFIX}/refresh",
json={"refresh_token": tokens["refresh_token"]},
)
assert retry.status_code == 200
async def test_logout_revokes_refresh_token(
api_client: AsyncClient,
) -> None:
register_payload = {
"email": "api-logout@example.com",
"password": "LogoutPass1!",
"org_name": "Logout Org",
}
tokens = await helpers.register_user(api_client, **register_payload)
logout_response = await api_client.post(
f"{API_PREFIX}/logout",
json={"refresh_token": tokens["refresh_token"]},
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert logout_response.status_code == 204
refresh_response = await api_client.post(
f"{API_PREFIX}/refresh",
json={"refresh_token": tokens["refresh_token"]},
)
assert refresh_response.status_code == 401

View File

@@ -3,9 +3,12 @@
from __future__ import annotations
import os
from uuid import uuid4
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Callable
from uuid import UUID, uuid4
import asyncpg
import httpx
import pytest
# Set test environment variables before importing app modules
@@ -13,6 +16,8 @@ os.environ.setdefault("DATABASE_URL", "postgresql://incidentops:incidentops@loca
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-testing-only")
os.environ.setdefault("REDIS_URL", "redis://localhost:6379/1")
from app.main import app
# Module-level setup: create database and run migrations once
_db_initialized = False
@@ -65,7 +70,7 @@ async def _init_test_db() -> None:
@pytest.fixture
async def db_conn() -> asyncpg.Connection:
async def db_conn() -> AsyncGenerator[asyncpg.Connection, None]:
"""Get a database connection with transaction rollback for test isolation."""
await _init_test_db()
@@ -84,12 +89,77 @@ async def db_conn() -> asyncpg.Connection:
@pytest.fixture
def make_user_id() -> uuid4:
def make_user_id() -> Callable[[], UUID]:
"""Factory for generating user IDs."""
return lambda: uuid4()
@pytest.fixture
def make_org_id() -> uuid4:
def make_org_id() -> Callable[[], UUID]:
"""Factory for generating org IDs."""
return lambda: uuid4()
TABLES_TO_TRUNCATE = [
"incident_events",
"notification_attempts",
"incidents",
"notification_targets",
"services",
"refresh_tokens",
"org_members",
"orgs",
"users",
]
async def _truncate_all_tables() -> None:
test_dsn = os.environ["DATABASE_URL"]
conn = await asyncpg.connect(test_dsn)
try:
tables = ", ".join(TABLES_TO_TRUNCATE)
await conn.execute(f"TRUNCATE TABLE {tables} CASCADE")
finally:
await conn.close()
@pytest.fixture
async def clean_database() -> AsyncGenerator[None, None]:
"""Ensure the database is initialized and truncated before/after tests."""
await _init_test_db()
await _truncate_all_tables()
yield
await _truncate_all_tables()
@asynccontextmanager
async def _lifespan_manager() -> AsyncGenerator[None, None]:
lifespan = app.router.lifespan_context
if lifespan is None:
yield
else:
async with lifespan(app):
yield
@pytest.fixture
async def api_client(clean_database: None) -> AsyncGenerator[httpx.AsyncClient, None]:
"""HTTPX async client bound to the FastAPI app with lifespan support."""
async with _lifespan_manager():
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
yield client
@pytest.fixture
async def db_admin(clean_database: None) -> AsyncGenerator[asyncpg.Connection, None]:
"""Plain connection for arranging/inspecting API test data (no rollback)."""
test_dsn = os.environ["DATABASE_URL"]
conn = await asyncpg.connect(test_dsn)
try:
yield conn
finally:
await conn.close()

80
tests/db/test_get_conn.py Normal file
View File

@@ -0,0 +1,80 @@
"""Tests for the get_conn dependency helper."""
from __future__ import annotations
import pytest
from app.db import db, get_conn
pytestmark = pytest.mark.asyncio
class _FakeConnection:
def __init__(self, idx: int) -> None:
self.idx = idx
class _AcquireContext:
def __init__(self, conn: _FakeConnection, tracker: "_FakePool") -> None:
self._conn = conn
self._tracker = tracker
async def __aenter__(self) -> _FakeConnection:
self._tracker.active += 1
return self._conn
async def __aexit__(self, exc_type, exc, tb) -> None:
self._tracker.active -= 1
class _FakePool:
def __init__(self) -> None:
self.acquire_calls = 0
self.active = 0
def acquire(self) -> _AcquireContext:
conn = _FakeConnection(self.acquire_calls)
self.acquire_calls += 1
return _AcquireContext(conn, self)
async def _collect_single_connection():
connection = None
async for conn in get_conn():
connection = conn
return connection
async def test_get_conn_reuses_connection_within_scope():
original_pool = db.pool
fake_pool = _FakePool()
db.pool = fake_pool
try:
captured: list[_FakeConnection] = []
async for outer in get_conn():
captured.append(outer)
async for inner in get_conn():
captured.append(inner)
assert len(captured) == 2
assert captured[0] is captured[1]
assert fake_pool.acquire_calls == 1
finally:
db.pool = original_pool
async def test_get_conn_acquires_new_connection_per_root_scope():
original_pool = db.pool
fake_pool = _FakePool()
db.pool = fake_pool
try:
first = await _collect_single_connection()
second = await _collect_single_connection()
assert first is not None and second is not None
assert first is not second
assert fake_pool.acquire_calls == 2
finally:
db.pool = original_pool

View File

@@ -0,0 +1,260 @@
"""Unit tests covering AuthService flows."""
from __future__ import annotations
from contextlib import asynccontextmanager
from datetime import UTC, datetime, timedelta
from uuid import UUID, uuid4
import pytest
from app.api.deps import CurrentUser
from app.core import security
from app.db import Database
from app.schemas.auth import (
LoginRequest,
LogoutRequest,
RefreshRequest,
RegisterRequest,
SwitchOrgRequest,
)
from app.services.auth import AuthService
pytestmark = pytest.mark.asyncio
class _SingleConnectionDatabase(Database):
"""Database stub that reuses a single asyncpg connection."""
def __init__(self, conn) -> None: # type: ignore[override]
self._conn = conn
@asynccontextmanager
async def connection(self): # type: ignore[override]
yield self._conn
@asynccontextmanager
async def transaction(self): # type: ignore[override]
tr = self._conn.transaction()
await tr.start()
try:
yield self._conn
except Exception:
await tr.rollback()
raise
else:
await tr.commit()
@pytest.fixture
async def auth_service(db_conn):
"""AuthService bound to the per-test database connection."""
return AuthService(database=_SingleConnectionDatabase(db_conn))
async def _create_user(conn, email: str, password: str) -> UUID:
user_id = uuid4()
password_hash = security.hash_password(password)
await conn.execute(
"INSERT INTO users (id, email, password_hash) VALUES ($1, $2, $3)",
user_id,
email,
password_hash,
)
return user_id
async def _create_org(
conn,
name: str,
slug: str | None = None,
*,
created_at: datetime | None = None,
) -> UUID:
org_id = uuid4()
slug_value = slug or f"{name.lower().replace(' ', '-')}-{org_id.hex[:6]}"
created = created_at or datetime.now(UTC)
await conn.execute(
"INSERT INTO orgs (id, name, slug, created_at) VALUES ($1, $2, $3, $4)",
org_id,
name,
slug_value,
created,
)
return org_id
async def _add_membership(conn, user_id: UUID, org_id: UUID, role: str) -> None:
await conn.execute(
"INSERT INTO org_members (id, user_id, org_id, role) VALUES ($1, $2, $3, $4)",
uuid4(),
user_id,
org_id,
role,
)
async def test_register_user_creates_admin_membership(auth_service, db_conn):
request = RegisterRequest(
email="founder@example.com",
password="SuperSecret1!",
org_name="Founders Inc",
)
response = await auth_service.register_user(request)
payload = security.decode_access_token(response.access_token)
assert payload["org_role"] == "admin"
user_id = UUID(payload["sub"])
org_id = UUID(payload["org_id"])
user = await db_conn.fetchrow("SELECT email FROM users WHERE id = $1", user_id)
assert user is not None and user["email"] == request.email
membership = await db_conn.fetchrow(
"SELECT role FROM org_members WHERE user_id = $1 AND org_id = $2",
user_id,
org_id,
)
assert membership is not None and membership["role"] == "admin"
refresh_hash = security.hash_token(response.refresh_token)
refresh_row = await db_conn.fetchrow(
"SELECT user_id, active_org_id FROM refresh_tokens WHERE token_hash = $1",
refresh_hash,
)
assert refresh_row is not None
assert refresh_row["user_id"] == user_id
assert refresh_row["active_org_id"] == org_id
async def test_login_user_returns_tokens_for_valid_credentials(auth_service, db_conn):
email = "member@example.com"
password = "Password123!"
user_id = await _create_user(db_conn, email, password)
org_id = await _create_org(
db_conn,
name="Member Org",
slug="member-org",
created_at=datetime.now(UTC) - timedelta(days=1),
)
await _add_membership(db_conn, user_id, org_id, "member")
response = await auth_service.login_user(LoginRequest(email=email, password=password))
payload = security.decode_access_token(response.access_token)
assert payload["sub"] == str(user_id)
assert payload["org_id"] == str(org_id)
refresh_hash = security.hash_token(response.refresh_token)
refresh_row = await db_conn.fetchrow(
"SELECT active_org_id FROM refresh_tokens WHERE token_hash = $1",
refresh_hash,
)
assert refresh_row is not None and refresh_row["active_org_id"] == org_id
async def test_refresh_tokens_rotates_existing_token(auth_service, db_conn):
email = "rotate@example.com"
password = "Rotate123!"
user_id = await _create_user(db_conn, email, password)
org_id = await _create_org(db_conn, name="Rotate Org", slug="rotate-org")
await _add_membership(db_conn, user_id, org_id, "member")
initial = await auth_service.login_user(LoginRequest(email=email, password=password))
rotated = await auth_service.refresh_tokens(
RefreshRequest(refresh_token=initial.refresh_token)
)
assert rotated.refresh_token != initial.refresh_token
old_hash = security.hash_token(initial.refresh_token)
old_row = await db_conn.fetchrow(
"SELECT rotated_to FROM refresh_tokens WHERE token_hash = $1",
old_hash,
)
assert old_row is not None and old_row["rotated_to"] is not None
new_hash = security.hash_token(rotated.refresh_token)
new_row = await db_conn.fetchrow(
"SELECT user_id FROM refresh_tokens WHERE token_hash = $1",
new_hash,
)
assert new_row is not None and new_row["user_id"] == user_id
async def test_switch_org_updates_active_org(auth_service, db_conn):
email = "switcher@example.com"
password = "Switch123!"
user_id = await _create_user(db_conn, email, password)
primary_org = await _create_org(
db_conn,
name="Primary Org",
slug="primary-org",
created_at=datetime.now(UTC) - timedelta(days=2),
)
await _add_membership(db_conn, user_id, primary_org, "member")
secondary_org = await _create_org(
db_conn,
name="Secondary Org",
slug="secondary-org",
created_at=datetime.now(UTC) - timedelta(days=1),
)
await _add_membership(db_conn, user_id, secondary_org, "admin")
initial = await auth_service.login_user(LoginRequest(email=email, password=password))
current_user = CurrentUser(
user_id=user_id,
email=email,
org_id=primary_org,
org_role="member",
token=initial.access_token,
)
switched = await auth_service.switch_org(
current_user,
SwitchOrgRequest(org_id=secondary_org, refresh_token=initial.refresh_token),
)
payload = security.decode_access_token(switched.access_token)
assert payload["org_id"] == str(secondary_org)
assert payload["org_role"] == "admin"
new_hash = security.hash_token(switched.refresh_token)
new_row = await db_conn.fetchrow(
"SELECT active_org_id FROM refresh_tokens WHERE token_hash = $1",
new_hash,
)
assert new_row is not None and new_row["active_org_id"] == secondary_org
async def test_logout_revokes_refresh_token(auth_service, db_conn):
email = "logout@example.com"
password = "Logout123!"
user_id = await _create_user(db_conn, email, password)
org_id = await _create_org(db_conn, name="Logout Org", slug="logout-org")
await _add_membership(db_conn, user_id, org_id, "member")
initial = await auth_service.login_user(LoginRequest(email=email, password=password))
current_user = CurrentUser(
user_id=user_id,
email=email,
org_id=org_id,
org_role="member",
token=initial.access_token,
)
await auth_service.logout(current_user, LogoutRequest(refresh_token=initial.refresh_token))
token_hash = security.hash_token(initial.refresh_token)
row = await db_conn.fetchrow(
"SELECT revoked_at FROM refresh_tokens WHERE token_hash = $1",
token_hash,
)
assert row is not None and row["revoked_at"] is not None