From ad94833830721ac229f1863c1e9268b98bddc62b Mon Sep 17 00:00:00 2001 From: minhtrannhat Date: Mon, 29 Dec 2025 09:55:30 +0000 Subject: [PATCH] feat(auth): implement auth stack --- app/api/deps.py | 101 +++++++++++ app/api/v1/auth.py | 59 ++++++ app/db.py | 33 +++- app/main.py | 39 +++- app/schemas/__init__.py | 2 + app/schemas/auth.py | 6 + app/services/__init__.py | 5 + app/services/auth.py | 269 ++++++++++++++++++++++++++++ tests/api/helpers.py | 65 +++++++ tests/api/test_auth.py | 213 ++++++++++++++++++++++ tests/conftest.py | 78 +++++++- tests/db/test_get_conn.py | 80 +++++++++ tests/services/test_auth_service.py | 260 +++++++++++++++++++++++++++ 13 files changed, 1199 insertions(+), 11 deletions(-) create mode 100644 app/api/deps.py create mode 100644 app/api/v1/auth.py create mode 100644 app/services/__init__.py create mode 100644 app/services/auth.py create mode 100644 tests/api/helpers.py create mode 100644 tests/api/test_auth.py create mode 100644 tests/db/test_get_conn.py create mode 100644 tests/services/test_auth_service.py diff --git a/app/api/deps.py b/app/api/deps.py new file mode 100644 index 0000000..ce8407a --- /dev/null +++ b/app/api/deps.py @@ -0,0 +1,101 @@ +"""Shared FastAPI dependencies (auth, RBAC, ownership).""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable +from uuid import UUID + +from fastapi import Depends +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + +from app.core import exceptions as exc, security +from app.db import db +from app.repositories import OrgRepository, UserRepository + + +bearer_scheme = HTTPBearer(auto_error=False) + +ROLE_RANKS: dict[str, int] = {"viewer": 0, "member": 1, "admin": 2} + + +@dataclass(slots=True) +class CurrentUser: + """Authenticated user context derived from the access token.""" + + user_id: UUID + email: str + org_id: UUID + org_role: str + token: str + + +async def get_current_user( + credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme), +) -> CurrentUser: + """Extract and validate the current user from the Authorization header.""" + + if credentials is None or credentials.scheme.lower() != "bearer": + raise exc.UnauthorizedError("Missing bearer token") + + try: + payload = security.TokenPayload(security.decode_access_token(credentials.credentials)) + except security.JWTError as err: # pragma: no cover - jose error types + raise exc.UnauthorizedError("Invalid access token") from err + + async with db.connection() as conn: + user_repo = UserRepository(conn) + user = await user_repo.get_by_id(payload.user_id) + if user is None: + raise exc.UnauthorizedError("User not found") + + org_repo = OrgRepository(conn) + membership = await org_repo.get_member(payload.user_id, payload.org_id) + if membership is None: + raise exc.ForbiddenError("Organization access denied") + + return CurrentUser( + user_id=payload.user_id, + email=user["email"], + org_id=payload.org_id, + org_role=membership["role"], + token=credentials.credentials, + ) + + +class RoleChecker: + """Dependency that enforces a minimum organization role.""" + + def __init__(self, minimum_role: str) -> None: + if minimum_role not in ROLE_RANKS: + raise ValueError(f"Unknown role '{minimum_role}'") + self.minimum_role = minimum_role + + def __call__(self, current_user: CurrentUser = Depends(get_current_user)) -> CurrentUser: + if ROLE_RANKS[current_user.org_role] < ROLE_RANKS[self.minimum_role]: + raise exc.ForbiddenError("Insufficient role for this operation") + return current_user + + +def require_role(min_role: str) -> Callable[[CurrentUser], CurrentUser]: + """Factory that returns a dependency enforcing the specified role.""" + + return RoleChecker(min_role) + + +def ensure_org_access(resource_org_id: UUID, current_user: CurrentUser) -> None: + """Verify that the resource belongs to the active org in the token.""" + + if resource_org_id != current_user.org_id: + raise exc.ForbiddenError("Resource does not belong to the active organization") + + +__all__ = [ + "CurrentUser", + "ROLE_RANKS", + "RoleChecker", + "bearer_scheme", + "ensure_org_access", + "get_current_user", + "require_role", +] diff --git a/app/api/v1/auth.py b/app/api/v1/auth.py new file mode 100644 index 0000000..061b60e --- /dev/null +++ b/app/api/v1/auth.py @@ -0,0 +1,59 @@ +"""Authentication API endpoints.""" + +from fastapi import APIRouter, Depends, status + +from app.api.deps import CurrentUser, get_current_user +from app.schemas.auth import ( + LoginRequest, + LogoutRequest, + RefreshRequest, + RegisterRequest, + SwitchOrgRequest, + TokenResponse, +) +from app.services import AuthService + + +router = APIRouter(prefix="/auth", tags=["auth"]) +auth_service = AuthService() + + +@router.post("/register", response_model=TokenResponse, status_code=status.HTTP_201_CREATED) +async def register_user(payload: RegisterRequest) -> TokenResponse: + """Register a new user and default org, returning auth tokens.""" + + return await auth_service.register_user(payload) + + +@router.post("/login", response_model=TokenResponse) +async def login_user(payload: LoginRequest) -> TokenResponse: + """Authenticate an existing user and issue tokens.""" + + return await auth_service.login_user(payload) + + +@router.post("/refresh", response_model=TokenResponse) +async def refresh_tokens(payload: RefreshRequest) -> TokenResponse: + """Rotate refresh token and mint a new access token.""" + + return await auth_service.refresh_tokens(payload) + + +@router.post("/switch-org", response_model=TokenResponse) +async def switch_org( + payload: SwitchOrgRequest, + current_user: CurrentUser = Depends(get_current_user), +) -> TokenResponse: + """Switch the active organization for the authenticated user.""" + + return await auth_service.switch_org(current_user, payload) + + +@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT) +async def logout( + payload: LogoutRequest, + current_user: CurrentUser = Depends(get_current_user), +) -> None: + """Revoke the provided refresh token for the current session.""" + + await auth_service.logout(current_user, payload) diff --git a/app/db.py b/app/db.py index 0729aa2..fb167ba 100644 --- a/app/db.py +++ b/app/db.py @@ -2,8 +2,10 @@ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +from contextvars import ContextVar import asyncpg +from asyncpg.pool import PoolConnectionProxy import redis.asyncio as redis @@ -27,7 +29,7 @@ class Database: await self.pool.close() @asynccontextmanager - async def connection(self) -> AsyncGenerator[asyncpg.Connection, None]: + async def connection(self) -> AsyncGenerator[asyncpg.Connection | PoolConnectionProxy, None]: """Acquire a connection from the pool.""" if not self.pool: raise RuntimeError("Database not connected") @@ -35,7 +37,7 @@ class Database: yield conn @asynccontextmanager - async def transaction(self) -> AsyncGenerator[asyncpg.Connection, None]: + async def transaction(self) -> AsyncGenerator[asyncpg.Connection | PoolConnectionProxy, None]: """Acquire a connection with an active transaction.""" if not self.pool: raise RuntimeError("Database not connected") @@ -74,7 +76,26 @@ db = Database() redis_client = RedisClient() -async def get_conn() -> AsyncGenerator[asyncpg.Connection, None]: - """Dependency for getting a database connection.""" - async with db.connection() as conn: - yield conn +_connection_ctx: ContextVar[asyncpg.Connection | PoolConnectionProxy | None] = ContextVar( + "db_connection", + default=None, +) + + +async def get_conn() -> AsyncGenerator[asyncpg.Connection | PoolConnectionProxy, None]: + """Dependency that reuses the same DB connection within a request context.""" + + existing_conn = _connection_ctx.get() + if existing_conn is not None: + yield existing_conn + return + + if not db.pool: + raise RuntimeError("Database not connected") + + async with db.pool.acquire() as conn: + token = _connection_ctx.set(conn) + try: + yield conn + finally: + _connection_ctx.reset(token) diff --git a/app/main.py b/app/main.py index 92c1e77..9395cb9 100644 --- a/app/main.py +++ b/app/main.py @@ -4,8 +4,9 @@ from contextlib import asynccontextmanager from typing import AsyncGenerator from fastapi import FastAPI +from fastapi.openapi.utils import get_openapi -from app.api.v1 import health +from app.api.v1 import auth, health from app.config import settings from app.db import db, redis_client @@ -26,8 +27,44 @@ app = FastAPI( title="IncidentOps", description="Incident management API with multi-tenant org support", version="0.1.0", + docs_url="/docs", + redoc_url="/redoc", + openapi_url="/openapi.json", lifespan=lifespan, ) +app.openapi_tags = [ + {"name": "auth", "description": "Registration, login, token lifecycle"}, + {"name": "health", "description": "Service health probes"}, +] + + +def custom_openapi() -> dict: + """Add JWT bearer security scheme to the generated OpenAPI schema.""" + + if app.openapi_schema: + return app.openapi_schema + + openapi_schema = get_openapi( + title=app.title, + version=app.version, + description=app.description, + routes=app.routes, + ) + security_schemes = openapi_schema.setdefault("components", {}).setdefault("securitySchemes", {}) + security_schemes["BearerToken"] = { + "type": "http", + "scheme": "bearer", + "bearerFormat": "JWT", + "description": "Paste the JWT access token returned by /auth endpoints", + } + openapi_schema["security"] = [{"BearerToken": []}] + app.openapi_schema = openapi_schema + return app.openapi_schema + + +app.openapi = custom_openapi # type: ignore[assignment] + # Include routers +app.include_router(auth.router, prefix=settings.api_v1_prefix) app.include_router(health.router, prefix=settings.api_v1_prefix, tags=["health"]) diff --git a/app/schemas/__init__.py b/app/schemas/__init__.py index c8c6da1..a546b21 100644 --- a/app/schemas/__init__.py +++ b/app/schemas/__init__.py @@ -2,6 +2,7 @@ from app.schemas.auth import ( LoginRequest, + LogoutRequest, RefreshRequest, RegisterRequest, SwitchOrgRequest, @@ -27,6 +28,7 @@ from app.schemas.org import ( __all__ = [ # Auth "LoginRequest", + "LogoutRequest", "RefreshRequest", "RegisterRequest", "SwitchOrgRequest", diff --git a/app/schemas/auth.py b/app/schemas/auth.py index 9e5b5a0..5e7bda3 100644 --- a/app/schemas/auth.py +++ b/app/schemas/auth.py @@ -33,6 +33,12 @@ class SwitchOrgRequest(BaseModel): refresh_token: str +class LogoutRequest(BaseModel): + """Request body for logging out and revoking a refresh token.""" + + refresh_token: str + + class TokenResponse(BaseModel): """Response containing access and refresh tokens.""" diff --git a/app/services/__init__.py b/app/services/__init__.py new file mode 100644 index 0000000..aef1bc2 --- /dev/null +++ b/app/services/__init__.py @@ -0,0 +1,5 @@ +"""Service layer entrypoints.""" + +from app.services.auth import AuthService + +__all__ = ["AuthService"] diff --git a/app/services/auth.py b/app/services/auth.py new file mode 100644 index 0000000..82a3b67 --- /dev/null +++ b/app/services/auth.py @@ -0,0 +1,269 @@ +"""Authentication service providing business logic for auth flows.""" + +from __future__ import annotations + +import re +from typing import cast +from uuid import UUID, uuid4 + +import asyncpg +from asyncpg.pool import PoolConnectionProxy + +from app.api.deps import CurrentUser +from app.config import settings +from app.core import exceptions as exc, security +from app.db import Database, db +from app.repositories import OrgRepository, RefreshTokenRepository, UserRepository +from app.schemas.auth import ( + LoginRequest, + LogoutRequest, + RefreshRequest, + RegisterRequest, + SwitchOrgRequest, + TokenResponse, +) + + +_SLUG_PATTERN = re.compile(r"[^a-z0-9]+") + + +def _as_conn(conn: asyncpg.Connection | PoolConnectionProxy) -> asyncpg.Connection: + """Helper to satisfy typing when a pool proxy is returned.""" + + return cast(asyncpg.Connection, conn) + + +class AuthService: + """Encapsulates authentication workflows (register/login/refresh/logout).""" + + def __init__(self, database: Database | None = None) -> None: + self.db = database or db + self._access_token_expires_in = settings.access_token_expire_minutes * 60 + + async def register_user(self, data: RegisterRequest) -> TokenResponse: + """Create a new user, default org, membership, and token pair.""" + + async with self.db.transaction() as conn: + db_conn = _as_conn(conn) + user_repo = UserRepository(db_conn) + org_repo = OrgRepository(db_conn) + refresh_repo = RefreshTokenRepository(db_conn) + + if await user_repo.exists_by_email(data.email): + raise exc.ConflictError("Email already registered") + + user_id = uuid4() + org_id = uuid4() + member_id = uuid4() + password_hash = security.hash_password(data.password) + + await user_repo.create(user_id, data.email, password_hash) + slug = await self._generate_unique_org_slug(org_repo, data.org_name) + await org_repo.create(org_id, data.org_name, slug) + await org_repo.add_member(member_id, user_id, org_id, "admin") + + return await self._issue_token_pair( + refresh_repo, + user_id=user_id, + org_id=org_id, + role="admin", + ) + + async def login_user(self, data: LoginRequest) -> TokenResponse: + """Authenticate a user and issue tokens for their first organization.""" + + async with self.db.connection() as conn: + db_conn = _as_conn(conn) + user_repo = UserRepository(db_conn) + org_repo = OrgRepository(db_conn) + refresh_repo = RefreshTokenRepository(db_conn) + + user = await user_repo.get_by_email(data.email) + if not user or not security.verify_password(data.password, user["password_hash"]): + raise exc.UnauthorizedError("Invalid email or password") + + orgs = await org_repo.get_user_orgs(user["id"]) + if not orgs: + raise exc.ForbiddenError("User does not belong to any organization") + + active_org = orgs[0] + return await self._issue_token_pair( + refresh_repo, + user_id=user["id"], + org_id=active_org["id"], + role=active_org["role"], + ) + + async def refresh_tokens(self, data: RefreshRequest) -> TokenResponse: + """Rotate refresh token and mint a new access token.""" + + old_hash = security.hash_token(data.refresh_token) + new_refresh_token = security.generate_refresh_token() + new_refresh_hash = security.hash_token(new_refresh_token) + new_refresh_id = uuid4() + new_refresh_expiry = security.get_refresh_token_expiry() + + rotated: dict | None = None + membership: dict | None = None + + async with self.db.transaction() as conn: + db_conn = _as_conn(conn) + refresh_repo = RefreshTokenRepository(db_conn) + rotated = await refresh_repo.rotate( + old_token_hash=old_hash, + new_token_id=new_refresh_id, + new_token_hash=new_refresh_hash, + new_expires_at=new_refresh_expiry, + ) + + if rotated is not None: + org_repo = OrgRepository(db_conn) + membership = await org_repo.get_member(rotated["user_id"], rotated["active_org_id"]) + if membership is None: + raise exc.UnauthorizedError("Invalid refresh token") + + if rotated is None or membership is None: + await self._handle_invalid_refresh(old_hash) + + assert rotated is not None and membership is not None + access_token = security.create_access_token( + sub=str(rotated["user_id"]), + org_id=str(rotated["active_org_id"]), + org_role=membership["role"], + ) + + return TokenResponse( + access_token=access_token, + refresh_token=new_refresh_token, + expires_in=self._access_token_expires_in, + ) + + async def switch_org( + self, + current_user: CurrentUser, + data: SwitchOrgRequest, + ) -> TokenResponse: + """Switch active organization (rotates refresh token + issues new JWT).""" + + target_org_id = data.org_id + old_hash = security.hash_token(data.refresh_token) + new_refresh_token = security.generate_refresh_token() + new_refresh_hash = security.hash_token(new_refresh_token) + new_refresh_expiry = security.get_refresh_token_expiry() + + rotated: dict | None = None + membership: dict | None = None + + async with self.db.transaction() as conn: + db_conn = _as_conn(conn) + org_repo = OrgRepository(db_conn) + membership = await org_repo.get_member(current_user.user_id, target_org_id) + if membership is None: + raise exc.ForbiddenError("Not a member of the requested organization") + + refresh_repo = RefreshTokenRepository(db_conn) + rotated = await refresh_repo.rotate( + old_token_hash=old_hash, + new_token_id=uuid4(), + new_token_hash=new_refresh_hash, + new_expires_at=new_refresh_expiry, + new_active_org_id=target_org_id, + expected_user_id=current_user.user_id, + ) + + if rotated is None: + await self._handle_invalid_refresh(old_hash) + + access_token = security.create_access_token( + sub=str(current_user.user_id), + org_id=str(target_org_id), + org_role=membership["role"], + ) + + return TokenResponse( + access_token=access_token, + refresh_token=new_refresh_token, + expires_in=self._access_token_expires_in, + ) + + async def logout(self, current_user: CurrentUser, data: LogoutRequest) -> None: + """Revoke the provided refresh token for the current session.""" + + token_hash = security.hash_token(data.refresh_token) + + async with self.db.transaction() as conn: + refresh_repo = RefreshTokenRepository(_as_conn(conn)) + token = await refresh_repo.get_by_hash(token_hash) + if token and token["user_id"] != current_user.user_id: + raise exc.ForbiddenError("Refresh token does not belong to this user") + + if not token: + return + + await refresh_repo.revoke(token["id"]) + + async def _issue_token_pair( + self, + refresh_repo: RefreshTokenRepository, + *, + user_id: UUID, + org_id: UUID, + role: str, + ) -> TokenResponse: + """Create access/refresh tokens and persist the refresh token.""" + + access_token = security.create_access_token( + sub=str(user_id), + org_id=str(org_id), + org_role=role, + ) + + refresh_token = security.generate_refresh_token() + await refresh_repo.create( + token_id=uuid4(), + user_id=user_id, + token_hash=security.hash_token(refresh_token), + active_org_id=org_id, + expires_at=security.get_refresh_token_expiry(), + ) + + return TokenResponse( + access_token=access_token, + refresh_token=refresh_token, + expires_in=self._access_token_expires_in, + ) + + async def _handle_invalid_refresh(self, token_hash: str) -> None: + """Raise appropriate errors for invalid/compromised refresh tokens.""" + + async with self.db.connection() as conn: + refresh_repo = RefreshTokenRepository(_as_conn(conn)) + reused = await refresh_repo.check_token_reuse(token_hash) + if reused: + await refresh_repo.revoke_token_chain(reused["id"]) + raise exc.UnauthorizedError("Refresh token reuse detected") + + raise exc.UnauthorizedError("Invalid refresh token") + + async def _generate_unique_org_slug( + self, + org_repo: OrgRepository, + org_name: str, + ) -> str: + """Slugify the org name and append a counter until unique.""" + + base_slug = self._slugify(org_name) + candidate = base_slug + counter = 1 + while await org_repo.slug_exists(candidate): + suffix = f"-{counter}" + max_base_len = 50 - len(suffix) + candidate = f"{base_slug[:max_base_len]}{suffix}" + counter += 1 + return candidate + + def _slugify(self, value: str) -> str: + """Convert arbitrary text into a URL-friendly slug.""" + + slug = _SLUG_PATTERN.sub("-", value.strip().lower()).strip("-") + return slug[:50] or "org" diff --git a/tests/api/helpers.py b/tests/api/helpers.py new file mode 100644 index 0000000..a37c239 --- /dev/null +++ b/tests/api/helpers.py @@ -0,0 +1,65 @@ +"""Shared helpers for API integration tests.""" + +from __future__ import annotations + +from typing import Any +from uuid import UUID, uuid4 + +import asyncpg +from httpx import AsyncClient + +API_PREFIX = "/v1" + + +async def register_user( + client: AsyncClient, + *, + email: str, + password: str, + org_name: str = "Test Org", +) -> dict[str, Any]: + """Call the register endpoint and return JSON body (raises on failure).""" + + response = await client.post( + f"{API_PREFIX}/auth/register", + json={"email": email, "password": password, "org_name": org_name}, + ) + response.raise_for_status() + return response.json() + + +async def create_org( + conn: asyncpg.Connection, + *, + name: str, + slug: str | None = None, +) -> UUID: + """Insert an organization row and return its ID.""" + + org_id = uuid4() + slug_value = slug or name.lower().replace(" ", "-") + await conn.execute( + "INSERT INTO orgs (id, name, slug) VALUES ($1, $2, $3)", + org_id, + name, + slug_value, + ) + return org_id + + +async def add_membership( + conn: asyncpg.Connection, + *, + user_id: UUID, + org_id: UUID, + role: str, +) -> None: + """Insert a membership record for the user/org pair.""" + + await conn.execute( + "INSERT INTO org_members (id, user_id, org_id, role) VALUES ($1, $2, $3, $4)", + uuid4(), + user_id, + org_id, + role, + ) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py new file mode 100644 index 0000000..2c05d05 --- /dev/null +++ b/tests/api/test_auth.py @@ -0,0 +1,213 @@ +"""Integration tests for FastAPI auth endpoints.""" + +from __future__ import annotations + +from uuid import UUID + +import asyncpg +import pytest +from httpx import AsyncClient + +from app.core import security +from tests.api import helpers + + +pytestmark = pytest.mark.asyncio + +API_PREFIX = "/v1/auth" + + +async def test_register_endpoint_persists_user_and_membership( + api_client: AsyncClient, + db_admin: asyncpg.Connection, +) -> None: + data = await helpers.register_user( + api_client, + email="api-register@example.com", + password="SuperSecret1!", + org_name="API Org", + ) + assert "access_token" in data and "refresh_token" in data + + token_payload = security.decode_access_token(data["access_token"]) + assert token_payload["org_role"] == "admin" + + stored_user = await db_admin.fetchrow("SELECT email FROM users WHERE email = $1", "api-register@example.com") + assert stored_user is not None + + membership = await db_admin.fetchrow( + "SELECT role FROM org_members WHERE user_id = $1 AND org_id = $2", + UUID(token_payload["sub"]), + UUID(token_payload["org_id"]), + ) + assert membership is not None and membership["role"] == "admin" + + +async def test_login_endpoint_rejects_bad_credentials( + api_client: AsyncClient, +) -> None: + register_payload = { + "email": "api-login@example.com", + "password": "CorrectHorse1!", + "org_name": "Login Org", + } + await helpers.register_user(api_client, **register_payload) + + response = await api_client.post( + f"{API_PREFIX}/login", + json={"email": register_payload["email"], "password": "wrong"}, + ) + + assert response.status_code == 401 + + +async def test_refresh_endpoint_rotates_refresh_token( + api_client: AsyncClient, + db_admin: asyncpg.Connection, +) -> None: + register_payload = { + "email": "api-refresh@example.com", + "password": "RefreshPass1!", + "org_name": "Refresh Org", + } + initial = await helpers.register_user(api_client, **register_payload) + + response = await api_client.post( + f"{API_PREFIX}/refresh", + json={"refresh_token": initial["refresh_token"]}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["refresh_token"] != initial["refresh_token"] + + old_hash = security.hash_token(initial["refresh_token"]) + old_row = await db_admin.fetchrow( + "SELECT rotated_to FROM refresh_tokens WHERE token_hash = $1", + old_hash, + ) + assert old_row is not None and old_row["rotated_to"] is not None + + +async def test_refresh_endpoint_detects_reuse( + api_client: AsyncClient, + db_admin: asyncpg.Connection, +) -> None: + tokens = await helpers.register_user( + api_client, + email="api-reuse@example.com", + password="ReusePass1!", + org_name="Reuse Org", + ) + + rotated = await api_client.post( + f"{API_PREFIX}/refresh", + json={"refresh_token": tokens["refresh_token"]}, + ) + assert rotated.status_code == 200 + + reuse_response = await api_client.post( + f"{API_PREFIX}/refresh", + json={"refresh_token": tokens["refresh_token"]}, + ) + assert reuse_response.status_code == 401 + + old_hash = security.hash_token(tokens["refresh_token"]) + old_row = await db_admin.fetchrow( + "SELECT revoked_at FROM refresh_tokens WHERE token_hash = $1", + old_hash, + ) + assert old_row is not None and old_row["revoked_at"] is not None + + +async def test_switch_org_changes_active_org( + api_client: AsyncClient, + db_admin: asyncpg.Connection, +) -> None: + email = "api-switch@example.com" + register_payload = { + "email": email, + "password": "SwitchPass1!", + "org_name": "Primary Org", + } + tokens = await helpers.register_user(api_client, **register_payload) + + user_id_row = await db_admin.fetchrow("SELECT id FROM users WHERE email = $1", email) + assert user_id_row is not None + user_id = user_id_row["id"] + + target_org_id = await helpers.create_org(db_admin, name="Secondary Org", slug="secondary-org") + await helpers.add_membership(db_admin, user_id=user_id, org_id=target_org_id, role="member") + + response = await api_client.post( + f"{API_PREFIX}/switch-org", + json={"org_id": str(target_org_id), "refresh_token": tokens["refresh_token"]}, + headers={"Authorization": f"Bearer {tokens['access_token']}"}, + ) + + assert response.status_code == 200 + data = response.json() + payload = security.decode_access_token(data["access_token"]) + assert payload["org_id"] == str(target_org_id) + assert payload["org_role"] == "member" + + new_hash = security.hash_token(data["refresh_token"]) + new_row = await db_admin.fetchrow( + "SELECT active_org_id FROM refresh_tokens WHERE token_hash = $1", + new_hash, + ) + assert new_row is not None and new_row["active_org_id"] == target_org_id + + +async def test_switch_org_forbidden_without_membership( + api_client: AsyncClient, + db_admin: asyncpg.Connection, +) -> None: + tokens = await helpers.register_user( + api_client, + email="api-switch-no-access@example.com", + password="SwitchBlock1!", + org_name="Primary", + ) + + foreign_org = await helpers.create_org(db_admin, name="Foreign Org", slug="foreign-org") + + response = await api_client.post( + f"{API_PREFIX}/switch-org", + json={"org_id": str(foreign_org), "refresh_token": tokens["refresh_token"]}, + headers={"Authorization": f"Bearer {tokens['access_token']}"}, + ) + assert response.status_code == 403 + + # ensure refresh token still valid after failed attempt + retry = await api_client.post( + f"{API_PREFIX}/refresh", + json={"refresh_token": tokens["refresh_token"]}, + ) + assert retry.status_code == 200 + + +async def test_logout_revokes_refresh_token( + api_client: AsyncClient, +) -> None: + register_payload = { + "email": "api-logout@example.com", + "password": "LogoutPass1!", + "org_name": "Logout Org", + } + tokens = await helpers.register_user(api_client, **register_payload) + + logout_response = await api_client.post( + f"{API_PREFIX}/logout", + json={"refresh_token": tokens["refresh_token"]}, + headers={"Authorization": f"Bearer {tokens['access_token']}"}, + ) + + assert logout_response.status_code == 204 + + refresh_response = await api_client.post( + f"{API_PREFIX}/refresh", + json={"refresh_token": tokens["refresh_token"]}, + ) + + assert refresh_response.status_code == 401 diff --git a/tests/conftest.py b/tests/conftest.py index 490f7a7..44b23e8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,9 +3,12 @@ from __future__ import annotations import os -from uuid import uuid4 +from contextlib import asynccontextmanager +from typing import AsyncGenerator, Callable +from uuid import UUID, uuid4 import asyncpg +import httpx import pytest # Set test environment variables before importing app modules @@ -13,6 +16,8 @@ os.environ.setdefault("DATABASE_URL", "postgresql://incidentops:incidentops@loca os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-testing-only") os.environ.setdefault("REDIS_URL", "redis://localhost:6379/1") +from app.main import app + # Module-level setup: create database and run migrations once _db_initialized = False @@ -65,7 +70,7 @@ async def _init_test_db() -> None: @pytest.fixture -async def db_conn() -> asyncpg.Connection: +async def db_conn() -> AsyncGenerator[asyncpg.Connection, None]: """Get a database connection with transaction rollback for test isolation.""" await _init_test_db() @@ -84,12 +89,77 @@ async def db_conn() -> asyncpg.Connection: @pytest.fixture -def make_user_id() -> uuid4: +def make_user_id() -> Callable[[], UUID]: """Factory for generating user IDs.""" return lambda: uuid4() @pytest.fixture -def make_org_id() -> uuid4: +def make_org_id() -> Callable[[], UUID]: """Factory for generating org IDs.""" return lambda: uuid4() + + +TABLES_TO_TRUNCATE = [ + "incident_events", + "notification_attempts", + "incidents", + "notification_targets", + "services", + "refresh_tokens", + "org_members", + "orgs", + "users", +] + + +async def _truncate_all_tables() -> None: + test_dsn = os.environ["DATABASE_URL"] + conn = await asyncpg.connect(test_dsn) + try: + tables = ", ".join(TABLES_TO_TRUNCATE) + await conn.execute(f"TRUNCATE TABLE {tables} CASCADE") + finally: + await conn.close() + + +@pytest.fixture +async def clean_database() -> AsyncGenerator[None, None]: + """Ensure the database is initialized and truncated before/after tests.""" + + await _init_test_db() + await _truncate_all_tables() + yield + await _truncate_all_tables() + + +@asynccontextmanager +async def _lifespan_manager() -> AsyncGenerator[None, None]: + lifespan = app.router.lifespan_context + if lifespan is None: + yield + else: + async with lifespan(app): + yield + + +@pytest.fixture +async def api_client(clean_database: None) -> AsyncGenerator[httpx.AsyncClient, None]: + """HTTPX async client bound to the FastAPI app with lifespan support.""" + + async with _lifespan_manager(): + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + yield client + + +@pytest.fixture +async def db_admin(clean_database: None) -> AsyncGenerator[asyncpg.Connection, None]: + """Plain connection for arranging/inspecting API test data (no rollback).""" + + test_dsn = os.environ["DATABASE_URL"] + conn = await asyncpg.connect(test_dsn) + try: + yield conn + finally: + await conn.close() diff --git a/tests/db/test_get_conn.py b/tests/db/test_get_conn.py new file mode 100644 index 0000000..6df975e --- /dev/null +++ b/tests/db/test_get_conn.py @@ -0,0 +1,80 @@ +"""Tests for the get_conn dependency helper.""" + +from __future__ import annotations + +import pytest + +from app.db import db, get_conn + + +pytestmark = pytest.mark.asyncio + + +class _FakeConnection: + def __init__(self, idx: int) -> None: + self.idx = idx + + +class _AcquireContext: + def __init__(self, conn: _FakeConnection, tracker: "_FakePool") -> None: + self._conn = conn + self._tracker = tracker + + async def __aenter__(self) -> _FakeConnection: + self._tracker.active += 1 + return self._conn + + async def __aexit__(self, exc_type, exc, tb) -> None: + self._tracker.active -= 1 + + +class _FakePool: + def __init__(self) -> None: + self.acquire_calls = 0 + self.active = 0 + + def acquire(self) -> _AcquireContext: + conn = _FakeConnection(self.acquire_calls) + self.acquire_calls += 1 + return _AcquireContext(conn, self) + + +async def _collect_single_connection(): + connection = None + async for conn in get_conn(): + connection = conn + return connection + + +async def test_get_conn_reuses_connection_within_scope(): + original_pool = db.pool + fake_pool = _FakePool() + db.pool = fake_pool + try: + captured: list[_FakeConnection] = [] + + async for outer in get_conn(): + captured.append(outer) + async for inner in get_conn(): + captured.append(inner) + + assert len(captured) == 2 + assert captured[0] is captured[1] + assert fake_pool.acquire_calls == 1 + finally: + db.pool = original_pool + + +async def test_get_conn_acquires_new_connection_per_root_scope(): + original_pool = db.pool + fake_pool = _FakePool() + db.pool = fake_pool + try: + first = await _collect_single_connection() + second = await _collect_single_connection() + + assert first is not None and second is not None + assert first is not second + assert fake_pool.acquire_calls == 2 + finally: + db.pool = original_pool diff --git a/tests/services/test_auth_service.py b/tests/services/test_auth_service.py new file mode 100644 index 0000000..b325963 --- /dev/null +++ b/tests/services/test_auth_service.py @@ -0,0 +1,260 @@ +"""Unit tests covering AuthService flows.""" + +from __future__ import annotations + +from contextlib import asynccontextmanager +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +import pytest + +from app.api.deps import CurrentUser +from app.core import security +from app.db import Database +from app.schemas.auth import ( + LoginRequest, + LogoutRequest, + RefreshRequest, + RegisterRequest, + SwitchOrgRequest, +) +from app.services.auth import AuthService + + +pytestmark = pytest.mark.asyncio + + +class _SingleConnectionDatabase(Database): + """Database stub that reuses a single asyncpg connection.""" + + def __init__(self, conn) -> None: # type: ignore[override] + self._conn = conn + + @asynccontextmanager + async def connection(self): # type: ignore[override] + yield self._conn + + @asynccontextmanager + async def transaction(self): # type: ignore[override] + tr = self._conn.transaction() + await tr.start() + try: + yield self._conn + except Exception: + await tr.rollback() + raise + else: + await tr.commit() + + +@pytest.fixture +async def auth_service(db_conn): + """AuthService bound to the per-test database connection.""" + + return AuthService(database=_SingleConnectionDatabase(db_conn)) + + +async def _create_user(conn, email: str, password: str) -> UUID: + user_id = uuid4() + password_hash = security.hash_password(password) + await conn.execute( + "INSERT INTO users (id, email, password_hash) VALUES ($1, $2, $3)", + user_id, + email, + password_hash, + ) + return user_id + + +async def _create_org( + conn, + name: str, + slug: str | None = None, + *, + created_at: datetime | None = None, +) -> UUID: + org_id = uuid4() + slug_value = slug or f"{name.lower().replace(' ', '-')}-{org_id.hex[:6]}" + created = created_at or datetime.now(UTC) + await conn.execute( + "INSERT INTO orgs (id, name, slug, created_at) VALUES ($1, $2, $3, $4)", + org_id, + name, + slug_value, + created, + ) + return org_id + + +async def _add_membership(conn, user_id: UUID, org_id: UUID, role: str) -> None: + await conn.execute( + "INSERT INTO org_members (id, user_id, org_id, role) VALUES ($1, $2, $3, $4)", + uuid4(), + user_id, + org_id, + role, + ) + + +async def test_register_user_creates_admin_membership(auth_service, db_conn): + request = RegisterRequest( + email="founder@example.com", + password="SuperSecret1!", + org_name="Founders Inc", + ) + + response = await auth_service.register_user(request) + + payload = security.decode_access_token(response.access_token) + assert payload["org_role"] == "admin" + + user_id = UUID(payload["sub"]) + org_id = UUID(payload["org_id"]) + + user = await db_conn.fetchrow("SELECT email FROM users WHERE id = $1", user_id) + assert user is not None and user["email"] == request.email + + membership = await db_conn.fetchrow( + "SELECT role FROM org_members WHERE user_id = $1 AND org_id = $2", + user_id, + org_id, + ) + assert membership is not None and membership["role"] == "admin" + + refresh_hash = security.hash_token(response.refresh_token) + refresh_row = await db_conn.fetchrow( + "SELECT user_id, active_org_id FROM refresh_tokens WHERE token_hash = $1", + refresh_hash, + ) + assert refresh_row is not None + assert refresh_row["user_id"] == user_id + assert refresh_row["active_org_id"] == org_id + + +async def test_login_user_returns_tokens_for_valid_credentials(auth_service, db_conn): + email = "member@example.com" + password = "Password123!" + user_id = await _create_user(db_conn, email, password) + org_id = await _create_org( + db_conn, + name="Member Org", + slug="member-org", + created_at=datetime.now(UTC) - timedelta(days=1), + ) + await _add_membership(db_conn, user_id, org_id, "member") + + response = await auth_service.login_user(LoginRequest(email=email, password=password)) + + payload = security.decode_access_token(response.access_token) + assert payload["sub"] == str(user_id) + assert payload["org_id"] == str(org_id) + + refresh_hash = security.hash_token(response.refresh_token) + refresh_row = await db_conn.fetchrow( + "SELECT active_org_id FROM refresh_tokens WHERE token_hash = $1", + refresh_hash, + ) + assert refresh_row is not None and refresh_row["active_org_id"] == org_id + + +async def test_refresh_tokens_rotates_existing_token(auth_service, db_conn): + email = "rotate@example.com" + password = "Rotate123!" + user_id = await _create_user(db_conn, email, password) + org_id = await _create_org(db_conn, name="Rotate Org", slug="rotate-org") + await _add_membership(db_conn, user_id, org_id, "member") + + initial = await auth_service.login_user(LoginRequest(email=email, password=password)) + + rotated = await auth_service.refresh_tokens( + RefreshRequest(refresh_token=initial.refresh_token) + ) + + assert rotated.refresh_token != initial.refresh_token + + old_hash = security.hash_token(initial.refresh_token) + old_row = await db_conn.fetchrow( + "SELECT rotated_to FROM refresh_tokens WHERE token_hash = $1", + old_hash, + ) + assert old_row is not None and old_row["rotated_to"] is not None + + new_hash = security.hash_token(rotated.refresh_token) + new_row = await db_conn.fetchrow( + "SELECT user_id FROM refresh_tokens WHERE token_hash = $1", + new_hash, + ) + assert new_row is not None and new_row["user_id"] == user_id + + +async def test_switch_org_updates_active_org(auth_service, db_conn): + email = "switcher@example.com" + password = "Switch123!" + user_id = await _create_user(db_conn, email, password) + + primary_org = await _create_org( + db_conn, + name="Primary Org", + slug="primary-org", + created_at=datetime.now(UTC) - timedelta(days=2), + ) + await _add_membership(db_conn, user_id, primary_org, "member") + + secondary_org = await _create_org( + db_conn, + name="Secondary Org", + slug="secondary-org", + created_at=datetime.now(UTC) - timedelta(days=1), + ) + await _add_membership(db_conn, user_id, secondary_org, "admin") + + initial = await auth_service.login_user(LoginRequest(email=email, password=password)) + current_user = CurrentUser( + user_id=user_id, + email=email, + org_id=primary_org, + org_role="member", + token=initial.access_token, + ) + + switched = await auth_service.switch_org( + current_user, + SwitchOrgRequest(org_id=secondary_org, refresh_token=initial.refresh_token), + ) + + payload = security.decode_access_token(switched.access_token) + assert payload["org_id"] == str(secondary_org) + assert payload["org_role"] == "admin" + + new_hash = security.hash_token(switched.refresh_token) + new_row = await db_conn.fetchrow( + "SELECT active_org_id FROM refresh_tokens WHERE token_hash = $1", + new_hash, + ) + assert new_row is not None and new_row["active_org_id"] == secondary_org + + +async def test_logout_revokes_refresh_token(auth_service, db_conn): + email = "logout@example.com" + password = "Logout123!" + user_id = await _create_user(db_conn, email, password) + org_id = await _create_org(db_conn, name="Logout Org", slug="logout-org") + await _add_membership(db_conn, user_id, org_id, "member") + + initial = await auth_service.login_user(LoginRequest(email=email, password=password)) + current_user = CurrentUser( + user_id=user_id, + email=email, + org_id=org_id, + org_role="member", + token=initial.access_token, + ) + + await auth_service.logout(current_user, LogoutRequest(refresh_token=initial.refresh_token)) + + token_hash = security.hash_token(initial.refresh_token) + row = await db_conn.fetchrow( + "SELECT revoked_at FROM refresh_tokens WHERE token_hash = $1", + token_hash, + ) + assert row is not None and row["revoked_at"] is not None