"""Database connection management using asyncpg.""" from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from contextvars import ContextVar import asyncpg from asyncpg.pool import PoolConnectionProxy import redis.asyncio as redis class Database: """Manages asyncpg connection pool.""" pool: asyncpg.Pool | None = None async def connect(self, dsn: str) -> None: """Create connection pool.""" self.pool = await asyncpg.create_pool( dsn, min_size=5, max_size=20, command_timeout=60, ) async def disconnect(self) -> None: """Close connection pool.""" if self.pool: await self.pool.close() @asynccontextmanager async def connection(self) -> AsyncGenerator[asyncpg.Connection | PoolConnectionProxy, None]: """Acquire a connection from the pool.""" if not self.pool: raise RuntimeError("Database not connected") async with self.pool.acquire() as conn: yield conn @asynccontextmanager async def transaction(self) -> AsyncGenerator[asyncpg.Connection | PoolConnectionProxy, None]: """Acquire a connection with an active transaction.""" if not self.pool: raise RuntimeError("Database not connected") async with self.pool.acquire() as conn: async with conn.transaction(): yield conn class RedisClient: """Manages Redis connection.""" client: redis.Redis | None = None async def connect(self, url: str) -> None: """Create Redis connection.""" self.client = redis.from_url(url, decode_responses=True) async def disconnect(self) -> None: """Close Redis connection.""" if self.client: await self.client.aclose() async def ping(self) -> bool: """Check if Redis is reachable.""" if not self.client: return False try: await self.client.ping() return True except redis.RedisError: return False # Global instances db = Database() redis_client = RedisClient() _connection_ctx: ContextVar[asyncpg.Connection | PoolConnectionProxy | None] = ContextVar( "db_connection", default=None, ) async def get_conn() -> AsyncGenerator[asyncpg.Connection | PoolConnectionProxy, None]: """Dependency that reuses the same DB connection within a request context.""" existing_conn = _connection_ctx.get() if existing_conn is not None: yield existing_conn return if not db.pool: raise RuntimeError("Database not connected") async with db.pool.acquire() as conn: token = _connection_ctx.set(conn) try: yield conn finally: _connection_ctx.reset(token)