feat(auth): implement auth stack

This commit is contained in:
2025-12-29 09:55:30 +00:00
parent 3170f10e86
commit ad94833830
13 changed files with 1199 additions and 11 deletions

101
app/api/deps.py Normal file
View File

@@ -0,0 +1,101 @@
"""Shared FastAPI dependencies (auth, RBAC, ownership)."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable
from uuid import UUID
from fastapi import Depends
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from app.core import exceptions as exc, security
from app.db import db
from app.repositories import OrgRepository, UserRepository
bearer_scheme = HTTPBearer(auto_error=False)
ROLE_RANKS: dict[str, int] = {"viewer": 0, "member": 1, "admin": 2}
@dataclass(slots=True)
class CurrentUser:
"""Authenticated user context derived from the access token."""
user_id: UUID
email: str
org_id: UUID
org_role: str
token: str
async def get_current_user(
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
) -> CurrentUser:
"""Extract and validate the current user from the Authorization header."""
if credentials is None or credentials.scheme.lower() != "bearer":
raise exc.UnauthorizedError("Missing bearer token")
try:
payload = security.TokenPayload(security.decode_access_token(credentials.credentials))
except security.JWTError as err: # pragma: no cover - jose error types
raise exc.UnauthorizedError("Invalid access token") from err
async with db.connection() as conn:
user_repo = UserRepository(conn)
user = await user_repo.get_by_id(payload.user_id)
if user is None:
raise exc.UnauthorizedError("User not found")
org_repo = OrgRepository(conn)
membership = await org_repo.get_member(payload.user_id, payload.org_id)
if membership is None:
raise exc.ForbiddenError("Organization access denied")
return CurrentUser(
user_id=payload.user_id,
email=user["email"],
org_id=payload.org_id,
org_role=membership["role"],
token=credentials.credentials,
)
class RoleChecker:
"""Dependency that enforces a minimum organization role."""
def __init__(self, minimum_role: str) -> None:
if minimum_role not in ROLE_RANKS:
raise ValueError(f"Unknown role '{minimum_role}'")
self.minimum_role = minimum_role
def __call__(self, current_user: CurrentUser = Depends(get_current_user)) -> CurrentUser:
if ROLE_RANKS[current_user.org_role] < ROLE_RANKS[self.minimum_role]:
raise exc.ForbiddenError("Insufficient role for this operation")
return current_user
def require_role(min_role: str) -> Callable[[CurrentUser], CurrentUser]:
"""Factory that returns a dependency enforcing the specified role."""
return RoleChecker(min_role)
def ensure_org_access(resource_org_id: UUID, current_user: CurrentUser) -> None:
"""Verify that the resource belongs to the active org in the token."""
if resource_org_id != current_user.org_id:
raise exc.ForbiddenError("Resource does not belong to the active organization")
__all__ = [
"CurrentUser",
"ROLE_RANKS",
"RoleChecker",
"bearer_scheme",
"ensure_org_access",
"get_current_user",
"require_role",
]

59
app/api/v1/auth.py Normal file
View File

@@ -0,0 +1,59 @@
"""Authentication API endpoints."""
from fastapi import APIRouter, Depends, status
from app.api.deps import CurrentUser, get_current_user
from app.schemas.auth import (
LoginRequest,
LogoutRequest,
RefreshRequest,
RegisterRequest,
SwitchOrgRequest,
TokenResponse,
)
from app.services import AuthService
router = APIRouter(prefix="/auth", tags=["auth"])
auth_service = AuthService()
@router.post("/register", response_model=TokenResponse, status_code=status.HTTP_201_CREATED)
async def register_user(payload: RegisterRequest) -> TokenResponse:
"""Register a new user and default org, returning auth tokens."""
return await auth_service.register_user(payload)
@router.post("/login", response_model=TokenResponse)
async def login_user(payload: LoginRequest) -> TokenResponse:
"""Authenticate an existing user and issue tokens."""
return await auth_service.login_user(payload)
@router.post("/refresh", response_model=TokenResponse)
async def refresh_tokens(payload: RefreshRequest) -> TokenResponse:
"""Rotate refresh token and mint a new access token."""
return await auth_service.refresh_tokens(payload)
@router.post("/switch-org", response_model=TokenResponse)
async def switch_org(
payload: SwitchOrgRequest,
current_user: CurrentUser = Depends(get_current_user),
) -> TokenResponse:
"""Switch the active organization for the authenticated user."""
return await auth_service.switch_org(current_user, payload)
@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT)
async def logout(
payload: LogoutRequest,
current_user: CurrentUser = Depends(get_current_user),
) -> None:
"""Revoke the provided refresh token for the current session."""
await auth_service.logout(current_user, payload)

View File

@@ -2,8 +2,10 @@
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from contextvars import ContextVar
import asyncpg import asyncpg
from asyncpg.pool import PoolConnectionProxy
import redis.asyncio as redis import redis.asyncio as redis
@@ -27,7 +29,7 @@ class Database:
await self.pool.close() await self.pool.close()
@asynccontextmanager @asynccontextmanager
async def connection(self) -> AsyncGenerator[asyncpg.Connection, None]: async def connection(self) -> AsyncGenerator[asyncpg.Connection | PoolConnectionProxy, None]:
"""Acquire a connection from the pool.""" """Acquire a connection from the pool."""
if not self.pool: if not self.pool:
raise RuntimeError("Database not connected") raise RuntimeError("Database not connected")
@@ -35,7 +37,7 @@ class Database:
yield conn yield conn
@asynccontextmanager @asynccontextmanager
async def transaction(self) -> AsyncGenerator[asyncpg.Connection, None]: async def transaction(self) -> AsyncGenerator[asyncpg.Connection | PoolConnectionProxy, None]:
"""Acquire a connection with an active transaction.""" """Acquire a connection with an active transaction."""
if not self.pool: if not self.pool:
raise RuntimeError("Database not connected") raise RuntimeError("Database not connected")
@@ -74,7 +76,26 @@ db = Database()
redis_client = RedisClient() redis_client = RedisClient()
async def get_conn() -> AsyncGenerator[asyncpg.Connection, None]: _connection_ctx: ContextVar[asyncpg.Connection | PoolConnectionProxy | None] = ContextVar(
"""Dependency for getting a database connection.""" "db_connection",
async with db.connection() as conn: default=None,
yield conn )
async def get_conn() -> AsyncGenerator[asyncpg.Connection | PoolConnectionProxy, None]:
"""Dependency that reuses the same DB connection within a request context."""
existing_conn = _connection_ctx.get()
if existing_conn is not None:
yield existing_conn
return
if not db.pool:
raise RuntimeError("Database not connected")
async with db.pool.acquire() as conn:
token = _connection_ctx.set(conn)
try:
yield conn
finally:
_connection_ctx.reset(token)

View File

@@ -4,8 +4,9 @@ from contextlib import asynccontextmanager
from typing import AsyncGenerator from typing import AsyncGenerator
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from app.api.v1 import health from app.api.v1 import auth, health
from app.config import settings from app.config import settings
from app.db import db, redis_client from app.db import db, redis_client
@@ -26,8 +27,44 @@ app = FastAPI(
title="IncidentOps", title="IncidentOps",
description="Incident management API with multi-tenant org support", description="Incident management API with multi-tenant org support",
version="0.1.0", version="0.1.0",
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json",
lifespan=lifespan, lifespan=lifespan,
) )
app.openapi_tags = [
{"name": "auth", "description": "Registration, login, token lifecycle"},
{"name": "health", "description": "Service health probes"},
]
def custom_openapi() -> dict:
"""Add JWT bearer security scheme to the generated OpenAPI schema."""
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
version=app.version,
description=app.description,
routes=app.routes,
)
security_schemes = openapi_schema.setdefault("components", {}).setdefault("securitySchemes", {})
security_schemes["BearerToken"] = {
"type": "http",
"scheme": "bearer",
"bearerFormat": "JWT",
"description": "Paste the JWT access token returned by /auth endpoints",
}
openapi_schema["security"] = [{"BearerToken": []}]
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi # type: ignore[assignment]
# Include routers # Include routers
app.include_router(auth.router, prefix=settings.api_v1_prefix)
app.include_router(health.router, prefix=settings.api_v1_prefix, tags=["health"]) app.include_router(health.router, prefix=settings.api_v1_prefix, tags=["health"])

View File

@@ -2,6 +2,7 @@
from app.schemas.auth import ( from app.schemas.auth import (
LoginRequest, LoginRequest,
LogoutRequest,
RefreshRequest, RefreshRequest,
RegisterRequest, RegisterRequest,
SwitchOrgRequest, SwitchOrgRequest,
@@ -27,6 +28,7 @@ from app.schemas.org import (
__all__ = [ __all__ = [
# Auth # Auth
"LoginRequest", "LoginRequest",
"LogoutRequest",
"RefreshRequest", "RefreshRequest",
"RegisterRequest", "RegisterRequest",
"SwitchOrgRequest", "SwitchOrgRequest",

View File

@@ -33,6 +33,12 @@ class SwitchOrgRequest(BaseModel):
refresh_token: str refresh_token: str
class LogoutRequest(BaseModel):
"""Request body for logging out and revoking a refresh token."""
refresh_token: str
class TokenResponse(BaseModel): class TokenResponse(BaseModel):
"""Response containing access and refresh tokens.""" """Response containing access and refresh tokens."""

5
app/services/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
"""Service layer entrypoints."""
from app.services.auth import AuthService
__all__ = ["AuthService"]

269
app/services/auth.py Normal file
View File

@@ -0,0 +1,269 @@
"""Authentication service providing business logic for auth flows."""
from __future__ import annotations
import re
from typing import cast
from uuid import UUID, uuid4
import asyncpg
from asyncpg.pool import PoolConnectionProxy
from app.api.deps import CurrentUser
from app.config import settings
from app.core import exceptions as exc, security
from app.db import Database, db
from app.repositories import OrgRepository, RefreshTokenRepository, UserRepository
from app.schemas.auth import (
LoginRequest,
LogoutRequest,
RefreshRequest,
RegisterRequest,
SwitchOrgRequest,
TokenResponse,
)
_SLUG_PATTERN = re.compile(r"[^a-z0-9]+")
def _as_conn(conn: asyncpg.Connection | PoolConnectionProxy) -> asyncpg.Connection:
"""Helper to satisfy typing when a pool proxy is returned."""
return cast(asyncpg.Connection, conn)
class AuthService:
"""Encapsulates authentication workflows (register/login/refresh/logout)."""
def __init__(self, database: Database | None = None) -> None:
self.db = database or db
self._access_token_expires_in = settings.access_token_expire_minutes * 60
async def register_user(self, data: RegisterRequest) -> TokenResponse:
"""Create a new user, default org, membership, and token pair."""
async with self.db.transaction() as conn:
db_conn = _as_conn(conn)
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
refresh_repo = RefreshTokenRepository(db_conn)
if await user_repo.exists_by_email(data.email):
raise exc.ConflictError("Email already registered")
user_id = uuid4()
org_id = uuid4()
member_id = uuid4()
password_hash = security.hash_password(data.password)
await user_repo.create(user_id, data.email, password_hash)
slug = await self._generate_unique_org_slug(org_repo, data.org_name)
await org_repo.create(org_id, data.org_name, slug)
await org_repo.add_member(member_id, user_id, org_id, "admin")
return await self._issue_token_pair(
refresh_repo,
user_id=user_id,
org_id=org_id,
role="admin",
)
async def login_user(self, data: LoginRequest) -> TokenResponse:
"""Authenticate a user and issue tokens for their first organization."""
async with self.db.connection() as conn:
db_conn = _as_conn(conn)
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
refresh_repo = RefreshTokenRepository(db_conn)
user = await user_repo.get_by_email(data.email)
if not user or not security.verify_password(data.password, user["password_hash"]):
raise exc.UnauthorizedError("Invalid email or password")
orgs = await org_repo.get_user_orgs(user["id"])
if not orgs:
raise exc.ForbiddenError("User does not belong to any organization")
active_org = orgs[0]
return await self._issue_token_pair(
refresh_repo,
user_id=user["id"],
org_id=active_org["id"],
role=active_org["role"],
)
async def refresh_tokens(self, data: RefreshRequest) -> TokenResponse:
"""Rotate refresh token and mint a new access token."""
old_hash = security.hash_token(data.refresh_token)
new_refresh_token = security.generate_refresh_token()
new_refresh_hash = security.hash_token(new_refresh_token)
new_refresh_id = uuid4()
new_refresh_expiry = security.get_refresh_token_expiry()
rotated: dict | None = None
membership: dict | None = None
async with self.db.transaction() as conn:
db_conn = _as_conn(conn)
refresh_repo = RefreshTokenRepository(db_conn)
rotated = await refresh_repo.rotate(
old_token_hash=old_hash,
new_token_id=new_refresh_id,
new_token_hash=new_refresh_hash,
new_expires_at=new_refresh_expiry,
)
if rotated is not None:
org_repo = OrgRepository(db_conn)
membership = await org_repo.get_member(rotated["user_id"], rotated["active_org_id"])
if membership is None:
raise exc.UnauthorizedError("Invalid refresh token")
if rotated is None or membership is None:
await self._handle_invalid_refresh(old_hash)
assert rotated is not None and membership is not None
access_token = security.create_access_token(
sub=str(rotated["user_id"]),
org_id=str(rotated["active_org_id"]),
org_role=membership["role"],
)
return TokenResponse(
access_token=access_token,
refresh_token=new_refresh_token,
expires_in=self._access_token_expires_in,
)
async def switch_org(
self,
current_user: CurrentUser,
data: SwitchOrgRequest,
) -> TokenResponse:
"""Switch active organization (rotates refresh token + issues new JWT)."""
target_org_id = data.org_id
old_hash = security.hash_token(data.refresh_token)
new_refresh_token = security.generate_refresh_token()
new_refresh_hash = security.hash_token(new_refresh_token)
new_refresh_expiry = security.get_refresh_token_expiry()
rotated: dict | None = None
membership: dict | None = None
async with self.db.transaction() as conn:
db_conn = _as_conn(conn)
org_repo = OrgRepository(db_conn)
membership = await org_repo.get_member(current_user.user_id, target_org_id)
if membership is None:
raise exc.ForbiddenError("Not a member of the requested organization")
refresh_repo = RefreshTokenRepository(db_conn)
rotated = await refresh_repo.rotate(
old_token_hash=old_hash,
new_token_id=uuid4(),
new_token_hash=new_refresh_hash,
new_expires_at=new_refresh_expiry,
new_active_org_id=target_org_id,
expected_user_id=current_user.user_id,
)
if rotated is None:
await self._handle_invalid_refresh(old_hash)
access_token = security.create_access_token(
sub=str(current_user.user_id),
org_id=str(target_org_id),
org_role=membership["role"],
)
return TokenResponse(
access_token=access_token,
refresh_token=new_refresh_token,
expires_in=self._access_token_expires_in,
)
async def logout(self, current_user: CurrentUser, data: LogoutRequest) -> None:
"""Revoke the provided refresh token for the current session."""
token_hash = security.hash_token(data.refresh_token)
async with self.db.transaction() as conn:
refresh_repo = RefreshTokenRepository(_as_conn(conn))
token = await refresh_repo.get_by_hash(token_hash)
if token and token["user_id"] != current_user.user_id:
raise exc.ForbiddenError("Refresh token does not belong to this user")
if not token:
return
await refresh_repo.revoke(token["id"])
async def _issue_token_pair(
self,
refresh_repo: RefreshTokenRepository,
*,
user_id: UUID,
org_id: UUID,
role: str,
) -> TokenResponse:
"""Create access/refresh tokens and persist the refresh token."""
access_token = security.create_access_token(
sub=str(user_id),
org_id=str(org_id),
org_role=role,
)
refresh_token = security.generate_refresh_token()
await refresh_repo.create(
token_id=uuid4(),
user_id=user_id,
token_hash=security.hash_token(refresh_token),
active_org_id=org_id,
expires_at=security.get_refresh_token_expiry(),
)
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
expires_in=self._access_token_expires_in,
)
async def _handle_invalid_refresh(self, token_hash: str) -> None:
"""Raise appropriate errors for invalid/compromised refresh tokens."""
async with self.db.connection() as conn:
refresh_repo = RefreshTokenRepository(_as_conn(conn))
reused = await refresh_repo.check_token_reuse(token_hash)
if reused:
await refresh_repo.revoke_token_chain(reused["id"])
raise exc.UnauthorizedError("Refresh token reuse detected")
raise exc.UnauthorizedError("Invalid refresh token")
async def _generate_unique_org_slug(
self,
org_repo: OrgRepository,
org_name: str,
) -> str:
"""Slugify the org name and append a counter until unique."""
base_slug = self._slugify(org_name)
candidate = base_slug
counter = 1
while await org_repo.slug_exists(candidate):
suffix = f"-{counter}"
max_base_len = 50 - len(suffix)
candidate = f"{base_slug[:max_base_len]}{suffix}"
counter += 1
return candidate
def _slugify(self, value: str) -> str:
"""Convert arbitrary text into a URL-friendly slug."""
slug = _SLUG_PATTERN.sub("-", value.strip().lower()).strip("-")
return slug[:50] or "org"

65
tests/api/helpers.py Normal file
View File

@@ -0,0 +1,65 @@
"""Shared helpers for API integration tests."""
from __future__ import annotations
from typing import Any
from uuid import UUID, uuid4
import asyncpg
from httpx import AsyncClient
API_PREFIX = "/v1"
async def register_user(
client: AsyncClient,
*,
email: str,
password: str,
org_name: str = "Test Org",
) -> dict[str, Any]:
"""Call the register endpoint and return JSON body (raises on failure)."""
response = await client.post(
f"{API_PREFIX}/auth/register",
json={"email": email, "password": password, "org_name": org_name},
)
response.raise_for_status()
return response.json()
async def create_org(
conn: asyncpg.Connection,
*,
name: str,
slug: str | None = None,
) -> UUID:
"""Insert an organization row and return its ID."""
org_id = uuid4()
slug_value = slug or name.lower().replace(" ", "-")
await conn.execute(
"INSERT INTO orgs (id, name, slug) VALUES ($1, $2, $3)",
org_id,
name,
slug_value,
)
return org_id
async def add_membership(
conn: asyncpg.Connection,
*,
user_id: UUID,
org_id: UUID,
role: str,
) -> None:
"""Insert a membership record for the user/org pair."""
await conn.execute(
"INSERT INTO org_members (id, user_id, org_id, role) VALUES ($1, $2, $3, $4)",
uuid4(),
user_id,
org_id,
role,
)

213
tests/api/test_auth.py Normal file
View File

@@ -0,0 +1,213 @@
"""Integration tests for FastAPI auth endpoints."""
from __future__ import annotations
from uuid import UUID
import asyncpg
import pytest
from httpx import AsyncClient
from app.core import security
from tests.api import helpers
pytestmark = pytest.mark.asyncio
API_PREFIX = "/v1/auth"
async def test_register_endpoint_persists_user_and_membership(
api_client: AsyncClient,
db_admin: asyncpg.Connection,
) -> None:
data = await helpers.register_user(
api_client,
email="api-register@example.com",
password="SuperSecret1!",
org_name="API Org",
)
assert "access_token" in data and "refresh_token" in data
token_payload = security.decode_access_token(data["access_token"])
assert token_payload["org_role"] == "admin"
stored_user = await db_admin.fetchrow("SELECT email FROM users WHERE email = $1", "api-register@example.com")
assert stored_user is not None
membership = await db_admin.fetchrow(
"SELECT role FROM org_members WHERE user_id = $1 AND org_id = $2",
UUID(token_payload["sub"]),
UUID(token_payload["org_id"]),
)
assert membership is not None and membership["role"] == "admin"
async def test_login_endpoint_rejects_bad_credentials(
api_client: AsyncClient,
) -> None:
register_payload = {
"email": "api-login@example.com",
"password": "CorrectHorse1!",
"org_name": "Login Org",
}
await helpers.register_user(api_client, **register_payload)
response = await api_client.post(
f"{API_PREFIX}/login",
json={"email": register_payload["email"], "password": "wrong"},
)
assert response.status_code == 401
async def test_refresh_endpoint_rotates_refresh_token(
api_client: AsyncClient,
db_admin: asyncpg.Connection,
) -> None:
register_payload = {
"email": "api-refresh@example.com",
"password": "RefreshPass1!",
"org_name": "Refresh Org",
}
initial = await helpers.register_user(api_client, **register_payload)
response = await api_client.post(
f"{API_PREFIX}/refresh",
json={"refresh_token": initial["refresh_token"]},
)
assert response.status_code == 200
data = response.json()
assert data["refresh_token"] != initial["refresh_token"]
old_hash = security.hash_token(initial["refresh_token"])
old_row = await db_admin.fetchrow(
"SELECT rotated_to FROM refresh_tokens WHERE token_hash = $1",
old_hash,
)
assert old_row is not None and old_row["rotated_to"] is not None
async def test_refresh_endpoint_detects_reuse(
api_client: AsyncClient,
db_admin: asyncpg.Connection,
) -> None:
tokens = await helpers.register_user(
api_client,
email="api-reuse@example.com",
password="ReusePass1!",
org_name="Reuse Org",
)
rotated = await api_client.post(
f"{API_PREFIX}/refresh",
json={"refresh_token": tokens["refresh_token"]},
)
assert rotated.status_code == 200
reuse_response = await api_client.post(
f"{API_PREFIX}/refresh",
json={"refresh_token": tokens["refresh_token"]},
)
assert reuse_response.status_code == 401
old_hash = security.hash_token(tokens["refresh_token"])
old_row = await db_admin.fetchrow(
"SELECT revoked_at FROM refresh_tokens WHERE token_hash = $1",
old_hash,
)
assert old_row is not None and old_row["revoked_at"] is not None
async def test_switch_org_changes_active_org(
api_client: AsyncClient,
db_admin: asyncpg.Connection,
) -> None:
email = "api-switch@example.com"
register_payload = {
"email": email,
"password": "SwitchPass1!",
"org_name": "Primary Org",
}
tokens = await helpers.register_user(api_client, **register_payload)
user_id_row = await db_admin.fetchrow("SELECT id FROM users WHERE email = $1", email)
assert user_id_row is not None
user_id = user_id_row["id"]
target_org_id = await helpers.create_org(db_admin, name="Secondary Org", slug="secondary-org")
await helpers.add_membership(db_admin, user_id=user_id, org_id=target_org_id, role="member")
response = await api_client.post(
f"{API_PREFIX}/switch-org",
json={"org_id": str(target_org_id), "refresh_token": tokens["refresh_token"]},
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert response.status_code == 200
data = response.json()
payload = security.decode_access_token(data["access_token"])
assert payload["org_id"] == str(target_org_id)
assert payload["org_role"] == "member"
new_hash = security.hash_token(data["refresh_token"])
new_row = await db_admin.fetchrow(
"SELECT active_org_id FROM refresh_tokens WHERE token_hash = $1",
new_hash,
)
assert new_row is not None and new_row["active_org_id"] == target_org_id
async def test_switch_org_forbidden_without_membership(
api_client: AsyncClient,
db_admin: asyncpg.Connection,
) -> None:
tokens = await helpers.register_user(
api_client,
email="api-switch-no-access@example.com",
password="SwitchBlock1!",
org_name="Primary",
)
foreign_org = await helpers.create_org(db_admin, name="Foreign Org", slug="foreign-org")
response = await api_client.post(
f"{API_PREFIX}/switch-org",
json={"org_id": str(foreign_org), "refresh_token": tokens["refresh_token"]},
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert response.status_code == 403
# ensure refresh token still valid after failed attempt
retry = await api_client.post(
f"{API_PREFIX}/refresh",
json={"refresh_token": tokens["refresh_token"]},
)
assert retry.status_code == 200
async def test_logout_revokes_refresh_token(
api_client: AsyncClient,
) -> None:
register_payload = {
"email": "api-logout@example.com",
"password": "LogoutPass1!",
"org_name": "Logout Org",
}
tokens = await helpers.register_user(api_client, **register_payload)
logout_response = await api_client.post(
f"{API_PREFIX}/logout",
json={"refresh_token": tokens["refresh_token"]},
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert logout_response.status_code == 204
refresh_response = await api_client.post(
f"{API_PREFIX}/refresh",
json={"refresh_token": tokens["refresh_token"]},
)
assert refresh_response.status_code == 401

View File

@@ -3,9 +3,12 @@
from __future__ import annotations from __future__ import annotations
import os import os
from uuid import uuid4 from contextlib import asynccontextmanager
from typing import AsyncGenerator, Callable
from uuid import UUID, uuid4
import asyncpg import asyncpg
import httpx
import pytest import pytest
# Set test environment variables before importing app modules # Set test environment variables before importing app modules
@@ -13,6 +16,8 @@ os.environ.setdefault("DATABASE_URL", "postgresql://incidentops:incidentops@loca
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-testing-only") os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-testing-only")
os.environ.setdefault("REDIS_URL", "redis://localhost:6379/1") os.environ.setdefault("REDIS_URL", "redis://localhost:6379/1")
from app.main import app
# Module-level setup: create database and run migrations once # Module-level setup: create database and run migrations once
_db_initialized = False _db_initialized = False
@@ -65,7 +70,7 @@ async def _init_test_db() -> None:
@pytest.fixture @pytest.fixture
async def db_conn() -> asyncpg.Connection: async def db_conn() -> AsyncGenerator[asyncpg.Connection, None]:
"""Get a database connection with transaction rollback for test isolation.""" """Get a database connection with transaction rollback for test isolation."""
await _init_test_db() await _init_test_db()
@@ -84,12 +89,77 @@ async def db_conn() -> asyncpg.Connection:
@pytest.fixture @pytest.fixture
def make_user_id() -> uuid4: def make_user_id() -> Callable[[], UUID]:
"""Factory for generating user IDs.""" """Factory for generating user IDs."""
return lambda: uuid4() return lambda: uuid4()
@pytest.fixture @pytest.fixture
def make_org_id() -> uuid4: def make_org_id() -> Callable[[], UUID]:
"""Factory for generating org IDs.""" """Factory for generating org IDs."""
return lambda: uuid4() return lambda: uuid4()
TABLES_TO_TRUNCATE = [
"incident_events",
"notification_attempts",
"incidents",
"notification_targets",
"services",
"refresh_tokens",
"org_members",
"orgs",
"users",
]
async def _truncate_all_tables() -> None:
test_dsn = os.environ["DATABASE_URL"]
conn = await asyncpg.connect(test_dsn)
try:
tables = ", ".join(TABLES_TO_TRUNCATE)
await conn.execute(f"TRUNCATE TABLE {tables} CASCADE")
finally:
await conn.close()
@pytest.fixture
async def clean_database() -> AsyncGenerator[None, None]:
"""Ensure the database is initialized and truncated before/after tests."""
await _init_test_db()
await _truncate_all_tables()
yield
await _truncate_all_tables()
@asynccontextmanager
async def _lifespan_manager() -> AsyncGenerator[None, None]:
lifespan = app.router.lifespan_context
if lifespan is None:
yield
else:
async with lifespan(app):
yield
@pytest.fixture
async def api_client(clean_database: None) -> AsyncGenerator[httpx.AsyncClient, None]:
"""HTTPX async client bound to the FastAPI app with lifespan support."""
async with _lifespan_manager():
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
yield client
@pytest.fixture
async def db_admin(clean_database: None) -> AsyncGenerator[asyncpg.Connection, None]:
"""Plain connection for arranging/inspecting API test data (no rollback)."""
test_dsn = os.environ["DATABASE_URL"]
conn = await asyncpg.connect(test_dsn)
try:
yield conn
finally:
await conn.close()

80
tests/db/test_get_conn.py Normal file
View File

@@ -0,0 +1,80 @@
"""Tests for the get_conn dependency helper."""
from __future__ import annotations
import pytest
from app.db import db, get_conn
pytestmark = pytest.mark.asyncio
class _FakeConnection:
def __init__(self, idx: int) -> None:
self.idx = idx
class _AcquireContext:
def __init__(self, conn: _FakeConnection, tracker: "_FakePool") -> None:
self._conn = conn
self._tracker = tracker
async def __aenter__(self) -> _FakeConnection:
self._tracker.active += 1
return self._conn
async def __aexit__(self, exc_type, exc, tb) -> None:
self._tracker.active -= 1
class _FakePool:
def __init__(self) -> None:
self.acquire_calls = 0
self.active = 0
def acquire(self) -> _AcquireContext:
conn = _FakeConnection(self.acquire_calls)
self.acquire_calls += 1
return _AcquireContext(conn, self)
async def _collect_single_connection():
connection = None
async for conn in get_conn():
connection = conn
return connection
async def test_get_conn_reuses_connection_within_scope():
original_pool = db.pool
fake_pool = _FakePool()
db.pool = fake_pool
try:
captured: list[_FakeConnection] = []
async for outer in get_conn():
captured.append(outer)
async for inner in get_conn():
captured.append(inner)
assert len(captured) == 2
assert captured[0] is captured[1]
assert fake_pool.acquire_calls == 1
finally:
db.pool = original_pool
async def test_get_conn_acquires_new_connection_per_root_scope():
original_pool = db.pool
fake_pool = _FakePool()
db.pool = fake_pool
try:
first = await _collect_single_connection()
second = await _collect_single_connection()
assert first is not None and second is not None
assert first is not second
assert fake_pool.acquire_calls == 2
finally:
db.pool = original_pool

View File

@@ -0,0 +1,260 @@
"""Unit tests covering AuthService flows."""
from __future__ import annotations
from contextlib import asynccontextmanager
from datetime import UTC, datetime, timedelta
from uuid import UUID, uuid4
import pytest
from app.api.deps import CurrentUser
from app.core import security
from app.db import Database
from app.schemas.auth import (
LoginRequest,
LogoutRequest,
RefreshRequest,
RegisterRequest,
SwitchOrgRequest,
)
from app.services.auth import AuthService
pytestmark = pytest.mark.asyncio
class _SingleConnectionDatabase(Database):
"""Database stub that reuses a single asyncpg connection."""
def __init__(self, conn) -> None: # type: ignore[override]
self._conn = conn
@asynccontextmanager
async def connection(self): # type: ignore[override]
yield self._conn
@asynccontextmanager
async def transaction(self): # type: ignore[override]
tr = self._conn.transaction()
await tr.start()
try:
yield self._conn
except Exception:
await tr.rollback()
raise
else:
await tr.commit()
@pytest.fixture
async def auth_service(db_conn):
"""AuthService bound to the per-test database connection."""
return AuthService(database=_SingleConnectionDatabase(db_conn))
async def _create_user(conn, email: str, password: str) -> UUID:
user_id = uuid4()
password_hash = security.hash_password(password)
await conn.execute(
"INSERT INTO users (id, email, password_hash) VALUES ($1, $2, $3)",
user_id,
email,
password_hash,
)
return user_id
async def _create_org(
conn,
name: str,
slug: str | None = None,
*,
created_at: datetime | None = None,
) -> UUID:
org_id = uuid4()
slug_value = slug or f"{name.lower().replace(' ', '-')}-{org_id.hex[:6]}"
created = created_at or datetime.now(UTC)
await conn.execute(
"INSERT INTO orgs (id, name, slug, created_at) VALUES ($1, $2, $3, $4)",
org_id,
name,
slug_value,
created,
)
return org_id
async def _add_membership(conn, user_id: UUID, org_id: UUID, role: str) -> None:
await conn.execute(
"INSERT INTO org_members (id, user_id, org_id, role) VALUES ($1, $2, $3, $4)",
uuid4(),
user_id,
org_id,
role,
)
async def test_register_user_creates_admin_membership(auth_service, db_conn):
request = RegisterRequest(
email="founder@example.com",
password="SuperSecret1!",
org_name="Founders Inc",
)
response = await auth_service.register_user(request)
payload = security.decode_access_token(response.access_token)
assert payload["org_role"] == "admin"
user_id = UUID(payload["sub"])
org_id = UUID(payload["org_id"])
user = await db_conn.fetchrow("SELECT email FROM users WHERE id = $1", user_id)
assert user is not None and user["email"] == request.email
membership = await db_conn.fetchrow(
"SELECT role FROM org_members WHERE user_id = $1 AND org_id = $2",
user_id,
org_id,
)
assert membership is not None and membership["role"] == "admin"
refresh_hash = security.hash_token(response.refresh_token)
refresh_row = await db_conn.fetchrow(
"SELECT user_id, active_org_id FROM refresh_tokens WHERE token_hash = $1",
refresh_hash,
)
assert refresh_row is not None
assert refresh_row["user_id"] == user_id
assert refresh_row["active_org_id"] == org_id
async def test_login_user_returns_tokens_for_valid_credentials(auth_service, db_conn):
email = "member@example.com"
password = "Password123!"
user_id = await _create_user(db_conn, email, password)
org_id = await _create_org(
db_conn,
name="Member Org",
slug="member-org",
created_at=datetime.now(UTC) - timedelta(days=1),
)
await _add_membership(db_conn, user_id, org_id, "member")
response = await auth_service.login_user(LoginRequest(email=email, password=password))
payload = security.decode_access_token(response.access_token)
assert payload["sub"] == str(user_id)
assert payload["org_id"] == str(org_id)
refresh_hash = security.hash_token(response.refresh_token)
refresh_row = await db_conn.fetchrow(
"SELECT active_org_id FROM refresh_tokens WHERE token_hash = $1",
refresh_hash,
)
assert refresh_row is not None and refresh_row["active_org_id"] == org_id
async def test_refresh_tokens_rotates_existing_token(auth_service, db_conn):
email = "rotate@example.com"
password = "Rotate123!"
user_id = await _create_user(db_conn, email, password)
org_id = await _create_org(db_conn, name="Rotate Org", slug="rotate-org")
await _add_membership(db_conn, user_id, org_id, "member")
initial = await auth_service.login_user(LoginRequest(email=email, password=password))
rotated = await auth_service.refresh_tokens(
RefreshRequest(refresh_token=initial.refresh_token)
)
assert rotated.refresh_token != initial.refresh_token
old_hash = security.hash_token(initial.refresh_token)
old_row = await db_conn.fetchrow(
"SELECT rotated_to FROM refresh_tokens WHERE token_hash = $1",
old_hash,
)
assert old_row is not None and old_row["rotated_to"] is not None
new_hash = security.hash_token(rotated.refresh_token)
new_row = await db_conn.fetchrow(
"SELECT user_id FROM refresh_tokens WHERE token_hash = $1",
new_hash,
)
assert new_row is not None and new_row["user_id"] == user_id
async def test_switch_org_updates_active_org(auth_service, db_conn):
email = "switcher@example.com"
password = "Switch123!"
user_id = await _create_user(db_conn, email, password)
primary_org = await _create_org(
db_conn,
name="Primary Org",
slug="primary-org",
created_at=datetime.now(UTC) - timedelta(days=2),
)
await _add_membership(db_conn, user_id, primary_org, "member")
secondary_org = await _create_org(
db_conn,
name="Secondary Org",
slug="secondary-org",
created_at=datetime.now(UTC) - timedelta(days=1),
)
await _add_membership(db_conn, user_id, secondary_org, "admin")
initial = await auth_service.login_user(LoginRequest(email=email, password=password))
current_user = CurrentUser(
user_id=user_id,
email=email,
org_id=primary_org,
org_role="member",
token=initial.access_token,
)
switched = await auth_service.switch_org(
current_user,
SwitchOrgRequest(org_id=secondary_org, refresh_token=initial.refresh_token),
)
payload = security.decode_access_token(switched.access_token)
assert payload["org_id"] == str(secondary_org)
assert payload["org_role"] == "admin"
new_hash = security.hash_token(switched.refresh_token)
new_row = await db_conn.fetchrow(
"SELECT active_org_id FROM refresh_tokens WHERE token_hash = $1",
new_hash,
)
assert new_row is not None and new_row["active_org_id"] == secondary_org
async def test_logout_revokes_refresh_token(auth_service, db_conn):
email = "logout@example.com"
password = "Logout123!"
user_id = await _create_user(db_conn, email, password)
org_id = await _create_org(db_conn, name="Logout Org", slug="logout-org")
await _add_membership(db_conn, user_id, org_id, "member")
initial = await auth_service.login_user(LoginRequest(email=email, password=password))
current_user = CurrentUser(
user_id=user_id,
email=email,
org_id=org_id,
org_role="member",
token=initial.access_token,
)
await auth_service.logout(current_user, LogoutRequest(refresh_token=initial.refresh_token))
token_hash = security.hash_token(initial.refresh_token)
row = await db_conn.fetchrow(
"SELECT revoked_at FROM refresh_tokens WHERE token_hash = $1",
token_hash,
)
assert row is not None and row["revoked_at"] is not None