feat(api): Pydantic schemas + Data Repositories
This commit is contained in:
396
app/repositories/refresh_token.py
Normal file
396
app/repositories/refresh_token.py
Normal file
@@ -0,0 +1,396 @@
|
||||
"""Refresh token repository for database operations.
|
||||
|
||||
Security considerations implemented:
|
||||
- Atomic rotation using SELECT FOR UPDATE to prevent race conditions
|
||||
- Token chain tracking via rotated_to for reuse/theft detection
|
||||
- Defense-in-depth validation with user_id and active_org_id checks
|
||||
- Uses RETURNING for robust row counting instead of string parsing
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
import asyncpg
|
||||
|
||||
|
||||
class RefreshTokenRepository:
|
||||
"""Database operations for refresh tokens."""
|
||||
|
||||
def __init__(self, conn: asyncpg.Connection) -> None:
|
||||
self.conn = conn
|
||||
|
||||
async def create(
|
||||
self,
|
||||
token_id: UUID,
|
||||
user_id: UUID,
|
||||
token_hash: str,
|
||||
active_org_id: UUID,
|
||||
expires_at: datetime,
|
||||
) -> dict:
|
||||
"""Create a new refresh token."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO refresh_tokens (id, user_id, token_hash, active_org_id, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
RETURNING id, user_id, token_hash, active_org_id, expires_at,
|
||||
revoked_at, rotated_to, created_at
|
||||
""",
|
||||
token_id,
|
||||
user_id,
|
||||
token_hash,
|
||||
active_org_id,
|
||||
expires_at,
|
||||
)
|
||||
return dict(row)
|
||||
|
||||
async def get_by_hash(self, token_hash: str) -> dict | None:
|
||||
"""Get refresh token by hash (includes revoked/expired for auditing)."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT id, user_id, token_hash, active_org_id, expires_at,
|
||||
revoked_at, rotated_to, created_at
|
||||
FROM refresh_tokens
|
||||
WHERE token_hash = $1
|
||||
""",
|
||||
token_hash,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def get_valid_by_hash(
|
||||
self,
|
||||
token_hash: str,
|
||||
user_id: UUID | None = None,
|
||||
active_org_id: UUID | None = None,
|
||||
) -> dict | None:
|
||||
"""Get refresh token by hash, only if valid.
|
||||
|
||||
Validates:
|
||||
- Token exists and matches hash
|
||||
- Token is not revoked
|
||||
- Token is not expired
|
||||
- Token has not been rotated (rotated_to is NULL)
|
||||
- Optionally: user_id matches (defense-in-depth)
|
||||
- Optionally: active_org_id matches (defense-in-depth)
|
||||
|
||||
Args:
|
||||
token_hash: The hashed token value
|
||||
user_id: If provided, token must belong to this user
|
||||
active_org_id: If provided, token must be bound to this org
|
||||
|
||||
Returns:
|
||||
Token dict if valid, None otherwise
|
||||
"""
|
||||
query = """
|
||||
SELECT id, user_id, token_hash, active_org_id, expires_at,
|
||||
revoked_at, rotated_to, created_at
|
||||
FROM refresh_tokens
|
||||
WHERE token_hash = $1
|
||||
AND revoked_at IS NULL
|
||||
AND rotated_to IS NULL
|
||||
AND expires_at > clock_timestamp()
|
||||
"""
|
||||
params: list = [token_hash]
|
||||
param_idx = 2
|
||||
|
||||
if user_id is not None:
|
||||
query += f" AND user_id = ${param_idx}"
|
||||
params.append(user_id)
|
||||
param_idx += 1
|
||||
|
||||
if active_org_id is not None:
|
||||
query += f" AND active_org_id = ${param_idx}"
|
||||
params.append(active_org_id)
|
||||
|
||||
row = await self.conn.fetchrow(query, *params)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def get_valid_for_rotation(
|
||||
self,
|
||||
token_hash: str,
|
||||
user_id: UUID | None = None,
|
||||
) -> dict | None:
|
||||
"""Get and lock a valid token for rotation using SELECT FOR UPDATE.
|
||||
|
||||
This acquires a row-level lock to prevent concurrent rotation attempts.
|
||||
Must be called within a transaction.
|
||||
|
||||
Args:
|
||||
token_hash: The hashed token value
|
||||
user_id: If provided, token must belong to this user
|
||||
|
||||
Returns:
|
||||
Token dict if valid and locked, None otherwise
|
||||
"""
|
||||
query = """
|
||||
SELECT id, user_id, token_hash, active_org_id, expires_at,
|
||||
revoked_at, rotated_to, created_at
|
||||
FROM refresh_tokens
|
||||
WHERE token_hash = $1
|
||||
AND revoked_at IS NULL
|
||||
AND rotated_to IS NULL
|
||||
AND expires_at > clock_timestamp()
|
||||
"""
|
||||
params: list = [token_hash]
|
||||
|
||||
if user_id is not None:
|
||||
query += " AND user_id = $2"
|
||||
params.append(user_id)
|
||||
|
||||
query += " FOR UPDATE"
|
||||
|
||||
row = await self.conn.fetchrow(query, *params)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def check_token_reuse(self, token_hash: str) -> dict | None:
|
||||
"""Check if a token has already been rotated (potential theft).
|
||||
|
||||
If a token is presented that has rotated_to set, it means:
|
||||
1. The token was legitimately rotated earlier
|
||||
2. Someone is now trying to use the old token
|
||||
3. This indicates the token may have been stolen
|
||||
|
||||
Returns:
|
||||
Token dict if this is a reused/stolen token, None if not found or not rotated
|
||||
"""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT id, user_id, token_hash, active_org_id, expires_at,
|
||||
revoked_at, rotated_to, created_at
|
||||
FROM refresh_tokens
|
||||
WHERE token_hash = $1 AND rotated_to IS NOT NULL
|
||||
""",
|
||||
token_hash,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def revoke_token_chain(self, token_id: UUID) -> int:
|
||||
"""Revoke a token and all tokens in its chain (for breach response).
|
||||
|
||||
When token reuse is detected, this revokes:
|
||||
1. The original stolen token
|
||||
2. Any token it was rotated to (and their rotations, recursively)
|
||||
|
||||
Args:
|
||||
token_id: The ID of the compromised token
|
||||
|
||||
Returns:
|
||||
Count of tokens revoked
|
||||
"""
|
||||
# Use recursive CTE to find all tokens in the chain
|
||||
rows = await self.conn.fetch(
|
||||
"""
|
||||
WITH RECURSIVE token_chain AS (
|
||||
-- Start with the given token
|
||||
SELECT id, rotated_to
|
||||
FROM refresh_tokens
|
||||
WHERE id = $1
|
||||
|
||||
UNION ALL
|
||||
|
||||
-- Follow the chain via rotated_to
|
||||
SELECT rt.id, rt.rotated_to
|
||||
FROM refresh_tokens rt
|
||||
INNER JOIN token_chain tc ON rt.id = tc.rotated_to
|
||||
)
|
||||
UPDATE refresh_tokens
|
||||
SET revoked_at = clock_timestamp()
|
||||
WHERE id IN (SELECT id FROM token_chain)
|
||||
AND revoked_at IS NULL
|
||||
RETURNING id
|
||||
""",
|
||||
token_id,
|
||||
)
|
||||
return len(rows)
|
||||
|
||||
async def rotate(
|
||||
self,
|
||||
old_token_hash: str,
|
||||
new_token_id: UUID,
|
||||
new_token_hash: str,
|
||||
new_expires_at: datetime,
|
||||
new_active_org_id: UUID | None = None,
|
||||
expected_user_id: UUID | None = None,
|
||||
) -> dict | None:
|
||||
"""Atomically rotate a refresh token.
|
||||
|
||||
This method:
|
||||
1. Validates the old token (not expired, not revoked, not already rotated)
|
||||
2. Locks the row to prevent concurrent rotation
|
||||
3. Marks old token as rotated (sets rotated_to)
|
||||
4. Creates new token with updated org if specified
|
||||
5. All in a single atomic operation
|
||||
|
||||
Args:
|
||||
old_token_hash: Hash of the token being rotated
|
||||
new_token_id: UUID for the new token
|
||||
new_token_hash: Hash for the new token
|
||||
new_expires_at: Expiry time for the new token
|
||||
new_active_org_id: New org ID (for org-switch), or None to keep current
|
||||
expected_user_id: If provided, validates token belongs to this user
|
||||
|
||||
Returns:
|
||||
New token dict if rotation succeeded, None if old token invalid/expired
|
||||
"""
|
||||
# First, get and lock the old token
|
||||
old_token = await self.get_valid_for_rotation(old_token_hash, expected_user_id)
|
||||
if old_token is None:
|
||||
return None
|
||||
|
||||
# Determine the org for the new token
|
||||
active_org_id = new_active_org_id or old_token["active_org_id"]
|
||||
user_id = old_token["user_id"]
|
||||
|
||||
# Create the new token
|
||||
new_token = await self.conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO refresh_tokens (id, user_id, token_hash, active_org_id, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
RETURNING id, user_id, token_hash, active_org_id, expires_at,
|
||||
revoked_at, rotated_to, created_at
|
||||
""",
|
||||
new_token_id,
|
||||
user_id,
|
||||
new_token_hash,
|
||||
active_org_id,
|
||||
new_expires_at,
|
||||
)
|
||||
|
||||
# Mark the old token as rotated (not revoked - for reuse detection)
|
||||
await self.conn.execute(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET rotated_to = $2
|
||||
WHERE id = $1
|
||||
""",
|
||||
old_token["id"],
|
||||
new_token_id,
|
||||
)
|
||||
|
||||
return dict(new_token)
|
||||
|
||||
async def revoke(self, token_id: UUID) -> bool:
|
||||
"""Revoke a refresh token by ID.
|
||||
|
||||
Returns:
|
||||
True if token was revoked, False if not found or already revoked
|
||||
"""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET revoked_at = clock_timestamp()
|
||||
WHERE id = $1 AND revoked_at IS NULL
|
||||
RETURNING id
|
||||
""",
|
||||
token_id,
|
||||
)
|
||||
return row is not None
|
||||
|
||||
async def revoke_by_hash(self, token_hash: str) -> bool:
|
||||
"""Revoke a refresh token by hash.
|
||||
|
||||
Returns:
|
||||
True if token was revoked, False if not found or already revoked
|
||||
"""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET revoked_at = clock_timestamp()
|
||||
WHERE token_hash = $1 AND revoked_at IS NULL
|
||||
RETURNING id
|
||||
""",
|
||||
token_hash,
|
||||
)
|
||||
return row is not None
|
||||
|
||||
async def revoke_all_for_user(self, user_id: UUID) -> int:
|
||||
"""Revoke all active refresh tokens for a user.
|
||||
|
||||
Use this for:
|
||||
- User-initiated logout from all devices
|
||||
- Password change
|
||||
- Account compromise response
|
||||
|
||||
Returns:
|
||||
Count of tokens revoked
|
||||
"""
|
||||
rows = await self.conn.fetch(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET revoked_at = clock_timestamp()
|
||||
WHERE user_id = $1 AND revoked_at IS NULL
|
||||
RETURNING id
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
return len(rows)
|
||||
|
||||
async def revoke_all_for_user_except(self, user_id: UUID, keep_token_id: UUID) -> int:
|
||||
"""Revoke all tokens for a user except one (logout other sessions).
|
||||
|
||||
Args:
|
||||
user_id: The user whose tokens to revoke
|
||||
keep_token_id: The token ID to keep active (current session)
|
||||
|
||||
Returns:
|
||||
Count of tokens revoked
|
||||
"""
|
||||
rows = await self.conn.fetch(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET revoked_at = clock_timestamp()
|
||||
WHERE user_id = $1 AND revoked_at IS NULL AND id != $2
|
||||
RETURNING id
|
||||
""",
|
||||
user_id,
|
||||
keep_token_id,
|
||||
)
|
||||
return len(rows)
|
||||
|
||||
async def get_active_tokens_for_user(self, user_id: UUID) -> list[dict]:
|
||||
"""Get all active (non-revoked, non-expired, non-rotated) tokens for a user.
|
||||
|
||||
Useful for:
|
||||
- Showing active sessions
|
||||
- Auditing
|
||||
|
||||
Returns:
|
||||
List of active token records
|
||||
"""
|
||||
rows = await self.conn.fetch(
|
||||
"""
|
||||
SELECT id, user_id, token_hash, active_org_id, expires_at,
|
||||
revoked_at, rotated_to, created_at
|
||||
FROM refresh_tokens
|
||||
WHERE user_id = $1
|
||||
AND revoked_at IS NULL
|
||||
AND rotated_to IS NULL
|
||||
AND expires_at > clock_timestamp()
|
||||
ORDER BY created_at DESC
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
async def cleanup_expired(self, older_than_days: int = 30) -> int:
|
||||
"""Delete expired tokens older than specified days.
|
||||
|
||||
Note: This performs a hard delete. For audit trails, I think we should:
|
||||
- Archiving to a separate table first
|
||||
- Using partitioning with retention policies
|
||||
- Only calling this for tokens well past their expiry
|
||||
|
||||
Args:
|
||||
older_than_days: Only delete tokens expired more than this many days ago
|
||||
|
||||
Returns:
|
||||
Count of tokens deleted
|
||||
"""
|
||||
rows = await self.conn.fetch(
|
||||
"""
|
||||
DELETE FROM refresh_tokens
|
||||
WHERE expires_at < clock_timestamp() - interval '1 day' * $1
|
||||
RETURNING id
|
||||
""",
|
||||
older_than_days,
|
||||
)
|
||||
return len(rows)
|
||||
Reference in New Issue
Block a user