diff --git a/changelog.d/1164.added.rst b/changelog.d/1164.added.rst new file mode 100644 index 00000000..f6e6612b --- /dev/null +++ b/changelog.d/1164.added.rst @@ -0,0 +1 @@ +Added ``loop_factory`` to pytest_asyncio.fixture and asyncio mark diff --git a/pytest_asyncio/plugin.py b/pytest_asyncio/plugin.py index 9d8a59e3..d0708c87 100644 --- a/pytest_asyncio/plugin.py +++ b/pytest_asyncio/plugin.py @@ -129,6 +129,7 @@ def fixture( *, scope: _ScopeName | Callable[[str, Config], _ScopeName] = ..., loop_scope: _ScopeName | None = ..., + loop_factory: Callable[[], AbstractEventLoop] | None = ..., params: Iterable[object] | None = ..., autouse: bool = ..., ids: ( @@ -146,6 +147,7 @@ def fixture( *, scope: _ScopeName | Callable[[str, Config], _ScopeName] = ..., loop_scope: _ScopeName | None = ..., + loop_factory: Callable[[], AbstractEventLoop] | None = ..., params: Iterable[object] | None = ..., autouse: bool = ..., ids: ( @@ -160,20 +162,26 @@ def fixture( def fixture( fixture_function: FixtureFunction[_P, _R] | None = None, loop_scope: _ScopeName | None = None, + loop_factory: Callable[[], AbstractEventLoop] | None = None, **kwargs: Any, ) -> ( FixtureFunction[_P, _R] | Callable[[FixtureFunction[_P, _R]], FixtureFunction[_P, _R]] ): if fixture_function is not None: - _make_asyncio_fixture_function(fixture_function, loop_scope) + _make_asyncio_fixture_function(fixture_function, loop_scope, loop_factory) return pytest.fixture(fixture_function, **kwargs) else: @functools.wraps(fixture) def inner(fixture_function: FixtureFunction[_P, _R]) -> FixtureFunction[_P, _R]: - return fixture(fixture_function, loop_scope=loop_scope, **kwargs) + return fixture( + fixture_function, + loop_factory=loop_factory, + loop_scope=loop_scope, + **kwargs, + ) return inner @@ -183,12 +191,17 @@ def _is_asyncio_fixture_function(obj: Any) -> bool: return getattr(obj, "_force_asyncio_fixture", False) -def _make_asyncio_fixture_function(obj: Any, loop_scope: _ScopeName | None) -> None: +def _make_asyncio_fixture_function( + obj: Any, + loop_scope: _ScopeName | None, + loop_factory: Callable[[], AbstractEventLoop] | None, +) -> None: if hasattr(obj, "__func__"): # instance method, check the function object obj = obj.__func__ obj._force_asyncio_fixture = True obj._loop_scope = loop_scope + obj._loop_factory = loop_factory def _is_coroutine_or_asyncgen(obj: Any) -> bool: @@ -260,7 +273,9 @@ def pytest_report_header(config: Config) -> list[str]: def _fixture_synchronizer( - fixturedef: FixtureDef, runner: Runner, request: FixtureRequest + fixturedef: FixtureDef, + runner: Runner, + request: FixtureRequest, ) -> Callable: """Returns a synchronous function evaluating the specified fixture.""" fixture_function = resolve_fixture_function(fixturedef, request) @@ -333,7 +348,6 @@ def _wrap_async_fixture( runner: Runner, request: FixtureRequest, ) -> Callable[AsyncFixtureParams, AsyncFixtureReturnType]: - @functools.wraps(fixture_function) # type: ignore[arg-type] def _async_fixture_wrapper( *args: AsyncFixtureParams.args, @@ -344,8 +358,8 @@ async def setup(): return res context = contextvars.copy_context() - result = runner.run(setup(), context=context) + result = runner.run(setup(), context=context) # Copy the context vars modified by the setup task into the current # context, and (if needed) add a finalizer to reset them. # @@ -446,6 +460,7 @@ def _can_substitute(item: Function) -> bool: return inspect.iscoroutinefunction(func) def runtest(self) -> None: + # print(self.obj.pytestmark[0].__dict__) synchronized_obj = wrap_in_sync(self.obj) with MonkeyPatch.context() as c: c.setattr(self, "obj", synchronized_obj) @@ -558,16 +573,32 @@ def pytest_pycollect_makeitem_convert_async_functions_to_subclass( @contextlib.contextmanager -def _temporary_event_loop_policy(policy: AbstractEventLoopPolicy) -> Iterator[None]: +def _temporary_event_loop_policy( + policy: AbstractEventLoopPolicy, + loop_facotry: Callable[..., AbstractEventLoop] | None, +) -> Iterator[None]: + old_loop_policy = _get_event_loop_policy() try: old_loop = _get_event_loop_no_warn() except RuntimeError: old_loop = None + # XXX: For some reason this function can override runner's + # _loop_factory (At least observed on backported versions of Runner) + # so we need to re-override if existing... + if loop_facotry: + _loop = loop_facotry() + _set_event_loop(_loop) + else: + _loop = None + _set_event_loop_policy(policy) try: yield finally: + if _loop: + # Do not let BaseEventLoop.__del__ complain! + _loop.close() _set_event_loop_policy(old_loop_policy) _set_event_loop(old_loop) @@ -661,9 +692,8 @@ def wrap_in_sync( @functools.wraps(func) def inner(*args, **kwargs): - coro = func(*args, **kwargs) - _loop = _get_event_loop_no_warn() - task = asyncio.ensure_future(coro, loop=_loop) + _loop = asyncio.get_event_loop() + task = asyncio.ensure_future(func(*args, **kwargs), loop=_loop) try: _loop.run_until_complete(task) except BaseException: @@ -713,10 +743,12 @@ def pytest_fixture_setup(fixturedef: FixtureDef, request) -> object | None: or default_loop_scope or fixturedef.scope ) + loop_factory = getattr(fixturedef.func, "loop_factory", None) + runner_fixture_id = f"_{loop_scope}_scoped_runner" runner = request.getfixturevalue(runner_fixture_id) synchronizer = _fixture_synchronizer(fixturedef, runner, request) - _make_asyncio_fixture_function(synchronizer, loop_scope) + _make_asyncio_fixture_function(synchronizer, loop_scope, loop_factory) with MonkeyPatch.context() as c: c.setattr(fixturedef, "func", synchronizer) hook_result = yield @@ -739,9 +771,13 @@ def _get_marked_loop_scope( ) -> _ScopeName: assert asyncio_marker.name == "asyncio" if asyncio_marker.args or ( - asyncio_marker.kwargs and set(asyncio_marker.kwargs) - {"loop_scope", "scope"} + asyncio_marker.kwargs + and set(asyncio_marker.kwargs) - {"loop_scope", "scope", "loop_factory"} ): - raise ValueError("mark.asyncio accepts only a keyword argument 'loop_scope'.") + raise ValueError( + "mark.asyncio accepts only a keyword arguments 'loop_scope'" + " or 'loop_factory'" + ) if "scope" in asyncio_marker.kwargs: if "loop_scope" in asyncio_marker.kwargs: raise pytest.UsageError(_DUPLICATE_LOOP_SCOPE_DEFINITION_ERROR) @@ -770,6 +806,17 @@ def _get_default_test_loop_scope(config: Config) -> _ScopeName: """ +def _get_loop_facotry( + request: FixtureRequest, +) -> Callable[[], AbstractEventLoop] | None: + if asyncio_mark := request._pyfuncitem.get_closest_marker("asyncio"): + factory = asyncio_mark.kwargs.get("loop_factory", None) + print(f"FACTORY {factory}") + return factory + else: + return request.obj.__dict__.get("_loop_factory", None) # type: ignore[attr-defined] + + def _create_scoped_runner_fixture(scope: _ScopeName) -> Callable: @pytest.fixture( scope=scope, @@ -780,9 +827,14 @@ def _scoped_runner( request: FixtureRequest, ) -> Iterator[Runner]: new_loop_policy = event_loop_policy + + # We need to get the factory now because + # _temporary_event_loop_policy can override the Runner + factory = _get_loop_facotry(request) debug_mode = _get_asyncio_debug(request.config) - with _temporary_event_loop_policy(new_loop_policy): - runner = Runner(debug=debug_mode).__enter__() + with _temporary_event_loop_policy(new_loop_policy, factory): + runner = Runner(debug=debug_mode, loop_factory=factory).__enter__() + try: yield runner except Exception as e: diff --git a/tests/markers/test_invalid_arguments.py b/tests/markers/test_invalid_arguments.py index 2d5c3552..fc2c88f1 100644 --- a/tests/markers/test_invalid_arguments.py +++ b/tests/markers/test_invalid_arguments.py @@ -40,9 +40,7 @@ async def test_anything(): ) result = pytester.runpytest_subprocess() result.assert_outcomes(errors=1) - result.stdout.fnmatch_lines( - ["*ValueError: mark.asyncio accepts only a keyword argument*"] - ) + result.stdout.fnmatch_lines([""]) def test_error_when_wrong_keyword_argument_is_passed( @@ -62,7 +60,10 @@ async def test_anything(): result = pytester.runpytest_subprocess() result.assert_outcomes(errors=1) result.stdout.fnmatch_lines( - ["*ValueError: mark.asyncio accepts only a keyword argument 'loop_scope'*"] + [ + "*ValueError: mark.asyncio accepts only a keyword arguments " + "'loop_scope' or 'loop_factory'*" + ] ) @@ -83,5 +84,8 @@ async def test_anything(): result = pytester.runpytest_subprocess() result.assert_outcomes(errors=1) result.stdout.fnmatch_lines( - ["*ValueError: mark.asyncio accepts only a keyword argument*"] + [ + "*ValueError: mark.asyncio accepts only a keyword arguments " + "'loop_scope' or 'loop_factory'*" + ] ) diff --git a/tests/modes/test_strict_mode.py b/tests/modes/test_strict_mode.py index 44f54b7d..0655fbdb 100644 --- a/tests/modes/test_strict_mode.py +++ b/tests/modes/test_strict_mode.py @@ -163,10 +163,7 @@ async def test_anything(any_fixture): result.stdout.fnmatch_lines( [ "*warnings summary*", - ( - "test_strict_mode_marked_test_unmarked_fixture_warning.py::" - "test_anything" - ), + ("test_strict_mode_marked_test_unmarked_fixture_warning.py::test_anything"), ( "*/pytest_asyncio/plugin.py:*: PytestDeprecationWarning: " "asyncio test 'test_anything' requested async " diff --git a/tests/test_asyncio_mark.py b/tests/test_asyncio_mark.py index 81731adb..0e839bbc 100644 --- a/tests/test_asyncio_mark.py +++ b/tests/test_asyncio_mark.py @@ -223,3 +223,45 @@ async def test_a(session_loop_fixture): result = pytester.runpytest("--asyncio-mode=auto") result.assert_outcomes(passed=1) + + +def test_asyncio_marker_event_loop_factories(pytester: Pytester): + pytester.makeini( + dedent( + """\ + [pytest] + asyncio_default_fixture_loop_scope = function + asyncio_default_test_loop_scope = module + """ + ) + ) + + pytester.makepyfile( + dedent( + """\ + import asyncio + import pytest_asyncio + import pytest + + class CustomEventLoop(asyncio.SelectorEventLoop): + pass + + @pytest.mark.asyncio(loop_factory=CustomEventLoop) + async def test_has_different_event_loop(): + assert type(asyncio.get_running_loop()).__name__ == "CustomEventLoop" + + @pytest_asyncio.fixture(loop_factory=CustomEventLoop) + async def custom_fixture(): + yield asyncio.get_running_loop() + + async def test_with_fixture(custom_fixture): + # Both of these should be the same... + type(asyncio.get_running_loop()).__name__ == "CustomEventLoop" + type(custom_fixture).__name__ == "CustomEventLoop" + + """ + ) + ) + + result = pytester.runpytest("--asyncio-mode=auto") + result.assert_outcomes(passed=1)