Skip to content

Commit

Permalink
typing: Improve FixtureDefinition and FixtureDef
Browse files Browse the repository at this point in the history
* Carry around parameters and return value in `FixtureFunctionDefinition`.
* Add `FixtureParams` to `FixtureDef`.

Follow up to #12473.
  • Loading branch information
nicoddemus committed Dec 7, 2024
1 parent ecde993 commit ec71fce
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 56 deletions.
108 changes: 63 additions & 45 deletions src/_pytest/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup

if sys.version_info < (3, 10):
from typing_extensions import ParamSpec
from typing_extensions import TypeAlias
else:
from typing import ParamSpec
from typing import TypeAlias


if TYPE_CHECKING:
from _pytest.python import CallSpec2
Expand All @@ -84,14 +91,17 @@

# The value of the fixture -- return/yield of the fixture function (type variable).
FixtureValue = TypeVar("FixtureValue")
# The type of the fixture function (type variable).
FixtureFunction = TypeVar("FixtureFunction", bound=Callable[..., object])
# The type of a fixture function (type alias generic in fixture value).
_FixtureFunc = Union[
Callable[..., FixtureValue], Callable[..., Generator[FixtureValue]]

# The parameters that a fixture function receives.
FixtureParams = ParamSpec("FixtureParams")

# The type of fixture function (type alias generic in fixture params and value).
_FixtureFunc: TypeAlias = Union[
Callable[FixtureParams, FixtureValue],
Callable[FixtureParams, Generator[FixtureValue, None, None]],
]
# The type of FixtureDef.cached_result (type alias generic in fixture value).
_FixtureCachedResult = Union[
_FixtureCachedResult: TypeAlias = Union[
tuple[
# The result.
FixtureValue,
Expand Down Expand Up @@ -121,7 +131,7 @@ def pytest_sessionstart(session: Session) -> None:

def get_scope_package(
node: nodes.Item,
fixturedef: FixtureDef[object],
fixturedef: FixtureDef[Any, object],
) -> nodes.Node | None:
from _pytest.python import Package

Expand Down Expand Up @@ -318,7 +328,7 @@ class FuncFixtureInfo:
# matching the name which are applicable to this function.
# There may be multiple overriding fixtures with the same name. The
# sequence is ordered from furthest to closes to the function.
name2fixturedefs: dict[str, Sequence[FixtureDef[Any]]]
name2fixturedefs: dict[str, Sequence[FixtureDef[Any, Any]]]

def prune_dependency_tree(self) -> None:
"""Recompute names_closure from initialnames and name2fixturedefs.
Expand Down Expand Up @@ -359,8 +369,8 @@ def __init__(
self,
pyfuncitem: Function,
fixturename: str | None,
arg2fixturedefs: dict[str, Sequence[FixtureDef[Any]]],
fixture_defs: dict[str, FixtureDef[Any]],
arg2fixturedefs: dict[str, Sequence[FixtureDef[Any, Any]]],
fixture_defs: dict[str, FixtureDef[Any, Any]],
*,
_ispytest: bool = False,
) -> None:
Expand Down Expand Up @@ -403,7 +413,7 @@ def scope(self) -> _ScopeName:
@abc.abstractmethod
def _check_scope(
self,
requested_fixturedef: FixtureDef[object] | PseudoFixtureDef[object],
requested_fixturedef: FixtureDef[Any, object] | PseudoFixtureDef[object],
requested_scope: Scope,
) -> None:
raise NotImplementedError()
Expand Down Expand Up @@ -544,7 +554,7 @@ def _iter_chain(self) -> Iterator[SubRequest]:

def _get_active_fixturedef(
self, argname: str
) -> FixtureDef[object] | PseudoFixtureDef[object]:
) -> FixtureDef[Any, object] | PseudoFixtureDef[object]:
if argname == "request":
cached_result = (self, [0], None)
return PseudoFixtureDef(cached_result, Scope.Function)
Expand Down Expand Up @@ -616,7 +626,9 @@ def _get_active_fixturedef(
self._fixture_defs[argname] = fixturedef
return fixturedef

def _check_fixturedef_without_param(self, fixturedef: FixtureDef[object]) -> None:
def _check_fixturedef_without_param(
self, fixturedef: FixtureDef[Any, object]
) -> None:
"""Check that this request is allowed to execute this fixturedef without
a param."""
funcitem = self._pyfuncitem
Expand Down Expand Up @@ -649,7 +661,7 @@ def _check_fixturedef_without_param(self, fixturedef: FixtureDef[object]) -> Non
)
fail(msg, pytrace=False)

def _get_fixturestack(self) -> list[FixtureDef[Any]]:
def _get_fixturestack(self) -> list[FixtureDef[Any, Any]]:
values = [request._fixturedef for request in self._iter_chain()]
values.reverse()
return values
Expand All @@ -674,7 +686,7 @@ def _scope(self) -> Scope:

def _check_scope(
self,
requested_fixturedef: FixtureDef[object] | PseudoFixtureDef[object],
requested_fixturedef: FixtureDef[Any, object] | PseudoFixtureDef[object],
requested_scope: Scope,
) -> None:
# TopRequest always has function scope so always valid.
Expand Down Expand Up @@ -708,7 +720,7 @@ def __init__(
scope: Scope,
param: Any,
param_index: int,
fixturedef: FixtureDef[object],
fixturedef: FixtureDef[Any, object],
*,
_ispytest: bool = False,
) -> None:
Expand All @@ -721,7 +733,7 @@ def __init__(
)
self._parent_request: Final[FixtureRequest] = request
self._scope_field: Final = scope
self._fixturedef: Final[FixtureDef[object]] = fixturedef
self._fixturedef: Final[FixtureDef[Any, object]] = fixturedef
if param is not NOTSET:
self.param = param
self.param_index: Final = param_index
Expand Down Expand Up @@ -751,7 +763,7 @@ def node(self):

def _check_scope(
self,
requested_fixturedef: FixtureDef[object] | PseudoFixtureDef[object],
requested_fixturedef: FixtureDef[Any, object] | PseudoFixtureDef[object],
requested_scope: Scope,
) -> None:
if isinstance(requested_fixturedef, PseudoFixtureDef):
Expand All @@ -772,7 +784,7 @@ def _check_scope(
pytrace=False,
)

def _format_fixturedef_line(self, fixturedef: FixtureDef[object]) -> str:
def _format_fixturedef_line(self, fixturedef: FixtureDef[Any, object]) -> str:
factory = fixturedef.func
path, lineno = getfslineno(factory)
if isinstance(path, Path):
Expand Down Expand Up @@ -886,7 +898,9 @@ def toterminal(self, tw: TerminalWriter) -> None:


def call_fixture_func(
fixturefunc: _FixtureFunc[FixtureValue], request: FixtureRequest, kwargs
fixturefunc: _FixtureFunc[FixtureParams, FixtureValue],
request: FixtureRequest,
kwargs: FixtureParams.kwargs,
) -> FixtureValue:
if inspect.isgeneratorfunction(fixturefunc):
fixturefunc = cast(Callable[..., Generator[FixtureValue]], fixturefunc)
Expand Down Expand Up @@ -945,9 +959,11 @@ def _eval_scope_callable(


@final
class FixtureDef(Generic[FixtureValue]):
class FixtureDef(Generic[FixtureParams, FixtureValue]):
"""A container for a fixture definition.
This is a generic class parametrized on the parameters that a fixture function receives and its return value.
Note: At this time, only explicitly documented fields and methods are
considered public stable API.
"""
Expand All @@ -957,7 +973,7 @@ def __init__(
config: Config,
baseid: str | None,
argname: str,
func: _FixtureFunc[FixtureValue],
func: _FixtureFunc[FixtureParams, FixtureValue],
scope: Scope | _ScopeName | Callable[[str, Config], _ScopeName] | None,
params: Sequence[object] | None,
ids: tuple[object | None, ...] | Callable[[Any], object | None] | None = None,
Expand Down Expand Up @@ -1112,8 +1128,8 @@ def __repr__(self) -> str:


def resolve_fixture_function(
fixturedef: FixtureDef[FixtureValue], request: FixtureRequest
) -> _FixtureFunc[FixtureValue]:
fixturedef: FixtureDef[FixtureParams, FixtureValue], request: FixtureRequest
) -> _FixtureFunc[FixtureParams, FixtureValue]:
"""Get the actual callable that can be called to obtain the fixture
value."""
fixturefunc = fixturedef.func
Expand All @@ -1136,7 +1152,7 @@ def resolve_fixture_function(


def pytest_fixture_setup(
fixturedef: FixtureDef[FixtureValue], request: SubRequest
fixturedef: FixtureDef[FixtureParams, FixtureValue], request: SubRequest
) -> FixtureValue:
"""Execution of fixture setup."""
kwargs = {}
Expand Down Expand Up @@ -1192,7 +1208,9 @@ class FixtureFunctionMarker:
def __post_init__(self, _ispytest: bool) -> None:
check_ispytest(_ispytest)

def __call__(self, function: FixtureFunction) -> FixtureFunctionDefinition:
def __call__(
self, function: Callable[FixtureParams, FixtureValue]
) -> FixtureFunctionDefinition[FixtureParams, FixtureValue]:
if inspect.isclass(function):
raise ValueError("class fixtures not supported (maybe in the future)")

Expand All @@ -1219,12 +1237,10 @@ def __call__(self, function: FixtureFunction) -> FixtureFunctionDefinition:
return fixture_definition


# TODO: paramspec/return type annotation tracking and storing
class FixtureFunctionDefinition:
class FixtureFunctionDefinition(Generic[FixtureParams, FixtureValue]):
def __init__(
self,
*,
function: Callable[..., Any],
function: Callable[FixtureParams, FixtureValue],
fixture_function_marker: FixtureFunctionMarker,
instance: object | None = None,
_ispytest: bool = False,
Expand All @@ -1237,7 +1253,7 @@ def __init__(
self._fixture_function_marker = fixture_function_marker
if instance is not None:
self._fixture_function = cast(
Callable[..., Any], function.__get__(instance)
Callable[FixtureParams, FixtureValue], function.__get__(instance)
)
else:
self._fixture_function = function
Expand All @@ -1246,12 +1262,14 @@ def __init__(
def __repr__(self) -> str:
return f"<pytest_fixture({self._fixture_function})>"

def __get__(self, instance, owner=None):
def __get__(
self, obj: object, objtype: type | None = None
) -> FixtureFunctionDefinition[FixtureParams, FixtureValue]:
"""Behave like a method if the function it was applied to was a method."""
return FixtureFunctionDefinition(
function=self._fixture_function,
fixture_function_marker=self._fixture_function_marker,
instance=instance,
instance=obj,
_ispytest=True,
)

Expand All @@ -1270,14 +1288,14 @@ def _get_wrapped_function(self) -> Callable[..., Any]:

@overload
def fixture(
fixture_function: Callable[..., object],
fixture_function: Callable[FixtureParams, FixtureValue],
*,
scope: _ScopeName | Callable[[str, Config], _ScopeName] = ...,
params: Iterable[object] | None = ...,
autouse: bool = ...,
ids: Sequence[object | None] | Callable[[Any], object | None] | None = ...,
name: str | None = ...,
) -> FixtureFunctionDefinition: ...
) -> FixtureFunctionDefinition[FixtureParams, FixtureValue]: ...


@overload
Expand All @@ -1293,14 +1311,14 @@ def fixture(


def fixture(
fixture_function: FixtureFunction | None = None,
fixture_function: Callable[FixtureParams, FixtureValue] | None = None,
*,
scope: _ScopeName | Callable[[str, Config], _ScopeName] = "function",
params: Iterable[object] | None = None,
autouse: bool = False,
ids: Sequence[object | None] | Callable[[Any], object | None] | None = None,
name: str | None = None,
) -> FixtureFunctionMarker | FixtureFunctionDefinition:
) -> FixtureFunctionMarker | FixtureFunctionDefinition[FixtureParams, FixtureValue]:
"""Decorator to mark a fixture factory function.
This decorator can be used, with or without parameters, to define a
Expand Down Expand Up @@ -1507,7 +1525,7 @@ def __init__(self, session: Session) -> None:
# suite/plugins defined with this name. Populated by parsefactories().
# TODO: The order of the FixtureDefs list of each arg is significant,
# explain.
self._arg2fixturedefs: Final[dict[str, list[FixtureDef[Any]]]] = {}
self._arg2fixturedefs: Final[dict[str, list[FixtureDef[Any, Any]]]] = {}
self._holderobjseen: Final[set[object]] = set()
# A mapping from a nodeid to a list of autouse fixtures it defines.
self._nodeid_autousenames: Final[dict[str, list[str]]] = {
Expand Down Expand Up @@ -1598,7 +1616,7 @@ def getfixtureclosure(
parentnode: nodes.Node,
initialnames: tuple[str, ...],
ignore_args: AbstractSet[str],
) -> tuple[list[str], dict[str, Sequence[FixtureDef[Any]]]]:
) -> tuple[list[str], dict[str, Sequence[FixtureDef[Any, Any]]]]:
# Collect the closure of all fixtures, starting with the given
# fixturenames as the initial set. As we have to visit all
# factory definitions anyway, we also return an arg2fixturedefs
Expand All @@ -1608,7 +1626,7 @@ def getfixtureclosure(

fixturenames_closure = list(initialnames)

arg2fixturedefs: dict[str, Sequence[FixtureDef[Any]]] = {}
arg2fixturedefs: dict[str, Sequence[FixtureDef[Any, Any]]] = {}
lastlen = -1
while lastlen != len(fixturenames_closure):
lastlen = len(fixturenames_closure)
Expand Down Expand Up @@ -1688,7 +1706,7 @@ def _register_fixture(
self,
*,
name: str,
func: _FixtureFunc[object],
func: _FixtureFunc[Any, object],
nodeid: str | None,
scope: Scope | _ScopeName | Callable[[str, Config], _ScopeName] = "function",
params: Sequence[object] | None = None,
Expand Down Expand Up @@ -1823,7 +1841,7 @@ def parsefactories(

def getfixturedefs(
self, argname: str, node: nodes.Node
) -> Sequence[FixtureDef[Any]] | None:
) -> Sequence[FixtureDef[Any, Any]] | None:
"""Get FixtureDefs for a fixture name which are applicable
to a given node.
Expand All @@ -1842,8 +1860,8 @@ def getfixturedefs(
return tuple(self._matchfactories(fixturedefs, node))

def _matchfactories(
self, fixturedefs: Iterable[FixtureDef[Any]], node: nodes.Node
) -> Iterator[FixtureDef[Any]]:
self, fixturedefs: Iterable[FixtureDef[Any, Any]], node: nodes.Node
) -> Iterator[FixtureDef[Any, Any]]:
parentnodeids = {n.nodeid for n in node.iter_parents()}
for fixturedef in fixturedefs:
if fixturedef.baseid in parentnodeids:
Expand Down Expand Up @@ -1880,7 +1898,7 @@ def get_best_relpath(func) -> str:
loc = getlocation(func, invocation_dir)
return bestrelpath(invocation_dir, Path(loc))

def write_fixture(fixture_def: FixtureDef[object]) -> None:
def write_fixture(fixture_def: FixtureDef[Any, object]) -> None:
argname = fixture_def.argname
if verbose <= 0 and argname.startswith("_"):
return
Expand Down
4 changes: 2 additions & 2 deletions src/_pytest/hookspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ def pytest_report_from_serializable(

@hookspec(firstresult=True)
def pytest_fixture_setup(
fixturedef: FixtureDef[Any], request: SubRequest
fixturedef: FixtureDef[Any, Any], request: SubRequest
) -> object | None:
"""Perform fixture setup execution.
Expand Down Expand Up @@ -894,7 +894,7 @@ def pytest_fixture_setup(


def pytest_fixture_post_finalizer(
fixturedef: FixtureDef[Any], request: SubRequest
fixturedef: FixtureDef[Any, Any], request: SubRequest
) -> None:
"""Called after fixture teardown, but before the cache is cleared, so
the fixture result ``fixturedef.cached_result`` is still available (not
Expand Down
6 changes: 3 additions & 3 deletions src/_pytest/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,7 @@ def get_direct_param_fixture_func(request: FixtureRequest) -> Any:


# Used for storing pseudo fixturedefs for direct parametrization.
name2pseudofixturedef_key = StashKey[dict[str, FixtureDef[Any]]]()
name2pseudofixturedef_key = StashKey[dict[str, FixtureDef[Any, Any]]]()


@final
Expand Down Expand Up @@ -1271,7 +1271,7 @@ def parametrize(
if node is None:
name2pseudofixturedef = None
else:
default: dict[str, FixtureDef[Any]] = {}
default: dict[str, FixtureDef[Any, Any]] = {}
name2pseudofixturedef = node.stash.setdefault(
name2pseudofixturedef_key, default
)
Expand Down Expand Up @@ -1458,7 +1458,7 @@ def _recompute_direct_params_indices(self) -> None:

def _find_parametrized_scope(
argnames: Sequence[str],
arg2fixturedefs: Mapping[str, Sequence[fixtures.FixtureDef[object]]],
arg2fixturedefs: Mapping[str, Sequence[fixtures.FixtureDef[Any, object]]],
indirect: bool | Sequence[str],
) -> Scope:
"""Find the most appropriate scope for a parametrized call based on its arguments.
Expand Down
Loading

0 comments on commit ec71fce

Please sign in to comment.