"""Refresh token repository for database operations. Security considerations implemented: - Atomic rotation using SELECT FOR UPDATE to prevent race conditions - Token chain tracking via rotated_to for reuse/theft detection - Defense-in-depth validation with user_id and active_org_id checks - Uses RETURNING for robust row counting instead of string parsing """ from datetime import datetime from uuid import UUID import asyncpg class RefreshTokenRepository: """Database operations for refresh tokens.""" def __init__(self, conn: asyncpg.Connection) -> None: self.conn = conn async def create( self, token_id: UUID, user_id: UUID, token_hash: str, active_org_id: UUID, expires_at: datetime, ) -> dict: """Create a new refresh token.""" row = await self.conn.fetchrow( """ INSERT INTO refresh_tokens (id, user_id, token_hash, active_org_id, expires_at) VALUES ($1, $2, $3, $4, $5) RETURNING id, user_id, token_hash, active_org_id, expires_at, revoked_at, rotated_to, created_at """, token_id, user_id, token_hash, active_org_id, expires_at, ) return dict(row) async def get_by_hash(self, token_hash: str) -> dict | None: """Get refresh token by hash (includes revoked/expired for auditing).""" row = await self.conn.fetchrow( """ SELECT id, user_id, token_hash, active_org_id, expires_at, revoked_at, rotated_to, created_at FROM refresh_tokens WHERE token_hash = $1 """, token_hash, ) return dict(row) if row else None async def get_valid_by_hash( self, token_hash: str, user_id: UUID | None = None, active_org_id: UUID | None = None, ) -> dict | None: """Get refresh token by hash, only if valid. Validates: - Token exists and matches hash - Token is not revoked - Token is not expired - Token has not been rotated (rotated_to is NULL) - Optionally: user_id matches (defense-in-depth) - Optionally: active_org_id matches (defense-in-depth) Args: token_hash: The hashed token value user_id: If provided, token must belong to this user active_org_id: If provided, token must be bound to this org Returns: Token dict if valid, None otherwise """ query = """ SELECT id, user_id, token_hash, active_org_id, expires_at, revoked_at, rotated_to, created_at FROM refresh_tokens WHERE token_hash = $1 AND revoked_at IS NULL AND rotated_to IS NULL AND expires_at > clock_timestamp() """ params: list = [token_hash] param_idx = 2 if user_id is not None: query += f" AND user_id = ${param_idx}" params.append(user_id) param_idx += 1 if active_org_id is not None: query += f" AND active_org_id = ${param_idx}" params.append(active_org_id) row = await self.conn.fetchrow(query, *params) return dict(row) if row else None async def get_valid_for_rotation( self, token_hash: str, user_id: UUID | None = None, ) -> dict | None: """Get and lock a valid token for rotation using SELECT FOR UPDATE. This acquires a row-level lock to prevent concurrent rotation attempts. Must be called within a transaction. Args: token_hash: The hashed token value user_id: If provided, token must belong to this user Returns: Token dict if valid and locked, None otherwise """ query = """ SELECT id, user_id, token_hash, active_org_id, expires_at, revoked_at, rotated_to, created_at FROM refresh_tokens WHERE token_hash = $1 AND revoked_at IS NULL AND rotated_to IS NULL AND expires_at > clock_timestamp() """ params: list = [token_hash] if user_id is not None: query += " AND user_id = $2" params.append(user_id) query += " FOR UPDATE" row = await self.conn.fetchrow(query, *params) return dict(row) if row else None async def check_token_reuse(self, token_hash: str) -> dict | None: """Check if a token has already been rotated (potential theft). If a token is presented that has rotated_to set, it means: 1. The token was legitimately rotated earlier 2. Someone is now trying to use the old token 3. This indicates the token may have been stolen Returns: Token dict if this is a reused/stolen token, None if not found or not rotated """ row = await self.conn.fetchrow( """ SELECT id, user_id, token_hash, active_org_id, expires_at, revoked_at, rotated_to, created_at FROM refresh_tokens WHERE token_hash = $1 AND rotated_to IS NOT NULL """, token_hash, ) return dict(row) if row else None async def revoke_token_chain(self, token_id: UUID) -> int: """Revoke a token and all tokens in its chain (for breach response). When token reuse is detected, this revokes: 1. The original stolen token 2. Any token it was rotated to (and their rotations, recursively) Args: token_id: The ID of the compromised token Returns: Count of tokens revoked """ # Use recursive CTE to find all tokens in the chain rows = await self.conn.fetch( """ WITH RECURSIVE token_chain AS ( -- Start with the given token SELECT id, rotated_to FROM refresh_tokens WHERE id = $1 UNION ALL -- Follow the chain via rotated_to SELECT rt.id, rt.rotated_to FROM refresh_tokens rt INNER JOIN token_chain tc ON rt.id = tc.rotated_to ) UPDATE refresh_tokens SET revoked_at = clock_timestamp() WHERE id IN (SELECT id FROM token_chain) AND revoked_at IS NULL RETURNING id """, token_id, ) return len(rows) async def rotate( self, old_token_hash: str, new_token_id: UUID, new_token_hash: str, new_expires_at: datetime, new_active_org_id: UUID | None = None, expected_user_id: UUID | None = None, ) -> dict | None: """Atomically rotate a refresh token. This method: 1. Validates the old token (not expired, not revoked, not already rotated) 2. Locks the row to prevent concurrent rotation 3. Marks old token as rotated (sets rotated_to) 4. Creates new token with updated org if specified 5. All in a single atomic operation Args: old_token_hash: Hash of the token being rotated new_token_id: UUID for the new token new_token_hash: Hash for the new token new_expires_at: Expiry time for the new token new_active_org_id: New org ID (for org-switch), or None to keep current expected_user_id: If provided, validates token belongs to this user Returns: New token dict if rotation succeeded, None if old token invalid/expired """ # First, get and lock the old token old_token = await self.get_valid_for_rotation(old_token_hash, expected_user_id) if old_token is None: return None # Determine the org for the new token active_org_id = new_active_org_id or old_token["active_org_id"] user_id = old_token["user_id"] # Create the new token new_token = await self.conn.fetchrow( """ INSERT INTO refresh_tokens (id, user_id, token_hash, active_org_id, expires_at) VALUES ($1, $2, $3, $4, $5) RETURNING id, user_id, token_hash, active_org_id, expires_at, revoked_at, rotated_to, created_at """, new_token_id, user_id, new_token_hash, active_org_id, new_expires_at, ) # Mark the old token as rotated (not revoked - for reuse detection) await self.conn.execute( """ UPDATE refresh_tokens SET rotated_to = $2 WHERE id = $1 """, old_token["id"], new_token_id, ) return dict(new_token) async def revoke(self, token_id: UUID) -> bool: """Revoke a refresh token by ID. Returns: True if token was revoked, False if not found or already revoked """ row = await self.conn.fetchrow( """ UPDATE refresh_tokens SET revoked_at = clock_timestamp() WHERE id = $1 AND revoked_at IS NULL RETURNING id """, token_id, ) return row is not None async def revoke_by_hash(self, token_hash: str) -> bool: """Revoke a refresh token by hash. Returns: True if token was revoked, False if not found or already revoked """ row = await self.conn.fetchrow( """ UPDATE refresh_tokens SET revoked_at = clock_timestamp() WHERE token_hash = $1 AND revoked_at IS NULL RETURNING id """, token_hash, ) return row is not None async def revoke_all_for_user(self, user_id: UUID) -> int: """Revoke all active refresh tokens for a user. Use this for: - User-initiated logout from all devices - Password change - Account compromise response Returns: Count of tokens revoked """ rows = await self.conn.fetch( """ UPDATE refresh_tokens SET revoked_at = clock_timestamp() WHERE user_id = $1 AND revoked_at IS NULL RETURNING id """, user_id, ) return len(rows) async def revoke_all_for_user_except(self, user_id: UUID, keep_token_id: UUID) -> int: """Revoke all tokens for a user except one (logout other sessions). Args: user_id: The user whose tokens to revoke keep_token_id: The token ID to keep active (current session) Returns: Count of tokens revoked """ rows = await self.conn.fetch( """ UPDATE refresh_tokens SET revoked_at = clock_timestamp() WHERE user_id = $1 AND revoked_at IS NULL AND id != $2 RETURNING id """, user_id, keep_token_id, ) return len(rows) async def get_active_tokens_for_user(self, user_id: UUID) -> list[dict]: """Get all active (non-revoked, non-expired, non-rotated) tokens for a user. Useful for: - Showing active sessions - Auditing Returns: List of active token records """ rows = await self.conn.fetch( """ SELECT id, user_id, token_hash, active_org_id, expires_at, revoked_at, rotated_to, created_at FROM refresh_tokens WHERE user_id = $1 AND revoked_at IS NULL AND rotated_to IS NULL AND expires_at > clock_timestamp() ORDER BY created_at DESC """, user_id, ) return [dict(row) for row in rows] async def cleanup_expired(self, older_than_days: int = 30) -> int: """Delete expired tokens older than specified days. Note: This performs a hard delete. For audit trails, I think we should: - Archiving to a separate table first - Using partitioning with retention policies - Only calling this for tokens well past their expiry Args: older_than_days: Only delete tokens expired more than this many days ago Returns: Count of tokens deleted """ rows = await self.conn.fetch( """ DELETE FROM refresh_tokens WHERE expires_at < clock_timestamp() - interval '1 day' * $1 RETURNING id """, older_than_days, ) return len(rows)