Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PYTHON-4927 - Add missing CSOT prose tests #1987

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion pymongo/_csot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,24 @@

from __future__ import annotations

import contextlib
import functools
import inspect
import time
from collections import deque
from contextlib import AbstractContextManager
from contextvars import ContextVar, Token
from typing import TYPE_CHECKING, Any, Callable, Deque, MutableMapping, Optional, TypeVar, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Deque,
Generator,
MutableMapping,
Optional,
TypeVar,
cast,
)

if TYPE_CHECKING:
from pymongo.write_concern import WriteConcern
Expand Down Expand Up @@ -54,6 +65,17 @@ def remaining() -> Optional[float]:
return DEADLINE.get() - time.monotonic()


@contextlib.contextmanager
def reset() -> Generator:
timeout = get_timeout()
if timeout is None:
deadline_token = DEADLINE.set(DEADLINE.get())
else:
deadline_token = DEADLINE.set(DEADLINE.get() + timeout)
yield
DEADLINE.reset(deadline_token)


def clamp_remaining(max_timeout: float) -> float:
"""Return the remaining timeout clamped to a max value."""
timeout = remaining()
Expand Down
12 changes: 10 additions & 2 deletions pymongo/asynchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,11 @@ def _max_time_expired_error(exc: PyMongoError) -> bool:

def _within_time_limit(start_time: float) -> bool:
"""Are we within the with_transaction retry limit?"""
return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
timeout = _csot.get_timeout()
if timeout:
return time.monotonic() - start_time < timeout
else:
return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT


