From 310c29512955b37fffee685120108795a8436b6c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Oct 2023 16:30:14 +0200 Subject: [PATCH 001/109] Rename workflow for making a release. --- .github/workflows/{wheels.yml => release.yml} | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) rename .github/workflows/{wheels.yml => release.yml} (97%) diff --git a/.github/workflows/wheels.yml b/.github/workflows/release.yml similarity index 97% rename from .github/workflows/wheels.yml rename to .github/workflows/release.yml index 707ef2c60..90f24b6f1 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/release.yml @@ -1,4 +1,4 @@ -name: Build wheels +name: Make release on: push: @@ -64,8 +64,8 @@ jobs: with: path: wheelhouse/*.whl - release: - name: Release + upload: + name: Upload needs: - sdist - wheels From 88e702ddaf214b46fcf6b3ceca25961f79ca9d00 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Oct 2023 16:39:20 +0200 Subject: [PATCH 002/109] Upgrade to Trusted Publishing. --- .github/workflows/release.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 90f24b6f1..8fad13529 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -72,6 +72,8 @@ jobs: runs-on: ubuntu-latest # Don't release when running the workflow manually from GitHub's UI. if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') + permissions: + id-token: write steps: - name: Download artifacts uses: actions/download-artifact@v3 @@ -80,8 +82,6 @@ jobs: path: dist - name: Upload to PyPI uses: pypa/gh-action-pypi-publish@release/v1 - with: - password: ${{ secrets.PYPI_API_TOKEN }} - name: Create GitHub release env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} From 5121bd15f988cc446db95b15a0bcac8dc64b68ab Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Oct 2023 16:54:27 +0200 Subject: [PATCH 003/109] Blind fix for automatic release creation. --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8fad13529..6e895e64e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -85,4 +85,4 @@ jobs: - name: Create GitHub release env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: gh release create ${{ github.ref_name }} --notes "See https://websockets.readthedocs.io/en/stable/project/changelog.html for details." + run: gh release -R python-websockets/websockets create ${{ github.ref_name }} --notes "See https://websockets.readthedocs.io/en/stable/project/changelog.html for details." From 2431e09eebc75578e310627f0eab38cd81df2f6b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 3 Nov 2023 07:55:23 +0100 Subject: [PATCH 004/109] Fix import style (likely autogenerated). --- src/websockets/sync/server.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 14767968c..d12da0c65 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -11,10 +11,9 @@ from types import TracebackType from typing import Any, Callable, Optional, Sequence, Type -from websockets.frames import CloseCode - from ..extensions.base import ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate +from ..frames import CloseCode from ..headers import validate_subprotocols from ..http import USER_AGENT from ..http11 import Request, Response From ec3bd2ab06278602c1d6018b476699e090036373 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 3 Nov 2023 08:22:33 +0100 Subject: [PATCH 005/109] Make sync reassembler more readable. No logic changes. --- src/websockets/sync/messages.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 67a22313c..d98ff855b 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -47,13 +47,13 @@ def __init__(self) -> None: # queue for transferring frames from the writing thread (library code) # to the reading thread (user code). We're buffering when chunks_queue # is None and streaming when it's a SimpleQueue. None is a sentinel - # value marking the end of the stream, superseding message_complete. + # value marking the end of the message, superseding message_complete. # Stream data from frames belonging to the same message. # Remove quotes around type when dropping Python < 3.9. self.chunks_queue: Optional["queue.SimpleQueue[Optional[Data]]"] = None - # This flag marks the end of the stream. + # This flag marks the end of the connection. self.closed = False def get(self, timeout: Optional[float] = None) -> Data: @@ -108,12 +108,12 @@ def get(self, timeout: Optional[float] = None) -> Data: # mypy cannot figure out that chunks have the proper type. message: Data = joiner.join(self.chunks) # type: ignore - assert not self.message_fetched.is_set() - self.message_fetched.set() - self.chunks = [] assert self.chunks_queue is None + assert not self.message_fetched.is_set() + self.message_fetched.set() + return message def get_iter(self) -> Iterator[Data]: @@ -169,26 +169,26 @@ def get_iter(self) -> Iterator[Data]: with self.mutex: self.get_in_progress = False - assert self.message_complete.is_set() - self.message_complete.clear() - # get_iter() was unblocked by close() rather than put(). if self.closed: raise EOFError("stream of frames ended") - assert not self.message_fetched.is_set() - self.message_fetched.set() + assert self.message_complete.is_set() + self.message_complete.clear() assert self.chunks == [] self.chunks_queue = None + assert not self.message_fetched.is_set() + self.message_fetched.set() + def put(self, frame: Frame) -> None: """ Add ``frame`` to the next message. When ``frame`` is the final frame in a message, :meth:`put` waits until - the message is fetched, either by calling :meth:`get` or by fully - consuming the return value of :meth:`get_iter`. + the message is fetched, which can be achieved by calling :meth:`get` or + by fully consuming the return value of :meth:`get_iter`. :meth:`put` assumes that the stream of frames respects the protocol. If it doesn't, the behavior is undefined. @@ -247,13 +247,13 @@ def put(self, frame: Frame) -> None: with self.mutex: self.put_in_progress = False - assert self.message_fetched.is_set() - self.message_fetched.clear() - # put() was unblocked by close() rather than get() or get_iter(). if self.closed: raise EOFError("stream of frames ended") + assert self.message_fetched.is_set() + self.message_fetched.clear() + self.decoder = None def close(self) -> None: From 5737b474ad7d4a3a5e04d68299f4e5ec34bd62ac Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Nov 2023 14:46:44 +0100 Subject: [PATCH 006/109] Start version 12.1. This commit should have been made right after releasing 12.0. --- docs/project/changelog.rst | 5 +++++ src/websockets/version.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 264e6e42d..200ca7ef3 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,11 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. +12.1 +---- + +*In development* + 12.0 ---- diff --git a/src/websockets/version.py b/src/websockets/version.py index d1c99458e..f1de3cbf4 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,9 +18,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = True +released = False -tag = version = commit = "12.0" +tag = version = commit = "12.1" if not released: # pragma: no cover From 94dd203f63bb52b1a30faa228e63ada2f0f2e874 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 9 Dec 2023 06:25:11 +0000 Subject: [PATCH 007/109] Bump actions/setup-python from 4 to 5 Bumps [actions/setup-python](https://github.com/actions/setup-python) from 4 to 5. - [Release notes](https://github.com/actions/setup-python/releases) - [Commits](https://github.com/actions/setup-python/compare/v4...v5) --- updated-dependencies: - dependency-name: actions/setup-python dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/release.yml | 4 ++-- .github/workflows/tests.yml | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 6e895e64e..7d56b9aa5 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -14,7 +14,7 @@ jobs: - name: Check out repository uses: actions/checkout@v4 - name: Install Python 3.x - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: 3.x - name: Build sdist @@ -47,7 +47,7 @@ jobs: - name: Check out repository uses: actions/checkout@v4 - name: Install Python 3.x - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: 3.x - name: Set up QEMU diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 470f5bc96..b128defb5 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -19,7 +19,7 @@ jobs: - name: Check out repository uses: actions/checkout@v4 - name: Install Python 3.x - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.x" - name: Install tox @@ -36,7 +36,7 @@ jobs: - name: Check out repository uses: actions/checkout@v4 - name: Install Python 3.x - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.x" - name: Install tox @@ -74,7 +74,7 @@ jobs: - name: Check out repository uses: actions/checkout@v4 - name: Install Python ${{ matrix.python }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} - name: Install tox From fe1833fb103f4d63baee525c5b62dedd24b9884e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Nov 2023 14:48:59 +0100 Subject: [PATCH 008/109] Confirm support for Python 3.12. Fix #1417. --- .github/workflows/tests.yml | 1 + docs/project/changelog.rst | 5 +++++ pyproject.toml | 1 + tox.ini | 1 + 4 files changed, 8 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b128defb5..8161f1cbb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -61,6 +61,7 @@ jobs: - "3.9" - "3.10" - "3.11" + - "3.12" - "pypy-3.8" - "pypy-3.9" is_main: diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 200ca7ef3..963353d0e 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,6 +30,11 @@ notice. *In development* +New features +............ + +* Validated compatibility with Python 3.12. + 12.0 ---- diff --git a/pyproject.toml b/pyproject.toml index f24616dd7..a7b4a6a9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ] dynamic = ["version", "readme"] diff --git a/tox.ini b/tox.ini index 939d8c0cd..538b638d9 100644 --- a/tox.ini +++ b/tox.ini @@ -5,6 +5,7 @@ envlist = py39 py310 py311 + py312 coverage black ruff From beeb9387dedb574c8d1a6c2a6e7312c17788c858 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 16 Dec 2023 06:24:38 +0000 Subject: [PATCH 009/109] Bump actions/download-artifact from 3 to 4 Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 3 to 4. - [Release notes](https://github.com/actions/download-artifact/releases) - [Commits](https://github.com/actions/download-artifact/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/download-artifact dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7d56b9aa5..c1b750c80 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -76,7 +76,7 @@ jobs: id-token: write steps: - name: Download artifacts - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: artifact path: dist From 33b20e11e86f8490770185c78ed39adab8db4560 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 16 Dec 2023 06:24:43 +0000 Subject: [PATCH 010/109] Bump actions/upload-artifact from 3 to 4 Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 3 to 4. - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](https://github.com/actions/upload-artifact/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/release.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c1b750c80..4a00bf8fc 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -20,7 +20,7 @@ jobs: - name: Build sdist run: python setup.py sdist - name: Save sdist - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: path: dist/*.tar.gz - name: Install wheel @@ -30,7 +30,7 @@ jobs: BUILD_EXTENSION: no run: python setup.py bdist_wheel - name: Save wheel - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: path: dist/*.whl @@ -60,7 +60,7 @@ jobs: env: BUILD_EXTENSION: yes - name: Save wheels - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: path: wheelhouse/*.whl From b3c51958849c80209b4d68fca081ef3fffc5e2bd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 1 Jan 2024 15:02:51 +0100 Subject: [PATCH 011/109] Make test_local/remote_address more robust. Fix #1427. --- tests/sync/test_connection.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 63544d4ad..e128425d8 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -656,13 +656,17 @@ def test_logger(self): """Connection has a logger attribute.""" self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) - def test_local_address(self): - """Connection has a local_address attribute.""" - self.assertIsNotNone(self.connection.local_address) - - def test_remote_address(self): - """Connection has a remote_address attribute.""" - self.assertIsNotNone(self.connection.remote_address) + @unittest.mock.patch("socket.socket.getsockname", return_value=("sock", 1234)) + def test_local_address(self, getsockname): + """Connection provides a local_address attribute.""" + self.assertEqual(self.connection.local_address, ("sock", 1234)) + getsockname.assert_called_with() + + @unittest.mock.patch("socket.socket.getpeername", return_value=("peer", 1234)) + def test_remote_address(self, getpeername): + """Connection provides a remote_address attribute.""" + self.assertEqual(self.connection.remote_address, ("peer", 1234)) + getpeername.assert_called_with() def test_request(self): """Connection has a request attribute.""" From 9038a62e7261af21109977407907038a1a0efc65 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 1 Jan 2024 15:38:53 +0100 Subject: [PATCH 012/109] Make mypy 1.8.0 happy. --- src/websockets/legacy/auth.py | 2 +- src/websockets/typing.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index d3425836e..e8d6b75d5 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -159,7 +159,7 @@ def basic_auth_protocol_factory( if is_credentials(credentials): credentials_list = [cast(Credentials, credentials)] elif isinstance(credentials, Iterable): - credentials_list = list(credentials) + credentials_list = list(cast(Iterable[Credentials], credentials)) if not all(is_credentials(item) for item in credentials_list): raise TypeError(f"invalid credentials argument: {credentials}") else: diff --git a/src/websockets/typing.py b/src/websockets/typing.py index cc3e3ec0d..e073e650d 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -2,7 +2,7 @@ import http import logging -from typing import List, NewType, Optional, Tuple, Union +from typing import Any, List, NewType, Optional, Tuple, Union __all__ = [ @@ -28,7 +28,7 @@ """ -LoggerLike = Union[logging.Logger, logging.LoggerAdapter] +LoggerLike = Union[logging.Logger, logging.LoggerAdapter[Any]] """Types accepted where a :class:`~logging.Logger` is expected.""" From 230d5052a33c0d940d926a1fc88909d39f57efd8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 1 Jan 2024 15:41:54 +0100 Subject: [PATCH 013/109] Add tests for abstract classes. This prevents Python 3.12 to complain that no test cases were run and to exit with code 5 (which breaks maxi_cov). --- pyproject.toml | 1 - tests/extensions/test_base.py | 28 +++++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a7b4a6a9e..c4c5412c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,6 @@ exclude_lines = [ "if typing.TYPE_CHECKING:", "pragma: no cover", "raise AssertionError", - "raise NotImplementedError", "self.fail\\(\".*\"\\)", "@unittest.skip", ] diff --git a/tests/extensions/test_base.py b/tests/extensions/test_base.py index b18ffb6fb..62250b07f 100644 --- a/tests/extensions/test_base.py +++ b/tests/extensions/test_base.py @@ -1,4 +1,30 @@ +import unittest + from websockets.extensions.base import * +from websockets.frames import Frame, Opcode + + +class ExtensionTests(unittest.TestCase): + def test_encode(self): + with self.assertRaises(NotImplementedError): + Extension().encode(Frame(Opcode.TEXT, b"")) + + def test_decode(self): + with self.assertRaises(NotImplementedError): + Extension().decode(Frame(Opcode.TEXT, b"")) + + +class ClientExtensionFactoryTests(unittest.TestCase): + def test_get_request_params(self): + with self.assertRaises(NotImplementedError): + ClientExtensionFactory().get_request_params() + + def test_process_response_params(self): + with self.assertRaises(NotImplementedError): + ClientExtensionFactory().process_response_params([], []) -# Abstract classes don't provide any behavior to test. +class ServerExtensionFactoryTests(unittest.TestCase): + def test_process_request_params(self): + with self.assertRaises(NotImplementedError): + ServerExtensionFactory().process_request_params([], []) From 5209b2a1cba00b28b8f62502157d5dbb98625a49 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 1 Jan 2024 16:11:16 +0100 Subject: [PATCH 014/109] Remove empty test modules. This prevents Python 3.12 to complain that no test cases were run and to exit with code 5 (which breaks maxi_cov). --- tests/maxi_cov.py | 33 ++++++++++++++++++++++----------- tests/test_auth.py | 1 - tests/test_http.py | 7 +++++++ tests/test_typing.py | 1 - 4 files changed, 29 insertions(+), 13 deletions(-) delete mode 100644 tests/test_auth.py delete mode 100644 tests/test_typing.py diff --git a/tests/maxi_cov.py b/tests/maxi_cov.py index 2568dcf18..bc4a44e8c 100755 --- a/tests/maxi_cov.py +++ b/tests/maxi_cov.py @@ -8,8 +8,15 @@ import sys -UNMAPPED_SRC_FILES = ["websockets/version.py"] -UNMAPPED_TEST_FILES = ["tests/test_exports.py"] +UNMAPPED_SRC_FILES = [ + "websockets/auth.py", + "websockets/typing.py", + "websockets/version.py", +] + +UNMAPPED_TEST_FILES = [ + "tests/test_exports.py", +] def check_environment(): @@ -60,7 +67,7 @@ def get_mapping(src_dir="src"): # Map source files to test files. mapping = {} - unmapped_test_files = [] + unmapped_test_files = set() for test_file in test_files: dir_name, file_name = os.path.split(test_file) @@ -73,26 +80,30 @@ def get_mapping(src_dir="src"): if src_file in src_files: mapping[src_file] = test_file else: - unmapped_test_files.append(test_file) + unmapped_test_files.add(test_file) - unmapped_src_files = list(set(src_files) - set(mapping)) + unmapped_src_files = set(src_files) - set(mapping) # Ensure that all files are mapped. - assert unmapped_src_files == UNMAPPED_SRC_FILES - assert unmapped_test_files == UNMAPPED_TEST_FILES + assert unmapped_src_files == set(UNMAPPED_SRC_FILES) + assert unmapped_test_files == set(UNMAPPED_TEST_FILES) return mapping def get_ignored_files(src_dir="src"): """Return the list of files to exclude from coverage measurement.""" - + # */websockets matches src/websockets and .tox/**/site-packages/websockets. return [ - # */websockets matches src/websockets and .tox/**/site-packages/websockets. - # There are no tests for the __main__ module and for compatibility modules. + # There are no tests for the __main__ module. "*/websockets/__main__.py", + # There is nothing to test on type declarations. + "*/websockets/typing.py", + # We don't test compatibility modules with previous versions of Python + # or websockets (import locations). "*/websockets/*/compatibility.py", + "*/websockets/auth.py", # This approach isn't applicable to the test suite of the legacy # implementation, due to the huge test_client_server test module. "*/websockets/legacy/*", @@ -125,7 +136,7 @@ def run_coverage(mapping, src_dir="src"): "-m", "unittest", ] - + UNMAPPED_TEST_FILES, + + list(UNMAPPED_TEST_FILES), check=True, ) # Append coverage of each source module by the corresponding test module. diff --git a/tests/test_auth.py b/tests/test_auth.py deleted file mode 100644 index 28db93155..000000000 --- a/tests/test_auth.py +++ /dev/null @@ -1 +0,0 @@ -from websockets.auth import * diff --git a/tests/test_http.py b/tests/test_http.py index 036bc1410..baaa7d416 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1 +1,8 @@ +import unittest + from websockets.http import * + + +class HTTPTests(unittest.TestCase): + def test_user_agent(self): + USER_AGENT # exists diff --git a/tests/test_typing.py b/tests/test_typing.py deleted file mode 100644 index 202de840f..000000000 --- a/tests/test_typing.py +++ /dev/null @@ -1 +0,0 @@ -from websockets.typing import * From 3c6b1aab96adde1a4b0d3e8f1a93b7f2c7310af0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 13 Jan 2024 20:46:54 +0100 Subject: [PATCH 015/109] Restore compatibility with Python < 3.11. Broken in 9038a62e. --- src/websockets/typing.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/websockets/typing.py b/src/websockets/typing.py index e073e650d..5dfecf66f 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -2,6 +2,7 @@ import http import logging +import typing from typing import Any, List, NewType, Optional, Tuple, Union @@ -28,8 +29,12 @@ """ -LoggerLike = Union[logging.Logger, logging.LoggerAdapter[Any]] -"""Types accepted where a :class:`~logging.Logger` is expected.""" +if typing.TYPE_CHECKING: + LoggerLike = Union[logging.Logger, logging.LoggerAdapter[Any]] + """Types accepted where a :class:`~logging.Logger` is expected.""" +else: # remove this branch when dropping support for Python < 3.11 + LoggerLike = Union[logging.Logger, logging.LoggerAdapter] + """Types accepted where a :class:`~logging.Logger` is expected.""" StatusLike = Union[http.HTTPStatus, int] From 7b522ec0df8f4e26abe09046a0ae7861714f5a2a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 13 Jan 2024 21:54:08 +0100 Subject: [PATCH 016/109] Simplify code. It had to be written in that way with asyncio.wait_for but that isn't necessary anymore with asyncio.timeout. --- src/websockets/legacy/client.py | 55 ++++++++++++++++----------------- 1 file changed, 26 insertions(+), 29 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 48622523e..b85d22867 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -640,38 +640,35 @@ async def __aexit__( def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]: # Create a suitable iterator by calling __await__ on a coroutine. - return self.__await_impl_timeout__().__await__() - - async def __await_impl_timeout__(self) -> WebSocketClientProtocol: - async with asyncio_timeout(self.open_timeout): - return await self.__await_impl__() + return self.__await_impl__().__await__() async def __await_impl__(self) -> WebSocketClientProtocol: - for redirects in range(self.MAX_REDIRECTS_ALLOWED): - _transport, _protocol = await self._create_connection() - protocol = cast(WebSocketClientProtocol, _protocol) - try: - await protocol.handshake( - self._wsuri, - origin=protocol.origin, - available_extensions=protocol.available_extensions, - available_subprotocols=protocol.available_subprotocols, - extra_headers=protocol.extra_headers, - ) - except RedirectHandshake as exc: - protocol.fail_connection() - await protocol.wait_closed() - self.handle_redirect(exc.uri) - # Avoid leaking a connected socket when the handshake fails. - except (Exception, asyncio.CancelledError): - protocol.fail_connection() - await protocol.wait_closed() - raise + async with asyncio_timeout(self.open_timeout): + for _redirects in range(self.MAX_REDIRECTS_ALLOWED): + _transport, _protocol = await self._create_connection() + protocol = cast(WebSocketClientProtocol, _protocol) + try: + await protocol.handshake( + self._wsuri, + origin=protocol.origin, + available_extensions=protocol.available_extensions, + available_subprotocols=protocol.available_subprotocols, + extra_headers=protocol.extra_headers, + ) + except RedirectHandshake as exc: + protocol.fail_connection() + await protocol.wait_closed() + self.handle_redirect(exc.uri) + # Avoid leaking a connected socket when the handshake fails. + except (Exception, asyncio.CancelledError): + protocol.fail_connection() + await protocol.wait_closed() + raise + else: + self.protocol = protocol + return protocol else: - self.protocol = protocol - return protocol - else: - raise SecurityError("too many redirects") + raise SecurityError("too many redirects") # ... = yield from connect(...) - remove when dropping Python < 3.10 From 35bc7dd8288445289134c335aae8af859862ccd1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 14 Jan 2024 21:20:06 +0100 Subject: [PATCH 017/109] Create futures with create_future. This is the preferred way to create Futures in asyncio. --- README.rst | 2 +- compliance/test_server.py | 2 +- docs/intro/tutorial1.rst | 2 +- docs/topics/broadcast.rst | 5 +++-- example/django/authentication.py | 2 +- example/echo.py | 2 +- example/faq/health_check_server.py | 2 +- example/legacy/basic_auth_server.py | 2 +- example/legacy/unix_server.py | 2 +- example/quickstart/counter.py | 2 +- example/quickstart/server.py | 2 +- example/quickstart/server_secure.py | 2 +- example/quickstart/show_time.py | 2 +- example/tutorial/step1/app.py | 2 +- example/tutorial/step2/app.py | 2 +- experiments/broadcast/server.py | 5 +++-- src/websockets/legacy/protocol.py | 4 ++-- src/websockets/legacy/server.py | 6 ++++-- 18 files changed, 26 insertions(+), 22 deletions(-) diff --git a/README.rst b/README.rst index 870b208ba..94cd79ab9 100644 --- a/README.rst +++ b/README.rst @@ -55,7 +55,7 @@ Here's an echo server with the ``asyncio`` API: async def main(): async with serve(echo, "localhost", 8765): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever asyncio.run(main()) diff --git a/compliance/test_server.py b/compliance/test_server.py index 92f895d92..5701e4485 100644 --- a/compliance/test_server.py +++ b/compliance/test_server.py @@ -21,7 +21,7 @@ async def echo(ws): async def main(): with websockets.serve(echo, HOST, PORT, max_size=2 ** 25, max_queue=1): try: - await asyncio.Future() + await asyncio.get_running_loop().create_future() # run forever except KeyboardInterrupt: pass diff --git a/docs/intro/tutorial1.rst b/docs/intro/tutorial1.rst index ff85003b5..6b32d47f6 100644 --- a/docs/intro/tutorial1.rst +++ b/docs/intro/tutorial1.rst @@ -195,7 +195,7 @@ Create an ``app.py`` file next to ``connect4.py`` with this content: async def main(): async with websockets.serve(handler, "", 8001): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index 1acb372d4..b6ddda734 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -273,10 +273,11 @@ Here's a message stream that supports multiple consumers:: class PubSub: def __init__(self): - self.waiter = asyncio.Future() + self.waiter = asyncio.get_running_loop().create_future() def publish(self, value): - waiter, self.waiter = self.waiter, asyncio.Future() + waiter = self.waiter + self.waiter = asyncio.get_running_loop().create_future() waiter.set_result((value, self.waiter)) async def subscribe(self): diff --git a/example/django/authentication.py b/example/django/authentication.py index f6dad0f55..83e128f07 100644 --- a/example/django/authentication.py +++ b/example/django/authentication.py @@ -23,7 +23,7 @@ async def handler(websocket): async def main(): async with websockets.serve(handler, "localhost", 8888): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": diff --git a/example/echo.py b/example/echo.py index 2e47e52d9..d11b33527 100755 --- a/example/echo.py +++ b/example/echo.py @@ -9,6 +9,6 @@ async def echo(websocket): async def main(): async with serve(echo, "localhost", 8765): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever asyncio.run(main()) diff --git a/example/faq/health_check_server.py b/example/faq/health_check_server.py index 7b8bded77..6c7681e8a 100755 --- a/example/faq/health_check_server.py +++ b/example/faq/health_check_server.py @@ -17,6 +17,6 @@ async def main(): echo, "localhost", 8765, process_request=health_check, ): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever asyncio.run(main()) diff --git a/example/legacy/basic_auth_server.py b/example/legacy/basic_auth_server.py index d2efeb7e5..6f6020253 100755 --- a/example/legacy/basic_auth_server.py +++ b/example/legacy/basic_auth_server.py @@ -16,6 +16,6 @@ async def main(): realm="example", credentials=("mary", "p@ssw0rd") ), ): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever asyncio.run(main()) diff --git a/example/legacy/unix_server.py b/example/legacy/unix_server.py index 335039c35..5bfb66072 100755 --- a/example/legacy/unix_server.py +++ b/example/legacy/unix_server.py @@ -18,6 +18,6 @@ async def hello(websocket): async def main(): socket_path = os.path.join(os.path.dirname(__file__), "socket") async with websockets.unix_serve(hello, socket_path): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever asyncio.run(main()) diff --git a/example/quickstart/counter.py b/example/quickstart/counter.py index 566e12965..414919e04 100755 --- a/example/quickstart/counter.py +++ b/example/quickstart/counter.py @@ -43,7 +43,7 @@ async def counter(websocket): async def main(): async with websockets.serve(counter, "localhost", 6789): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": asyncio.run(main()) diff --git a/example/quickstart/server.py b/example/quickstart/server.py index 31b182972..64d7adeb6 100755 --- a/example/quickstart/server.py +++ b/example/quickstart/server.py @@ -14,7 +14,7 @@ async def hello(websocket): async def main(): async with websockets.serve(hello, "localhost", 8765): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": asyncio.run(main()) diff --git a/example/quickstart/server_secure.py b/example/quickstart/server_secure.py index de41d30dc..11db5fb3a 100755 --- a/example/quickstart/server_secure.py +++ b/example/quickstart/server_secure.py @@ -20,7 +20,7 @@ async def hello(websocket): async def main(): async with websockets.serve(hello, "localhost", 8765, ssl=ssl_context): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": asyncio.run(main()) diff --git a/example/quickstart/show_time.py b/example/quickstart/show_time.py index a83078e8a..add226869 100755 --- a/example/quickstart/show_time.py +++ b/example/quickstart/show_time.py @@ -13,7 +13,7 @@ async def show_time(websocket): async def main(): async with websockets.serve(show_time, "localhost", 5678): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": asyncio.run(main()) diff --git a/example/tutorial/step1/app.py b/example/tutorial/step1/app.py index 3b0fbd786..6ec1c60b8 100644 --- a/example/tutorial/step1/app.py +++ b/example/tutorial/step1/app.py @@ -58,7 +58,7 @@ async def handler(websocket): async def main(): async with websockets.serve(handler, "", 8001): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": diff --git a/example/tutorial/step2/app.py b/example/tutorial/step2/app.py index 2693d4304..db3e36374 100644 --- a/example/tutorial/step2/app.py +++ b/example/tutorial/step2/app.py @@ -183,7 +183,7 @@ async def handler(websocket): async def main(): async with websockets.serve(handler, "", 8001): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": diff --git a/experiments/broadcast/server.py b/experiments/broadcast/server.py index 9c9907b7f..b0407ba34 100644 --- a/experiments/broadcast/server.py +++ b/experiments/broadcast/server.py @@ -27,10 +27,11 @@ async def relay(queue, websocket): class PubSub: def __init__(self): - self.waiter = asyncio.Future() + self.waiter = asyncio.get_running_loop().create_future() def publish(self, value): - waiter, self.waiter = self.waiter, asyncio.Future() + waiter = self.waiter + self.waiter = asyncio.get_running_loop().create_future() waiter.set_result((value, self.waiter)) async def subscribe(self): diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 19cee0e65..47f948b7a 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -664,7 +664,7 @@ async def send( return opcode, data = prepare_data(fragment) - self._fragmented_message_waiter = asyncio.Future() + self._fragmented_message_waiter = self.loop.create_future() try: # First fragment. await self.write_frame(False, opcode, data) @@ -709,7 +709,7 @@ async def send( return opcode, data = prepare_data(fragment) - self._fragmented_message_waiter = asyncio.Future() + self._fragmented_message_waiter = self.loop.create_future() try: # First fragment. await self.write_frame(False, opcode, data) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 7c24dd74a..d95bec4f6 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -897,7 +897,8 @@ class Serve: Awaiting :func:`serve` yields a :class:`WebSocketServer`. This object provides a :meth:`~WebSocketServer.close` method to shut down the server:: - stop = asyncio.Future() # set this future to exit the server + # set this future to exit the server + stop = asyncio.get_running_loop().create_future() server = await serve(...) await stop @@ -906,7 +907,8 @@ class Serve: :func:`serve` can be used as an asynchronous context manager. Then, the server is shut down automatically when exiting the context:: - stop = asyncio.Future() # set this future to exit the server + # set this future to exit the server + stop = asyncio.get_running_loop().create_future() async with serve(...): await stop From cba4c242614734a722891992e8bc005bc848c0c1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 20 Jan 2024 18:56:07 +0100 Subject: [PATCH 018/109] Avoid duplicating signature in API docs. --- docs/reference/sansio/client.rst | 2 +- docs/reference/sansio/common.rst | 2 +- docs/reference/sansio/server.rst | 2 +- docs/reference/sync/client.rst | 4 ++-- docs/reference/sync/server.rst | 4 ++-- src/websockets/sync/client.py | 18 ++++++++++-------- src/websockets/sync/server.py | 14 +++++++++----- 7 files changed, 26 insertions(+), 20 deletions(-) diff --git a/docs/reference/sansio/client.rst b/docs/reference/sansio/client.rst index 09bafc745..12f88b8ed 100644 --- a/docs/reference/sansio/client.rst +++ b/docs/reference/sansio/client.rst @@ -5,7 +5,7 @@ Client (`Sans-I/O`_) .. currentmodule:: websockets.client -.. autoclass:: ClientProtocol(wsuri, origin=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) +.. autoclass:: ClientProtocol .. automethod:: receive_data diff --git a/docs/reference/sansio/common.rst b/docs/reference/sansio/common.rst index cd1ef3c63..7d5447ac9 100644 --- a/docs/reference/sansio/common.rst +++ b/docs/reference/sansio/common.rst @@ -7,7 +7,7 @@ Both sides (`Sans-I/O`_) .. automodule:: websockets.protocol -.. autoclass:: Protocol(side, state=State.OPEN, max_size=2 ** 20, logger=None) +.. autoclass:: Protocol .. automethod:: receive_data diff --git a/docs/reference/sansio/server.rst b/docs/reference/sansio/server.rst index d70df6277..3152f174e 100644 --- a/docs/reference/sansio/server.rst +++ b/docs/reference/sansio/server.rst @@ -5,7 +5,7 @@ Server (`Sans-I/O`_) .. currentmodule:: websockets.server -.. autoclass:: ServerProtocol(origins=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) +.. autoclass:: ServerProtocol .. automethod:: receive_data diff --git a/docs/reference/sync/client.rst b/docs/reference/sync/client.rst index 6cccd6ec4..af1132412 100644 --- a/docs/reference/sync/client.rst +++ b/docs/reference/sync/client.rst @@ -6,9 +6,9 @@ Client (:mod:`threading`) Opening a connection -------------------- -.. autofunction:: connect(uri, *, sock=None, ssl_context=None, server_hostname=None, origin=None, extensions=None, subprotocols=None, additional_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", compression="deflate", open_timeout=10, close_timeout=10, max_size=2 ** 20, logger=None, create_connection=None) +.. autofunction:: connect -.. autofunction:: unix_connect(path, uri=None, *, sock=None, ssl_context=None, server_hostname=None, origin=None, extensions=None, subprotocols=None, additional_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", compression="deflate", open_timeout=10, close_timeout=10, max_size=2 ** 20, logger=None, create_connection=None) +.. autofunction:: unix_connect Using a connection ------------------ diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst index 35c112046..7ed744df2 100644 --- a/docs/reference/sync/server.rst +++ b/docs/reference/sync/server.rst @@ -6,9 +6,9 @@ Server (:mod:`threading`) Creating a server ----------------- -.. autofunction:: serve(handler, host=None, port=None, *, sock=None, ssl_context=None, origins=None, extensions=None, subprotocols=None, select_subprotocol=None, process_request=None, process_response=None, server_header="Python/x.y.z websockets/X.Y", compression="deflate", open_timeout=10, close_timeout=10, max_size=2 ** 20, logger=None, create_connection=None) +.. autofunction:: serve -.. autofunction:: unix_serve(handler, path=None, *, sock=None, ssl_context=None, origins=None, extensions=None, subprotocols=None, select_subprotocol=None, process_request=None, process_response=None, server_header="Python/x.y.z websockets/X.Y", compression="deflate", open_timeout=10, close_timeout=10, max_size=2 ** 20, logger=None, create_connection=None) +.. autofunction:: unix_serve Running a server ---------------- diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 087ff5f56..78a9a3c86 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -126,12 +126,10 @@ def recv_events(self) -> None: def connect( uri: str, *, - # TCP/TLS — unix and path are only for unix_connect() + # TCP/TLS sock: Optional[socket.socket] = None, ssl_context: Optional[ssl.SSLContext] = None, server_hostname: Optional[str] = None, - unix: bool = False, - path: Optional[str] = None, # WebSocket origin: Optional[Origin] = None, extensions: Optional[Sequence[ClientExtensionFactory]] = None, @@ -148,6 +146,7 @@ def connect( logger: Optional[LoggerLike] = None, # Escape hatch for advanced customization create_connection: Optional[Type[ClientConnection]] = None, + **kwargs: Any, ) -> ClientConnection: """ Connect to the WebSocket server at ``uri``. @@ -210,13 +209,15 @@ def connect( if not wsuri.secure and ssl_context is not None: raise TypeError("ssl_context argument is incompatible with a ws:// URI") + # Private APIs for unix_connect() + unix: bool = kwargs.pop("unix", False) + path: Optional[str] = kwargs.pop("path", None) + if unix: if path is None and sock is None: raise TypeError("missing path argument") elif path is not None and sock is not None: raise TypeError("path and sock arguments are incompatible") - else: - assert path is None # private argument, only set by unix_connect() if subprotocols is not None: validate_subprotocols(subprotocols) @@ -241,7 +242,7 @@ def connect( if unix: sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock.settimeout(deadline.timeout()) - assert path is not None # validated above -- this is for mpypy + assert path is not None # mypy cannot figure this out sock.connect(path) else: sock = socket.create_connection( @@ -308,8 +309,9 @@ def unix_connect( """ Connect to a WebSocket server listening on a Unix socket. - This function is identical to :func:`connect`, except for the additional - ``path`` argument. It's only available on Unix. + This function accepts the same keyword arguments as :func:`connect`. + + It's only available on Unix. It's mainly useful for debugging servers listening on Unix sockets. diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index d12da0c65..7faab0a3d 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -266,11 +266,9 @@ def serve( host: Optional[str] = None, port: Optional[int] = None, *, - # TCP/TLS — unix and path are only for unix_serve() + # TCP/TLS sock: Optional[socket.socket] = None, ssl_context: Optional[ssl.SSLContext] = None, - unix: bool = False, - path: Optional[str] = None, # WebSocket origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, @@ -304,6 +302,7 @@ def serve( logger: Optional[LoggerLike] = None, # Escape hatch for advanced customization create_connection: Optional[Type[ServerConnection]] = None, + **kwargs: Any, ) -> WebSocketServer: """ Create a WebSocket server listening on ``host`` and ``port``. @@ -397,6 +396,10 @@ def handler(websocket): # Bind socket and listen + # Private APIs for unix_connect() + unix: bool = kwargs.pop("unix", False) + path: Optional[str] = kwargs.pop("path", None) + if sock is None: if unix: if path is None: @@ -515,8 +518,9 @@ def unix_serve( """ Create a WebSocket server listening on a Unix socket. - This function is identical to :func:`serve`, except the ``host`` and - ``port`` arguments are replaced by ``path``. It's only available on Unix. + This function accepts the same keyword arguments as :func:`serve`. + + It's only available on Unix. It's useful for deploying a server behind a reverse proxy such as nginx. From cd4bc7960658db6d51f60f528b3b53c718426591 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 20 Jan 2024 19:08:49 +0100 Subject: [PATCH 019/109] Pass arguments to create_server/connection. --- src/websockets/sync/client.py | 8 ++++---- src/websockets/sync/server.py | 8 ++++++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 78a9a3c86..79af0132f 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -195,6 +195,8 @@ def connect( the connection. Set it to a wrapper or a subclass to customize connection handling. + Any other keyword arguments are passed to :func:`~socket.create_connection`. + Raises: InvalidURI: If ``uri`` isn't a valid WebSocket URI. OSError: If the TCP connection fails. @@ -245,10 +247,8 @@ def connect( assert path is not None # mypy cannot figure this out sock.connect(path) else: - sock = socket.create_connection( - (wsuri.host, wsuri.port), - deadline.timeout(), - ) + kwargs.setdefault("timeout", deadline.timeout()) + sock = socket.create_connection((wsuri.host, wsuri.port), **kwargs) sock.settimeout(None) # Disable Nagle algorithm diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 7faab0a3d..c19992849 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -379,6 +379,9 @@ def handler(websocket): create_connection: Factory for the :class:`ServerConnection` managing the connection. Set it to a wrapper or a subclass to customize connection handling. + + Any other keyword arguments are passed to :func:`~socket.create_server`. + """ # Process parameters @@ -404,9 +407,10 @@ def handler(websocket): if unix: if path is None: raise TypeError("missing path argument") - sock = socket.create_server(path, family=socket.AF_UNIX) + kwargs.setdefault("family", socket.AF_UNIX) + sock = socket.create_server(path, **kwargs) else: - sock = socket.create_server((host, port)) + sock = socket.create_server((host, port), **kwargs) else: if path is not None: raise TypeError("path and sock arguments are incompatible") From 03ecfa5611f0c87ea9cfa7497f78e0c85408060e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 20 Jan 2024 19:14:28 +0100 Subject: [PATCH 020/109] Standardize on .encode(). We had a mix of .encode() and .encode("utf-8") -- which is the default. --- experiments/compression/benchmark.py | 2 +- src/websockets/frames.py | 6 +- src/websockets/sync/connection.py | 8 +- tests/extensions/test_permessage_deflate.py | 26 +++--- tests/legacy/test_framing.py | 6 +- tests/legacy/test_protocol.py | 90 ++++++++++----------- tests/test_frames.py | 4 +- 7 files changed, 69 insertions(+), 73 deletions(-) diff --git a/experiments/compression/benchmark.py b/experiments/compression/benchmark.py index c5b13c8fa..4fbdf6220 100644 --- a/experiments/compression/benchmark.py +++ b/experiments/compression/benchmark.py @@ -66,7 +66,7 @@ def _run(data): for _ in range(REPEAT): for item in data: if isinstance(item, str): - item = item.encode("utf-8") + item = item.encode() # Taken from PerMessageDeflate.encode item = encoder.compress(item) + encoder.flush(zlib.Z_SYNC_FLUSH) if item.endswith(b"\x00\x00\xff\xff"): diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 6b1befb2e..63c35ed4d 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -364,7 +364,7 @@ def prepare_data(data: Data) -> Tuple[int, bytes]: """ if isinstance(data, str): - return OP_TEXT, data.encode("utf-8") + return OP_TEXT, data.encode() elif isinstance(data, BytesLike): return OP_BINARY, data else: @@ -387,7 +387,7 @@ def prepare_ctrl(data: Data) -> bytes: """ if isinstance(data, str): - return data.encode("utf-8") + return data.encode() elif isinstance(data, BytesLike): return bytes(data) else: @@ -456,7 +456,7 @@ def serialize(self) -> bytes: """ self.check() - return struct.pack("!H", self.code) + self.reason.encode("utf-8") + return struct.pack("!H", self.code) + self.reason.encode() def check(self) -> None: """ diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 4a8879e37..62aa17ffd 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -287,7 +287,7 @@ def send(self, message: Union[Data, Iterable[Data]]) -> None: "cannot call send while another thread " "is already running send" ) - self.protocol.send_text(message.encode("utf-8")) + self.protocol.send_text(message.encode()) elif isinstance(message, BytesLike): with self.send_context(): @@ -324,7 +324,7 @@ def send(self, message: Union[Data, Iterable[Data]]) -> None: ) self.send_in_progress = True self.protocol.send_text( - chunk.encode("utf-8"), + chunk.encode(), fin=False, ) elif isinstance(chunk, BytesLike): @@ -349,7 +349,7 @@ def send(self, message: Union[Data, Iterable[Data]]) -> None: with self.send_context(): assert self.send_in_progress self.protocol.send_continuation( - chunk.encode("utf-8"), + chunk.encode(), fin=False, ) elif isinstance(chunk, BytesLike) and not text: @@ -630,7 +630,7 @@ def send_context( socket:: with self.send_context(): - self.protocol.send_text(message.encode("utf-8")) + self.protocol.send_text(message.encode()) When the connection isn't open on entry, when the connection is expected to close on exit, or when an unexpected error happens, terminating the diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index 0e698566f..ee09813c4 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -84,7 +84,7 @@ def test_no_encode_decode_close_frame(self): # Data frames are encoded and decoded. def test_encode_decode_text_frame(self): - frame = Frame(OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode()) enc_frame = self.extension.encode(frame) @@ -112,9 +112,9 @@ def test_encode_decode_binary_frame(self): self.assertEqual(dec_frame, frame) def test_encode_decode_fragmented_text_frame(self): - frame1 = Frame(OP_TEXT, "café".encode("utf-8"), fin=False) - frame2 = Frame(OP_CONT, " & ".encode("utf-8"), fin=False) - frame3 = Frame(OP_CONT, "croissants".encode("utf-8")) + frame1 = Frame(OP_TEXT, "café".encode(), fin=False) + frame2 = Frame(OP_CONT, " & ".encode(), fin=False) + frame3 = Frame(OP_CONT, "croissants".encode()) enc_frame1 = self.extension.encode(frame1) enc_frame2 = self.extension.encode(frame2) @@ -168,7 +168,7 @@ def test_encode_decode_fragmented_binary_frame(self): self.assertEqual(dec_frame2, frame2) def test_no_decode_text_frame(self): - frame = Frame(OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode()) # Try decoding a frame that wasn't encoded. self.assertEqual(self.extension.decode(frame), frame) @@ -180,9 +180,9 @@ def test_no_decode_binary_frame(self): self.assertEqual(self.extension.decode(frame), frame) def test_no_decode_fragmented_text_frame(self): - frame1 = Frame(OP_TEXT, "café".encode("utf-8"), fin=False) - frame2 = Frame(OP_CONT, " & ".encode("utf-8"), fin=False) - frame3 = Frame(OP_CONT, "croissants".encode("utf-8")) + frame1 = Frame(OP_TEXT, "café".encode(), fin=False) + frame2 = Frame(OP_CONT, " & ".encode(), fin=False) + frame3 = Frame(OP_CONT, "croissants".encode()) dec_frame1 = self.extension.decode(frame1) dec_frame2 = self.extension.decode(frame2) @@ -203,7 +203,7 @@ def test_no_decode_fragmented_binary_frame(self): self.assertEqual(dec_frame2, frame2) def test_context_takeover(self): - frame = Frame(OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode()) enc_frame1 = self.extension.encode(frame) enc_frame2 = self.extension.encode(frame) @@ -215,7 +215,7 @@ def test_remote_no_context_takeover(self): # No context takeover when decoding messages. self.extension = PerMessageDeflate(True, False, 15, 15) - frame = Frame(OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode()) enc_frame1 = self.extension.encode(frame) enc_frame2 = self.extension.encode(frame) @@ -233,7 +233,7 @@ def test_local_no_context_takeover(self): # No context takeover when encoding and decoding messages. self.extension = PerMessageDeflate(True, True, 15, 15) - frame = Frame(OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode()) enc_frame1 = self.extension.encode(frame) enc_frame2 = self.extension.encode(frame) @@ -253,7 +253,7 @@ def test_compress_settings(self): # Configure an extension so that no compression actually occurs. extension = PerMessageDeflate(False, False, 15, 15, {"level": 0}) - frame = Frame(OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode()) enc_frame = extension.encode(frame) @@ -269,7 +269,7 @@ def test_compress_settings(self): # Frames aren't decoded beyond max_size. def test_decompress_max_size(self): - frame = Frame(OP_TEXT, ("a" * 20).encode("utf-8")) + frame = Frame(OP_TEXT, ("a" * 20).encode()) enc_frame = self.extension.encode(frame) diff --git a/tests/legacy/test_framing.py b/tests/legacy/test_framing.py index e1e4c891b..6f811bd5e 100644 --- a/tests/legacy/test_framing.py +++ b/tests/legacy/test_framing.py @@ -76,14 +76,12 @@ def test_binary_masked(self): ) def test_non_ascii_text(self): - self.round_trip( - b"\x81\x05caf\xc3\xa9", Frame(True, OP_TEXT, "café".encode("utf-8")) - ) + self.round_trip(b"\x81\x05caf\xc3\xa9", Frame(True, OP_TEXT, "café".encode())) def test_non_ascii_text_masked(self): self.round_trip( b"\x81\x85\x64\xbe\xee\x7e\x07\xdf\x88\xbd\xcd", - Frame(True, OP_TEXT, "café".encode("utf-8")), + Frame(True, OP_TEXT, "café".encode()), mask=True, ) diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index f2eb0fea0..f3dcd9ac7 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -428,7 +428,7 @@ def test_close_reason_not_set(self): # Test the recv coroutine. def test_recv_text(self): - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + self.receive_frame(Frame(True, OP_TEXT, "café".encode())) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, "café") @@ -458,7 +458,7 @@ def test_recv_on_closed_connection(self): self.loop.run_until_complete(self.protocol.recv()) def test_recv_protocol_error(self): - self.receive_frame(Frame(True, OP_CONT, "café".encode("utf-8"))) + self.receive_frame(Frame(True, OP_CONT, "café".encode())) self.process_invalid_frames() self.assertConnectionFailed(CloseCode.PROTOCOL_ERROR, "") @@ -469,7 +469,7 @@ def test_recv_unicode_error(self): def test_recv_text_payload_too_big(self): self.protocol.max_size = 1024 - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8") * 205)) + self.receive_frame(Frame(True, OP_TEXT, "café".encode() * 205)) self.process_invalid_frames() self.assertConnectionFailed(CloseCode.MESSAGE_TOO_BIG, "") @@ -481,7 +481,7 @@ def test_recv_binary_payload_too_big(self): def test_recv_text_no_max_size(self): self.protocol.max_size = None # for test coverage - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8") * 205)) + self.receive_frame(Frame(True, OP_TEXT, "café".encode() * 205)) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, "café" * 205) @@ -498,7 +498,7 @@ def test_recv_queue_empty(self): asyncio.wait_for(asyncio.shield(recv), timeout=MS) ) - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + self.receive_frame(Frame(True, OP_TEXT, "café".encode())) data = self.loop.run_until_complete(recv) self.assertEqual(data, "café") @@ -507,7 +507,7 @@ def test_recv_queue_full(self): # Test internals because it's hard to verify buffers from the outside. self.assertEqual(list(self.protocol.messages), []) - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + self.receive_frame(Frame(True, OP_TEXT, "café".encode())) self.run_loop_once() self.assertEqual(list(self.protocol.messages), ["café"]) @@ -535,7 +535,7 @@ def test_recv_queue_no_limit(self): self.protocol.max_queue = None for _ in range(100): - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + self.receive_frame(Frame(True, OP_TEXT, "café".encode())) self.run_loop_once() # Incoming message queue can contain at least 100 messages. @@ -562,7 +562,7 @@ def test_recv_canceled(self): self.loop.run_until_complete(recv) # The next frame doesn't disappear in a vacuum (it used to). - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + self.receive_frame(Frame(True, OP_TEXT, "café".encode())) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, "café") @@ -570,15 +570,13 @@ def test_recv_canceled_race_condition(self): recv = self.loop.create_task( asyncio.wait_for(self.protocol.recv(), timeout=0.000_001) ) - self.loop.call_soon( - self.receive_frame, Frame(True, OP_TEXT, "café".encode("utf-8")) - ) + self.loop.call_soon(self.receive_frame, Frame(True, OP_TEXT, "café".encode())) with self.assertRaises(asyncio.TimeoutError): self.loop.run_until_complete(recv) # The previous frame doesn't disappear in a vacuum (it used to). - self.receive_frame(Frame(True, OP_TEXT, "tea".encode("utf-8"))) + self.receive_frame(Frame(True, OP_TEXT, "tea".encode())) data = self.loop.run_until_complete(self.protocol.recv()) # If we're getting "tea" there, it means "café" was swallowed (ha, ha). self.assertEqual(data, "café") @@ -586,7 +584,7 @@ def test_recv_canceled_race_condition(self): def test_recv_when_transfer_data_cancelled(self): # Clog incoming queue. self.protocol.max_queue = 1 - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + self.receive_frame(Frame(True, OP_TEXT, "café".encode())) self.receive_frame(Frame(True, OP_BINARY, b"tea")) self.run_loop_once() @@ -620,7 +618,7 @@ def test_recv_prevents_concurrent_calls(self): def test_send_text(self): self.loop.run_until_complete(self.protocol.send("café")) - self.assertOneFrameSent(True, OP_TEXT, "café".encode("utf-8")) + self.assertOneFrameSent(True, OP_TEXT, "café".encode()) def test_send_binary(self): self.loop.run_until_complete(self.protocol.send(b"tea")) @@ -647,9 +645,9 @@ def test_send_type_error(self): def test_send_iterable_text(self): self.loop.run_until_complete(self.protocol.send(["ca", "fé"])) self.assertFramesSent( - (False, OP_TEXT, "ca".encode("utf-8")), - (False, OP_CONT, "fé".encode("utf-8")), - (True, OP_CONT, "".encode("utf-8")), + (False, OP_TEXT, "ca".encode()), + (False, OP_CONT, "fé".encode()), + (True, OP_CONT, "".encode()), ) def test_send_iterable_binary(self): @@ -687,7 +685,7 @@ def test_send_iterable_mixed_type_error(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.send(["café", b"tea"])) self.assertFramesSent( - (False, OP_TEXT, "café".encode("utf-8")), + (False, OP_TEXT, "café".encode()), (True, OP_CLOSE, Close(CloseCode.INTERNAL_ERROR, "").serialize()), ) @@ -710,18 +708,18 @@ async def run_concurrently(): self.loop.run_until_complete(run_concurrently()) self.assertFramesSent( - (False, OP_TEXT, "ca".encode("utf-8")), - (False, OP_CONT, "fé".encode("utf-8")), - (True, OP_CONT, "".encode("utf-8")), + (False, OP_TEXT, "ca".encode()), + (False, OP_CONT, "fé".encode()), + (True, OP_CONT, "".encode()), (True, OP_BINARY, b"tea"), ) def test_send_async_iterable_text(self): self.loop.run_until_complete(self.protocol.send(async_iterable(["ca", "fé"]))) self.assertFramesSent( - (False, OP_TEXT, "ca".encode("utf-8")), - (False, OP_CONT, "fé".encode("utf-8")), - (True, OP_CONT, "".encode("utf-8")), + (False, OP_TEXT, "ca".encode()), + (False, OP_CONT, "fé".encode()), + (True, OP_CONT, "".encode()), ) def test_send_async_iterable_binary(self): @@ -761,7 +759,7 @@ def test_send_async_iterable_mixed_type_error(self): self.protocol.send(async_iterable(["café", b"tea"])) ) self.assertFramesSent( - (False, OP_TEXT, "café".encode("utf-8")), + (False, OP_TEXT, "café".encode()), (True, OP_CLOSE, Close(CloseCode.INTERNAL_ERROR, "").serialize()), ) @@ -784,9 +782,9 @@ async def run_concurrently(): self.loop.run_until_complete(run_concurrently()) self.assertFramesSent( - (False, OP_TEXT, "ca".encode("utf-8")), - (False, OP_CONT, "fé".encode("utf-8")), - (True, OP_CONT, "".encode("utf-8")), + (False, OP_TEXT, "ca".encode()), + (False, OP_CONT, "fé".encode()), + (True, OP_CONT, "".encode()), (True, OP_BINARY, b"tea"), ) @@ -829,7 +827,7 @@ def test_ping_default(self): def test_ping_text(self): self.loop.run_until_complete(self.protocol.ping("café")) - self.assertOneFrameSent(True, OP_PING, "café".encode("utf-8")) + self.assertOneFrameSent(True, OP_PING, "café".encode()) def test_ping_binary(self): self.loop.run_until_complete(self.protocol.ping(b"tea")) @@ -882,7 +880,7 @@ def test_pong_default(self): def test_pong_text(self): self.loop.run_until_complete(self.protocol.pong("café")) - self.assertOneFrameSent(True, OP_PONG, "café".encode("utf-8")) + self.assertOneFrameSent(True, OP_PONG, "café".encode()) def test_pong_binary(self): self.loop.run_until_complete(self.protocol.pong(b"tea")) @@ -1072,8 +1070,8 @@ def test_return_latency_on_pong(self): # Test the protocol's logic for rebuilding fragmented messages. def test_fragmented_text(self): - self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) - self.receive_frame(Frame(True, OP_CONT, "fé".encode("utf-8"))) + self.receive_frame(Frame(False, OP_TEXT, "ca".encode())) + self.receive_frame(Frame(True, OP_CONT, "fé".encode())) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, "café") @@ -1086,8 +1084,8 @@ def test_fragmented_binary(self): def test_fragmented_text_payload_too_big(self): self.protocol.max_size = 1024 - self.receive_frame(Frame(False, OP_TEXT, "café".encode("utf-8") * 100)) - self.receive_frame(Frame(True, OP_CONT, "café".encode("utf-8") * 105)) + self.receive_frame(Frame(False, OP_TEXT, "café".encode() * 100)) + self.receive_frame(Frame(True, OP_CONT, "café".encode() * 105)) self.process_invalid_frames() self.assertConnectionFailed(CloseCode.MESSAGE_TOO_BIG, "") @@ -1100,8 +1098,8 @@ def test_fragmented_binary_payload_too_big(self): def test_fragmented_text_no_max_size(self): self.protocol.max_size = None # for test coverage - self.receive_frame(Frame(False, OP_TEXT, "café".encode("utf-8") * 100)) - self.receive_frame(Frame(True, OP_CONT, "café".encode("utf-8") * 105)) + self.receive_frame(Frame(False, OP_TEXT, "café".encode() * 100)) + self.receive_frame(Frame(True, OP_CONT, "café".encode() * 105)) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, "café" * 205) @@ -1113,22 +1111,22 @@ def test_fragmented_binary_no_max_size(self): self.assertEqual(data, b"tea" * 342) def test_control_frame_within_fragmented_text(self): - self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) + self.receive_frame(Frame(False, OP_TEXT, "ca".encode())) self.receive_frame(Frame(True, OP_PING, b"")) - self.receive_frame(Frame(True, OP_CONT, "fé".encode("utf-8"))) + self.receive_frame(Frame(True, OP_CONT, "fé".encode())) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, "café") self.assertOneFrameSent(True, OP_PONG, b"") def test_unterminated_fragmented_text(self): - self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) + self.receive_frame(Frame(False, OP_TEXT, "ca".encode())) # Missing the second part of the fragmented frame. self.receive_frame(Frame(True, OP_BINARY, b"tea")) self.process_invalid_frames() self.assertConnectionFailed(CloseCode.PROTOCOL_ERROR, "") def test_close_handshake_in_fragmented_text(self): - self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) + self.receive_frame(Frame(False, OP_TEXT, "ca".encode())) self.receive_frame(Frame(True, OP_CLOSE, b"")) self.process_invalid_frames() # The RFC may have overlooked this case: it says that control frames @@ -1138,7 +1136,7 @@ def test_close_handshake_in_fragmented_text(self): self.assertConnectionClosed(CloseCode.NO_STATUS_RCVD, "") def test_connection_close_in_fragmented_text(self): - self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) + self.receive_frame(Frame(False, OP_TEXT, "ca".encode())) self.process_invalid_frames() self.assertConnectionFailed(CloseCode.ABNORMAL_CLOSURE, "") @@ -1472,7 +1470,7 @@ def test_remote_close_during_send(self): def test_broadcast_text(self): broadcast([self.protocol], "café") - self.assertOneFrameSent(True, OP_TEXT, "café".encode("utf-8")) + self.assertOneFrameSent(True, OP_TEXT, "café".encode()) def test_broadcast_binary(self): broadcast([self.protocol], b"tea") @@ -1489,8 +1487,8 @@ def test_broadcast_no_clients(self): def test_broadcast_two_clients(self): broadcast([self.protocol, self.protocol], "café") self.assertFramesSent( - (True, OP_TEXT, "café".encode("utf-8")), - (True, OP_TEXT, "café".encode("utf-8")), + (True, OP_TEXT, "café".encode()), + (True, OP_TEXT, "café".encode()), ) def test_broadcast_skips_closed_connection(self): @@ -1513,7 +1511,7 @@ def test_broadcast_skips_connection_sending_fragmented_text(self): self.make_drain_slow() self.loop.create_task(self.protocol.send(["ca", "fé"])) self.run_loop_once() - self.assertOneFrameSent(False, OP_TEXT, "ca".encode("utf-8")) + self.assertOneFrameSent(False, OP_TEXT, "ca".encode()) with self.assertLogs("websockets", logging.WARNING) as logs: broadcast([self.protocol], "café") @@ -1530,7 +1528,7 @@ def test_broadcast_reports_connection_sending_fragmented_text(self): self.make_drain_slow() self.loop.create_task(self.protocol.send(["ca", "fé"])) self.run_loop_once() - self.assertOneFrameSent(False, OP_TEXT, "ca".encode("utf-8")) + self.assertOneFrameSent(False, OP_TEXT, "ca".encode()) with self.assertRaises(ExceptionGroup) as raised: broadcast([self.protocol], "café", raise_exceptions=True) diff --git a/tests/test_frames.py b/tests/test_frames.py index e323b3b57..3e9f5d6f8 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -77,14 +77,14 @@ def test_binary_masked(self): def test_non_ascii_text_unmasked(self): self.assertFrameData( - Frame(OP_TEXT, "café".encode("utf-8")), + Frame(OP_TEXT, "café".encode()), b"\x81\x05caf\xc3\xa9", mask=False, ) def test_non_ascii_text_masked(self): self.assertFrameData( - Frame(OP_TEXT, "café".encode("utf-8")), + Frame(OP_TEXT, "café".encode()), b"\x81\x85\x64\xbe\xee\x7e\x07\xdf\x88\xbd\xcd", mask=True, ) From ebc9890d4f2a7b1675d50d4fea167b9107082e9a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jan 2024 09:27:40 +0100 Subject: [PATCH 021/109] Remove redundant return types from docs. sphinx picks them from function signatures. This changes slightly the output e.g. ExtensionParameter becomes Tuple[str, str | None]. While this can be a bit less readable, it looks like an improvement because the information is available without needing to navigate to the definition of ExtensionParameter. --- src/websockets/client.py | 6 +++--- src/websockets/extensions/base.py | 13 ++++++------- src/websockets/legacy/auth.py | 2 +- src/websockets/legacy/handshake.py | 4 ++-- src/websockets/legacy/protocol.py | 11 +++++------ src/websockets/legacy/server.py | 11 +++++------ src/websockets/protocol.py | 6 +++--- src/websockets/server.py | 16 +++++++--------- src/websockets/uri.py | 2 +- 9 files changed, 33 insertions(+), 38 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index b2f622042..85bc81b47 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -101,7 +101,7 @@ def connect(self) -> Request: You can modify it before sending it, for example to add HTTP headers. Returns: - Request: WebSocket handshake request event to send to the server. + WebSocket handshake request event to send to the server. """ headers = Headers() @@ -213,7 +213,7 @@ def process_extensions(self, headers: Headers) -> List[Extension]: headers: WebSocket handshake response headers. Returns: - List[Extension]: List of accepted extensions. + List of accepted extensions. Raises: InvalidHandshake: to abort the handshake. @@ -271,7 +271,7 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: headers: WebSocket handshake response headers. Returns: - Optional[Subprotocol]: Subprotocol, if one was selected. + Subprotocol, if one was selected. """ subprotocol: Optional[Subprotocol] = None diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 6c481a46c..9eba6c9e7 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -32,7 +32,7 @@ def decode( max_size: maximum payload size in bytes. Returns: - Frame: Decoded frame. + Decoded frame. Raises: PayloadTooBig: if decoding the payload exceeds ``max_size``. @@ -48,7 +48,7 @@ def encode(self, frame: frames.Frame) -> frames.Frame: frame (Frame): outgoing frame. Returns: - Frame: Encoded frame. + Encoded frame. """ raise NotImplementedError @@ -68,7 +68,7 @@ def get_request_params(self) -> List[ExtensionParameter]: Build parameters to send to the server for this extension. Returns: - List[ExtensionParameter]: Parameters to send to the server. + Parameters to send to the server. """ raise NotImplementedError @@ -88,7 +88,7 @@ def process_response_params( accepted extensions. Returns: - Extension: An extension instance. + An extension instance. Raises: NegotiationError: if parameters aren't acceptable. @@ -121,9 +121,8 @@ def process_request_params( accepted extensions. Returns: - Tuple[List[ExtensionParameter], Extension]: To accept the offer, - parameters to send to the client for this extension and an - extension instance. + To accept the offer, parameters to send to the client for this + extension and an extension instance. Raises: NegotiationError: to reject the offer, if parameters received from diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index e8d6b75d5..8217afedd 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -66,7 +66,7 @@ async def check_credentials(self, username: str, password: str) -> bool: password: HTTP Basic Auth password. Returns: - bool: :obj:`True` if the handshake should continue; + :obj:`True` if the handshake should continue; :obj:`False` if it should fail with an HTTP 401 error. """ diff --git a/src/websockets/legacy/handshake.py b/src/websockets/legacy/handshake.py index ad8faf040..5853c31db 100644 --- a/src/websockets/legacy/handshake.py +++ b/src/websockets/legacy/handshake.py @@ -24,7 +24,7 @@ def build_request(headers: Headers) -> str: headers: Handshake request headers. Returns: - str: ``key`` that must be passed to :func:`check_response`. + ``key`` that must be passed to :func:`check_response`. """ key = generate_key() @@ -48,7 +48,7 @@ def check_request(headers: Headers) -> str: headers: Handshake request headers. Returns: - str: ``key`` that must be passed to :func:`build_response`. + ``key`` that must be passed to :func:`build_response`. Raises: InvalidHandshake: If the handshake request is invalid. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 47f948b7a..a9fbd5a7a 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -518,7 +518,7 @@ async def recv(self) -> Data: :func:`~asyncio.timeout` or :func:`~asyncio.wait_for`. Returns: - Data: A string (:class:`str`) for a Text_ frame. A bytestring + A string (:class:`str`) for a Text_ frame. A bytestring (:class:`bytes`) for a Binary_ frame. .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 @@ -805,7 +805,7 @@ async def wait_closed(self) -> None: """ await asyncio.shield(self.connection_lost_waiter) - async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: + async def ping(self, data: Optional[Data] = None) -> Awaitable[float]: """ Send a Ping_. @@ -827,10 +827,9 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: containing four random bytes. Returns: - ~asyncio.Future[float]: A future that will be completed when the - corresponding pong is received. You can ignore it if you don't - intend to wait. The result of the future is the latency of the - connection in seconds. + A future that will be completed when the corresponding pong is + received. You can ignore it if you don't intend to wait. The result + of the future is the latency of the connection in seconds. :: diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index d95bec4f6..297613591 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -349,8 +349,8 @@ async def process_request( request_headers: request headers. Returns: - Optional[Tuple[StatusLike, HeadersLike, bytes]]: :obj:`None` - to continue the WebSocket handshake normally. + Tuple[StatusLike, HeadersLike, bytes] | None: :obj:`None` to + continue the WebSocket handshake normally. An HTTP response, represented by a 3-uple of the response status, headers, and body, to abort the WebSocket handshake and return @@ -534,8 +534,7 @@ def select_subprotocol( server_subprotocols: list of subprotocols available on the server. Returns: - Optional[Subprotocol]: Selected subprotocol, if a common subprotocol - was found. + Selected subprotocol, if a common subprotocol was found. :obj:`None` to continue without a subprotocol. @@ -572,7 +571,7 @@ async def handshake( the handshake succeeds. Returns: - str: path of the URI of the request. + path of the URI of the request. Raises: InvalidHandshake: if the handshake fails. @@ -968,7 +967,7 @@ class Serve: outside of websockets. Returns: - WebSocketServer: WebSocket server. + WebSocket server. """ diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 765e6b9bb..342aba413 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -452,7 +452,7 @@ def events_received(self) -> List[Event]: Process resulting events, likely by passing them to the application. Returns: - List[Event]: Events read from the connection. + Events read from the connection. """ events, self.events = self.events, [] return events @@ -473,7 +473,7 @@ def data_to_send(self) -> List[bytes]: connection. Returns: - List[bytes]: Data to write to the connection. + Data to write to the connection. """ writes, self.writes = self.writes, [] @@ -490,7 +490,7 @@ def close_expected(self) -> bool: short timeout if the other side hasn't already closed it. Returns: - bool: Whether the TCP connection is expected to close soon. + Whether the TCP connection is expected to close soon. """ # We expect a TCP close if and only if we sent a close frame: diff --git a/src/websockets/server.py b/src/websockets/server.py index 191660553..58391d3cf 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -213,7 +213,6 @@ def process_request( request: WebSocket handshake request received from the client. Returns: - Tuple[str, Optional[str], Optional[str]]: ``Sec-WebSocket-Accept``, ``Sec-WebSocket-Extensions``, and ``Sec-WebSocket-Protocol`` headers for the handshake response. @@ -294,7 +293,7 @@ def process_origin(self, headers: Headers) -> Optional[Origin]: headers: WebSocket handshake request headers. Returns: - Optional[Origin]: origin, if it is acceptable. + origin, if it is acceptable. Raises: InvalidHandshake: if the Origin header is invalid. @@ -344,8 +343,8 @@ def process_extensions( headers: WebSocket handshake request headers. Returns: - Tuple[Optional[str], List[Extension]]: ``Sec-WebSocket-Extensions`` - HTTP response header and list of accepted extensions. + ``Sec-WebSocket-Extensions`` HTTP response header and list of + accepted extensions. Raises: InvalidHandshake: if the Sec-WebSocket-Extensions header is invalid. @@ -401,8 +400,8 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: headers: WebSocket handshake request headers. Returns: - Optional[Subprotocol]: Subprotocol, if one was selected; this is - also the value of the ``Sec-WebSocket-Protocol`` response header. + Subprotocol, if one was selected; this is also the value of the + ``Sec-WebSocket-Protocol`` response header. Raises: InvalidHandshake: if the Sec-WebSocket-Subprotocol header is invalid. @@ -449,8 +448,7 @@ def select_subprotocol(protocol, subprotocols): subprotocols: list of subprotocols offered by the client. Returns: - Optional[Subprotocol]: Selected subprotocol, if a common subprotocol - was found. + Selected subprotocol, if a common subprotocol was found. :obj:`None` to continue without a subprotocol. @@ -499,7 +497,7 @@ def reject( text: HTTP response body; will be encoded to UTF-8. Returns: - Response: WebSocket handshake response event to send to the client. + WebSocket handshake response event to send to the client. """ # If a user passes an int instead of a HTTPStatus, fix it automatically. diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 385090f66..970020e26 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -66,7 +66,7 @@ def parse_uri(uri: str) -> WebSocketURI: uri: WebSocket URI. Returns: - WebSocketURI: Parsed WebSocket URI. + Parsed WebSocket URI. Raises: InvalidURI: if ``uri`` isn't a valid WebSocket URI. From c53fc3b7eed17c12c1b4db5d456b7921c2cde98f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jan 2024 10:01:53 +0100 Subject: [PATCH 022/109] Start argument descriptions with uppercase letter. This change was automated with regex replaces: ^( Args:\n(?: .*\n(?: .*\n)*)*?(?: \w+(?: \(.*\))?): )([a-z]) $1\U$2 ^( Args:\n(?: .*\n(?: .*\n)*)*?(?: \w+(?: \(.*\))?): )([a-z]) $1\U$2 Also remove redundant type annotations. --- src/websockets/client.py | 12 ++++---- src/websockets/datastructures.py | 2 +- src/websockets/extensions/base.py | 18 +++++------- .../extensions/permessage_deflate.py | 22 +++++++-------- src/websockets/frames.py | 14 +++++----- src/websockets/headers.py | 6 ++-- src/websockets/http11.py | 12 ++++---- src/websockets/legacy/protocol.py | 11 +++----- src/websockets/legacy/server.py | 28 +++++++++---------- src/websockets/protocol.py | 6 ++-- src/websockets/server.py | 14 +++++----- src/websockets/streams.py | 8 +++--- src/websockets/sync/utils.py | 2 +- src/websockets/utils.py | 4 +-- 14 files changed, 76 insertions(+), 83 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 85bc81b47..028e7ce47 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -53,17 +53,17 @@ class ClientProtocol(Protocol): Args: wsuri: URI of the WebSocket server, parsed with :func:`~websockets.uri.parse_uri`. - origin: value of the ``Origin`` header. This is useful when connecting + origin: Value of the ``Origin`` header. This is useful when connecting to a server that validates the ``Origin`` header to defend against Cross-Site WebSocket Hijacking attacks. - extensions: list of supported extensions, in order in which they + extensions: List of supported extensions, in order in which they should be tried. - subprotocols: list of supported subprotocols, in order of decreasing + subprotocols: List of supported subprotocols, in order of decreasing preference. - state: initial state of the WebSocket connection. - max_size: maximum size of incoming messages in bytes; + state: Initial state of the WebSocket connection. + max_size: Maximum size of incoming messages in bytes; :obj:`None` disables the limit. - logger: logger for this connection; + logger: Logger for this connection; defaults to ``logging.getLogger("websockets.client")``; see the :doc:`logging guide <../../topics/logging>` for details. diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index a0a648463..c2a5acfee 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -152,7 +152,7 @@ def get_all(self, key: str) -> List[str]: Return the (possibly empty) list of all values for a header. Args: - key: header name. + key: Header name. """ return self._dict.get(key.lower(), []) diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 9eba6c9e7..cca3fe513 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -28,8 +28,8 @@ def decode( Decode an incoming frame. Args: - frame (Frame): incoming frame. - max_size: maximum payload size in bytes. + frame: Incoming frame. + max_size: Maximum payload size in bytes. Returns: Decoded frame. @@ -45,7 +45,7 @@ def encode(self, frame: frames.Frame) -> frames.Frame: Encode an outgoing frame. Args: - frame (Frame): outgoing frame. + frame: Outgoing frame. Returns: Encoded frame. @@ -82,10 +82,8 @@ def process_response_params( Process parameters received from the server. Args: - params (Sequence[ExtensionParameter]): parameters received from - the server for this extension. - accepted_extensions (Sequence[Extension]): list of previously - accepted extensions. + params: Parameters received from the server for this extension. + accepted_extensions: List of previously accepted extensions. Returns: An extension instance. @@ -115,10 +113,8 @@ def process_request_params( Process parameters received from the client. Args: - params (Sequence[ExtensionParameter]): parameters received from - the client for this extension. - accepted_extensions (Sequence[Extension]): list of previously - accepted extensions. + params: Parameters received from the client for this extension. + accepted_extensions: List of previously accepted extensions. Returns: To accept the offer, parameters to send to the client for this diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index b391837c6..edccac3ca 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -268,14 +268,14 @@ class ClientPerMessageDeflateFactory(ClientExtensionFactory): value or to an integer value to include them with this value. Args: - server_no_context_takeover: prevent server from using context takeover. - client_no_context_takeover: prevent client from using context takeover. - server_max_window_bits: maximum size of the server's LZ77 sliding window + server_no_context_takeover: Prevent server from using context takeover. + client_no_context_takeover: Prevent client from using context takeover. + server_max_window_bits: Maximum size of the server's LZ77 sliding window in bits, between 8 and 15. - client_max_window_bits: maximum size of the client's LZ77 sliding window + client_max_window_bits: Maximum size of the client's LZ77 sliding window in bits, between 8 and 15, or :obj:`True` to indicate support without setting a limit. - compress_settings: additional keyword arguments for :func:`zlib.compressobj`, + compress_settings: Additional keyword arguments for :func:`zlib.compressobj`, excluding ``wbits``. """ @@ -468,15 +468,15 @@ class ServerPerMessageDeflateFactory(ServerExtensionFactory): value or to an integer value to include them with this value. Args: - server_no_context_takeover: prevent server from using context takeover. - client_no_context_takeover: prevent client from using context takeover. - server_max_window_bits: maximum size of the server's LZ77 sliding window + server_no_context_takeover: Prevent server from using context takeover. + client_no_context_takeover: Prevent client from using context takeover. + server_max_window_bits: Maximum size of the server's LZ77 sliding window in bits, between 8 and 15. - client_max_window_bits: maximum size of the client's LZ77 sliding window + client_max_window_bits: Maximum size of the client's LZ77 sliding window in bits, between 8 and 15. - compress_settings: additional keyword arguments for :func:`zlib.compressobj`, + compress_settings: Additional keyword arguments for :func:`zlib.compressobj`, excluding ``wbits``. - require_client_max_window_bits: do not enable compression at all if + require_client_max_window_bits: Do not enable compression at all if client doesn't advertise support for ``client_max_window_bits``; the default behavior is to enable compression without enforcing ``client_max_window_bits``. diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 63c35ed4d..e5e2af8b4 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -208,12 +208,12 @@ def parse( This is a generator-based coroutine. Args: - read_exact: generator-based coroutine that reads the requested + read_exact: Generator-based coroutine that reads the requested bytes or raises an exception if there isn't enough data. - mask: whether the frame should be masked i.e. whether the read + mask: Whether the frame should be masked i.e. whether the read happens on the server side. - max_size: maximum payload size in bytes. - extensions: list of extensions, applied in reverse order. + max_size: Maximum payload size in bytes. + extensions: List of extensions, applied in reverse order. Raises: EOFError: if the connection is closed without a full WebSocket frame. @@ -280,9 +280,9 @@ def serialize( Serialize a WebSocket frame. Args: - mask: whether the frame should be masked i.e. whether the write + mask: Whether the frame should be masked i.e. whether the write happens on the client side. - extensions: list of extensions, applied in order. + extensions: List of extensions, applied in order. Raises: ProtocolError: if the frame contains incorrect values. @@ -432,7 +432,7 @@ def parse(cls, data: bytes) -> Close: Parse the payload of a close frame. Args: - data: payload of the close frame. + data: Payload of the close frame. Raises: ProtocolError: if data is ill-formed. diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 9ae3035a5..8391ad26c 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -289,7 +289,7 @@ def parse_upgrade(header: str) -> List[UpgradeProtocol]: Return a list of HTTP protocols. Args: - header: value of the ``Upgrade`` header. + header: Value of the ``Upgrade`` header. Raises: InvalidHeaderFormat: on invalid inputs. @@ -486,7 +486,7 @@ def build_www_authenticate_basic(realm: str) -> str: Build a ``WWW-Authenticate`` header for HTTP Basic Auth. Args: - realm: identifier of the protection space. + realm: Identifier of the protection space. """ # https://www.rfc-editor.org/rfc/rfc7617.html#section-2 @@ -532,7 +532,7 @@ def parse_authorization_basic(header: str) -> Tuple[str, str]: Return a ``(username, password)`` tuple. Args: - header: value of the ``Authorization`` header. + header: Value of the ``Authorization`` header. Raises: InvalidHeaderFormat: on invalid inputs. diff --git a/src/websockets/http11.py b/src/websockets/http11.py index ec4e3b8b7..c0a96f878 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -93,7 +93,7 @@ def parse( body, it may be read from the data stream after :meth:`parse` returns. Args: - read_line: generator-based coroutine that reads a LF-terminated + read_line: Generator-based coroutine that reads a LF-terminated line or raises an exception if there isn't enough data Raises: @@ -193,11 +193,11 @@ def parse( characters. Other characters are represented with surrogate escapes. Args: - read_line: generator-based coroutine that reads a LF-terminated + read_line: Generator-based coroutine that reads a LF-terminated line or raises an exception if there isn't enough data. - read_exact: generator-based coroutine that reads the requested + read_exact: Generator-based coroutine that reads the requested bytes or raises an exception if there isn't enough data. - read_to_eof: generator-based coroutine that reads until the end + read_to_eof: Generator-based coroutine that reads until the end of the stream. Raises: @@ -295,7 +295,7 @@ def parse_headers( Non-ASCII characters are represented with surrogate escapes. Args: - read_line: generator-based coroutine that reads a LF-terminated line + read_line: Generator-based coroutine that reads a LF-terminated line or raises an exception if there isn't enough data. Raises: @@ -346,7 +346,7 @@ def parse_line( CRLF is stripped from the return value. Args: - read_line: generator-based coroutine that reads a LF-terminated line + read_line: Generator-based coroutine that reads a LF-terminated line or raises an exception if there isn't enough data. Raises: diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index a9fbd5a7a..26d50a2cc 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -624,8 +624,7 @@ async def send( error or a network failure. Args: - message (Union[Data, Iterable[Data], AsyncIterable[Data]): message - to send. + message: Message to send. Raises: ConnectionClosed: When the connection is closed. @@ -822,9 +821,8 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[float]: effect. Args: - data (Optional[Data]): payload of the ping; a string will be - encoded to UTF-8; or :obj:`None` to generate a payload - containing four random bytes. + data: Payload of the ping. A string will be encoded to UTF-8. + If ``data`` is :obj:`None`, the payload is four random bytes. Returns: A future that will be completed when the corresponding pong is @@ -878,8 +876,7 @@ async def pong(self, data: Data = b"") -> None: wait, you should close the connection. Args: - data (Data): Payload of the pong. A string will be encoded to - UTF-8. + data: Payload of the pong. A string will be encoded to UTF-8. Raises: ConnectionClosed: When the connection is closed. diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 297613591..4af7ed109 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -345,8 +345,8 @@ async def process_request( from shutting down. Args: - path: request path, including optional query string. - request_headers: request headers. + path: Request path, including optional query string. + request_headers: Request headers. Returns: Tuple[StatusLike, HeadersLike, bytes] | None: :obj:`None` to @@ -377,8 +377,8 @@ def process_origin( Handle the Origin HTTP request header. Args: - headers: request headers. - origins: optional list of acceptable origins. + headers: Request headers. + origins: Optional list of acceptable origins. Raises: InvalidOrigin: if the origin isn't acceptable. @@ -428,8 +428,8 @@ def process_extensions( order of extensions, may be implemented by overriding this method. Args: - headers: request headers. - extensions: optional list of supported extensions. + headers: Request headers. + extensions: Optional list of supported extensions. Raises: InvalidHandshake: to abort the handshake with an HTTP 400 error. @@ -488,8 +488,8 @@ def process_subprotocol( as the selected subprotocol. Args: - headers: request headers. - available_subprotocols: optional list of supported subprotocols. + headers: Request headers. + available_subprotocols: Optional list of supported subprotocols. Raises: InvalidHandshake: to abort the handshake with an HTTP 400 error. @@ -530,8 +530,8 @@ def select_subprotocol( subprotocol. Args: - client_subprotocols: list of subprotocols offered by the client. - server_subprotocols: list of subprotocols available on the server. + client_subprotocols: List of subprotocols offered by the client. + server_subprotocols: List of subprotocols available on the server. Returns: Selected subprotocol, if a common subprotocol was found. @@ -561,13 +561,13 @@ async def handshake( Perform the server side of the opening handshake. Args: - origins: list of acceptable values of the Origin HTTP header; + origins: List of acceptable values of the Origin HTTP header; include :obj:`None` if the lack of an origin is acceptable. - extensions: list of supported extensions, in order in which they + extensions: List of supported extensions, in order in which they should be tried. - subprotocols: list of supported subprotocols, in order of + subprotocols: List of supported subprotocols, in order of decreasing preference. - extra_headers: arbitrary HTTP headers to add to the response when + extra_headers: Arbitrary HTTP headers to add to the response when the handshake succeeds. Returns: diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 342aba413..99c9ee1a8 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -74,10 +74,10 @@ class Protocol: Args: side: :attr:`~Side.CLIENT` or :attr:`~Side.SERVER`. - state: initial state of the WebSocket connection. - max_size: maximum size of incoming messages in bytes; + state: Initial state of the WebSocket connection. + max_size: Maximum size of incoming messages in bytes; :obj:`None` disables the limit. - logger: logger for this connection; depending on ``side``, + logger: Logger for this connection; depending on ``side``, defaults to ``logging.getLogger("websockets.client")`` or ``logging.getLogger("websockets.server")``; see the :doc:`logging guide <../../topics/logging>` for details. diff --git a/src/websockets/server.py b/src/websockets/server.py index 58391d3cf..6711a0bba 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -53,22 +53,22 @@ class ServerProtocol(Protocol): Sans-I/O implementation of a WebSocket server connection. Args: - origins: acceptable values of the ``Origin`` header; include + origins: Acceptable values of the ``Origin`` header; include :obj:`None` in the list if the lack of an origin is acceptable. This is useful for defending against Cross-Site WebSocket Hijacking attacks. - extensions: list of supported extensions, in order in which they + extensions: List of supported extensions, in order in which they should be tried. - subprotocols: list of supported subprotocols, in order of decreasing + subprotocols: List of supported subprotocols, in order of decreasing preference. select_subprotocol: Callback for selecting a subprotocol among those supported by the client and the server. It has the same signature as the :meth:`select_subprotocol` method, including a :class:`ServerProtocol` instance as first argument. - state: initial state of the WebSocket connection. - max_size: maximum size of incoming messages in bytes; + state: Initial state of the WebSocket connection. + max_size: Maximum size of incoming messages in bytes; :obj:`None` disables the limit. - logger: logger for this connection; + logger: Logger for this connection; defaults to ``logging.getLogger("websockets.client")``; see the :doc:`logging guide <../../topics/logging>` for details. @@ -445,7 +445,7 @@ def select_subprotocol(protocol, subprotocols): return "chat" Args: - subprotocols: list of subprotocols offered by the client. + subprotocols: List of subprotocols offered by the client. Returns: Selected subprotocol, if a common subprotocol was found. diff --git a/src/websockets/streams.py b/src/websockets/streams.py index f861d4bd2..d288cf0cc 100644 --- a/src/websockets/streams.py +++ b/src/websockets/streams.py @@ -26,7 +26,7 @@ def read_line(self, m: int) -> Generator[None, None, bytes]: The return value includes the LF character. Args: - m: maximum number bytes to read; this is a security limit. + m: Maximum number bytes to read; this is a security limit. Raises: EOFError: if the stream ends without a LF. @@ -58,7 +58,7 @@ def read_exact(self, n: int) -> Generator[None, None, bytes]: This is a generator-based coroutine. Args: - n: how many bytes to read. + n: How many bytes to read. Raises: EOFError: if the stream ends in less than ``n`` bytes. @@ -81,7 +81,7 @@ def read_to_eof(self, m: int) -> Generator[None, None, bytes]: This is a generator-based coroutine. Args: - m: maximum number bytes to read; this is a security limit. + m: Maximum number bytes to read; this is a security limit. Raises: RuntimeError: if the stream ends in more than ``m`` bytes. @@ -119,7 +119,7 @@ def feed_data(self, data: bytes) -> None: :meth:`feed_data` cannot be called after :meth:`feed_eof`. Args: - data: data to write. + data: Data to write. Raises: EOFError: if the stream has ended. diff --git a/src/websockets/sync/utils.py b/src/websockets/sync/utils.py index 471f32e19..3364bdc2d 100644 --- a/src/websockets/sync/utils.py +++ b/src/websockets/sync/utils.py @@ -28,7 +28,7 @@ def timeout(self, *, raise_if_elapsed: bool = True) -> Optional[float]: Calculate a timeout from a deadline. Args: - raise_if_elapsed (bool): Whether to raise :exc:`TimeoutError` + raise_if_elapsed: Whether to raise :exc:`TimeoutError` if the deadline lapsed. Raises: diff --git a/src/websockets/utils.py b/src/websockets/utils.py index c40404906..62d2dc177 100644 --- a/src/websockets/utils.py +++ b/src/websockets/utils.py @@ -26,7 +26,7 @@ def accept_key(key: str) -> str: Compute the value of the Sec-WebSocket-Accept header. Args: - key: value of the Sec-WebSocket-Key header. + key: Value of the Sec-WebSocket-Key header. """ sha1 = hashlib.sha1((key + GUID).encode()).digest() @@ -38,7 +38,7 @@ def apply_mask(data: bytes, mask: bytes) -> bytes: Apply masking to the data of a WebSocket message. Args: - data: data to mask. + data: Data to mask. mask: 4-bytes mask. """ From 2865bdcc8b93f78d019aa0c605c86535dd66d026 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jan 2024 10:14:25 +0100 Subject: [PATCH 023/109] Start exception descriptions with uppercase letter. This change was automated with regex replaces: ^( Raises:\n(?: .*\n(?: .*\n)*)*?(?: \w+): )([a-z]) $1\U$2 ^( Raises:\n(?: .*\n(?: .*\n)*)*?(?: \w+): )([a-z]) $1\U$2 --- src/websockets/client.py | 4 ++-- src/websockets/extensions/base.py | 6 +++--- src/websockets/frames.py | 22 +++++++++++----------- src/websockets/headers.py | 30 +++++++++++++++--------------- src/websockets/http11.py | 24 ++++++++++++------------ src/websockets/legacy/server.py | 10 +++++----- src/websockets/protocol.py | 16 ++++++++-------- src/websockets/server.py | 12 ++++++------ src/websockets/streams.py | 12 ++++++------ src/websockets/uri.py | 2 +- 10 files changed, 69 insertions(+), 69 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 028e7ce47..633b1960b 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -144,7 +144,7 @@ def process_response(self, response: Response) -> None: request: WebSocket handshake response received from the server. Raises: - InvalidHandshake: if the handshake response is invalid. + InvalidHandshake: If the handshake response is invalid. """ @@ -216,7 +216,7 @@ def process_extensions(self, headers: Headers) -> List[Extension]: List of accepted extensions. Raises: - InvalidHandshake: to abort the handshake. + InvalidHandshake: To abort the handshake. """ accepted_extensions: List[Extension] = [] diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index cca3fe513..7446c990c 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -35,7 +35,7 @@ def decode( Decoded frame. Raises: - PayloadTooBig: if decoding the payload exceeds ``max_size``. + PayloadTooBig: If decoding the payload exceeds ``max_size``. """ raise NotImplementedError @@ -89,7 +89,7 @@ def process_response_params( An extension instance. Raises: - NegotiationError: if parameters aren't acceptable. + NegotiationError: If parameters aren't acceptable. """ raise NotImplementedError @@ -121,7 +121,7 @@ def process_request_params( extension and an extension instance. Raises: - NegotiationError: to reject the offer, if parameters received from + NegotiationError: To reject the offer, if parameters received from the client aren't acceptable. """ diff --git a/src/websockets/frames.py b/src/websockets/frames.py index e5e2af8b4..201bc5068 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -216,10 +216,10 @@ def parse( extensions: List of extensions, applied in reverse order. Raises: - EOFError: if the connection is closed without a full WebSocket frame. - UnicodeDecodeError: if the frame contains invalid UTF-8. - PayloadTooBig: if the frame's payload size exceeds ``max_size``. - ProtocolError: if the frame contains incorrect values. + EOFError: If the connection is closed without a full WebSocket frame. + UnicodeDecodeError: If the frame contains invalid UTF-8. + PayloadTooBig: If the frame's payload size exceeds ``max_size``. + ProtocolError: If the frame contains incorrect values. """ # Read the header. @@ -285,7 +285,7 @@ def serialize( extensions: List of extensions, applied in order. Raises: - ProtocolError: if the frame contains incorrect values. + ProtocolError: If the frame contains incorrect values. """ self.check() @@ -334,7 +334,7 @@ def check(self) -> None: Check that reserved bits and opcode have acceptable values. Raises: - ProtocolError: if a reserved bit or the opcode is invalid. + ProtocolError: If a reserved bit or the opcode is invalid. """ if self.rsv1 or self.rsv2 or self.rsv3: @@ -360,7 +360,7 @@ def prepare_data(data: Data) -> Tuple[int, bytes]: object. Raises: - TypeError: if ``data`` doesn't have a supported type. + TypeError: If ``data`` doesn't have a supported type. """ if isinstance(data, str): @@ -383,7 +383,7 @@ def prepare_ctrl(data: Data) -> bytes: If ``data`` is a bytes-like object, return a :class:`bytes` object. Raises: - TypeError: if ``data`` doesn't have a supported type. + TypeError: If ``data`` doesn't have a supported type. """ if isinstance(data, str): @@ -435,8 +435,8 @@ def parse(cls, data: bytes) -> Close: data: Payload of the close frame. Raises: - ProtocolError: if data is ill-formed. - UnicodeDecodeError: if the reason isn't valid UTF-8. + ProtocolError: If data is ill-formed. + UnicodeDecodeError: If the reason isn't valid UTF-8. """ if len(data) >= 2: @@ -463,7 +463,7 @@ def check(self) -> None: Check that the close code has a valid value for a close frame. Raises: - ProtocolError: if the close code is invalid. + ProtocolError: If the close code is invalid. """ if not (self.code in EXTERNAL_CLOSE_CODES or 3000 <= self.code < 5000): diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 8391ad26c..463df3061 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -103,7 +103,7 @@ def parse_token(header: str, pos: int, header_name: str) -> Tuple[str, int]: Return the token value and the new position. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ match = _token_re.match(header, pos) @@ -127,7 +127,7 @@ def parse_quoted_string(header: str, pos: int, header_name: str) -> Tuple[str, i Return the unquoted value and the new position. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ match = _quoted_string_re.match(header, pos) @@ -180,7 +180,7 @@ def parse_list( Return a list of items. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ # Per https://www.rfc-editor.org/rfc/rfc7230.html#section-7, "a recipient @@ -234,7 +234,7 @@ def parse_connection_option( Return the protocol value and the new position. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ item, pos = parse_token(header, pos, header_name) @@ -251,7 +251,7 @@ def parse_connection(header: str) -> List[ConnectionOption]: header: value of the ``Connection`` header. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ return parse_list(parse_connection_option, header, 0, "Connection") @@ -271,7 +271,7 @@ def parse_upgrade_protocol( Return the protocol value and the new position. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ match = _protocol_re.match(header, pos) @@ -292,7 +292,7 @@ def parse_upgrade(header: str) -> List[UpgradeProtocol]: header: Value of the ``Upgrade`` header. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ return parse_list(parse_upgrade_protocol, header, 0, "Upgrade") @@ -307,7 +307,7 @@ def parse_extension_item_param( Return a ``(name, value)`` pair and the new position. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ # Extract parameter name. @@ -344,7 +344,7 @@ def parse_extension_item( list of ``(name, value)`` pairs, and the new position. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ # Extract extension name. @@ -379,7 +379,7 @@ def parse_extension(header: str) -> List[ExtensionHeader]: Parameter values are :obj:`None` when no value is provided. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ return parse_list(parse_extension_item, header, 0, "Sec-WebSocket-Extensions") @@ -431,7 +431,7 @@ def parse_subprotocol_item( Return the subprotocol value and the new position. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ item, pos = parse_token(header, pos, header_name) @@ -445,7 +445,7 @@ def parse_subprotocol(header: str) -> List[Subprotocol]: Return a list of WebSocket subprotocols. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ return parse_list(parse_subprotocol_item, header, 0, "Sec-WebSocket-Protocol") @@ -505,7 +505,7 @@ def parse_token68(header: str, pos: int, header_name: str) -> Tuple[str, int]: Return the token value and the new position. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ match = _token68_re.match(header, pos) @@ -535,8 +535,8 @@ def parse_authorization_basic(header: str) -> Tuple[str, str]: header: Value of the ``Authorization`` header. Raises: - InvalidHeaderFormat: on invalid inputs. - InvalidHeaderValue: on unsupported inputs. + InvalidHeaderFormat: On invalid inputs. + InvalidHeaderValue: On unsupported inputs. """ # https://www.rfc-editor.org/rfc/rfc7235.html#section-2.1 diff --git a/src/websockets/http11.py b/src/websockets/http11.py index c0a96f878..6fe775eec 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -97,9 +97,9 @@ def parse( line or raises an exception if there isn't enough data Raises: - EOFError: if the connection is closed without a full HTTP request. - SecurityError: if the request exceeds a security limit. - ValueError: if the request isn't well formatted. + EOFError: If the connection is closed without a full HTTP request. + SecurityError: If the request exceeds a security limit. + ValueError: If the request isn't well formatted. """ # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.1 @@ -201,10 +201,10 @@ def parse( of the stream. Raises: - EOFError: if the connection is closed without a full HTTP response. - SecurityError: if the response exceeds a security limit. - LookupError: if the response isn't well formatted. - ValueError: if the response isn't well formatted. + EOFError: If the connection is closed without a full HTTP response. + SecurityError: If the response exceeds a security limit. + LookupError: If the response isn't well formatted. + ValueError: If the response isn't well formatted. """ # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.2 @@ -299,9 +299,9 @@ def parse_headers( or raises an exception if there isn't enough data. Raises: - EOFError: if the connection is closed without complete headers. - SecurityError: if the request exceeds a security limit. - ValueError: if the request isn't well formatted. + EOFError: If the connection is closed without complete headers. + SecurityError: If the request exceeds a security limit. + ValueError: If the request isn't well formatted. """ # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.2 @@ -350,8 +350,8 @@ def parse_line( or raises an exception if there isn't enough data. Raises: - EOFError: if the connection is closed without a CRLF. - SecurityError: if the response exceeds a security limit. + EOFError: If the connection is closed without a CRLF. + SecurityError: If the response exceeds a security limit. """ try: diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 4af7ed109..e8cf8220f 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -271,7 +271,7 @@ async def read_http_request(self) -> Tuple[str, Headers]: after this coroutine returns. Raises: - InvalidMessage: if the HTTP message is malformed or isn't an + InvalidMessage: If the HTTP message is malformed or isn't an HTTP/1.1 GET request. """ @@ -381,7 +381,7 @@ def process_origin( origins: Optional list of acceptable origins. Raises: - InvalidOrigin: if the origin isn't acceptable. + InvalidOrigin: If the origin isn't acceptable. """ # "The user agent MUST NOT include more than one Origin header field" @@ -432,7 +432,7 @@ def process_extensions( extensions: Optional list of supported extensions. Raises: - InvalidHandshake: to abort the handshake with an HTTP 400 error. + InvalidHandshake: To abort the handshake with an HTTP 400 error. """ response_header_value: Optional[str] = None @@ -492,7 +492,7 @@ def process_subprotocol( available_subprotocols: Optional list of supported subprotocols. Raises: - InvalidHandshake: to abort the handshake with an HTTP 400 error. + InvalidHandshake: To abort the handshake with an HTTP 400 error. """ subprotocol: Optional[Subprotocol] = None @@ -574,7 +574,7 @@ async def handshake( path of the URI of the request. Raises: - InvalidHandshake: if the handshake fails. + InvalidHandshake: If the handshake fails. """ path, request_headers = await self.read_http_request() diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 99c9ee1a8..6851f3b1f 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -217,7 +217,7 @@ def close_exc(self) -> ConnectionClosed: known only once the connection is closed. Raises: - AssertionError: if the connection isn't closed yet. + AssertionError: If the connection isn't closed yet. """ assert self.state is CLOSED, "connection isn't closed yet" @@ -252,7 +252,7 @@ def receive_data(self, data: bytes) -> None: - You should call :meth:`events_received` and process resulting events. Raises: - EOFError: if :meth:`receive_eof` was called earlier. + EOFError: If :meth:`receive_eof` was called earlier. """ self.reader.feed_data(data) @@ -270,7 +270,7 @@ def receive_eof(self) -> None: any new events. Raises: - EOFError: if :meth:`receive_eof` was called earlier. + EOFError: If :meth:`receive_eof` was called earlier. """ self.reader.feed_eof() @@ -292,7 +292,7 @@ def send_continuation(self, data: bytes, fin: bool) -> None: of a fragmented message and to :obj:`False` otherwise. Raises: - ProtocolError: if a fragmented message isn't in progress. + ProtocolError: If a fragmented message isn't in progress. """ if not self.expect_continuation_frame: @@ -313,7 +313,7 @@ def send_text(self, data: bytes, fin: bool = True) -> None: a fragmented message. Raises: - ProtocolError: if a fragmented message is in progress. + ProtocolError: If a fragmented message is in progress. """ if self.expect_continuation_frame: @@ -334,7 +334,7 @@ def send_binary(self, data: bytes, fin: bool = True) -> None: a fragmented message. Raises: - ProtocolError: if a fragmented message is in progress. + ProtocolError: If a fragmented message is in progress. """ if self.expect_continuation_frame: @@ -354,7 +354,7 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: reason: close reason. Raises: - ProtocolError: if a fragmented message is being sent, if the code + ProtocolError: If a fragmented message is being sent, if the code isn't valid, or if a reason is provided without a code """ @@ -412,7 +412,7 @@ def fail(self, code: int, reason: str = "") -> None: reason: close reason Raises: - ProtocolError: if the code isn't valid. + ProtocolError: If the code isn't valid. """ # 7.1.7. Fail the WebSocket Connection diff --git a/src/websockets/server.py b/src/websockets/server.py index 6711a0bba..330e54f37 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -217,7 +217,7 @@ def process_request( ``Sec-WebSocket-Protocol`` headers for the handshake response. Raises: - InvalidHandshake: if the handshake request is invalid; + InvalidHandshake: If the handshake request is invalid; then the server must return 400 Bad Request error. """ @@ -296,8 +296,8 @@ def process_origin(self, headers: Headers) -> Optional[Origin]: origin, if it is acceptable. Raises: - InvalidHandshake: if the Origin header is invalid. - InvalidOrigin: if the origin isn't acceptable. + InvalidHandshake: If the Origin header is invalid. + InvalidOrigin: If the origin isn't acceptable. """ # "The user agent MUST NOT include more than one Origin header field" @@ -347,7 +347,7 @@ def process_extensions( accepted extensions. Raises: - InvalidHandshake: if the Sec-WebSocket-Extensions header is invalid. + InvalidHandshake: If the Sec-WebSocket-Extensions header is invalid. """ response_header_value: Optional[str] = None @@ -404,7 +404,7 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: ``Sec-WebSocket-Protocol`` response header. Raises: - InvalidHandshake: if the Sec-WebSocket-Subprotocol header is invalid. + InvalidHandshake: If the Sec-WebSocket-Subprotocol header is invalid. """ subprotocols: Sequence[Subprotocol] = sum( @@ -453,7 +453,7 @@ def select_subprotocol(protocol, subprotocols): :obj:`None` to continue without a subprotocol. Raises: - NegotiationError: custom implementations may raise this exception + NegotiationError: Custom implementations may raise this exception to abort the handshake with an HTTP 400 error. """ diff --git a/src/websockets/streams.py b/src/websockets/streams.py index d288cf0cc..956f139d4 100644 --- a/src/websockets/streams.py +++ b/src/websockets/streams.py @@ -29,8 +29,8 @@ def read_line(self, m: int) -> Generator[None, None, bytes]: m: Maximum number bytes to read; this is a security limit. Raises: - EOFError: if the stream ends without a LF. - RuntimeError: if the stream ends in more than ``m`` bytes. + EOFError: If the stream ends without a LF. + RuntimeError: If the stream ends in more than ``m`` bytes. """ n = 0 # number of bytes to read @@ -61,7 +61,7 @@ def read_exact(self, n: int) -> Generator[None, None, bytes]: n: How many bytes to read. Raises: - EOFError: if the stream ends in less than ``n`` bytes. + EOFError: If the stream ends in less than ``n`` bytes. """ assert n >= 0 @@ -84,7 +84,7 @@ def read_to_eof(self, m: int) -> Generator[None, None, bytes]: m: Maximum number bytes to read; this is a security limit. Raises: - RuntimeError: if the stream ends in more than ``m`` bytes. + RuntimeError: If the stream ends in more than ``m`` bytes. """ while not self.eof: @@ -122,7 +122,7 @@ def feed_data(self, data: bytes) -> None: data: Data to write. Raises: - EOFError: if the stream has ended. + EOFError: If the stream has ended. """ if self.eof: @@ -136,7 +136,7 @@ def feed_eof(self) -> None: :meth:`feed_eof` cannot be called more than once. Raises: - EOFError: if the stream has ended. + EOFError: If the stream has ended. """ if self.eof: diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 970020e26..8cf581743 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -69,7 +69,7 @@ def parse_uri(uri: str) -> WebSocketURI: Parsed WebSocket URI. Raises: - InvalidURI: if ``uri`` isn't a valid WebSocket URI. + InvalidURI: If ``uri`` isn't a valid WebSocket URI. """ parsed = urllib.parse.urlparse(uri) From 908c7ba23168da52d0006d67bc068e315e90daae Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 6 Jan 2024 10:02:56 +0100 Subject: [PATCH 024/109] Clean up sync message assembler. Remove support for control frames, which isn't actually used. --- src/websockets/sync/messages.py | 34 +++++----- tests/sync/test_messages.py | 113 ++++++++++++-------------------- 2 files changed, 60 insertions(+), 87 deletions(-) diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index d98ff855b..dcba183d9 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -5,7 +5,7 @@ import threading from typing import Iterator, List, Optional, cast -from ..frames import Frame, Opcode +from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from ..typing import Data @@ -25,8 +25,11 @@ def __init__(self) -> None: # primitives provided by the threading and queue modules. self.mutex = threading.Lock() - # We create a latch with two events to ensure proper interleaving of - # writing and reading messages. + # We create a latch with two events to synchronize the production of + # frames and the consumption of messages (or frames) without a buffer. + # This design requires a switch between the library thread and the user + # thread for each message; that shouldn't be a performance bottleneck. + # put() sets this event to tell get() that a message can be fetched. self.message_complete = threading.Event() # get() sets this event to let put() that the message was fetched. @@ -72,8 +75,10 @@ def get(self, timeout: Optional[float] = None) -> Data: Raises: EOFError: If the stream of frames has ended. - RuntimeError: If two threads run :meth:`get` or :meth:``get_iter` + RuntimeError: If two threads run :meth:`get` or :meth:`get_iter` concurrently. + TimeoutError: If a timeout is provided and elapses before a + complete message is received. """ with self.mutex: @@ -131,7 +136,7 @@ def get_iter(self) -> Iterator[Data]: Raises: EOFError: If the stream of frames has ended. - RuntimeError: If two threads run :meth:`get` or :meth:``get_iter` + RuntimeError: If two threads run :meth:`get` or :meth:`get_iter` concurrently. """ @@ -159,11 +164,10 @@ def get_iter(self) -> Iterator[Data]: self.get_in_progress = True # Locking with get_in_progress ensures only one thread can get here. - yield from chunks - while True: - chunk = self.chunks_queue.get() - if chunk is None: - break + chunk: Optional[Data] + for chunk in chunks: + yield chunk + while (chunk := self.chunks_queue.get()) is not None: yield chunk with self.mutex: @@ -205,15 +209,12 @@ def put(self, frame: Frame) -> None: if self.put_in_progress: raise RuntimeError("put is already running") - if frame.opcode is Opcode.TEXT: + if frame.opcode is OP_TEXT: self.decoder = UTF8Decoder(errors="strict") - elif frame.opcode is Opcode.BINARY: + elif frame.opcode is OP_BINARY: self.decoder = None - elif frame.opcode is Opcode.CONT: - pass else: - # Ignore control frames. - return + assert frame.opcode is OP_CONT data: Data if self.decoder is not None: @@ -242,6 +243,7 @@ def put(self, frame: Frame) -> None: self.put_in_progress = True # Release the lock to allow get() to run and eventually set the event. + # Locking with put_in_progress ensures only one coroutine can get here. self.message_fetched.wait() with self.mutex: diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index 825eb8797..c134b8304 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -1,6 +1,6 @@ import time -from websockets.frames import OP_BINARY, OP_CONT, OP_PING, OP_PONG, OP_TEXT, Frame +from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from websockets.sync.messages import * from ..utils import MS @@ -350,76 +350,6 @@ def test_get_with_timeout_times_out(self): with self.assertRaises(TimeoutError): self.assembler.get(MS) - # Test control frames - - def test_control_frame_before_message_is_ignored(self): - """get ignores control frames between messages.""" - - def putter(): - self.assembler.put(Frame(OP_PING, b"")) - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - - with self.run_in_thread(putter): - message = self.assembler.get() - - self.assertEqual(message, "café") - - def test_control_frame_in_fragmented_message_is_ignored(self): - """get ignores control frames within fragmented messages.""" - - def putter(): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - self.assembler.put(Frame(OP_PING, b"")) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - self.assembler.put(Frame(OP_PONG, b"")) - self.assembler.put(Frame(OP_CONT, b"a")) - - with self.run_in_thread(putter): - message = self.assembler.get() - - self.assertEqual(message, b"tea") - - # Test concurrency - - def test_get_fails_when_get_is_running(self): - """get cannot be called concurrently with itself.""" - with self.run_in_thread(self.assembler.get): - with self.assertRaises(RuntimeError): - self.assembler.get() - self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread - - def test_get_fails_when_get_iter_is_running(self): - """get cannot be called concurrently with get_iter.""" - with self.run_in_thread(lambda: list(self.assembler.get_iter())): - with self.assertRaises(RuntimeError): - self.assembler.get() - self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread - - def test_get_iter_fails_when_get_is_running(self): - """get_iter cannot be called concurrently with get.""" - with self.run_in_thread(self.assembler.get): - with self.assertRaises(RuntimeError): - list(self.assembler.get_iter()) - self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread - - def test_get_iter_fails_when_get_iter_is_running(self): - """get_iter cannot be called concurrently with itself.""" - with self.run_in_thread(lambda: list(self.assembler.get_iter())): - with self.assertRaises(RuntimeError): - list(self.assembler.get_iter()) - self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread - - def test_put_fails_when_put_is_running(self): - """put cannot be called concurrently with itself.""" - - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - - with self.run_in_thread(putter): - with self.assertRaises(RuntimeError): - self.assembler.put(Frame(OP_BINARY, b"tea")) - self.assembler.get() # unblock other thread - # Test termination def test_get_fails_when_interrupted_by_close(self): @@ -477,3 +407,44 @@ def test_close_is_idempotent(self): """close can be called multiple times safely.""" self.assembler.close() self.assembler.close() + + # Test (non-)concurrency + + def test_get_fails_when_get_is_running(self): + """get cannot be called concurrently with itself.""" + with self.run_in_thread(self.assembler.get): + with self.assertRaises(RuntimeError): + self.assembler.get() + self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread + + def test_get_fails_when_get_iter_is_running(self): + """get cannot be called concurrently with get_iter.""" + with self.run_in_thread(lambda: list(self.assembler.get_iter())): + with self.assertRaises(RuntimeError): + self.assembler.get() + self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread + + def test_get_iter_fails_when_get_is_running(self): + """get_iter cannot be called concurrently with get.""" + with self.run_in_thread(self.assembler.get): + with self.assertRaises(RuntimeError): + list(self.assembler.get_iter()) + self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread + + def test_get_iter_fails_when_get_iter_is_running(self): + """get_iter cannot be called concurrently with itself.""" + with self.run_in_thread(lambda: list(self.assembler.get_iter())): + with self.assertRaises(RuntimeError): + list(self.assembler.get_iter()) + self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread + + def test_put_fails_when_put_is_running(self): + """put cannot be called concurrently with itself.""" + + def putter(): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + with self.run_in_thread(putter): + with self.assertRaises(RuntimeError): + self.assembler.put(Frame(OP_BINARY, b"tea")) + self.assembler.get() # unblock other thread From e21811e751f3f4fef18ad13b1b6f7064be004af6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jan 2024 11:12:27 +0100 Subject: [PATCH 025/109] Rename ssl_context to ssl in sync implementation. --- docs/project/changelog.rst | 14 ++++++++++- src/websockets/sync/client.py | 27 ++++++++++++-------- src/websockets/sync/server.py | 21 ++++++++++------ tests/sync/client.py | 3 ++- tests/sync/test_client.py | 47 ++++++++++++++++++++--------------- tests/sync/test_server.py | 23 +++++++++++------ 6 files changed, 89 insertions(+), 46 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 963353d0e..e288831be 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,11 +25,23 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. -12.1 +13.0 ---- *In development* +Backwards-incompatible changes +.............................. + +.. admonition:: The ``ssl_context`` argument of :func:`~sync.client.connect` + and :func:`~sync.server.serve` is renamed to ``ssl``. + :class: note + + This aligns the API of the :mod:`threading` implementation with the + :mod:`asyncio` implementation. + + For backwards compatibility, ``ssl_context`` is still supported. + New features ............ diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 79af0132f..6faca7789 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -1,8 +1,9 @@ from __future__ import annotations import socket -import ssl +import ssl as ssl_module import threading +import warnings from typing import Any, Optional, Sequence, Type from ..client import ClientProtocol @@ -128,7 +129,7 @@ def connect( *, # TCP/TLS sock: Optional[socket.socket] = None, - ssl_context: Optional[ssl.SSLContext] = None, + ssl: Optional[ssl_module.SSLContext] = None, server_hostname: Optional[str] = None, # WebSocket origin: Optional[Origin] = None, @@ -166,7 +167,7 @@ def connect( sock: Preexisting TCP socket. ``sock`` overrides the host and port from ``uri``. You may call :func:`socket.create_connection` to create a suitable TCP socket. - ssl_context: Configuration for enabling TLS on the connection. + ssl: Configuration for enabling TLS on the connection. server_hostname: Host name for the TLS handshake. ``server_hostname`` overrides the host name from ``uri``. origin: Value of the ``Origin`` header, for servers that require it. @@ -207,9 +208,14 @@ def connect( # Process parameters + # Backwards compatibility: ssl used to be called ssl_context. + if ssl is None and "ssl_context" in kwargs: + ssl = kwargs.pop("ssl_context") + warnings.warn("ssl_context was renamed to ssl", DeprecationWarning) + wsuri = parse_uri(uri) - if not wsuri.secure and ssl_context is not None: - raise TypeError("ssl_context argument is incompatible with a ws:// URI") + if not wsuri.secure and ssl is not None: + raise TypeError("ssl argument is incompatible with a ws:// URI") # Private APIs for unix_connect() unix: bool = kwargs.pop("unix", False) @@ -259,12 +265,12 @@ def connect( # Initialize TLS wrapper and perform TLS handshake if wsuri.secure: - if ssl_context is None: - ssl_context = ssl.create_default_context() + if ssl is None: + ssl = ssl_module.create_default_context() if server_hostname is None: server_hostname = wsuri.host sock.settimeout(deadline.timeout()) - sock = ssl_context.wrap_socket(sock, server_hostname=server_hostname) + sock = ssl.wrap_socket(sock, server_hostname=server_hostname) sock.settimeout(None) # Initialize WebSocket connection @@ -318,12 +324,13 @@ def unix_connect( Args: path: File system path to the Unix socket. uri: URI of the WebSocket server. ``uri`` defaults to - ``ws://localhost/`` or, when a ``ssl_context`` is provided, to + ``ws://localhost/`` or, when a ``ssl`` is provided, to ``wss://localhost/``. """ if uri is None: - if kwargs.get("ssl_context") is None: + # Backwards compatibility: ssl used to be called ssl_context. + if kwargs.get("ssl") is None and kwargs.get("ssl_context") is None: uri = "ws://localhost/" else: uri = "wss://localhost/" diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index c19992849..fa6087d54 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -5,9 +5,10 @@ import os import selectors import socket -import ssl +import ssl as ssl_module import sys import threading +import warnings from types import TracebackType from typing import Any, Callable, Optional, Sequence, Type @@ -268,7 +269,7 @@ def serve( *, # TCP/TLS sock: Optional[socket.socket] = None, - ssl_context: Optional[ssl.SSLContext] = None, + ssl: Optional[ssl_module.SSLContext] = None, # WebSocket origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, @@ -337,7 +338,7 @@ def handler(websocket): sock: Preexisting TCP socket. ``sock`` replaces ``host`` and ``port``. You may call :func:`socket.create_server` to create a suitable TCP socket. - ssl_context: Configuration for enabling TLS on the connection. + ssl: Configuration for enabling TLS on the connection. origins: Acceptable values of the ``Origin`` header, for defending against Cross-Site WebSocket Hijacking attacks. Include :obj:`None` in the list if the lack of an origin is acceptable. @@ -386,6 +387,11 @@ def handler(websocket): # Process parameters + # Backwards compatibility: ssl used to be called ssl_context. + if ssl is None and "ssl_context" in kwargs: + ssl = kwargs.pop("ssl_context") + warnings.warn("ssl_context was renamed to ssl", DeprecationWarning) + if subprotocols is not None: validate_subprotocols(subprotocols) @@ -417,8 +423,8 @@ def handler(websocket): # Initialize TLS wrapper - if ssl_context is not None: - sock = ssl_context.wrap_socket( + if ssl is not None: + sock = ssl.wrap_socket( sock, server_side=True, # Delay TLS handshake until after we set a timeout on the socket. @@ -441,9 +447,10 @@ def conn_handler(sock: socket.socket, addr: Any) -> None: # Perform TLS handshake - if ssl_context is not None: + if ssl is not None: sock.settimeout(deadline.timeout()) - assert isinstance(sock, ssl.SSLSocket) # mypy cannot figure this out + # mypy cannot figure this out + assert isinstance(sock, ssl_module.SSLSocket) sock.do_handshake() sock.settimeout(None) diff --git a/tests/sync/client.py b/tests/sync/client.py index 683893e88..bb4855c7f 100644 --- a/tests/sync/client.py +++ b/tests/sync/client.py @@ -25,7 +25,8 @@ def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): else: assert isinstance(wsuri_or_server, WebSocketServer) if secure is None: - secure = "ssl_context" in kwargs + # Backwards compatibility: ssl used to be called ssl_context. + secure = "ssl" in kwargs or "ssl_context" in kwargs protocol = "wss" if secure else "ws" host, port = wsuri_or_server.socket.getsockname() wsuri = f"{protocol}://{host}:{port}{resource_name}" diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index c900f3b0f..fa363debf 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -7,7 +7,7 @@ from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.sync.client import * -from ..utils import MS, temp_unix_socket_path +from ..utils import MS, DeprecationTestCase, temp_unix_socket_path from .client import CLIENT_CONTEXT, run_client, run_unix_client from .server import SERVER_CONTEXT, do_nothing, run_server, run_unix_server @@ -137,18 +137,18 @@ def close_connection(self, request): class SecureClientTests(unittest.TestCase): def test_connection(self): """Client connects to server securely.""" - with run_server(ssl_context=SERVER_CONTEXT) as server: - with run_client(server, ssl_context=CLIENT_CONTEXT) as client: + with run_server(ssl=SERVER_CONTEXT) as server: + with run_client(server, ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertEqual(client.socket.version()[:3], "TLS") def test_set_server_hostname_implicitly(self): """Client sets server_hostname to the host in the WebSocket URI.""" with temp_unix_socket_path() as path: - with run_unix_server(path, ssl_context=SERVER_CONTEXT): + with run_unix_server(path, ssl=SERVER_CONTEXT): with run_unix_client( path, - ssl_context=CLIENT_CONTEXT, + ssl=CLIENT_CONTEXT, uri="wss://overridden/", ) as client: self.assertEqual(client.socket.server_hostname, "overridden") @@ -156,17 +156,17 @@ def test_set_server_hostname_implicitly(self): def test_set_server_hostname_explicitly(self): """Client sets server_hostname to the value provided in argument.""" with temp_unix_socket_path() as path: - with run_unix_server(path, ssl_context=SERVER_CONTEXT): + with run_unix_server(path, ssl=SERVER_CONTEXT): with run_unix_client( path, - ssl_context=CLIENT_CONTEXT, + ssl=CLIENT_CONTEXT, server_hostname="overridden", ) as client: self.assertEqual(client.socket.server_hostname, "overridden") def test_reject_invalid_server_certificate(self): """Client rejects certificate where server certificate isn't trusted.""" - with run_server(ssl_context=SERVER_CONTEXT) as server: + with run_server(ssl=SERVER_CONTEXT) as server: with self.assertRaisesRegex( ssl.SSLCertVerificationError, r"certificate verify failed: self[ -]signed certificate", @@ -177,15 +177,13 @@ def test_reject_invalid_server_certificate(self): def test_reject_invalid_server_hostname(self): """Client rejects certificate where server hostname doesn't match.""" - with run_server(ssl_context=SERVER_CONTEXT) as server: + with run_server(ssl=SERVER_CONTEXT) as server: with self.assertRaisesRegex( ssl.SSLCertVerificationError, r"certificate verify failed: Hostname mismatch", ): # This hostname isn't included in the test certificate. - with run_client( - server, ssl_context=CLIENT_CONTEXT, server_hostname="invalid" - ): + with run_client(server, ssl=CLIENT_CONTEXT, server_hostname="invalid"): self.fail("did not raise") @@ -212,8 +210,8 @@ class SecureUnixClientTests(unittest.TestCase): def test_connection(self): """Client connects to server securely over a Unix socket.""" with temp_unix_socket_path() as path: - with run_unix_server(path, ssl_context=SERVER_CONTEXT): - with run_unix_client(path, ssl_context=CLIENT_CONTEXT) as client: + with run_unix_server(path, ssl=SERVER_CONTEXT): + with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertEqual(client.socket.version()[:3], "TLS") @@ -221,23 +219,23 @@ def test_set_server_hostname(self): """Client sets server_hostname to the host in the WebSocket URI.""" # This is part of the documented behavior of unix_connect(). with temp_unix_socket_path() as path: - with run_unix_server(path, ssl_context=SERVER_CONTEXT): + with run_unix_server(path, ssl=SERVER_CONTEXT): with run_unix_client( path, - ssl_context=CLIENT_CONTEXT, + ssl=CLIENT_CONTEXT, uri="wss://overridden/", ) as client: self.assertEqual(client.socket.server_hostname, "overridden") class ClientUsageErrorsTests(unittest.TestCase): - def test_ssl_context_without_secure_uri(self): - """Client rejects ssl_context when URI isn't secure.""" + def test_ssl_without_secure_uri(self): + """Client rejects ssl when URI isn't secure.""" with self.assertRaisesRegex( TypeError, - "ssl_context argument is incompatible with a ws:// URI", + "ssl argument is incompatible with a ws:// URI", ): - connect("ws://localhost/", ssl_context=CLIENT_CONTEXT) + connect("ws://localhost/", ssl=CLIENT_CONTEXT) def test_unix_without_path_or_sock(self): """Unix client requires path when sock isn't provided.""" @@ -272,3 +270,12 @@ def test_unsupported_compression(self): "unsupported compression: False", ): connect("ws://localhost/", compression=False) + + +class BackwardsCompatibilityTests(DeprecationTestCase): + def test_ssl_context_argument(self): + """Client supports the deprecated ssl_context argument.""" + with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertDeprecationWarning("ssl_context was renamed to ssl"): + with run_client(server, ssl_context=CLIENT_CONTEXT): + pass diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index f9db84246..5e7e79c52 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -14,7 +14,7 @@ from websockets.http11 import Request, Response from websockets.sync.server import * -from ..utils import MS, temp_unix_socket_path +from ..utils import MS, DeprecationTestCase, temp_unix_socket_path from .client import CLIENT_CONTEXT, run_client, run_unix_client from .server import ( SERVER_CONTEXT, @@ -274,20 +274,20 @@ def handler(sock, addr): class SecureServerTests(EvalShellMixin, unittest.TestCase): def test_connection(self): """Server receives secure connection from client.""" - with run_server(ssl_context=SERVER_CONTEXT) as server: - with run_client(server, ssl_context=CLIENT_CONTEXT) as client: + with run_server(ssl=SERVER_CONTEXT) as server: + with run_client(server, ssl=CLIENT_CONTEXT) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") self.assertEval(client, "ws.socket.version()[:3]", "TLS") def test_timeout_during_tls_handshake(self): """Server times out before receiving TLS handshake request from client.""" - with run_server(ssl_context=SERVER_CONTEXT, open_timeout=MS) as server: + with run_server(ssl=SERVER_CONTEXT, open_timeout=MS) as server: with socket.create_connection(server.socket.getsockname()) as sock: self.assertEqual(sock.recv(4096), b"") def test_connection_closed_during_tls_handshake(self): """Server reads EOF before receiving TLS handshake request from client.""" - with run_server(ssl_context=SERVER_CONTEXT) as server: + with run_server(ssl=SERVER_CONTEXT) as server: # Patch handler to record a reference to the thread running it. server_thread = None conn_received = threading.Event() @@ -325,8 +325,8 @@ class SecureUnixServerTests(EvalShellMixin, unittest.TestCase): def test_connection(self): """Server receives secure connection from client over a Unix socket.""" with temp_unix_socket_path() as path: - with run_unix_server(path, ssl_context=SERVER_CONTEXT): - with run_unix_client(path, ssl_context=CLIENT_CONTEXT) as client: + with run_unix_server(path, ssl=SERVER_CONTEXT): + with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") self.assertEval(client, "ws.socket.version()[:3]", "TLS") @@ -386,3 +386,12 @@ def test_shutdown(self): # Check that the server socket is closed. with self.assertRaises(OSError): server.socket.accept() + + +class BackwardsCompatibilityTests(DeprecationTestCase): + def test_ssl_context_argument(self): + """Client supports the deprecated ssl_context argument.""" + with self.assertDeprecationWarning("ssl_context was renamed to ssl"): + with run_server(ssl_context=SERVER_CONTEXT) as server: + with run_client(server, ssl=CLIENT_CONTEXT): + pass From 45d8de7495ea33724bf93d753d65cad932472aac Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jan 2024 21:06:38 +0100 Subject: [PATCH 026/109] Standardize style for testing exceptions. --- tests/legacy/test_client_server.py | 19 +++-- tests/legacy/test_http.py | 76 +++++++++++++++----- tests/sync/test_client.py | 90 ++++++++++++----------- tests/sync/test_connection.py | 54 +++++++------- tests/sync/test_server.py | 110 ++++++++++++++++------------- tests/test_server.py | 73 +++++++++++++------ 6 files changed, 265 insertions(+), 157 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index c49d91b70..4a21f7cea 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1331,20 +1331,24 @@ def test_checking_origin_succeeds(self): @with_server(origins=["http://localhost"]) def test_checking_origin_fails(self): - with self.assertRaisesRegex( - InvalidHandshake, "server rejected WebSocket connection: HTTP 403" - ): + with self.assertRaises(InvalidHandshake) as raised: self.start_client(origin="http://otherhost") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) @with_server(origins=["http://localhost"]) def test_checking_origins_fails_with_multiple_headers(self): - with self.assertRaisesRegex( - InvalidHandshake, "server rejected WebSocket connection: HTTP 400" - ): + with self.assertRaises(InvalidHandshake) as raised: self.start_client( origin="http://localhost", extra_headers=[("Origin", "http://otherhost")], ) + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) @with_server(origins=[None]) @with_client() @@ -1574,8 +1578,9 @@ async def run_client(): pass # work around bug in coverage with self.assertLogs("websockets", logging.INFO) as logs: - with self.assertRaisesRegex(Exception, "BOOM"): + with self.assertRaises(Exception) as raised: self.loop.run_until_complete(run_client()) + self.assertEqual(str(raised.exception), "BOOM") # Iteration 1 self.assertEqual( diff --git a/tests/legacy/test_http.py b/tests/legacy/test_http.py index 15d53e08d..76af61122 100644 --- a/tests/legacy/test_http.py +++ b/tests/legacy/test_http.py @@ -31,30 +31,48 @@ async def test_read_request(self): async def test_read_request_empty(self): self.stream.feed_eof() - with self.assertRaisesRegex( - EOFError, "connection closed while reading HTTP request line" - ): + with self.assertRaises(EOFError) as raised: await read_request(self.stream) + self.assertEqual( + str(raised.exception), + "connection closed while reading HTTP request line", + ) async def test_read_request_invalid_request_line(self): self.stream.feed_data(b"GET /\r\n\r\n") - with self.assertRaisesRegex(ValueError, "invalid HTTP request line: GET /"): + with self.assertRaises(ValueError) as raised: await read_request(self.stream) + self.assertEqual( + str(raised.exception), + "invalid HTTP request line: GET /", + ) async def test_read_request_unsupported_method(self): self.stream.feed_data(b"OPTIONS * HTTP/1.1\r\n\r\n") - with self.assertRaisesRegex(ValueError, "unsupported HTTP method: OPTIONS"): + with self.assertRaises(ValueError) as raised: await read_request(self.stream) + self.assertEqual( + str(raised.exception), + "unsupported HTTP method: OPTIONS", + ) async def test_read_request_unsupported_version(self): self.stream.feed_data(b"GET /chat HTTP/1.0\r\n\r\n") - with self.assertRaisesRegex(ValueError, "unsupported HTTP version: HTTP/1.0"): + with self.assertRaises(ValueError) as raised: await read_request(self.stream) + self.assertEqual( + str(raised.exception), + "unsupported HTTP version: HTTP/1.0", + ) async def test_read_request_invalid_header(self): self.stream.feed_data(b"GET /chat HTTP/1.1\r\nOops\r\n") - with self.assertRaisesRegex(ValueError, "invalid HTTP header line: Oops"): + with self.assertRaises(ValueError) as raised: await read_request(self.stream) + self.assertEqual( + str(raised.exception), + "invalid HTTP header line: Oops", + ) async def test_read_response(self): # Example from the protocol overview in RFC 6455 @@ -73,40 +91,66 @@ async def test_read_response(self): async def test_read_response_empty(self): self.stream.feed_eof() - with self.assertRaisesRegex( - EOFError, "connection closed while reading HTTP status line" - ): + with self.assertRaises(EOFError) as raised: await read_response(self.stream) + self.assertEqual( + str(raised.exception), + "connection closed while reading HTTP status line", + ) async def test_read_request_invalid_status_line(self): self.stream.feed_data(b"Hello!\r\n") - with self.assertRaisesRegex(ValueError, "invalid HTTP status line: Hello!"): + with self.assertRaises(ValueError) as raised: await read_response(self.stream) + self.assertEqual( + str(raised.exception), + "invalid HTTP status line: Hello!", + ) async def test_read_response_unsupported_version(self): self.stream.feed_data(b"HTTP/1.0 400 Bad Request\r\n\r\n") - with self.assertRaisesRegex(ValueError, "unsupported HTTP version: HTTP/1.0"): + with self.assertRaises(ValueError) as raised: await read_response(self.stream) + self.assertEqual( + str(raised.exception), + "unsupported HTTP version: HTTP/1.0", + ) async def test_read_response_invalid_status(self): self.stream.feed_data(b"HTTP/1.1 OMG WTF\r\n\r\n") - with self.assertRaisesRegex(ValueError, "invalid HTTP status code: OMG"): + with self.assertRaises(ValueError) as raised: await read_response(self.stream) + self.assertEqual( + str(raised.exception), + "invalid HTTP status code: OMG", + ) async def test_read_response_unsupported_status(self): self.stream.feed_data(b"HTTP/1.1 007 My name is Bond\r\n\r\n") - with self.assertRaisesRegex(ValueError, "unsupported HTTP status code: 007"): + with self.assertRaises(ValueError) as raised: await read_response(self.stream) + self.assertEqual( + str(raised.exception), + "unsupported HTTP status code: 007", + ) async def test_read_response_invalid_reason(self): self.stream.feed_data(b"HTTP/1.1 200 \x7f\r\n\r\n") - with self.assertRaisesRegex(ValueError, "invalid HTTP reason phrase: \\x7f"): + with self.assertRaises(ValueError) as raised: await read_response(self.stream) + self.assertEqual( + str(raised.exception), + "invalid HTTP reason phrase: \x7f", + ) async def test_read_response_invalid_header(self): self.stream.feed_data(b"HTTP/1.1 500 Internal Server Error\r\nOops\r\n") - with self.assertRaisesRegex(ValueError, "invalid HTTP header line: Oops"): + with self.assertRaises(ValueError) as raised: await read_response(self.stream) + self.assertEqual( + str(raised.exception), + "invalid HTTP header line: Oops", + ) async def test_header_name(self): self.stream.feed_data(b"foo bar: baz qux\r\n\r\n") diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index fa363debf..c403b9632 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -28,12 +28,13 @@ def remove_accept_header(self, request, response): # The connection will be open for the server but failed for the client. # Use a connection handler that exits immediately to avoid an exception. with run_server(do_nothing, process_response=remove_accept_header) as server: - with self.assertRaisesRegex( - InvalidHandshake, - "missing Sec-WebSocket-Accept header", - ): + with self.assertRaises(InvalidHandshake) as raised: with run_client(server, close_timeout=MS): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "missing Sec-WebSocket-Accept header", + ) def test_tcp_connection_fails(self): """Client fails to connect to server.""" @@ -107,15 +108,16 @@ def stall_connection(self, request): # Use a connection handler that exits immediately to avoid an exception. with run_server(do_nothing, process_request=stall_connection) as server: try: - with self.assertRaisesRegex( - TimeoutError, - "timed out during handshake", - ): + with self.assertRaises(TimeoutError) as raised: # While it shouldn't take 50ms to open a connection, this # test becomes flaky in CI when setting a smaller timeout, # even after increasing WEBSOCKETS_TESTS_TIMEOUT_FACTOR. with run_client(server, open_timeout=5 * MS): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during handshake", + ) finally: gate.set() @@ -126,12 +128,13 @@ def close_connection(self, request): self.close_socket() with run_server(process_request=close_connection) as server: - with self.assertRaisesRegex( - ConnectionError, - "connection closed during handshake", - ): + with self.assertRaises(ConnectionError) as raised: with run_client(server): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "connection closed during handshake", + ) class SecureClientTests(unittest.TestCase): @@ -167,24 +170,26 @@ def test_set_server_hostname_explicitly(self): def test_reject_invalid_server_certificate(self): """Client rejects certificate where server certificate isn't trusted.""" with run_server(ssl=SERVER_CONTEXT) as server: - with self.assertRaisesRegex( - ssl.SSLCertVerificationError, - r"certificate verify failed: self[ -]signed certificate", - ): + with self.assertRaises(ssl.SSLCertVerificationError) as raised: # The test certificate isn't trusted system-wide. with run_client(server, secure=True): self.fail("did not raise") + self.assertIn( + "certificate verify failed: self signed certificate", + str(raised.exception).replace("-", " "), + ) def test_reject_invalid_server_hostname(self): """Client rejects certificate where server hostname doesn't match.""" with run_server(ssl=SERVER_CONTEXT) as server: - with self.assertRaisesRegex( - ssl.SSLCertVerificationError, - r"certificate verify failed: Hostname mismatch", - ): + with self.assertRaises(ssl.SSLCertVerificationError) as raised: # This hostname isn't included in the test certificate. with run_client(server, ssl=CLIENT_CONTEXT, server_hostname="invalid"): self.fail("did not raise") + self.assertIn( + "certificate verify failed: Hostname mismatch", + str(raised.exception), + ) @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") @@ -231,45 +236,50 @@ def test_set_server_hostname(self): class ClientUsageErrorsTests(unittest.TestCase): def test_ssl_without_secure_uri(self): """Client rejects ssl when URI isn't secure.""" - with self.assertRaisesRegex( - TypeError, - "ssl argument is incompatible with a ws:// URI", - ): + with self.assertRaises(TypeError) as raised: connect("ws://localhost/", ssl=CLIENT_CONTEXT) + self.assertEqual( + str(raised.exception), + "ssl argument is incompatible with a ws:// URI", + ) def test_unix_without_path_or_sock(self): """Unix client requires path when sock isn't provided.""" - with self.assertRaisesRegex( - TypeError, - "missing path argument", - ): + with self.assertRaises(TypeError) as raised: unix_connect() + self.assertEqual( + str(raised.exception), + "missing path argument", + ) def test_unix_with_path_and_sock(self): """Unix client rejects path when sock is provided.""" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.addCleanup(sock.close) - with self.assertRaisesRegex( - TypeError, - "path and sock arguments are incompatible", - ): + with self.assertRaises(TypeError) as raised: unix_connect(path="/", sock=sock) + self.assertEqual( + str(raised.exception), + "path and sock arguments are incompatible", + ) def test_invalid_subprotocol(self): """Client rejects single value of subprotocols.""" - with self.assertRaisesRegex( - TypeError, - "subprotocols must be a list", - ): + with self.assertRaises(TypeError) as raised: connect("ws://localhost/", subprotocols="chat") + self.assertEqual( + str(raised.exception), + "subprotocols must be a list, not a str", + ) def test_unsupported_compression(self): """Client rejects incorrect value of compression.""" - with self.assertRaisesRegex( - ValueError, - "unsupported compression: False", - ): + with self.assertRaises(ValueError) as raised: connect("ws://localhost/", compression=False) + self.assertEqual( + str(raised.exception), + "unsupported compression: False", + ) class BackwardsCompatibilityTests(DeprecationTestCase): diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index e128425d8..953c8c253 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -177,12 +177,13 @@ def test_recv_during_recv(self): recv_thread = threading.Thread(target=self.connection.recv) recv_thread.start() - with self.assertRaisesRegex( - RuntimeError, + with self.assertRaises(RuntimeError) as raised: + self.connection.recv() + self.assertEqual( + str(raised.exception), "cannot call recv while another thread " "is already running recv or recv_streaming", - ): - self.connection.recv() + ) self.remote_connection.send("") recv_thread.join() @@ -194,12 +195,13 @@ def test_recv_during_recv_streaming(self): ) recv_streaming_thread.start() - with self.assertRaisesRegex( - RuntimeError, + with self.assertRaises(RuntimeError) as raised: + self.connection.recv() + self.assertEqual( + str(raised.exception), "cannot call recv while another thread " "is already running recv or recv_streaming", - ): - self.connection.recv() + ) self.remote_connection.send("") recv_streaming_thread.join() @@ -257,12 +259,13 @@ def test_recv_streaming_during_recv(self): recv_thread = threading.Thread(target=self.connection.recv) recv_thread.start() - with self.assertRaisesRegex( - RuntimeError, + with self.assertRaises(RuntimeError) as raised: + list(self.connection.recv_streaming()) + self.assertEqual( + str(raised.exception), "cannot call recv_streaming while another thread " "is already running recv or recv_streaming", - ): - list(self.connection.recv_streaming()) + ) self.remote_connection.send("") recv_thread.join() @@ -274,12 +277,13 @@ def test_recv_streaming_during_recv_streaming(self): ) recv_streaming_thread.start() - with self.assertRaisesRegex( - RuntimeError, + with self.assertRaises(RuntimeError) as raised: + list(self.connection.recv_streaming()) + self.assertEqual( + str(raised.exception), r"cannot call recv_streaming while another thread " r"is already running recv or recv_streaming", - ): - list(self.connection.recv_streaming()) + ) self.remote_connection.send("") recv_streaming_thread.join() @@ -355,11 +359,12 @@ def fragments(): [b"\x01\x02", b"\xfe\xff"], ]: with self.subTest(message=message): - with self.assertRaisesRegex( - RuntimeError, - "cannot call send while another thread is already running send", - ): + with self.assertRaises(RuntimeError) as raised: self.connection.send(message) + self.assertEqual( + str(raised.exception), + "cannot call send while another thread is already running send", + ) exit_gate.set() send_thread.join() @@ -598,11 +603,12 @@ def test_ping_duplicate_payload(self): """ping rejects the same payload until receiving the pong.""" with self.remote_connection.protocol_mutex: # block response to ping pong_waiter = self.connection.ping("idem") - with self.assertRaisesRegex( - RuntimeError, - "already waiting for a pong with the same data", - ): + with self.assertRaises(RuntimeError) as raised: self.connection.ping("idem") + self.assertEqual( + str(raised.exception), + "already waiting for a pong with the same data", + ) self.assertTrue(pong_waiter.wait(MS)) self.connection.ping("idem") # doesn't raise an exception diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 5e7e79c52..f9f30baf1 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -41,33 +41,36 @@ def remove_key_header(self, request): del request.headers["Sec-WebSocket-Key"] with run_server(process_request=remove_key_header) as server: - with self.assertRaisesRegex( - InvalidStatus, - "server rejected WebSocket connection: HTTP 400", - ): + with self.assertRaises(InvalidStatus) as raised: with run_client(server): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) def test_connection_handler_returns(self): """Connection handler returns.""" with run_server(do_nothing) as server: with run_client(server) as client: - with self.assertRaisesRegex( - ConnectionClosedOK, - r"received 1000 \(OK\); then sent 1000 \(OK\)", - ): + with self.assertRaises(ConnectionClosedOK) as raised: client.recv() + self.assertEqual( + str(raised.exception), + "received 1000 (OK); then sent 1000 (OK)", + ) def test_connection_handler_raises_exception(self): """Connection handler raises an exception.""" with run_server(crash) as server: with run_client(server) as client: - with self.assertRaisesRegex( - ConnectionClosedError, - r"received 1011 \(internal error\); " - r"then sent 1011 \(internal error\)", - ): + with self.assertRaises(ConnectionClosedError) as raised: client.recv() + self.assertEqual( + str(raised.exception), + "received 1011 (internal error); " + "then sent 1011 (internal error)", + ) def test_existing_socket(self): """Server receives connection using a pre-existing socket.""" @@ -100,12 +103,13 @@ def select_subprotocol(ws, subprotocols): raise NegotiationError with run_server(select_subprotocol=select_subprotocol) as server: - with self.assertRaisesRegex( - InvalidStatus, - "server rejected WebSocket connection: HTTP 400", - ): + with self.assertRaises(InvalidStatus) as raised: with run_client(server): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) def test_select_subprotocol_raises_exception(self): """Server returns an error if select_subprotocol raises an exception.""" @@ -114,12 +118,13 @@ def select_subprotocol(ws, subprotocols): raise RuntimeError with run_server(select_subprotocol=select_subprotocol) as server: - with self.assertRaisesRegex( - InvalidStatus, - "server rejected WebSocket connection: HTTP 500", - ): + with self.assertRaises(InvalidStatus) as raised: with run_client(server): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) def test_process_request(self): """Server runs process_request before processing the handshake.""" @@ -139,12 +144,13 @@ def process_request(ws, request): return ws.protocol.reject(http.HTTPStatus.FORBIDDEN, "Forbidden") with run_server(process_request=process_request) as server: - with self.assertRaisesRegex( - InvalidStatus, - "server rejected WebSocket connection: HTTP 403", - ): + with self.assertRaises(InvalidStatus) as raised: with run_client(server): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) def test_process_request_raises_exception(self): """Server returns an error if process_request raises an exception.""" @@ -153,12 +159,13 @@ def process_request(ws, request): raise RuntimeError with run_server(process_request=process_request) as server: - with self.assertRaisesRegex( - InvalidStatus, - "server rejected WebSocket connection: HTTP 500", - ): + with self.assertRaises(InvalidStatus) as raised: with run_client(server): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) def test_process_response(self): """Server runs process_response after processing the handshake.""" @@ -193,12 +200,13 @@ def process_response(ws, request, response): raise RuntimeError with run_server(process_response=process_response) as server: - with self.assertRaisesRegex( - InvalidStatus, - "server rejected WebSocket connection: HTTP 500", - ): + with self.assertRaises(InvalidStatus) as raised: with run_client(server): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) def test_override_server(self): """Server can override Server header with server_header.""" @@ -334,37 +342,41 @@ def test_connection(self): class ServerUsageErrorsTests(unittest.TestCase): def test_unix_without_path_or_sock(self): """Unix server requires path when sock isn't provided.""" - with self.assertRaisesRegex( - TypeError, - "missing path argument", - ): + with self.assertRaises(TypeError) as raised: unix_serve(eval_shell) + self.assertEqual( + str(raised.exception), + "missing path argument", + ) def test_unix_with_path_and_sock(self): """Unix server rejects path when sock is provided.""" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.addCleanup(sock.close) - with self.assertRaisesRegex( - TypeError, - "path and sock arguments are incompatible", - ): + with self.assertRaises(TypeError) as raised: unix_serve(eval_shell, path="/", sock=sock) + self.assertEqual( + str(raised.exception), + "path and sock arguments are incompatible", + ) def test_invalid_subprotocol(self): """Server rejects single value of subprotocols.""" - with self.assertRaisesRegex( - TypeError, - "subprotocols must be a list", - ): + with self.assertRaises(TypeError) as raised: serve(eval_shell, subprotocols="chat") + self.assertEqual( + str(raised.exception), + "subprotocols must be a list, not a str", + ) def test_unsupported_compression(self): """Server rejects incorrect value of compression.""" - with self.assertRaisesRegex( - ValueError, - "unsupported compression: False", - ): + with self.assertRaises(ValueError) as raised: serve(eval_shell, compression=False) + self.assertEqual( + str(raised.exception), + "unsupported compression: False", + ) class WebSocketServerTests(unittest.TestCase): diff --git a/tests/test_server.py b/tests/test_server.py index b6f5e3568..e4460dcba 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -213,7 +213,10 @@ def test_unexpected_exception(self): self.assertEqual(response.status_code, 500) with self.assertRaises(Exception) as raised: raise server.handshake_exc - self.assertEqual(str(raised.exception), "BOOM") + self.assertEqual( + str(raised.exception), + "BOOM", + ) def test_missing_connection(self): server = ServerProtocol() @@ -225,7 +228,10 @@ def test_missing_connection(self): self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: raise server.handshake_exc - self.assertEqual(str(raised.exception), "missing Connection header") + self.assertEqual( + str(raised.exception), + "missing Connection header", + ) def test_invalid_connection(self): server = ServerProtocol() @@ -238,7 +244,10 @@ def test_invalid_connection(self): self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: raise server.handshake_exc - self.assertEqual(str(raised.exception), "invalid Connection header: close") + self.assertEqual( + str(raised.exception), + "invalid Connection header: close", + ) def test_missing_upgrade(self): server = ServerProtocol() @@ -250,7 +259,10 @@ def test_missing_upgrade(self): self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: raise server.handshake_exc - self.assertEqual(str(raised.exception), "missing Upgrade header") + self.assertEqual( + str(raised.exception), + "missing Upgrade header", + ) def test_invalid_upgrade(self): server = ServerProtocol() @@ -263,7 +275,10 @@ def test_invalid_upgrade(self): self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: raise server.handshake_exc - self.assertEqual(str(raised.exception), "invalid Upgrade header: h2c") + self.assertEqual( + str(raised.exception), + "invalid Upgrade header: h2c", + ) def test_missing_key(self): server = ServerProtocol() @@ -274,7 +289,10 @@ def test_missing_key(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: raise server.handshake_exc - self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Key header") + self.assertEqual( + str(raised.exception), + "missing Sec-WebSocket-Key header", + ) def test_multiple_key(self): server = ServerProtocol() @@ -302,7 +320,8 @@ def test_invalid_key(self): with self.assertRaises(InvalidHeader) as raised: raise server.handshake_exc self.assertEqual( - str(raised.exception), "invalid Sec-WebSocket-Key header: not Base64 data!" + str(raised.exception), + "invalid Sec-WebSocket-Key header: not Base64 data!", ) def test_truncated_key(self): @@ -318,7 +337,8 @@ def test_truncated_key(self): with self.assertRaises(InvalidHeader) as raised: raise server.handshake_exc self.assertEqual( - str(raised.exception), f"invalid Sec-WebSocket-Key header: {KEY[:16]}" + str(raised.exception), + f"invalid Sec-WebSocket-Key header: {KEY[:16]}", ) def test_missing_version(self): @@ -330,7 +350,10 @@ def test_missing_version(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: raise server.handshake_exc - self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Version header") + self.assertEqual( + str(raised.exception), + "missing Sec-WebSocket-Version header", + ) def test_multiple_version(self): server = ServerProtocol() @@ -358,7 +381,8 @@ def test_invalid_version(self): with self.assertRaises(InvalidHeader) as raised: raise server.handshake_exc self.assertEqual( - str(raised.exception), "invalid Sec-WebSocket-Version header: 11" + str(raised.exception), + "invalid Sec-WebSocket-Version header: 11", ) def test_no_origin(self): @@ -369,7 +393,10 @@ def test_no_origin(self): self.assertEqual(response.status_code, 403) with self.assertRaises(InvalidOrigin) as raised: raise server.handshake_exc - self.assertEqual(str(raised.exception), "missing Origin header") + self.assertEqual( + str(raised.exception), + "missing Origin header", + ) def test_origin(self): server = ServerProtocol(origins=["https://example.com"]) @@ -390,7 +417,8 @@ def test_unexpected_origin(self): with self.assertRaises(InvalidOrigin) as raised: raise server.handshake_exc self.assertEqual( - str(raised.exception), "invalid Origin header: https://other.example.com" + str(raised.exception), + "invalid Origin header: https://other.example.com", ) def test_multiple_origin(self): @@ -435,7 +463,8 @@ def test_unsupported_origin(self): with self.assertRaises(InvalidOrigin) as raised: raise server.handshake_exc self.assertEqual( - str(raised.exception), "invalid Origin header: https://original.example.com" + str(raised.exception), + "invalid Origin header: https://original.example.com", ) def test_no_origin_accepted(self): @@ -574,11 +603,12 @@ def test_no_subprotocol(self): response = server.accept(request) self.assertEqual(response.status_code, 400) - with self.assertRaisesRegex( - NegotiationError, - r"missing subprotocol", - ): + with self.assertRaises(NegotiationError) as raised: raise server.handshake_exc + self.assertEqual( + str(raised.exception), + "missing subprotocol", + ) def test_subprotocol(self): server = ServerProtocol(subprotocols=["chat"]) @@ -628,11 +658,12 @@ def test_unsupported_subprotocol(self): response = server.accept(request) self.assertEqual(response.status_code, 400) - with self.assertRaisesRegex( - NegotiationError, - r"invalid subprotocol; expected one of superchat, chat", - ): + with self.assertRaises(NegotiationError) as raised: raise server.handshake_exc + self.assertEqual( + str(raised.exception), + "invalid subprotocol; expected one of superchat, chat", + ) @staticmethod def optional_chat(protocol, subprotocols): From c06e44d214ca3650b12fbbcaa1a0266dae9432d0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 27 Jan 2024 15:20:39 +0100 Subject: [PATCH 027/109] Support closing while sending a fragmented message. On one hand, it will close the connection with an unfinished fragmented message, which is less than ideal. On the other hand, RFC 6455 implies that it should be legal and it's probably best to let users close the connection if they want to close the connection (rather than force them to call fail() instead). --- src/websockets/protocol.py | 6 ++---- tests/test_protocol.py | 36 ++++++++++++------------------------ 2 files changed, 14 insertions(+), 28 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 6851f3b1f..4650cf16d 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -354,12 +354,10 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: reason: close reason. Raises: - ProtocolError: If a fragmented message is being sent, if the code - isn't valid, or if a reason is provided without a code + ProtocolError: If the code isn't valid or if a reason is provided + without a code. """ - if self.expect_continuation_frame: - raise ProtocolError("expected a continuation frame") if code is None: if reason != "": raise ProtocolError("cannot send a reason without a code") diff --git a/tests/test_protocol.py b/tests/test_protocol.py index a64172b53..a1661231f 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1364,26 +1364,22 @@ def test_client_send_close_in_fragmented_message(self): client = Protocol(CLIENT) client.send_text(b"Spam", fin=False) self.assertFrameSent(client, Frame(OP_TEXT, b"Spam", fin=False)) - # The spec says: "An endpoint MUST be capable of handling control - # frames in the middle of a fragmented message." However, since the - # endpoint must not send a data frame after a close frame, a close - # frame can't be "in the middle" of a fragmented message. - with self.assertRaises(ProtocolError) as raised: - client.send_close(CloseCode.GOING_AWAY) - self.assertEqual(str(raised.exception), "expected a continuation frame") - client.send_continuation(b"Eggs", fin=True) + with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): + client.send_close() + self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) + self.assertIs(client.state, CLOSING) + with self.assertRaises(InvalidState): + client.send_continuation(b"Eggs", fin=True) def test_server_send_close_in_fragmented_message(self): - server = Protocol(CLIENT) + server = Protocol(SERVER) server.send_text(b"Spam", fin=False) self.assertFrameSent(server, Frame(OP_TEXT, b"Spam", fin=False)) - # The spec says: "An endpoint MUST be capable of handling control - # frames in the middle of a fragmented message." However, since the - # endpoint must not send a data frame after a close frame, a close - # frame can't be "in the middle" of a fragmented message. - with self.assertRaises(ProtocolError) as raised: - server.send_close(CloseCode.NORMAL_CLOSURE) - self.assertEqual(str(raised.exception), "expected a continuation frame") + server.send_close() + self.assertEqual(server.data_to_send(), [b"\x88\x00"]) + self.assertIs(server.state, CLOSING) + with self.assertRaises(InvalidState): + server.send_continuation(b"Eggs", fin=True) def test_client_receive_close_in_fragmented_message(self): client = Protocol(CLIENT) @@ -1392,10 +1388,6 @@ def test_client_receive_close_in_fragmented_message(self): client, Frame(OP_TEXT, b"Spam", fin=False), ) - # The spec says: "An endpoint MUST be capable of handling control - # frames in the middle of a fragmented message." However, since the - # endpoint must not send a data frame after a close frame, a close - # frame can't be "in the middle" of a fragmented message. client.receive_data(b"\x88\x02\x03\xe8") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "incomplete fragmented message") @@ -1410,10 +1402,6 @@ def test_server_receive_close_in_fragmented_message(self): server, Frame(OP_TEXT, b"Spam", fin=False), ) - # The spec says: "An endpoint MUST be capable of handling control - # frames in the middle of a fragmented message." However, since the - # endpoint must not send a data frame after a close frame, a close - # frame can't be "in the middle" of a fragmented message. server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "incomplete fragmented message") From d28b71dd297da99aad9d644a2f4721707e464707 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 27 Jan 2024 15:41:42 +0100 Subject: [PATCH 028/109] Upgrade to the latest version of black. --- src/websockets/datastructures.py | 6 ++---- tests/test_protocol.py | 14 ++++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index c2a5acfee..aef11bf23 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -172,11 +172,9 @@ class SupportsKeysAndGetItem(Protocol): # pragma: no cover """ - def keys(self) -> Iterable[str]: - ... + def keys(self) -> Iterable[str]: ... - def __getitem__(self, key: str) -> str: - ... + def __getitem__(self, key: str) -> str: ... HeadersLike = Union[ diff --git a/tests/test_protocol.py b/tests/test_protocol.py index a1661231f..b53c8a1ec 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -37,12 +37,14 @@ def assertFrameSent(self, connection, frame, eof=False): """ frames_sent = [ - None - if write is SEND_EOF - else self.parse( - write, - mask=connection.side is CLIENT, - extensions=connection.extensions, + ( + None + if write is SEND_EOF + else self.parse( + write, + mask=connection.side is CLIENT, + extensions=connection.extensions, + ) ) for write in connection.data_to_send() ] From 705dc85e87bb1184d926ab95a591097780c4b855 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 27 Jan 2024 14:55:44 +0100 Subject: [PATCH 029/109] Allow sending ping and pong after close. Fix #1429. --- src/websockets/protocol.py | 28 ++++++++-- tests/test_protocol.py | 111 +++++++++++++++++++++++++++---------- 2 files changed, 105 insertions(+), 34 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 4650cf16d..0b36202e5 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -297,6 +297,8 @@ def send_continuation(self, data: bytes, fin: bool) -> None: """ if not self.expect_continuation_frame: raise ProtocolError("unexpected continuation frame") + if self._state is not OPEN: + raise InvalidState(f"connection is {self.state.name.lower()}") self.expect_continuation_frame = not fin self.send_frame(Frame(OP_CONT, data, fin)) @@ -318,6 +320,8 @@ def send_text(self, data: bytes, fin: bool = True) -> None: """ if self.expect_continuation_frame: raise ProtocolError("expected a continuation frame") + if self._state is not OPEN: + raise InvalidState(f"connection is {self.state.name.lower()}") self.expect_continuation_frame = not fin self.send_frame(Frame(OP_TEXT, data, fin)) @@ -339,6 +343,8 @@ def send_binary(self, data: bytes, fin: bool = True) -> None: """ if self.expect_continuation_frame: raise ProtocolError("expected a continuation frame") + if self._state is not OPEN: + raise InvalidState(f"connection is {self.state.name.lower()}") self.expect_continuation_frame = not fin self.send_frame(Frame(OP_BINARY, data, fin)) @@ -358,6 +364,10 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: without a code. """ + # While RFC 6455 doesn't rule out sending more than one close Frame, + # websockets is conservative in what it sends and doesn't allow that. + if self._state is not OPEN: + raise InvalidState(f"connection is {self.state.name.lower()}") if code is None: if reason != "": raise ProtocolError("cannot send a reason without a code") @@ -383,6 +393,9 @@ def send_ping(self, data: bytes) -> None: data: payload containing arbitrary binary data. """ + # RFC 6455 allows control frames after starting the closing handshake. + if self._state is not OPEN and self._state is not CLOSING: + raise InvalidState(f"connection is {self.state.name.lower()}") self.send_frame(Frame(OP_PING, data)) def send_pong(self, data: bytes) -> None: @@ -396,6 +409,9 @@ def send_pong(self, data: bytes) -> None: data: payload containing arbitrary binary data. """ + # RFC 6455 allows control frames after starting the closing handshake. + if self._state is not OPEN and self._state is not CLOSING: + raise InvalidState(f"connection is {self.state.name.lower()}") self.send_frame(Frame(OP_PONG, data)) def fail(self, code: int, reason: str = "") -> None: @@ -675,6 +691,8 @@ def recv_frame(self, frame: Frame) -> None: # 1.4. Closing Handshake: "after receiving a control frame # indicating the connection should be closed, a peer discards # any further data received." + # RFC 6455 allows reading Ping and Pong frames after a Close frame. + # However, that doesn't seem useful; websockets doesn't support it. self.parser = self.discard() next(self.parser) # start coroutine @@ -687,15 +705,13 @@ def recv_frame(self, frame: Frame) -> None: # Private methods for sending events. def send_frame(self, frame: Frame) -> None: - if self.state is not OPEN: - raise InvalidState( - f"cannot write to a WebSocket in the {self.state.name} state" - ) - if self.debug: self.logger.debug("> %s", frame) self.writes.append( - frame.serialize(mask=self.side is CLIENT, extensions=self.extensions) + frame.serialize( + mask=self.side is CLIENT, + extensions=self.extensions, + ) ) def send_eof(self) -> None: diff --git a/tests/test_protocol.py b/tests/test_protocol.py index b53c8a1ec..1d5dab7a0 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -465,15 +465,17 @@ def test_client_sends_text_after_sending_close(self): with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - with self.assertRaises(InvalidState): + with self.assertRaises(InvalidState) as raised: client.send_text(b"") + self.assertEqual(str(raised.exception), "connection is closing") def test_server_sends_text_after_sending_close(self): server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - with self.assertRaises(InvalidState): + with self.assertRaises(InvalidState) as raised: server.send_text(b"") + self.assertEqual(str(raised.exception), "connection is closing") def test_client_receives_text_after_receiving_close(self): client = Protocol(CLIENT) @@ -679,15 +681,17 @@ def test_client_sends_binary_after_sending_close(self): with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - with self.assertRaises(InvalidState): + with self.assertRaises(InvalidState) as raised: client.send_binary(b"") + self.assertEqual(str(raised.exception), "connection is closing") def test_server_sends_binary_after_sending_close(self): server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - with self.assertRaises(InvalidState): + with self.assertRaises(InvalidState) as raised: server.send_binary(b"") + self.assertEqual(str(raised.exception), "connection is closing") def test_client_receives_binary_after_receiving_close(self): client = Protocol(CLIENT) @@ -956,6 +960,37 @@ def test_server_receives_close_with_non_utf8_reason(self): ) self.assertIs(server.state, CLOSING) + def test_client_sends_close_twice(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_close(CloseCode.GOING_AWAY) + self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + with self.assertRaises(InvalidState) as raised: + client.send_close(CloseCode.GOING_AWAY) + self.assertEqual(str(raised.exception), "connection is closing") + + def test_server_sends_close_twice(self): + server = Protocol(SERVER) + server.send_close(CloseCode.NORMAL_CLOSURE) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) + with self.assertRaises(InvalidState) as raised: + server.send_close(CloseCode.NORMAL_CLOSURE) + self.assertEqual(str(raised.exception), "connection is closing") + + def test_client_sends_close_after_connection_is_closed(self): + client = Protocol(CLIENT) + client.receive_eof() + with self.assertRaises(InvalidState) as raised: + client.send_close(CloseCode.GOING_AWAY) + self.assertEqual(str(raised.exception), "connection is closed") + + def test_server_sends_close_after_connection_is_closed(self): + server = Protocol(SERVER) + server.receive_eof() + with self.assertRaises(InvalidState) as raised: + server.send_close(CloseCode.NORMAL_CLOSURE) + self.assertEqual(str(raised.exception), "connection is closed") + class PingTests(ProtocolTestCase): """ @@ -1072,35 +1107,23 @@ def test_client_sends_ping_after_sending_close(self): with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - # The spec says: "An endpoint MAY send a Ping frame any time (...) - # before the connection is closed" but websockets doesn't support - # sending a Ping frame after a Close frame. - with self.assertRaises(InvalidState) as raised: + with self.enforce_mask(b"\x00\x44\x88\xcc"): client.send_ping(b"") - self.assertEqual( - str(raised.exception), - "cannot write to a WebSocket in the CLOSING state", - ) + self.assertEqual(client.data_to_send(), [b"\x89\x80\x00\x44\x88\xcc"]) def test_server_sends_ping_after_sending_close(self): server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - # The spec says: "An endpoint MAY send a Ping frame any time (...) - # before the connection is closed" but websockets doesn't support - # sending a Ping frame after a Close frame. - with self.assertRaises(InvalidState) as raised: - server.send_ping(b"") - self.assertEqual( - str(raised.exception), - "cannot write to a WebSocket in the CLOSING state", - ) + server.send_ping(b"") + self.assertEqual(server.data_to_send(), [b"\x89\x00"]) def test_client_receives_ping_after_receiving_close(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE) client.receive_data(b"\x89\x04\x22\x66\xaa\xee") + # websockets ignores control frames after a close frame. self.assertFrameReceived(client, None) self.assertFrameSent(client, None) @@ -1109,9 +1132,24 @@ def test_server_receives_ping_after_receiving_close(self): server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, CloseCode.GOING_AWAY) server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") + # websockets ignores control frames after a close frame. self.assertFrameReceived(server, None) self.assertFrameSent(server, None) + def test_client_sends_ping_after_connection_is_closed(self): + client = Protocol(CLIENT) + client.receive_eof() + with self.assertRaises(InvalidState) as raised: + client.send_ping(b"") + self.assertEqual(str(raised.exception), "connection is closed") + + def test_server_sends_ping_after_connection_is_closed(self): + server = Protocol(SERVER) + server.receive_eof() + with self.assertRaises(InvalidState) as raised: + server.send_ping(b"") + self.assertEqual(str(raised.exception), "connection is closed") + class PongTests(ProtocolTestCase): """ @@ -1212,23 +1250,23 @@ def test_client_sends_pong_after_sending_close(self): with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - # websockets doesn't support sending a Pong frame after a Close frame. - with self.assertRaises(InvalidState): + with self.enforce_mask(b"\x00\x44\x88\xcc"): client.send_pong(b"") + self.assertEqual(client.data_to_send(), [b"\x8a\x80\x00\x44\x88\xcc"]) def test_server_sends_pong_after_sending_close(self): server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - # websockets doesn't support sending a Pong frame after a Close frame. - with self.assertRaises(InvalidState): - server.send_pong(b"") + server.send_pong(b"") + self.assertEqual(server.data_to_send(), [b"\x8a\x00"]) def test_client_receives_pong_after_receiving_close(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE) client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") + # websockets ignores control frames after a close frame. self.assertFrameReceived(client, None) self.assertFrameSent(client, None) @@ -1237,9 +1275,24 @@ def test_server_receives_pong_after_receiving_close(self): server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, CloseCode.GOING_AWAY) server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") + # websockets ignores control frames after a close frame. self.assertFrameReceived(server, None) self.assertFrameSent(server, None) + def test_client_sends_pong_after_connection_is_closed(self): + client = Protocol(CLIENT) + client.receive_eof() + with self.assertRaises(InvalidState) as raised: + client.send_pong(b"") + self.assertEqual(str(raised.exception), "connection is closed") + + def test_server_sends_pong_after_connection_is_closed(self): + server = Protocol(SERVER) + server.receive_eof() + with self.assertRaises(InvalidState) as raised: + server.send_pong(b"") + self.assertEqual(str(raised.exception), "connection is closed") + class FailTests(ProtocolTestCase): """ @@ -1370,8 +1423,9 @@ def test_client_send_close_in_fragmented_message(self): client.send_close() self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) self.assertIs(client.state, CLOSING) - with self.assertRaises(InvalidState): + with self.assertRaises(InvalidState) as raised: client.send_continuation(b"Eggs", fin=True) + self.assertEqual(str(raised.exception), "connection is closing") def test_server_send_close_in_fragmented_message(self): server = Protocol(SERVER) @@ -1380,8 +1434,9 @@ def test_server_send_close_in_fragmented_message(self): server.send_close() self.assertEqual(server.data_to_send(), [b"\x88\x00"]) self.assertIs(server.state, CLOSING) - with self.assertRaises(InvalidState): + with self.assertRaises(InvalidState) as raised: server.send_continuation(b"Eggs", fin=True) + self.assertEqual(str(raised.exception), "connection is closing") def test_client_receive_close_in_fragmented_message(self): client = Protocol(CLIENT) From 96fddaf49b5a5af1f3215076bf2a73dfb4b72ca1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 28 Jan 2024 16:47:51 +0100 Subject: [PATCH 030/109] Wording and line wrapping fixes in changelog. --- docs/project/changelog.rst | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index e288831be..dc84a5ae2 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -34,7 +34,8 @@ Backwards-incompatible changes .............................. .. admonition:: The ``ssl_context`` argument of :func:`~sync.client.connect` - and :func:`~sync.server.serve` is renamed to ``ssl``. + and :func:`~sync.server.serve` in the :mod:`threading` implementation is + renamed to ``ssl``. :class: note This aligns the API of the :mod:`threading` implementation with the @@ -140,7 +141,8 @@ Backwards-incompatible changes As a consequence, calling ``WebSocket.close()`` without arguments in a browser isn't reported as an error anymore. -.. admonition:: :func:`~server.serve` times out on the opening handshake after 10 seconds by default. +.. admonition:: :func:`~server.serve` times out on the opening handshake after + 10 seconds by default. :class: note You can adjust the timeout with the ``open_timeout`` parameter. Set it to @@ -149,7 +151,7 @@ Backwards-incompatible changes New features ............ -.. admonition:: websockets 11.0 introduces a implementation on top of :mod:`threading`. +.. admonition:: websockets 11.0 introduces a :mod:`threading` implementation. :class: important It may be more convenient if you don't need to manage many connections and @@ -211,7 +213,8 @@ Improvements Backwards-incompatible changes .............................. -.. admonition:: The ``exception`` attribute of :class:`~http11.Request` and :class:`~http11.Response` is deprecated. +.. admonition:: The ``exception`` attribute of :class:`~http11.Request` and + :class:`~http11.Response` is deprecated. :class: note Use the ``handshake_exc`` attribute of :class:`~server.ServerProtocol` and @@ -565,11 +568,11 @@ Backwards-incompatible changes .. admonition:: ``process_request`` is now expected to be a coroutine. :class: note - If you're passing a ``process_request`` argument to - :func:`~server.serve` or :class:`~server.WebSocketServerProtocol`, or if - you're overriding + If you're passing a ``process_request`` argument to :func:`~server.serve` + or :class:`~server.WebSocketServerProtocol`, or if you're overriding :meth:`~server.WebSocketServerProtocol.process_request` in a subclass, - define it with ``async def`` instead of ``def``. Previously, both were supported. + define it with ``async def`` instead of ``def``. Previously, both were + supported. For backwards compatibility, functions are still accepted, but mixing functions and coroutines won't work in some inheritance scenarios. From 3b7fa7673bf6a96a5e9debd7dcfa65e04f85efbb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 28 Jan 2024 16:49:13 +0100 Subject: [PATCH 031/109] Enable deprecation for second argument of handlers. --- docs/project/changelog.rst | 26 ++++++++++++++++++++------ src/websockets/legacy/server.py | 4 +--- tests/legacy/test_client_server.py | 6 ++---- 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index dc84a5ae2..fd186a5fc 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -43,6 +43,21 @@ Backwards-incompatible changes For backwards compatibility, ``ssl_context`` is still supported. +.. admonition:: Receiving the request path in the second parameter of connection + handlers is deprecated. + :class: note + + If you implemented the connection handler of a server as:: + + async def handler(request, path): + ... + + You should switch to the recommended pattern since 10.1:: + + async def handler(request): + path = request.path # only if handler() uses the path argument + ... + New features ............ @@ -257,20 +272,19 @@ New features * Added a tutorial. -* Made the second parameter of connection handlers optional. It will be - deprecated in the next major release. The request path is available in - the :attr:`~legacy.protocol.WebSocketCommonProtocol.path` attribute of - the first argument. +* Made the second parameter of connection handlers optional. The request path is + available in the :attr:`~legacy.protocol.WebSocketCommonProtocol.path` + attribute of the first argument. If you implemented the connection handler of a server as:: async def handler(request, path): ... - You should replace it by:: + You should replace it with:: async def handler(request): - path = request.path # if handler() uses the path argument + path = request.path # only if handler() uses the path argument ... * Added ``python -m websockets --version``. diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index e8cf8220f..4659ed9a6 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -1168,9 +1168,7 @@ def remove_path_argument( pass else: # ws_handler accepts two arguments; activate backwards compatibility. - - # Enable deprecation warning and announce deprecation in 11.0. - # warnings.warn("remove second argument of ws_handler", DeprecationWarning) + warnings.warn("remove second argument of ws_handler", DeprecationWarning) async def _ws_handler(websocket: WebSocketServerProtocol) -> Any: return await cast( diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 4a21f7cea..51a74734b 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -480,8 +480,7 @@ async def handler_with_path(ws, path): with self.temp_server( handler=handler_with_path, - # Enable deprecation warning and announce deprecation in 11.0. - # deprecation_warnings=["remove second argument of ws_handler"], + deprecation_warnings=["remove second argument of ws_handler"], ): with self.temp_client("/path"): self.assertEqual( @@ -497,8 +496,7 @@ async def handler_with_path(ws, path, extra): with self.temp_server( handler=bound_handler_with_path, - # Enable deprecation warning and announce deprecation in 11.0. - # deprecation_warnings=["remove second argument of ws_handler"], + deprecation_warnings=["remove second argument of ws_handler"], ): with self.temp_client("/path"): self.assertEqual( From aa33161cd9498bfca39d64fc36319bc1fbce68f2 Mon Sep 17 00:00:00 2001 From: MtkN1 <51289448+MtkN1@users.noreply.github.com> Date: Wed, 7 Feb 2024 13:14:22 +0900 Subject: [PATCH 032/109] Fix wrong RFC number --- src/websockets/legacy/client.py | 2 +- tests/test_protocol.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index b85d22867..255696580 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -599,7 +599,7 @@ async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: yield protocol except Exception: # Add a random initial delay between 0 and 5 seconds. - # See 7.2.3. Recovering from Abnormal Closure in RFC 6544. + # See 7.2.3. Recovering from Abnormal Closure in RFC 6455. if backoff_delay == self.BACKOFF_MIN: initial_delay = random.random() * self.BACKOFF_INITIAL self.logger.info( diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 1d5dab7a0..e1527525b 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -714,7 +714,7 @@ class CloseTests(ProtocolTestCase): """ Test close frames. - See RFC 6544: + See RFC 6455: 5.5.1. Close 7.1.6. The WebSocket Connection Close Reason @@ -994,7 +994,7 @@ def test_server_sends_close_after_connection_is_closed(self): class PingTests(ProtocolTestCase): """ - Test ping. See 5.5.2. Ping in RFC 6544. + Test ping. See 5.5.2. Ping in RFC 6455. """ @@ -1153,7 +1153,7 @@ def test_server_sends_ping_after_connection_is_closed(self): class PongTests(ProtocolTestCase): """ - Test pong frames. See 5.5.3. Pong in RFC 6544. + Test pong frames. See 5.5.3. Pong in RFC 6455. """ @@ -1298,7 +1298,7 @@ class FailTests(ProtocolTestCase): """ Test failing the connection. - See 7.1.7. Fail the WebSocket Connection in RFC 6544. + See 7.1.7. Fail the WebSocket Connection in RFC 6455. """ @@ -1321,7 +1321,7 @@ class FragmentationTests(ProtocolTestCase): """ Test message fragmentation. - See 5.4. Fragmentation in RFC 6544. + See 5.4. Fragmentation in RFC 6455. """ From 87f58c7190025521e5dc380945b0cc536169bd0c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Feb 2024 17:25:00 +0100 Subject: [PATCH 033/109] Fix make clean. --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index cf3b53393..bf8c8dc58 100644 --- a/Makefile +++ b/Makefile @@ -30,6 +30,6 @@ build: python setup.py build_ext --inplace clean: - find . -name '*.pyc' -o -name '*.so' -delete + find . -name '*.pyc' -delete -o -name '*.so' -delete find . -name __pycache__ -delete rm -rf .coverage .mypy_cache build compliance/reports dist docs/_build htmlcov MANIFEST src/websockets.egg-info From 9b5273c68323dd63598dfcba97339f03f61d3d0f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Feb 2024 17:34:57 +0100 Subject: [PATCH 034/109] Move CLIENT/SERVER_CONTEXT to utils. Then we can reuse them for testing other implementations. --- tests/sync/client.py | 8 -------- tests/sync/server.py | 17 ----------------- tests/sync/test_client.py | 12 +++++++++--- tests/sync/test_server.py | 11 ++++++++--- tests/utils.py | 18 ++++++++++++++++++ 5 files changed, 35 insertions(+), 31 deletions(-) diff --git a/tests/sync/client.py b/tests/sync/client.py index bb4855c7f..72eb5b8d2 100644 --- a/tests/sync/client.py +++ b/tests/sync/client.py @@ -1,23 +1,15 @@ import contextlib -import ssl from websockets.sync.client import * from websockets.sync.server import WebSocketServer -from ..utils import CERTIFICATE - __all__ = [ - "CLIENT_CONTEXT", "run_client", "run_unix_client", ] -CLIENT_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) -CLIENT_CONTEXT.load_verify_locations(CERTIFICATE) - - @contextlib.contextmanager def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): if isinstance(wsuri_or_server, str): diff --git a/tests/sync/server.py b/tests/sync/server.py index a9a77438c..10ab789c2 100644 --- a/tests/sync/server.py +++ b/tests/sync/server.py @@ -1,25 +1,8 @@ import contextlib -import ssl import threading from websockets.sync.server import * -from ..utils import CERTIFICATE - - -SERVER_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) -SERVER_CONTEXT.load_cert_chain(CERTIFICATE) - -# Work around https://github.com/openssl/openssl/issues/7967 - -# This bug causes connect() to hang in tests for the client. Including this -# workaround acknowledges that the issue could happen outside of the test suite. - -# It shouldn't happen too often, or else OpenSSL 1.1.1 would be unusable. If it -# happens, we can look for a library-level fix, but it won't be easy. - -SERVER_CONTEXT.num_tickets = 0 - def crash(ws): raise RuntimeError diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index c403b9632..bebf68aa5 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -7,9 +7,15 @@ from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.sync.client import * -from ..utils import MS, DeprecationTestCase, temp_unix_socket_path -from .client import CLIENT_CONTEXT, run_client, run_unix_client -from .server import SERVER_CONTEXT, do_nothing, run_server, run_unix_server +from ..utils import ( + CLIENT_CONTEXT, + MS, + SERVER_CONTEXT, + DeprecationTestCase, + temp_unix_socket_path, +) +from .client import run_client, run_unix_client +from .server import do_nothing, run_server, run_unix_server class ClientTests(unittest.TestCase): diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index f9f30baf1..490a3f63e 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -14,10 +14,15 @@ from websockets.http11 import Request, Response from websockets.sync.server import * -from ..utils import MS, DeprecationTestCase, temp_unix_socket_path -from .client import CLIENT_CONTEXT, run_client, run_unix_client -from .server import ( +from ..utils import ( + CLIENT_CONTEXT, + MS, SERVER_CONTEXT, + DeprecationTestCase, + temp_unix_socket_path, +) +from .client import run_client, run_unix_client +from .server import ( EvalShellMixin, crash, do_nothing, diff --git a/tests/utils.py b/tests/utils.py index 2937a2f15..bd3b61d7b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,6 +3,7 @@ import os import pathlib import platform +import ssl import tempfile import time import unittest @@ -17,6 +18,23 @@ CERTIFICATE = bytes(pathlib.Path(__file__).with_name("test_localhost.pem")) +CLIENT_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +CLIENT_CONTEXT.load_verify_locations(CERTIFICATE) + + +SERVER_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) +SERVER_CONTEXT.load_cert_chain(CERTIFICATE) + +# Work around https://github.com/openssl/openssl/issues/7967 + +# This bug causes connect() to hang in tests for the client. Including this +# workaround acknowledges that the issue could happen outside of the test suite. + +# It shouldn't happen too often, or else OpenSSL 1.1.1 would be unusable. If it +# happens, we can look for a library-level fix, but it won't be easy. + +SERVER_CONTEXT.num_tickets = 0 + DATE = email.utils.formatdate(usegmt=True) From de768cf65e7e2b1a3b67854fb9e08816a5ff7050 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Feb 2024 21:08:45 +0100 Subject: [PATCH 035/109] Improve tests for sync implementation. --- tests/sync/server.py | 8 +++--- tests/sync/test_client.py | 59 ++++++++++++++++++++------------------- tests/sync/test_server.py | 36 ++++++++++++------------ 3 files changed, 53 insertions(+), 50 deletions(-) diff --git a/tests/sync/server.py b/tests/sync/server.py index 10ab789c2..d5295ccd8 100644 --- a/tests/sync/server.py +++ b/tests/sync/server.py @@ -25,8 +25,8 @@ def assertEval(self, client, expr, value): @contextlib.contextmanager -def run_server(ws_handler=eval_shell, host="localhost", port=0, **kwargs): - with serve(ws_handler, host, port, **kwargs) as server: +def run_server(handler=eval_shell, host="localhost", port=0, **kwargs): + with serve(handler, host, port, **kwargs) as server: thread = threading.Thread(target=server.serve_forever) thread.start() try: @@ -37,8 +37,8 @@ def run_server(ws_handler=eval_shell, host="localhost", port=0, **kwargs): @contextlib.contextmanager -def run_unix_server(path, ws_handler=eval_shell, **kwargs): - with unix_serve(ws_handler, path, **kwargs) as server: +def run_unix_server(path, handler=eval_shell, **kwargs): + with unix_serve(handler, path, **kwargs) as server: thread = threading.Thread(target=server.serve_forever) thread.start() try: diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index bebf68aa5..03f4e972f 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -3,7 +3,7 @@ import threading import unittest -from websockets.exceptions import InvalidHandshake +from websockets.exceptions import InvalidHandshake, InvalidURI from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.sync.client import * @@ -25,29 +25,6 @@ def test_connection(self): with run_client(server) as client: self.assertEqual(client.protocol.state.name, "OPEN") - def test_connection_fails(self): - """Client connects to server but the handshake fails.""" - - def remove_accept_header(self, request, response): - del response.headers["Sec-WebSocket-Accept"] - - # The connection will be open for the server but failed for the client. - # Use a connection handler that exits immediately to avoid an exception. - with run_server(do_nothing, process_response=remove_accept_header) as server: - with self.assertRaises(InvalidHandshake) as raised: - with run_client(server, close_timeout=MS): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "missing Sec-WebSocket-Accept header", - ) - - def test_tcp_connection_fails(self): - """Client fails to connect to server.""" - with self.assertRaises(OSError): - with run_client("ws://localhost:54321"): # invalid port - self.fail("did not raise") - def test_existing_socket(self): """Client connects using a pre-existing socket.""" with run_server() as server: @@ -103,6 +80,35 @@ def create_connection(*args, **kwargs): with run_client(server, create_connection=create_connection) as client: self.assertTrue(client.create_connection_ran) + def test_invalid_uri(self): + """Client receives an invalid URI.""" + with self.assertRaises(InvalidURI): + with run_client("http://localhost"): # invalid scheme + self.fail("did not raise") + + def test_tcp_connection_fails(self): + """Client fails to connect to server.""" + with self.assertRaises(OSError): + with run_client("ws://localhost:54321"): # invalid port + self.fail("did not raise") + + def test_handshake_fails(self): + """Client connects to server but the handshake fails.""" + + def remove_accept_header(self, request, response): + del response.headers["Sec-WebSocket-Accept"] + + # The connection will be open for the server but failed for the client. + # Use a connection handler that exits immediately to avoid an exception. + with run_server(do_nothing, process_response=remove_accept_header) as server: + with self.assertRaises(InvalidHandshake) as raised: + with run_client(server, close_timeout=MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "missing Sec-WebSocket-Accept header", + ) + def test_timeout_during_handshake(self): """Client times out before receiving handshake response from server.""" gate = threading.Event() @@ -115,10 +121,7 @@ def stall_connection(self, request): with run_server(do_nothing, process_request=stall_connection) as server: try: with self.assertRaises(TimeoutError) as raised: - # While it shouldn't take 50ms to open a connection, this - # test becomes flaky in CI when setting a smaller timeout, - # even after increasing WEBSOCKETS_TESTS_TIMEOUT_FACTOR. - with run_client(server, open_timeout=5 * MS): + with run_client(server, open_timeout=2 * MS): self.fail("did not raise") self.assertEqual( str(raised.exception), diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 490a3f63e..9d509a5c4 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -39,21 +39,6 @@ def test_connection(self): with run_client(server) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") - def test_connection_fails(self): - """Server receives connection from client but the handshake fails.""" - - def remove_key_header(self, request): - del request.headers["Sec-WebSocket-Key"] - - with run_server(process_request=remove_key_header) as server: - with self.assertRaises(InvalidStatus) as raised: - with run_client(server): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 400", - ) - def test_connection_handler_returns(self): """Connection handler returns.""" with run_server(do_nothing) as server: @@ -81,8 +66,8 @@ def test_existing_socket(self): """Server receives connection using a pre-existing socket.""" with socket.create_server(("localhost", 0)) as sock: with run_server(sock=sock): - # Build WebSocket URI to ensure we connect to the right socket. - with run_client("ws://{}:{}/".format(*sock.getsockname())) as client: + uri = "ws://{}:{}/".format(*sock.getsockname()) + with run_client(uri) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") def test_select_subprotocol(self): @@ -185,7 +170,7 @@ def process_response(ws, request, response): self.assertEval(client, "ws.process_response_ran", "True") def test_process_response_override_response(self): - """Server runs process_response after processing the handshake.""" + """Server runs process_response and overrides the handshake response.""" def process_response(ws, request, response): headers = response.headers.copy() @@ -253,6 +238,21 @@ def create_connection(*args, **kwargs): with run_client(server) as client: self.assertEval(client, "ws.create_connection_ran", "True") + def test_handshake_fails(self): + """Server receives connection from client but the handshake fails.""" + + def remove_key_header(self, request): + del request.headers["Sec-WebSocket-Key"] + + with run_server(process_request=remove_key_header) as server: + with self.assertRaises(InvalidStatus) as raised: + with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) + def test_timeout_during_handshake(self): """Server times out before receiving handshake request from client.""" with run_server(open_timeout=MS) as server: From 50b6d20d7a652d39cffc7aea9f8c0abc88fb8f37 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Feb 2024 21:19:24 +0100 Subject: [PATCH 036/109] Various cleanups in sync implementation. --- src/websockets/sync/client.py | 9 +++-- src/websockets/sync/connection.py | 58 +++++++++++++++---------------- src/websockets/sync/server.py | 27 +++++++------- 3 files changed, 45 insertions(+), 49 deletions(-) diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 6faca7789..0bb7a76fd 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -25,7 +25,7 @@ class ClientConnection(Connection): """ - Threaded implementation of a WebSocket client connection. + :mod:`threading` implementation of a WebSocket client connection. :class:`ClientConnection` provides :meth:`recv` and :meth:`send` methods for receiving and sending messages. @@ -157,7 +157,7 @@ def connect( :func:`connect` may be used as a context manager:: - async with websockets.sync.client.connect(...) as websocket: + with websockets.sync.client.connect(...) as websocket: ... The connection is closed automatically when exiting the context. @@ -273,19 +273,18 @@ def connect( sock = ssl.wrap_socket(sock, server_hostname=server_hostname) sock.settimeout(None) - # Initialize WebSocket connection + # Initialize WebSocket protocol protocol = ClientProtocol( wsuri, origin=origin, extensions=extensions, subprotocols=subprotocols, - state=CONNECTING, max_size=max_size, logger=logger, ) - # Initialize WebSocket protocol + # Initialize WebSocket connection connection = create_connection( sock, diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 62aa17ffd..6ac40cd7c 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -21,12 +21,10 @@ __all__ = ["Connection"] -logger = logging.getLogger(__name__) - class Connection: """ - Threaded implementation of a WebSocket connection. + :mod:`threading` implementation of a WebSocket connection. :class:`Connection` provides APIs shared between WebSocket servers and clients. @@ -82,7 +80,7 @@ def __init__( self.close_deadline: Optional[Deadline] = None # Mapping of ping IDs to pong waiters, in chronological order. - self.pings: Dict[bytes, threading.Event] = {} + self.ping_waiters: Dict[bytes, threading.Event] = {} # Receiving events from the socket. self.recv_events_thread = threading.Thread(target=self.recv_events) @@ -90,7 +88,7 @@ def __init__( # Exception raised in recv_events, to be chained to ConnectionClosed # in the user thread in order to show why the TCP connection dropped. - self.recv_events_exc: Optional[BaseException] = None + self.recv_exc: Optional[BaseException] = None # Public attributes @@ -198,7 +196,7 @@ def recv(self, timeout: Optional[float] = None) -> Data: try: return self.recv_messages.get(timeout) except EOFError: - raise self.protocol.close_exc from self.recv_events_exc + raise self.protocol.close_exc from self.recv_exc except RuntimeError: raise RuntimeError( "cannot call recv while another thread " @@ -229,9 +227,10 @@ def recv_streaming(self) -> Iterator[Data]: """ try: - yield from self.recv_messages.get_iter() + for frame in self.recv_messages.get_iter(): + yield frame except EOFError: - raise self.protocol.close_exc from self.recv_events_exc + raise self.protocol.close_exc from self.recv_exc except RuntimeError: raise RuntimeError( "cannot call recv_streaming while another thread " @@ -273,7 +272,7 @@ def send(self, message: Union[Data, Iterable[Data]]) -> None: Raises: ConnectionClosed: When the connection is closed. - RuntimeError: If a connection is busy sending a fragmented message. + RuntimeError: If the connection is sending a fragmented message. TypeError: If ``message`` doesn't have a supported type. """ @@ -449,15 +448,15 @@ def ping(self, data: Optional[Data] = None) -> threading.Event: with self.send_context(): # Protect against duplicates if a payload is explicitly set. - if data in self.pings: + if data in self.ping_waiters: raise RuntimeError("already waiting for a pong with the same data") # Generate a unique random payload otherwise. - while data is None or data in self.pings: + while data is None or data in self.ping_waiters: data = struct.pack("!I", random.getrandbits(32)) pong_waiter = threading.Event() - self.pings[data] = pong_waiter + self.ping_waiters[data] = pong_waiter self.protocol.send_ping(data) return pong_waiter @@ -504,22 +503,22 @@ def acknowledge_pings(self, data: bytes) -> None: """ with self.protocol_mutex: # Ignore unsolicited pong. - if data not in self.pings: + if data not in self.ping_waiters: return # Sending a pong for only the most recent ping is legal. # Acknowledge all previous pings too in that case. ping_id = None ping_ids = [] - for ping_id, ping in self.pings.items(): + for ping_id, ping in self.ping_waiters.items(): ping_ids.append(ping_id) ping.set() if ping_id == data: break else: raise AssertionError("solicited pong not found in pings") - # Remove acknowledged pings from self.pings. + # Remove acknowledged pings from self.ping_waiters. for ping_id in ping_ids: - del self.pings[ping_id] + del self.ping_waiters[ping_id] def recv_events(self) -> None: """ @@ -541,10 +540,10 @@ def recv_events(self) -> None: self.logger.debug("error while receiving data", exc_info=True) # When the closing handshake is initiated by our side, # recv() may block until send_context() closes the socket. - # In that case, send_context() already set recv_events_exc. - # Calling set_recv_events_exc() avoids overwriting it. + # In that case, send_context() already set recv_exc. + # Calling set_recv_exc() avoids overwriting it. with self.protocol_mutex: - self.set_recv_events_exc(exc) + self.set_recv_exc(exc) break if data == b"": @@ -552,7 +551,7 @@ def recv_events(self) -> None: # Acquire the connection lock. with self.protocol_mutex: - # Feed incoming data to the connection. + # Feed incoming data to the protocol. self.protocol.receive_data(data) # This isn't expected to raise an exception. @@ -568,7 +567,7 @@ def recv_events(self) -> None: # set by send_context(), in case of a race condition # i.e. send_context() closes the socket after recv() # returns above but before send_data() calls send(). - self.set_recv_events_exc(exc) + self.set_recv_exc(exc) break if self.protocol.close_expected(): @@ -595,7 +594,7 @@ def recv_events(self) -> None: # Breaking out of the while True: ... loop means that we believe # that the socket doesn't work anymore. with self.protocol_mutex: - # Feed the end of the data stream to the connection. + # Feed the end of the data stream to the protocol. self.protocol.receive_eof() # This isn't expected to generate events. @@ -609,7 +608,7 @@ def recv_events(self) -> None: # This branch should never run. It's a safety net in case of bugs. self.logger.error("unexpected internal error", exc_info=True) with self.protocol_mutex: - self.set_recv_events_exc(exc) + self.set_recv_exc(exc) # We don't know where we crashed. Force protocol state to CLOSED. self.protocol.state = CLOSED finally: @@ -668,7 +667,6 @@ def send_context( wait_for_close = True # If the connection is expected to close soon, set the # close deadline based on the close timeout. - # Since we tested earlier that protocol.state was OPEN # (or CONNECTING) and we didn't release protocol_mutex, # it is certain that self.close_deadline is still None. @@ -710,11 +708,11 @@ def send_context( # original_exc is never set when wait_for_close is True. assert original_exc is None original_exc = TimeoutError("timed out while closing connection") - # Set recv_events_exc before closing the socket in order to get + # Set recv_exc before closing the socket in order to get # proper exception reporting. raise_close_exc = True with self.protocol_mutex: - self.set_recv_events_exc(original_exc) + self.set_recv_exc(original_exc) # If an error occurred, close the socket to terminate the connection and # raise an exception. @@ -745,16 +743,16 @@ def send_data(self) -> None: except OSError: # socket already closed pass - def set_recv_events_exc(self, exc: Optional[BaseException]) -> None: + def set_recv_exc(self, exc: Optional[BaseException]) -> None: """ - Set recv_events_exc, if not set yet. + Set recv_exc, if not set yet. This method requires holding protocol_mutex. """ assert self.protocol_mutex.locked() - if self.recv_events_exc is None: - self.recv_events_exc = exc + if self.recv_exc is None: + self.recv_exc = exc def close_socket(self) -> None: """ diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index fa6087d54..a070edf18 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -30,7 +30,7 @@ class ServerConnection(Connection): """ - Threaded implementation of a WebSocket server connection. + :mod:`threading` implementation of a WebSocket server connection. :class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for receiving and sending messages. @@ -188,6 +188,8 @@ class WebSocketServer: handler: Handler for one connection. Receives the socket and address returned by :meth:`~socket.socket.accept`. logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. + See the :doc:`logging guide <../../topics/logging>` for details. """ @@ -311,16 +313,16 @@ def serve( Whenever a client connects, the server creates a :class:`ServerConnection`, performs the opening handshake, and delegates to the ``handler``. - The handler receives a :class:`ServerConnection` instance, which you can use - to send and receive messages. + The handler receives the :class:`ServerConnection` instance, which you can + use to send and receive messages. Once the handler completes, either normally or with an exception, the server performs the closing handshake and closes the connection. - :class:`WebSocketServer` mirrors the API of + This function returns a :class:`WebSocketServer` whose API mirrors :class:`~socketserver.BaseServer`. Treat it as a context manager to ensure - that it will be closed and call the :meth:`~WebSocketServer.serve_forever` - method to serve requests:: + that it will be closed and call :meth:`~WebSocketServer.serve_forever` to + serve requests:: def handler(websocket): ... @@ -454,15 +456,13 @@ def conn_handler(sock: socket.socket, addr: Any) -> None: sock.do_handshake() sock.settimeout(None) - # Create a closure so that select_subprotocol has access to self. - + # Create a closure to give select_subprotocol access to connection. protocol_select_subprotocol: Optional[ Callable[ [ServerProtocol, Sequence[Subprotocol]], Optional[Subprotocol], ] ] = None - if select_subprotocol is not None: def protocol_select_subprotocol( @@ -475,19 +475,18 @@ def protocol_select_subprotocol( assert protocol is connection.protocol return select_subprotocol(connection, subprotocols) - # Initialize WebSocket connection + # Initialize WebSocket protocol protocol = ServerProtocol( origins=origins, extensions=extensions, subprotocols=subprotocols, select_subprotocol=protocol_select_subprotocol, - state=CONNECTING, max_size=max_size, logger=logger, ) - # Initialize WebSocket protocol + # Initialize WebSocket connection assert create_connection is not None # help mypy connection = create_connection( @@ -522,7 +521,7 @@ def protocol_select_subprotocol( def unix_serve( - handler: Callable[[ServerConnection], Any], + handler: Callable[[ServerConnection], None], path: Optional[str] = None, **kwargs: Any, ) -> WebSocketServer: @@ -541,4 +540,4 @@ def unix_serve( path: File system path to the Unix socket. """ - return serve(handler, path=path, unix=True, **kwargs) + return serve(handler, unix=True, path=path, **kwargs) From e217458ef8b692e45ca6f66c5aeb7fad0aee97ee Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Feb 2024 21:20:48 +0100 Subject: [PATCH 037/109] Small cleanups in legacy implementation. --- src/websockets/legacy/client.py | 2 +- src/websockets/legacy/server.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 255696580..e5da8b13a 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -60,7 +60,7 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): :class:`WebSocketClientProtocol` provides :meth:`recv` and :meth:`send` coroutines for receiving and sending messages. - It supports asynchronous iteration to receive incoming messages:: + It supports asynchronous iteration to receive messages:: async for message in websocket: await process(message) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 4659ed9a6..0f3c1c150 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -649,9 +649,7 @@ class WebSocketServer: """ WebSocket server returned by :func:`serve`. - This class provides the same interface as :class:`~asyncio.Server`, - notably the :meth:`~asyncio.Server.close` - and :meth:`~asyncio.Server.wait_closed` methods. + This class mirrors the API of :class:`~asyncio.Server`. It keeps track of WebSocket connections in order to close them properly when shutting down. From 5f24866bfeefbe561fa76f7e5a494996d95a2757 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 16 Apr 2024 08:48:42 +0200 Subject: [PATCH 038/109] Always mark background threads as daemon. Fix #1455. --- src/websockets/sync/connection.py | 9 +++++++-- src/websockets/sync/server.py | 3 +++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 6ac40cd7c..b41202dc9 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -82,8 +82,13 @@ def __init__( # Mapping of ping IDs to pong waiters, in chronological order. self.ping_waiters: Dict[bytes, threading.Event] = {} - # Receiving events from the socket. - self.recv_events_thread = threading.Thread(target=self.recv_events) + # Receiving events from the socket. This thread explicitly is marked as + # to support creating a connection in a non-daemon thread then using it + # in a daemon thread; this shouldn't block the intpreter from exiting. + self.recv_events_thread = threading.Thread( + target=self.recv_events, + daemon=True, + ) self.recv_events_thread.start() # Exception raised in recv_events, to be chained to ConnectionClosed diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index a070edf18..fd4f5d3bd 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -233,6 +233,9 @@ def serve_forever(self) -> None: sock, addr = self.socket.accept() except OSError: break + # Since there isn't a mechanism for tracking connections and waiting + # for them to terminate, we cannot use daemon threads, or else all + # connections would be terminate brutally when closing the server. thread = threading.Thread(target=self.handler, args=(sock, addr)) thread.start() From 2774fabc13f09311dec345cc8513aa7b93200b92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexis=20M=C3=A9taireau?= Date: Tue, 16 Apr 2024 16:52:01 +0200 Subject: [PATCH 039/109] docs(nginx): Fix a typo --- docs/howto/nginx.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/howto/nginx.rst b/docs/howto/nginx.rst index 30545fbc7..ff42c3c2b 100644 --- a/docs/howto/nginx.rst +++ b/docs/howto/nginx.rst @@ -17,9 +17,9 @@ Save this app to ``app.py``: .. literalinclude:: ../../example/deployment/nginx/app.py :emphasize-lines: 21,23 -We'd like to nginx to connect to websockets servers via Unix sockets in order -to avoid the overhead of TCP for communicating between processes running in -the same OS. +We'd like nginx to connect to websockets servers via Unix sockets in order to +avoid the overhead of TCP for communicating between processes running in the +same OS. We start the app with :func:`~websockets.server.unix_serve`. Each server process listens on a different socket thanks to an environment variable set From 0fdc694a980ede0e91286ea5ea1d4f9c62bb42fb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 16 Apr 2024 08:55:36 +0200 Subject: [PATCH 040/109] Make it easy to monkey-patch length of frames repr. Fix #1451. --- src/websockets/frames.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 201bc5068..862eef3aa 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -146,6 +146,9 @@ class Frame: rsv2: bool = False rsv3: bool = False + # Monkey-patch if you want to see more in logs. Should be a multiple of 3. + MAX_LOG = 75 + def __str__(self) -> str: """ Return a human-readable representation of a frame. @@ -163,8 +166,9 @@ def __str__(self) -> str: # We'll show at most the first 16 bytes and the last 8 bytes. # Encode just what we need, plus two dummy bytes to elide later. binary = self.data - if len(binary) > 25: - binary = b"".join([binary[:16], b"\x00\x00", binary[-8:]]) + if len(binary) > self.MAX_LOG // 3: + cut = (self.MAX_LOG // 3 - 1) // 3 # by default cut = 8 + binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]]) data = " ".join(f"{byte:02x}" for byte in binary) elif self.opcode is OP_CLOSE: data = str(Close.parse(self.data)) @@ -179,15 +183,17 @@ def __str__(self) -> str: coding = "text" except (UnicodeDecodeError, AttributeError): binary = self.data - if len(binary) > 25: - binary = b"".join([binary[:16], b"\x00\x00", binary[-8:]]) + if len(binary) > self.MAX_LOG // 3: + cut = (self.MAX_LOG // 3 - 1) // 3 # by default cut = 8 + binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]]) data = " ".join(f"{byte:02x}" for byte in binary) coding = "binary" else: data = "''" - if len(data) > 75: - data = data[:48] + "..." + data[-24:] + if len(data) > self.MAX_LOG: + cut = self.MAX_LOG // 3 - 1 # by default cut = 24 + data = data[: 2 * cut] + "..." + data[-cut:] metadata = ", ".join(filter(None, [coding, length, non_final])) From f0398141d2efd28f64d8e1d6d9adc179a9e5e334 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 16 Apr 2024 19:41:14 +0200 Subject: [PATCH 041/109] Bump asyncio_timeout to 4.0.3. This makes type checking pass again. --- src/websockets/legacy/async_timeout.py | 39 ++++++++++++++++++-------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/src/websockets/legacy/async_timeout.py b/src/websockets/legacy/async_timeout.py index 8264094f5..6ffa89969 100644 --- a/src/websockets/legacy/async_timeout.py +++ b/src/websockets/legacy/async_timeout.py @@ -9,12 +9,12 @@ from typing import Optional, Type -# From https://github.com/python/typing_extensions/blob/main/src/typing_extensions.py -# Licensed under the Python Software Foundation License (PSF-2.0) - if sys.version_info >= (3, 11): from typing import final else: + # From https://github.com/python/typing_extensions/blob/main/src/typing_extensions.py + # Licensed under the Python Software Foundation License (PSF-2.0) + # @final exists in 3.8+, but we backport it for all versions # before 3.11 to keep support for the __final__ attribute. # See https://bugs.python.org/issue46342 @@ -49,10 +49,21 @@ class Other(Leaf): # Error reported by type checker pass return f + # End https://github.com/python/typing_extensions/blob/main/src/typing_extensions.py + + +if sys.version_info >= (3, 11): + + def _uncancel_task(task: "asyncio.Task[object]") -> None: + task.uncancel() + +else: + + def _uncancel_task(task: "asyncio.Task[object]") -> None: + pass -# End https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py -__version__ = "4.0.2" +__version__ = "4.0.3" __all__ = ("timeout", "timeout_at", "Timeout") @@ -124,7 +135,7 @@ class Timeout: # The purpose is to time out as soon as possible # without waiting for the next await expression. - __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler") + __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler", "_task") def __init__( self, deadline: Optional[float], loop: asyncio.AbstractEventLoop @@ -132,6 +143,7 @@ def __init__( self._loop = loop self._state = _State.INIT + self._task: Optional["asyncio.Task[object]"] = None self._timeout_handler = None # type: Optional[asyncio.Handle] if deadline is None: self._deadline = None # type: Optional[float] @@ -187,6 +199,7 @@ def reject(self) -> None: self._reject() def _reject(self) -> None: + self._task = None if self._timeout_handler is not None: self._timeout_handler.cancel() self._timeout_handler = None @@ -234,11 +247,11 @@ def _reschedule(self) -> None: if self._timeout_handler is not None: self._timeout_handler.cancel() - task = asyncio.current_task() + self._task = asyncio.current_task() if deadline <= now: - self._timeout_handler = self._loop.call_soon(self._on_timeout, task) + self._timeout_handler = self._loop.call_soon(self._on_timeout) else: - self._timeout_handler = self._loop.call_at(deadline, self._on_timeout, task) + self._timeout_handler = self._loop.call_at(deadline, self._on_timeout) def _do_enter(self) -> None: if self._state != _State.INIT: @@ -248,15 +261,19 @@ def _do_enter(self) -> None: def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None: if exc_type is asyncio.CancelledError and self._state == _State.TIMEOUT: + assert self._task is not None + _uncancel_task(self._task) self._timeout_handler = None + self._task = None raise asyncio.TimeoutError # timeout has not expired self._state = _State.EXIT self._reject() return None - def _on_timeout(self, task: "asyncio.Task[None]") -> None: - task.cancel() + def _on_timeout(self) -> None: + assert self._task is not None + self._task.cancel() self._state = _State.TIMEOUT # drop the reference early self._timeout_handler = None From 33997631a04320a5f8d57fac0f2645dc2d654c29 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 12 Jul 2024 08:35:58 +0200 Subject: [PATCH 042/109] Update ruff. --- Makefile | 2 +- pyproject.toml | 4 ++-- tox.ini | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Makefile b/Makefile index bf8c8dc58..dacfe2a0b 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ default: style types tests style: black src tests - ruff --fix src tests + ruff check --fix src tests types: mypy --strict src diff --git a/pyproject.toml b/pyproject.toml index c4c5412c5..2367849ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,7 @@ exclude_lines = [ "@unittest.skip", ] -[tool.ruff] +[tool.ruff.lint] select = [ "E", # pycodestyle "F", # Pyflakes @@ -82,6 +82,6 @@ ignore = [ "F405", ] -[tool.ruff.isort] +[tool.ruff.lint.isort] combine-as-imports = true lines-after-imports = 2 diff --git a/tox.ini b/tox.ini index 538b638d9..b0e4a5931 100644 --- a/tox.ini +++ b/tox.ini @@ -32,7 +32,7 @@ commands = black --check src tests deps = black [testenv:ruff] -commands = ruff src tests +commands = ruff check src tests deps = ruff [testenv:mypy] From 2d195baaa632efd9fb87f09813d01af28464eb8c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 12 Jul 2024 09:01:24 +0200 Subject: [PATCH 043/109] Don't run tests on Python 3.7. Forgotten in 1bf73423. --- tox.ini | 1 - 1 file changed, 1 deletion(-) diff --git a/tox.ini b/tox.ini index b0e4a5931..06003c85b 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,5 @@ [tox] envlist = - py37 py38 py39 py310 From 7f402303fe1703767d9236494aacc3f197fbc708 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 12 Jul 2024 09:02:07 +0200 Subject: [PATCH 044/109] Switch from typing.Optional to | None. --- src/websockets/client.py | 16 +-- src/websockets/exceptions.py | 19 ++- src/websockets/extensions/base.py | 4 +- .../extensions/permessage_deflate.py | 32 ++--- src/websockets/frames.py | 8 +- src/websockets/headers.py | 6 +- src/websockets/http11.py | 14 +- src/websockets/imports.py | 6 +- src/websockets/legacy/auth.py | 18 +-- src/websockets/legacy/client.py | 77 +++++----- src/websockets/legacy/framing.py | 8 +- src/websockets/legacy/protocol.py | 65 ++++----- src/websockets/legacy/server.py | 135 +++++++++--------- src/websockets/protocol.py | 28 ++-- src/websockets/server.py | 35 ++--- src/websockets/sync/client.py | 44 +++--- src/websockets/sync/connection.py | 28 ++-- src/websockets/sync/messages.py | 12 +- src/websockets/sync/server.py | 92 ++++++------ src/websockets/sync/utils.py | 7 +- src/websockets/typing.py | 2 +- src/websockets/uri.py | 8 +- 22 files changed, 334 insertions(+), 330 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 633b1960b..cfb441fd9 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import Any, Generator, List, Optional, Sequence +from typing import Any, Generator, List, Sequence from .datastructures import Headers, MultipleValuesError from .exceptions import ( @@ -73,12 +73,12 @@ def __init__( self, wsuri: WebSocketURI, *, - origin: Optional[Origin] = None, - extensions: Optional[Sequence[ClientExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, state: State = CONNECTING, - max_size: Optional[int] = 2**20, - logger: Optional[LoggerLike] = None, + max_size: int | None = 2**20, + logger: LoggerLike | None = None, ): super().__init__( side=CLIENT, @@ -261,7 +261,7 @@ def process_extensions(self, headers: Headers) -> List[Extension]: return accepted_extensions - def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: + def process_subprotocol(self, headers: Headers) -> Subprotocol | None: """ Handle the Sec-WebSocket-Protocol HTTP response header. @@ -274,7 +274,7 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: Subprotocol, if one was selected. """ - subprotocol: Optional[Subprotocol] = None + subprotocol: Subprotocol | None = None subprotocols = headers.get_all("Sec-WebSocket-Protocol") diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index f7169e3b1..adb66e262 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -31,7 +31,6 @@ from __future__ import annotations import http -from typing import Optional from . import datastructures, frames, http11 from .typing import StatusLike @@ -78,11 +77,11 @@ class ConnectionClosed(WebSocketException): Raised when trying to interact with a closed connection. Attributes: - rcvd (Optional[Close]): if a close frame was received, its code and + rcvd (Close | None): if a close frame was received, its code and reason are available in ``rcvd.code`` and ``rcvd.reason``. - sent (Optional[Close]): if a close frame was sent, its code and reason + sent (Close | None): if a close frame was sent, its code and reason are available in ``sent.code`` and ``sent.reason``. - rcvd_then_sent (Optional[bool]): if close frames were received and + rcvd_then_sent (bool | None): if close frames were received and sent, this attribute tells in which order this happened, from the perspective of this side of the connection. @@ -90,9 +89,9 @@ class ConnectionClosed(WebSocketException): def __init__( self, - rcvd: Optional[frames.Close], - sent: Optional[frames.Close], - rcvd_then_sent: Optional[bool] = None, + rcvd: frames.Close | None, + sent: frames.Close | None, + rcvd_then_sent: bool | None = None, ) -> None: self.rcvd = rcvd self.sent = sent @@ -181,7 +180,7 @@ class InvalidHeader(InvalidHandshake): """ - def __init__(self, name: str, value: Optional[str] = None) -> None: + def __init__(self, name: str, value: str | None = None) -> None: self.name = name self.value = value @@ -221,7 +220,7 @@ class InvalidOrigin(InvalidHeader): """ - def __init__(self, origin: Optional[str]) -> None: + def __init__(self, origin: str | None) -> None: super().__init__("Origin", origin) @@ -301,7 +300,7 @@ class InvalidParameterValue(NegotiationError): """ - def __init__(self, name: str, value: Optional[str]) -> None: + def __init__(self, name: str, value: str | None) -> None: self.name = name self.value = value diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 7446c990c..5b5528a09 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Optional, Sequence, Tuple +from typing import List, Sequence, Tuple from .. import frames from ..typing import ExtensionName, ExtensionParameter @@ -22,7 +22,7 @@ def decode( self, frame: frames.Frame, *, - max_size: Optional[int] = None, + max_size: int | None = None, ) -> frames.Frame: """ Decode an incoming frame. diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index edccac3ca..e95b1064b 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -2,7 +2,7 @@ import dataclasses import zlib -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Sequence, Tuple, Union from .. import exceptions, frames from ..typing import ExtensionName, ExtensionParameter @@ -36,7 +36,7 @@ def __init__( local_no_context_takeover: bool, remote_max_window_bits: int, local_max_window_bits: int, - compress_settings: Optional[Dict[Any, Any]] = None, + compress_settings: Dict[Any, Any] | None = None, ) -> None: """ Configure the Per-Message Deflate extension. @@ -84,7 +84,7 @@ def decode( self, frame: frames.Frame, *, - max_size: Optional[int] = None, + max_size: int | None = None, ) -> frames.Frame: """ Decode an incoming frame. @@ -174,8 +174,8 @@ def encode(self, frame: frames.Frame) -> frames.Frame: def _build_parameters( server_no_context_takeover: bool, client_no_context_takeover: bool, - server_max_window_bits: Optional[int], - client_max_window_bits: Optional[Union[int, bool]], + server_max_window_bits: int | None, + client_max_window_bits: Union[int, bool] | None, ) -> List[ExtensionParameter]: """ Build a list of ``(name, value)`` pairs for some compression parameters. @@ -197,7 +197,7 @@ def _build_parameters( def _extract_parameters( params: Sequence[ExtensionParameter], *, is_server: bool -) -> Tuple[bool, bool, Optional[int], Optional[Union[int, bool]]]: +) -> Tuple[bool, bool, int | None, Union[int, bool] | None]: """ Extract compression parameters from a list of ``(name, value)`` pairs. @@ -207,8 +207,8 @@ def _extract_parameters( """ server_no_context_takeover: bool = False client_no_context_takeover: bool = False - server_max_window_bits: Optional[int] = None - client_max_window_bits: Optional[Union[int, bool]] = None + server_max_window_bits: int | None = None + client_max_window_bits: Union[int, bool] | None = None for name, value in params: if name == "server_no_context_takeover": @@ -286,9 +286,9 @@ def __init__( self, server_no_context_takeover: bool = False, client_no_context_takeover: bool = False, - server_max_window_bits: Optional[int] = None, - client_max_window_bits: Optional[Union[int, bool]] = True, - compress_settings: Optional[Dict[str, Any]] = None, + server_max_window_bits: int | None = None, + client_max_window_bits: Union[int, bool] | None = True, + compress_settings: Dict[str, Any] | None = None, ) -> None: """ Configure the Per-Message Deflate extension factory. @@ -433,7 +433,7 @@ def process_response_params( def enable_client_permessage_deflate( - extensions: Optional[Sequence[ClientExtensionFactory]], + extensions: Sequence[ClientExtensionFactory] | None, ) -> Sequence[ClientExtensionFactory]: """ Enable Per-Message Deflate with default settings in client extensions. @@ -489,9 +489,9 @@ def __init__( self, server_no_context_takeover: bool = False, client_no_context_takeover: bool = False, - server_max_window_bits: Optional[int] = None, - client_max_window_bits: Optional[int] = None, - compress_settings: Optional[Dict[str, Any]] = None, + server_max_window_bits: int | None = None, + client_max_window_bits: int | None = None, + compress_settings: Dict[str, Any] | None = None, require_client_max_window_bits: bool = False, ) -> None: """ @@ -635,7 +635,7 @@ def process_request_params( def enable_server_permessage_deflate( - extensions: Optional[Sequence[ServerExtensionFactory]], + extensions: Sequence[ServerExtensionFactory] | None, ) -> Sequence[ServerExtensionFactory]: """ Enable Per-Message Deflate with default settings in server extensions. diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 862eef3aa..5a304d6a7 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -5,7 +5,7 @@ import io import secrets import struct -from typing import Callable, Generator, Optional, Sequence, Tuple +from typing import Callable, Generator, Sequence, Tuple from . import exceptions, extensions from .typing import Data @@ -205,8 +205,8 @@ def parse( read_exact: Callable[[int], Generator[None, None, bytes]], *, mask: bool, - max_size: Optional[int] = None, - extensions: Optional[Sequence[extensions.Extension]] = None, + max_size: int | None = None, + extensions: Sequence[extensions.Extension] | None = None, ) -> Generator[None, None, Frame]: """ Parse a WebSocket frame. @@ -280,7 +280,7 @@ def serialize( self, *, mask: bool, - extensions: Optional[Sequence[extensions.Extension]] = None, + extensions: Sequence[extensions.Extension] | None = None, ) -> bytes: """ Serialize a WebSocket frame. diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 463df3061..3b316e0bf 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -4,7 +4,7 @@ import binascii import ipaddress import re -from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, cast +from typing import Callable, List, Sequence, Tuple, TypeVar, cast from . import exceptions from .typing import ( @@ -63,7 +63,7 @@ def build_host(host: str, port: int, secure: bool) -> str: # https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B. -def peek_ahead(header: str, pos: int) -> Optional[str]: +def peek_ahead(header: str, pos: int) -> str | None: """ Return the next character from ``header`` at the given position. @@ -314,7 +314,7 @@ def parse_extension_item_param( name, pos = parse_token(header, pos, header_name) pos = parse_OWS(header, pos) # Extract parameter value, if there is one. - value: Optional[str] = None + value: str | None = None if peek_ahead(header, pos) == "=": pos = parse_OWS(header, pos + 1) if peek_ahead(header, pos) == '"': diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 6fe775eec..a7e9ae682 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -3,7 +3,7 @@ import dataclasses import re import warnings -from typing import Callable, Generator, Optional +from typing import Callable, Generator from . import datastructures, exceptions @@ -62,10 +62,10 @@ class Request: headers: datastructures.Headers # body isn't useful is the context of this library. - _exception: Optional[Exception] = None + _exception: Exception | None = None @property - def exception(self) -> Optional[Exception]: # pragma: no cover + def exception(self) -> Exception | None: # pragma: no cover warnings.warn( "Request.exception is deprecated; " "use ServerProtocol.handshake_exc instead", @@ -164,12 +164,12 @@ class Response: status_code: int reason_phrase: str headers: datastructures.Headers - body: Optional[bytes] = None + body: bytes | None = None - _exception: Optional[Exception] = None + _exception: Exception | None = None @property - def exception(self) -> Optional[Exception]: # pragma: no cover + def exception(self) -> Exception | None: # pragma: no cover warnings.warn( "Response.exception is deprecated; " "use ClientProtocol.handshake_exc instead", @@ -245,7 +245,7 @@ def parse( if 100 <= status_code < 200 or status_code == 204 or status_code == 304: body = None else: - content_length: Optional[int] + content_length: int | None try: # MultipleValuesError is sufficiently unlikely that we don't # attempt to handle it. Instead we document that its parent diff --git a/src/websockets/imports.py b/src/websockets/imports.py index a6a59d4c2..9c05234f5 100644 --- a/src/websockets/imports.py +++ b/src/websockets/imports.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import Any, Dict, Iterable, Optional +from typing import Any, Dict, Iterable __all__ = ["lazy_import"] @@ -30,8 +30,8 @@ def import_name(name: str, source: str, namespace: Dict[str, Any]) -> Any: def lazy_import( namespace: Dict[str, Any], - aliases: Optional[Dict[str, str]] = None, - deprecated_aliases: Optional[Dict[str, str]] = None, + aliases: Dict[str, str] | None = None, + deprecated_aliases: Dict[str, str] | None = None, ) -> None: """ Provide lazy, module-level imports. diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 8217afedd..067f9c78c 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -3,7 +3,7 @@ import functools import hmac import http -from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Union, cast +from typing import Any, Awaitable, Callable, Iterable, Tuple, Union, cast from ..datastructures import Headers from ..exceptions import InvalidHeader @@ -39,14 +39,14 @@ class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol): encoding of non-ASCII characters is undefined. """ - username: Optional[str] = None + username: str | None = None """Username of the authenticated user.""" def __init__( self, *args: Any, - realm: Optional[str] = None, - check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None, + realm: str | None = None, + check_credentials: Callable[[str, str], Awaitable[bool]] | None = None, **kwargs: Any, ) -> None: if realm is not None: @@ -79,7 +79,7 @@ async def process_request( self, path: str, request_headers: Headers, - ) -> Optional[HTTPResponse]: + ) -> HTTPResponse | None: """ Check HTTP Basic Auth and return an HTTP 401 response if needed. @@ -115,10 +115,10 @@ async def process_request( def basic_auth_protocol_factory( - realm: Optional[str] = None, - credentials: Optional[Union[Credentials, Iterable[Credentials]]] = None, - check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None, - create_protocol: Optional[Callable[..., BasicAuthWebSocketServerProtocol]] = None, + realm: str | None = None, + credentials: Union[Credentials, Iterable[Credentials]] | None = None, + check_credentials: Callable[[str, str], Awaitable[bool]] | None = None, + create_protocol: Callable[..., BasicAuthWebSocketServerProtocol] | None = None, ) -> Callable[..., BasicAuthWebSocketServerProtocol]: """ Protocol factory that enforces HTTP Basic Auth. diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index e5da8b13a..f7464368f 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -13,7 +13,6 @@ Callable, Generator, List, - Optional, Sequence, Tuple, Type, @@ -86,12 +85,12 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): def __init__( self, *, - logger: Optional[LoggerLike] = None, - origin: Optional[Origin] = None, - extensions: Optional[Sequence[ClientExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLike] = None, - user_agent_header: Optional[str] = USER_AGENT, + logger: LoggerLike | None = None, + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, **kwargs: Any, ) -> None: if logger is None: @@ -152,7 +151,7 @@ async def read_http_response(self) -> Tuple[int, Headers]: @staticmethod def process_extensions( headers: Headers, - available_extensions: Optional[Sequence[ClientExtensionFactory]], + available_extensions: Sequence[ClientExtensionFactory] | None, ) -> List[Extension]: """ Handle the Sec-WebSocket-Extensions HTTP response header. @@ -224,8 +223,8 @@ def process_extensions( @staticmethod def process_subprotocol( - headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]] - ) -> Optional[Subprotocol]: + headers: Headers, available_subprotocols: Sequence[Subprotocol] | None + ) -> Subprotocol | None: """ Handle the Sec-WebSocket-Protocol HTTP response header. @@ -234,7 +233,7 @@ def process_subprotocol( Return the selected subprotocol. """ - subprotocol: Optional[Subprotocol] = None + subprotocol: Subprotocol | None = None header_values = headers.get_all("Sec-WebSocket-Protocol") @@ -260,10 +259,10 @@ def process_subprotocol( async def handshake( self, wsuri: WebSocketURI, - origin: Optional[Origin] = None, - available_extensions: Optional[Sequence[ClientExtensionFactory]] = None, - available_subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLike] = None, + origin: Origin | None = None, + available_extensions: Sequence[ClientExtensionFactory] | None = None, + available_subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLike | None = None, ) -> None: """ Perform the client side of the opening handshake. @@ -427,26 +426,26 @@ def __init__( self, uri: str, *, - create_protocol: Optional[Callable[..., WebSocketClientProtocol]] = None, - logger: Optional[LoggerLike] = None, - compression: Optional[str] = "deflate", - origin: Optional[Origin] = None, - extensions: Optional[Sequence[ClientExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLike] = None, - user_agent_header: Optional[str] = USER_AGENT, - open_timeout: Optional[float] = 10, - ping_interval: Optional[float] = 20, - ping_timeout: Optional[float] = 20, - close_timeout: Optional[float] = None, - max_size: Optional[int] = 2**20, - max_queue: Optional[int] = 2**5, + create_protocol: Callable[..., WebSocketClientProtocol] | None = None, + logger: LoggerLike | None = None, + compression: str | None = "deflate", + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = None, + max_size: int | None = 2**20, + max_queue: int | None = 2**5, read_limit: int = 2**16, write_limit: int = 2**16, **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. - timeout: Optional[float] = kwargs.pop("timeout", None) + timeout: float | None = kwargs.pop("timeout", None) if timeout is None: timeout = 10 else: @@ -456,7 +455,7 @@ def __init__( close_timeout = timeout # Backwards compatibility: create_protocol used to be called klass. - klass: Optional[Type[WebSocketClientProtocol]] = kwargs.pop("klass", None) + klass: Type[WebSocketClientProtocol] | None = kwargs.pop("klass", None) if klass is None: klass = WebSocketClientProtocol else: @@ -469,7 +468,7 @@ def __init__( legacy_recv: bool = kwargs.pop("legacy_recv", False) # Backwards compatibility: the loop parameter used to be supported. - _loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) + _loop: asyncio.AbstractEventLoop | None = kwargs.pop("loop", None) if _loop is None: loop = asyncio.get_event_loop() else: @@ -516,13 +515,13 @@ def __init__( ) if kwargs.pop("unix", False): - path: Optional[str] = kwargs.pop("path", None) + path: str | None = kwargs.pop("path", None) create_connection = functools.partial( loop.create_unix_connection, factory, path, **kwargs ) else: - host: Optional[str] - port: Optional[int] + host: str | None + port: int | None if kwargs.get("sock") is None: host, port = wsuri.host, wsuri.port else: @@ -630,9 +629,9 @@ async def __aenter__(self) -> WebSocketClientProtocol: async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: await self.protocol.close() @@ -679,7 +678,7 @@ async def __await_impl__(self) -> WebSocketClientProtocol: def unix_connect( - path: Optional[str] = None, + path: str | None = None, uri: str = "ws://localhost/", **kwargs: Any, ) -> Connect: diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index b77b869e3..8a13fa446 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -1,7 +1,7 @@ from __future__ import annotations import struct -from typing import Any, Awaitable, Callable, NamedTuple, Optional, Sequence, Tuple +from typing import Any, Awaitable, Callable, NamedTuple, Sequence, Tuple from .. import extensions, frames from ..exceptions import PayloadTooBig, ProtocolError @@ -44,8 +44,8 @@ async def read( reader: Callable[[int], Awaitable[bytes]], *, mask: bool, - max_size: Optional[int] = None, - extensions: Optional[Sequence[extensions.Extension]] = None, + max_size: int | None = None, + extensions: Sequence[extensions.Extension] | None = None, ) -> Frame: """ Read a WebSocket frame. @@ -122,7 +122,7 @@ def write( write: Callable[[bytes], Any], *, mask: bool, - extensions: Optional[Sequence[extensions.Extension]] = None, + extensions: Sequence[extensions.Extension] | None = None, ) -> None: """ Write a WebSocket frame. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 26d50a2cc..94d42cfdb 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -22,7 +22,6 @@ Iterable, List, Mapping, - Optional, Tuple, Union, cast, @@ -173,21 +172,21 @@ class WebSocketCommonProtocol(asyncio.Protocol): def __init__( self, *, - logger: Optional[LoggerLike] = None, - ping_interval: Optional[float] = 20, - ping_timeout: Optional[float] = 20, - close_timeout: Optional[float] = None, - max_size: Optional[int] = 2**20, - max_queue: Optional[int] = 2**5, + logger: LoggerLike | None = None, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = None, + max_size: int | None = 2**20, + max_queue: int | None = 2**5, read_limit: int = 2**16, write_limit: int = 2**16, # The following arguments are kept only for backwards compatibility. - host: Optional[str] = None, - port: Optional[int] = None, - secure: Optional[bool] = None, + host: str | None = None, + port: int | None = None, + secure: bool | None = None, legacy_recv: bool = False, - loop: Optional[asyncio.AbstractEventLoop] = None, - timeout: Optional[float] = None, + loop: asyncio.AbstractEventLoop | None = None, + timeout: float | None = None, ) -> None: if legacy_recv: # pragma: no cover warnings.warn("legacy_recv is deprecated", DeprecationWarning) @@ -243,7 +242,7 @@ def __init__( # Copied from asyncio.FlowControlMixin self._paused = False - self._drain_waiter: Optional[asyncio.Future[None]] = None + self._drain_waiter: asyncio.Future[None] | None = None self._drain_lock = asyncio.Lock() @@ -265,13 +264,13 @@ def __init__( # WebSocket protocol parameters. self.extensions: List[Extension] = [] - self.subprotocol: Optional[Subprotocol] = None + self.subprotocol: Subprotocol | None = None """Subprotocol, if one was negotiated.""" # Close code and reason, set when a close frame is sent or received. - self.close_rcvd: Optional[Close] = None - self.close_sent: Optional[Close] = None - self.close_rcvd_then_sent: Optional[bool] = None + self.close_rcvd: Close | None = None + self.close_sent: Close | None = None + self.close_rcvd_then_sent: bool | None = None # Completed when the connection state becomes CLOSED. Translates the # :meth:`connection_lost` callback to a :class:`~asyncio.Future` @@ -281,11 +280,11 @@ def __init__( # Queue of received messages. self.messages: Deque[Data] = collections.deque() - self._pop_message_waiter: Optional[asyncio.Future[None]] = None - self._put_message_waiter: Optional[asyncio.Future[None]] = None + self._pop_message_waiter: asyncio.Future[None] | None = None + self._put_message_waiter: asyncio.Future[None] | None = None # Protect sending fragmented messages. - self._fragmented_message_waiter: Optional[asyncio.Future[None]] = None + self._fragmented_message_waiter: asyncio.Future[None] | None = None # Mapping of ping IDs to pong waiters, in chronological order. self.pings: Dict[bytes, Tuple[asyncio.Future[float], float]] = {} @@ -306,7 +305,7 @@ def __init__( self.transfer_data_task: asyncio.Task[None] # Exception that occurred during data transfer, if any. - self.transfer_data_exc: Optional[BaseException] = None + self.transfer_data_exc: BaseException | None = None # Task sending keepalive pings. self.keepalive_ping_task: asyncio.Task[None] @@ -363,19 +362,19 @@ def connection_open(self) -> None: self.close_connection_task = self.loop.create_task(self.close_connection()) @property - def host(self) -> Optional[str]: + def host(self) -> str | None: alternative = "remote_address" if self.is_client else "local_address" warnings.warn(f"use {alternative}[0] instead of host", DeprecationWarning) return self._host @property - def port(self) -> Optional[int]: + def port(self) -> int | None: alternative = "remote_address" if self.is_client else "local_address" warnings.warn(f"use {alternative}[1] instead of port", DeprecationWarning) return self._port @property - def secure(self) -> Optional[bool]: + def secure(self) -> bool | None: warnings.warn("don't use secure", DeprecationWarning) return self._secure @@ -447,7 +446,7 @@ def closed(self) -> bool: return self.state is State.CLOSED @property - def close_code(self) -> Optional[int]: + def close_code(self) -> int | None: """ WebSocket close code, defined in `section 7.1.5 of RFC 6455`_. @@ -465,7 +464,7 @@ def close_code(self) -> Optional[int]: return self.close_rcvd.code @property - def close_reason(self) -> Optional[str]: + def close_reason(self) -> str | None: """ WebSocket close reason, defined in `section 7.1.6 of RFC 6455`_. @@ -804,7 +803,7 @@ async def wait_closed(self) -> None: """ await asyncio.shield(self.connection_lost_waiter) - async def ping(self, data: Optional[Data] = None) -> Awaitable[float]: + async def ping(self, data: Data | None = None) -> Awaitable[float]: """ Send a Ping_. @@ -1017,7 +1016,7 @@ async def transfer_data(self) -> None: self.transfer_data_exc = exc self.fail_connection(CloseCode.INTERNAL_ERROR) - async def read_message(self) -> Optional[Data]: + async def read_message(self) -> Data | None: """ Read a single message from the connection. @@ -1090,7 +1089,7 @@ def append(frame: Frame) -> None: return ("" if text else b"").join(fragments) - async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: + async def read_data_frame(self, max_size: int | None) -> Frame | None: """ Read a single data frame from the connection. @@ -1153,7 +1152,7 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: else: return frame - async def read_frame(self, max_size: Optional[int]) -> Frame: + async def read_frame(self, max_size: int | None) -> Frame: """ Read a single frame from the connection. @@ -1204,9 +1203,7 @@ async def write_frame( self.write_frame_sync(fin, opcode, data) await self.drain() - async def write_close_frame( - self, close: Close, data: Optional[bytes] = None - ) -> None: + async def write_close_frame(self, close: Close, data: bytes | None = None) -> None: """ Write a close frame if and only if the connection state is OPEN. @@ -1484,7 +1481,7 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: # Copied from asyncio.StreamReaderProtocol self.reader.set_transport(transport) - def connection_lost(self, exc: Optional[Exception]) -> None: + def connection_lost(self, exc: Exception | None) -> None: """ 7.1.4. The WebSocket Connection is Closed. diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 0f3c1c150..551115174 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -16,7 +16,6 @@ Generator, Iterable, List, - Optional, Sequence, Set, Tuple, @@ -103,19 +102,19 @@ def __init__( ], ws_server: WebSocketServer, *, - logger: Optional[LoggerLike] = None, - origins: Optional[Sequence[Optional[Origin]]] = None, - extensions: Optional[Sequence[ServerExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLikeOrCallable] = None, - server_header: Optional[str] = USER_AGENT, - process_request: Optional[ - Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]] - ] = None, - select_subprotocol: Optional[ - Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] - ] = None, - open_timeout: Optional[float] = 10, + logger: LoggerLike | None = None, + origins: Sequence[Origin | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLikeOrCallable | None = None, + server_header: str | None = USER_AGENT, + process_request: ( + Callable[[str, Headers], Awaitable[HTTPResponse | None]] | None + ) = None, + select_subprotocol: ( + Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] | None + ) = None, + open_timeout: float | None = 10, **kwargs: Any, ) -> None: if logger is None: @@ -293,7 +292,7 @@ async def read_http_request(self) -> Tuple[str, Headers]: return path, headers def write_http_response( - self, status: http.HTTPStatus, headers: Headers, body: Optional[bytes] = None + self, status: http.HTTPStatus, headers: Headers, body: bytes | None = None ) -> None: """ Write status line and headers to the HTTP response. @@ -322,7 +321,7 @@ def write_http_response( async def process_request( self, path: str, request_headers: Headers - ) -> Optional[HTTPResponse]: + ) -> HTTPResponse | None: """ Intercept the HTTP request and return an HTTP response if appropriate. @@ -371,8 +370,8 @@ async def process_request( @staticmethod def process_origin( - headers: Headers, origins: Optional[Sequence[Optional[Origin]]] = None - ) -> Optional[Origin]: + headers: Headers, origins: Sequence[Origin | None] | None = None + ) -> Origin | None: """ Handle the Origin HTTP request header. @@ -387,9 +386,11 @@ def process_origin( # "The user agent MUST NOT include more than one Origin header field" # per https://www.rfc-editor.org/rfc/rfc6454.html#section-7.3. try: - origin = cast(Optional[Origin], headers.get("Origin")) + origin = headers.get("Origin") except MultipleValuesError as exc: raise InvalidHeader("Origin", "more than one Origin header found") from exc + if origin is not None: + origin = cast(Origin, origin) if origins is not None: if origin not in origins: raise InvalidOrigin(origin) @@ -398,8 +399,8 @@ def process_origin( @staticmethod def process_extensions( headers: Headers, - available_extensions: Optional[Sequence[ServerExtensionFactory]], - ) -> Tuple[Optional[str], List[Extension]]: + available_extensions: Sequence[ServerExtensionFactory] | None, + ) -> Tuple[str | None, List[Extension]]: """ Handle the Sec-WebSocket-Extensions HTTP request header. @@ -435,7 +436,7 @@ def process_extensions( InvalidHandshake: To abort the handshake with an HTTP 400 error. """ - response_header_value: Optional[str] = None + response_header_value: str | None = None extension_headers: List[ExtensionHeader] = [] accepted_extensions: List[Extension] = [] @@ -479,8 +480,8 @@ def process_extensions( # Not @staticmethod because it calls self.select_subprotocol() def process_subprotocol( - self, headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]] - ) -> Optional[Subprotocol]: + self, headers: Headers, available_subprotocols: Sequence[Subprotocol] | None + ) -> Subprotocol | None: """ Handle the Sec-WebSocket-Protocol HTTP request header. @@ -495,7 +496,7 @@ def process_subprotocol( InvalidHandshake: To abort the handshake with an HTTP 400 error. """ - subprotocol: Optional[Subprotocol] = None + subprotocol: Subprotocol | None = None header_values = headers.get_all("Sec-WebSocket-Protocol") @@ -514,7 +515,7 @@ def select_subprotocol( self, client_subprotocols: Sequence[Subprotocol], server_subprotocols: Sequence[Subprotocol], - ) -> Optional[Subprotocol]: + ) -> Subprotocol | None: """ Pick a subprotocol among those supported by the client and the server. @@ -552,10 +553,10 @@ def select_subprotocol( async def handshake( self, - origins: Optional[Sequence[Optional[Origin]]] = None, - available_extensions: Optional[Sequence[ServerExtensionFactory]] = None, - available_subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLikeOrCallable] = None, + origins: Sequence[Origin | None] | None = None, + available_extensions: Sequence[ServerExtensionFactory] | None = None, + available_subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLikeOrCallable | None = None, ) -> str: """ Perform the server side of the opening handshake. @@ -661,7 +662,7 @@ class WebSocketServer: """ - def __init__(self, logger: Optional[LoggerLike] = None): + def __init__(self, logger: LoggerLike | None = None): if logger is None: logger = logging.getLogger("websockets.server") self.logger = logger @@ -670,7 +671,7 @@ def __init__(self, logger: Optional[LoggerLike] = None): self.websockets: Set[WebSocketServerProtocol] = set() # Task responsible for closing the server and terminating connections. - self.close_task: Optional[asyncio.Task[None]] = None + self.close_task: asyncio.Task[None] | None = None # Completed when the server is closed and connections are terminated. self.closed_waiter: asyncio.Future[None] @@ -869,9 +870,9 @@ async def __aenter__(self) -> WebSocketServer: # pragma: no cover async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: # pragma: no cover self.close() await self.wait_closed() @@ -941,8 +942,8 @@ class Serve: server_header: Value of the ``Server`` response header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. - process_request (Optional[Callable[[str, Headers], \ - Awaitable[Optional[Tuple[StatusLike, HeadersLike, bytes]]]]]): + process_request (Callable[[str, Headers], \ + Awaitable[Tuple[StatusLike, HeadersLike, bytes] | None]] | None): Intercept HTTP request before the opening handshake. See :meth:`~WebSocketServerProtocol.process_request` for details. select_subprotocol: Select a subprotocol supported by the client. @@ -975,35 +976,35 @@ def __init__( Callable[[WebSocketServerProtocol], Awaitable[Any]], Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated ], - host: Optional[Union[str, Sequence[str]]] = None, - port: Optional[int] = None, + host: Union[str, Sequence[str]] | None = None, + port: int | None = None, *, - create_protocol: Optional[Callable[..., WebSocketServerProtocol]] = None, - logger: Optional[LoggerLike] = None, - compression: Optional[str] = "deflate", - origins: Optional[Sequence[Optional[Origin]]] = None, - extensions: Optional[Sequence[ServerExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLikeOrCallable] = None, - server_header: Optional[str] = USER_AGENT, - process_request: Optional[ - Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]] - ] = None, - select_subprotocol: Optional[ - Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] - ] = None, - open_timeout: Optional[float] = 10, - ping_interval: Optional[float] = 20, - ping_timeout: Optional[float] = 20, - close_timeout: Optional[float] = None, - max_size: Optional[int] = 2**20, - max_queue: Optional[int] = 2**5, + create_protocol: Callable[..., WebSocketServerProtocol] | None = None, + logger: LoggerLike | None = None, + compression: str | None = "deflate", + origins: Sequence[Origin | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLikeOrCallable | None = None, + server_header: str | None = USER_AGENT, + process_request: ( + Callable[[str, Headers], Awaitable[HTTPResponse | None]] | None + ) = None, + select_subprotocol: ( + Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] | None + ) = None, + open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = None, + max_size: int | None = 2**20, + max_queue: int | None = 2**5, read_limit: int = 2**16, write_limit: int = 2**16, **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. - timeout: Optional[float] = kwargs.pop("timeout", None) + timeout: float | None = kwargs.pop("timeout", None) if timeout is None: timeout = 10 else: @@ -1013,7 +1014,7 @@ def __init__( close_timeout = timeout # Backwards compatibility: create_protocol used to be called klass. - klass: Optional[Type[WebSocketServerProtocol]] = kwargs.pop("klass", None) + klass: Type[WebSocketServerProtocol] | None = kwargs.pop("klass", None) if klass is None: klass = WebSocketServerProtocol else: @@ -1026,7 +1027,7 @@ def __init__( legacy_recv: bool = kwargs.pop("legacy_recv", False) # Backwards compatibility: the loop parameter used to be supported. - _loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) + _loop: asyncio.AbstractEventLoop | None = kwargs.pop("loop", None) if _loop is None: loop = asyncio.get_event_loop() else: @@ -1076,7 +1077,7 @@ def __init__( ) if kwargs.pop("unix", False): - path: Optional[str] = kwargs.pop("path", None) + path: str | None = kwargs.pop("path", None) # unix_serve(path) must not specify host and port parameters. assert host is None and port is None create_server = functools.partial( @@ -1098,9 +1099,9 @@ async def __aenter__(self) -> WebSocketServer: async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: self.ws_server.close() await self.ws_server.wait_closed() @@ -1129,7 +1130,7 @@ def unix_serve( Callable[[WebSocketServerProtocol], Awaitable[Any]], Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated ], - path: Optional[str] = None, + path: str | None = None, **kwargs: Any, ) -> Serve: """ diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 0b36202e5..8aa222eeb 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -3,7 +3,7 @@ import enum import logging import uuid -from typing import Generator, List, Optional, Type, Union +from typing import Generator, List, Type, Union from .exceptions import ( ConnectionClosed, @@ -89,8 +89,8 @@ def __init__( side: Side, *, state: State = OPEN, - max_size: Optional[int] = 2**20, - logger: Optional[LoggerLike] = None, + max_size: int | None = 2**20, + logger: LoggerLike | None = None, ) -> None: # Unique identifier. For logs. self.id: uuid.UUID = uuid.uuid4() @@ -116,24 +116,24 @@ def __init__( # Current size of incoming message in bytes. Only set while reading a # fragmented message i.e. a data frames with the FIN bit not set. - self.cur_size: Optional[int] = None + self.cur_size: int | None = None # True while sending a fragmented message i.e. a data frames with the # FIN bit not set. self.expect_continuation_frame = False # WebSocket protocol parameters. - self.origin: Optional[Origin] = None + self.origin: Origin | None = None self.extensions: List[Extension] = [] - self.subprotocol: Optional[Subprotocol] = None + self.subprotocol: Subprotocol | None = None # Close code and reason, set when a close frame is sent or received. - self.close_rcvd: Optional[Close] = None - self.close_sent: Optional[Close] = None - self.close_rcvd_then_sent: Optional[bool] = None + self.close_rcvd: Close | None = None + self.close_sent: Close | None = None + self.close_rcvd_then_sent: bool | None = None # Track if an exception happened during the handshake. - self.handshake_exc: Optional[Exception] = None + self.handshake_exc: Exception | None = None """ Exception to raise if the opening handshake failed. @@ -150,7 +150,7 @@ def __init__( self.writes: List[bytes] = [] self.parser = self.parse() next(self.parser) # start coroutine - self.parser_exc: Optional[Exception] = None + self.parser_exc: Exception | None = None @property def state(self) -> State: @@ -169,7 +169,7 @@ def state(self, state: State) -> None: self._state = state @property - def close_code(self) -> Optional[int]: + def close_code(self) -> int | None: """ `WebSocket close code`_. @@ -187,7 +187,7 @@ def close_code(self) -> Optional[int]: return self.close_rcvd.code @property - def close_reason(self) -> Optional[str]: + def close_reason(self) -> str | None: """ `WebSocket close reason`_. @@ -348,7 +348,7 @@ def send_binary(self, data: bytes, fin: bool = True) -> None: self.expect_continuation_frame = not fin self.send_frame(Frame(OP_BINARY, data, fin)) - def send_close(self, code: Optional[int] = None, reason: str = "") -> None: + def send_close(self, code: int | None = None, reason: str = "") -> None: """ Send a `Close frame`_. diff --git a/src/websockets/server.py b/src/websockets/server.py index 330e54f37..a92541085 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -5,7 +5,7 @@ import email.utils import http import warnings -from typing import Any, Callable, Generator, List, Optional, Sequence, Tuple, cast +from typing import Any, Callable, Generator, List, Sequence, Tuple, cast from .datastructures import Headers, MultipleValuesError from .exceptions import ( @@ -77,18 +77,19 @@ class ServerProtocol(Protocol): def __init__( self, *, - origins: Optional[Sequence[Optional[Origin]]] = None, - extensions: Optional[Sequence[ServerExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - select_subprotocol: Optional[ + origins: Sequence[Origin | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + select_subprotocol: ( Callable[ [ServerProtocol, Sequence[Subprotocol]], - Optional[Subprotocol], + Subprotocol | None, ] - ] = None, + | None + ) = None, state: State = CONNECTING, - max_size: Optional[int] = 2**20, - logger: Optional[LoggerLike] = None, + max_size: int | None = 2**20, + logger: LoggerLike | None = None, ): super().__init__( side=SERVER, @@ -200,7 +201,7 @@ def accept(self, request: Request) -> Response: def process_request( self, request: Request, - ) -> Tuple[str, Optional[str], Optional[str]]: + ) -> Tuple[str, str | None, str | None]: """ Check a handshake request and negotiate extensions and subprotocol. @@ -285,7 +286,7 @@ def process_request( protocol_header, ) - def process_origin(self, headers: Headers) -> Optional[Origin]: + def process_origin(self, headers: Headers) -> Origin | None: """ Handle the Origin HTTP request header. @@ -303,9 +304,11 @@ def process_origin(self, headers: Headers) -> Optional[Origin]: # "The user agent MUST NOT include more than one Origin header field" # per https://www.rfc-editor.org/rfc/rfc6454.html#section-7.3. try: - origin = cast(Optional[Origin], headers.get("Origin")) + origin = headers.get("Origin") except MultipleValuesError as exc: raise InvalidHeader("Origin", "more than one Origin header found") from exc + if origin is not None: + origin = cast(Origin, origin) if self.origins is not None: if origin not in self.origins: raise InvalidOrigin(origin) @@ -314,7 +317,7 @@ def process_origin(self, headers: Headers) -> Optional[Origin]: def process_extensions( self, headers: Headers, - ) -> Tuple[Optional[str], List[Extension]]: + ) -> Tuple[str | None, List[Extension]]: """ Handle the Sec-WebSocket-Extensions HTTP request header. @@ -350,7 +353,7 @@ def process_extensions( InvalidHandshake: If the Sec-WebSocket-Extensions header is invalid. """ - response_header_value: Optional[str] = None + response_header_value: str | None = None extension_headers: List[ExtensionHeader] = [] accepted_extensions: List[Extension] = [] @@ -392,7 +395,7 @@ def process_extensions( return response_header_value, accepted_extensions - def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: + def process_subprotocol(self, headers: Headers) -> Subprotocol | None: """ Handle the Sec-WebSocket-Protocol HTTP request header. @@ -420,7 +423,7 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: def select_subprotocol( self, subprotocols: Sequence[Subprotocol], - ) -> Optional[Subprotocol]: + ) -> Subprotocol | None: """ Pick a subprotocol among those offered by the client. diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 0bb7a76fd..60b49ebc3 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -4,7 +4,7 @@ import ssl as ssl_module import threading import warnings -from typing import Any, Optional, Sequence, Type +from typing import Any, Sequence, Type from ..client import ClientProtocol from ..datastructures import HeadersLike @@ -52,7 +52,7 @@ def __init__( socket: socket.socket, protocol: ClientProtocol, *, - close_timeout: Optional[float] = 10, + close_timeout: float | None = 10, ) -> None: self.protocol: ClientProtocol self.response_rcvd = threading.Event() @@ -64,9 +64,9 @@ def __init__( def handshake( self, - additional_headers: Optional[HeadersLike] = None, - user_agent_header: Optional[str] = USER_AGENT, - timeout: Optional[float] = None, + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + timeout: float | None = None, ) -> None: """ Perform the opening handshake. @@ -128,25 +128,25 @@ def connect( uri: str, *, # TCP/TLS - sock: Optional[socket.socket] = None, - ssl: Optional[ssl_module.SSLContext] = None, - server_hostname: Optional[str] = None, + sock: socket.socket | None = None, + ssl: ssl_module.SSLContext | None = None, + server_hostname: str | None = None, # WebSocket - origin: Optional[Origin] = None, - extensions: Optional[Sequence[ClientExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - additional_headers: Optional[HeadersLike] = None, - user_agent_header: Optional[str] = USER_AGENT, - compression: Optional[str] = "deflate", + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + compression: str | None = "deflate", # Timeouts - open_timeout: Optional[float] = 10, - close_timeout: Optional[float] = 10, + open_timeout: float | None = 10, + close_timeout: float | None = 10, # Limits - max_size: Optional[int] = 2**20, + max_size: int | None = 2**20, # Logging - logger: Optional[LoggerLike] = None, + logger: LoggerLike | None = None, # Escape hatch for advanced customization - create_connection: Optional[Type[ClientConnection]] = None, + create_connection: Type[ClientConnection] | None = None, **kwargs: Any, ) -> ClientConnection: """ @@ -219,7 +219,7 @@ def connect( # Private APIs for unix_connect() unix: bool = kwargs.pop("unix", False) - path: Optional[str] = kwargs.pop("path", None) + path: str | None = kwargs.pop("path", None) if unix: if path is None and sock is None: @@ -307,8 +307,8 @@ def connect( def unix_connect( - path: Optional[str] = None, - uri: Optional[str] = None, + path: str | None = None, + uri: str | None = None, **kwargs: Any, ) -> ClientConnection: """ diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index b41202dc9..bb9743181 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -8,7 +8,7 @@ import threading import uuid from types import TracebackType -from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Type, Union +from typing import Any, Dict, Iterable, Iterator, Mapping, Type, Union from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl @@ -42,7 +42,7 @@ def __init__( socket: socket.socket, protocol: Protocol, *, - close_timeout: Optional[float] = 10, + close_timeout: float | None = 10, ) -> None: self.socket = socket self.protocol = protocol @@ -62,9 +62,9 @@ def __init__( self.debug = self.protocol.debug # HTTP handshake request and response. - self.request: Optional[Request] = None + self.request: Request | None = None """Opening handshake request.""" - self.response: Optional[Response] = None + self.response: Response | None = None """Opening handshake response.""" # Mutex serializing interactions with the protocol. @@ -77,7 +77,7 @@ def __init__( self.send_in_progress = False # Deadline for the closing handshake. - self.close_deadline: Optional[Deadline] = None + self.close_deadline: Deadline | None = None # Mapping of ping IDs to pong waiters, in chronological order. self.ping_waiters: Dict[bytes, threading.Event] = {} @@ -93,7 +93,7 @@ def __init__( # Exception raised in recv_events, to be chained to ConnectionClosed # in the user thread in order to show why the TCP connection dropped. - self.recv_exc: Optional[BaseException] = None + self.recv_exc: BaseException | None = None # Public attributes @@ -124,7 +124,7 @@ def remote_address(self) -> Any: return self.socket.getpeername() @property - def subprotocol(self) -> Optional[Subprotocol]: + def subprotocol(self) -> Subprotocol | None: """ Subprotocol negotiated during the opening handshake. @@ -140,9 +140,9 @@ def __enter__(self) -> Connection: def __exit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: if exc_type is None: self.close() @@ -166,7 +166,7 @@ def __iter__(self) -> Iterator[Data]: except ConnectionClosedOK: return - def recv(self, timeout: Optional[float] = None) -> Data: + def recv(self, timeout: float | None = None) -> Data: """ Receive the next message. @@ -420,7 +420,7 @@ def close(self, code: int = CloseCode.NORMAL_CLOSURE, reason: str = "") -> None: # They mean that the connection is closed, which was the goal. pass - def ping(self, data: Optional[Data] = None) -> threading.Event: + def ping(self, data: Data | None = None) -> threading.Event: """ Send a Ping_. @@ -647,7 +647,7 @@ def send_context( # Should we close the socket and raise ConnectionClosed? raise_close_exc = False # What exception should we chain ConnectionClosed to? - original_exc: Optional[BaseException] = None + original_exc: BaseException | None = None # Acquire the protocol lock. with self.protocol_mutex: @@ -748,7 +748,7 @@ def send_data(self) -> None: except OSError: # socket already closed pass - def set_recv_exc(self, exc: Optional[BaseException]) -> None: + def set_recv_exc(self, exc: BaseException | None) -> None: """ Set recv_exc, if not set yet. diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index dcba183d9..2c604ba09 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -3,7 +3,7 @@ import codecs import queue import threading -from typing import Iterator, List, Optional, cast +from typing import Iterator, List, cast from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from ..typing import Data @@ -41,7 +41,7 @@ def __init__(self) -> None: self.put_in_progress = False # Decoder for text frames, None for binary frames. - self.decoder: Optional[codecs.IncrementalDecoder] = None + self.decoder: codecs.IncrementalDecoder | None = None # Buffer of frames belonging to the same message. self.chunks: List[Data] = [] @@ -54,12 +54,12 @@ def __init__(self) -> None: # Stream data from frames belonging to the same message. # Remove quotes around type when dropping Python < 3.9. - self.chunks_queue: Optional["queue.SimpleQueue[Optional[Data]]"] = None + self.chunks_queue: "queue.SimpleQueue[Data | None] | None" = None # This flag marks the end of the connection. self.closed = False - def get(self, timeout: Optional[float] = None) -> Data: + def get(self, timeout: float | None = None) -> Data: """ Read the next message. @@ -151,7 +151,7 @@ def get_iter(self) -> Iterator[Data]: self.chunks = [] self.chunks_queue = cast( # Remove quotes around type when dropping Python < 3.9. - "queue.SimpleQueue[Optional[Data]]", + "queue.SimpleQueue[Data | None]", queue.SimpleQueue(), ) @@ -164,7 +164,7 @@ def get_iter(self) -> Iterator[Data]: self.get_in_progress = True # Locking with get_in_progress ensures only one thread can get here. - chunk: Optional[Data] + chunk: Data | None for chunk in chunks: yield chunk while (chunk := self.chunks_queue.get()) is not None: diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index fd4f5d3bd..b801510b4 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -10,7 +10,7 @@ import threading import warnings from types import TracebackType -from typing import Any, Callable, Optional, Sequence, Type +from typing import Any, Callable, Sequence, Type from ..extensions.base import ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate @@ -57,7 +57,7 @@ def __init__( socket: socket.socket, protocol: ServerProtocol, *, - close_timeout: Optional[float] = 10, + close_timeout: float | None = 10, ) -> None: self.protocol: ServerProtocol self.request_rcvd = threading.Event() @@ -69,20 +69,22 @@ def __init__( def handshake( self, - process_request: Optional[ + process_request: ( Callable[ [ServerConnection, Request], - Optional[Response], + Response | None, ] - ] = None, - process_response: Optional[ + | None + ) = None, + process_response: ( Callable[ [ServerConnection, Request, Response], - Optional[Response], + Response | None, ] - ] = None, - server_header: Optional[str] = USER_AGENT, - timeout: Optional[float] = None, + | None + ) = None, + server_header: str | None = USER_AGENT, + timeout: float | None = None, ) -> None: """ Perform the opening handshake. @@ -197,7 +199,7 @@ def __init__( self, socket: socket.socket, handler: Callable[[socket.socket, Any], None], - logger: Optional[LoggerLike] = None, + logger: LoggerLike | None = None, ): self.socket = socket self.handler = handler @@ -260,54 +262,57 @@ def __enter__(self) -> WebSocketServer: def __exit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: self.shutdown() def serve( handler: Callable[[ServerConnection], None], - host: Optional[str] = None, - port: Optional[int] = None, + host: str | None = None, + port: int | None = None, *, # TCP/TLS - sock: Optional[socket.socket] = None, - ssl: Optional[ssl_module.SSLContext] = None, + sock: socket.socket | None = None, + ssl: ssl_module.SSLContext | None = None, # WebSocket - origins: Optional[Sequence[Optional[Origin]]] = None, - extensions: Optional[Sequence[ServerExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - select_subprotocol: Optional[ + origins: Sequence[Origin | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + select_subprotocol: ( Callable[ [ServerConnection, Sequence[Subprotocol]], - Optional[Subprotocol], + Subprotocol | None, ] - ] = None, - process_request: Optional[ + | None + ) = None, + process_request: ( Callable[ [ServerConnection, Request], - Optional[Response], + Response | None, ] - ] = None, - process_response: Optional[ + | None + ) = None, + process_response: ( Callable[ [ServerConnection, Request, Response], - Optional[Response], + Response | None, ] - ] = None, - server_header: Optional[str] = USER_AGENT, - compression: Optional[str] = "deflate", + | None + ) = None, + server_header: str | None = USER_AGENT, + compression: str | None = "deflate", # Timeouts - open_timeout: Optional[float] = 10, - close_timeout: Optional[float] = 10, + open_timeout: float | None = 10, + close_timeout: float | None = 10, # Limits - max_size: Optional[int] = 2**20, + max_size: int | None = 2**20, # Logging - logger: Optional[LoggerLike] = None, + logger: LoggerLike | None = None, # Escape hatch for advanced customization - create_connection: Optional[Type[ServerConnection]] = None, + create_connection: Type[ServerConnection] | None = None, **kwargs: Any, ) -> WebSocketServer: """ @@ -412,7 +417,7 @@ def handler(websocket): # Private APIs for unix_connect() unix: bool = kwargs.pop("unix", False) - path: Optional[str] = kwargs.pop("path", None) + path: str | None = kwargs.pop("path", None) if sock is None: if unix: @@ -460,18 +465,19 @@ def conn_handler(sock: socket.socket, addr: Any) -> None: sock.settimeout(None) # Create a closure to give select_subprotocol access to connection. - protocol_select_subprotocol: Optional[ + protocol_select_subprotocol: ( Callable[ [ServerProtocol, Sequence[Subprotocol]], - Optional[Subprotocol], + Subprotocol | None, ] - ] = None + | None + ) = None if select_subprotocol is not None: def protocol_select_subprotocol( protocol: ServerProtocol, subprotocols: Sequence[Subprotocol], - ) -> Optional[Subprotocol]: + ) -> Subprotocol | None: # mypy doesn't know that select_subprotocol is immutable. assert select_subprotocol is not None # Ensure this function is only used in the intended context. @@ -525,7 +531,7 @@ def protocol_select_subprotocol( def unix_serve( handler: Callable[[ServerConnection], None], - path: Optional[str] = None, + path: str | None = None, **kwargs: Any, ) -> WebSocketServer: """ diff --git a/src/websockets/sync/utils.py b/src/websockets/sync/utils.py index 3364bdc2d..00bce2cc6 100644 --- a/src/websockets/sync/utils.py +++ b/src/websockets/sync/utils.py @@ -1,7 +1,6 @@ from __future__ import annotations import time -from typing import Optional __all__ = ["Deadline"] @@ -16,14 +15,14 @@ class Deadline: """ - def __init__(self, timeout: Optional[float]) -> None: - self.deadline: Optional[float] + def __init__(self, timeout: float | None) -> None: + self.deadline: float | None if timeout is None: self.deadline = None else: self.deadline = time.monotonic() + timeout - def timeout(self, *, raise_if_elapsed: bool = True) -> Optional[float]: + def timeout(self, *, raise_if_elapsed: bool = True) -> float | None: """ Calculate a timeout from a deadline. diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 5dfecf66f..7c5b3664d 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -53,7 +53,7 @@ ExtensionName = NewType("ExtensionName", str) """Name of a WebSocket extension.""" - +# Change to str | None when dropping Python < 3.10. ExtensionParameter = Tuple[str, Optional[str]] """Parameter of a WebSocket extension.""" diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 8cf581743..902716066 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -2,7 +2,7 @@ import dataclasses import urllib.parse -from typing import Optional, Tuple +from typing import Tuple from . import exceptions @@ -33,8 +33,8 @@ class WebSocketURI: port: int path: str query: str - username: Optional[str] = None - password: Optional[str] = None + username: str | None = None + password: str | None = None @property def resource_name(self) -> str: @@ -47,7 +47,7 @@ def resource_name(self) -> str: return resource_name @property - def user_info(self) -> Optional[Tuple[str, str]]: + def user_info(self) -> Tuple[str, str] | None: if self.username is None: return None assert self.password is not None From 63a2d8eff62fe487c42a8f3176528730b7eed727 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 12 Jul 2024 09:16:17 +0200 Subject: [PATCH 045/109] Switch from typing.Union to |. --- src/websockets/datastructures.py | 1 + .../extensions/permessage_deflate.py | 10 ++--- src/websockets/legacy/auth.py | 4 +- src/websockets/legacy/protocol.py | 3 +- src/websockets/legacy/server.py | 37 ++++++++++--------- src/websockets/protocol.py | 1 + src/websockets/sync/connection.py | 4 +- src/websockets/typing.py | 3 ++ 8 files changed, 34 insertions(+), 29 deletions(-) diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index aef11bf23..5605772d8 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -177,6 +177,7 @@ def keys(self) -> Iterable[str]: ... def __getitem__(self, key: str) -> str: ... +# Change to Headers | Mapping[str, str] | ... when dropping Python < 3.10. HeadersLike = Union[ Headers, Mapping[str, str], diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index e95b1064b..48a6a0833 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -2,7 +2,7 @@ import dataclasses import zlib -from typing import Any, Dict, List, Sequence, Tuple, Union +from typing import Any, Dict, List, Sequence, Tuple from .. import exceptions, frames from ..typing import ExtensionName, ExtensionParameter @@ -175,7 +175,7 @@ def _build_parameters( server_no_context_takeover: bool, client_no_context_takeover: bool, server_max_window_bits: int | None, - client_max_window_bits: Union[int, bool] | None, + client_max_window_bits: int | bool | None, ) -> List[ExtensionParameter]: """ Build a list of ``(name, value)`` pairs for some compression parameters. @@ -197,7 +197,7 @@ def _build_parameters( def _extract_parameters( params: Sequence[ExtensionParameter], *, is_server: bool -) -> Tuple[bool, bool, int | None, Union[int, bool] | None]: +) -> Tuple[bool, bool, int | None, int | bool | None]: """ Extract compression parameters from a list of ``(name, value)`` pairs. @@ -208,7 +208,7 @@ def _extract_parameters( server_no_context_takeover: bool = False client_no_context_takeover: bool = False server_max_window_bits: int | None = None - client_max_window_bits: Union[int, bool] | None = None + client_max_window_bits: int | bool | None = None for name, value in params: if name == "server_no_context_takeover": @@ -287,7 +287,7 @@ def __init__( server_no_context_takeover: bool = False, client_no_context_takeover: bool = False, server_max_window_bits: int | None = None, - client_max_window_bits: Union[int, bool] | None = True, + client_max_window_bits: int | bool | None = True, compress_settings: Dict[str, Any] | None = None, ) -> None: """ diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 067f9c78c..9d685d9f4 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -3,7 +3,7 @@ import functools import hmac import http -from typing import Any, Awaitable, Callable, Iterable, Tuple, Union, cast +from typing import Any, Awaitable, Callable, Iterable, Tuple, cast from ..datastructures import Headers from ..exceptions import InvalidHeader @@ -116,7 +116,7 @@ async def process_request( def basic_auth_protocol_factory( realm: str | None = None, - credentials: Union[Credentials, Iterable[Credentials]] | None = None, + credentials: Credentials | Iterable[Credentials] | None = None, check_credentials: Callable[[str, str], Awaitable[bool]] | None = None, create_protocol: Callable[..., BasicAuthWebSocketServerProtocol] | None = None, ) -> Callable[..., BasicAuthWebSocketServerProtocol]: diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 94d42cfdb..f4c5901dc 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -23,7 +23,6 @@ List, Mapping, Tuple, - Union, cast, ) @@ -578,7 +577,7 @@ async def recv(self) -> Data: async def send( self, - message: Union[Data, Iterable[Data], AsyncIterable[Data]], + message: Data | Iterable[Data] | AsyncIterable[Data], ) -> None: """ Send a message. diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 551115174..13a6f5591 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -54,6 +54,7 @@ __all__ = ["serve", "unix_serve", "WebSocketServerProtocol", "WebSocketServer"] +# Change to HeadersLike | ... when dropping Python < 3.10. HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] HTTPResponse = Tuple[StatusLike, HeadersLike, bytes] @@ -96,10 +97,10 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): def __init__( self, - ws_handler: Union[ - Callable[[WebSocketServerProtocol], Awaitable[Any]], - Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated - ], + ws_handler: ( + Callable[[WebSocketServerProtocol], Awaitable[Any]] + | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] # deprecated + ), ws_server: WebSocketServer, *, logger: LoggerLike | None = None, @@ -934,7 +935,7 @@ class Serve: should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing preference. - extra_headers (Union[HeadersLike, Callable[[str, Headers], HeadersLike]]): + extra_headers (HeadersLike | Callable[[str, Headers] | HeadersLike]): Arbitrary HTTP headers to add to the response. This can be a :data:`~websockets.datastructures.HeadersLike` or a callable taking the request path and headers in arguments and returning @@ -972,11 +973,11 @@ class Serve: def __init__( self, - ws_handler: Union[ - Callable[[WebSocketServerProtocol], Awaitable[Any]], - Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated - ], - host: Union[str, Sequence[str]] | None = None, + ws_handler: ( + Callable[[WebSocketServerProtocol], Awaitable[Any]] + | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] # deprecated + ), + host: str | Sequence[str] | None = None, port: int | None = None, *, create_protocol: Callable[..., WebSocketServerProtocol] | None = None, @@ -1126,10 +1127,10 @@ async def __await_impl__(self) -> WebSocketServer: def unix_serve( - ws_handler: Union[ - Callable[[WebSocketServerProtocol], Awaitable[Any]], - Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated - ], + ws_handler: ( + Callable[[WebSocketServerProtocol], Awaitable[Any]] + | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] # deprecated + ), path: str | None = None, **kwargs: Any, ) -> Serve: @@ -1152,10 +1153,10 @@ def unix_serve( def remove_path_argument( - ws_handler: Union[ - Callable[[WebSocketServerProtocol], Awaitable[Any]], - Callable[[WebSocketServerProtocol, str], Awaitable[Any]], - ] + ws_handler: ( + Callable[[WebSocketServerProtocol], Awaitable[Any]] + | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] + ) ) -> Callable[[WebSocketServerProtocol], Awaitable[Any]]: try: inspect.signature(ws_handler).bind(None) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 8aa222eeb..f288a2733 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -38,6 +38,7 @@ "SEND_EOF", ] +# Change to Request | Response | Frame when dropping Python < 3.10. Event = Union[Request, Response, Frame] """Events that :meth:`~Protocol.events_received` may return.""" diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index bb9743181..7a750331d 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -8,7 +8,7 @@ import threading import uuid from types import TracebackType -from typing import Any, Dict, Iterable, Iterator, Mapping, Type, Union +from typing import Any, Dict, Iterable, Iterator, Mapping, Type from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl @@ -242,7 +242,7 @@ def recv_streaming(self) -> Iterator[Data]: "is already running recv or recv_streaming" ) from None - def send(self, message: Union[Data, Iterable[Data]]) -> None: + def send(self, message: Data | Iterable[Data]) -> None: """ Send a message. diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 7c5b3664d..73d4a4754 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -19,6 +19,7 @@ # Public types used in the signature of public APIs +# Change to str | bytes when dropping Python < 3.10. Data = Union[str, bytes] """Types supported in a WebSocket message: :class:`str` for a Text_ frame, :class:`bytes` for a Binary_. @@ -29,6 +30,7 @@ """ +# Change to logging.Logger | ... when dropping Python < 3.10. if typing.TYPE_CHECKING: LoggerLike = Union[logging.Logger, logging.LoggerAdapter[Any]] """Types accepted where a :class:`~logging.Logger` is expected.""" @@ -37,6 +39,7 @@ """Types accepted where a :class:`~logging.Logger` is expected.""" +# Change to http.HTTPStatus | int when dropping Python < 3.10. StatusLike = Union[http.HTTPStatus, int] """ Types accepted where an :class:`~http.HTTPStatus` is expected.""" From cd059d5633eed129775572054a7458d8e3f07166 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 12 Jul 2024 09:28:36 +0200 Subject: [PATCH 046/109] Switch from typing.Dict/List/Tuple/Type/Set to native types. --- src/websockets/client.py | 12 +++---- src/websockets/datastructures.py | 11 +++--- src/websockets/extensions/base.py | 6 ++-- .../extensions/permessage_deflate.py | 18 +++++----- src/websockets/frames.py | 4 +-- src/websockets/headers.py | 34 +++++++++---------- src/websockets/imports.py | 10 +++--- src/websockets/legacy/auth.py | 1 + src/websockets/legacy/client.py | 15 ++++---- src/websockets/legacy/framing.py | 4 +-- src/websockets/legacy/handshake.py | 9 +++-- src/websockets/legacy/http.py | 5 ++- src/websockets/legacy/protocol.py | 9 ++--- src/websockets/legacy/server.py | 28 +++++++-------- src/websockets/protocol.py | 14 ++++---- src/websockets/server.py | 16 ++++----- src/websockets/sync/client.py | 4 +-- src/websockets/sync/connection.py | 6 ++-- src/websockets/sync/messages.py | 4 +-- src/websockets/sync/server.py | 6 ++-- src/websockets/typing.py | 4 ++- src/websockets/uri.py | 3 +- 22 files changed, 107 insertions(+), 116 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index cfb441fd9..8f78ac320 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import Any, Generator, List, Sequence +from typing import Any, Generator, Sequence from .datastructures import Headers, MultipleValuesError from .exceptions import ( @@ -153,7 +153,7 @@ def process_response(self, response: Response) -> None: headers = response.headers - connection: List[ConnectionOption] = sum( + connection: list[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) @@ -162,7 +162,7 @@ def process_response(self, response: Response) -> None: "Connection", ", ".join(connection) if connection else None ) - upgrade: List[UpgradeProtocol] = sum( + upgrade: list[UpgradeProtocol] = sum( [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] ) @@ -188,7 +188,7 @@ def process_response(self, response: Response) -> None: self.subprotocol = self.process_subprotocol(headers) - def process_extensions(self, headers: Headers) -> List[Extension]: + def process_extensions(self, headers: Headers) -> list[Extension]: """ Handle the Sec-WebSocket-Extensions HTTP response header. @@ -219,7 +219,7 @@ def process_extensions(self, headers: Headers) -> List[Extension]: InvalidHandshake: To abort the handshake. """ - accepted_extensions: List[Extension] = [] + accepted_extensions: list[Extension] = [] extensions = headers.get_all("Sec-WebSocket-Extensions") @@ -227,7 +227,7 @@ def process_extensions(self, headers: Headers) -> List[Extension]: if self.available_extensions is None: raise InvalidHandshake("no extensions supported") - parsed_extensions: List[ExtensionHeader] = sum( + parsed_extensions: list[ExtensionHeader] = sum( [parse_extension(header_value) for header_value in extensions], [] ) diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 5605772d8..3d64d951e 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -2,10 +2,8 @@ from typing import ( Any, - Dict, Iterable, Iterator, - List, Mapping, MutableMapping, Protocol, @@ -72,8 +70,8 @@ class Headers(MutableMapping[str, str]): # Like dict, Headers accepts an optional "mapping or iterable" argument. def __init__(self, *args: HeadersLike, **kwargs: str) -> None: - self._dict: Dict[str, List[str]] = {} - self._list: List[Tuple[str, str]] = [] + self._dict: dict[str, list[str]] = {} + self._list: list[tuple[str, str]] = [] self.update(*args, **kwargs) def __str__(self) -> str: @@ -147,7 +145,7 @@ def update(self, *args: HeadersLike, **kwargs: str) -> None: # Methods for handling multiple values - def get_all(self, key: str) -> List[str]: + def get_all(self, key: str) -> list[str]: """ Return the (possibly empty) list of all values for a header. @@ -157,7 +155,7 @@ def get_all(self, key: str) -> List[str]: """ return self._dict.get(key.lower(), []) - def raw_items(self) -> Iterator[Tuple[str, str]]: + def raw_items(self) -> Iterator[tuple[str, str]]: """ Return an iterator of all values as ``(name, value)`` pairs. @@ -181,6 +179,7 @@ def __getitem__(self, key: str) -> str: ... HeadersLike = Union[ Headers, Mapping[str, str], + # Change to tuple[str, str] when dropping Python < 3.9. Iterable[Tuple[str, str]], SupportsKeysAndGetItem, ] diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 5b5528a09..a6c76c3d4 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Sequence, Tuple +from typing import Sequence from .. import frames from ..typing import ExtensionName, ExtensionParameter @@ -63,7 +63,7 @@ class ClientExtensionFactory: name: ExtensionName """Extension identifier.""" - def get_request_params(self) -> List[ExtensionParameter]: + def get_request_params(self) -> list[ExtensionParameter]: """ Build parameters to send to the server for this extension. @@ -108,7 +108,7 @@ def process_request_params( self, params: Sequence[ExtensionParameter], accepted_extensions: Sequence[Extension], - ) -> Tuple[List[ExtensionParameter], Extension]: + ) -> tuple[list[ExtensionParameter], Extension]: """ Process parameters received from the client. diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 48a6a0833..579262f02 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -2,7 +2,7 @@ import dataclasses import zlib -from typing import Any, Dict, List, Sequence, Tuple +from typing import Any, Sequence from .. import exceptions, frames from ..typing import ExtensionName, ExtensionParameter @@ -36,7 +36,7 @@ def __init__( local_no_context_takeover: bool, remote_max_window_bits: int, local_max_window_bits: int, - compress_settings: Dict[Any, Any] | None = None, + compress_settings: dict[Any, Any] | None = None, ) -> None: """ Configure the Per-Message Deflate extension. @@ -176,12 +176,12 @@ def _build_parameters( client_no_context_takeover: bool, server_max_window_bits: int | None, client_max_window_bits: int | bool | None, -) -> List[ExtensionParameter]: +) -> list[ExtensionParameter]: """ Build a list of ``(name, value)`` pairs for some compression parameters. """ - params: List[ExtensionParameter] = [] + params: list[ExtensionParameter] = [] if server_no_context_takeover: params.append(("server_no_context_takeover", None)) if client_no_context_takeover: @@ -197,7 +197,7 @@ def _build_parameters( def _extract_parameters( params: Sequence[ExtensionParameter], *, is_server: bool -) -> Tuple[bool, bool, int | None, int | bool | None]: +) -> tuple[bool, bool, int | None, int | bool | None]: """ Extract compression parameters from a list of ``(name, value)`` pairs. @@ -288,7 +288,7 @@ def __init__( client_no_context_takeover: bool = False, server_max_window_bits: int | None = None, client_max_window_bits: int | bool | None = True, - compress_settings: Dict[str, Any] | None = None, + compress_settings: dict[str, Any] | None = None, ) -> None: """ Configure the Per-Message Deflate extension factory. @@ -314,7 +314,7 @@ def __init__( self.client_max_window_bits = client_max_window_bits self.compress_settings = compress_settings - def get_request_params(self) -> List[ExtensionParameter]: + def get_request_params(self) -> list[ExtensionParameter]: """ Build request parameters. @@ -491,7 +491,7 @@ def __init__( client_no_context_takeover: bool = False, server_max_window_bits: int | None = None, client_max_window_bits: int | None = None, - compress_settings: Dict[str, Any] | None = None, + compress_settings: dict[str, Any] | None = None, require_client_max_window_bits: bool = False, ) -> None: """ @@ -524,7 +524,7 @@ def process_request_params( self, params: Sequence[ExtensionParameter], accepted_extensions: Sequence[Extension], - ) -> Tuple[List[ExtensionParameter], PerMessageDeflate]: + ) -> tuple[list[ExtensionParameter], PerMessageDeflate]: """ Process request parameters. diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 5a304d6a7..0da676432 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -5,7 +5,7 @@ import io import secrets import struct -from typing import Callable, Generator, Sequence, Tuple +from typing import Callable, Generator, Sequence from . import exceptions, extensions from .typing import Data @@ -353,7 +353,7 @@ def check(self) -> None: raise exceptions.ProtocolError("fragmented control frame") -def prepare_data(data: Data) -> Tuple[int, bytes]: +def prepare_data(data: Data) -> tuple[int, bytes]: """ Convert a string or byte-like object to an opcode and a bytes-like object. diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 3b316e0bf..bc42e0b72 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -4,7 +4,7 @@ import binascii import ipaddress import re -from typing import Callable, List, Sequence, Tuple, TypeVar, cast +from typing import Callable, Sequence, TypeVar, cast from . import exceptions from .typing import ( @@ -96,7 +96,7 @@ def parse_OWS(header: str, pos: int) -> int: _token_re = re.compile(r"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+") -def parse_token(header: str, pos: int, header_name: str) -> Tuple[str, int]: +def parse_token(header: str, pos: int, header_name: str) -> tuple[str, int]: """ Parse a token from ``header`` at the given position. @@ -120,7 +120,7 @@ def parse_token(header: str, pos: int, header_name: str) -> Tuple[str, int]: _unquote_re = re.compile(r"\\([\x09\x20-\x7e\x80-\xff])") -def parse_quoted_string(header: str, pos: int, header_name: str) -> Tuple[str, int]: +def parse_quoted_string(header: str, pos: int, header_name: str) -> tuple[str, int]: """ Parse a quoted string from ``header`` at the given position. @@ -158,11 +158,11 @@ def build_quoted_string(value: str) -> str: def parse_list( - parse_item: Callable[[str, int, str], Tuple[T, int]], + parse_item: Callable[[str, int, str], tuple[T, int]], header: str, pos: int, header_name: str, -) -> List[T]: +) -> list[T]: """ Parse a comma-separated list from ``header`` at the given position. @@ -227,7 +227,7 @@ def parse_list( def parse_connection_option( header: str, pos: int, header_name: str -) -> Tuple[ConnectionOption, int]: +) -> tuple[ConnectionOption, int]: """ Parse a Connection option from ``header`` at the given position. @@ -241,7 +241,7 @@ def parse_connection_option( return cast(ConnectionOption, item), pos -def parse_connection(header: str) -> List[ConnectionOption]: +def parse_connection(header: str) -> list[ConnectionOption]: """ Parse a ``Connection`` header. @@ -264,7 +264,7 @@ def parse_connection(header: str) -> List[ConnectionOption]: def parse_upgrade_protocol( header: str, pos: int, header_name: str -) -> Tuple[UpgradeProtocol, int]: +) -> tuple[UpgradeProtocol, int]: """ Parse an Upgrade protocol from ``header`` at the given position. @@ -282,7 +282,7 @@ def parse_upgrade_protocol( return cast(UpgradeProtocol, match.group()), match.end() -def parse_upgrade(header: str) -> List[UpgradeProtocol]: +def parse_upgrade(header: str) -> list[UpgradeProtocol]: """ Parse an ``Upgrade`` header. @@ -300,7 +300,7 @@ def parse_upgrade(header: str) -> List[UpgradeProtocol]: def parse_extension_item_param( header: str, pos: int, header_name: str -) -> Tuple[ExtensionParameter, int]: +) -> tuple[ExtensionParameter, int]: """ Parse a single extension parameter from ``header`` at the given position. @@ -336,7 +336,7 @@ def parse_extension_item_param( def parse_extension_item( header: str, pos: int, header_name: str -) -> Tuple[ExtensionHeader, int]: +) -> tuple[ExtensionHeader, int]: """ Parse an extension definition from ``header`` at the given position. @@ -359,7 +359,7 @@ def parse_extension_item( return (cast(ExtensionName, name), parameters), pos -def parse_extension(header: str) -> List[ExtensionHeader]: +def parse_extension(header: str) -> list[ExtensionHeader]: """ Parse a ``Sec-WebSocket-Extensions`` header. @@ -389,7 +389,7 @@ def parse_extension(header: str) -> List[ExtensionHeader]: def build_extension_item( - name: ExtensionName, parameters: List[ExtensionParameter] + name: ExtensionName, parameters: list[ExtensionParameter] ) -> str: """ Build an extension definition. @@ -424,7 +424,7 @@ def build_extension(extensions: Sequence[ExtensionHeader]) -> str: def parse_subprotocol_item( header: str, pos: int, header_name: str -) -> Tuple[Subprotocol, int]: +) -> tuple[Subprotocol, int]: """ Parse a subprotocol from ``header`` at the given position. @@ -438,7 +438,7 @@ def parse_subprotocol_item( return cast(Subprotocol, item), pos -def parse_subprotocol(header: str) -> List[Subprotocol]: +def parse_subprotocol(header: str) -> list[Subprotocol]: """ Parse a ``Sec-WebSocket-Protocol`` header. @@ -498,7 +498,7 @@ def build_www_authenticate_basic(realm: str) -> str: _token68_re = re.compile(r"[A-Za-z0-9-._~+/]+=*") -def parse_token68(header: str, pos: int, header_name: str) -> Tuple[str, int]: +def parse_token68(header: str, pos: int, header_name: str) -> tuple[str, int]: """ Parse a token68 from ``header`` at the given position. @@ -525,7 +525,7 @@ def parse_end(header: str, pos: int, header_name: str) -> None: raise exceptions.InvalidHeaderFormat(header_name, "trailing data", header, pos) -def parse_authorization_basic(header: str) -> Tuple[str, str]: +def parse_authorization_basic(header: str) -> tuple[str, str]: """ Parse an ``Authorization`` header for HTTP Basic Auth. diff --git a/src/websockets/imports.py b/src/websockets/imports.py index 9c05234f5..bb80e4eac 100644 --- a/src/websockets/imports.py +++ b/src/websockets/imports.py @@ -1,13 +1,13 @@ from __future__ import annotations import warnings -from typing import Any, Dict, Iterable +from typing import Any, Iterable __all__ = ["lazy_import"] -def import_name(name: str, source: str, namespace: Dict[str, Any]) -> Any: +def import_name(name: str, source: str, namespace: dict[str, Any]) -> Any: """ Import ``name`` from ``source`` in ``namespace``. @@ -29,9 +29,9 @@ def import_name(name: str, source: str, namespace: Dict[str, Any]) -> Any: def lazy_import( - namespace: Dict[str, Any], - aliases: Dict[str, str] | None = None, - deprecated_aliases: Dict[str, str] | None = None, + namespace: dict[str, Any], + aliases: dict[str, str] | None = None, + deprecated_aliases: dict[str, str] | None = None, ) -> None: """ Provide lazy, module-level imports. diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 9d685d9f4..c2d30e4b4 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -13,6 +13,7 @@ __all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"] +# Change to tuple[str, str] when dropping Python < 3.9. Credentials = Tuple[str, str] diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index f7464368f..d9d69fdaa 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -12,10 +12,7 @@ AsyncIterator, Callable, Generator, - List, Sequence, - Tuple, - Type, cast, ) @@ -122,7 +119,7 @@ def write_http_request(self, path: str, headers: Headers) -> None: self.transport.write(request.encode()) - async def read_http_response(self) -> Tuple[int, Headers]: + async def read_http_response(self) -> tuple[int, Headers]: """ Read status line and headers from the HTTP response. @@ -152,7 +149,7 @@ async def read_http_response(self) -> Tuple[int, Headers]: def process_extensions( headers: Headers, available_extensions: Sequence[ClientExtensionFactory] | None, - ) -> List[Extension]: + ) -> list[Extension]: """ Handle the Sec-WebSocket-Extensions HTTP response header. @@ -179,7 +176,7 @@ def process_extensions( order of extensions, may be implemented by overriding this method. """ - accepted_extensions: List[Extension] = [] + accepted_extensions: list[Extension] = [] header_values = headers.get_all("Sec-WebSocket-Extensions") @@ -187,7 +184,7 @@ def process_extensions( if available_extensions is None: raise InvalidHandshake("no extensions supported") - parsed_header_values: List[ExtensionHeader] = sum( + parsed_header_values: list[ExtensionHeader] = sum( [parse_extension(header_value) for header_value in header_values], [] ) @@ -455,7 +452,7 @@ def __init__( close_timeout = timeout # Backwards compatibility: create_protocol used to be called klass. - klass: Type[WebSocketClientProtocol] | None = kwargs.pop("klass", None) + klass: type[WebSocketClientProtocol] | None = kwargs.pop("klass", None) if klass is None: klass = WebSocketClientProtocol else: @@ -629,7 +626,7 @@ async def __aenter__(self) -> WebSocketClientProtocol: async def __aexit__( self, - exc_type: Type[BaseException] | None, + exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 8a13fa446..1aaca5cc6 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -1,7 +1,7 @@ from __future__ import annotations import struct -from typing import Any, Awaitable, Callable, NamedTuple, Sequence, Tuple +from typing import Any, Awaitable, Callable, NamedTuple, Sequence from .. import extensions, frames from ..exceptions import PayloadTooBig, ProtocolError @@ -152,7 +152,7 @@ def write( ) -def parse_close(data: bytes) -> Tuple[int, str]: +def parse_close(data: bytes) -> tuple[int, str]: """ Parse the payload from a close frame. diff --git a/src/websockets/legacy/handshake.py b/src/websockets/legacy/handshake.py index 5853c31db..2a39c1b03 100644 --- a/src/websockets/legacy/handshake.py +++ b/src/websockets/legacy/handshake.py @@ -2,7 +2,6 @@ import base64 import binascii -from typing import List from ..datastructures import Headers, MultipleValuesError from ..exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade @@ -55,14 +54,14 @@ def check_request(headers: Headers) -> str: Then, the server must return a 400 Bad Request error. """ - connection: List[ConnectionOption] = sum( + connection: list[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) if not any(value.lower() == "upgrade" for value in connection): raise InvalidUpgrade("Connection", ", ".join(connection)) - upgrade: List[UpgradeProtocol] = sum( + upgrade: list[UpgradeProtocol] = sum( [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] ) @@ -135,14 +134,14 @@ def check_response(headers: Headers, key: str) -> None: InvalidHandshake: If the handshake response is invalid. """ - connection: List[ConnectionOption] = sum( + connection: list[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) if not any(value.lower() == "upgrade" for value in connection): raise InvalidUpgrade("Connection", " ".join(connection)) - upgrade: List[UpgradeProtocol] = sum( + upgrade: list[UpgradeProtocol] = sum( [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] ) diff --git a/src/websockets/legacy/http.py b/src/websockets/legacy/http.py index 2ac7f7092..9a553e175 100644 --- a/src/websockets/legacy/http.py +++ b/src/websockets/legacy/http.py @@ -2,7 +2,6 @@ import asyncio import re -from typing import Tuple from ..datastructures import Headers from ..exceptions import SecurityError @@ -42,7 +41,7 @@ def d(value: bytes) -> str: _value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*") -async def read_request(stream: asyncio.StreamReader) -> Tuple[str, Headers]: +async def read_request(stream: asyncio.StreamReader) -> tuple[str, Headers]: """ Read an HTTP/1.1 GET request and return ``(path, headers)``. @@ -91,7 +90,7 @@ async def read_request(stream: asyncio.StreamReader) -> Tuple[str, Headers]: return path, headers -async def read_response(stream: asyncio.StreamReader) -> Tuple[int, str, Headers]: +async def read_response(stream: asyncio.StreamReader) -> tuple[int, str, Headers]: """ Read an HTTP/1.1 response and return ``(status_code, reason, headers)``. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index f4c5901dc..67161019f 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -18,11 +18,8 @@ Awaitable, Callable, Deque, - Dict, Iterable, - List, Mapping, - Tuple, cast, ) @@ -262,7 +259,7 @@ def __init__( """Opening handshake response headers.""" # WebSocket protocol parameters. - self.extensions: List[Extension] = [] + self.extensions: list[Extension] = [] self.subprotocol: Subprotocol | None = None """Subprotocol, if one was negotiated.""" @@ -286,7 +283,7 @@ def __init__( self._fragmented_message_waiter: asyncio.Future[None] | None = None # Mapping of ping IDs to pong waiters, in chronological order. - self.pings: Dict[bytes, Tuple[asyncio.Future[float], float]] = {} + self.pings: dict[bytes, tuple[asyncio.Future[float], float]] = {} self.latency: float = 0 """ @@ -1042,7 +1039,7 @@ async def read_message(self) -> Data | None: return frame.data.decode("utf-8") if text else frame.data # 5.4. Fragmentation - fragments: List[Data] = [] + fragments: list[Data] = [] max_size = self.max_size if text: decoder_factory = codecs.getincrementaldecoder("utf-8") diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 13a6f5591..c0ea6a764 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -15,11 +15,8 @@ Callable, Generator, Iterable, - List, Sequence, - Set, Tuple, - Type, Union, cast, ) @@ -57,6 +54,7 @@ # Change to HeadersLike | ... when dropping Python < 3.10. HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] +# Change to tuple[...] when dropping Python < 3.9. HTTPResponse = Tuple[StatusLike, HeadersLike, bytes] @@ -263,7 +261,7 @@ async def handler(self) -> None: self.ws_server.unregister(self) self.logger.info("connection closed") - async def read_http_request(self) -> Tuple[str, Headers]: + async def read_http_request(self) -> tuple[str, Headers]: """ Read request line and headers from the HTTP request. @@ -349,7 +347,7 @@ async def process_request( request_headers: Request headers. Returns: - Tuple[StatusLike, HeadersLike, bytes] | None: :obj:`None` to + tuple[StatusLike, HeadersLike, bytes] | None: :obj:`None` to continue the WebSocket handshake normally. An HTTP response, represented by a 3-uple of the response status, @@ -401,7 +399,7 @@ def process_origin( def process_extensions( headers: Headers, available_extensions: Sequence[ServerExtensionFactory] | None, - ) -> Tuple[str | None, List[Extension]]: + ) -> tuple[str | None, list[Extension]]: """ Handle the Sec-WebSocket-Extensions HTTP request header. @@ -439,13 +437,13 @@ def process_extensions( """ response_header_value: str | None = None - extension_headers: List[ExtensionHeader] = [] - accepted_extensions: List[Extension] = [] + extension_headers: list[ExtensionHeader] = [] + accepted_extensions: list[Extension] = [] header_values = headers.get_all("Sec-WebSocket-Extensions") if header_values and available_extensions: - parsed_header_values: List[ExtensionHeader] = sum( + parsed_header_values: list[ExtensionHeader] = sum( [parse_extension(header_value) for header_value in header_values], [] ) @@ -502,7 +500,7 @@ def process_subprotocol( header_values = headers.get_all("Sec-WebSocket-Protocol") if header_values and available_subprotocols: - parsed_header_values: List[Subprotocol] = sum( + parsed_header_values: list[Subprotocol] = sum( [parse_subprotocol(header_value) for header_value in header_values], [] ) @@ -669,7 +667,7 @@ def __init__(self, logger: LoggerLike | None = None): self.logger = logger # Keep track of active connections. - self.websockets: Set[WebSocketServerProtocol] = set() + self.websockets: set[WebSocketServerProtocol] = set() # Task responsible for closing the server and terminating connections. self.close_task: asyncio.Task[None] | None = None @@ -871,7 +869,7 @@ async def __aenter__(self) -> WebSocketServer: # pragma: no cover async def __aexit__( self, - exc_type: Type[BaseException] | None, + exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: # pragma: no cover @@ -944,7 +942,7 @@ class Serve: It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. process_request (Callable[[str, Headers], \ - Awaitable[Tuple[StatusLike, HeadersLike, bytes] | None]] | None): + Awaitable[tuple[StatusLike, HeadersLike, bytes] | None]] | None): Intercept HTTP request before the opening handshake. See :meth:`~WebSocketServerProtocol.process_request` for details. select_subprotocol: Select a subprotocol supported by the client. @@ -1015,7 +1013,7 @@ def __init__( close_timeout = timeout # Backwards compatibility: create_protocol used to be called klass. - klass: Type[WebSocketServerProtocol] | None = kwargs.pop("klass", None) + klass: type[WebSocketServerProtocol] | None = kwargs.pop("klass", None) if klass is None: klass = WebSocketServerProtocol else: @@ -1100,7 +1098,7 @@ async def __aenter__(self) -> WebSocketServer: async def __aexit__( self, - exc_type: Type[BaseException] | None, + exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index f288a2733..2f5542f6e 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -3,7 +3,7 @@ import enum import logging import uuid -from typing import Generator, List, Type, Union +from typing import Generator, Union from .exceptions import ( ConnectionClosed, @@ -125,7 +125,7 @@ def __init__( # WebSocket protocol parameters. self.origin: Origin | None = None - self.extensions: List[Extension] = [] + self.extensions: list[Extension] = [] self.subprotocol: Subprotocol | None = None # Close code and reason, set when a close frame is sent or received. @@ -147,8 +147,8 @@ def __init__( # Parser state. self.reader = StreamReader() - self.events: List[Event] = [] - self.writes: List[bytes] = [] + self.events: list[Event] = [] + self.writes: list[bytes] = [] self.parser = self.parse() next(self.parser) # start coroutine self.parser_exc: Exception | None = None @@ -222,7 +222,7 @@ def close_exc(self) -> ConnectionClosed: """ assert self.state is CLOSED, "connection isn't closed yet" - exc_type: Type[ConnectionClosed] + exc_type: type[ConnectionClosed] if ( self.close_rcvd is not None and self.close_sent is not None @@ -458,7 +458,7 @@ def fail(self, code: int, reason: str = "") -> None: # Public method for getting incoming events after receiving data. - def events_received(self) -> List[Event]: + def events_received(self) -> list[Event]: """ Fetch events generated from data received from the network. @@ -474,7 +474,7 @@ def events_received(self) -> List[Event]: # Public method for getting outgoing data after receiving data or sending events. - def data_to_send(self) -> List[bytes]: + def data_to_send(self) -> list[bytes]: """ Obtain data to send to the network. diff --git a/src/websockets/server.py b/src/websockets/server.py index a92541085..f976ebad7 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -5,7 +5,7 @@ import email.utils import http import warnings -from typing import Any, Callable, Generator, List, Sequence, Tuple, cast +from typing import Any, Callable, Generator, Sequence, cast from .datastructures import Headers, MultipleValuesError from .exceptions import ( @@ -201,7 +201,7 @@ def accept(self, request: Request) -> Response: def process_request( self, request: Request, - ) -> Tuple[str, str | None, str | None]: + ) -> tuple[str, str | None, str | None]: """ Check a handshake request and negotiate extensions and subprotocol. @@ -224,7 +224,7 @@ def process_request( """ headers = request.headers - connection: List[ConnectionOption] = sum( + connection: list[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) @@ -233,7 +233,7 @@ def process_request( "Connection", ", ".join(connection) if connection else None ) - upgrade: List[UpgradeProtocol] = sum( + upgrade: list[UpgradeProtocol] = sum( [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] ) @@ -317,7 +317,7 @@ def process_origin(self, headers: Headers) -> Origin | None: def process_extensions( self, headers: Headers, - ) -> Tuple[str | None, List[Extension]]: + ) -> tuple[str | None, list[Extension]]: """ Handle the Sec-WebSocket-Extensions HTTP request header. @@ -355,13 +355,13 @@ def process_extensions( """ response_header_value: str | None = None - extension_headers: List[ExtensionHeader] = [] - accepted_extensions: List[Extension] = [] + extension_headers: list[ExtensionHeader] = [] + accepted_extensions: list[Extension] = [] header_values = headers.get_all("Sec-WebSocket-Extensions") if header_values and self.available_extensions: - parsed_header_values: List[ExtensionHeader] = sum( + parsed_header_values: list[ExtensionHeader] = sum( [parse_extension(header_value) for header_value in header_values], [] ) diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 60b49ebc3..c97a09402 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -4,7 +4,7 @@ import ssl as ssl_module import threading import warnings -from typing import Any, Sequence, Type +from typing import Any, Sequence from ..client import ClientProtocol from ..datastructures import HeadersLike @@ -146,7 +146,7 @@ def connect( # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization - create_connection: Type[ClientConnection] | None = None, + create_connection: type[ClientConnection] | None = None, **kwargs: Any, ) -> ClientConnection: """ diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 7a750331d..33d8299e2 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -8,7 +8,7 @@ import threading import uuid from types import TracebackType -from typing import Any, Dict, Iterable, Iterator, Mapping, Type +from typing import Any, Iterable, Iterator, Mapping from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl @@ -80,7 +80,7 @@ def __init__( self.close_deadline: Deadline | None = None # Mapping of ping IDs to pong waiters, in chronological order. - self.ping_waiters: Dict[bytes, threading.Event] = {} + self.ping_waiters: dict[bytes, threading.Event] = {} # Receiving events from the socket. This thread explicitly is marked as # to support creating a connection in a non-daemon thread then using it @@ -140,7 +140,7 @@ def __enter__(self) -> Connection: def __exit__( self, - exc_type: Type[BaseException] | None, + exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 2c604ba09..a6e78e7fd 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -3,7 +3,7 @@ import codecs import queue import threading -from typing import Iterator, List, cast +from typing import Iterator, cast from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from ..typing import Data @@ -44,7 +44,7 @@ def __init__(self) -> None: self.decoder: codecs.IncrementalDecoder | None = None # Buffer of frames belonging to the same message. - self.chunks: List[Data] = [] + self.chunks: list[Data] = [] # When switching from "buffering" to "streaming", we use a thread-safe # queue for transferring frames from the writing thread (library code) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index b801510b4..4f088b63a 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -10,7 +10,7 @@ import threading import warnings from types import TracebackType -from typing import Any, Callable, Sequence, Type +from typing import Any, Callable, Sequence from ..extensions.base import ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate @@ -262,7 +262,7 @@ def __enter__(self) -> WebSocketServer: def __exit__( self, - exc_type: Type[BaseException] | None, + exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: @@ -312,7 +312,7 @@ def serve( # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization - create_connection: Type[ServerConnection] | None = None, + create_connection: type[ServerConnection] | None = None, **kwargs: Any, ) -> WebSocketServer: """ diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 73d4a4754..6360c7a0a 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -56,13 +56,15 @@ ExtensionName = NewType("ExtensionName", str) """Name of a WebSocket extension.""" -# Change to str | None when dropping Python < 3.10. +# Change to tuple[str, Optional[str]] when dropping Python < 3.9. +# Change to tuple[str, str | None] when dropping Python < 3.10. ExtensionParameter = Tuple[str, Optional[str]] """Parameter of a WebSocket extension.""" # Private types +# Change to tuple[.., list[...]] when dropping Python < 3.9. ExtensionHeader = Tuple[ExtensionName, List[ExtensionParameter]] """Extension in a ``Sec-WebSocket-Extensions`` header.""" diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 902716066..5cb38a9cc 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -2,7 +2,6 @@ import dataclasses import urllib.parse -from typing import Tuple from . import exceptions @@ -47,7 +46,7 @@ def resource_name(self) -> str: return resource_name @property - def user_info(self) -> Tuple[str, str] | None: + def user_info(self) -> tuple[str, str] | None: if self.username is None: return None assert self.password is not None From f45286b3b2d54f8b79087b060858042b2488688b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 12 Jul 2024 09:37:09 +0200 Subject: [PATCH 047/109] Pick changes suggested by `pyupgrade --py38-plus`. Other changes were ignored, on purpose. --- src/websockets/sync/messages.py | 3 +-- tests/legacy/test_client_server.py | 2 +- tests/legacy/utils.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index a6e78e7fd..6cbff2595 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -53,8 +53,7 @@ def __init__(self) -> None: # value marking the end of the message, superseding message_complete. # Stream data from frames belonging to the same message. - # Remove quotes around type when dropping Python < 3.9. - self.chunks_queue: "queue.SimpleQueue[Data | None] | None" = None + self.chunks_queue: queue.SimpleQueue[Data | None] | None = None # This flag marks the end of the connection. self.closed = False diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 51a74734b..c38086572 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -65,7 +65,7 @@ async def default_handler(ws): await ws.wait_closed() await asyncio.sleep(2 * MS) else: - await ws.send((await ws.recv())) + await ws.send(await ws.recv()) async def redirect_request(path, headers, test, status): diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index 4a21dcaeb..28bc90df3 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -79,6 +79,6 @@ def assertDeprecationWarnings(self, recorded_warnings, expected_warnings): for recorded in recorded_warnings: self.assertEqual(type(recorded.message), DeprecationWarning) self.assertEqual( - set(str(recorded.message) for recorded in recorded_warnings), + {str(recorded.message) for recorded in recorded_warnings}, set(expected_warnings), ) From 650d08caf1c5f84c77a8bf8780a1b407a1432357 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jul 2024 08:19:32 +0200 Subject: [PATCH 048/109] Upgrade to mypy 1.11. --- src/websockets/legacy/auth.py | 5 +++++ src/websockets/legacy/client.py | 6 ++++-- src/websockets/legacy/server.py | 3 +++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index c2d30e4b4..8526bad6b 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -178,6 +178,11 @@ async def check_credentials(username: str, password: str) -> bool: if create_protocol is None: create_protocol = BasicAuthWebSocketServerProtocol + # Help mypy and avoid this error: "type[BasicAuthWebSocketServerProtocol] | + # Callable[..., BasicAuthWebSocketServerProtocol]" not callable [misc] + create_protocol = cast( + Callable[..., BasicAuthWebSocketServerProtocol], create_protocol + ) return functools.partial( create_protocol, realm=realm, diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index d9d69fdaa..b15eddf75 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -489,6 +489,9 @@ def __init__( if subprotocols is not None: validate_subprotocols(subprotocols) + # Help mypy and avoid this error: "type[WebSocketClientProtocol] | + # Callable[..., WebSocketClientProtocol]" not callable [misc] + create_protocol = cast(Callable[..., WebSocketClientProtocol], create_protocol) factory = functools.partial( create_protocol, logger=logger, @@ -641,8 +644,7 @@ def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]: async def __await_impl__(self) -> WebSocketClientProtocol: async with asyncio_timeout(self.open_timeout): for _redirects in range(self.MAX_REDIRECTS_ALLOWED): - _transport, _protocol = await self._create_connection() - protocol = cast(WebSocketClientProtocol, _protocol) + _transport, protocol = await self._create_connection() try: await protocol.handshake( self._wsuri, diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index c0ea6a764..08c82df25 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -1045,6 +1045,9 @@ def __init__( if subprotocols is not None: validate_subprotocols(subprotocols) + # Help mypy and avoid this error: "type[WebSocketServerProtocol] | + # Callable[..., WebSocketServerProtocol]" not callable [misc] + create_protocol = cast(Callable[..., WebSocketServerProtocol], create_protocol) factory = functools.partial( create_protocol, # For backwards compatibility with 10.0 or earlier. Done here in From e05f6dc83434dae7d91fc0db822ab15aa1e4c00b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jul 2024 08:05:24 +0200 Subject: [PATCH 049/109] Support ws:// to wss:// redirects. Fix #1454. --- src/websockets/legacy/client.py | 14 ++++++++++---- tests/legacy/test_client_server.py | 13 ++++++++++++- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index b15eddf75..b0e15b543 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -558,21 +558,27 @@ def handle_redirect(self, uri: str) -> None: raise SecurityError("redirect from WSS to WS") same_origin = ( - old_wsuri.host == new_wsuri.host and old_wsuri.port == new_wsuri.port + old_wsuri.secure == new_wsuri.secure + and old_wsuri.host == new_wsuri.host + and old_wsuri.port == new_wsuri.port ) - # Rewrite the host and port arguments for cross-origin redirects. + # Rewrite secure, host, and port for cross-origin redirects. # This preserves connection overrides with the host and port # arguments if the redirect points to the same host and port. if not same_origin: - # Replace the host and port argument passed to the protocol factory. factory = self._create_connection.args[0] + # Support TLS upgrade. + if not old_wsuri.secure and new_wsuri.secure: + factory.keywords["secure"] = True + self._create_connection.keywords.setdefault("ssl", True) + # Replace secure, host, and port arguments of the protocol factory. factory = functools.partial( factory.func, *factory.args, **dict(factory.keywords, host=new_wsuri.host, port=new_wsuri.port), ) - # Replace the host and port argument passed to create_connection. + # Replace secure, host, and port arguments of create_connection. self._create_connection = functools.partial( self._create_connection.func, *(factory, new_wsuri.host, new_wsuri.port), diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index c38086572..09b3b361a 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -75,6 +75,8 @@ async def redirect_request(path, headers, test, status): location = "/" elif path == "/infinite": location = get_server_uri(test.server, test.secure, "/infinite") + elif path == "/force_secure": + location = get_server_uri(test.server, True, "/") elif path == "/force_insecure": location = get_server_uri(test.server, False, "/") elif path == "/missing_location": @@ -1290,7 +1292,16 @@ def test_connection_error_during_closing_handshake(self, close): class ClientServerTests( CommonClientServerTests, ClientServerTestsMixin, AsyncioTestCase ): - pass + + def test_redirect_secure(self): + with temp_test_redirecting_server(self): + # websockets doesn't support serving non-TLS and TLS connections + # from the same server and this test suite makes it difficult to + # run two servers. Therefore, we expect the redirect to create a + # TLS client connection to a non-TLS server, which will fail. + with self.assertRaises(ssl.SSLError): + with self.temp_client("/force_secure"): + self.fail("did not raise") class SecureClientServerTests( From 61b69db60cceff6c46ff308d2c10f7f81480788c Mon Sep 17 00:00:00 2001 From: Antonio Curado Date: Mon, 20 May 2024 18:29:55 +0200 Subject: [PATCH 050/109] Correct handle exceptions in `legacy/broadcast` --- src/websockets/legacy/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 67161019f..3d09440e1 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1630,5 +1630,5 @@ def broadcast( exc_info=True, ) - if raise_exceptions: + if raise_exceptions and exceptions: raise ExceptionGroup("skipped broadcast", exceptions) From 96d3adf6617fd53fc7a7adcc5a560eeeb8493473 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jul 2024 08:30:01 +0200 Subject: [PATCH 051/109] Add tests for previous commit. --- tests/legacy/test_protocol.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index f3dcd9ac7..05d2f3795 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1472,10 +1472,24 @@ def test_broadcast_text(self): broadcast([self.protocol], "café") self.assertOneFrameSent(True, OP_TEXT, "café".encode()) + @unittest.skipIf( + sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+" + ) + def test_broadcast_text_reports_no_errors(self): + broadcast([self.protocol], "café", raise_exceptions=True) + self.assertOneFrameSent(True, OP_TEXT, "café".encode()) + def test_broadcast_binary(self): broadcast([self.protocol], b"tea") self.assertOneFrameSent(True, OP_BINARY, b"tea") + @unittest.skipIf( + sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+" + ) + def test_broadcast_binary_reports_no_errors(self): + broadcast([self.protocol], b"tea", raise_exceptions=True) + self.assertOneFrameSent(True, OP_BINARY, b"tea") + def test_broadcast_type_error(self): with self.assertRaises(TypeError): broadcast([self.protocol], ["ca", "fé"]) From ee997c157d3214f758c2422fc44c2a582153f58a Mon Sep 17 00:00:00 2001 From: xuanzhi33 <37460139+xuanzhi33@users.noreply.github.com> Date: Wed, 28 Feb 2024 17:36:11 +0800 Subject: [PATCH 052/109] docs: Correct the example for "Starting a server" in the API reference --- src/websockets/legacy/server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 08c82df25..fb91265d8 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -899,7 +899,8 @@ class Serve: server = await serve(...) await stop - await server.close() + server.close() + await server.wait_closed() :func:`serve` can be used as an asynchronous context manager. Then, the server is shut down automatically when exiting the context:: From 1210ee81e470bd8df7700d459ba101263fb7413c Mon Sep 17 00:00:00 2001 From: xuanzhi33 <37460139+xuanzhi33@users.noreply.github.com> Date: Wed, 28 Feb 2024 18:30:01 +0800 Subject: [PATCH 053/109] Update server.rst --- docs/faq/server.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/faq/server.rst b/docs/faq/server.rst index 08b412d30..cba1cd35f 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -300,7 +300,8 @@ Here's how to adapt the example just above:: server = await websockets.serve(echo, "localhost", 8765) await stop - await server.close(close_connections=False) + server.close(close_connections=False) + await server.wait_closed() How do I implement a health check? ---------------------------------- From 41c42b8681dc1245a65e2db8491a573bba1827dc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jul 2024 13:39:14 +0200 Subject: [PATCH 054/109] Make it easier to debug version numbers. --- src/websockets/version.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/websockets/version.py b/src/websockets/version.py index f1de3cbf4..145c7a9ed 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -55,7 +55,8 @@ def get_version(tag: str) -> str: else: description_re = r"[0-9.]+-([0-9]+)-(g[0-9a-f]{7,}(?:-dirty)?)" match = re.fullmatch(description_re, description) - assert match is not None + if match is None: + raise ValueError(f"Unexpected git description: {description}") distance, remainder = match.groups() remainder = remainder.replace("-", ".") # required by PEP 440 return f"{tag}.dev{distance}+{remainder}" @@ -75,7 +76,8 @@ def get_commit(tag: str, version: str) -> str: # Extract commit from version, falling back to tag if not available. version_re = r"[0-9.]+\.dev[0-9]+\+g([0-9a-f]{7,}|unknown)(?:\.dirty)?" match = re.fullmatch(version_re, version) - assert match is not None + if match is None: + raise ValueError(f"Unexpected version: {version}") (commit,) = match.groups() return tag if commit == "unknown" else commit From eaa64c07676a9c28d17c9538242fab4638754584 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jul 2024 14:41:49 +0200 Subject: [PATCH 055/109] Avoid reading the wrong version. This was causing builds to fail on Read the Docs since sphinx-autobuild added websockets as a dependency. --- src/websockets/version.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/websockets/version.py b/src/websockets/version.py index 145c7a9ed..46ae34a47 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -34,8 +34,20 @@ def get_version(tag: str) -> str: file_path = pathlib.Path(__file__) root_dir = file_path.parents[0 if file_path.name == "setup.py" else 2] - # Read version from git if available. This prevents reading stale - # information from src/websockets.egg-info after building a sdist. + # Read version from package metadata if it is installed. + try: + version = importlib.metadata.version("websockets") + except ImportError: + pass + else: + # Check that this file belongs to the installed package. + files = importlib.metadata.files("websockets") + if files: + version_file = [f for f in files if f.name == file_path.name][0] + if version_file.locate() == file_path: + return version + + # Read version from git if available. try: description = subprocess.run( ["git", "describe", "--dirty", "--tags", "--long"], @@ -61,12 +73,6 @@ def get_version(tag: str) -> str: remainder = remainder.replace("-", ".") # required by PEP 440 return f"{tag}.dev{distance}+{remainder}" - # Read version from package metadata if it is installed. - try: - return importlib.metadata.version("websockets") - except ImportError: - pass - # Avoid crashing if the development version cannot be determined. return f"{tag}.dev0+gunknown" From e10eebaec368b04f28102df513e26e933ed5a6fd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jul 2024 15:24:25 +0200 Subject: [PATCH 056/109] Unshallow git clone on RtD. This is required for get_version to find the last tag. --- .readthedocs.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.readthedocs.yml b/.readthedocs.yml index 0369e0656..28c990c5c 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -4,6 +4,9 @@ build: os: ubuntu-20.04 tools: python: "3.10" + jobs: + post_checkout: + - git fetch --unshallow sphinx: configuration: docs/conf.py From c8c0a9bfee962540eb3c9c228e36d4ef7bd7ed42 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jul 2024 15:56:42 +0200 Subject: [PATCH 057/109] Improve error reporting when header is too long. Refs #1471. --- src/websockets/legacy/server.py | 6 +++++- src/websockets/server.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index fb91265d8..c0c138767 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -199,10 +199,14 @@ async def handler(self) -> None: elif isinstance(exc, InvalidHandshake): if self.debug: self.logger.debug("! invalid handshake", exc_info=True) + exc_str = f"{exc}" + while exc.__cause__ is not None: + exc = exc.__cause__ + exc_str += f"; {exc}" status, headers, body = ( http.HTTPStatus.BAD_REQUEST, Headers(), - f"Failed to open a WebSocket connection: {exc}.\n".encode(), + f"Failed to open a WebSocket connection: {exc_str}.\n".encode(), ) else: self.logger.error("opening handshake failed", exc_info=True) diff --git a/src/websockets/server.py b/src/websockets/server.py index f976ebad7..7f5631230 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -163,9 +163,13 @@ def accept(self, request: Request) -> Response: self.handshake_exc = exc if self.debug: self.logger.debug("! invalid handshake", exc_info=True) + exc_str = f"{exc}" + while exc.__cause__ is not None: + exc = exc.__cause__ + exc_str += f"; {exc}" return self.reject( http.HTTPStatus.BAD_REQUEST, - f"Failed to open a WebSocket connection: {exc}.\n", + f"Failed to open a WebSocket connection: {exc_str}.\n", ) except Exception as exc: # Handle exceptions raised by user-provided select_subprotocol and From d26bac47eac98e6f3b77358b8c836ed02e493fc6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 30 Jul 2024 09:04:22 +0200 Subject: [PATCH 058/109] Make eaa64c07 more robust. This avoids crashing on ossfuzz, which uses a custom loader. --- src/websockets/version.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/websockets/version.py b/src/websockets/version.py index 46ae34a47..44709a91b 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -43,9 +43,11 @@ def get_version(tag: str) -> str: # Check that this file belongs to the installed package. files = importlib.metadata.files("websockets") if files: - version_file = [f for f in files if f.name == file_path.name][0] - if version_file.locate() == file_path: - return version + version_files = [f for f in files if f.name == file_path.name] + if version_files: + version_file = version_files[0] + if version_file.locate() == file_path: + return version # Read version from git if available. try: From d2710227cd2464162861b584f5dd83c472208929 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 3 Aug 2024 10:04:43 +0200 Subject: [PATCH 059/109] Make mypy happy. --- src/websockets/legacy/server.py | 9 +++++---- src/websockets/server.py | 9 +++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index c0c138767..93698e1cb 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -199,10 +199,11 @@ async def handler(self) -> None: elif isinstance(exc, InvalidHandshake): if self.debug: self.logger.debug("! invalid handshake", exc_info=True) - exc_str = f"{exc}" - while exc.__cause__ is not None: - exc = exc.__cause__ - exc_str += f"; {exc}" + exc_chain = cast(BaseException, exc) + exc_str = f"{exc_chain}" + while exc_chain.__cause__ is not None: + exc_chain = exc_chain.__cause__ + exc_str += f"; {exc_chain}" status, headers, body = ( http.HTTPStatus.BAD_REQUEST, Headers(), diff --git a/src/websockets/server.py b/src/websockets/server.py index 7f5631230..baab400d4 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -163,10 +163,11 @@ def accept(self, request: Request) -> Response: self.handshake_exc = exc if self.debug: self.logger.debug("! invalid handshake", exc_info=True) - exc_str = f"{exc}" - while exc.__cause__ is not None: - exc = exc.__cause__ - exc_str += f"; {exc}" + exc_chain = cast(BaseException, exc) + exc_str = f"{exc_chain}" + while exc_chain.__cause__ is not None: + exc_chain = exc_chain.__cause__ + exc_str += f"; {exc_chain}" return self.reject( http.HTTPStatus.BAD_REQUEST, f"Failed to open a WebSocket connection: {exc_str}.\n", From 309e62fa89311de51083e1a62adf06d4450fc5f2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 3 Aug 2024 10:09:24 +0200 Subject: [PATCH 060/109] Test against current PyPy versions. --- .github/workflows/tests.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8161f1cbb..15a45bdfb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -62,15 +62,15 @@ jobs: - "3.10" - "3.11" - "3.12" - - "pypy-3.8" - "pypy-3.9" + - "pypy-3.10" is_main: - ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} exclude: - - python: "pypy-3.8" - is_main: false - python: "pypy-3.9" is_main: false + - python: "pypy-3.10" + is_main: false steps: - name: Check out repository uses: actions/checkout@v4 From fab77d60b660585bcdd996a10ec904f79c901085 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 3 Aug 2024 10:38:17 +0200 Subject: [PATCH 061/109] Annotate __init__ methods consistently. --- src/websockets/client.py | 2 +- src/websockets/legacy/server.py | 2 +- src/websockets/server.py | 2 +- src/websockets/sync/server.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 8f78ac320..07d1d34ed 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -79,7 +79,7 @@ def __init__( state: State = CONNECTING, max_size: int | None = 2**20, logger: LoggerLike | None = None, - ): + ) -> None: super().__init__( side=CLIENT, state=state, diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 93698e1cb..39464be6c 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -666,7 +666,7 @@ class WebSocketServer: """ - def __init__(self, logger: LoggerLike | None = None): + def __init__(self, logger: LoggerLike | None = None) -> None: if logger is None: logger = logging.getLogger("websockets.server") self.logger = logger diff --git a/src/websockets/server.py b/src/websockets/server.py index baab400d4..7211d3cbf 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -90,7 +90,7 @@ def __init__( state: State = CONNECTING, max_size: int | None = 2**20, logger: LoggerLike | None = None, - ): + ) -> None: super().__init__( side=SERVER, state=state, diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 4f088b63a..7fb46f5aa 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -200,7 +200,7 @@ def __init__( socket: socket.socket, handler: Callable[[socket.socket, Any], None], logger: LoggerLike | None = None, - ): + ) -> None: self.socket = socket self.handler = handler if logger is None: From 14cca7699971c19c30a38cad260aeb5f26e0c3ca Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 3 Aug 2024 10:46:04 +0200 Subject: [PATCH 062/109] Bugs in coverage were fixed \o/ --- src/websockets/legacy/protocol.py | 3 +-- tests/legacy/test_client_server.py | 1 - tests/legacy/test_protocol.py | 18 ++++++------------ 3 files changed, 7 insertions(+), 15 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 3d09440e1..de9ea59b6 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1349,8 +1349,7 @@ async def close_transport(self) -> None: # Abort the TCP connection. Buffers are discarded. if self.debug: self.logger.debug("x aborting TCP connection") - # Due to a bug in coverage, this is erroneously reported as not covered. - self.transport.abort() # pragma: no cover + self.transport.abort() # connection_lost() is called quickly after aborting. await self.wait_for_connection_lost() diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 09b3b361a..0c3f22156 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1584,7 +1584,6 @@ async def run_client(): else: # Exit block with an exception. raise Exception("BOOM") - pass # work around bug in coverage with self.assertLogs("websockets", logging.INFO) as logs: with self.assertRaises(Exception) as raised: diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 05d2f3795..d6303dcc7 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1613,8 +1613,7 @@ def test_local_close_connection_lost_timeout_after_write_eof(self): self.receive_frame(self.close_frame) self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) - # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed( # pragma: no cover + self.assertConnectionClosed( CloseCode.NORMAL_CLOSURE, "close", ) @@ -1634,8 +1633,7 @@ def test_local_close_connection_lost_timeout_after_close(self): self.receive_frame(self.close_frame) self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) - # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed( # pragma: no cover + self.assertConnectionClosed( CloseCode.NORMAL_CLOSURE, "close", ) @@ -1656,8 +1654,7 @@ def test_local_close_send_close_frame_timeout(self): # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(19 * MS, 29 * MS): self.loop.run_until_complete(self.protocol.close(reason="close")) - # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed( # pragma: no cover + self.assertConnectionClosed( CloseCode.ABNORMAL_CLOSURE, "", ) @@ -1670,8 +1667,7 @@ def test_local_close_receive_close_frame_timeout(self): # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(19 * MS, 29 * MS): self.loop.run_until_complete(self.protocol.close(reason="close")) - # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed( # pragma: no cover + self.assertConnectionClosed( CloseCode.ABNORMAL_CLOSURE, "", ) @@ -1689,8 +1685,7 @@ def test_local_close_connection_lost_timeout_after_write_eof(self): self.receive_frame(self.close_frame) self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) - # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed( # pragma: no cover + self.assertConnectionClosed( CloseCode.NORMAL_CLOSURE, "close", ) @@ -1713,8 +1708,7 @@ def test_local_close_connection_lost_timeout_after_close(self): self.receive_frame(self.close_frame) self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) - # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed( # pragma: no cover + self.assertConnectionClosed( CloseCode.NORMAL_CLOSURE, "close", ) From 7bb18a6ea84d2651b68ad45f5e9464a47d314b6b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 5 Aug 2024 07:53:06 +0200 Subject: [PATCH 063/109] Update references to Python's bug tracker. --- src/websockets/__main__.py | 2 +- src/websockets/legacy/protocol.py | 6 +++--- src/websockets/legacy/server.py | 3 ++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index f2ea5cf4e..8647481d0 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -22,7 +22,7 @@ def win_enable_vt100() -> None: """ Enable VT-100 for console output on Windows. - See also https://bugs.python.org/issue29059. + See also https://github.com/python/cpython/issues/73245. """ import ctypes diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index de9ea59b6..57cb4e770 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1175,9 +1175,9 @@ def write_frame_sync(self, fin: bool, opcode: int, data: bytes) -> None: async def drain(self) -> None: try: - # drain() cannot be called concurrently by multiple coroutines: - # http://bugs.python.org/issue29930. Remove this lock when no - # version of Python where this bugs exists is supported anymore. + # drain() cannot be called concurrently by multiple coroutines. + # See https://github.com/python/cpython/issues/74116 for details. + # This workaround can be removed when dropping Python < 3.10. async with self._drain_lock: # Handle flow control automatically. await self._drain() diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 39464be6c..f4442fecc 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -764,7 +764,8 @@ async def _close(self, close_connections: bool) -> None: self.server.close() # Wait until all accepted connections reach connection_made() and call - # register(). See https://bugs.python.org/issue34852 for details. + # register(). See https://github.com/python/cpython/issues/79033 for + # details. This workaround can be removed when dropping Python < 3.11. await asyncio.sleep(0) if close_connections: From 273db5bcc4113061bd7d8f0a4edbf6c4d76c4d84 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 4 Aug 2024 15:18:07 +0200 Subject: [PATCH 064/109] Make it easier to enable logs while running tests. --- tests/__init__.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index dd78609f5..bb1866f2d 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,5 +1,14 @@ import logging +import os -# Avoid displaying stack traces at the ERROR logging level. -logging.basicConfig(level=logging.CRITICAL) +format = "%(asctime)s %(levelname)s %(name)s %(message)s" + +if bool(os.environ.get("WEBSOCKETS_DEBUG")): # pragma: no cover + # Display every frame sent or received in debug mode. + level = logging.DEBUG +else: + # Hide stack traces of exceptions. + level = logging.CRITICAL + +logging.basicConfig(format=format, level=level) From cbcb7fd715be0f1efb98102302739e9d9f3ca08c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 5 Aug 2024 15:59:16 +0200 Subject: [PATCH 065/109] Pass WEBSOCKETS_TESTS_TIMEOUT_FACTOR to tox. Previously, despite being declared in .github/workflows/tests.yml, it had no effect because tox insulates test runs from the environment. --- tox.ini | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 06003c85b..b00833e73 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = +env_list = py38 py39 py310 @@ -12,6 +12,7 @@ envlist = [testenv] commands = python -W error::DeprecationWarning -W error::PendingDeprecationWarning -m unittest {posargs} +pass_env = WEBSOCKETS_* [testenv:coverage] commands = From 8c4fd9c24a701b7050681f786b2918d205b91338 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 5 Aug 2024 16:04:10 +0200 Subject: [PATCH 066/109] Remove superfluous `coverage erase`. `coverage run` starts clean unless `--append` is specified. --- tox.ini | 1 - 1 file changed, 1 deletion(-) diff --git a/tox.ini b/tox.ini index b00833e73..1edcfe261 100644 --- a/tox.ini +++ b/tox.ini @@ -16,7 +16,6 @@ pass_env = WEBSOCKETS_* [testenv:coverage] commands = - python -m coverage erase python -m coverage run --source {envsitepackagesdir}/websockets,tests -m unittest {posargs} python -m coverage report --show-missing --fail-under=100 deps = coverage From 02b333829e385d4fef42fa0565996adcfea653b3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 5 Aug 2024 08:59:51 +0200 Subject: [PATCH 067/109] Make Protocol.receive_eof idempotent. This removes the need for keeping track of whether you called it or not, especially in an asyncio context where it may be called in eof_received or in connection_lost. --- src/websockets/protocol.py | 5 +++-- tests/test_protocol.py | 8 ++------ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 2f5542f6e..7f2b45c74 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -270,10 +270,11 @@ def receive_eof(self) -> None: - You aren't expected to call :meth:`events_received`; it won't return any new events. - Raises: - EOFError: If :meth:`receive_eof` was called earlier. + :meth:`receive_eof` is idempotent. """ + if self.reader.eof: + return self.reader.feed_eof() next(self.parser) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index e1527525b..7f1276bb2 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1590,18 +1590,14 @@ def test_client_receives_eof_after_eof(self): client.receive_data(b"\x88\x00") self.assertConnectionClosing(client) client.receive_eof() - with self.assertRaises(EOFError) as raised: - client.receive_eof() - self.assertEqual(str(raised.exception), "stream ended") + client.receive_eof() # this is idempotent def test_server_receives_eof_after_eof(self): server = Protocol(SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertConnectionClosing(server) server.receive_eof() - with self.assertRaises(EOFError) as raised: - server.receive_eof() - self.assertEqual(str(raised.exception), "stream ended") + server.receive_eof() # this is idempotent class TCPCloseTests(ProtocolTestCase): From 3ad92b50515e5c83344f0771a34e8d9e7cd8ff4e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 7 Aug 2024 15:28:18 +0200 Subject: [PATCH 068/109] Don't specify the encoding when it's utf-8. --- src/websockets/frames.py | 2 +- src/websockets/legacy/protocol.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 0da676432..af56d3f8f 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -447,7 +447,7 @@ def parse(cls, data: bytes) -> Close: """ if len(data) >= 2: (code,) = struct.unpack("!H", data[:2]) - reason = data[2:].decode("utf-8") + reason = data[2:].decode() close = cls(code, reason) close.check() return close diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 57cb4e770..c28bdcf48 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1036,7 +1036,7 @@ async def read_message(self) -> Data | None: # Shortcut for the common case - no fragmentation if frame.fin: - return frame.data.decode("utf-8") if text else frame.data + return frame.data.decode() if text else frame.data # 5.4. Fragmentation fragments: list[Data] = [] From d0fd9cf61432a8dcf1cd639139ebb57d4b522c01 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 4 Aug 2024 08:08:35 +0200 Subject: [PATCH 069/109] Improve tests for sync implementation slightly. --- src/websockets/sync/connection.py | 3 +- tests/sync/connection.py | 4 +- tests/sync/test_connection.py | 76 +++++++++++++++++++++---------- 3 files changed, 57 insertions(+), 26 deletions(-) diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 33d8299e2..2bcb3aa0e 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -373,6 +373,7 @@ def send(self, message: Data | Iterable[Data]) -> None: except RuntimeError: # We didn't start sending a fragmented message. + # The connection is still usable. raise except Exception: @@ -756,7 +757,7 @@ def set_recv_exc(self, exc: BaseException | None) -> None: """ assert self.protocol_mutex.locked() - if self.recv_exc is None: + if self.recv_exc is None: # pragma: no branch self.recv_exc = exc def close_socket(self) -> None: diff --git a/tests/sync/connection.py b/tests/sync/connection.py index 89d4909ee..9c8bacea0 100644 --- a/tests/sync/connection.py +++ b/tests/sync/connection.py @@ -8,7 +8,7 @@ class InterceptingConnection(Connection): """ Connection subclass that can intercept outgoing packets. - By interfacing with this connection, you can simulate network conditions + By interfacing with this connection, we simulate network conditions affecting what the component being tested receives during a test. """ @@ -80,7 +80,7 @@ def drop_eof_sent(self): class InterceptingSocket: """ - Socket wrapper that intercepts calls to sendall and shutdown. + Socket wrapper that intercepts calls to ``sendall()`` and ``shutdown()``. This is coupled to the implementation, which relies on these two methods. diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 953c8c253..88cbcd669 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -246,13 +246,15 @@ def test_recv_streaming_connection_closed_ok(self): """recv_streaming raises ConnectionClosedOK after a normal closure.""" self.remote_connection.close() with self.assertRaises(ConnectionClosedOK): - list(self.connection.recv_streaming()) + for _ in self.connection.recv_streaming(): + self.fail("did not raise") def test_recv_streaming_connection_closed_error(self): """recv_streaming raises ConnectionClosedError after an error.""" self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): - list(self.connection.recv_streaming()) + for _ in self.connection.recv_streaming(): + self.fail("did not raise") def test_recv_streaming_during_recv(self): """recv_streaming raises RuntimeError when called concurrently with recv.""" @@ -260,7 +262,8 @@ def test_recv_streaming_during_recv(self): recv_thread.start() with self.assertRaises(RuntimeError) as raised: - list(self.connection.recv_streaming()) + for _ in self.connection.recv_streaming(): + self.fail("did not raise") self.assertEqual( str(raised.exception), "cannot call recv_streaming while another thread " @@ -278,7 +281,8 @@ def test_recv_streaming_during_recv_streaming(self): recv_streaming_thread.start() with self.assertRaises(RuntimeError) as raised: - list(self.connection.recv_streaming()) + for _ in self.connection.recv_streaming(): + self.fail("did not raise") self.assertEqual( str(raised.exception), r"cannot call recv_streaming while another thread " @@ -374,7 +378,7 @@ def test_send_empty_iterable(self): """send does nothing when called with an empty iterable.""" self.connection.send([]) self.connection.close() - self.assertEqual(list(iter(self.remote_connection)), []) + self.assertEqual(list(self.remote_connection), []) def test_send_mixed_iterable(self): """send raises TypeError when called with an iterable of inconsistent types.""" @@ -437,7 +441,7 @@ def test_close_waits_for_connection_closed(self): def test_close_timeout_waiting_for_close_frame(self): """close times out if no close frame is received.""" - with self.drop_eof_rcvd(), self.drop_frames_rcvd(): + with self.drop_frames_rcvd(), self.drop_eof_rcvd(): self.connection.close() with self.assertRaises(ConnectionClosedError) as raised: @@ -464,6 +468,10 @@ def test_close_timeout_waiting_for_connection_closed(self): self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) def test_close_waits_for_recv(self): + # The sync implementation doesn't have a buffer for incoming messsages. + # It requires reading incoming frames until the close frame is reached. + # This behavior — close() blocks until recv() is called — is less than + # ideal and inconsistent with the asyncio implementation. self.remote_connection.send("😀") close_thread = threading.Thread(target=self.connection.close) @@ -547,6 +555,25 @@ def closer(): close_thread.join() + def test_close_during_recv(self): + """close aborts recv when called concurrently with recv.""" + + def closer(): + time.sleep(MS) + self.connection.close() + + close_thread = threading.Thread(target=closer) + close_thread.start() + + with self.assertRaises(ConnectionClosedOK) as raised: + self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + close_thread.join() + def test_close_during_send(self): """close fails the connection when called concurrently with send.""" close_gate = threading.Event() @@ -599,42 +626,45 @@ def test_ping_explicit_binary(self): self.connection.ping(b"ping") self.assertFrameSent(Frame(Opcode.PING, b"ping")) - def test_ping_duplicate_payload(self): - """ping rejects the same payload until receiving the pong.""" - with self.remote_connection.protocol_mutex: # block response to ping - pong_waiter = self.connection.ping("idem") - with self.assertRaises(RuntimeError) as raised: - self.connection.ping("idem") - self.assertEqual( - str(raised.exception), - "already waiting for a pong with the same data", - ) - self.assertTrue(pong_waiter.wait(MS)) - self.connection.ping("idem") # doesn't raise an exception - def test_acknowledge_ping(self): """ping is acknowledged by a pong with the same payload.""" - with self.drop_frames_rcvd(): + with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = self.connection.ping("this") - self.assertFalse(pong_waiter.wait(MS)) self.remote_connection.pong("this") self.assertTrue(pong_waiter.wait(MS)) def test_acknowledge_ping_non_matching_pong(self): """ping isn't acknowledged by a pong with a different payload.""" - with self.drop_frames_rcvd(): + with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = self.connection.ping("this") self.remote_connection.pong("that") self.assertFalse(pong_waiter.wait(MS)) def test_acknowledge_previous_ping(self): """ping is acknowledged by a pong with the same payload as a later ping.""" - with self.drop_frames_rcvd(): + with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = self.connection.ping("this") self.connection.ping("that") self.remote_connection.pong("that") self.assertTrue(pong_waiter.wait(MS)) + def test_ping_duplicate_payload(self): + """ping rejects the same payload until receiving the pong.""" + with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = self.connection.ping("idem") + + with self.assertRaises(RuntimeError) as raised: + self.connection.ping("idem") + self.assertEqual( + str(raised.exception), + "already waiting for a pong with the same data", + ) + + self.remote_connection.pong("idem") + self.assertTrue(pong_waiter.wait(MS)) + + self.connection.ping("idem") # doesn't raise an exception + # Test pong. def test_pong(self): From c92fba02db87a88af54a47e7f5bae050587490dd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 7 Apr 2023 23:11:11 +0200 Subject: [PATCH 070/109] Move asyncio compatibility to a new package. --- pyproject.toml | 4 ++-- src/websockets/asyncio/__init__.py | 0 src/websockets/{legacy => asyncio}/async_timeout.py | 0 src/websockets/{legacy => asyncio}/compatibility.py | 0 src/websockets/legacy/client.py | 2 +- src/websockets/legacy/protocol.py | 2 +- src/websockets/legacy/server.py | 2 +- tests/legacy/test_client_server.py | 2 +- tests/maxi_cov.py | 6 ++++-- 9 files changed, 10 insertions(+), 8 deletions(-) create mode 100644 src/websockets/asyncio/__init__.py rename src/websockets/{legacy => asyncio}/async_timeout.py (100%) rename src/websockets/{legacy => asyncio}/compatibility.py (100%) diff --git a/pyproject.toml b/pyproject.toml index 2367849ca..de8acd6a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,8 +47,8 @@ branch = true omit = [ # */websockets matches src/websockets and .tox/**/site-packages/websockets "*/websockets/__main__.py", - "*/websockets/legacy/async_timeout.py", - "*/websockets/legacy/compatibility.py", + "*/websockets/asyncio/async_timeout.py", + "*/websockets/asyncio/compatibility.py", "tests/maxi_cov.py", ] diff --git a/src/websockets/asyncio/__init__.py b/src/websockets/asyncio/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/websockets/legacy/async_timeout.py b/src/websockets/asyncio/async_timeout.py similarity index 100% rename from src/websockets/legacy/async_timeout.py rename to src/websockets/asyncio/async_timeout.py diff --git a/src/websockets/legacy/compatibility.py b/src/websockets/asyncio/compatibility.py similarity index 100% rename from src/websockets/legacy/compatibility.py rename to src/websockets/asyncio/compatibility.py diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index b0e15b543..d1d8d5608 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -16,6 +16,7 @@ cast, ) +from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers, HeadersLike from ..exceptions import ( InvalidHandshake, @@ -40,7 +41,6 @@ from ..http import USER_AGENT from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri -from .compatibility import asyncio_timeout from .handshake import build_request, check_response from .http import read_response from .protocol import WebSocketCommonProtocol diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index c28bdcf48..120ff8e73 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -23,6 +23,7 @@ cast, ) +from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers from ..exceptions import ( ConnectionClosed, @@ -49,7 +50,6 @@ ) from ..protocol import State from ..typing import Data, LoggerLike, Subprotocol -from .compatibility import asyncio_timeout from .framing import Frame diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index f4442fecc..208ffa780 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -21,6 +21,7 @@ cast, ) +from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers, HeadersLike, MultipleValuesError from ..exceptions import ( AbortHandshake, @@ -42,7 +43,6 @@ from ..http import USER_AGENT from ..protocol import State from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol -from .compatibility import asyncio_timeout from .handshake import build_response, check_request from .http import read_request from .protocol import WebSocketCommonProtocol diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 0c3f22156..b5c5d726a 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -14,6 +14,7 @@ import urllib.request import warnings +from websockets.asyncio.compatibility import asyncio_timeout from websockets.datastructures import Headers from websockets.exceptions import ( ConnectionClosed, @@ -30,7 +31,6 @@ from websockets.frames import CloseCode from websockets.http import USER_AGENT from websockets.legacy.client import * -from websockets.legacy.compatibility import asyncio_timeout from websockets.legacy.handshake import build_response from websockets.legacy.http import read_response from websockets.legacy.server import * diff --git a/tests/maxi_cov.py b/tests/maxi_cov.py index bc4a44e8c..83686c3d3 100755 --- a/tests/maxi_cov.py +++ b/tests/maxi_cov.py @@ -52,8 +52,9 @@ def get_mapping(src_dir="src"): os.path.relpath(src_file, src_dir) for src_file in sorted(src_files) if "legacy" not in os.path.dirname(src_file) - if os.path.basename(src_file) != "__init__.py" + and os.path.basename(src_file) != "__init__.py" and os.path.basename(src_file) != "__main__.py" + and os.path.basename(src_file) != "async_timeout.py" and os.path.basename(src_file) != "compatibility.py" ] test_files = [ @@ -102,7 +103,8 @@ def get_ignored_files(src_dir="src"): "*/websockets/typing.py", # We don't test compatibility modules with previous versions of Python # or websockets (import locations). - "*/websockets/*/compatibility.py", + "*/websockets/asyncio/async_timeout.py", + "*/websockets/asyncio/compatibility.py", "*/websockets/auth.py", # This approach isn't applicable to the test suite of the legacy # implementation, due to the huge test_client_server test module. From 9f8f2f27218e4dc7ad4126109e6ffe012946b71b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 7 Apr 2023 23:12:28 +0200 Subject: [PATCH 071/109] Add asyncio message assembler. --- src/websockets/asyncio/compatibility.py | 20 +- src/websockets/asyncio/messages.py | 283 ++++++++++++++ src/websockets/sync/messages.py | 4 +- tests/asyncio/__init__.py | 0 tests/asyncio/test_messages.py | 471 ++++++++++++++++++++++++ tests/asyncio/utils.py | 5 + 6 files changed, 777 insertions(+), 6 deletions(-) create mode 100644 src/websockets/asyncio/messages.py create mode 100644 tests/asyncio/__init__.py create mode 100644 tests/asyncio/test_messages.py create mode 100644 tests/asyncio/utils.py diff --git a/src/websockets/asyncio/compatibility.py b/src/websockets/asyncio/compatibility.py index 6bd01e70d..390f00ac7 100644 --- a/src/websockets/asyncio/compatibility.py +++ b/src/websockets/asyncio/compatibility.py @@ -3,10 +3,22 @@ import sys -__all__ = ["asyncio_timeout"] +__all__ = ["TimeoutError", "aiter", "anext", "asyncio_timeout"] if sys.version_info[:2] >= (3, 11): - from asyncio import timeout as asyncio_timeout # noqa: F401 -else: - from .async_timeout import timeout as asyncio_timeout # noqa: F401 + TimeoutError = TimeoutError + aiter = aiter + anext = anext + from asyncio import timeout as asyncio_timeout + +else: # Python < 3.11 + from asyncio import TimeoutError + + def aiter(async_iterable): + return type(async_iterable).__aiter__(async_iterable) + + async def anext(async_iterator): + return await type(async_iterator).__anext__(async_iterator) + + from .async_timeout import timeout as asyncio_timeout diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py new file mode 100644 index 000000000..2a9c4d37d --- /dev/null +++ b/src/websockets/asyncio/messages.py @@ -0,0 +1,283 @@ +from __future__ import annotations + +import asyncio +import codecs +import collections +from typing import ( + Any, + AsyncIterator, + Callable, + Generic, + Iterable, + TypeVar, +) + +from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame +from ..typing import Data + + +__all__ = ["Assembler"] + +UTF8Decoder = codecs.getincrementaldecoder("utf-8") + +T = TypeVar("T") + + +class SimpleQueue(Generic[T]): + """ + Simplified version of :class:`asyncio.Queue`. + + Provides only the subset of functionality needed by :class:`Assembler`. + + """ + + def __init__(self) -> None: + self.loop = asyncio.get_running_loop() + self.get_waiter: asyncio.Future[None] | None = None + self.queue: collections.deque[T] = collections.deque() + + def __len__(self) -> int: + return len(self.queue) + + def put(self, item: T) -> None: + """Put an item into the queue without waiting.""" + self.queue.append(item) + if self.get_waiter is not None and not self.get_waiter.done(): + self.get_waiter.set_result(None) + + async def get(self) -> T: + """Remove and return an item from the queue, waiting if necessary.""" + if not self.queue: + if self.get_waiter is not None: + raise RuntimeError("get is already running") + self.get_waiter = self.loop.create_future() + try: + await self.get_waiter + finally: + self.get_waiter.cancel() + self.get_waiter = None + return self.queue.popleft() + + def reset(self, items: Iterable[T]) -> None: + """Put back items into an empty, idle queue.""" + assert self.get_waiter is None, "cannot reset() while get() is running" + assert not self.queue, "cannot reset() while queue isn't empty" + self.queue.extend(items) + + def abort(self) -> None: + if self.get_waiter is not None and not self.get_waiter.done(): + self.get_waiter.set_exception(EOFError("stream of frames ended")) + # Clear the queue to avoid storing unnecessary data in memory. + self.queue.clear() + + +class Assembler: + """ + Assemble messages from frames. + + :class:`Assembler` expects only data frames. The stream of frames must + respect the protocol; if it doesn't, the behavior is undefined. + + Args: + pause: Called when the buffer of frames goes above the high water mark; + should pause reading from the network. + resume: Called when the buffer of frames goes below the low water mark; + should resume reading from the network. + + """ + + # coverage reports incorrectly: "line NN didn't jump to the function exit" + def __init__( # pragma: no cover + self, + pause: Callable[[], Any] = lambda: None, + resume: Callable[[], Any] = lambda: None, + ) -> None: + # Queue of incoming messages. Each item is a queue of frames. + self.frames: SimpleQueue[Frame] = SimpleQueue() + + # We cannot put a hard limit on the size of the queue because a single + # call to Protocol.data_received() could produce thousands of frames, + # which must be buffered. Instead, we pause reading when the buffer goes + # above the high limit and we resume when it goes under the low limit. + self.high = 16 + self.low = 4 + self.paused = False + self.pause = pause + self.resume = resume + + # This flag prevents concurrent calls to get() by user code. + self.get_in_progress = False + + # This flag marks the end of the connection. + self.closed = False + + async def get(self, decode: bool | None = None) -> Data: + """ + Read the next message. + + :meth:`get` returns a single :class:`str` or :class:`bytes`. + + If the message is fragmented, :meth:`get` waits until the last frame is + received, then it reassembles the message and returns it. To receive + messages frame by frame, use :meth:`get_iter` instead. + + Raises: + EOFError: If the stream of frames has ended. + RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter` + concurrently. + + """ + if self.closed: + raise EOFError("stream of frames ended") + + if self.get_in_progress: + raise RuntimeError("get() or get_iter() is already running") + + # Locking with get_in_progress ensures only one coroutine can get here. + self.get_in_progress = True + + # First frame + try: + frame = await self.frames.get() + except asyncio.CancelledError: + self.get_in_progress = False + raise + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + frames = [frame] + + # Following frames, for fragmented messages + while not frame.fin: + try: + frame = await self.frames.get() + except asyncio.CancelledError: + # Put frames already received back into the queue + # so that future calls to get() can return them. + self.frames.reset(frames) + self.get_in_progress = False + raise + self.maybe_resume() + assert frame.opcode is OP_CONT + frames.append(frame) + + self.get_in_progress = False + + data = b"".join(frame.data for frame in frames) + if decode: + return data.decode() + else: + return data + + async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: + """ + Stream the next message. + + Iterating the return value of :meth:`get_iter` asynchronously yields a + :class:`str` or :class:`bytes` for each frame in the message. + + The iterator must be fully consumed before calling :meth:`get_iter` or + :meth:`get` again. Else, :exc:`RuntimeError` is raised. + + This method only makes sense for fragmented messages. If messages aren't + fragmented, use :meth:`get` instead. + + Raises: + EOFError: If the stream of frames has ended. + RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter` + concurrently. + + """ + if self.closed: + raise EOFError("stream of frames ended") + + if self.get_in_progress: + raise RuntimeError("get() or get_iter() is already running") + + # Locking with get_in_progress ensures only one coroutine can get here. + self.get_in_progress = True + + # First frame + try: + frame = await self.frames.get() + except asyncio.CancelledError: + self.get_in_progress = False + raise + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + if decode: + decoder = UTF8Decoder() + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data + + # Following frames, for fragmented messages + while not frame.fin: + # We cannot handle asyncio.CancelledError because we don't buffer + # previous fragments — we're streaming them. Canceling get_iter() + # here will leave the assembler in a stuck state. Future calls to + # get() or get_iter() will raise RuntimeError. + frame = await self.frames.get() + self.maybe_resume() + assert frame.opcode is OP_CONT + if decode: + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data + + self.get_in_progress = False + + def put(self, frame: Frame) -> None: + """ + Add ``frame`` to the next message. + + Raises: + EOFError: If the stream of frames has ended. + + """ + if self.closed: + raise EOFError("stream of frames ended") + + self.frames.put(frame) + self.maybe_pause() + + def get_limits(self) -> tuple[int, int]: + """Return low and high water marks for flow control.""" + return self.low, self.high + + def set_limits(self, low: int = 4, high: int = 16) -> None: + """Configure low and high water marks for flow control.""" + self.low, self.high = low, high + + def maybe_pause(self) -> None: + """Pause the writer if queue is above the high water mark.""" + # Check for "> high" to support high = 0 + if len(self.frames) > self.high and not self.paused: + self.paused = True + self.pause() + + def maybe_resume(self) -> None: + """Resume the writer if queue is below the low water mark.""" + # Check for "<= low" to support low = 0 + if len(self.frames) <= self.low and self.paused: + self.paused = False + self.resume() + + def close(self) -> None: + """ + End the stream of frames. + + Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, + or :meth:`put` is safe. They will raise :exc:`EOFError`. + + """ + if self.closed: + return + + self.closed = True + + # Unblock get() or get_iter(). + self.frames.abort() diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 6cbff2595..ff90345ac 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -85,7 +85,7 @@ def get(self, timeout: float | None = None) -> Data: raise EOFError("stream of frames ended") if self.get_in_progress: - raise RuntimeError("get or get_iter is already running") + raise RuntimeError("get() or get_iter() is already running") self.get_in_progress = True @@ -144,7 +144,7 @@ def get_iter(self) -> Iterator[Data]: raise EOFError("stream of frames ended") if self.get_in_progress: - raise RuntimeError("get or get_iter is already running") + raise RuntimeError("get() or get_iter() is already running") chunks = self.chunks self.chunks = [] diff --git a/tests/asyncio/__init__.py b/tests/asyncio/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py new file mode 100644 index 000000000..c8a2d7cd5 --- /dev/null +++ b/tests/asyncio/test_messages.py @@ -0,0 +1,471 @@ +import asyncio +import unittest +import unittest.mock + +from websockets.asyncio.compatibility import aiter, anext +from websockets.asyncio.messages import * +from websockets.asyncio.messages import SimpleQueue +from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame + +from .utils import alist + + +class SimpleQueueTests(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.queue = SimpleQueue() + + async def test_len(self): + """__len__ returns queue length.""" + self.assertEqual(len(self.queue), 0) + self.queue.put(42) + self.assertEqual(len(self.queue), 1) + await self.queue.get() + self.assertEqual(len(self.queue), 0) + + async def test_put_then_get(self): + """get returns an item that is already put.""" + self.queue.put(42) + item = await self.queue.get() + self.assertEqual(item, 42) + + async def test_get_then_put(self): + """get returns an item when it is put.""" + getter_task = asyncio.create_task(self.queue.get()) + await asyncio.sleep(0) # let the task start + self.queue.put(42) + item = await getter_task + self.assertEqual(item, 42) + + async def test_get_concurrently(self): + """get cannot be called concurrently with itself.""" + getter_task = asyncio.create_task(self.queue.get()) + await asyncio.sleep(0) # let the task start + with self.assertRaises(RuntimeError): + await self.queue.get() + getter_task.cancel() + + async def test_reset(self): + """reset sets the content of the queue.""" + self.queue.reset([42]) + item = await self.queue.get() + self.assertEqual(item, 42) + + async def test_abort(self): + """abort throws an exception in get.""" + getter_task = asyncio.create_task(self.queue.get()) + await asyncio.sleep(0) # let the task start + self.queue.abort() + with self.assertRaises(EOFError): + await getter_task + + async def test_abort_clears_queue(self): + """abort clears buffered data from the queue.""" + self.queue.put(42) + self.assertEqual(len(self.queue), 1) + self.queue.abort() + self.assertEqual(len(self.queue), 0) + + +class AssemblerTests(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.pause = unittest.mock.Mock() + self.resume = unittest.mock.Mock() + self.assembler = Assembler(pause=self.pause, resume=self.resume) + self.assembler.set_limits(low=1, high=2) + + # Test get + + async def test_get_text_message_already_received(self): + """get returns a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_binary_message_already_received(self): + """get returns a binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_text_message_not_received_yet(self): + """get returns a text message when it is received.""" + getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start getter_task + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = await getter_task + self.assertEqual(message, "café") + + async def test_get_binary_message_not_received_yet(self): + """get returns a binary message when it is received.""" + getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start getter_task + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = await getter_task + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_already_received(self): + """get reassembles a fragmented a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_already_received(self): + """get reassembles a fragmented binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_not_received_yet(self): + """get reassembles a fragmented text message when it is received.""" + getter_task = asyncio.create_task(self.assembler.get()) + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = await getter_task + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_not_received_yet(self): + """get reassembles a fragmented binary message when it is received.""" + getter_task = asyncio.create_task(self.assembler.get()) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + message = await getter_task + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_being_received(self): + """get reassembles a fragmented text message that is partially received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + getter_task = asyncio.create_task(self.assembler.get()) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = await getter_task + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_being_received(self): + """get reassembles a fragmented binary message that is partially received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + getter_task = asyncio.create_task(self.assembler.get()) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + message = await getter_task + self.assertEqual(message, b"tea") + + async def test_get_encoded_text_message(self): + """get returns a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = await self.assembler.get(decode=False) + self.assertEqual(message, b"caf\xc3\xa9") + + async def test_get_decoded_binary_message(self): + """get returns a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = await self.assembler.get(decode=True) + self.assertEqual(message, "tea") + + async def test_get_resumes_reading(self): + """get resumes reading when queue goes below the high-water mark.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + + # queue is above the low-water mark + await self.assembler.get() + self.resume.assert_not_called() + + # queue is at the low-water mark + await self.assembler.get() + self.resume.assert_called_once_with() + + # queue is below the low-water mark + await self.assembler.get() + self.resume.assert_called_once_with() + + async def test_cancel_get_before_first_frame(self): + """get can be canceled safely before reading the first frame.""" + getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start getter_task + getter_task.cancel() + with self.assertRaises(asyncio.CancelledError): + await getter_task + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_cancel_get_after_first_frame(self): + """get can be canceled safely after reading the first frame.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + + getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start getter_task + getter_task.cancel() + with self.assertRaises(asyncio.CancelledError): + await getter_task + + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + message = await self.assembler.get() + self.assertEqual(message, "café") + + # Test get_iter + + async def test_get_iter_text_message_already_received(self): + """get_iter yields a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_get_iter_binary_message_already_received(self): + """get_iter yields a binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"tea"]) + + async def test_get_iter_text_message_not_received_yet(self): + """get_iter yields a text message when it is received.""" + getter_task = asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) # let the event loop start getter_task + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + fragments = await getter_task + self.assertEqual(fragments, ["café"]) + + async def test_get_iter_binary_message_not_received_yet(self): + """get_iter yields a binary message when it is received.""" + getter_task = asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) # let the event loop start getter_task + self.assembler.put(Frame(OP_BINARY, b"tea")) + fragments = await getter_task + self.assertEqual(fragments, [b"tea"]) + + async def test_get_iter_fragmented_text_message_already_received(self): + """get_iter yields a fragmented text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["ca", "f", "é"]) + + async def test_get_iter_fragmented_binary_message_already_received(self): + """get_iter yields a fragmented binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + async def test_get_iter_fragmented_text_message_not_received_yet(self): + """get_iter yields a fragmented text message when it is received.""" + iterator = aiter(self.assembler.get_iter()) + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assertEqual(await anext(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(await anext(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(await anext(iterator), "é") + + async def test_get_iter_fragmented_binary_message_not_received_yet(self): + """get_iter yields a fragmented binary message when it is received.""" + iterator = aiter(self.assembler.get_iter()) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assertEqual(await anext(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(await anext(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(await anext(iterator), b"a") + + async def test_get_iter_fragmented_text_message_being_received(self): + """get_iter yields a fragmented text message that is partially received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + iterator = aiter(self.assembler.get_iter()) + self.assertEqual(await anext(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(await anext(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(await anext(iterator), "é") + + async def test_get_iter_fragmented_binary_message_being_received(self): + """get_iter yields a fragmented binary message that is partially received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + iterator = aiter(self.assembler.get_iter()) + self.assertEqual(await anext(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(await anext(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(await anext(iterator), b"a") + + async def test_get_iter_encoded_text_message(self): + """get_iter yields a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = await alist(self.assembler.get_iter(decode=False)) + self.assertEqual(fragments, [b"ca", b"f\xc3", b"\xa9"]) + + async def test_get_iter_decoded_binary_message(self): + """get_iter yields a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = await alist(self.assembler.get_iter(decode=True)) + self.assertEqual(fragments, ["t", "e", "a"]) + + async def test_get_iter_resumes_reading(self): + """get_iter resumes reading when queue goes below the high-water mark.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + iterator = aiter(self.assembler.get_iter()) + + # queue is above the low-water mark + await anext(iterator) + self.resume.assert_not_called() + + # queue is at the low-water mark + await anext(iterator) + self.resume.assert_called_once_with() + + # queue is below the low-water mark + await anext(iterator) + self.resume.assert_called_once_with() + + async def test_cancel_get_iter_before_first_frame(self): + """get_iter can be canceled safely before reading the first frame.""" + getter_task = asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) # let the event loop start getter_task + getter_task.cancel() + with self.assertRaises(asyncio.CancelledError): + await getter_task + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_cancel_get_iter_after_first_frame(self): + """get cannot be canceled after reading the first frame.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + + getter_task = asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) # let the event loop start getter_task + getter_task.cancel() + with self.assertRaises(asyncio.CancelledError): + await getter_task + + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + with self.assertRaises(RuntimeError): + await alist(self.assembler.get_iter()) + + # Test put + + async def test_put_pauses_reading(self): + """put pauses reading when queue goes above the high-water mark.""" + # queue is below the high-water mark + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.pause.assert_not_called() + + # queue is at the high-water mark + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.pause.assert_called_once_with() + + # queue is above the high-water mark + self.assembler.put(Frame(OP_CONT, b"a")) + self.pause.assert_called_once_with() + + # Test termination + + async def test_get_fails_when_interrupted_by_close(self): + """get raises EOFError when close is called.""" + asyncio.get_running_loop().call_soon(self.assembler.close) + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_fails_when_interrupted_by_close(self): + """get_iter raises EOFError when close is called.""" + asyncio.get_running_loop().call_soon(self.assembler.close) + with self.assertRaises(EOFError): + async for _ in self.assembler.get_iter(): + self.fail("no fragment expected") + + async def test_get_fails_after_close(self): + """get raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_fails_after_close(self): + """get_iter raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + async for _ in self.assembler.get_iter(): + self.fail("no fragment expected") + + async def test_put_fails_after_close(self): + """put raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + async def test_close_is_idempotent(self): + """close can be called multiple times safely.""" + self.assembler.close() + self.assembler.close() + + # Test (non-)concurrency + + async def test_get_fails_when_get_is_running(self): + """get cannot be called concurrently with itself.""" + asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) + with self.assertRaises(RuntimeError): + await self.assembler.get() + self.assembler.close() # let task terminate + + async def test_get_fails_when_get_iter_is_running(self): + """get cannot be called concurrently with get_iter.""" + asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) + with self.assertRaises(RuntimeError): + await self.assembler.get() + self.assembler.close() # let task terminate + + async def test_get_iter_fails_when_get_is_running(self): + """get_iter cannot be called concurrently with get.""" + asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) + with self.assertRaises(RuntimeError): + await alist(self.assembler.get_iter()) + self.assembler.close() # let task terminate + + async def test_get_iter_fails_when_get_iter_is_running(self): + """get_iter cannot be called concurrently with itself.""" + asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) + with self.assertRaises(RuntimeError): + await alist(self.assembler.get_iter()) + self.assembler.close() # let task terminate + + # Test getting and setting limits + + async def test_get_limits(self): + """get_limits returns low and high water marks.""" + low, high = self.assembler.get_limits() + self.assertEqual(low, 1) + self.assertEqual(high, 2) + + async def test_set_limits(self): + """set_limits changes low and high water marks.""" + self.assembler.set_limits(low=2, high=4) + low, high = self.assembler.get_limits() + self.assertEqual(low, 2) + self.assertEqual(high, 4) diff --git a/tests/asyncio/utils.py b/tests/asyncio/utils.py new file mode 100644 index 000000000..a611bfc4b --- /dev/null +++ b/tests/asyncio/utils.py @@ -0,0 +1,5 @@ +async def alist(async_iterable): + items = [] + async for item in async_iterable: + items.append(item) + return items From 4a981688198f91385281b8c8e1cdfc0197d43bf5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 Apr 2023 08:37:40 +0200 Subject: [PATCH 072/109] Add new asyncio-based implementation. --- docs/project/changelog.rst | 15 +- docs/reference/index.rst | 12 + docs/reference/new-asyncio/client.rst | 53 ++ docs/reference/new-asyncio/common.rst | 43 ++ docs/reference/new-asyncio/server.rst | 72 ++ src/websockets/asyncio/client.py | 331 +++++++++ src/websockets/asyncio/compatibility.py | 12 +- src/websockets/asyncio/connection.py | 883 ++++++++++++++++++++++ src/websockets/asyncio/server.py | 772 +++++++++++++++++++ tests/asyncio/client.py | 33 + tests/asyncio/connection.py | 115 +++ tests/asyncio/server.py | 50 ++ tests/asyncio/test_client.py | 306 ++++++++ tests/asyncio/test_connection.py | 948 ++++++++++++++++++++++++ tests/asyncio/test_server.py | 525 +++++++++++++ 15 files changed, 4165 insertions(+), 5 deletions(-) create mode 100644 docs/reference/new-asyncio/client.rst create mode 100644 docs/reference/new-asyncio/common.rst create mode 100644 docs/reference/new-asyncio/server.rst create mode 100644 src/websockets/asyncio/client.py create mode 100644 src/websockets/asyncio/connection.py create mode 100644 src/websockets/asyncio/server.py create mode 100644 tests/asyncio/client.py create mode 100644 tests/asyncio/connection.py create mode 100644 tests/asyncio/server.py create mode 100644 tests/asyncio/test_client.py create mode 100644 tests/asyncio/test_connection.py create mode 100644 tests/asyncio/test_server.py diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index fd186a5fc..108b7c9c0 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -52,7 +52,7 @@ Backwards-incompatible changes async def handler(request, path): ... - You should switch to the recommended pattern since 10.1:: + You should switch to the pattern recommended since version 10.1:: async def handler(request): path = request.path # only if handler() uses the path argument @@ -61,6 +61,16 @@ Backwards-incompatible changes New features ............ +.. admonition:: websockets 11.0 introduces a new :mod:`asyncio` implementation. + :class: important + + This new implementation is intended to be a drop-in replacement for the + current implementation. It will become the default in a future release. + Please try it and report any issue that you encounter! + + See :func:`websockets.asyncio.client.connect` and + :func:`websockets.asyncio.server.serve` for details. + * Validated compatibility with Python 3.12. 12.0 @@ -175,7 +185,8 @@ New features It is particularly suited to client applications that establish only one connection. It may be used for servers handling few connections. - See :func:`~sync.client.connect` and :func:`~sync.server.serve` for details. + See :func:`websockets.sync.client.connect` and + :func:`websockets.sync.server.serve` for details. * Added ``open_timeout`` to :func:`~server.serve`. diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 0b80f087a..2486ac564 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -26,6 +26,18 @@ clients concurrently. asyncio/server asyncio/client +:mod:`asyncio` (new) +-------------------- + +This is a rewrite of the :mod:`asyncio` implementation. It will become the +default in the future. + +.. toctree:: + :titlesonly: + + new-asyncio/server + new-asyncio/client + :mod:`threading` ---------------- diff --git a/docs/reference/new-asyncio/client.rst b/docs/reference/new-asyncio/client.rst new file mode 100644 index 000000000..552d83b2f --- /dev/null +++ b/docs/reference/new-asyncio/client.rst @@ -0,0 +1,53 @@ +Client (:mod:`asyncio` - new) +============================= + +.. automodule:: websockets.asyncio.client + +Opening a connection +-------------------- + +.. autofunction:: connect + :async: + +.. autofunction:: unix_connect + :async: + +Using a connection +------------------ + +.. autoclass:: ClientConnection + + .. automethod:: __aiter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol diff --git a/docs/reference/new-asyncio/common.rst b/docs/reference/new-asyncio/common.rst new file mode 100644 index 000000000..ba23552dc --- /dev/null +++ b/docs/reference/new-asyncio/common.rst @@ -0,0 +1,43 @@ +:orphan: + +Both sides (:mod:`asyncio` - new) +================================= + +.. automodule:: websockets.asyncio.connection + +.. autoclass:: Connection + + .. automethod:: __aiter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol diff --git a/docs/reference/new-asyncio/server.rst b/docs/reference/new-asyncio/server.rst new file mode 100644 index 000000000..f3446fb80 --- /dev/null +++ b/docs/reference/new-asyncio/server.rst @@ -0,0 +1,72 @@ +Server (:mod:`asyncio` - new) +============================= + +.. automodule:: websockets.asyncio.server + +Creating a server +----------------- + +.. autofunction:: serve + :async: + +.. autofunction:: unix_serve + :async: + +Running a server +---------------- + +.. autoclass:: WebSocketServer + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: get_loop + + .. automethod:: is_serving + + .. automethod:: start_serving + + .. automethod:: serve_forever + + .. autoattribute:: sockets + +Using a connection +------------------ + +.. autoclass:: ServerConnection + + .. automethod:: __aiter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py new file mode 100644 index 000000000..040d68ece --- /dev/null +++ b/src/websockets/asyncio/client.py @@ -0,0 +1,331 @@ +from __future__ import annotations + +import asyncio +from types import TracebackType +from typing import Any, Generator, Sequence + +from ..client import ClientProtocol +from ..datastructures import HeadersLike +from ..extensions.base import ClientExtensionFactory +from ..extensions.permessage_deflate import enable_client_permessage_deflate +from ..headers import validate_subprotocols +from ..http import USER_AGENT +from ..http11 import Response +from ..protocol import CONNECTING, Event +from ..typing import LoggerLike, Origin, Subprotocol +from ..uri import parse_uri +from .compatibility import TimeoutError, asyncio_timeout +from .connection import Connection + + +__all__ = ["connect", "unix_connect", "ClientConnection"] + + +class ClientConnection(Connection): + """ + :mod:`asyncio` implementation of a WebSocket client connection. + + :class:`ClientConnection` provides :meth:`recv` and :meth:`send` coroutines + for receiving and sending messages. + + It supports asynchronous iteration to receive messages:: + + async for message in websocket: + await process(message) + + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away) or without a close code. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is + closed with any other code. + + Args: + protocol: Sans-I/O connection. + close_timeout: Timeout for closing the connection in seconds. + :obj:`None` disables the timeout. + + """ + + def __init__( + self, + protocol: ClientProtocol, + *, + close_timeout: float | None = 10, + ) -> None: + self.protocol: ClientProtocol + super().__init__( + protocol, + close_timeout=close_timeout, + ) + self.response_rcvd: asyncio.Future[None] = self.loop.create_future() + + async def handshake( + self, + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + ) -> None: + """ + Perform the opening handshake. + + """ + async with self.send_context(expected_state=CONNECTING): + self.request = self.protocol.connect() + if additional_headers is not None: + self.request.headers.update(additional_headers) + if user_agent_header is not None: + self.request.headers["User-Agent"] = user_agent_header + self.protocol.send_request(self.request) + + # May raise CancelledError if open_timeout is exceeded. + await self.response_rcvd + + if self.response is None: + raise ConnectionError("connection closed during handshake") + + if self.protocol.handshake_exc is not None: + try: + async with asyncio_timeout(self.close_timeout): + await self.connection_lost_waiter + finally: + raise self.protocol.handshake_exc + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + """ + # First event - handshake response. + if self.response is None: + assert isinstance(event, Response) + self.response = event + self.response_rcvd.set_result(None) + # Later events - frames. + else: + super().process_event(event) + + def connection_lost(self, exc: Exception | None) -> None: + try: + super().connection_lost(exc) + finally: + # If the connection is closed during the handshake, unblock it. + if not self.response_rcvd.done(): + self.response_rcvd.set_result(None) + + +# This is spelled in lower case because it's exposed as a callable in the API. +class connect: + """ + Connect to the WebSocket server at ``uri``. + + This coroutine returns a :class:`ClientConnection` instance, which you can + use to send and receive messages. + + :func:`connect` may be used as an asynchronous context manager:: + + async with websockets.asyncio.client.connect(...) as websocket: + ... + + The connection is closed automatically when exiting the context. + + Args: + uri: URI of the WebSocket server. + origin: Value of the ``Origin`` header, for servers that require it. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + additional_headers (HeadersLike | None): Arbitrary HTTP headers to add + to the handshake request. + user_agent_header: Value of the ``User-Agent`` request header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. + Setting it to :obj:`None` removes the header. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + open_timeout: Timeout for opening the connection in seconds. + :obj:`None` disables the timeout. + close_timeout: Timeout for closing the connection in seconds. + :obj:`None` disables the timeout. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. + logger: Logger for this client. + It defaults to ``logging.getLogger("websockets.client")``. + See the :doc:`logging guide <../../topics/logging>` for details. + create_connection: Factory for the :class:`ClientConnection` managing + the connection. Set it to a wrapper or a subclass to customize + connection handling. + + Any other keyword arguments are passed to the event loop's + :meth:`~asyncio.loop.create_connection` method. + + For example: + + * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enforce TLS settings. + When connecting to a ``wss://`` URI, if ``ssl`` isn't provided, a TLS + context is created with :func:`~ssl.create_default_context`. + + * You can set ``server_hostname`` to override the host name from ``uri`` in + the TLS handshake. + + * You can set ``host`` and ``port`` to connect to a different host and port + from those found in ``uri``. This only changes the destination of the TCP + connection. The host name from ``uri`` is still used in the TLS handshake + for secure connections and in the ``Host`` header. + + * You can set ``sock`` to provide a preexisting TCP socket. You may call + :func:`socket.create_connection` (not to be confused with the event loop's + :meth:`~asyncio.loop.create_connection` method) to create a suitable + client socket and customize it. + + Raises: + InvalidURI: If ``uri`` isn't a valid WebSocket URI. + OSError: If the TCP connection fails. + InvalidHandshake: If the opening handshake fails. + TimeoutError: If the opening handshake times out. + + """ + + def __init__( + self, + uri: str, + *, + # WebSocket + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + compression: str | None = "deflate", + # Timeouts + open_timeout: float | None = 10, + close_timeout: float | None = 10, + # Limits + max_size: int | None = 2**20, + # Logging + logger: LoggerLike | None = None, + # Escape hatch for advanced customization + create_connection: type[ClientConnection] | None = None, + # Other keyword arguments are passed to loop.create_connection + **kwargs: Any, + ) -> None: + + wsuri = parse_uri(uri) + + if wsuri.secure: + kwargs.setdefault("ssl", True) + kwargs.setdefault("server_hostname", wsuri.host) + if kwargs.get("ssl") is None: + raise TypeError("ssl=None is incompatible with a wss:// URI") + else: + if kwargs.get("ssl") is not None: + raise TypeError("ssl argument is incompatible with a ws:// URI") + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_client_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if create_connection is None: + create_connection = ClientConnection + + def factory() -> ClientConnection: + # This is a protocol in the Sans-I/O implementation of websockets. + protocol = ClientProtocol( + wsuri, + origin=origin, + extensions=extensions, + subprotocols=subprotocols, + max_size=max_size, + logger=logger, + ) + # This is a connection in websockets and a protocol in asyncio. + connection = create_connection( + protocol, + close_timeout=close_timeout, + ) + return connection + + loop = asyncio.get_running_loop() + if kwargs.pop("unix", False): + self._create_connection = loop.create_unix_connection(factory, **kwargs) + else: + if kwargs.get("sock") is None: + kwargs.setdefault("host", wsuri.host) + kwargs.setdefault("port", wsuri.port) + self._create_connection = loop.create_connection(factory, **kwargs) + + self._handshake_args = ( + additional_headers, + user_agent_header, + ) + + self._open_timeout = open_timeout + + # async with connect(...) as ...: ... + + async def __aenter__(self) -> ClientConnection: + return await self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + await self.connection.close() + + # ... = await connect(...) + + def __await__(self) -> Generator[Any, None, ClientConnection]: + # Create a suitable iterator by calling __await__ on a coroutine. + return self.__await_impl__().__await__() + + async def __await_impl__(self) -> ClientConnection: + try: + async with asyncio_timeout(self._open_timeout): + _transport, self.connection = await self._create_connection + try: + await self.connection.handshake(*self._handshake_args) + except (Exception, asyncio.CancelledError): + self.connection.transport.close() + raise + else: + return self.connection + except TimeoutError: + # Re-raise exception with an informative error message. + raise TimeoutError("timed out during handshake") from None + + # ... = yield from connect(...) - remove when dropping Python < 3.10 + + __iter__ = __await__ + + +def unix_connect( + path: str | None = None, + uri: str | None = None, + **kwargs: Any, +) -> connect: + """ + Connect to a WebSocket server listening on a Unix socket. + + This function accepts the same keyword arguments as :func:`connect`. + + It's only available on Unix. + + It's mainly useful for debugging servers listening on Unix sockets. + + Args: + path: File system path to the Unix socket. + uri: URI of the WebSocket server. ``uri`` defaults to + ``ws://localhost/`` or, when a ``ssl`` argument is provided, to + ``wss://localhost/``. + + """ + if uri is None: + if kwargs.get("ssl") is None: + uri = "ws://localhost/" + else: + uri = "wss://localhost/" + return connect(uri=uri, unix=True, path=path, **kwargs) diff --git a/src/websockets/asyncio/compatibility.py b/src/websockets/asyncio/compatibility.py index 390f00ac7..e17000069 100644 --- a/src/websockets/asyncio/compatibility.py +++ b/src/websockets/asyncio/compatibility.py @@ -3,14 +3,17 @@ import sys -__all__ = ["TimeoutError", "aiter", "anext", "asyncio_timeout"] +__all__ = ["TimeoutError", "aiter", "anext", "asyncio_timeout", "asyncio_timeout_at"] if sys.version_info[:2] >= (3, 11): TimeoutError = TimeoutError aiter = aiter anext = anext - from asyncio import timeout as asyncio_timeout + from asyncio import ( + timeout as asyncio_timeout, # noqa: F401 + timeout_at as asyncio_timeout_at, # noqa: F401 + ) else: # Python < 3.11 from asyncio import TimeoutError @@ -21,4 +24,7 @@ def aiter(async_iterable): async def anext(async_iterator): return await type(async_iterator).__anext__(async_iterator) - from .async_timeout import timeout as asyncio_timeout + from .async_timeout import ( + timeout as asyncio_timeout, # noqa: F401 + timeout_at as asyncio_timeout_at, # noqa: F401 + ) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py new file mode 100644 index 000000000..550c0ac97 --- /dev/null +++ b/src/websockets/asyncio/connection.py @@ -0,0 +1,883 @@ +from __future__ import annotations + +import asyncio +import collections +import contextlib +import logging +import random +import struct +import uuid +from types import TracebackType +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Awaitable, + Iterable, + Mapping, + cast, +) + +from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError +from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl +from ..http11 import Request, Response +from ..protocol import CLOSED, OPEN, Event, Protocol, State +from ..typing import Data, LoggerLike, Subprotocol +from .compatibility import TimeoutError, aiter, anext, asyncio_timeout_at +from .messages import Assembler + + +__all__ = ["Connection"] + + +class Connection(asyncio.Protocol): + """ + :mod:`asyncio` implementation of a WebSocket connection. + + :class:`Connection` provides APIs shared between WebSocket servers and + clients. + + You shouldn't use it directly. Instead, use + :class:`~websockets.asyncio.client.ClientConnection` or + :class:`~websockets.asyncio.server.ServerConnection`. + + """ + + def __init__( + self, + protocol: Protocol, + *, + close_timeout: float | None = 10, + ) -> None: + self.protocol = protocol + self.close_timeout = close_timeout + + # Inject reference to this instance in the protocol's logger. + self.protocol.logger = logging.LoggerAdapter( + self.protocol.logger, + {"websocket": self}, + ) + + # Copy attributes from the protocol for convenience. + self.id: uuid.UUID = self.protocol.id + """Unique identifier of the connection. Useful in logs.""" + self.logger: LoggerLike = self.protocol.logger + """Logger for this connection.""" + self.debug = self.protocol.debug + + # HTTP handshake request and response. + self.request: Request | None = None + """Opening handshake request.""" + self.response: Response | None = None + """Opening handshake response.""" + + # Event loop running this connection. + self.loop = asyncio.get_running_loop() + + # Assembler turning frames into messages and serializing reads. + self.recv_messages: Assembler # initialized in connection_made + + # Deadline for the closing handshake. + self.close_deadline: float | None = None + + # Protect sending fragmented messages. + self.fragmented_send_waiter: asyncio.Future[None] | None = None + + # Mapping of ping IDs to pong waiters, in chronological order. + self.pong_waiters: dict[bytes, tuple[asyncio.Future[float], float]] = {} + + # Exception raised while reading from the connection, to be chained to + # ConnectionClosed in order to show why the TCP connection dropped. + self.recv_exc: BaseException | None = None + + # Completed when the TCP connection is closed and the WebSocket + # connection state becomes CLOSED. + self.connection_lost_waiter: asyncio.Future[None] = self.loop.create_future() + + # Adapted from asyncio.FlowControlMixin + self.paused: bool = False + self.drain_waiters: collections.deque[asyncio.Future[None]] = ( + collections.deque() + ) + + # Public attributes + + @property + def local_address(self) -> Any: + """ + Local address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getsockname`. + + """ + return self.transport.get_extra_info("sockname") + + @property + def remote_address(self) -> Any: + """ + Remote address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getpeername`. + + """ + return self.transport.get_extra_info("peername") + + @property + def subprotocol(self) -> Subprotocol | None: + """ + Subprotocol negotiated during the opening handshake. + + :obj:`None` if no subprotocol was negotiated. + + """ + return self.protocol.subprotocol + + # Public methods + + async def __aenter__(self) -> Connection: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + if exc_type is None: + await self.close() + else: + await self.close(CloseCode.INTERNAL_ERROR) + + async def __aiter__(self) -> AsyncIterator[Data]: + """ + Iterate on incoming messages. + + The iterator calls :meth:`recv` and yields messages asynchronously in an + infinite loop. + + It exits when the connection is closed normally. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` exception after a + protocol error or a network failure. + + """ + try: + while True: + yield await self.recv() + except ConnectionClosedOK: + return + + async def recv(self) -> Data: + """ + Receive the next message. + + When the connection is closed, :meth:`recv` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises + :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal closure + and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. This is how you detect the end of the + message stream. + + Canceling :meth:`recv` is safe. There's no risk of losing data. The next + invocation of :meth:`recv` will return the next message. + + This makes it possible to enforce a timeout by wrapping :meth:`recv` in + :func:`~asyncio.timeout` or :func:`~asyncio.wait_for`. + + When the message is fragmented, :meth:`recv` waits until all fragments + are received, reassembles them, and returns the whole message. + + Returns: + A string (:class:`str`) for a Text_ frame or a bytestring + (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If two coroutines call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + return await self.recv_messages.get() + except EOFError: + raise self.protocol.close_exc from self.recv_exc + except RuntimeError: + raise RuntimeError( + "cannot call recv while another coroutine " + "is already running recv or recv_streaming" + ) from None + + async def recv_streaming(self) -> AsyncIterator[Data]: + """ + Receive the next message frame by frame. + + This method is designed for receiving fragmented messages. It returns an + asynchronous iterator that yields each fragment as it is received. This + iterator must be fully consumed. Else, future calls to :meth:`recv` or + :meth:`recv_streaming` will raise :exc:`RuntimeError`, making the + connection unusable. + + :meth:`recv_streaming` raises the same exceptions as :meth:`recv`. + + Canceling :meth:`recv_streaming` before receiving the first frame is + safe. Canceling it after receiving one or more frames leaves the + iterator in a partially consumed state, making the connection unusable. + Instead, you should close the connection with :meth:`close`. + + Returns: + An iterator of strings (:class:`str`) for a Text_ frame or + bytestrings (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If two coroutines call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + async for frame in self.recv_messages.get_iter(): + yield frame + except EOFError: + raise self.protocol.close_exc from self.recv_exc + except RuntimeError: + raise RuntimeError( + "cannot call recv_streaming while another coroutine " + "is already running recv or recv_streaming" + ) from None + + async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> None: + """ + Send a message. + + A string (:class:`str`) is sent as a Text_ frame. A bytestring or + bytes-like object (:class:`bytes`, :class:`bytearray`, or + :class:`memoryview`) is sent as a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + :meth:`send` also accepts an iterable or an asynchronous iterable of + strings, bytestrings, or bytes-like objects to enable fragmentation_. + Each item is treated as a message fragment and sent in its own frame. + All items must be of the same type, or else :meth:`send` will raise a + :exc:`TypeError` and the connection will be closed. + + .. _fragmentation: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.4 + + :meth:`send` rejects dict-like objects because this is often an error. + (If you really want to send the keys of a dict-like object as fragments, + call its :meth:`~dict.keys` method and pass the result to :meth:`send`.) + + Canceling :meth:`send` is discouraged. Instead, you should close the + connection with :meth:`close`. Indeed, there are only two situations + where :meth:`send` may yield control to the event loop and then get + canceled; in both cases, :meth:`close` has the same effect and is + more clear: + + 1. The write buffer is full. If you don't want to wait until enough + data is sent, your only alternative is to close the connection. + :meth:`close` will likely time out then abort the TCP connection. + 2. ``message`` is an asynchronous iterator that yields control. + Stopping in the middle of a fragmented message will cause a + protocol error and the connection will be closed. + + When the connection is closed, :meth:`send` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it + raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + + Args: + message: Message to send. + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If the connection busy sending a fragmented message. + TypeError: If ``message`` doesn't have a supported type. + + """ + # While sending a fragmented message, prevent sending other messages + # until all fragments are sent. + while self.fragmented_send_waiter is not None: + await asyncio.shield(self.fragmented_send_waiter) + + # Unfragmented message -- this case must be handled first because + # strings and bytes-like objects are iterable. + + if isinstance(message, str): + async with self.send_context(): + self.protocol.send_text(message.encode()) + + elif isinstance(message, BytesLike): + async with self.send_context(): + self.protocol.send_binary(message) + + # Catch a common mistake -- passing a dict to send(). + + elif isinstance(message, Mapping): + raise TypeError("data is a dict-like object") + + # Fragmented message -- regular iterator. + + elif isinstance(message, Iterable): + chunks = iter(message) + try: + chunk = next(chunks) + except StopIteration: + return + + assert self.fragmented_send_waiter is None + self.fragmented_send_waiter = self.loop.create_future() + try: + # First fragment. + if isinstance(chunk, str): + text = True + async with self.send_context(): + self.protocol.send_text( + chunk.encode(), + fin=False, + ) + elif isinstance(chunk, BytesLike): + text = False + async with self.send_context(): + self.protocol.send_binary( + chunk, + fin=False, + ) + else: + raise TypeError("iterable must contain bytes or str") + + # Other fragments + for chunk in chunks: + if isinstance(chunk, str) and text: + async with self.send_context(): + self.protocol.send_continuation( + chunk.encode(), + fin=False, + ) + elif isinstance(chunk, BytesLike) and not text: + async with self.send_context(): + self.protocol.send_continuation( + chunk, + fin=False, + ) + else: + raise TypeError("iterable must contain uniform types") + + # Final fragment. + async with self.send_context(): + self.protocol.send_continuation(b"", fin=True) + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + async with self.send_context(): + self.protocol.fail(1011, "error in fragmented message") + raise + + finally: + self.fragmented_send_waiter.set_result(None) + self.fragmented_send_waiter = None + + # Fragmented message -- async iterator. + + elif isinstance(message, AsyncIterable): + achunks = aiter(message) + try: + chunk = await anext(achunks) + except StopAsyncIteration: + return + + assert self.fragmented_send_waiter is None + self.fragmented_send_waiter = self.loop.create_future() + try: + # First fragment. + if isinstance(chunk, str): + text = True + async with self.send_context(): + self.protocol.send_text( + chunk.encode(), + fin=False, + ) + elif isinstance(chunk, BytesLike): + text = False + async with self.send_context(): + self.protocol.send_binary( + chunk, + fin=False, + ) + else: + raise TypeError("async iterable must contain bytes or str") + + # Other fragments + async for chunk in achunks: + if isinstance(chunk, str) and text: + async with self.send_context(): + self.protocol.send_continuation( + chunk.encode(), + fin=False, + ) + elif isinstance(chunk, BytesLike) and not text: + async with self.send_context(): + self.protocol.send_continuation( + chunk, + fin=False, + ) + else: + raise TypeError("async iterable must contain uniform types") + + # Final fragment. + async with self.send_context(): + self.protocol.send_continuation(b"", fin=True) + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + async with self.send_context(): + self.protocol.fail(1011, "error in fragmented message") + raise + + finally: + self.fragmented_send_waiter.set_result(None) + self.fragmented_send_waiter = None + + else: + raise TypeError("data must be bytes, str, iterable, or async iterable") + + async def close(self, code: int = 1000, reason: str = "") -> None: + """ + Perform the closing handshake. + + :meth:`close` waits for the other end to complete the handshake and + for the TCP connection to terminate. + + :meth:`close` is idempotent: it doesn't do anything once the + connection is closed. + + Args: + code: WebSocket close code. + reason: WebSocket close reason. + + """ + try: + # The context manager takes care of waiting for the TCP connection + # to terminate after calling a method that sends a close frame. + async with self.send_context(): + if self.fragmented_send_waiter is not None: + self.protocol.fail(1011, "close during fragmented message") + else: + self.protocol.send_close(code, reason) + except ConnectionClosed: + # Ignore ConnectionClosed exceptions raised from send_context(). + # They mean that the connection is closed, which was the goal. + pass + + async def wait_closed(self) -> None: + """ + Wait until the connection is closed. + + :meth:`wait_closed` waits for the closing handshake to complete and for + the TCP connection to terminate. + + """ + await asyncio.shield(self.connection_lost_waiter) + + async def ping(self, data: Data | None = None) -> Awaitable[None]: + """ + Send a Ping_. + + .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 + + A ping may serve as a keepalive or as a check that the remote endpoint + received all messages up to this point + + Args: + data: Payload of the ping. A :class:`str` will be encoded to UTF-8. + If ``data`` is :obj:`None`, the payload is four random bytes. + + Returns: + A future that will be completed when the corresponding pong is + received. You can ignore it if you don't intend to wait. The result + of the future is the latency of the connection in seconds. + + :: + + pong_waiter = await ws.ping() + # only if you want to wait for the corresponding pong + latency = await pong_waiter + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If another ping was sent with the same data and + the corresponding pong wasn't received yet. + + """ + if data is not None: + data = prepare_ctrl(data) + + async with self.send_context(): + # Protect against duplicates if a payload is explicitly set. + if data in self.pong_waiters: + raise RuntimeError("already waiting for a pong with the same data") + + # Generate a unique random payload otherwise. + while data is None or data in self.pong_waiters: + data = struct.pack("!I", random.getrandbits(32)) + + pong_waiter = self.loop.create_future() + # The event loop's default clock is time.monotonic(). Its resolution + # is a bit low on Windows (~16ms). We cannot use time.perf_counter() + # because it doesn't count time elapsed while the process sleeps. + ping_timestamp = self.loop.time() + self.pong_waiters[data] = (pong_waiter, ping_timestamp) + self.protocol.send_ping(data) + return pong_waiter + + async def pong(self, data: Data = b"") -> None: + """ + Send a Pong_. + + .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + + An unsolicited pong may serve as a unidirectional heartbeat. + + Args: + data: Payload of the pong. A :class:`str` will be encoded to UTF-8. + + Raises: + ConnectionClosed: When the connection is closed. + + """ + data = prepare_ctrl(data) + + async with self.send_context(): + self.protocol.send_pong(data) + + # Private methods + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + This method is overridden in subclasses to handle the handshake. + + """ + assert isinstance(event, Frame) + if event.opcode in DATA_OPCODES: + self.recv_messages.put(event) + + if event.opcode is Opcode.PONG: + self.acknowledge_pings(bytes(event.data)) + + def acknowledge_pings(self, data: bytes) -> None: + """ + Acknowledge pings when receiving a pong. + + """ + # Ignore unsolicited pong. + if data not in self.pong_waiters: + return + + pong_timestamp = self.loop.time() + + # Sending a pong for only the most recent ping is legal. + # Acknowledge all previous pings too in that case. + ping_id = None + ping_ids = [] + for ping_id, (pong_waiter, ping_timestamp) in self.pong_waiters.items(): + ping_ids.append(ping_id) + pong_waiter.set_result(pong_timestamp - ping_timestamp) + if ping_id == data: + break + else: + raise AssertionError("solicited pong not found in pings") + + # Remove acknowledged pings from self.pong_waiters. + for ping_id in ping_ids: + del self.pong_waiters[ping_id] + + def abort_pings(self) -> None: + """ + Raise ConnectionClosed in pending pings. + + They'll never receive a pong once the connection is closed. + + """ + assert self.protocol.state is CLOSED + exc = self.protocol.close_exc + + for pong_waiter, _ping_timestamp in self.pong_waiters.values(): + pong_waiter.set_exception(exc) + # If the exception is never retrieved, it will be logged when ping + # is garbage-collected. This is confusing for users. + # Given that ping is done (with an exception), canceling it does + # nothing, but it prevents logging the exception. + pong_waiter.cancel() + + self.pong_waiters.clear() + + @contextlib.asynccontextmanager + async def send_context( + self, + *, + expected_state: State = OPEN, # CONNECTING during the opening handshake + ) -> AsyncIterator[None]: + """ + Create a context for writing to the connection from user code. + + On entry, :meth:`send_context` checks that the connection is open; on + exit, it writes outgoing data to the socket:: + + async async with self.send_context(): + self.protocol.send_text(message.encode()) + + When the connection isn't open on entry, when the connection is expected + to close on exit, or when an unexpected error happens, terminating the + connection, :meth:`send_context` waits until the connection is closed + then raises :exc:`~websockets.exceptions.ConnectionClosed`. + + """ + # Should we wait until the connection is closed? + wait_for_close = False + # Should we close the transport and raise ConnectionClosed? + raise_close_exc = False + # What exception should we chain ConnectionClosed to? + original_exc: BaseException | None = None + + if self.protocol.state is expected_state: + # Let the caller interact with the protocol. + try: + yield + except (ProtocolError, RuntimeError): + # The protocol state wasn't changed. Exit immediately. + raise + except Exception as exc: + self.logger.error("unexpected internal error", exc_info=True) + # This branch should never run. It's a safety net in case of + # bugs. Since we don't know what happened, we will close the + # connection and raise the exception to the caller. + wait_for_close = False + raise_close_exc = True + original_exc = exc + else: + # Check if the connection is expected to close soon. + if self.protocol.close_expected(): + wait_for_close = True + # If the connection is expected to close soon, set the + # close deadline based on the close timeout. + # Since we tested earlier that protocol.state was OPEN + # (or CONNECTING), self.close_deadline is still None. + if self.close_timeout is not None: + assert self.close_deadline is None + self.close_deadline = self.loop.time() + self.close_timeout + # Write outgoing data to the socket and enforce flow control. + try: + self.send_data() + await self.drain() + except Exception as exc: + if self.debug: + self.logger.debug("error while sending data", exc_info=True) + # While the only expected exception here is OSError, + # other exceptions would be treated identically. + wait_for_close = False + raise_close_exc = True + original_exc = exc + + else: # self.protocol.state is not expected_state + # Minor layering violation: we assume that the connection + # will be closing soon if it isn't in the expected state. + wait_for_close = True + # Calculate close_deadline if it wasn't set yet. + if self.close_timeout is not None: + if self.close_deadline is None: + self.close_deadline = self.loop.time() + self.close_timeout + raise_close_exc = True + + # If the connection is expected to close soon and the close timeout + # elapses, close the socket to terminate the connection. + if wait_for_close: + try: + async with asyncio_timeout_at(self.close_deadline): + await asyncio.shield(self.connection_lost_waiter) + except TimeoutError: + # There's no risk to overwrite another error because + # original_exc is never set when wait_for_close is True. + assert original_exc is None + original_exc = TimeoutError("timed out while closing connection") + # Set recv_exc before closing the transport in order to get + # proper exception reporting. + raise_close_exc = True + self.set_recv_exc(original_exc) + + # If an error occurred, close the transport to terminate the connection and + # raise an exception. + if raise_close_exc: + self.close_transport() + await asyncio.shield(self.connection_lost_waiter) + raise self.protocol.close_exc from original_exc + + def send_data(self) -> None: + """ + Send outgoing data. + + Raises: + OSError: When a socket operations fails. + + """ + for data in self.protocol.data_to_send(): + if data: + self.transport.write(data) + else: + # Half-close the TCP connection when possible i.e. no TLS. + if self.transport.can_write_eof(): + if self.debug: + self.logger.debug("x half-closing TCP connection") + # write_eof() doesn't document which exceptions it raises. + # OSError is plausible. uvloop can raise RuntimeError here. + try: + self.transport.write_eof() + except (OSError, RuntimeError): # pragma: no cover + pass + # Else, close the TCP connection. + else: # pragma: no cover + if self.debug: + self.logger.debug("x closing TCP connection") + self.transport.close() + + def set_recv_exc(self, exc: BaseException | None) -> None: + """ + Set recv_exc, if not set yet. + + """ + if self.recv_exc is None: + self.recv_exc = exc + + def close_transport(self) -> None: + """ + Close transport and message assembler. + + """ + self.transport.close() + self.recv_messages.close() + + # asyncio.Protocol methods + + # Connection callbacks + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + transport = cast(asyncio.Transport, transport) + self.transport = transport + self.recv_messages = Assembler( + pause=self.transport.pause_reading, + resume=self.transport.resume_reading, + ) + + def connection_lost(self, exc: Exception | None) -> None: + self.protocol.receive_eof() # receive_eof is idempotent + self.recv_messages.close() + self.set_recv_exc(exc) + # If self.connection_lost_waiter isn't pending, that's a bug, because: + # - it's set only here in connection_lost() which is called only once; + # - it must never be canceled. + self.connection_lost_waiter.set_result(None) + self.abort_pings() + + # Adapted from asyncio.streams.FlowControlMixin + if self.paused: # pragma: no cover + self.paused = False + for waiter in self.drain_waiters: + if not waiter.done(): + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) + + # Flow control callbacks + + def pause_writing(self) -> None: # pragma: no cover + # Adapted from asyncio.streams.FlowControlMixin + assert not self.paused + self.paused = True + + def resume_writing(self) -> None: # pragma: no cover + # Adapted from asyncio.streams.FlowControlMixin + assert self.paused + self.paused = False + for waiter in self.drain_waiters: + if not waiter.done(): + waiter.set_result(None) + + async def drain(self) -> None: # pragma: no cover + # We don't check if the connection is closed because we call drain() + # immediately after write() and write() would fail in that case. + + # Adapted from asyncio.streams.StreamWriter + # Yield to the event loop so that connection_lost() may be called. + if self.transport.is_closing(): + await asyncio.sleep(0) + + # Adapted from asyncio.streams.FlowControlMixin + if self.paused: + waiter = self.loop.create_future() + self.drain_waiters.append(waiter) + try: + await waiter + finally: + self.drain_waiters.remove(waiter) + + # Streaming protocol callbacks + + def data_received(self, data: bytes) -> None: + # Feed incoming data to the protocol. + self.protocol.receive_data(data) + + # This isn't expected to raise an exception. + events = self.protocol.events_received() + + # Write outgoing data to the transport. + try: + self.send_data() + except Exception as exc: + if self.debug: + self.logger.debug("error while sending data", exc_info=True) + self.set_recv_exc(exc) + + if self.protocol.close_expected(): + # If the connection is expected to close soon, set the + # close deadline based on the close timeout. + if self.close_timeout is not None: + if self.close_deadline is None: + self.close_deadline = self.loop.time() + self.close_timeout + + for event in events: + # This isn't expected to raise an exception. + self.process_event(event) + + def eof_received(self) -> None: + # Feed the end of the data stream to the connection. + self.protocol.receive_eof() + + # This isn't expected to generate events. + assert not self.protocol.events_received() + + # There is no error handling because send_data() can only write + # the end of the data stream here and it shouldn't raise errors. + self.send_data() + + # The WebSocket protocol has its own closing handshake: endpoints close + # the TCP or TLS connection after sending and receiving a close frame. + # As a consequence, they never need to write after receiving EOF, so + # there's no reason to keep the transport open by returning True. + # Besides, that doesn't work on TLS connections. diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py new file mode 100644 index 000000000..aa175f775 --- /dev/null +++ b/src/websockets/asyncio/server.py @@ -0,0 +1,772 @@ +from __future__ import annotations + +import asyncio +import http +import logging +import socket +import sys +from types import TracebackType +from typing import ( + Any, + Awaitable, + Callable, + Generator, + Iterable, + Sequence, +) + +from websockets.frames import CloseCode + +from ..extensions.base import ServerExtensionFactory +from ..extensions.permessage_deflate import enable_server_permessage_deflate +from ..headers import validate_subprotocols +from ..http import USER_AGENT +from ..http11 import Request, Response +from ..protocol import CONNECTING, Event +from ..server import ServerProtocol +from ..typing import LoggerLike, Origin, Subprotocol +from .compatibility import asyncio_timeout +from .connection import Connection + + +__all__ = ["serve", "unix_serve", "ServerConnection", "WebSocketServer"] + + +class ServerConnection(Connection): + """ + :mod:`asyncio` implementation of a WebSocket server connection. + + :class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for + receiving and sending messages. + + It supports asynchronous iteration to receive messages:: + + async for message in websocket: + await process(message) + + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away) or without a close code. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is + closed with any other code. + + Args: + protocol: Sans-I/O connection. + server: Server that manages this connection. + close_timeout: Timeout for closing connections in seconds. + :obj:`None` disables the timeout. + + """ + + def __init__( + self, + protocol: ServerProtocol, + server: WebSocketServer, + *, + close_timeout: float | None = 10, + ) -> None: + self.protocol: ServerProtocol + super().__init__( + protocol, + close_timeout=close_timeout, + ) + self.server = server + self.request_rcvd: asyncio.Future[None] = self.loop.create_future() + + async def handshake( + self, + process_request: ( + Callable[ + [ServerConnection, Request], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + server_header: str | None = USER_AGENT, + ) -> None: + """ + Perform the opening handshake. + + """ + # May raise CancelledError if open_timeout is exceeded. + await self.request_rcvd + + if self.request is None: + raise ConnectionError("connection closed during handshake") + + async with self.send_context(expected_state=CONNECTING): + response = None + + if process_request is not None: + try: + response = process_request(self, self.request) + if isinstance(response, Awaitable): + response = await response + except Exception as exc: + self.protocol.handshake_exc = exc + self.logger.error("opening handshake failed", exc_info=True) + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is None: + if self.server.is_serving(): + self.response = self.protocol.accept(self.request) + else: + self.response = self.protocol.reject( + http.HTTPStatus.SERVICE_UNAVAILABLE, + "Server is shutting down.\n", + ) + else: + assert isinstance(response, Response) # help mypy + self.response = response + + if server_header is not None: + self.response.headers["Server"] = server_header + + response = None + + if process_response is not None: + try: + response = process_response(self, self.request, self.response) + if isinstance(response, Awaitable): + response = await response + except Exception as exc: + self.protocol.handshake_exc = exc + self.logger.error("opening handshake failed", exc_info=True) + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is not None: + assert isinstance(response, Response) # help mypy + self.response = response + + self.protocol.send_response(self.response) + + if self.protocol.handshake_exc is not None: + try: + async with asyncio_timeout(self.close_timeout): + await self.connection_lost_waiter + finally: + raise self.protocol.handshake_exc + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + """ + # First event - handshake request. + if self.request is None: + assert isinstance(event, Request) + self.request = event + self.request_rcvd.set_result(None) + # Later events - frames. + else: + super().process_event(event) + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + super().connection_made(transport) + self.server.start_connection_handler(self) + + def connection_lost(self, exc: Exception | None) -> None: + try: + super().connection_lost(exc) + finally: + # If the connection is closed during the handshake, unblock it. + if not self.request_rcvd.done(): + self.request_rcvd.set_result(None) + + +class WebSocketServer: + """ + WebSocket server returned by :func:`serve`. + + This class mirrors the API of :class:`~asyncio.Server`. + + It keeps track of WebSocket connections in order to close them properly + when shutting down. + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + process_request: Intercept the request during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + ``process_request`` may be a function or a coroutine. + process_response: Intercept the response during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + ``process_response`` may be a function or a coroutine. + server_header: Value of the ``Server`` response header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to + :obj:`None` removes the header. + open_timeout: Timeout for opening connections in seconds. + :obj:`None` disables the timeout. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. + See the :doc:`logging guide <../../topics/logging>` for details. + + """ + + def __init__( + self, + handler: Callable[[ServerConnection], Awaitable[None]], + *, + process_request: ( + Callable[ + [ServerConnection, Request], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + server_header: str | None = USER_AGENT, + open_timeout: float | None = 10, + logger: LoggerLike | None = None, + ) -> None: + self.loop = asyncio.get_running_loop() + self.handler = handler + self.process_request = process_request + self.process_response = process_response + self.server_header = server_header + self.open_timeout = open_timeout + if logger is None: + logger = logging.getLogger("websockets.server") + self.logger = logger + + # Keep track of active connections. + self.handlers: dict[ServerConnection, asyncio.Task[None]] = {} + + # Task responsible for closing the server and terminating connections. + self.close_task: asyncio.Task[None] | None = None + + # Completed when the server is closed and connections are terminated. + self.closed_waiter: asyncio.Future[None] = self.loop.create_future() + + def wrap(self, server: asyncio.Server) -> None: + """ + Attach to a given :class:`~asyncio.Server`. + + Since :meth:`~asyncio.loop.create_server` doesn't support injecting a + custom ``Server`` class, the easiest solution that doesn't rely on + private :mod:`asyncio` APIs is to: + + - instantiate a :class:`WebSocketServer` + - give the protocol factory a reference to that instance + - call :meth:`~asyncio.loop.create_server` with the factory + - attach the resulting :class:`~asyncio.Server` with this method + + """ + self.server = server + for sock in server.sockets: + if sock.family == socket.AF_INET: + name = "%s:%d" % sock.getsockname() + elif sock.family == socket.AF_INET6: + name = "[%s]:%d" % sock.getsockname()[:2] + elif sock.family == socket.AF_UNIX: + name = sock.getsockname() + # In the unlikely event that someone runs websockets over a + # protocol other than IP or Unix sockets, avoid crashing. + else: # pragma: no cover + name = str(sock.getsockname()) + self.logger.info("server listening on %s", name) + + async def conn_handler(self, connection: ServerConnection) -> None: + """ + Handle the lifecycle of a WebSocket connection. + + Since this method doesn't have a caller that can handle exceptions, + it attempts to log relevant ones. + + It guarantees that the TCP connection is closed before exiting. + + """ + try: + # On failure, handshake() closes the transport, raises an + # exception, and logs it. + async with asyncio_timeout(self.open_timeout): + await connection.handshake( + self.process_request, + self.process_response, + self.server_header, + ) + + try: + await self.handler(connection) + except Exception: + self.logger.error("connection handler failed", exc_info=True) + await connection.close(CloseCode.INTERNAL_ERROR) + else: + await connection.close() + + except Exception: + # Don't leak connections on errors. + connection.transport.abort() + + finally: + # Registration is tied to the lifecycle of conn_handler() because + # the server waits for connection handlers to terminate, even if + # all connections are already closed. + del self.handlers[connection] + + def start_connection_handler(self, connection: ServerConnection) -> None: + """ + Register a connection with this server. + + """ + # The connection must be registered in self.handlers immediately. + # If it was registered in conn_handler(), a race condition could + # happen when closing the server after scheduling conn_handler() + # but before it starts executing. + self.handlers[connection] = self.loop.create_task(self.conn_handler(connection)) + + def close(self, close_connections: bool = True) -> None: + """ + Close the server. + + * Close the underlying :class:`~asyncio.Server`. + * When ``close_connections`` is :obj:`True`, which is the default, + close existing connections. Specifically: + + * Reject opening WebSocket connections with an HTTP 503 (service + unavailable) error. This happens when the server accepted the TCP + connection but didn't complete the opening handshake before closing. + * Close open WebSocket connections with close code 1001 (going away). + + * Wait until all connection handlers terminate. + + :meth:`close` is idempotent. + + """ + if self.close_task is None: + self.close_task = self.get_loop().create_task( + self._close(close_connections) + ) + + async def _close(self, close_connections: bool) -> None: + """ + Implementation of :meth:`close`. + + This calls :meth:`~asyncio.Server.close` on the underlying + :class:`~asyncio.Server` object to stop accepting new connections and + then closes open connections with close code 1001. + + """ + self.logger.info("server closing") + + # Stop accepting new connections. + self.server.close() + + # Wait until all accepted connections reach connection_made() and call + # register(). See https://github.com/python/cpython/issues/79033 for + # details. This workaround can be removed when dropping Python < 3.11. + await asyncio.sleep(0) + + if close_connections: + # Close OPEN connections with close code 1001. After server.close(), + # handshake() closes OPENING connections with an HTTP 503 error. + close_tasks = [ + asyncio.create_task(connection.close(1001)) + for connection in self.handlers + if connection.protocol.state is not CONNECTING + ] + # asyncio.wait doesn't accept an empty first argument. + if close_tasks: + await asyncio.wait(close_tasks) + + # Wait until all TCP connections are closed. + await self.server.wait_closed() + + # Wait until all connection handlers terminate. + # asyncio.wait doesn't accept an empty first argument. + if self.handlers: + await asyncio.wait(self.handlers.values()) + + # Tell wait_closed() to return. + self.closed_waiter.set_result(None) + + self.logger.info("server closed") + + async def wait_closed(self) -> None: + """ + Wait until the server is closed. + + When :meth:`wait_closed` returns, all TCP connections are closed and + all connection handlers have returned. + + To ensure a fast shutdown, a connection handler should always be + awaiting at least one of: + + * :meth:`~ServerConnection.recv`: when the connection is closed, + it raises :exc:`~websockets.exceptions.ConnectionClosedOK`; + * :meth:`~ServerConnection.wait_closed`: when the connection is + closed, it returns. + + Then the connection handler is immediately notified of the shutdown; + it can clean up and exit. + + """ + await asyncio.shield(self.closed_waiter) + + def get_loop(self) -> asyncio.AbstractEventLoop: + """ + See :meth:`asyncio.Server.get_loop`. + + """ + return self.server.get_loop() + + def is_serving(self) -> bool: # pragma: no cover + """ + See :meth:`asyncio.Server.is_serving`. + + """ + return self.server.is_serving() + + async def start_serving(self) -> None: # pragma: no cover + """ + See :meth:`asyncio.Server.start_serving`. + + Typical use:: + + server = await serve(..., start_serving=False) + # perform additional setup here... + # ... then start the server + await server.start_serving() + + """ + await self.server.start_serving() + + async def serve_forever(self) -> None: # pragma: no cover + """ + See :meth:`asyncio.Server.serve_forever`. + + Typical use:: + + server = await serve(...) + # this coroutine doesn't return + # canceling it stops the server + await server.serve_forever() + + This is an alternative to using :func:`serve` as an asynchronous context + manager. Shutdown is triggered by canceling :meth:`serve_forever` + instead of exiting a :func:`serve` context. + + """ + await self.server.serve_forever() + + @property + def sockets(self) -> Iterable[socket.socket]: + """ + See :attr:`asyncio.Server.sockets`. + + """ + return self.server.sockets + + async def __aenter__(self) -> WebSocketServer: # pragma: no cover + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: # pragma: no cover + self.close() + await self.wait_closed() + + +# This is spelled in lower case because it's exposed as a callable in the API. +class serve: + """ + Create a WebSocket server listening on ``host`` and ``port``. + + Whenever a client connects, the server creates a :class:`ServerConnection`, + performs the opening handshake, and delegates to the ``handler`` coroutine. + + The handler receives the :class:`ServerConnection` instance, which you can + use to send and receive messages. + + Once the handler completes, either normally or with an exception, the server + performs the closing handshake and closes the connection. + + This coroutine returns a :class:`WebSocketServer` whose API mirrors + :class:`~asyncio.Server`. Treat it as an asynchronous context manager to + ensure that the server will be closed:: + + def handler(websocket): + ... + + # set this future to exit the server + stop = asyncio.get_running_loop().create_future() + + async with websockets.asyncio.server.serve(handler, host, port): + await stop + + Alternatively, call :meth:`~WebSocketServer.serve_forever` to serve requests + and cancel it to stop the server:: + + server = await websockets.asyncio.server.serve(handler, host, port) + await server.serve_forever() + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + host: Network interfaces the server binds to. + See :meth:`~asyncio.loop.create_server` for details. + port: TCP port the server listens on. + See :meth:`~asyncio.loop.create_server` for details. + origins: Acceptable values of the ``Origin`` header, for defending + against Cross-Site WebSocket Hijacking attacks. Include :obj:`None` + in the list if the lack of an origin is acceptable. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + select_subprotocol: Callback for selecting a subprotocol among + those supported by the client and the server. It receives a + :class:`ServerConnection` (not a + :class:`~websockets.server.ServerProtocol`!) instance and a list of + subprotocols offered by the client. Other than the first argument, + it has the same behavior as the + :meth:`ServerProtocol.select_subprotocol + ` method. + process_request: Intercept the request during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + ``process_request`` may be a function or a coroutine. + process_response: Intercept the response during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + ``process_response`` may be a function or a coroutine. + server_header: Value of the ``Server`` response header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to + :obj:`None` removes the header. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + open_timeout: Timeout for opening connections in seconds. + :obj:`None` disables the timeout. + close_timeout: Timeout for closing connections in seconds. + :obj:`None` disables the timeout. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. See the + :doc:`logging guide <../../topics/logging>` for details. + create_connection: Factory for the :class:`ServerConnection` managing + the connection. Set it to a wrapper or a subclass to customize + connection handling. + + Any other keyword arguments are passed to the event loop's + :meth:`~asyncio.loop.create_server` method. + + For example: + + * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enable TLS. + + * You can set ``sock`` to provide a preexisting TCP socket. You may call + :func:`socket.create_server` (not to be confused with the event loop's + :meth:`~asyncio.loop.create_server` method) to create a suitable server + socket and customize it. + + * You can set ``start_serving`` to ``False`` to start accepting connections + only after you call :meth:`~WebSocketServer.start_serving()` or + :meth:`~WebSocketServer.serve_forever()`. + + """ + + def __init__( + self, + handler: Callable[[ServerConnection], Awaitable[None]], + host: str | None = None, + port: int | None = None, + *, + # WebSocket + origins: Sequence[Origin | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + select_subprotocol: ( + Callable[ + [ServerConnection, Sequence[Subprotocol]], + Subprotocol | None, + ] + | None + ) = None, + process_request: ( + Callable[ + [ServerConnection, Request], + Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Response | None, + ] + | None + ) = None, + server_header: str | None = USER_AGENT, + compression: str | None = "deflate", + # Timeouts + open_timeout: float | None = 10, + close_timeout: float | None = 10, + # Limits + max_size: int | None = 2**20, + # Logging + logger: LoggerLike | None = None, + # Escape hatch for advanced customization + create_connection: type[ServerConnection] | None = None, + # Other keyword arguments are passed to loop.create_server + **kwargs: Any, + ) -> None: + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_server_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if create_connection is None: + create_connection = ServerConnection + + self.server = WebSocketServer( + handler, + process_request=process_request, + process_response=process_response, + server_header=server_header, + open_timeout=open_timeout, + logger=logger, + ) + + if kwargs.get("ssl") is not None: + kwargs.setdefault("ssl_handshake_timeout", open_timeout) + if sys.version_info[:2] >= (3, 11): # pragma: no branch + kwargs.setdefault("ssl_shutdown_timeout", close_timeout) + + def factory() -> ServerConnection: + """ + Create an asyncio protocol for managing a WebSocket connection. + + """ + # Create a closure to give select_subprotocol access to connection. + protocol_select_subprotocol: ( + Callable[ + [ServerProtocol, Sequence[Subprotocol]], + Subprotocol | None, + ] + | None + ) = None + if select_subprotocol is not None: + + def protocol_select_subprotocol( + protocol: ServerProtocol, + subprotocols: Sequence[Subprotocol], + ) -> Subprotocol | None: + # mypy doesn't know that select_subprotocol is immutable. + assert select_subprotocol is not None + # Ensure this function is only used in the intended context. + assert protocol is connection.protocol + return select_subprotocol(connection, subprotocols) + + # This is a protocol in the Sans-I/O implementation of websockets. + protocol = ServerProtocol( + origins=origins, + extensions=extensions, + subprotocols=subprotocols, + select_subprotocol=protocol_select_subprotocol, + max_size=max_size, + logger=logger, + ) + # This is a connection in websockets and a protocol in asyncio. + connection = create_connection( + protocol, + self.server, + close_timeout=close_timeout, + ) + return connection + + loop = asyncio.get_running_loop() + if kwargs.pop("unix", False): + self._create_server = loop.create_unix_server(factory, **kwargs) + else: + # mypy cannot tell that kwargs must provide sock when port is None. + self._create_server = loop.create_server(factory, host, port, **kwargs) # type: ignore[arg-type] + + # async with serve(...) as ...: ... + + async def __aenter__(self) -> WebSocketServer: + return await self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.server.close() + await self.server.wait_closed() + + # ... = await serve(...) + + def __await__(self) -> Generator[Any, None, WebSocketServer]: + # Create a suitable iterator by calling __await__ on a coroutine. + return self.__await_impl__().__await__() + + async def __await_impl__(self) -> WebSocketServer: + server = await self._create_server + self.server.wrap(server) + return self.server + + # ... = yield from serve(...) - remove when dropping Python < 3.10 + + __iter__ = __await__ + + +def unix_serve( + handler: Callable[[ServerConnection], Awaitable[None]], + path: str | None = None, + **kwargs: Any, +) -> Awaitable[WebSocketServer]: + """ + Create a WebSocket server listening on a Unix socket. + + This function is identical to :func:`serve`, except the ``host`` and + ``port`` arguments are replaced by ``path``. It's only available on Unix. + + It's useful for deploying a server behind a reverse proxy such as nginx. + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + path: File system path to the Unix socket. + + """ + return serve(handler, unix=True, path=path, **kwargs) diff --git a/tests/asyncio/client.py b/tests/asyncio/client.py new file mode 100644 index 000000000..e5826add7 --- /dev/null +++ b/tests/asyncio/client.py @@ -0,0 +1,33 @@ +import contextlib + +from websockets.asyncio.client import * +from websockets.asyncio.server import WebSocketServer + +from .server import get_server_host_port + + +__all__ = [ + "run_client", + "run_unix_client", +] + + +@contextlib.asynccontextmanager +async def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): + if isinstance(wsuri_or_server, str): + wsuri = wsuri_or_server + else: + assert isinstance(wsuri_or_server, WebSocketServer) + if secure is None: + secure = "ssl" in kwargs + protocol = "wss" if secure else "ws" + host, port = get_server_host_port(wsuri_or_server) + wsuri = f"{protocol}://{host}:{port}{resource_name}" + async with connect(wsuri, **kwargs) as client: + yield client + + +@contextlib.asynccontextmanager +async def run_unix_client(path, **kwargs): + async with unix_connect(path, **kwargs) as client: + yield client diff --git a/tests/asyncio/connection.py b/tests/asyncio/connection.py new file mode 100644 index 000000000..ad1c121bf --- /dev/null +++ b/tests/asyncio/connection.py @@ -0,0 +1,115 @@ +import asyncio +import contextlib + +from websockets.asyncio.connection import Connection + + +class InterceptingConnection(Connection): + """ + Connection subclass that can intercept outgoing packets. + + By interfacing with this connection, we simulate network conditions + affecting what the component being tested receives during a test. + + """ + + def connection_made(self, transport): + super().connection_made(InterceptingTransport(transport)) + + @contextlib.contextmanager + def delay_frames_sent(self, delay): + """ + Add a delay before sending frames. + + This can result in out-of-order writes, which is unrealistic. + + """ + assert self.transport.delay_write is None + self.transport.delay_write = delay + try: + yield + finally: + self.transport.delay_write = None + + @contextlib.contextmanager + def delay_eof_sent(self, delay): + """ + Add a delay before sending EOF. + + This can result in out-of-order writes, which is unrealistic. + + """ + assert self.transport.delay_write_eof is None + self.transport.delay_write_eof = delay + try: + yield + finally: + self.transport.delay_write_eof = None + + @contextlib.contextmanager + def drop_frames_sent(self): + """ + Prevent frames from being sent. + + Since TCP is reliable, sending frames or EOF afterwards is unrealistic. + + """ + assert not self.transport.drop_write + self.transport.drop_write = True + try: + yield + finally: + self.transport.drop_write = False + + @contextlib.contextmanager + def drop_eof_sent(self): + """ + Prevent EOF from being sent. + + Since TCP is reliable, sending frames or EOF afterwards is unrealistic. + + """ + assert not self.transport.drop_write_eof + self.transport.drop_write_eof = True + try: + yield + finally: + self.transport.drop_write_eof = False + + +class InterceptingTransport: + """ + Transport wrapper that intercepts calls to ``write()`` and ``write_eof()``. + + This is coupled to the implementation, which relies on these two methods. + + Since ``write()`` and ``write_eof()`` are not coroutines, this effect is + achieved by scheduling writes at a later time, after the methods return. + This can easily result in out-of-order writes, which is unrealistic. + + """ + + def __init__(self, transport): + self.loop = asyncio.get_running_loop() + self.transport = transport + self.delay_write = None + self.delay_write_eof = None + self.drop_write = False + self.drop_write_eof = False + + def __getattr__(self, name): + return getattr(self.transport, name) + + def write(self, data): + if not self.drop_write: + if self.delay_write is not None: + self.loop.call_later(self.delay_write, self.transport.write, data) + else: + self.transport.write(data) + + def write_eof(self): + if not self.drop_write_eof: + if self.delay_write_eof is not None: + self.loop.call_later(self.delay_write_eof, self.transport.write_eof) + else: + self.transport.write_eof() diff --git a/tests/asyncio/server.py b/tests/asyncio/server.py new file mode 100644 index 000000000..0fe20dc65 --- /dev/null +++ b/tests/asyncio/server.py @@ -0,0 +1,50 @@ +import asyncio +import contextlib +import socket + +from websockets.asyncio.server import * + + +def get_server_host_port(server): + for sock in server.sockets: + if sock.family == socket.AF_INET: # pragma: no branch + return sock.getsockname() + raise AssertionError("expected at least one IPv4 socket") + + +async def eval_shell(ws): + async for expr in ws: + value = eval(expr) + await ws.send(str(value)) + + +class EvalShellMixin: + async def assertEval(self, client, expr, value): + await client.send(expr) + self.assertEqual(await client.recv(), value) + + +async def crash(ws): + raise RuntimeError + + +async def do_nothing(ws): + pass + + +async def keep_running(ws): + delay = float(await ws.recv()) + await ws.close() + await asyncio.sleep(delay) + + +@contextlib.asynccontextmanager +async def run_server(handler=eval_shell, host="localhost", port=0, **kwargs): + async with serve(handler, host, port, **kwargs) as server: + yield server + + +@contextlib.asynccontextmanager +async def run_unix_server(path, handler=eval_shell, **kwargs): + async with unix_serve(handler, path, **kwargs) as server: + yield server diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py new file mode 100644 index 000000000..aab65cd2e --- /dev/null +++ b/tests/asyncio/test_client.py @@ -0,0 +1,306 @@ +import asyncio +import socket +import ssl +import unittest + +from websockets.asyncio.client import * +from websockets.asyncio.compatibility import TimeoutError +from websockets.exceptions import InvalidHandshake, InvalidURI +from websockets.extensions.permessage_deflate import PerMessageDeflate + +from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT, temp_unix_socket_path +from .client import run_client, run_unix_client +from .server import do_nothing, get_server_host_port, run_server, run_unix_server + + +class ClientTests(unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Client connects to server and the handshake succeeds.""" + async with run_server() as server: + async with run_client(server) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_existing_socket(self): + """Client connects using a pre-existing socket.""" + async with run_server() as server: + with socket.create_connection(get_server_host_port(server)) as sock: + # Use a non-existing domain to ensure we connect to the right socket. + async with run_client("ws://invalid/", sock=sock) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_additional_headers(self): + """Client can set additional headers with additional_headers.""" + async with run_server() as server: + async with run_client( + server, additional_headers={"Authorization": "Bearer ..."} + ) as client: + self.assertEqual(client.request.headers["Authorization"], "Bearer ...") + + async def test_override_user_agent(self): + """Client can override User-Agent header with user_agent_header.""" + async with run_server() as server: + async with run_client(server, user_agent_header="Smith") as client: + self.assertEqual(client.request.headers["User-Agent"], "Smith") + + async def test_remove_user_agent(self): + """Client can remove User-Agent header with user_agent_header.""" + async with run_server() as server: + async with run_client(server, user_agent_header=None) as client: + self.assertNotIn("User-Agent", client.request.headers) + + async def test_compression_is_enabled(self): + """Client enables compression by default.""" + async with run_server() as server: + async with run_client(server) as client: + self.assertEqual( + [type(ext) for ext in client.protocol.extensions], + [PerMessageDeflate], + ) + + async def test_disable_compression(self): + """Client disables compression.""" + async with run_server() as server: + async with run_client(server, compression=None) as client: + self.assertEqual(client.protocol.extensions, []) + + async def test_custom_connection_factory(self): + """Client runs ClientConnection factory provided in create_connection.""" + + def create_connection(*args, **kwargs): + client = ClientConnection(*args, **kwargs) + client.create_connection_ran = True + return client + + async with run_server() as server: + async with run_client( + server, create_connection=create_connection + ) as client: + self.assertTrue(client.create_connection_ran) + + async def test_invalid_uri(self): + """Client receives an invalid URI.""" + with self.assertRaises(InvalidURI): + async with run_client("http://localhost"): # invalid scheme + self.fail("did not raise") + + async def test_tcp_connection_fails(self): + """Client fails to connect to server.""" + with self.assertRaises(OSError): + async with run_client("ws://localhost:54321"): # invalid port + self.fail("did not raise") + + async def test_handshake_fails(self): + """Client connects to server but the handshake fails.""" + + def remove_accept_header(self, request, response): + del response.headers["Sec-WebSocket-Accept"] + + # The connection will be open for the server but failed for the client. + # Use a connection handler that exits immediately to avoid an exception. + async with run_server( + do_nothing, process_response=remove_accept_header + ) as server: + with self.assertRaises(InvalidHandshake) as raised: + async with run_client(server, close_timeout=MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "missing Sec-WebSocket-Accept header", + ) + + async def test_timeout_during_handshake(self): + """Client times out before receiving handshake response from server.""" + gate = asyncio.get_running_loop().create_future() + + async def stall_connection(self, request): + await gate + + # The connection will be open for the server but failed for the client. + # Use a connection handler that exits immediately to avoid an exception. + async with run_server(do_nothing, process_request=stall_connection) as server: + try: + with self.assertRaises(TimeoutError) as raised: + async with run_client(server, open_timeout=2 * MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during handshake", + ) + finally: + gate.set_result(None) + + async def test_connection_closed_during_handshake(self): + """Client reads EOF before receiving handshake response from server.""" + + def close_connection(self, request): + self.close_transport() + + async with run_server(process_request=close_connection) as server: + with self.assertRaises(ConnectionError) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "connection closed during handshake", + ) + + +class SecureClientTests(unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Client connects to server securely.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with run_client(server, ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.version()[:3], "TLS") + + async def test_set_server_hostname_implicitly(self): + """Client sets server_hostname to the host in the WebSocket URI.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with run_unix_client( + path, + ssl=CLIENT_CONTEXT, + uri="wss://overridden/", + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") + + async def test_set_server_hostname_explicitly(self): + """Client sets server_hostname to the value provided in argument.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with run_unix_client( + path, + ssl=CLIENT_CONTEXT, + server_hostname="overridden", + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") + + async def test_reject_invalid_server_certificate(self): + """Client rejects certificate where server certificate isn't trusted.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The test certificate isn't trusted system-wide. + async with run_client(server, secure=True): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: self signed certificate", + str(raised.exception).replace("-", " "), + ) + + async def test_reject_invalid_server_hostname(self): + """Client rejects certificate where server hostname doesn't match.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # This hostname isn't included in the test certificate. + async with run_client( + server, ssl=CLIENT_CONTEXT, server_hostname="invalid" + ): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: Hostname mismatch", + str(raised.exception), + ) + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class UnixClientTests(unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Client connects to server over a Unix socket.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path): + async with run_unix_client(path) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_set_host_header(self): + """Client sets the Host header to the host in the WebSocket URI.""" + # This is part of the documented behavior of unix_connect(). + with temp_unix_socket_path() as path: + async with run_unix_server(path): + async with run_unix_client(path, uri="ws://overridden/") as client: + self.assertEqual(client.request.headers["Host"], "overridden") + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class SecureUnixClientTests(unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Client connects to server securely over a Unix socket.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.version()[:3], "TLS") + + async def test_set_server_hostname(self): + """Client sets server_hostname to the host in the WebSocket URI.""" + # This is part of the documented behavior of unix_connect(). + with temp_unix_socket_path() as path: + async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with run_unix_client( + path, + ssl=CLIENT_CONTEXT, + uri="wss://overridden/", + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") + + +class ClientUsageErrorsTests(unittest.IsolatedAsyncioTestCase): + async def test_ssl_without_secure_uri(self): + """Client rejects ssl when URI isn't secure.""" + with self.assertRaises(TypeError) as raised: + await connect("ws://localhost/", ssl=CLIENT_CONTEXT) + self.assertEqual( + str(raised.exception), + "ssl argument is incompatible with a ws:// URI", + ) + + async def test_secure_uri_without_ssl(self): + """Client rejects no ssl when URI is secure.""" + with self.assertRaises(TypeError) as raised: + await connect("wss://localhost/", ssl=None) + self.assertEqual( + str(raised.exception), + "ssl=None is incompatible with a wss:// URI", + ) + + async def test_unix_without_path_or_sock(self): + """Unix client requires path when sock isn't provided.""" + with self.assertRaises(ValueError) as raised: + await unix_connect() + self.assertEqual( + str(raised.exception), + "no path and sock were specified", + ) + + async def test_unix_with_path_and_sock(self): + """Unix client rejects path when sock is provided.""" + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.addCleanup(sock.close) + with self.assertRaises(ValueError) as raised: + await unix_connect(path="/", sock=sock) + self.assertEqual( + str(raised.exception), + "path and sock can not be specified at the same time", + ) + + async def test_invalid_subprotocol(self): + """Client rejects single value of subprotocols.""" + with self.assertRaises(TypeError) as raised: + await connect("ws://localhost/", subprotocols="chat") + self.assertEqual( + str(raised.exception), + "subprotocols must be a list, not a str", + ) + + async def test_unsupported_compression(self): + """Client rejects incorrect value of compression.""" + with self.assertRaises(ValueError) as raised: + await connect("ws://localhost/", compression=False) + self.assertEqual( + str(raised.exception), + "unsupported compression: False", + ) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py new file mode 100644 index 000000000..a8b3980b4 --- /dev/null +++ b/tests/asyncio/test_connection.py @@ -0,0 +1,948 @@ +import asyncio +import contextlib +import logging +import socket +import unittest +import uuid +from unittest.mock import patch + +from websockets.asyncio.compatibility import TimeoutError, aiter, anext, asyncio_timeout +from websockets.asyncio.connection import * +from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK +from websockets.frames import CloseCode, Frame, Opcode +from websockets.protocol import CLIENT, SERVER, Protocol + +from ..protocol import RecordingProtocol +from ..utils import MS +from .connection import InterceptingConnection +from .utils import alist + + +# Connection implements symmetrical behavior between clients and servers. +# All tests run on the client side and the server side to validate this. + + +class ClientConnectionTests(unittest.IsolatedAsyncioTestCase): + LOCAL = CLIENT + REMOTE = SERVER + + async def asyncSetUp(self): + loop = asyncio.get_running_loop() + socket_, remote_socket = socket.socketpair() + self.transport, self.connection = await loop.create_connection( + lambda: Connection(Protocol(self.LOCAL), close_timeout=2 * MS), + sock=socket_, + ) + self.remote_transport, self.remote_connection = await loop.create_connection( + lambda: InterceptingConnection(RecordingProtocol(self.REMOTE)), + sock=remote_socket, + ) + + async def asyncTearDown(self): + await self.remote_connection.close() + await self.connection.close() + + # Test helpers built upon RecordingProtocol and InterceptingConnection. + + async def assertFrameSent(self, frame): + """Check that a single frame was sent.""" + # Let the remote side process messages. + # Two runs of the event loop are required for answering pings. + await asyncio.sleep(0) + await asyncio.sleep(0) + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), [frame]) + + async def assertFramesSent(self, frames): + """Check that several frames were sent.""" + # Let the remote side process messages. + # Two runs of the event loop are required for answering pings. + await asyncio.sleep(0) + await asyncio.sleep(0) + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), frames) + + async def assertNoFrameSent(self): + """Check that no frame was sent.""" + # Run the event loop twice for consistency with assertFrameSent. + await asyncio.sleep(0) + await asyncio.sleep(0) + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), []) + + @contextlib.asynccontextmanager + async def delay_frames_rcvd(self, delay): + """Delay frames before they're received by the connection.""" + with self.remote_connection.delay_frames_sent(delay): + yield + await asyncio.sleep(MS) # let the remote side process messages + + @contextlib.asynccontextmanager + async def delay_eof_rcvd(self, delay): + """Delay EOF before it's received by the connection.""" + with self.remote_connection.delay_eof_sent(delay): + yield + await asyncio.sleep(MS) # let the remote side process messages + + @contextlib.asynccontextmanager + async def drop_frames_rcvd(self): + """Drop frames before they're received by the connection.""" + with self.remote_connection.drop_frames_sent(): + yield + await asyncio.sleep(MS) # let the remote side process messages + + @contextlib.asynccontextmanager + async def drop_eof_rcvd(self): + """Drop EOF before it's received by the connection.""" + with self.remote_connection.drop_eof_sent(): + yield + await asyncio.sleep(MS) # let the remote side process messages + + # Test __aenter__ and __aexit__. + + async def test_aenter(self): + """__aenter__ returns the connection itself.""" + async with self.connection as connection: + self.assertIs(connection, self.connection) + + async def test_aexit(self): + """__aexit__ closes the connection with code 1000.""" + async with self.connection: + await self.assertNoFrameSent() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + async def test_exit_with_exception(self): + """__exit__ with an exception closes the connection with code 1011.""" + with self.assertRaises(RuntimeError): + async with self.connection: + raise RuntimeError + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xf3")) + + # Test __aiter__. + + async def test_aiter_text(self): + """__aiter__ yields text messages.""" + aiterator = aiter(self.connection) + await self.remote_connection.send("😀") + self.assertEqual(await anext(aiterator), "😀") + await self.remote_connection.send("😀") + self.assertEqual(await anext(aiterator), "😀") + + async def test_aiter_binary(self): + """__aiter__ yields binary messages.""" + aiterator = aiter(self.connection) + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") + + async def test_aiter_mixed(self): + """__aiter__ yields a mix of text and binary messages.""" + aiterator = aiter(self.connection) + await self.remote_connection.send("😀") + self.assertEqual(await anext(aiterator), "😀") + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") + + async def test_aiter_connection_closed_ok(self): + """__aiter__ terminates after a normal closure.""" + aiterator = aiter(self.connection) + await self.remote_connection.close() + with self.assertRaises(StopAsyncIteration): + await anext(aiterator) + + async def test_aiter_connection_closed_error(self): + """__aiter__ raises ConnnectionClosedError after an error.""" + aiterator = aiter(self.connection) + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await anext(aiterator) + + # Test recv. + + async def test_recv_text(self): + """recv receives a text message.""" + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(), "😀") + + async def test_recv_binary(self): + """recv receives a binary message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") + + async def test_recv_fragmented_text(self): + """recv receives a fragmented text message.""" + await self.remote_connection.send(["😀", "😀"]) + self.assertEqual(await self.connection.recv(), "😀😀") + + async def test_recv_fragmented_binary(self): + """recv receives a fragmented binary message.""" + await self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) + self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") + + async def test_recv_connection_closed_ok(self): + """recv raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + await self.connection.recv() + + async def test_recv_connection_closed_error(self): + """recv raises ConnectionClosedError after an error.""" + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await self.connection.recv() + + async def test_recv_during_recv(self): + """recv raises RuntimeError when called concurrently with itself.""" + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(0) # let the event loop start recv_task + self.addCleanup(recv_task.cancel) + + with self.assertRaises(RuntimeError) as raised: + await self.connection.recv() + self.assertEqual( + str(raised.exception), + "cannot call recv while another coroutine " + "is already running recv or recv_streaming", + ) + + async def test_recv_during_recv_streaming(self): + """recv raises RuntimeError when called concurrently with recv_streaming.""" + recv_streaming_task = asyncio.create_task( + alist(self.connection.recv_streaming()) + ) + await asyncio.sleep(0) # let the event loop start recv_streaming_task + self.addCleanup(recv_streaming_task.cancel) + + with self.assertRaises(RuntimeError) as raised: + await self.connection.recv() + self.assertEqual( + str(raised.exception), + "cannot call recv while another coroutine " + "is already running recv or recv_streaming", + ) + + async def test_recv_cancellation_before_receiving(self): + """recv can be cancelled before receiving a frame.""" + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(0) # let the event loop start recv_task + + recv_task.cancel() + await asyncio.sleep(0) # let the event loop cancel recv_task + + # Running recv again receives the next message. + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(), "😀") + + async def test_recv_cancellation_while_receiving(self): + """recv cannot be cancelled after receiving a frame.""" + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(0) # let the event loop start recv_task + + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + yield "⌛️" + + asyncio.create_task(self.remote_connection.send(fragments())) + await asyncio.sleep(MS) + + recv_task.cancel() + await asyncio.sleep(0) # let the event loop cancel recv_task + + # Running recv again receives the complete message. + gate.set_result(None) + self.assertEqual(await self.connection.recv(), "⏳⌛️") + + # Test recv_streaming. + + async def test_recv_streaming_text(self): + """recv_streaming receives a text message.""" + await self.remote_connection.send("😀") + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀"], + ) + + async def test_recv_streaming_binary(self): + """recv_streaming receives a binary message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual( + await alist(self.connection.recv_streaming()), + [b"\x01\x02\xfe\xff"], + ) + + async def test_recv_streaming_fragmented_text(self): + """recv_streaming receives a fragmented text message.""" + await self.remote_connection.send(["😀", "😀"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_recv_streaming_fragmented_binary(self): + """recv_streaming receives a fragmented binary message.""" + await self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_recv_streaming_connection_closed_ok(self): + """recv_streaming raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + + async def test_recv_streaming_connection_closed_error(self): + """recv_streaming raises ConnectionClosedError after an error.""" + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + + async def test_recv_streaming_during_recv(self): + """recv_streaming raises RuntimeError when called concurrently with recv.""" + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(0) # let the event loop start recv_task + self.addCleanup(recv_task.cancel) + + with self.assertRaises(RuntimeError) as raised: + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "cannot call recv_streaming while another coroutine " + "is already running recv or recv_streaming", + ) + + async def test_recv_streaming_during_recv_streaming(self): + """recv_streaming raises RuntimeError when called concurrently with itself.""" + recv_streaming_task = asyncio.create_task( + alist(self.connection.recv_streaming()) + ) + await asyncio.sleep(0) # let the event loop start recv_streaming_task + self.addCleanup(recv_streaming_task.cancel) + + with self.assertRaises(RuntimeError) as raised: + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + r"cannot call recv_streaming while another coroutine " + r"is already running recv or recv_streaming", + ) + + async def test_recv_streaming_cancellation_before_receiving(self): + """recv_streaming can be cancelled before receiving a frame.""" + recv_streaming_task = asyncio.create_task( + alist(self.connection.recv_streaming()) + ) + await asyncio.sleep(0) # let the event loop start recv_streaming_task + + recv_streaming_task.cancel() + await asyncio.sleep(0) # let the event loop cancel recv_streaming_task + + # Running recv_streaming again receives the next message. + await self.remote_connection.send(["😀", "😀"]) + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_recv_streaming_cancellation_while_receiving(self): + """recv_streaming cannot be cancelled after receiving a frame.""" + recv_streaming_task = asyncio.create_task( + alist(self.connection.recv_streaming()) + ) + await asyncio.sleep(0) # let the event loop start recv_streaming_task + + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + yield "⌛️" + + asyncio.create_task(self.remote_connection.send(fragments())) + await asyncio.sleep(MS) + + recv_streaming_task.cancel() + await asyncio.sleep(0) # let the event loop cancel recv_streaming_task + + gate.set_result(None) + # Running recv_streaming again fails. + with self.assertRaises(RuntimeError): + await alist(self.connection.recv_streaming()) + + # Test send. + + async def test_send_text(self): + """send sends a text message.""" + await self.connection.send("😀") + self.assertEqual(await self.remote_connection.recv(), "😀") + + async def test_send_binary(self): + """send sends a binary message.""" + await self.connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await self.remote_connection.recv(), b"\x01\x02\xfe\xff") + + async def test_send_fragmented_text(self): + """send sends a fragmented text message.""" + await self.connection.send(["😀", "😀"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_fragmented_binary(self): + """send sends a fragmented binary message.""" + await self.connection.send([b"\x01\x02", b"\xfe\xff"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_send_async_fragmented_text(self): + """send sends a fragmented text message asynchronously.""" + + async def fragments(): + yield "😀" + yield "😀" + + await self.connection.send(fragments()) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_async_fragmented_binary(self): + """send sends a fragmented binary message asynchronously.""" + + async def fragments(): + yield b"\x01\x02" + yield b"\xfe\xff" + + await self.connection.send(fragments()) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_send_connection_closed_ok(self): + """send raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + await self.connection.send("😀") + + async def test_send_connection_closed_error(self): + """send raises ConnectionClosedError after an error.""" + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await self.connection.send("😀") + + async def test_send_while_send_blocked(self): + """send waits for a previous call to send to complete.""" + # This test fails if the guard with fragmented_send_waiter is removed + # from send() in the case when message is an Iterable. + self.connection.pause_writing() + asyncio.create_task(self.connection.send(["⏳", "⌛️"])) + await asyncio.sleep(MS) + await self.assertFrameSent( + Frame(Opcode.TEXT, "⏳".encode(), fin=False), + ) + + asyncio.create_task(self.connection.send("✅")) + await asyncio.sleep(MS) + await self.assertNoFrameSent() + + self.connection.resume_writing() + await asyncio.sleep(MS) + await self.assertFramesSent( + [ + Frame(Opcode.CONT, "⌛️".encode(), fin=False), + Frame(Opcode.CONT, b"", fin=True), + Frame(Opcode.TEXT, "✅".encode()), + ] + ) + + async def test_send_while_send_async_blocked(self): + """send waits for a previous call to send to complete.""" + # This test fails if the guard with fragmented_send_waiter is removed + # from send() in the case when message is an AsyncIterable. + self.connection.pause_writing() + + async def fragments(): + yield "⏳" + yield "⌛️" + + asyncio.create_task(self.connection.send(fragments())) + await asyncio.sleep(MS) + await self.assertFrameSent( + Frame(Opcode.TEXT, "⏳".encode(), fin=False), + ) + + asyncio.create_task(self.connection.send("✅")) + await asyncio.sleep(MS) + await self.assertNoFrameSent() + + self.connection.resume_writing() + await asyncio.sleep(MS) + await self.assertFramesSent( + [ + Frame(Opcode.CONT, "⌛️".encode(), fin=False), + Frame(Opcode.CONT, b"", fin=True), + Frame(Opcode.TEXT, "✅".encode()), + ] + ) + + async def test_send_during_send_async(self): + """send waits for a previous call to send to complete.""" + # This test fails if the guard with fragmented_send_waiter is removed + # from send() in the case when message is an AsyncIterable. + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + yield "⌛️" + + asyncio.create_task(self.connection.send(fragments())) + await asyncio.sleep(MS) + await self.assertFrameSent( + Frame(Opcode.TEXT, "⏳".encode(), fin=False), + ) + + asyncio.create_task(self.connection.send("✅")) + await asyncio.sleep(MS) + await self.assertNoFrameSent() + + gate.set_result(None) + await asyncio.sleep(MS) + await self.assertFramesSent( + [ + Frame(Opcode.CONT, "⌛️".encode(), fin=False), + Frame(Opcode.CONT, b"", fin=True), + Frame(Opcode.TEXT, "✅".encode()), + ] + ) + + async def test_send_empty_iterable(self): + """send does nothing when called with an empty iterable.""" + await self.connection.send([]) + await self.connection.close() + self.assertEqual(await alist(self.remote_connection), []) + + async def test_send_mixed_iterable(self): + """send raises TypeError when called with an iterable of inconsistent types.""" + with self.assertRaises(TypeError): + await self.connection.send(["😀", b"\xfe\xff"]) + + async def test_send_unsupported_iterable(self): + """send raises TypeError when called with an iterable of unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.send([None]) + + async def test_send_empty_async_iterable(self): + """send does nothing when called with an empty async iterable.""" + + async def fragments(): + return + yield # pragma: no cover + + await self.connection.send(fragments()) + await self.connection.close() + self.assertEqual(await alist(self.remote_connection), []) + + async def test_send_mixed_async_iterable(self): + """send raises TypeError when called with an iterable of inconsistent types.""" + + async def fragments(): + yield "😀" + yield b"\xfe\xff" + + with self.assertRaises(TypeError): + await self.connection.send(fragments()) + + async def test_send_unsupported_async_iterable(self): + """send raises TypeError when called with an iterable of unsupported type.""" + + async def fragments(): + yield None + + with self.assertRaises(TypeError): + await self.connection.send(fragments()) + + async def test_send_dict(self): + """send raises TypeError when called with a dict.""" + with self.assertRaises(TypeError): + await self.connection.send({"type": "object"}) + + async def test_send_unsupported_type(self): + """send raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.send(None) + + # Test close. + + async def test_close(self): + """close sends a close frame.""" + await self.connection.close() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + async def test_close_explicit_code_reason(self): + """close sends a close frame with a given code and reason.""" + await self.connection.close(CloseCode.GOING_AWAY, "bye!") + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe9bye!")) + + async def test_close_waits_for_close_frame(self): + """close waits for a close frame (then EOF) before returning.""" + async with self.delay_frames_rcvd(MS), self.delay_eof_rcvd(MS): + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_waits_for_connection_closed(self): + """close waits for EOF before returning.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + async with self.delay_eof_rcvd(MS): + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_no_timeout_waits_for_close_frame(self): + """close without timeout waits for a close frame (then EOF) before returning.""" + self.connection.close_timeout = None + + async with self.delay_frames_rcvd(MS), self.delay_eof_rcvd(MS): + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_no_timeout_waits_for_connection_closed(self): + """close without timeout waits for EOF before returning.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + self.connection.close_timeout = None + + async with self.delay_eof_rcvd(MS): + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_timeout_waiting_for_close_frame(self): + """close times out if no close frame is received.""" + async with self.drop_eof_rcvd(), self.drop_frames_rcvd(): + await self.connection.close() + + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); no close frame received") + self.assertIsInstance(exc.__cause__, TimeoutError) + + async def test_close_timeout_waiting_for_connection_closed(self): + """close times out if EOF isn't received.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + async with self.drop_eof_rcvd(): + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + # Remove socket.timeout when dropping Python < 3.10. + self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) + + async def test_close_does_not_wait_for_recv(self): + # The asyncio implementation has a buffer for incoming messages. Closing + # the connection discards buffered messages. This is allowed by the RFC: + # > However, there is no guarantee that the endpoint that has already + # > sent a Close frame will continue to process data. + await self.remote_connection.send("😀") + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_idempotency(self): + """close does nothing if the connection is already closed.""" + await self.connection.close() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + await self.connection.close() + await self.assertNoFrameSent() + + async def test_close_during_recv(self): + """close aborts recv when called concurrently with recv.""" + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(MS) + await self.connection.close() + with self.assertRaises(ConnectionClosedOK) as raised: + await recv_task + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_during_send(self): + """close fails the connection when called concurrently with send.""" + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + yield "⌛️" + + send_task = asyncio.create_task(self.connection.send(fragments())) + await asyncio.sleep(MS) + + asyncio.create_task(self.connection.close()) + await asyncio.sleep(MS) + + gate.set_result(None) + + with self.assertRaises(ConnectionClosedError) as raised: + await send_task + + exc = raised.exception + self.assertEqual( + str(exc), + "sent 1011 (internal error) close during fragmented message; " + "no close frame received", + ) + self.assertIsNone(exc.__cause__) + + # Test wait_closed. + + async def test_wait_closed(self): + """wait_closed waits for the connection to close.""" + wait_closed_task = asyncio.create_task(self.connection.wait_closed()) + await asyncio.sleep(0) # let the event loop start wait_closed_task + self.assertFalse(wait_closed_task.done()) + await self.connection.close() + self.assertTrue(wait_closed_task.done()) + + # Test ping. + + @patch("random.getrandbits") + async def test_ping(self, getrandbits): + """ping sends a ping frame with a random payload.""" + getrandbits.return_value = 1918987876 + await self.connection.ping() + getrandbits.assert_called_once_with(32) + await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + + async def test_ping_explicit_text(self): + """ping sends a ping frame with a payload provided as text.""" + await self.connection.ping("ping") + await self.assertFrameSent(Frame(Opcode.PING, b"ping")) + + async def test_ping_explicit_binary(self): + """ping sends a ping frame with a payload provided as binary.""" + await self.connection.ping(b"ping") + await self.assertFrameSent(Frame(Opcode.PING, b"ping")) + + async def test_acknowledge_ping(self): + """ping is acknowledged by a pong with the same payload.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = await self.connection.ping("this") + await self.remote_connection.pong("this") + async with asyncio_timeout(MS): + await pong_waiter + + async def test_acknowledge_ping_non_matching_pong(self): + """ping isn't acknowledged by a pong with a different payload.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = await self.connection.ping("this") + await self.remote_connection.pong("that") + with self.assertRaises(TimeoutError): + async with asyncio_timeout(MS): + await pong_waiter + + async def test_acknowledge_previous_ping(self): + """ping is acknowledged by a pong with the same payload as a later ping.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = await self.connection.ping("this") + await self.connection.ping("that") + await self.remote_connection.pong("that") + async with asyncio_timeout(MS): + await pong_waiter + + async def test_ping_duplicate_payload(self): + """ping rejects the same payload until receiving the pong.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = await self.connection.ping("idem") + + with self.assertRaises(RuntimeError) as raised: + await self.connection.ping("idem") + self.assertEqual( + str(raised.exception), + "already waiting for a pong with the same data", + ) + + await self.remote_connection.pong("idem") + async with asyncio_timeout(MS): + await pong_waiter + + await self.connection.ping("idem") # doesn't raise an exception + + # Test pong. + + async def test_pong(self): + """pong sends a pong frame.""" + await self.connection.pong() + await self.assertFrameSent(Frame(Opcode.PONG, b"")) + + async def test_pong_explicit_text(self): + """pong sends a pong frame with a payload provided as text.""" + await self.connection.pong("pong") + await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + + async def test_pong_explicit_binary(self): + """pong sends a pong frame with a payload provided as binary.""" + await self.connection.pong(b"pong") + await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + + # Test attributes. + + async def test_id(self): + """Connection has an id attribute.""" + self.assertIsInstance(self.connection.id, uuid.UUID) + + async def test_logger(self): + """Connection has a logger attribute.""" + self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) + + @unittest.mock.patch( + "asyncio.BaseTransport.get_extra_info", return_value=("sock", 1234) + ) + async def test_local_address(self, get_extra_info): + """Connection provides a local_address attribute.""" + self.assertEqual(self.connection.local_address, ("sock", 1234)) + get_extra_info.assert_called_with("sockname") + + @unittest.mock.patch( + "asyncio.BaseTransport.get_extra_info", return_value=("peer", 1234) + ) + async def test_remote_address(self, get_extra_info): + """Connection provides a remote_address attribute.""" + self.assertEqual(self.connection.remote_address, ("peer", 1234)) + get_extra_info.assert_called_with("peername") + + async def test_request(self): + """Connection has a request attribute.""" + self.assertIsNone(self.connection.request) + + async def test_response(self): + """Connection has a response attribute.""" + self.assertIsNone(self.connection.response) + + async def test_subprotocol(self): + """Connection has a subprotocol attribute.""" + self.assertIsNone(self.connection.subprotocol) + + # Test reporting of network errors. + + async def test_writing_in_data_received_fails(self): + """Error when responding to incoming frames is correctly reported.""" + # Inject a fault by shutting down the transport for writing — but not by + # closing it because that would terminate the connection. + self.transport.write_eof() + # Receive a ping. Responding with a pong will fail. + await self.remote_connection.ping() + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + cause = raised.exception.__cause__ + self.assertEqual(str(cause), "Cannot call write() after write_eof()") + self.assertIsInstance(cause, RuntimeError) + + async def test_writing_in_send_context_fails(self): + """Error when sending outgoing frame is correctly reported.""" + # Inject a fault by shutting down the transport for writing — but not by + # closing it because that would terminate the connection. + self.transport.write_eof() + # Sending a pong will fail. + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.pong() + cause = raised.exception.__cause__ + self.assertEqual(str(cause), "Cannot call write() after write_eof()") + self.assertIsInstance(cause, RuntimeError) + + # Test safety nets — catching all exceptions in case of bugs. + + @patch("websockets.protocol.Protocol.events_received") + async def test_unexpected_failure_in_data_received(self, events_received): + """Unexpected internal error in data_received() is correctly reported.""" + # Inject a fault in a random call in data_received(). + # This test is tightly coupled to the implementation. + events_received.side_effect = AssertionError + # Receive a message to trigger the fault. + await self.remote_connection.send("😀") + + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "no close frame received or sent") + self.assertIsInstance(exc.__cause__, AssertionError) + + @patch("websockets.protocol.Protocol.send_text") + async def test_unexpected_failure_in_send_context(self, send_text): + """Unexpected internal error in send_context() is correctly reported.""" + # Inject a fault in a random call in send_context(). + # This test is tightly coupled to the implementation. + send_text.side_effect = AssertionError + + # Send a message to trigger the fault. + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.send("😀") + + exc = raised.exception + self.assertEqual(str(exc), "no close frame received or sent") + self.assertIsInstance(exc.__cause__, AssertionError) + + +class ServerConnectionTests(ClientConnectionTests): + LOCAL = SERVER + REMOTE = CLIENT diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py new file mode 100644 index 000000000..2e59f49b1 --- /dev/null +++ b/tests/asyncio/test_server.py @@ -0,0 +1,525 @@ +import asyncio +import dataclasses +import http +import logging +import socket +import unittest + +from websockets.asyncio.compatibility import TimeoutError, asyncio_timeout +from websockets.asyncio.server import * +from websockets.exceptions import ( + ConnectionClosedError, + ConnectionClosedOK, + InvalidStatus, + NegotiationError, +) +from websockets.http11 import Request, Response + +from ..utils import ( + CLIENT_CONTEXT, + MS, + SERVER_CONTEXT, + temp_unix_socket_path, +) +from .client import run_client, run_unix_client +from .server import ( + EvalShellMixin, + crash, + do_nothing, + eval_shell, + get_server_host_port, + keep_running, + run_server, + run_unix_server, +) + + +class ServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Server receives connection from client and the handshake succeeds.""" + async with run_server() as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + + async def test_connection_handler_returns(self): + """Connection handler returns.""" + async with run_server(do_nothing) as server: + async with run_client(server) as client: + with self.assertRaises(ConnectionClosedOK) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1000 (OK); then sent 1000 (OK)", + ) + + async def test_connection_handler_raises_exception(self): + """Connection handler raises an exception.""" + async with run_server(crash) as server: + async with run_client(server) as client: + with self.assertRaises(ConnectionClosedError) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1011 (internal error); " + "then sent 1011 (internal error)", + ) + + async def test_existing_socket(self): + """Server receives connection using a pre-existing socket.""" + with socket.create_server(("localhost", 0)) as sock: + async with run_server(sock=sock, host=None, port=None): + uri = "ws://{}:{}/".format(*sock.getsockname()) + async with run_client(uri) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + + async def test_select_subprotocol(self): + """Server selects a subprotocol with the select_subprotocol callable.""" + + def select_subprotocol(ws, subprotocols): + ws.select_subprotocol_ran = True + assert "chat" in subprotocols + return "chat" + + async with run_server( + subprotocols=["chat"], + select_subprotocol=select_subprotocol, + ) as server: + async with run_client(server, subprotocols=["chat"]) as client: + await self.assertEval(client, "ws.select_subprotocol_ran", "True") + await self.assertEval(client, "ws.subprotocol", "chat") + + async def test_select_subprotocol_rejects_handshake(self): + """Server rejects handshake if select_subprotocol raises NegotiationError.""" + + def select_subprotocol(ws, subprotocols): + raise NegotiationError + + async with run_server(select_subprotocol=select_subprotocol) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) + + async def test_select_subprotocol_raises_exception(self): + """Server returns an error if select_subprotocol raises an exception.""" + + def select_subprotocol(ws, subprotocols): + raise RuntimeError + + async with run_server(select_subprotocol=select_subprotocol) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_process_request(self): + """Server runs process_request before processing the handshake.""" + + def process_request(ws, request): + self.assertIsInstance(request, Request) + ws.process_request_ran = True + + async with run_server(process_request=process_request) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.process_request_ran", "True") + + async def test_async_process_request(self): + """Server runs async process_request before processing the handshake.""" + + async def process_request(ws, request): + self.assertIsInstance(request, Request) + ws.process_request_ran = True + + async with run_server(process_request=process_request) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.process_request_ran", "True") + + async def test_process_request_abort_handshake(self): + """Server aborts handshake if process_request returns a response.""" + + def process_request(ws, request): + return ws.protocol.reject(http.HTTPStatus.FORBIDDEN, "Forbidden") + + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) + + async def test_async_process_request_abort_handshake(self): + """Server aborts handshake if async process_request returns a response.""" + + async def process_request(ws, request): + return ws.protocol.reject(http.HTTPStatus.FORBIDDEN, "Forbidden") + + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) + + async def test_process_request_raises_exception(self): + """Server returns an error if process_request raises an exception.""" + + def process_request(ws, request): + raise RuntimeError + + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_async_process_request_raises_exception(self): + """Server returns an error if async process_request raises an exception.""" + + async def process_request(ws, request): + raise RuntimeError + + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_process_response(self): + """Server runs process_response after processing the handshake.""" + + def process_response(ws, request, response): + self.assertIsInstance(request, Request) + self.assertIsInstance(response, Response) + ws.process_response_ran = True + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.process_response_ran", "True") + + async def test_async_process_response(self): + """Server runs async process_response after processing the handshake.""" + + async def process_response(ws, request, response): + self.assertIsInstance(request, Request) + self.assertIsInstance(response, Response) + ws.process_response_ran = True + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.process_response_ran", "True") + + async def test_process_response_override_response(self): + """Server runs process_response and overrides the handshake response.""" + + def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse-Ran"] = "true" + return dataclasses.replace(response, headers=headers) + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + self.assertEqual( + client.response.headers["X-ProcessResponse-Ran"], "true" + ) + + async def test_async_process_response_override_response(self): + """Server runs async process_response and overrides the handshake response.""" + + async def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse-Ran"] = "true" + return dataclasses.replace(response, headers=headers) + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + self.assertEqual( + client.response.headers["X-ProcessResponse-Ran"], "true" + ) + + async def test_process_response_raises_exception(self): + """Server returns an error if process_response raises an exception.""" + + def process_response(ws, request, response): + raise RuntimeError + + async with run_server(process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_async_process_response_raises_exception(self): + """Server returns an error if async process_response raises an exception.""" + + async def process_response(ws, request, response): + raise RuntimeError + + async with run_server(process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_override_server(self): + """Server can override Server header with server_header.""" + async with run_server(server_header="Neo") as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.response.headers['Server']", "Neo") + + async def test_remove_server(self): + """Server can remove Server header with server_header.""" + async with run_server(server_header=None) as server: + async with run_client(server) as client: + await self.assertEval( + client, "'Server' in ws.response.headers", "False" + ) + + async def test_compression_is_enabled(self): + """Server enables compression by default.""" + async with run_server() as server: + async with run_client(server) as client: + await self.assertEval( + client, + "[type(ext).__name__ for ext in ws.protocol.extensions]", + "['PerMessageDeflate']", + ) + + async def test_disable_compression(self): + """Server disables compression.""" + async with run_server(compression=None) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.protocol.extensions", "[]") + + async def test_custom_connection_factory(self): + """Server runs ServerConnection factory provided in create_connection.""" + + def create_connection(*args, **kwargs): + server = ServerConnection(*args, **kwargs) + server.create_connection_ran = True + return server + + async with run_server(create_connection=create_connection) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.create_connection_ran", "True") + + async def test_handshake_fails(self): + """Server receives connection from client but the handshake fails.""" + + def remove_key_header(self, request): + del request.headers["Sec-WebSocket-Key"] + + async with run_server(process_request=remove_key_header) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) + + async def test_timeout_during_handshake(self): + """Server times out before receiving handshake request from client.""" + async with run_server(open_timeout=MS) as server: + reader, writer = await asyncio.open_connection( + *get_server_host_port(server) + ) + try: + self.assertEqual(await reader.read(4096), b"") + finally: + writer.close() + + async def test_connection_closed_during_handshake(self): + """Server reads EOF before receiving handshake request from client.""" + async with run_server() as server: + _reader, writer = await asyncio.open_connection( + *get_server_host_port(server) + ) + writer.close() + + async def test_close_server_rejects_connecting_connections(self): + """Server rejects connecting connections with HTTP 503 when closing.""" + + async def process_request(ws, _request): + while ws.server.is_serving(): + await asyncio.sleep(0) + + async with run_server(process_request=process_request) as server: + asyncio.get_running_loop().call_later(MS, server.close) + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 503", + ) + + async def test_close_server_closes_open_connections(self): + """Server closes open connections with close code 1001 when closing.""" + async with run_server() as server: + async with run_client(server) as client: + server.close() + with self.assertRaises(ConnectionClosedOK) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1001 (going away); then sent 1001 (going away)", + ) + + async def test_close_server_keeps_connections_open(self): + """Server waits for client to close open connections when closing.""" + async with run_server() as server: + async with run_client(server) as client: + server.close(close_connections=False) + + # Server cannot receive new connections. + await asyncio.sleep(0) + self.assertFalse(server.sockets) + + # The server waits for the client to close the connection. + with self.assertRaises(TimeoutError): + async with asyncio_timeout(MS): + await server.wait_closed() + + # Once the client closes the connection, the server terminates. + await client.close() + async with asyncio_timeout(MS): + await server.wait_closed() + + async def test_close_server_keeps_handlers_running(self): + """Server waits for connection handlers to terminate.""" + async with run_server(keep_running) as server: + async with run_client(server) as client: + # Delay termination of connection handler. + await client.send(str(2 * MS)) + + server.close() + + # The server waits for the connection handler to terminate. + with self.assertRaises(TimeoutError): + async with asyncio_timeout(MS): + await server.wait_closed() + + async with asyncio_timeout(2 * MS): + await server.wait_closed() + + +SSL_OBJECT = "ws.transport.get_extra_info('ssl_object')" + + +class SecureServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Server receives secure connection from client.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with run_client(server, ssl=CLIENT_CONTEXT) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") + + async def test_timeout_during_tls_handshake(self): + """Server times out before receiving TLS handshake request from client.""" + async with run_server(ssl=SERVER_CONTEXT, open_timeout=MS) as server: + reader, writer = await asyncio.open_connection( + *get_server_host_port(server) + ) + try: + self.assertEqual(await reader.read(4096), b"") + finally: + writer.close() + + async def test_connection_closed_during_tls_handshake(self): + """Server reads EOF before receiving TLS handshake request from client.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + _reader, writer = await asyncio.open_connection( + *get_server_host_port(server) + ) + writer.close() + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class UnixServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Server receives connection from client over a Unix socket.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path): + async with run_unix_client(path) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class SecureUnixServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Server receives secure connection from client over a Unix socket.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") + + +class ServerUsageErrorsTests(unittest.IsolatedAsyncioTestCase): + async def test_unix_without_path_or_sock(self): + """Unix server requires path when sock isn't provided.""" + with self.assertRaises(ValueError) as raised: + await unix_serve(eval_shell) + self.assertEqual( + str(raised.exception), + "path was not specified, and no sock specified", + ) + + async def test_unix_with_path_and_sock(self): + """Unix server rejects path when sock is provided.""" + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.addCleanup(sock.close) + with self.assertRaises(ValueError) as raised: + await unix_serve(eval_shell, path="/", sock=sock) + self.assertEqual( + str(raised.exception), + "path and sock can not be specified at the same time", + ) + + async def test_invalid_subprotocol(self): + """Server rejects single value of subprotocols.""" + with self.assertRaises(TypeError) as raised: + await serve(eval_shell, subprotocols="chat") + self.assertEqual( + str(raised.exception), + "subprotocols must be a list, not a str", + ) + + async def test_unsupported_compression(self): + """Server rejects incorrect value of compression.""" + with self.assertRaises(ValueError) as raised: + await serve(eval_shell, compression=False) + self.assertEqual( + str(raised.exception), + "unsupported compression: False", + ) + + +class WebSocketServerTests(unittest.IsolatedAsyncioTestCase): + async def test_logger(self): + """WebSocketServer accepts a logger argument.""" + logger = logging.getLogger("test") + async with run_server(logger=logger) as server: + self.assertIs(server.logger, logger) From d2120de4708d09d3cade465541f202d5a8bab722 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 7 Aug 2024 08:31:06 +0200 Subject: [PATCH 073/109] Add an option to disable decoding of text frames. Also support decoding binary frames. Fix #1376. --- src/websockets/asyncio/connection.py | 34 ++++++++++++++++++++++++---- src/websockets/asyncio/messages.py | 10 ++++++++ tests/asyncio/test_connection.py | 26 +++++++++++++++++++++ 3 files changed, 66 insertions(+), 4 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 550c0ac97..152c6789e 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -172,7 +172,7 @@ async def __aiter__(self) -> AsyncIterator[Data]: except ConnectionClosedOK: return - async def recv(self) -> Data: + async def recv(self, decode: bool | None = None) -> Data: """ Receive the next message. @@ -192,6 +192,10 @@ async def recv(self) -> Data: When the message is fragmented, :meth:`recv` waits until all fragments are received, reassembles them, and returns the whole message. + Args: + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + Returns: A string (:class:`str`) for a Text_ frame or a bytestring (:class:`bytes`) for a Binary_ frame. @@ -199,6 +203,15 @@ async def recv(self) -> Data: .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames + and return a bytestring (:class:`bytes`). This may be useful to + optimize performance when decoding isn't needed. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return a string (:class:`str`). This is useful for servers + that send binary frames instead of text frames. + Raises: ConnectionClosed: When the connection is closed. RuntimeError: If two coroutines call :meth:`recv` or @@ -206,7 +219,7 @@ async def recv(self) -> Data: """ try: - return await self.recv_messages.get() + return await self.recv_messages.get(decode) except EOFError: raise self.protocol.close_exc from self.recv_exc except RuntimeError: @@ -215,7 +228,7 @@ async def recv(self) -> Data: "is already running recv or recv_streaming" ) from None - async def recv_streaming(self) -> AsyncIterator[Data]: + async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: """ Receive the next message frame by frame. @@ -232,6 +245,10 @@ async def recv_streaming(self) -> AsyncIterator[Data]: iterator in a partially consumed state, making the connection unusable. Instead, you should close the connection with :meth:`close`. + Args: + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + Returns: An iterator of strings (:class:`str`) for a Text_ frame or bytestrings (:class:`bytes`) for a Binary_ frame. @@ -239,6 +256,15 @@ async def recv_streaming(self) -> AsyncIterator[Data]: .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames + and return bytestrings (:class:`bytes`). This may be useful to + optimize performance when decoding isn't needed. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return strings (:class:`str`). This is useful for servers + that send binary frames instead of text frames. + Raises: ConnectionClosed: When the connection is closed. RuntimeError: If two coroutines call :meth:`recv` or @@ -246,7 +272,7 @@ async def recv_streaming(self) -> AsyncIterator[Data]: """ try: - async for frame in self.recv_messages.get_iter(): + async for frame in self.recv_messages.get_iter(decode): yield frame except EOFError: raise self.protocol.close_exc from self.recv_exc diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 2a9c4d37d..bc33df8d7 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -121,6 +121,11 @@ async def get(self, decode: bool | None = None) -> Data: received, then it reassembles the message and returns it. To receive messages frame by frame, use :meth:`get_iter` instead. + Args: + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + Raises: EOFError: If the stream of frames has ended. RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter` @@ -183,6 +188,11 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: This method only makes sense for fragmented messages. If messages aren't fragmented, use :meth:`get` instead. + Args: + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + Raises: EOFError: If the stream of frames has ended. RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter` diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index a8b3980b4..2efd4e96d 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -167,6 +167,16 @@ async def test_recv_binary(self): await self.remote_connection.send(b"\x01\x02\xfe\xff") self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") + async def test_recv_encoded_text(self): + """recv receives an UTF-8 encoded text message.""" + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(decode=False), "😀".encode()) + + async def test_recv_decoded_binary(self): + """recv receives an UTF-8 decoded binary message.""" + await self.remote_connection.send("😀".encode()) + self.assertEqual(await self.connection.recv(decode=True), "😀") + async def test_recv_fragmented_text(self): """recv receives a fragmented text message.""" await self.remote_connection.send(["😀", "😀"]) @@ -271,6 +281,22 @@ async def test_recv_streaming_binary(self): [b"\x01\x02\xfe\xff"], ) + async def test_recv_streaming_encoded_text(self): + """recv_streaming receives an UTF-8 encoded text message.""" + await self.remote_connection.send("😀") + self.assertEqual( + await alist(self.connection.recv_streaming(decode=False)), + ["😀".encode()], + ) + + async def test_recv_streaming_decoded_binary(self): + """recv_streaming receives a UTF-8 decoded binary message.""" + await self.remote_connection.send("😀".encode()) + self.assertEqual( + await alist(self.connection.recv_streaming(decode=True)), + ["😀"], + ) + async def test_recv_streaming_fragmented_text(self): """recv_streaming receives a fragmented text message.""" await self.remote_connection.send(["😀", "😀"]) From e35c15a2a70c347f6f7a3e503ff1181ac35e1298 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 7 Aug 2024 18:45:02 +0200 Subject: [PATCH 074/109] Reduce MS for situations with performance penalties. Nowadays it's tuned with WEBSOCKETS_TESTS_TIMEOUT_FACTOR. --- tests/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index bd3b61d7b..1793f3e8b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -45,11 +45,11 @@ # PyPy has a performance penalty for this test suite. if platform.python_implementation() == "PyPy": # pragma: no cover - MS *= 5 + MS *= 2 -# asyncio's debug mode has a 10x performance penalty for this test suite. +# asyncio's debug mode has a performance penalty for this test suite. if os.environ.get("PYTHONASYNCIODEBUG"): # pragma: no cover - MS *= 10 + MS *= 2 # Ensure that timeouts are larger than the clock's resolution (for Windows). MS = max(MS, 2.5 * time.get_clock_info("monotonic").resolution) From bbb316155a5aeb719f262873a5b29a98c19b25d9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 8 Aug 2024 10:07:44 +0200 Subject: [PATCH 075/109] Add env vars for configuring constants. --- docs/project/changelog.rst | 6 ++++ docs/reference/index.rst | 1 + docs/reference/variables.rst | 48 ++++++++++++++++++++++++++++++ docs/topics/logging.rst | 4 +++ docs/topics/security.rst | 38 +++++++++++++++++------ src/websockets/asyncio/client.py | 5 ++-- src/websockets/asyncio/server.py | 11 ++++--- src/websockets/frames.py | 17 ++++++----- src/websockets/http.py | 9 ------ src/websockets/http11.py | 36 +++++++++++++++++----- src/websockets/legacy/client.py | 4 +-- src/websockets/legacy/http.py | 9 +++--- src/websockets/legacy/server.py | 8 ++--- src/websockets/sync/client.py | 3 +- src/websockets/sync/server.py | 9 +++--- tests/legacy/test_client_server.py | 2 +- tests/test_http.py | 8 ----- 17 files changed, 149 insertions(+), 69 deletions(-) create mode 100644 docs/reference/variables.rst delete mode 100644 tests/test_http.py diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 108b7c9c0..8143e3483 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -73,6 +73,12 @@ New features * Validated compatibility with Python 3.12. +* Added :doc:`environment variables <../reference/variables>` to configure debug + logs, the ``Server`` and ``User-Agent`` headers, as well as security limits. + + If you were monkey-patching constants, be aware that they were renamed, which + will break your configuration. You must switch to the environment variables. + 12.0 ---- diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 2486ac564..d3a0e935c 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -85,6 +85,7 @@ These low-level APIs are shared by all implementations. datastructures exceptions types + variables API stability ------------- diff --git a/docs/reference/variables.rst b/docs/reference/variables.rst new file mode 100644 index 000000000..4bca112da --- /dev/null +++ b/docs/reference/variables.rst @@ -0,0 +1,48 @@ +Environment variables +===================== + +Logging +------- + +.. envvar:: WEBSOCKETS_MAX_LOG_SIZE + + How much of each frame to show in debug logs. + + The default value is ``75``. + +See the :doc:`logging guide <../topics/logging>` for details. + +Security +........ + +.. envvar:: WEBSOCKETS_SERVER + + Server header sent by websockets. + + The default value uses the format ``"Python/x.y.z websockets/X.Y"``. + +.. envvar:: WEBSOCKETS_USER_AGENT + + User-Agent header sent by websockets. + + The default value uses the format ``"Python/x.y.z websockets/X.Y"``. + +.. envvar:: WEBSOCKETS_MAX_LINE_LENGTH + + Maximum length of the request or status line in the opening handshake. + + The default value is ``8192``. + +.. envvar:: WEBSOCKETS_MAX_NUM_HEADERS + + Maximum number of HTTP headers in the opening handshake. + + The default value is ``128``. + +.. envvar:: WEBSOCKETS_MAX_BODY_SIZE + + Maximum size of the body of an HTTP response in the opening handshake. + + The default value is ``1_048_576`` (1 MiB). + +See the :doc:`security guide <../topics/security>` for details. diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index e7abd96ce..873c852c2 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -76,6 +76,10 @@ Here's how to enable debug logs for development:: level=logging.DEBUG, ) +By default, websockets elides the content of messages to improve readability. +If you want to see more, you can increase the :envvar:`WEBSOCKETS_MAX_LOG_SIZE` +environment variable. The default value is 75. + Furthermore, websockets adds a ``websocket`` attribute to log records, so you can include additional information about the current connection in logs. diff --git a/docs/topics/security.rst b/docs/topics/security.rst index d3dec21bd..83d79e35b 100644 --- a/docs/topics/security.rst +++ b/docs/topics/security.rst @@ -1,6 +1,8 @@ Security ======== +.. currentmodule:: websockets + Encryption ---------- @@ -27,15 +29,33 @@ an amplification factor of 1000 between network traffic and memory usage. Configuring a server to :doc:`optimize memory usage ` will improve security in addition to improving performance. -Other limits ------------- +HTTP limits +----------- + +In the opening handshake, websockets applies limits to the amount of data that +it accepts in order to minimize exposure to denial of service attacks. + +The request or status line is limited to 8192 bytes. Each header line, including +the name and value, is limited to 8192 bytes too. No more than 128 HTTP headers +are allowed. When the HTTP response includes a body, it is limited to 1 MiB. + +You may change these limits by setting the :envvar:`WEBSOCKETS_MAX_LINE_LENGTH`, +:envvar:`WEBSOCKETS_MAX_NUM_HEADERS`, and :envvar:`WEBSOCKETS_MAX_BODY_SIZE` +environment variables respectively. + +Identification +-------------- + +By default, websockets identifies itself with a ``Server`` or ``User-Agent`` +header in the format ``"Python/x.y.z websockets/X.Y"``. -websockets implements additional limits on the amount of data it accepts in -order to minimize exposure to security vulnerabilities. +You can set the ``server_header`` argument of :func:`~server.serve` or the +``user_agent_header`` argument of :func:`~client.connect` to configure another +value. Setting them to :obj:`None` removes the header. -In the opening handshake, websockets limits the number of HTTP headers to 256 -and the size of an individual header to 4096 bytes. These limits are 10 to 20 -times larger than what's expected in standard use cases. They're hard-coded. +Alternatively, you can set the :envvar:`WEBSOCKETS_SERVER` and +:envvar:`WEBSOCKETS_USER_AGENT` environment variables respectively. Setting them +to an empty string removes the header. -If you need to change these limits, you can monkey-patch the constants in -``websockets.http11``. +If both the argument and the environment variable are set, the argument takes +precedence. diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 040d68ece..ac8ded8ca 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -9,8 +9,7 @@ from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import validate_subprotocols -from ..http import USER_AGENT -from ..http11 import Response +from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event from ..typing import LoggerLike, Origin, Subprotocol from ..uri import parse_uri @@ -71,7 +70,7 @@ async def handshake( self.request = self.protocol.connect() if additional_headers is not None: self.request.headers.update(additional_headers) - if user_agent_header is not None: + if user_agent_header: self.request.headers["User-Agent"] = user_agent_header self.protocol.send_request(self.request) diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index aa175f775..0c8b8780b 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -20,8 +20,7 @@ from ..extensions.base import ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate from ..headers import validate_subprotocols -from ..http import USER_AGENT -from ..http11 import Request, Response +from ..http11 import SERVER, Request, Response from ..protocol import CONNECTING, Event from ..server import ServerProtocol from ..typing import LoggerLike, Origin, Subprotocol @@ -88,7 +87,7 @@ async def handshake( ] | None ) = None, - server_header: str | None = USER_AGENT, + server_header: str | None = SERVER, ) -> None: """ Perform the opening handshake. @@ -131,7 +130,7 @@ async def handshake( assert isinstance(response, Response) # help mypy self.response = response - if server_header is not None: + if server_header: self.response.headers["Server"] = server_header response = None @@ -243,7 +242,7 @@ def __init__( ] | None ) = None, - server_header: str | None = USER_AGENT, + server_header: str | None = SERVER, open_timeout: float | None = 10, logger: LoggerLike | None = None, ) -> None: @@ -631,7 +630,7 @@ def __init__( ] | None ) = None, - server_header: str | None = USER_AGENT, + server_header: str | None = SERVER, compression: str | None = "deflate", # Timeouts open_timeout: float | None = 10, diff --git a/src/websockets/frames.py b/src/websockets/frames.py index af56d3f8f..819fdd742 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -3,6 +3,7 @@ import dataclasses import enum import io +import os import secrets import struct from typing import Callable, Generator, Sequence @@ -146,8 +147,8 @@ class Frame: rsv2: bool = False rsv3: bool = False - # Monkey-patch if you want to see more in logs. Should be a multiple of 3. - MAX_LOG = 75 + # Configure if you want to see more in logs. Should be a multiple of 3. + MAX_LOG_SIZE = int(os.environ.get("WEBSOCKETS_MAX_LOG_SIZE", "75")) def __str__(self) -> str: """ @@ -166,8 +167,8 @@ def __str__(self) -> str: # We'll show at most the first 16 bytes and the last 8 bytes. # Encode just what we need, plus two dummy bytes to elide later. binary = self.data - if len(binary) > self.MAX_LOG // 3: - cut = (self.MAX_LOG // 3 - 1) // 3 # by default cut = 8 + if len(binary) > self.MAX_LOG_SIZE // 3: + cut = (self.MAX_LOG_SIZE // 3 - 1) // 3 # by default cut = 8 binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]]) data = " ".join(f"{byte:02x}" for byte in binary) elif self.opcode is OP_CLOSE: @@ -183,16 +184,16 @@ def __str__(self) -> str: coding = "text" except (UnicodeDecodeError, AttributeError): binary = self.data - if len(binary) > self.MAX_LOG // 3: - cut = (self.MAX_LOG // 3 - 1) // 3 # by default cut = 8 + if len(binary) > self.MAX_LOG_SIZE // 3: + cut = (self.MAX_LOG_SIZE // 3 - 1) // 3 # by default cut = 8 binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]]) data = " ".join(f"{byte:02x}" for byte in binary) coding = "binary" else: data = "''" - if len(data) > self.MAX_LOG: - cut = self.MAX_LOG // 3 - 1 # by default cut = 24 + if len(data) > self.MAX_LOG_SIZE: + cut = self.MAX_LOG_SIZE // 3 - 1 # by default cut = 24 data = data[: 2 * cut] + "..." + data[-cut:] metadata = ", ".join(filter(None, [coding, length, non_final])) diff --git a/src/websockets/http.py b/src/websockets/http.py index 9f86f6a1f..a24102307 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -1,10 +1,8 @@ from __future__ import annotations -import sys import typing from .imports import lazy_import -from .version import version as websockets_version # For backwards compatibility: @@ -26,10 +24,3 @@ "read_response": ".legacy.http", }, ) - - -__all__ = ["USER_AGENT"] - - -PYTHON_VERSION = "{}.{}".format(*sys.version_info) -USER_AGENT = f"Python/{PYTHON_VERSION} websockets/{websockets_version}" diff --git a/src/websockets/http11.py b/src/websockets/http11.py index a7e9ae682..ed49fcbf9 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -1,23 +1,43 @@ from __future__ import annotations import dataclasses +import os import re +import sys import warnings from typing import Callable, Generator from . import datastructures, exceptions +from .version import version as websockets_version +__all__ = ["SERVER", "USER_AGENT", "Request", "Response"] + + +PYTHON_VERSION = "{}.{}".format(*sys.version_info) + +# User-Agent header for HTTP requests. +USER_AGENT = os.environ.get( + "WEBSOCKETS_USER_AGENT", + f"Python/{PYTHON_VERSION} websockets/{websockets_version}", +) + +# Server header for HTTP responses. +SERVER = os.environ.get( + "WEBSOCKETS_SERVER", + f"Python/{PYTHON_VERSION} websockets/{websockets_version}", +) + # Maximum total size of headers is around 128 * 8 KiB = 1 MiB. -MAX_HEADERS = 128 +MAX_NUM_HEADERS = int(os.environ.get("WEBSOCKETS_MAX_NUM_HEADERS", "128")) # Limit request line and header lines. 8KiB is the most common default # configuration of popular HTTP servers. -MAX_LINE = 8192 +MAX_LINE_LENGTH = int(os.environ.get("WEBSOCKETS_MAX_LINE_LENGTH", "8192")) # Support for HTTP response bodies is intended to read an error message # returned by a server. It isn't designed to perform large file transfers. -MAX_BODY = 2**20 # 1 MiB +MAX_BODY_SIZE = int(os.environ.get("WEBSOCKETS_MAX_BODY_SIZE", "1_048_576")) # 1 MiB def d(value: bytes) -> str: @@ -258,12 +278,12 @@ def parse( if content_length is None: try: - body = yield from read_to_eof(MAX_BODY) + body = yield from read_to_eof(MAX_BODY_SIZE) except RuntimeError: raise exceptions.SecurityError( - f"body too large: over {MAX_BODY} bytes" + f"body too large: over {MAX_BODY_SIZE} bytes" ) - elif content_length > MAX_BODY: + elif content_length > MAX_BODY_SIZE: raise exceptions.SecurityError( f"body too large: {content_length} bytes" ) @@ -309,7 +329,7 @@ def parse_headers( # We don't attempt to support obsolete line folding. headers = datastructures.Headers() - for _ in range(MAX_HEADERS + 1): + for _ in range(MAX_NUM_HEADERS + 1): try: line = yield from parse_line(read_line) except EOFError as exc: @@ -355,7 +375,7 @@ def parse_line( """ try: - line = yield from read_line(MAX_LINE) + line = yield from read_line(MAX_LINE_LENGTH) except RuntimeError: raise exceptions.SecurityError("line too long") # Not mandatory but safe - https://www.rfc-editor.org/rfc/rfc7230.html#section-3.5 diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index d1d8d5608..b61126c81 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -38,7 +38,7 @@ parse_subprotocol, validate_subprotocols, ) -from ..http import USER_AGENT +from ..http11 import USER_AGENT from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri from .handshake import build_request, check_response @@ -307,7 +307,7 @@ async def handshake( if self.extra_headers is not None: request_headers.update(self.extra_headers) - if self.user_agent_header is not None: + if self.user_agent_header: request_headers.setdefault("User-Agent", self.user_agent_header) self.write_http_request(wsuri.resource_name, request_headers) diff --git a/src/websockets/legacy/http.py b/src/websockets/legacy/http.py index 9a553e175..b5df7e4c4 100644 --- a/src/websockets/legacy/http.py +++ b/src/websockets/legacy/http.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import os import re from ..datastructures import Headers @@ -9,8 +10,8 @@ __all__ = ["read_request", "read_response"] -MAX_HEADERS = 128 -MAX_LINE = 8192 +MAX_NUM_HEADERS = int(os.environ.get("WEBSOCKETS_MAX_NUM_HEADERS", "128")) +MAX_LINE_LENGTH = int(os.environ.get("WEBSOCKETS_MAX_LINE_LENGTH", "8192")) def d(value: bytes) -> str: @@ -154,7 +155,7 @@ async def read_headers(stream: asyncio.StreamReader) -> Headers: # We don't attempt to support obsolete line folding. headers = Headers() - for _ in range(MAX_HEADERS + 1): + for _ in range(MAX_NUM_HEADERS + 1): try: line = await read_line(stream) except EOFError as exc: @@ -192,7 +193,7 @@ async def read_line(stream: asyncio.StreamReader) -> bytes: # Security: this is bounded by the StreamReader's limit (default = 32 KiB). line = await stream.readline() # Security: this guarantees header values are small (hard-coded = 8 KiB) - if len(line) > MAX_LINE: + if len(line) > MAX_LINE_LENGTH: raise SecurityError("line too long") # Not mandatory but safe - https://www.rfc-editor.org/rfc/rfc7230.html#section-3.5 if not line.endswith(b"\r\n"): diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 208ffa780..cd7980e00 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -40,7 +40,7 @@ parse_subprotocol, validate_subprotocols, ) -from ..http import USER_AGENT +from ..http11 import SERVER from ..protocol import State from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol from .handshake import build_response, check_request @@ -106,7 +106,7 @@ def __init__( extensions: Sequence[ServerExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, extra_headers: HeadersLikeOrCallable | None = None, - server_header: str | None = USER_AGENT, + server_header: str | None = SERVER, process_request: ( Callable[[str, Headers], Awaitable[HTTPResponse | None]] | None ) = None, @@ -221,7 +221,7 @@ async def handler(self) -> None: ) headers.setdefault("Date", email.utils.formatdate(usegmt=True)) - if self.server_header is not None: + if self.server_header: headers.setdefault("Server", self.server_header) headers.setdefault("Content-Length", str(len(body))) @@ -992,7 +992,7 @@ def __init__( extensions: Sequence[ServerExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, extra_headers: HeadersLikeOrCallable | None = None, - server_header: str | None = USER_AGENT, + server_header: str | None = SERVER, process_request: ( Callable[[str, Headers], Awaitable[HTTPResponse | None]] | None ) = None, diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index c97a09402..e33d53f62 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -11,8 +11,7 @@ from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import validate_subprotocols -from ..http import USER_AGENT -from ..http11 import Response +from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, OPEN, Event from ..typing import LoggerLike, Origin, Subprotocol from ..uri import parse_uri diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 7fb46f5aa..ebbbd0312 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -16,8 +16,7 @@ from ..extensions.permessage_deflate import enable_server_permessage_deflate from ..frames import CloseCode from ..headers import validate_subprotocols -from ..http import USER_AGENT -from ..http11 import Request, Response +from ..http11 import SERVER, Request, Response from ..protocol import CONNECTING, OPEN, Event from ..server import ServerProtocol from ..typing import LoggerLike, Origin, Subprotocol @@ -83,7 +82,7 @@ def handshake( ] | None ) = None, - server_header: str | None = USER_AGENT, + server_header: str | None = SERVER, timeout: float | None = None, ) -> None: """ @@ -120,7 +119,7 @@ def handshake( if self.response is None: self.response = self.protocol.accept(self.request) - if server_header is not None: + if server_header: self.response.headers["Server"] = server_header if process_response is not None: @@ -302,7 +301,7 @@ def serve( ] | None ) = None, - server_header: str | None = USER_AGENT, + server_header: str | None = SERVER, compression: str | None = "deflate", # Timeouts open_timeout: float | None = 10, diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index b5c5d726a..329f59286 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -29,7 +29,7 @@ ServerPerMessageDeflateFactory, ) from websockets.frames import CloseCode -from websockets.http import USER_AGENT +from websockets.http11 import USER_AGENT from websockets.legacy.client import * from websockets.legacy.handshake import build_response from websockets.legacy.http import read_response diff --git a/tests/test_http.py b/tests/test_http.py deleted file mode 100644 index baaa7d416..000000000 --- a/tests/test_http.py +++ /dev/null @@ -1,8 +0,0 @@ -import unittest - -from websockets.http import * - - -class HTTPTests(unittest.TestCase): - def test_user_agent(self): - USER_AGENT # exists From a7a5042bed89b96fa2e391f3be0e255a59bffb0a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 8 Aug 2024 10:21:23 +0200 Subject: [PATCH 076/109] Deprecate fully websockets.http. All public API within this module are deprecated since version 9.0 so there's nothing to document. --- src/websockets/connection.py | 1 - src/websockets/http.py | 31 ++++++++++--------------------- tests/test_http.py | 16 ++++++++++++++++ 3 files changed, 26 insertions(+), 22 deletions(-) create mode 100644 tests/test_http.py diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 88bcda1aa..7942c1a28 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -2,7 +2,6 @@ import warnings -# lazy_import doesn't support this use case. from .protocol import SEND_EOF, Protocol as Connection, Side, State # noqa: F401 diff --git a/src/websockets/http.py b/src/websockets/http.py index a24102307..3dc560062 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -1,26 +1,15 @@ from __future__ import annotations -import typing +import warnings -from .imports import lazy_import +from .datastructures import Headers, MultipleValuesError # noqa: F401 +from .legacy.http import read_request, read_response # noqa: F401 -# For backwards compatibility: - - -# When type checking, import non-deprecated aliases eagerly. Else, import on demand. -if typing.TYPE_CHECKING: - from .datastructures import Headers, MultipleValuesError # noqa: F401 -else: - lazy_import( - globals(), - # Headers and MultipleValuesError used to be defined in this module. - aliases={ - "Headers": ".datastructures", - "MultipleValuesError": ".datastructures", - }, - deprecated_aliases={ - "read_request": ".legacy.http", - "read_response": ".legacy.http", - }, - ) +warnings.warn( + "Headers and MultipleValuesError were moved " + "from websockets.http to websockets.datastructures" + "and read_request and read_response were moved " + "from websockets.http to websockets.legacy.http", + DeprecationWarning, +) diff --git a/tests/test_http.py b/tests/test_http.py new file mode 100644 index 000000000..6e81199fc --- /dev/null +++ b/tests/test_http.py @@ -0,0 +1,16 @@ +from websockets.datastructures import Headers + +from .utils import DeprecationTestCase + + +class BackwardsCompatibilityTests(DeprecationTestCase): + def test_headers_class(self): + with self.assertDeprecationWarning( + "Headers and MultipleValuesError were moved " + "from websockets.http to websockets.datastructures" + "and read_request and read_response were moved " + "from websockets.http to websockets.legacy.http", + ): + from websockets.http import Headers as OldHeaders + + self.assertIs(OldHeaders, Headers) From 5835da4967e130cd631a7601e75ea5228ab27537 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 9 Aug 2024 09:25:19 +0200 Subject: [PATCH 077/109] Adjust timings to avoid spurious failures. --- tests/asyncio/test_server.py | 6 +++--- tests/utils.py | 7 ++++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 2e59f49b1..535083cbc 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -412,16 +412,16 @@ async def test_close_server_keeps_handlers_running(self): async with run_server(keep_running) as server: async with run_client(server) as client: # Delay termination of connection handler. - await client.send(str(2 * MS)) + await client.send(str(3 * MS)) server.close() # The server waits for the connection handler to terminate. with self.assertRaises(TimeoutError): - async with asyncio_timeout(MS): + async with asyncio_timeout(2 * MS): await server.wait_closed() - async with asyncio_timeout(2 * MS): + async with asyncio_timeout(3 * MS): await server.wait_closed() diff --git a/tests/utils.py b/tests/utils.py index 1793f3e8b..960439135 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -43,13 +43,14 @@ # WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. MS = 0.001 * float(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", "1")) -# PyPy has a performance penalty for this test suite. +# PyPy, asyncio's debug mode, and coverage penalize performance of this +# test suite. Increase timeouts to reduce the risk of spurious failures. if platform.python_implementation() == "PyPy": # pragma: no cover MS *= 2 - -# asyncio's debug mode has a performance penalty for this test suite. if os.environ.get("PYTHONASYNCIODEBUG"): # pragma: no cover MS *= 2 +if os.environ.get("COVERAGE_RUN"): # pragma: no branch + MS *= 2 # Ensure that timeouts are larger than the clock's resolution (for Windows). MS = max(MS, 2.5 * time.get_clock_info("monotonic").resolution) From 906592908bb5850a4f78a5d4877fbc2412d611b7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 9 Aug 2024 09:25:42 +0200 Subject: [PATCH 078/109] Avoid spurious coverage failures due to timing effects. --- tests/asyncio/test_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 535083cbc..4a8a76a21 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -363,7 +363,7 @@ async def test_close_server_rejects_connecting_connections(self): async def process_request(ws, _request): while ws.server.is_serving(): - await asyncio.sleep(0) + await asyncio.sleep(0) # pragma: no cover async with run_server(process_request=process_request) as server: asyncio.get_running_loop().call_later(MS, server.close) From 84e8bd879b8dfc528b4e57517f2e1f8b7ad0a378 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 9 Aug 2024 09:49:55 +0200 Subject: [PATCH 079/109] Fix spurious exception while running tests. Due to a race condition between serve_forever and shutdown, test run logs randomly contained this exception: Exception in thread Thread-NNN (serve_forever): Traceback (most recent call last): ... ValueError: Invalid file descriptor: -1 --- src/websockets/sync/server.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index ebbbd0312..10fbe4859 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -222,7 +222,13 @@ def serve_forever(self) -> None: """ poller = selectors.DefaultSelector() - poller.register(self.socket, selectors.EVENT_READ) + try: + poller.register(self.socket, selectors.EVENT_READ) + except ValueError: # pragma: no cover + # If shutdown() is called before poller.register(), + # the socket is closed and poller.register() raises + # ValueError: Invalid file descriptor: -1 + return if sys.platform != "win32": poller.register(self.shutdown_watcher, selectors.EVENT_READ) From a3ed1604b0f331fe91df52641cfab2ae5349eb46 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 9 Aug 2024 10:24:24 +0200 Subject: [PATCH 080/109] Make test_reconnect robust to slower runs. This avoids failures with higher WEBSOCKETS_TESTS_TIMEOUT_FACTOR, notably on PyPy. Refs #1483. --- tests/legacy/test_client_server.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 329f59286..0c5d66c92 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -5,6 +5,7 @@ import logging import platform import random +import re import socket import ssl import sys @@ -1608,16 +1609,19 @@ async def run_client(): ) # Iteration 3 self.assertEqual( - [record.getMessage() for record in logs.records][4:-1], + [ + re.sub(r"[0-9\.]+ seconds", "X seconds", record.getMessage()) + for record in logs.records + ][4:-1], [ "connection rejected (503 Service Unavailable)", "connection closed", - "! connect failed; reconnecting in 0.0 seconds", + "! connect failed; reconnecting in X seconds", ] + [ "connection rejected (503 Service Unavailable)", "connection closed", - "! connect failed again; retrying in 0 seconds", + "! connect failed again; retrying in X seconds", ] * ((len(logs.records) - 8) // 3) + [ From 58787cc6a58a1f1baf4be3f78d868594108afebd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Aug 2024 14:35:38 +0200 Subject: [PATCH 081/109] Confirm support for Python 3.13. --- .github/workflows/release.yml | 2 +- .github/workflows/tests.yml | 2 ++ docs/faq/misc.rst | 3 +++ docs/project/changelog.rst | 2 +- pyproject.toml | 1 + src/websockets/asyncio/connection.py | 3 +-- tox.ini | 1 + 7 files changed, 10 insertions(+), 4 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 4a00bf8fc..ed52ddd80 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -56,7 +56,7 @@ jobs: with: platforms: all - name: Build wheels - uses: pypa/cibuildwheel@v2.16.2 + uses: pypa/cibuildwheel@v2.20.0 env: BUILD_EXTENSION: yes - name: Save wheels diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 15a45bdfb..b9172b7fb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -62,6 +62,7 @@ jobs: - "3.10" - "3.11" - "3.12" + - "3.13" - "pypy-3.9" - "pypy-3.10" is_main: @@ -78,6 +79,7 @@ jobs: uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} + allow-prereleases: true - name: Install tox run: pip install tox - name: Run tests diff --git a/docs/faq/misc.rst b/docs/faq/misc.rst index ee5ad2372..0e74a784f 100644 --- a/docs/faq/misc.rst +++ b/docs/faq/misc.rst @@ -3,6 +3,9 @@ Miscellaneous .. currentmodule:: websockets +.. Remove this question when dropping Python < 3.13, which provides natively +.. a good error message in this case. + Why do I get the error: ``module 'websockets' has no attribute '...'``? ....................................................................... diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 8143e3483..00b055dd1 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -71,7 +71,7 @@ New features See :func:`websockets.asyncio.client.connect` and :func:`websockets.asyncio.server.serve` for details. -* Validated compatibility with Python 3.12. +* Validated compatibility with Python 3.12 and 3.13. * Added :doc:`environment variables <../reference/variables>` to configure debug logs, the ``Server`` and ``User-Agent`` headers, as well as security limits. diff --git a/pyproject.toml b/pyproject.toml index de8acd6a3..c1d34c90b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ] dynamic = ["version", "readme"] diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 152c6789e..4f44d798c 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -564,8 +564,7 @@ async def ping(self, data: Data | None = None) -> Awaitable[None]: pong_waiter = self.loop.create_future() # The event loop's default clock is time.monotonic(). Its resolution - # is a bit low on Windows (~16ms). We cannot use time.perf_counter() - # because it doesn't count time elapsed while the process sleeps. + # is a bit low on Windows (~16ms). This is improved in Python 3.13. ping_timestamp = self.loop.time() self.pong_waiters[data] = (pong_waiter, ping_timestamp) self.protocol.send_ping(data) diff --git a/tox.ini b/tox.ini index 1edcfe261..16d9c9f16 100644 --- a/tox.ini +++ b/tox.ini @@ -5,6 +5,7 @@ env_list = py310 py311 py312 + py313 coverage black ruff From 9ec785d6f12cb1a3a3bc43f543f4a831a635472b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Aug 2024 15:07:40 +0200 Subject: [PATCH 082/109] Fix copy-paste error in tests. --- tests/asyncio/test_client.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index aab65cd2e..b74617ef0 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -156,27 +156,27 @@ async def test_connection(self): async def test_set_server_hostname_implicitly(self): """Client sets server_hostname to the host in the WebSocket URI.""" - with temp_unix_socket_path() as path: - async with run_unix_server(path, ssl=SERVER_CONTEXT): - async with run_unix_client( - path, - ssl=CLIENT_CONTEXT, - uri="wss://overridden/", - ) as client: - ssl_object = client.transport.get_extra_info("ssl_object") - self.assertEqual(ssl_object.server_hostname, "overridden") + async with run_server(ssl=SERVER_CONTEXT) as server: + host, port = get_server_host_port(server) + async with run_client( + "wss://overridden/", + host=host, + port=port, + ssl=CLIENT_CONTEXT, + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") async def test_set_server_hostname_explicitly(self): """Client sets server_hostname to the value provided in argument.""" - with temp_unix_socket_path() as path: - async with run_unix_server(path, ssl=SERVER_CONTEXT): - async with run_unix_client( - path, - ssl=CLIENT_CONTEXT, - server_hostname="overridden", - ) as client: - ssl_object = client.transport.get_extra_info("ssl_object") - self.assertEqual(ssl_object.server_hostname, "overridden") + async with run_server(ssl=SERVER_CONTEXT) as server: + async with run_client( + server, + ssl=CLIENT_CONTEXT, + server_hostname="overridden", + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") async def test_reject_invalid_server_certificate(self): """Client rejects certificate where server certificate isn't trusted.""" From 1853a9b2d0247573633e2749fe1169f764abe03c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Aug 2024 15:59:59 +0200 Subject: [PATCH 083/109] Ignore ResourceWarning in test. This is expected to prevent a spurious test failure under PyPy. Refs #1483. --- tests/legacy/utils.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index 28bc90df3..5bb56b26f 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -2,6 +2,7 @@ import contextlib import functools import logging +import sys import unittest @@ -76,8 +77,19 @@ def assertDeprecationWarnings(self, recorded_warnings, expected_warnings): Check recorded deprecation warnings match a list of expected messages. """ + # Work around https://github.com/python/cpython/issues/90476. + if sys.version_info[:2] < (3, 11): # pragma: no cover + recorded_warnings = [ + recorded + for recorded in recorded_warnings + if not ( + type(recorded.message) is ResourceWarning + and str(recorded.message).startswith("unclosed transport") + ) + ] + for recorded in recorded_warnings: - self.assertEqual(type(recorded.message), DeprecationWarning) + self.assertIs(type(recorded.message), DeprecationWarning) self.assertEqual( {str(recorded.message) for recorded in recorded_warnings}, set(expected_warnings), From 00b63afe7d921d17fd48abee2e25389050a2410c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 12 Aug 2024 09:44:02 +0200 Subject: [PATCH 084/109] Add new asyncio implementation to feature matrices. --- docs/reference/features.rst | 257 ++++++++++++++++++------------------ 1 file changed, 128 insertions(+), 129 deletions(-) diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 98b3c0dda..946770fe3 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -14,9 +14,10 @@ Feature support matrices summarize which implementations support which features. .support-matrix-table td:not(:first-child) { text-align: center; } -.. |aio| replace:: :mod:`asyncio` +.. |aio| replace:: :mod:`asyncio` (new) .. |sync| replace:: :mod:`threading` .. |sans| replace:: `Sans-I/O`_ +.. |leg| replace:: :mod:`asyncio` (legacy) .. _Sans-I/O: https://sans-io.readthedocs.io/ Both sides @@ -25,60 +26,58 @@ Both sides .. table:: :class: support-matrix-table - +------------------------------------+--------+--------+--------+ - | | |aio| | |sync| | |sans| | - +====================================+========+========+========+ - | Perform the opening handshake | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Send a message | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Receive a message | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Iterate over received messages | ✅ | ✅ | ❌ | - +------------------------------------+--------+--------+--------+ - | Send a fragmented message | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Receive a fragmented message after | ✅ | ✅ | ❌ | - | reassembly | | | | - +------------------------------------+--------+--------+--------+ - | Receive a fragmented message frame | ❌ | ✅ | ✅ | - | by frame (`#479`_) | | | | - +------------------------------------+--------+--------+--------+ - | Send a ping | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Respond to pings automatically | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Send a pong | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Perform the closing handshake | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Report close codes and reasons | ❌ | ✅ | ✅ | - | from both sides | | | | - +------------------------------------+--------+--------+--------+ - | Compress messages (:rfc:`7692`) | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Tune memory usage for compression | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Negotiate extensions | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Implement custom extensions | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Negotiate a subprotocol | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Enforce security limits | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Log events | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Enforce opening timeout | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Enforce closing timeout | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Keepalive | ✅ | ❌ | — | - +------------------------------------+--------+--------+--------+ - | Heartbeat | ✅ | ❌ | — | - +------------------------------------+--------+--------+--------+ - -.. _#479: https://github.com/python-websockets/websockets/issues/479 + +------------------------------------+--------+--------+--------+--------+ + | | |aio| | |sync| | |sans| | |leg| | + +====================================+========+========+========+========+ + | Perform the opening handshake | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Enforce opening timeout | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Send a message | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Receive a message | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Iterate over received messages | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Send a fragmented message | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Receive a fragmented message frame | ✅ | ✅ | ✅ | ❌ | + | by frame | | | | | + +------------------------------------+--------+--------+--------+--------+ + | Receive a fragmented message after | ✅ | ✅ | — | ✅ | + | reassembly | | | | | + +------------------------------------+--------+--------+--------+--------+ + | Send a ping | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Respond to pings automatically | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Send a pong | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Keepalive | ❌ | ❌ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Heartbeat | ❌ | ❌ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Perform the closing handshake | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Enforce closing timeout | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Report close codes and reasons | ✅ | ✅ | ✅ | ❌ | + | from both sides | | | | | + +------------------------------------+--------+--------+--------+--------+ + | Compress messages (:rfc:`7692`) | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Tune memory usage for compression | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Negotiate extensions | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Implement custom extensions | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Negotiate a subprotocol | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Enforce security limits | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Log events | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ Server ------ @@ -86,39 +85,39 @@ Server .. table:: :class: support-matrix-table - +------------------------------------+--------+--------+--------+ - | | |aio| | |sync| | |sans| | - +====================================+========+========+========+ - | Listen on a TCP socket | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Listen on a Unix socket | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Listen using a preexisting socket | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Encrypt connection with TLS | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Close server on context exit | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Close connection on handler exit | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Shut down server gracefully | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Check ``Origin`` header | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Customize subprotocol selection | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Configure ``Server`` header | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Alter opening handshake request | ❌ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Alter opening handshake response | ❌ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Perform HTTP Basic Authentication | ✅ | ❌ | ❌ | - +------------------------------------+--------+--------+--------+ - | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | - +------------------------------------+--------+--------+--------+ - | Force HTTP response | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ + +------------------------------------+--------+--------+--------+--------+ + | | |aio| | |sync| | |sans| | |leg| | + +====================================+========+========+========+========+ + | Listen on a TCP socket | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Listen on a Unix socket | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Listen using a preexisting socket | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Encrypt connection with TLS | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Close server on context exit | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Close connection on handler exit | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Shut down server gracefully | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Check ``Origin`` header | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Customize subprotocol selection | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Configure ``Server`` header | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Alter opening handshake request | ✅ | ✅ | ✅ | ❌ | + +------------------------------------+--------+--------+--------+--------+ + | Alter opening handshake response | ✅ | ✅ | ✅ | ❌ | + +------------------------------------+--------+--------+--------+--------+ + | Force an HTTP response | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Perform HTTP Basic Authentication | ❌ | ❌ | ❌ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | ❌ | + +------------------------------------+--------+--------+--------+--------+ Client ------ @@ -126,41 +125,43 @@ Client .. table:: :class: support-matrix-table - +------------------------------------+--------+--------+--------+ - | | |aio| | |sync| | |sans| | - +====================================+========+========+========+ - | Connect to a TCP socket | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Connect to a Unix socket | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Connect using a preexisting socket | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Encrypt connection with TLS | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Close connection on context exit | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Reconnect automatically | ✅ | ❌ | — | - +------------------------------------+--------+--------+--------+ - | Configure ``Origin`` header | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Configure ``User-Agent`` header | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Alter opening handshake request | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Connect to non-ASCII IRIs | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Perform HTTP Basic Authentication | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | - | (`#784`_) | | | | - +------------------------------------+--------+--------+--------+ - | Follow HTTP redirects | ✅ | ❌ | — | - +------------------------------------+--------+--------+--------+ - | Connect via a HTTP proxy (`#364`_) | ❌ | ❌ | — | - +------------------------------------+--------+--------+--------+ - | Connect via a SOCKS5 proxy | ❌ | ❌ | — | - | (`#475`_) | | | | - +------------------------------------+--------+--------+--------+ + +------------------------------------+--------+--------+--------+--------+ + | | |aio| | |sync| | |sans| | |leg| | + +====================================+========+========+========+========+ + | Connect to a TCP socket | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Connect to a Unix socket | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Connect using a preexisting socket | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Encrypt connection with TLS | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Close connection on context exit | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Reconnect automatically | ❌ | ❌ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Configure ``Origin`` header | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Configure ``User-Agent`` header | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Modify opening handshake request | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Modify opening handshake response | ✅ | ✅ | ✅ | ❌ | + +------------------------------------+--------+--------+--------+--------+ + | Connect to non-ASCII IRIs | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Follow HTTP redirects | ❌ | ❌ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Perform HTTP Basic Authentication | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | ❌ | + | (`#784`_) | | | | | + +------------------------------------+--------+--------+--------+--------+ + | Connect via a HTTP proxy (`#364`_) | ❌ | ❌ | — | ❌ | + +------------------------------------+--------+--------+--------+--------+ + | Connect via a SOCKS5 proxy | ❌ | ❌ | — | ❌ | + | (`#475`_) | | | | | + +------------------------------------+--------+--------+--------+--------+ .. _#364: https://github.com/python-websockets/websockets/issues/364 .. _#475: https://github.com/python-websockets/websockets/issues/475 @@ -174,14 +175,12 @@ There is no way to control compression of outgoing frames on a per-frame basis .. _#538: https://github.com/python-websockets/websockets/issues/538 -The server doesn't check the Host header and respond with a HTTP 400 Bad Request -if it is missing or invalid (`#1246`). +The server doesn't check the Host header and doesn't respond with a HTTP 400 Bad +Request if it is missing or invalid (`#1246`). .. _#1246: https://github.com/python-websockets/websockets/issues/1246 The client API doesn't attempt to guarantee that there is no more than one connection to a given IP address in a CONNECTING state. This behavior is -`mandated by RFC 6455`_. However, :func:`~client.connect()` isn't the right -layer for enforcing this constraint. It's the caller's responsibility. - -.. _mandated by RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-4.1 +mandated by :rfc:`6455`, section 4.1. However, :func:`~client.connect()` isn't +the right layer for enforcing this constraint. It's the caller's responsibility. From 7345b31edc82abc200ebb58dc0fbe856e65d447b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 12 Aug 2024 09:45:39 +0200 Subject: [PATCH 085/109] Cool URI don't change... except when they do. --- src/websockets/asyncio/connection.py | 18 ++++++------- .../extensions/permessage_deflate.py | 4 +-- src/websockets/headers.py | 18 ++++++------- src/websockets/http11.py | 14 +++++----- src/websockets/legacy/http.py | 10 +++---- src/websockets/legacy/protocol.py | 26 +++++++++---------- src/websockets/legacy/server.py | 2 +- src/websockets/protocol.py | 4 +-- src/websockets/server.py | 2 +- src/websockets/sync/connection.py | 18 ++++++------- src/websockets/typing.py | 4 +-- src/websockets/uri.py | 2 +- 12 files changed, 61 insertions(+), 61 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 4f44d798c..0a3ddb9aa 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -200,8 +200,8 @@ async def recv(self, decode: bool | None = None) -> Data: A string (:class:`str`) for a Text_ frame or a bytestring (:class:`bytes`) for a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 You may override this behavior with the ``decode`` argument: @@ -253,8 +253,8 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data An iterator of strings (:class:`str`) for a Text_ frame or bytestrings (:class:`bytes`) for a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 You may override this behavior with the ``decode`` argument: @@ -290,8 +290,8 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No bytes-like object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent as a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 :meth:`send` also accepts an iterable or an asynchronous iterable of strings, bytestrings, or bytes-like objects to enable fragmentation_. @@ -299,7 +299,7 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No All items must be of the same type, or else :meth:`send` will raise a :exc:`TypeError` and the connection will be closed. - .. _fragmentation: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.4 + .. _fragmentation: https://datatracker.ietf.org/doc/html/rfc6455#section-5.4 :meth:`send` rejects dict-like objects because this is often an error. (If you really want to send the keys of a dict-like object as fragments, @@ -524,7 +524,7 @@ async def ping(self, data: Data | None = None) -> Awaitable[None]: """ Send a Ping_. - .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 + .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 A ping may serve as a keepalive or as a check that the remote endpoint received all messages up to this point @@ -574,7 +574,7 @@ async def pong(self, data: Data = b"") -> None: """ Send a Pong_. - .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 An unsolicited pong may serve as a unidirectional heartbeat. diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 579262f02..fea14131e 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -262,7 +262,7 @@ class ClientPerMessageDeflateFactory(ClientExtensionFactory): Parameters behave as described in `section 7.1 of RFC 7692`_. - .. _section 7.1 of RFC 7692: https://www.rfc-editor.org/rfc/rfc7692.html#section-7.1 + .. _section 7.1 of RFC 7692: https://datatracker.ietf.org/doc/html/rfc7692#section-7.1 Set them to :obj:`True` to include them in the negotiation offer without a value or to an integer value to include them with this value. @@ -462,7 +462,7 @@ class ServerPerMessageDeflateFactory(ServerExtensionFactory): Parameters behave as described in `section 7.1 of RFC 7692`_. - .. _section 7.1 of RFC 7692: https://www.rfc-editor.org/rfc/rfc7692.html#section-7.1 + .. _section 7.1 of RFC 7692: https://datatracker.ietf.org/doc/html/rfc7692#section-7.1 Set them to :obj:`True` to include them in the negotiation offer without a value or to an integer value to include them with this value. diff --git a/src/websockets/headers.py b/src/websockets/headers.py index bc42e0b72..0ffd65233 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -40,7 +40,7 @@ def build_host(host: str, port: int, secure: bool) -> str: Build a ``Host`` header. """ - # https://www.rfc-editor.org/rfc/rfc3986.html#section-3.2.2 + # https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.2 # IPv6 addresses must be enclosed in brackets. try: address = ipaddress.ip_address(host) @@ -59,8 +59,8 @@ def build_host(host: str, port: int, secure: bool) -> str: # To avoid a dependency on a parsing library, we implement manually the ABNF -# described in https://www.rfc-editor.org/rfc/rfc6455.html#section-9.1 and -# https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B. +# described in https://datatracker.ietf.org/doc/html/rfc6455#section-9.1 and +# https://datatracker.ietf.org/doc/html/rfc7230#appendix-B. def peek_ahead(header: str, pos: int) -> str | None: @@ -183,7 +183,7 @@ def parse_list( InvalidHeaderFormat: On invalid inputs. """ - # Per https://www.rfc-editor.org/rfc/rfc7230.html#section-7, "a recipient + # Per https://datatracker.ietf.org/doc/html/rfc7230#section-7, "a recipient # MUST parse and ignore a reasonable number of empty list elements"; # hence while loops that remove extra delimiters. @@ -320,7 +320,7 @@ def parse_extension_item_param( if peek_ahead(header, pos) == '"': pos_before = pos # for proper error reporting below value, pos = parse_quoted_string(header, pos, header_name) - # https://www.rfc-editor.org/rfc/rfc6455.html#section-9.1 says: + # https://datatracker.ietf.org/doc/html/rfc6455#section-9.1 says: # the value after quoted-string unescaping MUST conform to # the 'token' ABNF. if _token_re.fullmatch(value) is None: @@ -489,7 +489,7 @@ def build_www_authenticate_basic(realm: str) -> str: realm: Identifier of the protection space. """ - # https://www.rfc-editor.org/rfc/rfc7617.html#section-2 + # https://datatracker.ietf.org/doc/html/rfc7617#section-2 realm = build_quoted_string(realm) charset = build_quoted_string("UTF-8") return f"Basic realm={realm}, charset={charset}" @@ -539,8 +539,8 @@ def parse_authorization_basic(header: str) -> tuple[str, str]: InvalidHeaderValue: On unsupported inputs. """ - # https://www.rfc-editor.org/rfc/rfc7235.html#section-2.1 - # https://www.rfc-editor.org/rfc/rfc7617.html#section-2 + # https://datatracker.ietf.org/doc/html/rfc7235#section-2.1 + # https://datatracker.ietf.org/doc/html/rfc7617#section-2 scheme, pos = parse_token(header, 0, "Authorization") if scheme.lower() != "basic": raise exceptions.InvalidHeaderValue( @@ -580,7 +580,7 @@ def build_authorization_basic(username: str, password: str) -> str: This is the reverse of :func:`parse_authorization_basic`. """ - # https://www.rfc-editor.org/rfc/rfc7617.html#section-2 + # https://datatracker.ietf.org/doc/html/rfc7617#section-2 assert ":" not in username user_pass = f"{username}:{password}" basic_credentials = base64.b64encode(user_pass.encode()).decode() diff --git a/src/websockets/http11.py b/src/websockets/http11.py index ed49fcbf9..b86c6ca4a 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -48,7 +48,7 @@ def d(value: bytes) -> str: return value.decode(errors="backslashreplace") -# See https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B. +# See https://datatracker.ietf.org/doc/html/rfc7230#appendix-B. # Regex for validating header names. @@ -122,7 +122,7 @@ def parse( ValueError: If the request isn't well formatted. """ - # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.1 + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.1 # Parsing is simple because fixed values are expected for method and # version and because path isn't checked. Since WebSocket software tends @@ -146,7 +146,7 @@ def parse( headers = yield from parse_headers(read_line) - # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.3.3 + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3 if "Transfer-Encoding" in headers: raise NotImplementedError("transfer codings aren't supported") @@ -227,7 +227,7 @@ def parse( ValueError: If the response isn't well formatted. """ - # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.2 + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.2 try: status_line = yield from parse_line(read_line) @@ -255,7 +255,7 @@ def parse( headers = yield from parse_headers(read_line) - # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.3.3 + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3 if "Transfer-Encoding" in headers: raise NotImplementedError("transfer codings aren't supported") @@ -324,7 +324,7 @@ def parse_headers( ValueError: If the request isn't well formatted. """ - # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.2 + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.2 # We don't attempt to support obsolete line folding. @@ -378,7 +378,7 @@ def parse_line( line = yield from read_line(MAX_LINE_LENGTH) except RuntimeError: raise exceptions.SecurityError("line too long") - # Not mandatory but safe - https://www.rfc-editor.org/rfc/rfc7230.html#section-3.5 + # Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5 if not line.endswith(b"\r\n"): raise EOFError("line without CRLF") return line[:-2] diff --git a/src/websockets/legacy/http.py b/src/websockets/legacy/http.py index b5df7e4c4..a7c8a927e 100644 --- a/src/websockets/legacy/http.py +++ b/src/websockets/legacy/http.py @@ -22,7 +22,7 @@ def d(value: bytes) -> str: return value.decode(errors="backslashreplace") -# See https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B. +# See https://datatracker.ietf.org/doc/html/rfc7230#appendix-B. # Regex for validating header names. @@ -64,7 +64,7 @@ async def read_request(stream: asyncio.StreamReader) -> tuple[str, Headers]: ValueError: If the request isn't well formatted. """ - # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.1 + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.1 # Parsing is simple because fixed values are expected for method and # version and because path isn't checked. Since WebSocket software tends @@ -111,7 +111,7 @@ async def read_response(stream: asyncio.StreamReader) -> tuple[int, str, Headers ValueError: If the response isn't well formatted. """ - # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.2 + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.2 # As in read_request, parsing is simple because a fixed value is expected # for version, status_code is a 3-digit number, and reason can be ignored. @@ -150,7 +150,7 @@ async def read_headers(stream: asyncio.StreamReader) -> Headers: Non-ASCII characters are represented with surrogate escapes. """ - # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.2 + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.2 # We don't attempt to support obsolete line folding. @@ -195,7 +195,7 @@ async def read_line(stream: asyncio.StreamReader) -> bytes: # Security: this guarantees header values are small (hard-coded = 8 KiB) if len(line) > MAX_LINE_LENGTH: raise SecurityError("line too long") - # Not mandatory but safe - https://www.rfc-editor.org/rfc/rfc7230.html#section-3.5 + # Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5 if not line.endswith(b"\r\n"): raise EOFError("line without CRLF") return line[:-2] diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 120ff8e73..6f8916576 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -80,14 +80,14 @@ class WebSocketCommonProtocol(asyncio.Protocol): especially in the presence of proxies with short timeouts on inactive connections. Set ``ping_interval`` to :obj:`None` to disable this behavior. - .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 + .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 If the corresponding Pong_ frame isn't received within ``ping_timeout`` seconds, the connection is considered unusable and is closed with code 1011. This ensures that the remote endpoint remains responsive. Set ``ping_timeout`` to :obj:`None` to disable this behavior. - .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 See the discussion of :doc:`timeouts <../../topics/timeouts>` for details. @@ -447,7 +447,7 @@ def close_code(self) -> int | None: WebSocket close code, defined in `section 7.1.5 of RFC 6455`_. .. _section 7.1.5 of RFC 6455: - https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.5 + https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.5 :obj:`None` if the connection isn't closed yet. @@ -465,7 +465,7 @@ def close_reason(self) -> str | None: WebSocket close reason, defined in `section 7.1.6 of RFC 6455`_. .. _section 7.1.6 of RFC 6455: - https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.6 + https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.6 :obj:`None` if the connection isn't closed yet. @@ -516,8 +516,8 @@ async def recv(self) -> Data: A string (:class:`str`) for a Text_ frame. A bytestring (:class:`bytes`) for a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 Raises: ConnectionClosed: When the connection is closed. @@ -583,8 +583,8 @@ async def send( bytes-like object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent as a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 :meth:`send` also accepts an iterable or an asynchronous iterable of strings, bytestrings, or bytes-like objects to enable fragmentation_. @@ -592,7 +592,7 @@ async def send( All items must be of the same type, or else :meth:`send` will raise a :exc:`TypeError` and the connection will be closed. - .. _fragmentation: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.4 + .. _fragmentation: https://datatracker.ietf.org/doc/html/rfc6455#section-5.4 :meth:`send` rejects dict-like objects because this is often an error. (If you want to send the keys of a dict-like object as fragments, call @@ -803,7 +803,7 @@ async def ping(self, data: Data | None = None) -> Awaitable[float]: """ Send a Ping_. - .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 + .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 A ping may serve as a keepalive, as a check that the remote endpoint received all messages up to this point, or to measure :attr:`latency`. @@ -862,7 +862,7 @@ async def pong(self, data: Data = b"") -> None: """ Send a Pong_. - .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 An unsolicited pong may serve as a unidirectional heartbeat. @@ -1559,8 +1559,8 @@ def broadcast( object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent as a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 :func:`broadcast` pushes the message synchronously to all connections even if their write buffers are overflowing. There's no backpressure. diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index cd7980e00..d230f009e 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -388,7 +388,7 @@ def process_origin( """ # "The user agent MUST NOT include more than one Origin header field" - # per https://www.rfc-editor.org/rfc/rfc6454.html#section-7.3. + # per https://datatracker.ietf.org/doc/html/rfc6454#section-7.3. try: origin = headers.get("Origin") except MultipleValuesError as exc: diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 7f2b45c74..917c19163 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -175,7 +175,7 @@ def close_code(self) -> int | None: `WebSocket close code`_. .. _WebSocket close code: - https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.5 + https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.5 :obj:`None` if the connection isn't closed yet. @@ -193,7 +193,7 @@ def close_reason(self) -> str | None: `WebSocket close reason`_. .. _WebSocket close reason: - https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.6 + https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.6 :obj:`None` if the connection isn't closed yet. diff --git a/src/websockets/server.py b/src/websockets/server.py index 7211d3cbf..1b4c3bf29 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -307,7 +307,7 @@ def process_origin(self, headers: Headers) -> Origin | None: """ # "The user agent MUST NOT include more than one Origin header field" - # per https://www.rfc-editor.org/rfc/rfc6454.html#section-7.3. + # per https://datatracker.ietf.org/doc/html/rfc6454#section-7.3. try: origin = headers.get("Origin") except MultipleValuesError as exc: diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 2bcb3aa0e..a4826c785 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -189,8 +189,8 @@ def recv(self, timeout: float | None = None) -> Data: A string (:class:`str`) for a Text_ frame or a bytestring (:class:`bytes`) for a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 Raises: ConnectionClosed: When the connection is closed. @@ -222,8 +222,8 @@ def recv_streaming(self) -> Iterator[Data]: An iterator of strings (:class:`str`) for a Text_ frame or bytestrings (:class:`bytes`) for a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 Raises: ConnectionClosed: When the connection is closed. @@ -250,8 +250,8 @@ def send(self, message: Data | Iterable[Data]) -> None: bytes-like object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent as a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 :meth:`send` also accepts an iterable of strings, bytestrings, or bytes-like objects to enable fragmentation_. Each item is treated as a @@ -259,7 +259,7 @@ def send(self, message: Data | Iterable[Data]) -> None: same type, or else :meth:`send` will raise a :exc:`TypeError` and the connection will be closed. - .. _fragmentation: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.4 + .. _fragmentation: https://datatracker.ietf.org/doc/html/rfc6455#section-5.4 :meth:`send` rejects dict-like objects because this is often an error. (If you really want to send the keys of a dict-like object as fragments, @@ -425,7 +425,7 @@ def ping(self, data: Data | None = None) -> threading.Event: """ Send a Ping_. - .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 + .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 A ping may serve as a keepalive or as a check that the remote endpoint received all messages up to this point @@ -470,7 +470,7 @@ def pong(self, data: Data = b"") -> None: """ Send a Pong_. - .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 An unsolicited pong may serve as a unidirectional heartbeat. diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 6360c7a0a..447fe79da 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -24,8 +24,8 @@ """Types supported in a WebSocket message: :class:`str` for a Text_ frame, :class:`bytes` for a Binary_. -.. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 -.. _Binary : https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 +.. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 +.. _Binary : https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 """ diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 5cb38a9cc..82b35f92a 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -23,7 +23,7 @@ class WebSocketURI: username: Available when the URI contains `User Information`_. password: Available when the URI contains `User Information`_. - .. _User Information: https://www.rfc-editor.org/rfc/rfc3986.html#section-3.2.1 + .. _User Information: https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.1 """ From e2f0385119992317c0f49b32775ef50b2fa40218 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 13 Aug 2024 23:15:55 +0200 Subject: [PATCH 086/109] Add guide for upgrading to the new asyncio implementation. --- docs/howto/index.rst | 8 + docs/howto/upgrade.rst | 357 ++++++++++++++++++++++++++ docs/project/changelog.rst | 66 ++++- docs/reference/asyncio/client.rst | 4 +- docs/reference/asyncio/common.rst | 4 +- docs/reference/asyncio/server.rst | 10 +- docs/reference/new-asyncio/client.rst | 4 +- docs/reference/new-asyncio/common.rst | 4 +- docs/reference/new-asyncio/server.rst | 4 +- docs/spelling_wordlist.txt | 5 +- 10 files changed, 446 insertions(+), 20 deletions(-) create mode 100644 docs/howto/upgrade.rst diff --git a/docs/howto/index.rst b/docs/howto/index.rst index ddbe67d3a..863c1c63c 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -8,6 +8,14 @@ In a hurry? Check out these examples. quickstart +Upgrading from the legacy :mod:`asyncio` implementation to the new one? +Read this. + +.. toctree:: + :titlesonly: + + upgrade + If you're stuck, perhaps you'll find the answer here. .. toctree:: diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst new file mode 100644 index 000000000..bb4c59bc4 --- /dev/null +++ b/docs/howto/upgrade.rst @@ -0,0 +1,357 @@ +Upgrade to the new :mod:`asyncio` implementation +================================================ + +.. currentmodule:: websockets + +The new :mod:`asyncio` implementation is a rewrite of the original +implementation of websockets. + +It provides a very similar API. However, there are a few differences. + +The recommended upgrade process is: + +1. Make sure that your application doesn't use any `deprecated APIs`_. If it + doesn't raise any warnings, you can skip this step. +2. Check if your application depends on `missing features`_. If it does, you + should stick to the original implementation until they're added. +3. `Update import paths`_. For straightforward usage of websockets, this could + be the only step you need to take. Upgrading could be transparent. +4. `Review API changes`_ and adapt your application to preserve its current + functionality or take advantage of improvements in the new implementation. + +In the interest of brevity, only :func:`~asyncio.client.connect` and +:func:`~asyncio.server.serve` are discussed below but everything also applies +to :func:`~asyncio.client.unix_connect` and :func:`~asyncio.server.unix_serve` +respectively. + +.. admonition:: What will happen to the original implementation? + :class: hint + + The original implementation is now considered legacy. + + The next steps are: + + 1. Deprecating it once the new implementation reaches feature parity. + 2. Maintaining it for five years per the :ref:`backwards-compatibility + policy `. + 3. Removing it. This is expected to happen around 2030. + +.. _deprecated APIs: + +Deprecated APIs +--------------- + +Here's the list of deprecated behaviors that the original implementation still +supports and that the new implementation doesn't reproduce. + +If you're seeing a :class:`DeprecationWarning`, follow upgrade instructions from +the release notes of the version in which the feature was deprecated. + +* The ``path`` argument of connection handlers — unnecessary since :ref:`10.1` + and deprecated in :ref:`13.0`. +* The ``loop`` and ``legacy_recv`` arguments of :func:`~client.connect` and + :func:`~server.serve`, which were removed — deprecated in :ref:`10.0`. +* The ``timeout`` and ``klass`` arguments of :func:`~client.connect` and + :func:`~server.serve`, which were renamed to ``close_timeout`` and + ``create_protocol`` — deprecated in :ref:`7.0` and :ref:`3.4` respectively. +* An empty string in the ``origins`` argument of :func:`~server.serve` — + deprecated in :ref:`7.0`. +* The ``host``, ``port``, and ``secure`` attributes of connections — deprecated + in :ref:`8.0`. + +.. _missing features: + +Missing features +---------------- + +.. admonition:: All features listed below will be provided in a future release. + :class: tip + + If your application relies on one of them, you should stick to the original + implementation until the new implementation supports it in a future release. + +Broadcast +......... + +The new implementation doesn't support :doc:`broadcasting messages +<../topics/broadcast>` yet. + +Keepalive +......... + +The new implementation doesn't provide a :ref:`keepalive mechanism ` +yet. + +As a consequence, :func:`~asyncio.client.connect` and +:func:`~asyncio.server.serve` don't accept the ``ping_interval`` and +``ping_timeout`` arguments and the +:attr:`~legacy.protocol.WebSocketCommonProtocol.latency` property doesn't exist. + +HTTP Basic Authentication +......................... + +On the server side, :func:`~asyncio.server.serve` doesn't provide HTTP Basic +Authentication yet. + +For the avoidance of doubt, on the client side, :func:`~asyncio.client.connect` +performs HTTP Basic Authentication. + +Following redirects +................... + +The new implementation of :func:`~asyncio.client.connect` doesn't follow HTTP +redirects yet. + +Automatic reconnection +...................... + +The new implementation of :func:`~asyncio.client.connect` doesn't provide +automatic reconnection yet. + +In other words, the following pattern isn't supported:: + + from websockets.asyncio.client import connect + + async for websocket in connect(...): # this doesn't work yet + ... + +Configuring buffers +................... + +The new implementation doesn't provide a way to configure read and write buffers +yet. + +In practice, :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` +don't accept the ``max_queue``, ``read_limit``, and ``write_limit`` arguments. + +Here's the most likely outcome: + +* ``max_queue`` will be implemented but its semantics will change from "maximum + number of messages" to "maximum number of frames", which makes a difference + when messages are fragmented. +* ``read_limit`` won't be implemented because the buffer that it configured was + removed from the new implementation. The queue that ``max_queue`` configures + is the only read buffer now. +* ``write_limit`` will be implemented as in the original implementation. + Alternatively, the same functionality could be exposed with a different API. + +.. _Update import paths: + +Import paths +------------ + +For context, the ``websockets`` package is structured as follows: + +* The new implementation is found in the ``websockets.asyncio`` package. +* The original implementation was moved to the ``websockets.legacy`` package. +* The ``websockets`` package provides aliases for convenience. +* The ``websockets.client`` and ``websockets.server`` packages provide aliases + for backwards-compatibility with earlier versions of websockets. +* Currently, all aliases point to the original implementation. In the future, + they will point to the new implementation or they will be deprecated. + +To upgrade to the new :mod:`asyncio` implementation, change import paths as +shown in the tables below. + +.. |br| raw:: html + +
+ +Client APIs +........... + ++-------------------------------------------------------------------+-----------------------------------------------------+ +| Legacy :mod:`asyncio` implementation | New :mod:`asyncio` implementation | ++===================================================================+=====================================================+ +| ``websockets.connect()`` |br| | :func:`websockets.asyncio.client.connect` | +| :func:`websockets.client.connect` |br| | | +| ``websockets.legacy.client.connect()`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ +| ``websockets.unix_connect()`` |br| | :func:`websockets.asyncio.client.unix_connect` | +| :func:`websockets.client.unix_connect` |br| | | +| ``websockets.legacy.client.unix_connect()`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ +| ``websockets.WebSocketClientProtocol`` |br| | :class:`websockets.asyncio.client.ClientConnection` | +| :class:`websockets.client.WebSocketClientProtocol` |br| | | +| ``websockets.legacy.client.WebSocketClientProtocol`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ + +Server APIs +........... + ++-------------------------------------------------------------------+-----------------------------------------------------+ +| Legacy :mod:`asyncio` implementation | New :mod:`asyncio` implementation | ++===================================================================+=====================================================+ +| ``websockets.serve()`` |br| | :func:`websockets.asyncio.server.serve` | +| :func:`websockets.server.serve` |br| | | +| ``websockets.legacy.server.serve()`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ +| ``websockets.unix_serve()`` |br| | :func:`websockets.asyncio.server.unix_serve` | +| :func:`websockets.server.unix_serve` |br| | | +| ``websockets.legacy.server.unix_serve()`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ +| ``websockets.WebSocketServer`` |br| | :class:`websockets.asyncio.server.WebSocketServer` | +| :class:`websockets.server.WebSocketServer` |br| | | +| ``websockets.legacy.server.WebSocketServer`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ +| ``websockets.WebSocketServerProtocol`` |br| | :class:`websockets.asyncio.server.ServerConnection` | +| :class:`websockets.server.WebSocketServerProtocol` |br| | | +| ``websockets.legacy.server.WebSocketServerProtocol`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ +| :func:`websockets.broadcast` |br| | *not available yet* | +| ``websockets.legacy.protocol.broadcast()`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ +| ``websockets.BasicAuthWebSocketServerProtocol`` |br| | *not available yet* | +| :class:`websockets.auth.BasicAuthWebSocketServerProtocol` |br| | | +| ``websockets.legacy.auth.BasicAuthWebSocketServerProtocol`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ +| ``websockets.basic_auth_protocol_factory()`` |br| | *not available yet* | +| :func:`websockets.auth.basic_auth_protocol_factory` |br| | | +| ``websockets.legacy.auth.basic_auth_protocol_factory()`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ + +.. _Review API changes: + +API changes +----------- + +Controlling UTF-8 decoding +.......................... + +The new implementation of the :meth:`~asyncio.connection.Connection.recv` method +provides the ``decode`` argument to control UTF-8 decoding of messages. This +didn't exist in the original implementation. + +If you're calling :meth:`~str.encode` on a :class:`str` object returned by +:meth:`~asyncio.connection.Connection.recv`, using ``decode=False`` and removing +:meth:`~str.encode` saves a round-trip of UTF-8 decoding and encoding for text +messages. + +You can also force UTF-8 decoding of binary messages with ``decode=True``. This +is rarely useful and has no performance benefits over decoding a :class:`bytes` +object returned by :meth:`~asyncio.connection.Connection.recv`. + +Receiving fragmented messages +............................. + +The new implementation provides the +:meth:`~asyncio.connection.Connection.recv_streaming` method for receiving a +fragmented message frame by frame. There was no way to do this in the original +implementation. + +Depending on your use case, adopting this method may improve performance when +streaming large messages. Specifically, it could reduce memory usage. + +Customizing the opening handshake +................................. + +On the client side, if you're adding headers to the handshake request sent by +:func:`~client.connect` with the ``extra_headers`` argument, you must rename it +to ``additional_headers``. + +On the server side, if you're customizing how :func:`~server.serve` processes +the opening handshake with the ``process_request``, ``extra_headers``, or +``select_subprotocol``, you must update your code. ``process_response`` and +``select_subprotocol`` have new signatures; ``process_response`` replaces +``extra_headers`` and provides more flexibility. + +``process_request`` +~~~~~~~~~~~~~~~~~~~ + +The signature of ``process_request`` changed. This is easiest to illustrate with +an example:: + + import http + + # Original implementation + + def process_request(path, request_headers): + return http.HTTPStatus.OK, [], b"OK\n" + + serve(..., process_request=process_request, ...) + + # New implementation + + def process_request(connection, request): + return connection.protocol.reject(http.HTTPStatus.OK, "OK\n") + + serve(..., process_request=process_request, ...) + +``connection`` is always available in ``process_request``. In the original +implementation, you had to write a subclass of +:class:`~server.WebSocketServerProtocol` and pass it in the ``create_protocol`` +argument to make the connection object available in a ``process_request`` +method. This pattern isn't useful anymore; you can replace it with a +``process_request`` function or coroutine. + +``path`` and ``headers`` are available as attributes of the ``request`` object. + +``process_response`` +~~~~~~~~~~~~~~~~~~~~ + +``process_request`` replaces ``extra_headers`` and provides more flexibility. +In the most basic case, you would adapt your code as follows:: + + # Original implementation + + serve(..., extra_headers=HEADERS, ...) + + # New implementation + + def process_response(connection, request, response): + response.headers.update(HEADERS) + return response + + serve(..., process_response=process_response, ...) + +``connection`` is always available in ``process_response``, similar to +``process_request``. In the original implementation, there was no way to make +the connection object available. + +In addition, the ``request`` and ``response`` objects are available, which +enables a broader range of use cases (e.g., logging) and makes +``process_response`` more useful than ``extra_headers``. + +``select_subprotocol`` +~~~~~~~~~~~~~~~~~~~~~~ + +The signature of ``select_subprotocol`` changed. Here's an example:: + + # Original implementation + + def select_subprotocol(client_subprotocols, server_subprotocols): + if "chat" in client_subprotocols: + return "chat" + + # New implementation + + def select_subprotocol(connection, subprotocols): + if "chat" in subprotocols + return "chat" + + serve(..., select_subprotocol=select_subprotocol, ...) + +``connection`` is always available in ``select_subprotocol``. This brings the +same benefits as in ``process_request``. It may remove the need to subclass of +:class:`~server.WebSocketServerProtocol`. + +The ``subprotocols`` argument contains the list of subprotocols offered by the +client. The list of subprotocols supported by the server was removed because +``select_subprotocols`` already knows which subprotocols it may select and under +which conditions. + +Miscellaneous changes +..................... + +The first argument of :func:`~asyncio.server.serve` is called ``handler`` instead +of ``ws_handler``. It's usually passed as a positional argument, making this +change transparent. If you're passing it as a keyword argument, you must update +its name. + +The keyword argument of :func:`~asyncio.server.serve` for customizing the +creation of the connection object is called ``create_connection`` instead of +``create_protocol``. It must return a :class:`~asyncio.server.ServerConnection` +instead of a :class:`~server.WebSocketServerProtocol`. If you were customizing +connection objects, you should check the new implementation and possibly redo +your customization. Keep in mind that the changes to ``process_request`` and +``select_subprotocol`` remove most use cases for ``create_connection``. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 00b055dd1..f033f5632 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,8 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. +.. _13.0: + 13.0 ---- @@ -66,10 +68,10 @@ New features This new implementation is intended to be a drop-in replacement for the current implementation. It will become the default in a future release. - Please try it and report any issue that you encounter! - See :func:`websockets.asyncio.client.connect` and - :func:`websockets.asyncio.server.serve` for details. + Please try it and report any issue that you encounter! The :doc:`upgrade + guide <../howto/upgrade>` explains everything you need to know about the + upgrade process. * Validated compatibility with Python 3.12 and 3.13. @@ -79,6 +81,8 @@ New features If you were monkey-patching constants, be aware that they were renamed, which will break your configuration. You must switch to the environment variables. +.. _12.0: + 12.0 ---- @@ -135,6 +139,8 @@ Bug fixes * Restored the C extension in the source distribution. +.. _11.0: + 11.0 ---- @@ -211,6 +217,8 @@ Improvements * Set ``server_hostname`` automatically on TLS connections when providing a ``sock`` argument to :func:`~sync.client.connect`. +.. _10.4: + 10.4 ---- @@ -237,6 +245,8 @@ Improvements * Improved FAQ. +.. _10.3: + 10.3 ---- @@ -259,6 +269,8 @@ Improvements * Reduced noise in logs when :mod:`ssl` or :mod:`zlib` raise exceptions. +.. _10.2: + 10.2 ---- @@ -279,6 +291,8 @@ Bug fixes * Avoided leaking open sockets when :func:`~client.connect` is canceled. +.. _10.1: + 10.1 ---- @@ -328,6 +342,8 @@ Bug fixes * Avoided half-closing TCP connections that are already closed. +.. _10.0: + 10.0 ---- @@ -434,6 +450,8 @@ Bug fixes * Avoided a crash when receiving a ping while the connection is closing. +.. _9.1: + 9.1 --- @@ -472,6 +490,8 @@ Bug fixes * Fixed issues with the packaging of the 9.0 release. +.. _9.0: + 9.0 --- @@ -549,6 +569,8 @@ Bug fixes * Ensured cancellation always propagates, even on Python versions where :exc:`~asyncio.CancelledError` inherits :exc:`Exception`. +.. _8.1: + 8.1 --- @@ -583,6 +605,8 @@ Bug fixes * Restored the ability to import ``WebSocketProtocolError`` from ``websockets``. +.. _8.0: + 8.0 --- @@ -692,6 +716,8 @@ Bug fixes * Avoided a crash when a ``extra_headers`` callable returns :obj:`None`. +.. _7.0: + 7.0 --- @@ -786,6 +812,8 @@ Bug fixes :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`: canceling it at the wrong time could result in messages being dropped. +.. _6.0: + 6.0 --- @@ -840,6 +868,8 @@ Bug fixes * Fixed a regression in 5.0 that broke some invocations of :func:`~server.serve` and :func:`~client.connect`. +.. _5.0: + 5.0 --- @@ -925,6 +955,8 @@ Bug fixes * Fixed issues with the packaging of the 4.0 release. +.. _4.0: + 4.0 --- @@ -984,6 +1016,8 @@ Bug fixes * Stopped leaking pending tasks when :meth:`~asyncio.Task.cancel` is called on a connection while it's being closed. +.. _3.4: + 3.4 --- @@ -1027,6 +1061,8 @@ Bug fixes * Providing a ``sock`` argument to :func:`~client.connect` no longer crashes. +.. _3.3: + 3.3 --- @@ -1047,6 +1083,8 @@ Bug fixes * Avoided crashing on concurrent writes on slow connections. +.. _3.2: + 3.2 --- @@ -1063,6 +1101,8 @@ Improvements * Made server shutdown more robust. +.. _3.1: + 3.1 --- @@ -1078,6 +1118,8 @@ Bug fixes * Avoided a warning when closing a connection before the opening handshake. +.. _3.0: + 3.0 --- @@ -1135,6 +1177,8 @@ Improvements * Improved documentation. +.. _2.7: + 2.7 --- @@ -1150,6 +1194,8 @@ Improvements * Refreshed documentation. +.. _2.6: + 2.6 --- @@ -1167,6 +1213,8 @@ Bug fixes * Avoided TCP fragmentation of small frames. +.. _2.5: + 2.5 --- @@ -1200,6 +1248,8 @@ Bug fixes * Canceling :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` no longer drops the next message. +.. _2.4: + 2.4 --- @@ -1213,6 +1263,8 @@ New features * Added ``loop`` argument to :func:`~client.connect` and :func:`~server.serve`. +.. _2.3: + 2.3 --- @@ -1223,6 +1275,8 @@ Improvements * Improved compliance of close codes. +.. _2.2: + 2.2 --- @@ -1233,6 +1287,8 @@ New features * Added support for limiting message size. +.. _2.1: + 2.1 --- @@ -1247,6 +1303,8 @@ New features .. _Origin: https://www.rfc-editor.org/rfc/rfc6455.html#section-10.2 +.. _2.0: + 2.0 --- @@ -1275,6 +1333,8 @@ New features * Added flow control for outgoing data. +.. _1.0: + 1.0 --- diff --git a/docs/reference/asyncio/client.rst b/docs/reference/asyncio/client.rst index 5086015b7..f9ce2f2d8 100644 --- a/docs/reference/asyncio/client.rst +++ b/docs/reference/asyncio/client.rst @@ -1,5 +1,5 @@ -Client (:mod:`asyncio`) -======================= +Client (legacy :mod:`asyncio`) +============================== .. automodule:: websockets.client diff --git a/docs/reference/asyncio/common.rst b/docs/reference/asyncio/common.rst index dc7a54ee1..aee774479 100644 --- a/docs/reference/asyncio/common.rst +++ b/docs/reference/asyncio/common.rst @@ -1,7 +1,7 @@ :orphan: -Both sides (:mod:`asyncio`) -=========================== +Both sides (legacy :mod:`asyncio`) +================================== .. automodule:: websockets.legacy.protocol diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index 106317916..4bd52b40b 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -1,15 +1,15 @@ -Server (:mod:`asyncio`) -======================= +Server (legacy :mod:`asyncio`) +============================== .. automodule:: websockets.server Starting a server ----------------- -.. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: -.. autofunction:: unix_serve(ws_handler, path=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: unix_serve(ws_handler, path=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: Stopping a server @@ -34,7 +34,7 @@ Stopping a server Using a connection ------------------ -.. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, logger=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) +.. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, logger=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) .. automethod:: recv diff --git a/docs/reference/new-asyncio/client.rst b/docs/reference/new-asyncio/client.rst index 552d83b2f..196bda2b7 100644 --- a/docs/reference/new-asyncio/client.rst +++ b/docs/reference/new-asyncio/client.rst @@ -1,5 +1,5 @@ -Client (:mod:`asyncio` - new) -============================= +Client (new :mod:`asyncio`) +=========================== .. automodule:: websockets.asyncio.client diff --git a/docs/reference/new-asyncio/common.rst b/docs/reference/new-asyncio/common.rst index ba23552dc..4fa97dcf2 100644 --- a/docs/reference/new-asyncio/common.rst +++ b/docs/reference/new-asyncio/common.rst @@ -1,7 +1,7 @@ :orphan: -Both sides (:mod:`asyncio` - new) -================================= +Both sides (new :mod:`asyncio`) +=============================== .. automodule:: websockets.asyncio.connection diff --git a/docs/reference/new-asyncio/server.rst b/docs/reference/new-asyncio/server.rst index f3446fb80..c43673d33 100644 --- a/docs/reference/new-asyncio/server.rst +++ b/docs/reference/new-asyncio/server.rst @@ -1,5 +1,5 @@ -Server (:mod:`asyncio` - new) -============================= +Server (new :mod:`asyncio`) +=========================== .. automodule:: websockets.asyncio.server diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index dfa7065e7..a1ba59a37 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -21,8 +21,8 @@ cryptocurrency css ctrl deserialize -django dev +django Dockerfile dyno formatter @@ -44,6 +44,7 @@ linkerd liveness lookups MiB +middleware mutex mypy nginx @@ -77,8 +78,8 @@ uple uvicorn uvloop virtualenv -WebSocket websocket +WebSocket websockets ws wsgi From 8385cf02fccd5e171e1ee5b8949df11773c0f954 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 14 Aug 2024 08:40:46 +0200 Subject: [PATCH 087/109] Add write_limit parameter to the new asyncio API. --- docs/howto/upgrade.rst | 84 ++++++++++++++++++---------- src/websockets/asyncio/client.py | 16 ++++++ src/websockets/asyncio/connection.py | 16 +++++- src/websockets/asyncio/messages.py | 21 ++++--- src/websockets/asyncio/server.py | 16 ++++++ tests/asyncio/test_connection.py | 38 ++++++++++++- tests/asyncio/test_messages.py | 35 +++++++----- 7 files changed, 166 insertions(+), 60 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index bb4c59bc4..10e8967d8 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -115,26 +115,6 @@ In other words, the following pattern isn't supported:: async for websocket in connect(...): # this doesn't work yet ... -Configuring buffers -................... - -The new implementation doesn't provide a way to configure read and write buffers -yet. - -In practice, :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` -don't accept the ``max_queue``, ``read_limit``, and ``write_limit`` arguments. - -Here's the most likely outcome: - -* ``max_queue`` will be implemented but its semantics will change from "maximum - number of messages" to "maximum number of frames", which makes a difference - when messages are fragmented. -* ``read_limit`` won't be implemented because the buffer that it configured was - removed from the new implementation. The queue that ``max_queue`` configures - is the only read buffer now. -* ``write_limit`` will be implemented as in the original implementation. - Alternatively, the same functionality could be exposed with a different API. - .. _Update import paths: Import paths @@ -340,18 +320,60 @@ client. The list of subprotocols supported by the server was removed because ``select_subprotocols`` already knows which subprotocols it may select and under which conditions. -Miscellaneous changes -..................... +Arguments of :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` +.............................................................................. + +``ws_handler`` → ``handler`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -The first argument of :func:`~asyncio.server.serve` is called ``handler`` instead -of ``ws_handler``. It's usually passed as a positional argument, making this -change transparent. If you're passing it as a keyword argument, you must update -its name. +The first argument of :func:`~asyncio.server.serve` is now called ``handler`` +instead of ``ws_handler``. It's usually passed as a positional argument, making +this change transparent. If you're passing it as a keyword argument, you must +update its name. + +``create_protocol`` → ``create_connection`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The keyword argument of :func:`~asyncio.server.serve` for customizing the -creation of the connection object is called ``create_connection`` instead of +creation of the connection object is now called ``create_connection`` instead of ``create_protocol``. It must return a :class:`~asyncio.server.ServerConnection` -instead of a :class:`~server.WebSocketServerProtocol`. If you were customizing -connection objects, you should check the new implementation and possibly redo -your customization. Keep in mind that the changes to ``process_request`` and -``select_subprotocol`` remove most use cases for ``create_connection``. +instead of a :class:`~server.WebSocketServerProtocol`. + +If you were customizing connection objects, you should check the new +implementation and possibly redo your customization. Keep in mind that the +changes to ``process_request`` and ``select_subprotocol`` remove most use cases +for ``create_connection``. + +``max_queue`` +~~~~~~~~~~~~~ + +The ``max_queue`` argument of :func:`~asyncio.client.connect` and +:func:`~asyncio.server.serve` has a new meaning but achieves a similar effect. + +It is now the high-water mark of a buffer of incoming frames. It defaults to 16 +frames. It used to be the size of a buffer of incoming messages that refilled as +soon as a message was read. It used to default to 32 messages. + +This can make a difference when messages are fragmented in several frames. In +that case, you may want to increase ``max_queue``. If you're writing a high +performance server and you know that you're receiving fragmented messages, +probably you should adopt :meth:`~asyncio.connection.Connection.recv_streaming` +and optimize the performance of reads again. In all other cases, given how +uncommon fragmentation is, you shouldn't worry about this change. + +``read_limit`` +~~~~~~~~~~~~~~ + +The ``read_limit`` argument doesn't exist in the new implementation because it +doesn't buffer data received from the network in a +:class:`~asyncio.StreamReader`. With a better design, this buffer could be +removed. + +The buffer of incoming frames configured by ``max_queue`` is the only read +buffer now. + +``write_limit`` +~~~~~~~~~~~~~~~ + +The ``write_limit`` argument of :func:`~asyncio.client.connect` and +:func:`~asyncio.server.serve` defaults to 32 KiB instead of 64 KiB. diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index ac8ded8ca..b2eaf9a65 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -49,11 +49,15 @@ def __init__( protocol: ClientProtocol, *, close_timeout: float | None = 10, + max_queue: int | tuple[int, int | None] = 16, + write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol: ClientProtocol super().__init__( protocol, close_timeout=close_timeout, + max_queue=max_queue, + write_limit=write_limit, ) self.response_rcvd: asyncio.Future[None] = self.loop.create_future() @@ -146,6 +150,14 @@ class connect: :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. :obj:`None` disables the limit. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. + write_limit: High-water mark of write buffer in bytes. It is passed to + :meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults + to 32 KiB. You may pass a ``(high, low)`` tuple to set the + high-water and low-water marks. logger: Logger for this client. It defaults to ``logging.getLogger("websockets.client")``. See the :doc:`logging guide <../../topics/logging>` for details. @@ -199,6 +211,8 @@ def __init__( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, + max_queue: int | tuple[int, int | None] = 16, + write_limit: int | tuple[int, int | None] = 2**15, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -243,6 +257,8 @@ def factory() -> ClientConnection: connection = create_connection( protocol, close_timeout=close_timeout, + max_queue=max_queue, + write_limit=write_limit, ) return connection diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 0a3ddb9aa..1c4424f0d 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -48,9 +48,17 @@ def __init__( protocol: Protocol, *, close_timeout: float | None = 10, + max_queue: int | tuple[int, int | None] = 16, + write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol = protocol self.close_timeout = close_timeout + if isinstance(max_queue, int): + max_queue = (max_queue, None) + self.max_queue = max_queue + if isinstance(write_limit, int): + write_limit = (write_limit, None) + self.write_limit = write_limit # Inject reference to this instance in the protocol's logger. self.protocol.logger = logging.LoggerAdapter( @@ -803,11 +811,13 @@ def close_transport(self) -> None: def connection_made(self, transport: asyncio.BaseTransport) -> None: transport = cast(asyncio.Transport, transport) - self.transport = transport self.recv_messages = Assembler( - pause=self.transport.pause_reading, - resume=self.transport.resume_reading, + *self.max_queue, + pause=transport.pause_reading, + resume=transport.resume_reading, ) + transport.set_write_buffer_limits(*self.write_limit) + self.transport = transport def connection_lost(self, exc: Exception | None) -> None: self.protocol.receive_eof() # receive_eof is idempotent diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index bc33df8d7..33ab6a5e9 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -89,6 +89,8 @@ class Assembler: # coverage reports incorrectly: "line NN didn't jump to the function exit" def __init__( # pragma: no cover self, + high: int = 16, + low: int | None = None, pause: Callable[[], Any] = lambda: None, resume: Callable[[], Any] = lambda: None, ) -> None: @@ -99,11 +101,16 @@ def __init__( # pragma: no cover # call to Protocol.data_received() could produce thousands of frames, # which must be buffered. Instead, we pause reading when the buffer goes # above the high limit and we resume when it goes under the low limit. - self.high = 16 - self.low = 4 - self.paused = False + if low is None: + low = high // 4 + if low < 0: + raise ValueError("low must be positive or equal to zero") + if high < low: + raise ValueError("high must be greater than or equal to low") + self.high, self.low = high, low self.pause = pause self.resume = resume + self.paused = False # This flag prevents concurrent calls to get() by user code. self.get_in_progress = False @@ -254,14 +261,6 @@ def put(self, frame: Frame) -> None: self.frames.put(frame) self.maybe_pause() - def get_limits(self) -> tuple[int, int]: - """Return low and high water marks for flow control.""" - return self.low, self.high - - def set_limits(self, low: int = 4, high: int = 16) -> None: - """Configure low and high water marks for flow control.""" - self.low, self.high = low, high - def maybe_pause(self) -> None: """Pause the writer if queue is above the high water mark.""" # Check for "> high" to support high = 0 diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 0c8b8780b..4feea13c4 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -62,11 +62,15 @@ def __init__( server: WebSocketServer, *, close_timeout: float | None = 10, + max_queue: int | tuple[int, int | None] = 16, + write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol: ServerProtocol super().__init__( protocol, close_timeout=close_timeout, + max_queue=max_queue, + write_limit=write_limit, ) self.server = server self.request_rcvd: asyncio.Future[None] = self.loop.create_future() @@ -574,6 +578,14 @@ def handler(websocket): :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. :obj:`None` disables the limit. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. + write_limit: High-water mark of write buffer in bytes. It is passed to + :meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults + to 32 KiB. You may pass a ``(high, low)`` tuple to set the + high-water and low-water marks. logger: Logger for this server. It defaults to ``logging.getLogger("websockets.server")``. See the :doc:`logging guide <../../topics/logging>` for details. @@ -637,6 +649,8 @@ def __init__( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, + max_queue: int | tuple[int, int | None] = 16, + write_limit: int | tuple[int, int | None] = 2**15, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -709,6 +723,8 @@ def protocol_select_subprotocol( protocol, self.server, close_timeout=close_timeout, + max_queue=max_queue, + write_limit=write_limit, ) return connection diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 2efd4e96d..02029b754 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -4,7 +4,7 @@ import socket import unittest import uuid -from unittest.mock import patch +from unittest.mock import Mock, patch from websockets.asyncio.compatibility import TimeoutError, aiter, anext, asyncio_timeout from websockets.asyncio.connection import * @@ -867,6 +867,42 @@ async def test_pong_explicit_binary(self): await self.connection.pong(b"pong") await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + # Test parameters. + + async def test_close_timeout(self): + """close_timeout parameter configures close timeout.""" + connection = Connection(Protocol(self.LOCAL), close_timeout=42 * MS) + self.assertEqual(connection.close_timeout, 42 * MS) + + async def test_max_queue(self): + """max_queue parameter configures high-water mark of frames buffer.""" + connection = Connection(Protocol(self.LOCAL), max_queue=4) + transport = Mock() + connection.connection_made(transport) + self.assertEqual(connection.recv_messages.high, 4) + + async def test_max_queue_tuple(self): + """max_queue parameter configures high-water mark of frames buffer.""" + connection = Connection(Protocol(self.LOCAL), max_queue=(4, 2)) + transport = Mock() + connection.connection_made(transport) + self.assertEqual(connection.recv_messages.high, 4) + self.assertEqual(connection.recv_messages.low, 2) + + async def test_write_limit(self): + """write_limit parameter configures high-water mark of write buffer.""" + connection = Connection(Protocol(self.LOCAL), write_limit=4096) + transport = Mock() + connection.connection_made(transport) + transport.set_write_buffer_limits.assert_called_once_with(4096, None) + + async def test_write_limits(self): + """write_limit parameter configures high and low-water marks of write buffer.""" + connection = Connection(Protocol(self.LOCAL), write_limit=(4096, 2048)) + transport = Mock() + connection.connection_made(transport) + transport.set_write_buffer_limits.assert_called_once_with(4096, 2048) + # Test attributes. async def test_id(self): diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index c8a2d7cd5..615b1f3a8 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -70,8 +70,7 @@ class AssemblerTests(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): self.pause = unittest.mock.Mock() self.resume = unittest.mock.Mock() - self.assembler = Assembler(pause=self.pause, resume=self.resume) - self.assembler.set_limits(low=1, high=2) + self.assembler = Assembler(high=2, low=1, pause=self.pause, resume=self.resume) # Test get @@ -455,17 +454,25 @@ async def test_get_iter_fails_when_get_iter_is_running(self): await alist(self.assembler.get_iter()) self.assembler.close() # let task terminate - # Test getting and setting limits + # Test setting limits - async def test_get_limits(self): - """get_limits returns low and high water marks.""" - low, high = self.assembler.get_limits() - self.assertEqual(low, 1) - self.assertEqual(high, 2) + async def test_set_high_water_mark(self): + """high sets the high-water mark.""" + assembler = Assembler(high=10) + self.assertEqual(assembler.high, 10) - async def test_set_limits(self): - """set_limits changes low and high water marks.""" - self.assembler.set_limits(low=2, high=4) - low, high = self.assembler.get_limits() - self.assertEqual(low, 2) - self.assertEqual(high, 4) + async def test_set_high_and_low_water_mark(self): + """high sets the high-water mark.""" + assembler = Assembler(high=10, low=5) + self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 5) + + async def test_set_invalid_high_water_mark(self): + """high must be a non-negative integer.""" + with self.assertRaises(ValueError): + Assembler(high=-1) + + async def test_set_invalid_low_water_mark(self): + """low must be higher than high.""" + with self.assertRaises(ValueError): + Assembler(low=10, high=5) From 5eafbe466b909f21dc7e74b1350583b4d5ae0606 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 15 Aug 2024 16:25:51 +0200 Subject: [PATCH 088/109] Rewrite documentation of buffers. Describe all implementations. Also update documentation of compression. --- .gitignore | 2 +- docs/topics/compression.rst | 173 ++++++++++-------- docs/topics/design.rst | 49 ----- docs/topics/memory.rst | 156 +++++++++++++--- experiments/compression/benchmark.py | 74 ++------ experiments/compression/client.py | 18 +- experiments/compression/corpus.py | 52 ++++++ experiments/compression/server.py | 10 +- .../extensions/permessage_deflate.py | 6 +- 9 files changed, 316 insertions(+), 224 deletions(-) create mode 100644 experiments/compression/corpus.py diff --git a/.gitignore b/.gitignore index 324e77069..d8e6697a8 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,7 @@ .tox build/ compliance/reports/ -experiments/compression/corpus.pkl +experiments/compression/corpus/ dist/ docs/_build/ htmlcov/ diff --git a/docs/topics/compression.rst b/docs/topics/compression.rst index eaf99070d..be263e56f 100644 --- a/docs/topics/compression.rst +++ b/docs/topics/compression.rst @@ -7,37 +7,36 @@ Most WebSocket servers exchange JSON messages because they're convenient to parse and serialize in a browser. These messages contain text data and tend to be repetitive. -This makes the stream of messages highly compressible. Enabling compression +This makes the stream of messages highly compressible. Compressing messages can reduce network traffic by more than 80%. -There's a standard for compressing messages. :rfc:`7692` defines WebSocket -Per-Message Deflate, a compression extension based on the Deflate_ algorithm. +websockets implements WebSocket Per-Message Deflate, a compression extension +based on the Deflate_ algorithm specified in :rfc:`7692`. .. _Deflate: https://en.wikipedia.org/wiki/Deflate -Configuring compression ------------------------ +:func:`~websockets.asyncio.client.connect` and +:func:`~websockets.asyncio.server.serve` enable compression by default because +the reduction in network bandwidth is usually worth the additional memory and +CPU cost. -:func:`~websockets.client.connect` and :func:`~websockets.server.serve` enable -compression by default because the reduction in network bandwidth is usually -worth the additional memory and CPU cost. -If you want to disable compression, set ``compression=None``:: +Configuring compression +----------------------- - import websockets +To disable compression, set ``compression=None``:: - websockets.connect(..., compression=None) + connect(..., compression=None, ...) - websockets.serve(..., compression=None) + serve(..., compression=None, ...) -If you want to customize compression settings, you can enable the Per-Message -Deflate extension explicitly with :class:`ClientPerMessageDeflateFactory` or +To customize compression settings, enable the Per-Message Deflate extension +explicitly with :class:`ClientPerMessageDeflateFactory` or :class:`ServerPerMessageDeflateFactory`:: - import websockets from websockets.extensions import permessage_deflate - websockets.connect( + connect( ..., extensions=[ permessage_deflate.ClientPerMessageDeflateFactory( @@ -46,9 +45,10 @@ Deflate extension explicitly with :class:`ClientPerMessageDeflateFactory` or compress_settings={"memLevel": 4}, ), ], + ..., ) - websockets.serve( + serve( ..., extensions=[ permessage_deflate.ServerPerMessageDeflateFactory( @@ -57,13 +57,14 @@ Deflate extension explicitly with :class:`ClientPerMessageDeflateFactory` or compress_settings={"memLevel": 4}, ), ], + ..., ) The Window Bits and Memory Level values in these examples reduce memory usage at the expense of compression rate. -Compression settings --------------------- +Compression parameters +---------------------- When a client and a server enable the Per-Message Deflate extension, they negotiate two parameters to guarantee compatibility between compression and @@ -81,9 +82,9 @@ and memory usage for both sides. This requires retaining the compression context and state between messages, which increases the memory footprint of a connection. -* **Window Bits** controls the size of the compression context. It must be - an integer between 9 (lowest memory usage) and 15 (best compression). - Setting it to 8 is possible but rejected by some versions of zlib. +* **Window Bits** controls the size of the compression context. It must be an + integer between 9 (lowest memory usage) and 15 (best compression). Setting it + to 8 is possible but rejected by some versions of zlib and not very useful. On the server side, websockets defaults to 12. Specifically, the compression window size (server to client) is always 12 while the decompression window @@ -94,9 +95,8 @@ and memory usage for both sides. has the same effect as defaulting to 15. :mod:`zlib` offers additional parameters for tuning compression. They control -the trade-off between compression rate, memory usage, and CPU usage only for -compressing. They're transparent for decompressing. Unless mentioned -otherwise, websockets inherits defaults of :func:`~zlib.compressobj`. +the trade-off between compression rate, memory usage, and CPU usage for +compressing. They're transparent for decompressing. * **Memory Level** controls the size of the compression state. It must be an integer between 1 (lowest memory usage) and 9 (best compression). @@ -108,87 +108,82 @@ otherwise, websockets inherits defaults of :func:`~zlib.compressobj`. * **Compression Level** controls the effort to optimize compression. It must be an integer between 1 (lowest CPU usage) and 9 (best compression). + websockets relies on the default value chosen by :func:`~zlib.compressobj`, + ``Z_DEFAULT_COMPRESSION``. + * **Strategy** selects the compression strategy. The best choice depends on the type of data being compressed. + websockets relies on the default value chosen by :func:`~zlib.compressobj`, + ``Z_DEFAULT_STRATEGY``. -Tuning compression ------------------- +To customize these parameters, add keyword arguments for +:func:`~zlib.compressobj` in ``compress_settings``. -For servers -........... +Default settings for servers +---------------------------- By default, websockets enables compression with conservative settings that optimize memory usage at the cost of a slightly worse compression rate: -Window Bits = 12 and Memory Level = 5. This strikes a good balance for small +Window Bits = 12 and Memory Level = 5. This strikes a good balance for small messages that are typical of WebSocket servers. -Here's how various compression settings affect memory usage of a single -connection on a 64-bit system, as well a benchmark of compressed size and -compression time for a corpus of small JSON documents. +Here's an example of how compression settings affect memory usage per +connection, compressed size, and compression time for a corpus of JSON +documents. =========== ============ ============ ================ ================ Window Bits Memory Level Memory usage Size vs. default Time vs. default =========== ============ ============ ================ ================ -15 8 322 KiB -4.0% +15% -14 7 178 KiB -2.6% +10% -13 6 106 KiB -1.4% +5% -**12** **5** **70 KiB** **=** **=** -11 4 52 KiB +3.7% -5% -10 3 43 KiB +90% +50% -9 2 39 KiB +160% +100% -— — 19 KiB +452% — +15 8 316 KiB -10% +10% +14 7 172 KiB -7% +5% +13 6 100 KiB -3% +2% +**12** **5** **64 KiB** **=** **=** +11 4 46 KiB +10% +4% +10 3 37 KiB +70% +40% +9 2 33 KiB +130% +90% +— — 14 KiB +350% — =========== ============ ============ ================ ================ Window Bits and Memory Level don't have to move in lockstep. However, other combinations don't yield significantly better results than those shown above. -Compressed size and compression time depend heavily on the kind of messages -exchanged by the application so this example may not apply to your use case. - -You can adapt `compression/benchmark.py`_ by creating a list of typical -messages and passing it to the ``_run`` function. - -Window Bits = 11 and Memory Level = 4 looks like the sweet spot in this table. - -websockets defaults to Window Bits = 12 and Memory Level = 5 to stay away from -Window Bits = 10 or Memory Level = 3 where performance craters, raising doubts -on what could happen at Window Bits = 11 and Memory Level = 4 on a different +websockets defaults to Window Bits = 12 and Memory Level = 5 to stay away from +Window Bits = 10 or Memory Level = 3 where performance craters, raising doubts +on what could happen at Window Bits = 11 and Memory Level = 4 on a different corpus. Defaults must be safe for all applications, hence a more conservative choice. -.. _compression/benchmark.py: https://github.com/python-websockets/websockets/blob/main/experiments/compression/benchmark.py +Optimizing settings +------------------- -The benchmark focuses on compression because it's more expensive than -decompression. Indeed, leaving aside small allocations, theoretical memory -usage is: +Compressed size and compression time depend on the structure of messages +exchanged by your application. As a consequence, default settings may not be +optimal for your use case. -* ``(1 << (windowBits + 2)) + (1 << (memLevel + 9))`` for compression; -* ``1 << windowBits`` for decompression. +To compare how various compression settings perform for your use case: -CPU usage is also higher for compression than decompression. +1. Create a corpus of typical messages in a directory, one message per file. +2. Run the `compression/benchmark.py`_ script, passing the directory in + argument. -While it's always possible for a server to use a smaller window size for -compressing outgoing messages, using a smaller window size for decompressing -incoming messages requires collaboration from clients. +The script measures compressed size and compression time for all combinations of +Window Bits and Memory Level. It outputs two tables with absolute values and two +tables with values relative to websockets' default settings. -When a client doesn't support configuring the size of its compression window, -websockets enables compression with the largest possible decompression window. -In most use cases, this is more efficient than disabling compression both ways. +Pick your favorite settings in these tables and configure them as shown above. -If you are very sensitive to memory usage, you can reverse this behavior by -setting the ``require_client_max_window_bits`` parameter of -:class:`ServerPerMessageDeflateFactory` to ``True``. +.. _compression/benchmark.py: https://github.com/python-websockets/websockets/blob/main/experiments/compression/benchmark.py -For clients -........... +Default settings for clients +---------------------------- -By default, websockets enables compression with Memory Level = 5 but leaves +By default, websockets enables compression with Memory Level = 5 but leaves the Window Bits setting up to the server. -There's two good reasons and one bad reason for not optimizing the client side -like the server side: +There's two good reasons and one bad reason for not optimizing Window Bits on +the client side as on the server side: 1. If the maintainers of a server configured some optimized settings, we don't want to override them with more restrictive settings. @@ -196,8 +191,9 @@ like the server side: 2. Optimizing memory usage doesn't matter very much for clients because it's uncommon to open thousands of client connections in a program. -3. On a more pragmatic note, some servers misbehave badly when a client - configures compression settings. `AWS API Gateway`_ is the worst offender. +3. On a more pragmatic and annoying note, some servers misbehave badly when a + client configures compression settings. `AWS API Gateway`_ is the worst + offender. .. _AWS API Gateway: https://github.com/python-websockets/websockets/issues/1065 @@ -207,6 +203,29 @@ like the server side: Until the ecosystem levels up, interoperability with buggy servers seems more valuable than optimizing memory usage. +Decompression +------------- + +The discussion above focuses on compression because it's more expensive than +decompression. Indeed, leaving aside small allocations, theoretical memory +usage is: + +* ``(1 << (windowBits + 2)) + (1 << (memLevel + 9))`` for compression; +* ``1 << windowBits`` for decompression. + +CPU usage is also higher for compression than decompression. + +While it's always possible for a server to use a smaller window size for +compressing outgoing messages, using a smaller window size for decompressing +incoming messages requires collaboration from clients. + +When a client doesn't support configuring the size of its compression window, +websockets enables compression with the largest possible decompression window. +In most use cases, this is more efficient than disabling compression both ways. + +If you are very sensitive to memory usage, you can reverse this behavior by +setting the ``require_client_max_window_bits`` parameter of +:class:`ServerPerMessageDeflateFactory` to ``True``. Further reading --------------- @@ -216,7 +235,7 @@ settings affect memory usage and how to optimize them. .. _blog post by Ilya Grigorik: https://www.igvita.com/2013/11/27/configuring-and-optimizing-websocket-compression/ -This `experiment by Peter Thorson`_ recommends Window Bits = 11 and Memory -Level = 4 for optimizing memory usage. +This `experiment by Peter Thorson`_ recommends Window Bits = 11 and Memory +Level = 4 for optimizing memory usage. .. _experiment by Peter Thorson: https://mailarchive.ietf.org/arch/msg/hybi/F9t4uPufVEy8KBLuL36cZjCmM_Y/ diff --git a/docs/topics/design.rst b/docs/topics/design.rst index f164d2990..cc65e6a70 100644 --- a/docs/topics/design.rst +++ b/docs/topics/design.rst @@ -488,55 +488,6 @@ they're drained. That's why all APIs that write frames are asynchronous. Of course, it's still possible for an application to create its own unbounded buffers and break the backpressure. Be careful with queues. - -.. _buffers: - -Buffers -------- - -.. note:: - - This section discusses buffers from the perspective of a server but it - applies to clients as well. - -An asynchronous systems works best when its buffers are almost always empty. - -For example, if a client sends data too fast for a server, the queue of -incoming messages will be constantly full. The server will always be 32 -messages (by default) behind the client. This consumes memory and increases -latency for no good reason. The problem is called bufferbloat. - -If buffers are almost always full and that problem cannot be solved by adding -capacity — typically because the system is bottlenecked by the output and -constantly regulated by backpressure — reducing the size of buffers minimizes -negative consequences. - -By default websockets has rather high limits. You can decrease them according -to your application's characteristics. - -Bufferbloat can happen at every level in the stack where there is a buffer. -For each connection, the receiving side contains these buffers: - -- OS buffers: tuning them is an advanced optimization. -- :class:`~asyncio.StreamReader` bytes buffer: the default limit is 64 KiB. - You can set another limit by passing a ``read_limit`` keyword argument to - :func:`~client.connect()` or :func:`~server.serve`. -- Incoming messages :class:`~collections.deque`: its size depends both on - the size and the number of messages it contains. By default the maximum - UTF-8 encoded size is 1 MiB and the maximum number is 32. In the worst case, - after UTF-8 decoding, a single message could take up to 4 MiB of memory and - the overall memory consumption could reach 128 MiB. You should adjust these - limits by setting the ``max_size`` and ``max_queue`` keyword arguments of - :func:`~client.connect()` or :func:`~server.serve` according to your - application's requirements. - -For each connection, the sending side contains these buffers: - -- :class:`~asyncio.StreamWriter` bytes buffer: the default size is 64 KiB. - You can set another limit by passing a ``write_limit`` keyword argument to - :func:`~client.connect()` or :func:`~server.serve`. -- OS buffers: tuning them is an advanced optimization. - Concurrency ----------- diff --git a/docs/topics/memory.rst b/docs/topics/memory.rst index e44247a77..efbcbb83f 100644 --- a/docs/topics/memory.rst +++ b/docs/topics/memory.rst @@ -1,5 +1,5 @@ -Memory usage -============ +Memory and buffers +================== .. currentmodule:: websockets @@ -9,40 +9,148 @@ memory usage can become a bottleneck. Memory usage of a single connection is the sum of: -1. the baseline amount of memory websockets requires for each connection, -2. the amount of data held in buffers before the application processes it, -3. any additional memory allocated by the application itself. +1. the baseline amount of memory that websockets uses for each connection; +2. the amount of memory needed by your application code; +3. the amount of data held in buffers. -Baseline --------- +Connection +---------- -Compression settings are the main factor affecting the baseline amount of -memory used by each connection. +Compression settings are the primary factor affecting how much memory each +connection uses. -With websockets' defaults, on the server side, a single connections uses -70 KiB of memory. +The :mod:`asyncio` implementation with default settings uses 64 KiB of memory +for each connection. + +You can reduce memory usage to 14 KiB per connection if you disable compression +entirely. Refer to the :doc:`topic guide on compression <../topics/compression>` to learn more about tuning compression settings. +Application +----------- + +Your application will allocate memory for its data structures. Memory usage +depends on your use case and your implementation. + +Make sure that you don't keep references to data that you don't need anymore +because this prevents garbage collection. + Buffers ------- -Under normal circumstances, buffers are almost always empty. +Typical WebSocket applications exchange small messages at a rate that doesn't +saturate the CPU or the network. Buffers are almost always empty. This is the +optimal situation. Buffers absorb bursts of incoming or outgoing messages +without having to pause reading or writing. + +If the application receives messages faster than it can process them, receive +buffers will fill up when. If the application sends messages faster than the +network can transmit them, send buffers will fill up. + +When buffers are almost always full, not only does the additional memory usage +fail to bring any benefit, but latency degrades as well. This problem is called +bufferbloat_. If it cannot be resolved by adding capacity, typically because the +system is bottlenecked by its output and constantly regulated by +:ref:`backpressure `, then buffers should be kept small to ensure +that backpressure kicks in quickly. + +.. _bufferbloat: https://en.wikipedia.org/wiki/Bufferbloat + +To sum up, buffers should be sized to absorb bursts of messages. Making them +larger than necessary often causes more harm than good. + +There are three levels of buffering in an application built with websockets. + +TCP buffers +........... + +The operating system allocates buffers for each TCP connection. The receive +buffer stores data received from the network until the application reads it. +The send buffer stores data written by the application until it's sent to +the network and acknowledged by the recipient. + +Modern operating systems adjust the size of TCP buffers automatically to match +network conditions. Overall, you shouldn't worry about TCP buffers. Just be +aware that they exist. + +In very high throughput scenarios, TCP buffers may grow to several megabytes +to store the data in flight. Then, they can make up the bulk of the memory +usage of a connection. + +I/O library buffers +................... + +I/O libraries like :mod:`asyncio` may provide read and write buffers to reduce +the frequency of system calls or the need to pause reading or writing. + +You should keep these buffers small. Increasing them can help with spiky +workloads but it can also backfire because it delays backpressure. + +* In the new :mod:`asyncio` implementation, there is no library-level read + buffer. + + There is a write buffer. The ``write_limit`` argument of + :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` controls its + size. When the write buffer grows above the high-water mark, + :meth:`~asyncio.connection.Connection.send` waits until it drains under the + low-water mark to return. This creates backpressure on coroutines that send + messages. + +* In the legacy :mod:`asyncio` implementation, there is a library-level read + buffer. The ``read_limit`` argument of :func:`~client.connect` and + :func:`~server.serve` controls its size. When the read buffer grows above the + high-water mark, the connection stops reading from the network until it drains + under the low-water mark. This creates backpressure on the TCP connection. + + There is a write buffer. It as controlled by ``write_limit``. It behaves like + the new :mod:`asyncio` implementation described above. + +* In the :mod:`threading` implementation, there are no library-level buffers. + All I/O operations are performed directly on the :class:`~socket.socket`. + +websockets' buffers +................... + +Incoming messages are queued in a buffer after they have been received from the +network and parsed. A larger buffer may help a slow applications handle bursts +of messages while remaining responsive to control frames. + +The memory footprint of this buffer is bounded by the product of ``max_size``, +which controls the size of items in the queue, and ``max_queue``, which controls +the number of items. + +The ``max_size`` argument of :func:`~asyncio.client.connect` and +:func:`~asyncio.server.serve` defaults to 1 MiB. Most applications never receive +such large messages. Configuring a smaller value puts a tighter boundary on +memory usage. This can make your application more resilient to denial of service +attacks. + +The behavior of the ``max_queue`` argument of :func:`~asyncio.client.connect` +and :func:`~asyncio.server.serve` varies across implementations. -Under high load, if a server receives more messages than it can process, -bufferbloat can result in excessive memory usage. +* In the new :mod:`asyncio` implementation, ``max_queue`` is the high-water mark + of a queue of incoming frames. It defaults to 16 frames. If the queue grows + larger, the connection stops reading from the network until the application + consumes messages and the queue goes below the low-water mark. This creates + backpressure on the TCP connection. -By default websockets has generous limits. It is strongly recommended to adapt -them to your application. When you call :func:`~server.serve`: + Each item in the queue is a frame. A frame can be a message or a message + fragment. Either way, it must be smaller than ``max_size``, the maximum size + of a message. The queue may use up to ``max_size * max_queue`` bytes of + memory. By default, this is 16 MiB. -- Set ``max_size`` (default: 1 MiB, UTF-8 encoded) to the maximum size of - messages your application generates. -- Set ``max_queue`` (default: 32) to the maximum number of messages your - application expects to receive faster than it can process them. The queue - provides burst tolerance without slowing down the TCP connection. +* In the legacy :mod:`asyncio` implementation, ``max_queue`` is the maximum + size of a queue of incoming messages. It defaults to 32 messages. If the queue + fills up, the connection stops reading from the library-level read buffer + described above. If that buffer fills up as well, it will create backpressure + on the TCP connection. -Furthermore, you can lower ``read_limit`` and ``write_limit`` (default: -64 KiB) to reduce the size of buffers for incoming and outgoing data. + Text messages are decoded before they're added to the queue. Since Python can + use up to 4 bytes of memory per character, the queue may use up to ``4 * + max_size * max_queue`` bytes of memory. By default, this is 128 MiB. -The design document provides :ref:`more details about buffers `. +* In the :mod:`threading` implementation, there is no queue of incoming + messages. The ``max_queue`` argument doesn't exist. The connection keeps at + most one message in memory at a time. diff --git a/experiments/compression/benchmark.py b/experiments/compression/benchmark.py index 4fbdf6220..86ebece31 100644 --- a/experiments/compression/benchmark.py +++ b/experiments/compression/benchmark.py @@ -1,72 +1,32 @@ #!/usr/bin/env python -import getpass -import json -import pickle -import subprocess +import collections +import pathlib import sys import time import zlib -CORPUS_FILE = "corpus.pkl" - REPEAT = 10 WB, ML = 12, 5 # defaults used as a reference -def _corpus(): - OAUTH_TOKEN = getpass.getpass("OAuth Token? ") - COMMIT_API = ( - f'curl -H "Authorization: token {OAUTH_TOKEN}" ' - f"https://api.github.com/repos/python-websockets/websockets/git/commits/:sha" - ) - - commits = [] - - head = subprocess.check_output("git rev-parse HEAD", shell=True).decode().strip() - todo = [head] - seen = set() - - while todo: - sha = todo.pop(0) - commit = subprocess.check_output(COMMIT_API.replace(":sha", sha), shell=True) - commits.append(commit) - seen.add(sha) - for parent in json.loads(commit)["parents"]: - sha = parent["sha"] - if sha not in seen and sha not in todo: - todo.append(sha) - time.sleep(1) # rate throttling - - return commits - - -def corpus(): - data = _corpus() - with open(CORPUS_FILE, "wb") as handle: - pickle.dump(data, handle) - - -def _run(data): - size = {} - duration = {} +def benchmark(data): + size = collections.defaultdict(dict) + duration = collections.defaultdict(dict) for wbits in range(9, 16): - size[wbits] = {} - duration[wbits] = {} - for memLevel in range(1, 10): encoder = zlib.compressobj(wbits=-wbits, memLevel=memLevel) encoded = [] + print(f"Compressing {REPEAT} times with {wbits=} and {memLevel=}") + t0 = time.perf_counter() for _ in range(REPEAT): for item in data: - if isinstance(item, str): - item = item.encode() # Taken from PerMessageDeflate.encode item = encoder.compress(item) + encoder.flush(zlib.Z_SYNC_FLUSH) if item.endswith(b"\x00\x00\xff\xff"): @@ -75,7 +35,7 @@ def _run(data): t1 = time.perf_counter() - size[wbits][memLevel] = sum(len(item) for item in encoded) + size[wbits][memLevel] = sum(len(item) for item in encoded) / REPEAT duration[wbits][memLevel] = (t1 - t0) / REPEAT raw_size = sum(len(item) for item in data) @@ -149,15 +109,13 @@ def _run(data): print() -def run(): - with open(CORPUS_FILE, "rb") as handle: - data = pickle.load(handle) - _run(data) +def main(corpus): + data = [file.read_bytes() for file in corpus.iterdir()] + benchmark(data) -try: - run = globals()[sys.argv[1]] -except (KeyError, IndexError): - print(f"Usage: {sys.argv[0]} [corpus|run]") -else: - run() +if __name__ == "__main__": + if len(sys.argv) < 2: + print(f"Usage: {sys.argv[0]} [directory]") + sys.exit(2) + main(pathlib.Path(sys.argv[1])) diff --git a/experiments/compression/client.py b/experiments/compression/client.py index 3ee19ddc5..69bfd5e7c 100644 --- a/experiments/compression/client.py +++ b/experiments/compression/client.py @@ -4,8 +4,8 @@ import statistics import tracemalloc -import websockets -from websockets.extensions import permessage_deflate +from websockets.asyncio.client import connect +from websockets.extensions.permessage_deflate import ClientPerMessageDeflateFactory CLIENTS = 20 @@ -16,16 +16,16 @@ MEM_SIZE = [] -async def client(client): +async def client(num): # Space out connections to make them sequential. - await asyncio.sleep(client * INTERVAL) + await asyncio.sleep(num * INTERVAL) tracemalloc.start() - async with websockets.connect( + async with connect( "ws://localhost:8765", extensions=[ - permessage_deflate.ClientPerMessageDeflateFactory( + ClientPerMessageDeflateFactory( server_max_window_bits=WB, client_max_window_bits=WB, compress_settings={"memLevel": ML}, @@ -42,11 +42,13 @@ async def client(client): tracemalloc.stop() # Hold connection open until the end of the test. - await asyncio.sleep(CLIENTS * INTERVAL) + await asyncio.sleep((CLIENTS + 1 - num) * INTERVAL) async def clients(): - await asyncio.gather(*[client(client) for client in range(CLIENTS + 1)]) + # Start one more client than necessary because we will ignore + # non-representative results from the first connection. + await asyncio.gather(*[client(num) for num in range(CLIENTS + 1)]) asyncio.run(clients()) diff --git a/experiments/compression/corpus.py b/experiments/compression/corpus.py new file mode 100644 index 000000000..da5661dfa --- /dev/null +++ b/experiments/compression/corpus.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python + +import getpass +import json +import pathlib +import subprocess +import sys +import time + + +def github_commits(): + OAUTH_TOKEN = getpass.getpass("OAuth Token? ") + COMMIT_API = ( + f'curl -H "Authorization: token {OAUTH_TOKEN}" ' + f"https://api.github.com/repos/python-websockets/websockets/git/commits/:sha" + ) + + commits = [] + + head = subprocess.check_output( + "git rev-parse origin/main", + shell=True, + text=True, + ).strip() + todo = [head] + seen = set() + + while todo: + sha = todo.pop(0) + commit = subprocess.check_output(COMMIT_API.replace(":sha", sha), shell=True) + commits.append(commit) + seen.add(sha) + for parent in json.loads(commit)["parents"]: + sha = parent["sha"] + if sha not in seen and sha not in todo: + todo.append(sha) + time.sleep(1) # rate throttling + + return commits + + +def main(corpus): + data = github_commits() + for num, content in enumerate(reversed(data)): + (corpus / f"{num:04d}.json").write_bytes(content) + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print(f"Usage: {sys.argv[0]} [directory]") + sys.exit(2) + main(pathlib.Path(sys.argv[1])) diff --git a/experiments/compression/server.py b/experiments/compression/server.py index 8d1ee3cd7..1c28f7355 100644 --- a/experiments/compression/server.py +++ b/experiments/compression/server.py @@ -6,8 +6,8 @@ import statistics import tracemalloc -import websockets -from websockets.extensions import permessage_deflate +from websockets.asyncio.server import serve +from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory CLIENTS = 20 @@ -44,12 +44,12 @@ async def server(): print() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve( + async with serve( handler, "localhost", 8765, extensions=[ - permessage_deflate.ServerPerMessageDeflateFactory( + ServerPerMessageDeflateFactory( server_max_window_bits=WB, client_max_window_bits=WB, compress_settings={"memLevel": ML}, @@ -63,7 +63,7 @@ async def server(): asyncio.run(server()) -# First connection may incur non-representative setup costs. +# First connection incurs non-representative setup costs. del MEM_SIZE[0] print(f"µ = {statistics.mean(MEM_SIZE) / 1024:.1f} KiB") diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index fea14131e..5b907b79f 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -62,7 +62,8 @@ def __init__( if not self.local_no_context_takeover: self.encoder = zlib.compressobj( - wbits=-self.local_max_window_bits, **self.compress_settings + wbits=-self.local_max_window_bits, + **self.compress_settings, ) # To handle continuation frames properly, we must keep track of @@ -156,7 +157,8 @@ def encode(self, frame: frames.Frame) -> frames.Frame: # Re-initialize per-message decoder. if self.local_no_context_takeover: self.encoder = zlib.compressobj( - wbits=-self.local_max_window_bits, **self.compress_settings + wbits=-self.local_max_window_bits, + **self.compress_settings, ) # Compress data. From c3b162d05c3788b9367eb3cce8c5001c37a3e6fa Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 17 Aug 2024 08:53:45 +0200 Subject: [PATCH 089/109] Add broadcast to the new asyncio implementation. --- docs/faq/server.rst | 2 +- docs/howto/upgrade.rst | 10 +- docs/intro/tutorial2.rst | 16 +-- docs/project/changelog.rst | 4 +- docs/reference/asyncio/server.rst | 2 +- docs/reference/new-asyncio/server.rst | 5 + docs/topics/broadcast.rst | 69 ++++++------ docs/topics/logging.rst | 2 +- docs/topics/performance.rst | 6 +- experiments/broadcast/server.py | 21 ++-- src/websockets/asyncio/connection.py | 100 ++++++++++++++++- src/websockets/legacy/protocol.py | 14 +-- src/websockets/sync/connection.py | 2 +- tests/asyncio/test_connection.py | 156 ++++++++++++++++++++++++++ tests/legacy/test_protocol.py | 12 +- 15 files changed, 341 insertions(+), 80 deletions(-) diff --git a/docs/faq/server.rst b/docs/faq/server.rst index cba1cd35f..53e34632f 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -102,7 +102,7 @@ Record all connections in a global variable:: finally: CONNECTIONS.remove(websocket) -Then, call :func:`~websockets.broadcast`:: +Then, call :func:`~asyncio.connection.broadcast`:: import websockets diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 10e8967d8..6efaf0f56 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -70,12 +70,6 @@ Missing features If your application relies on one of them, you should stick to the original implementation until the new implementation supports it in a future release. -Broadcast -......... - -The new implementation doesn't support :doc:`broadcasting messages -<../topics/broadcast>` yet. - Keepalive ......... @@ -178,8 +172,8 @@ Server APIs | :class:`websockets.server.WebSocketServerProtocol` |br| | | | ``websockets.legacy.server.WebSocketServerProtocol`` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| :func:`websockets.broadcast` |br| | *not available yet* | -| ``websockets.legacy.protocol.broadcast()`` | | +| ``websockets.broadcast`` |br| | :func:`websockets.asyncio.connection.broadcast` | +| :func:`websockets.legacy.protocol.broadcast()` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.BasicAuthWebSocketServerProtocol`` |br| | *not available yet* | | :class:`websockets.auth.BasicAuthWebSocketServerProtocol` |br| | | diff --git a/docs/intro/tutorial2.rst b/docs/intro/tutorial2.rst index 5ac4ae9dd..b8e35f292 100644 --- a/docs/intro/tutorial2.rst +++ b/docs/intro/tutorial2.rst @@ -482,7 +482,7 @@ you're using this pattern: ... Since this is a very common pattern in WebSocket servers, websockets provides -the :func:`broadcast` helper for this purpose: +the :func:`~legacy.protocol.broadcast` helper for this purpose: .. code-block:: python @@ -494,13 +494,14 @@ the :func:`broadcast` helper for this purpose: ... -Calling :func:`broadcast` once is more efficient than +Calling :func:`legacy.protocol.broadcast` once is more efficient than calling :meth:`~legacy.protocol.WebSocketCommonProtocol.send` in a loop. -However, there's a subtle difference in behavior. Did you notice that there's -no ``await`` in the second version? Indeed, :func:`broadcast` is a function, -not a coroutine like :meth:`~legacy.protocol.WebSocketCommonProtocol.send` -or :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`. +However, there's a subtle difference in behavior. Did you notice that there's no +``await`` in the second version? Indeed, :func:`legacy.protocol.broadcast` is a +function, not a coroutine like +:meth:`~legacy.protocol.WebSocketCommonProtocol.send` or +:meth:`~legacy.protocol.WebSocketCommonProtocol.recv`. It's quite obvious why :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` is a coroutine. When you want to receive the next message, you have to wait @@ -521,7 +522,8 @@ That said, when you're sending the same messages to many clients in a loop, applying backpressure in this way can become counterproductive. When you're broadcasting, you don't want to slow down everyone to the pace of the slowest clients; you want to drop clients that cannot keep up with the data stream. -That's why :func:`broadcast` doesn't wait until write buffers drain. +That's why :func:`legacy.protocol.broadcast` doesn't wait until write buffers +drain. For our Connect Four game, there's no difference in practice: the total amount of data sent on a connection for a game of Connect Four is less than 64 KB, diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index f033f5632..eaabb2e9f 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -212,7 +212,7 @@ Improvements * Added platform-independent wheels. -* Improved error handling in :func:`~websockets.broadcast`. +* Improved error handling in :func:`~legacy.protocol.broadcast`. * Set ``server_hostname`` automatically on TLS connections when providing a ``sock`` argument to :func:`~sync.client.connect`. @@ -402,7 +402,7 @@ New features * Added compatibility with Python 3.10. -* Added :func:`~websockets.broadcast` to send a message to many clients. +* Added :func:`~legacy.protocol.broadcast` to send a message to many clients. * Added support for reconnecting automatically by using :func:`~client.connect` as an asynchronous iterator. diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index 4bd52b40b..3636f0b33 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -110,4 +110,4 @@ websockets supports HTTP Basic Authentication according to Broadcast --------- -.. autofunction:: websockets.broadcast +.. autofunction:: websockets.legacy.protocol.broadcast diff --git a/docs/reference/new-asyncio/server.rst b/docs/reference/new-asyncio/server.rst index c43673d33..7f9de6148 100644 --- a/docs/reference/new-asyncio/server.rst +++ b/docs/reference/new-asyncio/server.rst @@ -70,3 +70,8 @@ Using a connection .. autoattribute:: response .. autoproperty:: subprotocol + +Broadcast +--------- + +.. autofunction:: websockets.asyncio.connection.broadcast diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index b6ddda734..671319136 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -1,21 +1,22 @@ -Broadcasting messages -===================== +Broadcasting +============ .. currentmodule:: websockets - -.. admonition:: If you just want to send a message to all connected clients, - use :func:`broadcast`. +.. admonition:: If you want to send a message to all connected clients, + use :func:`~asyncio.connection.broadcast`. :class: tip - If you want to learn about its design in depth, continue reading this - document. + If you want to learn about its design, continue reading this document. + + For the legacy :mod:`asyncio` implementation, use + :func:`~legacy.protocol.broadcast`. WebSocket servers often send the same message to all connected clients or to a subset of clients for which the message is relevant. -Let's explore options for broadcasting a message, explain the design -of :func:`broadcast`, and discuss alternatives. +Let's explore options for broadcasting a message, explain the design of +:func:`~asyncio.connection.broadcast`, and discuss alternatives. For each option, we'll provide a connection handler called ``handler()`` and a function or coroutine called ``broadcast()`` that sends a message to all @@ -24,7 +25,7 @@ connected clients. Integrating them is left as an exercise for the reader. You could start with:: import asyncio - import websockets + from websockets.asyncio.server import serve async def handler(websocket): ... @@ -39,7 +40,7 @@ Integrating them is left as an exercise for the reader. You could start with:: await broadcast(message) async def main(): - async with websockets.serve(handler, "localhost", 8765): + async with serve(handler, "localhost", 8765): await broadcast_messages() # runs forever if __name__ == "__main__": @@ -82,11 +83,13 @@ to:: Here's a coroutine that broadcasts a message to all clients:: + from websockets import ConnectionClosed + async def broadcast(message): for websocket in CLIENTS.copy(): try: await websocket.send(message) - except websockets.ConnectionClosed: + except ConnectionClosed: pass There are two tricks in this version of ``broadcast()``. @@ -117,11 +120,11 @@ which is usually outside of the control of the server. If you know for sure that you will never write more than ``write_limit`` bytes within ``ping_interval + ping_timeout``, then websockets will terminate slow -connections before the write buffer has time to fill up. +connections before the write buffer can fill up. -Don't set extreme ``write_limit``, ``ping_interval``, and ``ping_timeout`` -values to ensure that this condition holds. Set reasonable values and use the -built-in :func:`broadcast` function instead. +Don't set extreme values of ``write_limit``, ``ping_interval``, or +``ping_timeout`` to ensure that this condition holds! Instead, set reasonable +values and use the built-in :func:`~asyncio.connection.broadcast` function. The concurrent way ------------------ @@ -134,7 +137,7 @@ Let's modify ``broadcast()`` to send messages concurrently:: async def send(websocket, message): try: await websocket.send(message) - except websockets.ConnectionClosed: + except ConnectionClosed: pass def broadcast(message): @@ -179,20 +182,20 @@ doesn't work well when broadcasting a message to thousands of clients. When you're sending messages to a single client, you don't want to send them faster than the network can transfer them and the client accept them. This is -why :meth:`~server.WebSocketServerProtocol.send` checks if the write buffer -is full and, if it is, waits until it drain, giving the network and the -client time to catch up. This provides backpressure. +why :meth:`~asyncio.server.ServerConnection.send` checks if the write buffer is +above the high-water mark and, if it is, waits until it drains, giving the +network and the client time to catch up. This provides backpressure. Without backpressure, you could pile up data in the write buffer until the server process runs out of memory and the operating system kills it. -The :meth:`~server.WebSocketServerProtocol.send` API is designed to enforce +The :meth:`~asyncio.server.ServerConnection.send` API is designed to enforce backpressure by default. This helps users of websockets write robust programs even if they never heard about backpressure. For comparison, :class:`asyncio.StreamWriter` requires users to understand -backpressure and to await :meth:`~asyncio.StreamWriter.drain` explicitly -after each :meth:`~asyncio.StreamWriter.write`. +backpressure and to await :meth:`~asyncio.StreamWriter.drain` after each +:meth:`~asyncio.StreamWriter.write` — or at least sufficiently frequently. When broadcasting messages, backpressure consists in slowing down all clients in an attempt to let the slowest client catch up. With thousands of clients, @@ -203,14 +206,14 @@ How do we avoid running out of memory when slow clients can't keep up with the broadcast rate, then? The most straightforward option is to disconnect them. If a client gets too far behind, eventually it reaches the limit defined by -``ping_timeout`` and websockets terminates the connection. You can read the -discussion of :doc:`keepalive and timeouts <./timeouts>` for details. +``ping_timeout`` and websockets terminates the connection. You can refer to +the discussion of :doc:`keepalive and timeouts ` for details. -How :func:`broadcast` works ---------------------------- +How :func:`~asyncio.connection.broadcast` works +----------------------------------------------- -The built-in :func:`broadcast` function is similar to the naive way. The main -difference is that it doesn't apply backpressure. +The built-in :func:`~asyncio.connection.broadcast` function is similar to the +naive way. The main difference is that it doesn't apply backpressure. This provides the best performance by avoiding the overhead of scheduling and running one task per client. @@ -321,9 +324,9 @@ the asynchronous iterator returned by ``subscribe()``. Performance considerations -------------------------- -The built-in :func:`broadcast` function sends all messages without yielding -control to the event loop. So does the naive way when the network and clients -are fast and reliable. +The built-in :func:`~asyncio.connection.broadcast` function sends all messages +without yielding control to the event loop. So does the naive way when the +network and clients are fast and reliable. For each client, a WebSocket frame is prepared and sent to the network. This is the minimum amount of work required to broadcast a message. @@ -343,7 +346,7 @@ However, this isn't possible in general for two reasons: All other patterns discussed above yield control to the event loop once per client because messages are sent by different tasks. This makes them slower -than the built-in :func:`broadcast` function. +than the built-in :func:`~asyncio.connection.broadcast` function. There is no major difference between the performance of per-client queues and publish–subscribe. diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index 873c852c2..765278360 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -220,7 +220,7 @@ Here's what websockets logs at each level. ``WARNING`` ........... -* Failures in :func:`~websockets.broadcast` +* Failures in :func:`~asyncio.connection.broadcast` ``INFO`` ........ diff --git a/docs/topics/performance.rst b/docs/topics/performance.rst index 45e23b239..b226cec43 100644 --- a/docs/topics/performance.rst +++ b/docs/topics/performance.rst @@ -1,6 +1,8 @@ Performance =========== +.. currentmodule:: websockets + Here are tips to optimize performance. uvloop @@ -16,5 +18,5 @@ application.) broadcast --------- -:func:`~websockets.broadcast` is the most efficient way to send a message to -many clients. +:func:`~asyncio.connection.broadcast` is the most efficient way to send a +message to many clients. diff --git a/experiments/broadcast/server.py b/experiments/broadcast/server.py index b0407ba34..0a5c82b3c 100644 --- a/experiments/broadcast/server.py +++ b/experiments/broadcast/server.py @@ -6,7 +6,9 @@ import sys import time -import websockets +from websockets import ConnectionClosed +from websockets.asyncio.server import serve +from websockets.asyncio.connection import broadcast CLIENTS = set() @@ -15,7 +17,7 @@ async def send(websocket, message): try: await websocket.send(message) - except websockets.ConnectionClosed: + except ConnectionClosed: pass @@ -43,9 +45,6 @@ async def subscribe(self): __aiter__ = subscribe -PUBSUB = PubSub() - - async def handler(websocket, method=None): if method in ["default", "naive", "task", "wait"]: CLIENTS.add(websocket) @@ -63,14 +62,18 @@ async def handler(websocket, method=None): CLIENTS.remove(queue) relay_task.cancel() elif method == "pubsub": + global PUBSUB async for message in PUBSUB: await websocket.send(message) else: raise NotImplementedError(f"unsupported method: {method}") -async def broadcast(method, size, delay): +async def broadcast_messages(method, size, delay): """Broadcast messages at regular intervals.""" + if method == "pubsub": + global PUBSUB + PUBSUB = PubSub() load_average = 0 time_average = 0 pc1, pt1 = time.perf_counter_ns(), time.process_time_ns() @@ -90,7 +93,7 @@ async def broadcast(method, size, delay): message = str(time.time_ns()).encode() + b" " + os.urandom(size - 20) if method == "default": - websockets.broadcast(CLIENTS, message) + broadcast(CLIENTS, message) elif method == "naive": # Since the loop can yield control, make a copy of CLIENTS # to avoid: RuntimeError: Set changed size during iteration @@ -128,14 +131,14 @@ async def broadcast(method, size, delay): async def main(method, size, delay): - async with websockets.serve( + async with serve( functools.partial(handler, method=method), "localhost", 8765, compression=None, ping_timeout=None, ): - await broadcast(method, size, delay) + await broadcast_messages(method, size, delay) if __name__ == "__main__": diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 1c4424f0d..9d2f087da 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -6,6 +6,7 @@ import logging import random import struct +import sys import uuid from types import TracebackType from typing import ( @@ -27,7 +28,7 @@ from .messages import Assembler -__all__ = ["Connection"] +__all__ = ["Connection", "broadcast"] class Connection(asyncio.Protocol): @@ -338,7 +339,6 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No Raises: ConnectionClosed: When the connection is closed. - RuntimeError: If the connection busy sending a fragmented message. TypeError: If ``message`` doesn't have a supported type. """ @@ -488,7 +488,7 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No self.fragmented_send_waiter = None else: - raise TypeError("data must be bytes, str, iterable, or async iterable") + raise TypeError("data must be str, bytes, iterable, or async iterable") async def close(self, code: int = 1000, reason: str = "") -> None: """ @@ -673,7 +673,7 @@ async def send_context( On entry, :meth:`send_context` checks that the connection is open; on exit, it writes outgoing data to the socket:: - async async with self.send_context(): + async with self.send_context(): self.protocol.send_text(message.encode()) When the connection isn't open on entry, when the connection is expected @@ -916,3 +916,95 @@ def eof_received(self) -> None: # As a consequence, they never need to write after receiving EOF, so # there's no reason to keep the transport open by returning True. # Besides, that doesn't work on TLS connections. + + +def broadcast( + connections: Iterable[Connection], + message: Data, + raise_exceptions: bool = False, +) -> None: + """ + Broadcast a message to several WebSocket connections. + + A string (:class:`str`) is sent as a Text_ frame. A bytestring or bytes-like + object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent + as a Binary_ frame. + + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + :func:`broadcast` pushes the message synchronously to all connections even + if their write buffers are overflowing. There's no backpressure. + + If you broadcast messages faster than a connection can handle them, messages + will pile up in its write buffer until the connection times out. Keep + ``ping_interval`` and ``ping_timeout`` low to prevent excessive memory usage + from slow connections. + + Unlike :meth:`~Connection.send`, :func:`broadcast` doesn't support sending + fragmented messages. Indeed, fragmentation is useful for sending large + messages without buffering them in memory, while :func:`broadcast` buffers + one copy per connection as fast as possible. + + :func:`broadcast` skips connections that aren't open in order to avoid + errors on connections where the closing handshake is in progress. + + :func:`broadcast` ignores failures to write the message on some connections. + It continues writing to other connections. On Python 3.11 and above, you may + set ``raise_exceptions`` to :obj:`True` to record failures and raise all + exceptions in a :pep:`654` :exc:`ExceptionGroup`. + + Args: + websockets: WebSocket connections to which the message will be sent. + message: Message to send. + raise_exceptions: Whether to raise an exception in case of failures. + + Raises: + TypeError: If ``message`` doesn't have a supported type. + + """ + if isinstance(message, str): + send_method = "send_text" + message = message.encode() + elif isinstance(message, BytesLike): + send_method = "send_binary" + else: + raise TypeError("data must be str or bytes") + + if raise_exceptions: + if sys.version_info[:2] < (3, 11): # pragma: no cover + raise ValueError("raise_exceptions requires at least Python 3.11") + exceptions = [] + + for connection in connections: + if connection.protocol.state is not OPEN: + continue + + if connection.fragmented_send_waiter is not None: + if raise_exceptions: + exception = RuntimeError("sending a fragmented message") + exceptions.append(exception) + else: + connection.logger.warning( + "skipped broadcast: sending a fragmented message", + ) + continue + + try: + # Call connection.protocol.send_text or send_binary. + # Either way, message is already converted to bytes. + getattr(connection.protocol, send_method)(message) + connection.send_data() + except Exception as write_exception: + if raise_exceptions: + exception = RuntimeError("failed to write message") + exception.__cause__ = write_exception + exceptions.append(exception) + else: + connection.logger.warning( + "skipped broadcast: failed to write message", + exc_info=True, + ) + + if raise_exceptions and exceptions: + raise ExceptionGroup("skipped broadcast", exceptions) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 6f8916576..b948257e0 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1570,18 +1570,17 @@ def broadcast( ``ping_interval`` and ``ping_timeout`` low to prevent excessive memory usage from slow connections. - Unlike :meth:`~websockets.server.WebSocketServerProtocol.send`, - :func:`broadcast` doesn't support sending fragmented messages. Indeed, - fragmentation is useful for sending large messages without buffering them in - memory, while :func:`broadcast` buffers one copy per connection as fast as - possible. + Unlike :meth:`~WebSocketCommonProtocol.send`, :func:`broadcast` doesn't + support sending fragmented messages. Indeed, fragmentation is useful for + sending large messages without buffering them in memory, while + :func:`broadcast` buffers one copy per connection as fast as possible. :func:`broadcast` skips connections that aren't open in order to avoid errors on connections where the closing handshake is in progress. :func:`broadcast` ignores failures to write the message on some connections. - It continues writing to other connections. On Python 3.11 and above, you - may set ``raise_exceptions`` to :obj:`True` to record failures and raise all + It continues writing to other connections. On Python 3.11 and above, you may + set ``raise_exceptions`` to :obj:`True` to record failures and raise all exceptions in a :pep:`654` :exc:`ExceptionGroup`. Args: @@ -1615,6 +1614,7 @@ def broadcast( websocket.logger.warning( "skipped broadcast: sending a fragmented message", ) + continue try: websocket.write_frame_sync(True, opcode, data) diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index a4826c785..88d6aee1f 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -387,7 +387,7 @@ def send(self, message: Data | Iterable[Data]) -> None: raise else: - raise TypeError("data must be bytes, str, or iterable") + raise TypeError("data must be str, bytes, or iterable") def close(self, code: int = CloseCode.NORMAL_CLOSURE, reason: str = "") -> None: """ diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 02029b754..1cf382a01 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -2,6 +2,7 @@ import contextlib import logging import socket +import sys import unittest import uuid from unittest.mock import Mock, patch @@ -1004,6 +1005,161 @@ async def test_unexpected_failure_in_send_context(self, send_text): self.assertEqual(str(exc), "no close frame received or sent") self.assertIsInstance(exc.__cause__, AssertionError) + # Test broadcast. + + async def test_broadcast_text(self): + """broadcast broadcasts a text message.""" + broadcast([self.connection], "😀") + await self.assertFrameSent(Frame(Opcode.TEXT, "😀".encode())) + + @unittest.skipIf( + sys.version_info[:2] < (3, 11), + "raise_exceptions requires Python 3.11+", + ) + async def test_broadcast_text_reports_no_errors(self): + """broadcast broadcasts a text message without raising exceptions.""" + broadcast([self.connection], "😀", raise_exceptions=True) + await self.assertFrameSent(Frame(Opcode.TEXT, "😀".encode())) + + async def test_broadcast_binary(self): + """broadcast broadcasts a binary message.""" + broadcast([self.connection], b"\x01\x02\xfe\xff") + await self.assertFrameSent(Frame(Opcode.BINARY, b"\x01\x02\xfe\xff")) + + @unittest.skipIf( + sys.version_info[:2] < (3, 11), + "raise_exceptions requires Python 3.11+", + ) + async def test_broadcast_binary_reports_no_errors(self): + """broadcast broadcasts a binary message without raising exceptions.""" + broadcast([self.connection], b"\x01\x02\xfe\xff", raise_exceptions=True) + await self.assertFrameSent(Frame(Opcode.BINARY, b"\x01\x02\xfe\xff")) + + async def test_broadcast_no_clients(self): + """broadcast does nothing when called with an empty list of clients.""" + broadcast([], "😀") + await self.assertNoFrameSent() + + async def test_broadcast_two_clients(self): + """broadcast broadcasts a message to several clients.""" + broadcast([self.connection, self.connection], "😀") + await self.assertFramesSent( + [ + Frame(Opcode.TEXT, "😀".encode()), + Frame(Opcode.TEXT, "😀".encode()), + ] + ) + + async def test_broadcast_skips_closed_connection(self): + """broadcast ignores closed connections.""" + await self.connection.close() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + with self.assertNoLogs(): + broadcast([self.connection], "😀") + await self.assertNoFrameSent() + + async def test_broadcast_skips_closing_connection(self): + """broadcast ignores closing connections.""" + async with self.delay_frames_rcvd(MS): + close_task = asyncio.create_task(self.connection.close()) + await asyncio.sleep(0) + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + with self.assertNoLogs(): + broadcast([self.connection], "😀") + await self.assertNoFrameSent() + + await close_task + + async def test_broadcast_skips_connection_with_send_blocked(self): + """broadcast logs a warning when a connection is blocked in send.""" + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + + send_task = asyncio.create_task(self.connection.send(fragments())) + await asyncio.sleep(MS) + await self.assertFrameSent(Frame(Opcode.TEXT, "⏳".encode(), fin=False)) + + with self.assertLogs("websockets", logging.WARNING) as logs: + broadcast([self.connection], "😀") + + self.assertEqual( + [record.getMessage() for record in logs.records][:2], + ["skipped broadcast: sending a fragmented message"], + ) + + gate.set_result(None) + await send_task + + @unittest.skipIf( + sys.version_info[:2] < (3, 11), + "raise_exceptions requires Python 3.11+", + ) + async def test_broadcast_reports_connection_with_send_blocked(self): + """broadcast raises exceptions for connections blocked in send.""" + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + + send_task = asyncio.create_task(self.connection.send(fragments())) + await asyncio.sleep(MS) + await self.assertFrameSent(Frame(Opcode.TEXT, "⏳".encode(), fin=False)) + + with self.assertRaises(ExceptionGroup) as raised: + broadcast([self.connection], "😀", raise_exceptions=True) + + self.assertEqual(str(raised.exception), "skipped broadcast (1 sub-exception)") + exc = raised.exception.exceptions[0] + self.assertEqual(str(exc), "sending a fragmented message") + self.assertIsInstance(exc, RuntimeError) + + gate.set_result(None) + await send_task + + async def test_broadcast_skips_connection_failing_to_send(self): + """broadcast logs a warning when a connection fails to send.""" + # Inject a fault by shutting down the transport for writing. + self.transport.write_eof() + + with self.assertLogs("websockets", logging.WARNING) as logs: + broadcast([self.connection], "😀") + + self.assertEqual( + [record.getMessage() for record in logs.records][:2], + ["skipped broadcast: failed to write message"], + ) + + @unittest.skipIf( + sys.version_info[:2] < (3, 11), + "raise_exceptions requires Python 3.11+", + ) + async def test_broadcast_reports_connection_failing_to_send(self): + """broadcast raises exceptions for connections failing to send.""" + # Inject a fault by shutting down the transport for writing. + self.transport.write_eof() + + with self.assertRaises(ExceptionGroup) as raised: + broadcast([self.connection], "😀", raise_exceptions=True) + + self.assertEqual(str(raised.exception), "skipped broadcast (1 sub-exception)") + exc = raised.exception.exceptions[0] + self.assertEqual(str(exc), "failed to write message") + self.assertIsInstance(exc, RuntimeError) + cause = exc.__cause__ + self.assertEqual(str(cause), "Cannot call write() after write_eof()") + self.assertIsInstance(cause, RuntimeError) + + async def test_broadcast_type_error(self): + """broadcast raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + broadcast([self.connection], ["⏳", "⌛️"]) + class ServerConnectionTests(ClientConnectionTests): LOCAL = SERVER diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index d6303dcc7..ccea34719 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1473,7 +1473,8 @@ def test_broadcast_text(self): self.assertOneFrameSent(True, OP_TEXT, "café".encode()) @unittest.skipIf( - sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+" + sys.version_info[:2] < (3, 11), + "raise_exceptions requires Python 3.11+", ) def test_broadcast_text_reports_no_errors(self): broadcast([self.protocol], "café", raise_exceptions=True) @@ -1484,7 +1485,8 @@ def test_broadcast_binary(self): self.assertOneFrameSent(True, OP_BINARY, b"tea") @unittest.skipIf( - sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+" + sys.version_info[:2] < (3, 11), + "raise_exceptions requires Python 3.11+", ) def test_broadcast_binary_reports_no_errors(self): broadcast([self.protocol], b"tea", raise_exceptions=True) @@ -1536,7 +1538,8 @@ def test_broadcast_skips_connection_sending_fragmented_text(self): ) @unittest.skipIf( - sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+" + sys.version_info[:2] < (3, 11), + "raise_exceptions requires Python 3.11+", ) def test_broadcast_reports_connection_sending_fragmented_text(self): self.make_drain_slow() @@ -1565,7 +1568,8 @@ def test_broadcast_skips_connection_failing_to_send(self): ) @unittest.skipIf( - sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+" + sys.version_info[:2] < (3, 11), + "raise_exceptions requires Python 3.11+", ) def test_broadcast_reports_connection_failing_to_send(self): # Configure mock to raise an exception when writing to the network. From 8d9f9a1cc791df01d7995693551cd9cf83e154c2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 17 Aug 2024 11:15:51 +0200 Subject: [PATCH 090/109] Expose connection state in the new asyncio implementation. --- docs/reference/new-asyncio/client.rst | 2 ++ docs/reference/new-asyncio/common.rst | 2 ++ docs/reference/new-asyncio/server.rst | 2 ++ src/websockets/asyncio/connection.py | 12 ++++++++++++ src/websockets/protocol.py | 2 +- tests/asyncio/test_connection.py | 6 +++++- 6 files changed, 24 insertions(+), 2 deletions(-) diff --git a/docs/reference/new-asyncio/client.rst b/docs/reference/new-asyncio/client.rst index 196bda2b7..efd143f14 100644 --- a/docs/reference/new-asyncio/client.rst +++ b/docs/reference/new-asyncio/client.rst @@ -43,6 +43,8 @@ Using a connection .. autoproperty:: remote_address + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: diff --git a/docs/reference/new-asyncio/common.rst b/docs/reference/new-asyncio/common.rst index 4fa97dcf2..60ea6bb37 100644 --- a/docs/reference/new-asyncio/common.rst +++ b/docs/reference/new-asyncio/common.rst @@ -33,6 +33,8 @@ Both sides (new :mod:`asyncio`) .. autoproperty:: remote_address + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: diff --git a/docs/reference/new-asyncio/server.rst b/docs/reference/new-asyncio/server.rst index 7f9de6148..b163e0fcd 100644 --- a/docs/reference/new-asyncio/server.rst +++ b/docs/reference/new-asyncio/server.rst @@ -62,6 +62,8 @@ Using a connection .. autoproperty:: remote_address + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 9d2f087da..a323376ca 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -137,6 +137,18 @@ def remote_address(self) -> Any: """ return self.transport.get_extra_info("peername") + @property + def state(self) -> State: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should call :meth:`~recv` or + :meth:`send` and handle :exc:`~exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.state + @property def subprotocol(self) -> Subprotocol | None: """ diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 917c19163..de065c544 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -156,7 +156,7 @@ def __init__( @property def state(self) -> State: """ - WebSocket connection state. + State of the WebSocket connection. Defined in 4.1, 4.2, 7.1.3, and 7.1.4 of :rfc:`6455`. diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 1cf382a01..239b5312e 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -11,7 +11,7 @@ from websockets.asyncio.connection import * from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK from websockets.frames import CloseCode, Frame, Opcode -from websockets.protocol import CLIENT, SERVER, Protocol +from websockets.protocol import CLIENT, SERVER, Protocol, State from ..protocol import RecordingProtocol from ..utils import MS @@ -930,6 +930,10 @@ async def test_remote_address(self, get_extra_info): self.assertEqual(self.connection.remote_address, ("peer", 1234)) get_extra_info.assert_called_with("peername") + async def test_state(self): + """Connection has a state attribute.""" + self.assertEqual(self.connection.state, State.OPEN) + async def test_request(self): """Connection has a request attribute.""" self.assertIsNone(self.connection.request) From 7c8e0b9d6246cd7bdd304f630f719fc55620f89a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 17 Aug 2024 10:20:37 +0200 Subject: [PATCH 091/109] Document removal of open and closed properties. They won't be added to the new asyncio implementation. --- docs/howto/upgrade.rst | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 6efaf0f56..8ff18c594 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -371,3 +371,29 @@ buffer now. The ``write_limit`` argument of :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` defaults to 32 KiB instead of 64 KiB. + +Attributes of connections +......................... + +``open`` and ``closed`` +~~~~~~~~~~~~~~~~~~~~~~~ + +The :attr:`~legacy.protocol.WebSocketCommonProtocol.open` and +:attr:`~legacy.protocol.WebSocketCommonProtocol.closed` properties are removed. +Using them was discouraged. + +Instead, you should call :meth:`~asyncio.connection.Connection.recv` or +:meth:`~asyncio.connection.Connection.send` and handle +:exc:`~exceptions.ConnectionClosed` exceptions. + +If your code relies on them, you can replace:: + + connection.open + connection.closed + +with:: + + from websockets.protocol import State + + connection.state is State.OPEN + connection.state is State.CLOSED From 7c1d1d9b97fa698034d2b3651eb5a757e42b3dfb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 17 Aug 2024 11:53:37 +0200 Subject: [PATCH 092/109] Add a respond method to server connections. It's an alias of the reject method of the underlying server protocol. It makes it easier to write process_request It's called respond because the semantics are "consider the request as an HTTP request and create an HTTP response". There isn't a similar alias for accept because process_request should just return and websockets will call accept. --- docs/howto/upgrade.rst | 2 +- docs/reference/new-asyncio/server.rst | 2 ++ docs/reference/sync/server.rst | 2 ++ src/websockets/asyncio/server.py | 23 ++++++++++++++++++++++- src/websockets/server.py | 27 ++++++++++++++------------- src/websockets/sync/server.py | 23 ++++++++++++++++++++++- tests/asyncio/test_server.py | 4 ++-- tests/sync/test_server.py | 2 +- 8 files changed, 66 insertions(+), 19 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 8ff18c594..fe95a6517 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -247,7 +247,7 @@ an example:: # New implementation def process_request(connection, request): - return connection.protocol.reject(http.HTTPStatus.OK, "OK\n") + return connection.respond(http.HTTPStatus.OK, "OK\n") serve(..., process_request=process_request, ...) diff --git a/docs/reference/new-asyncio/server.rst b/docs/reference/new-asyncio/server.rst index b163e0fcd..5ffcff843 100644 --- a/docs/reference/new-asyncio/server.rst +++ b/docs/reference/new-asyncio/server.rst @@ -52,6 +52,8 @@ Using a connection .. automethod:: pong + .. automethod:: respond + WebSocket connection objects also provide these attributes: .. autoattribute:: id diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst index 7ed744df2..26ab872c8 100644 --- a/docs/reference/sync/server.rst +++ b/docs/reference/sync/server.rst @@ -40,6 +40,8 @@ Using a connection .. automethod:: pong + .. automethod:: respond + WebSocket connection objects also provide these attributes: .. autoattribute:: id diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 4feea13c4..cc2f46216 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -23,7 +23,7 @@ from ..http11 import SERVER, Request, Response from ..protocol import CONNECTING, Event from ..server import ServerProtocol -from ..typing import LoggerLike, Origin, Subprotocol +from ..typing import LoggerLike, Origin, StatusLike, Subprotocol from .compatibility import asyncio_timeout from .connection import Connection @@ -75,6 +75,27 @@ def __init__( self.server = server self.request_rcvd: asyncio.Future[None] = self.loop.create_future() + def respond(self, status: StatusLike, text: str) -> Response: + """ + Create a plain text HTTP response. + + ``process_request`` and ``process_response`` may call this method to + return an HTTP response instead of performing the WebSocket opening + handshake. + + You can modify the response before returning it, for example by changing + HTTP headers. + + Args: + status: HTTP status code. + text: HTTP response body; it will be encoded to UTF-8. + + Returns: + HTTP response to send to the client. + + """ + return self.protocol.reject(status, text) + async def handshake( self, process_request: ( diff --git a/src/websockets/server.py b/src/websockets/server.py index 1b4c3bf29..2ab9102f7 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -113,18 +113,22 @@ def accept(self, request: Request) -> Response: """ Create a handshake response to accept the connection. - If the connection cannot be established, the handshake response - actually rejects the handshake. + If the handshake request is valid and the handshake successful, + :meth:`accept` returns an HTTP response with status code 101. + + Else, it returns an HTTP response with another status code. This rejects + the connection, like :meth:`reject` would. You must send the handshake response with :meth:`send_response`. - You may modify it before sending it, for example to add HTTP headers. + You may modify the response before sending it, typically by adding HTTP + headers. Args: - request: WebSocket handshake request event received from the client. + request: WebSocket handshake request received from the client. Returns: - WebSocket handshake response event to send to the client. + WebSocket handshake response or HTTP response to send to the client. """ try: @@ -485,11 +489,7 @@ def select_subprotocol(protocol, subprotocols): + ", ".join(self.available_subprotocols) ) - def reject( - self, - status: StatusLike, - text: str, - ) -> Response: + def reject(self, status: StatusLike, text: str) -> Response: """ Create a handshake response to reject the connection. @@ -498,14 +498,15 @@ def reject( You must send the handshake response with :meth:`send_response`. - You can modify it before sending it, for example to alter HTTP headers. + You may modify the response before sending it, for example by changing + HTTP headers. Args: status: HTTP status code. - text: HTTP response body; will be encoded to UTF-8. + text: HTTP response body; it will be encoded to UTF-8. Returns: - WebSocket handshake response event to send to the client. + HTTP response to send to the client. """ # If a user passes an int instead of a HTTPStatus, fix it automatically. diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 10fbe4859..b381908ca 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -19,7 +19,7 @@ from ..http11 import SERVER, Request, Response from ..protocol import CONNECTING, OPEN, Event from ..server import ServerProtocol -from ..typing import LoggerLike, Origin, Subprotocol +from ..typing import LoggerLike, Origin, StatusLike, Subprotocol from .connection import Connection from .utils import Deadline @@ -66,6 +66,27 @@ def __init__( close_timeout=close_timeout, ) + def respond(self, status: StatusLike, text: str) -> Response: + """ + Create a plain text HTTP response. + + ``process_request`` and ``process_response`` may call this method to + return an HTTP response instead of performing the WebSocket opening + handshake. + + You can modify the response before returning it, for example by changing + HTTP headers. + + Args: + status: HTTP status code. + text: HTTP response body; it will be encoded to UTF-8. + + Returns: + HTTP response to send to the client. + + """ + return self.protocol.reject(status, text) + def handshake( self, process_request: ( diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 4a8a76a21..fa590210f 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -144,7 +144,7 @@ async def test_process_request_abort_handshake(self): """Server aborts handshake if process_request returns a response.""" def process_request(ws, request): - return ws.protocol.reject(http.HTTPStatus.FORBIDDEN, "Forbidden") + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") async with run_server(process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: @@ -159,7 +159,7 @@ async def test_async_process_request_abort_handshake(self): """Server aborts handshake if async process_request returns a response.""" async def process_request(ws, request): - return ws.protocol.reject(http.HTTPStatus.FORBIDDEN, "Forbidden") + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") async with run_server(process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 9d509a5c4..4e04a39d5 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -131,7 +131,7 @@ def test_process_request_abort_handshake(self): """Server aborts handshake if process_request returns a response.""" def process_request(ws, request): - return ws.protocol.reject(http.HTTPStatus.FORBIDDEN, "Forbidden") + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") with run_server(process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: From 7b19e790ce766dadb0b90b040be68694074a7e0d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 17 Aug 2024 23:12:36 +0200 Subject: [PATCH 093/109] Add keepalive to the new asyncio implementation. --- docs/faq/common.rst | 4 +- docs/howto/upgrade.rst | 11 --- docs/reference/features.rst | 4 +- docs/reference/new-asyncio/client.rst | 2 + docs/reference/new-asyncio/common.rst | 2 + docs/reference/new-asyncio/server.rst | 2 + docs/topics/broadcast.rst | 4 +- docs/topics/index.rst | 2 +- docs/topics/{timeouts.rst => keepalive.rst} | 16 ++-- src/websockets/asyncio/client.py | 21 ++++- src/websockets/asyncio/connection.py | 89 ++++++++++++++++++-- src/websockets/asyncio/server.py | 21 ++++- src/websockets/legacy/protocol.py | 16 ++-- tests/asyncio/test_client.py | 15 ++++ tests/asyncio/test_connection.py | 91 ++++++++++++++++++++- tests/asyncio/test_server.py | 21 +++++ tests/legacy/test_protocol.py | 4 +- 17 files changed, 274 insertions(+), 51 deletions(-) rename docs/topics/{timeouts.rst => keepalive.rst} (90%) diff --git a/docs/faq/common.rst b/docs/faq/common.rst index 2c63c4f36..84256fdfe 100644 --- a/docs/faq/common.rst +++ b/docs/faq/common.rst @@ -97,7 +97,7 @@ There are two main reasons why latency may increase: * Poor network connectivity. * More traffic than the recipient can handle. -See the discussion of :doc:`timeouts <../topics/timeouts>` for details. +See the discussion of :doc:`keepalive <../topics/keepalive>` for details. If websockets' default timeout of 20 seconds is too short for your use case, you can adjust it with the ``ping_timeout`` argument. @@ -146,7 +146,7 @@ It closes the connection if it doesn't get a pong within 20 seconds. You can adjust this behavior with ``ping_interval`` and ``ping_timeout``. -See :doc:`../topics/timeouts` for details. +See :doc:`../topics/keepalive` for details. How do I respond to pings? -------------------------- diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index fe95a6517..16b010aca 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -70,17 +70,6 @@ Missing features If your application relies on one of them, you should stick to the original implementation until the new implementation supports it in a future release. -Keepalive -......... - -The new implementation doesn't provide a :ref:`keepalive mechanism ` -yet. - -As a consequence, :func:`~asyncio.client.connect` and -:func:`~asyncio.server.serve` don't accept the ``ping_interval`` and -``ping_timeout`` arguments and the -:attr:`~legacy.protocol.WebSocketCommonProtocol.latency` property doesn't exist. - HTTP Basic Authentication ......................... diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 946770fe3..45fa79c48 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -53,9 +53,9 @@ Both sides +------------------------------------+--------+--------+--------+--------+ | Send a pong | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Keepalive | ❌ | ❌ | — | ✅ | + | Keepalive | ✅ | ❌ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Heartbeat | ❌ | ❌ | — | ✅ | + | Heartbeat | ✅ | ❌ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Perform the closing handshake | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ diff --git a/docs/reference/new-asyncio/client.rst b/docs/reference/new-asyncio/client.rst index efd143f14..77a3c5d53 100644 --- a/docs/reference/new-asyncio/client.rst +++ b/docs/reference/new-asyncio/client.rst @@ -43,6 +43,8 @@ Using a connection .. autoproperty:: remote_address + .. autoattribute:: latency + .. autoproperty:: state The following attributes are available after the opening handshake, diff --git a/docs/reference/new-asyncio/common.rst b/docs/reference/new-asyncio/common.rst index 60ea6bb37..a58325fb9 100644 --- a/docs/reference/new-asyncio/common.rst +++ b/docs/reference/new-asyncio/common.rst @@ -33,6 +33,8 @@ Both sides (new :mod:`asyncio`) .. autoproperty:: remote_address + .. autoattribute:: latency + .. autoproperty:: state The following attributes are available after the opening handshake, diff --git a/docs/reference/new-asyncio/server.rst b/docs/reference/new-asyncio/server.rst index 5ffcff843..7bceca5a0 100644 --- a/docs/reference/new-asyncio/server.rst +++ b/docs/reference/new-asyncio/server.rst @@ -64,6 +64,8 @@ Using a connection .. autoproperty:: remote_address + .. autoattribute:: latency + .. autoproperty:: state The following attributes are available after the opening handshake, diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index 671319136..ec358bbd2 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -206,8 +206,8 @@ How do we avoid running out of memory when slow clients can't keep up with the broadcast rate, then? The most straightforward option is to disconnect them. If a client gets too far behind, eventually it reaches the limit defined by -``ping_timeout`` and websockets terminates the connection. You can refer to -the discussion of :doc:`keepalive and timeouts ` for details. +``ping_timeout`` and websockets terminates the connection. You can refer to the +discussion of :doc:`keepalive ` for details. How :func:`~asyncio.connection.broadcast` works ----------------------------------------------- diff --git a/docs/topics/index.rst b/docs/topics/index.rst index 120a3dd32..a2b8ca879 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -11,7 +11,7 @@ Get a deeper understanding of how websockets is built and why. authentication broadcast compression - timeouts + keepalive design memory security diff --git a/docs/topics/timeouts.rst b/docs/topics/keepalive.rst similarity index 90% rename from docs/topics/timeouts.rst rename to docs/topics/keepalive.rst index 633fc1ab4..1c7a43264 100644 --- a/docs/topics/timeouts.rst +++ b/docs/topics/keepalive.rst @@ -1,5 +1,5 @@ -Timeouts -======== +Keepalive and latency +===================== .. currentmodule:: websockets @@ -49,9 +49,9 @@ This mechanism serves two purposes: application gets a :exc:`~exceptions.ConnectionClosed` exception. Timings are configurable with the ``ping_interval`` and ``ping_timeout`` -arguments of :func:`~client.connect` and :func:`~server.serve`. Shorter values -will detect connection drops faster but they will increase network traffic and -they will be more sensitive to latency. +arguments of :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve`. +Shorter values will detect connection drops faster but they will increase +network traffic and they will be more sensitive to latency. Setting ``ping_interval`` to :obj:`None` disables the whole keepalive and heartbeat mechanism. @@ -111,6 +111,6 @@ Latency between a client and a server may increase for two reasons: than the client can accept. The latency measured during the last exchange of Ping and Pong frames is -available in the :attr:`~legacy.protocol.WebSocketCommonProtocol.latency` -attribute. Alternatively, you can measure the latency at any time with the -:attr:`~legacy.protocol.WebSocketCommonProtocol.ping` method. +available in the :attr:`~asyncio.connection.Connection.latency` attribute. +Alternatively, you can measure the latency at any time with the +:attr:`~asyncio.connection.Connection.ping` method. diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index b2eaf9a65..632d3ac2b 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -37,10 +37,11 @@ class ClientConnection(Connection): :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. + The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``, + and ``write_limit`` arguments the same meaning as in :func:`connect`. + Args: protocol: Sans-I/O connection. - close_timeout: Timeout for closing the connection in seconds. - :obj:`None` disables the timeout. """ @@ -48,6 +49,8 @@ def __init__( self, protocol: ClientProtocol, *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, max_queue: int | tuple[int, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, @@ -55,6 +58,8 @@ def __init__( self.protocol: ClientProtocol super().__init__( protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, write_limit=write_limit, @@ -84,7 +89,9 @@ async def handshake( if self.response is None: raise ConnectionError("connection closed during handshake") - if self.protocol.handshake_exc is not None: + if self.protocol.handshake_exc is None: + self.start_keepalive() + else: try: async with asyncio_timeout(self.close_timeout): await self.connection_lost_waiter @@ -146,6 +153,10 @@ class connect: :doc:`compression guide <../../topics/compression>` for details. open_timeout: Timeout for opening the connection in seconds. :obj:`None` disables the timeout. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. close_timeout: Timeout for closing the connection in seconds. :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. @@ -208,6 +219,8 @@ def __init__( compression: str | None = "deflate", # Timeouts open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, @@ -256,6 +269,8 @@ def factory() -> ClientConnection: # This is a connection in websockets and a protocol in asyncio. connection = create_connection( protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, write_limit=write_limit, diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index a323376ca..b232b7956 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -24,7 +24,13 @@ from ..http11 import Request, Response from ..protocol import CLOSED, OPEN, Event, Protocol, State from ..typing import Data, LoggerLike, Subprotocol -from .compatibility import TimeoutError, aiter, anext, asyncio_timeout_at +from .compatibility import ( + TimeoutError, + aiter, + anext, + asyncio_timeout, + asyncio_timeout_at, +) from .messages import Assembler @@ -48,11 +54,15 @@ def __init__( self, protocol: Protocol, *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, max_queue: int | tuple[int, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol = protocol + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout self.close_timeout = close_timeout if isinstance(max_queue, int): max_queue = (max_queue, None) @@ -95,6 +105,21 @@ def __init__( # Mapping of ping IDs to pong waiters, in chronological order. self.pong_waiters: dict[bytes, tuple[asyncio.Future[float], float]] = {} + self.latency: float = 0 + """ + Latency of the connection, in seconds. + + This value is updated after sending a ping frame and receiving a + matching pong frame. Before the first ping, :attr:`latency` is ``0``. + + By default, websockets enables a :ref:`keepalive ` mechanism + that sends ping frames automatically at regular intervals. You can also + send ping frames and measure latency with :meth:`ping`. + """ + + # Task that sends keepalive pings. None when ping_interval is None. + self.keepalive_task: asyncio.Task[None] | None = None + # Exception raised while reading from the connection, to be chained to # ConnectionClosed in order to show why the TCP connection dropped. self.recv_exc: BaseException | None = None @@ -144,7 +169,8 @@ def state(self) -> State: This attribute is provided for completeness. Typical applications shouldn't check its value. Instead, they should call :meth:`~recv` or - :meth:`send` and handle :exc:`~exceptions.ConnectionClosed` exceptions. + :meth:`send` and handle :exc:`~websockets.exceptions.ConnectionClosed` + exceptions. """ return self.protocol.state @@ -540,7 +566,7 @@ async def wait_closed(self) -> None: """ await asyncio.shield(self.connection_lost_waiter) - async def ping(self, data: Data | None = None) -> Awaitable[None]: + async def ping(self, data: Data | None = None) -> Awaitable[float]: """ Send a Ping_. @@ -643,8 +669,10 @@ def acknowledge_pings(self, data: bytes) -> None: ping_ids = [] for ping_id, (pong_waiter, ping_timestamp) in self.pong_waiters.items(): ping_ids.append(ping_id) - pong_waiter.set_result(pong_timestamp - ping_timestamp) + latency = pong_timestamp - ping_timestamp + pong_waiter.set_result(latency) if ping_id == data: + self.latency = latency break else: raise AssertionError("solicited pong not found in pings") @@ -664,7 +692,8 @@ def abort_pings(self) -> None: exc = self.protocol.close_exc for pong_waiter, _ping_timestamp in self.pong_waiters.values(): - pong_waiter.set_exception(exc) + if not pong_waiter.done(): + pong_waiter.set_exception(exc) # If the exception is never retrieved, it will be logged when ping # is garbage-collected. This is confusing for users. # Given that ping is done (with an exception), canceling it does @@ -673,6 +702,50 @@ def abort_pings(self) -> None: self.pong_waiters.clear() + async def keepalive(self) -> None: + """ + Send a Ping frame and wait for a Pong frame at regular intervals. + + """ + assert self.ping_interval is not None + latency = 0.0 + try: + while True: + # If self.ping_timeout > latency > self.ping_interval, pings + # will be sent immediately after receiving pongs. The period + # will be longer than self.ping_interval. + await asyncio.sleep(self.ping_interval - latency) + + self.logger.debug("% sending keepalive ping") + pong_waiter = await self.ping() + + if self.ping_timeout is not None: + try: + async with asyncio_timeout(self.ping_timeout): + latency = await pong_waiter + self.logger.debug("% received keepalive pong") + except asyncio.TimeoutError: + if self.debug: + self.logger.debug("! timed out waiting for keepalive pong") + async with self.send_context(): + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "keepalive ping timeout", + ) + break + except ConnectionClosed: + pass + except Exception: + self.logger.error("keepalive ping failed", exc_info=True) + + def start_keepalive(self) -> None: + """ + Run :meth:`keepalive` in a task, unless keepalive is disabled. + + """ + if self.ping_interval is not None: + self.keepalive_task = self.loop.create_task(self.keepalive()) + @contextlib.asynccontextmanager async def send_context( self, @@ -835,11 +908,15 @@ def connection_lost(self, exc: Exception | None) -> None: self.protocol.receive_eof() # receive_eof is idempotent self.recv_messages.close() self.set_recv_exc(exc) + self.abort_pings() + # If keepalive() was waiting for a pong, abort_pings() terminated it. + # If it was sleeping until the next ping, we need to cancel it now + if self.keepalive_task is not None: + self.keepalive_task.cancel() # If self.connection_lost_waiter isn't pending, that's a bug, because: # - it's set only here in connection_lost() which is called only once; # - it must never be canceled. self.connection_lost_waiter.set_result(None) - self.abort_pings() # Adapted from asyncio.streams.FlowControlMixin if self.paused: # pragma: no cover diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index cc2f46216..1f55502bb 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -48,11 +48,12 @@ class ServerConnection(Connection): :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. + The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``, + and ``write_limit`` arguments the same meaning as in :func:`serve`. + Args: protocol: Sans-I/O connection. server: Server that manages this connection. - close_timeout: Timeout for closing connections in seconds. - :obj:`None` disables the timeout. """ @@ -61,6 +62,8 @@ def __init__( protocol: ServerProtocol, server: WebSocketServer, *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, max_queue: int | tuple[int, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, @@ -68,6 +71,8 @@ def __init__( self.protocol: ServerProtocol super().__init__( protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, write_limit=write_limit, @@ -182,7 +187,9 @@ async def handshake( self.protocol.send_response(self.response) - if self.protocol.handshake_exc is not None: + if self.protocol.handshake_exc is None: + self.start_keepalive() + else: try: async with asyncio_timeout(self.close_timeout): await self.connection_lost_waiter @@ -595,6 +602,10 @@ def handler(websocket): :doc:`compression guide <../../topics/compression>` for details. open_timeout: Timeout for opening connections in seconds. :obj:`None` disables the timeout. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. close_timeout: Timeout for closing connections in seconds. :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. @@ -667,6 +678,8 @@ def __init__( compression: str | None = "deflate", # Timeouts open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, @@ -743,6 +756,8 @@ def protocol_select_subprotocol( connection = create_connection( protocol, self.server, + ping_interval=ping_interval, + ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, write_limit=write_limit, diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index b948257e0..191350de3 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -89,7 +89,7 @@ class WebSocketCommonProtocol(asyncio.Protocol): .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 - See the discussion of :doc:`timeouts <../../topics/timeouts>` for details. + See the discussion of :doc:`timeouts <../../topics/keepalive>` for details. The ``close_timeout`` parameter defines a maximum wait time for completing the closing handshake and terminating the TCP connection. For legacy @@ -144,8 +144,8 @@ class WebSocketCommonProtocol(asyncio.Protocol): logger: Logger for this server. It defaults to ``logging.getLogger("websockets.protocol")``. See the :doc:`logging guide <../../topics/logging>` for details. - ping_interval: Delay between keepalive pings in seconds. - :obj:`None` disables keepalive pings. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. ping_timeout: Timeout for keepalive pings in seconds. :obj:`None` disables timeouts. close_timeout: Timeout for closing the connection in seconds. @@ -1242,18 +1242,16 @@ async def keepalive_ping(self) -> None: while True: await asyncio.sleep(self.ping_interval) - # ping() raises CancelledError if the connection is closed, - # when close_connection() cancels self.keepalive_ping_task. - - # ping() raises ConnectionClosed if the connection is lost, - # when connection_lost() calls abort_pings(). - self.logger.debug("% sending keepalive ping") pong_waiter = await self.ping() if self.ping_timeout is not None: try: async with asyncio_timeout(self.ping_timeout): + # Raises CancelledError if the connection is closed, + # when close_connection() cancels keepalive_ping(). + # Raises ConnectionClosed if the connection is lost, + # when connection_lost() calls abort_pings(). await pong_waiter self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index b74617ef0..0bd2af4f1 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -63,6 +63,21 @@ async def test_disable_compression(self): async with run_client(server, compression=None) as client: self.assertEqual(client.protocol.extensions, []) + async def test_keepalive_is_enabled(self): + """Client enables keepalive and measures latency by default.""" + async with run_server() as server: + async with run_client(server, ping_interval=MS) as client: + self.assertEqual(client.latency, 0) + await asyncio.sleep(2 * MS) + self.assertGreater(client.latency, 0) + + async def test_disable_keepalive(self): + """Client disables keepalive.""" + async with run_server() as server: + async with run_client(server, ping_interval=None) as client: + await asyncio.sleep(2 * MS) + self.assertEqual(client.latency, 0) + async def test_custom_connection_factory(self): """Client runs ClientConnection factory provided in create_connection.""" diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 239b5312e..9b84a6b81 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -868,6 +868,93 @@ async def test_pong_explicit_binary(self): await self.connection.pong(b"pong") await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + # Test keepalive. + + @patch("random.getrandbits") + async def test_keepalive(self, getrandbits): + """keepalive sends pings.""" + self.connection.ping_interval = 2 * MS + getrandbits.return_value = 1918987876 + self.connection.start_keepalive() + await asyncio.sleep(3 * MS) + await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + + @patch("random.getrandbits") + async def test_keepalive_times_out(self, getrandbits): + """keepalive closes the connection if ping_timeout elapses.""" + self.connection.ping_interval = 4 * MS + self.connection.ping_timeout = 2 * MS + async with self.drop_frames_rcvd(): + getrandbits.return_value = 1918987876 + self.connection.start_keepalive() + await asyncio.sleep(4 * MS) + # Exiting the context manager sleeps for MS. + await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + await asyncio.sleep(MS) + await self.assertFrameSent( + Frame(Opcode.CLOSE, b"\x03\xf3keepalive ping timeout") + ) + + @patch("random.getrandbits") + async def test_keepalive_ignores_timeout(self, getrandbits): + """keepalive ignores timeouts if ping_timeout isn't set.""" + self.connection.ping_interval = 4 * MS + self.connection.ping_timeout = None + async with self.drop_frames_rcvd(): + getrandbits.return_value = 1918987876 + self.connection.start_keepalive() + await asyncio.sleep(4 * MS) + # Exiting the context manager sleeps for MS. + await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + await asyncio.sleep(MS) + await self.assertNoFrameSent() + + async def test_disable_keepalive(self): + """keepalive is disabled when ping_interval is None.""" + self.connection.ping_interval = None + self.connection.start_keepalive() + await asyncio.sleep(3 * MS) + await self.assertNoFrameSent() + + async def test_keepalive_terminates_while_sleeping(self): + """keepalive task terminates while waiting to send a ping.""" + self.connection.ping_interval = 2 * MS + self.connection.start_keepalive() + await asyncio.sleep(MS) + await self.connection.close() + self.assertTrue(self.connection.keepalive_task.done()) + + async def test_keepalive_terminates_while_waiting_for_pong(self): + """keepalive task terminates while waiting to receive a pong.""" + self.connection.ping_interval = 2 * MS + async with self.drop_frames_rcvd(): + self.connection.start_keepalive() + await asyncio.sleep(2 * MS) + # Exiting the context manager sleeps for MS. + await self.connection.close() + self.assertTrue(self.connection.keepalive_task.done()) + + async def test_keepalive_reports_errors(self): + """keepalive reports unexpected errors in logs.""" + self.connection.ping_interval = 2 * MS + # Inject a fault by raising an exception in a pending pong waiter. + async with self.drop_frames_rcvd(): + self.connection.start_keepalive() + await asyncio.sleep(2 * MS) + # Exiting the context manager sleeps for MS. + pong_waiter = next(iter(self.connection.pong_waiters.values()))[0] + with self.assertLogs("websockets", logging.ERROR) as logs: + pong_waiter.set_exception(Exception("BOOM")) + await asyncio.sleep(0) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["keepalive ping failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) + # Test parameters. async def test_close_timeout(self): @@ -1092,7 +1179,7 @@ async def fragments(): broadcast([self.connection], "😀") self.assertEqual( - [record.getMessage() for record in logs.records][:2], + [record.getMessage() for record in logs.records], ["skipped broadcast: sending a fragmented message"], ) @@ -1135,7 +1222,7 @@ async def test_broadcast_skips_connection_failing_to_send(self): broadcast([self.connection], "😀") self.assertEqual( - [record.getMessage() for record in logs.records][:2], + [record.getMessage() for record in logs.records], ["skipped broadcast: failed to write message"], ) diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index fa590210f..b3023434b 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -312,6 +312,27 @@ async def test_disable_compression(self): async with run_client(server) as client: await self.assertEval(client, "ws.protocol.extensions", "[]") + async def test_keepalive_is_enabled(self): + """Server enables keepalive and measures latency.""" + async with run_server(ping_interval=MS) as server: + async with run_client(server) as client: + await client.send("ws.latency") + latency = eval(await client.recv()) + self.assertEqual(latency, 0) + await asyncio.sleep(2 * MS) + await client.send("ws.latency") + latency = eval(await client.recv()) + self.assertGreater(latency, 0) + + async def test_disable_keepalive(self): + """Client disables keepalive.""" + async with run_server(ping_interval=None) as server: + async with run_client(server) as client: + await asyncio.sleep(2 * MS) + await client.send("ws.latency") + latency = eval(await client.recv()) + self.assertEqual(latency, 0) + async def test_custom_connection_factory(self): """Server runs ServerConnection factory provided in create_connection.""" diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index ccea34719..8751b9ac6 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1533,7 +1533,7 @@ def test_broadcast_skips_connection_sending_fragmented_text(self): broadcast([self.protocol], "café") self.assertEqual( - [record.getMessage() for record in logs.records][:2], + [record.getMessage() for record in logs.records], ["skipped broadcast: sending a fragmented message"], ) @@ -1563,7 +1563,7 @@ def test_broadcast_skips_connection_failing_to_send(self): broadcast([self.protocol], "café") self.assertEqual( - [record.getMessage() for record in logs.records][:2], + [record.getMessage() for record in logs.records], ["skipped broadcast: failed to write message"], ) From 60381d2566b55126f0f89f0c8380cf44ddc51aa1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 19 Aug 2024 07:58:31 +0200 Subject: [PATCH 094/109] Fix exception chaining for ConnectionClosed. --- src/websockets/asyncio/connection.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index b232b7956..284fe2124 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -906,13 +906,17 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: def connection_lost(self, exc: Exception | None) -> None: self.protocol.receive_eof() # receive_eof is idempotent - self.recv_messages.close() + + # Abort recv() and pending pings with a ConnectionClosed exception. + # Set recv_exc first to get proper exception reporting. self.set_recv_exc(exc) + self.recv_messages.close() self.abort_pings() # If keepalive() was waiting for a pong, abort_pings() terminated it. # If it was sleeping until the next ping, we need to cancel it now if self.keepalive_task is not None: self.keepalive_task.cancel() + # If self.connection_lost_waiter isn't pending, that's a bug, because: # - it's set only here in connection_lost() which is called only once; # - it must never be canceled. From a78b5546074ed9e89a265eec1b54292a628d9b25 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 19 Aug 2024 08:00:16 +0200 Subject: [PATCH 095/109] Document legacy implementation in websockets.legacy. Until now it was documented directly in the websockets package. Also update most examples to use the new asyncio implementation. Some drive-by documentation improvements too. --- compliance/test_client.py | 10 +- compliance/test_server.py | 6 +- docs/conf.py | 28 +- docs/faq/asyncio.rst | 44 ++-- docs/faq/client.rst | 62 +++-- docs/faq/common.rst | 53 ++-- docs/faq/misc.rst | 17 +- docs/faq/server.rst | 111 ++++---- docs/howto/cheatsheet.rst | 57 ++-- docs/howto/django.rst | 2 +- docs/howto/heroku.rst | 2 +- docs/howto/nginx.rst | 6 +- docs/howto/patterns.rst | 10 +- docs/howto/quickstart.rst | 14 +- docs/howto/upgrade.rst | 91 ++++--- docs/intro/tutorial1.rst | 34 +-- docs/intro/tutorial3.rst | 6 +- docs/project/changelog.rst | 117 ++++----- docs/reference/asyncio/client.rst | 37 ++- docs/reference/asyncio/common.rst | 33 +-- docs/reference/asyncio/server.rst | 70 ++--- docs/reference/features.rst | 5 +- docs/reference/index.rst | 19 +- docs/reference/legacy/client.rst | 64 +++++ docs/reference/legacy/common.rst | 54 ++++ docs/reference/legacy/server.rst | 113 ++++++++ docs/reference/new-asyncio/client.rst | 57 ---- docs/reference/new-asyncio/common.rst | 47 ---- docs/reference/new-asyncio/server.rst | 83 ------ docs/topics/authentication.rst | 22 +- docs/topics/deployment.rst | 13 +- docs/topics/design.rst | 286 ++++++++++----------- docs/topics/logging.rst | 11 +- docs/topics/memory.rst | 9 +- docs/topics/security.rst | 6 +- example/deployment/fly/app.py | 10 +- example/deployment/haproxy/app.py | 4 +- example/deployment/heroku/app.py | 4 +- example/deployment/kubernetes/app.py | 18 +- example/deployment/kubernetes/benchmark.py | 5 +- example/deployment/nginx/app.py | 4 +- example/deployment/render/app.py | 10 +- example/deployment/supervisor/app.py | 4 +- example/django/authentication.py | 4 +- example/django/notifications.py | 7 +- example/echo.py | 2 +- example/faq/health_check_server.py | 15 +- example/faq/shutdown_client.py | 8 +- example/faq/shutdown_server.py | 5 +- example/legacy/basic_auth_client.py | 5 +- example/legacy/basic_auth_server.py | 8 +- example/legacy/unix_client.py | 5 +- example/legacy/unix_server.py | 5 +- example/logging/json_log_formatter.py | 2 +- example/quickstart/client.py | 5 +- example/quickstart/client_secure.py | 5 +- example/quickstart/counter.py | 13 +- example/quickstart/server.py | 5 +- example/quickstart/server_secure.py | 5 +- example/quickstart/show_time.py | 5 +- example/quickstart/show_time_2.py | 8 +- example/tutorial/step1/app.py | 4 +- example/tutorial/step2/app.py | 4 +- example/tutorial/step3/app.py | 4 +- experiments/authentication/app.py | 19 +- experiments/broadcast/clients.py | 4 +- src/websockets/exceptions.py | 4 +- src/websockets/legacy/auth.py | 6 +- src/websockets/legacy/client.py | 4 +- src/websockets/legacy/protocol.py | 16 +- tests/asyncio/test_connection.py | 2 +- tests/asyncio/test_server.py | 9 +- tests/sync/test_connection.py | 2 +- tests/sync/test_server.py | 5 +- 74 files changed, 951 insertions(+), 902 deletions(-) create mode 100644 docs/reference/legacy/client.rst create mode 100644 docs/reference/legacy/common.rst create mode 100644 docs/reference/legacy/server.rst delete mode 100644 docs/reference/new-asyncio/client.rst delete mode 100644 docs/reference/new-asyncio/common.rst delete mode 100644 docs/reference/new-asyncio/server.rst diff --git a/compliance/test_client.py b/compliance/test_client.py index 1ed4d711e..8e22569fd 100644 --- a/compliance/test_client.py +++ b/compliance/test_client.py @@ -1,9 +1,9 @@ +import asyncio import json import logging import urllib.parse -import asyncio -import websockets +from websockets.asyncio.client import connect logging.basicConfig(level=logging.WARNING) @@ -18,21 +18,21 @@ async def get_case_count(server): uri = f"{server}/getCaseCount" - async with websockets.connect(uri) as ws: + async with connect(uri) as ws: msg = ws.recv() return json.loads(msg) async def run_case(server, case, agent): uri = f"{server}/runCase?case={case}&agent={agent}" - async with websockets.connect(uri, max_size=2 ** 25, max_queue=1) as ws: + async with connect(uri, max_size=2 ** 25, max_queue=1) as ws: async for msg in ws: await ws.send(msg) async def update_reports(server, agent): uri = f"{server}/updateReports?agent={agent}" - async with websockets.connect(uri): + async with connect(uri): pass diff --git a/compliance/test_server.py b/compliance/test_server.py index 5701e4485..39176e902 100644 --- a/compliance/test_server.py +++ b/compliance/test_server.py @@ -1,7 +1,7 @@ +import asyncio import logging -import asyncio -import websockets +from websockets.asyncio.server import serve logging.basicConfig(level=logging.WARNING) @@ -19,7 +19,7 @@ async def echo(ws): async def main(): - with websockets.serve(echo, HOST, PORT, max_size=2 ** 25, max_queue=1): + with serve(echo, HOST, PORT, max_size=2 ** 25, max_queue=1): try: await asyncio.get_running_loop().create_future() # run forever except KeyboardInterrupt: diff --git a/docs/conf.py b/docs/conf.py index 9d61dc717..2c621bf41 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -36,20 +36,20 @@ # topics/design.rst discusses undocumented APIs ("py:meth", "client.WebSocketClientProtocol.handshake"), ("py:meth", "server.WebSocketServerProtocol.handshake"), - ("py:attr", "legacy.protocol.WebSocketCommonProtocol.is_client"), - ("py:attr", "legacy.protocol.WebSocketCommonProtocol.messages"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.close_connection"), - ("py:attr", "legacy.protocol.WebSocketCommonProtocol.close_connection_task"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.keepalive_ping"), - ("py:attr", "legacy.protocol.WebSocketCommonProtocol.keepalive_ping_task"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.transfer_data"), - ("py:attr", "legacy.protocol.WebSocketCommonProtocol.transfer_data_task"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.connection_open"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.ensure_open"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.fail_connection"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.connection_lost"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.read_message"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.write_frame"), + ("py:attr", "protocol.WebSocketCommonProtocol.is_client"), + ("py:attr", "protocol.WebSocketCommonProtocol.messages"), + ("py:meth", "protocol.WebSocketCommonProtocol.close_connection"), + ("py:attr", "protocol.WebSocketCommonProtocol.close_connection_task"), + ("py:meth", "protocol.WebSocketCommonProtocol.keepalive_ping"), + ("py:attr", "protocol.WebSocketCommonProtocol.keepalive_ping_task"), + ("py:meth", "protocol.WebSocketCommonProtocol.transfer_data"), + ("py:attr", "protocol.WebSocketCommonProtocol.transfer_data_task"), + ("py:meth", "protocol.WebSocketCommonProtocol.connection_open"), + ("py:meth", "protocol.WebSocketCommonProtocol.ensure_open"), + ("py:meth", "protocol.WebSocketCommonProtocol.fail_connection"), + ("py:meth", "protocol.WebSocketCommonProtocol.connection_lost"), + ("py:meth", "protocol.WebSocketCommonProtocol.read_message"), + ("py:meth", "protocol.WebSocketCommonProtocol.write_frame"), ] # Add any Sphinx extension module names here, as strings. They can be diff --git a/docs/faq/asyncio.rst b/docs/faq/asyncio.rst index e77f50add..3bc381cfd 100644 --- a/docs/faq/asyncio.rst +++ b/docs/faq/asyncio.rst @@ -1,7 +1,12 @@ Using asyncio ============= -.. currentmodule:: websockets +.. currentmodule:: websockets.asyncio.connection + +.. admonition:: This FAQ is written for the new :mod:`asyncio` implementation. + :class: hint + + Answers are also valid for the legacy :mod:`asyncio` implementation. How do I run two coroutines in parallel? ---------------------------------------- @@ -9,8 +14,8 @@ How do I run two coroutines in parallel? You must start two tasks, which the event loop will run concurrently. You can achieve this with :func:`asyncio.gather` or :func:`asyncio.create_task`. -Keep track of the tasks and make sure they terminate or you cancel them when -the connection terminates. +Keep track of the tasks and make sure that they terminate or that you cancel +them when the connection terminates. Why does my program never receive any messages? ----------------------------------------------- @@ -22,13 +27,12 @@ Putting an ``await`` statement in a ``for`` or a ``while`` loop isn't enough to yield control. Awaiting a coroutine may yield control, but there's no guarantee that it will. -For example, :meth:`~legacy.protocol.WebSocketCommonProtocol.send` only yields -control when send buffers are full, which never happens in most practical -cases. +For example, :meth:`~Connection.send` only yields control when send buffers are +full, which never happens in most practical cases. -If you run a loop that contains only synchronous operations and -a :meth:`~legacy.protocol.WebSocketCommonProtocol.send` call, you must yield -control explicitly with :func:`asyncio.sleep`:: +If you run a loop that contains only synchronous operations and a +:meth:`~Connection.send` call, you must yield control explicitly with +:func:`asyncio.sleep`:: async def producer(websocket): message = generate_next_message() @@ -46,16 +50,19 @@ See `issue 867`_. Why am I having problems with threads? -------------------------------------- -If you choose websockets' default implementation based on :mod:`asyncio`, then -you shouldn't use threads. Indeed, choosing :mod:`asyncio` to handle concurrency -is mutually exclusive with :mod:`threading`. +If you choose websockets' :mod:`asyncio` implementation, then you shouldn't use +threads. Indeed, choosing :mod:`asyncio` to handle concurrency is mutually +exclusive with :mod:`threading`. If you believe that you need to run websockets in a thread and some logic in another thread, you should run that logic in a :class:`~asyncio.Task` instead. -If it blocks the event loop, :meth:`~asyncio.loop.run_in_executor` will help. -This question is really about :mod:`asyncio`. Please review the advice about -:ref:`asyncio-multithreading` in the Python documentation. +If it has to run in another thread because it would block the event loop, +:func:`~asyncio.to_thread` or :meth:`~asyncio.loop.run_in_executor` is the way +to go. + +Please review the advice about :ref:`asyncio-multithreading` in the Python +documentation. Why does my simple program misbehave mysteriously? -------------------------------------------------- @@ -63,7 +70,6 @@ Why does my simple program misbehave mysteriously? You are using :func:`time.sleep` instead of :func:`asyncio.sleep`, which blocks the event loop and prevents asyncio from operating normally. -This may lead to messages getting send but not received, to connection -timeouts, and to unexpected results of shotgun debugging e.g. adding an -unnecessary call to :meth:`~legacy.protocol.WebSocketCommonProtocol.send` -makes the program functional. +This may lead to messages getting send but not received, to connection timeouts, +and to unexpected results of shotgun debugging e.g. adding an unnecessary call +to a coroutine makes the program functional. diff --git a/docs/faq/client.rst b/docs/faq/client.rst index c590ac107..0dfc84253 100644 --- a/docs/faq/client.rst +++ b/docs/faq/client.rst @@ -1,7 +1,16 @@ Client ====== -.. currentmodule:: websockets +.. currentmodule:: websockets.asyncio.client + +.. admonition:: This FAQ is written for the new :mod:`asyncio` implementation. + :class: hint + + Answers are also valid for the legacy :mod:`asyncio` implementation. + + They translate to the :mod:`threading` implementation by removing ``await`` + and ``async`` keywords and by using a :class:`~threading.Thread` instead of + a :class:`~asyncio.Task` for concurrent execution. Why does the client close the connection prematurely? ----------------------------------------------------- @@ -22,46 +31,47 @@ change it to:: How do I access HTTP headers? ----------------------------- -Once the connection is established, HTTP headers are available in -:attr:`~client.WebSocketClientProtocol.request_headers` and -:attr:`~client.WebSocketClientProtocol.response_headers`. +Once the connection is established, HTTP headers are available in the +:attr:`~ClientConnection.request` and :attr:`~ClientConnection.response` +objects:: + + async with connect(...) as websocket: + websocket.request.headers + websocket.response.headers How do I set HTTP headers? -------------------------- To set the ``Origin``, ``Sec-WebSocket-Extensions``, or ``Sec-WebSocket-Protocol`` headers in the WebSocket handshake request, use the -``origin``, ``extensions``, or ``subprotocols`` arguments of -:func:`~client.connect`. +``origin``, ``extensions``, or ``subprotocols`` arguments of :func:`~connect`. To override the ``User-Agent`` header, use the ``user_agent_header`` argument. Set it to :obj:`None` to remove the header. To set other HTTP headers, for example the ``Authorization`` header, use the -``extra_headers`` argument:: +``additional_headers`` argument:: - async with connect(..., extra_headers={"Authorization": ...}) as websocket: + async with connect(..., additional_headers={"Authorization": ...}) as websocket: ... -In the :mod:`threading` API, this argument is named ``additional_headers``:: - - with connect(..., additional_headers={"Authorization": ...}) as websocket: - ... +In the legacy :mod:`asyncio` API, this argument is named ``extra_headers``. How do I force the IP address that the client connects to? ---------------------------------------------------------- -Use the ``host`` argument of :meth:`~asyncio.loop.create_connection`:: +Use the ``host`` argument :func:`~connect`:: - await websockets.connect("ws://example.com", host="192.168.0.1") + async with connect(..., host="192.168.0.1") as websocket: + ... -:func:`~client.connect` accepts the same arguments as -:meth:`~asyncio.loop.create_connection`. +:func:`~connect` accepts the same arguments as +:meth:`~asyncio.loop.create_connection` and passes them through. How do I close a connection? ---------------------------- -The easiest is to use :func:`~client.connect` as a context manager:: +The easiest is to use :func:`~connect` as a context manager:: async with connect(...) as websocket: ... @@ -71,9 +81,17 @@ The connection is closed when exiting the context manager. How do I reconnect when the connection drops? --------------------------------------------- -Use :func:`~client.connect` as an asynchronous iterator:: +.. admonition:: This feature is only supported by the legacy :mod:`asyncio` + implementation. + :class: warning + + It will be added to the new :mod:`asyncio` implementation soon. + +Use :func:`~websockets.legacy.client.connect` as an asynchronous iterator:: + + from websockets.legacy.client import connect - async for websocket in websockets.connect(...): + async for websocket in connect(...): try: ... except websockets.ConnectionClosed: @@ -90,12 +108,12 @@ You can close the connection. Here's an example that terminates cleanly when it receives SIGTERM on Unix: .. literalinclude:: ../../example/faq/shutdown_client.py - :emphasize-lines: 10-13 + :emphasize-lines: 11-13 How do I disable TLS/SSL certificate verification? -------------------------------------------------- Look at the ``ssl`` argument of :meth:`~asyncio.loop.create_connection`. -:func:`~client.connect` accepts the same arguments as -:meth:`~asyncio.loop.create_connection`. +:func:`~connect` accepts the same arguments as +:meth:`~asyncio.loop.create_connection` and passes them through. diff --git a/docs/faq/common.rst b/docs/faq/common.rst index 84256fdfe..0dc4a3aeb 100644 --- a/docs/faq/common.rst +++ b/docs/faq/common.rst @@ -1,7 +1,7 @@ Both sides ========== -.. currentmodule:: websockets +.. currentmodule:: websockets.asyncio.connection What does ``ConnectionClosedError: no close frame received or sent`` mean? -------------------------------------------------------------------------- @@ -11,12 +11,6 @@ If you're seeing this traceback in the logs of a server: .. code-block:: pytb connection handler failed - Traceback (most recent call last): - ... - asyncio.exceptions.IncompleteReadError: 0 bytes read on a total of 2 expected bytes - - The above exception was the direct cause of the following exception: - Traceback (most recent call last): ... websockets.exceptions.ConnectionClosedError: no close frame received or sent @@ -25,12 +19,6 @@ or if a client crashes with this traceback: .. code-block:: pytb - Traceback (most recent call last): - ... - ConnectionResetError: [Errno 54] Connection reset by peer - - The above exception was the direct cause of the following exception: - Traceback (most recent call last): ... websockets.exceptions.ConnectionClosedError: no close frame received or sent @@ -39,8 +27,8 @@ it means that the TCP connection was lost. As a consequence, the WebSocket connection was closed without receiving and sending a close frame, which is abnormal. -You can catch and handle :exc:`~exceptions.ConnectionClosed` to prevent it -from being logged. +You can catch and handle :exc:`~websockets.exceptions.ConnectionClosed` to +prevent it from being logged. There are several reasons why long-lived connections may be lost: @@ -62,12 +50,6 @@ If you're seeing this traceback in the logs of a server: .. code-block:: pytb connection handler failed - Traceback (most recent call last): - ... - asyncio.exceptions.CancelledError - - The above exception was the direct cause of the following exception: - Traceback (most recent call last): ... websockets.exceptions.ConnectionClosedError: sent 1011 (internal error) keepalive ping timeout; no close frame received @@ -76,12 +58,6 @@ or if a client crashes with this traceback: .. code-block:: pytb - Traceback (most recent call last): - ... - asyncio.exceptions.CancelledError - - The above exception was the direct cause of the following exception: - Traceback (most recent call last): ... websockets.exceptions.ConnectionClosedError: sent 1011 (internal error) keepalive ping timeout; no close frame received @@ -89,8 +65,8 @@ or if a client crashes with this traceback: it means that the WebSocket connection suffered from excessive latency and was closed after reaching the timeout of websockets' keepalive mechanism. -You can catch and handle :exc:`~exceptions.ConnectionClosed` to prevent it -from being logged. +You can catch and handle :exc:`~websockets.exceptions.ConnectionClosed` to +prevent it from being logged. There are two main reasons why latency may increase: @@ -102,8 +78,8 @@ See the discussion of :doc:`keepalive <../topics/keepalive>` for details. If websockets' default timeout of 20 seconds is too short for your use case, you can adjust it with the ``ping_timeout`` argument. -How do I set a timeout on :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`? --------------------------------------------------------------------------------- +How do I set a timeout on :meth:`~Connection.recv`? +--------------------------------------------------- On Python ≥ 3.11, use :func:`asyncio.timeout`:: @@ -117,23 +93,24 @@ On older versions of Python, use :func:`asyncio.wait_for`:: This technique works for most APIs. When it doesn't, for example with asynchronous context managers, websockets provides an ``open_timeout`` argument. -How can I pass arguments to a custom protocol subclass? -------------------------------------------------------- +How can I pass arguments to a custom connection subclass? +--------------------------------------------------------- -You can bind additional arguments to the protocol factory with +You can bind additional arguments to the connection factory with :func:`functools.partial`:: import asyncio import functools - import websockets + from websockets.asyncio.server import ServerConnection, serve - class MyServerProtocol(websockets.WebSocketServerProtocol): + class MyServerConnection(ServerConnection): def __init__(self, *args, extra_argument=None, **kwargs): super().__init__(*args, **kwargs) # do something with extra_argument - create_protocol = functools.partial(MyServerProtocol, extra_argument=42) - start_server = websockets.serve(..., create_protocol=create_protocol) + create_connection = functools.partial(ServerConnection, extra_argument=42) + async with serve(..., create_connection=create_connection): + ... This example was for a server. The same pattern applies on a client. diff --git a/docs/faq/misc.rst b/docs/faq/misc.rst index 0e74a784f..4936aa6f3 100644 --- a/docs/faq/misc.rst +++ b/docs/faq/misc.rst @@ -13,27 +13,12 @@ Often, this is because you created a script called ``websockets.py`` in your current working directory. Then ``import websockets`` imports this module instead of the websockets library. -.. _real-import-paths: - -Why is the default implementation located in ``websockets.legacy``? -................................................................... - -This is an artifact of websockets' history. For its first eight years, only the -:mod:`asyncio` implementation existed. Then, the Sans-I/O implementation was -added. Moving the code in a ``legacy`` submodule eased this refactoring and -optimized maintainability. - -All public APIs were kept at their original locations. ``websockets.legacy`` -isn't a public API. It's only visible in the source code and in stack traces. -There is no intent to deprecate this implementation — at least until a superior -alternative exists. - Why is websockets slower than another library in my benchmark? .............................................................. Not all libraries are as feature-complete as websockets. For a fair benchmark, you should disable features that the other library doesn't provide. Typically, -you may need to disable: +you must disable: * Compression: set ``compression=None`` * Keepalive: set ``ping_interval=None`` diff --git a/docs/faq/server.rst b/docs/faq/server.rst index 53e34632f..e6b068316 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -1,7 +1,16 @@ Server ====== -.. currentmodule:: websockets +.. currentmodule:: websockets.asyncio.server + +.. admonition:: This FAQ is written for the new :mod:`asyncio` implementation. + :class: hint + + Answers are also valid for the legacy :mod:`asyncio` implementation. + + They translate to the :mod:`threading` implementation by removing ``await`` + and ``async`` keywords and by using a :class:`~threading.Thread` instead of + a :class:`~asyncio.Task` for concurrent execution. Why does the server close the connection prematurely? ----------------------------------------------------- @@ -36,8 +45,13 @@ change it like this:: async for message in websocket: print(message) -*Don't feel bad if this happens to you — it's the most common question in -websockets' issue tracker :-)* +If you have prior experience with an API that relies on callbacks, you may +assume that ``handler()`` is executed every time a message is received. The API +of websockets relies on coroutines instead. + +The handler coroutine is started when a new connection is established. Then, it +is responsible for receiving or sending messages throughout the lifetime of that +connection. Why can only one client connect at a time? ------------------------------------------ @@ -69,9 +83,9 @@ continuously:: while True: await websocket.send("firehose!") -:meth:`~legacy.protocol.WebSocketCommonProtocol.send` completes synchronously as -long as there's space in send buffers. The event loop never runs. (This pattern -is uncommon in real-world applications. It occurs mostly in toy programs.) +:meth:`~ServerConnection.send` completes synchronously as long as there's space +in send buffers. The event loop never runs. (This pattern is uncommon in +real-world applications. It occurs mostly in toy programs.) You can avoid the issue by yielding control to the event loop explicitly:: @@ -102,12 +116,12 @@ Record all connections in a global variable:: finally: CONNECTIONS.remove(websocket) -Then, call :func:`~asyncio.connection.broadcast`:: +Then, call :func:`~websockets.asyncio.connection.broadcast`:: - import websockets + from websockets.asyncio.connection import broadcast def message_all(message): - websockets.broadcast(CONNECTIONS, message) + broadcast(CONNECTIONS, message) If you're running multiple server processes, make sure you call ``message_all`` in each process. @@ -129,7 +143,7 @@ Record connections in a global variable, keyed by user identifier:: finally: del CONNECTIONS[user_id] -Then, call :meth:`~legacy.protocol.WebSocketCommonProtocol.send`:: +Then, call :meth:`~ServerConnection.send`:: async def message_user(user_id, message): websocket = CONNECTIONS[user_id] # raises KeyError if user disconnected @@ -178,15 +192,12 @@ How do I pass arguments to the connection handler? You can bind additional arguments to the connection handler with :func:`functools.partial`:: - import asyncio import functools - import websockets async def handler(websocket, extra_argument): ... bound_handler = functools.partial(handler, extra_argument=42) - start_server = websockets.serve(bound_handler, ...) Another way to achieve this result is to define the ``handler`` coroutine in a scope where the ``extra_argument`` variable exists instead of injecting it @@ -195,14 +206,14 @@ through an argument. How do I access the request path? --------------------------------- -It is available in the :attr:`~server.WebSocketServerProtocol.path` attribute. +It is available in the :attr:`~ServerConnection.request` object. You may route a connection to different handlers depending on the request path:: async def handler(websocket): - if websocket.path == "/blue": + if websocket.request.path == "/blue": await blue_handler(websocket) - elif websocket.path == "/green": + elif websocket.request.path == "/green": await green_handler(websocket) else: # No handler for this path; close the connection. @@ -219,35 +230,46 @@ it may ignore the request path entirely. How do I access HTTP headers? ----------------------------- -To access HTTP headers during the WebSocket handshake, you can override -:attr:`~server.WebSocketServerProtocol.process_request`:: +You can access HTTP headers during the WebSocket handshake by providing a +``process_request`` callable or coroutine:: - async def process_request(self, path, request_headers): - authorization = request_headers["Authorization"] + def process_request(connection, request): + authorization = request.headers["Authorization"] + ... + + async with serve(handler, process_request=process_request): + ... -Once the connection is established, HTTP headers are available in -:attr:`~server.WebSocketServerProtocol.request_headers` and -:attr:`~server.WebSocketServerProtocol.response_headers`:: +Once the connection is established, HTTP headers are available in the +:attr:`~ServerConnection.request` and :attr:`~ServerConnection.response` +objects:: async def handler(websocket): - authorization = websocket.request_headers["Authorization"] + authorization = websocket.request.headers["Authorization"] How do I set HTTP headers? -------------------------- To set the ``Sec-WebSocket-Extensions`` or ``Sec-WebSocket-Protocol`` headers in the WebSocket handshake response, use the ``extensions`` or ``subprotocols`` -arguments of :func:`~server.serve`. +arguments of :func:`~serve`. To override the ``Server`` header, use the ``server_header`` argument. Set it to :obj:`None` to remove the header. -To set other HTTP headers, use the ``extra_headers`` argument. +To set other HTTP headers, provide a ``process_response`` callable or +coroutine:: + + def process_response(connection, request, response): + response.headers["X-Blessing"] = "May the network be with you" + + async with serve(handler, process_response=process_response): + ... How do I get the IP address of the client? ------------------------------------------ -It's available in :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address`:: +It's available in :attr:`~ServerConnection.remote_address`:: async def handler(websocket): remote_ip = websocket.remote_address[0] @@ -255,18 +277,19 @@ It's available in :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address How do I set the IP addresses that my server listens on? -------------------------------------------------------- -Use the ``host`` argument of :meth:`~asyncio.loop.create_server`:: +Use the ``host`` argument of :meth:`~serve`:: - await websockets.serve(handler, host="192.168.0.1", port=8080) + async with serve(handler, host="192.168.0.1", port=8080): + ... -:func:`~server.serve` accepts the same arguments as -:meth:`~asyncio.loop.create_server`. +:func:`~serve` accepts the same arguments as +:meth:`~asyncio.loop.create_server` and passes them through. What does ``OSError: [Errno 99] error while attempting to bind on address ('::1', 80, 0, 0): address not available`` mean? -------------------------------------------------------------------------------------------------------------------------- -You are calling :func:`~server.serve` without a ``host`` argument in a context -where IPv6 isn't available. +You are calling :func:`~serve` without a ``host`` argument in a context where +IPv6 isn't available. To listen only on IPv4, specify ``host="0.0.0.0"`` or ``family=socket.AF_INET``. @@ -280,17 +303,17 @@ websockets takes care of closing the connection when the handler exits. How do I stop a server? ----------------------- -Exit the :func:`~server.serve` context manager. +Exit the :func:`~serve` context manager. Here's an example that terminates cleanly when it receives SIGTERM on Unix: .. literalinclude:: ../../example/faq/shutdown_server.py - :emphasize-lines: 12-15,18 + :emphasize-lines: 13-16,19 How do I stop a server while keeping existing connections open? --------------------------------------------------------------- -Call the server's :meth:`~server.WebSocketServer.close` method with +Call the server's :meth:`~WebSocketServer.close` method with ``close_connections=False``. Here's how to adapt the example just above:: @@ -298,7 +321,7 @@ Here's how to adapt the example just above:: async def server(): ... - server = await websockets.serve(echo, "localhost", 8765) + server = await serve(echo, "localhost", 8765) await stop server.close(close_connections=False) await server.wait_closed() @@ -306,14 +329,14 @@ Here's how to adapt the example just above:: How do I implement a health check? ---------------------------------- -Intercept WebSocket handshake requests with the -:meth:`~server.WebSocketServerProtocol.process_request` hook. - -When a request is sent to the health check endpoint, treat is as an HTTP request -and return a ``(status, headers, body)`` tuple, as in this example: +Intercept requests with the ``process_request`` hook. When a request is sent to +the health check endpoint, treat is as an HTTP request and return a response: .. literalinclude:: ../../example/faq/health_check_server.py - :emphasize-lines: 7-9,18 + :emphasize-lines: 7-9,16 + +:meth:`~ServerConnection.respond` makes it easy to send a plain text response. +You can also construct a :class:`~websockets.http11.Response` object directly. How do I run HTTP and WebSocket servers on the same port? --------------------------------------------------------- @@ -327,7 +350,7 @@ Providing an HTTP server is out of scope for websockets. It only aims at providing a WebSocket server. There's limited support for returning HTTP responses with the -:attr:`~server.WebSocketServerProtocol.process_request` hook. +``process_request`` hook. If you need more, pick an HTTP server and run it separately. diff --git a/docs/howto/cheatsheet.rst b/docs/howto/cheatsheet.rst index 95b551f67..8df2f234b 100644 --- a/docs/howto/cheatsheet.rst +++ b/docs/howto/cheatsheet.rst @@ -9,24 +9,24 @@ Server * Write a coroutine that handles a single connection. It receives a WebSocket protocol instance and the URI path in argument. - * Call :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` and - :meth:`~legacy.protocol.WebSocketCommonProtocol.send` to receive and send - messages at any time. + * Call :meth:`~asyncio.connection.Connection.recv` and + :meth:`~asyncio.connection.Connection.send` to receive and send messages at + any time. - * When :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` or - :meth:`~legacy.protocol.WebSocketCommonProtocol.send` raises - :exc:`~exceptions.ConnectionClosed`, clean up and exit. If you started - other :class:`asyncio.Task`, terminate them before exiting. + * When :meth:`~asyncio.connection.Connection.recv` or + :meth:`~asyncio.connection.Connection.send` raises + :exc:`~exceptions.ConnectionClosed`, clean up and exit. If you started other + :class:`asyncio.Task`, terminate them before exiting. - * If you aren't awaiting :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`, - consider awaiting :meth:`~legacy.protocol.WebSocketCommonProtocol.wait_closed` - to detect quickly when the connection is closed. + * If you aren't awaiting :meth:`~asyncio.connection.Connection.recv`, consider + awaiting :meth:`~asyncio.connection.Connection.wait_closed` to detect + quickly when the connection is closed. - * You may :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` or - :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` if you wish but it isn't - needed in general. + * You may :meth:`~asyncio.connection.Connection.ping` or + :meth:`~asyncio.connection.Connection.pong` if you wish but it isn't needed + in general. -* Create a server with :func:`~server.serve` which is similar to asyncio's +* Create a server with :func:`~asyncio.server.serve` which is similar to asyncio's :meth:`~asyncio.loop.create_server`. You can also use it as an asynchronous context manager. @@ -35,30 +35,30 @@ Server handler exits normally or with an exception. * For advanced customization, you may subclass - :class:`~server.WebSocketServerProtocol` and pass either this subclass or - a factory function as the ``create_protocol`` argument. + :class:`~asyncio.server.ServerConnection` and pass either this subclass or a + factory function as the ``create_connection`` argument. Client ------ -* Create a client with :func:`~client.connect` which is similar to asyncio's - :meth:`~asyncio.loop.create_connection`. You can also use it as an +* Create a client with :func:`~asyncio.client.connect` which is similar to + asyncio's :meth:`~asyncio.loop.create_connection`. You can also use it as an asynchronous context manager. * For advanced customization, you may subclass - :class:`~client.WebSocketClientProtocol` and pass either this subclass or - a factory function as the ``create_protocol`` argument. + :class:`~asyncio.client.ClientConnection` and pass either this subclass or + a factory function as the ``create_connection`` argument. -* Call :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` and - :meth:`~legacy.protocol.WebSocketCommonProtocol.send` to receive and send messages - at any time. +* Call :meth:`~asyncio.connection.Connection.recv` and + :meth:`~asyncio.connection.Connection.send` to receive and send messages at + any time. -* You may :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` or - :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` if you wish but it isn't - needed in general. +* You may :meth:`~asyncio.connection.Connection.ping` or + :meth:`~asyncio.connection.Connection.pong` if you wish but it isn't needed in + general. -* If you aren't using :func:`~client.connect` as a context manager, call - :meth:`~legacy.protocol.WebSocketCommonProtocol.close` to terminate the connection. +* If you aren't using :func:`~asyncio.client.connect` as a context manager, call + :meth:`~asyncio.connection.Connection.close` to terminate the connection. .. _debugging: @@ -84,4 +84,3 @@ particular. Fortunately Python's official documentation provides advice to `develop with asyncio`_. Check it out: it's invaluable! .. _develop with asyncio: https://docs.python.org/3/library/asyncio-dev.html - diff --git a/docs/howto/django.rst b/docs/howto/django.rst index e3da0a878..dada9c5e4 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -124,7 +124,7 @@ support asynchronous I/O. It would block the event loop if it didn't run in a separate thread. :func:`~asyncio.to_thread` is available since Python 3.9. In earlier versions, use :meth:`~asyncio.loop.run_in_executor` instead. -Finally, we start a server with :func:`~websockets.server.serve`. +Finally, we start a server with :func:`~websockets.asyncio.server.serve`. We're ready to test! diff --git a/docs/howto/heroku.rst b/docs/howto/heroku.rst index a97d2e7ce..b335e14c5 100644 --- a/docs/howto/heroku.rst +++ b/docs/howto/heroku.rst @@ -42,7 +42,7 @@ Here's the implementation of the app, an echo server. Save it in a file called Heroku expects the server to `listen on a specific port`_, which is provided in the ``$PORT`` environment variable. The app reads it and passes it to -:func:`~websockets.server.serve`. +:func:`~websockets.asyncio.server.serve`. .. _listen on a specific port: https://devcenter.heroku.com/articles/preparing-a-codebase-for-heroku-deployment#4-listen-on-the-correct-port diff --git a/docs/howto/nginx.rst b/docs/howto/nginx.rst index ff42c3c2b..872353cad 100644 --- a/docs/howto/nginx.rst +++ b/docs/howto/nginx.rst @@ -21,9 +21,9 @@ We'd like nginx to connect to websockets servers via Unix sockets in order to avoid the overhead of TCP for communicating between processes running in the same OS. -We start the app with :func:`~websockets.server.unix_serve`. Each server -process listens on a different socket thanks to an environment variable set -by Supervisor to a different value. +We start the app with :func:`~websockets.asyncio.server.unix_serve`. Each server +process listens on a different socket thanks to an environment variable set by +Supervisor to a different value. Save this configuration to ``supervisord.conf``: diff --git a/docs/howto/patterns.rst b/docs/howto/patterns.rst index c6f325d21..60bc8ab42 100644 --- a/docs/howto/patterns.rst +++ b/docs/howto/patterns.rst @@ -8,7 +8,7 @@ client. You will certainly implement some of them in your application. This page gives examples of connection handlers for a server. However, they're also applicable to a client, simply by assuming that ``websocket`` is a -connection created with :func:`~client.connect`. +connection created with :func:`~asyncio.client.connect`. WebSocket connections are long-lived. You will usually write a loop to process several messages during the lifetime of a connection. @@ -42,10 +42,10 @@ In this example, ``producer()`` is a coroutine implementing your business logic for generating the next message to send on the WebSocket connection. Each message must be :class:`str` or :class:`bytes`. -Iteration terminates when the client disconnects -because :meth:`~server.WebSocketServerProtocol.send` raises a -:exc:`~exceptions.ConnectionClosed` exception, -which breaks out of the ``while True`` loop. +Iteration terminates when the client disconnects because +:meth:`~asyncio.server.ServerConnection.send` raises a +:exc:`~exceptions.ConnectionClosed` exception, which breaks out of the ``while +True`` loop. Consumer and producer --------------------- diff --git a/docs/howto/quickstart.rst b/docs/howto/quickstart.rst index ab870952c..e6bd362a4 100644 --- a/docs/howto/quickstart.rst +++ b/docs/howto/quickstart.rst @@ -17,9 +17,9 @@ It receives a name from the client, sends a greeting, and closes the connection. :language: python :linenos: -:func:`~server.serve` executes the connection handler coroutine ``hello()`` -once for each WebSocket connection. It closes the WebSocket connection when -the handler returns. +:func:`~asyncio.server.serve` executes the connection handler coroutine +``hello()`` once for each WebSocket connection. It closes the WebSocket +connection when the handler returns. Here's a corresponding WebSocket client. @@ -30,8 +30,8 @@ It sends a name to the server, receives a greeting, and closes the connection. :language: python :linenos: -Using :func:`~client.connect` as an asynchronous context manager ensures the -WebSocket connection is closed. +Using :func:`~asyncio.client.connect` as an asynchronous context manager ensures +the WebSocket connection is closed. .. _secure-server-example: @@ -73,8 +73,8 @@ In this example, the client needs a TLS context because the server uses a self-signed certificate. When connecting to a secure WebSocket server with a valid certificate — any -certificate signed by a CA that your Python installation trusts — you can -simply pass ``ssl=True`` to :func:`~client.connect`. +certificate signed by a CA that your Python installation trusts — you can simply +pass ``ssl=True`` to :func:`~asyncio.client.connect`. .. admonition:: Configure the TLS context securely :class: attention diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 16b010aca..40c8c5ec9 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -49,12 +49,13 @@ the release notes of the version in which the feature was deprecated. * The ``path`` argument of connection handlers — unnecessary since :ref:`10.1` and deprecated in :ref:`13.0`. -* The ``loop`` and ``legacy_recv`` arguments of :func:`~client.connect` and - :func:`~server.serve`, which were removed — deprecated in :ref:`10.0`. -* The ``timeout`` and ``klass`` arguments of :func:`~client.connect` and - :func:`~server.serve`, which were renamed to ``close_timeout`` and +* The ``loop`` and ``legacy_recv`` arguments of :func:`~legacy.client.connect` + and :func:`~legacy.server.serve`, which were removed — deprecated in + :ref:`10.0`. +* The ``timeout`` and ``klass`` arguments of :func:`~legacy.client.connect` and + :func:`~legacy.server.serve`, which were renamed to ``close_timeout`` and ``create_protocol`` — deprecated in :ref:`7.0` and :ref:`3.4` respectively. -* An empty string in the ``origins`` argument of :func:`~server.serve` — +* An empty string in the ``origins`` argument of :func:`~legacy.server.serve` — deprecated in :ref:`7.0`. * The ``host``, ``port``, and ``secure`` attributes of connections — deprecated in :ref:`8.0`. @@ -127,16 +128,16 @@ Client APIs | Legacy :mod:`asyncio` implementation | New :mod:`asyncio` implementation | +===================================================================+=====================================================+ | ``websockets.connect()`` |br| | :func:`websockets.asyncio.client.connect` | -| :func:`websockets.client.connect` |br| | | -| ``websockets.legacy.client.connect()`` | | +| ``websockets.client.connect()`` |br| | | +| :func:`websockets.legacy.client.connect` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.unix_connect()`` |br| | :func:`websockets.asyncio.client.unix_connect` | -| :func:`websockets.client.unix_connect` |br| | | -| ``websockets.legacy.client.unix_connect()`` | | +| ``websockets.client.unix_connect()`` |br| | | +| :func:`websockets.legacy.client.unix_connect` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.WebSocketClientProtocol`` |br| | :class:`websockets.asyncio.client.ClientConnection` | -| :class:`websockets.client.WebSocketClientProtocol` |br| | | -| ``websockets.legacy.client.WebSocketClientProtocol`` | | +| ``websockets.client.WebSocketClientProtocol`` |br| | | +| :class:`websockets.legacy.client.WebSocketClientProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ Server APIs @@ -146,31 +147,31 @@ Server APIs | Legacy :mod:`asyncio` implementation | New :mod:`asyncio` implementation | +===================================================================+=====================================================+ | ``websockets.serve()`` |br| | :func:`websockets.asyncio.server.serve` | -| :func:`websockets.server.serve` |br| | | -| ``websockets.legacy.server.serve()`` | | +| ``websockets.server.serve()`` |br| | | +| :func:`websockets.legacy.server.serve` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.unix_serve()`` |br| | :func:`websockets.asyncio.server.unix_serve` | -| :func:`websockets.server.unix_serve` |br| | | -| ``websockets.legacy.server.unix_serve()`` | | +| ``websockets.server.unix_serve()`` |br| | | +| :func:`websockets.legacy.server.unix_serve` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.WebSocketServer`` |br| | :class:`websockets.asyncio.server.WebSocketServer` | -| :class:`websockets.server.WebSocketServer` |br| | | -| ``websockets.legacy.server.WebSocketServer`` | | +| ``websockets.server.WebSocketServer`` |br| | | +| :class:`websockets.legacy.server.WebSocketServer` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.WebSocketServerProtocol`` |br| | :class:`websockets.asyncio.server.ServerConnection` | -| :class:`websockets.server.WebSocketServerProtocol` |br| | | -| ``websockets.legacy.server.WebSocketServerProtocol`` | | +| ``websockets.server.WebSocketServerProtocol`` |br| | | +| :class:`websockets.legacy.server.WebSocketServerProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.broadcast`` |br| | :func:`websockets.asyncio.connection.broadcast` | | :func:`websockets.legacy.protocol.broadcast()` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.BasicAuthWebSocketServerProtocol`` |br| | *not available yet* | -| :class:`websockets.auth.BasicAuthWebSocketServerProtocol` |br| | | -| ``websockets.legacy.auth.BasicAuthWebSocketServerProtocol`` | | +| ``websockets.auth.BasicAuthWebSocketServerProtocol`` |br| | | +| :class:`websockets.legacy.auth.BasicAuthWebSocketServerProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.basic_auth_protocol_factory()`` |br| | *not available yet* | -| :func:`websockets.auth.basic_auth_protocol_factory` |br| | | -| ``websockets.legacy.auth.basic_auth_protocol_factory()`` | | +| ``websockets.auth.basic_auth_protocol_factory()`` |br| | | +| :func:`websockets.legacy.auth.basic_auth_protocol_factory` | | +-------------------------------------------------------------------+-----------------------------------------------------+ .. _Review API changes: @@ -209,12 +210,12 @@ Customizing the opening handshake ................................. On the client side, if you're adding headers to the handshake request sent by -:func:`~client.connect` with the ``extra_headers`` argument, you must rename it -to ``additional_headers``. +:func:`~legacy.client.connect` with the ``extra_headers`` argument, you must +rename it to ``additional_headers``. -On the server side, if you're customizing how :func:`~server.serve` processes -the opening handshake with the ``process_request``, ``extra_headers``, or -``select_subprotocol``, you must update your code. ``process_response`` and +On the server side, if you're customizing how :func:`~legacy.server.serve` +processes the opening handshake with the ``process_request``, ``extra_headers``, +or ``select_subprotocol``, you must update your code. ``process_response`` and ``select_subprotocol`` have new signatures; ``process_response`` replaces ``extra_headers`` and provides more flexibility. @@ -242,10 +243,10 @@ an example:: ``connection`` is always available in ``process_request``. In the original implementation, you had to write a subclass of -:class:`~server.WebSocketServerProtocol` and pass it in the ``create_protocol`` -argument to make the connection object available in a ``process_request`` -method. This pattern isn't useful anymore; you can replace it with a -``process_request`` function or coroutine. +:class:`~legacy.server.WebSocketServerProtocol` and pass it in the +``create_protocol`` argument to make the connection object available in a +``process_request`` method. This pattern isn't useful anymore; you can replace +it with a ``process_request`` function or coroutine. ``path`` and ``headers`` are available as attributes of the ``request`` object. @@ -296,7 +297,7 @@ The signature of ``select_subprotocol`` changed. Here's an example:: ``connection`` is always available in ``select_subprotocol``. This brings the same benefits as in ``process_request``. It may remove the need to subclass of -:class:`~server.WebSocketServerProtocol`. +:class:`~legacy.server.WebSocketServerProtocol`. The ``subprotocols`` argument contains the list of subprotocols offered by the client. The list of subprotocols supported by the server was removed because @@ -320,7 +321,7 @@ update its name. The keyword argument of :func:`~asyncio.server.serve` for customizing the creation of the connection object is now called ``create_connection`` instead of ``create_protocol``. It must return a :class:`~asyncio.server.ServerConnection` -instead of a :class:`~server.WebSocketServerProtocol`. +instead of a :class:`~legacy.server.WebSocketServerProtocol`. If you were customizing connection objects, you should check the new implementation and possibly redo your customization. Keep in mind that the @@ -364,6 +365,28 @@ The ``write_limit`` argument of :func:`~asyncio.client.connect` and Attributes of connections ......................... +``path``, ``request_headers`` and ``response_headers`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The :attr:`~legacy.protocol.WebSocketCommonProtocol.path`, +:attr:`~legacy.protocol.WebSocketCommonProtocol.request_headers` and +:attr:`~legacy.protocol.WebSocketCommonProtocol.response_headers` properties are +replaced by :attr:`~asyncio.connection.Connection.request` and +:attr:`~asyncio.connection.Connection.response`, which provide a ``headers`` +attribute. + +If your code relies on them, you can replace:: + + connection.path + connection.request_headers + connection.response_headers + +with:: + + connection.request.path + connection.request.headers + connection.response.headers + ``open`` and ``closed`` ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/intro/tutorial1.rst b/docs/intro/tutorial1.rst index 6b32d47f6..74f5f79a3 100644 --- a/docs/intro/tutorial1.rst +++ b/docs/intro/tutorial1.rst @@ -184,7 +184,7 @@ Create an ``app.py`` file next to ``connect4.py`` with this content: import asyncio - import websockets + from websockets.asyncio.server import serve async def handler(websocket): @@ -194,7 +194,7 @@ Create an ``app.py`` file next to ``connect4.py`` with this content: async def main(): - async with websockets.serve(handler, "", 8001): + async with serve(handler, "", 8001): await asyncio.get_running_loop().create_future() # run forever @@ -204,8 +204,9 @@ Create an ``app.py`` file next to ``connect4.py`` with this content: The entry point of this program is ``asyncio.run(main())``. It creates an asyncio event loop, runs the ``main()`` coroutine, and shuts down the loop. -The ``main()`` coroutine calls :func:`~server.serve` to start a websockets -server. :func:`~server.serve` takes three positional arguments: +The ``main()`` coroutine calls :func:`~asyncio.server.serve` to start a +websockets server. :func:`~asyncio.server.serve` takes three positional +arguments: * ``handler`` is a coroutine that manages a connection. When a client connects, websockets calls ``handler`` with the connection in argument. @@ -215,7 +216,7 @@ server. :func:`~server.serve` takes three positional arguments: on the same local network can connect. * The third argument is the port on which the server listens. -Invoking :func:`~server.serve` as an asynchronous context manager, in an +Invoking :func:`~asyncio.server.serve` as an asynchronous context manager, in an ``async with`` block, ensures that the server shuts down properly when terminating the program. @@ -258,11 +259,11 @@ stack trace of an exception: ... websockets.exceptions.ConnectionClosedOK: received 1000 (OK); then sent 1000 (OK) -Indeed, the server was waiting for the next message -with :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` when the client -disconnected. When this happens, websockets raises -a :exc:`~exceptions.ConnectionClosedOK` exception to let you know that you -won't receive another message on this connection. +Indeed, the server was waiting for the next message with +:meth:`~asyncio.server.ServerConnection.recv` when the client disconnected. +When this happens, websockets raises a :exc:`~exceptions.ConnectionClosedOK` +exception to let you know that you won't receive another message on this +connection. This exception creates noise in the server logs, making it more difficult to spot real errors when you add functionality to the server. Catch it in the @@ -551,13 +552,12 @@ Summary In this first part of the tutorial, you learned how to: -* build and run a WebSocket server in Python with :func:`~server.serve`; -* receive a message in a connection handler - with :meth:`~server.WebSocketServerProtocol.recv`; -* send a message in a connection handler - with :meth:`~server.WebSocketServerProtocol.send`; -* iterate over incoming messages with ``async for - message in websocket: ...``; +* build and run a WebSocket server in Python with :func:`~asyncio.server.serve`; +* receive a message in a connection handler with + :meth:`~asyncio.server.ServerConnection.recv`; +* send a message in a connection handler with + :meth:`~asyncio.server.ServerConnection.send`; +* iterate over incoming messages with ``async for message in websocket: ...``; * open a WebSocket connection in JavaScript with the ``WebSocket`` API; * send messages in a browser with ``WebSocket.send()``; * receive messages in a browser by listening to ``message`` events; diff --git a/docs/intro/tutorial3.rst b/docs/intro/tutorial3.rst index 6fdec113b..21d51371b 100644 --- a/docs/intro/tutorial3.rst +++ b/docs/intro/tutorial3.rst @@ -93,9 +93,9 @@ called ``stop`` and registers a signal handler that sets the result of this future. The value of the future doesn't matter; it's only for waiting for ``SIGTERM``. -Then, by using :func:`~server.serve` as a context manager and exiting the -context when ``stop`` has a result, ``main()`` ensures that the server closes -connections cleanly and exits on ``SIGTERM``. +Then, by using :func:`~asyncio.server.serve` as a context manager and exiting +the context when ``stop`` has a result, ``main()`` ensures that the server +closes connections cleanly and exits on ``SIGTERM``. The app is now fully compatible with Heroku. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index eaabb2e9f..df5af54f4 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -178,8 +178,8 @@ Backwards-incompatible changes As a consequence, calling ``WebSocket.close()`` without arguments in a browser isn't reported as an error anymore. -.. admonition:: :func:`~server.serve` times out on the opening handshake after - 10 seconds by default. +.. admonition:: :func:`~legacy.server.serve` times out on the opening handshake + after 10 seconds by default. :class: note You can adjust the timeout with the ``open_timeout`` parameter. Set it to @@ -200,7 +200,7 @@ New features See :func:`websockets.sync.client.connect` and :func:`websockets.sync.server.serve` for details. -* Added ``open_timeout`` to :func:`~server.serve`. +* Added ``open_timeout`` to :func:`~legacy.server.serve`. * Made it possible to close a server without closing existing connections. @@ -289,7 +289,7 @@ Bug fixes * Fixed backwards-incompatibility in 10.1 for connection handlers created with :func:`functools.partial`. -* Avoided leaking open sockets when :func:`~client.connect` is canceled. +* Avoided leaking open sockets when :func:`~legacy.client.connect` is canceled. .. _10.1: @@ -330,8 +330,8 @@ Improvements .. _AWS API Gateway: https://github.com/python-websockets/websockets/issues/1065 -* Mirrored the entire :class:`~asyncio.Server` API - in :class:`~server.WebSocketServer`. +* Mirrored the entire :class:`~asyncio.Server` API in + :class:`~legacy.server.WebSocketServer`. * Improved performance for large messages on ARM processors. @@ -364,9 +364,9 @@ Backwards-incompatible changes Python 3.10 for details. The ``loop`` parameter is also removed - from :class:`~server.WebSocketServer`. This should be transparent. + from :class:`~legacy.server.WebSocketServer`. This should be transparent. -.. admonition:: :func:`~client.connect` times out after 10 seconds by default. +.. admonition:: :func:`~legacy.client.connect` times out after 10 seconds by default. :class: note You can adjust the timeout with the ``open_timeout`` parameter. Set it to @@ -405,9 +405,9 @@ New features * Added :func:`~legacy.protocol.broadcast` to send a message to many clients. * Added support for reconnecting automatically by using - :func:`~client.connect` as an asynchronous iterator. + :func:`~legacy.client.connect` as an asynchronous iterator. -* Added ``open_timeout`` to :func:`~client.connect`. +* Added ``open_timeout`` to :func:`~legacy.client.connect`. * Documented how to integrate with `Django `_. @@ -427,12 +427,12 @@ Improvements * Optimized processing of client-to-server messages when the C extension isn't available. -* Supported relative redirects in :func:`~client.connect`. +* Supported relative redirects in :func:`~legacy.client.connect`. * Handled TCP connection drops during the opening handshake. * Made it easier to customize authentication with - :meth:`~auth.BasicAuthWebSocketServerProtocol.check_credentials`. + :meth:`~legacy.auth.BasicAuthWebSocketServerProtocol.check_credentials`. * Provided additional information in :exc:`~exceptions.ConnectionClosed` exceptions. @@ -590,7 +590,7 @@ Bug fixes ......... * Restored the ability to pass a socket with the ``sock`` parameter of - :func:`~server.serve`. + :func:`~legacy.server.serve`. * Removed an incorrect assertion when a connection drops. @@ -623,11 +623,12 @@ Backwards-incompatible changes .. admonition:: ``process_request`` is now expected to be a coroutine. :class: note - If you're passing a ``process_request`` argument to :func:`~server.serve` - or :class:`~server.WebSocketServerProtocol`, or if you're overriding - :meth:`~server.WebSocketServerProtocol.process_request` in a subclass, - define it with ``async def`` instead of ``def``. Previously, both were - supported. + If you're passing a ``process_request`` argument to + :func:`~legacy.server.serve` or + :class:`~legacy.server.WebSocketServerProtocol`, or if you're overriding + :meth:`~legacy.server.WebSocketServerProtocol.process_request` in a + subclass, define it with ``async def`` instead of ``def``. Previously, both + were supported. For backwards compatibility, functions are still accepted, but mixing functions and coroutines won't work in some inheritance scenarios. @@ -661,15 +662,15 @@ Backwards-incompatible changes New features ............ -* Added :func:`~auth.basic_auth_protocol_factory` to enforce HTTP - Basic Auth on the server side. +* Added :func:`~legacy.auth.basic_auth_protocol_factory` to enforce HTTP Basic + Auth on the server side. -* :func:`~client.connect` handles redirects from the server during the +* :func:`~legacy.client.connect` handles redirects from the server during the handshake. -* :func:`~client.connect` supports overriding ``host`` and ``port``. +* :func:`~legacy.client.connect` supports overriding ``host`` and ``port``. -* Added :func:`~client.unix_connect` for connecting to Unix sockets. +* Added :func:`~legacy.client.unix_connect` for connecting to Unix sockets. * Added support for asynchronous generators in :meth:`~legacy.protocol.WebSocketCommonProtocol.send` @@ -699,9 +700,8 @@ Improvements :exc:`~exceptions.ConnectionClosed` to tell apart normal connection termination from errors. -* Changed :meth:`WebSocketServer.close() - ` to perform a proper closing handshake - instead of failing the connection. +* Changed :meth:`WebSocketServer.close() ` + to perform a proper closing handshake instead of failing the connection. * Improved error messages when HTTP parsing fails. @@ -734,7 +734,7 @@ Backwards-incompatible changes See :class:`~legacy.protocol.WebSocketCommonProtocol` for details. .. admonition:: Termination of connections by :meth:`WebSocketServer.close() - ` changes. + ` changes. :class: caution Previously, connections handlers were canceled. Now, connections are @@ -758,15 +758,16 @@ Backwards-incompatible changes Concurrent calls lead to non-deterministic behavior because there are no guarantees about which coroutine will receive which message. -.. admonition:: The ``timeout`` argument of :func:`~server.serve` - and :func:`~client.connect` is renamed to ``close_timeout`` . +.. admonition:: The ``timeout`` argument of :func:`~legacy.server.serve` + and :func:`~legacy.client.connect` is renamed to ``close_timeout`` . :class: note This prevents confusion with ``ping_timeout``. For backwards compatibility, ``timeout`` is still supported. -.. admonition:: The ``origins`` argument of :func:`~server.serve` changes. +.. admonition:: The ``origins`` argument of :func:`~legacy.server.serve` + changes. :class: note Include :obj:`None` in the list rather than ``''`` to allow requests that @@ -786,10 +787,10 @@ New features ............ * Added ``process_request`` and ``select_subprotocol`` arguments to - :func:`~server.serve` and - :class:`~server.WebSocketServerProtocol` to facilitate customization of - :meth:`~server.WebSocketServerProtocol.process_request` and - :meth:`~server.WebSocketServerProtocol.select_subprotocol`. + :func:`~legacy.server.serve` and + :class:`~legacy.server.WebSocketServerProtocol` to facilitate customization of + :meth:`~legacy.server.WebSocketServerProtocol.process_request` and + :meth:`~legacy.server.WebSocketServerProtocol.select_subprotocol`. * Added support for sending fragmented messages. @@ -826,10 +827,10 @@ Backwards-incompatible changes several APIs are updated to use it. :class: caution - * The ``request_headers`` argument - of :meth:`~server.WebSocketServerProtocol.process_request` is now - a :class:`~datastructures.Headers` instead of - an ``http.client.HTTPMessage``. + * The ``request_headers`` argument of + :meth:`~legacy.server.WebSocketServerProtocol.process_request` is now a + :class:`~datastructures.Headers` instead of an + ``http.client.HTTPMessage``. * The ``request_headers`` and ``response_headers`` attributes of :class:`~legacy.protocol.WebSocketCommonProtocol` are now @@ -866,7 +867,7 @@ Bug fixes ......... * Fixed a regression in 5.0 that broke some invocations of - :func:`~server.serve` and :func:`~client.connect`. + :func:`~legacy.server.serve` and :func:`~legacy.client.connect`. .. _5.0: @@ -900,10 +901,10 @@ Backwards-incompatible changes New features ............ -* :func:`~client.connect` performs HTTP Basic Auth when the URI contains +* :func:`~legacy.client.connect` performs HTTP Basic Auth when the URI contains credentials. -* :func:`~server.unix_serve` can be used as an asynchronous context +* :func:`~legacy.server.unix_serve` can be used as an asynchronous context manager on Python ≥ 3.5.1. * Added the :attr:`~legacy.protocol.WebSocketCommonProtocol.closed` property @@ -979,7 +980,7 @@ Backwards-incompatible changes Compression should improve performance but it increases RAM and CPU use. If you want to disable compression, add ``compression=None`` when calling - :func:`~server.serve` or :func:`~client.connect`. + :func:`~legacy.server.serve` or :func:`~legacy.client.connect`. .. admonition:: The ``state_name`` attribute of protocols is deprecated. :class: note @@ -992,10 +993,10 @@ New features * :class:`~legacy.protocol.WebSocketCommonProtocol` instances can be used as asynchronous iterators on Python ≥ 3.6. They yield incoming messages. -* Added :func:`~server.unix_serve` for listening on Unix sockets. +* Added :func:`~legacy.server.unix_serve` for listening on Unix sockets. -* Added the :attr:`~server.WebSocketServer.sockets` attribute to the - return value of :func:`~server.serve`. +* Added the :attr:`~legacy.server.WebSocketServer.sockets` attribute to the + return value of :func:`~legacy.server.serve`. * Allowed ``extra_headers`` to override ``Server`` and ``User-Agent`` headers. @@ -1030,24 +1031,24 @@ Backwards-incompatible changes by :class:`~exceptions.InvalidStatusCode`. :class: note - This exception is raised when :func:`~client.connect` receives an invalid + This exception is raised when :func:`~legacy.client.connect` receives an invalid response status code from the server. New features ............ -* :func:`~server.serve` can be used as an asynchronous context manager +* :func:`~legacy.server.serve` can be used as an asynchronous context manager on Python ≥ 3.5.1. * Added support for customizing handling of incoming connections with - :meth:`~server.WebSocketServerProtocol.process_request`. + :meth:`~legacy.server.WebSocketServerProtocol.process_request`. * Made read and write buffer sizes configurable. Improvements ............ -* Renamed :func:`~server.serve` and :func:`~client.connect`'s +* Renamed :func:`~legacy.server.serve` and :func:`~legacy.client.connect`'s ``klass`` argument to ``create_protocol`` to reflect that it can also be a callable. For backwards compatibility, ``klass`` is still supported. @@ -1058,7 +1059,7 @@ Improvements Bug fixes ......... -* Providing a ``sock`` argument to :func:`~client.connect` no longer +* Providing a ``sock`` argument to :func:`~legacy.client.connect` no longer crashes. .. _3.3: @@ -1094,7 +1095,7 @@ New features ............ * Added ``timeout``, ``max_size``, and ``max_queue`` arguments to - :func:`~client.connect` and :func:`~server.serve`. + :func:`~legacy.client.connect` and :func:`~legacy.server.serve`. Improvements ............ @@ -1151,15 +1152,15 @@ Backwards-incompatible changes In order to avoid stranding projects built upon an earlier version, the previous behavior can be restored by passing ``legacy_recv=True`` to - :func:`~server.serve`, :func:`~client.connect`, - :class:`~server.WebSocketServerProtocol`, or - :class:`~client.WebSocketClientProtocol`. + :func:`~legacy.server.serve`, :func:`~legacy.client.connect`, + :class:`~legacy.server.WebSocketServerProtocol`, or + :class:`~legacy.client.WebSocketClientProtocol`. New features ............ -* :func:`~client.connect` can be used as an asynchronous context - manager on Python ≥ 3.5.1. +* :func:`~legacy.client.connect` can be used as an asynchronous context manager + on Python ≥ 3.5.1. * :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` and :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` support data passed as @@ -1260,8 +1261,8 @@ New features * Added support for subprotocols. -* Added ``loop`` argument to :func:`~client.connect` and - :func:`~server.serve`. +* Added ``loop`` argument to :func:`~legacy.client.connect` and + :func:`~legacy.server.serve`. .. _2.3: diff --git a/docs/reference/asyncio/client.rst b/docs/reference/asyncio/client.rst index f9ce2f2d8..77a3c5d53 100644 --- a/docs/reference/asyncio/client.rst +++ b/docs/reference/asyncio/client.rst @@ -1,24 +1,28 @@ -Client (legacy :mod:`asyncio`) -============================== +Client (new :mod:`asyncio`) +=========================== -.. automodule:: websockets.client +.. automodule:: websockets.asyncio.client Opening a connection -------------------- -.. autofunction:: connect(uri, *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: connect :async: -.. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: unix_connect :async: Using a connection ------------------ -.. autoclass:: WebSocketClientProtocol(*, logger=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) +.. autoclass:: ClientConnection + + .. automethod:: __aiter__ .. automethod:: recv + .. automethod:: recv_streaming + .. automethod:: send .. automethod:: close @@ -39,26 +43,15 @@ Using a connection .. autoproperty:: remote_address - .. autoproperty:: open - - .. autoproperty:: closed - .. autoattribute:: latency + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: - .. autoattribute:: path - - .. autoattribute:: request_headers - - .. autoattribute:: response_headers - - .. autoattribute:: subprotocol - - The following attributes are available after the closing handshake, - once the WebSocket connection is closed: + .. autoattribute:: request - .. autoproperty:: close_code + .. autoattribute:: response - .. autoproperty:: close_reason + .. autoproperty:: subprotocol diff --git a/docs/reference/asyncio/common.rst b/docs/reference/asyncio/common.rst index aee774479..a58325fb9 100644 --- a/docs/reference/asyncio/common.rst +++ b/docs/reference/asyncio/common.rst @@ -1,14 +1,18 @@ :orphan: -Both sides (legacy :mod:`asyncio`) -================================== +Both sides (new :mod:`asyncio`) +=============================== -.. automodule:: websockets.legacy.protocol +.. automodule:: websockets.asyncio.connection -.. autoclass:: WebSocketCommonProtocol(*, logger=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) +.. autoclass:: Connection + + .. automethod:: __aiter__ .. automethod:: recv + .. automethod:: recv_streaming + .. automethod:: send .. automethod:: close @@ -29,26 +33,15 @@ Both sides (legacy :mod:`asyncio`) .. autoproperty:: remote_address - .. autoproperty:: open - - .. autoproperty:: closed - .. autoattribute:: latency + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: - .. autoattribute:: path - - .. autoattribute:: request_headers - - .. autoattribute:: response_headers - - .. autoattribute:: subprotocol - - The following attributes are available after the closing handshake, - once the WebSocket connection is closed: + .. autoattribute:: request - .. autoproperty:: close_code + .. autoattribute:: response - .. autoproperty:: close_reason + .. autoproperty:: subprotocol diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index 3636f0b33..7bceca5a0 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -1,19 +1,19 @@ -Server (legacy :mod:`asyncio`) -============================== +Server (new :mod:`asyncio`) +=========================== -.. automodule:: websockets.server +.. automodule:: websockets.asyncio.server -Starting a server +Creating a server ----------------- -.. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: serve :async: -.. autofunction:: unix_serve(ws_handler, path=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: unix_serve :async: -Stopping a server ------------------ +Running a server +---------------- .. autoclass:: WebSocketServer @@ -34,10 +34,14 @@ Stopping a server Using a connection ------------------ -.. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, logger=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) +.. autoclass:: ServerConnection + + .. automethod:: __aiter__ .. automethod:: recv + .. automethod:: recv_streaming + .. automethod:: send .. automethod:: close @@ -48,11 +52,7 @@ Using a connection .. automethod:: pong - You can customize the opening handshake in a subclass by overriding these methods: - - .. automethod:: process_request - - .. automethod:: select_subprotocol + .. automethod:: respond WebSocket connection objects also provide these attributes: @@ -64,50 +64,20 @@ Using a connection .. autoproperty:: remote_address - .. autoproperty:: open - - .. autoproperty:: closed - .. autoattribute:: latency + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: - .. autoattribute:: path - - .. autoattribute:: request_headers - - .. autoattribute:: response_headers - - .. autoattribute:: subprotocol - - The following attributes are available after the closing handshake, - once the WebSocket connection is closed: - - .. autoproperty:: close_code - - .. autoproperty:: close_reason - - -Basic authentication --------------------- - -.. automodule:: websockets.auth - -websockets supports HTTP Basic Authentication according to -:rfc:`7235` and :rfc:`7617`. - -.. autofunction:: basic_auth_protocol_factory - -.. autoclass:: BasicAuthWebSocketServerProtocol - - .. autoattribute:: realm + .. autoattribute:: request - .. autoattribute:: username + .. autoattribute:: response - .. automethod:: check_credentials + .. autoproperty:: subprotocol Broadcast --------- -.. autofunction:: websockets.legacy.protocol.broadcast +.. autofunction:: websockets.asyncio.connection.broadcast diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 45fa79c48..6840fe15b 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -182,5 +182,6 @@ Request if it is missing or invalid (`#1246`). The client API doesn't attempt to guarantee that there is no more than one connection to a given IP address in a CONNECTING state. This behavior is -mandated by :rfc:`6455`, section 4.1. However, :func:`~client.connect()` isn't -the right layer for enforcing this constraint. It's the caller's responsibility. +mandated by :rfc:`6455`, section 4.1. However, :func:`~asyncio.client.connect()` +isn't the right layer for enforcing this constraint. It's the caller's +responsibility. diff --git a/docs/reference/index.rst b/docs/reference/index.rst index d3a0e935c..77b538b78 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -13,12 +13,12 @@ Check which implementations support which features and known limitations. features +:mod:`asyncio` (new) +-------------------- -:mod:`asyncio` --------------- +It's ideal for servers that handle many clients concurrently. -This is the default implementation. It's ideal for servers that handle many -clients concurrently. +It's a rewrite of the legacy :mod:`asyncio` implementation. .. toctree:: :titlesonly: @@ -26,17 +26,16 @@ clients concurrently. asyncio/server asyncio/client -:mod:`asyncio` (new) --------------------- +:mod:`asyncio` (legacy) +----------------------- -This is a rewrite of the :mod:`asyncio` implementation. It will become the -default in the future. +This is the historical implementation. .. toctree:: :titlesonly: - new-asyncio/server - new-asyncio/client + legacy/server + legacy/client :mod:`threading` ---------------- diff --git a/docs/reference/legacy/client.rst b/docs/reference/legacy/client.rst new file mode 100644 index 000000000..fca45d218 --- /dev/null +++ b/docs/reference/legacy/client.rst @@ -0,0 +1,64 @@ +Client (legacy :mod:`asyncio`) +============================== + +.. automodule:: websockets.legacy.client + +Opening a connection +-------------------- + +.. autofunction:: connect(uri, *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) + :async: + +.. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) + :async: + +Using a connection +------------------ + +.. autoclass:: WebSocketClientProtocol(*, logger=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) + + .. automethod:: recv + + .. automethod:: send + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + .. autoproperty:: open + + .. autoproperty:: closed + + .. autoattribute:: latency + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: path + + .. autoattribute:: request_headers + + .. autoattribute:: response_headers + + .. autoattribute:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/legacy/common.rst b/docs/reference/legacy/common.rst new file mode 100644 index 000000000..aee774479 --- /dev/null +++ b/docs/reference/legacy/common.rst @@ -0,0 +1,54 @@ +:orphan: + +Both sides (legacy :mod:`asyncio`) +================================== + +.. automodule:: websockets.legacy.protocol + +.. autoclass:: WebSocketCommonProtocol(*, logger=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) + + .. automethod:: recv + + .. automethod:: send + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + .. autoproperty:: open + + .. autoproperty:: closed + + .. autoattribute:: latency + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: path + + .. autoattribute:: request_headers + + .. autoattribute:: response_headers + + .. autoattribute:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/legacy/server.rst b/docs/reference/legacy/server.rst new file mode 100644 index 000000000..c2758f5a2 --- /dev/null +++ b/docs/reference/legacy/server.rst @@ -0,0 +1,113 @@ +Server (legacy :mod:`asyncio`) +============================== + +.. automodule:: websockets.legacy.server + +Starting a server +----------------- + +.. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) + :async: + +.. autofunction:: unix_serve(ws_handler, path=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) + :async: + +Stopping a server +----------------- + +.. autoclass:: WebSocketServer + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: get_loop + + .. automethod:: is_serving + + .. automethod:: start_serving + + .. automethod:: serve_forever + + .. autoattribute:: sockets + +Using a connection +------------------ + +.. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, logger=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) + + .. automethod:: recv + + .. automethod:: send + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + You can customize the opening handshake in a subclass by overriding these methods: + + .. automethod:: process_request + + .. automethod:: select_subprotocol + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + .. autoproperty:: open + + .. autoproperty:: closed + + .. autoattribute:: latency + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: path + + .. autoattribute:: request_headers + + .. autoattribute:: response_headers + + .. autoattribute:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason + + +Basic authentication +-------------------- + +.. automodule:: websockets.legacy.auth + +websockets supports HTTP Basic Authentication according to +:rfc:`7235` and :rfc:`7617`. + +.. autofunction:: basic_auth_protocol_factory + +.. autoclass:: BasicAuthWebSocketServerProtocol + + .. autoattribute:: realm + + .. autoattribute:: username + + .. automethod:: check_credentials + +Broadcast +--------- + +.. autofunction:: websockets.legacy.protocol.broadcast diff --git a/docs/reference/new-asyncio/client.rst b/docs/reference/new-asyncio/client.rst deleted file mode 100644 index 77a3c5d53..000000000 --- a/docs/reference/new-asyncio/client.rst +++ /dev/null @@ -1,57 +0,0 @@ -Client (new :mod:`asyncio`) -=========================== - -.. automodule:: websockets.asyncio.client - -Opening a connection --------------------- - -.. autofunction:: connect - :async: - -.. autofunction:: unix_connect - :async: - -Using a connection ------------------- - -.. autoclass:: ClientConnection - - .. automethod:: __aiter__ - - .. automethod:: recv - - .. automethod:: recv_streaming - - .. automethod:: send - - .. automethod:: close - - .. automethod:: wait_closed - - .. automethod:: ping - - .. automethod:: pong - - WebSocket connection objects also provide these attributes: - - .. autoattribute:: id - - .. autoattribute:: logger - - .. autoproperty:: local_address - - .. autoproperty:: remote_address - - .. autoattribute:: latency - - .. autoproperty:: state - - The following attributes are available after the opening handshake, - once the WebSocket connection is open: - - .. autoattribute:: request - - .. autoattribute:: response - - .. autoproperty:: subprotocol diff --git a/docs/reference/new-asyncio/common.rst b/docs/reference/new-asyncio/common.rst deleted file mode 100644 index a58325fb9..000000000 --- a/docs/reference/new-asyncio/common.rst +++ /dev/null @@ -1,47 +0,0 @@ -:orphan: - -Both sides (new :mod:`asyncio`) -=============================== - -.. automodule:: websockets.asyncio.connection - -.. autoclass:: Connection - - .. automethod:: __aiter__ - - .. automethod:: recv - - .. automethod:: recv_streaming - - .. automethod:: send - - .. automethod:: close - - .. automethod:: wait_closed - - .. automethod:: ping - - .. automethod:: pong - - WebSocket connection objects also provide these attributes: - - .. autoattribute:: id - - .. autoattribute:: logger - - .. autoproperty:: local_address - - .. autoproperty:: remote_address - - .. autoattribute:: latency - - .. autoproperty:: state - - The following attributes are available after the opening handshake, - once the WebSocket connection is open: - - .. autoattribute:: request - - .. autoattribute:: response - - .. autoproperty:: subprotocol diff --git a/docs/reference/new-asyncio/server.rst b/docs/reference/new-asyncio/server.rst deleted file mode 100644 index 7bceca5a0..000000000 --- a/docs/reference/new-asyncio/server.rst +++ /dev/null @@ -1,83 +0,0 @@ -Server (new :mod:`asyncio`) -=========================== - -.. automodule:: websockets.asyncio.server - -Creating a server ------------------ - -.. autofunction:: serve - :async: - -.. autofunction:: unix_serve - :async: - -Running a server ----------------- - -.. autoclass:: WebSocketServer - - .. automethod:: close - - .. automethod:: wait_closed - - .. automethod:: get_loop - - .. automethod:: is_serving - - .. automethod:: start_serving - - .. automethod:: serve_forever - - .. autoattribute:: sockets - -Using a connection ------------------- - -.. autoclass:: ServerConnection - - .. automethod:: __aiter__ - - .. automethod:: recv - - .. automethod:: recv_streaming - - .. automethod:: send - - .. automethod:: close - - .. automethod:: wait_closed - - .. automethod:: ping - - .. automethod:: pong - - .. automethod:: respond - - WebSocket connection objects also provide these attributes: - - .. autoattribute:: id - - .. autoattribute:: logger - - .. autoproperty:: local_address - - .. autoproperty:: remote_address - - .. autoattribute:: latency - - .. autoproperty:: state - - The following attributes are available after the opening handshake, - once the WebSocket connection is open: - - .. autoattribute:: request - - .. autoattribute:: response - - .. autoproperty:: subprotocol - -Broadcast ---------- - -.. autofunction:: websockets.asyncio.connection.broadcast diff --git a/docs/topics/authentication.rst b/docs/topics/authentication.rst index 31bc8e6da..86d2e2587 100644 --- a/docs/topics/authentication.rst +++ b/docs/topics/authentication.rst @@ -212,7 +212,9 @@ the user. If authentication fails, it returns an HTTP 401: .. code-block:: python - class QueryParamProtocol(websockets.WebSocketServerProtocol): + from websockets.legacy.server import WebSocketServerProtocol + + class QueryParamProtocol(WebSocketServerProtocol): async def process_request(self, path, headers): token = get_query_parameter(path, "token") if token is None: @@ -258,7 +260,9 @@ the user. If authentication fails, it returns an HTTP 401: .. code-block:: python - class CookieProtocol(websockets.WebSocketServerProtocol): + from websockets.legacy.server import WebSocketServerProtocol + + class CookieProtocol(WebSocketServerProtocol): async def process_request(self, path, headers): # Serve iframe on non-WebSocket requests ... @@ -299,7 +303,9 @@ the user. If authentication fails, it returns an HTTP 401: .. code-block:: python - class UserInfoProtocol(websockets.BasicAuthWebSocketServerProtocol): + from websockets.legacy.auth import BasicAuthWebSocketServerProtocol + + class UserInfoProtocol(BasicAuthWebSocketServerProtocol): async def check_credentials(self, username, password): if username != "token": return False @@ -328,8 +334,10 @@ To authenticate a websockets client with HTTP Basic Authentication .. code-block:: python - async with websockets.connect( - f"wss://{username}:{password}@example.com", + from websockets.legacy.client import connect + + async with connect( + f"wss://{username}:{password}@example.com" ) as websocket: ... @@ -341,7 +349,9 @@ To authenticate a websockets client with HTTP Bearer Authentication .. code-block:: python - async with websockets.connect( + from websockets.legacy.client import connect + + async with connect( "wss://example.com", extra_headers={"Authorization": f"Bearer {token}"} ) as websocket: diff --git a/docs/topics/deployment.rst b/docs/topics/deployment.rst index 2a1fe9a78..48ef72b56 100644 --- a/docs/topics/deployment.rst +++ b/docs/topics/deployment.rst @@ -78,7 +78,7 @@ Option 2 almost always combines with option 3. How do I start a process? ......................... -Run a Python program that invokes :func:`~server.serve`. That's it. +Run a Python program that invokes :func:`~asyncio.server.serve`. That's it. Don't run an ASGI server such as Uvicorn, Hypercorn, or Daphne. They're alternatives to websockets, not complements. @@ -98,18 +98,19 @@ signal and exit the server to ensure a graceful shutdown. Here's an example: .. literalinclude:: ../../example/faq/shutdown_server.py - :emphasize-lines: 12-15,18 + :emphasize-lines: 13-16,19 -When exiting the context manager, :func:`~server.serve` closes all connections +When exiting the context manager, :func:`~asyncio.server.serve` closes all +connections with code 1001 (going away). As a consequence: * If the connection handler is awaiting - :meth:`~server.WebSocketServerProtocol.recv`, it receives a + :meth:`~asyncio.server.ServerConnection.recv`, it receives a :exc:`~exceptions.ConnectionClosedOK` exception. It can catch the exception and clean up before exiting. * Otherwise, it should be waiting on - :meth:`~server.WebSocketServerProtocol.wait_closed`, so it can receive the + :meth:`~asyncio.server.ServerConnection.wait_closed`, so it can receive the :exc:`~exceptions.ConnectionClosedOK` exception and exit. This example is easily adapted to handle other signals. @@ -173,7 +174,7 @@ Load balancers need a way to check whether server processes are up and running to avoid routing connections to a non-functional backend. websockets provide minimal support for responding to HTTP requests with the -:meth:`~server.WebSocketServerProtocol.process_request` hook. +``process_request`` hook. Here's an example: diff --git a/docs/topics/design.rst b/docs/topics/design.rst index cc65e6a70..d2fd18d0c 100644 --- a/docs/topics/design.rst +++ b/docs/topics/design.rst @@ -1,10 +1,11 @@ -Design -====== +Design (legacy :mod:`asyncio`) +============================== -.. currentmodule:: websockets +.. currentmodule:: websockets.legacy -This document describes the design of websockets. It assumes familiarity with -the specification of the WebSocket protocol in :rfc:`6455`. +This document describes the design of the legacy implementation of websockets. +It assumes familiarity with the specification of the WebSocket protocol in +:rfc:`6455`. It's primarily intended at maintainers. It may also be useful for users who wish to understand what happens under the hood. @@ -32,21 +33,19 @@ WebSocket connections go through a trivial state machine: Transitions happen in the following places: - ``CONNECTING -> OPEN``: in - :meth:`~legacy.protocol.WebSocketCommonProtocol.connection_open` which runs when - the :ref:`opening handshake ` completes and the WebSocket + :meth:`~protocol.WebSocketCommonProtocol.connection_open` which runs when the + :ref:`opening handshake ` completes and the WebSocket connection is established — not to be confused with - :meth:`~asyncio.BaseProtocol.connection_made` which runs when the TCP connection - is established; -- ``OPEN -> CLOSING``: in - :meth:`~legacy.protocol.WebSocketCommonProtocol.write_frame` immediately before - sending a close frame; since receiving a close frame triggers sending a - close frame, this does the right thing regardless of which side started the - :ref:`closing handshake `; also in - :meth:`~legacy.protocol.WebSocketCommonProtocol.fail_connection` which duplicates - a few lines of code from ``write_close_frame()`` and ``write_frame()``; -- ``* -> CLOSED``: in - :meth:`~legacy.protocol.WebSocketCommonProtocol.connection_lost` which is always - called exactly once when the TCP connection is closed. + :meth:`~asyncio.BaseProtocol.connection_made` which runs when the TCP + connection is established; +- ``OPEN -> CLOSING``: in :meth:`~protocol.WebSocketCommonProtocol.write_frame` + immediately before sending a close frame; since receiving a close frame + triggers sending a close frame, this does the right thing regardless of which + side started the :ref:`closing handshake `; also in + :meth:`~protocol.WebSocketCommonProtocol.fail_connection` which duplicates a + few lines of code from ``write_close_frame()`` and ``write_frame()``; +- ``* -> CLOSED``: in :meth:`~protocol.WebSocketCommonProtocol.connection_lost` + which is always called exactly once when the TCP connection is closed. Coroutines .......... @@ -57,38 +56,38 @@ connection lifecycle on the client side. .. image:: lifecycle.svg :target: _images/lifecycle.svg -The lifecycle is identical on the server side, except inversion of control -makes the equivalent of :meth:`~client.connect` implicit. +The lifecycle is identical on the server side, except inversion of control makes +the equivalent of :meth:`~client.connect` implicit. Coroutines shown in green are called by the application. Multiple coroutines may interact with the WebSocket connection concurrently. Coroutines shown in gray manage the connection. When the opening handshake -succeeds, :meth:`~legacy.protocol.WebSocketCommonProtocol.connection_open` starts -two tasks: - -- :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` runs - :meth:`~legacy.protocol.WebSocketCommonProtocol.transfer_data` which handles - incoming data and lets :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` - consume it. It may be canceled to terminate the connection. It never exits - with an exception other than :exc:`~asyncio.CancelledError`. See :ref:`data - transfer ` below. - -- :attr:`~legacy.protocol.WebSocketCommonProtocol.keepalive_ping_task` runs - :meth:`~legacy.protocol.WebSocketCommonProtocol.keepalive_ping` which sends Ping +succeeds, :meth:`~protocol.WebSocketCommonProtocol.connection_open` starts two +tasks: + +- :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` runs + :meth:`~protocol.WebSocketCommonProtocol.transfer_data` which handles incoming + data and lets :meth:`~protocol.WebSocketCommonProtocol.recv` consume it. It + may be canceled to terminate the connection. It never exits with an exception + other than :exc:`~asyncio.CancelledError`. See :ref:`data transfer + ` below. + +- :attr:`~protocol.WebSocketCommonProtocol.keepalive_ping_task` runs + :meth:`~protocol.WebSocketCommonProtocol.keepalive_ping` which sends Ping frames at regular intervals and ensures that corresponding Pong frames are - received. It is canceled when the connection terminates. It never exits - with an exception other than :exc:`~asyncio.CancelledError`. + received. It is canceled when the connection terminates. It never exits with + an exception other than :exc:`~asyncio.CancelledError`. -- :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` runs - :meth:`~legacy.protocol.WebSocketCommonProtocol.close_connection` which waits for - the data transfer to terminate, then takes care of closing the TCP - connection. It must not be canceled. It never exits with an exception. See - :ref:`connection termination ` below. +- :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` runs + :meth:`~protocol.WebSocketCommonProtocol.close_connection` which waits for the + data transfer to terminate, then takes care of closing the TCP connection. It + must not be canceled. It never exits with an exception. See :ref:`connection + termination ` below. -Besides, :meth:`~legacy.protocol.WebSocketCommonProtocol.fail_connection` starts -the same :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` when -the opening handshake fails, in order to close the TCP connection. +Besides, :meth:`~protocol.WebSocketCommonProtocol.fail_connection` starts the +same :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` when the +opening handshake fails, in order to close the TCP connection. Splitting the responsibilities between two tasks makes it easier to guarantee that websockets can terminate connections: @@ -99,11 +98,11 @@ that websockets can terminate connections: regardless of whether the connection terminates normally or abnormally. -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` completes when no +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` completes when no more data will be received on the connection. Under normal circumstances, it exits after exchanging close frames. -:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` completes when +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task` completes when the TCP connection is closed. @@ -113,10 +112,9 @@ Opening handshake ----------------- websockets performs the opening handshake when establishing a WebSocket -connection. On the client side, :meth:`~client.connect` executes it -before returning the protocol to the caller. On the server side, it's executed -before passing the protocol to the ``ws_handler`` coroutine handling the -connection. +connection. On the client side, :meth:`~client.connect` executes it before +returning the protocol to the caller. On the server side, it's executed before +passing the protocol to the ``ws_handler`` coroutine handling the connection. While the opening handshake is asymmetrical — the client sends an HTTP Upgrade request and the server replies with an HTTP Switching Protocols response — @@ -136,9 +134,9 @@ On the client side, :meth:`~client.WebSocketClientProtocol.handshake`: On the server side, :meth:`~server.WebSocketServerProtocol.handshake`: - reads an HTTP request from the network; -- calls :meth:`~server.WebSocketServerProtocol.process_request` which may - abort the WebSocket handshake and return an HTTP response instead; this - hook only makes sense on the server side; +- calls :meth:`~server.WebSocketServerProtocol.process_request` which may abort + the WebSocket handshake and return an HTTP response instead; this hook only + makes sense on the server side; - checks the HTTP request, negotiates ``extensions`` and ``subprotocol``, and configures the protocol accordingly; - builds an HTTP response based on the above and parameters passed to @@ -178,13 +176,13 @@ differences between a server and a client: These differences are so minor that all the logic for `data framing`_, for `sending and receiving data`_ and for `closing the connection`_ is implemented -in the same class, :class:`~legacy.protocol.WebSocketCommonProtocol`. +in the same class, :class:`~protocol.WebSocketCommonProtocol`. .. _data framing: https://www.rfc-editor.org/rfc/rfc6455.html#section-5 .. _sending and receiving data: https://www.rfc-editor.org/rfc/rfc6455.html#section-6 .. _closing the connection: https://www.rfc-editor.org/rfc/rfc6455.html#section-7 -The :attr:`~legacy.protocol.WebSocketCommonProtocol.is_client` attribute tells which +The :attr:`~protocol.WebSocketCommonProtocol.is_client` attribute tells which side a protocol instance is managing. This attribute is defined on the :attr:`~server.WebSocketServerProtocol` and :attr:`~client.WebSocketClientProtocol` classes. @@ -211,11 +209,11 @@ The left side of the diagram shows how websockets receives data. Incoming data is written to a :class:`~asyncio.StreamReader` in order to implement flow control and provide backpressure on the TCP connection. -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task`, which is started +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`, which is started when the WebSocket connection is established, processes this data. When it receives data frames, it reassembles fragments and puts the resulting -messages in the :attr:`~legacy.protocol.WebSocketCommonProtocol.messages` queue. +messages in the :attr:`~protocol.WebSocketCommonProtocol.messages` queue. When it encounters a control frame: @@ -227,11 +225,11 @@ When it encounters a control frame: Running this process in a task guarantees that control frames are processed promptly. Without such a task, websockets would depend on the application to drive the connection by having exactly one coroutine awaiting -:meth:`~legacy.protocol.WebSocketCommonProtocol.recv` at any time. While this -happens naturally in many use cases, it cannot be relied upon. +:meth:`~protocol.WebSocketCommonProtocol.recv` at any time. While this happens +naturally in many use cases, it cannot be relied upon. -Then :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` fetches the next message -from the :attr:`~legacy.protocol.WebSocketCommonProtocol.messages` queue, with some +Then :meth:`~protocol.WebSocketCommonProtocol.recv` fetches the next message +from the :attr:`~protocol.WebSocketCommonProtocol.messages` queue, with some complexity added for handling backpressure and termination correctly. Sending data @@ -239,19 +237,19 @@ Sending data The right side of the diagram shows how websockets sends data. -:meth:`~legacy.protocol.WebSocketCommonProtocol.send` writes one or several data -frames containing the message. While sending a fragmented message, concurrent -calls to :meth:`~legacy.protocol.WebSocketCommonProtocol.send` are put on hold until -all fragments are sent. This makes concurrent calls safe. +:meth:`~protocol.WebSocketCommonProtocol.send` writes one or several data frames +containing the message. While sending a fragmented message, concurrent calls to +:meth:`~protocol.WebSocketCommonProtocol.send` are put on hold until all +fragments are sent. This makes concurrent calls safe. -:meth:`~legacy.protocol.WebSocketCommonProtocol.ping` writes a ping frame and -yields a :class:`~asyncio.Future` which will be completed when a matching pong -frame is received. +:meth:`~protocol.WebSocketCommonProtocol.ping` writes a ping frame and yields a +:class:`~asyncio.Future` which will be completed when a matching pong frame is +received. -:meth:`~legacy.protocol.WebSocketCommonProtocol.pong` writes a pong frame. +:meth:`~protocol.WebSocketCommonProtocol.pong` writes a pong frame. -:meth:`~legacy.protocol.WebSocketCommonProtocol.close` writes a close frame and -waits for the TCP connection to terminate. +:meth:`~protocol.WebSocketCommonProtocol.close` writes a close frame and waits +for the TCP connection to terminate. Outgoing data is written to a :class:`~asyncio.StreamWriter` in order to implement flow control and provide backpressure from the TCP connection. @@ -262,17 +260,17 @@ Closing handshake ................. When the other side of the connection initiates the closing handshake, -:meth:`~legacy.protocol.WebSocketCommonProtocol.read_message` receives a close -frame while in the ``OPEN`` state. It moves to the ``CLOSING`` state, sends a -close frame, and returns :obj:`None`, causing -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. +:meth:`~protocol.WebSocketCommonProtocol.read_message` receives a close frame +while in the ``OPEN`` state. It moves to the ``CLOSING`` state, sends a close +frame, and returns :obj:`None`, causing +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. When this side of the connection initiates the closing handshake with -:meth:`~legacy.protocol.WebSocketCommonProtocol.close`, it moves to the ``CLOSING`` +:meth:`~protocol.WebSocketCommonProtocol.close`, it moves to the ``CLOSING`` state and sends a close frame. When the other side sends a close frame, -:meth:`~legacy.protocol.WebSocketCommonProtocol.read_message` receives it in the +:meth:`~protocol.WebSocketCommonProtocol.read_message` receives it in the ``CLOSING`` state and returns :obj:`None`, also causing -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. If the other side doesn't send a close frame within the connection's close timeout, websockets :ref:`fails the connection `. @@ -289,33 +287,33 @@ Then websockets terminates the TCP connection. Connection termination ---------------------- -:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task`, which is +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task`, which is started when the WebSocket connection is established, is responsible for eventually closing the TCP connection. -First :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` waits -for :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` to terminate, -which may happen as a result of: +First :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` waits for +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to terminate, which +may happen as a result of: - a successful closing handshake: as explained above, this exits the infinite - loop in :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task`; + loop in :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`; - a timeout while waiting for the closing handshake to complete: this cancels - :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task`; + :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`; - a protocol error, including connection errors: depending on the exception, - :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` :ref:`fails the + :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` :ref:`fails the connection ` with a suitable code and exits. -:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` is separate -from :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` to make it -easier to implement the timeout on the closing handshake. Canceling -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` creates no risk -of canceling :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` -and failing to close the TCP connection, thus leaking resources. +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task` is separate from +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to make it easier +to implement the timeout on the closing handshake. Canceling +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` creates no risk of +canceling :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` and +failing to close the TCP connection, thus leaking resources. -Then :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` cancels -:meth:`~legacy.protocol.WebSocketCommonProtocol.keepalive_ping`. This task has no -protocol compliance responsibilities. Terminating it to avoid leaking it is -the only concern. +Then :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` cancels +:meth:`~protocol.WebSocketCommonProtocol.keepalive_ping`. This task has no +protocol compliance responsibilities. Terminating it to avoid leaking it is the +only concern. Terminating the TCP connection can take up to ``2 * close_timeout`` on the server side and ``3 * close_timeout`` on the client side. Clients start by @@ -335,11 +333,11 @@ If the opening handshake doesn't complete successfully, websockets fails the connection by closing the TCP connection. Once the opening handshake has completed, websockets fails the connection by -canceling :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` -and sending a close frame if appropriate. +canceling :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` and +sending a close frame if appropriate. -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` exits, unblocking -:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task`, which closes +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` exits, unblocking +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task`, which closes the TCP connection. @@ -348,7 +346,7 @@ the TCP connection. Server shutdown --------------- -:class:`~websockets.server.WebSocketServer` closes asynchronously like +:class:`~server.WebSocketServer` closes asynchronously like :class:`asyncio.Server`. The shutdown happen in two steps: 1. Stop listening and accepting new connections; @@ -356,10 +354,10 @@ Server shutdown the opening handshake is still in progress, with HTTP status code 503 (Service Unavailable). -The first call to :class:`~websockets.server.WebSocketServer.close` starts a -task that performs this sequence. Further calls are ignored. This is the -easiest way to make :class:`~websockets.server.WebSocketServer.close` and -:class:`~websockets.server.WebSocketServer.wait_closed` idempotent. +The first call to :class:`~server.WebSocketServer.close` starts a task that +performs this sequence. Further calls are ignored. This is the easiest way to +make :class:`~server.WebSocketServer.close` and +:class:`~server.WebSocketServer.wait_closed` idempotent. .. _cancellation: @@ -415,45 +413,45 @@ happen on the client side. On the server side, the opening handshake is managed by websockets and nothing results in a cancellation. Once the WebSocket connection is established, internal tasks -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` and -:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` mustn't get +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` and +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task` mustn't get accidentally canceled if a coroutine that awaits them is canceled. In other words, they must be shielded from cancellation. -:meth:`~legacy.protocol.WebSocketCommonProtocol.recv` waits for the next message in -the queue or for :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` -to terminate, whichever comes first. It relies on :func:`~asyncio.wait` for -waiting on two futures in parallel. As a consequence, even though it's waiting -on a :class:`~asyncio.Future` signaling the next message and on -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task`, it doesn't +:meth:`~protocol.WebSocketCommonProtocol.recv` waits for the next message in the +queue or for :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to +terminate, whichever comes first. It relies on :func:`~asyncio.wait` for waiting +on two futures in parallel. As a consequence, even though it's waiting on a +:class:`~asyncio.Future` signaling the next message and on +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`, it doesn't propagate cancellation to them. -:meth:`~legacy.protocol.WebSocketCommonProtocol.ensure_open` is called by -:meth:`~legacy.protocol.WebSocketCommonProtocol.send`, -:meth:`~legacy.protocol.WebSocketCommonProtocol.ping`, and -:meth:`~legacy.protocol.WebSocketCommonProtocol.pong`. When the connection state is +:meth:`~protocol.WebSocketCommonProtocol.ensure_open` is called by +:meth:`~protocol.WebSocketCommonProtocol.send`, +:meth:`~protocol.WebSocketCommonProtocol.ping`, and +:meth:`~protocol.WebSocketCommonProtocol.pong`. When the connection state is ``CLOSING``, it waits for -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` but shields it to +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` but shields it to prevent cancellation. -:meth:`~legacy.protocol.WebSocketCommonProtocol.close` waits for the data transfer -task to terminate with :func:`~asyncio.timeout`. If it's canceled or if the -timeout elapses, :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` -is canceled, which is correct at this point. -:meth:`~legacy.protocol.WebSocketCommonProtocol.close` then waits for -:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` but shields it +:meth:`~protocol.WebSocketCommonProtocol.close` waits for the data transfer task +to terminate with :func:`~asyncio.timeout`. If it's canceled or if the timeout +elapses, :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` is +canceled, which is correct at this point. +:meth:`~protocol.WebSocketCommonProtocol.close` then waits for +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task` but shields it to prevent cancellation. -:meth:`~legacy.protocol.WebSocketCommonProtocol.close` and -:meth:`~legacy.protocol.WebSocketCommonProtocol.fail_connection` are the only -places where :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` may -be canceled. +:meth:`~protocol.WebSocketCommonProtocol.close` and +:meth:`~protocol.WebSocketCommonProtocol.fail_connection` are the only places +where :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` may be +canceled. -:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` starts by -waiting for :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task`. It +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task` starts by +waiting for :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`. It catches :exc:`~asyncio.CancelledError` to prevent a cancellation of -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` from propagating -to :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task`. +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` from propagating to +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task`. .. _backpressure: @@ -491,28 +489,28 @@ buffers and break the backpressure. Be careful with queues. Concurrency ----------- -Awaiting any combination of :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`, -:meth:`~legacy.protocol.WebSocketCommonProtocol.send`, -:meth:`~legacy.protocol.WebSocketCommonProtocol.close` -:meth:`~legacy.protocol.WebSocketCommonProtocol.ping`, or -:meth:`~legacy.protocol.WebSocketCommonProtocol.pong` concurrently is safe, including +Awaiting any combination of :meth:`~protocol.WebSocketCommonProtocol.recv`, +:meth:`~protocol.WebSocketCommonProtocol.send`, +:meth:`~protocol.WebSocketCommonProtocol.close` +:meth:`~protocol.WebSocketCommonProtocol.ping`, or +:meth:`~protocol.WebSocketCommonProtocol.pong` concurrently is safe, including multiple calls to the same method, with one exception and one limitation. -* **Only one coroutine can receive messages at a time.** This constraint - avoids non-deterministic behavior (and simplifies the implementation). If a - coroutine is awaiting :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`, - awaiting it again in another coroutine raises :exc:`RuntimeError`. +* **Only one coroutine can receive messages at a time.** This constraint avoids + non-deterministic behavior (and simplifies the implementation). If a coroutine + is awaiting :meth:`~protocol.WebSocketCommonProtocol.recv`, awaiting it again + in another coroutine raises :exc:`RuntimeError`. * **Sending a fragmented message forces serialization.** Indeed, the WebSocket protocol doesn't support multiplexing messages. If a coroutine is awaiting - :meth:`~legacy.protocol.WebSocketCommonProtocol.send` to send a fragmented message, + :meth:`~protocol.WebSocketCommonProtocol.send` to send a fragmented message, awaiting it again in another coroutine waits until the first call completes. - This will be transparent in many cases. It may be a concern if the - fragmented message is generated slowly by an asynchronous iterator. + This will be transparent in many cases. It may be a concern if the fragmented + message is generated slowly by an asynchronous iterator. Receiving frames is independent from sending frames. This isolates -:meth:`~legacy.protocol.WebSocketCommonProtocol.recv`, which receives frames, from -the other methods, which send frames. +:meth:`~protocol.WebSocketCommonProtocol.recv`, which receives frames, from the +other methods, which send frames. While the connection is open, each frame is sent with a single write. Combined with the concurrency model of :mod:`asyncio`, this enforces serialization. The diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index 765278360..cad49ba55 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -101,9 +101,10 @@ However, this technique runs into two problems: * Even with :meth:`str.format` style, you're restricted to attribute and index lookups, which isn't enough to implement some fairly simple requirements. -There's a better way. :func:`~client.connect` and :func:`~server.serve` accept -a ``logger`` argument to override the default :class:`~logging.Logger`. You -can set ``logger`` to a :class:`~logging.LoggerAdapter` that enriches logs. +There's a better way. :func:`~asyncio.client.connect` and +:func:`~asyncio.server.serve` accept a ``logger`` argument to override the +default :class:`~logging.Logger`. You can set ``logger`` to a +:class:`~logging.LoggerAdapter` that enriches logs. For example, if the server is behind a reverse proxy, :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address` gives @@ -128,7 +129,7 @@ Here's how to include them in logs, assuming they're in the xff = websocket.request_headers.get("X-Forwarded-For") return f"{websocket.id} {xff} {msg}", kwargs - async with websockets.serve( + async with serve( ..., # Python < 3.10 requires passing None as the second argument. logger=LoggerAdapter(logging.getLogger("websockets.server"), None), @@ -170,7 +171,7 @@ a :class:`~logging.LoggerAdapter`:: } return msg, kwargs - async with websockets.serve( + async with serve( ..., # Python < 3.10 requires passing None as the second argument. logger=LoggerAdapter(logging.getLogger("websockets.server"), None), diff --git a/docs/topics/memory.rst b/docs/topics/memory.rst index efbcbb83f..61b1113e2 100644 --- a/docs/topics/memory.rst +++ b/docs/topics/memory.rst @@ -99,10 +99,11 @@ workloads but it can also backfire because it delays backpressure. messages. * In the legacy :mod:`asyncio` implementation, there is a library-level read - buffer. The ``read_limit`` argument of :func:`~client.connect` and - :func:`~server.serve` controls its size. When the read buffer grows above the - high-water mark, the connection stops reading from the network until it drains - under the low-water mark. This creates backpressure on the TCP connection. + buffer. The ``read_limit`` argument of :func:`~legacy.client.connect` and + :func:`~legacy.server.serve` controls its size. When the read buffer grows + above the high-water mark, the connection stops reading from the network until + it drains under the low-water mark. This creates backpressure on the TCP + connection. There is a write buffer. It as controlled by ``write_limit``. It behaves like the new :mod:`asyncio` implementation described above. diff --git a/docs/topics/security.rst b/docs/topics/security.rst index 83d79e35b..a22b752c7 100644 --- a/docs/topics/security.rst +++ b/docs/topics/security.rst @@ -49,9 +49,9 @@ Identification By default, websockets identifies itself with a ``Server`` or ``User-Agent`` header in the format ``"Python/x.y.z websockets/X.Y"``. -You can set the ``server_header`` argument of :func:`~server.serve` or the -``user_agent_header`` argument of :func:`~client.connect` to configure another -value. Setting them to :obj:`None` removes the header. +You can set the ``server_header`` argument of :func:`~asyncio.server.serve` or +the ``user_agent_header`` argument of :func:`~asyncio.client.connect` to +configure another value. Setting them to :obj:`None` removes the header. Alternatively, you can set the :envvar:`WEBSOCKETS_SERVER` and :envvar:`WEBSOCKETS_USER_AGENT` environment variables respectively. Setting them diff --git a/example/deployment/fly/app.py b/example/deployment/fly/app.py index 4ca34d23b..c8e6af4f9 100644 --- a/example/deployment/fly/app.py +++ b/example/deployment/fly/app.py @@ -4,7 +4,7 @@ import http import signal -import websockets +from websockets.asyncio.server import serve async def echo(websocket): @@ -12,9 +12,9 @@ async def echo(websocket): await websocket.send(message) -async def health_check(path, request_headers): - if path == "/healthz": - return http.HTTPStatus.OK, [], b"OK\n" +def health_check(connection, request): + if request.path == "/healthz": + return connection.respond(http.HTTPStatus.OK, "OK\n") async def main(): @@ -23,7 +23,7 @@ async def main(): stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve( + async with serve( echo, host="", port=8080, diff --git a/example/deployment/haproxy/app.py b/example/deployment/haproxy/app.py index 360479b8e..ef6d9c42d 100644 --- a/example/deployment/haproxy/app.py +++ b/example/deployment/haproxy/app.py @@ -4,7 +4,7 @@ import os import signal -import websockets +from websockets.asyncio.server import serve async def echo(websocket): @@ -18,7 +18,7 @@ async def main(): stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve( + async with serve( echo, host="localhost", port=8000 + int(os.environ["SUPERVISOR_PROCESS_NAME"][-2:]), diff --git a/example/deployment/heroku/app.py b/example/deployment/heroku/app.py index d4ba3edb5..17ad09d26 100644 --- a/example/deployment/heroku/app.py +++ b/example/deployment/heroku/app.py @@ -4,7 +4,7 @@ import signal import os -import websockets +from websockets.asyncio.server import serve async def echo(websocket): @@ -18,7 +18,7 @@ async def main(): stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve( + async with serve( echo, host="", port=int(os.environ["PORT"]), diff --git a/example/deployment/kubernetes/app.py b/example/deployment/kubernetes/app.py index a8bcef688..387f0ade1 100755 --- a/example/deployment/kubernetes/app.py +++ b/example/deployment/kubernetes/app.py @@ -6,7 +6,7 @@ import sys import time -import websockets +from websockets.asyncio.server import serve async def slow_echo(websocket): @@ -17,17 +17,17 @@ async def slow_echo(websocket): await websocket.send(message) -async def health_check(path, request_headers): - if path == "/healthz": - return http.HTTPStatus.OK, [], b"OK\n" - if path == "/inemuri": +def health_check(connection, request): + if request.path == "/healthz": + return connection.respond(http.HTTPStatus.OK, "OK\n") + if request.path == "/inemuri": loop = asyncio.get_running_loop() loop.call_later(1, time.sleep, 10) - return http.HTTPStatus.OK, [], b"Sleeping for 10s\n" - if path == "/seppuku": + return connection.respond(http.HTTPStatus.OK, "Sleeping for 10s\n") + if request.path == "/seppuku": loop = asyncio.get_running_loop() loop.call_later(1, sys.exit, 69) - return http.HTTPStatus.OK, [], b"Terminating\n" + return connection.respond(http.HTTPStatus.OK, "Terminating\n") async def main(): @@ -36,7 +36,7 @@ async def main(): stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve( + async with serve( slow_echo, host="", port=80, diff --git a/example/deployment/kubernetes/benchmark.py b/example/deployment/kubernetes/benchmark.py index 22ee4c5bd..11a452d55 100755 --- a/example/deployment/kubernetes/benchmark.py +++ b/example/deployment/kubernetes/benchmark.py @@ -2,14 +2,15 @@ import asyncio import sys -import websockets + +from websockets.asyncio.client import connect URI = "ws://localhost:32080" async def run(client_id, messages): - async with websockets.connect(URI) as websocket: + async with connect(URI) as websocket: for message_id in range(messages): await websocket.send(f"{client_id}:{message_id}") await websocket.recv() diff --git a/example/deployment/nginx/app.py b/example/deployment/nginx/app.py index 24e608975..134070f61 100644 --- a/example/deployment/nginx/app.py +++ b/example/deployment/nginx/app.py @@ -4,7 +4,7 @@ import os import signal -import websockets +from websockets.asyncio.server import unix_serve async def echo(websocket): @@ -18,7 +18,7 @@ async def main(): stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.unix_serve( + async with unix_serve( echo, path=f"{os.environ['SUPERVISOR_PROCESS_NAME']}.sock", ): diff --git a/example/deployment/render/app.py b/example/deployment/render/app.py index 4ca34d23b..c8e6af4f9 100644 --- a/example/deployment/render/app.py +++ b/example/deployment/render/app.py @@ -4,7 +4,7 @@ import http import signal -import websockets +from websockets.asyncio.server import serve async def echo(websocket): @@ -12,9 +12,9 @@ async def echo(websocket): await websocket.send(message) -async def health_check(path, request_headers): - if path == "/healthz": - return http.HTTPStatus.OK, [], b"OK\n" +def health_check(connection, request): + if request.path == "/healthz": + return connection.respond(http.HTTPStatus.OK, "OK\n") async def main(): @@ -23,7 +23,7 @@ async def main(): stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve( + async with serve( echo, host="", port=8080, diff --git a/example/deployment/supervisor/app.py b/example/deployment/supervisor/app.py index bf61983ef..5e69f16a6 100644 --- a/example/deployment/supervisor/app.py +++ b/example/deployment/supervisor/app.py @@ -3,7 +3,7 @@ import asyncio import signal -import websockets +from websockets.asyncio.server import serve async def echo(websocket): @@ -17,7 +17,7 @@ async def main(): stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve( + async with serve( echo, host="", port=8080, diff --git a/example/django/authentication.py b/example/django/authentication.py index 83e128f07..c4f12a3f8 100644 --- a/example/django/authentication.py +++ b/example/django/authentication.py @@ -3,11 +3,11 @@ import asyncio import django -import websockets django.setup() from sesame.utils import get_user +from websockets.asyncio.server import serve from websockets.frames import CloseCode @@ -22,7 +22,7 @@ async def handler(websocket): async def main(): - async with websockets.serve(handler, "localhost", 8888): + async with serve(handler, "localhost", 8888): await asyncio.get_running_loop().create_future() # run forever diff --git a/example/django/notifications.py b/example/django/notifications.py index 3a9ed10cf..445438d2d 100644 --- a/example/django/notifications.py +++ b/example/django/notifications.py @@ -5,12 +5,13 @@ import aioredis import django -import websockets django.setup() from django.contrib.contenttypes.models import ContentType from sesame.utils import get_user +from websockets.asyncio.connection import broadcast +from websockets.asyncio.server import serve from websockets.frames import CloseCode @@ -61,11 +62,11 @@ async def process_events(): for websocket, connection in CONNECTIONS.items() if event["content_type_id"] in connection["content_type_ids"] ) - websockets.broadcast(recipients, payload) + broadcast(recipients, payload) async def main(): - async with websockets.serve(handler, "localhost", 8888): + async with serve(handler, "localhost", 8888): await process_events() # runs forever diff --git a/example/echo.py b/example/echo.py index d11b33527..b952a5cfb 100755 --- a/example/echo.py +++ b/example/echo.py @@ -1,7 +1,7 @@ #!/usr/bin/env python import asyncio -from websockets.server import serve +from websockets.asyncio.server import serve async def echo(websocket): async for message in websocket: diff --git a/example/faq/health_check_server.py b/example/faq/health_check_server.py index 6c7681e8a..c0fa4327f 100755 --- a/example/faq/health_check_server.py +++ b/example/faq/health_check_server.py @@ -1,22 +1,19 @@ #!/usr/bin/env python import asyncio -import http -import websockets +from http import HTTPStatus +from websockets.asyncio.server import serve -async def health_check(path, request_headers): - if path == "/healthz": - return http.HTTPStatus.OK, [], b"OK\n" +def health_check(connection, request): + if request.path == "/healthz": + return connection.respond(HTTPStatus.OK, b"OK\n") async def echo(websocket): async for message in websocket: await websocket.send(message) async def main(): - async with websockets.serve( - echo, "localhost", 8765, - process_request=health_check, - ): + async with serve(echo, "localhost", 8765, process_request=health_check): await asyncio.get_running_loop().create_future() # run forever asyncio.run(main()) diff --git a/example/faq/shutdown_client.py b/example/faq/shutdown_client.py index 539dd0304..5c8bd8cbe 100755 --- a/example/faq/shutdown_client.py +++ b/example/faq/shutdown_client.py @@ -2,15 +2,15 @@ import asyncio import signal -import websockets + +from websockets.asyncio.client import connect async def client(): uri = "ws://localhost:8765" - async with websockets.connect(uri) as websocket: + async with connect(uri) as websocket: # Close the connection when receiving SIGTERM. loop = asyncio.get_running_loop() - loop.add_signal_handler( - signal.SIGTERM, loop.create_task, websocket.close()) + loop.add_signal_handler(signal.SIGTERM, loop.create_task, websocket.close()) # Process messages received on the connection. async for message in websocket: diff --git a/example/faq/shutdown_server.py b/example/faq/shutdown_server.py index 1bcc9c90b..3f7bc5732 100755 --- a/example/faq/shutdown_server.py +++ b/example/faq/shutdown_server.py @@ -2,7 +2,8 @@ import asyncio import signal -import websockets + +from websockets.asyncio.server import serve async def echo(websocket): async for message in websocket: @@ -14,7 +15,7 @@ async def server(): stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve(echo, "localhost", 8765): + async with serve(echo, "localhost", 8765): await stop asyncio.run(server()) diff --git a/example/legacy/basic_auth_client.py b/example/legacy/basic_auth_client.py index 164732152..0252894b7 100755 --- a/example/legacy/basic_auth_client.py +++ b/example/legacy/basic_auth_client.py @@ -3,11 +3,12 @@ # WS client example with HTTP Basic Authentication import asyncio -import websockets + +from websockets.legacy.client import connect async def hello(): uri = "ws://mary:p@ssw0rd@localhost:8765" - async with websockets.connect(uri) as websocket: + async with connect(uri) as websocket: greeting = await websocket.recv() print(greeting) diff --git a/example/legacy/basic_auth_server.py b/example/legacy/basic_auth_server.py index 6f6020253..fc45a0270 100755 --- a/example/legacy/basic_auth_server.py +++ b/example/legacy/basic_auth_server.py @@ -3,16 +3,18 @@ # Server example with HTTP Basic Authentication over TLS import asyncio -import websockets + +from websockets.legacy.auth import basic_auth_protocol_factory +from websockets.legacy.server import serve async def hello(websocket): greeting = f"Hello {websocket.username}!" await websocket.send(greeting) async def main(): - async with websockets.serve( + async with serve( hello, "localhost", 8765, - create_protocol=websockets.basic_auth_protocol_factory( + create_protocol=basic_auth_protocol_factory( realm="example", credentials=("mary", "p@ssw0rd") ), ): diff --git a/example/legacy/unix_client.py b/example/legacy/unix_client.py index 926156730..87201c9e4 100755 --- a/example/legacy/unix_client.py +++ b/example/legacy/unix_client.py @@ -4,11 +4,12 @@ import asyncio import os.path -import websockets + +from websockets.legacy.client import unix_connect async def hello(): socket_path = os.path.join(os.path.dirname(__file__), "socket") - async with websockets.unix_connect(socket_path) as websocket: + async with unix_connect(socket_path) as websocket: name = input("What's your name? ") await websocket.send(name) print(f">>> {name}") diff --git a/example/legacy/unix_server.py b/example/legacy/unix_server.py index 5bfb66072..8a4981f5f 100755 --- a/example/legacy/unix_server.py +++ b/example/legacy/unix_server.py @@ -4,7 +4,8 @@ import asyncio import os.path -import websockets + +from websockets.legacy.server import unix_serve async def hello(websocket): name = await websocket.recv() @@ -17,7 +18,7 @@ async def hello(websocket): async def main(): socket_path = os.path.join(os.path.dirname(__file__), "socket") - async with websockets.unix_serve(hello, socket_path): + async with unix_serve(hello, socket_path): await asyncio.get_running_loop().create_future() # run forever asyncio.run(main()) diff --git a/example/logging/json_log_formatter.py b/example/logging/json_log_formatter.py index b8fc8d6dc..ff7fce8b5 100644 --- a/example/logging/json_log_formatter.py +++ b/example/logging/json_log_formatter.py @@ -1,6 +1,6 @@ +import datetime import json import logging -import datetime class JSONFormatter(logging.Formatter): """ diff --git a/example/quickstart/client.py b/example/quickstart/client.py index 8d588c2b0..934af69e3 100755 --- a/example/quickstart/client.py +++ b/example/quickstart/client.py @@ -1,11 +1,12 @@ #!/usr/bin/env python import asyncio -import websockets + +from websockets.asyncio.client import connect async def hello(): uri = "ws://localhost:8765" - async with websockets.connect(uri) as websocket: + async with connect(uri) as websocket: name = input("What's your name? ") await websocket.send(name) diff --git a/example/quickstart/client_secure.py b/example/quickstart/client_secure.py index f4b39f2b8..a1449587a 100755 --- a/example/quickstart/client_secure.py +++ b/example/quickstart/client_secure.py @@ -3,7 +3,8 @@ import asyncio import pathlib import ssl -import websockets + +from websockets.asyncio.client import connect ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) localhost_pem = pathlib.Path(__file__).with_name("localhost.pem") @@ -11,7 +12,7 @@ async def hello(): uri = "wss://localhost:8765" - async with websockets.connect(uri, ssl=ssl_context) as websocket: + async with connect(uri, ssl=ssl_context) as websocket: name = input("What's your name? ") await websocket.send(name) diff --git a/example/quickstart/counter.py b/example/quickstart/counter.py index 414919e04..d42069e64 100755 --- a/example/quickstart/counter.py +++ b/example/quickstart/counter.py @@ -3,7 +3,8 @@ import asyncio import json import logging -import websockets +from websockets.asyncio.connection import broadcast +from websockets.asyncio.server import serve logging.basicConfig() @@ -22,7 +23,7 @@ async def counter(websocket): try: # Register user USERS.add(websocket) - websockets.broadcast(USERS, users_event()) + broadcast(USERS, users_event()) # Send current state to user await websocket.send(value_event()) # Manage state changes @@ -30,19 +31,19 @@ async def counter(websocket): event = json.loads(message) if event["action"] == "minus": VALUE -= 1 - websockets.broadcast(USERS, value_event()) + broadcast(USERS, value_event()) elif event["action"] == "plus": VALUE += 1 - websockets.broadcast(USERS, value_event()) + broadcast(USERS, value_event()) else: logging.error("unsupported event: %s", event) finally: # Unregister user USERS.remove(websocket) - websockets.broadcast(USERS, users_event()) + broadcast(USERS, users_event()) async def main(): - async with websockets.serve(counter, "localhost", 6789): + async with serve(counter, "localhost", 6789): await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": diff --git a/example/quickstart/server.py b/example/quickstart/server.py index 64d7adeb6..bde5e6126 100755 --- a/example/quickstart/server.py +++ b/example/quickstart/server.py @@ -1,7 +1,8 @@ #!/usr/bin/env python import asyncio -import websockets + +from websockets.asyncio.server import serve async def hello(websocket): name = await websocket.recv() @@ -13,7 +14,7 @@ async def hello(websocket): print(f">>> {greeting}") async def main(): - async with websockets.serve(hello, "localhost", 8765): + async with serve(hello, "localhost", 8765): await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": diff --git a/example/quickstart/server_secure.py b/example/quickstart/server_secure.py index 11db5fb3a..8b456ed6e 100755 --- a/example/quickstart/server_secure.py +++ b/example/quickstart/server_secure.py @@ -3,7 +3,8 @@ import asyncio import pathlib import ssl -import websockets + +from websockets.asyncio.server import serve async def hello(websocket): name = await websocket.recv() @@ -19,7 +20,7 @@ async def hello(websocket): ssl_context.load_cert_chain(localhost_pem) async def main(): - async with websockets.serve(hello, "localhost", 8765, ssl=ssl_context): + async with serve(hello, "localhost", 8765, ssl=ssl_context): await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": diff --git a/example/quickstart/show_time.py b/example/quickstart/show_time.py index add226869..8aeb811db 100755 --- a/example/quickstart/show_time.py +++ b/example/quickstart/show_time.py @@ -3,7 +3,8 @@ import asyncio import datetime import random -import websockets + +from websockets.asyncio.server import serve async def show_time(websocket): while True: @@ -12,7 +13,7 @@ async def show_time(websocket): await asyncio.sleep(random.random() * 2 + 1) async def main(): - async with websockets.serve(show_time, "localhost", 5678): + async with serve(show_time, "localhost", 5678): await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": diff --git a/example/quickstart/show_time_2.py b/example/quickstart/show_time_2.py index 08e87f593..4fa244a23 100755 --- a/example/quickstart/show_time_2.py +++ b/example/quickstart/show_time_2.py @@ -3,7 +3,9 @@ import asyncio import datetime import random -import websockets + +from websockets.asyncio.connection import broadcast +from websockets.asyncio.server import serve CONNECTIONS = set() @@ -17,11 +19,11 @@ async def register(websocket): async def show_time(): while True: message = datetime.datetime.utcnow().isoformat() + "Z" - websockets.broadcast(CONNECTIONS, message) + broadcast(CONNECTIONS, message) await asyncio.sleep(random.random() * 2 + 1) async def main(): - async with websockets.serve(register, "localhost", 5678): + async with serve(register, "localhost", 5678): await show_time() if __name__ == "__main__": diff --git a/example/tutorial/step1/app.py b/example/tutorial/step1/app.py index 6ec1c60b8..db69070a1 100644 --- a/example/tutorial/step1/app.py +++ b/example/tutorial/step1/app.py @@ -4,7 +4,7 @@ import itertools import json -import websockets +from websockets.asyncio.server import serve from connect4 import PLAYER1, PLAYER2, Connect4 @@ -57,7 +57,7 @@ async def handler(websocket): async def main(): - async with websockets.serve(handler, "", 8001): + async with serve(handler, "", 8001): await asyncio.get_running_loop().create_future() # run forever diff --git a/example/tutorial/step2/app.py b/example/tutorial/step2/app.py index db3e36374..feaf223a0 100644 --- a/example/tutorial/step2/app.py +++ b/example/tutorial/step2/app.py @@ -4,7 +4,7 @@ import json import secrets -import websockets +from websockets.asyncio.server import serve from connect4 import PLAYER1, PLAYER2, Connect4 @@ -182,7 +182,7 @@ async def handler(websocket): async def main(): - async with websockets.serve(handler, "", 8001): + async with serve(handler, "", 8001): await asyncio.get_running_loop().create_future() # run forever diff --git a/example/tutorial/step3/app.py b/example/tutorial/step3/app.py index c2ee020d2..a428e29e7 100644 --- a/example/tutorial/step3/app.py +++ b/example/tutorial/step3/app.py @@ -6,7 +6,7 @@ import secrets import signal -import websockets +from websockets.asyncio.server import serve from connect4 import PLAYER1, PLAYER2, Connect4 @@ -190,7 +190,7 @@ async def main(): loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) port = int(os.environ.get("PORT", "8001")) - async with websockets.serve(handler, "", port): + async with serve(handler, "", port): await stop diff --git a/experiments/authentication/app.py b/experiments/authentication/app.py index 039e21174..e3b2cf1f6 100644 --- a/experiments/authentication/app.py +++ b/experiments/authentication/app.py @@ -8,8 +8,9 @@ import urllib.parse import uuid -import websockets from websockets.frames import CloseCode +from websockets.legacy.auth import BasicAuthWebSocketServerProtocol +from websockets.legacy.server import WebSocketServerProtocol, serve # User accounts database @@ -107,7 +108,7 @@ async def first_message_handler(websocket): # Add credentials to the WebSocket URI in a query parameter -class QueryParamProtocol(websockets.WebSocketServerProtocol): +class QueryParamProtocol(WebSocketServerProtocol): async def process_request(self, path, headers): token = get_query_param(path, "token") if token is None: @@ -131,7 +132,7 @@ async def query_param_handler(websocket): # Set a cookie on the domain of the WebSocket URI -class CookieProtocol(websockets.WebSocketServerProtocol): +class CookieProtocol(WebSocketServerProtocol): async def process_request(self, path, headers): if "Upgrade" not in headers: template = pathlib.Path(__file__).with_name(path[1:]) @@ -161,7 +162,7 @@ async def cookie_handler(websocket): # Adding credentials to the WebSocket URI in user information -class UserInfoProtocol(websockets.BasicAuthWebSocketServerProtocol): +class UserInfoProtocol(BasicAuthWebSocketServerProtocol): async def check_credentials(self, username, password): if username != "token": return False @@ -192,26 +193,26 @@ async def main(): loop.add_signal_handler(signal.SIGINT, stop.set_result, None) loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve( + async with serve( noop_handler, host="", port=8000, process_request=serve_html, - ), websockets.serve( + ), serve( first_message_handler, host="", port=8001, - ), websockets.serve( + ), serve( query_param_handler, host="", port=8002, create_protocol=QueryParamProtocol, - ), websockets.serve( + ), serve( cookie_handler, host="", port=8003, create_protocol=CookieProtocol, - ), websockets.serve( + ), serve( user_info_handler, host="", port=8004, diff --git a/experiments/broadcast/clients.py b/experiments/broadcast/clients.py index fe39dfe05..64334f20f 100644 --- a/experiments/broadcast/clients.py +++ b/experiments/broadcast/clients.py @@ -5,7 +5,7 @@ import sys import time -import websockets +from websockets.asyncio.client import connect LATENCIES = {} @@ -26,7 +26,7 @@ async def log_latency(interval): async def client(): try: - async with websockets.connect( + async with connect( "ws://localhost:8765", ping_timeout=None, ) as websocket: diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index adb66e262..52cc48898 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -319,8 +319,8 @@ class AbortHandshake(InvalidHandshake): This exception is an implementation detail. - The public API - is :meth:`~websockets.server.WebSocketServerProtocol.process_request`. + The public API is + :meth:`~websockets.legacy.server.WebSocketServerProtocol.process_request`. Attributes: status (~http.HTTPStatus): HTTP status code. diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 8526bad6b..4d030e5e2 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -125,11 +125,11 @@ def basic_auth_protocol_factory( Protocol factory that enforces HTTP Basic Auth. :func:`basic_auth_protocol_factory` is designed to integrate with - :func:`~websockets.server.serve` like this:: + :func:`~websockets.legacy.server.serve` like this:: - websockets.serve( + serve( ..., - create_protocol=websockets.basic_auth_protocol_factory( + create_protocol=basic_auth_protocol_factory( realm="my dev server", credentials=("hello", "iloveyou"), ) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index b61126c81..256bee14c 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -342,7 +342,7 @@ class Connect: :func:`connect` can be used as a asynchronous context manager:: - async with websockets.connect(...) as websocket: + async with connect(...) as websocket: ... The connection is closed automatically when exiting the context. @@ -350,7 +350,7 @@ class Connect: :func:`connect` can be used as an infinite asynchronous iterator to reconnect automatically on errors:: - async for websocket in websockets.connect(...): + async for websocket in connect(...): try: ... except websockets.ConnectionClosed: diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 191350de3..66eb94199 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -67,13 +67,13 @@ class WebSocketCommonProtocol(asyncio.Protocol): :class:`WebSocketCommonProtocol` provides APIs shared between WebSocket servers and clients. You shouldn't use it directly. Instead, use - :class:`~websockets.client.WebSocketClientProtocol` or - :class:`~websockets.server.WebSocketServerProtocol`. + :class:`~websockets.legacy.client.WebSocketClientProtocol` or + :class:`~websockets.legacy.server.WebSocketServerProtocol`. This documentation focuses on low-level details that aren't covered in the - documentation of :class:`~websockets.client.WebSocketClientProtocol` and - :class:`~websockets.server.WebSocketServerProtocol` for the sake of - simplicity. + documentation of :class:`~websockets.legacy.client.WebSocketClientProtocol` + and :class:`~websockets.legacy.server.WebSocketServerProtocol` for the sake + of simplicity. Once the connection is open, a Ping_ frame is sent every ``ping_interval`` seconds. This serves as a keepalive. It helps keeping the connection open, @@ -89,7 +89,7 @@ class WebSocketCommonProtocol(asyncio.Protocol): .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 - See the discussion of :doc:`timeouts <../../topics/keepalive>` for details. + See the discussion of :doc:`keepalive <../../topics/keepalive>` for details. The ``close_timeout`` parameter defines a maximum wait time for completing the closing handshake and terminating the TCP connection. For legacy @@ -99,8 +99,8 @@ class WebSocketCommonProtocol(asyncio.Protocol): ``close_timeout`` is a parameter of the protocol because websockets usually calls :meth:`close` implicitly upon exit: - * on the client side, when using :func:`~websockets.client.connect` as a - context manager; + * on the client side, when using :func:`~websockets.legacy.client.connect` + as a context manager; * on the server side, when the connection handler terminates. To apply a timeout to any other API, wrap it in :func:`~asyncio.timeout` or diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 9b84a6b81..3cdfeb21b 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -150,7 +150,7 @@ async def test_aiter_connection_closed_ok(self): await anext(aiterator) async def test_aiter_connection_closed_error(self): - """__aiter__ raises ConnnectionClosedError after an error.""" + """__aiter__ raises ConnectionClosedError after an error.""" aiterator = aiter(self.connection) await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index b3023434b..5d4f0e2f8 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -1,5 +1,4 @@ import asyncio -import dataclasses import http import logging import socket @@ -228,9 +227,7 @@ async def test_process_response_override_response(self): """Server runs process_response and overrides the handshake response.""" def process_response(ws, request, response): - headers = response.headers.copy() - headers["X-ProcessResponse-Ran"] = "true" - return dataclasses.replace(response, headers=headers) + response.headers["X-ProcessResponse-Ran"] = "true" async with run_server(process_response=process_response) as server: async with run_client(server) as client: @@ -242,9 +239,7 @@ async def test_async_process_response_override_response(self): """Server runs async process_response and overrides the handshake response.""" async def process_response(ws, request, response): - headers = response.headers.copy() - headers["X-ProcessResponse-Ran"] = "true" - return dataclasses.replace(response, headers=headers) + response.headers["X-ProcessResponse-Ran"] = "true" async with run_server(process_response=process_response) as server: async with run_client(server) as client: diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 88cbcd669..877adc4bf 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -132,7 +132,7 @@ def test_iter_connection_closed_ok(self): next(iterator) def test_iter_connection_closed_error(self): - """__iter__ raises ConnnectionClosedError after an error.""" + """__iter__ raises ConnectionClosedError after an error.""" iterator = iter(self.connection) self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 4e04a39d5..c0a5f01e6 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -1,4 +1,3 @@ -import dataclasses import http import logging import socket @@ -173,9 +172,7 @@ def test_process_response_override_response(self): """Server runs process_response and overrides the handshake response.""" def process_response(ws, request, response): - headers = response.headers.copy() - headers["X-ProcessResponse-Ran"] = "true" - return dataclasses.replace(response, headers=headers) + response.headers["X-ProcessResponse-Ran"] = "true" with run_server(process_response=process_response) as server: with run_client(server) as client: From 472f9517b0f8d1f190ae5961fe10a064ef016972 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 19 Aug 2024 19:34:38 +0200 Subject: [PATCH 096/109] Explain new asyncio implementation in docs index page. --- docs/index.rst | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index d9737db12..218a489a3 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -28,23 +28,42 @@ with a focus on correctness, simplicity, robustness, and performance. It supports several network I/O and control flow paradigms: -1. The default implementation builds upon :mod:`asyncio`, Python's standard +1. The primary implementation builds upon :mod:`asyncio`, Python's standard asynchronous I/O framework. It provides an elegant coroutine-based API. It's ideal for servers that handle many clients concurrently. + + .. admonition:: As of version :ref:`13.0`, there is a new :mod:`asyncio` + implementation. + :class: important + + The historical implementation in ``websockets.legacy`` traces its roots to + early versions of websockets. Although it's stable and robust, it is now + considered legacy. + + The new implementation in ``websockets.asyncio`` is a rewrite on top of + the Sans-I/O implementation. It adds a few features that were impossible + to implement within the original design. + + The new implementation will become the default as soon as it reaches + feature parity. If you're using the historical implementation, you should + :doc:`ugrade to the new implementation `. It's usually + straightforward. + 2. The :mod:`threading` implementation is a good alternative for clients, especially if you aren't familiar with :mod:`asyncio`. It may also be used for servers that don't need to serve many clients. + 3. The `Sans-I/O`_ implementation is designed for integrating in third-party libraries, typically application servers, in addition being used internally by websockets. .. _Sans-I/O: https://sans-io.readthedocs.io/ -Here's an echo server with the :mod:`asyncio` API: +Here's an echo server using the :mod:`asyncio` API: .. literalinclude:: ../example/echo.py -Here's how a client sends and receives messages with the :mod:`threading` API: +Here's a client using the :mod:`threading` API: .. literalinclude:: ../example/hello.py From 14ca557f53cf19084eb64aef2e4563e5630c211b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 19 Aug 2024 19:35:02 +0200 Subject: [PATCH 097/109] Proof-read tutorial. Switch to the new asyncio implementation. Change exception to ValueError when arguments have incorrect values. --- docs/intro/tutorial1.rst | 4 ++-- docs/intro/tutorial2.rst | 38 ++++++++++++++++-------------- example/tutorial/start/connect4.py | 6 ++--- example/tutorial/step1/app.py | 2 +- example/tutorial/step2/app.py | 7 +++--- example/tutorial/step3/app.py | 7 +++--- 6 files changed, 34 insertions(+), 30 deletions(-) diff --git a/docs/intro/tutorial1.rst b/docs/intro/tutorial1.rst index 74f5f79a3..6e91867c8 100644 --- a/docs/intro/tutorial1.rst +++ b/docs/intro/tutorial1.rst @@ -123,7 +123,7 @@ wins. Here's its API. :param player: :data:`~connect4.PLAYER1` or :data:`~connect4.PLAYER2`. :param column: between ``0`` and ``6``. :returns: Row where the checker lands, between ``0`` and ``5``. - :raises RuntimeError: if the move is illegal. + :raises ValueError: if the move is illegal. .. attribute:: moves @@ -520,7 +520,7 @@ Then, you're going to iterate over incoming messages and take these steps: interface sends; * play the move in the board with the :meth:`~connect4.Connect4.play` method, alternating between the two players; -* if :meth:`~connect4.Connect4.play` raises :exc:`RuntimeError` because the +* if :meth:`~connect4.Connect4.play` raises :exc:`ValueError` because the move is illegal, send an event of type ``"error"``; * else, send an event of type ``"play"`` to tell the user interface where the checker lands; diff --git a/docs/intro/tutorial2.rst b/docs/intro/tutorial2.rst index b8e35f292..b5d3a3dc8 100644 --- a/docs/intro/tutorial2.rst +++ b/docs/intro/tutorial2.rst @@ -84,7 +84,7 @@ When the second player joins the game, look it up: async def handler(websocket): ... - join_key = ... # TODO + join_key = ... # Find the Connect Four game. game, connected = JOIN[join_key] @@ -434,7 +434,7 @@ Once the initialization sequence is done, watching a game is as simple as registering the WebSocket connection in the ``connected`` set in order to receive game events and doing nothing until the spectator disconnects. You can wait for a connection to terminate with -:meth:`~legacy.protocol.WebSocketCommonProtocol.wait_closed`: +:meth:`~asyncio.server.ServerConnection.wait_closed`: .. code-block:: python @@ -482,38 +482,40 @@ you're using this pattern: ... Since this is a very common pattern in WebSocket servers, websockets provides -the :func:`~legacy.protocol.broadcast` helper for this purpose: +the :func:`~asyncio.connection.broadcast` helper for this purpose: .. code-block:: python + from websockets.asyncio.connection import broadcast + async def handler(websocket): ... - websockets.broadcast(connected, json.dumps(event)) + broadcast(connected, json.dumps(event)) ... -Calling :func:`legacy.protocol.broadcast` once is more efficient than -calling :meth:`~legacy.protocol.WebSocketCommonProtocol.send` in a loop. +Calling :func:`~asyncio.connection.broadcast` once is more efficient than +calling :meth:`~asyncio.server.ServerConnection.send` in a loop. However, there's a subtle difference in behavior. Did you notice that there's no -``await`` in the second version? Indeed, :func:`legacy.protocol.broadcast` is a -function, not a coroutine like -:meth:`~legacy.protocol.WebSocketCommonProtocol.send` or -:meth:`~legacy.protocol.WebSocketCommonProtocol.recv`. +``await`` in the second version? Indeed, :func:`~asyncio.connection.broadcast` +is a function, not a coroutine like +:meth:`~asyncio.server.ServerConnection.send` or +:meth:`~asyncio.server.ServerConnection.recv`. -It's quite obvious why :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` +It's quite obvious why :meth:`~asyncio.server.ServerConnection.recv` is a coroutine. When you want to receive the next message, you have to wait until the client sends it and the network transmits it. -It's less obvious why :meth:`~legacy.protocol.WebSocketCommonProtocol.send` is +It's less obvious why :meth:`~asyncio.server.ServerConnection.send` is a coroutine. If you send many messages or large messages, you could write data faster than the network can transmit it or the client can read it. Then, outgoing data will pile up in buffers, which will consume memory and may crash your application. -To avoid this problem, :meth:`~legacy.protocol.WebSocketCommonProtocol.send` +To avoid this problem, :meth:`~asyncio.server.ServerConnection.send` waits until the write buffer drains. By slowing down the application as necessary, this ensures that the server doesn't send data too quickly. This is called backpressure and it's useful for building robust systems. @@ -522,12 +524,12 @@ That said, when you're sending the same messages to many clients in a loop, applying backpressure in this way can become counterproductive. When you're broadcasting, you don't want to slow down everyone to the pace of the slowest clients; you want to drop clients that cannot keep up with the data stream. -That's why :func:`legacy.protocol.broadcast` doesn't wait until write buffers -drain. +That's why :func:`~asyncio.connection.broadcast` doesn't wait until write +buffers drain and therefore doesn't need to be a coroutine. -For our Connect Four game, there's no difference in practice: the total amount -of data sent on a connection for a game of Connect Four is less than 64 KB, -so the write buffer never fills up and backpressure never kicks in anyway. +For our Connect Four game, there's no difference in practice. The total amount +of data sent on a connection for a game of Connect Four is so small that the +write buffer cannot fill up. As a consequence, backpressure never kicks in. Summary ------- diff --git a/example/tutorial/start/connect4.py b/example/tutorial/start/connect4.py index 0a61e7c7e..104476962 100644 --- a/example/tutorial/start/connect4.py +++ b/example/tutorial/start/connect4.py @@ -43,15 +43,15 @@ def play(self, player, column): Returns the row where the checker lands. - Raises :exc:`RuntimeError` if the move is illegal. + Raises :exc:`ValueError` if the move is illegal. """ if player == self.last_player: - raise RuntimeError("It isn't your turn.") + raise ValueError("It isn't your turn.") row = self.top[column] if row == 6: - raise RuntimeError("This slot is full.") + raise ValueError("This slot is full.") self.moves.append((player, column, row)) self.top[column] += 1 diff --git a/example/tutorial/step1/app.py b/example/tutorial/step1/app.py index db69070a1..595a10dc7 100644 --- a/example/tutorial/step1/app.py +++ b/example/tutorial/step1/app.py @@ -26,7 +26,7 @@ async def handler(websocket): try: # Play the move. row = game.play(player, column) - except RuntimeError as exc: + except ValueError as exc: # Send an "error" event if the move was illegal. event = { "type": "error", diff --git a/example/tutorial/step2/app.py b/example/tutorial/step2/app.py index feaf223a0..86b2c88c3 100644 --- a/example/tutorial/step2/app.py +++ b/example/tutorial/step2/app.py @@ -4,6 +4,7 @@ import json import secrets +from websockets.asyncio.connection import broadcast from websockets.asyncio.server import serve from connect4 import PLAYER1, PLAYER2, Connect4 @@ -59,7 +60,7 @@ async def play(websocket, game, player, connected): try: # Play the move. row = game.play(player, column) - except RuntimeError as exc: + except ValueError as exc: # Send an "error" event if the move was illegal. await error(websocket, str(exc)) continue @@ -71,7 +72,7 @@ async def play(websocket, game, player, connected): "column": column, "row": row, } - websockets.broadcast(connected, json.dumps(event)) + broadcast(connected, json.dumps(event)) # If move is winning, send a "win" event. if game.winner is not None: @@ -79,7 +80,7 @@ async def play(websocket, game, player, connected): "type": "win", "player": game.winner, } - websockets.broadcast(connected, json.dumps(event)) + broadcast(connected, json.dumps(event)) async def start(websocket): diff --git a/example/tutorial/step3/app.py b/example/tutorial/step3/app.py index a428e29e7..34024d087 100644 --- a/example/tutorial/step3/app.py +++ b/example/tutorial/step3/app.py @@ -6,6 +6,7 @@ import secrets import signal +from websockets.asyncio.connection import broadcast from websockets.asyncio.server import serve from connect4 import PLAYER1, PLAYER2, Connect4 @@ -61,7 +62,7 @@ async def play(websocket, game, player, connected): try: # Play the move. row = game.play(player, column) - except RuntimeError as exc: + except ValueError as exc: # Send an "error" event if the move was illegal. await error(websocket, str(exc)) continue @@ -73,7 +74,7 @@ async def play(websocket, game, player, connected): "column": column, "row": row, } - websockets.broadcast(connected, json.dumps(event)) + broadcast(connected, json.dumps(event)) # If move is winning, send a "win" event. if game.winner is not None: @@ -81,7 +82,7 @@ async def play(websocket, game, player, connected): "type": "win", "player": game.winner, } - websockets.broadcast(connected, json.dumps(event)) + broadcast(connected, json.dumps(event)) async def start(websocket): From 2a17e1dac6a4514f3663c4e15546ebe33bf90e4b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 19 Aug 2024 21:49:46 +0200 Subject: [PATCH 098/109] Move broadcast() to the server module. Rationale: * It is only useful for servers. (Maybe there's a use case for client but I couldn't picture it.) * It was already documented in the page covering the server module. * Within this page, it was the only API from the connection or protocol module. The implementation remains in the connection or protocol module because moving it would require refactoring tests. I'd rather keep them simple. (And I'm lazy.) This change doesn't require a backwards compatibility shim because the documentated location of the legacy implementation of broadcast was websockets.broadcast, it's changing with the introduction of the new asyncio API, and the changes are already documented. --- docs/faq/server.rst | 4 ++-- docs/howto/patterns.rst | 2 +- docs/howto/upgrade.rst | 4 ++-- docs/intro/tutorial2.rst | 15 +++++++-------- docs/project/changelog.rst | 4 ++-- docs/reference/asyncio/server.rst | 2 +- docs/reference/features.rst | 2 ++ docs/reference/legacy/server.rst | 10 +++++----- docs/topics/broadcast.rst | 20 ++++++++++---------- docs/topics/logging.rst | 2 +- docs/topics/performance.rst | 4 ++-- example/django/notifications.py | 3 +-- example/quickstart/counter.py | 3 +-- example/quickstart/show_time_2.py | 3 +-- example/tutorial/step2/app.py | 3 +-- example/tutorial/step3/app.py | 3 +-- experiments/broadcast/server.py | 3 +-- src/websockets/__init__.py | 7 ++++--- src/websockets/asyncio/connection.py | 25 ++++++++++++++++++++----- src/websockets/asyncio/server.py | 4 ++-- src/websockets/legacy/protocol.py | 25 ++++++++++++++++++++----- src/websockets/legacy/server.py | 10 ++++++++-- tests/asyncio/test_connection.py | 1 + 23 files changed, 96 insertions(+), 63 deletions(-) diff --git a/docs/faq/server.rst b/docs/faq/server.rst index e6b068316..66e81edfe 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -116,9 +116,9 @@ Record all connections in a global variable:: finally: CONNECTIONS.remove(websocket) -Then, call :func:`~websockets.asyncio.connection.broadcast`:: +Then, call :func:`~websockets.asyncio.server.broadcast`:: - from websockets.asyncio.connection import broadcast + from websockets.asyncio.server import broadcast def message_all(message): broadcast(CONNECTIONS, message) diff --git a/docs/howto/patterns.rst b/docs/howto/patterns.rst index 60bc8ab42..bfb78b6ca 100644 --- a/docs/howto/patterns.rst +++ b/docs/howto/patterns.rst @@ -90,7 +90,7 @@ connect and unregister them when they disconnect:: connected.add(websocket) try: # Broadcast a message to all connected clients. - websockets.broadcast(connected, "Hello!") + broadcast(connected, "Hello!") await asyncio.sleep(10) finally: # Unregister. diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 40c8c5ec9..c5320155e 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -162,8 +162,8 @@ Server APIs | ``websockets.server.WebSocketServerProtocol`` |br| | | | :class:`websockets.legacy.server.WebSocketServerProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.broadcast`` |br| | :func:`websockets.asyncio.connection.broadcast` | -| :func:`websockets.legacy.protocol.broadcast()` | | +| ``websockets.broadcast`` |br| | :func:`websockets.asyncio.server.broadcast` | +| :func:`websockets.legacy.server.broadcast()` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.BasicAuthWebSocketServerProtocol`` |br| | *not available yet* | | ``websockets.auth.BasicAuthWebSocketServerProtocol`` |br| | | diff --git a/docs/intro/tutorial2.rst b/docs/intro/tutorial2.rst index b5d3a3dc8..0211615d1 100644 --- a/docs/intro/tutorial2.rst +++ b/docs/intro/tutorial2.rst @@ -482,11 +482,11 @@ you're using this pattern: ... Since this is a very common pattern in WebSocket servers, websockets provides -the :func:`~asyncio.connection.broadcast` helper for this purpose: +the :func:`~asyncio.server.broadcast` helper for this purpose: .. code-block:: python - from websockets.asyncio.connection import broadcast + from websockets.asyncio.server import broadcast async def handler(websocket): @@ -496,13 +496,12 @@ the :func:`~asyncio.connection.broadcast` helper for this purpose: ... -Calling :func:`~asyncio.connection.broadcast` once is more efficient than +Calling :func:`~asyncio.server.broadcast` once is more efficient than calling :meth:`~asyncio.server.ServerConnection.send` in a loop. However, there's a subtle difference in behavior. Did you notice that there's no -``await`` in the second version? Indeed, :func:`~asyncio.connection.broadcast` -is a function, not a coroutine like -:meth:`~asyncio.server.ServerConnection.send` or +``await`` in the second version? Indeed, :func:`~asyncio.server.broadcast` is a +function, not a coroutine like :meth:`~asyncio.server.ServerConnection.send` or :meth:`~asyncio.server.ServerConnection.recv`. It's quite obvious why :meth:`~asyncio.server.ServerConnection.recv` @@ -524,8 +523,8 @@ That said, when you're sending the same messages to many clients in a loop, applying backpressure in this way can become counterproductive. When you're broadcasting, you don't want to slow down everyone to the pace of the slowest clients; you want to drop clients that cannot keep up with the data stream. -That's why :func:`~asyncio.connection.broadcast` doesn't wait until write -buffers drain and therefore doesn't need to be a coroutine. +That's why :func:`~asyncio.server.broadcast` doesn't wait until write buffers +drain and therefore doesn't need to be a coroutine. For our Connect Four game, there's no difference in practice. The total amount of data sent on a connection for a game of Connect Four is so small that the diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index df5af54f4..e85c3a395 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -212,7 +212,7 @@ Improvements * Added platform-independent wheels. -* Improved error handling in :func:`~legacy.protocol.broadcast`. +* Improved error handling in :func:`~legacy.server.broadcast`. * Set ``server_hostname`` automatically on TLS connections when providing a ``sock`` argument to :func:`~sync.client.connect`. @@ -402,7 +402,7 @@ New features * Added compatibility with Python 3.10. -* Added :func:`~legacy.protocol.broadcast` to send a message to many clients. +* Added :func:`~legacy.server.broadcast` to send a message to many clients. * Added support for reconnecting automatically by using :func:`~legacy.client.connect` as an asynchronous iterator. diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index 7bceca5a0..541c9952c 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -80,4 +80,4 @@ Using a connection Broadcast --------- -.. autofunction:: websockets.asyncio.connection.broadcast +.. autofunction:: websockets.asyncio.server.broadcast diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 6840fe15b..cb0e564f9 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -35,6 +35,8 @@ Both sides +------------------------------------+--------+--------+--------+--------+ | Send a message | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ + | Broadcast a message | ✅ | ❌ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ | Receive a message | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Iterate over received messages | ✅ | ✅ | — | ✅ | diff --git a/docs/reference/legacy/server.rst b/docs/reference/legacy/server.rst index c2758f5a2..b6c383ce7 100644 --- a/docs/reference/legacy/server.rst +++ b/docs/reference/legacy/server.rst @@ -89,6 +89,11 @@ Using a connection .. autoproperty:: close_reason +Broadcast +--------- + +.. autofunction:: websockets.legacy.server.broadcast + Basic authentication -------------------- @@ -106,8 +111,3 @@ websockets supports HTTP Basic Authentication according to .. autoattribute:: username .. automethod:: check_credentials - -Broadcast ---------- - -.. autofunction:: websockets.legacy.protocol.broadcast diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index ec358bbd2..c9699feb2 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -4,19 +4,19 @@ Broadcasting .. currentmodule:: websockets .. admonition:: If you want to send a message to all connected clients, - use :func:`~asyncio.connection.broadcast`. + use :func:`~asyncio.server.broadcast`. :class: tip If you want to learn about its design, continue reading this document. For the legacy :mod:`asyncio` implementation, use - :func:`~legacy.protocol.broadcast`. + :func:`~legacy.server.broadcast`. WebSocket servers often send the same message to all connected clients or to a subset of clients for which the message is relevant. Let's explore options for broadcasting a message, explain the design of -:func:`~asyncio.connection.broadcast`, and discuss alternatives. +:func:`~asyncio.server.broadcast`, and discuss alternatives. For each option, we'll provide a connection handler called ``handler()`` and a function or coroutine called ``broadcast()`` that sends a message to all @@ -124,7 +124,7 @@ connections before the write buffer can fill up. Don't set extreme values of ``write_limit``, ``ping_interval``, or ``ping_timeout`` to ensure that this condition holds! Instead, set reasonable -values and use the built-in :func:`~asyncio.connection.broadcast` function. +values and use the built-in :func:`~asyncio.server.broadcast` function. The concurrent way ------------------ @@ -209,11 +209,11 @@ If a client gets too far behind, eventually it reaches the limit defined by ``ping_timeout`` and websockets terminates the connection. You can refer to the discussion of :doc:`keepalive ` for details. -How :func:`~asyncio.connection.broadcast` works ------------------------------------------------ +How :func:`~asyncio.server.broadcast` works +------------------------------------------- -The built-in :func:`~asyncio.connection.broadcast` function is similar to the -naive way. The main difference is that it doesn't apply backpressure. +The built-in :func:`~asyncio.server.broadcast` function is similar to the naive +way. The main difference is that it doesn't apply backpressure. This provides the best performance by avoiding the overhead of scheduling and running one task per client. @@ -324,7 +324,7 @@ the asynchronous iterator returned by ``subscribe()``. Performance considerations -------------------------- -The built-in :func:`~asyncio.connection.broadcast` function sends all messages +The built-in :func:`~asyncio.server.broadcast` function sends all messages without yielding control to the event loop. So does the naive way when the network and clients are fast and reliable. @@ -346,7 +346,7 @@ However, this isn't possible in general for two reasons: All other patterns discussed above yield control to the event loop once per client because messages are sent by different tasks. This makes them slower -than the built-in :func:`~asyncio.connection.broadcast` function. +than the built-in :func:`~asyncio.server.broadcast` function. There is no major difference between the performance of per-client queues and publish–subscribe. diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index cad49ba55..9580b4c50 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -221,7 +221,7 @@ Here's what websockets logs at each level. ``WARNING`` ........... -* Failures in :func:`~asyncio.connection.broadcast` +* Failures in :func:`~asyncio.server.broadcast` ``INFO`` ........ diff --git a/docs/topics/performance.rst b/docs/topics/performance.rst index b226cec43..b0828fe0d 100644 --- a/docs/topics/performance.rst +++ b/docs/topics/performance.rst @@ -18,5 +18,5 @@ application.) broadcast --------- -:func:`~asyncio.connection.broadcast` is the most efficient way to send a -message to many clients. +:func:`~asyncio.server.broadcast` is the most efficient way to send a message to +many clients. diff --git a/example/django/notifications.py b/example/django/notifications.py index 445438d2d..76ce9c2d7 100644 --- a/example/django/notifications.py +++ b/example/django/notifications.py @@ -10,8 +10,7 @@ from django.contrib.contenttypes.models import ContentType from sesame.utils import get_user -from websockets.asyncio.connection import broadcast -from websockets.asyncio.server import serve +from websockets.asyncio.server import broadcast, serve from websockets.frames import CloseCode diff --git a/example/quickstart/counter.py b/example/quickstart/counter.py index d42069e64..91eedc56a 100755 --- a/example/quickstart/counter.py +++ b/example/quickstart/counter.py @@ -3,8 +3,7 @@ import asyncio import json import logging -from websockets.asyncio.connection import broadcast -from websockets.asyncio.server import serve +from websockets.asyncio.server import broadcast, serve logging.basicConfig() diff --git a/example/quickstart/show_time_2.py b/example/quickstart/show_time_2.py index 4fa244a23..9c9659d14 100755 --- a/example/quickstart/show_time_2.py +++ b/example/quickstart/show_time_2.py @@ -4,8 +4,7 @@ import datetime import random -from websockets.asyncio.connection import broadcast -from websockets.asyncio.server import serve +from websockets.asyncio.server import broadcast, serve CONNECTIONS = set() diff --git a/example/tutorial/step2/app.py b/example/tutorial/step2/app.py index 86b2c88c3..ef3dd9483 100644 --- a/example/tutorial/step2/app.py +++ b/example/tutorial/step2/app.py @@ -4,8 +4,7 @@ import json import secrets -from websockets.asyncio.connection import broadcast -from websockets.asyncio.server import serve +from websockets.asyncio.server import broadcast, serve from connect4 import PLAYER1, PLAYER2, Connect4 diff --git a/example/tutorial/step3/app.py b/example/tutorial/step3/app.py index 34024d087..261057f9a 100644 --- a/example/tutorial/step3/app.py +++ b/example/tutorial/step3/app.py @@ -6,8 +6,7 @@ import secrets import signal -from websockets.asyncio.connection import broadcast -from websockets.asyncio.server import serve +from websockets.asyncio.server import broadcast, serve from connect4 import PLAYER1, PLAYER2, Connect4 diff --git a/experiments/broadcast/server.py b/experiments/broadcast/server.py index 0a5c82b3c..d5b50bd71 100644 --- a/experiments/broadcast/server.py +++ b/experiments/broadcast/server.py @@ -7,8 +7,7 @@ import time from websockets import ConnectionClosed -from websockets.asyncio.server import serve -from websockets.asyncio.connection import broadcast +from websockets.asyncio.server import broadcast, serve CLIENTS = set() diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index fdb028f4c..b618a6dff 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -48,10 +48,10 @@ "unix_connect", # .legacy.protocol "WebSocketCommonProtocol", - "broadcast", # .legacy.server "WebSocketServer", "WebSocketServerProtocol", + "broadcast", "serve", "unix_serve", # .server @@ -102,10 +102,11 @@ basic_auth_protocol_factory, ) from .legacy.client import WebSocketClientProtocol, connect, unix_connect - from .legacy.protocol import WebSocketCommonProtocol, broadcast + from .legacy.protocol import WebSocketCommonProtocol from .legacy.server import ( WebSocketServer, WebSocketServerProtocol, + broadcast, serve, unix_serve, ) @@ -164,10 +165,10 @@ "unix_connect": ".legacy.client", # .legacy.protocol "WebSocketCommonProtocol": ".legacy.protocol", - "broadcast": ".legacy.protocol", # .legacy.server "WebSocketServer": ".legacy.server", "WebSocketServerProtocol": ".legacy.server", + "broadcast": ".legacy.server", "serve": ".legacy.server", "unix_serve": ".legacy.server", # .server diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 284fe2124..a6b909c72 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -34,7 +34,7 @@ from .messages import Assembler -__all__ = ["Connection", "broadcast"] +__all__ = ["Connection"] class Connection(asyncio.Protocol): @@ -1011,6 +1011,12 @@ def eof_received(self) -> None: # Besides, that doesn't work on TLS connections. +# broadcast() is defined in the connection module even though it's primarily +# used by servers and documented in the server module because it works with +# client connections too and because it's easier to test together with the +# Connection class. + + def broadcast( connections: Iterable[Connection], message: Data, @@ -1034,10 +1040,11 @@ def broadcast( ``ping_interval`` and ``ping_timeout`` low to prevent excessive memory usage from slow connections. - Unlike :meth:`~Connection.send`, :func:`broadcast` doesn't support sending - fragmented messages. Indeed, fragmentation is useful for sending large - messages without buffering them in memory, while :func:`broadcast` buffers - one copy per connection as fast as possible. + Unlike :meth:`~websockets.asyncio.connection.Connection.send`, + :func:`broadcast` doesn't support sending fragmented messages. Indeed, + fragmentation is useful for sending large messages without buffering them in + memory, while :func:`broadcast` buffers one copy per connection as fast as + possible. :func:`broadcast` skips connections that aren't open in order to avoid errors on connections where the closing handshake is in progress. @@ -1047,6 +1054,10 @@ def broadcast( set ``raise_exceptions`` to :obj:`True` to record failures and raise all exceptions in a :pep:`654` :exc:`ExceptionGroup`. + While :func:`broadcast` makes more sense for servers, it works identically + with clients, if you have a use case for opening connections to many servers + and broadcasting a message to them. + Args: websockets: WebSocket connections to which the message will be sent. message: Message to send. @@ -1101,3 +1112,7 @@ def broadcast( if raise_exceptions and exceptions: raise ExceptionGroup("skipped broadcast", exceptions) + + +# Pretend that broadcast is actually defined in the server module. +broadcast.__module__ = "websockets.asyncio.server" diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 1f55502bb..35637a18f 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -25,10 +25,10 @@ from ..server import ServerProtocol from ..typing import LoggerLike, Origin, StatusLike, Subprotocol from .compatibility import asyncio_timeout -from .connection import Connection +from .connection import Connection, broadcast -__all__ = ["serve", "unix_serve", "ServerConnection", "WebSocketServer"] +__all__ = ["broadcast", "serve", "unix_serve", "ServerConnection", "WebSocketServer"] class ServerConnection(Connection): diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 66eb94199..e83e146f9 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -53,7 +53,7 @@ from .framing import Frame -__all__ = ["WebSocketCommonProtocol", "broadcast"] +__all__ = ["WebSocketCommonProtocol"] # In order to ensure consistency, the code always checks the current value of @@ -1545,6 +1545,12 @@ def eof_received(self) -> None: self.reader.feed_eof() +# broadcast() is defined in the protocol module even though it's primarily +# used by servers and documented in the server module because it works with +# client connections too and because it's easier to test together with the +# WebSocketCommonProtocol class. + + def broadcast( websockets: Iterable[WebSocketCommonProtocol], message: Data, @@ -1568,10 +1574,11 @@ def broadcast( ``ping_interval`` and ``ping_timeout`` low to prevent excessive memory usage from slow connections. - Unlike :meth:`~WebSocketCommonProtocol.send`, :func:`broadcast` doesn't - support sending fragmented messages. Indeed, fragmentation is useful for - sending large messages without buffering them in memory, while - :func:`broadcast` buffers one copy per connection as fast as possible. + Unlike :meth:`~websockets.legacy.protocol.WebSocketCommonProtocol.send`, + :func:`broadcast` doesn't support sending fragmented messages. Indeed, + fragmentation is useful for sending large messages without buffering them in + memory, while :func:`broadcast` buffers one copy per connection as fast as + possible. :func:`broadcast` skips connections that aren't open in order to avoid errors on connections where the closing handshake is in progress. @@ -1581,6 +1588,10 @@ def broadcast( set ``raise_exceptions`` to :obj:`True` to record failures and raise all exceptions in a :pep:`654` :exc:`ExceptionGroup`. + While :func:`broadcast` makes more sense for servers, it works identically + with clients, if you have a use case for opening connections to many servers + and broadcasting a message to them. + Args: websockets: WebSocket connections to which the message will be sent. message: Message to send. @@ -1629,3 +1640,7 @@ def broadcast( if raise_exceptions and exceptions: raise ExceptionGroup("skipped broadcast", exceptions) + + +# Pretend that broadcast is actually defined in the server module. +broadcast.__module__ = "websockets.legacy.server" diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index d230f009e..43136db3e 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -45,10 +45,16 @@ from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol from .handshake import build_response, check_request from .http import read_request -from .protocol import WebSocketCommonProtocol +from .protocol import WebSocketCommonProtocol, broadcast -__all__ = ["serve", "unix_serve", "WebSocketServerProtocol", "WebSocketServer"] +__all__ = [ + "broadcast", + "serve", + "unix_serve", + "WebSocketServerProtocol", + "WebSocketServer", +] # Change to HeadersLike | ... when dropping Python < 3.10. diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 3cdfeb21b..52e4fc5c8 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -9,6 +9,7 @@ from websockets.asyncio.compatibility import TimeoutError, aiter, anext, asyncio_timeout from websockets.asyncio.connection import * +from websockets.asyncio.connection import broadcast from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK from websockets.frames import CloseCode, Frame, Opcode from websockets.protocol import CLIENT, SERVER, Protocol, State From b05fa2cceefcc5cfaba4e0e06e40d588505c8334 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 19 Aug 2024 22:17:14 +0200 Subject: [PATCH 099/109] Rename WebSocketServer to Server. The shorter name is better. Representation remains unambiguous: The change isn't applied to the legacy implementation because it has longer names for other API too. --- docs/faq/server.rst | 3 +-- docs/howto/upgrade.rst | 2 +- docs/project/changelog.rst | 30 ++++++++++++++-------- docs/reference/asyncio/server.rst | 2 +- docs/reference/sync/server.rst | 2 +- src/websockets/asyncio/server.py | 42 +++++++++++++++---------------- src/websockets/sync/server.py | 28 ++++++++++++++------- tests/asyncio/client.py | 4 +-- tests/asyncio/test_server.py | 2 +- tests/sync/client.py | 4 +-- tests/sync/test_server.py | 11 +++++--- 11 files changed, 77 insertions(+), 53 deletions(-) diff --git a/docs/faq/server.rst b/docs/faq/server.rst index 66e81edfe..63eb5ffc6 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -313,8 +313,7 @@ Here's an example that terminates cleanly when it receives SIGTERM on Unix: How do I stop a server while keeping existing connections open? --------------------------------------------------------------- -Call the server's :meth:`~WebSocketServer.close` method with -``close_connections=False``. +Call the server's :meth:`~Server.close` method with ``close_connections=False``. Here's how to adapt the example just above:: diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index c5320155e..8d0895638 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -154,7 +154,7 @@ Server APIs | ``websockets.server.unix_serve()`` |br| | | | :func:`websockets.legacy.server.unix_serve` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.WebSocketServer`` |br| | :class:`websockets.asyncio.server.WebSocketServer` | +| ``websockets.WebSocketServer`` |br| | :class:`websockets.asyncio.server.Server` | | ``websockets.server.WebSocketServer`` |br| | | | :class:`websockets.legacy.server.WebSocketServer` | | +-------------------------------------------------------------------+-----------------------------------------------------+ diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index e85c3a395..f4ae76702 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,16 +35,6 @@ notice. Backwards-incompatible changes .............................. -.. admonition:: The ``ssl_context`` argument of :func:`~sync.client.connect` - and :func:`~sync.server.serve` in the :mod:`threading` implementation is - renamed to ``ssl``. - :class: note - - This aligns the API of the :mod:`threading` implementation with the - :mod:`asyncio` implementation. - - For backwards compatibility, ``ssl_context`` is still supported. - .. admonition:: Receiving the request path in the second parameter of connection handlers is deprecated. :class: note @@ -60,6 +50,26 @@ Backwards-incompatible changes path = request.path # only if handler() uses the path argument ... +.. admonition:: The ``ssl_context`` argument of :func:`~sync.client.connect` + and :func:`~sync.server.serve` in the :mod:`threading` implementation is + renamed to ``ssl``. + :class: note + + This aligns the API of the :mod:`threading` implementation with the + :mod:`asyncio` implementation. + + For backwards compatibility, ``ssl_context`` is still supported. + +.. admonition:: The ``WebSocketServer`` class in the :mod:`threading` + implementation is renamed to :class:`~sync.server.Server`. + :class: note + + This class isn't designed to be imported or instantiated directly. + :func:`~sync.server.serve` returns an instance. For this reason, + the change should be transparent. + + Regardless, an alias provides backwards compatibility. + New features ............ diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index 541c9952c..bd5a34b19 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -15,7 +15,7 @@ Creating a server Running a server ---------------- -.. autoclass:: WebSocketServer +.. autoclass:: Server .. automethod:: close diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst index 26ab872c8..23cb04097 100644 --- a/docs/reference/sync/server.rst +++ b/docs/reference/sync/server.rst @@ -13,7 +13,7 @@ Creating a server Running a server ---------------- -.. autoclass:: WebSocketServer +.. autoclass:: Server .. automethod:: serve_forever diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 35637a18f..8ebbddb67 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -28,7 +28,7 @@ from .connection import Connection, broadcast -__all__ = ["broadcast", "serve", "unix_serve", "ServerConnection", "WebSocketServer"] +__all__ = ["broadcast", "serve", "unix_serve", "ServerConnection", "Server"] class ServerConnection(Connection): @@ -60,7 +60,7 @@ class ServerConnection(Connection): def __init__( self, protocol: ServerProtocol, - server: WebSocketServer, + server: Server, *, ping_interval: float | None = 20, ping_timeout: float | None = 20, @@ -223,11 +223,11 @@ def connection_lost(self, exc: Exception | None) -> None: self.request_rcvd.set_result(None) -class WebSocketServer: +class Server: """ WebSocket server returned by :func:`serve`. - This class mirrors the API of :class:`~asyncio.Server`. + This class mirrors the API of :class:`asyncio.Server`. It keeps track of WebSocket connections in order to close them properly when shutting down. @@ -299,16 +299,16 @@ def __init__( def wrap(self, server: asyncio.Server) -> None: """ - Attach to a given :class:`~asyncio.Server`. + Attach to a given :class:`asyncio.Server`. Since :meth:`~asyncio.loop.create_server` doesn't support injecting a custom ``Server`` class, the easiest solution that doesn't rely on private :mod:`asyncio` APIs is to: - - instantiate a :class:`WebSocketServer` + - instantiate a :class:`Server` - give the protocol factory a reference to that instance - call :meth:`~asyncio.loop.create_server` with the factory - - attach the resulting :class:`~asyncio.Server` with this method + - attach the resulting :class:`asyncio.Server` with this method """ self.server = server @@ -378,7 +378,7 @@ def close(self, close_connections: bool = True) -> None: """ Close the server. - * Close the underlying :class:`~asyncio.Server`. + * Close the underlying :class:`asyncio.Server`. * When ``close_connections`` is :obj:`True`, which is the default, close existing connections. Specifically: @@ -402,7 +402,7 @@ async def _close(self, close_connections: bool) -> None: Implementation of :meth:`close`. This calls :meth:`~asyncio.Server.close` on the underlying - :class:`~asyncio.Server` object to stop accepting new connections and + :class:`asyncio.Server` object to stop accepting new connections and then closes open connections with close code 1001. """ @@ -516,7 +516,7 @@ def sockets(self) -> Iterable[socket.socket]: """ return self.server.sockets - async def __aenter__(self) -> WebSocketServer: # pragma: no cover + async def __aenter__(self) -> Server: # pragma: no cover return self async def __aexit__( @@ -543,8 +543,8 @@ class serve: Once the handler completes, either normally or with an exception, the server performs the closing handshake and closes the connection. - This coroutine returns a :class:`WebSocketServer` whose API mirrors - :class:`~asyncio.Server`. Treat it as an asynchronous context manager to + This coroutine returns a :class:`Server` whose API mirrors + :class:`asyncio.Server`. Treat it as an asynchronous context manager to ensure that the server will be closed:: def handler(websocket): @@ -556,8 +556,8 @@ def handler(websocket): async with websockets.asyncio.server.serve(handler, host, port): await stop - Alternatively, call :meth:`~WebSocketServer.serve_forever` to serve requests - and cancel it to stop the server:: + Alternatively, call :meth:`~Server.serve_forever` to serve requests and + cancel it to stop the server:: server = await websockets.asyncio.server.serve(handler, host, port) await server.serve_forever() @@ -638,8 +638,8 @@ def handler(websocket): socket and customize it. * You can set ``start_serving`` to ``False`` to start accepting connections - only after you call :meth:`~WebSocketServer.start_serving()` or - :meth:`~WebSocketServer.serve_forever()`. + only after you call :meth:`~Server.start_serving()` or + :meth:`~Server.serve_forever()`. """ @@ -704,7 +704,7 @@ def __init__( if create_connection is None: create_connection = ServerConnection - self.server = WebSocketServer( + self.server = Server( handler, process_request=process_request, process_response=process_response, @@ -773,7 +773,7 @@ def protocol_select_subprotocol( # async with serve(...) as ...: ... - async def __aenter__(self) -> WebSocketServer: + async def __aenter__(self) -> Server: return await self async def __aexit__( @@ -787,11 +787,11 @@ async def __aexit__( # ... = await serve(...) - def __await__(self) -> Generator[Any, None, WebSocketServer]: + def __await__(self) -> Generator[Any, None, Server]: # Create a suitable iterator by calling __await__ on a coroutine. return self.__await_impl__().__await__() - async def __await_impl__(self) -> WebSocketServer: + async def __await_impl__(self) -> Server: server = await self._create_server self.server.wrap(server) return self.server @@ -805,7 +805,7 @@ def unix_serve( handler: Callable[[ServerConnection], Awaitable[None]], path: str | None = None, **kwargs: Any, -) -> Awaitable[WebSocketServer]: +) -> Awaitable[Server]: """ Create a WebSocket server listening on a Unix socket. diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index b381908ca..85a7e9907 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -24,7 +24,7 @@ from .utils import Deadline -__all__ = ["serve", "unix_serve", "ServerConnection", "WebSocketServer"] +__all__ = ["serve", "unix_serve", "ServerConnection", "Server"] class ServerConnection(Connection): @@ -196,7 +196,7 @@ def recv_events(self) -> None: self.request_rcvd.set() -class WebSocketServer: +class Server: """ WebSocket server returned by :func:`serve`. @@ -283,7 +283,7 @@ def fileno(self) -> int: """ return self.socket.fileno() - def __enter__(self) -> WebSocketServer: + def __enter__(self) -> Server: return self def __exit__( @@ -295,6 +295,16 @@ def __exit__( self.shutdown() +def __getattr__(name: str) -> Any: + if name == "WebSocketServer": + warnings.warn( + "WebSocketServer was renamed to Server", + DeprecationWarning, + ) + return Server + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + def serve( handler: Callable[[ServerConnection], None], host: str | None = None, @@ -340,7 +350,7 @@ def serve( # Escape hatch for advanced customization create_connection: type[ServerConnection] | None = None, **kwargs: Any, -) -> WebSocketServer: +) -> Server: """ Create a WebSocket server listening on ``host`` and ``port``. @@ -353,10 +363,10 @@ def serve( Once the handler completes, either normally or with an exception, the server performs the closing handshake and closes the connection. - This function returns a :class:`WebSocketServer` whose API mirrors + This function returns a :class:`Server` whose API mirrors :class:`~socketserver.BaseServer`. Treat it as a context manager to ensure - that it will be closed and call :meth:`~WebSocketServer.serve_forever` to - serve requests:: + that it will be closed and call :meth:`~Server.serve_forever` to serve + requests:: def handler(websocket): ... @@ -552,14 +562,14 @@ def protocol_select_subprotocol( # Initialize server - return WebSocketServer(sock, conn_handler, logger) + return Server(sock, conn_handler, logger) def unix_serve( handler: Callable[[ServerConnection], None], path: str | None = None, **kwargs: Any, -) -> WebSocketServer: +) -> Server: """ Create a WebSocket server listening on a Unix socket. diff --git a/tests/asyncio/client.py b/tests/asyncio/client.py index e5826add7..a73079c6e 100644 --- a/tests/asyncio/client.py +++ b/tests/asyncio/client.py @@ -1,7 +1,7 @@ import contextlib from websockets.asyncio.client import * -from websockets.asyncio.server import WebSocketServer +from websockets.asyncio.server import Server from .server import get_server_host_port @@ -17,7 +17,7 @@ async def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): if isinstance(wsuri_or_server, str): wsuri = wsuri_or_server else: - assert isinstance(wsuri_or_server, WebSocketServer) + assert isinstance(wsuri_or_server, Server) if secure is None: secure = "ssl" in kwargs protocol = "wss" if secure else "ws" diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 5d4f0e2f8..4b637f3af 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -535,7 +535,7 @@ async def test_unsupported_compression(self): class WebSocketServerTests(unittest.IsolatedAsyncioTestCase): async def test_logger(self): - """WebSocketServer accepts a logger argument.""" + """Server accepts a logger argument.""" logger = logging.getLogger("test") async with run_server(logger=logger) as server: self.assertIs(server.logger, logger) diff --git a/tests/sync/client.py b/tests/sync/client.py index 72eb5b8d2..acbf97fa7 100644 --- a/tests/sync/client.py +++ b/tests/sync/client.py @@ -1,7 +1,7 @@ import contextlib from websockets.sync.client import * -from websockets.sync.server import WebSocketServer +from websockets.sync.server import Server __all__ = [ @@ -15,7 +15,7 @@ def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): if isinstance(wsuri_or_server, str): wsuri = wsuri_or_server else: - assert isinstance(wsuri_or_server, WebSocketServer) + assert isinstance(wsuri_or_server, Server) if secure is None: # Backwards compatibility: ssl used to be called ssl_context. secure = "ssl" in kwargs or "ssl_context" in kwargs diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index c0a5f01e6..315601eca 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -383,18 +383,18 @@ def test_unsupported_compression(self): class WebSocketServerTests(unittest.TestCase): def test_logger(self): - """WebSocketServer accepts a logger argument.""" + """Server accepts a logger argument.""" logger = logging.getLogger("test") with run_server(logger=logger) as server: self.assertIs(server.logger, logger) def test_fileno(self): - """WebSocketServer provides a fileno attribute.""" + """Server provides a fileno attribute.""" with run_server() as server: self.assertIsInstance(server.fileno(), int) def test_shutdown(self): - """WebSocketServer provides a shutdown method.""" + """Server provides a shutdown method.""" with run_server() as server: server.shutdown() # Check that the server socket is closed. @@ -409,3 +409,8 @@ def test_ssl_context_argument(self): with run_server(ssl_context=SERVER_CONTEXT) as server: with run_client(server, ssl=CLIENT_CONTEXT): pass + + def test_web_socket_server_class(self): + with self.assertDeprecationWarning("WebSocketServer was renamed to Server"): + from websockets.sync.server import WebSocketServer + self.assertIs(WebSocketServer, Server) From 09b1d8d4d585ed6d4e2c0db6e200e48f176215b1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 19 Aug 2024 22:26:56 +0200 Subject: [PATCH 100/109] Fix tests on Python < 3.10. --- tests/asyncio/test_connection.py | 17 +++++++++++++++++ tests/legacy/utils.py | 31 ++++++++++++++++--------------- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 52e4fc5c8..29bb00418 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -44,6 +44,23 @@ async def asyncTearDown(self): await self.remote_connection.close() await self.connection.close() + if sys.version_info[:2] < (3, 10): # pragma: no cover + + @contextlib.contextmanager + def assertNoLogs(self, logger="websockets", level=logging.ERROR): + """ + No message is logged on the given logger with at least the given level. + + """ + with self.assertLogs(logger, level) as logs: + # We want to test that no log message is emitted + # but assertLogs expects at least one log message. + logging.getLogger(logger).log(level, "dummy") + yield + + level_name = logging.getLevelName(level) + self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) + # Test helpers built upon RecordingProtocol and InterceptingConnection. async def assertFrameSent(self, frame): diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index 5bb56b26f..5b56050d5 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -56,21 +56,22 @@ def run_loop_once(self): self.loop.call_soon(self.loop.stop) self.loop.run_forever() - # Remove when dropping Python < 3.10 - @contextlib.contextmanager - def assertNoLogs(self, logger="websockets", level=logging.ERROR): - """ - No message is logged on the given logger with at least the given level. - - """ - with self.assertLogs(logger, level) as logs: - # We want to test that no log message is emitted - # but assertLogs expects at least one log message. - logging.getLogger(logger).log(level, "dummy") - yield - - level_name = logging.getLevelName(level) - self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) + if sys.version_info[:2] < (3, 10): # pragma: no cover + + @contextlib.contextmanager + def assertNoLogs(self, logger="websockets", level=logging.ERROR): + """ + No message is logged on the given logger with at least the given level. + + """ + with self.assertLogs(logger, level) as logs: + # We want to test that no log message is emitted + # but assertLogs expects at least one log message. + logging.getLogger(logger).log(level, "dummy") + yield + + level_name = logging.getLevelName(level) + self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) def assertDeprecationWarnings(self, recorded_warnings, expected_warnings): """ From 8eaa5a26b667fccb1b9d75034a77b2d6906e2b2e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 07:49:47 +0200 Subject: [PATCH 101/109] Document & test process_response modifying the response. a78b5546 inadvertently changed the test from "returning a new response" to "modifying the existing response". Both are supported.. --- src/websockets/asyncio/server.py | 11 +++--- src/websockets/sync/server.py | 13 ++++--- tests/asyncio/test_server.py | 65 +++++++++++++++++++++----------- tests/sync/test_server.py | 33 ++++++++++------ 4 files changed, 78 insertions(+), 44 deletions(-) diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 8ebbddb67..8f04ec318 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -236,15 +236,16 @@ class Server: handler: Connection handler. It receives the WebSocket connection, which is a :class:`ServerConnection`, in argument. process_request: Intercept the request during the opening handshake. - Return an HTTP response to force the response or :obj:`None` to + Return an HTTP response to force the response. Return :obj:`None` to continue normally. When you force an HTTP 101 Continue response, the handshake is successful. Else, the connection is aborted. ``process_request`` may be a function or a coroutine. process_response: Intercept the response during the opening handshake. - Return an HTTP response to force the response or :obj:`None` to - continue normally. When you force an HTTP 101 Continue response, the - handshake is successful. Else, the connection is aborted. - ``process_response`` may be a function or a coroutine. + Modify the response or return a new HTTP response to force the + response. Return :obj:`None` to continue normally. When you force an + HTTP 101 Continue response, the handshake is successful. Else, the + connection is aborted. ``process_response`` may be a function or a + coroutine. server_header: Value of the ``Server`` response header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 85a7e9907..86c162af3 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -401,13 +401,14 @@ def handler(websocket): :meth:`ServerProtocol.select_subprotocol ` method. process_request: Intercept the request during the opening handshake. - Return an HTTP response to force the response or :obj:`None` to - continue normally. When you force an HTTP 101 Continue response, - the handshake is successful. Else, the connection is aborted. + Return an HTTP response to force the response. Return :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. process_response: Intercept the response during the opening handshake. - Return an HTTP response to force the response or :obj:`None` to - continue normally. When you force an HTTP 101 Continue response, - the handshake is successful. Else, the connection is aborted. + Modify the response or return a new HTTP response to force the + response. Return :obj:`None` to continue normally. When you force an + HTTP 101 Continue response, the handshake is successful. Else, the + connection is aborted. server_header: Value of the ``Server`` response header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 4b637f3af..b899998f4 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -1,4 +1,5 @@ import asyncio +import dataclasses import http import logging import socket @@ -117,8 +118,8 @@ def select_subprotocol(ws, subprotocols): "server rejected WebSocket connection: HTTP 500", ) - async def test_process_request(self): - """Server runs process_request before processing the handshake.""" + async def test_process_request_returns_none(self): + """Server runs process_request and continues the handshake.""" def process_request(ws, request): self.assertIsInstance(request, Request) @@ -128,8 +129,8 @@ def process_request(ws, request): async with run_client(server) as client: await self.assertEval(client, "ws.process_request_ran", "True") - async def test_async_process_request(self): - """Server runs async process_request before processing the handshake.""" + async def test_async_process_request_returns_none(self): + """Server runs async process_request and continues the handshake.""" async def process_request(ws, request): self.assertIsInstance(request, Request) @@ -139,7 +140,7 @@ async def process_request(ws, request): async with run_client(server) as client: await self.assertEval(client, "ws.process_request_ran", "True") - async def test_process_request_abort_handshake(self): + async def test_process_request_returns_response(self): """Server aborts handshake if process_request returns a response.""" def process_request(ws, request): @@ -154,7 +155,7 @@ def process_request(ws, request): "server rejected WebSocket connection: HTTP 403", ) - async def test_async_process_request_abort_handshake(self): + async def test_async_process_request_returns_response(self): """Server aborts handshake if async process_request returns a response.""" async def process_request(ws, request): @@ -199,8 +200,8 @@ async def process_request(ws, request): "server rejected WebSocket connection: HTTP 500", ) - async def test_process_response(self): - """Server runs process_response after processing the handshake.""" + async def test_process_response_returns_none(self): + """Server runs process_response but keeps the handshake response.""" def process_response(ws, request, response): self.assertIsInstance(request, Request) @@ -211,8 +212,8 @@ def process_response(ws, request, response): async with run_client(server) as client: await self.assertEval(client, "ws.process_response_ran", "True") - async def test_async_process_response(self): - """Server runs async process_response after processing the handshake.""" + async def test_async_process_response_returns_none(self): + """Server runs async process_response but keeps the handshake response.""" async def process_response(ws, request, response): self.assertIsInstance(request, Request) @@ -223,29 +224,49 @@ async def process_response(ws, request, response): async with run_client(server) as client: await self.assertEval(client, "ws.process_response_ran", "True") - async def test_process_response_override_response(self): - """Server runs process_response and overrides the handshake response.""" + async def test_process_response_modifies_response(self): + """Server runs process_response and modifies the handshake response.""" def process_response(ws, request, response): - response.headers["X-ProcessResponse-Ran"] = "true" + response.headers["X-ProcessResponse"] = "OK" async with run_server(process_response=process_response) as server: async with run_client(server) as client: - self.assertEqual( - client.response.headers["X-ProcessResponse-Ran"], "true" - ) + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") - async def test_async_process_response_override_response(self): - """Server runs async process_response and overrides the handshake response.""" + async def test_async_process_response_modifies_response(self): + """Server runs async process_response and modifies the handshake response.""" async def process_response(ws, request, response): - response.headers["X-ProcessResponse-Ran"] = "true" + response.headers["X-ProcessResponse"] = "OK" async with run_server(process_response=process_response) as server: async with run_client(server) as client: - self.assertEqual( - client.response.headers["X-ProcessResponse-Ran"], "true" - ) + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") + + async def test_process_response_replaces_response(self): + """Server runs process_response and replaces the handshake response.""" + + def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse"] = "OK" + return dataclasses.replace(response, headers=headers) + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") + + async def test_async_process_response_replaces_response(self): + """Server runs async process_response and replaces the handshake response.""" + + async def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse"] = "OK" + return dataclasses.replace(response, headers=headers) + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") async def test_process_response_raises_exception(self): """Server returns an error if process_response raises an exception.""" diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 315601eca..e3dfeb271 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -1,3 +1,4 @@ +import dataclasses import http import logging import socket @@ -115,8 +116,8 @@ def select_subprotocol(ws, subprotocols): "server rejected WebSocket connection: HTTP 500", ) - def test_process_request(self): - """Server runs process_request before processing the handshake.""" + def test_process_request_returns_none(self): + """Server runs process_request and continues the handshake.""" def process_request(ws, request): self.assertIsInstance(request, Request) @@ -126,7 +127,7 @@ def process_request(ws, request): with run_client(server) as client: self.assertEval(client, "ws.process_request_ran", "True") - def test_process_request_abort_handshake(self): + def test_process_request_returns_response(self): """Server aborts handshake if process_request returns a response.""" def process_request(ws, request): @@ -156,8 +157,8 @@ def process_request(ws, request): "server rejected WebSocket connection: HTTP 500", ) - def test_process_response(self): - """Server runs process_response after processing the handshake.""" + def test_process_response_returns_none(self): + """Server runs process_response but keeps the handshake response.""" def process_response(ws, request, response): self.assertIsInstance(request, Request) @@ -168,17 +169,27 @@ def process_response(ws, request, response): with run_client(server) as client: self.assertEval(client, "ws.process_response_ran", "True") - def test_process_response_override_response(self): - """Server runs process_response and overrides the handshake response.""" + def test_process_response_modifies_response(self): + """Server runs process_response and modifies the handshake response.""" def process_response(ws, request, response): - response.headers["X-ProcessResponse-Ran"] = "true" + response.headers["X-ProcessResponse"] = "OK" with run_server(process_response=process_response) as server: with run_client(server) as client: - self.assertEqual( - client.response.headers["X-ProcessResponse-Ran"], "true" - ) + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") + + def test_process_response_replaces_response(self): + """Server runs process_response and replaces the handshake response.""" + + def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse"] = "OK" + return dataclasses.replace(response, headers=headers) + + with run_server(process_response=process_response) as server: + with run_client(server) as client: + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") def test_process_response_raises_exception(self): """Server returns an error if process_response raises an exception.""" From 9e5b91bf8f9039de0af85e597e7fd643cfd1a139 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 08:26:44 +0200 Subject: [PATCH 102/109] Improve documentation of latency. Also fix #1414. --- docs/reference/features.rst | 2 ++ docs/topics/keepalive.rst | 28 ++++++++++++++++++++-------- src/websockets/asyncio/connection.py | 9 +++++---- src/websockets/legacy/protocol.py | 9 +++++---- 4 files changed, 32 insertions(+), 16 deletions(-) diff --git a/docs/reference/features.rst b/docs/reference/features.rst index cb0e564f9..a380f4555 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -59,6 +59,8 @@ Both sides +------------------------------------+--------+--------+--------+--------+ | Heartbeat | ✅ | ❌ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ + | Measure latency | ✅ | ❌ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ | Perform the closing handshake | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Enforce closing timeout | ✅ | ✅ | — | ✅ | diff --git a/docs/topics/keepalive.rst b/docs/topics/keepalive.rst index 1c7a43264..91f11fb11 100644 --- a/docs/topics/keepalive.rst +++ b/docs/topics/keepalive.rst @@ -40,13 +40,16 @@ It loops through these steps: If the Pong frame isn't received, websockets considers the connection broken and closes it. -This mechanism serves two purposes: +This mechanism serves three purposes: 1. It creates a trickle of traffic so that the TCP connection isn't idle and network infrastructure along the path keeps it open ("keepalive"). 2. It detects if the connection drops or becomes so slow that it's unusable in practice ("heartbeat"). In that case, it terminates the connection and your application gets a :exc:`~exceptions.ConnectionClosed` exception. +3. It measures the :attr:`~asyncio.connection.Connection.latency` of the + connection. The time between sending a Ping frame and receiving a matching + Pong frame approximates the round-trip time. Timings are configurable with the ``ping_interval`` and ``ping_timeout`` arguments of :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve`. @@ -54,7 +57,7 @@ Shorter values will detect connection drops faster but they will increase network traffic and they will be more sensitive to latency. Setting ``ping_interval`` to :obj:`None` disables the whole keepalive and -heartbeat mechanism. +heartbeat mechanism, including measurement of latency. Setting ``ping_timeout`` to :obj:`None` disables only timeouts. This enables keepalive, to keep idle connections open, and disables heartbeat, to support large @@ -85,9 +88,23 @@ Unfortunately, the WebSocket API in browsers doesn't expose the native Ping and Pong functionality in the WebSocket protocol. You have to roll your own in the application layer. +Read this `blog post `_ for +a complete walk-through of this issue. + Latency issues -------------- +The :attr:`~asyncio.connection.Connection.latency` attribute stores latency +measured during the last exchange of Ping and Pong frames:: + + latency = websocket.latency + +Alternatively, you can measure the latency at any time by calling +:attr:`~asyncio.connection.Connection.ping` and awaiting its result:: + + pong_waiter = await websocket.ping() + latency = await pong_waiter + Latency between a client and a server may increase for two reasons: * Network connectivity is poor. When network packets are lost, TCP attempts to @@ -97,7 +114,7 @@ Latency between a client and a server may increase for two reasons: * Traffic is high. For example, if a client sends messages on the connection faster than a server can process them, this manifests as latency as well, - because data is waiting in flight, mostly in OS buffers. + because data is waiting in :doc:`buffers `. If the server is more than 20 seconds behind, it doesn't see the Pong before the default timeout elapses. As a consequence, it closes the connection. @@ -109,8 +126,3 @@ Latency between a client and a server may increase for two reasons: The same reasoning applies to situations where the server sends more traffic than the client can accept. - -The latency measured during the last exchange of Ping and Pong frames is -available in the :attr:`~asyncio.connection.Connection.latency` attribute. -Alternatively, you can measure the latency at any time with the -:attr:`~asyncio.connection.Connection.ping` method. diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index a6b909c72..9e7ea3d8c 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -109,12 +109,13 @@ def __init__( """ Latency of the connection, in seconds. - This value is updated after sending a ping frame and receiving a - matching pong frame. Before the first ping, :attr:`latency` is ``0``. + Latency is defined as the round-trip time of the connection. It is + measured by sending a Ping frame and waiting for a matching Pong frame. + Before the first measurement, :attr:`latency` is ``0``. By default, websockets enables a :ref:`keepalive ` mechanism - that sends ping frames automatically at regular intervals. You can also - send ping frames and measure latency with :meth:`ping`. + that sends Ping frames automatically at regular intervals. You can also + send Ping frames and measure latency with :meth:`ping`. """ # Task that sends keepalive pings. None when ping_interval is None. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index e83e146f9..3b9a8c4aa 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -289,12 +289,13 @@ def __init__( """ Latency of the connection, in seconds. - This value is updated after sending a ping frame and receiving a - matching pong frame. Before the first ping, :attr:`latency` is ``0``. + Latency is defined as the round-trip time of the connection. It is + measured by sending a Ping frame and waiting for a matching Pong frame. + Before the first measurement, :attr:`latency` is ``0``. By default, websockets enables a :ref:`keepalive ` mechanism - that sends ping frames automatically at regular intervals. You can also - send ping frames and measure latency with :meth:`ping`. + that sends Ping frames automatically at regular intervals. You can also + send Ping frames and measure latency with :meth:`ping`. """ # Task running the data transfer. From 453e55ac2a20a50bfd30c0b4c011c50d01e7bb0a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 09:52:31 +0200 Subject: [PATCH 103/109] Standardize on raise AssertionError(...). --- experiments/optimization/parse_frames.py | 2 +- experiments/optimization/parse_handshake.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/experiments/optimization/parse_frames.py b/experiments/optimization/parse_frames.py index e3acbe3c2..9ea71c58e 100644 --- a/experiments/optimization/parse_frames.py +++ b/experiments/optimization/parse_frames.py @@ -33,7 +33,7 @@ def parse_frame(data, count, mask, extensions): except StopIteration: pass else: - assert False, "parser should return frame" + raise AssertionError("parser should return frame") reader.feed_eof() assert reader.at_eof(), "parser should consume all data" diff --git a/experiments/optimization/parse_handshake.py b/experiments/optimization/parse_handshake.py index af5a4ecae..393e0215c 100644 --- a/experiments/optimization/parse_handshake.py +++ b/experiments/optimization/parse_handshake.py @@ -71,7 +71,7 @@ def parse_handshake(handshake): except StopIteration: pass else: - assert False, "parser should return request" + raise AssertionError("parser should return request") reader.feed_eof() assert reader.at_eof(), "parser should consume all data" From 9d355bfeb784e41b2a879645113c73c4560c9a91 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 09:53:05 +0200 Subject: [PATCH 104/109] Remove unnecessary code paths in keepalive(). Also add comments in tests to clarify the intended sequence. --- src/websockets/asyncio/connection.py | 14 +++++--- tests/asyncio/test_connection.py | 50 ++++++++++++++++++---------- 2 files changed, 42 insertions(+), 22 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 9e7ea3d8c..005e9b4bb 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -723,6 +723,10 @@ async def keepalive(self) -> None: if self.ping_timeout is not None: try: async with asyncio_timeout(self.ping_timeout): + # connection_lost cancels keepalive immediately + # after setting a ConnectionClosed exception on + # pong_waiter. A CancelledError is raised here, + # not a ConnectionClosed exception. latency = await pong_waiter self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: @@ -733,9 +737,10 @@ async def keepalive(self) -> None: CloseCode.INTERNAL_ERROR, "keepalive ping timeout", ) - break - except ConnectionClosed: - pass + raise AssertionError( + "send_context() should wait for connection_lost(), " + "which cancels keepalive()" + ) except Exception: self.logger.error("keepalive ping failed", exc_info=True) @@ -913,8 +918,7 @@ def connection_lost(self, exc: Exception | None) -> None: self.set_recv_exc(exc) self.recv_messages.close() self.abort_pings() - # If keepalive() was waiting for a pong, abort_pings() terminated it. - # If it was sleeping until the next ping, we need to cancel it now + if self.keepalive_task is not None: self.keepalive_task.cancel() diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 29bb00418..59218de4b 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -890,12 +890,25 @@ async def test_pong_explicit_binary(self): @patch("random.getrandbits") async def test_keepalive(self, getrandbits): - """keepalive sends pings.""" + """keepalive sends pings at ping_interval and measures latency.""" self.connection.ping_interval = 2 * MS getrandbits.return_value = 1918987876 self.connection.start_keepalive() + self.assertEqual(self.connection.latency, 0) + # 2 ms: keepalive() sends a ping frame. + # 2.x ms: a pong frame is received. await asyncio.sleep(3 * MS) + # 3 ms: check that the ping frame was sent. await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + self.assertGreater(self.connection.latency, 0) + self.assertLess(self.connection.latency, MS) + + async def test_disable_keepalive(self): + """keepalive is disabled when ping_interval is None.""" + self.connection.ping_interval = None + self.connection.start_keepalive() + await asyncio.sleep(3 * MS) + await self.assertNoFrameSent() @patch("random.getrandbits") async def test_keepalive_times_out(self, getrandbits): @@ -905,13 +918,14 @@ async def test_keepalive_times_out(self, getrandbits): async with self.drop_frames_rcvd(): getrandbits.return_value = 1918987876 self.connection.start_keepalive() + # 4 ms: keepalive() sends a ping frame. await asyncio.sleep(4 * MS) # Exiting the context manager sleeps for MS. - await self.assertFrameSent(Frame(Opcode.PING, b"rand")) - await asyncio.sleep(MS) - await self.assertFrameSent( - Frame(Opcode.CLOSE, b"\x03\xf3keepalive ping timeout") - ) + # 4.x ms: a pong frame is dropped. + # 6 ms: no pong frame is received; the connection is closed. + await asyncio.sleep(2 * MS) + # 7 ms: check that the connection is closed. + self.assertEqual(self.connection.state, State.CLOSED) @patch("random.getrandbits") async def test_keepalive_ignores_timeout(self, getrandbits): @@ -921,18 +935,14 @@ async def test_keepalive_ignores_timeout(self, getrandbits): async with self.drop_frames_rcvd(): getrandbits.return_value = 1918987876 self.connection.start_keepalive() + # 4 ms: keepalive() sends a ping frame. await asyncio.sleep(4 * MS) # Exiting the context manager sleeps for MS. - await self.assertFrameSent(Frame(Opcode.PING, b"rand")) - await asyncio.sleep(MS) - await self.assertNoFrameSent() - - async def test_disable_keepalive(self): - """keepalive is disabled when ping_interval is None.""" - self.connection.ping_interval = None - self.connection.start_keepalive() - await asyncio.sleep(3 * MS) - await self.assertNoFrameSent() + # 4.x ms: a pong frame is dropped. + # 6 ms: no pong frame is received; the connection remains open. + await asyncio.sleep(2 * MS) + # 7 ms: check that the connection is still open. + self.assertEqual(self.connection.state, State.OPEN) async def test_keepalive_terminates_while_sleeping(self): """keepalive task terminates while waiting to send a ping.""" @@ -945,21 +955,27 @@ async def test_keepalive_terminates_while_sleeping(self): async def test_keepalive_terminates_while_waiting_for_pong(self): """keepalive task terminates while waiting to receive a pong.""" self.connection.ping_interval = 2 * MS + self.connection.ping_timeout = 2 * MS async with self.drop_frames_rcvd(): self.connection.start_keepalive() + # 2 ms: keepalive() sends a ping frame. await asyncio.sleep(2 * MS) # Exiting the context manager sleeps for MS. + # 2.x ms: a pong frame is dropped. + # 3 ms: close the connection before ping_timeout elapses. await self.connection.close() self.assertTrue(self.connection.keepalive_task.done()) async def test_keepalive_reports_errors(self): """keepalive reports unexpected errors in logs.""" self.connection.ping_interval = 2 * MS - # Inject a fault by raising an exception in a pending pong waiter. async with self.drop_frames_rcvd(): self.connection.start_keepalive() + # 2 ms: keepalive() sends a ping frame. await asyncio.sleep(2 * MS) # Exiting the context manager sleeps for MS. + # 2.x ms: a pong frame is dropped. + # 3 ms: inject a fault: raise an exception in the pending pong waiter. pong_waiter = next(iter(self.connection.pong_waiters.values()))[0] with self.assertLogs("websockets", logging.ERROR) as logs: pong_waiter.set_exception(Exception("BOOM")) From 12fa8bc8fcc03a120ceb05700905b6e4698df563 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 14:24:25 +0200 Subject: [PATCH 105/109] Complete changelog with changes since 12.0. --- docs/project/changelog.rst | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index f4ae76702..06d8a7774 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,6 +35,11 @@ notice. Backwards-incompatible changes .............................. +.. admonition:: websockets 13.0 requires Python ≥ 3.8. + :class: tip + + websockets 12.0 is the last version supporting Python 3.7. + .. admonition:: Receiving the request path in the second parameter of connection handlers is deprecated. :class: note @@ -64,9 +69,8 @@ Backwards-incompatible changes implementation is renamed to :class:`~sync.server.Server`. :class: note - This class isn't designed to be imported or instantiated directly. - :func:`~sync.server.serve` returns an instance. For this reason, - the change should be transparent. + This change should be transparent because this class shouldn't be + instantiated directly; :func:`~sync.server.serve` returns an instance. Regardless, an alias provides backwards compatibility. @@ -91,6 +95,22 @@ New features If you were monkey-patching constants, be aware that they were renamed, which will break your configuration. You must switch to the environment variables. +Improvements +............ + +* The error message in server logs when a header is too long is more explicit. + +Bug fixes +......... + +* Fixed a bug in the :mod:`threading` implementation that could prevent the + program from exiting when a connection wasn't closed properly. + +* Redirecting from a ``ws://`` URI to a ``wss://`` URI now works. + +* ``broadcast(raise_exceptions=True)`` no longer crashes when there isn't any + exception. + .. _12.0: 12.0 From 0019943e551d285ec27c29315a23dcc959a2ec29 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 14:39:43 +0200 Subject: [PATCH 106/109] Release version 13.0. --- docs/project/changelog.rst | 2 +- src/websockets/version.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 06d8a7774..7c5998288 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,7 +30,7 @@ notice. 13.0 ---- -*In development* +*August 20, 2024* Backwards-incompatible changes .............................. diff --git a/src/websockets/version.py b/src/websockets/version.py index 44709a91b..56c321940 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,9 +18,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True -tag = version = commit = "12.1" +tag = version = commit = "13.0" if not released: # pragma: no cover From 4d0e0e10c6ebb779780a9e590667661381df78dc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 14:55:32 +0200 Subject: [PATCH 107/109] Build sdist and arch-independent wheel with build. This removes the dependency on setuptools, which isn't installed by default anymore, causing the build to fail. --- .github/workflows/release.yml | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index ed52ddd80..184444e56 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -17,22 +17,18 @@ jobs: uses: actions/setup-python@v5 with: python-version: 3.x - - name: Build sdist - run: python setup.py sdist - - name: Save sdist - uses: actions/upload-artifact@v4 - with: - path: dist/*.tar.gz - - name: Install wheel - run: pip install wheel - - name: Build wheel + - name: Install build + run: pip install build + - name: Build sdist & wheel + run: python -m build env: BUILD_EXTENSION: no - run: python setup.py bdist_wheel - - name: Save wheel + - name: Save sdist & wheel uses: actions/upload-artifact@v4 with: - path: dist/*.whl + path: | + dist/*.tar.gz + dist/*.whl wheels: name: Build architecture-specific wheels on ${{ matrix.os }} From f9c20d0e4c9a25b66d4643879bc4594137036793 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 15:06:07 +0200 Subject: [PATCH 108/109] Avoid deleting .so files in .direnv or equivalent. --- Makefile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index dacfe2a0b..a69248b6e 100644 --- a/Makefile +++ b/Makefile @@ -30,6 +30,7 @@ build: python setup.py build_ext --inplace clean: - find . -name '*.pyc' -delete -o -name '*.so' -delete + find src -name '*.so' -delete + find . -name '*.pyc' -delete find . -name __pycache__ -delete rm -rf .coverage .mypy_cache build compliance/reports dist docs/_build htmlcov MANIFEST src/websockets.egg-info From 323adef1f3000cf07617d7ee649c27c0801126e6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 15:19:28 +0200 Subject: [PATCH 109/109] Migrate to actions/upload-artifact@v4. The version number was increased without accounting for backwards-incompatible changes. --- .github/workflows/release.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 184444e56..4d2b5b75e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -26,6 +26,7 @@ jobs: - name: Save sdist & wheel uses: actions/upload-artifact@v4 with: + name: dist-architecture-independent path: | dist/*.tar.gz dist/*.whl @@ -58,6 +59,7 @@ jobs: - name: Save wheels uses: actions/upload-artifact@v4 with: + name: dist-${{ matrix.os }} path: wheelhouse/*.whl upload: @@ -74,7 +76,8 @@ jobs: - name: Download artifacts uses: actions/download-artifact@v4 with: - name: artifact + pattern: dist-* + merge-multiple: true path: dist - name: Upload to PyPI uses: pypa/gh-action-pypi-publish@release/v1