Source code for brunns.matchers.dbapi

import logging
from collections.abc import Iterable
from typing import (
    Any,
    Optional,
    Protocol,  # type: ignore[attr-defined]
    cast,
)

from hamcrest import anything, described_as
from hamcrest.core.base_matcher import BaseMatcher
from hamcrest.core.description import Description
from hamcrest.core.matcher import Matcher

from brunns.row.rowwrapper import RowWrapper  # type: ignore[attr-defined]

logger = logging.getLogger(__name__)


[docs] class Cursor(Protocol): def fetchall(self) -> Iterable[tuple[Any, ...]]: # pragma: no cover ... def execute(self, statement: str): # pragma: no cover ... @property def description(self) -> Optional[tuple[tuple[str, str]]]: # pragma: no cover ...
[docs] class Connection(Protocol): def cursor(self) -> Cursor: # pragma: no cover ...
[docs] class SelectReturnsRowsMatching(BaseMatcher[Connection]): def __init__(self, select: str, row_matcher: Matcher[list[Any]]) -> None: self.select = select self.row_matcher = row_matcher def _matches(self, conn: Connection) -> bool: try: rows = self._get_rows(conn, self.select) return self.row_matcher.matches(rows) except Exception: return False @staticmethod def _get_rows(conn: Connection, select: str): cursor = conn.cursor() cursor.execute(select) wrapper = RowWrapper(cast("Any", cursor.description or ())) return [wrapper.wrap(row) for row in cursor.fetchall()]
[docs] def describe_to(self, description: Description) -> None: description.append_text("DB connection for which statement ").append_description_of(self.select).append_text( " returns rows matching ", ).append_description_of(self.row_matcher)
[docs] def describe_mismatch(self, conn: Connection, mismatch_description: Description) -> None: try: rows = self._get_rows(conn, self.select) self.row_matcher.describe_mismatch(rows, mismatch_description) except Exception as e: mismatch_description.append_text("SQL statement ").append_description_of(self.select).append_text( " gives ", ).append_description_of(type(e).__name__).append_text(" ").append_description_of(e)
[docs] def has_table(table: str) -> Matcher[Connection]: """Matches if database has table with name. :param table: Table name. """ select = f"SELECT * FROM {table};" # nosec return described_as( "DB connection has table named %0", given_select_returns_rows_matching(select, anything()), table, )
[docs] def has_table_with_rows(table: str, row_matcher: Matcher[list[Any]]) -> Matcher[Connection]: """Matches if database has table with rows matching. :param table: Table name. :param row_matcher: Row matchers. """ select = f"SELECT * FROM {table};" # nosec return described_as( "DB connection with table %0 with rows matching %1", given_select_returns_rows_matching(select, row_matcher), table, row_matcher, )
def given_select_returns_rows_matching(select: str, row_matcher: Matcher[list[Any]]) -> SelectReturnsRowsMatching: return SelectReturnsRowsMatching(select, row_matcher)