"""Tests for the get_conn dependency helper.""" from __future__ import annotations import pytest from app.db import db, get_conn pytestmark = pytest.mark.asyncio class _FakeConnection: def __init__(self, idx: int) -> None: self.idx = idx class _AcquireContext: def __init__(self, conn: _FakeConnection, tracker: "_FakePool") -> None: self._conn = conn self._tracker = tracker async def __aenter__(self) -> _FakeConnection: self._tracker.active += 1 return self._conn async def __aexit__(self, exc_type, exc, tb) -> None: self._tracker.active -= 1 class _FakePool: def __init__(self) -> None: self.acquire_calls = 0 self.active = 0 def acquire(self) -> _AcquireContext: conn = _FakeConnection(self.acquire_calls) self.acquire_calls += 1 return _AcquireContext(conn, self) async def _collect_single_connection(): connection = None async for conn in get_conn(): connection = conn return connection async def test_get_conn_reuses_connection_within_scope(): original_pool = db.pool fake_pool = _FakePool() db.pool = fake_pool try: captured: list[_FakeConnection] = [] async for outer in get_conn(): captured.append(outer) async for inner in get_conn(): captured.append(inner) assert len(captured) == 2 assert captured[0] is captured[1] assert fake_pool.acquire_calls == 1 finally: db.pool = original_pool async def test_get_conn_acquires_new_connection_per_root_scope(): original_pool = db.pool fake_pool = _FakePool() db.pool = fake_pool try: first = await _collect_single_connection() second = await _collect_single_connection() assert first is not None and second is not None assert first is not second assert fake_pool.acquire_calls == 2 finally: db.pool = original_pool