diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 86d55d57..1ba879bf 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -16,10 +16,10 @@ jobs: SETUPTOOLS_SCM_PRETEND_VERSION: ${{ github.event.inputs.version }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Build and Check Package - uses: hynek/build-and-inspect-python-package@v1.5 + uses: hynek/build-and-inspect-python-package@v2.2 deploy: needs: package @@ -30,16 +30,16 @@ jobs: contents: write # For tag. steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Download Package - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: Packages path: dist - name: Publish package to PyPI - uses: pypa/gh-action-pypi-publish@v1.8.5 + uses: pypa/gh-action-pypi-publish@v1.8.14 - name: Push tag run: | diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c3b0e71f..a5ead60f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,9 +20,9 @@ jobs: package: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Build and Check Package - uses: hynek/build-and-inspect-python-package@v1.5 + uses: hynek/build-and-inspect-python-package@v2.2 test: @@ -34,28 +34,26 @@ jobs: fail-fast: false matrix: os: [ windows-latest, ubuntu-latest ] - python: [ "3.7","3.8","3.10","3.11", "pypy-3.7" ] + python: [ "3.8","3.10","3.11","3.12", "pypy-3.8" ] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 - name: Download Package - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4.1.4 with: name: Packages path: dist - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} - name: Install tox - run: | - python -m pip install --upgrade pip - pip install tox + run: pip install tox - name: Test shell: bash diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2cec5ad2..eaf6fc6d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,33 +1,24 @@ repos: - repo: https://github.com/codespell-project/codespell - rev: v2.2.5 + rev: v2.2.6 hooks: - id: codespell -- repo: https://github.com/psf/black - rev: 23.3.0 - hooks: - - id: black - repo: https://github.com/asottile/blacken-docs - rev: 1.14.0 + rev: 1.16.0 hooks: - id: blacken-docs additional_dependencies: [black==22.12.0] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - - id: trailing-whitespace - - id: end-of-file-fixer - id: check-yaml -- repo: https://github.com/asottile/pyupgrade - rev: v3.8.0 - hooks: - - id: pyupgrade - args: [--py37-plus] -- repo: https://github.com/asottile/reorder-python-imports - rev: v3.10.0 +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.5 hooks: - - id: reorder-python-imports - args: ['--application-directories=execnet', --py37-plus] + - id: ruff + args: [ --fix ] + exclude: "^doc/" + - id: ruff-format - repo: https://github.com/PyCQA/doc8 rev: 'v1.1.1' hooks: @@ -35,6 +26,10 @@ repos: args: ["--ignore", "D001"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.4.1' + rev: 'v1.9.0' hooks: - id: mypy + additional_dependencies: + - pytest + - types-pywin32 + - types-gevent diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 00000000..516834fe --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,11 @@ +version: 2 + +build: + os: ubuntu-22.04 + tools: + python: "3" + +python: + install: + - method: pip + path: . diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 3776bad1..dfe36275 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,33 @@ +2.1.1 (2024-04-08) +------------------ + +* `#267 `__ Fixed regression + in 2.1.0 where the ``strconfig`` argument to ``load``/``loads`` is ignored. + +2.1.0 (2024-04-05) +------------------ + +* `#243 `__: Added ``main_thread_only`` + execmodel which is derived from the thread execmodel and only executes ``remote_exec`` + calls in the main thread. + + Callers of ``remote_exec`` must use the returned channel to wait for a task to complete + before they call remote_exec again, otherwise the ``remote_exec`` call will fail with a + ``concurrent remote_exec would cause deadlock`` error. The main_thread_only execmodel + provides solutions for `#96 `__ and + `pytest-dev/pytest-xdist#620 `__ + (pending a new `pytest-xdist` release). + + Also fixed ``init_popen_io`` to use ``closefd=False`` for shared stdin and stdout file + descriptors, preventing ``Bad file descriptor`` errors triggered by test_stdouterrin_setnull. +* The library is now typed and the typing is exposed to type-checkers. +* Re-exported ``Gateway``, ``Channel``, ``DumpError`` and ``LoadError`` from + ``execnet``. The constructors are private. +* Fixed ``GatewayBase.join()`` timeout argument getting ignored. +* Removed support for Python 3.7. +* Added official support for Python 3.12. + + 2.0.2 (2023-07-09) ------------------ @@ -130,7 +160,7 @@ (this also fixes the bpython interaction issues) - Fix issue38: provide ability to connect to Vagrant VMs easily - using :code:`vagrant_ssh=defaut` or :code:`vagrant_ssh=machinename` + using :code:`vagrant_ssh=default` or :code:`vagrant_ssh=machinename` this feature is experimental and will be refined in future releases. Thanks Christian Theune for the discussion and the initial pull request. @@ -432,7 +462,7 @@ * make internal protocols more robust against serialization failures -* fix a seralization bug with nested tuples containing empty tuples +* fix a serialization bug with nested tuples containing empty tuples (thanks to ronny for discovering it) * setting the environment variable EXECNET_DEBUG will generate per diff --git a/README.rst b/README.rst index a79845eb..7624091a 100644 --- a/README.rst +++ b/README.rst @@ -1,9 +1,6 @@ execnet: distributed Python deployment and communication ======================================================== -Important ---------- - .. image:: https://img.shields.io/pypi/v/execnet.svg :target: https://pypi.org/project/execnet/ @@ -19,7 +16,7 @@ Important .. image:: https://img.shields.io/badge/code%20style-black-000000.svg :target: https://github.com/python/black -.. _execnet: http://codespeak.net/execnet +.. _execnet: https://execnet.readthedocs.io execnet_ provides carefully tested means to ad-hoc interact with Python interpreters across version, platform and network barriers. It provides diff --git a/doc/_templates/layout.html b/doc/_templates/layout.html index f4463252..b6665ee9 100644 --- a/doc/_templates/layout.html +++ b/doc/_templates/layout.html @@ -18,16 +18,3 @@

execnet: Distributed Python deployment and communication

{% endblock %} - -{% block footer %} -{{ super() }} - - -{% endblock %} diff --git a/doc/basics.rst b/doc/basics.rst index aa6dabaf..78672647 100644 --- a/doc/basics.rst +++ b/doc/basics.rst @@ -6,7 +6,7 @@ execnet ad-hoc instantiates local and remote Python interpreters. Each interpreter is accessible through a **Gateway** which manages code and data communication. **Channels** allow to exchange data between the local and the remote end. **Groups** -help to manage creation and termination of sub interpreters. +help to manage creation and termination of sub-interpreters. .. image:: _static/basic1.png @@ -26,10 +26,10 @@ Here is an example which instantiates a simple Python subprocess:: >>> gateway = execnet.makegateway() -gateways allow to `remote execute code`_ and +Gateways allow to `remote execute code`_ and `exchange data`_ bidirectionally. -examples for valid gateway specifications +Examples for valid gateway specifications ------------------------------------------- * ``ssh=wyvern//python=python3.3//chdir=mycache`` specifies a Python3.3 @@ -82,7 +82,7 @@ in the instantiated subprocess-interpreter: .. automethod:: Gateway.remote_exec(source) It is allowed to pass a module object as source code -in which case it's source code will be obtained and +in which case its source code will be obtained and get sent for remote execution. ``remote_exec`` returns a channel object whose symmetric counterpart channel is available to the remotely executing source. @@ -90,7 +90,7 @@ is available to the remotely executing source. .. method:: Gateway.reconfigure([py2str_as_py3str=True, py3str_as_py2str=False]) - reconfigures the string-coercion behaviour of the gateway + Reconfigures the string-coercion behaviour of the gateway .. _`Channel`: .. _`channel-api`: @@ -138,14 +138,14 @@ processes then you often want to call ``group.terminate()`` yourself and specify a larger or not timeout. -threading models: gevent, eventlet, thread -=========================================== +threading models: gevent, eventlet, thread, main_thread_only +==================================================================== .. versionadded:: 1.2 (status: experimental!) -execnet supports "thread", "eventlet" and "gevent" as thread models -on each of the two sides. You need to decide which model to use -before you create any gateways:: +execnet supports "main_thread_only", "thread", "eventlet" and "gevent" +as thread models on each of the two sides. You need to decide which +model to use before you create any gateways:: # content of threadmodel.py import execnet @@ -164,17 +164,10 @@ you can execute this little test file:: 1 -.. note:: - - With python3 you can (as of December 2013) only use the thread model - because neither eventlet-0.14.0 nor gevent-1.0 support Python3. - When they start to support Python3, execnet will probably just work - because it is itself Python3 compatible. - How to execute in the main thread ------------------------------------------------ -When the remote side of a gateway uses the 'thread' model, execution +When the remote side of a gateway uses the "thread" model, execution will preferably run in the main thread. This allows GUI loops or other code to behave correctly. If you, however, start multiple executions concurrently, they will run in non-main threads. @@ -227,7 +220,7 @@ configure a tracing mechanism: .. _`dumps/loads`: .. _`dumps/loads API`: -cross-interpreter serialization of python objects +Cross-interpreter serialization of Python objects ======================================================= .. versionadded:: 1.1 diff --git a/doc/conf.py b/doc/conf.py index c12a8c44..3d380529 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -26,7 +26,11 @@ # Add any Sphinx extension module names here, as strings. # They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. -extensions = ["sphinx.ext.autodoc", "sphinx.ext.doctest"] +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.doctest", + "sphinx.ext.intersphinx", +] # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] @@ -83,6 +87,16 @@ # A list of ignored prefixes for module index sorting. # modindex_common_prefix = [] +intersphinx_mapping = { + "python": ("https://docs.python.org/3", None), +} + +nitpicky = True +nitpick_ignore = [ + ("py:class", "execnet.gateway_base.ChannelFileRead"), + ("py:class", "execnet.gateway_base.ChannelFileWrite"), + ("py:class", "execnet.gateway.Gateway"), +] # -- Options for HTML output -------------------------------------------------- diff --git a/doc/example/popen_read_multiple.py b/doc/example/popen_read_multiple.py index 571faf34..43d0ed04 100644 --- a/doc/example/popen_read_multiple.py +++ b/doc/example/popen_read_multiple.py @@ -3,6 +3,7 @@ reading results from possibly blocking code running in sub processes. """ + import execnet NUM_PROCESSES = 5 diff --git a/doc/example/redirect_remote_output.py b/doc/example/redirect_remote_output.py index 8900a658..0fea4ebb 100644 --- a/doc/example/redirect_remote_output.py +++ b/doc/example/redirect_remote_output.py @@ -7,6 +7,7 @@ - setting a callback for receiving channel data """ + import execnet gw = execnet.makegateway() @@ -26,7 +27,7 @@ def write(data): print("received:", repr(data)) -outchan.setcallback(write) +outchan.setcallback(write) # type: ignore[attr-defined] gw.remote_exec( """ diff --git a/doc/example/svn-sync-repo.py b/doc/example/svn-sync-repo.py index 24e7bd58..4377dc3f 100644 --- a/doc/example/svn-sync-repo.py +++ b/doc/example/svn-sync-repo.py @@ -5,6 +5,7 @@ uses execnet. """ + import os import pathlib import subprocess diff --git a/doc/example/sysinfo.py b/doc/example/sysinfo.py index ce3bb9ce..b3fa2d5e 100644 --- a/doc/example/sysinfo.py +++ b/doc/example/sysinfo.py @@ -5,13 +5,13 @@ (c) Holger Krekel, MIT license """ + import optparse import re import sys import execnet - parser = optparse.OptionParser(usage=__doc__) parser.add_option( "-f", @@ -129,7 +129,7 @@ def getinfo(sshname, ssh_config=None, loginfo=sys.stdout): if ssh_config: spec = f"ssh=-F {ssh_config} {sshname}" else: - spec += "ssh=%s" % sshname + spec = "ssh=%s" % sshname debug("connecting to", repr(spec)) try: gw = execnet.makegateway(spec) diff --git a/doc/example/test_debug.rst b/doc/example/test_debug.rst index f0a71641..144a197f 100644 --- a/doc/example/test_debug.rst +++ b/doc/example/test_debug.rst @@ -1,5 +1,5 @@ -Debugging execnet / Wire messages +Debugging execnet / wire messages =============================================================== By setting the environment variable ``EXECNET_DEBUG`` you can diff --git a/doc/example/test_group.rst b/doc/example/test_group.rst index ed7a07b6..dd6275b5 100644 --- a/doc/example/test_group.rst +++ b/doc/example/test_group.rst @@ -5,7 +5,7 @@ Usings Groups for managing multiple gateways ------------------------------------------------------ Use ``execnet.Group`` to manage membership and lifetime of -of multiple gateways:: +multiple gateways:: >>> import execnet >>> group = execnet.Group(['popen'] * 2) @@ -25,7 +25,7 @@ of multiple gateways:: >>> group -Assigning Gateway IDs +Assigning gateway IDs ------------------------------------------------------ All gateways are created as part of a group and receive @@ -66,12 +66,12 @@ you actually use the ``execnet.default_group``:: >>> execnet.default_group.defaultspec # used for empty makegateway() calls 'popen' -Robust Termination of ssh/popen processes +Robust termination of SSH/popen processes ----------------------------------------------- Use ``group.terminate(timeout)`` if you want to terminate -member gateways and ensure that no local sub processes remain -you can specify a ``timeout`` after which an attempt at killing +member gateways and ensure that no local subprocesses remain. +You can specify a ``timeout`` after which an attempt at killing the related process is made:: >>> import execnet @@ -86,7 +86,7 @@ the related process is made:: execnet aims to provide totally robust termination so if you have left-over processes or other termination issues -please :doc:`report them <../support>`. thanks! +please :doc:`report them <../support>`. Thanks! Using Groups to manage a certain type of gateway diff --git a/doc/example/test_info.rst b/doc/example/test_info.rst index 8dbfa538..0777dd7b 100644 --- a/doc/example/test_info.rst +++ b/doc/example/test_info.rst @@ -1,5 +1,5 @@ -basic local and remote communication -========================================= +Basic local and remote communication +==================================== Execute source code in subprocess, communicate through a channel ------------------------------------------------------------------- @@ -25,8 +25,8 @@ messages between two processes. .. _`share-nothing model`: http://en.wikipedia.org/wiki/Shared_nothing_architecture -remote-exec a function (avoiding inlined source part I) -------------------------------------------------------------------- +Remote-exec a function (avoiding inlined source part I) +------------------------------------------------------- You can send and remote execute parametrized pure functions like this: @@ -49,8 +49,8 @@ Notes: between the nodes). -remote-exec a module (avoiding inlined source part II) --------------------------------------------------------------- +Remote-exec a module (avoiding inlined source part II) +------------------------------------------------------ You can pass a module object to ``remote_exec`` in which case its source code will be sent. No dependencies will be transferred @@ -86,8 +86,8 @@ A local subprocess gateway has the same working directory as the instantiatior:: "ssh" gateways default to the login home directory. -Get information from remote ssh account --------------------------------------------- +Get information from remote SSH account +--------------------------------------- Use simple execution to obtain information from remote environments:: @@ -143,7 +143,7 @@ and use it to transfer information:: -a simple command loop pattern +A simple command loop pattern -------------------------------------------------------------- If you want the remote side to serve a number @@ -189,6 +189,6 @@ itself into the remote socket endpoint:: gw = execnet.makegateway("socket=TARGET-IP:8888") That's it, you can now use the gateway object just like -a popen- or ssh-based one. +a popen- or SSH-based one. .. include:: test_ssh_fileserver.rst diff --git a/doc/example/test_multi.rst b/doc/example/test_multi.rst index 611a3808..306b960d 100644 --- a/doc/example/test_multi.rst +++ b/doc/example/test_multi.rst @@ -1,4 +1,4 @@ -advanced (multi) channel communication +Advanced (multi) channel communication ===================================================== MultiChannel: container for multiple channels @@ -19,7 +19,7 @@ Use ``execnet.MultiChannel`` to work with multiple channels:: >>> sum(mch.receive_each()) 3 -receive results from sub processes with a Queue +Receive results from sub processes with a Queue ----------------------------------------------------- Use ``MultiChannel.make_receive_queue()`` to get a queue @@ -52,7 +52,7 @@ data immediately and without blocking execution:: Note that the callback function will be executed in the receiver thread and should not block or run for too long. -robustly receive results and termination notification +Robustly receive results and termination notification ----------------------------------------------------- Use ``MultiChannel.make_receive_queue(endmarker)`` to specify @@ -76,7 +76,7 @@ is blocked in execution and is terminated/killed:: -saturate multiple Hosts and CPUs with tasks to process +Saturate multiple Hosts and CPUs with tasks to process -------------------------------------------------------- If you have multiple CPUs or hosts you can create as many diff --git a/doc/example/test_proxy.rst b/doc/example/test_proxy.rst index 0a90f258..f2af992d 100644 --- a/doc/example/test_proxy.rst +++ b/doc/example/test_proxy.rst @@ -1,13 +1,13 @@ -Managing Proxied gateways +Managing proxied gateways ========================== -Simple Proxying +Simple proxying ---------------- -Using the via arg of specs we can create a gateway +Using the ``via`` arg of specs we can create a gateway whose io is created on a remote gateway and proxied to the master. -The simlest use case, is where one creates one master process +The simplest use case, is where one creates one master process and uses it to control new workers and their environment :: diff --git a/doc/example/test_ssh_fileserver.rst b/doc/example/test_ssh_fileserver.rst index 54807452..e7a82817 100644 --- a/doc/example/test_ssh_fileserver.rst +++ b/doc/example/test_ssh_fileserver.rst @@ -1,4 +1,3 @@ - Receive file contents from remote SSH account ----------------------------------------------------- diff --git a/doc/implnotes.rst b/doc/implnotes.rst index 9223fd9b..d13d9e87 100644 --- a/doc/implnotes.rst +++ b/doc/implnotes.rst @@ -1,8 +1,7 @@ - gateway_base.py ---------------------- -the code of this module is sent to the "other side" +The code of this module is sent to the "other side" as a means of bootstrapping a Gateway object capable of receiving and executing code, and routing data through channels. diff --git a/doc/index.rst b/doc/index.rst index fe1028c0..19a20eb8 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -1,5 +1,3 @@ - - .. image:: _static/pythonring.png :align: right @@ -17,9 +15,9 @@ communication for distributing execution across many Python interpreters across version, platform and network barriers. It has a minimal and fast API targeting the following uses: -* distribute tasks to (many) local or remote CPUs -* write and deploy hybrid multi-process applications -* write scripts to administer multiple environments +* Distribute tasks to (many) local or remote CPUs +* Write and deploy hybrid multi-process applications +* Write scripts to administer multiple environments .. _`channel-send/receive`: http://en.wikipedia.org/wiki/Channel_(programming) .. _`share-nothing model`: http://en.wikipedia.org/wiki/Shared_nothing_architecture @@ -30,25 +28,25 @@ a minimal and fast API targeting the following uses: Features ------------------ -* automatic bootstrapping: no manual remote installation. +* Automatic bootstrapping: no manual remote installation. -* safe and simple serialization of python builtin +* Safe and simple serialization of Python builtin types for sending/receiving structured data messages. (New in 1.1) execnet offers a new :ref:`dumps/loads ` API which allows cross-interpreter compatible serialization of Python builtin types. -* flexible communication: synchronous send/receive as well as +* Flexible communication: synchronous send/receive as well as callback/queue mechanisms supported -* easy creation, handling and termination of multiple processes +* Easy creation, handling and termination of multiple processes -* well tested interactions between CPython 2.5-2.7, CPython-3.3, Jython 2.5.1 +* Well tested interactions between CPython 2.5-2.7, CPython-3.3, Jython 2.5.1 and PyPy interpreters. -* fully interoperable between Windows and Unix-ish systems. +* Fully interoperable between Windows and Unix-ish systems. -* many tested :doc:`examples` +* Many tested :doc:`examples` Known uses ------------------- @@ -64,7 +62,7 @@ Known uses * Ronny Pfannschmidt uses it for his `anyvc`_ VCS-abstraction project to bridge the Python2/Python3 version gap. -* sysadmins and developers are using it for ad-hoc custom scripting +* Sysadmins and developers are using it for ad-hoc custom scripting .. _`quora`: http://quora.com .. _`connecting CPython and PyPy`: http://www.quora.com/Quora-Infrastructure/Did-Quoras-switch-to-PyPy-result-in-increased-memory-consumption @@ -87,7 +85,6 @@ used as backend of the popular `pytest-xdist `_ +* Join `execnet-dev`_ for general discussions +* Join `execnet-commit`_ to be notified of changes +* Clone the `github repository`_ and submit patches +* Hang out on the #pytest channel on `irc.libera.chat `_ (using an IRC client, via `webchat `_, or `via Matrix `_). diff --git a/pyproject.toml b/pyproject.toml index c010499f..3a7ffd44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ dynamic = ["version"] description = "execnet: rapid multi-Python deployment" readme = {"file" = "README.rst", "content-type" = "text/x-rst"} license = "MIT" -requires-python = ">=3.7" +requires-python = ">=3.8" authors = [ { name = "holger krekel and others" }, ] @@ -22,11 +22,11 @@ classifiers = [ "Operating System :: MacOS :: MacOS X", "Operating System :: Microsoft :: Windows", "Operating System :: POSIX", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", "Topic :: Software Development :: Libraries", @@ -45,6 +45,41 @@ testing = [ [project.urls] Homepage = "https://execnet.readthedocs.io/en/latest/" +[tool.ruff.lint] +extend-select = [ + "B", # bugbear + "E", # pycodestyle + "F", # pyflakes + "I", # isort + "PYI", # flake8-pyi + "UP", # pyupgrade + "RUF", # ruff + "W", # pycodestyle + "PIE", # flake8-pie + "PGH", # pygrep-hooks + "PLE", # pylint error + "PLW", # pylint warning +] +ignore = [ + # bugbear ignore + "B007", # Loop control variable `i` not used within loop body + "B011", # Do not `assert False` (`python -O` removes these calls) + # pycodestyle ignore + "E501", # Line too long + "E741", # Ambiguous variable name + # ruff ignore + "RUF012", # Mutable class attributes should be annotated with `typing.ClassVar` + # pylint ignore + "PLW0603", # Using the global statement + "PLW0120", # remove the else and dedent its contents + "PLW2901", # for loop variable overwritten by assignment target + "PLR5501", # Use `elif` instead of `else` then `if` +] + +[tool.ruff.lint.isort] +force-single-line = true +known-third-party = ["src"] + [tool.hatch.version] source = "vcs" @@ -60,4 +95,19 @@ include = [ ] [tool.mypy] -python_version = "3.7" +python_version = "3.8" +mypy_path = ["src"] +files = ["src", "testing"] +strict = true +warn_unreachable = true +warn_unused_ignores = false +disallow_untyped_calls = false +disallow_untyped_defs = false +disallow_incomplete_defs = false + +[[tool.mypy.overrides]] +module = [ + "eventlet.*", + "gevent.thread.*", +] +ignore_missing_imports = true diff --git a/src/execnet/__init__.py b/src/execnet/__init__.py index c403e2bb..5c3f20e3 100644 --- a/src/execnet/__init__.py +++ b/src/execnet/__init__.py @@ -6,24 +6,28 @@ (c) 2012, Holger Krekel and others """ + from ._version import version as __version__ +from .gateway import Gateway +from .gateway_base import Channel from .gateway_base import DataFormatError +from .gateway_base import DumpError +from .gateway_base import LoadError +from .gateway_base import RemoteError +from .gateway_base import TimeoutError from .gateway_base import dump from .gateway_base import dumps from .gateway_base import load from .gateway_base import loads -from .gateway_base import RemoteError -from .gateway_base import TimeoutError from .gateway_bootstrap import HostNotFound -from .multi import default_group from .multi import Group -from .multi import makegateway from .multi import MultiChannel +from .multi import default_group +from .multi import makegateway from .multi import set_execmodel from .rsync import RSync from .xspec import XSpec - __all__ = [ "__version__", "makegateway", @@ -32,13 +36,17 @@ "RemoteError", "TimeoutError", "XSpec", + "Gateway", "Group", "MultiChannel", "RSync", "default_group", + "Channel", "dumps", + "dump", + "DumpError", "loads", "load", - "dump", + "LoadError", "DataFormatError", ] diff --git a/src/execnet/gateway.py b/src/execnet/gateway.py index 6e0b8a7d..547be291 100644 --- a/src/execnet/gateway.py +++ b/src/execnet/gateway.py @@ -1,50 +1,57 @@ -""" -gateway code for initiating popen, socket and ssh connections. +"""Gateway code for initiating popen, socket and ssh connections. + (c) 2004-2013, Holger Krekel and others """ + +from __future__ import annotations + import inspect import linecache -import os -import sys import textwrap import types - -import execnet +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable from . import gateway_base +from .gateway_base import IO +from .gateway_base import Channel from .gateway_base import Message - -importdir = os.path.dirname(os.path.dirname(execnet.__file__)) +from .multi import Group +from .xspec import XSpec class Gateway(gateway_base.BaseGateway): """Gateway to a local or remote Python Interpreter.""" - def __init__(self, io, spec): + _group: Group + + def __init__(self, io: IO, spec: XSpec) -> None: + """:private:""" super().__init__(io=io, id=spec.id, _startcount=1) self.spec = spec self._initreceive() @property - def remoteaddress(self): - return self._io.remoteaddress + def remoteaddress(self) -> str: + # Only defined for remote IO types. + return self._io.remoteaddress # type: ignore[attr-defined,no-any-return] - def __repr__(self): - """return string representing gateway type and status.""" + def __repr__(self) -> str: + """A string representing gateway type and status.""" try: - r = self.hasreceiver() and "receive-live" or "not-receiving" - i = len(self._channelfactory.channels()) + r: str = self.hasreceiver() and "receive-live" or "not-receiving" + i = str(len(self._channelfactory.channels())) except AttributeError: r = "uninitialized" i = "no" - return "<{} id={!r} {}, {} model, {} active channels>".format( - self.__class__.__name__, self.id, r, self.execmodel.backend, i - ) + return f"<{self.__class__.__name__} id={self.id!r} {r}, {self.execmodel.backend} model, {i} active channels>" - def exit(self): - """trigger gateway exit. Defer waiting for finishing - of receiver-thread and subprocess activity to when - group.terminate() is called. + def exit(self) -> None: + """Trigger gateway exit. + + Defer waiting for finishing of receiver-thread and subprocess activity + to when group.terminate() is called. """ self._trace("gateway.exit() called") if self not in self._group: @@ -56,23 +63,24 @@ def exit(self): self._send(Message.GATEWAY_TERMINATE) self._trace("--> io.close_write") self._io.close_write() - except (ValueError, EOFError, OSError): - v = sys.exc_info()[1] + except (ValueError, EOFError, OSError) as exc: self._trace("io-error: could not send termination sequence") - self._trace(" exception: %r" % v) + self._trace(" exception: %r" % exc) - def reconfigure(self, py2str_as_py3str=True, py3str_as_py2str=False): - """ - set the string coercion for this gateway - the default is to try to convert py2 str as py3 str, - but not to try and convert py3 str to py2 str + def reconfigure( + self, py2str_as_py3str: bool = True, py3str_as_py2str: bool = False + ) -> None: + """Set the string coercion for this gateway. + + The default is to try to convert py2 str as py3 str, but not to try and + convert py3 str to py2 str. """ self._strconfig = (py2str_as_py3str, py3str_as_py2str) data = gateway_base.dumps_internal(self._strconfig) self._send(Message.RECONFIGURE, data=data) - def _rinfo(self, update=False): - """return some sys/env information from remote.""" + def _rinfo(self, update: bool = False) -> RInfo: + """Return some sys/env information from remote.""" if update or not hasattr(self, "_cache_rinfo"): ch = self.remote_exec(rinfo_source) try: @@ -81,12 +89,12 @@ def _rinfo(self, update=False): ch.waitclose() return self._cache_rinfo - def hasreceiver(self): - """return True if gateway is able to receive data.""" + def hasreceiver(self) -> bool: + """Whether gateway is able to receive data.""" return self._receivepool.active_count() > 0 - def remote_status(self): - """return information object about remote execution status.""" + def remote_status(self) -> RemoteStatus: + """Obtain information about the remote execution status.""" channel = self.newchannel() self._send(Message.STATUS, channel.id) statusdict = channel.receive() @@ -95,8 +103,12 @@ def remote_status(self): self._channelfactory._local_close(channel.id) return RemoteStatus(statusdict) - def remote_exec(self, source, **kwargs): - """return channel object and connect it to a remote + def remote_exec( + self, + source: str | types.FunctionType | Callable[..., object] | types.ModuleType, + **kwargs: object, + ) -> Channel: + """Return channel object and connect it to a remote execution thread where the given ``source`` executes. * ``source`` is a string: execute source string remotely @@ -105,7 +117,7 @@ def remote_exec(self, source, **kwargs): call function with ``**kwargs``, adding a ``channel`` object to the keyword arguments. * ``source`` is a pure module: execute source of module - with a ``channel`` in its global namespace + with a ``channel`` in its global namespace. In all cases the binding ``__name__='__channelexec__'`` will be available in the global namespace of the remotely @@ -115,7 +127,7 @@ def remote_exec(self, source, **kwargs): file_name = None if isinstance(source, types.ModuleType): file_name = inspect.getsourcefile(source) - linecache.updatecache(file_name) + linecache.updatecache(file_name) # type: ignore[arg-type] source = inspect.getsource(source) elif isinstance(source, types.FunctionType): call_name = source.__name__ @@ -135,26 +147,30 @@ def remote_exec(self, source, **kwargs): ) return channel - def remote_init_threads(self, num=None): + def remote_init_threads(self, num: int | None = None) -> None: """DEPRECATED. Is currently a NO-OPERATION already.""" - print("WARNING: remote_init_threads()" " is a no-operation in execnet-1.2") + print("WARNING: remote_init_threads() is a no-operation in execnet-1.2") class RInfo: - def __init__(self, kwargs): + def __init__(self, kwargs) -> None: self.__dict__.update(kwargs) - def __repr__(self): - info = ", ".join("%s=%s" % item for item in sorted(self.__dict__.items())) + def __repr__(self) -> str: + info = ", ".join(f"{k}={v}" for k, v in sorted(self.__dict__.items())) return "" % info + if TYPE_CHECKING: + + def __getattr__(self, name: str) -> Any: ... + RemoteStatus = RInfo -def rinfo_source(channel): - import sys +def rinfo_source(channel) -> None: import os + import sys channel.send( dict( @@ -167,7 +183,7 @@ def rinfo_source(channel): ) -def _find_non_builtin_globals(source, codeobj): +def _find_non_builtin_globals(source: str, codeobj: types.CodeType) -> list[str]: import ast import builtins @@ -181,7 +197,7 @@ def _find_non_builtin_globals(source, codeobj): ] -def _source_of_function(function): +def _source_of_function(function: types.FunctionType | Callable[..., object]) -> str: if function.__name__ == "": raise ValueError("can't evaluate lambda functions'") # XXX: we dont check before remote instantiation @@ -203,8 +219,8 @@ def _source_of_function(function): try: source = inspect.getsource(function) - except OSError: - raise ValueError("can't find source file for %s" % function) + except OSError as e: + raise ValueError("can't find source file for %s" % function) from e source = textwrap.dedent(source) # just for inner functions diff --git a/src/execnet/gateway_base.py b/src/execnet/gateway_base.py index 83c23e90..9ba25b42 100644 --- a/src/execnet/gateway_base.py +++ b/src/execnet/gateway_base.py @@ -1,7 +1,4 @@ -""" -base execnet gateway code send to the other side for bootstrapping. - -NOTE: aims to be compatible to Python 2.5-3.X, Jython and IronPython +"""Base execnet gateway code send to the other side for bootstrapping. :copyright: 2004-2015 :authors: @@ -11,6 +8,7 @@ - Ronny Pfannschmidt - many others """ + from __future__ import annotations import abc @@ -21,16 +19,59 @@ import weakref from _thread import interrupt_main from io import BytesIO +from typing import Any from typing import Callable +from typing import Iterator +from typing import Literal +from typing import MutableSet +from typing import Protocol +from typing import cast +from typing import overload + + +class WriteIO(Protocol): + def write(self, data: bytes, /) -> None: ... + + +class ReadIO(Protocol): + def read(self, numbytes: int, /) -> bytes: ... + + +class IO(Protocol): + execmodel: ExecModel + + def read(self, numbytes: int, /) -> bytes: ... + + def write(self, data: bytes, /) -> None: ... + + def close_read(self) -> None: ... + + def close_write(self) -> None: ... + + def wait(self) -> int | None: ... + + def kill(self) -> None: ... + + +class Event(Protocol): + """Protocol for types which look like threading.Event.""" + + def is_set(self) -> bool: ... + + def set(self) -> None: ... + + def clear(self) -> None: ... + + def wait(self, timeout: float | None = None) -> bool: ... class ExecModel(metaclass=abc.ABCMeta): @property @abc.abstractmethod - def backend(self): + def backend(self) -> str: raise NotImplementedError() - def __repr__(self): + def __repr__(self) -> str: return "" % self.backend @property @@ -49,19 +90,19 @@ def socket(self): raise NotImplementedError() @abc.abstractmethod - def start(self, func, args=()): + def start(self, func, args=()) -> None: raise NotImplementedError() @abc.abstractmethod - def get_ident(self): + def get_ident(self) -> int: raise NotImplementedError() @abc.abstractmethod - def sleep(self, delay): + def sleep(self, delay: float) -> None: raise NotImplementedError() @abc.abstractmethod - def fdopen(self, fd, mode, bufsize=1): + def fdopen(self, fd, mode, bufsize=1, closefd=True): raise NotImplementedError() @abc.abstractmethod @@ -73,7 +114,7 @@ def RLock(self): raise NotImplementedError() @abc.abstractmethod - def Event(self): + def Event(self) -> Event: raise NotImplementedError() @@ -98,25 +139,25 @@ def socket(self): return socket - def get_ident(self): + def get_ident(self) -> int: import _thread return _thread.get_ident() - def sleep(self, delay): + def sleep(self, delay: float) -> None: import time time.sleep(delay) - def start(self, func, args=()): + def start(self, func, args=()) -> None: import _thread - return _thread.start_new_thread(func, args) + _thread.start_new_thread(func, args) - def fdopen(self, fd, mode, bufsize=1): + def fdopen(self, fd, mode, bufsize=1, closefd=True): import os - return os.fdopen(fd, mode, bufsize, encoding="utf-8") + return os.fdopen(fd, mode, bufsize, encoding="utf-8", closefd=closefd) def Lock(self): import threading @@ -134,6 +175,10 @@ def Event(self): return threading.Event() +class MainThreadOnlyExecModel(ThreadExecModel): + backend = "main_thread_only" + + class EventletExecModel(ExecModel): backend = "eventlet" @@ -155,25 +200,25 @@ def socket(self): return eventlet.green.socket - def get_ident(self): + def get_ident(self) -> int: import eventlet.green.thread - return eventlet.green.thread.get_ident() + return eventlet.green.thread.get_ident() # type: ignore[no-any-return] - def sleep(self, delay): + def sleep(self, delay: float) -> None: import eventlet eventlet.sleep(delay) - def start(self, func, args=()): + def start(self, func, args=()) -> None: import eventlet - return eventlet.spawn_n(func, *args) + eventlet.spawn_n(func, *args) - def fdopen(self, fd, mode, bufsize=1): + def fdopen(self, fd, mode, bufsize=1, closefd=True): import eventlet.green.os - return eventlet.green.os.fdopen(fd, mode, bufsize) + return eventlet.green.os.fdopen(fd, mode, bufsize, closefd=closefd) def Lock(self): import eventlet.green.threading @@ -212,26 +257,26 @@ def socket(self): return gevent.socket - def get_ident(self): + def get_ident(self) -> int: import gevent.thread - return gevent.thread.get_ident() + return gevent.thread.get_ident() # type: ignore[no-any-return] - def sleep(self, delay): + def sleep(self, delay: float) -> None: import gevent gevent.sleep(delay) - def start(self, func, args=()): + def start(self, func, args=()) -> None: import gevent - return gevent.spawn(func, *args) + gevent.spawn(func, *args) - def fdopen(self, fd, mode, bufsize=1): + def fdopen(self, fd, mode, bufsize=1, closefd=True): # XXX import gevent.fileobject - return gevent.fileobject.FileObjectThread(fd, mode, bufsize) + return gevent.fileobject.FileObjectThread(fd, mode, bufsize, closefd=closefd) def Lock(self): import gevent.lock @@ -249,11 +294,13 @@ def Event(self): return gevent.event.Event() -def get_execmodel(backend): - if hasattr(backend, "backend"): +def get_execmodel(backend: str | ExecModel) -> ExecModel: + if isinstance(backend, ExecModel): return backend if backend == "thread": return ThreadExecModel() + elif backend == "main_thread_only": + return MainThreadOnlyExecModel() elif backend == "eventlet": return EventletExecModel() elif backend == "gevent": @@ -263,17 +310,15 @@ def get_execmodel(backend): class Reply: - """reply instances provide access to the result - of a function execution that got dispatched - through WorkerPool.spawn() - """ + """Provide access to the result of a function execution that got dispatched + through WorkerPool.spawn().""" - def __init__(self, task, threadmodel): + def __init__(self, task, threadmodel: ExecModel) -> None: self.task = task self._result_ready = threadmodel.Event() self.running = True - def get(self, timeout=None): + def get(self, timeout: float | None = None): """get the result object from an asynchronous function execution. if the function execution raised an exception, then calling get() will reraise that exception @@ -283,21 +328,19 @@ def get(self, timeout=None): try: return self._result except AttributeError: - raise self._excinfo[1].with_traceback(self._excinfo[2]) + raise self._exc from None - def waitfinish(self, timeout=None): + def waitfinish(self, timeout: float | None = None) -> None: if not self._result_ready.wait(timeout): raise OSError(f"timeout waiting for {self.task!r}") - def run(self): + def run(self) -> None: func, args, kwargs = self.task try: try: self._result = func(*args, **kwargs) - except BaseException: - # sys may be already None when shutting down the interpreter - if sys is not None: - self._excinfo = sys.exc_info() + except BaseException as exc: + self._exc = exc finally: self._result_ready.set() self.running = False @@ -312,28 +355,31 @@ class WorkerPool: itself into performing function execution through calling integrate_as_primary_thread() which will return when the pool received a trigger_shutdown(). + + By default allows unlimited number of spawns. """ - def __init__(self, execmodel, hasprimary=False): - """by default allow unlimited number of spawns.""" + _primary_thread_task: Reply | None + + def __init__(self, execmodel: ExecModel, hasprimary: bool = False) -> None: self.execmodel = execmodel self._running_lock = self.execmodel.Lock() - self._running = set() + self._running: MutableSet[Reply] = set() self._shuttingdown = False - self._waitall_events = [] + self._waitall_events: list[Event] = [] if hasprimary: - if self.execmodel.backend != "thread": + if self.execmodel.backend not in ("thread", "main_thread_only"): raise ValueError("hasprimary=True requires thread model") - self._primary_thread_task_ready = self.execmodel.Event() + self._primary_thread_task_ready: Event | None = self.execmodel.Event() else: self._primary_thread_task_ready = None - def integrate_as_primary_thread(self): - """integrate the thread with which we are called as a primary - thread for executing functions triggered with spawn(). - """ - assert self.execmodel.backend == "thread", self.execmodel + def integrate_as_primary_thread(self) -> None: + """Integrate the thread with which we are called as a primary + thread for executing functions triggered with spawn().""" + assert self.execmodel.backend in ("thread", "main_thread_only"), self.execmodel primary_thread_task_ready = self._primary_thread_task_ready + assert primary_thread_task_ready is not None # interacts with code at REF1 while 1: primary_thread_task_ready.wait() @@ -345,19 +391,23 @@ def integrate_as_primary_thread(self): with self._running_lock: if self._shuttingdown: break - primary_thread_task_ready.clear() + # Only clear if _try_send_to_primary_thread has not + # yet set the next self._primary_thread_task reply + # after waiting for this one to complete. + if reply is self._primary_thread_task: + primary_thread_task_ready.clear() - def trigger_shutdown(self): + def trigger_shutdown(self) -> None: with self._running_lock: self._shuttingdown = True if self._primary_thread_task_ready is not None: self._primary_thread_task = None self._primary_thread_task_ready.set() - def active_count(self): + def active_count(self) -> int: return len(self._running) - def _perform_spawn(self, reply): + def _perform_spawn(self, reply: Reply) -> None: reply.run() with self._running_lock: self._running.remove(reply) @@ -366,7 +416,7 @@ def _perform_spawn(self, reply): waitall_event = self._waitall_events.pop() waitall_event.set() - def _try_send_to_primary_thread(self, reply): + def _try_send_to_primary_thread(self, reply: Reply) -> bool: # REF1 in 'thread' model we give priority to running in main thread # note that we should be called with _running_lock hold primary_thread_task_ready = self._primary_thread_task_ready @@ -376,12 +426,23 @@ def _try_send_to_primary_thread(self, reply): # wake up primary thread primary_thread_task_ready.set() return True + elif ( + self.execmodel.backend == "main_thread_only" + and self._primary_thread_task is not None + ): + self._primary_thread_task.waitfinish() + self._primary_thread_task = reply + # wake up primary thread (it's okay if this is already set + # because we waited for the previous task to finish above + # and integrate_as_primary_thread will not clear it when + # it enters self._running_lock if it detects that a new + # task is available) + primary_thread_task_ready.set() + return True return False - def spawn(self, func, *args, **kwargs): - """return Reply object for the asynchronous dispatch - of the given func(*args, **kwargs). - """ + def spawn(self, func, *args, **kwargs) -> Reply: + """Asynchronously dispatch func(*args, **kwargs) and return a Reply.""" reply = Reply((func, args, kwargs), self.execmodel) with self._running_lock: if self._shuttingdown: @@ -391,13 +452,13 @@ def spawn(self, func, *args, **kwargs): self.execmodel.start(self._perform_spawn, (reply,)) return reply - def terminate(self, timeout=None): - """trigger shutdown and wait for completion of all executions.""" + def terminate(self, timeout: float | None = None) -> bool: + """Trigger shutdown and wait for completion of all executions.""" self.trigger_shutdown() return self.waitall(timeout=timeout) - def waitall(self, timeout=None): - """wait until all active spawns have finished executing.""" + def waitall(self, timeout: float | None = None) -> bool: + """Wait until all active spawns have finished executing.""" with self._running_lock: if not self._running: return True @@ -416,7 +477,7 @@ def waitall(self, timeout=None): pid = os.getpid() if DEBUG == "2": - def trace(*msg): + def trace(*msg: object) -> None: try: line = " ".join(map(str, msg)) sys.stderr.write(f"[{pid}] {line}\n") @@ -425,22 +486,21 @@ def trace(*msg): pass # nothing we can do, likely interpreter-shutdown elif DEBUG: - import tempfile import os + import tempfile fn = os.path.join(tempfile.gettempdir(), "execnet-debug-%d" % pid) # sys.stderr.write("execnet-debug at %r" % (fn,)) debugfile = open(fn, "w") - def trace(*msg): + def trace(*msg: object) -> None: try: line = " ".join(map(str, msg)) debugfile.write(line + "\n") debugfile.flush() - except Exception: + except Exception as exc: try: - v = sys.exc_info()[1] - sys.stderr.write(f"[{pid}] exception during tracing: {v!r}\n") + sys.stderr.write(f"[{pid}] exception during tracing: {exc!r}\n") except Exception: pass # nothing we can do, likely interpreter-shutdown @@ -451,7 +511,7 @@ def trace(*msg): class Popen2IO: error = (IOError, OSError, EOFError) - def __init__(self, outfile, infile, execmodel): + def __init__(self, outfile, infile, execmodel: ExecModel) -> None: # we need raw byte streams self.outfile, self.infile = outfile, infile if sys.platform == "win32": @@ -466,7 +526,7 @@ def __init__(self, outfile, infile, execmodel): self._write = getattr(outfile, "buffer", outfile).write self.execmodel = execmodel - def read(self, numbytes): + def read(self, numbytes: int) -> bytes: """Read exactly 'numbytes' bytes from the pipe.""" # a file in non-blocking mode may return less bytes, so we loop buf = b"" @@ -477,62 +537,60 @@ def read(self, numbytes): buf += data return buf - def write(self, data): - """write out all data bytes.""" + def write(self, data: bytes) -> None: + """Write out all data bytes.""" assert isinstance(data, bytes) self._write(data) self.outfile.flush() - def close_read(self): + def close_read(self) -> None: self.infile.close() - def close_write(self): + def close_write(self) -> None: self.outfile.close() class Message: - """encapsulates Messages and their wire protocol.""" + """Encapsulates Messages and their wire protocol.""" # message code -> name, handler _types: dict[int, tuple[str, Callable[[Message, BaseGateway], None]]] = {} - def __init__(self, msgcode, channelid=0, data=b""): + def __init__(self, msgcode: int, channelid: int = 0, data: bytes = b"") -> None: self.msgcode = msgcode self.channelid = channelid self.data = data @staticmethod - def from_io(io): + def from_io(io: ReadIO) -> Message: try: header = io.read(9) # type 1, channel 4, payload 4 if not header: raise EOFError("empty read") - except EOFError: - e = sys.exc_info()[1] - raise EOFError("couldn't load message header, " + e.args[0]) + except EOFError as e: + raise EOFError("couldn't load message header, " + e.args[0]) from None msgtype, channel, payload = struct.unpack("!bii", header) return Message(msgtype, channel, io.read(payload)) - def to_io(self, io): + def to_io(self, io: WriteIO) -> None: header = struct.pack("!bii", self.msgcode, self.channelid, len(self.data)) io.write(header + self.data) - def received(self, gateway): + def received(self, gateway: BaseGateway) -> None: handler = self._types[self.msgcode][1] handler(self, gateway) - def __repr__(self): + def __repr__(self) -> str: name = self._types[self.msgcode][0] - return "".format( - name, self.channelid, len(self.data) - ) + return f"" - def _status(message, gateway): + def _status(message: Message, gateway: BaseGateway) -> None: # we use the channelid to send back information # but don't instantiate a channel object d = { "numchannels": len(gateway._channelfactory._channels), - "numexecuting": gateway._execpool.active_count(), + # TODO(typing): Attribute `_execpool` is only on WorkerGateway. + "numexecuting": gateway._execpool.active_count(), # type: ignore[attr-defined] "execmodel": gateway.execmodel.backend, } gateway._send(Message.CHANNEL_DATA, message.channelid, dumps_internal(d)) @@ -541,49 +599,53 @@ def _status(message, gateway): STATUS = 0 _types[STATUS] = ("STATUS", _status) - def _reconfigure(message, gateway): + def _reconfigure(message: Message, gateway: BaseGateway) -> None: + data = loads_internal(message.data, gateway) + assert isinstance(data, tuple) + strconfig: tuple[bool, bool] = data if message.channelid == 0: - target = gateway + gateway._strconfig = strconfig else: - target = gateway._channelfactory.new(message.channelid) - target._strconfig = loads_internal(message.data, gateway) + gateway._channelfactory.new(message.channelid)._strconfig = strconfig RECONFIGURE = 1 _types[RECONFIGURE] = ("RECONFIGURE", _reconfigure) - def _gateway_terminate(message, gateway): + def _gateway_terminate(message: Message, gateway: BaseGateway) -> None: raise GatewayReceivedTerminate(gateway) GATEWAY_TERMINATE = 2 _types[GATEWAY_TERMINATE] = ("GATEWAY_TERMINATE", _gateway_terminate) - def _channel_exec(message, gateway): + def _channel_exec(message: Message, gateway: BaseGateway) -> None: channel = gateway._channelfactory.new(message.channelid) gateway._local_schedulexec(channel=channel, sourcetask=message.data) CHANNEL_EXEC = 3 _types[CHANNEL_EXEC] = ("CHANNEL_EXEC", _channel_exec) - def _channel_data(message, gateway): + def _channel_data(message: Message, gateway: BaseGateway) -> None: gateway._channelfactory._local_receive(message.channelid, message.data) CHANNEL_DATA = 4 _types[CHANNEL_DATA] = ("CHANNEL_DATA", _channel_data) - def _channel_close(message, gateway): + def _channel_close(message: Message, gateway: BaseGateway) -> None: gateway._channelfactory._local_close(message.channelid) CHANNEL_CLOSE = 5 _types[CHANNEL_CLOSE] = ("CHANNEL_CLOSE", _channel_close) - def _channel_close_error(message, gateway): - remote_error = RemoteError(loads_internal(message.data)) + def _channel_close_error(message: Message, gateway: BaseGateway) -> None: + error_message = loads_internal(message.data) + assert isinstance(error_message, str) + remote_error = RemoteError(error_message) gateway._channelfactory._local_close(message.channelid, remote_error) CHANNEL_CLOSE_ERROR = 6 _types[CHANNEL_CLOSE_ERROR] = ("CHANNEL_CLOSE_ERROR", _channel_close_error) - def _channel_last_message(message, gateway): + def _channel_last_message(message: Message, gateway: BaseGateway) -> None: gateway._channelfactory._local_close(message.channelid, sendonly=True) CHANNEL_LAST_MESSAGE = 7 @@ -594,31 +656,37 @@ class GatewayReceivedTerminate(Exception): """Receiverthread got termination message.""" -def geterrortext(excinfo, format_exception=traceback.format_exception, sysex=sysex): +def geterrortext( + exc: BaseException, + format_exception=traceback.format_exception, + sysex: tuple[type[BaseException], ...] = sysex, +) -> str: try: - l = format_exception(*excinfo) + # In py310, can change this to: + # l = format_exception(exc) + l = format_exception(type(exc), exc, exc.__traceback__) errortext = "".join(l) except sysex: raise except BaseException: - errortext = f"{excinfo[0].__name__}: {excinfo[1]}" + errortext = f"{type(exc).__name__}: {exc}" return errortext class RemoteError(Exception): """Exception containing a stringified error from the other side.""" - def __init__(self, formatted): + def __init__(self, formatted: str) -> None: super().__init__() self.formatted = formatted - def __str__(self): + def __str__(self) -> str: return self.formatted - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}: {self.formatted}" - def warn(self): + def warn(self) -> None: if self.formatted != INTERRUPT_TEXT: # XXX do this better sys.stderr.write(f"[{os.getpid()}] Warning: unhandled {self!r}\n") @@ -632,13 +700,15 @@ class TimeoutError(IOError): class Channel: - "Communication channel between two Python Interpreter execution points." + """Communication channel between two Python Interpreter execution points.""" + RemoteError = RemoteError TimeoutError = TimeoutError _INTERNALWAKEUP = 1000 _executing = False - def __init__(self, gateway, id): + def __init__(self, gateway: BaseGateway, id: int) -> None: + """:private:""" assert isinstance(id, int) assert not isinstance(gateway, type) self.gateway = gateway @@ -648,15 +718,19 @@ def __init__(self, gateway, id): self._items = self.gateway.execmodel.queue.Queue() self._closed = False self._receiveclosed = self.gateway.execmodel.Event() - self._remoteerrors = [] + self._remoteerrors: list[RemoteError] = [] - def _trace(self, *msg): + def _trace(self, *msg: object) -> None: self.gateway._trace(self.id, *msg) - def setcallback(self, callback, endmarker=NO_ENDMARKER_WANTED): - """set a callback function for receiving items. + def setcallback( + self, + callback: Callable[[Any], Any], + endmarker: object = NO_ENDMARKER_WANTED, + ) -> None: + """Set a callback function for receiving items. - All already queued items will immediately trigger the callback. + All already-queued items will immediately trigger the callback. Afterwards the callback will execute in the receiver thread for each received data item and calls to ``receive()`` will raise an error. @@ -685,13 +759,14 @@ def setcallback(self, callback, endmarker=NO_ENDMARKER_WANTED): else: callback(olditem) - def __repr__(self): + def __repr__(self) -> str: flag = self.isclosed() and "closed" or "open" return "" % (self.id, flag) - def __del__(self): + def __del__(self) -> None: if self.gateway is None: # can be None in tests - return + return # type: ignore[unreachable] + self._trace("channel.__del__") # no multithreading issues here, because we have the last ref to 'self' if self._closed: @@ -731,16 +806,34 @@ def _getremoteerror(self): # # public API for channel objects # - def isclosed(self): - """return True if the channel is closed. A closed - channel may still hold items. + def isclosed(self) -> bool: + """Return True if the channel is closed. + + A closed channel may still hold items. """ return self._closed - def makefile(self, mode="w", proxyclose=False): - """return a file-like object. + @overload + def makefile(self, mode: Literal["r"], proxyclose: bool = ...) -> ChannelFileRead: + pass + + @overload + def makefile( + self, + mode: Literal["w"] = ..., + proxyclose: bool = ..., + ) -> ChannelFileWrite: + pass + + def makefile( + self, + mode: Literal["r", "w"] = "w", + proxyclose: bool = False, + ) -> ChannelFileWrite | ChannelFileRead: + """Return a file-like object. + mode can be 'w' or 'r' for writeable/readable files. - if proxyclose is true file.close() will also close the channel. + If proxyclose is true, file.close() will also close the channel. """ if mode == "w": return ChannelFileWrite(channel=self, proxyclose=proxyclose) @@ -748,8 +841,9 @@ def makefile(self, mode="w", proxyclose=False): return ChannelFileRead(channel=self, proxyclose=proxyclose) raise ValueError(f"mode {mode!r} not available") - def close(self, error=None): - """close down this channel with an optional error message. + def close(self, error=None) -> None: + """Close down this channel with an optional error message. + Note that closing of a channel tied to remote_exec happens automatically at the end of execution and cannot be done explicitly. @@ -780,14 +874,19 @@ def close(self, error=None): queue.put(ENDMARKER) self.gateway._channelfactory._no_longer_opened(self.id) - def waitclose(self, timeout=None): - """wait until this channel is closed (or the remote side + def waitclose(self, timeout: float | None = None) -> None: + """Wait until this channel is closed (or the remote side otherwise signalled that no more data was being sent). + The channel may still hold receiveable items, but not receive - any more after waitclose() has returned. Exceptions from executing - code on the other side are reraised as local channel.RemoteErrors. + any more after waitclose() has returned. + + Exceptions from executing code on the other side are reraised as local + channel.RemoteErrors. + EOFError is raised if the reading-connection was prematurely closed, which often indicates a dying process. + self.TimeoutError is raised after the specified number of seconds (default is None, i.e. wait indefinitely). """ @@ -799,22 +898,26 @@ def waitclose(self, timeout=None): if error: raise error - def send(self, item): - """sends the given item to the other side of the channel, + def send(self, item: object) -> None: + """Sends the given item to the other side of the channel, possibly blocking if the sender queue is full. - The item must be a simple python type and will be - copied to the other side by value. IOError is - raised if the write pipe was prematurely closed. + + The item must be a simple Python type and will be + copied to the other side by value. + + OSError is raised if the write pipe was prematurely closed. """ if self.isclosed(): raise OSError(f"cannot send to {self!r}") self.gateway._send(Message.CHANNEL_DATA, self.id, dumps_internal(item)) - def receive(self, timeout=None): - """receive a data item that was sent from the other side. - timeout: None [default] blocked waiting. A positive number + def receive(self, timeout: float | None = None) -> Any: + """Receive a data item that was sent from the other side. + + timeout: None [default] blocked waiting. A positive number indicates the number of seconds after which a channel.TimeoutError exception will be raised if no item was received. + Note that exceptions from the remotely executing code will be reraised as channel.RemoteError exceptions containing a textual representation of the remote traceback. @@ -825,28 +928,30 @@ def receive(self, timeout=None): try: x = itemqueue.get(timeout=timeout) except self.gateway.execmodel.queue.Empty: - raise self.TimeoutError("no item after %r seconds" % timeout) + raise self.TimeoutError("no item after %r seconds" % timeout) from None if x is ENDMARKER: itemqueue.put(x) # for other receivers raise self._getremoteerror() or EOFError() else: return x - def __iter__(self): + def __iter__(self) -> Iterator[Any]: return self - def next(self): + def next(self) -> Any: try: return self.receive() except EOFError: - raise StopIteration + raise StopIteration from None __next__ = next - def reconfigure(self, py2str_as_py3str=True, py3str_as_py2str=False): - """ - set the string coercion for this channel - the default is to try to convert py2 str as py3 str, + def reconfigure( + self, py2str_as_py3str: bool = True, py3str_as_py2str: bool = False + ) -> None: + """Set the string coercion for this channel. + + The default is to try to convert py2 str as py3 str, but not to try and convert py3 str to py2 str """ self._strconfig = (py2str_as_py3str, py3str_as_py2str) @@ -856,20 +961,28 @@ def reconfigure(self, py2str_as_py3str=True, py3str_as_py2str=False): ENDMARKER = object() INTERRUPT_TEXT = "keyboard-interrupted" +MAIN_THREAD_ONLY_DEADLOCK_TEXT = ( + "concurrent remote_exec would cause deadlock for main_thread_only execmodel" +) class ChannelFactory: - def __init__(self, gateway, startcount=1): - self._channels = weakref.WeakValueDictionary() - self._callbacks = {} + def __init__(self, gateway: BaseGateway, startcount: int = 1) -> None: + self._channels: weakref.WeakValueDictionary[int, Channel] = ( + weakref.WeakValueDictionary() + ) + # Channel ID => (callback, end marker, strconfig) + self._callbacks: dict[ + int, tuple[Callable[[Any], Any], object, tuple[bool, bool]] + ] = {} self._writelock = gateway.execmodel.Lock() self.gateway = gateway self.count = startcount self.finished = False self._list = list # needed during interp-shutdown - def new(self, id=None): - """create a new Channel with 'id' (or create new id if None).""" + def new(self, id: int | None = None) -> Channel: + """Create a new Channel with 'id' (or create new id if None).""" with self._writelock: if self.finished: raise OSError(f"connection already closed: {self.gateway}") @@ -882,13 +995,13 @@ def new(self, id=None): channel = self._channels[id] = Channel(self.gateway, id) return channel - def channels(self): + def channels(self) -> list[Channel]: return self._list(self._channels.values()) # # internal methods, called from the receiver thread # - def _no_longer_opened(self, id): + def _no_longer_opened(self, id: int) -> None: try: del self._channels[id] except KeyError: @@ -901,7 +1014,7 @@ def _no_longer_opened(self, id): if endmarker is not NO_ENDMARKER_WANTED: callback(endmarker) - def _local_close(self, id, remoteerror=None, sendonly=False): + def _local_close(self, id: int, remoteerror=None, sendonly: bool = False) -> None: channel = self._channels.get(id) if channel is None: # channel already in "deleted" state @@ -920,13 +1033,13 @@ def _local_close(self, id, remoteerror=None, sendonly=False): channel._closed = True # --> "closed" channel._receiveclosed.set() - def _local_receive(self, id, data): + def _local_receive(self, id: int, data) -> None: # executes in receiver thread channel = self._channels.get(id) try: callback, endmarker, strconfig = self._callbacks[id] except KeyError: - queue = channel and channel._items + queue = channel._items if channel is not None else None if queue is None: pass # drop data else: @@ -936,16 +1049,15 @@ def _local_receive(self, id, data): try: data = loads_internal(data, channel, strconfig) callback(data) # even if channel may be already closed - except Exception: - excinfo = sys.exc_info() - self.gateway._trace("exception during callback: %s" % excinfo[1]) - errortext = self.gateway._geterrortext(excinfo) + except Exception as exc: + self.gateway._trace("exception during callback: %s" % exc) + errortext = self.gateway._geterrortext(exc) self.gateway._send( Message.CHANNEL_CLOSE_ERROR, id, dumps_internal(errortext) ) self._local_close(id, errortext) - def _finished_receiving(self): + def _finished_receiving(self) -> None: with self._writelock: self.finished = True for id in self._list(self._channels): @@ -955,41 +1067,41 @@ def _finished_receiving(self): class ChannelFile: - def __init__(self, channel, proxyclose=True): + def __init__(self, channel: Channel, proxyclose: bool = True) -> None: self.channel = channel self._proxyclose = proxyclose - def isatty(self): + def isatty(self) -> bool: return False - def close(self): + def close(self) -> None: if self._proxyclose: self.channel.close() - def __repr__(self): + def __repr__(self) -> str: state = self.channel.isclosed() and "closed" or "open" return "" % (self.channel.id, state) class ChannelFileWrite(ChannelFile): - def write(self, out): + def write(self, out: bytes) -> None: self.channel.send(out) - def flush(self): + def flush(self) -> None: pass class ChannelFileRead(ChannelFile): - def __init__(self, channel, proxyclose=True): + def __init__(self, channel: Channel, proxyclose: bool = True) -> None: super().__init__(channel, proxyclose) - self._buffer = None + self._buffer: str | None = None - def read(self, n): + def read(self, n: int) -> str: try: if self._buffer is None: - self._buffer = self.channel.receive() + self._buffer = cast(str, self.channel.receive()) while len(self._buffer) < n: - self._buffer += self.channel.receive() + self._buffer += cast(str, self.channel.receive()) except EOFError: self.close() if self._buffer is None: @@ -999,7 +1111,7 @@ def read(self, n): self._buffer = self._buffer[n:] return ret - def readline(self): + def readline(self) -> str: if self._buffer is not None: i = self._buffer.find("\n") if i != -1: @@ -1016,11 +1128,10 @@ def readline(self): class BaseGateway: - exc_info = sys.exc_info _sysex = sysex id = "" - def __init__(self, io, id, _startcount=2): + def __init__(self, io: IO, id, _startcount: int = 2) -> None: self.execmodel = io.execmodel self._io = io self.id = id @@ -1032,14 +1143,14 @@ def __init__(self, io, id, _startcount=2): self._geterrortext = geterrortext self._receivepool = WorkerPool(self.execmodel) - def _trace(self, *msg): + def _trace(self, *msg: object) -> None: self.__trace(self.id, *msg) - def _initreceive(self): + def _initreceive(self) -> None: self._receivepool.spawn(self._thread_receiver) - def _thread_receiver(self): - def log(*msg): + def _thread_receiver(self) -> None: + def log(*msg: object) -> None: self._trace("[receiver-thread]", *msg) log("RECEIVERTHREAD: starting to run") @@ -1053,11 +1164,11 @@ def log(*msg): del msg except (KeyboardInterrupt, GatewayReceivedTerminate): pass - except EOFError: + except EOFError as exc: log("EOF without prior gateway termination message") - self._error = self.exc_info()[1] - except Exception: - log(self._geterrortext(self.exc_info())) + self._error = exc + except Exception as exc: + log(self._geterrortext(exc)) log("finishing receiving thread") # wake up and terminate any execution waiting to receive self._channelfactory._finished_receiving() @@ -1070,21 +1181,20 @@ def log(*msg): log("terminating our receive pseudo pool") self._receivepool.trigger_shutdown() - def _terminate_execution(self): + def _terminate_execution(self) -> None: pass - def _send(self, msgcode, channelid=0, data=b""): + def _send(self, msgcode: int, channelid: int = 0, data: bytes = b"") -> None: message = Message(msgcode, channelid, data) try: message.to_io(self._io) self._trace("sent", message) - except (OSError, ValueError): - e = sys.exc_info()[1] + except (OSError, ValueError) as e: self._trace("failed to send", message, e) # ValueError might be because the IO is already closed - raise OSError("cannot send (already closed?)") + raise OSError("cannot send (already closed?)") from e - def _local_schedulexec(self, channel, sourcetask): + def _local_schedulexec(self, channel: Channel, sourcetask: bytes) -> None: channel.close("execution disallowed") # _____________________________________________________________________ @@ -1092,22 +1202,37 @@ def _local_schedulexec(self, channel, sourcetask): # High Level Interface # _____________________________________________________________________ # - def newchannel(self): - """return a new independent channel.""" + def newchannel(self) -> Channel: + """Return a new independent channel.""" return self._channelfactory.new() - def join(self, timeout=None): + def join(self, timeout: float | None = None) -> None: """Wait for receiverthread to terminate.""" self._trace("waiting for receiver thread to finish") - self._receivepool.waitall() + self._receivepool.waitall(timeout) class WorkerGateway(BaseGateway): - def _local_schedulexec(self, channel, sourcetask): - sourcetask = loads_internal(sourcetask) - self._execpool.spawn(self.executetask, (channel, sourcetask)) - - def _terminate_execution(self): + def _local_schedulexec(self, channel: Channel, sourcetask: bytes) -> None: + if self._execpool.execmodel.backend == "main_thread_only": + assert self._executetask_complete is not None + # It's necessary to wait for a short time in order to ensure + # that we do not report a false-positive deadlock error, since + # channel close does not elicit a response that would provide + # a guarantee to remote_exec callers that the previous task + # has released the main thread. If the timeout expires then it + # should be practically impossible to report a false-positive. + if not self._executetask_complete.wait(timeout=1): + channel.close(MAIN_THREAD_ONLY_DEADLOCK_TEXT) + return + # It's only safe to clear here because the above wait proves + # that there is not a previous task about to set it again. + self._executetask_complete.clear() + + sourcetask_ = loads_internal(sourcetask) + self._execpool.spawn(self.executetask, (channel, sourcetask_)) + + def _terminate_execution(self) -> None: # called from receiverthread self._trace("shutting down execution pool") self._execpool.trigger_shutdown() @@ -1127,12 +1252,18 @@ def _terminate_execution(self): ) os._exit(1) - def serve(self): - def trace(msg): + def serve(self) -> None: + def trace(msg: str) -> None: self._trace("[serve] " + msg) - hasprimary = self.execmodel.backend == "thread" + hasprimary = self.execmodel.backend in ("thread", "main_thread_only") self._execpool = WorkerPool(self.execmodel, hasprimary=hasprimary) + self._executetask_complete = None + if self.execmodel.backend == "main_thread_only": + self._executetask_complete = self.execmodel.Event() + # Initialize state to indicate that there is no previous task + # executing so that we don't need a separate flag to track this. + self._executetask_complete.set() trace("spawning receiver thread") self._initreceive() try: @@ -1146,10 +1277,13 @@ def trace(msg): # in the worker we can't really do anything sensible trace("swallowing keyboardinterrupt, serve finished") - def executetask(self, item): + def executetask( + self, + item: tuple[Channel, tuple[str, str | None, str | None, dict[str, object]]], + ) -> None: try: channel, (source, file_name, call_name, kwargs) = item - loc = {"channel": channel, "__name__": "__channelexec__"} + loc: dict[str, Any] = {"channel": channel, "__name__": "__channelexec__"} self._trace(f"execution starts[{channel.id}]: {repr(source)[:50]}") channel._executing = True try: @@ -1165,16 +1299,20 @@ def executetask(self, item): except KeyboardInterrupt: channel.close(INTERRUPT_TEXT) raise - except BaseException: - excinfo = self.exc_info() - if not isinstance(excinfo[1], EOFError): + except BaseException as exc: + if not isinstance(exc, EOFError): if not channel.gateway._channelfactory.finished: - self._trace(f"got exception: {excinfo[1]!r}") - errortext = self._geterrortext(excinfo) + self._trace(f"got exception: {exc!r}") + errortext = self._geterrortext(exc) channel.close(errortext) return self._trace("ignoring EOFError because receiving finished") channel.close() + if self._executetask_complete is not None: + # Indicate that this task has finished executing, meaning + # that there is no possibility of it triggering a deadlock + # for the next spawn call. + self._executetask_complete.set() # @@ -1194,7 +1332,7 @@ class LoadError(DataFormatError): """Error while unserializing an object.""" -def bchr(n): +def bchr(n: int) -> bytes: return bytes([n]) @@ -1213,7 +1351,7 @@ class _Stop(Exception): class opcode: - """container for name -> num mappings.""" + """Container for name -> num mappings.""" BUILDTUPLE = b"@" BYTES = b"A" @@ -1243,20 +1381,32 @@ class Unserializer: py2str_as_py3str = True # True py3str_as_py2str = False # false means py2 will get unicode - def __init__(self, stream, channel_or_gateway=None, strconfig=None): - gateway = getattr(channel_or_gateway, "gateway", channel_or_gateway) - strconfig = getattr(channel_or_gateway, "_strconfig", strconfig) + def __init__( + self, + stream: ReadIO, + channel_or_gateway: Channel | BaseGateway | None = None, + strconfig: tuple[bool, bool] | None = None, + ) -> None: + if isinstance(channel_or_gateway, Channel): + gw: BaseGateway | None = channel_or_gateway.gateway + else: + gw = channel_or_gateway + if channel_or_gateway is not None: + strconfig = channel_or_gateway._strconfig if strconfig: self.py2str_as_py3str, self.py3str_as_py2str = strconfig self.stream = stream - self.channelfactory = getattr(gateway, "_channelfactory", gateway) + if gw is None: + self.channelfactory = None + else: + self.channelfactory = gw._channelfactory - def load(self, versioned=False): + def load(self, versioned: bool = False) -> Any: if versioned: ver = self.stream.read(1) if ver != DUMPFORMAT_VERSION: raise LoadError("wrong dumpformat version %r" % ver) - self.stack = [] + self.stack: list[object] = [] try: while True: opcode = self.stream.read(1) @@ -1266,38 +1416,38 @@ def load(self, versioned=False): loader = self.num2func[opcode] except KeyError: raise LoadError( - "unknown opcode %r - " "wire protocol corruption?" % (opcode,) - ) + f"unknown opcode {opcode!r} - wire protocol corruption?" + ) from None loader(self) except _Stop: if len(self.stack) != 1: - raise LoadError("internal unserialization error") + raise LoadError("internal unserialization error") from None return self.stack.pop(0) else: raise LoadError("didn't get STOP") - def load_none(self): + def load_none(self) -> None: self.stack.append(None) num2func[opcode.NONE] = load_none - def load_true(self): + def load_true(self) -> None: self.stack.append(True) num2func[opcode.TRUE] = load_true - def load_false(self): + def load_false(self) -> None: self.stack.append(False) num2func[opcode.FALSE] = load_false - def load_int(self): + def load_int(self) -> None: i = self._read_int4() self.stack.append(i) num2func[opcode.INT] = load_int - def load_longint(self): + def load_longint(self) -> None: s = self._read_byte_string() self.stack.append(int(s)) @@ -1308,27 +1458,28 @@ def load_longint(self): load_longlong = load_longint num2func[opcode.LONGLONG] = load_longlong - def load_float(self): + def load_float(self) -> None: binary = self.stream.read(FLOAT_FORMAT_SIZE) self.stack.append(struct.unpack(FLOAT_FORMAT, binary)[0]) num2func[opcode.FLOAT] = load_float - def load_complex(self): + def load_complex(self) -> None: binary = self.stream.read(COMPLEX_FORMAT_SIZE) self.stack.append(complex(*struct.unpack(COMPLEX_FORMAT, binary))) num2func[opcode.COMPLEX] = load_complex - def _read_int4(self): - return struct.unpack("!i", self.stream.read(4))[0] + def _read_int4(self) -> int: + value: int = struct.unpack("!i", self.stream.read(4))[0] + return value - def _read_byte_string(self): + def _read_byte_string(self) -> bytes: length = self._read_int4() as_bytes = self.stream.read(length) return as_bytes - def load_py3string(self): + def load_py3string(self) -> None: as_bytes = self._read_byte_string() if self.py3str_as_py2str: # XXX Should we try to decode into latin-1? @@ -1338,48 +1489,48 @@ def load_py3string(self): num2func[opcode.PY3STRING] = load_py3string - def load_py2string(self): + def load_py2string(self) -> None: as_bytes = self._read_byte_string() if self.py2str_as_py3str: - s = as_bytes.decode("latin-1") + s: bytes | str = as_bytes.decode("latin-1") else: s = as_bytes self.stack.append(s) num2func[opcode.PY2STRING] = load_py2string - def load_bytes(self): + def load_bytes(self) -> None: s = self._read_byte_string() self.stack.append(s) num2func[opcode.BYTES] = load_bytes - def load_unicode(self): + def load_unicode(self) -> None: self.stack.append(self._read_byte_string().decode("utf-8")) num2func[opcode.UNICODE] = load_unicode - def load_newlist(self): + def load_newlist(self) -> None: length = self._read_int4() self.stack.append([None] * length) num2func[opcode.NEWLIST] = load_newlist - def load_setitem(self): + def load_setitem(self) -> None: if len(self.stack) < 3: raise LoadError("not enough items for setitem") value = self.stack.pop() key = self.stack.pop() - self.stack[-1][key] = value + self.stack[-1][key] = value # type: ignore[index] num2func[opcode.SETITEM] = load_setitem - def load_newdict(self): + def load_newdict(self) -> None: self.stack.append({}) num2func[opcode.NEWDICT] = load_newdict - def _load_collection(self, type_): + def _load_collection(self, type_: type) -> None: length = self._read_int4() if length: res = type_(self.stack[-length:]) @@ -1388,60 +1539,63 @@ def _load_collection(self, type_): else: self.stack.append(type_()) - def load_buildtuple(self): + def load_buildtuple(self) -> None: self._load_collection(tuple) num2func[opcode.BUILDTUPLE] = load_buildtuple - def load_set(self): + def load_set(self) -> None: self._load_collection(set) num2func[opcode.SET] = load_set - def load_frozenset(self): + def load_frozenset(self) -> None: self._load_collection(frozenset) num2func[opcode.FROZENSET] = load_frozenset - def load_stop(self): + def load_stop(self) -> None: raise _Stop num2func[opcode.STOP] = load_stop - def load_channel(self): + def load_channel(self) -> None: id = self._read_int4() + assert self.channelfactory is not None newchannel = self.channelfactory.new(id) self.stack.append(newchannel) num2func[opcode.CHANNEL] = load_channel -def dumps(obj): - """return a serialized bytestring of the given obj. +def dumps(obj: object) -> bytes: + """Serialize the given obj to a bytestring. The obj and all contained objects must be of a builtin - python type (so nested dicts, sets, etc. are all ok but + Python type (so nested dicts, sets, etc. are all OK but not user-level instances). """ - return _Serializer().save(obj, versioned=True) + return _Serializer().save(obj, versioned=True) # type: ignore[return-value] -def dump(byteio, obj): +def dump(byteio, obj: object) -> None: """write a serialized bytestring of the given obj to the given stream.""" _Serializer(write=byteio.write).save(obj, versioned=True) -def loads(bytestring, py2str_as_py3str=False, py3str_as_py2str=False): - """return the object as deserialized from the given bytestring. +def loads( + bytestring: bytes, py2str_as_py3str: bool = False, py3str_as_py2str: bool = False +) -> Any: + """Deserialize the given bytestring to an object. - py2str_as_py3str: if true then string (str) objects previously + py2str_as_py3str: If true then string (str) objects previously dumped on Python2 will be loaded as Python3 strings which really are text objects. - py3str_as_py2str: if true then string (str) objects previously + py3str_as_py2str: If true then string (str) objects previously dumped on Python3 will be loaded as Python2 strings instead of unicode objects. - if the bytestring was dumped with an incompatible protocol + If the bytestring was dumped with an incompatible protocol version or if the bytestring is corrupted, the ``execnet.DataFormatError`` will be raised. """ @@ -1451,8 +1605,10 @@ def loads(bytestring, py2str_as_py3str=False, py3str_as_py2str=False): ) -def load(io, py2str_as_py3str=False, py3str_as_py2str=False): - """derserialize an object form the specified stream. +def load( + io: ReadIO, py2str_as_py3str: bool = False, py3str_as_py2str: bool = False +) -> Any: + """Derserialize an object form the specified stream. Behaviour and parameters are otherwise the same as with ``loads`` """ @@ -1460,25 +1616,29 @@ def load(io, py2str_as_py3str=False, py3str_as_py2str=False): return Unserializer(io, strconfig=strconfig).load(versioned=True) -def loads_internal(bytestring, channelfactory=None, strconfig=None): +def loads_internal( + bytestring: bytes, + channelfactory=None, + strconfig: tuple[bool, bool] | None = None, +) -> Any: io = BytesIO(bytestring) return Unserializer(io, channelfactory, strconfig).load() -def dumps_internal(obj): - return _Serializer().save(obj) +def dumps_internal(obj: object) -> bytes: + return _Serializer().save(obj) # type: ignore[return-value] class _Serializer: _dispatch: dict[type, Callable[[_Serializer, object], None]] = {} - def __init__(self, write=None): + def __init__(self, write: Callable[[bytes], None] | None = None) -> None: if write is None: - self._streamlist = [] + self._streamlist: list[bytes] = [] write = self._streamlist.append self._write = write - def save(self, obj, versioned=False): + def save(self, obj: object, versioned: bool = False) -> bytes | None: # calling here is not re-entrant but multiple instances # may write to the same stream because of the common platform # atomic-write guarantee (concurrent writes each happen atomically) @@ -1492,47 +1652,49 @@ def save(self, obj, versioned=False): return None return b"".join(streamlist) - def _save(self, obj): + def _save(self, obj: object) -> None: tp = type(obj) try: dispatch = self._dispatch[tp] except KeyError: methodname = "save_" + tp.__name__ - meth = getattr(self.__class__, methodname, None) + meth: Callable[[_Serializer, object], None] | None = getattr( + self.__class__, methodname, None + ) if meth is None: - raise DumpError(f"can't serialize {tp}") + raise DumpError(f"can't serialize {tp}") from None dispatch = self._dispatch[tp] = meth dispatch(self, obj) - def save_NoneType(self, non): + def save_NoneType(self, non: None) -> None: self._write(opcode.NONE) - def save_bool(self, boolean): + def save_bool(self, boolean: bool) -> None: if boolean: self._write(opcode.TRUE) else: self._write(opcode.FALSE) - def save_bytes(self, bytes_): + def save_bytes(self, bytes_: bytes) -> None: self._write(opcode.BYTES) self._write_byte_sequence(bytes_) - def save_str(self, s): + def save_str(self, s: str) -> None: self._write(opcode.PY3STRING) self._write_unicode_string(s) - def _write_unicode_string(self, s): + def _write_unicode_string(self, s: str) -> None: try: as_bytes = s.encode("utf-8") - except UnicodeEncodeError: - raise DumpError("strings must be utf-8 encodable") + except UnicodeEncodeError as e: + raise DumpError("strings must be utf-8 encodable") from e self._write_byte_sequence(as_bytes) - def _write_byte_sequence(self, bytes_): + def _write_byte_sequence(self, bytes_: bytes) -> None: self._write_int4(len(bytes_), "string is too long") self._write(bytes_) - def _save_integral(self, i, short_op, long_op): + def _save_integral(self, i: int, short_op: bytes, long_op: bytes) -> None: if i <= FOUR_BYTE_INT_MAX: self._write(short_op) self._write_int4(i) @@ -1540,65 +1702,67 @@ def _save_integral(self, i, short_op, long_op): self._write(long_op) self._write_byte_sequence(str(i).rstrip("L").encode("ascii")) - def save_int(self, i): + def save_int(self, i: int) -> None: self._save_integral(i, opcode.INT, opcode.LONGINT) - def save_long(self, l): + def save_long(self, l: int) -> None: self._save_integral(l, opcode.LONG, opcode.LONGLONG) - def save_float(self, flt): + def save_float(self, flt: float) -> None: self._write(opcode.FLOAT) self._write(struct.pack(FLOAT_FORMAT, flt)) - def save_complex(self, cpx): + def save_complex(self, cpx: complex) -> None: self._write(opcode.COMPLEX) self._write(struct.pack(COMPLEX_FORMAT, cpx.real, cpx.imag)) - def _write_int4(self, i, error="int must be less than %i" % (FOUR_BYTE_INT_MAX,)): + def _write_int4( + self, i: int, error: str = "int must be less than %i" % (FOUR_BYTE_INT_MAX,) + ) -> None: if i > FOUR_BYTE_INT_MAX: raise DumpError(error) self._write(struct.pack("!i", i)) - def save_list(self, L): + def save_list(self, L: list[object]) -> None: self._write(opcode.NEWLIST) self._write_int4(len(L), "list is too long") for i, item in enumerate(L): self._write_setitem(i, item) - def _write_setitem(self, key, value): + def _write_setitem(self, key: object, value: object) -> None: self._save(key) self._save(value) self._write(opcode.SETITEM) - def save_dict(self, d): + def save_dict(self, d: dict[object, object]) -> None: self._write(opcode.NEWDICT) for key, value in d.items(): self._write_setitem(key, value) - def save_tuple(self, tup): + def save_tuple(self, tup: tuple[object, ...]) -> None: for item in tup: self._save(item) self._write(opcode.BUILDTUPLE) self._write_int4(len(tup), "tuple is too long") - def _write_set(self, s, op): + def _write_set(self, s: set[object] | frozenset[object], op: bytes) -> None: for item in s: self._save(item) self._write(op) self._write_int4(len(s), "set is too long") - def save_set(self, s): + def save_set(self, s: set[object]) -> None: self._write_set(s, opcode.SET) - def save_frozenset(self, s): + def save_frozenset(self, s: frozenset[object]) -> None: self._write_set(s, opcode.FROZENSET) - def save_Channel(self, channel): + def save_Channel(self, channel: Channel) -> None: self._write(opcode.CHANNEL) self._write_int4(channel.id) -def init_popen_io(execmodel): +def init_popen_io(execmodel: ExecModel) -> Popen2IO: if not hasattr(os, "dup"): # jython io = Popen2IO(sys.stdout, sys.stdin, execmodel) import tempfile @@ -1630,11 +1794,13 @@ def init_popen_io(execmodel): os.dup2(fd, 2) os.close(fd) io = Popen2IO(stdout, stdin, execmodel) - sys.stdin = execmodel.fdopen(0, "r", 1) - sys.stdout = execmodel.fdopen(1, "w", 1) + # Use closefd=False since 0 and 1 are shared with + # sys.__stdin__ and sys.__stdout__. + sys.stdin = execmodel.fdopen(0, "r", 1, closefd=False) + sys.stdout = execmodel.fdopen(1, "w", 1, closefd=False) return io -def serve(io, id): +def serve(io: IO, id) -> None: trace(f"creating workergateway on {io!r}") WorkerGateway(io=io, id=id, _startcount=2).serve() diff --git a/src/execnet/gateway_bootstrap.py b/src/execnet/gateway_bootstrap.py index ba5bf104..e9d7efe1 100644 --- a/src/execnet/gateway_bootstrap.py +++ b/src/execnet/gateway_bootstrap.py @@ -1,13 +1,15 @@ -""" -code to initialize the remote side of a gateway once the io is created -""" +"""Code to initialize the remote side of a gateway once the IO is created.""" + +from __future__ import annotations + import inspect import os import execnet from . import gateway_base -from .gateway import Gateway +from .gateway_base import IO +from .xspec import XSpec importdir = os.path.dirname(os.path.dirname(execnet.__file__)) @@ -16,10 +18,10 @@ class HostNotFound(Exception): pass -def bootstrap_import(io, spec): - # only insert the importdir into the path if we must. This prevents +def bootstrap_import(io: IO, spec: XSpec) -> None: + # Only insert the importdir into the path if we must. This prevents # bugs where backports expect to be shadowed by the standard library on - # newer versions of python but would instead shadow the standard library + # newer versions of python but would instead shadow the standard library. sendexec( io, "import sys", @@ -35,7 +37,7 @@ def bootstrap_import(io, spec): assert s == b"1", repr(s) -def bootstrap_exec(io, spec): +def bootstrap_exec(io: IO, spec: XSpec) -> None: try: sendexec( io, @@ -49,11 +51,11 @@ def bootstrap_exec(io, spec): assert s == b"1" except EOFError: ret = io.wait() - if ret == 255: - raise HostNotFound(io.remoteaddress) + if ret == 255 and hasattr(io, "remoteaddress"): + raise HostNotFound(io.remoteaddress) from None -def bootstrap_socket(io, id): +def bootstrap_socket(io: IO, id) -> None: # XXX: switch to spec from execnet.gateway_socket import SocketIO @@ -73,26 +75,12 @@ def bootstrap_socket(io, id): assert s == b"1" -def sendexec(io, *sources): +def sendexec(io: IO, *sources: str) -> None: source = "\n".join(sources) io.write((repr(source) + "\n").encode("utf-8")) -def fix_pid_for_jython_popen(gw): - """ - fix for jython 2.5.1 - """ - spec, io = gw.spec, gw._io - if spec.popen and not spec.via: - # XXX: handle the case of remote being jython - # and not having the popen pid - if io.popen.pid is None: - io.popen.pid = gw.remote_exec( - "import os; channel.send(os.getpid())" - ).receive() - - -def bootstrap(io, spec): +def bootstrap(io: IO, spec: XSpec) -> execnet.Gateway: if spec.popen: if spec.via or spec.python: bootstrap_exec(io, spec) @@ -104,6 +92,5 @@ def bootstrap(io, spec): bootstrap_socket(io, spec) else: raise ValueError("unknown gateway type, can't bootstrap") - gw = Gateway(io, spec) - fix_pid_for_jython_popen(gw) + gw = execnet.Gateway(io, spec) return gw diff --git a/src/execnet/gateway_io.py b/src/execnet/gateway_io.py index c631f8d9..21285ab4 100644 --- a/src/execnet/gateway_io.py +++ b/src/execnet/gateway_io.py @@ -1,48 +1,57 @@ -""" -execnet io initialization code +"""execnet IO initialization code. -creates io instances used for gateway io +Creates IO instances used for gateway IO. """ -import os + +from __future__ import annotations + import shlex import sys +from typing import TYPE_CHECKING +from typing import cast + +if TYPE_CHECKING: + from execnet.gateway_base import Channel + from execnet.gateway_base import ExecModel + from execnet.xspec import XSpec try: - from execnet.gateway_base import Popen2IO, Message + from execnet.gateway_base import Message + from execnet.gateway_base import Popen2IO except ImportError: - from __main__ import Popen2IO, Message # type: ignore[no-redef] + from __main__ import Message # type: ignore[no-redef] + from __main__ import Popen2IO # type: ignore[no-redef] from functools import partial class Popen2IOMaster(Popen2IO): - def __init__(self, args, execmodel): + # Set externally, for some specs only. + remoteaddress: str + + def __init__(self, args, execmodel: ExecModel) -> None: PIPE = execmodel.subprocess.PIPE self.popen = p = execmodel.subprocess.Popen(args, stdout=PIPE, stdin=PIPE) super().__init__(p.stdin, p.stdout, execmodel=execmodel) - def wait(self): + def wait(self) -> int | None: try: - return self.popen.wait() + return self.popen.wait() # type: ignore[no-any-return] except OSError: - pass # subprocess probably dead already + return None - def kill(self): - killpopen(self.popen) - - -def killpopen(popen): - try: - popen.kill() - except OSError as e: - sys.stderr.write("ERROR killing: %s\n" % e) - sys.stderr.flush() + def kill(self) -> None: + try: + self.popen.kill() + except OSError as e: + sys.stderr.write("ERROR killing: %s\n" % e) + sys.stderr.flush() popen_bootstrapline = "import sys;exec(eval(sys.stdin.readline()))" -def shell_split_path(path): +def shell_split_path(path: str) -> list[str]: """ Use shell lexer to split the given path into a list of components, taking care to handle Windows' '\' correctly. @@ -53,7 +62,7 @@ def shell_split_path(path): return shlex.split(path) -def popen_args(spec): +def popen_args(spec: XSpec) -> list[str]: args = shell_split_path(spec.python) if spec.python else [sys.executable] args.append("-u") if spec.dont_write_bytecode: @@ -62,7 +71,7 @@ def popen_args(spec): return args -def ssh_args(spec): +def ssh_args(spec: XSpec) -> list[str]: # NOTE: If changing this, you need to sync those changes to vagrant_args # as well, or, take some time to further refactor the commonalities of # ssh_args and vagrant_args. @@ -71,19 +80,21 @@ def ssh_args(spec): if spec.ssh_config is not None: args.extend(["-F", str(spec.ssh_config)]) + assert spec.ssh is not None args.extend(spec.ssh.split()) remotecmd = f'{remotepython} -c "{popen_bootstrapline}"' args.append(remotecmd) return args -def vagrant_ssh_args(spec): +def vagrant_ssh_args(spec: XSpec) -> list[str]: # This is the vagrant-wrapped version of SSH. Unfortunately the # command lines are incompatible to just channel through ssh_args # due to ordering/templating issues. # NOTE: This should be kept in sync with the ssh_args behaviour. # spec.vagrant is identical to spec.ssh in that they both carry # the remote host "address". + assert spec.vagrant_ssh is not None remotepython = spec.python or "python" args = ["vagrant", "ssh", spec.vagrant_ssh, "--", "-C"] if spec.ssh_config is not None: @@ -93,7 +104,7 @@ def vagrant_ssh_args(spec): return args -def create_io(spec, execmodel): +def create_io(spec: XSpec, execmodel: ExecModel) -> Popen2IOMaster: if spec.popen: args = popen_args(spec) return Popen2IOMaster(args, execmodel) @@ -107,6 +118,7 @@ def create_io(spec, execmodel): io = Popen2IOMaster(args, execmodel) io.remoteaddress = spec.vagrant_ssh return io + assert False # @@ -124,14 +136,15 @@ def create_io(spec, execmodel): class ProxyIO: """A Proxy IO object allows to instantiate a Gateway - through another "via" gateway. A master:ProxyIO object - provides an IO object effectively connected to the sub - via the forwarder. To achieve this, master:ProxyIO interacts - with forwarder:serve_proxy_io() which itself - instantiates and interacts with the sub. + through another "via" gateway. + + A master:ProxyIO object provides an IO object effectively connected to the + sub via the forwarder. To achieve this, master:ProxyIO interacts with + forwarder:serve_proxy_io() which itself instantiates and interacts with the + sub. """ - def __init__(self, proxy_channel, execmodel): + def __init__(self, proxy_channel: Channel, execmodel: ExecModel) -> None: # after exchanging the control channel we use proxy_channel # for messaging IO self.controlchan = proxy_channel.gateway.newchannel() @@ -140,69 +153,80 @@ def __init__(self, proxy_channel, execmodel): self.iochan_file = self.iochan.makefile("r") self.execmodel = execmodel - def read(self, nbytes): - return self.iochan_file.read(nbytes) + def read(self, nbytes: int) -> bytes: + # TODO(typing): The IO protocol requires bytes here but ChannelFileRead + # returns str. + return self.iochan_file.read(nbytes) # type: ignore[return-value] - def write(self, data): - return self.iochan.send(data) + def write(self, data: bytes) -> None: + self.iochan.send(data) - def _controll(self, event): + def _controll(self, event: int) -> object: self.controlchan.send(event) return self.controlchan.receive() - def close_write(self): + def close_write(self) -> None: self._controll(RIO_CLOSE_WRITE) - def kill(self): + def close_read(self) -> None: + raise NotImplementedError() + + def kill(self) -> None: self._controll(RIO_KILL) - def wait(self): - return self._controll(RIO_WAIT) + def wait(self) -> int | None: + response = self._controll(RIO_WAIT) + assert response is None or isinstance(response, int) + return response @property - def remoteaddress(self): - return self._controll(RIO_REMOTEADDRESS) + def remoteaddress(self) -> str: + response = self._controll(RIO_REMOTEADDRESS) + assert isinstance(response, str) + return response - def __repr__(self): + def __repr__(self) -> str: return f"" class PseudoSpec: - def __init__(self, vars): + def __init__(self, vars) -> None: self.__dict__.update(vars) - def __getattr__(self, name): + def __getattr__(self, name: str) -> None: return None -def serve_proxy_io(proxy_channelX): +def serve_proxy_io(proxy_channelX: Channel) -> None: execmodel = proxy_channelX.gateway.execmodel log = partial( proxy_channelX.gateway._trace, "serve_proxy_io:%s" % proxy_channelX.id ) - spec = PseudoSpec(proxy_channelX.receive()) + spec = cast("XSpec", PseudoSpec(proxy_channelX.receive())) # create sub IO object which we will proxy back to our proxy initiator sub_io = create_io(spec, execmodel) - control_chan = proxy_channelX.receive() + control_chan = cast("Channel", proxy_channelX.receive()) log("got control chan", control_chan) # read data from master, forward it to the sub # XXX writing might block, thus blocking the receiver thread - def forward_to_sub(data): + def forward_to_sub(data: bytes) -> None: log("forward data to sub, size %s" % len(data)) sub_io.write(data) proxy_channelX.setcallback(forward_to_sub) - def control(data): + def control(data: int) -> None: if data == RIO_WAIT: control_chan.send(sub_io.wait()) elif data == RIO_KILL: - control_chan.send(sub_io.kill()) + sub_io.kill() + control_chan.send(None) elif data == RIO_REMOTEADDRESS: control_chan.send(sub_io.remoteaddress) elif data == RIO_CLOSE_WRITE: - control_chan.send(sub_io.close_write()) + sub_io.close_write() + control_chan.send(None) control_chan.setcallback(control) @@ -228,4 +252,4 @@ def control(data): if __name__ == "__channelexec__": - serve_proxy_io(channel) # type: ignore[name-defined] + serve_proxy_io(channel) # type: ignore[name-defined] # noqa:F821 diff --git a/src/execnet/gateway_socket.py b/src/execnet/gateway_socket.py index 4379e015..be42f1ab 100644 --- a/src/execnet/gateway_socket.py +++ b/src/execnet/gateway_socket.py @@ -1,10 +1,19 @@ +from __future__ import annotations + import sys +from typing import cast +from execnet.gateway import Gateway +from execnet.gateway_base import ExecModel from execnet.gateway_bootstrap import HostNotFound +from execnet.multi import Group +from execnet.xspec import XSpec class SocketIO: - def __init__(self, sock, execmodel): + remoteaddress: str + + def __init__(self, sock, execmodel: ExecModel) -> None: self.sock = sock self.execmodel = execmodel socket = execmodel.socket @@ -15,7 +24,7 @@ def __init__(self, sock, execmodel): except (AttributeError, OSError): sys.stderr.write("WARNING: cannot set socketoption") - def read(self, numbytes): + def read(self, numbytes: int) -> bytes: "Read exactly 'bytes' bytes from the socket." buf = b"" while len(buf) < numbytes: @@ -25,31 +34,34 @@ def read(self, numbytes): buf += t return buf - def write(self, data): + def write(self, data: bytes) -> None: self.sock.sendall(data) - def close_read(self): + def close_read(self) -> None: try: self.sock.shutdown(0) except self.execmodel.socket.error: pass - def close_write(self): + def close_write(self) -> None: try: self.sock.shutdown(1) except self.execmodel.socket.error: pass - def wait(self): + def wait(self) -> None: pass - def kill(self): + def kill(self) -> None: pass -def start_via(gateway, hostport=None): - """return a host, port tuple, - after instantiating a socketserver on the given gateway +def start_via( + gateway: Gateway, hostport: tuple[str, int] | None = None +) -> tuple[str, int]: + """Instantiate a socketserver on the given gateway. + + Returns a host, port tuple. """ if hostport is None: host, port = ("localhost", 0) @@ -61,7 +73,7 @@ def start_via(gateway, hostport=None): # execute the above socketserverbootstrap on the other side channel = gateway.remote_exec(socketserver) channel.send((host, port)) - (realhost, realport) = channel.receive() + realhost, realport = cast("tuple[str, int]", channel.receive()) # self._trace("new_remote received" # "port=%r, hostname = %r" %(realport, hostname)) if not realhost or realhost == "0.0.0.0": @@ -69,14 +81,15 @@ def start_via(gateway, hostport=None): return realhost, realport -def create_io(spec, group, execmodel): +def create_io(spec: XSpec, group: Group, execmodel: ExecModel) -> SocketIO: + assert spec.socket is not None assert not spec.python, "socket: specifying python executables not yet supported" gateway_id = spec.installvia if gateway_id: host, port = start_via(group[gateway_id]) else: - host, port = spec.socket.split(":") - port = int(port) + host, port_str = spec.socket.split(":") + port = int(port_str) socket = execmodel.socket sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -84,6 +97,6 @@ def create_io(spec, group, execmodel): io.remoteaddress = "%s:%d" % (host, port) try: sock.connect((host, port)) - except execmodel.socket.gaierror: - raise HostNotFound(str(sys.exc_info()[1])) + except execmodel.socket.gaierror as e: + raise HostNotFound() from e return io diff --git a/src/execnet/multi.py b/src/execnet/multi.py index 64e95017..42a2dc36 100644 --- a/src/execnet/multi.py +++ b/src/execnet/multi.py @@ -3,34 +3,54 @@ (c) 2008-2014, Holger Krekel and others """ + +from __future__ import annotations + import atexit -import sys +import types from functools import partial from threading import Lock +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable +from typing import Iterable +from typing import Iterator +from typing import Literal +from typing import Sequence +from typing import overload from . import gateway_bootstrap from . import gateway_io +from .gateway_base import Channel +from .gateway_base import ExecModel +from .gateway_base import WorkerPool from .gateway_base import get_execmodel from .gateway_base import trace -from .gateway_base import WorkerPool from .xspec import XSpec +if TYPE_CHECKING: + from .gateway import Gateway + + NO_ENDMARKER_WANTED = object() class Group: - """Gateway Groups.""" + """Gateway Group.""" defaultspec = "popen" - def __init__(self, xspecs=(), execmodel="thread"): - """initialize group and make gateways as specified. - execmodel can be 'thread' or 'eventlet'. + def __init__( + self, xspecs: Iterable[XSpec | str | None] = (), execmodel: str = "thread" + ) -> None: + """Initialize a group and make gateways as specified. + + execmodel can be one of the supported execution models. """ - self._gateways = [] + self._gateways: list[Gateway] = [] self._autoidcounter = 0 self._autoidlock = Lock() - self._gateways_to_join = [] + self._gateways_to_join: list[Gateway] = [] # we use the same execmodel for all of the Gateway objects # we spawn on our side. Probably we should not allow different # execmodels between different groups but not clear. @@ -42,23 +62,23 @@ def __init__(self, xspecs=(), execmodel="thread"): atexit.register(self._cleanup_atexit) @property - def execmodel(self): + def execmodel(self) -> ExecModel: return self._execmodel @property - def remote_execmodel(self): + def remote_execmodel(self) -> ExecModel: return self._remote_execmodel - def set_execmodel(self, execmodel, remote_execmodel=None): + def set_execmodel( + self, execmodel: str, remote_execmodel: str | None = None + ) -> None: """Set the execution model for local and remote site. - execmodel can be one of "thread" or "eventlet" (XXX gevent). + execmodel can be one of the supported execution models. It determines the execution model for any newly created gateway. - If remote_execmodel is not specified it takes on the value - of execmodel. + If remote_execmodel is not specified it takes on the value of execmodel. NOTE: Execution models can only be set before any gateway is created. - """ if self._gateways: raise ValueError( @@ -69,11 +89,11 @@ def set_execmodel(self, execmodel, remote_execmodel=None): self._execmodel = get_execmodel(execmodel) self._remote_execmodel = get_execmodel(remote_execmodel) - def __repr__(self): + def __repr__(self) -> str: idgateways = [gw.id for gw in self] return "" % idgateways - def __getitem__(self, key): + def __getitem__(self, key: int | str | Gateway) -> Gateway: if isinstance(key, int): return self._gateways[key] for gw in self._gateways: @@ -81,21 +101,22 @@ def __getitem__(self, key): return gw raise KeyError(key) - def __contains__(self, key): + def __contains__(self, key: str) -> bool: try: self[key] return True except KeyError: return False - def __len__(self): + def __len__(self) -> int: return len(self._gateways) - def __iter__(self): + def __iter__(self) -> Iterator[Gateway]: return iter(list(self._gateways)) - def makegateway(self, spec=None): - """create and configure a gateway to a Python interpreter. + def makegateway(self, spec: XSpec | str | None = None) -> Gateway: + """Create and configure a gateway to a Python interpreter. + The ``spec`` string encodes the target gateway type and configuration information. The general format is:: @@ -107,7 +128,7 @@ def makegateway(self, spec=None): id= specifies the gateway id python= specifies which python interpreter to execute - execmodel=model 'thread', 'eventlet', 'gevent' model for execution + execmodel=model 'thread', 'main_thread_only', 'eventlet', 'gevent' execution model chdir= specifies to which directory to change nice= specifies process priority of new process env:NAME=value specifies a remote environment variable setting. @@ -134,8 +155,8 @@ def makegateway(self, spec=None): elif spec.socket: from . import gateway_socket - io = gateway_socket.create_io(spec, self, execmodel=self.execmodel) - gw = gateway_bootstrap.bootstrap(io, spec) + sio = gateway_socket.create_io(spec, self, execmodel=self.execmodel) + gw = gateway_bootstrap.bootstrap(sio, spec) else: raise ValueError(f"no gateway type found for {spec._spec!r}") gw.spec = spec @@ -161,7 +182,7 @@ def makegateway(self, spec=None): channel.waitclose() return gw - def allocate_id(self, spec): + def allocate_id(self, spec: XSpec) -> None: """(re-entrant) allocate id for the given xspec object.""" if spec.id is None: with self._autoidlock: @@ -171,43 +192,45 @@ def allocate_id(self, spec): raise ValueError(f"already have gateway with id {id!r}") spec.id = id - def _register(self, gateway): + def _register(self, gateway: Gateway) -> None: assert not hasattr(gateway, "_group") assert gateway.id assert gateway.id not in self self._gateways.append(gateway) gateway._group = self - def _unregister(self, gateway): + def _unregister(self, gateway: Gateway) -> None: self._gateways.remove(gateway) self._gateways_to_join.append(gateway) - def _cleanup_atexit(self): + def _cleanup_atexit(self) -> None: trace(f"=== atexit cleanup {self!r} ===") self.terminate(timeout=1.0) - def terminate(self, timeout=None): - """trigger exit of member gateways and wait for termination - of member gateways and associated subprocesses. After waiting - timeout seconds try to to kill local sub processes of popen- - and ssh-gateways. Timeout defaults to None meaning - open-ended waiting and no kill attempts. - """ + def terminate(self, timeout: float | None = None) -> None: + """Trigger exit of member gateways and wait for termination + of member gateways and associated subprocesses. + After waiting timeout seconds try to to kill local sub processes of + popen- and ssh-gateways. + + Timeout defaults to None meaning open-ended waiting and no kill + attempts. + """ while self: - vias = {} + vias: set[str] = set() for gw in self: if gw.spec.via: - vias[gw.spec.via] = True + vias.add(gw.spec.via) for gw in self: if gw.id not in vias: gw.exit() - def join_wait(gw): + def join_wait(gw: Gateway) -> None: gw.join() gw._io.wait() - def kill(gw): + def kill(gw: Gateway) -> None: trace("Gateways did not come down after timeout: %r" % gw) gw._io.kill() @@ -221,10 +244,13 @@ def kill(gw): ) self._gateways_to_join[:] = [] - def remote_exec(self, source, **kwargs): + def remote_exec( + self, + source: str | types.FunctionType | Callable[..., object] | types.ModuleType, + **kwargs, + ) -> MultiChannel: """remote_exec source on all member gateways and return - MultiChannel connecting to all sub processes. - """ + a MultiChannel connecting to all sub processes.""" channels = [] for gw in self: channels.append(gw.remote_exec(source, **kwargs)) @@ -232,28 +258,38 @@ def remote_exec(self, source, **kwargs): class MultiChannel: - def __init__(self, channels): + def __init__(self, channels: Sequence[Channel]) -> None: self._channels = channels - def __len__(self): + def __len__(self) -> int: return len(self._channels) - def __iter__(self): + def __iter__(self) -> Iterator[Channel]: return iter(self._channels) - def __getitem__(self, key): + def __getitem__(self, key: int) -> Channel: return self._channels[key] - def __contains__(self, chan): + def __contains__(self, chan: Channel) -> bool: return chan in self._channels - def send_each(self, item): + def send_each(self, item: object) -> None: for ch in self._channels: ch.send(item) - def receive_each(self, withchannel=False): + @overload + def receive_each(self, withchannel: Literal[False] = ...) -> list[Any]: + pass + + @overload + def receive_each(self, withchannel: Literal[True]) -> list[tuple[Channel, Any]]: + pass + + def receive_each( + self, withchannel: bool = False + ) -> list[tuple[Channel, Any]] | list[Any]: assert not hasattr(self, "_queue") - l = [] + l: list[object] = [] for ch in self._channels: obj = ch.receive() if withchannel: @@ -262,17 +298,17 @@ def receive_each(self, withchannel=False): l.append(obj) return l - def make_receive_queue(self, endmarker=NO_ENDMARKER_WANTED): + def make_receive_queue(self, endmarker: object = NO_ENDMARKER_WANTED): try: - return self._queue + return self._queue # type: ignore[has-type] except AttributeError: self._queue = None for ch in self._channels: if self._queue is None: self._queue = ch.gateway.execmodel.queue.Queue() - def putreceived(obj, channel=ch): - self._queue.put((channel, obj)) + def putreceived(obj, channel: Channel = ch) -> None: + self._queue.put((channel, obj)) # type: ignore[union-attr] if endmarker is NO_ENDMARKER_WANTED: ch.setcallback(putreceived) @@ -280,22 +316,24 @@ def putreceived(obj, channel=ch): ch.setcallback(putreceived, endmarker=endmarker) return self._queue - def waitclose(self): + def waitclose(self) -> None: first = None for ch in self._channels: try: ch.waitclose() - except ch.RemoteError: + except ch.RemoteError as exc: if first is None: - first = sys.exc_info() + first = exc if first: - raise first[1].with_traceback(first[2]) + raise first -def safe_terminate(execmodel, timeout, list_of_paired_functions): +def safe_terminate( + execmodel: ExecModel, timeout: float | None, list_of_paired_functions +) -> None: workerpool = WorkerPool(execmodel) - def termkill(termfunc, killfunc): + def termkill(termfunc, killfunc) -> None: termreply = workerpool.spawn(termfunc) try: termreply.get(timeout=timeout) diff --git a/src/execnet/py.typed b/src/execnet/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/execnet/rsync.py b/src/execnet/rsync.py index 1484d49d..84820532 100644 --- a/src/execnet/rsync.py +++ b/src/execnet/rsync.py @@ -3,12 +3,20 @@ (c) 2006-2009, Armin Rigo, Holger Krekel, Maciej Fijalkowski """ + +from __future__ import annotations + import os import stat from hashlib import md5 from queue import Queue +from typing import Callable +from typing import Literal import execnet.rsync_remote +from execnet.gateway import Gateway +from execnet.gateway_base import BaseGateway +from execnet.gateway_base import Channel class RSync: @@ -21,48 +29,64 @@ class RSync: a path on remote side). """ - def __init__(self, sourcedir, callback=None, verbose=True): + def __init__(self, sourcedir, callback=None, verbose: bool = True) -> None: self._sourcedir = str(sourcedir) self._verbose = verbose - assert callback is None or hasattr(callback, "__call__") + assert callback is None or callable(callback) self._callback = callback - self._channels = {} - self._receivequeue = Queue() - self._links = [] - - def filter(self, path): + self._channels: dict[Channel, Callable[[], None] | None] = {} + self._receivequeue: Queue[ + tuple[ + Channel, + ( + None + | tuple[Literal["send"], tuple[list[str], bytes]] + | tuple[Literal["list_done"], None] + | tuple[Literal["ack"], str] + | tuple[Literal["links"], None] + | tuple[Literal["done"], None] + ), + ] + ] = Queue() + self._links: list[tuple[Literal["linkbase", "link"], str, str]] = [] + + def filter(self, path: str) -> bool: return True - def _end_of_channel(self, channel): + def _end_of_channel(self, channel: Channel) -> None: if channel in self._channels: # too early! we must have got an error channel.waitclose() # or else we raise one raise OSError(f"connection unexpectedly closed: {channel.gateway} ") - def _process_link(self, channel): + def _process_link(self, channel: Channel) -> None: for link in self._links: channel.send(link) # completion marker, this host is done channel.send(42) - def _done(self, channel): - """Call all callbacks""" + def _done(self, channel: Channel) -> None: + """Call all callbacks.""" finishedcallback = self._channels.pop(channel) if finishedcallback: finishedcallback() channel.waitclose() - def _list_done(self, channel): + def _list_done(self, channel: Channel) -> None: # sum up all to send if self._callback: s = sum([self._paths[i] for i in self._to_send[channel]]) self._callback("list", s, channel) - def _send_item(self, channel, data): - """Send one item""" - modified_rel_path, checksum = data - modifiedpath = os.path.join(self._sourcedir, *modified_rel_path) + def _send_item( + self, + channel: Channel, + modified_rel_path_components: list[str], + checksum: bytes, + ) -> None: + """Send one item.""" + modifiedpath = os.path.join(self._sourcedir, *modified_rel_path_components) try: f = open(modifiedpath, "rb") data = f.read() @@ -70,7 +94,7 @@ def _send_item(self, channel, data): data = None # provide info to progress callback function - modified_rel_path = "/".join(modified_rel_path) + modified_rel_path = "/".join(modified_rel_path_components) if data is not None: self._paths[modified_rel_path] = len(data) else: @@ -88,14 +112,15 @@ def _send_item(self, channel, data): self._report_send_file(channel.gateway, modified_rel_path) channel.send(data) - def _report_send_file(self, gateway, modified_rel_path): + def _report_send_file(self, gateway: BaseGateway, modified_rel_path: str) -> None: if self._verbose: print(f"{gateway} <= {modified_rel_path}") - def send(self, raises=True): - """Sends a sourcedir to all added targets. Flag indicates - whether to raise an error or return in case of lack of - targets + def send(self, raises: bool = True) -> None: + """Sends a sourcedir to all added targets. + + raises indicates whether to raise an error or return in case of lack of + targets. """ if not self._channels: if raises: @@ -110,8 +135,8 @@ def send(self, raises=True): # paths and to_send are only used for doing # progress-related callbacks - self._paths = {} - self._to_send = {} + self._paths: dict[str, int] = {} + self._to_send: dict[Channel, list[str]] = {} # send modified file to clients while self._channels: @@ -119,30 +144,33 @@ def send(self, raises=True): if req is None: self._end_of_channel(channel) else: - command, data = req - if command == "links": + if req[0] == "links": self._process_link(channel) - elif command == "done": + elif req[0] == "done": self._done(channel) - elif command == "ack": + elif req[0] == "ack": if self._callback: - self._callback("ack", self._paths[data], channel) - elif command == "list_done": + self._callback("ack", self._paths[req[1]], channel) + elif req[0] == "list_done": self._list_done(channel) - elif command == "send": - self._send_item(channel, data) - del data + elif req[0] == "send": + self._send_item(channel, req[1][0], req[1][1]) else: - assert "Unknown command %s" % command - - def add_target(self, gateway, destdir, finishedcallback=None, **options): - """Adds a remote target specified via a gateway - and a remote destination directory. - """ + assert "Unknown command %s" % req[0] # type: ignore[unreachable] + + def add_target( + self, + gateway: Gateway, + destdir: str | os.PathLike[str], + finishedcallback: Callable[[], None] | None = None, + **options, + ) -> None: + """Add a remote target specified via a gateway and a remote destination + directory.""" for name in options: assert name in ("delete",) - def itemcallback(req): + def itemcallback(req) -> None: self._receivequeue.put((channel, req)) channel = gateway.remote_exec(execnet.rsync_remote) @@ -151,14 +179,19 @@ def itemcallback(req): channel.send((str(destdir), options)) self._channels[channel] = finishedcallback - def _broadcast(self, msg): + def _broadcast(self, msg: object) -> None: for channel in self._channels: channel.send(msg) - def _send_link(self, linktype, basename, linkpoint): + def _send_link( + self, + linktype: Literal["linkbase", "link"], + basename: str, + linkpoint: str, + ) -> None: self._links.append((linktype, basename, linkpoint)) - def _send_directory(self, path): + def _send_directory(self, path: str) -> None: # dir: send a list of entries names = [] subpaths = [] @@ -168,11 +201,11 @@ def _send_directory(self, path): names.append(name) subpaths.append(p) mode = os.lstat(path).st_mode - self._broadcast([mode] + names) + self._broadcast([mode, *names]) for p in subpaths: self._send_directory_structure(p) - def _send_link_structure(self, path): + def _send_link_structure(self, path: str) -> None: sourcedir = self._sourcedir basename = path[len(self._sourcedir) + 1 :] linkpoint = os.readlink(path) @@ -189,8 +222,10 @@ def _send_link_structure(self, path): relpath = os.path.relpath(linkpoint, sourcedir) except ValueError: relpath = None - if relpath not in (None, os.curdir, os.pardir) and not relpath.startswith( - os.pardir + os.sep + if ( + relpath is not None + and relpath not in (os.curdir, os.pardir) + and not relpath.startswith(os.pardir + os.sep) ): self._send_link("linkbase", basename, relpath) else: @@ -198,7 +233,7 @@ def _send_link_structure(self, path): self._send_link("link", basename, linkpoint) self._broadcast(None) - def _send_directory_structure(self, path): + def _send_directory_structure(self, path: str) -> None: try: st = os.lstat(path) except OSError: diff --git a/src/execnet/rsync_remote.py b/src/execnet/rsync_remote.py index 4ac1880f..b560df75 100644 --- a/src/execnet/rsync_remote.py +++ b/src/execnet/rsync_remote.py @@ -2,17 +2,26 @@ (c) 2006-2013, Armin Rigo, Holger Krekel, Maciej Fijalkowski """ +from __future__ import annotations -def serve_rsync(channel): +from typing import TYPE_CHECKING +from typing import Literal +from typing import cast + +if TYPE_CHECKING: + from execnet.gateway_base import Channel + + +def serve_rsync(channel: Channel) -> None: import os - import stat import shutil + import stat from hashlib import md5 - destdir, options = channel.receive() + destdir, options = cast("tuple[str, dict[str, object]]", channel.receive()) modifiedfiles = [] - def remove(path): + def remove(path: str) -> None: assert path.startswith(destdir) try: os.unlink(path) @@ -20,7 +29,7 @@ def remove(path): # assume it's a dir shutil.rmtree(path, True) - def receive_directory_structure(path, relcomponents): + def receive_directory_structure(path: str, relcomponents: list[str]) -> None: try: st = os.lstat(path) except OSError: @@ -42,7 +51,7 @@ def receive_directory_structure(path, relcomponents): entrynames = {} for entryname in msg: destpath = os.path.join(path, entryname) - receive_directory_structure(destpath, relcomponents + [entryname]) + receive_directory_structure(destpath, [*relcomponents, entryname]) entrynames[entryname] = True if options.get("delete"): for othername in os.listdir(path): @@ -77,7 +86,7 @@ def receive_directory_structure(path, relcomponents): channel.send(("list_done", None)) for path, (mode, time, size) in modifiedfiles: - data = channel.receive() + data = cast(bytes, channel.receive()) channel.send(("ack", path[len(destdir) + 1 :])) if data is not None: if STRICT_CHECK and len(data) != size: @@ -97,7 +106,9 @@ def receive_directory_structure(path, relcomponents): msg = channel.receive() while msg != 42: # we get symlink - _type, relpath, linkpoint = msg + _type, relpath, linkpoint = cast( + "tuple[Literal['linkbase', 'link'], str, str]", msg + ) path = os.path.join(destdir, relpath) try: remove(path) @@ -114,4 +125,4 @@ def receive_directory_structure(path, relcomponents): if __name__ == "__channelexec__": - serve_rsync(channel) # type: ignore[name-defined] + serve_rsync(channel) # type: ignore[name-defined] # noqa:F821 diff --git a/src/execnet/script/quitserver.py b/src/execnet/script/quitserver.py index 9d5cbb15..4c94c383 100644 --- a/src/execnet/script/quitserver.py +++ b/src/execnet/script/quitserver.py @@ -1,14 +1,14 @@ """ - send a "quit" signal to a remote server +send a "quit" signal to a remote server """ + from __future__ import annotations import socket import sys - host, port = sys.argv[1].split(":") hostport = (host, int(port)) diff --git a/src/execnet/script/shell.py b/src/execnet/script/shell.py index f47cd4d3..569d4042 100644 --- a/src/execnet/script/shell.py +++ b/src/execnet/script/shell.py @@ -4,49 +4,51 @@ for injection into startserver.py """ + import os import select import socket import sys from threading import Thread from traceback import print_exc +from typing import NoReturn -def clientside(): +def clientside() -> NoReturn: print("client side starting") - host, port = sys.argv[1].split(":") - port = int(port) + host, portstr = sys.argv[1].split(":") + port = int(portstr) myself = open(os.path.abspath(sys.argv[0])).read() sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect((host, port)) - sock.sendall(repr(myself) + "\n") + sock.sendall((repr(myself) + "\n").encode()) print("send boot string") inputlist = [sock, sys.stdin] try: while 1: r, w, e = select.select(inputlist, [], []) if sys.stdin in r: - line = raw_input() - sock.sendall(line + "\n") + line = input() + sock.sendall((line + "\n").encode()) if sock in r: - line = sock.recv(4096) + line = sock.recv(4096).decode() sys.stdout.write(line) sys.stdout.flush() except BaseException: import traceback - print(traceback.print_exc()) + traceback.print_exc() sys.exit(1) class promptagent(Thread): - def __init__(self, clientsock): + def __init__(self, clientsock) -> None: print("server side starting") - super.__init__() + super.__init__() # type: ignore[call-overload] self.clientsock = clientsock - def run(self): + def run(self) -> None: print("Entering thread prompt loop") clientfile = self.clientsock.makefile("w") @@ -55,7 +57,7 @@ def run(self): while 1: try: - clientfile.write("%s %s >>> " % loc) + clientfile.write("{} {} >>> ".format(*loc)) clientfile.flush() line = filein.readline() if not line: diff --git a/src/execnet/script/socketserver.py b/src/execnet/script/socketserver.py index 034b4346..b58d1a4c 100644 --- a/src/execnet/script/socketserver.py +++ b/src/execnet/script/socketserver.py @@ -1,31 +1,34 @@ #! /usr/bin/env python """ - start socket based minimal readline exec server +start socket based minimal readline exec server - it can exeuted in 2 modes of operation +it can exeuted in 2 modes of operation - 1. as normal script, that listens for new connections +1. as normal script, that listens for new connections - 2. via existing_gateway.remote_exec (as imported module) +2. via existing_gateway.remote_exec (as imported module) """ + # this part of the program only executes on the server side # +from __future__ import annotations + import os import sys +from typing import TYPE_CHECKING -progname = "socket_readline_exec_server-1.2" - +try: + import fcntl +except ImportError: + fcntl = None # type: ignore[assignment] -def get_fcntl(): - try: - import fcntl - except ImportError: - fcntl = None - return fcntl +if TYPE_CHECKING: + from execnet.gateway_base import Channel + from execnet.gateway_base import ExecModel +progname = "socket_readline_exec_server-1.2" -fcntl = get_fcntl() debug = 0 @@ -35,7 +38,7 @@ def get_fcntl(): sys.stdout = sys.stderr = f -def print_(*args): +def print_(*args) -> None: print(" ".join(str(arg) for arg in args)) @@ -45,10 +48,10 @@ def print_(*args): ) -def exec_from_one_connection(serversock): +def exec_from_one_connection(serversock) -> None: print_(progname, "Entering Accept loop", serversock.getsockname()) clientsock, address = serversock.accept() - print_(progname, "got new connection from %s %s" % address) + print_(progname, "got new connection from {} {}".format(*address)) clientfile = clientsock.makefile("rb") print_("reading line") # rstrip so that we can use \r\n for telnet testing @@ -60,14 +63,14 @@ def exec_from_one_connection(serversock): co = compile(source + "\n", "", "exec") print_(progname, "compiled source, executing") try: - exec_(co, g) # noqa + exec_(co, g) # type: ignore[name-defined] # noqa: F821 finally: print_(progname, "finished executing code") # background thread might hold a reference to this (!?) # clientsock.close() -def bind_and_listen(hostport, execmodel): +def bind_and_listen(hostport: str | tuple[str, int], execmodel: ExecModel): socket = execmodel.socket if isinstance(hostport, str): host, port = hostport.split(":") @@ -86,7 +89,7 @@ def bind_and_listen(hostport, execmodel): return serversock -def startserver(serversock, loop=False): +def startserver(serversock, loop: bool = False) -> None: execute_path = os.getcwd() try: while 1: @@ -94,14 +97,13 @@ def startserver(serversock, loop=False): exec_from_one_connection(serversock) except (KeyboardInterrupt, SystemExit): raise - except BaseException: + except BaseException as exc: if debug: import traceback traceback.print_exc() else: - excinfo = sys.exc_info() - print_("got exception", excinfo[1]) + print_("got exception", exc) os.chdir(execute_path) if not loop: break @@ -124,9 +126,10 @@ def startserver(serversock, loop=False): startserver(serversock, loop=True) elif __name__ == "__channelexec__": - chan = globals()["channel"] + chan: Channel = globals()["channel"] execmodel = chan.gateway.execmodel bindname = chan.receive() + assert isinstance(bindname, (str, tuple)) sock = bind_and_listen(bindname, execmodel) port = sock.getsockname() chan.send(port) diff --git a/src/execnet/script/socketserverservice.py b/src/execnet/script/socketserverservice.py index 3d64f139..9c18c12c 100644 --- a/src/execnet/script/socketserverservice.py +++ b/src/execnet/script/socketserverservice.py @@ -5,7 +5,7 @@ python socketserverservice.py register net start ExecNetSocketServer """ -import socketserver + import sys import threading @@ -15,6 +15,9 @@ import win32service import win32serviceutil +from execnet.gateway_base import get_execmodel + +from . import socketserver appname = "ExecNetSocketServer" @@ -24,7 +27,7 @@ class SocketServerService(win32serviceutil.ServiceFramework): _svc_display_name_ = "%s" % appname _svc_deps_ = ["EventLog"] - def __init__(self, args): + def __init__(self, args) -> None: # The exe-file has messages for the Event Log Viewer. # Register the exe-file as event source. # @@ -40,11 +43,11 @@ def __init__(self, args): self.hWaitStop = win32event.CreateEvent(None, 0, 0, None) self.WAIT_TIME = 1000 # in milliseconds - def SvcStop(self): + def SvcStop(self) -> None: self.ReportServiceStatus(win32service.SERVICE_STOP_PENDING) win32event.SetEvent(self.hWaitStop) - def SvcDoRun(self): + def SvcDoRun(self) -> None: # Redirect stdout and stderr to prevent "IOError: [Errno 9] # Bad file descriptor". Windows services don't have functional # output streams. @@ -62,7 +65,8 @@ def SvcDoRun(self): hostport = ":8888" print("Starting py.execnet SocketServer on %s" % hostport) - serversock = socketserver.bind_and_listen(hostport) + exec_model = get_execmodel("thread") + serversock = socketserver.bind_and_listen(hostport, exec_model) thread = threading.Thread( target=socketserver.startserver, args=(serversock,), kwargs={"loop": True} ) diff --git a/src/execnet/script/xx.py b/src/execnet/script/xx.py deleted file mode 100644 index 687cc81e..00000000 --- a/src/execnet/script/xx.py +++ /dev/null @@ -1,12 +0,0 @@ -import sys - -import register -import rlcompleter2 - -rlcompleter2.setup() - -try: - hostport = sys.argv[1] -except BaseException: - hostport = ":8888" -gw = register.ServerGateway(hostport) diff --git a/src/execnet/xspec.py b/src/execnet/xspec.py index 4d33ad67..0559ed8c 100644 --- a/src/execnet/xspec.py +++ b/src/execnet/xspec.py @@ -2,26 +2,40 @@ (c) 2008-2013, holger krekel """ +from __future__ import annotations + class XSpec: """Execution Specification: key1=value1//key2=value2 ... - * keys need to be unique within the specification scope - * neither key nor value are allowed to contain "//" - * keys are not allowed to contain "=" - * keys are not allowed to start with underscore - * if no "=value" is given, assume a boolean True value + + * Keys need to be unique within the specification scope + * Neither key nor value are allowed to contain "//" + * Keys are not allowed to contain "=" + * Keys are not allowed to start with underscore + * If no "=value" is given, assume a boolean True value """ # XXX allow customization, for only allow specific key names - popen = ( - ssh - ) = socket = python = chdir = nice = dont_write_bytecode = execmodel = None + chdir: str | None = None + dont_write_bytecode: bool | None = None + execmodel: str | None = None + id: str | None = None + installvia: str | None = None + nice: str | None = None + popen: bool | None = None + python: str | None = None + socket: str | None = None + ssh: str | None = None + ssh_config: str | None = None + vagrant_ssh: str | None = None + via: str | None = None - def __init__(self, string): + def __init__(self, string: str) -> None: self._spec = string self.env = {} for keyvalue in string.split("//"): i = keyvalue.find("=") + value: str | bool if i == -1: key, value = keyvalue, True else: @@ -35,25 +49,25 @@ def __init__(self, string): else: setattr(self, key, value) - def __getattr__(self, name): + def __getattr__(self, name: str) -> None | bool | str: if name[0] == "_": raise AttributeError(name) return None - def __repr__(self): + def __repr__(self) -> str: return f"" - def __str__(self): + def __str__(self) -> str: return self._spec - def __hash__(self): + def __hash__(self) -> int: return hash(self._spec) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return self._spec == getattr(other, "_spec", None) - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return self._spec != getattr(other, "_spec", None) - def _samefilesystem(self): + def _samefilesystem(self) -> bool: return self.popen is not None and self.chdir is None diff --git a/testing/conftest.py b/testing/conftest.py index 066df23a..c5fa4cf5 100644 --- a/testing/conftest.py +++ b/testing/conftest.py @@ -1,13 +1,18 @@ +from __future__ import annotations + import shutil import sys from functools import lru_cache from typing import Callable +from typing import Generator from typing import Iterator import execnet import pytest -from execnet.gateway_base import get_execmodel +from execnet.gateway import Gateway +from execnet.gateway_base import ExecModel from execnet.gateway_base import WorkerPool +from execnet.gateway_base import get_execmodel collect_ignore = ["build", "doc/_build"] @@ -15,7 +20,7 @@ @pytest.hookimpl(hookwrapper=True) -def pytest_runtest_setup(item): +def pytest_runtest_setup(item: pytest.Item) -> Generator[None, None, None]: if item.fspath.purebasename in ("test_group", "test_info"): getspecssh(item.config) # will skip if no gx given yield @@ -31,7 +36,7 @@ def group_function() -> Iterator[execnet.Group]: @pytest.fixture -def makegateway(group_function) -> Callable[[str], execnet.gateway.Gateway]: +def makegateway(group_function: execnet.Group) -> Callable[[str], Gateway]: return group_function.makegateway @@ -39,7 +44,7 @@ def makegateway(group_function) -> Callable[[str], execnet.gateway.Gateway]: # configuration information for tests -def pytest_addoption(parser): +def pytest_addoption(parser: pytest.Parser) -> None: group = parser.getgroup("execnet", "execnet testing options") group.addoption( "--gx", @@ -66,20 +71,20 @@ def pytest_addoption(parser): @pytest.fixture -def specssh(request): +def specssh(request: pytest.FixtureRequest) -> execnet.XSpec: return getspecssh(request.config) @pytest.fixture -def specsocket(request): +def specsocket(request: pytest.FixtureRequest) -> execnet.XSpec: return getsocketspec(request.config) -def getgspecs(config): - return map(execnet.XSpec, config.getvalueorskip("gspecs")) +def getgspecs(config: pytest.Config) -> list[execnet.XSpec]: + return [execnet.XSpec(gspec) for gspec in config.getvalueorskip("gspecs")] -def getspecssh(config): +def getspecssh(config: pytest.Config) -> execnet.XSpec: xspecs = getgspecs(config) for spec in xspecs: if spec.ssh: @@ -89,7 +94,7 @@ def getspecssh(config): pytest.skip("need '--gx ssh=...'") -def getsocketspec(config): +def getsocketspec(config: pytest.Config) -> execnet.XSpec: xspecs = getgspecs(config) for spec in xspecs: if spec.socket: @@ -97,7 +102,7 @@ def getsocketspec(config): pytest.skip("need '--gx socket=...'") -def pytest_generate_tests(metafunc): +def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: if "gw" in metafunc.fixturenames: assert "anypython" not in metafunc.fixturenames, "need combine?" if hasattr(metafunc.function, "gwtypes"): @@ -109,35 +114,39 @@ def pytest_generate_tests(metafunc): metafunc.parametrize("gw", gwtypes, indirect=True) -@lru_cache() -def getexecutable(name): +@lru_cache +def getexecutable(name: str) -> str | None: if name == "sys.executable": return sys.executable return shutil.which(name) @pytest.fixture(params=("sys.executable", "pypy3")) -def anypython(request): +def anypython(request: pytest.FixtureRequest) -> str: name = request.param executable = getexecutable(name) if executable is None: pytest.skip(f"no {name} found") if "execmodel" in request.fixturenames and name != "sys.executable": backend = request.getfixturevalue("execmodel").backend - if backend != "thread": + if backend not in ("thread", "main_thread_only"): pytest.xfail(f"cannot run {backend!r} execmodel with bare {name}") return executable @pytest.fixture(scope="session") -def group(): +def group() -> Iterator[execnet.Group]: g = execnet.Group() yield g g.terminate(timeout=1) @pytest.fixture -def gw(request, execmodel, group): +def gw( + request: pytest.FixtureRequest, + execmodel: ExecModel, + group: execnet.Group, +) -> Gateway: try: return group[request.param] except KeyError: @@ -152,10 +161,11 @@ def gw(request, execmodel, group): proxygw = group.makegateway("popen//id=%s" % pname) # assert group['proxygw'].remote_status().receiving gw = group.makegateway( - "socket//id=socket//installvia=%s" - "//execmodel=%s" % (pname, execmodel.backend) + f"socket//id=socket//installvia={pname}" + f"//execmodel={execmodel.backend}" ) - gw.proxygw = proxygw + # TODO(typing): Clarify this assignment. + gw.proxygw = proxygw # type: ignore[attr-defined] assert pname in group elif request.param == "ssh": sshhost = request.getfixturevalue("specssh").ssh @@ -173,9 +183,11 @@ def gw(request, execmodel, group): return gw -@pytest.fixture(params=["thread", "eventlet", "gevent"], scope="session") -def execmodel(request): - if request.param != "thread": +@pytest.fixture( + params=["thread", "main_thread_only", "eventlet", "gevent"], scope="session" +) +def execmodel(request: pytest.FixtureRequest) -> ExecModel: + if request.param not in ("thread", "main_thread_only"): pytest.importorskip(request.param) if request.param in ("eventlet", "gevent") and sys.platform == "win32": pytest.xfail(request.param + " does not work on win32") @@ -183,5 +195,5 @@ def execmodel(request): @pytest.fixture -def pool(execmodel): +def pool(execmodel: ExecModel) -> WorkerPool: return WorkerPool(execmodel=execmodel) diff --git a/testing/test_basics.py b/testing/test_basics.py index 5820c492..5ed53e97 100644 --- a/testing/test_basics.py +++ b/testing/test_basics.py @@ -1,3 +1,4 @@ +# ruff: noqa: B018 from __future__ import annotations import inspect @@ -9,6 +10,7 @@ from io import BytesIO from pathlib import Path from typing import Any +from typing import Callable import execnet import pytest @@ -16,10 +18,10 @@ from execnet import gateway_base from execnet import gateway_io from execnet.gateway_base import ChannelFactory +from execnet.gateway_base import ExecModel from execnet.gateway_base import Message from execnet.gateway_base import Popen2IO - skip_win_pypy = pytest.mark.xfail( condition=hasattr(sys, "pypy_version_info") and sys.platform.startswith("win"), reason="failing on Windows on PyPy (#63)", @@ -28,12 +30,12 @@ @pytest.mark.parametrize("val", ["123", 42, [1, 2, 3], ["23", 25]]) class TestSerializeAPI: - def test_serializer_api(self, val): + def test_serializer_api(self, val: object) -> None: dumped = execnet.dumps(val) val2 = execnet.loads(dumped) assert val == val2 - def test_mmap(self, tmp_path, val): + def test_mmap(self, tmp_path: Path, val: object) -> None: mmap = pytest.importorskip("mmap").mmap p = tmp_path / "data.bin" @@ -43,7 +45,7 @@ def test_mmap(self, tmp_path, val): val2 = execnet.load(m) assert val == val2 - def test_bytesio(self, val): + def test_bytesio(self, val: object) -> None: f = BytesIO() execnet.dump(f, val) read = BytesIO(f.getvalue()) @@ -51,7 +53,7 @@ def test_bytesio(self, val): assert val == val2 -def test_serializer_api_version_error(monkeypatch): +def test_serializer_api_version_error(monkeypatch: pytest.MonkeyPatch) -> None: bchr = gateway_base.bchr monkeypatch.setattr(gateway_base, "DUMPFORMAT_VERSION", bchr(1)) dumped = execnet.dumps(42) @@ -59,13 +61,13 @@ def test_serializer_api_version_error(monkeypatch): pytest.raises(execnet.DataFormatError, lambda: execnet.loads(dumped)) -def test_errors_on_execnet(): +def test_errors_on_execnet() -> None: assert hasattr(execnet, "RemoteError") assert hasattr(execnet, "TimeoutError") assert hasattr(execnet, "DataFormatError") -def test_subprocess_interaction(anypython): +def test_subprocess_interaction(anypython: str) -> None: line = gateway_io.popen_bootstrapline compile(line, "xyz", "exec") args = [str(anypython), "-c", line] @@ -77,11 +79,16 @@ def test_subprocess_interaction(anypython): stdout=subprocess.PIPE, ) - def send(line): + assert popen.stdin is not None + assert popen.stdout is not None + + def send(line: str) -> None: + assert popen.stdin is not None popen.stdin.write(line) popen.stdin.flush() - def receive(): + def receive() -> str: + assert popen.stdout is not None return popen.stdout.readline() try: @@ -102,7 +109,7 @@ def receive(): popen.wait() -def read_write_loop(): +def read_write_loop() -> None: sys.stdout.write("ok\n") sys.stdout.flush() while 1: @@ -119,10 +126,7 @@ def read_write_loop(): IO_MESSAGE_EXTRA_SOURCE = """ import sys backend = sys.argv[1] -try: - from io import BytesIO -except ImportError: - from StringIO import StringIO as BytesIO +from io import BytesIO import tempfile temp_out = BytesIO() temp_in = BytesIO() @@ -172,7 +176,7 @@ def checker(anypython: str, tmp_path: Path) -> Checker: return Checker(python=anypython, path=tmp_path) -def test_io_message(checker, execmodel): +def test_io_message(checker: Checker, execmodel: ExecModel) -> None: out = checker.run_check( inspect.getsource(gateway_base) + IO_MESSAGE_EXTRA_SOURCE, execmodel.backend ) @@ -180,7 +184,7 @@ def test_io_message(checker, execmodel): assert "all passed" in out.stdout -def test_popen_io(checker, execmodel): +def test_popen_io(checker: Checker, execmodel: ExecModel) -> None: out = checker.run_check( inspect.getsource(gateway_base) + f""" @@ -195,22 +199,22 @@ def test_popen_io(checker, execmodel): assert "hello" in out.stdout -def test_popen_io_readloop(monkeypatch, execmodel): +def test_popen_io_readloop(execmodel: ExecModel) -> None: sio = BytesIO(b"test") io = Popen2IO(sio, sio, execmodel) real_read = io._read - def newread(numbytes): + def newread(numbytes: int) -> bytes: if numbytes > 1: numbytes = numbytes - 1 - return real_read(numbytes) + return real_read(numbytes) # type: ignore[no-any-return] io._read = newread result = io.read(3) assert result == b"tes" -def test_rinfo_source(checker): +def test_rinfo_source(checker: Checker) -> None: out = checker.run_check( f""" class Channel: @@ -226,20 +230,18 @@ def send(self, data): assert "all passed" in out.stdout -def test_geterrortext(checker): +def test_geterrortext(checker: Checker) -> None: out = checker.run_check( inspect.getsource(gateway_base) + """ -class Arg: +class Arg(Exception): pass -errortext = geterrortext((Arg, "1", 4)) +errortext = geterrortext(Arg()) assert "Arg" in errortext -import sys try: raise ValueError("17") -except ValueError: - excinfo = sys.exc_info() - s = geterrortext(excinfo) +except ValueError as exc: + s = geterrortext(exc) assert "17" in s print ("all passed") """ @@ -249,13 +251,28 @@ class Arg: @pytest.mark.skipif("not hasattr(os, 'dup')") -def test_stdouterrin_setnull(execmodel, capfd): - gateway_base.init_popen_io(execmodel) - os.write(1, b"hello") - os.read(0, 1) - out, err = capfd.readouterr() - assert not out - assert not err +def test_stdouterrin_setnull( + execmodel: ExecModel, capfd: pytest.CaptureFixture[str] +) -> None: + # Backup and restore stdin state, and rely on capfd to handle + # this for stdout and stderr. + orig_stdin = sys.stdin + orig_stdin_fd = os.dup(0) + try: + # The returned Popen2IO instance can be garbage collected + # prematurely since we don't hold a reference here, but we + # tolerate this because it is intended to leave behind a + # sane state afterwards. + gateway_base.init_popen_io(execmodel) + os.write(1, b"hello") + os.read(0, 1) + out, err = capfd.readouterr() + assert not out + assert not err + finally: + sys.stdin = orig_stdin + os.dup2(orig_stdin_fd, 0) + os.close(orig_stdin_fd) class PseudoChannel: @@ -263,33 +280,34 @@ class gateway: class _channelfactory: finished = False - def __init__(self): - self._sent = [] - self._closed = [] + def __init__(self) -> None: + self._sent: list[object] = [] + self._closed: list[str | None] = [] self.id = 1000 - def send(self, obj): + def send(self, obj: object) -> None: self._sent.append(obj) - def close(self, errortext=None): + def close(self, errortext: str | None = None) -> None: self._closed.append(errortext) -def test_exectask(execmodel): +def test_exectask(execmodel: ExecModel) -> None: io = BytesIO() - io.execmodel = execmodel - gw = gateway_base.WorkerGateway(io, id="something") + io.execmodel = execmodel # type: ignore[attr-defined] + gw = gateway_base.WorkerGateway(io, id="something") # type: ignore[arg-type] ch = PseudoChannel() - gw.executetask((ch, ("raise ValueError()", None, {}))) + gw.executetask((ch, ("raise ValueError()", None, {}))) # type: ignore[arg-type] assert "ValueError" in str(ch._closed[0]) class TestMessage: - def test_wire_protocol(self): + def test_wire_protocol(self) -> None: for i, handler in enumerate(Message._types): one = BytesIO() data = b"23" - Message(i, 42, data).to_io(one) + # TODO(typing): Maybe make this work. + Message(i, 42, data).to_io(one) # type: ignore[arg-type] two = BytesIO(one.getvalue()) msg = Message.from_io(two) assert msg.msgcode == i @@ -301,87 +319,89 @@ def test_wire_protocol(self): class TestPureChannel: @pytest.fixture - def fac(self, execmodel): + def fac(self, execmodel: ExecModel) -> ChannelFactory: class FakeGateway: - def _trace(self, *args): + def _trace(self, *args) -> None: pass - def _send(self, *k): + def _send(self, *k) -> None: pass - FakeGateway.execmodel = execmodel - return ChannelFactory(FakeGateway()) + FakeGateway.execmodel = execmodel # type: ignore[attr-defined] + return ChannelFactory(FakeGateway()) # type: ignore[arg-type] - def test_factory_create(self, fac): + def test_factory_create(self, fac: ChannelFactory) -> None: chan1 = fac.new() assert chan1.id == 1 chan2 = fac.new() assert chan2.id == 3 - def test_factory_getitem(self, fac): + def test_factory_getitem(self, fac: ChannelFactory) -> None: chan1 = fac.new() assert fac._channels[chan1.id] == chan1 chan2 = fac.new() assert fac._channels[chan2.id] == chan2 - def test_channel_timeouterror(self, fac): + def test_channel_timeouterror(self, fac: ChannelFactory) -> None: channel = fac.new() pytest.raises(IOError, channel.waitclose, timeout=0.01) - def test_channel_makefile_incompatmode(self, fac): + def test_channel_makefile_incompatmode(self, fac) -> None: channel = fac.new() with pytest.raises(ValueError): channel.makefile("rw") class TestSourceOfFunction: - def test_lambda_unsupported(self): + def test_lambda_unsupported(self) -> None: pytest.raises(ValueError, gateway._source_of_function, lambda: 1) - def test_wrong_prototype_fails(self): - def prototype(wrong): + def test_wrong_prototype_fails(self) -> None: + def prototype(wrong) -> None: pass pytest.raises(ValueError, gateway._source_of_function, prototype) - def test_function_without_known_source_fails(self): + def test_function_without_known_source_fails(self) -> None: # this one won't be able to find the source - mess = {} + mess: dict[str, Any] = {} exec("def fail(channel): pass", mess, mess) print(inspect.getsourcefile(mess["fail"])) - pytest.raises(ValueError, gateway._source_of_function, mess["fail"]) + with pytest.raises(ValueError): + gateway._source_of_function(mess["fail"]) - def test_function_with_closure_fails(self): - mess = {} + def test_function_with_closure_fails(self) -> None: + mess: dict[str, Any] = {} - def closure(channel): + def closure(channel: object) -> None: print(mess) - pytest.raises(ValueError, gateway._source_of_function, closure) + with pytest.raises(ValueError): + gateway._source_of_function(closure) - def test_source_of_nested_function(self): - def working(channel): + def test_source_of_nested_function(self) -> None: + def working(channel: object) -> None: pass send_source = gateway._source_of_function(working).lstrip("\r\n") - expected = "def working(channel):\n pass\n" + expected = "def working(channel: object) -> None:\n pass\n" assert send_source == expected class TestGlobalFinder: - def check(self, func): + def check(self, func) -> list[str]: src = textwrap.dedent(inspect.getsource(func)) code = func.__code__ return gateway._find_non_builtin_globals(src, code) - def test_local(self): + def test_local(self) -> None: def f(a, b, c): d = 3 return d assert self.check(f) == [] - def test_global(self): + def test_global(self) -> None: def f(a, b): sys d = 4 @@ -389,19 +409,19 @@ def f(a, b): assert self.check(f) == ["sys"] - def test_builtin(self): - def f(): + def test_builtin(self) -> None: + def f() -> None: len assert self.check(f) == [] - def test_function_with_global_fails(self): - def func(channel): + def test_function_with_global_fails(self) -> None: + def func(channel) -> None: sys pytest.raises(ValueError, gateway._source_of_function, func) - def test_method_call(self): + def test_method_call(self) -> None: # method names are reason # for the simple code object based heusteric failing def f(channel): @@ -411,8 +431,10 @@ def f(channel): @skip_win_pypy -def test_remote_exec_function_with_kwargs(anypython, makegateway): - def func(channel, data): +def test_remote_exec_function_with_kwargs( + anypython: str, makegateway: Callable[[str], gateway.Gateway] +) -> None: + def func(channel, data) -> None: channel.send(data) gw = makegateway("popen//python=%s" % anypython) @@ -423,7 +445,7 @@ def func(channel, data): assert result == 1 -def test_remote_exc__no_kwargs(makegateway): +def test_remote_exc__no_kwargs(makegateway: Callable[[], gateway.Gateway]) -> None: gw = makegateway() with pytest.raises(TypeError): gw.remote_exec(gateway_base, kwarg=1) @@ -432,7 +454,9 @@ def test_remote_exc__no_kwargs(makegateway): @skip_win_pypy -def test_remote_exec_inspect_stack(makegateway): +def test_remote_exec_inspect_stack( + makegateway: Callable[[], gateway.Gateway], +) -> None: gw = makegateway() ch = gw.remote_exec( """ @@ -442,5 +466,7 @@ def test_remote_exec_inspect_stack(makegateway): channel.send('\\n'.join(traceback.format_stack())) """ ) - assert 'File ""' in ch.receive() + received = ch.receive() + assert isinstance(received, str) + assert 'File ""' in received ch.waitclose() diff --git a/testing/test_channel.py b/testing/test_channel.py index a1bfbae9..f500f117 100644 --- a/testing/test_channel.py +++ b/testing/test_channel.py @@ -1,10 +1,14 @@ """ mostly functional tests of gateways. """ + +from __future__ import annotations + import time import pytest - +from execnet.gateway import Gateway +from execnet.gateway_base import Channel needs_early_gc = pytest.mark.skipif("not hasattr(sys, 'getrefcount')") needs_osdup = pytest.mark.skipif("not hasattr(os, 'dup')") @@ -12,16 +16,16 @@ class TestChannelBasicBehaviour: - def test_serialize_error(self, gw): + def test_serialize_error(self, gw: Gateway) -> None: ch = gw.remote_exec("channel.send(ValueError(42))") excinfo = pytest.raises(ch.RemoteError, ch.receive) assert "can't serialize" in str(excinfo.value) - def test_channel_close_and_then_receive_error(self, gw): + def test_channel_close_and_then_receive_error(self, gw: Gateway) -> None: channel = gw.remote_exec("raise ValueError") pytest.raises(channel.RemoteError, channel.receive) - def test_channel_finish_and_then_EOFError(self, gw): + def test_channel_finish_and_then_EOFError(self, gw: Gateway) -> None: channel = gw.remote_exec("channel.send(42)") x = channel.receive() assert x == 42 @@ -29,20 +33,22 @@ def test_channel_finish_and_then_EOFError(self, gw): pytest.raises(EOFError, channel.receive) pytest.raises(EOFError, channel.receive) - def test_waitclose_timeouterror(self, gw): + def test_waitclose_timeouterror(self, gw: Gateway) -> None: channel = gw.remote_exec("channel.receive()") pytest.raises(channel.TimeoutError, channel.waitclose, 0.02) channel.send(1) channel.waitclose(timeout=TESTTIMEOUT) - def test_channel_receive_timeout(self, gw): + def test_channel_receive_timeout(self, gw: Gateway) -> None: channel = gw.remote_exec("channel.send(channel.receive())") with pytest.raises(channel.TimeoutError): channel.receive(timeout=0.2) channel.send(1) channel.receive(timeout=TESTTIMEOUT) - def test_channel_receive_internal_timeout(self, gw, monkeypatch): + def test_channel_receive_internal_timeout( + self, gw: Gateway, monkeypatch: pytest.MonkeyPatch + ) -> None: channel = gw.remote_exec( """ import time @@ -53,23 +59,23 @@ def test_channel_receive_internal_timeout(self, gw, monkeypatch): monkeypatch.setattr(channel.__class__, "_INTERNALWAKEUP", 0.2) channel.receive() - def test_channel_close_and_then_receive_error_multiple(self, gw): + def test_channel_close_and_then_receive_error_multiple(self, gw: Gateway) -> None: channel = gw.remote_exec("channel.send(42) ; raise ValueError") x = channel.receive() assert x == 42 pytest.raises(channel.RemoteError, channel.receive) - def test_channel__local_close(self, gw): + def test_channel__local_close(self, gw: Gateway) -> None: channel = gw._channelfactory.new() gw._channelfactory._local_close(channel.id) channel.waitclose(0.1) - def test_channel__local_close_error(self, gw): + def test_channel__local_close_error(self, gw: Gateway) -> None: channel = gw._channelfactory.new() gw._channelfactory._local_close(channel.id, channel.RemoteError("error")) pytest.raises(channel.RemoteError, channel.waitclose, 0.01) - def test_channel_error_reporting(self, gw): + def test_channel_error_reporting(self, gw: Gateway) -> None: channel = gw.remote_exec("def foo():\n return foobar()\nfoo()\n") excinfo = pytest.raises(channel.RemoteError, channel.receive) msg = str(excinfo.value) @@ -77,7 +83,7 @@ def test_channel_error_reporting(self, gw): assert "NameError" in msg assert "foobar" in msg - def test_channel_syntax_error(self, gw): + def test_channel_syntax_error(self, gw: Gateway) -> None: # missing colon channel = gw.remote_exec("def foo()\n return 1\nfoo()\n") excinfo = pytest.raises(channel.RemoteError, channel.receive) @@ -85,7 +91,7 @@ def test_channel_syntax_error(self, gw): assert msg.startswith("Traceback (most recent call last):") assert "SyntaxError" in msg - def test_channel_iter(self, gw): + def test_channel_iter(self, gw: Gateway) -> None: channel = gw.remote_exec( """ for x in range(3): @@ -95,7 +101,7 @@ def test_channel_iter(self, gw): l = list(channel) assert l == [0, 1, 2] - def test_channel_pass_in_structure(self, gw): + def test_channel_pass_in_structure(self, gw: Gateway) -> None: channel = gw.remote_exec( """ ch1, ch2 = channel.receive() @@ -110,7 +116,7 @@ def test_channel_pass_in_structure(self, gw): data = newchan2.receive() assert data == 2 - def test_channel_multipass(self, gw): + def test_channel_multipass(self, gw: Gateway) -> None: channel = gw.remote_exec( """ channel.send(channel) @@ -123,15 +129,16 @@ def test_channel_multipass(self, gw): channel.send(newchan) channel.waitclose() - def test_channel_passing_over_channel(self, gw): + def test_channel_passing_over_channel(self, gw: Gateway) -> None: channel = gw.remote_exec( """ - c = channel.gateway.newchannel() - channel.send(c) - c.send(42) - """ + c = channel.gateway.newchannel() + channel.send(c) + c.send(42) + """ ) c = channel.receive() + assert isinstance(c, Channel) x = c.receive() assert x == 42 @@ -140,15 +147,15 @@ def test_channel_passing_over_channel(self, gw): # assert c.id not in gw._channelfactory newchan = gw.remote_exec( """ - assert %d not in channel.gateway._channelfactory._channels - """ + assert %d not in channel.gateway._channelfactory._channels + """ % channel.id ) newchan.waitclose(TESTTIMEOUT) assert channel.id not in gw._channelfactory._channels - def test_channel_receiver_callback(self, gw): - l = [] + def test_channel_receiver_callback(self, gw: Gateway) -> None: + l: list[int] = [] # channel = gw.newchannel(receiver=l.append) channel = gw.remote_exec( source=""" @@ -164,8 +171,8 @@ def test_channel_receiver_callback(self, gw): assert l[:2] == [42, 13] assert isinstance(l[2], channel.__class__) - def test_channel_callback_after_receive(self, gw): - l = [] + def test_channel_callback_after_receive(self, gw: Gateway) -> None: + l: list[int] = [] channel = gw.remote_exec( source=""" channel.send(42) @@ -182,10 +189,10 @@ def test_channel_callback_after_receive(self, gw): assert l[0] == 13 assert isinstance(l[1], channel.__class__) - def test_waiting_for_callbacks(self, gw): + def test_waiting_for_callbacks(self, gw: Gateway) -> None: l = [] - def callback(msg): + def callback(msg) -> None: import time time.sleep(0.2) @@ -200,14 +207,16 @@ def callback(msg): channel.waitclose(TESTTIMEOUT) assert l == [42] - def test_channel_callback_stays_active(self, gw): + def test_channel_callback_stays_active(self, gw: Gateway) -> None: self.check_channel_callback_stays_active(gw, earlyfree=True) - def check_channel_callback_stays_active(self, gw, earlyfree=True): + def check_channel_callback_stays_active( + self, gw: Gateway, earlyfree: bool = True + ) -> Channel | None: if gw.spec.execmodel == "gevent": pytest.xfail("investigate gevent failure") # with 'earlyfree==True', this tests the "sendonly" channel state. - l = [] + l: list[int] = [] channel = gw.remote_exec( source=""" import _thread @@ -224,11 +233,10 @@ def producer(subchannel): subchannel = gw.newchannel() subchannel.setcallback(l.append) channel.send(subchannel) - if earlyfree: - subchannel = None + subchan = None if earlyfree else subchannel counter = 100 while len(l) < 5: - if subchannel and subchannel.isclosed(): + if subchan and subchan.isclosed(): break counter -= 1 print(counter) @@ -236,16 +244,17 @@ def producer(subchannel): pytest.fail("timed out waiting for the answer[%d]" % len(l)) time.sleep(0.04) # busy-wait assert l == [0, 100, 200, 300, 400] - return subchannel + return subchan @needs_early_gc - def test_channel_callback_remote_freed(self, gw): + def test_channel_callback_remote_freed(self, gw: Gateway) -> None: channel = self.check_channel_callback_stays_active(gw, earlyfree=False) + assert channel is not None # freed automatically at the end of producer() channel.waitclose(TESTTIMEOUT) - def test_channel_endmarker_callback(self, gw): - l = [] + def test_channel_endmarker_callback(self, gw: Gateway) -> None: + l: list[int | Channel] = [] channel = gw.remote_exec( source=""" channel.send(42) @@ -261,7 +270,7 @@ def test_channel_endmarker_callback(self, gw): assert isinstance(l[2], channel.__class__) assert l[3] == 999 - def test_channel_endmarker_callback_error(self, gw): + def test_channel_endmarker_callback_error(self, gw: Gateway) -> None: q = gw.execmodel.queue.Queue() channel = gw.remote_exec( source=""" @@ -275,7 +284,7 @@ def test_channel_endmarker_callback_error(self, gw): assert err assert str(err).find("ValueError") != -1 - def test_channel_callback_error(self, gw): + def test_channel_callback_error(self, gw: Gateway) -> None: channel = gw.remote_exec( """ def f(item): @@ -288,6 +297,7 @@ def f(item): """ ) subchan = channel.receive() + assert isinstance(subchan, Channel) subchan.send(1) with pytest.raises(subchan.RemoteError) as excinfo: subchan.waitclose(TESTTIMEOUT) @@ -297,7 +307,7 @@ def f(item): class TestChannelFile: - def test_channel_file_write(self, gw): + def test_channel_file_write(self, gw: Gateway) -> None: channel = gw.remote_exec( """ f = channel.makefile() @@ -307,19 +317,20 @@ def test_channel_file_write(self, gw): """ ) first = channel.receive() + assert isinstance(first, str) assert first.strip() == "hello world" second = channel.receive() assert second == 42 - def test_channel_file_write_error(self, gw): + def test_channel_file_write_error(self, gw: Gateway) -> None: channel = gw.remote_exec("pass") f = channel.makefile() assert not f.isatty() channel.waitclose(TESTTIMEOUT) with pytest.raises(IOError): - f.write("hello") + f.write(b"hello") - def test_channel_file_proxyclose(self, gw): + def test_channel_file_proxyclose(self, gw: Gateway) -> None: channel = gw.remote_exec( """ f = channel.makefile(proxyclose=True) @@ -329,10 +340,11 @@ def test_channel_file_proxyclose(self, gw): """ ) first = channel.receive() + assert isinstance(first, str) assert first.strip() == "hello world" pytest.raises(channel.RemoteError, channel.receive) - def test_channel_file_read(self, gw): + def test_channel_file_read(self, gw: Gateway) -> None: channel = gw.remote_exec( """ f = channel.makefile(mode='r') @@ -348,7 +360,7 @@ def test_channel_file_read(self, gw): assert s1 == "xy" assert s2 == "abcde" - def test_channel_file_read_empty(self, gw): + def test_channel_file_read_empty(self, gw: Gateway) -> None: channel = gw.remote_exec("pass") f = channel.makefile(mode="r") s = f.read(3) @@ -356,7 +368,7 @@ def test_channel_file_read_empty(self, gw): s = f.read(5) assert s == "" - def test_channel_file_readline_remote(self, gw): + def test_channel_file_readline_remote(self, gw: Gateway) -> None: channel = gw.remote_exec( """ channel.send('123\\n45') @@ -369,7 +381,7 @@ def test_channel_file_readline_remote(self, gw): s = f.readline() assert s == "45" - def test_channel_makefile_incompatmode(self, gw): + def test_channel_makefile_incompatmode(self, gw: Gateway) -> None: channel = gw.newchannel() with pytest.raises(ValueError): - channel.makefile("rw") + channel.makefile("rw") # type: ignore[call-overload] diff --git a/testing/test_compatibility_regressions.py b/testing/test_compatibility_regressions.py index bc788e20..5343e7d5 100644 --- a/testing/test_compatibility_regressions.py +++ b/testing/test_compatibility_regressions.py @@ -1,7 +1,7 @@ from execnet import gateway_base -def test_opcodes(): +def test_opcodes() -> None: data = vars(gateway_base.opcode) computed = {k: v for k, v in data.items() if "__" not in k} assert computed == { diff --git a/testing/test_gateway.py b/testing/test_gateway.py index 809a13d9..4291404e 100644 --- a/testing/test_gateway.py +++ b/testing/test_gateway.py @@ -1,17 +1,22 @@ """ mostly functional tests of gateways. """ + +from __future__ import annotations + import os import pathlib import shutil import signal import sys from textwrap import dedent +from typing import Callable import execnet import pytest from execnet import gateway_base from execnet import gateway_io +from execnet.gateway import Gateway TESTTIMEOUT = 10.0 # seconds needs_osdup = pytest.mark.skipif("not hasattr(os, 'dup')") @@ -27,25 +32,25 @@ class TestBasicGateway: - def test_correct_setup(self, gw): + def test_correct_setup(self, gw: Gateway) -> None: assert gw.hasreceiver() assert gw in gw._group assert gw.id in gw._group assert gw.spec - def test_repr_doesnt_crash(self, gw): + def test_repr_doesnt_crash(self, gw: Gateway) -> None: assert isinstance(repr(gw), str) - def test_attribute__name__(self, gw): + def test_attribute__name__(self, gw: Gateway) -> None: channel = gw.remote_exec("channel.send(__name__)") name = channel.receive() assert name == "__channelexec__" - def test_gateway_status_simple(self, gw): + def test_gateway_status_simple(self, gw: Gateway) -> None: status = gw.remote_status() assert status.numexecuting == 0 - def test_exc_info_is_clear_after_gateway_startup(self, gw): + def test_exc_info_is_clear_after_gateway_startup(self, gw: Gateway) -> None: ch = gw.remote_exec( """ import traceback, sys @@ -61,7 +66,7 @@ def test_exc_info_is_clear_after_gateway_startup(self, gw): if res != 0: pytest.fail("remote raised\n%s" % res) - def test_gateway_status_no_real_channel(self, gw): + def test_gateway_status_no_real_channel(self, gw: Gateway) -> None: numchan = gw._channelfactory.channels() gw.remote_status() numchan2 = gw._channelfactory.channels() @@ -71,7 +76,7 @@ def test_gateway_status_no_real_channel(self, gw): assert numchan2 == numchan @flakytest - def test_gateway_status_busy(self, gw): + def test_gateway_status_busy(self, gw: Gateway) -> None: numchannels = gw.remote_status().numchannels ch1 = gw.remote_exec("channel.send(1); channel.receive()") ch2 = gw.remote_exec("channel.receive()") @@ -92,7 +97,7 @@ def test_gateway_status_busy(self, gw): # race condition assert status.numchannels <= numchannels - def test_remote_exec_module(self, tmp_path, gw): + def test_remote_exec_module(self, tmp_path: pathlib.Path, gw: Gateway) -> None: p = tmp_path / "remotetest.py" p.write_text("channel.send(1)") mod = type(os)("remotetest") @@ -105,7 +110,9 @@ def test_remote_exec_module(self, tmp_path, gw): name = channel.receive() assert name == 2 - def test_remote_exec_module_is_removed(self, gw, tmp_path, monkeypatch): + def test_remote_exec_module_is_removed( + self, gw: Gateway, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch + ) -> None: remotetest = tmp_path / "remote.py" remotetest.write_text( dedent( @@ -122,7 +129,7 @@ def remote(): ) monkeypatch.syspath_prepend(tmp_path) - import remote + import remote # type: ignore[import-not-found] ch = gw.remote_exec(remote) # simulate sending the code to a remote location that does not have @@ -136,7 +143,12 @@ def remote(): assert result is True - def test_remote_exec_module_with_traceback(self, gw, tmp_path, monkeypatch): + def test_remote_exec_module_with_traceback( + self, + gw: Gateway, + tmp_path: pathlib.Path, + monkeypatch: pytest.MonkeyPatch, + ) -> None: remotetestpy = tmp_path / "remotetest.py" remotetestpy.write_text( dedent( @@ -151,7 +163,7 @@ def run_me(channel=None): ) monkeypatch.syspath_prepend(tmp_path) - import remotetest + import remotetest # type: ignore[import-not-found] ch = gw.remote_exec(remotetest) try: @@ -171,7 +183,7 @@ def run_me(channel=None): finally: ch.close() - def test_correct_setup_no_py(self, gw): + def test_correct_setup_no_py(self, gw: Gateway) -> None: channel = gw.remote_exec( """ import sys @@ -179,32 +191,33 @@ def test_correct_setup_no_py(self, gw): """ ) remotemodules = channel.receive() + assert isinstance(remotemodules, list) assert "py" not in remotemodules, "py should not be imported on remote side" - def test_remote_exec_waitclose(self, gw): + def test_remote_exec_waitclose(self, gw: Gateway) -> None: channel = gw.remote_exec("pass") channel.waitclose(TESTTIMEOUT) - def test_remote_exec_waitclose_2(self, gw): + def test_remote_exec_waitclose_2(self, gw: Gateway) -> None: channel = gw.remote_exec("def gccycle(): pass") channel.waitclose(TESTTIMEOUT) - def test_remote_exec_waitclose_noarg(self, gw): + def test_remote_exec_waitclose_noarg(self, gw: Gateway) -> None: channel = gw.remote_exec("pass") channel.waitclose() - def test_remote_exec_error_after_close(self, gw): + def test_remote_exec_error_after_close(self, gw: Gateway) -> None: channel = gw.remote_exec("pass") channel.waitclose(TESTTIMEOUT) pytest.raises(IOError, channel.send, 0) - def test_remote_exec_no_explicit_close(self, gw): + def test_remote_exec_no_explicit_close(self, gw: Gateway) -> None: channel = gw.remote_exec("channel.close()") with pytest.raises(channel.RemoteError) as excinfo: channel.waitclose(TESTTIMEOUT) assert "explicit" in excinfo.value.formatted - def test_remote_exec_channel_anonymous(self, gw): + def test_remote_exec_channel_anonymous(self, gw: Gateway) -> None: channel = gw.remote_exec( """ obj = channel.receive() @@ -216,7 +229,7 @@ def test_remote_exec_channel_anonymous(self, gw): assert result == 42 @needs_osdup - def test_confusion_from_os_write_stdout(self, gw): + def test_confusion_from_os_write_stdout(self, gw: Gateway) -> None: channel = gw.remote_exec( """ import os @@ -233,7 +246,7 @@ def test_confusion_from_os_write_stdout(self, gw): assert res == 42 @needs_osdup - def test_confusion_from_os_write_stderr(self, gw): + def test_confusion_from_os_write_stderr(self, gw: Gateway) -> None: channel = gw.remote_exec( """ import os @@ -249,7 +262,7 @@ def test_confusion_from_os_write_stderr(self, gw): res = channel.receive() assert res == 42 - def test__rinfo(self, gw): + def test__rinfo(self, gw: Gateway) -> None: rinfo = gw._rinfo() assert rinfo.executable assert rinfo.cwd @@ -276,20 +289,23 @@ def test__rinfo(self, gw): class TestPopenGateway: gwtype = "popen" - def test_chdir_separation(self, tmp_path, makegateway): + def test_chdir_separation( + self, tmp_path: pathlib.Path, makegateway: Callable[[str], Gateway] + ) -> None: with pytest.MonkeyPatch.context() as mp: mp.chdir(tmp_path) gw = makegateway("popen") c = gw.remote_exec("import os ; channel.send(os.getcwd())") x = c.receive() + assert isinstance(x, str) assert x.lower() == str(tmp_path).lower() - def test_remoteerror_readable_traceback(self, gw): + def test_remoteerror_readable_traceback(self, gw: Gateway) -> None: with pytest.raises(gateway_base.RemoteError) as e: gw.remote_exec("x y").waitclose() assert "gateway_base" in e.value.formatted - def test_many_popen(self, makegateway): + def test_many_popen(self, makegateway: Callable[[str], Gateway]) -> None: num = 4 l = [] for i in range(num): @@ -303,13 +319,15 @@ def test_many_popen(self, makegateway): ret = channel.receive() assert ret == 42 - def test_rinfo_popen(self, gw): + def test_rinfo_popen(self, gw: Gateway) -> None: rinfo = gw._rinfo() assert rinfo.executable == sys.executable assert rinfo.cwd == os.getcwd() assert rinfo.version_info == sys.version_info - def test_waitclose_on_remote_killed(self, makegateway): + def test_waitclose_on_remote_killed( + self, makegateway: Callable[[str], Gateway] + ) -> None: gw = makegateway("popen") channel = gw.remote_exec( """ @@ -320,6 +338,7 @@ def test_waitclose_on_remote_killed(self, makegateway): """ ) remotepid = channel.receive() + assert isinstance(remotepid, int) os.kill(remotepid, signal.SIGTERM) with pytest.raises(EOFError): channel.waitclose(TESTTIMEOUT) @@ -328,7 +347,7 @@ def test_waitclose_on_remote_killed(self, makegateway): with pytest.raises(EOFError): channel.receive() - def test_receive_on_remote_sysexit(self, gw): + def test_receive_on_remote_sysexit(self, gw: Gateway) -> None: channel = gw.remote_exec( """ raise SystemExit() @@ -336,7 +355,7 @@ def test_receive_on_remote_sysexit(self, gw): ) pytest.raises(channel.RemoteError, channel.receive) - def test_dont_write_bytecode(self, makegateway): + def test_dont_write_bytecode(self, makegateway: Callable[[str], Gateway]) -> None: check_sys_dont_write_bytecode = """ import sys channel.send(sys.dont_write_bytecode) @@ -353,36 +372,41 @@ def test_dont_write_bytecode(self, makegateway): @pytest.mark.skipif("config.option.broken_isp") -def test_socket_gw_host_not_found(gw, makegateway): - pytest.raises(execnet.HostNotFound, lambda: makegateway("socket=qwepoipqwe:9000")) +def test_socket_gw_host_not_found(makegateway: Callable[[str], Gateway]) -> None: + with pytest.raises(execnet.HostNotFound): + makegateway("socket=qwepoipqwe:9000") class TestSshPopenGateway: gwtype = "ssh" - def test_sshconfig_config_parsing(self, monkeypatch, makegateway): + def test_sshconfig_config_parsing( + self, monkeypatch: pytest.MonkeyPatch, makegateway: Callable[[str], Gateway] + ) -> None: l = [] monkeypatch.setattr( gateway_io, "Popen2IOMaster", lambda *args, **kwargs: l.append(args[0]) ) - pytest.raises(AttributeError, lambda: makegateway("ssh=xyz//ssh_config=qwe")) + with pytest.raises(AttributeError): + makegateway("ssh=xyz//ssh_config=qwe") assert len(l) == 1 popen_args = l[0] i = popen_args.index("-F") assert popen_args[i + 1] == "qwe" - def test_sshaddress(self, gw, specssh): + def test_sshaddress(self, gw: Gateway, specssh: execnet.XSpec) -> None: assert gw.remoteaddress == specssh.ssh - def test_host_not_found(self, gw, makegateway): - pytest.raises( - execnet.HostNotFound, lambda: makegateway("ssh=nowhere.codespeak.net") - ) + def test_host_not_found( + self, gw: Gateway, makegateway: Callable[[str], Gateway] + ) -> None: + with pytest.raises(execnet.HostNotFound): + makegateway("ssh=nowhere.codespeak.net") class TestThreads: - def test_threads(self, makegateway): + def test_threads(self, makegateway: Callable[[str], Gateway]) -> None: gw = makegateway("popen") gw.remote_init_threads(3) c1 = gw.remote_exec("channel.send(channel.receive())") @@ -394,7 +418,7 @@ def test_threads(self, makegateway): res = c1.receive() assert res == 42 - def test_threads_race_sending(self, makegateway): + def test_threads_race_sending(self, makegateway: Callable[[str], Gateway]) -> None: # multiple threads sending data in parallel gw = makegateway("popen") num = 5 @@ -418,7 +442,7 @@ def test_threads_race_sending(self, makegateway): ch.waitclose(TESTTIMEOUT) @flakytest - def test_status_with_threads(self, makegateway): + def test_status_with_threads(self, makegateway: Callable[[str], Gateway]) -> None: gw = makegateway("popen") c1 = gw.remote_exec("channel.send(1) ; channel.receive()") c2 = gw.remote_exec("channel.send(2) ; channel.receive()") @@ -441,7 +465,12 @@ def test_status_with_threads(self, makegateway): class TestTracing: - def test_popen_filetracing(self, tmp_path, monkeypatch, makegateway): + def test_popen_filetracing( + self, + tmp_path: pathlib.Path, + monkeypatch: pytest.MonkeyPatch, + makegateway: Callable[[str], Gateway], + ) -> None: monkeypatch.setenv("TMP", str(tmp_path)) monkeypatch.setenv("TEMP", str(tmp_path)) # windows monkeypatch.setenv("EXECNET_DEBUG", "1") @@ -450,6 +479,7 @@ def test_popen_filetracing(self, tmp_path, monkeypatch, makegateway): fn = gw.remote_exec( "import execnet;channel.send(execnet.gateway_base.fn)" ).receive() + assert isinstance(fn, str) workerfile = pathlib.Path(fn) assert workerfile.exists() worker_line = "creating workergateway" @@ -463,7 +493,12 @@ def test_popen_filetracing(self, tmp_path, monkeypatch, makegateway): @skip_win_pypy @flakytest - def test_popen_stderr_tracing(self, capfd, monkeypatch, makegateway): + def test_popen_stderr_tracing( + self, + capfd: pytest.CaptureFixture[str], + monkeypatch: pytest.MonkeyPatch, + makegateway: Callable[[str], Gateway], + ) -> None: monkeypatch.setenv("EXECNET_DEBUG", "2") gw = makegateway("popen") pid = gw.remote_exec("import os ; channel.send(os.getpid())").receive() @@ -493,8 +528,8 @@ def test_no_tracing_by_default(self): ('popen//python="/u/test me/python" -e', ["/u/test me/python", "-e"]), ], ) -def test_popen_args(spec, expected_args): - expected_args = expected_args + ["-u", "-c", gateway_io.popen_bootstrapline] +def test_popen_args(spec: str, expected_args: list[str]) -> None: + expected_args = [*expected_args, "-u", "-c", gateway_io.popen_bootstrapline] args = gateway_io.popen_args(execnet.XSpec(spec)) assert args == expected_args @@ -512,16 +547,101 @@ def test_popen_args(spec, expected_args): ), ], ) -def test_regression_gevent_hangs(group, interleave_getstatus): +def test_regression_gevent_hangs( + group: execnet.Group, interleave_getstatus: bool +) -> None: pytest.importorskip("gevent") gw = group.makegateway("popen//execmodel=gevent") print(gw.remote_status()) - def sendback(channel): + def sendback(channel) -> None: channel.send(1234) ch = gw.remote_exec(sendback) if interleave_getstatus: print(gw.remote_status()) assert ch.receive(timeout=0.5) == 1234 + + +def test_assert_main_thread_only( + execmodel: gateway_base.ExecModel, makegateway: Callable[[str], Gateway] +) -> None: + if execmodel.backend != "main_thread_only": + pytest.skip("can only run with main_thread_only") + + gw = makegateway(f"execmodel={execmodel.backend}//popen") + + try: + # Submit multiple remote_exec requests in quick succession and + # assert that all tasks execute in the main thread. It is + # necessary to call receive on each channel before the next + # remote_exec call, since the channel will raise an error if + # concurrent remote_exec requests are submitted as in + # test_main_thread_only_concurrent_remote_exec_deadlock. + for i in range(10): + ch = gw.remote_exec( + """ + import time, threading + time.sleep(0.02) + channel.send(threading.current_thread() is threading.main_thread()) + """ + ) + + try: + res = ch.receive() + finally: + ch.close() + # This doesn't actually block because we closed + # the channel already, but it does check for remote + # errors and raise them. + ch.waitclose() + if res is not True: + pytest.fail("remote raised\n%s" % res) + finally: + gw.exit() + gw.join() + + +def test_main_thread_only_concurrent_remote_exec_deadlock( + execmodel: gateway_base.ExecModel, makegateway: Callable[[str], Gateway] +) -> None: + if execmodel.backend != "main_thread_only": + pytest.skip("can only run with main_thread_only") + + gw = makegateway(f"execmodel={execmodel.backend}//popen") + channels = [] + try: + # Submit multiple remote_exec requests in quick succession and + # assert that MAIN_THREAD_ONLY_DEADLOCK_TEXT is raised if + # concurrent remote_exec requests are submitted for the + # main_thread_only execmodel (as compensation for the lack of + # back pressure in remote_exec calls which do not attempt to + # block until the remote main thread is idle). + for i in range(2): + channels.append( + gw.remote_exec( + """ + import threading + channel.send(threading.current_thread() is threading.main_thread()) + # Wait forever, ensuring that the deadlock case triggers. + channel.gateway.execmodel.Event().wait() + """ + ) + ) + + expected_results = ( + True, + execnet.gateway_base.MAIN_THREAD_ONLY_DEADLOCK_TEXT, + ) + for expected, ch in zip(expected_results, channels): + try: + res = ch.receive() + except execnet.RemoteError as e: + res = e.formatted + assert res == expected + finally: + for ch in channels: + ch.close() + gw.exit() + gw.join() diff --git a/testing/test_multi.py b/testing/test_multi.py index 79861b11..edc93037 100644 --- a/testing/test_multi.py +++ b/testing/test_multi.py @@ -1,19 +1,27 @@ """ - tests for multi channels and gateway Groups +tests for multi channels and gateway Groups """ + +from __future__ import annotations + import gc from time import sleep +from typing import Callable import execnet import pytest from execnet import XSpec +from execnet.gateway import Gateway from execnet.gateway_base import Channel +from execnet.gateway_base import ExecModel from execnet.multi import Group from execnet.multi import safe_terminate class TestMultiChannelAndGateway: - def test_multichannel_container_basics(self, gw, execmodel): + def test_multichannel_container_basics( + self, gw: Gateway, execmodel: ExecModel + ) -> None: mch = execnet.MultiChannel([Channel(gw, i) for i in range(3)]) assert len(mch) == 3 channels = list(mch) @@ -26,21 +34,21 @@ def test_multichannel_container_basics(self, gw, execmodel): assert channels[1] in mch assert channels[2] in mch - def test_multichannel_receive_each(self): + def test_multichannel_receive_each(self) -> None: class pseudochannel: - def receive(self): + def receive(self) -> object: return 12 pc1 = pseudochannel() pc2 = pseudochannel() - multichannel = execnet.MultiChannel([pc1, pc2]) + multichannel = execnet.MultiChannel([pc1, pc2]) # type: ignore[list-item] l = multichannel.receive_each(withchannel=True) assert len(l) == 2 - assert l == [(pc1, 12), (pc2, 12)] - l = multichannel.receive_each(withchannel=False) - assert l == [12, 12] + assert l == [(pc1, 12), (pc2, 12)] # type: ignore[comparison-overlap] + l2 = multichannel.receive_each(withchannel=False) + assert l2 == [12, 12] - def test_multichannel_send_each(self): + def test_multichannel_send_each(self) -> None: gm = execnet.Group(["popen"] * 2) mc = gm.remote_exec( """ @@ -52,12 +60,12 @@ def test_multichannel_send_each(self): l = mc.receive_each() assert l == [42, 42] - def test_Group_execmodel_setting(self): + def test_Group_execmodel_setting(self) -> None: gm = execnet.Group() gm.set_execmodel("thread") assert gm.execmodel.backend == "thread" assert gm.remote_execmodel.backend == "thread" - gm._gateways.append(1) + gm._gateways.append(1) # type: ignore[arg-type] try: with pytest.raises(ValueError): gm.set_execmodel("eventlet") @@ -65,7 +73,7 @@ def test_Group_execmodel_setting(self): finally: gm._gateways.pop() - def test_multichannel_receive_queue_for_two_subprocesses(self): + def test_multichannel_receive_queue_for_two_subprocesses(self) -> None: gm = execnet.Group(["popen"] * 2) mc = gm.remote_exec( """ @@ -81,23 +89,23 @@ def test_multichannel_receive_queue_for_two_subprocesses(self): assert item != item2 mc.waitclose() - def test_multichannel_waitclose(self): + def test_multichannel_waitclose(self) -> None: l = [] class pseudochannel: - def waitclose(self): + def waitclose(self) -> None: l.append(0) - multichannel = execnet.MultiChannel([pseudochannel(), pseudochannel()]) + multichannel = execnet.MultiChannel([pseudochannel(), pseudochannel()]) # type: ignore[list-item] multichannel.waitclose() assert len(l) == 2 class TestGroup: - def test_basic_group(self, monkeypatch): + def test_basic_group(self, monkeypatch: pytest.MonkeyPatch) -> None: import atexit - atexitlist = [] + atexitlist: list[Callable[[], object]] = [] monkeypatch.setattr(atexit, "register", atexitlist.append) group = Group() assert atexitlist == [group._cleanup_atexit] @@ -105,7 +113,7 @@ def test_basic_group(self, monkeypatch): joinlist = [] class PseudoIO: - def wait(self): + def wait(self) -> None: pass class PseudoSpec: @@ -116,15 +124,15 @@ class PseudoGW: _io = PseudoIO() spec = PseudoSpec() - def exit(self): + def exit(self) -> None: exitlist.append(self) - group._unregister(self) + group._unregister(self) # type: ignore[arg-type] - def join(self): + def join(self) -> None: joinlist.append(self) gw = PseudoGW() - group._register(gw) + group._register(gw) # type: ignore[arg-type] assert len(exitlist) == 0 assert len(joinlist) == 0 group._cleanup_atexit() @@ -136,12 +144,12 @@ def join(self): assert len(exitlist) == 1 assert len(joinlist) == 1 - def test_group_default_spec(self): + def test_group_default_spec(self) -> None: group = Group() group.defaultspec = "not-existing-type" pytest.raises(ValueError, group.makegateway) - def test_group_PopenGateway(self): + def test_group_PopenGateway(self) -> None: group = Group() gw = group.makegateway("popen") assert list(group) == [gw] @@ -150,7 +158,7 @@ def test_group_PopenGateway(self): group._cleanup_atexit() assert not group._gateways - def test_group_ordering_and_termination(self): + def test_group_ordering_and_termination(self) -> None: group = Group() group.makegateway("popen//id=3") group.makegateway("popen//id=2") @@ -165,7 +173,7 @@ def test_group_ordering_and_termination(self): assert not group assert repr(group) == "" - def test_group_id_allocation(self): + def test_group_id_allocation(self) -> None: group = Group() specs = [XSpec("popen"), XSpec("popen//id=hello")] group.allocate_id(specs[0]) @@ -178,14 +186,14 @@ def test_group_id_allocation(self): # group.allocate_id, XSpec("popen//id=hello")) group.terminate() - def test_gateway_and_id(self): + def test_gateway_and_id(self) -> None: group = Group() gw = group.makegateway("popen//id=hello") assert group["hello"] == gw with pytest.raises((TypeError, AttributeError)): - del group["hello"] + del group["hello"] # type: ignore[attr-defined] with pytest.raises((TypeError, AttributeError)): - group["hello"] = 5 + group["hello"] = 5 # type: ignore[index] assert "hello" in group assert gw in group assert len(group) == 1 @@ -194,7 +202,7 @@ def test_gateway_and_id(self): with pytest.raises(KeyError): _ = group["hello"] - def test_default_group(self): + def test_default_group(self) -> None: oldlist = list(execnet.default_group) gw = execnet.makegateway("popen") try: @@ -205,26 +213,27 @@ def test_default_group(self): finally: gw.exit() - def test_remote_exec_args(self): + def test_remote_exec_args(self) -> None: group = Group() group.makegateway("popen") - def fun(channel, arg): + def fun(channel, arg) -> None: channel.send(arg) mch = group.remote_exec(fun, arg=1) result = mch.receive_each() assert result == [1] - def test_terminate_with_proxying(self): + def test_terminate_with_proxying(self) -> None: group = Group() group.makegateway("popen//id=master") group.makegateway("popen//via=master//id=worker") group.terminate(1.0) -def test_safe_terminate(execmodel): - if execmodel.backend != "threading": +@pytest.mark.xfail(reason="active_count() has been broken for some time") +def test_safe_terminate(execmodel: ExecModel) -> None: + if execmodel.backend not in ("thread", "main_thread_only"): pytest.xfail( "execution model %r does not support task count" % execmodel.backend ) @@ -233,21 +242,22 @@ def test_safe_terminate(execmodel): active = threading.active_count() l = [] - def term(): + def term() -> None: sleep(3) - def kill(): + def kill() -> None: l.append(1) safe_terminate(execmodel, 1, [(term, kill)] * 10) assert len(l) == 10 sleep(0.1) gc.collect() - assert execmodel.active_count() == active + assert execmodel.active_count() == active # type: ignore[attr-defined] -def test_safe_terminate2(execmodel): - if execmodel.backend != "threading": +@pytest.mark.xfail(reason="active_count() has been broken for some time") +def test_safe_terminate2(execmodel: ExecModel) -> None: + if execmodel.backend not in ("thread", "main_thread_only"): pytest.xfail( "execution model %r does not support task count" % execmodel.backend ) @@ -256,10 +266,10 @@ def test_safe_terminate2(execmodel): active = threading.active_count() l = [] - def term(): + def term() -> None: return - def kill(): + def kill() -> None: l.append(1) safe_terminate(execmodel, 3, [(term, kill)] * 10) diff --git a/testing/test_rsync.py b/testing/test_rsync.py index fefd9a06..6f40bc44 100644 --- a/testing/test_rsync.py +++ b/testing/test_rsync.py @@ -7,24 +7,25 @@ import execnet import pytest from execnet import RSync +from execnet.gateway import Gateway @pytest.fixture(scope="module") -def group(request): +def group(request: pytest.FixtureRequest) -> execnet.Group: group = execnet.Group() request.addfinalizer(group.terminate) return group @pytest.fixture(scope="module") -def gw1(request, group): +def gw1(request: pytest.FixtureRequest, group: execnet.Group) -> Gateway: gw = group.makegateway("popen//id=gw1") request.addfinalizer(gw.exit) return gw @pytest.fixture(scope="module") -def gw2(request, group): +def gw2(request: pytest.FixtureRequest, group: execnet.Group) -> Gateway: gw = group.makegateway("popen//id=gw2") request.addfinalizer(gw.exit) return gw @@ -44,7 +45,7 @@ class _dirs(types.SimpleNamespace): @pytest.fixture -def dirs(request, tmp_path) -> _dirs: +def dirs(tmp_path: pathlib.Path) -> _dirs: dirs = _dirs( source=tmp_path / "source", dest1=tmp_path / "dest1", @@ -56,7 +57,7 @@ def dirs(request, tmp_path) -> _dirs: return dirs -def are_paths_equal(path1, path2): +def are_paths_equal(path1: pathlib.Path, path2: pathlib.Path) -> bool: if os.path.__name__ == "ntpath": # On Windows, os.readlink returns an extended path (\\?\) # for absolute symlinks. However, extended does not compare @@ -72,13 +73,13 @@ def are_paths_equal(path1, path2): class TestRSync: - def test_notargets(self, dirs): + def test_notargets(self, dirs: _dirs) -> None: rsync = RSync(dirs.source) with pytest.raises(IOError): rsync.send() - assert rsync.send(raises=False) is None + assert rsync.send(raises=False) is None # type: ignore[func-returns-value] - def test_dirsync(self, dirs, gw1, gw2): + def test_dirsync(self, dirs: _dirs, gw1: Gateway, gw2: Gateway) -> None: dest = dirs.dest1 dest2 = dirs.dest2 source = dirs.source @@ -115,7 +116,7 @@ def test_dirsync(self, dirs, gw1, gw2): assert not dest.joinpath("subdir", "file1").exists() assert dest2.joinpath("subdir", "file1").exists() - def test_dirsync_twice(self, dirs, gw1, gw2): + def test_dirsync_twice(self, dirs: _dirs, gw1: Gateway, gw2: Gateway) -> None: source = dirs.source source.joinpath("hello").touch() rsync = RSync(source) @@ -124,15 +125,17 @@ def test_dirsync_twice(self, dirs, gw1, gw2): assert dirs.dest1.joinpath("hello").exists() with pytest.raises(IOError): rsync.send() - assert rsync.send(raises=False) is None + assert rsync.send(raises=False) is None # type: ignore[func-returns-value] rsync.add_target(gw1, dirs.dest2) rsync.send() assert dirs.dest2.joinpath("hello").exists() with pytest.raises(IOError): rsync.send() - assert rsync.send(raises=False) is None + assert rsync.send(raises=False) is None # type: ignore[func-returns-value] - def test_rsync_default_reporting(self, capsys, dirs, gw1): + def test_rsync_default_reporting( + self, capsys: pytest.CaptureFixture[str], dirs: _dirs, gw1: Gateway + ) -> None: source = dirs.source source.joinpath("hello").touch() rsync = RSync(source) @@ -141,7 +144,9 @@ def test_rsync_default_reporting(self, capsys, dirs, gw1): out, err = capsys.readouterr() assert out.find("hello") != -1 - def test_rsync_non_verbose(self, capsys, dirs, gw1): + def test_rsync_non_verbose( + self, capsys: pytest.CaptureFixture[str], dirs: _dirs, gw1: Gateway + ) -> None: source = dirs.source source.joinpath("hello").touch() rsync = RSync(source, verbose=False) @@ -155,7 +160,7 @@ def test_rsync_non_verbose(self, capsys, dirs, gw1): sys.platform == "win32" or getattr(os, "_name", "") == "nt", reason="irrelevant on windows", ) - def test_permissions(self, dirs, gw1, gw2): + def test_permissions(self, dirs: _dirs, gw1: Gateway, gw2: Gateway) -> None: source = dirs.source dest = dirs.dest1 onedir = dirs.source / "one" @@ -194,7 +199,7 @@ def test_permissions(self, dirs, gw1, gw2): sys.platform == "win32" or getattr(os, "_name", "") == "nt", reason="irrelevant on windows", ) - def test_read_only_directories(self, dirs, gw1): + def test_read_only_directories(self, dirs: _dirs, gw1: Gateway) -> None: source = dirs.source dest = dirs.dest1 sub = source / "sub" @@ -214,7 +219,7 @@ def test_read_only_directories(self, dirs, gw1): assert dest.joinpath("sub", "subsub").stat().st_mode & 0o700 @needssymlink - def test_symlink_rsync(self, dirs, gw1): + def test_symlink_rsync(self, dirs: _dirs, gw1: Gateway) -> None: source = dirs.source dest = dirs.dest1 subdir = dirs.source / "subdir" @@ -236,7 +241,7 @@ def test_symlink_rsync(self, dirs, gw1): assert are_paths_equal(abslink, expected) @needssymlink - def test_symlink2_rsync(self, dirs, gw1): + def test_symlink2_rsync(self, dirs: _dirs, gw1: Gateway) -> None: source = dirs.source dest = dirs.dest1 subdir = dirs.source / "subdir" @@ -261,7 +266,7 @@ def test_symlink2_rsync(self, dirs, gw1): link3 = pathlib.Path(os.readlink(str(destsub / "link3"))) assert are_paths_equal(link3, source.parent) - def test_callback(self, dirs, gw1): + def test_callback(self, dirs: _dirs, gw1: Gateway) -> None: dest = dirs.dest1 source = dirs.source source.joinpath("existent").write_text("a" * 100) @@ -278,15 +283,15 @@ def callback(cmd, lgt, channel): assert total == {("list", 110): True, ("ack", 100): True, ("ack", 10): True} - def test_file_disappearing(self, dirs, gw1): + def test_file_disappearing(self, dirs: _dirs, gw1: Gateway) -> None: dest = dirs.dest1 source = dirs.source source.joinpath("ex").write_text("a" * 100) source.joinpath("ex2").write_text("a" * 100) class DRsync(RSync): - def filter(self, x): - assert x != source + def filter(self, x: str) -> bool: + assert x != str(source) if x.endswith("ex2"): self.x = 1 source.joinpath("ex2").unlink() diff --git a/testing/test_serializer.py b/testing/test_serializer.py index fed4a5de..aff41a98 100644 --- a/testing/test_serializer.py +++ b/testing/test_serializer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import subprocess import sys @@ -6,13 +8,12 @@ import execnet import pytest - # We use the execnet folder in order to avoid triggering a missing apipkg. pyimportdir = os.fspath(Path(execnet.__file__).parent) class PythonWrapper: - def __init__(self, executable, tmp_path): + def __init__(self, executable: str, tmp_path: Path) -> None: self.executable = executable self.tmp_path = tmp_path @@ -32,7 +33,7 @@ def dump(self, obj_rep: str) -> bytes: ) return res.stdout - def load(self, data: bytes): + def load(self, data: bytes) -> list[str]: script_file = self.tmp_path.joinpath("load.py") script_file.write_text( rf""" @@ -51,28 +52,27 @@ def load(self, data: bytes): res = subprocess.run( [str(self.executable), str(script_file)], capture_output=True, + check=True, ) - if res.returncode: - raise ValueError(res.stderr) return res.stdout.decode("ascii").splitlines() - def __repr__(self): + def __repr__(self) -> str: return f"" @pytest.fixture -def py3(request, tmp_path): +def py3(tmp_path: Path) -> PythonWrapper: return PythonWrapper(sys.executable, tmp_path) @pytest.fixture -def dump(py3): +def dump(py3: PythonWrapper): return py3.dump @pytest.fixture -def load(py3): +def load(py3: PythonWrapper): return py3.load @@ -88,14 +88,14 @@ def load(py3): @pytest.mark.parametrize(["tp_name", "repr"], simple_tests) -def test_simple(tp_name, repr, dump, load): +def test_simple(tp_name, repr, dump, load) -> None: p = dump(repr) tp, v = load(p) assert tp == tp_name assert v == repr -def test_set(load, dump): +def test_set(load, dump) -> None: p = dump("set((1, 2, 3))") tp, v = load(p) @@ -116,7 +116,7 @@ def test_frozenset(load, dump): assert v == "frozenset({1, 2, 3})" -def test_long(load, dump): +def test_long(load, dump) -> None: really_big = "9223372036854775807324234" p = dump(really_big) tp, v = load(p) @@ -124,42 +124,47 @@ def test_long(load, dump): assert v == really_big -def test_bytes(dump, load): +def test_bytes(dump, load) -> None: p = dump("b'hi'") tp, v = load(p) assert tp == "bytes" assert v == "b'hi'" -def test_str(dump, load): +def test_str(dump, load) -> None: p = dump("'xyz'") tp, s = load(p) assert tp == "str" assert s == "'xyz'" -def test_unicode(load, dump): +def test_unicode(load, dump) -> None: p = dump("u'hi'") tp, s = load(p) assert tp == "str" assert s == "'hi'" -def test_bool(dump, load): +def test_bool(dump, load) -> None: p = dump("True") tp, s = load(p) assert s == "True" assert tp == "bool" -def test_none(dump, load): +def test_none(dump, load) -> None: p = dump("None") tp, s = load(p) assert s == "None" -def test_tuple_nested_with_empty_in_between(dump, load): +def test_tuple_nested_with_empty_in_between(dump, load) -> None: p = dump("(1, (), 3)") tp, s = load(p) assert tp == "tuple" assert s == "(1, (), 3)" + + +def test_py2_string_loads() -> None: + """Regression test for #267.""" + assert execnet.loads(b"\x02M\x00\x00\x00\x01aQ") == b"a" diff --git a/testing/test_termination.py b/testing/test_termination.py index 282c1979..44bd1a52 100644 --- a/testing/test_termination.py +++ b/testing/test_termination.py @@ -4,9 +4,13 @@ import signal import subprocess import sys +from typing import Callable import execnet import pytest +from execnet.gateway import Gateway +from execnet.gateway_base import ExecModel +from execnet.gateway_base import WorkerPool from test_gateway import TESTTIMEOUT execnetdir = pathlib.Path(execnet.__file__).parent.parent @@ -17,7 +21,9 @@ ) -def test_exit_blocked_worker_execution_gateway(anypython, makegateway, pool): +def test_exit_blocked_worker_execution_gateway( + anypython: str, makegateway: Callable[[str], Gateway], pool: WorkerPool +) -> None: gateway = makegateway("popen//python=%s" % anypython) gateway.remote_exec( """ @@ -26,7 +32,7 @@ def test_exit_blocked_worker_execution_gateway(anypython, makegateway, pool): """ ) - def doit(): + def doit() -> int: gateway.exit() return 17 @@ -35,8 +41,10 @@ def doit(): assert x == 17 -def test_endmarker_delivery_on_remote_killterm(makegateway, execmodel): - if execmodel.backend != "thread": +def test_endmarker_delivery_on_remote_killterm( + makegateway: Callable[[str], Gateway], execmodel: ExecModel +) -> None: + if execmodel.backend not in ("thread", "main_thread_only"): pytest.xfail("test and execnet not compatible to greenlets yet") gw = makegateway("popen") q = execmodel.queue.Queue() @@ -48,6 +56,7 @@ def test_endmarker_delivery_on_remote_killterm(makegateway, execmodel): """ ) pid = channel.receive() + assert isinstance(pid, int) os.kill(pid, signal.SIGTERM) channel.setcallback(q.put, endmarker=999) val = q.get(TESTTIMEOUT) @@ -57,7 +66,9 @@ def test_endmarker_delivery_on_remote_killterm(makegateway, execmodel): @skip_win_pypy -def test_termination_on_remote_channel_receive(monkeypatch, makegateway): +def test_termination_on_remote_channel_receive( + monkeypatch: pytest.MonkeyPatch, makegateway: Callable[[str], Gateway] +) -> None: if not shutil.which("ps"): pytest.skip("need 'ps' command to externally check process status") monkeypatch.setenv("EXECNET_DEBUG", "2") @@ -66,12 +77,14 @@ def test_termination_on_remote_channel_receive(monkeypatch, makegateway): gw.remote_exec("channel.receive()") gw._group.terminate() command = ["ps", "-p", str(pid)] - output = subprocess.run(command, capture_output=True, text=True) + output = subprocess.run(command, capture_output=True, text=True, check=False) assert str(pid) not in output.stdout, output -def test_close_initiating_remote_no_error(testdir, anypython): - p = testdir.makepyfile( +def test_close_initiating_remote_no_error( + pytester: pytest.Pytester, anypython: str +) -> None: + p = pytester.makepyfile( """ import sys sys.path.insert(0, sys.argv[1]) @@ -86,20 +99,28 @@ def test_close_initiating_remote_no_error(testdir, anypython): """ ) popen = subprocess.Popen( - [str(anypython), str(p), str(execnetdir)], stdout=None, stderr=subprocess.PIPE + [anypython, str(p), str(execnetdir)], stdout=None, stderr=subprocess.PIPE ) out, err = popen.communicate() print(err) - err = err.decode("utf8") - lines = [x for x in err.splitlines() if "*sys-package" not in x] - # print (lines) + errstr = err.decode("utf8") + lines = [x for x in errstr.splitlines() if "*sys-package" not in x] assert not lines -def test_terminate_implicit_does_trykill(testdir, anypython, capfd, pool): - if pool.execmodel != "thread": +def test_terminate_implicit_does_trykill( + pytester: pytest.Pytester, + anypython: str, + capfd: pytest.CaptureFixture[str], + pool: WorkerPool, +) -> None: + if pool.execmodel.backend not in ("thread", "main_thread_only"): pytest.xfail("only os threading model supported") - p = testdir.makepyfile( + if sys.version_info >= (3, 12): + pytest.xfail( + "since python3.12 this test triggers RuntimeError: can't create new thread at interpreter shutdown" + ) + p = pytester.makepyfile( """ import sys sys.path.insert(0, %r) @@ -125,6 +146,7 @@ def flush(self): ) popen = subprocess.Popen([str(anypython), str(p)], stdout=subprocess.PIPE) # sync with start-up + assert popen.stdout is not None popen.stdout.readline() reply = pool.spawn(popen.communicate) reply.get(timeout=50) diff --git a/testing/test_threadpool.py b/testing/test_threadpool.py index 47a226d0..480282d9 100644 --- a/testing/test_threadpool.py +++ b/testing/test_threadpool.py @@ -1,11 +1,12 @@ import os -import sys +from pathlib import Path import pytest +from execnet.gateway_base import ExecModel from execnet.gateway_base import WorkerPool -def test_execmodel(execmodel, tmp_path): +def test_execmodel(execmodel: ExecModel, tmp_path: Path) -> None: assert execmodel.backend p = tmp_path / "somefile" p.write_text("content") @@ -15,22 +16,22 @@ def test_execmodel(execmodel, tmp_path): f.close() -def test_execmodel_basic_attrs(execmodel): +def test_execmodel_basic_attrs(execmodel: ExecModel) -> None: m = execmodel assert callable(m.start) assert m.get_ident() -def test_simple(pool): +def test_simple(pool: WorkerPool) -> None: reply = pool.spawn(lambda: 42) assert reply.get() == 42 -def test_some(pool, execmodel): +def test_some(pool: WorkerPool, execmodel: ExecModel) -> None: q = execmodel.queue.Queue() num = 4 - def f(i): + def f(i: int) -> None: q.put(i) while q.qsize(): execmodel.sleep(0.01) @@ -45,10 +46,10 @@ def f(i): assert len(pool._running) == 0 -def test_running_semnatics(pool, execmodel): +def test_running_semnatics(pool: WorkerPool, execmodel: ExecModel) -> None: q = execmodel.queue.Queue() - def first(): + def first() -> None: q.get() reply = pool.spawn(first) @@ -60,7 +61,7 @@ def first(): assert not reply.running -def test_waitfinish_on_reply(pool): +def test_waitfinish_on_reply(pool: WorkerPool) -> None: l = [] reply = pool.spawn(lambda: l.append(1)) reply.waitfinish() @@ -71,20 +72,20 @@ def test_waitfinish_on_reply(pool): @pytest.mark.xfail(reason="WorkerPool does not implement limited size") -def test_limited_size(execmodel): - pool = WorkerPool(execmodel, size=1) +def test_limited_size(execmodel: ExecModel) -> None: + pool = WorkerPool(execmodel, size=1) # type: ignore[call-arg] q = execmodel.queue.Queue() q2 = execmodel.queue.Queue() q3 = execmodel.queue.Queue() - def first(): + def first() -> None: q.put(1) q2.get() pool.spawn(first) assert q.get() == 1 - def second(): + def second() -> None: q3.put(3) # we spawn a second pool to spawn the second function @@ -98,8 +99,8 @@ def second(): assert pool.waitall() -def test_get(pool): - def f(): +def test_get(pool: WorkerPool) -> None: + def f() -> int: return 42 reply = pool.spawn(f) @@ -107,8 +108,8 @@ def f(): assert result == 42 -def test_get_timeout(execmodel, pool): - def f(): +def test_get_timeout(execmodel: ExecModel, pool: WorkerPool) -> None: + def f() -> int: execmodel.sleep(0.2) return 42 @@ -117,8 +118,8 @@ def f(): reply.get(timeout=0.01) -def test_get_excinfo(pool): - def f(): +def test_get_excinfo(pool: WorkerPool) -> None: + def f() -> None: raise ValueError("42") reply = pool.spawn(f) @@ -128,10 +129,10 @@ def f(): reply.get(1.0) -def test_waitall_timeout(pool, execmodel): +def test_waitall_timeout(pool: WorkerPool, execmodel: ExecModel) -> None: q = execmodel.queue.Queue() - def f(): + def f() -> None: q.get() reply = pool.spawn(f) @@ -142,10 +143,12 @@ def f(): @pytest.mark.skipif(not hasattr(os, "dup"), reason="no os.dup") -def test_pool_clean_shutdown(pool, capfd): +def test_pool_clean_shutdown( + pool: WorkerPool, capfd: pytest.CaptureFixture[str] +) -> None: q = pool.execmodel.queue.Queue() - def f(): + def f() -> None: q.get() pool.spawn(f) @@ -154,7 +157,7 @@ def f(): with pytest.raises(ValueError): pool.spawn(f) - def wait_then_put(): + def wait_then_put() -> None: pool.execmodel.sleep(0.1) q.put(1) @@ -164,21 +167,21 @@ def wait_then_put(): assert err == "" -def test_primary_thread_integration(execmodel): - if execmodel.backend != "thread": +def test_primary_thread_integration(execmodel: ExecModel) -> None: + if execmodel.backend not in ("thread", "main_thread_only"): with pytest.raises(ValueError): WorkerPool(execmodel=execmodel, hasprimary=True) return pool = WorkerPool(execmodel=execmodel, hasprimary=True) queue = execmodel.queue.Queue() - def do_integrate(): + def do_integrate() -> None: queue.put(execmodel.get_ident()) pool.integrate_as_primary_thread() execmodel.start(do_integrate) - def func(): + def func() -> None: queue.put(execmodel.get_ident()) pool.spawn(func) @@ -188,13 +191,13 @@ def func(): pool.terminate() -def test_primary_thread_integration_shutdown(execmodel): - if execmodel.backend != "thread": +def test_primary_thread_integration_shutdown(execmodel: ExecModel) -> None: + if execmodel.backend not in ("thread", "main_thread_only"): pytest.skip("can only run with threading") pool = WorkerPool(execmodel=execmodel, hasprimary=True) queue = execmodel.queue.Queue() - def do_integrate(): + def do_integrate() -> None: queue.put(execmodel.get_ident()) pool.integrate_as_primary_thread() @@ -203,7 +206,7 @@ def do_integrate(): queue2 = execmodel.queue.Queue() - def get_two(): + def get_two() -> None: queue.put(execmodel.get_ident()) queue2.get() diff --git a/testing/test_xspec.py b/testing/test_xspec.py index 2218e736..4c9ff8d6 100644 --- a/testing/test_xspec.py +++ b/testing/test_xspec.py @@ -4,16 +4,17 @@ import shutil import subprocess import sys +from pathlib import Path +from typing import Callable import execnet import pytest +from execnet import XSpec +from execnet.gateway import Gateway from execnet.gateway_io import popen_args from execnet.gateway_io import ssh_args from execnet.gateway_io import vagrant_ssh_args -XSpec = execnet.XSpec - - skip_win_pypy = pytest.mark.xfail( condition=hasattr(sys, "pypy_version_info") and sys.platform.startswith("win"), reason="failing on Windows on PyPy (#63)", @@ -21,7 +22,7 @@ class TestXSpec: - def test_norm_attributes(self): + def test_norm_attributes(self) -> None: spec = XSpec( r"socket=192.168.102.2:8888//python=c:/this/python3.8//chdir=d:\hello" ) @@ -32,7 +33,7 @@ def test_norm_attributes(self): assert not hasattr(spec, "_xyz") with pytest.raises(AttributeError): - spec._hello() + spec._hello() # type: ignore[misc,operator] spec = XSpec("socket=192.168.102.2:8888//python=python2.5//nice=3") assert spec.socket == "192.168.102.2:8888" @@ -48,7 +49,7 @@ def test_norm_attributes(self): spec = XSpec("popen") assert spec.popen is True - def test_ssh_options(self): + def test_ssh_options(self) -> None: spec = XSpec("ssh=-p 22100 user@host//python=python3") assert spec.ssh == "-p 22100 user@host" assert spec.python == "python3" @@ -60,22 +61,22 @@ def test_ssh_options(self): assert spec.ssh == "-i ~/.ssh/id_rsa-passwordless_login -p 22100 user@host" assert spec.python == "python3" - def test_execmodel(self): + def test_execmodel(self) -> None: spec = XSpec("execmodel=thread") assert spec.execmodel == "thread" spec = XSpec("execmodel=eventlet") assert spec.execmodel == "eventlet" - def test_ssh_options_and_config(self): + def test_ssh_options_and_config(self) -> None: spec = XSpec("ssh=-p 22100 user@host//python=python3") spec.ssh_config = "/home/user/ssh_config" assert ssh_args(spec)[:6] == ["ssh", "-C", "-F", spec.ssh_config, "-p", "22100"] - def test_vagrant_options(self): + def test_vagrant_options(self) -> None: spec = XSpec("vagrant_ssh=default//python=python3") assert vagrant_ssh_args(spec)[:-1] == ["vagrant", "ssh", "default", "--", "-C"] - def test_popen_with_sudo_python(self): + def test_popen_with_sudo_python(self) -> None: spec = XSpec("popen//python=sudo python3") assert popen_args(spec) == [ "sudo", @@ -85,33 +86,33 @@ def test_popen_with_sudo_python(self): "import sys;exec(eval(sys.stdin.readline()))", ] - def test_env(self): + def test_env(self) -> None: xspec = XSpec("popen//env:NAME=value1") assert xspec.env["NAME"] == "value1" - def test__samefilesystem(self): + def test__samefilesystem(self) -> None: assert XSpec("popen")._samefilesystem() assert XSpec("popen//python=123")._samefilesystem() assert not XSpec("popen//chdir=hello")._samefilesystem() - def test__spec_spec(self): + def test__spec_spec(self) -> None: for x in ("popen", "popen//python=this"): assert XSpec(x)._spec == x - def test_samekeyword_twice_raises(self): + def test_samekeyword_twice_raises(self) -> None: pytest.raises(ValueError, XSpec, "popen//popen") pytest.raises(ValueError, XSpec, "popen//popen=123") - def test_unknown_keys_allowed(self): + def test_unknown_keys_allowed(self) -> None: xspec = XSpec("hello=3") assert xspec.hello == "3" - def test_repr_and_string(self): + def test_repr_and_string(self) -> None: for x in ("popen", "popen//python=this"): assert repr(XSpec(x)).find("popen") != -1 assert str(XSpec(x)) == x - def test_hash_equality(self): + def test_hash_equality(self) -> None: assert XSpec("popen") == XSpec("popen") assert hash(XSpec("popen")) == hash(XSpec("popen")) assert XSpec("popen//python=123") != XSpec("popen") @@ -119,11 +120,11 @@ def test_hash_equality(self): class TestMakegateway: - def test_no_type(self, makegateway): + def test_no_type(self, makegateway: Callable[[str], Gateway]) -> None: pytest.raises(ValueError, lambda: makegateway("hello")) @skip_win_pypy - def test_popen_default(self, makegateway): + def test_popen_default(self, makegateway: Callable[[str], Gateway]) -> None: gw = makegateway("") assert gw.spec.popen assert gw.spec.python is None @@ -134,10 +135,10 @@ def test_popen_default(self, makegateway): @pytest.mark.skipif("not hasattr(os, 'nice')") @pytest.mark.xfail(reason="fails due to timing problems on busy single-core VMs") - def test_popen_nice(self, makegateway): + def test_popen_nice(self, makegateway: Callable[[str], Gateway]) -> None: gw = makegateway("popen") - def getnice(channel): + def getnice(channel) -> None: import os if hasattr(os, "nice"): @@ -146,13 +147,14 @@ def getnice(channel): channel.send(None) remotenice = gw.remote_exec(getnice).receive() + assert isinstance(remotenice, int) gw.exit() if remotenice is not None: gw = makegateway("popen//nice=5") remotenice2 = gw.remote_exec(getnice).receive() assert remotenice2 == remotenice + 5 - def test_popen_env(self, makegateway): + def test_popen_env(self, makegateway: Callable[[str], Gateway]) -> None: gw = makegateway("popen//env:NAME123=123") ch = gw.remote_exec( """ @@ -164,7 +166,7 @@ def test_popen_env(self, makegateway): assert value == "123" @skip_win_pypy - def test_popen_explicit(self, makegateway): + def test_popen_explicit(self, makegateway: Callable[[str], Gateway]) -> None: gw = makegateway("popen//python=%s" % sys.executable) assert gw.spec.python == sys.executable rinfo = gw._rinfo() @@ -173,31 +175,32 @@ def test_popen_explicit(self, makegateway): assert rinfo.version_info == sys.version_info @skip_win_pypy - def test_popen_chdir_absolute(self, tmp_path, makegateway): + def test_popen_chdir_absolute( + self, tmp_path: Path, makegateway: Callable[[str], Gateway] + ) -> None: gw = makegateway("popen//chdir=%s" % tmp_path) rinfo = gw._rinfo() assert rinfo.cwd == str(tmp_path.resolve()) @skip_win_pypy - def test_popen_chdir_newsub(self, monkeypatch, tmp_path, makegateway): + def test_popen_chdir_newsub( + self, + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + makegateway: Callable[[str], Gateway], + ) -> None: monkeypatch.chdir(tmp_path) gw = makegateway("popen//chdir=hello") rinfo = gw._rinfo() expected = str(tmp_path.joinpath("hello").resolve()).lower() assert rinfo.cwd.lower() == expected - def test_ssh(self, specssh, makegateway): + def test_ssh(self, specssh: XSpec, makegateway: Callable[[str], Gateway]) -> None: sshhost = specssh.ssh gw = makegateway("ssh=%s//id=ssh1" % sshhost) - rinfo = gw._rinfo() assert gw.id == "ssh1" - gw2 = execnet.SshGateway(sshhost) - rinfo2 = gw2._rinfo() - assert rinfo.executable == rinfo2.executable - assert rinfo.cwd == rinfo2.cwd - assert rinfo.version_info == rinfo2.version_info - def test_vagrant(self, makegateway): + def test_vagrant(self, makegateway: Callable[[str], Gateway]) -> None: vagrant_bin = shutil.which("vagrant") if vagrant_bin is None: pytest.skip("Vagrant binary not in PATH") @@ -206,6 +209,7 @@ def test_vagrant(self, makegateway): capture_output=True, encoding="utf-8", errors="replace", + check=True, ).stdout print(res) if ",default,state,shutoff\n" in res: @@ -217,10 +221,12 @@ def test_vagrant(self, makegateway): gw = makegateway("vagrant_ssh=default//python=python3") rinfo = gw._rinfo() - rinfo.cwd == "/home/vagrant" - rinfo.executable == "/usr/bin/python" + assert rinfo.cwd == "/home/vagrant" + assert rinfo.executable == "/usr/bin/python" - def test_socket(self, specsocket, makegateway): + def test_socket( + self, specsocket: XSpec, makegateway: Callable[[str], Gateway] + ) -> None: gw = makegateway("socket=%s//id=sock1" % specsocket.socket) rinfo = gw._rinfo() assert rinfo.executable @@ -230,7 +236,9 @@ def test_socket(self, specsocket, makegateway): # we cannot instantiate a second gateway @pytest.mark.xfail(reason="we can't instantiate a second gateway") - def test_socket_second(self, specsocket, makegateway): + def test_socket_second( + self, specsocket: XSpec, makegateway: Callable[[str], Gateway] + ) -> None: gw = makegateway("socket=%s//id=sock1" % specsocket.socket) gw2 = makegateway("socket=%s//id=sock1" % specsocket.socket) rinfo = gw._rinfo() @@ -239,7 +247,7 @@ def test_socket_second(self, specsocket, makegateway): assert rinfo.cwd == rinfo2.cwd assert rinfo.version_info == rinfo2.version_info - def test_socket_installvia(self): + def test_socket_installvia(self) -> None: group = execnet.Group() group.makegateway("popen//id=p1") gw = group.makegateway("socket//installvia=p1//id=s1") diff --git a/tox.ini b/tox.ini index 2449698a..4c743a66 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist=py{37,38,39,310,311,pypy37},docs,linting +envlist=py{38,39,310,311,312,pypy38},docs,linting isolated_build = true [testenv]