_T = TypeVar("_T")
Expand Down Expand Up @@ -512,6 +516,7 @@ def __init__(
# Is this an implicitly created session?
self._implicit = implicit
self._transaction = _Transaction(None, client)
self._timeout = client.options.timeout

async def end_session(self) -> None:
"""Finish this session. If a transaction has started, abort it.
Expand Down Expand Up @@ -597,6 +602,7 @@ def _inherit_option(self, name: str, val: _T) -> _T:
return parent_val
return getattr(self.client, name)

@_csot.apply
async def with_transaction(
self,
callback: Callable[[AsyncClientSession], Coroutine[Any, Any, _T]],
Expand Down Expand Up @@ -697,7 +703,8 @@ async def callback(session, custom_arg, custom_kwarg=None):
ret = await callback(self)
except Exception as exc:
if self.in_transaction:
await self.abort_transaction()
with _csot.reset():
await self.abort_transaction()
if (
isinstance(exc, PyMongoError)
and exc.has_error_label("TransientTransactionError")
Expand Down Expand Up @@ -816,6 +823,7 @@ async def commit_transaction(self) -> None:
finally:
self._transaction.state = _TxnState.COMMITTED

@_csot.apply
async def abort_transaction(self) -> None:
"""Abort a multi-statement transaction.

Expand Down
3 changes: 2 additions & 1 deletion pymongo/asynchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ def get_server_selection_timeout(self) -> float:
timeout = _csot.remaining()
if timeout is None:
return self._settings.server_selection_timeout
return timeout
else:
return min(timeout, self._settings.server_selection_timeout)

async def select_servers(
self,
Expand Down
12 changes: 10 additions & 2 deletions pymongo/synchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,11 @@ def _max_time_expired_error(exc: PyMongoError) -> bool:

def _within_time_limit(start_time: float) -> bool:
"""Are we within the with_transaction retry limit?"""
return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
timeout = _csot.get_timeout()
if timeout:
return time.monotonic() - start_time < timeout
else:
return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT


_T = TypeVar("_T")
Expand Down Expand Up @@ -511,6 +515,7 @@ def __init__(
# Is this an implicitly created session?
self._implicit = implicit
self._transaction = _Transaction(None, client)
self._timeout = client.options.timeout

def end_session(self) -> None:
"""Finish this session. If a transaction has started, abort it.
Expand Down Expand Up @@ -596,6 +601,7 @@ def _inherit_option(self, name: str, val: _T) -> _T:
return parent_val
return getattr(self.client, name)

@_csot.apply
def with_transaction(
self,
callback: Callable[[ClientSession], _T],
Expand Down Expand Up @@ -694,7 +700,8 @@ def callback(session, custom_arg, custom_kwarg=None):
ret = callback(self)
except Exception as exc:
if self.in_transaction:
self.abort_transaction()
with _csot.reset():
self.abort_transaction()
if (
isinstance(exc, PyMongoError)
and exc.has_error_label("TransientTransactionError")
Expand Down Expand Up @@ -813,6 +820,7 @@ def commit_transaction(self) -> None:
finally:
self._transaction.state = _TxnState.COMMITTED

@_csot.apply
def abort_transaction(self) -> None:
"""Abort a multi-statement transaction.

Expand Down
3 changes: 2 additions & 1 deletion pymongo/synchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ def get_server_selection_timeout(self) -> float:
timeout = _csot.remaining()
if timeout is None:
return self._settings.server_selection_timeout
return timeout
else:
return min(timeout, self._settings.server_selection_timeout)

def select_servers(
self,
Expand Down
48 changes: 47 additions & 1 deletion test/asynchronous/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from test.utils import (
NTHREADS,
CMAPListener,
EventListener,
FunctionCallRecorder,
async_get_pool,
async_wait_until,
Expand Down Expand Up @@ -114,7 +115,13 @@
ServerSelectionTimeoutError,
WriteConcernError,
)
from pymongo.monitoring import ServerHeartbeatListener, ServerHeartbeatStartedEvent
from pymongo.monitoring import (
ConnectionClosedEvent,
ConnectionCreatedEvent,
ConnectionReadyEvent,
ServerHeartbeatListener,
ServerHeartbeatStartedEvent,
)
from pymongo.pool_options import _MAX_METADATA_SIZE, _METADATA, ENV_VAR_K8S, PoolOptions
from pymongo.read_preferences import ReadPreference
from pymongo.server_description import ServerDescription
Expand Down Expand Up @@ -2585,5 +2592,44 @@ async def test_direct_client_maintains_pool_to_arbiter(self):
self.assertEqual(listener.event_count(monitoring.PoolReadyEvent), 1)


# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#4-background-connection-pooling
class TestClientCSOTProse(AsyncIntegrationTest):
# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#timeoutms-is-refreshed-for-each-handshake-command
@async_client_context.require_auth
@async_client_context.require_version_min(4, 4, -1)
@async_client_context.require_failCommand_appName
async def test_02_timeoutMS_refreshed_for_each_handshake_command(self):
listener = CMAPListener()

async with self.fail_point(
{
"mode": {"times": 1},
"data": {
"failCommands": ["hello", "isMaster", "saslContinue"],
"blockConnection": True,
"blockTimeMS": 15,
"appName": "refreshTimeoutBackgroundPoolTest",
},
}
):
_ = await self.async_single_client(
minPoolSize=1,
timeoutMS=20,
appname="refreshTimeoutBackgroundPoolTest",
event_listeners=[listener],
)

async def predicate():
return (
listener.event_count(ConnectionCreatedEvent) == 1
and listener.event_count(ConnectionReadyEvent) == 1
)

await async_wait_until(
predicate,
"didn't ever see a ConnectionCreatedEvent and a ConnectionReadyEvent",
)


if __name__ == "__main__":
unittest.main()
27 changes: 27 additions & 0 deletions test/asynchronous/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
InvalidDocument,
InvalidName,
InvalidOperation,
NetworkTimeout,
OperationFailure,
WriteConcernError,
)
Expand Down Expand Up @@ -2277,6 +2278,32 @@ async def afind(*args, **kwargs):
for helper, args in helpers:
await helper(*args, let={}) # type: ignore

# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#1-multi-batch-inserts
@async_client_context.require_standalone
@async_client_context.require_version_min(4, 4, -1)
@async_client_context.require_failCommand_fail_point
async def test_01_multi_batch_inserts(self):
client = await self.async_single_client(read_preference=ReadPreference.PRIMARY_PREFERRED)
await client.db.coll.drop()

async with self.fail_point(
{
"mode": {"times": 2},
"data": {"failCommands": ["insert"], "blockConnection": True, "blockTimeMS": 1010},
}
):
listener = OvertCommandListener()
client2 = await self.async_single_client(
timeoutMS=2000,
read_preference=ReadPreference.PRIMARY_PREFERRED,
event_listeners=[listener],
)
docs = [{"a": "b" * 1000000} for _ in range(50)]
with self.assertRaises(NetworkTimeout):
await client2.db.coll.insert_many(docs)

self.assertEqual(2, len(listener.started_events))


if __name__ == "__main__":
unittest.main()
109 changes: 109 additions & 0 deletions test/asynchronous/test_encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
EncryptedCollectionError,
EncryptionError,
InvalidOperation,
NetworkTimeout,
OperationFailure,
ServerSelectionTimeoutError,
WriteError,
Expand Down Expand Up @@ -3133,5 +3134,113 @@ async def test_explicit_session_errors_when_unsupported(self):
await self.mongocryptd_client.db.test.insert_one({"x": 1}, session=s)


# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#3-clientencryption
class TestCSOTProse(AsyncEncryptionIntegrationTest):
mongocryptd_client: AsyncMongoClient
MONGOCRYPTD_PORT = 27020
LOCAL_MASTERKEY = Binary(
base64.b64decode(
b"Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk"
),
UUID_SUBTYPE,
)

async def asyncSetUp(self) -> None:
self.listener = OvertCommandListener()
self.client = await self.async_single_client(
read_preference=ReadPreference.PRIMARY_PREFERRED, event_listeners=[self.listener]
)
await self.client.keyvault.datakeys.drop()
self.key_vault_client = await self.async_rs_or_single_client(
timeoutMS=50, event_listeners=[self.listener]
)
self.client_encryption = self.create_client_encryption(
key_vault_namespace="keyvault.datakeys",
kms_providers={"local": {"key": self.LOCAL_MASTERKEY}},
key_vault_client=self.key_vault_client,
codec_options=OPTS,
)

# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#createdatakey
@async_client_context.require_failCommand_fail_point
@async_client_context.require_version_min(4, 4, -1)
async def test_01_create_data_key(self):
async with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["insert"], "blockConnection": True, "blockTimeMS": 100},
}
):
self.listener.reset()
with self.assertRaisesRegex(EncryptionError, "timed out"):
await self.client_encryption.create_data_key("local")

