from itertools import chain, zip_longest
from typing import Any, Union, cast
from unittest.mock import Mock, _Call
from hamcrest.core.base_matcher import BaseMatcher
from hamcrest.core.description import Description
from hamcrest.core.helpers.wrap_matcher import wrap_matcher
from hamcrest.core.matcher import Matcher
[docs]
class CallHasPositionalArg(BaseMatcher[_Call]):
def __init__(self, index: int, expected: Any) -> None:
super().__init__()
self.index = index
self.expected = wrap_matcher(expected)
def _matches(self, actual_call: _Call) -> bool:
args = actual_call[1]
return len(args) > self.index and self.expected.matches(args[self.index])
[docs]
def describe_to(self, description: Description) -> None:
description.append_text("mock.call with argument index ").append_description_of(self.index).append_text(
" matching ",
)
self.expected.describe_to(description)
[docs]
def describe_mismatch(self, actual_call: _Call, mismatch_description: Description) -> None:
args = actual_call[1]
if len(args) > self.index:
mismatch_description.append_text("got mock.call with argument index ").append_description_of(
self.index,
).append_text(" with value ").append_description_of(args[self.index])
else:
mismatch_description.append_text("got mock.call with without argument index ").append_description_of(
self.index,
)
[docs]
class CallHasKeywordArg(BaseMatcher[_Call]):
def __init__(self, key: str, expected: Any) -> None:
super().__init__()
self.key = key
self.expected = wrap_matcher(expected)
def _matches(self, actual_call: _Call) -> bool:
args = actual_call[2]
return self.key in args and self.expected.matches(args[self.key])
[docs]
def describe_to(self, description: Description) -> None:
description.append_text("mock.call with keyword argument ").append_description_of(self.key).append_text(
" matching ",
)
self.expected.describe_to(description)
[docs]
def describe_mismatch(self, actual_call: _Call, mismatch_description: Description) -> None:
args = actual_call[2]
if self.key in args:
mismatch_description.append_text("got mock.call with keyword argument ").append_description_of(
self.key,
).append_text(" with value ").append_description_of(args[self.key])
else:
mismatch_description.append_text("got mock.call with without keyword argument ").append_description_of(
self.key,
)
[docs]
class HasCall(BaseMatcher[Mock]):
def __init__(self, call_matcher: Matcher) -> None:
super().__init__()
self.call_matcher = call_matcher
def _matches(self, mock: Mock) -> bool:
return any(self.call_matcher.matches(call) for call in mock.mock_calls)
[docs]
def describe_to(self, description: Description) -> None:
description.append_text("has call matching ")
self.call_matcher.describe_to(description)
[docs]
def describe_mismatch(self, mock: Mock, mismatch_description: Description) -> None:
mismatch_description.append_list("got calls [", ", ", "]", [str(c) for c in mock.mock_calls])
[docs]
class CallHasArgs(BaseMatcher[_Call]):
def __init__(self, *args, **kwargs) -> None:
super().__init__()
self.args = [wrap_matcher(arg) for arg in args]
self.kwargs = {key: wrap_matcher(value) for key, value in kwargs.items()}
def _matches(self, actual_call: _Call) -> bool:
actual_positional = actual_call[1]
actual_keyword = actual_call[2]
return all(m.matches(a) for m, a in zip_longest(self.args, actual_positional) if m is not None) and all(
m.matches(actual_keyword.get(k, None)) for k, m in self.kwargs.items()
)
[docs]
def describe_to(self, description: Description) -> None:
description.append_text("mock.call with arguments (").append_text(
", ".join(chain((str(a) for a in self.args), (f"{k}={v}" for k, v in self.kwargs.items()))),
).append_text(")")
[docs]
def describe_mismatch(self, call: _Call, mismatch_description: Description) -> None:
mismatch_description.append_text("got arguments (").append_text(
", ".join(chain((repr(a) for a in call[1]), (f"{k}={v!r}" for k, v in call[2].items()))),
).append_text(")")
[docs]
def call_has_arg(arg: Union[int, str], expected: Any) -> BaseMatcher[_Call]:
"""Matches a ``mock.call`` if a specific positional or keyword argument satisfies the matcher.
:param arg: If an integer, refers to the index of a positional argument.
If a string, refers to the name of a keyword argument.
:param expected: The expected value or matcher for that argument.
"""
if isinstance(arg, int):
return CallHasPositionalArg(cast("int", arg), expected)
return CallHasKeywordArg(cast("str", arg), expected)
[docs]
def has_call(call_matcher: Matcher) -> HasCall:
"""Matches a ``unittest.mock.Mock`` object if any of its calls satisfy the given matcher.
:param call_matcher: A matcher that validates a single ``mock.call`` object
(e.g., created by ``call_has_arg`` or ``call_has_args``).
"""
return HasCall(call_matcher)
[docs]
def call_has_args(*args, **kwargs) -> CallHasArgs:
"""Matches a ``mock.call`` if it matches all provided positional and keyword arguments.
The match is loose for keyword arguments (only specified keys are checked), but
positional arguments are checked in order.
:param args: Expected values or matchers for positional arguments.
:param kwargs: Expected values or matchers for keyword arguments.
"""
return CallHasArgs(*args, **kwargs)