107 lines
2.7 KiB
Python
107 lines
2.7 KiB
Python
|
|
"""Security utilities for JWT and password hashing."""
|
||
|
|
|
||
|
|
import hashlib
|
||
|
|
import secrets
|
||
|
|
from datetime import UTC, datetime, timedelta
|
||
|
|
from typing import Any
|
||
|
|
from uuid import UUID, uuid4
|
||
|
|
|
||
|
|
import bcrypt
|
||
|
|
from jose import JWTError, jwt
|
||
|
|
|
||
|
|
from app.config import settings
|
||
|
|
|
||
|
|
|
||
|
|
def hash_password(password: str) -> str:
|
||
|
|
"""Hash a password using bcrypt."""
|
||
|
|
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||
|
|
|
||
|
|
|
||
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||
|
|
"""Verify a password against its hash."""
|
||
|
|
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
|
||
|
|
|
||
|
|
|
||
|
|
def create_access_token(
|
||
|
|
sub: str,
|
||
|
|
org_id: str,
|
||
|
|
org_role: str,
|
||
|
|
expires_delta: timedelta | None = None,
|
||
|
|
) -> str:
|
||
|
|
"""Create a JWT access token with org context."""
|
||
|
|
if expires_delta is None:
|
||
|
|
expires_delta = timedelta(minutes=settings.access_token_expire_minutes)
|
||
|
|
|
||
|
|
now = datetime.now(UTC)
|
||
|
|
expire = now + expires_delta
|
||
|
|
|
||
|
|
payload = {
|
||
|
|
"sub": sub,
|
||
|
|
"org_id": org_id,
|
||
|
|
"org_role": org_role,
|
||
|
|
"iss": settings.jwt_issuer,
|
||
|
|
"aud": settings.jwt_audience,
|
||
|
|
"jti": str(uuid4()),
|
||
|
|
"iat": now,
|
||
|
|
"exp": expire,
|
||
|
|
}
|
||
|
|
|
||
|
|
return jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
|
||
|
|
|
||
|
|
|
||
|
|
def decode_access_token(token: str) -> dict[str, Any]:
|
||
|
|
"""Decode and validate a JWT access token.
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
JWTError: If token is invalid or expired.
|
||
|
|
"""
|
||
|
|
return jwt.decode(
|
||
|
|
token,
|
||
|
|
settings.jwt_secret_key,
|
||
|
|
algorithms=[settings.jwt_algorithm],
|
||
|
|
issuer=settings.jwt_issuer,
|
||
|
|
audience=settings.jwt_audience,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def generate_refresh_token() -> str:
|
||
|
|
"""Generate a secure random refresh token."""
|
||
|
|
return secrets.token_urlsafe(32)
|
||
|
|
|
||
|
|
|
||
|
|
def hash_token(token: str) -> str:
|
||
|
|
"""Hash a refresh token for storage."""
|
||
|
|
return hashlib.sha256(token.encode()).hexdigest()
|
||
|
|
|
||
|
|
|
||
|
|
def get_refresh_token_expiry() -> datetime:
|
||
|
|
"""Get expiry datetime for a new refresh token."""
|
||
|
|
return datetime.now(UTC) + timedelta(days=settings.refresh_token_expire_days)
|
||
|
|
|
||
|
|
|
||
|
|
class TokenPayload:
|
||
|
|
"""Parsed JWT token payload."""
|
||
|
|
|
||
|
|
def __init__(self, payload: dict[str, Any]) -> None:
|
||
|
|
self.user_id = UUID(payload["sub"])
|
||
|
|
self.org_id = UUID(payload["org_id"])
|
||
|
|
self.org_role = payload["org_role"]
|
||
|
|
self.issuer = payload["iss"]
|
||
|
|
self.audience = payload["aud"]
|
||
|
|
self.jti = UUID(payload["jti"])
|
||
|
|
self.issued_at = payload["iat"]
|
||
|
|
self.expires_at = payload["exp"]
|
||
|
|
|
||
|
|
|
||
|
|
__all__ = [
|
||
|
|
"JWTError",
|
||
|
|
"TokenPayload",
|
||
|
|
"create_access_token",
|
||
|
|
"decode_access_token",
|
||
|
|
"generate_refresh_token",
|
||
|
|
"get_refresh_token_expiry",
|
||
|
|
"hash_password",
|
||
|
|
"hash_token",
|
||
|
|
"verify_password",
|
||
|
|
]
|