events = self.listener.started_events
self.assertEqual(1, len(events))
self.assertEqual("insert", events[0].command_name)

# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#encrypt
@async_client_context.require_failCommand_fail_point
@async_client_context.require_version_min(4, 4, -1)
async def test_02_encrypt(self):
data_key_id = await self.client_encryption.create_data_key("local")
self.assertEqual(4, data_key_id.subtype)
async with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["find"], "blockConnection": True, "blockTimeMS": 100},
}
):
self.listener.reset()
with self.assertRaisesRegex(EncryptionError, "timed out"):
await self.client_encryption.encrypt(
"hello",
Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic,
key_id=data_key_id,
)

events = self.listener.started_events
self.assertEqual(1, len(events))
self.assertEqual("find", events[0].command_name)

# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#decrypt
@async_client_context.require_failCommand_fail_point
@async_client_context.require_version_min(4, 4, -1)
async def test_03_decrypt(self):
data_key_id = await self.client_encryption.create_data_key("local")
self.assertEqual(4, data_key_id.subtype)

encrypted = await self.client_encryption.encrypt(
"hello", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=data_key_id
)
self.assertEqual(6, encrypted.subtype)

await self.key_vault_client.close()
self.key_vault_client = await self.async_rs_or_single_client(
timeoutMS=50, event_listeners=[self.listener]
)
await self.client_encryption.close()
self.client_encryption = self.create_client_encryption(
key_vault_namespace="keyvault.datakeys",
kms_providers={"local": {"key": self.LOCAL_MASTERKEY}},
key_vault_client=self.key_vault_client,
codec_options=OPTS,
)

async with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["find"], "blockConnection": True, "blockTimeMS": 100},
}
):
self.listener.reset()
with self.assertRaisesRegex(EncryptionError, "timed out"):
await self.client_encryption.decrypt(encrypted)

events = self.listener.started_events
self.assertEqual(1, len(events))
self.assertEqual("find", events[0].command_name)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion test/asynchronous/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.asynchronous.helpers import anext
from pymongo.common import _MAX_END_SESSIONS
from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure
from pymongo.errors import ConfigurationError, InvalidOperation, NetworkTimeout, OperationFailure
from pymongo.operations import IndexModel, InsertOne, UpdateOne
from pymongo.read_concern import ReadConcern

Expand Down
Loading
Loading