Files
incidentops/tests/repositories/test_refresh_token.py

789 lines
31 KiB
Python

"""Tests for RefreshTokenRepository with security features."""
from datetime import UTC, datetime, timedelta
from uuid import uuid4
import asyncpg
import pytest
from app.repositories.org import OrgRepository
from app.repositories.refresh_token import RefreshTokenRepository
from app.repositories.user import UserRepository
class TestRefreshTokenRepository:
"""Tests for basic RefreshTokenRepository operations."""
async def _create_user(self, conn: asyncpg.Connection, email: str) -> uuid4:
"""Helper to create a user."""
user_repo = UserRepository(conn)
user_id = uuid4()
await user_repo.create(user_id, email, "hash")
return user_id
async def _create_org(self, conn: asyncpg.Connection, slug: str) -> uuid4:
"""Helper to create an org."""
org_repo = OrgRepository(conn)
org_id = uuid4()
await org_repo.create(org_id, f"Org {slug}", slug)
return org_id
async def test_create_token_returns_token_data(self, db_conn: asyncpg.Connection) -> None:
"""Creating a refresh token returns the token data including rotated_to."""
user_id = await self._create_user(db_conn, "token_create@example.com")
org_id = await self._create_org(db_conn, "token-create-org")
repo = RefreshTokenRepository(db_conn)
token_id = uuid4()
token_hash = "sha256_hashed_token_value"
expires_at = datetime.now(UTC) + timedelta(days=30)
result = await repo.create(token_id, user_id, token_hash, org_id, expires_at)
assert result["id"] == token_id
assert result["user_id"] == user_id
assert result["token_hash"] == token_hash
assert result["active_org_id"] == org_id
assert result["expires_at"] is not None
assert result["revoked_at"] is None
assert result["rotated_to"] is None # New field
assert result["created_at"] is not None
async def test_token_hash_must_be_unique(self, db_conn: asyncpg.Connection) -> None:
"""Token hash uniqueness constraint per SPECS.md refresh_tokens table."""
user_id = await self._create_user(db_conn, "unique_hash@example.com")
org_id = await self._create_org(db_conn, "unique-hash-org")
repo = RefreshTokenRepository(db_conn)
token_hash = "duplicate_hash_value"
expires_at = datetime.now(UTC) + timedelta(days=30)
await repo.create(uuid4(), user_id, token_hash, org_id, expires_at)
with pytest.raises(asyncpg.UniqueViolationError):
await repo.create(uuid4(), user_id, token_hash, org_id, expires_at)
async def test_get_by_hash_returns_token(self, db_conn: asyncpg.Connection) -> None:
"""get_by_hash returns the correct token (even if revoked/expired)."""
user_id = await self._create_user(db_conn, "get_hash@example.com")
org_id = await self._create_org(db_conn, "get-hash-org")
repo = RefreshTokenRepository(db_conn)
token_id = uuid4()
token_hash = "lookup_by_hash_value"
expires_at = datetime.now(UTC) + timedelta(days=30)
await repo.create(token_id, user_id, token_hash, org_id, expires_at)
result = await repo.get_by_hash(token_hash)
assert result is not None
assert result["id"] == token_id
assert result["token_hash"] == token_hash
async def test_get_by_hash_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
"""get_by_hash returns None for non-existent hash."""
repo = RefreshTokenRepository(db_conn)
result = await repo.get_by_hash("nonexistent_hash")
assert result is None
class TestGetValidByHash:
"""Tests for get_valid_by_hash with defense-in-depth validation."""
async def _setup_token(
self, conn: asyncpg.Connection, suffix: str = ""
) -> tuple[uuid4, uuid4, uuid4, str, RefreshTokenRepository]:
"""Helper to create user, org, and token."""
user_repo = UserRepository(conn)
org_repo = OrgRepository(conn)
token_repo = RefreshTokenRepository(conn)
user_id = uuid4()
org_id = uuid4()
token_id = uuid4()
token_hash = f"token_hash_{uuid4().hex[:8]}{suffix}"
await user_repo.create(user_id, f"user_{uuid4().hex[:8]}@example.com", "hash")
await org_repo.create(org_id, "Test Org", f"test-org-{uuid4().hex[:8]}")
expires_at = datetime.now(UTC) + timedelta(days=30)
await token_repo.create(token_id, user_id, token_hash, org_id, expires_at)
return token_id, user_id, org_id, token_hash, token_repo
async def test_get_valid_returns_valid_token(self, db_conn: asyncpg.Connection) -> None:
"""get_valid_by_hash returns token if not expired, not revoked, not rotated."""
_, _, _, token_hash, repo = await self._setup_token(db_conn)
result = await repo.get_valid_by_hash(token_hash)
assert result is not None
assert result["token_hash"] == token_hash
async def test_get_valid_returns_none_for_expired(self, db_conn: asyncpg.Connection) -> None:
"""get_valid_by_hash returns None for expired token."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, "expired@example.com", "hash")
await org_repo.create(org_id, "Org", "expired-org")
token_hash = "expired_token_hash"
expires_at = datetime.now(UTC) - timedelta(days=1) # Already expired
await repo.create(uuid4(), user_id, token_hash, org_id, expires_at)
result = await repo.get_valid_by_hash(token_hash)
assert result is None
async def test_get_valid_returns_none_for_revoked(self, db_conn: asyncpg.Connection) -> None:
"""get_valid_by_hash returns None for revoked token."""
token_id, _, _, token_hash, repo = await self._setup_token(db_conn, "_revoked")
await repo.revoke(token_id)
result = await repo.get_valid_by_hash(token_hash)
assert result is None
async def test_get_valid_returns_none_for_rotated(self, db_conn: asyncpg.Connection) -> None:
"""get_valid_by_hash returns None for already-rotated token."""
_, user_id, org_id, old_hash, repo = await self._setup_token(db_conn, "_rotated")
# Rotate the token
new_hash = f"new_token_{uuid4().hex[:8]}"
new_expires = datetime.now(UTC) + timedelta(days=30)
await repo.rotate(old_hash, uuid4(), new_hash, new_expires)
# Old token should no longer be valid
result = await repo.get_valid_by_hash(old_hash)
assert result is None
async def test_get_valid_with_user_id_validation(self, db_conn: asyncpg.Connection) -> None:
"""get_valid_by_hash validates user_id when provided (defense-in-depth)."""
_, user_id, _, token_hash, repo = await self._setup_token(db_conn, "_user_check")
# Correct user_id should work
result = await repo.get_valid_by_hash(token_hash, user_id=user_id)
assert result is not None
# Wrong user_id should return None
wrong_user_id = uuid4()
result = await repo.get_valid_by_hash(token_hash, user_id=wrong_user_id)
assert result is None
async def test_get_valid_with_org_id_validation(self, db_conn: asyncpg.Connection) -> None:
"""get_valid_by_hash validates active_org_id when provided (defense-in-depth)."""
_, _, org_id, token_hash, repo = await self._setup_token(db_conn, "_org_check")
# Correct org_id should work
result = await repo.get_valid_by_hash(token_hash, active_org_id=org_id)
assert result is not None
# Wrong org_id should return None
wrong_org_id = uuid4()
result = await repo.get_valid_by_hash(token_hash, active_org_id=wrong_org_id)
assert result is None
async def test_get_valid_with_both_user_and_org_validation(self, db_conn: asyncpg.Connection) -> None:
"""get_valid_by_hash validates both user_id and active_org_id together."""
_, user_id, org_id, token_hash, repo = await self._setup_token(db_conn, "_both")
# Both correct should work
result = await repo.get_valid_by_hash(token_hash, user_id=user_id, active_org_id=org_id)
assert result is not None
# Either wrong should fail
result = await repo.get_valid_by_hash(token_hash, user_id=uuid4(), active_org_id=org_id)
assert result is None
result = await repo.get_valid_by_hash(token_hash, user_id=user_id, active_org_id=uuid4())
assert result is None
class TestAtomicRotation:
"""Tests for atomic token rotation per SPECS.md."""
async def _setup_token(
self, conn: asyncpg.Connection
) -> tuple[uuid4, uuid4, uuid4, str, RefreshTokenRepository]:
"""Helper to create user, org, and token."""
user_repo = UserRepository(conn)
org_repo = OrgRepository(conn)
token_repo = RefreshTokenRepository(conn)
user_id = uuid4()
org_id = uuid4()
token_id = uuid4()
token_hash = f"rotate_token_{uuid4().hex[:8]}"
await user_repo.create(user_id, f"rotate_{uuid4().hex[:8]}@example.com", "hash")
await org_repo.create(org_id, "Rotate Org", f"rotate-org-{uuid4().hex[:8]}")
expires_at = datetime.now(UTC) + timedelta(days=30)
await token_repo.create(token_id, user_id, token_hash, org_id, expires_at)
return token_id, user_id, org_id, token_hash, token_repo
async def test_rotate_creates_new_token(self, db_conn: asyncpg.Connection) -> None:
"""rotate() creates a new token and returns it."""
old_id, user_id, org_id, old_hash, repo = await self._setup_token(db_conn)
new_id = uuid4()
new_hash = f"new_rotated_{uuid4().hex[:8]}"
new_expires = datetime.now(UTC) + timedelta(days=30)
result = await repo.rotate(old_hash, new_id, new_hash, new_expires)
assert result is not None
assert result["id"] == new_id
assert result["token_hash"] == new_hash
assert result["user_id"] == user_id
assert result["active_org_id"] == org_id
async def test_rotate_marks_old_token_as_rotated(self, db_conn: asyncpg.Connection) -> None:
"""rotate() sets rotated_to on the old token (not revoked_at)."""
old_id, _, _, old_hash, repo = await self._setup_token(db_conn)
new_id = uuid4()
new_hash = f"new_{uuid4().hex[:8]}"
new_expires = datetime.now(UTC) + timedelta(days=30)
await repo.rotate(old_hash, new_id, new_hash, new_expires)
# Check old token state
old_token = await repo.get_by_hash(old_hash)
assert old_token["rotated_to"] == new_id
assert old_token["revoked_at"] is None # Not revoked, just rotated
async def test_rotate_fails_for_invalid_token(self, db_conn: asyncpg.Connection) -> None:
"""rotate() returns None if old token is invalid."""
_, _, _, _, repo = await self._setup_token(db_conn)
result = await repo.rotate(
"nonexistent_hash",
uuid4(),
f"new_{uuid4().hex[:8]}",
datetime.now(UTC) + timedelta(days=30),
)
assert result is None
async def test_rotate_fails_for_expired_token(self, db_conn: asyncpg.Connection) -> None:
"""rotate() returns None if old token is expired."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, "exp_rotate@example.com", "hash")
await org_repo.create(org_id, "Org", "exp-rotate-org")
old_hash = "expired_for_rotation"
await repo.create(
uuid4(), user_id, old_hash, org_id,
datetime.now(UTC) - timedelta(days=1) # Already expired
)
result = await repo.rotate(
old_hash, uuid4(), f"new_{uuid4().hex[:8]}",
datetime.now(UTC) + timedelta(days=30)
)
assert result is None
async def test_rotate_fails_for_revoked_token(self, db_conn: asyncpg.Connection) -> None:
"""rotate() returns None if old token is revoked."""
old_id, _, _, old_hash, repo = await self._setup_token(db_conn)
await repo.revoke(old_id)
result = await repo.rotate(
old_hash, uuid4(), f"new_{uuid4().hex[:8]}",
datetime.now(UTC) + timedelta(days=30)
)
assert result is None
async def test_rotate_fails_for_already_rotated_token(self, db_conn: asyncpg.Connection) -> None:
"""rotate() returns None if old token was already rotated."""
_, _, _, old_hash, repo = await self._setup_token(db_conn)
# First rotation should succeed
result1 = await repo.rotate(
old_hash, uuid4(), f"new1_{uuid4().hex[:8]}",
datetime.now(UTC) + timedelta(days=30)
)
assert result1 is not None
# Second rotation of same token should fail
result2 = await repo.rotate(
old_hash, uuid4(), f"new2_{uuid4().hex[:8]}",
datetime.now(UTC) + timedelta(days=30)
)
assert result2 is None
async def test_rotate_with_org_switch(self, db_conn: asyncpg.Connection) -> None:
"""rotate() can change active_org_id (for org-switch flow)."""
_, user_id, old_org_id, old_hash, repo = await self._setup_token(db_conn)
# Create a new org for the user to switch to
org_repo = OrgRepository(db_conn)
new_org_id = uuid4()
await org_repo.create(new_org_id, "New Org", f"new-org-{uuid4().hex[:8]}")
new_hash = f"switched_{uuid4().hex[:8]}"
result = await repo.rotate(
old_hash, uuid4(), new_hash,
datetime.now(UTC) + timedelta(days=30),
new_active_org_id=new_org_id # Switch org
)
assert result is not None
assert result["active_org_id"] == new_org_id
assert result["active_org_id"] != old_org_id
async def test_rotate_validates_expected_user_id(self, db_conn: asyncpg.Connection) -> None:
"""rotate() fails if expected_user_id doesn't match token's user."""
_, user_id, _, old_hash, repo = await self._setup_token(db_conn)
# Wrong user should fail
result = await repo.rotate(
old_hash, uuid4(), f"new_{uuid4().hex[:8]}",
datetime.now(UTC) + timedelta(days=30),
expected_user_id=uuid4() # Wrong user
)
assert result is None
# Correct user should work
result = await repo.rotate(
old_hash, uuid4(), f"new_{uuid4().hex[:8]}",
datetime.now(UTC) + timedelta(days=30),
expected_user_id=user_id # Correct user
)
assert result is not None
class TestTokenReuseDetection:
"""Tests for detecting token reuse (stolen token attacks)."""
async def _setup_rotated_token(
self, conn: asyncpg.Connection
) -> tuple[uuid4, str, str, RefreshTokenRepository]:
"""Create a token and rotate it, returning old and new hashes."""
user_repo = UserRepository(conn)
org_repo = OrgRepository(conn)
token_repo = RefreshTokenRepository(conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, f"reuse_{uuid4().hex[:8]}@example.com", "hash")
await org_repo.create(org_id, "Reuse Org", f"reuse-org-{uuid4().hex[:8]}")
old_hash = f"old_token_{uuid4().hex[:8]}"
expires_at = datetime.now(UTC) + timedelta(days=30)
old_token = await token_repo.create(uuid4(), user_id, old_hash, org_id, expires_at)
new_hash = f"new_token_{uuid4().hex[:8]}"
await token_repo.rotate(old_hash, uuid4(), new_hash, expires_at)
return old_token["id"], old_hash, new_hash, token_repo
async def test_check_token_reuse_detects_rotated_token(self, db_conn: asyncpg.Connection) -> None:
"""check_token_reuse returns token if it has been rotated."""
old_id, old_hash, _, repo = await self._setup_rotated_token(db_conn)
result = await repo.check_token_reuse(old_hash)
assert result is not None
assert result["id"] == old_id
assert result["rotated_to"] is not None
async def test_check_token_reuse_returns_none_for_active_token(self, db_conn: asyncpg.Connection) -> None:
"""check_token_reuse returns None for token that hasn't been rotated."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, "active@example.com", "hash")
await org_repo.create(org_id, "Org", "active-org")
token_hash = "active_token_hash"
await repo.create(
uuid4(), user_id, token_hash, org_id,
datetime.now(UTC) + timedelta(days=30)
)
result = await repo.check_token_reuse(token_hash)
assert result is None
async def test_check_token_reuse_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
"""check_token_reuse returns None for non-existent token."""
repo = RefreshTokenRepository(db_conn)
result = await repo.check_token_reuse("nonexistent_hash")
assert result is None
class TestTokenChainRevocation:
"""Tests for revoking entire token chains (breach response)."""
async def _setup_token_chain(
self, conn: asyncpg.Connection, chain_length: int = 3
) -> tuple[list[uuid4], list[str], uuid4, RefreshTokenRepository]:
"""Create a chain of rotated tokens."""
user_repo = UserRepository(conn)
org_repo = OrgRepository(conn)
token_repo = RefreshTokenRepository(conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, f"chain_{uuid4().hex[:8]}@example.com", "hash")
await org_repo.create(org_id, "Chain Org", f"chain-org-{uuid4().hex[:8]}")
token_ids = []
token_hashes = []
expires_at = datetime.now(UTC) + timedelta(days=30)
# Create first token
first_hash = f"chain_token_0_{uuid4().hex[:8]}"
first_token = await token_repo.create(uuid4(), user_id, first_hash, org_id, expires_at)
token_ids.append(first_token["id"])
token_hashes.append(first_hash)
# Rotate to create chain
current_hash = first_hash
for i in range(1, chain_length):
new_hash = f"chain_token_{i}_{uuid4().hex[:8]}"
new_id = uuid4()
await token_repo.rotate(current_hash, new_id, new_hash, expires_at)
token_ids.append(new_id)
token_hashes.append(new_hash)
current_hash = new_hash
return token_ids, token_hashes, user_id, token_repo
async def test_revoke_token_chain_revokes_all_in_chain(self, db_conn: asyncpg.Connection) -> None:
"""revoke_token_chain revokes the token and all its rotations."""
token_ids, token_hashes, _, repo = await self._setup_token_chain(db_conn, chain_length=3)
# Revoke starting from the first token
count = await repo.revoke_token_chain(token_ids[0])
# Should revoke all 3 tokens in the chain
# But note: only the last one wasn't already "consumed" by rotation
# Let's check that revoke was called on all that were eligible
assert count >= 1 # At least the leaf token
# Verify the leaf token is revoked
leaf_token = await repo.get_by_hash(token_hashes[-1])
assert leaf_token["revoked_at"] is not None
async def test_revoke_token_chain_returns_count(self, db_conn: asyncpg.Connection) -> None:
"""revoke_token_chain returns count of actually revoked tokens."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, "single@example.com", "hash")
await org_repo.create(org_id, "Single Org", "single-org")
token_hash = "single_chain_token"
token = await repo.create(
uuid4(), user_id, token_hash, org_id,
datetime.now(UTC) + timedelta(days=30)
)
count = await repo.revoke_token_chain(token["id"])
assert count == 1
class TestTokenRevocation:
"""Tests for token revocation methods."""
async def _setup_token(self, conn: asyncpg.Connection) -> tuple[uuid4, str, RefreshTokenRepository]:
"""Helper to create user, org, and token."""
user_repo = UserRepository(conn)
org_repo = OrgRepository(conn)
token_repo = RefreshTokenRepository(conn)
user_id = uuid4()
org_id = uuid4()
token_id = uuid4()
token_hash = f"revoke_token_{uuid4().hex[:8]}"
await user_repo.create(user_id, f"revoke_{uuid4().hex[:8]}@example.com", "hash")
await org_repo.create(org_id, "Revoke Org", f"revoke-org-{uuid4().hex[:8]}")
expires_at = datetime.now(UTC) + timedelta(days=30)
await token_repo.create(token_id, user_id, token_hash, org_id, expires_at)
return token_id, token_hash, token_repo
async def test_revoke_sets_revoked_at(self, db_conn: asyncpg.Connection) -> None:
"""revoke() sets the revoked_at timestamp."""
token_id, token_hash, repo = await self._setup_token(db_conn)
result = await repo.revoke(token_id)
assert result is True
token = await repo.get_by_hash(token_hash)
assert token["revoked_at"] is not None
async def test_revoke_returns_true_on_success(self, db_conn: asyncpg.Connection) -> None:
"""revoke() returns True when token is revoked."""
token_id, _, repo = await self._setup_token(db_conn)
result = await repo.revoke(token_id)
assert result is True
async def test_revoke_returns_false_for_already_revoked(self, db_conn: asyncpg.Connection) -> None:
"""revoke() returns False if token already revoked."""
token_id, _, repo = await self._setup_token(db_conn)
await repo.revoke(token_id)
result = await repo.revoke(token_id)
assert result is False
async def test_revoke_returns_false_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
"""revoke() returns False for non-existent token."""
_, _, repo = await self._setup_token(db_conn)
result = await repo.revoke(uuid4())
assert result is False
async def test_revoke_by_hash_works(self, db_conn: asyncpg.Connection) -> None:
"""revoke_by_hash() revokes token by hash value."""
_, token_hash, repo = await self._setup_token(db_conn)
result = await repo.revoke_by_hash(token_hash)
assert result is True
token = await repo.get_by_hash(token_hash)
assert token["revoked_at"] is not None
async def test_revoke_by_hash_returns_false_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
"""revoke_by_hash() returns False for non-existent hash."""
_, _, repo = await self._setup_token(db_conn)
result = await repo.revoke_by_hash("nonexistent_hash")
assert result is False
class TestRevokeAllForUser:
"""Tests for revoking all tokens for a user."""
async def test_revoke_all_for_user_revokes_all_tokens(self, db_conn: asyncpg.Connection) -> None:
"""revoke_all_for_user() revokes all tokens for the user."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
token_repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, "multi_token@example.com", "hash")
await org_repo.create(org_id, "Multi Token Org", "multi-token-org")
# Create multiple tokens
hashes = []
for i in range(3):
token_hash = f"token_{i}_{uuid4().hex[:8]}"
hashes.append(token_hash)
expires_at = datetime.now(UTC) + timedelta(days=30)
await token_repo.create(uuid4(), user_id, token_hash, org_id, expires_at)
result = await token_repo.revoke_all_for_user(user_id)
assert result == 3
for token_hash in hashes:
token = await token_repo.get_valid_by_hash(token_hash)
assert token is None
async def test_revoke_all_for_user_returns_zero_for_no_tokens(self, db_conn: asyncpg.Connection) -> None:
"""revoke_all_for_user() returns 0 if user has no tokens."""
user_repo = UserRepository(db_conn)
token_repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
await user_repo.create(user_id, "no_tokens@example.com", "hash")
result = await token_repo.revoke_all_for_user(user_id)
assert result == 0
async def test_revoke_all_for_user_only_affects_user_tokens(self, db_conn: asyncpg.Connection) -> None:
"""revoke_all_for_user() doesn't affect other users' tokens."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
token_repo = RefreshTokenRepository(db_conn)
user1 = uuid4()
user2 = uuid4()
org_id = uuid4()
await user_repo.create(user1, "user1@example.com", "hash")
await user_repo.create(user2, "user2@example.com", "hash")
await org_repo.create(org_id, "Shared Org", "shared-org")
user1_hash = f"user1_token_{uuid4().hex[:8]}"
user2_hash = f"user2_token_{uuid4().hex[:8]}"
expires_at = datetime.now(UTC) + timedelta(days=30)
await token_repo.create(uuid4(), user1, user1_hash, org_id, expires_at)
await token_repo.create(uuid4(), user2, user2_hash, org_id, expires_at)
await token_repo.revoke_all_for_user(user1)
# User1's token is revoked
assert await token_repo.get_valid_by_hash(user1_hash) is None
# User2's token is still valid
assert await token_repo.get_valid_by_hash(user2_hash) is not None
class TestRevokeAllExcept:
"""Tests for revoking all tokens except current session."""
async def test_revoke_all_except_keeps_specified_token(self, db_conn: asyncpg.Connection) -> None:
"""revoke_all_for_user_except() keeps the specified token active."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
token_repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, "except@example.com", "hash")
await org_repo.create(org_id, "Except Org", "except-org")
# Create multiple tokens
expires_at = datetime.now(UTC) + timedelta(days=30)
keep_token_id = uuid4()
keep_hash = f"keep_token_{uuid4().hex[:8]}"
await token_repo.create(keep_token_id, user_id, keep_hash, org_id, expires_at)
other_hashes = []
for i in range(2):
other_hash = f"other_token_{i}_{uuid4().hex[:8]}"
other_hashes.append(other_hash)
await token_repo.create(uuid4(), user_id, other_hash, org_id, expires_at)
result = await token_repo.revoke_all_for_user_except(user_id, keep_token_id)
assert result == 2 # Revoked 2 other tokens
# Keep token is still valid
assert await token_repo.get_valid_by_hash(keep_hash) is not None
# Other tokens are revoked
for other_hash in other_hashes:
assert await token_repo.get_valid_by_hash(other_hash) is None
class TestActiveTokensForUser:
"""Tests for listing active tokens for a user."""
async def test_get_active_tokens_returns_only_active(self, db_conn: asyncpg.Connection) -> None:
"""get_active_tokens_for_user() returns only non-revoked, non-expired, non-rotated."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
token_repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, "active_list@example.com", "hash")
await org_repo.create(org_id, "Active List Org", "active-list-org")
expires_at = datetime.now(UTC) + timedelta(days=30)
expired_at = datetime.now(UTC) - timedelta(days=1)
# Create active token
active_hash = f"active_{uuid4().hex[:8]}"
await token_repo.create(uuid4(), user_id, active_hash, org_id, expires_at)
# Create revoked token
revoked_id = uuid4()
revoked_hash = f"revoked_{uuid4().hex[:8]}"
await token_repo.create(revoked_id, user_id, revoked_hash, org_id, expires_at)
await token_repo.revoke(revoked_id)
# Create expired token
expired_hash = f"expired_{uuid4().hex[:8]}"
await token_repo.create(uuid4(), user_id, expired_hash, org_id, expired_at)
# Create rotated token
rotated_hash = f"rotated_{uuid4().hex[:8]}"
await token_repo.create(uuid4(), user_id, rotated_hash, org_id, expires_at)
await token_repo.rotate(rotated_hash, uuid4(), f"new_{uuid4().hex[:8]}", expires_at)
result = await token_repo.get_active_tokens_for_user(user_id)
# Should only return the active token and the new rotated token
assert len(result) == 2
hashes = {t["token_hash"] for t in result}
assert active_hash in hashes
assert revoked_hash not in hashes
assert expired_hash not in hashes
assert rotated_hash not in hashes
class TestTokenForeignKeys:
"""Tests for refresh token foreign key constraints."""
async def test_token_requires_valid_user_foreign_key(self, db_conn: asyncpg.Connection) -> None:
"""refresh_tokens.user_id must reference existing user."""
org_repo = OrgRepository(db_conn)
token_repo = RefreshTokenRepository(db_conn)
org_id = uuid4()
await org_repo.create(org_id, "FK Test Org", "fk-test-org")
with pytest.raises(asyncpg.ForeignKeyViolationError):
await token_repo.create(
uuid4(), uuid4(), "orphan_token", org_id,
datetime.now(UTC) + timedelta(days=30)
)
async def test_token_requires_valid_org_foreign_key(self, db_conn: asyncpg.Connection) -> None:
"""refresh_tokens.active_org_id must reference existing org."""
user_repo = UserRepository(db_conn)
token_repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
await user_repo.create(user_id, "fk_org_test@example.com", "hash")
with pytest.raises(asyncpg.ForeignKeyViolationError):
await token_repo.create(
uuid4(), user_id, "orphan_org_token", uuid4(),
datetime.now(UTC) + timedelta(days=30)
)
async def test_token_stores_active_org_id(self, db_conn: asyncpg.Connection) -> None:
"""Token stores active_org_id for org context per SPECS.md."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
token_repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, "active_org@example.com", "hash")
await org_repo.create(org_id, "Active Org", "active-org")
token_hash = f"active_org_token_{uuid4().hex[:8]}"
await token_repo.create(
uuid4(), user_id, token_hash, org_id,
datetime.now(UTC) + timedelta(days=30)
)
token = await token_repo.get_by_hash(token_hash)
assert token["active_org_id"] == org_id