Files
incidentops/app/repositories/refresh_token.py

397 lines
13 KiB
Python
Raw Permalink Normal View History

"""Refresh token repository for database operations.
Security considerations implemented:
- Atomic rotation using SELECT FOR UPDATE to prevent race conditions
- Token chain tracking via rotated_to for reuse/theft detection
- Defense-in-depth validation with user_id and active_org_id checks
- Uses RETURNING for robust row counting instead of string parsing
"""
from datetime import datetime
from uuid import UUID
import asyncpg
class RefreshTokenRepository:
"""Database operations for refresh tokens."""
def __init__(self, conn: asyncpg.Connection) -> None:
self.conn = conn
async def create(
self,
token_id: UUID,
user_id: UUID,
token_hash: str,
active_org_id: UUID,
expires_at: datetime,
) -> dict:
"""Create a new refresh token."""
row = await self.conn.fetchrow(
"""
INSERT INTO refresh_tokens (id, user_id, token_hash, active_org_id, expires_at)
VALUES ($1, $2, $3, $4, $5)
RETURNING id, user_id, token_hash, active_org_id, expires_at,
revoked_at, rotated_to, created_at
""",
token_id,
user_id,
token_hash,
active_org_id,
expires_at,
)
return dict(row)
async def get_by_hash(self, token_hash: str) -> dict | None:
"""Get refresh token by hash (includes revoked/expired for auditing)."""
row = await self.conn.fetchrow(
"""
SELECT id, user_id, token_hash, active_org_id, expires_at,
revoked_at, rotated_to, created_at
FROM refresh_tokens
WHERE token_hash = $1
""",
token_hash,
)
return dict(row) if row else None
async def get_valid_by_hash(
self,
token_hash: str,
user_id: UUID | None = None,
active_org_id: UUID | None = None,
) -> dict | None:
"""Get refresh token by hash, only if valid.
Validates:
- Token exists and matches hash
- Token is not revoked
- Token is not expired
- Token has not been rotated (rotated_to is NULL)
- Optionally: user_id matches (defense-in-depth)
- Optionally: active_org_id matches (defense-in-depth)
Args:
token_hash: The hashed token value
user_id: If provided, token must belong to this user
active_org_id: If provided, token must be bound to this org
Returns:
Token dict if valid, None otherwise
"""
query = """
SELECT id, user_id, token_hash, active_org_id, expires_at,
revoked_at, rotated_to, created_at
FROM refresh_tokens
WHERE token_hash = $1
AND revoked_at IS NULL
AND rotated_to IS NULL
AND expires_at > clock_timestamp()
"""
params: list = [token_hash]
param_idx = 2
if user_id is not None:
query += f" AND user_id = ${param_idx}"
params.append(user_id)
param_idx += 1
if active_org_id is not None:
query += f" AND active_org_id = ${param_idx}"
params.append(active_org_id)
row = await self.conn.fetchrow(query, *params)
return dict(row) if row else None
async def get_valid_for_rotation(
self,
token_hash: str,
user_id: UUID | None = None,
) -> dict | None:
"""Get and lock a valid token for rotation using SELECT FOR UPDATE.
This acquires a row-level lock to prevent concurrent rotation attempts.
Must be called within a transaction.
Args:
token_hash: The hashed token value
user_id: If provided, token must belong to this user
Returns:
Token dict if valid and locked, None otherwise
"""
query = """
SELECT id, user_id, token_hash, active_org_id, expires_at,
revoked_at, rotated_to, created_at
FROM refresh_tokens
WHERE token_hash = $1
AND revoked_at IS NULL
AND rotated_to IS NULL
AND expires_at > clock_timestamp()
"""
params: list = [token_hash]
if user_id is not None:
query += " AND user_id = $2"
params.append(user_id)
query += " FOR UPDATE"
row = await self.conn.fetchrow(query, *params)
return dict(row) if row else None
async def check_token_reuse(self, token_hash: str) -> dict | None:
"""Check if a token has already been rotated (potential theft).
If a token is presented that has rotated_to set, it means:
1. The token was legitimately rotated earlier
2. Someone is now trying to use the old token
3. This indicates the token may have been stolen
Returns:
Token dict if this is a reused/stolen token, None if not found or not rotated
"""
row = await self.conn.fetchrow(
"""
SELECT id, user_id, token_hash, active_org_id, expires_at,
revoked_at, rotated_to, created_at
FROM refresh_tokens
WHERE token_hash = $1 AND rotated_to IS NOT NULL
""",
token_hash,
)
return dict(row) if row else None
async def revoke_token_chain(self, token_id: UUID) -> int:
"""Revoke a token and all tokens in its chain (for breach response).
When token reuse is detected, this revokes:
1. The original stolen token
2. Any token it was rotated to (and their rotations, recursively)
Args:
token_id: The ID of the compromised token
Returns:
Count of tokens revoked
"""
# Use recursive CTE to find all tokens in the chain
rows = await self.conn.fetch(
"""
WITH RECURSIVE token_chain AS (
-- Start with the given token
SELECT id, rotated_to
FROM refresh_tokens
WHERE id = $1
UNION ALL
-- Follow the chain via rotated_to
SELECT rt.id, rt.rotated_to
FROM refresh_tokens rt
INNER JOIN token_chain tc ON rt.id = tc.rotated_to
)
UPDATE refresh_tokens
SET revoked_at = clock_timestamp()
WHERE id IN (SELECT id FROM token_chain)
AND revoked_at IS NULL
RETURNING id
""",
token_id,
)
return len(rows)
async def rotate(
self,
old_token_hash: str,
new_token_id: UUID,
new_token_hash: str,
new_expires_at: datetime,
new_active_org_id: UUID | None = None,
expected_user_id: UUID | None = None,
) -> dict | None:
"""Atomically rotate a refresh token.
This method:
1. Validates the old token (not expired, not revoked, not already rotated)
2. Locks the row to prevent concurrent rotation
3. Marks old token as rotated (sets rotated_to)
4. Creates new token with updated org if specified
5. All in a single atomic operation
Args:
old_token_hash: Hash of the token being rotated
new_token_id: UUID for the new token
new_token_hash: Hash for the new token
new_expires_at: Expiry time for the new token
new_active_org_id: New org ID (for org-switch), or None to keep current
expected_user_id: If provided, validates token belongs to this user
Returns:
New token dict if rotation succeeded, None if old token invalid/expired
"""
# First, get and lock the old token
old_token = await self.get_valid_for_rotation(old_token_hash, expected_user_id)
if old_token is None:
return None
# Determine the org for the new token
active_org_id = new_active_org_id or old_token["active_org_id"]
user_id = old_token["user_id"]
# Create the new token
new_token = await self.conn.fetchrow(
"""
INSERT INTO refresh_tokens (id, user_id, token_hash, active_org_id, expires_at)
VALUES ($1, $2, $3, $4, $5)
RETURNING id, user_id, token_hash, active_org_id, expires_at,
revoked_at, rotated_to, created_at
""",
new_token_id,
user_id,
new_token_hash,
active_org_id,
new_expires_at,
)
# Mark the old token as rotated (not revoked - for reuse detection)
await self.conn.execute(
"""
UPDATE refresh_tokens
SET rotated_to = $2
WHERE id = $1
""",
old_token["id"],
new_token_id,
)
return dict(new_token)
async def revoke(self, token_id: UUID) -> bool:
"""Revoke a refresh token by ID.
Returns:
True if token was revoked, False if not found or already revoked
"""
row = await self.conn.fetchrow(
"""
UPDATE refresh_tokens
SET revoked_at = clock_timestamp()
WHERE id = $1 AND revoked_at IS NULL
RETURNING id
""",
token_id,
)
return row is not None
async def revoke_by_hash(self, token_hash: str) -> bool:
"""Revoke a refresh token by hash.
Returns:
True if token was revoked, False if not found or already revoked
"""
row = await self.conn.fetchrow(
"""
UPDATE refresh_tokens
SET revoked_at = clock_timestamp()
WHERE token_hash = $1 AND revoked_at IS NULL
RETURNING id
""",
token_hash,
)
return row is not None
async def revoke_all_for_user(self, user_id: UUID) -> int:
"""Revoke all active refresh tokens for a user.
Use this for:
- User-initiated logout from all devices
- Password change
- Account compromise response
Returns:
Count of tokens revoked
"""
rows = await self.conn.fetch(
"""
UPDATE refresh_tokens
SET revoked_at = clock_timestamp()
WHERE user_id = $1 AND revoked_at IS NULL
RETURNING id
""",
user_id,
)
return len(rows)
async def revoke_all_for_user_except(self, user_id: UUID, keep_token_id: UUID) -> int:
"""Revoke all tokens for a user except one (logout other sessions).
Args:
user_id: The user whose tokens to revoke
keep_token_id: The token ID to keep active (current session)
Returns:
Count of tokens revoked
"""
rows = await self.conn.fetch(
"""
UPDATE refresh_tokens
SET revoked_at = clock_timestamp()
WHERE user_id = $1 AND revoked_at IS NULL AND id != $2
RETURNING id
""",
user_id,
keep_token_id,
)
return len(rows)
async def get_active_tokens_for_user(self, user_id: UUID) -> list[dict]:
"""Get all active (non-revoked, non-expired, non-rotated) tokens for a user.
Useful for:
- Showing active sessions
- Auditing
Returns:
List of active token records
"""
rows = await self.conn.fetch(
"""
SELECT id, user_id, token_hash, active_org_id, expires_at,
revoked_at, rotated_to, created_at
FROM refresh_tokens
WHERE user_id = $1
AND revoked_at IS NULL
AND rotated_to IS NULL
AND expires_at > clock_timestamp()
ORDER BY created_at DESC
""",
user_id,
)
return [dict(row) for row in rows]
async def cleanup_expired(self, older_than_days: int = 30) -> int:
"""Delete expired tokens older than specified days.
Note: This performs a hard delete. For audit trails, I think we should:
- Archiving to a separate table first
- Using partitioning with retention policies
- Only calling this for tokens well past their expiry
Args:
older_than_days: Only delete tokens expired more than this many days ago
Returns:
Count of tokens deleted
"""
rows = await self.conn.fetch(
"""
DELETE FROM refresh_tokens
WHERE expires_at < clock_timestamp() - interval '1 day' * $1
RETURNING id
""",
older_than_days,
)
return len(rows)