261 lines
8.0 KiB
Python
261 lines
8.0 KiB
Python
|
|
"""Unit tests covering AuthService flows."""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from contextlib import asynccontextmanager
|
||
|
|
from datetime import UTC, datetime, timedelta
|
||
|
|
from uuid import UUID, uuid4
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from app.api.deps import CurrentUser
|
||
|
|
from app.core import security
|
||
|
|
from app.db import Database
|
||
|
|
from app.schemas.auth import (
|
||
|
|
LoginRequest,
|
||
|
|
LogoutRequest,
|
||
|
|
RefreshRequest,
|
||
|
|
RegisterRequest,
|
||
|
|
SwitchOrgRequest,
|
||
|
|
)
|
||
|
|
from app.services.auth import AuthService
|
||
|
|
|
||
|
|
|
||
|
|
pytestmark = pytest.mark.asyncio
|
||
|
|
|
||
|
|
|
||
|
|
class _SingleConnectionDatabase(Database):
|
||
|
|
"""Database stub that reuses a single asyncpg connection."""
|
||
|
|
|
||
|
|
def __init__(self, conn) -> None: # type: ignore[override]
|
||
|
|
self._conn = conn
|
||
|
|
|
||
|
|
@asynccontextmanager
|
||
|
|
async def connection(self): # type: ignore[override]
|
||
|
|
yield self._conn
|
||
|
|
|
||
|
|
@asynccontextmanager
|
||
|
|
async def transaction(self): # type: ignore[override]
|
||
|
|
tr = self._conn.transaction()
|
||
|
|
await tr.start()
|
||
|
|
try:
|
||
|
|
yield self._conn
|
||
|
|
except Exception:
|
||
|
|
await tr.rollback()
|
||
|
|
raise
|
||
|
|
else:
|
||
|
|
await tr.commit()
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
async def auth_service(db_conn):
|
||
|
|
"""AuthService bound to the per-test database connection."""
|
||
|
|
|
||
|
|
return AuthService(database=_SingleConnectionDatabase(db_conn))
|
||
|
|
|
||
|
|
|
||
|
|
async def _create_user(conn, email: str, password: str) -> UUID:
|
||
|
|
user_id = uuid4()
|
||
|
|
password_hash = security.hash_password(password)
|
||
|
|
await conn.execute(
|
||
|
|
"INSERT INTO users (id, email, password_hash) VALUES ($1, $2, $3)",
|
||
|
|
user_id,
|
||
|
|
email,
|
||
|
|
password_hash,
|
||
|
|
)
|
||
|
|
return user_id
|
||
|
|
|
||
|
|
|
||
|
|
async def _create_org(
|
||
|
|
conn,
|
||
|
|
name: str,
|
||
|
|
slug: str | None = None,
|
||
|
|
*,
|
||
|
|
created_at: datetime | None = None,
|
||
|
|
) -> UUID:
|
||
|
|
org_id = uuid4()
|
||
|
|
slug_value = slug or f"{name.lower().replace(' ', '-')}-{org_id.hex[:6]}"
|
||
|
|
created = created_at or datetime.now(UTC)
|
||
|
|
await conn.execute(
|
||
|
|
"INSERT INTO orgs (id, name, slug, created_at) VALUES ($1, $2, $3, $4)",
|
||
|
|
org_id,
|
||
|
|
name,
|
||
|
|
slug_value,
|
||
|
|
created,
|
||
|
|
)
|
||
|
|
return org_id
|
||
|
|
|
||
|
|
|
||
|
|
async def _add_membership(conn, user_id: UUID, org_id: UUID, role: str) -> None:
|
||
|
|
await conn.execute(
|
||
|
|
"INSERT INTO org_members (id, user_id, org_id, role) VALUES ($1, $2, $3, $4)",
|
||
|
|
uuid4(),
|
||
|
|
user_id,
|
||
|
|
org_id,
|
||
|
|
role,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
async def test_register_user_creates_admin_membership(auth_service, db_conn):
|
||
|
|
request = RegisterRequest(
|
||
|
|
email="founder@example.com",
|
||
|
|
password="SuperSecret1!",
|
||
|
|
org_name="Founders Inc",
|
||
|
|
)
|
||
|
|
|
||
|
|
response = await auth_service.register_user(request)
|
||
|
|
|
||
|
|
payload = security.decode_access_token(response.access_token)
|
||
|
|
assert payload["org_role"] == "admin"
|
||
|
|
|
||
|
|
user_id = UUID(payload["sub"])
|
||
|
|
org_id = UUID(payload["org_id"])
|
||
|
|
|
||
|
|
user = await db_conn.fetchrow("SELECT email FROM users WHERE id = $1", user_id)
|
||
|
|
assert user is not None and user["email"] == request.email
|
||
|
|
|
||
|
|
membership = await db_conn.fetchrow(
|
||
|
|
"SELECT role FROM org_members WHERE user_id = $1 AND org_id = $2",
|
||
|
|
user_id,
|
||
|
|
org_id,
|
||
|
|
)
|
||
|
|
assert membership is not None and membership["role"] == "admin"
|
||
|
|
|
||
|
|
refresh_hash = security.hash_token(response.refresh_token)
|
||
|
|
refresh_row = await db_conn.fetchrow(
|
||
|
|
"SELECT user_id, active_org_id FROM refresh_tokens WHERE token_hash = $1",
|
||
|
|
refresh_hash,
|
||
|
|
)
|
||
|
|
assert refresh_row is not None
|
||
|
|
assert refresh_row["user_id"] == user_id
|
||
|
|
assert refresh_row["active_org_id"] == org_id
|
||
|
|
|
||
|
|
|
||
|
|
async def test_login_user_returns_tokens_for_valid_credentials(auth_service, db_conn):
|
||
|
|
email = "member@example.com"
|
||
|
|
password = "Password123!"
|
||
|
|
user_id = await _create_user(db_conn, email, password)
|
||
|
|
org_id = await _create_org(
|
||
|
|
db_conn,
|
||
|
|
name="Member Org",
|
||
|
|
slug="member-org",
|
||
|
|
created_at=datetime.now(UTC) - timedelta(days=1),
|
||
|
|
)
|
||
|
|
await _add_membership(db_conn, user_id, org_id, "member")
|
||
|
|
|
||
|
|
response = await auth_service.login_user(LoginRequest(email=email, password=password))
|
||
|
|
|
||
|
|
payload = security.decode_access_token(response.access_token)
|
||
|
|
assert payload["sub"] == str(user_id)
|
||
|
|
assert payload["org_id"] == str(org_id)
|
||
|
|
|
||
|
|
refresh_hash = security.hash_token(response.refresh_token)
|
||
|
|
refresh_row = await db_conn.fetchrow(
|
||
|
|
"SELECT active_org_id FROM refresh_tokens WHERE token_hash = $1",
|
||
|
|
refresh_hash,
|
||
|
|
)
|
||
|
|
assert refresh_row is not None and refresh_row["active_org_id"] == org_id
|
||
|
|
|
||
|
|
|
||
|
|
async def test_refresh_tokens_rotates_existing_token(auth_service, db_conn):
|
||
|
|
email = "rotate@example.com"
|
||
|
|
password = "Rotate123!"
|
||
|
|
user_id = await _create_user(db_conn, email, password)
|
||
|
|
org_id = await _create_org(db_conn, name="Rotate Org", slug="rotate-org")
|
||
|
|
await _add_membership(db_conn, user_id, org_id, "member")
|
||
|
|
|
||
|
|
initial = await auth_service.login_user(LoginRequest(email=email, password=password))
|
||
|
|
|
||
|
|
rotated = await auth_service.refresh_tokens(
|
||
|
|
RefreshRequest(refresh_token=initial.refresh_token)
|
||
|
|
)
|
||
|
|
|
||
|
|
assert rotated.refresh_token != initial.refresh_token
|
||
|
|
|
||
|
|
old_hash = security.hash_token(initial.refresh_token)
|
||
|
|
old_row = await db_conn.fetchrow(
|
||
|
|
"SELECT rotated_to FROM refresh_tokens WHERE token_hash = $1",
|
||
|
|
old_hash,
|
||
|
|
)
|
||
|
|
assert old_row is not None and old_row["rotated_to"] is not None
|
||
|
|
|
||
|
|
new_hash = security.hash_token(rotated.refresh_token)
|
||
|
|
new_row = await db_conn.fetchrow(
|
||
|
|
"SELECT user_id FROM refresh_tokens WHERE token_hash = $1",
|
||
|
|
new_hash,
|
||
|
|
)
|
||
|
|
assert new_row is not None and new_row["user_id"] == user_id
|
||
|
|
|
||
|
|
|
||
|
|
async def test_switch_org_updates_active_org(auth_service, db_conn):
|
||
|
|
email = "switcher@example.com"
|
||
|
|
password = "Switch123!"
|
||
|
|
user_id = await _create_user(db_conn, email, password)
|
||
|
|
|
||
|
|
primary_org = await _create_org(
|
||
|
|
db_conn,
|
||
|
|
name="Primary Org",
|
||
|
|
slug="primary-org",
|
||
|
|
created_at=datetime.now(UTC) - timedelta(days=2),
|
||
|
|
)
|
||
|
|
await _add_membership(db_conn, user_id, primary_org, "member")
|
||
|
|
|
||
|
|
secondary_org = await _create_org(
|
||
|
|
db_conn,
|
||
|
|
name="Secondary Org",
|
||
|
|
slug="secondary-org",
|
||
|
|
created_at=datetime.now(UTC) - timedelta(days=1),
|
||
|
|
)
|
||
|
|
await _add_membership(db_conn, user_id, secondary_org, "admin")
|
||
|
|
|
||
|
|
initial = await auth_service.login_user(LoginRequest(email=email, password=password))
|
||
|
|
current_user = CurrentUser(
|
||
|
|
user_id=user_id,
|
||
|
|
email=email,
|
||
|
|
org_id=primary_org,
|
||
|
|
org_role="member",
|
||
|
|
token=initial.access_token,
|
||
|
|
)
|
||
|
|
|
||
|
|
switched = await auth_service.switch_org(
|
||
|
|
current_user,
|
||
|
|
SwitchOrgRequest(org_id=secondary_org, refresh_token=initial.refresh_token),
|
||
|
|
)
|
||
|
|
|
||
|
|
payload = security.decode_access_token(switched.access_token)
|
||
|
|
assert payload["org_id"] == str(secondary_org)
|
||
|
|
assert payload["org_role"] == "admin"
|
||
|
|
|
||
|
|
new_hash = security.hash_token(switched.refresh_token)
|
||
|
|
new_row = await db_conn.fetchrow(
|
||
|
|
"SELECT active_org_id FROM refresh_tokens WHERE token_hash = $1",
|
||
|
|
new_hash,
|
||
|
|
)
|
||
|
|
assert new_row is not None and new_row["active_org_id"] == secondary_org
|
||
|
|
|
||
|
|
|
||
|
|
async def test_logout_revokes_refresh_token(auth_service, db_conn):
|
||
|
|
email = "logout@example.com"
|
||
|
|
password = "Logout123!"
|
||
|
|
user_id = await _create_user(db_conn, email, password)
|
||
|
|
org_id = await _create_org(db_conn, name="Logout Org", slug="logout-org")
|
||
|
|
await _add_membership(db_conn, user_id, org_id, "member")
|
||
|
|
|
||
|
|
initial = await auth_service.login_user(LoginRequest(email=email, password=password))
|
||
|
|
current_user = CurrentUser(
|
||
|
|
user_id=user_id,
|
||
|
|
email=email,
|
||
|
|
org_id=org_id,
|
||
|
|
org_role="member",
|
||
|
|
token=initial.access_token,
|
||
|
|
)
|
||
|
|
|
||
|
|
await auth_service.logout(current_user, LogoutRequest(refresh_token=initial.refresh_token))
|
||
|
|
|
||
|
|
token_hash = security.hash_token(initial.refresh_token)
|
||
|
|
row = await db_conn.fetchrow(
|
||
|
|
"SELECT revoked_at FROM refresh_tokens WHERE token_hash = $1",
|
||
|
|
token_hash,
|
||
|
|
)
|
||
|
|
assert row is not None and row["revoked_at"] is not None
|