From 9589b1f9c935aeac42dac5e081ffcbd021114ff0 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 31 Oct 2024 16:16:32 -0400 Subject: [PATCH 1/5] PYTHON-4927 - Add missing CSOT prose tests --- pymongo/_csot.py | 20 ++++- pymongo/asynchronous/client_session.py | 12 ++- pymongo/asynchronous/topology.py | 3 +- pymongo/synchronous/client_session.py | 12 ++- pymongo/synchronous/topology.py | 3 +- test/asynchronous/test_client.py | 47 ++++++++++- test/asynchronous/test_collection.py | 26 +++++++ test/asynchronous/test_encryption.py | 103 +++++++++++++++++++++++++ test/asynchronous/test_session.py | 2 +- test/asynchronous/test_transactions.py | 37 +++++++++ test/test_client.py | 47 ++++++++++- test/test_collection.py | 26 +++++++ test/test_csot.py | 9 ++- test/test_encryption.py | 103 +++++++++++++++++++++++++ test/test_gridfs_bucket.py | 74 ++++++++++++++++++ test/test_server_selection.py | 93 +++++++++++++++++++++- test/test_session.py | 2 +- test/test_transactions.py | 37 +++++++++ test/utils.py | 4 + 19 files changed, 645 insertions(+), 15 deletions(-) diff --git a/pymongo/_csot.py b/pymongo/_csot.py index 06c6b68ac9..fb77a69ca1 100644 --- a/pymongo/_csot.py +++ b/pymongo/_csot.py @@ -16,13 +16,24 @@ from __future__ import annotations +import contextlib import functools import inspect import time from collections import deque from contextlib import AbstractContextManager from contextvars import ContextVar, Token -from typing import TYPE_CHECKING, Any, Callable, Deque, MutableMapping, Optional, TypeVar, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Deque, + Generator, + MutableMapping, + Optional, + TypeVar, + cast, +) if TYPE_CHECKING: from pymongo.write_concern import WriteConcern @@ -54,6 +65,13 @@ def remaining() -> Optional[float]: return DEADLINE.get() - time.monotonic() +@contextlib.contextmanager +def reset() -> Generator: + deadline_token = DEADLINE.set(DEADLINE.get() + get_timeout()) # type: ignore[operator] + yield + DEADLINE.reset(deadline_token) + + def clamp_remaining(max_timeout: float) -> float: """Return the remaining timeout clamped to a max value.""" timeout = remaining() diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index d80495d804..e3a7d92fb1 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -473,7 +473,11 @@ def _max_time_expired_error(exc: PyMongoError) -> bool: def _within_time_limit(start_time: float) -> bool: """Are we within the with_transaction retry limit?""" - return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT + timeout = _csot.get_timeout() + if timeout: + return time.monotonic() - start_time < timeout + else: + return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT _T = TypeVar("_T") @@ -512,6 +516,7 @@ def __init__( # Is this an implicitly created session? self._implicit = implicit self._transaction = _Transaction(None, client) + self._timeout = client.options.timeout async def end_session(self) -> None: """Finish this session. If a transaction has started, abort it. @@ -597,6 +602,7 @@ def _inherit_option(self, name: str, val: _T) -> _T: return parent_val return getattr(self.client, name) + @_csot.apply async def with_transaction( self, callback: Callable[[AsyncClientSession], Coroutine[Any, Any, _T]], @@ -697,7 +703,8 @@ async def callback(session, custom_arg, custom_kwarg=None): ret = await callback(self) except Exception as exc: if self.in_transaction: - await self.abort_transaction() + with _csot.reset(): + await self.abort_transaction() if ( isinstance(exc, PyMongoError) and exc.has_error_label("TransientTransactionError") @@ -816,6 +823,7 @@ async def commit_transaction(self) -> None: finally: self._transaction.state = _TxnState.COMMITTED + @_csot.apply async def abort_transaction(self) -> None: """Abort a multi-statement transaction. diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index 82af4257ba..a51dd8a98f 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -249,7 +249,8 @@ def get_server_selection_timeout(self) -> float: timeout = _csot.remaining() if timeout is None: return self._settings.server_selection_timeout - return timeout + else: + return min(timeout, self._settings.server_selection_timeout) async def select_servers( self, diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index f1d680fc0a..ade46a9799 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -472,7 +472,11 @@ def _max_time_expired_error(exc: PyMongoError) -> bool: def _within_time_limit(start_time: float) -> bool: """Are we within the with_transaction retry limit?""" - return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT + timeout = _csot.get_timeout() + if timeout: + return time.monotonic() - start_time < timeout + else: + return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT _T = TypeVar("_T") @@ -511,6 +515,7 @@ def __init__( # Is this an implicitly created session? self._implicit = implicit self._transaction = _Transaction(None, client) + self._timeout = client.options.timeout def end_session(self) -> None: """Finish this session. If a transaction has started, abort it. @@ -596,6 +601,7 @@ def _inherit_option(self, name: str, val: _T) -> _T: return parent_val return getattr(self.client, name) + @_csot.apply def with_transaction( self, callback: Callable[[ClientSession], _T], @@ -694,7 +700,8 @@ def callback(session, custom_arg, custom_kwarg=None): ret = callback(self) except Exception as exc: if self.in_transaction: - self.abort_transaction() + with _csot.reset(): + self.abort_transaction() if ( isinstance(exc, PyMongoError) and exc.has_error_label("TransientTransactionError") @@ -813,6 +820,7 @@ def commit_transaction(self) -> None: finally: self._transaction.state = _TxnState.COMMITTED + @_csot.apply def abort_transaction(self) -> None: """Abort a multi-statement transaction. diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index a350c1702e..9e5284c9d0 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -249,7 +249,8 @@ def get_server_selection_timeout(self) -> float: timeout = _csot.remaining() if timeout is None: return self._settings.server_selection_timeout - return timeout + else: + return min(timeout, self._settings.server_selection_timeout) def select_servers( self, diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 590154b857..6015aa452e 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -64,6 +64,7 @@ from test.utils import ( NTHREADS, CMAPListener, + EventListener, FunctionCallRecorder, async_get_pool, async_wait_until, @@ -114,7 +115,13 @@ ServerSelectionTimeoutError, WriteConcernError, ) -from pymongo.monitoring import ServerHeartbeatListener, ServerHeartbeatStartedEvent +from pymongo.monitoring import ( + ConnectionClosedEvent, + ConnectionCreatedEvent, + ConnectionReadyEvent, + ServerHeartbeatListener, + ServerHeartbeatStartedEvent, +) from pymongo.pool_options import _MAX_METADATA_SIZE, _METADATA, ENV_VAR_K8S, PoolOptions from pymongo.read_preferences import ReadPreference from pymongo.server_description import ServerDescription @@ -2585,5 +2592,43 @@ async def test_direct_client_maintains_pool_to_arbiter(self): self.assertEqual(listener.event_count(monitoring.PoolReadyEvent), 1) +# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#4-background-connection-pooling +class TestClientCSOTProse(AsyncIntegrationTest): + # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#timeoutms-is-refreshed-for-each-handshake-command + @async_client_context.require_auth + @async_client_context.require_version_min(4, 4, -1) + async def test_02_timeoutMS_refreshed_for_each_handshake_command(self): + listener = CMAPListener() + + async with self.fail_point( + { + "mode": {"times": 1}, + "data": { + "failCommands": ["hello", "isMaster", "saslContinue"], + "blockConnection": True, + "blockTimeMS": 15, + "appName": "refreshTimeoutBackgroundPoolTest", + }, + } + ): + _ = await self.async_single_client( + minPoolSize=1, + timeoutMS=20, + appname="refreshTimeoutBackgroundPoolTest", + event_listeners=[listener], + ) + + async def predicate(): + return ( + listener.event_count(ConnectionCreatedEvent) == 1 + and listener.event_count(ConnectionReadyEvent) == 1 + ) + + await async_wait_until( + predicate, + "didn't ever see a ConnectionCreatedEvent and a ConnectionReadyEvent", + ) + + if __name__ == "__main__": unittest.main() diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index db52bad4ac..f048b9acd1 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -64,6 +64,7 @@ InvalidDocument, InvalidName, InvalidOperation, + NetworkTimeout, OperationFailure, WriteConcernError, ) @@ -2277,6 +2278,31 @@ async def afind(*args, **kwargs): for helper, args in helpers: await helper(*args, let={}) # type: ignore + # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#1-multi-batch-inserts + @async_client_context.require_standalone + @async_client_context.require_version_min(4, 4, -1) + async def test_01_multi_batch_inserts(self): + client = await self.async_single_client(read_preference=ReadPreference.PRIMARY_PREFERRED) + await client.db.coll.drop() + + async with self.fail_point( + { + "mode": {"times": 2}, + "data": {"failCommands": ["insert"], "blockConnection": True, "blockTimeMS": 1010}, + } + ): + listener = OvertCommandListener() + client2 = await self.async_single_client( + timeoutMS=2000, + read_preference=ReadPreference.PRIMARY_PREFERRED, + event_listeners=[listener], + ) + docs = [{"a": "b" * 1000000} for _ in range(50)] + with self.assertRaises(NetworkTimeout): + await client2.db.coll.insert_many(docs) + + self.assertEqual(2, len(listener.started_events)) + if __name__ == "__main__": unittest.main() diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index 40f1acd32d..c1a4ebcac0 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -86,6 +86,7 @@ EncryptedCollectionError, EncryptionError, InvalidOperation, + NetworkTimeout, OperationFailure, ServerSelectionTimeoutError, WriteError, @@ -3133,5 +3134,107 @@ async def test_explicit_session_errors_when_unsupported(self): await self.mongocryptd_client.db.test.insert_one({"x": 1}, session=s) +# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#3-clientencryption +class TestCSOTProse(AsyncEncryptionIntegrationTest): + mongocryptd_client: AsyncMongoClient + MONGOCRYPTD_PORT = 27020 + LOCAL_MASTERKEY = Binary( + base64.b64decode( + b"Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk" + ), + UUID_SUBTYPE, + ) + + async def asyncSetUp(self) -> None: + self.listener = OvertCommandListener() + self.client = await self.async_single_client( + read_preference=ReadPreference.PRIMARY_PREFERRED, event_listeners=[self.listener] + ) + await self.client.keyvault.datakeys.drop() + self.key_vault_client = await self.async_rs_or_single_client( + timeoutMS=50, event_listeners=[self.listener] + ) + self.client_encryption = self.create_client_encryption( + key_vault_namespace="keyvault.datakeys", + kms_providers={"local": {"key": self.LOCAL_MASTERKEY}}, + key_vault_client=self.key_vault_client, + codec_options=OPTS, + ) + + # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#createdatakey + async def test_01_create_data_key(self): + async with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["insert"], "blockConnection": True, "blockTimeMS": 100}, + } + ): + self.listener.reset() + with self.assertRaisesRegex(EncryptionError, "timed out"): + await self.client_encryption.create_data_key("local") + + events = self.listener.started_events + self.assertEqual(1, len(events)) + self.assertEqual("insert", events[0].command_name) + + # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#encrypt + async def test_02_encrypt(self): + data_key_id = await self.client_encryption.create_data_key("local") + self.assertEqual(4, data_key_id.subtype) + async with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "blockConnection": True, "blockTimeMS": 100}, + } + ): + self.listener.reset() + with self.assertRaisesRegex(EncryptionError, "timed out"): + await self.client_encryption.encrypt( + "hello", + Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, + key_id=data_key_id, + ) + + events = self.listener.started_events + self.assertEqual(1, len(events)) + self.assertEqual("find", events[0].command_name) + + # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#decrypt + async def test_03_decrypt(self): + data_key_id = await self.client_encryption.create_data_key("local") + self.assertEqual(4, data_key_id.subtype) + + encrypted = await self.client_encryption.encrypt( + "hello", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=data_key_id + ) + self.assertEqual(6, encrypted.subtype) + + await self.key_vault_client.close() + self.key_vault_client = await self.async_rs_or_single_client( + timeoutMS=50, event_listeners=[self.listener] + ) + await self.client_encryption.close() + self.client_encryption = self.create_client_encryption( + key_vault_namespace="keyvault.datakeys", + kms_providers={"local": {"key": self.LOCAL_MASTERKEY}}, + key_vault_client=self.key_vault_client, + codec_options=OPTS, + ) + + async with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "blockConnection": True, "blockTimeMS": 100}, + } + ): + self.listener.reset() + with self.assertRaisesRegex(EncryptionError, "timed out"): + await self.client_encryption.decrypt(encrypted) + + events = self.listener.started_events + self.assertEqual(1, len(events)) + self.assertEqual("find", events[0].command_name) + + if __name__ == "__main__": unittest.main() diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index b432621798..ba991a6ee8 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -48,7 +48,7 @@ from pymongo.asynchronous.cursor import AsyncCursor from pymongo.asynchronous.helpers import anext from pymongo.common import _MAX_END_SESSIONS -from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure +from pymongo.errors import ConfigurationError, InvalidOperation, NetworkTimeout, OperationFailure from pymongo.operations import IndexModel, InsertOne, UpdateOne from pymongo.read_concern import ReadConcern diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py index b5d0686417..3b533114df 100644 --- a/test/asynchronous/test_transactions.py +++ b/test/asynchronous/test_transactions.py @@ -43,6 +43,7 @@ ConfigurationError, ConnectionFailure, InvalidOperation, + NetworkTimeout, OperationFailure, ) from pymongo.operations import IndexModel, InsertOne @@ -386,6 +387,42 @@ async def find_raw_batches(*args, **kwargs): if isinstance(res, (AsyncCommandCursor, AsyncCursor)): await res.to_list() + @async_client_context.require_transactions + async def test_10_convenient_transactions_csot(self): + await self.client.db.coll.drop() + + listener = OvertCommandListener() + + async with self.fail_point( + { + "mode": {"times": 2}, + "data": { + "failCommands": ["insert", "abortTransaction"], + "blockConnection": True, + "blockTimeMS": 200, + }, + } + ): + client = await self.async_rs_or_single_client( + timeoutMS=150, + event_listeners=[listener], + ) + session = client.start_session() + + async def callback(s): + await client.db.coll.insert_one({"_id": 1}, session=s) + + with self.assertRaises(NetworkTimeout): + await session.with_transaction(callback) + + started = listener.started_command_names() + failed = listener.failed_command_names() + + self.assertIn("insert", started) + self.assertIn("abortTransaction", started) + self.assertIn("insert", failed) + self.assertIn("abortTransaction", failed) + class PatchSessionTimeout: """Patches the client_session's with_transaction timeout for testing.""" diff --git a/test/test_client.py b/test/test_client.py index 5bbb5bd751..f043fdbe71 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -63,6 +63,7 @@ from test.utils import ( NTHREADS, CMAPListener, + EventListener, FunctionCallRecorder, assertRaisesExactly, delay, @@ -102,7 +103,13 @@ ServerSelectionTimeoutError, WriteConcernError, ) -from pymongo.monitoring import ServerHeartbeatListener, ServerHeartbeatStartedEvent +from pymongo.monitoring import ( + ConnectionClosedEvent, + ConnectionCreatedEvent, + ConnectionReadyEvent, + ServerHeartbeatListener, + ServerHeartbeatStartedEvent, +) from pymongo.pool_options import _MAX_METADATA_SIZE, _METADATA, ENV_VAR_K8S, PoolOptions from pymongo.read_preferences import ReadPreference from pymongo.server_description import ServerDescription @@ -2541,5 +2548,43 @@ def test_direct_client_maintains_pool_to_arbiter(self): self.assertEqual(listener.event_count(monitoring.PoolReadyEvent), 1) +# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#4-background-connection-pooling +class TestClientCSOTProse(IntegrationTest): + # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#timeoutms-is-refreshed-for-each-handshake-command + @client_context.require_auth + @client_context.require_version_min(4, 4, -1) + def test_02_timeoutMS_refreshed_for_each_handshake_command(self): + listener = CMAPListener() + + with self.fail_point( + { + "mode": {"times": 1}, + "data": { + "failCommands": ["hello", "isMaster", "saslContinue"], + "blockConnection": True, + "blockTimeMS": 15, + "appName": "refreshTimeoutBackgroundPoolTest", + }, + } + ): + _ = self.single_client( + minPoolSize=1, + timeoutMS=20, + appname="refreshTimeoutBackgroundPoolTest", + event_listeners=[listener], + ) + + def predicate(): + return ( + listener.event_count(ConnectionCreatedEvent) == 1 + and listener.event_count(ConnectionReadyEvent) == 1 + ) + + wait_until( + predicate, + "didn't ever see a ConnectionCreatedEvent and a ConnectionReadyEvent", + ) + + if __name__ == "__main__": unittest.main() diff --git a/test/test_collection.py b/test/test_collection.py index 84a900d45b..a392c733d2 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -59,6 +59,7 @@ InvalidDocument, InvalidName, InvalidOperation, + NetworkTimeout, OperationFailure, WriteConcernError, ) @@ -2254,6 +2255,31 @@ def afind(*args, **kwargs): for helper, args in helpers: helper(*args, let={}) # type: ignore + # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#1-multi-batch-inserts + @client_context.require_standalone + @client_context.require_version_min(4, 4, -1) + def test_01_multi_batch_inserts(self): + client = self.single_client(read_preference=ReadPreference.PRIMARY_PREFERRED) + client.db.coll.drop() + + with self.fail_point( + { + "mode": {"times": 2}, + "data": {"failCommands": ["insert"], "blockConnection": True, "blockTimeMS": 1010}, + } + ): + listener = OvertCommandListener() + client2 = self.single_client( + timeoutMS=2000, + read_preference=ReadPreference.PRIMARY_PREFERRED, + event_listeners=[listener], + ) + docs = [{"a": "b" * 1000000} for _ in range(50)] + with self.assertRaises(NetworkTimeout): + client2.db.coll.insert_many(docs) + + self.assertEqual(2, len(listener.started_events)) + if __name__ == "__main__": unittest.main() diff --git a/test/test_csot.py b/test/test_csot.py index 64210b4d64..dc7c999e08 100644 --- a/test/test_csot.py +++ b/test/test_csot.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Test the CSOT unified spec tests.""" +"""Test the CSOT unified spec and prose tests.""" from __future__ import annotations import os import sys +from test.utils import OvertCommandListener + +from pymongo.read_concern import ReadConcern sys.path[0:0] = [""] @@ -24,8 +27,8 @@ from test.unified_format import generate_test_classes import pymongo -from pymongo import _csot -from pymongo.errors import PyMongoError +from pymongo import ReadPreference, WriteConcern, _csot +from pymongo.errors import NetworkTimeout, PyMongoError # Location of JSON test specifications. TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "csot") diff --git a/test/test_encryption.py b/test/test_encryption.py index 373981b1d2..569e4422b8 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -83,6 +83,7 @@ EncryptedCollectionError, EncryptionError, InvalidOperation, + NetworkTimeout, OperationFailure, ServerSelectionTimeoutError, WriteError, @@ -3115,5 +3116,107 @@ def test_explicit_session_errors_when_unsupported(self): self.mongocryptd_client.db.test.insert_one({"x": 1}, session=s) +# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#3-clientencryption +class TestCSOTProse(EncryptionIntegrationTest): + mongocryptd_client: MongoClient + MONGOCRYPTD_PORT = 27020 + LOCAL_MASTERKEY = Binary( + base64.b64decode( + b"Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk" + ), + UUID_SUBTYPE, + ) + + def setUp(self) -> None: + self.listener = OvertCommandListener() + self.client = self.single_client( + read_preference=ReadPreference.PRIMARY_PREFERRED, event_listeners=[self.listener] + ) + self.client.keyvault.datakeys.drop() + self.key_vault_client = self.rs_or_single_client( + timeoutMS=50, event_listeners=[self.listener] + ) + self.client_encryption = self.create_client_encryption( + key_vault_namespace="keyvault.datakeys", + kms_providers={"local": {"key": self.LOCAL_MASTERKEY}}, + key_vault_client=self.key_vault_client, + codec_options=OPTS, + ) + + # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#createdatakey + def test_01_create_data_key(self): + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["insert"], "blockConnection": True, "blockTimeMS": 100}, + } + ): + self.listener.reset() + with self.assertRaisesRegex(EncryptionError, "timed out"): + self.client_encryption.create_data_key("local") + + events = self.listener.started_events + self.assertEqual(1, len(events)) + self.assertEqual("insert", events[0].command_name) + + # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#encrypt + def test_02_encrypt(self): + data_key_id = self.client_encryption.create_data_key("local") + self.assertEqual(4, data_key_id.subtype) + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "blockConnection": True, "blockTimeMS": 100}, + } + ): + self.listener.reset() + with self.assertRaisesRegex(EncryptionError, "timed out"): + self.client_encryption.encrypt( + "hello", + Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, + key_id=data_key_id, + ) + + events = self.listener.started_events + self.assertEqual(1, len(events)) + self.assertEqual("find", events[0].command_name) + + # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#decrypt + def test_03_decrypt(self): + data_key_id = self.client_encryption.create_data_key("local") + self.assertEqual(4, data_key_id.subtype) + + encrypted = self.client_encryption.encrypt( + "hello", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=data_key_id + ) + self.assertEqual(6, encrypted.subtype) + + self.key_vault_client.close() + self.key_vault_client = self.rs_or_single_client( + timeoutMS=50, event_listeners=[self.listener] + ) + self.client_encryption.close() + self.client_encryption = self.create_client_encryption( + key_vault_namespace="keyvault.datakeys", + kms_providers={"local": {"key": self.LOCAL_MASTERKEY}}, + key_vault_client=self.key_vault_client, + codec_options=OPTS, + ) + + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "blockConnection": True, "blockTimeMS": 100}, + } + ): + self.listener.reset() + with self.assertRaisesRegex(EncryptionError, "timed out"): + self.client_encryption.decrypt(encrypted) + + events = self.listener.started_events + self.assertEqual(1, len(events)) + self.assertEqual("find", events[0].command_name) + + if __name__ == "__main__": unittest.main() diff --git a/test/test_gridfs_bucket.py b/test/test_gridfs_bucket.py index 28adb7051a..6c8b1a1c8d 100644 --- a/test/test_gridfs_bucket.py +++ b/test/test_gridfs_bucket.py @@ -37,6 +37,7 @@ from gridfs.errors import CorruptGridFile, NoFile from pymongo.errors import ( ConfigurationError, + NetworkTimeout, NotPrimaryError, ServerSelectionTimeoutError, WriteConcernError, @@ -525,5 +526,78 @@ def test_gridfs_secondary_lazy(self): self.assertRaises(NotPrimaryError, gfs.upload_from_stream, "test_filename", b"data") +# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#6-gridfs---upload +class TestGridFsCSOT(IntegrationTest): + # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#uploads-via-openuploadstream-can-be-timed-out + def test_06_01_uploads_via_open_upload_stream_can_be_timed_out(self): + self.client.db.fs.files.drop() + self.client.db.fs.chunks.drop() + + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["insert"], "blockConnection": True, "blockTimeMS": 200}, + } + ): + client = self.single_client(timeoutMS=150) + bucket = gridfs.GridFSBucket(client.db) + upload_stream = bucket.open_upload_stream("filename") + upload_stream.write(b"0x12") + with self.assertRaises(NetworkTimeout): + upload_stream.close() + + # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#aborting-an-upload-stream-can-be-timed-out + def test_06_02_aborting_an_upload_stream_can_be_timed_out(self): + self.client.db.fs.files.drop() + self.client.db.fs.chunks.drop() + + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["delete"], "blockConnection": True, "blockTimeMS": 200}, + } + ): + client = self.single_client(timeoutMS=150) + bucket = gridfs.GridFSBucket(client.db, chunk_size_bytes=2) + upload_stream = bucket.open_upload_stream("filename") + upload_stream.write(b"0x010x020x030x04") + with self.assertRaises(NetworkTimeout): + upload_stream.abort() + + # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#7-gridfs---download + def test_07_gridfs_download_csot(self): + self.client.db.fs.files.drop() + self.client.db.fs.chunks.drop() + + id = ObjectId("000000000000000000000005") + + self.client.db.fs.files.insert_one( + { + "_id": id, + "length": 10, + "chunkSize": 4, + "uploadDate": {"$date": "1970-01-01T00:00:00.000Z"}, + "md5": "57d83cd477bfb1ccd975ab33d827a92b", + "filename": "length-10", + "contentType": "application/octet-stream", + "aliases": [], + "metadata": {}, + } + ) + + client = self.single_client(timeoutMS=150) + bucket = gridfs.GridFSBucket(client.db) + download_stream = bucket.open_download_stream(id) + + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "blockConnection": True, "blockTimeMS": 200}, + } + ): + with self.assertRaises(NetworkTimeout): + download_stream.read() + + if __name__ == "__main__": unittest.main() diff --git a/test/test_server_selection.py b/test/test_server_selection.py index 984b967f50..dc7be00692 100644 --- a/test/test_server_selection.py +++ b/test/test_server_selection.py @@ -17,9 +17,10 @@ import os import sys +import time from pymongo import MongoClient, ReadPreference -from pymongo.errors import ServerSelectionTimeoutError +from pymongo.errors import NetworkTimeout, ServerSelectionTimeoutError from pymongo.hello import HelloCompat from pymongo.operations import _Op from pymongo.server_selectors import writable_server_selector @@ -200,5 +201,95 @@ def test_server_selector_bypassed(self): self.assertEqual(selector.call_count, 0) +# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#8-server-selection +class TestServerSelectionCSOT(IntegrationTest): + # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#serverselectiontimeoutms-honored-if-timeoutms-is-not-set + def test_08_01_server_selection_timeoutMS_honored(self): + client = self.single_client("mongodb://invalid/?serverSelectionTimeoutMS=10") + with self.assertRaises(ServerSelectionTimeoutError): + start = time.time_ns() * 1000 + client.admin.command({"ping": 1}) + + end = time.time_ns() * 1000 + + self.assertLessEqual(start - end, 15) + + # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#timeoutms-honored-for-server-selection-if-its-lower-than-serverselectiontimeoutms + def test_08_02_timeoutMS_honored_for_server_selection_if_lower(self): + client = self.single_client("mongodb://invalid/?timeoutMS=10&serverSelectionTimeoutMS=20") + with self.assertRaises(ServerSelectionTimeoutError): + start = time.time_ns() * 1_000_000 + client.admin.command({"ping": 1}) + end = time.time_ns() * 1_000_000 + + self.assertLessEqual(start - end, 15) + + # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#serverselectiontimeoutms-honored-for-server-selection-if-its-lower-than-timeoutms + def test_08_03_serverselectiontimeoutms_honored_for_server_selection_if_lower(self): + client = self.single_client("mongodb://invalid/?timeoutMS=20&serverSelectionTimeoutMS=10") + with self.assertRaises(ServerSelectionTimeoutError): + start = time.time_ns() * 1_000_000 + client.admin.command({"ping": 1}) + + end = time.time_ns() * 1_000_000 + + self.assertLessEqual(start - end, 15) + + # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#serverselectiontimeoutms-honored-for-server-selection-if-timeoutms0 + def test_08_04_serverselectiontimeoutms_honored_for_server_selection_if_zero_timeoutms(self): + client = self.single_client("mongodb://invalid/?timeoutMS=0&serverSelectionTimeoutMS=10") + with self.assertRaises(ServerSelectionTimeoutError): + start = time.time_ns() * 1_000_000 + client.admin.command({"ping": 1}) + + end = time.time_ns() * 1_000_000 + + self.assertLessEqual(start - end, 15) + + # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#timeoutms-honored-for-connection-handshake-commands-if-its-lower-than-serverselectiontimeoutms + @client_context.require_auth + def test_08_05_timeoutms_honored_for_handshake_if_lower(self): + with self.fail_point( + { + "mode": {"times": 1}, + "data": { + "failCommands": ["saslContinue"], + "blockConnection": True, + "blockTimeMS": 15, + }, + } + ): + client = self.single_client(timeoutMS=10, serverSelectionTimeoutMS=20) + with self.assertRaises(NetworkTimeout): + start = time.time_ns() * 1_000_000 + client.db.coll.insert_one({"x": 1}) + + end = time.time_ns() * 1_000_000 + + self.assertLessEqual(start - end, 15) + + # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#serverselectiontimeoutms-honored-for-connection-handshake-commands-if-its-lower-than-timeoutms + @client_context.require_auth + def test_08_06_serverSelectionTimeoutMS_honored_for_handshake_if_lower(self): + with self.fail_point( + { + "mode": {"times": 1}, + "data": { + "failCommands": ["saslContinue"], + "blockConnection": True, + "blockTimeMS": 15, + }, + } + ): + client = self.single_client(timeoutMS=20, serverSelectionTimeoutMS=10) + with self.assertRaises(NetworkTimeout): + start = time.time_ns() * 1_000_000 + client.db.coll.insert_one({"x": 1}) + + end = time.time_ns() * 1_000_000 + + self.assertLessEqual(start - end, 15) + + if __name__ == "__main__": unittest.main() diff --git a/test/test_session.py b/test/test_session.py index d0bbb075a8..5349e9604a 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -44,7 +44,7 @@ from gridfs.synchronous.grid_file import GridFS, GridFSBucket from pymongo import ASCENDING, MongoClient, monitoring from pymongo.common import _MAX_END_SESSIONS -from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure +from pymongo.errors import ConfigurationError, InvalidOperation, NetworkTimeout, OperationFailure from pymongo.operations import IndexModel, InsertOne, UpdateOne from pymongo.read_concern import ReadConcern from pymongo.synchronous.command_cursor import CommandCursor diff --git a/test/test_transactions.py b/test/test_transactions.py index 3cecbe9d38..4480df0edf 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -38,6 +38,7 @@ ConfigurationError, ConnectionFailure, InvalidOperation, + NetworkTimeout, OperationFailure, ) from pymongo.operations import IndexModel, InsertOne @@ -378,6 +379,42 @@ def find_raw_batches(*args, **kwargs): if isinstance(res, (CommandCursor, Cursor)): res.to_list() + @client_context.require_transactions + def test_10_convenient_transactions_csot(self): + self.client.db.coll.drop() + + listener = OvertCommandListener() + + with self.fail_point( + { + "mode": {"times": 2}, + "data": { + "failCommands": ["insert", "abortTransaction"], + "blockConnection": True, + "blockTimeMS": 200, + }, + } + ): + client = self.rs_or_single_client( + timeoutMS=150, + event_listeners=[listener], + ) + session = client.start_session() + + def callback(s): + client.db.coll.insert_one({"_id": 1}, session=s) + + with self.assertRaises(NetworkTimeout): + session.with_transaction(callback) + + started = listener.started_command_names() + failed = listener.failed_command_names() + + self.assertIn("insert", started) + self.assertIn("abortTransaction", started) + self.assertIn("insert", failed) + self.assertIn("abortTransaction", failed) + class PatchSessionTimeout: """Patches the client_session's with_transaction timeout for testing.""" diff --git a/test/utils.py b/test/utils.py index 3eac4fa509..177a2940fb 100644 --- a/test/utils.py +++ b/test/utils.py @@ -178,6 +178,10 @@ def started_command_names(self) -> List[str]: """Return list of command names started.""" return [event.command_name for event in self.started_events] + def failed_command_names(self) -> List[str]: + """Return list of command names failed.""" + return [event.command_name for event in self.failed_events] + def reset(self) -> None: """Reset the state of this listener.""" self.results.clear() From 907ffbc5fb70d9353de29f12d7be9963b6c1aa3d Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 1 Nov 2024 10:20:51 -0400 Subject: [PATCH 2/5] Fix test reqs --- test/asynchronous/test_client.py | 1 + test/asynchronous/test_collection.py | 1 + test/asynchronous/test_encryption.py | 6 ++++++ test/test_client.py | 1 + test/test_collection.py | 1 + test/test_encryption.py | 6 ++++++ test/test_gridfs_bucket.py | 6 ++++++ test/test_server_selection.py | 8 ++++++++ 8 files changed, 30 insertions(+) diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 6015aa452e..623dcd086e 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -2597,6 +2597,7 @@ class TestClientCSOTProse(AsyncIntegrationTest): # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#timeoutms-is-refreshed-for-each-handshake-command @async_client_context.require_auth @async_client_context.require_version_min(4, 4, -1) + @async_client_context.require_failCommand_appName async def test_02_timeoutMS_refreshed_for_each_handshake_command(self): listener = CMAPListener() diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index f048b9acd1..3b8da88602 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -2281,6 +2281,7 @@ async def afind(*args, **kwargs): # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#1-multi-batch-inserts @async_client_context.require_standalone @async_client_context.require_version_min(4, 4, -1) + @async_client_context.require_failCommand_fail_point async def test_01_multi_batch_inserts(self): client = await self.async_single_client(read_preference=ReadPreference.PRIMARY_PREFERRED) await client.db.coll.drop() diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index c1a4ebcac0..dd7295b6ce 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -3162,6 +3162,8 @@ async def asyncSetUp(self) -> None: ) # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#createdatakey + @async_client_context.require_failCommand_fail_point + @async_client_context.require_version_min(4, 4, -1) async def test_01_create_data_key(self): async with self.fail_point( { @@ -3178,6 +3180,8 @@ async def test_01_create_data_key(self): self.assertEqual("insert", events[0].command_name) # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#encrypt + @async_client_context.require_failCommand_fail_point + @async_client_context.require_version_min(4, 4, -1) async def test_02_encrypt(self): data_key_id = await self.client_encryption.create_data_key("local") self.assertEqual(4, data_key_id.subtype) @@ -3200,6 +3204,8 @@ async def test_02_encrypt(self): self.assertEqual("find", events[0].command_name) # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#decrypt + @async_client_context.require_failCommand_fail_point + @async_client_context.require_version_min(4, 4, -1) async def test_03_decrypt(self): data_key_id = await self.client_encryption.create_data_key("local") self.assertEqual(4, data_key_id.subtype) diff --git a/test/test_client.py b/test/test_client.py index f043fdbe71..63fd843bed 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -2553,6 +2553,7 @@ class TestClientCSOTProse(IntegrationTest): # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#timeoutms-is-refreshed-for-each-handshake-command @client_context.require_auth @client_context.require_version_min(4, 4, -1) + @client_context.require_failCommand_appName def test_02_timeoutMS_refreshed_for_each_handshake_command(self): listener = CMAPListener() diff --git a/test/test_collection.py b/test/test_collection.py index a392c733d2..87ed4d459a 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -2258,6 +2258,7 @@ def afind(*args, **kwargs): # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#1-multi-batch-inserts @client_context.require_standalone @client_context.require_version_min(4, 4, -1) + @client_context.require_failCommand_fail_point def test_01_multi_batch_inserts(self): client = self.single_client(read_preference=ReadPreference.PRIMARY_PREFERRED) client.db.coll.drop() diff --git a/test/test_encryption.py b/test/test_encryption.py index 569e4422b8..d8c7414a04 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -3144,6 +3144,8 @@ def setUp(self) -> None: ) # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#createdatakey + @client_context.require_failCommand_fail_point + @client_context.require_version_min(4, 4, -1) def test_01_create_data_key(self): with self.fail_point( { @@ -3160,6 +3162,8 @@ def test_01_create_data_key(self): self.assertEqual("insert", events[0].command_name) # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#encrypt + @client_context.require_failCommand_fail_point + @client_context.require_version_min(4, 4, -1) def test_02_encrypt(self): data_key_id = self.client_encryption.create_data_key("local") self.assertEqual(4, data_key_id.subtype) @@ -3182,6 +3186,8 @@ def test_02_encrypt(self): self.assertEqual("find", events[0].command_name) # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#decrypt + @client_context.require_failCommand_fail_point + @client_context.require_version_min(4, 4, -1) def test_03_decrypt(self): data_key_id = self.client_encryption.create_data_key("local") self.assertEqual(4, data_key_id.subtype) diff --git a/test/test_gridfs_bucket.py b/test/test_gridfs_bucket.py index 6c8b1a1c8d..4322b6b74a 100644 --- a/test/test_gridfs_bucket.py +++ b/test/test_gridfs_bucket.py @@ -529,6 +529,8 @@ def test_gridfs_secondary_lazy(self): # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#6-gridfs---upload class TestGridFsCSOT(IntegrationTest): # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#uploads-via-openuploadstream-can-be-timed-out + @client_context.require_failCommand_fail_point + @client_context.require_version_min(4, 4, -1) def test_06_01_uploads_via_open_upload_stream_can_be_timed_out(self): self.client.db.fs.files.drop() self.client.db.fs.chunks.drop() @@ -547,6 +549,8 @@ def test_06_01_uploads_via_open_upload_stream_can_be_timed_out(self): upload_stream.close() # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#aborting-an-upload-stream-can-be-timed-out + @client_context.require_failCommand_fail_point + @client_context.require_version_min(4, 4, -1) def test_06_02_aborting_an_upload_stream_can_be_timed_out(self): self.client.db.fs.files.drop() self.client.db.fs.chunks.drop() @@ -565,6 +569,8 @@ def test_06_02_aborting_an_upload_stream_can_be_timed_out(self): upload_stream.abort() # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#7-gridfs---download + @client_context.require_failCommand_fail_point + @client_context.require_version_min(4, 4, -1) def test_07_gridfs_download_csot(self): self.client.db.fs.files.drop() self.client.db.fs.chunks.drop() diff --git a/test/test_server_selection.py b/test/test_server_selection.py index dc7be00692..ab6f7ae14e 100644 --- a/test/test_server_selection.py +++ b/test/test_server_selection.py @@ -204,6 +204,7 @@ def test_server_selector_bypassed(self): # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#8-server-selection class TestServerSelectionCSOT(IntegrationTest): # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#serverselectiontimeoutms-honored-if-timeoutms-is-not-set + @client_context.require_version_min(4, 4, -1) def test_08_01_server_selection_timeoutMS_honored(self): client = self.single_client("mongodb://invalid/?serverSelectionTimeoutMS=10") with self.assertRaises(ServerSelectionTimeoutError): @@ -215,6 +216,7 @@ def test_08_01_server_selection_timeoutMS_honored(self): self.assertLessEqual(start - end, 15) # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#timeoutms-honored-for-server-selection-if-its-lower-than-serverselectiontimeoutms + @client_context.require_version_min(4, 4, -1) def test_08_02_timeoutMS_honored_for_server_selection_if_lower(self): client = self.single_client("mongodb://invalid/?timeoutMS=10&serverSelectionTimeoutMS=20") with self.assertRaises(ServerSelectionTimeoutError): @@ -225,6 +227,7 @@ def test_08_02_timeoutMS_honored_for_server_selection_if_lower(self): self.assertLessEqual(start - end, 15) # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#serverselectiontimeoutms-honored-for-server-selection-if-its-lower-than-timeoutms + @client_context.require_version_min(4, 4, -1) def test_08_03_serverselectiontimeoutms_honored_for_server_selection_if_lower(self): client = self.single_client("mongodb://invalid/?timeoutMS=20&serverSelectionTimeoutMS=10") with self.assertRaises(ServerSelectionTimeoutError): @@ -236,6 +239,7 @@ def test_08_03_serverselectiontimeoutms_honored_for_server_selection_if_lower(se self.assertLessEqual(start - end, 15) # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#serverselectiontimeoutms-honored-for-server-selection-if-timeoutms0 + @client_context.require_version_min(4, 4, -1) def test_08_04_serverselectiontimeoutms_honored_for_server_selection_if_zero_timeoutms(self): client = self.single_client("mongodb://invalid/?timeoutMS=0&serverSelectionTimeoutMS=10") with self.assertRaises(ServerSelectionTimeoutError): @@ -248,6 +252,8 @@ def test_08_04_serverselectiontimeoutms_honored_for_server_selection_if_zero_tim # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#timeoutms-honored-for-connection-handshake-commands-if-its-lower-than-serverselectiontimeoutms @client_context.require_auth + @client_context.require_version_min(4, 4, -1) + @client_context.require_failCommand_fail_point def test_08_05_timeoutms_honored_for_handshake_if_lower(self): with self.fail_point( { @@ -270,6 +276,8 @@ def test_08_05_timeoutms_honored_for_handshake_if_lower(self): # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#serverselectiontimeoutms-honored-for-connection-handshake-commands-if-its-lower-than-timeoutms @client_context.require_auth + @client_context.require_version_min(4, 4, -1) + @client_context.require_failCommand_fail_point def test_08_06_serverSelectionTimeoutMS_honored_for_handshake_if_lower(self): with self.fail_point( { From 9c0c54b164d73f48bd4c9926c58037c75bdc866e Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 1 Nov 2024 12:35:07 -0400 Subject: [PATCH 3/5] Fix csot reset --- pymongo/_csot.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pymongo/_csot.py b/pymongo/_csot.py index fb77a69ca1..ac9ccd8ead 100644 --- a/pymongo/_csot.py +++ b/pymongo/_csot.py @@ -67,7 +67,11 @@ def remaining() -> Optional[float]: @contextlib.contextmanager def reset() -> Generator: - deadline_token = DEADLINE.set(DEADLINE.get() + get_timeout()) # type: ignore[operator] + timeout = get_timeout() + if timeout is None: + deadline_token = DEADLINE.set(DEADLINE.get()) + else: + deadline_token = DEADLINE.set(DEADLINE.get() + timeout) # type: ignore[operator] yield DEADLINE.reset(deadline_token) From b98b99af6774dac0002fe0aade8f0f761592e70e Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 1 Nov 2024 13:22:10 -0400 Subject: [PATCH 4/5] Fix csot reset --- pymongo/_csot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymongo/_csot.py b/pymongo/_csot.py index ac9ccd8ead..fc91f57436 100644 --- a/pymongo/_csot.py +++ b/pymongo/_csot.py @@ -71,7 +71,7 @@ def reset() -> Generator: if timeout is None: deadline_token = DEADLINE.set(DEADLINE.get()) else: - deadline_token = DEADLINE.set(DEADLINE.get() + timeout) # type: ignore[operator] + deadline_token = DEADLINE.set(DEADLINE.get() + timeout) yield DEADLINE.reset(deadline_token) From a364a91d93bc95a143f25b4187f9c6c30f0688f3 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 1 Nov 2024 14:07:08 -0400 Subject: [PATCH 5/5] Upadte runOnReqs for test_10_convenient_transactions_csot --- test/asynchronous/test_transactions.py | 3 +++ test/test_transactions.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py index 3b533114df..f2bd94e25c 100644 --- a/test/asynchronous/test_transactions.py +++ b/test/asynchronous/test_transactions.py @@ -387,7 +387,10 @@ async def find_raw_batches(*args, **kwargs): if isinstance(res, (AsyncCommandCursor, AsyncCursor)): await res.to_list() + # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#10-convenient-transactions @async_client_context.require_transactions + @async_client_context.require_version_min(4, 4, -1) + @async_client_context.require_failCommand_fail_point async def test_10_convenient_transactions_csot(self): await self.client.db.coll.drop() diff --git a/test/test_transactions.py b/test/test_transactions.py index 4480df0edf..a7cab34d78 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -379,7 +379,10 @@ def find_raw_batches(*args, **kwargs): if isinstance(res, (CommandCursor, Cursor)): res.to_list() + # https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#10-convenient-transactions @client_context.require_transactions + @client_context.require_version_min(4, 4, -1) + @client_context.require_failCommand_fail_point def test_10_convenient_transactions_csot(self): self.client.db.coll.drop()