270 lines
9.5 KiB
Python
270 lines
9.5 KiB
Python
|
|
"""Authentication service providing business logic for auth flows."""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import re
|
||
|
|
from typing import cast
|
||
|
|
from uuid import UUID, uuid4
|
||
|
|
|
||
|
|
import asyncpg
|
||
|
|
from asyncpg.pool import PoolConnectionProxy
|
||
|
|
|
||
|
|
from app.api.deps import CurrentUser
|
||
|
|
from app.config import settings
|
||
|
|
from app.core import exceptions as exc, security
|
||
|
|
from app.db import Database, db
|
||
|
|
from app.repositories import OrgRepository, RefreshTokenRepository, UserRepository
|
||
|
|
from app.schemas.auth import (
|
||
|
|
LoginRequest,
|
||
|
|
LogoutRequest,
|
||
|
|
RefreshRequest,
|
||
|
|
RegisterRequest,
|
||
|
|
SwitchOrgRequest,
|
||
|
|
TokenResponse,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
_SLUG_PATTERN = re.compile(r"[^a-z0-9]+")
|
||
|
|
|
||
|
|
|
||
|
|
def _as_conn(conn: asyncpg.Connection | PoolConnectionProxy) -> asyncpg.Connection:
|
||
|
|
"""Helper to satisfy typing when a pool proxy is returned."""
|
||
|
|
|
||
|
|
return cast(asyncpg.Connection, conn)
|
||
|
|
|
||
|
|
|
||
|
|
class AuthService:
|
||
|
|
"""Encapsulates authentication workflows (register/login/refresh/logout)."""
|
||
|
|
|
||
|
|
def __init__(self, database: Database | None = None) -> None:
|
||
|
|
self.db = database or db
|
||
|
|
self._access_token_expires_in = settings.access_token_expire_minutes * 60
|
||
|
|
|
||
|
|
async def register_user(self, data: RegisterRequest) -> TokenResponse:
|
||
|
|
"""Create a new user, default org, membership, and token pair."""
|
||
|
|
|
||
|
|
async with self.db.transaction() as conn:
|
||
|
|
db_conn = _as_conn(conn)
|
||
|
|
user_repo = UserRepository(db_conn)
|
||
|
|
org_repo = OrgRepository(db_conn)
|
||
|
|
refresh_repo = RefreshTokenRepository(db_conn)
|
||
|
|
|
||
|
|
if await user_repo.exists_by_email(data.email):
|
||
|
|
raise exc.ConflictError("Email already registered")
|
||
|
|
|
||
|
|
user_id = uuid4()
|
||
|
|
org_id = uuid4()
|
||
|
|
member_id = uuid4()
|
||
|
|
password_hash = security.hash_password(data.password)
|
||
|
|
|
||
|
|
await user_repo.create(user_id, data.email, password_hash)
|
||
|
|
slug = await self._generate_unique_org_slug(org_repo, data.org_name)
|
||
|
|
await org_repo.create(org_id, data.org_name, slug)
|
||
|
|
await org_repo.add_member(member_id, user_id, org_id, "admin")
|
||
|
|
|
||
|
|
return await self._issue_token_pair(
|
||
|
|
refresh_repo,
|
||
|
|
user_id=user_id,
|
||
|
|
org_id=org_id,
|
||
|
|
role="admin",
|
||
|
|
)
|
||
|
|
|
||
|
|
async def login_user(self, data: LoginRequest) -> TokenResponse:
|
||
|
|
"""Authenticate a user and issue tokens for their first organization."""
|
||
|
|
|
||
|
|
async with self.db.connection() as conn:
|
||
|
|
db_conn = _as_conn(conn)
|
||
|
|
user_repo = UserRepository(db_conn)
|
||
|
|
org_repo = OrgRepository(db_conn)
|
||
|
|
refresh_repo = RefreshTokenRepository(db_conn)
|
||
|
|
|
||
|
|
user = await user_repo.get_by_email(data.email)
|
||
|
|
if not user or not security.verify_password(data.password, user["password_hash"]):
|
||
|
|
raise exc.UnauthorizedError("Invalid email or password")
|
||
|
|
|
||
|
|
orgs = await org_repo.get_user_orgs(user["id"])
|
||
|
|
if not orgs:
|
||
|
|
raise exc.ForbiddenError("User does not belong to any organization")
|
||
|
|
|
||
|
|
active_org = orgs[0]
|
||
|
|
return await self._issue_token_pair(
|
||
|
|
refresh_repo,
|
||
|
|
user_id=user["id"],
|
||
|
|
org_id=active_org["id"],
|
||
|
|
role=active_org["role"],
|
||
|
|
)
|
||
|
|
|
||
|
|
async def refresh_tokens(self, data: RefreshRequest) -> TokenResponse:
|
||
|
|
"""Rotate refresh token and mint a new access token."""
|
||
|
|
|
||
|
|
old_hash = security.hash_token(data.refresh_token)
|
||
|
|
new_refresh_token = security.generate_refresh_token()
|
||
|
|
new_refresh_hash = security.hash_token(new_refresh_token)
|
||
|
|
new_refresh_id = uuid4()
|
||
|
|
new_refresh_expiry = security.get_refresh_token_expiry()
|
||
|
|
|
||
|
|
rotated: dict | None = None
|
||
|
|
membership: dict | None = None
|
||
|
|
|
||
|
|
async with self.db.transaction() as conn:
|
||
|
|
db_conn = _as_conn(conn)
|
||
|
|
refresh_repo = RefreshTokenRepository(db_conn)
|
||
|
|
rotated = await refresh_repo.rotate(
|
||
|
|
old_token_hash=old_hash,
|
||
|
|
new_token_id=new_refresh_id,
|
||
|
|
new_token_hash=new_refresh_hash,
|
||
|
|
new_expires_at=new_refresh_expiry,
|
||
|
|
)
|
||
|
|
|
||
|
|
if rotated is not None:
|
||
|
|
org_repo = OrgRepository(db_conn)
|
||
|
|
membership = await org_repo.get_member(rotated["user_id"], rotated["active_org_id"])
|
||
|
|
if membership is None:
|
||
|
|
raise exc.UnauthorizedError("Invalid refresh token")
|
||
|
|
|
||
|
|
if rotated is None or membership is None:
|
||
|
|
await self._handle_invalid_refresh(old_hash)
|
||
|
|
|
||
|
|
assert rotated is not None and membership is not None
|
||
|
|
access_token = security.create_access_token(
|
||
|
|
sub=str(rotated["user_id"]),
|
||
|
|
org_id=str(rotated["active_org_id"]),
|
||
|
|
org_role=membership["role"],
|
||
|
|
)
|
||
|
|
|
||
|
|
return TokenResponse(
|
||
|
|
access_token=access_token,
|
||
|
|
refresh_token=new_refresh_token,
|
||
|
|
expires_in=self._access_token_expires_in,
|
||
|
|
)
|
||
|
|
|
||
|
|
async def switch_org(
|
||
|
|
self,
|
||
|
|
current_user: CurrentUser,
|
||
|
|
data: SwitchOrgRequest,
|
||
|
|
) -> TokenResponse:
|
||
|
|
"""Switch active organization (rotates refresh token + issues new JWT)."""
|
||
|
|
|
||
|
|
target_org_id = data.org_id
|
||
|
|
old_hash = security.hash_token(data.refresh_token)
|
||
|
|
new_refresh_token = security.generate_refresh_token()
|
||
|
|
new_refresh_hash = security.hash_token(new_refresh_token)
|
||
|
|
new_refresh_expiry = security.get_refresh_token_expiry()
|
||
|
|
|
||
|
|
rotated: dict | None = None
|
||
|
|
membership: dict | None = None
|
||
|
|
|
||
|
|
async with self.db.transaction() as conn:
|
||
|
|
db_conn = _as_conn(conn)
|
||
|
|
org_repo = OrgRepository(db_conn)
|
||
|
|
membership = await org_repo.get_member(current_user.user_id, target_org_id)
|
||
|
|
if membership is None:
|
||
|
|
raise exc.ForbiddenError("Not a member of the requested organization")
|
||
|
|
|
||
|
|
refresh_repo = RefreshTokenRepository(db_conn)
|
||
|
|
rotated = await refresh_repo.rotate(
|
||
|
|
old_token_hash=old_hash,
|
||
|
|
new_token_id=uuid4(),
|
||
|
|
new_token_hash=new_refresh_hash,
|
||
|
|
new_expires_at=new_refresh_expiry,
|
||
|
|
new_active_org_id=target_org_id,
|
||
|
|
expected_user_id=current_user.user_id,
|
||
|
|
)
|
||
|
|
|
||
|
|
if rotated is None:
|
||
|
|
await self._handle_invalid_refresh(old_hash)
|
||
|
|
|
||
|
|
access_token = security.create_access_token(
|
||
|
|
sub=str(current_user.user_id),
|
||
|
|
org_id=str(target_org_id),
|
||
|
|
org_role=membership["role"],
|
||
|
|
)
|
||
|
|
|
||
|
|
return TokenResponse(
|
||
|
|
access_token=access_token,
|
||
|
|
refresh_token=new_refresh_token,
|
||
|
|
expires_in=self._access_token_expires_in,
|
||
|
|
)
|
||
|
|
|
||
|
|
async def logout(self, current_user: CurrentUser, data: LogoutRequest) -> None:
|
||
|
|
"""Revoke the provided refresh token for the current session."""
|
||
|
|
|
||
|
|
token_hash = security.hash_token(data.refresh_token)
|
||
|
|
|
||
|
|
async with self.db.transaction() as conn:
|
||
|
|
refresh_repo = RefreshTokenRepository(_as_conn(conn))
|
||
|
|
token = await refresh_repo.get_by_hash(token_hash)
|
||
|
|
if token and token["user_id"] != current_user.user_id:
|
||
|
|
raise exc.ForbiddenError("Refresh token does not belong to this user")
|
||
|
|
|
||
|
|
if not token:
|
||
|
|
return
|
||
|
|
|
||
|
|
await refresh_repo.revoke(token["id"])
|
||
|
|
|
||
|
|
async def _issue_token_pair(
|
||
|
|
self,
|
||
|
|
refresh_repo: RefreshTokenRepository,
|
||
|
|
*,
|
||
|
|
user_id: UUID,
|
||
|
|
org_id: UUID,
|
||
|
|
role: str,
|
||
|
|
) -> TokenResponse:
|
||
|
|
"""Create access/refresh tokens and persist the refresh token."""
|
||
|
|
|
||
|
|
access_token = security.create_access_token(
|
||
|
|
sub=str(user_id),
|
||
|
|
org_id=str(org_id),
|
||
|
|
org_role=role,
|
||
|
|
)
|
||
|
|
|
||
|
|
refresh_token = security.generate_refresh_token()
|
||
|
|
await refresh_repo.create(
|
||
|
|
token_id=uuid4(),
|
||
|
|
user_id=user_id,
|
||
|
|
token_hash=security.hash_token(refresh_token),
|
||
|
|
active_org_id=org_id,
|
||
|
|
expires_at=security.get_refresh_token_expiry(),
|
||
|
|
)
|
||
|
|
|
||
|
|
return TokenResponse(
|
||
|
|
access_token=access_token,
|
||
|
|
refresh_token=refresh_token,
|
||
|
|
expires_in=self._access_token_expires_in,
|
||
|
|
)
|
||
|
|
|
||
|
|
async def _handle_invalid_refresh(self, token_hash: str) -> None:
|
||
|
|
"""Raise appropriate errors for invalid/compromised refresh tokens."""
|
||
|
|
|
||
|
|
async with self.db.connection() as conn:
|
||
|
|
refresh_repo = RefreshTokenRepository(_as_conn(conn))
|
||
|
|
reused = await refresh_repo.check_token_reuse(token_hash)
|
||
|
|
if reused:
|
||
|
|
await refresh_repo.revoke_token_chain(reused["id"])
|
||
|
|
raise exc.UnauthorizedError("Refresh token reuse detected")
|
||
|
|
|
||
|
|
raise exc.UnauthorizedError("Invalid refresh token")
|
||
|
|
|
||
|
|
async def _generate_unique_org_slug(
|
||
|
|
self,
|
||
|
|
org_repo: OrgRepository,
|
||
|
|
org_name: str,
|
||
|
|
) -> str:
|
||
|
|
"""Slugify the org name and append a counter until unique."""
|
||
|
|
|
||
|
|
base_slug = self._slugify(org_name)
|
||
|
|
candidate = base_slug
|
||
|
|
counter = 1
|
||
|
|
while await org_repo.slug_exists(candidate):
|
||
|
|
suffix = f"-{counter}"
|
||
|
|
max_base_len = 50 - len(suffix)
|
||
|
|
candidate = f"{base_slug[:max_base_len]}{suffix}"
|
||
|
|
counter += 1
|
||
|
|
return candidate
|
||
|
|
|
||
|
|
def _slugify(self, value: str) -> str:
|
||
|
|
"""Convert arbitrary text into a URL-friendly slug."""
|
||
|
|
|
||
|
|
slug = _SLUG_PATTERN.sub("-", value.strip().lower()).strip("-")
|
||
|
|
return slug[:50] or "org"
|