diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index c30fc6679f..4bb927d995 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -135,7 +135,9 @@ from __future__ import annotations +import asyncio import collections +import random import time import uuid from collections.abc import Mapping as _Mapping @@ -471,6 +473,8 @@ def _max_time_expired_error(exc: PyMongoError) -> bool: # This limit is non-configurable and was chosen to be twice the 60 second # default value of MongoDB's `transactionLifetimeLimitSeconds` parameter. _WITH_TRANSACTION_RETRY_TIME_LIMIT = 120 +_BACKOFF_MAX = 1 +_BACKOFF_INITIAL = 0.050 # 50ms initial backoff def _within_time_limit(start_time: float) -> bool: @@ -700,7 +704,13 @@ async def callback(session, custom_arg, custom_kwarg=None): https://github.com/mongodb/specifications/blob/master/source/transactions-convenient-api/transactions-convenient-api.md#handling-errors-inside-the-callback """ start_time = time.monotonic() + retry = 0 while True: + if retry: # Implement exponential backoff on retry. + jitter = random.random() # noqa: S311 + backoff = jitter * min(_BACKOFF_INITIAL * (2**retry), _BACKOFF_MAX) + await asyncio.sleep(backoff) + retry += 1 await self.start_transaction( read_concern, write_concern, read_preference, max_commit_time_ms ) diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index 741c11e551..6ff62e9fe3 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -58,6 +58,7 @@ AsyncCursor, AsyncRawBatchCursor, ) +from pymongo.asynchronous.helpers import _retry_overload from pymongo.collation import validate_collation_or_none from pymongo.common import _ecoc_coll_name, _esc_coll_name from pymongo.errors import ( @@ -252,6 +253,7 @@ def __init__( unicode_decode_error_handler="replace", document_class=dict ) self._timeout = database.client.options.timeout + self._retry_policy = database.client._retry_policy if create or kwargs: if _IS_SYNC: @@ -2227,6 +2229,7 @@ async def create_indexes( return await self._create_indexes(indexes, session, **kwargs) @_csot.apply + @_retry_overload async def _create_indexes( self, indexes: Sequence[IndexModel], session: Optional[AsyncClientSession], **kwargs: Any ) -> list[str]: @@ -2422,7 +2425,6 @@ async def drop_indexes( kwargs["comment"] = comment await self._drop_index("*", session=session, **kwargs) - @_csot.apply async def drop_index( self, index_or_name: _IndexKeyHint, @@ -2472,6 +2474,7 @@ async def drop_index( await self._drop_index(index_or_name, session, comment, **kwargs) @_csot.apply + @_retry_overload async def _drop_index( self, index_or_name: _IndexKeyHint, @@ -3079,6 +3082,7 @@ async def aggregate_raw_batches( ) @_csot.apply + @_retry_overload async def rename( self, new_name: str, diff --git a/pymongo/asynchronous/database.py b/pymongo/asynchronous/database.py index f70c2b403f..8abc7059d0 100644 --- a/pymongo/asynchronous/database.py +++ b/pymongo/asynchronous/database.py @@ -38,6 +38,7 @@ from pymongo.asynchronous.change_stream import AsyncDatabaseChangeStream from pymongo.asynchronous.collection import AsyncCollection from pymongo.asynchronous.command_cursor import AsyncCommandCursor +from pymongo.asynchronous.helpers import _retry_overload from pymongo.common import _ecoc_coll_name, _esc_coll_name from pymongo.database_shared import _check_name, _CodecDocumentType from pymongo.errors import CollectionInvalid, InvalidOperation @@ -135,6 +136,7 @@ def __init__( self._name = name self._client: AsyncMongoClient[_DocumentType] = client self._timeout = client.options.timeout + self._retry_policy = client._retry_policy @property def client(self) -> AsyncMongoClient[_DocumentType]: @@ -477,6 +479,7 @@ async def watch( return change_stream @_csot.apply + @_retry_overload async def create_collection( self, name: str, @@ -816,6 +819,7 @@ async def command( ... @_csot.apply + @_retry_overload async def command( self, command: Union[str, MutableMapping[str, Any]], @@ -947,6 +951,7 @@ async def command( ) @_csot.apply + @_retry_overload async def cursor_command( self, command: Union[str, MutableMapping[str, Any]], @@ -1264,6 +1269,7 @@ async def _drop_helper( ) @_csot.apply + @_retry_overload async def drop_collection( self, name_or_collection: Union[str, AsyncCollection[_DocumentTypeArg]], diff --git a/pymongo/asynchronous/helpers.py b/pymongo/asynchronous/helpers.py index 54fd64f74a..6ef3beacf5 100644 --- a/pymongo/asynchronous/helpers.py +++ b/pymongo/asynchronous/helpers.py @@ -17,8 +17,11 @@ import asyncio import builtins +import functools +import random import socket import sys +import time as time # noqa: PLC0414 # needed in sync version from typing import ( Any, Callable, @@ -26,10 +29,13 @@ cast, ) +from pymongo import _csot from pymongo.errors import ( OperationFailure, + PyMongoError, ) from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE +from pymongo.lock import _async_create_lock _IS_SYNC = False @@ -38,6 +44,7 @@ def _handle_reauth(func: F) -> F: + @functools.wraps(func) async def inner(*args: Any, **kwargs: Any) -> Any: no_reauth = kwargs.pop("no_reauth", False) from pymongo.asynchronous.pool import AsyncConnection @@ -70,6 +77,123 @@ async def inner(*args: Any, **kwargs: Any) -> Any: return cast(F, inner) +_MAX_RETRIES = 3 +_BACKOFF_INITIAL = 0.05 +_BACKOFF_MAX = 10 +# DRIVERS-3240 will determine these defaults. +DEFAULT_RETRY_TOKEN_CAPACITY = 1000.0 +DEFAULT_RETRY_TOKEN_RETURN = 0.1 + + +def _backoff( + attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX +) -> float: + jitter = random.random() # noqa: S311 + return jitter * min(initial_delay * (2**attempt), max_delay) + + +class _TokenBucket: + """A token bucket implementation for rate limiting.""" + + def __init__( + self, + capacity: float = DEFAULT_RETRY_TOKEN_CAPACITY, + return_rate: float = DEFAULT_RETRY_TOKEN_RETURN, + ): + self.lock = _async_create_lock() + self.capacity = capacity + # DRIVERS-3240 will determine how full the bucket should start. + self.tokens = capacity + self.return_rate = return_rate + + async def consume(self) -> bool: + """Consume a token from the bucket if available.""" + async with self.lock: + if self.tokens >= 1: + self.tokens -= 1 + return True + return False + + async def deposit(self, retry: bool = False) -> None: + """Deposit a token back into the bucket.""" + retry_token = 1 if retry else 0 + async with self.lock: + self.tokens = min(self.capacity, self.tokens + retry_token + self.return_rate) + + +class _RetryPolicy: + """A retry limiter that performs exponential backoff with jitter. + + Retry attempts are limited by a token bucket to prevent overwhelming the server during + a prolonged outage or high load. + """ + + def __init__( + self, + token_bucket: _TokenBucket, + attempts: int = _MAX_RETRIES, + backoff_initial: float = _BACKOFF_INITIAL, + backoff_max: float = _BACKOFF_MAX, + ): + self.token_bucket = token_bucket + self.attempts = attempts + self.backoff_initial = backoff_initial + self.backoff_max = backoff_max + + async def record_success(self, retry: bool) -> None: + """Record a successful operation.""" + await self.token_bucket.deposit(retry) + + def backoff(self, attempt: int) -> float: + """Return the backoff duration for the given .""" + return _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max) + + async def should_retry(self, attempt: int, delay: float) -> bool: + """Return if we have budget to retry and how long to backoff.""" + if attempt > self.attempts: + return False + + # If the delay would exceed the deadline, bail early before consuming a token. + if _csot.get_timeout(): + if time.monotonic() + delay > _csot.get_deadline(): + return False + + # Check token bucket last since we only want to consume a token if we actually retry. + if not await self.token_bucket.consume(): + # DRIVERS-3246 Improve diagnostics when this case happens. + # We could add info to the exception and log. + return False + return True + + +def _retry_overload(func: F) -> F: + @functools.wraps(func) + async def inner(self: Any, *args: Any, **kwargs: Any) -> Any: + retry_policy = self._retry_policy + attempt = 0 + while True: + try: + res = await func(self, *args, **kwargs) + await retry_policy.record_success(retry=attempt > 0) + return res + except PyMongoError as exc: + if not exc.has_error_label("Retryable"): + raise + attempt += 1 + delay = 0 + if exc.has_error_label("SystemOverloaded"): + delay = retry_policy.backoff(attempt) + if not await retry_policy.should_retry(attempt, delay): + raise + + # Implement exponential backoff on retry. + if delay: + await asyncio.sleep(delay) + continue + + return cast(F, inner) + + async def _getaddrinfo( host: Any, port: Any, **kwargs: Any ) -> list[ diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index b616647791..d9994e9902 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -35,6 +35,7 @@ import asyncio import contextlib import os +import time as time # noqa: PLC0414 # needed in sync version import warnings import weakref from collections import defaultdict @@ -67,6 +68,11 @@ from pymongo.asynchronous.client_bulk import _AsyncClientBulk from pymongo.asynchronous.client_session import _EmptyServerSession from pymongo.asynchronous.command_cursor import AsyncCommandCursor +from pymongo.asynchronous.helpers import ( + _retry_overload, + _RetryPolicy, + _TokenBucket, +) from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology, _ErrorContext from pymongo.client_options import ClientOptions @@ -773,6 +779,7 @@ def __init__( self._timeout: float | None = None self._topology_settings: TopologySettings = None # type: ignore[assignment] self._event_listeners: _EventListeners | None = None + self._retry_policy = _RetryPolicy(_TokenBucket()) # _pool_class, _monitor_class, and _condition_class are for deep # customization of PyMongo, e.g. Motor. @@ -2398,6 +2405,7 @@ async def list_database_names( return [doc["name"] async for doc in res] @_csot.apply + @_retry_overload async def drop_database( self, name_or_database: Union[str, database.AsyncDatabase[_DocumentTypeArg]], @@ -2735,9 +2743,10 @@ def __init__( ): self._last_error: Optional[Exception] = None self._retrying = False + self._always_retryable = False self._multiple_retries = _csot.get_timeout() is not None self._client = mongo_client - + self._retry_policy = mongo_client._retry_policy self._func = func self._bulk = bulk self._session = session @@ -2772,7 +2781,9 @@ async def run(self) -> T: while True: self._check_last_error(check_csot=True) try: - return await self._read() if self._is_read else await self._write() + res = await self._read() if self._is_read else await self._write() + await self._retry_policy.record_success(self._attempt_number > 0) + return res except ServerSelectionTimeoutError: # The application may think the write was never attempted # if we raise ServerSelectionTimeoutError on the retry @@ -2783,14 +2794,22 @@ async def run(self) -> T: # most likely be a waste of time. raise except PyMongoError as exc: + always_retryable = False + overloaded = False + exc_to_check = exc # Execute specialized catch on read if self._is_read: if isinstance(exc, (ConnectionFailure, OperationFailure)): # ConnectionFailures do not supply a code property exc_code = getattr(exc, "code", None) - if self._is_not_eligible_for_retry() or ( - isinstance(exc, OperationFailure) - and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES + always_retryable = exc.has_error_label("Retryable") + overloaded = exc.has_error_label("SystemOverloaded") + if not always_retryable and ( + self._is_not_eligible_for_retry() + or ( + isinstance(exc, OperationFailure) + and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES + ) ): raise self._retrying = True @@ -2801,19 +2820,22 @@ async def run(self) -> T: # Specialized catch on write operation if not self._is_read: - if not self._retryable: + if isinstance(exc, ClientBulkWriteException) and isinstance( + exc.error, PyMongoError + ): + exc_to_check = exc.error + retryable_write_label = exc_to_check.has_error_label("RetryableWriteError") + always_retryable = exc_to_check.has_error_label("Retryable") + overloaded = exc_to_check.has_error_label("SystemOverloaded") + if not self._retryable and not always_retryable: raise - if isinstance(exc, ClientBulkWriteException) and exc.error: - retryable_write_error_exc = isinstance( - exc.error, PyMongoError - ) and exc.error.has_error_label("RetryableWriteError") - else: - retryable_write_error_exc = exc.has_error_label("RetryableWriteError") - if retryable_write_error_exc: + if retryable_write_label or always_retryable: assert self._session await self._session._unpin() - if not retryable_write_error_exc or self._is_not_eligible_for_retry(): - if exc.has_error_label("NoWritesPerformed") and self._last_error: + if not always_retryable and ( + not retryable_write_label or self._is_not_eligible_for_retry() + ): + if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error: raise self._last_error from exc else: raise @@ -2822,7 +2844,7 @@ async def run(self) -> T: self._bulk.retrying = True else: self._retrying = True - if not exc.has_error_label("NoWritesPerformed"): + if not exc_to_check.has_error_label("NoWritesPerformed"): self._last_error = exc if self._last_error is None: self._last_error = exc @@ -2830,6 +2852,17 @@ async def run(self) -> T: if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded: self._deprioritized_servers.append(self._server) + self._always_retryable = always_retryable + if always_retryable: + delay = self._retry_policy.backoff(self._attempt_number) if overloaded else 0 + if not await self._retry_policy.should_retry(self._attempt_number, delay): + if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error: + raise self._last_error from exc + else: + raise + if overloaded: + await asyncio.sleep(delay) + def _is_not_eligible_for_retry(self) -> bool: """Checks if the exchange is not eligible for retry""" return not self._retryable or (self._is_retrying() and not self._multiple_retries) @@ -2891,7 +2924,7 @@ async def _write(self) -> T: and conn.supports_sessions ) is_mongos = conn.is_mongos - if not sessions_supported: + if not self._always_retryable and not sessions_supported: # A retry is not possible because this server does # not support sessions raise the last error. self._check_last_error() @@ -2923,7 +2956,7 @@ async def _read(self) -> T: conn, read_pref, ): - if self._retrying and not self._retryable: + if self._retrying and not self._retryable and not self._always_retryable: self._check_last_error() if self._retrying: _debug_log( diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index 68a01dd7e7..a8f03fac74 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -136,6 +136,7 @@ from __future__ import annotations import collections +import random import time import uuid from collections.abc import Mapping as _Mapping @@ -470,6 +471,8 @@ def _max_time_expired_error(exc: PyMongoError) -> bool: # This limit is non-configurable and was chosen to be twice the 60 second # default value of MongoDB's `transactionLifetimeLimitSeconds` parameter. _WITH_TRANSACTION_RETRY_TIME_LIMIT = 120 +_BACKOFF_MAX = 1 +_BACKOFF_INITIAL = 0.050 # 50ms initial backoff def _within_time_limit(start_time: float) -> bool: @@ -699,7 +702,13 @@ def callback(session, custom_arg, custom_kwarg=None): https://github.com/mongodb/specifications/blob/master/source/transactions-convenient-api/transactions-convenient-api.md#handling-errors-inside-the-callback """ start_time = time.monotonic() + retry = 0 while True: + if retry: # Implement exponential backoff on retry. + jitter = random.random() # noqa: S311 + backoff = jitter * min(_BACKOFF_INITIAL * (2**retry), _BACKOFF_MAX) + time.sleep(backoff) + retry += 1 self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms) try: ret = callback(self) diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 9f32deb765..324139d40a 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -89,6 +89,7 @@ Cursor, RawBatchCursor, ) +from pymongo.synchronous.helpers import _retry_overload from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern, validate_boolean @@ -255,6 +256,7 @@ def __init__( unicode_decode_error_handler="replace", document_class=dict ) self._timeout = database.client.options.timeout + self._retry_policy = database.client._retry_policy if create or kwargs: if _IS_SYNC: @@ -2224,6 +2226,7 @@ def create_indexes( return self._create_indexes(indexes, session, **kwargs) @_csot.apply + @_retry_overload def _create_indexes( self, indexes: Sequence[IndexModel], session: Optional[ClientSession], **kwargs: Any ) -> list[str]: @@ -2419,7 +2422,6 @@ def drop_indexes( kwargs["comment"] = comment self._drop_index("*", session=session, **kwargs) - @_csot.apply def drop_index( self, index_or_name: _IndexKeyHint, @@ -2469,6 +2471,7 @@ def drop_index( self._drop_index(index_or_name, session, comment, **kwargs) @_csot.apply + @_retry_overload def _drop_index( self, index_or_name: _IndexKeyHint, @@ -3072,6 +3075,7 @@ def aggregate_raw_batches( ) @_csot.apply + @_retry_overload def rename( self, new_name: str, diff --git a/pymongo/synchronous/database.py b/pymongo/synchronous/database.py index e30f97817c..62f8f69067 100644 --- a/pymongo/synchronous/database.py +++ b/pymongo/synchronous/database.py @@ -43,6 +43,7 @@ from pymongo.synchronous.change_stream import DatabaseChangeStream from pymongo.synchronous.collection import Collection from pymongo.synchronous.command_cursor import CommandCursor +from pymongo.synchronous.helpers import _retry_overload from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline if TYPE_CHECKING: @@ -135,6 +136,7 @@ def __init__( self._name = name self._client: MongoClient[_DocumentType] = client self._timeout = client.options.timeout + self._retry_policy = client._retry_policy @property def client(self) -> MongoClient[_DocumentType]: @@ -477,6 +479,7 @@ def watch( return change_stream @_csot.apply + @_retry_overload def create_collection( self, name: str, @@ -816,6 +819,7 @@ def command( ... @_csot.apply + @_retry_overload def command( self, command: Union[str, MutableMapping[str, Any]], @@ -945,6 +949,7 @@ def command( ) @_csot.apply + @_retry_overload def cursor_command( self, command: Union[str, MutableMapping[str, Any]], @@ -1257,6 +1262,7 @@ def _drop_helper( ) @_csot.apply + @_retry_overload def drop_collection( self, name_or_collection: Union[str, Collection[_DocumentTypeArg]], diff --git a/pymongo/synchronous/helpers.py b/pymongo/synchronous/helpers.py index bc69a49e80..0a2cd71062 100644 --- a/pymongo/synchronous/helpers.py +++ b/pymongo/synchronous/helpers.py @@ -17,8 +17,11 @@ import asyncio import builtins +import functools +import random import socket import sys +import time as time # noqa: PLC0414 # needed in sync version from typing import ( Any, Callable, @@ -26,10 +29,13 @@ cast, ) +from pymongo import _csot from pymongo.errors import ( OperationFailure, + PyMongoError, ) from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE +from pymongo.lock import _create_lock _IS_SYNC = True @@ -38,6 +44,7 @@ def _handle_reauth(func: F) -> F: + @functools.wraps(func) def inner(*args: Any, **kwargs: Any) -> Any: no_reauth = kwargs.pop("no_reauth", False) from pymongo.message import _BulkWriteContext @@ -70,6 +77,123 @@ def inner(*args: Any, **kwargs: Any) -> Any: return cast(F, inner) +_MAX_RETRIES = 3 +_BACKOFF_INITIAL = 0.05 +_BACKOFF_MAX = 10 +# DRIVERS-3240 will determine these defaults. +DEFAULT_RETRY_TOKEN_CAPACITY = 1000.0 +DEFAULT_RETRY_TOKEN_RETURN = 0.1 + + +def _backoff( + attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX +) -> float: + jitter = random.random() # noqa: S311 + return jitter * min(initial_delay * (2**attempt), max_delay) + + +class _TokenBucket: + """A token bucket implementation for rate limiting.""" + + def __init__( + self, + capacity: float = DEFAULT_RETRY_TOKEN_CAPACITY, + return_rate: float = DEFAULT_RETRY_TOKEN_RETURN, + ): + self.lock = _create_lock() + self.capacity = capacity + # DRIVERS-3240 will determine how full the bucket should start. + self.tokens = capacity + self.return_rate = return_rate + + def consume(self) -> bool: + """Consume a token from the bucket if available.""" + with self.lock: + if self.tokens >= 1: + self.tokens -= 1 + return True + return False + + def deposit(self, retry: bool = False) -> None: + """Deposit a token back into the bucket.""" + retry_token = 1 if retry else 0 + with self.lock: + self.tokens = min(self.capacity, self.tokens + retry_token + self.return_rate) + + +class _RetryPolicy: + """A retry limiter that performs exponential backoff with jitter. + + Retry attempts are limited by a token bucket to prevent overwhelming the server during + a prolonged outage or high load. + """ + + def __init__( + self, + token_bucket: _TokenBucket, + attempts: int = _MAX_RETRIES, + backoff_initial: float = _BACKOFF_INITIAL, + backoff_max: float = _BACKOFF_MAX, + ): + self.token_bucket = token_bucket + self.attempts = attempts + self.backoff_initial = backoff_initial + self.backoff_max = backoff_max + + def record_success(self, retry: bool) -> None: + """Record a successful operation.""" + self.token_bucket.deposit(retry) + + def backoff(self, attempt: int) -> float: + """Return the backoff duration for the given .""" + return _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max) + + def should_retry(self, attempt: int, delay: float) -> bool: + """Return if we have budget to retry and how long to backoff.""" + if attempt > self.attempts: + return False + + # If the delay would exceed the deadline, bail early before consuming a token. + if _csot.get_timeout(): + if time.monotonic() + delay > _csot.get_deadline(): + return False + + # Check token bucket last since we only want to consume a token if we actually retry. + if not self.token_bucket.consume(): + # DRIVERS-3246 Improve diagnostics when this case happens. + # We could add info to the exception and log. + return False + return True + + +def _retry_overload(func: F) -> F: + @functools.wraps(func) + def inner(self: Any, *args: Any, **kwargs: Any) -> Any: + retry_policy = self._retry_policy + attempt = 0 + while True: + try: + res = func(self, *args, **kwargs) + retry_policy.record_success(retry=attempt > 0) + return res + except PyMongoError as exc: + if not exc.has_error_label("Retryable"): + raise + attempt += 1 + delay = 0 + if exc.has_error_label("SystemOverloaded"): + delay = retry_policy.backoff(attempt) + if not retry_policy.should_retry(attempt, delay): + raise + + # Implement exponential backoff on retry. + if delay: + time.sleep(delay) + continue + + return cast(F, inner) + + def _getaddrinfo( host: Any, port: Any, **kwargs: Any ) -> list[ diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index ef0663584c..9beda807ef 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -35,6 +35,7 @@ import asyncio import contextlib import os +import time as time # noqa: PLC0414 # needed in sync version import warnings import weakref from collections import defaultdict @@ -110,6 +111,11 @@ from pymongo.synchronous.client_bulk import _ClientBulk from pymongo.synchronous.client_session import _EmptyServerSession from pymongo.synchronous.command_cursor import CommandCursor +from pymongo.synchronous.helpers import ( + _retry_overload, + _RetryPolicy, + _TokenBucket, +) from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology, _ErrorContext from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription @@ -773,6 +779,7 @@ def __init__( self._timeout: float | None = None self._topology_settings: TopologySettings = None # type: ignore[assignment] self._event_listeners: _EventListeners | None = None + self._retry_policy = _RetryPolicy(_TokenBucket()) # _pool_class, _monitor_class, and _condition_class are for deep # customization of PyMongo, e.g. Motor. @@ -2388,6 +2395,7 @@ def list_database_names( return [doc["name"] for doc in res] @_csot.apply + @_retry_overload def drop_database( self, name_or_database: Union[str, database.Database[_DocumentTypeArg]], @@ -2725,9 +2733,10 @@ def __init__( ): self._last_error: Optional[Exception] = None self._retrying = False + self._always_retryable = False self._multiple_retries = _csot.get_timeout() is not None self._client = mongo_client - + self._retry_policy = mongo_client._retry_policy self._func = func self._bulk = bulk self._session = session @@ -2762,7 +2771,9 @@ def run(self) -> T: while True: self._check_last_error(check_csot=True) try: - return self._read() if self._is_read else self._write() + res = self._read() if self._is_read else self._write() + self._retry_policy.record_success(self._attempt_number > 0) + return res except ServerSelectionTimeoutError: # The application may think the write was never attempted # if we raise ServerSelectionTimeoutError on the retry @@ -2773,14 +2784,22 @@ def run(self) -> T: # most likely be a waste of time. raise except PyMongoError as exc: + always_retryable = False + overloaded = False + exc_to_check = exc # Execute specialized catch on read if self._is_read: if isinstance(exc, (ConnectionFailure, OperationFailure)): # ConnectionFailures do not supply a code property exc_code = getattr(exc, "code", None) - if self._is_not_eligible_for_retry() or ( - isinstance(exc, OperationFailure) - and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES + always_retryable = exc.has_error_label("Retryable") + overloaded = exc.has_error_label("SystemOverloaded") + if not always_retryable and ( + self._is_not_eligible_for_retry() + or ( + isinstance(exc, OperationFailure) + and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES + ) ): raise self._retrying = True @@ -2791,19 +2810,22 @@ def run(self) -> T: # Specialized catch on write operation if not self._is_read: - if not self._retryable: + if isinstance(exc, ClientBulkWriteException) and isinstance( + exc.error, PyMongoError + ): + exc_to_check = exc.error + retryable_write_label = exc_to_check.has_error_label("RetryableWriteError") + always_retryable = exc_to_check.has_error_label("Retryable") + overloaded = exc_to_check.has_error_label("SystemOverloaded") + if not self._retryable and not always_retryable: raise - if isinstance(exc, ClientBulkWriteException) and exc.error: - retryable_write_error_exc = isinstance( - exc.error, PyMongoError - ) and exc.error.has_error_label("RetryableWriteError") - else: - retryable_write_error_exc = exc.has_error_label("RetryableWriteError") - if retryable_write_error_exc: + if retryable_write_label or always_retryable: assert self._session self._session._unpin() - if not retryable_write_error_exc or self._is_not_eligible_for_retry(): - if exc.has_error_label("NoWritesPerformed") and self._last_error: + if not always_retryable and ( + not retryable_write_label or self._is_not_eligible_for_retry() + ): + if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error: raise self._last_error from exc else: raise @@ -2812,7 +2834,7 @@ def run(self) -> T: self._bulk.retrying = True else: self._retrying = True - if not exc.has_error_label("NoWritesPerformed"): + if not exc_to_check.has_error_label("NoWritesPerformed"): self._last_error = exc if self._last_error is None: self._last_error = exc @@ -2820,6 +2842,17 @@ def run(self) -> T: if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded: self._deprioritized_servers.append(self._server) + self._always_retryable = always_retryable + if always_retryable: + delay = self._retry_policy.backoff(self._attempt_number) if overloaded else 0 + if not self._retry_policy.should_retry(self._attempt_number, delay): + if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error: + raise self._last_error from exc + else: + raise + if overloaded: + time.sleep(delay) + def _is_not_eligible_for_retry(self) -> bool: """Checks if the exchange is not eligible for retry""" return not self._retryable or (self._is_retrying() and not self._multiple_retries) @@ -2881,7 +2914,7 @@ def _write(self) -> T: and conn.supports_sessions ) is_mongos = conn.is_mongos - if not sessions_supported: + if not self._always_retryable and not sessions_supported: # A retry is not possible because this server does # not support sessions raise the last error. self._check_last_error() @@ -2913,7 +2946,7 @@ def _read(self) -> T: conn, read_pref, ): - if self._retrying and not self._retryable: + if self._retrying and not self._retryable and not self._always_retryable: self._check_last_error() if self._retrying: _debug_log( diff --git a/test/asynchronous/test_backpressure.py b/test/asynchronous/test_backpressure.py new file mode 100644 index 0000000000..598236dbfe --- /dev/null +++ b/test/asynchronous/test_backpressure.py @@ -0,0 +1,230 @@ +# Copyright 2025-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test Client Backpressure spec.""" +from __future__ import annotations + +import asyncio +import sys + +import pymongo + +sys.path[0:0] = [""] + +from test.asynchronous import ( + AsyncIntegrationTest, + AsyncPyMongoTestCase, + async_client_context, + unittest, +) + +from pymongo.asynchronous import helpers +from pymongo.asynchronous.helpers import _MAX_RETRIES, _RetryPolicy, _TokenBucket +from pymongo.errors import PyMongoError + +_IS_SYNC = False + +# Mock an system overload error. +mock_overload_error = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["find", "insert", "update"], + "errorCode": 462, # IngressRequestRateLimitExceeded + "errorLabels": ["Retryable"], + }, +} + + +class TestBackpressure(AsyncIntegrationTest): + RUN_ON_LOAD_BALANCER = True + + @async_client_context.require_failCommand_appName + async def test_retry_overload_error_command(self): + await self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": _MAX_RETRIES} + async with self.fail_point(fail_many): + await self.db.command("find", "t") + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + async with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + await self.db.command("find", "t") + + self.assertIn("Retryable", str(error.exception)) + + @async_client_context.require_failCommand_appName + async def test_retry_overload_error_find(self): + await self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": _MAX_RETRIES} + async with self.fail_point(fail_many): + await self.db.t.find_one() + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + async with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + await self.db.t.find_one() + + self.assertIn("Retryable", str(error.exception)) + + @async_client_context.require_failCommand_appName + async def test_retry_overload_error_insert_one(self): + await self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": _MAX_RETRIES} + async with self.fail_point(fail_many): + await self.db.t.find_one() + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + async with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + await self.db.t.find_one() + + self.assertIn("Retryable", str(error.exception)) + + @async_client_context.require_failCommand_appName + async def test_retry_overload_error_update_many(self): + # Even though update_many is not a retryable write operation, it will + # still be retried via the "Retryable" error label. + await self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": _MAX_RETRIES} + async with self.fail_point(fail_many): + await self.db.t.update_many({}, {"$set": {"x": 2}}) + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + async with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + await self.db.t.update_many({}, {"$set": {"x": 2}}) + + self.assertIn("Retryable", str(error.exception)) + + @async_client_context.require_failCommand_appName + async def test_retry_overload_error_getMore(self): + coll = self.db.t + await coll.insert_many([{"x": 1} for _ in range(10)]) + + # Ensure command is retried on overload error. + fail_many = { + "configureFailPoint": "failCommand", + "mode": {"times": _MAX_RETRIES}, + "data": { + "failCommands": ["getMore"], + "errorCode": 462, # IngressRequestRateLimitExceeded + "errorLabels": ["Retryable"], + }, + } + cursor = coll.find(batch_size=2) + await cursor.next() + async with self.fail_point(fail_many): + await cursor.to_list() + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = fail_many.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + cursor = coll.find(batch_size=2) + await cursor.next() + async with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + await cursor.to_list() + + self.assertIn("Retryable", str(error.exception)) + + @async_client_context.require_failCommand_appName + async def test_limit_retry_command(self): + client = await self.async_rs_or_single_client() + client._retry_policy.token_bucket.tokens = 1 + db = client.pymongo_test + await db.t.insert_one({"x": 1}) + + # Ensure command is retried once overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": 1} + async with self.fail_point(fail_many): + await db.command("find", "t") + + # Ensure command stops retrying when there are no tokens left. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": 2} + async with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + await db.command("find", "t") + + self.assertIn("Retryable", str(error.exception)) + + +class TestRetryPolicy(AsyncPyMongoTestCase): + async def test_retry_policy(self): + capacity = 10 + retry_policy = _RetryPolicy(_TokenBucket(capacity=capacity)) + self.assertEqual(retry_policy.attempts, helpers._MAX_RETRIES) + self.assertEqual(retry_policy.backoff_initial, helpers._BACKOFF_INITIAL) + self.assertEqual(retry_policy.backoff_max, helpers._BACKOFF_MAX) + for i in range(1, helpers._MAX_RETRIES + 1): + self.assertTrue(await retry_policy.should_retry(i, 0)) + self.assertFalse(await retry_policy.should_retry(helpers._MAX_RETRIES + 1, 0)) + for i in range(capacity - helpers._MAX_RETRIES): + self.assertTrue(await retry_policy.should_retry(1, 0)) + # No tokens left, should not retry. + self.assertFalse(await retry_policy.should_retry(1, 0)) + self.assertEqual(retry_policy.token_bucket.tokens, 0) + + # record_success should generate tokens. + for _ in range(int(2 / helpers.DEFAULT_RETRY_TOKEN_RETURN)): + await retry_policy.record_success(retry=False) + self.assertAlmostEqual(retry_policy.token_bucket.tokens, 2) + for i in range(2): + self.assertTrue(await retry_policy.should_retry(1, 0)) + self.assertFalse(await retry_policy.should_retry(1, 0)) + + # Recording a successful retry should return 1 additional token. + await retry_policy.record_success(retry=True) + self.assertAlmostEqual( + retry_policy.token_bucket.tokens, 1 + helpers.DEFAULT_RETRY_TOKEN_RETURN + ) + self.assertTrue(await retry_policy.should_retry(1, 0)) + self.assertFalse(await retry_policy.should_retry(1, 0)) + self.assertAlmostEqual(retry_policy.token_bucket.tokens, helpers.DEFAULT_RETRY_TOKEN_RETURN) + + async def test_retry_policy_csot(self): + retry_policy = _RetryPolicy(_TokenBucket()) + self.assertTrue(await retry_policy.should_retry(1, 0.5)) + with pymongo.timeout(0.5): + self.assertTrue(await retry_policy.should_retry(1, 0)) + self.assertTrue(await retry_policy.should_retry(1, 0.1)) + # Would exceed the timeout, should not retry. + self.assertFalse(await retry_policy.should_retry(1, 1.0)) + self.assertTrue(await retry_policy.should_retry(1, 1.0)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_backpressure.py b/test/test_backpressure.py new file mode 100644 index 0000000000..182ce424a9 --- /dev/null +++ b/test/test_backpressure.py @@ -0,0 +1,230 @@ +# Copyright 2025-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test Client Backpressure spec.""" +from __future__ import annotations + +import asyncio +import sys + +import pymongo + +sys.path[0:0] = [""] + +from test import ( + IntegrationTest, + PyMongoTestCase, + client_context, + unittest, +) + +from pymongo.errors import PyMongoError +from pymongo.synchronous import helpers +from pymongo.synchronous.helpers import _MAX_RETRIES, _RetryPolicy, _TokenBucket + +_IS_SYNC = True + +# Mock an system overload error. +mock_overload_error = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["find", "insert", "update"], + "errorCode": 462, # IngressRequestRateLimitExceeded + "errorLabels": ["Retryable"], + }, +} + + +class TestBackpressure(IntegrationTest): + RUN_ON_LOAD_BALANCER = True + + @client_context.require_failCommand_appName + def test_retry_overload_error_command(self): + self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": _MAX_RETRIES} + with self.fail_point(fail_many): + self.db.command("find", "t") + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + self.db.command("find", "t") + + self.assertIn("Retryable", str(error.exception)) + + @client_context.require_failCommand_appName + def test_retry_overload_error_find(self): + self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": _MAX_RETRIES} + with self.fail_point(fail_many): + self.db.t.find_one() + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + self.db.t.find_one() + + self.assertIn("Retryable", str(error.exception)) + + @client_context.require_failCommand_appName + def test_retry_overload_error_insert_one(self): + self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": _MAX_RETRIES} + with self.fail_point(fail_many): + self.db.t.find_one() + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + self.db.t.find_one() + + self.assertIn("Retryable", str(error.exception)) + + @client_context.require_failCommand_appName + def test_retry_overload_error_update_many(self): + # Even though update_many is not a retryable write operation, it will + # still be retried via the "Retryable" error label. + self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": _MAX_RETRIES} + with self.fail_point(fail_many): + self.db.t.update_many({}, {"$set": {"x": 2}}) + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + self.db.t.update_many({}, {"$set": {"x": 2}}) + + self.assertIn("Retryable", str(error.exception)) + + @client_context.require_failCommand_appName + def test_retry_overload_error_getMore(self): + coll = self.db.t + coll.insert_many([{"x": 1} for _ in range(10)]) + + # Ensure command is retried on overload error. + fail_many = { + "configureFailPoint": "failCommand", + "mode": {"times": _MAX_RETRIES}, + "data": { + "failCommands": ["getMore"], + "errorCode": 462, # IngressRequestRateLimitExceeded + "errorLabels": ["Retryable"], + }, + } + cursor = coll.find(batch_size=2) + cursor.next() + with self.fail_point(fail_many): + cursor.to_list() + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = fail_many.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + cursor = coll.find(batch_size=2) + cursor.next() + with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + cursor.to_list() + + self.assertIn("Retryable", str(error.exception)) + + @client_context.require_failCommand_appName + def test_limit_retry_command(self): + client = self.rs_or_single_client() + client._retry_policy.token_bucket.tokens = 1 + db = client.pymongo_test + db.t.insert_one({"x": 1}) + + # Ensure command is retried once overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": 1} + with self.fail_point(fail_many): + db.command("find", "t") + + # Ensure command stops retrying when there are no tokens left. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": 2} + with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + db.command("find", "t") + + self.assertIn("Retryable", str(error.exception)) + + +class TestRetryPolicy(PyMongoTestCase): + def test_retry_policy(self): + capacity = 10 + retry_policy = _RetryPolicy(_TokenBucket(capacity=capacity)) + self.assertEqual(retry_policy.attempts, helpers._MAX_RETRIES) + self.assertEqual(retry_policy.backoff_initial, helpers._BACKOFF_INITIAL) + self.assertEqual(retry_policy.backoff_max, helpers._BACKOFF_MAX) + for i in range(1, helpers._MAX_RETRIES + 1): + self.assertTrue(retry_policy.should_retry(i, 0)) + self.assertFalse(retry_policy.should_retry(helpers._MAX_RETRIES + 1, 0)) + for i in range(capacity - helpers._MAX_RETRIES): + self.assertTrue(retry_policy.should_retry(1, 0)) + # No tokens left, should not retry. + self.assertFalse(retry_policy.should_retry(1, 0)) + self.assertEqual(retry_policy.token_bucket.tokens, 0) + + # record_success should generate tokens. + for _ in range(int(2 / helpers.DEFAULT_RETRY_TOKEN_RETURN)): + retry_policy.record_success(retry=False) + self.assertAlmostEqual(retry_policy.token_bucket.tokens, 2) + for i in range(2): + self.assertTrue(retry_policy.should_retry(1, 0)) + self.assertFalse(retry_policy.should_retry(1, 0)) + + # Recording a successful retry should return 1 additional token. + retry_policy.record_success(retry=True) + self.assertAlmostEqual( + retry_policy.token_bucket.tokens, 1 + helpers.DEFAULT_RETRY_TOKEN_RETURN + ) + self.assertTrue(retry_policy.should_retry(1, 0)) + self.assertFalse(retry_policy.should_retry(1, 0)) + self.assertAlmostEqual(retry_policy.token_bucket.tokens, helpers.DEFAULT_RETRY_TOKEN_RETURN) + + def test_retry_policy_csot(self): + retry_policy = _RetryPolicy(_TokenBucket()) + self.assertTrue(retry_policy.should_retry(1, 0.5)) + with pymongo.timeout(0.5): + self.assertTrue(retry_policy.should_retry(1, 0)) + self.assertTrue(retry_policy.should_retry(1, 0.1)) + # Would exceed the timeout, should not retry. + self.assertFalse(retry_policy.should_retry(1, 1.0)) + self.assertTrue(retry_policy.should_retry(1, 1.0)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/synchro.py b/tools/synchro.py index 9a760c0ad7..44698134cd 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -208,6 +208,7 @@ def async_only_test(f: str) -> bool: "test_auth_oidc.py", "test_auth_spec.py", "test_bulk.py", + "test_backpressure.py", "test_change_stream.py", "test_client.py", "test_client_bulk_write.py",