81 lines
2.0 KiB
Python
81 lines
2.0 KiB
Python
|
|
"""Tests for the get_conn dependency helper."""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from app.db import db, get_conn
|
||
|
|
|
||
|
|
|
||
|
|
pytestmark = pytest.mark.asyncio
|
||
|
|
|
||
|
|
|
||
|
|
class _FakeConnection:
|
||
|
|
def __init__(self, idx: int) -> None:
|
||
|
|
self.idx = idx
|
||
|
|
|
||
|
|
|
||
|
|
class _AcquireContext:
|
||
|
|
def __init__(self, conn: _FakeConnection, tracker: "_FakePool") -> None:
|
||
|
|
self._conn = conn
|
||
|
|
self._tracker = tracker
|
||
|
|
|
||
|
|
async def __aenter__(self) -> _FakeConnection:
|
||
|
|
self._tracker.active += 1
|
||
|
|
return self._conn
|
||
|
|
|
||
|
|
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||
|
|
self._tracker.active -= 1
|
||
|
|
|
||
|
|
|
||
|
|
class _FakePool:
|
||
|
|
def __init__(self) -> None:
|
||
|
|
self.acquire_calls = 0
|
||
|
|
self.active = 0
|
||
|
|
|
||
|
|
def acquire(self) -> _AcquireContext:
|
||
|
|
conn = _FakeConnection(self.acquire_calls)
|
||
|
|
self.acquire_calls += 1
|
||
|
|
return _AcquireContext(conn, self)
|
||
|
|
|
||
|
|
|
||
|
|
async def _collect_single_connection():
|
||
|
|
connection = None
|
||
|
|
async for conn in get_conn():
|
||
|
|
connection = conn
|
||
|
|
return connection
|
||
|
|
|
||
|
|
|
||
|
|
async def test_get_conn_reuses_connection_within_scope():
|
||
|
|
original_pool = db.pool
|
||
|
|
fake_pool = _FakePool()
|
||
|
|
db.pool = fake_pool
|
||
|
|
try:
|
||
|
|
captured: list[_FakeConnection] = []
|
||
|
|
|
||
|
|
async for outer in get_conn():
|
||
|
|
captured.append(outer)
|
||
|
|
async for inner in get_conn():
|
||
|
|
captured.append(inner)
|
||
|
|
|
||
|
|
assert len(captured) == 2
|
||
|
|
assert captured[0] is captured[1]
|
||
|
|
assert fake_pool.acquire_calls == 1
|
||
|
|
finally:
|
||
|
|
db.pool = original_pool
|
||
|
|
|
||
|
|
|
||
|
|
async def test_get_conn_acquires_new_connection_per_root_scope():
|
||
|
|
original_pool = db.pool
|
||
|
|
fake_pool = _FakePool()
|
||
|
|
db.pool = fake_pool
|
||
|
|
try:
|
||
|
|
first = await _collect_single_connection()
|
||
|
|
second = await _collect_single_connection()
|
||
|
|
|
||
|
|
assert first is not None and second is not None
|
||
|
|
assert first is not second
|
||
|
|
assert fake_pool.acquire_calls == 2
|
||
|
|
finally:
|
||
|
|
db.pool = original_pool
|