"""Tests for RefreshTokenRepository with security features.""" from datetime import UTC, datetime, timedelta from uuid import uuid4 import asyncpg import pytest from app.repositories.org import OrgRepository from app.repositories.refresh_token import RefreshTokenRepository from app.repositories.user import UserRepository class TestRefreshTokenRepository: """Tests for basic RefreshTokenRepository operations.""" async def _create_user(self, conn: asyncpg.Connection, email: str) -> uuid4: """Helper to create a user.""" user_repo = UserRepository(conn) user_id = uuid4() await user_repo.create(user_id, email, "hash") return user_id async def _create_org(self, conn: asyncpg.Connection, slug: str) -> uuid4: """Helper to create an org.""" org_repo = OrgRepository(conn) org_id = uuid4() await org_repo.create(org_id, f"Org {slug}", slug) return org_id async def test_create_token_returns_token_data(self, db_conn: asyncpg.Connection) -> None: """Creating a refresh token returns the token data including rotated_to.""" user_id = await self._create_user(db_conn, "token_create@example.com") org_id = await self._create_org(db_conn, "token-create-org") repo = RefreshTokenRepository(db_conn) token_id = uuid4() token_hash = "sha256_hashed_token_value" expires_at = datetime.now(UTC) + timedelta(days=30) result = await repo.create(token_id, user_id, token_hash, org_id, expires_at) assert result["id"] == token_id assert result["user_id"] == user_id assert result["token_hash"] == token_hash assert result["active_org_id"] == org_id assert result["expires_at"] is not None assert result["revoked_at"] is None assert result["rotated_to"] is None # New field assert result["created_at"] is not None async def test_token_hash_must_be_unique(self, db_conn: asyncpg.Connection) -> None: """Token hash uniqueness constraint per SPECS.md refresh_tokens table.""" user_id = await self._create_user(db_conn, "unique_hash@example.com") org_id = await self._create_org(db_conn, "unique-hash-org") repo = RefreshTokenRepository(db_conn) token_hash = "duplicate_hash_value" expires_at = datetime.now(UTC) + timedelta(days=30) await repo.create(uuid4(), user_id, token_hash, org_id, expires_at) with pytest.raises(asyncpg.UniqueViolationError): await repo.create(uuid4(), user_id, token_hash, org_id, expires_at) async def test_get_by_hash_returns_token(self, db_conn: asyncpg.Connection) -> None: """get_by_hash returns the correct token (even if revoked/expired).""" user_id = await self._create_user(db_conn, "get_hash@example.com") org_id = await self._create_org(db_conn, "get-hash-org") repo = RefreshTokenRepository(db_conn) token_id = uuid4() token_hash = "lookup_by_hash_value" expires_at = datetime.now(UTC) + timedelta(days=30) await repo.create(token_id, user_id, token_hash, org_id, expires_at) result = await repo.get_by_hash(token_hash) assert result is not None assert result["id"] == token_id assert result["token_hash"] == token_hash async def test_get_by_hash_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None: """get_by_hash returns None for non-existent hash.""" repo = RefreshTokenRepository(db_conn) result = await repo.get_by_hash("nonexistent_hash") assert result is None class TestGetValidByHash: """Tests for get_valid_by_hash with defense-in-depth validation.""" async def _setup_token( self, conn: asyncpg.Connection, suffix: str = "" ) -> tuple[uuid4, uuid4, uuid4, str, RefreshTokenRepository]: """Helper to create user, org, and token.""" user_repo = UserRepository(conn) org_repo = OrgRepository(conn) token_repo = RefreshTokenRepository(conn) user_id = uuid4() org_id = uuid4() token_id = uuid4() token_hash = f"token_hash_{uuid4().hex[:8]}{suffix}" await user_repo.create(user_id, f"user_{uuid4().hex[:8]}@example.com", "hash") await org_repo.create(org_id, "Test Org", f"test-org-{uuid4().hex[:8]}") expires_at = datetime.now(UTC) + timedelta(days=30) await token_repo.create(token_id, user_id, token_hash, org_id, expires_at) return token_id, user_id, org_id, token_hash, token_repo async def test_get_valid_returns_valid_token(self, db_conn: asyncpg.Connection) -> None: """get_valid_by_hash returns token if not expired, not revoked, not rotated.""" _, _, _, token_hash, repo = await self._setup_token(db_conn) result = await repo.get_valid_by_hash(token_hash) assert result is not None assert result["token_hash"] == token_hash async def test_get_valid_returns_none_for_expired(self, db_conn: asyncpg.Connection) -> None: """get_valid_by_hash returns None for expired token.""" user_repo = UserRepository(db_conn) org_repo = OrgRepository(db_conn) repo = RefreshTokenRepository(db_conn) user_id = uuid4() org_id = uuid4() await user_repo.create(user_id, "expired@example.com", "hash") await org_repo.create(org_id, "Org", "expired-org") token_hash = "expired_token_hash" expires_at = datetime.now(UTC) - timedelta(days=1) # Already expired await repo.create(uuid4(), user_id, token_hash, org_id, expires_at) result = await repo.get_valid_by_hash(token_hash) assert result is None async def test_get_valid_returns_none_for_revoked(self, db_conn: asyncpg.Connection) -> None: """get_valid_by_hash returns None for revoked token.""" token_id, _, _, token_hash, repo = await self._setup_token(db_conn, "_revoked") await repo.revoke(token_id) result = await repo.get_valid_by_hash(token_hash) assert result is None async def test_get_valid_returns_none_for_rotated(self, db_conn: asyncpg.Connection) -> None: """get_valid_by_hash returns None for already-rotated token.""" _, user_id, org_id, old_hash, repo = await self._setup_token(db_conn, "_rotated") # Rotate the token new_hash = f"new_token_{uuid4().hex[:8]}" new_expires = datetime.now(UTC) + timedelta(days=30) await repo.rotate(old_hash, uuid4(), new_hash, new_expires) # Old token should no longer be valid result = await repo.get_valid_by_hash(old_hash) assert result is None async def test_get_valid_with_user_id_validation(self, db_conn: asyncpg.Connection) -> None: """get_valid_by_hash validates user_id when provided (defense-in-depth).""" _, user_id, _, token_hash, repo = await self._setup_token(db_conn, "_user_check") # Correct user_id should work result = await repo.get_valid_by_hash(token_hash, user_id=user_id) assert result is not None # Wrong user_id should return None wrong_user_id = uuid4() result = await repo.get_valid_by_hash(token_hash, user_id=wrong_user_id) assert result is None async def test_get_valid_with_org_id_validation(self, db_conn: asyncpg.Connection) -> None: """get_valid_by_hash validates active_org_id when provided (defense-in-depth).""" _, _, org_id, token_hash, repo = await self._setup_token(db_conn, "_org_check") # Correct org_id should work result = await repo.get_valid_by_hash(token_hash, active_org_id=org_id) assert result is not None # Wrong org_id should return None wrong_org_id = uuid4() result = await repo.get_valid_by_hash(token_hash, active_org_id=wrong_org_id) assert result is None async def test_get_valid_with_both_user_and_org_validation(self, db_conn: asyncpg.Connection) -> None: """get_valid_by_hash validates both user_id and active_org_id together.""" _, user_id, org_id, token_hash, repo = await self._setup_token(db_conn, "_both") # Both correct should work result = await repo.get_valid_by_hash(token_hash, user_id=user_id, active_org_id=org_id) assert result is not None # Either wrong should fail result = await repo.get_valid_by_hash(token_hash, user_id=uuid4(), active_org_id=org_id) assert result is None result = await repo.get_valid_by_hash(token_hash, user_id=user_id, active_org_id=uuid4()) assert result is None class TestAtomicRotation: """Tests for atomic token rotation per SPECS.md.""" async def _setup_token( self, conn: asyncpg.Connection ) -> tuple[uuid4, uuid4, uuid4, str, RefreshTokenRepository]: """Helper to create user, org, and token.""" user_repo = UserRepository(conn) org_repo = OrgRepository(conn) token_repo = RefreshTokenRepository(conn) user_id = uuid4() org_id = uuid4() token_id = uuid4() token_hash = f"rotate_token_{uuid4().hex[:8]}" await user_repo.create(user_id, f"rotate_{uuid4().hex[:8]}@example.com", "hash") await org_repo.create(org_id, "Rotate Org", f"rotate-org-{uuid4().hex[:8]}") expires_at = datetime.now(UTC) + timedelta(days=30) await token_repo.create(token_id, user_id, token_hash, org_id, expires_at) return token_id, user_id, org_id, token_hash, token_repo async def test_rotate_creates_new_token(self, db_conn: asyncpg.Connection) -> None: """rotate() creates a new token and returns it.""" old_id, user_id, org_id, old_hash, repo = await self._setup_token(db_conn) new_id = uuid4() new_hash = f"new_rotated_{uuid4().hex[:8]}" new_expires = datetime.now(UTC) + timedelta(days=30) result = await repo.rotate(old_hash, new_id, new_hash, new_expires) assert result is not None assert result["id"] == new_id assert result["token_hash"] == new_hash assert result["user_id"] == user_id assert result["active_org_id"] == org_id async def test_rotate_marks_old_token_as_rotated(self, db_conn: asyncpg.Connection) -> None: """rotate() sets rotated_to on the old token (not revoked_at).""" old_id, _, _, old_hash, repo = await self._setup_token(db_conn) new_id = uuid4() new_hash = f"new_{uuid4().hex[:8]}" new_expires = datetime.now(UTC) + timedelta(days=30) await repo.rotate(old_hash, new_id, new_hash, new_expires) # Check old token state old_token = await repo.get_by_hash(old_hash) assert old_token["rotated_to"] == new_id assert old_token["revoked_at"] is None # Not revoked, just rotated async def test_rotate_fails_for_invalid_token(self, db_conn: asyncpg.Connection) -> None: """rotate() returns None if old token is invalid.""" _, _, _, _, repo = await self._setup_token(db_conn) result = await repo.rotate( "nonexistent_hash", uuid4(), f"new_{uuid4().hex[:8]}", datetime.now(UTC) + timedelta(days=30), ) assert result is None async def test_rotate_fails_for_expired_token(self, db_conn: asyncpg.Connection) -> None: """rotate() returns None if old token is expired.""" user_repo = UserRepository(db_conn) org_repo = OrgRepository(db_conn) repo = RefreshTokenRepository(db_conn) user_id = uuid4() org_id = uuid4() await user_repo.create(user_id, "exp_rotate@example.com", "hash") await org_repo.create(org_id, "Org", "exp-rotate-org") old_hash = "expired_for_rotation" await repo.create( uuid4(), user_id, old_hash, org_id, datetime.now(UTC) - timedelta(days=1) # Already expired ) result = await repo.rotate( old_hash, uuid4(), f"new_{uuid4().hex[:8]}", datetime.now(UTC) + timedelta(days=30) ) assert result is None async def test_rotate_fails_for_revoked_token(self, db_conn: asyncpg.Connection) -> None: """rotate() returns None if old token is revoked.""" old_id, _, _, old_hash, repo = await self._setup_token(db_conn) await repo.revoke(old_id) result = await repo.rotate( old_hash, uuid4(), f"new_{uuid4().hex[:8]}", datetime.now(UTC) + timedelta(days=30) ) assert result is None async def test_rotate_fails_for_already_rotated_token(self, db_conn: asyncpg.Connection) -> None: """rotate() returns None if old token was already rotated.""" _, _, _, old_hash, repo = await self._setup_token(db_conn) # First rotation should succeed result1 = await repo.rotate( old_hash, uuid4(), f"new1_{uuid4().hex[:8]}", datetime.now(UTC) + timedelta(days=30) ) assert result1 is not None # Second rotation of same token should fail result2 = await repo.rotate( old_hash, uuid4(), f"new2_{uuid4().hex[:8]}", datetime.now(UTC) + timedelta(days=30) ) assert result2 is None async def test_rotate_with_org_switch(self, db_conn: asyncpg.Connection) -> None: """rotate() can change active_org_id (for org-switch flow).""" _, user_id, old_org_id, old_hash, repo = await self._setup_token(db_conn) # Create a new org for the user to switch to org_repo = OrgRepository(db_conn) new_org_id = uuid4() await org_repo.create(new_org_id, "New Org", f"new-org-{uuid4().hex[:8]}") new_hash = f"switched_{uuid4().hex[:8]}" result = await repo.rotate( old_hash, uuid4(), new_hash, datetime.now(UTC) + timedelta(days=30), new_active_org_id=new_org_id # Switch org ) assert result is not None assert result["active_org_id"] == new_org_id assert result["active_org_id"] != old_org_id async def test_rotate_validates_expected_user_id(self, db_conn: asyncpg.Connection) -> None: """rotate() fails if expected_user_id doesn't match token's user.""" _, user_id, _, old_hash, repo = await self._setup_token(db_conn) # Wrong user should fail result = await repo.rotate( old_hash, uuid4(), f"new_{uuid4().hex[:8]}", datetime.now(UTC) + timedelta(days=30), expected_user_id=uuid4() # Wrong user ) assert result is None # Correct user should work result = await repo.rotate( old_hash, uuid4(), f"new_{uuid4().hex[:8]}", datetime.now(UTC) + timedelta(days=30), expected_user_id=user_id # Correct user ) assert result is not None class TestTokenReuseDetection: """Tests for detecting token reuse (stolen token attacks).""" async def _setup_rotated_token( self, conn: asyncpg.Connection ) -> tuple[uuid4, str, str, RefreshTokenRepository]: """Create a token and rotate it, returning old and new hashes.""" user_repo = UserRepository(conn) org_repo = OrgRepository(conn) token_repo = RefreshTokenRepository(conn) user_id = uuid4() org_id = uuid4() await user_repo.create(user_id, f"reuse_{uuid4().hex[:8]}@example.com", "hash") await org_repo.create(org_id, "Reuse Org", f"reuse-org-{uuid4().hex[:8]}") old_hash = f"old_token_{uuid4().hex[:8]}" expires_at = datetime.now(UTC) + timedelta(days=30) old_token = await token_repo.create(uuid4(), user_id, old_hash, org_id, expires_at) new_hash = f"new_token_{uuid4().hex[:8]}" await token_repo.rotate(old_hash, uuid4(), new_hash, expires_at) return old_token["id"], old_hash, new_hash, token_repo async def test_check_token_reuse_detects_rotated_token(self, db_conn: asyncpg.Connection) -> None: """check_token_reuse returns token if it has been rotated.""" old_id, old_hash, _, repo = await self._setup_rotated_token(db_conn) result = await repo.check_token_reuse(old_hash) assert result is not None assert result["id"] == old_id assert result["rotated_to"] is not None async def test_check_token_reuse_returns_none_for_active_token(self, db_conn: asyncpg.Connection) -> None: """check_token_reuse returns None for token that hasn't been rotated.""" user_repo = UserRepository(db_conn) org_repo = OrgRepository(db_conn) repo = RefreshTokenRepository(db_conn) user_id = uuid4() org_id = uuid4() await user_repo.create(user_id, "active@example.com", "hash") await org_repo.create(org_id, "Org", "active-org") token_hash = "active_token_hash" await repo.create( uuid4(), user_id, token_hash, org_id, datetime.now(UTC) + timedelta(days=30) ) result = await repo.check_token_reuse(token_hash) assert result is None async def test_check_token_reuse_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None: """check_token_reuse returns None for non-existent token.""" repo = RefreshTokenRepository(db_conn) result = await repo.check_token_reuse("nonexistent_hash") assert result is None class TestTokenChainRevocation: """Tests for revoking entire token chains (breach response).""" async def _setup_token_chain( self, conn: asyncpg.Connection, chain_length: int = 3 ) -> tuple[list[uuid4], list[str], uuid4, RefreshTokenRepository]: """Create a chain of rotated tokens.""" user_repo = UserRepository(conn) org_repo = OrgRepository(conn) token_repo = RefreshTokenRepository(conn) user_id = uuid4() org_id = uuid4() await user_repo.create(user_id, f"chain_{uuid4().hex[:8]}@example.com", "hash") await org_repo.create(org_id, "Chain Org", f"chain-org-{uuid4().hex[:8]}") token_ids = [] token_hashes = [] expires_at = datetime.now(UTC) + timedelta(days=30) # Create first token first_hash = f"chain_token_0_{uuid4().hex[:8]}" first_token = await token_repo.create(uuid4(), user_id, first_hash, org_id, expires_at) token_ids.append(first_token["id"]) token_hashes.append(first_hash) # Rotate to create chain current_hash = first_hash for i in range(1, chain_length): new_hash = f"chain_token_{i}_{uuid4().hex[:8]}" new_id = uuid4() await token_repo.rotate(current_hash, new_id, new_hash, expires_at) token_ids.append(new_id) token_hashes.append(new_hash) current_hash = new_hash return token_ids, token_hashes, user_id, token_repo async def test_revoke_token_chain_revokes_all_in_chain(self, db_conn: asyncpg.Connection) -> None: """revoke_token_chain revokes the token and all its rotations.""" token_ids, token_hashes, _, repo = await self._setup_token_chain(db_conn, chain_length=3) # Revoke starting from the first token count = await repo.revoke_token_chain(token_ids[0]) # Should revoke all 3 tokens in the chain # But note: only the last one wasn't already "consumed" by rotation # Let's check that revoke was called on all that were eligible assert count >= 1 # At least the leaf token # Verify the leaf token is revoked leaf_token = await repo.get_by_hash(token_hashes[-1]) assert leaf_token["revoked_at"] is not None async def test_revoke_token_chain_returns_count(self, db_conn: asyncpg.Connection) -> None: """revoke_token_chain returns count of actually revoked tokens.""" user_repo = UserRepository(db_conn) org_repo = OrgRepository(db_conn) repo = RefreshTokenRepository(db_conn) user_id = uuid4() org_id = uuid4() await user_repo.create(user_id, "single@example.com", "hash") await org_repo.create(org_id, "Single Org", "single-org") token_hash = "single_chain_token" token = await repo.create( uuid4(), user_id, token_hash, org_id, datetime.now(UTC) + timedelta(days=30) ) count = await repo.revoke_token_chain(token["id"]) assert count == 1 class TestTokenRevocation: """Tests for token revocation methods.""" async def _setup_token(self, conn: asyncpg.Connection) -> tuple[uuid4, str, RefreshTokenRepository]: """Helper to create user, org, and token.""" user_repo = UserRepository(conn) org_repo = OrgRepository(conn) token_repo = RefreshTokenRepository(conn) user_id = uuid4() org_id = uuid4() token_id = uuid4() token_hash = f"revoke_token_{uuid4().hex[:8]}" await user_repo.create(user_id, f"revoke_{uuid4().hex[:8]}@example.com", "hash") await org_repo.create(org_id, "Revoke Org", f"revoke-org-{uuid4().hex[:8]}") expires_at = datetime.now(UTC) + timedelta(days=30) await token_repo.create(token_id, user_id, token_hash, org_id, expires_at) return token_id, token_hash, token_repo async def test_revoke_sets_revoked_at(self, db_conn: asyncpg.Connection) -> None: """revoke() sets the revoked_at timestamp.""" token_id, token_hash, repo = await self._setup_token(db_conn) result = await repo.revoke(token_id) assert result is True token = await repo.get_by_hash(token_hash) assert token["revoked_at"] is not None async def test_revoke_returns_true_on_success(self, db_conn: asyncpg.Connection) -> None: """revoke() returns True when token is revoked.""" token_id, _, repo = await self._setup_token(db_conn) result = await repo.revoke(token_id) assert result is True async def test_revoke_returns_false_for_already_revoked(self, db_conn: asyncpg.Connection) -> None: """revoke() returns False if token already revoked.""" token_id, _, repo = await self._setup_token(db_conn) await repo.revoke(token_id) result = await repo.revoke(token_id) assert result is False async def test_revoke_returns_false_for_nonexistent(self, db_conn: asyncpg.Connection) -> None: """revoke() returns False for non-existent token.""" _, _, repo = await self._setup_token(db_conn) result = await repo.revoke(uuid4()) assert result is False async def test_revoke_by_hash_works(self, db_conn: asyncpg.Connection) -> None: """revoke_by_hash() revokes token by hash value.""" _, token_hash, repo = await self._setup_token(db_conn) result = await repo.revoke_by_hash(token_hash) assert result is True token = await repo.get_by_hash(token_hash) assert token["revoked_at"] is not None async def test_revoke_by_hash_returns_false_for_nonexistent(self, db_conn: asyncpg.Connection) -> None: """revoke_by_hash() returns False for non-existent hash.""" _, _, repo = await self._setup_token(db_conn) result = await repo.revoke_by_hash("nonexistent_hash") assert result is False class TestRevokeAllForUser: """Tests for revoking all tokens for a user.""" async def test_revoke_all_for_user_revokes_all_tokens(self, db_conn: asyncpg.Connection) -> None: """revoke_all_for_user() revokes all tokens for the user.""" user_repo = UserRepository(db_conn) org_repo = OrgRepository(db_conn) token_repo = RefreshTokenRepository(db_conn) user_id = uuid4() org_id = uuid4() await user_repo.create(user_id, "multi_token@example.com", "hash") await org_repo.create(org_id, "Multi Token Org", "multi-token-org") # Create multiple tokens hashes = [] for i in range(3): token_hash = f"token_{i}_{uuid4().hex[:8]}" hashes.append(token_hash) expires_at = datetime.now(UTC) + timedelta(days=30) await token_repo.create(uuid4(), user_id, token_hash, org_id, expires_at) result = await token_repo.revoke_all_for_user(user_id) assert result == 3 for token_hash in hashes: token = await token_repo.get_valid_by_hash(token_hash) assert token is None async def test_revoke_all_for_user_returns_zero_for_no_tokens(self, db_conn: asyncpg.Connection) -> None: """revoke_all_for_user() returns 0 if user has no tokens.""" user_repo = UserRepository(db_conn) token_repo = RefreshTokenRepository(db_conn) user_id = uuid4() await user_repo.create(user_id, "no_tokens@example.com", "hash") result = await token_repo.revoke_all_for_user(user_id) assert result == 0 async def test_revoke_all_for_user_only_affects_user_tokens(self, db_conn: asyncpg.Connection) -> None: """revoke_all_for_user() doesn't affect other users' tokens.""" user_repo = UserRepository(db_conn) org_repo = OrgRepository(db_conn) token_repo = RefreshTokenRepository(db_conn) user1 = uuid4() user2 = uuid4() org_id = uuid4() await user_repo.create(user1, "user1@example.com", "hash") await user_repo.create(user2, "user2@example.com", "hash") await org_repo.create(org_id, "Shared Org", "shared-org") user1_hash = f"user1_token_{uuid4().hex[:8]}" user2_hash = f"user2_token_{uuid4().hex[:8]}" expires_at = datetime.now(UTC) + timedelta(days=30) await token_repo.create(uuid4(), user1, user1_hash, org_id, expires_at) await token_repo.create(uuid4(), user2, user2_hash, org_id, expires_at) await token_repo.revoke_all_for_user(user1) # User1's token is revoked assert await token_repo.get_valid_by_hash(user1_hash) is None # User2's token is still valid assert await token_repo.get_valid_by_hash(user2_hash) is not None class TestRevokeAllExcept: """Tests for revoking all tokens except current session.""" async def test_revoke_all_except_keeps_specified_token(self, db_conn: asyncpg.Connection) -> None: """revoke_all_for_user_except() keeps the specified token active.""" user_repo = UserRepository(db_conn) org_repo = OrgRepository(db_conn) token_repo = RefreshTokenRepository(db_conn) user_id = uuid4() org_id = uuid4() await user_repo.create(user_id, "except@example.com", "hash") await org_repo.create(org_id, "Except Org", "except-org") # Create multiple tokens expires_at = datetime.now(UTC) + timedelta(days=30) keep_token_id = uuid4() keep_hash = f"keep_token_{uuid4().hex[:8]}" await token_repo.create(keep_token_id, user_id, keep_hash, org_id, expires_at) other_hashes = [] for i in range(2): other_hash = f"other_token_{i}_{uuid4().hex[:8]}" other_hashes.append(other_hash) await token_repo.create(uuid4(), user_id, other_hash, org_id, expires_at) result = await token_repo.revoke_all_for_user_except(user_id, keep_token_id) assert result == 2 # Revoked 2 other tokens # Keep token is still valid assert await token_repo.get_valid_by_hash(keep_hash) is not None # Other tokens are revoked for other_hash in other_hashes: assert await token_repo.get_valid_by_hash(other_hash) is None class TestActiveTokensForUser: """Tests for listing active tokens for a user.""" async def test_get_active_tokens_returns_only_active(self, db_conn: asyncpg.Connection) -> None: """get_active_tokens_for_user() returns only non-revoked, non-expired, non-rotated.""" user_repo = UserRepository(db_conn) org_repo = OrgRepository(db_conn) token_repo = RefreshTokenRepository(db_conn) user_id = uuid4() org_id = uuid4() await user_repo.create(user_id, "active_list@example.com", "hash") await org_repo.create(org_id, "Active List Org", "active-list-org") expires_at = datetime.now(UTC) + timedelta(days=30) expired_at = datetime.now(UTC) - timedelta(days=1) # Create active token active_hash = f"active_{uuid4().hex[:8]}" await token_repo.create(uuid4(), user_id, active_hash, org_id, expires_at) # Create revoked token revoked_id = uuid4() revoked_hash = f"revoked_{uuid4().hex[:8]}" await token_repo.create(revoked_id, user_id, revoked_hash, org_id, expires_at) await token_repo.revoke(revoked_id) # Create expired token expired_hash = f"expired_{uuid4().hex[:8]}" await token_repo.create(uuid4(), user_id, expired_hash, org_id, expired_at) # Create rotated token rotated_hash = f"rotated_{uuid4().hex[:8]}" await token_repo.create(uuid4(), user_id, rotated_hash, org_id, expires_at) await token_repo.rotate(rotated_hash, uuid4(), f"new_{uuid4().hex[:8]}", expires_at) result = await token_repo.get_active_tokens_for_user(user_id) # Should only return the active token and the new rotated token assert len(result) == 2 hashes = {t["token_hash"] for t in result} assert active_hash in hashes assert revoked_hash not in hashes assert expired_hash not in hashes assert rotated_hash not in hashes class TestTokenForeignKeys: """Tests for refresh token foreign key constraints.""" async def test_token_requires_valid_user_foreign_key(self, db_conn: asyncpg.Connection) -> None: """refresh_tokens.user_id must reference existing user.""" org_repo = OrgRepository(db_conn) token_repo = RefreshTokenRepository(db_conn) org_id = uuid4() await org_repo.create(org_id, "FK Test Org", "fk-test-org") with pytest.raises(asyncpg.ForeignKeyViolationError): await token_repo.create( uuid4(), uuid4(), "orphan_token", org_id, datetime.now(UTC) + timedelta(days=30) ) async def test_token_requires_valid_org_foreign_key(self, db_conn: asyncpg.Connection) -> None: """refresh_tokens.active_org_id must reference existing org.""" user_repo = UserRepository(db_conn) token_repo = RefreshTokenRepository(db_conn) user_id = uuid4() await user_repo.create(user_id, "fk_org_test@example.com", "hash") with pytest.raises(asyncpg.ForeignKeyViolationError): await token_repo.create( uuid4(), user_id, "orphan_org_token", uuid4(), datetime.now(UTC) + timedelta(days=30) ) async def test_token_stores_active_org_id(self, db_conn: asyncpg.Connection) -> None: """Token stores active_org_id for org context per SPECS.md.""" user_repo = UserRepository(db_conn) org_repo = OrgRepository(db_conn) token_repo = RefreshTokenRepository(db_conn) user_id = uuid4() org_id = uuid4() await user_repo.create(user_id, "active_org@example.com", "hash") await org_repo.create(org_id, "Active Org", "active-org") token_hash = f"active_org_token_{uuid4().hex[:8]}" await token_repo.create( uuid4(), user_id, token_hash, org_id, datetime.now(UTC) + timedelta(days=30) ) token = await token_repo.get_by_hash(token_hash) assert token["active_org_id"] == org_id