2025-11-21 12:00:00 +00:00
|
|
|
"""Database connection management using asyncpg."""
|
|
|
|
|
|
|
|
|
|
from collections.abc import AsyncGenerator
|
|
|
|
|
from contextlib import asynccontextmanager
|
2025-12-29 09:55:30 +00:00
|
|
|
from contextvars import ContextVar
|
2025-11-21 12:00:00 +00:00
|
|
|
|
|
|
|
|
import asyncpg
|
2025-12-29 09:55:30 +00:00
|
|
|
from asyncpg.pool import PoolConnectionProxy
|
2025-11-21 12:00:00 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class Database:
|
|
|
|
|
"""Manages asyncpg connection pool."""
|
|
|
|
|
|
|
|
|
|
pool: asyncpg.Pool | None = None
|
|
|
|
|
|
|
|
|
|
async def connect(self, dsn: str) -> None:
|
|
|
|
|
"""Create connection pool."""
|
|
|
|
|
self.pool = await asyncpg.create_pool(
|
|
|
|
|
dsn,
|
|
|
|
|
min_size=5,
|
|
|
|
|
max_size=20,
|
|
|
|
|
command_timeout=60,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def disconnect(self) -> None:
|
|
|
|
|
"""Close connection pool."""
|
|
|
|
|
if self.pool:
|
|
|
|
|
await self.pool.close()
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager
|
2025-12-29 09:55:30 +00:00
|
|
|
async def connection(self) -> AsyncGenerator[asyncpg.Connection | PoolConnectionProxy, None]:
|
2025-11-21 12:00:00 +00:00
|
|
|
"""Acquire a connection from the pool."""
|
|
|
|
|
if not self.pool:
|
|
|
|
|
raise RuntimeError("Database not connected")
|
|
|
|
|
async with self.pool.acquire() as conn:
|
|
|
|
|
yield conn
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager
|
2025-12-29 09:55:30 +00:00
|
|
|
async def transaction(self) -> AsyncGenerator[asyncpg.Connection | PoolConnectionProxy, None]:
|
2025-11-21 12:00:00 +00:00
|
|
|
"""Acquire a connection with an active transaction."""
|
|
|
|
|
if not self.pool:
|
|
|
|
|
raise RuntimeError("Database not connected")
|
|
|
|
|
async with self.pool.acquire() as conn:
|
|
|
|
|
async with conn.transaction():
|
|
|
|
|
yield conn
|
|
|
|
|
|
|
|
|
|
|
2026-01-07 20:51:13 -05:00
|
|
|
# Global instance
|
2025-11-21 12:00:00 +00:00
|
|
|
db = Database()
|
|
|
|
|
|
|
|
|
|
|
2025-12-29 09:55:30 +00:00
|
|
|
_connection_ctx: ContextVar[asyncpg.Connection | PoolConnectionProxy | None] = ContextVar(
|
|
|
|
|
"db_connection",
|
|
|
|
|
default=None,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_conn() -> AsyncGenerator[asyncpg.Connection | PoolConnectionProxy, None]:
|
|
|
|
|
"""Dependency that reuses the same DB connection within a request context."""
|
|
|
|
|
|
|
|
|
|
existing_conn = _connection_ctx.get()
|
|
|
|
|
if existing_conn is not None:
|
|
|
|
|
yield existing_conn
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if not db.pool:
|
|
|
|
|
raise RuntimeError("Database not connected")
|
|
|
|
|
|
|
|
|
|
async with db.pool.acquire() as conn:
|
|
|
|
|
token = _connection_ctx.set(conn)
|
|
|
|
|
try:
|
|
|
|
|
yield conn
|
|
|
|
|
finally:
|
|
|
|
|
_connection_ctx.reset(token)
|