Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for mbedtls_ssl_conf_read_timeout #102

Merged
merged 3 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
[next]

* tls: Add `mbedtls_ssl_conf_read_timeout`, for the read timeout
configuration

[2.8.0] - 2023-11-28

* ci: Update wheels to mbedtls 2.28.6
Expand Down
8 changes: 7 additions & 1 deletion src/mbedtls/_tls.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ cdef extern from "mbedtls/ssl.h" nogil:
# set_handshake_timeout
unsigned int hs_timeout_min
unsigned int hs_timeout_max
# set_read_timeout
unsigned int read_timeout

unsigned int endpoint
unsigned int transport
Expand Down Expand Up @@ -247,7 +249,10 @@ cdef extern from "mbedtls/ssl.h" nogil:
void *p_dbg
)

# mbedtls_ssl_conf_read_timeout
void mbedtls_ssl_conf_read_timeout(
mbedtls_ssl_config *conf,
unsigned int timeout
)
# mbedtls_ssl_conf_session_tickets_cb
# mbedtls_ssl_conf_export_keys_cb

Expand Down Expand Up @@ -427,6 +432,7 @@ cdef class MbedTLSConfiguration:
cdef _set_max_fragmentation_length(self, object mfl)
cdef _set_anti_replay(self, mode)
cdef _set_handshake_timeout(self, minimum, maximum)
cdef _set_read_timeout(self, timeout)
cdef _set_cookie(self, _DTLSCookie cookie)
cdef _set_sni_callback(self, callback)
cdef _set_pre_shared_key(self, psk)
Expand Down
34 changes: 34 additions & 0 deletions src/mbedtls/_tls.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ cdef class MbedTLSConfiguration:
# badmac_limit
handshake_timeout_min,
handshake_timeout_max,
read_timeout,
sni_callback,
pre_shared_key,
pre_shared_key_store,
Expand All @@ -423,6 +424,7 @@ cdef class MbedTLSConfiguration:
self._set_handshake_timeout(
handshake_timeout_min, handshake_timeout_max
)
self._set_read_timeout(read_timeout)
self._set_sni_callback(sni_callback)
self._set_pre_shared_key(pre_shared_key)
self._set_pre_shared_key_store(pre_shared_key_store)
Expand Down Expand Up @@ -479,6 +481,7 @@ cdef class MbedTLSConfiguration:
self.anti_replay,
self.handshake_timeout_min,
self.handshake_timeout_max,
self.read_timeout,
self.sni_callback,
self.pre_shared_key,
self.pre_shared_key_store,
Expand Down Expand Up @@ -802,6 +805,34 @@ cdef class MbedTLSConfiguration:

return float(self._ctx.hs_timeout_max) / 1000.0

cdef _set_read_timeout(self, timeout):
"""Set TLS/DTLS read timeout.
Use 0 for no timeout.

Args:
timeout (float, optional): read timeout in seconds.

"""
if timeout is None:
return

def validate(extremum, *, default: float) -> float:
if extremum is None:
return default
if extremum < 0.0:
raise ValueError(extremum)
return extremum

_tls.mbedtls_ssl_conf_read_timeout(
&self._ctx,
int(1000.0 * validate(timeout, default=0))
)

@property
def read_timeout(self):
"""Read timeout in seconds. Use 0 for no timeout. (default 0)."""
return float(self._ctx.read_timeout) / 1000.0

cdef _set_sni_callback(self, callback):
# PEP 543, optional, server-side only
if callback is None:
Expand Down Expand Up @@ -916,6 +947,7 @@ cdef class _BaseContext:
anti_replay=None,
handshake_timeout_min=None,
handshake_timeout_max=None,
read_timeout=None,
sni_callback=configuration.sni_callback,
pre_shared_key=configuration.pre_shared_key,
pre_shared_key_store=configuration.pre_shared_key_store,
Expand All @@ -938,6 +970,7 @@ cdef class _BaseContext:
anti_replay=configuration.anti_replay,
handshake_timeout_min=configuration.handshake_timeout_min,
handshake_timeout_max=configuration.handshake_timeout_max,
read_timeout=configuration.read_timeout,
sni_callback=configuration.sni_callback,
pre_shared_key=configuration.pre_shared_key,
pre_shared_key_store=configuration.pre_shared_key_store,
Expand Down Expand Up @@ -991,6 +1024,7 @@ cdef class _BaseContext:
anti_replay=self._conf.anti_replay,
handshake_timeout_min=self._conf.handshake_timeout_min,
handshake_timeout_max=self._conf.handshake_timeout_max,
read_timeout=self._conf.read_timeout,
sni_callback=self._conf.sni_callback,
pre_shared_key=self._conf.pre_shared_key,
pre_shared_key_store=self._conf.pre_shared_key_store,
Expand Down
11 changes: 11 additions & 0 deletions src/mbedtls/_tlsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ class TLSConfiguration:
highest_supported_version: TLSVersion = TLSVersion.MAXIMUM_SUPPORTED
trust_store: Optional[TrustStore] = None
max_fragmentation_length: Optional[MaxFragmentLength] = None
read_timeout: float = 0.0
sni_callback: Optional[ServerNameCallback] = None
pre_shared_key: Optional[Tuple[str, bytes]] = None
pre_shared_key_store: Mapping[str, bytes] = field(default_factory=dict)
Expand Down Expand Up @@ -250,6 +251,7 @@ def __eq__(self, other: object) -> bool:
self.sni_callback == other.sni_callback,
self.pre_shared_key == other.pre_shared_key,
self.pre_shared_key_store == other.pre_shared_key_store,
self.read_timeout == other.read_timeout,
)
)

Expand All @@ -263,6 +265,7 @@ def update(
highest_supported_version: _Wrap[TLSVersion] = _DEFAULT_VALUE,
trust_store: _Wrap[TrustStore] = _DEFAULT_VALUE,
max_fragmentation_length: _Wrap[MaxFragmentLength] = _DEFAULT_VALUE,
read_timeout: _Wrap[float] = _DEFAULT_VALUE,
sni_callback: _Wrap[Optional[ServerNameCallback]] = _DEFAULT_VALUE,
pre_shared_key: _Wrap[Tuple[str, bytes]] = _DEFAULT_VALUE,
pre_shared_key_store: _Wrap[Mapping[str, bytes]] = _DEFAULT_VALUE,
Expand Down Expand Up @@ -296,6 +299,7 @@ def update(
self.max_fragmentation_length,
),
sni_callback=_unwrap(sni_callback, self.sni_callback),
read_timeout=_unwrap(read_timeout, self.read_timeout),
pre_shared_key=_unwrap(pre_shared_key, self.pre_shared_key),
pre_shared_key_store=_unwrap(
pre_shared_key_store,
Expand All @@ -318,6 +322,7 @@ class DTLSConfiguration:
anti_replay: bool = True
handshake_timeout_min: float = 1.0
handshake_timeout_max: float = 60.0
read_timeout: float = 0.0
sni_callback: Optional[ServerNameCallback] = None
pre_shared_key: Optional[Tuple[str, bytes]] = None
pre_shared_key_store: Mapping[str, bytes] = field(default_factory=dict)
Expand Down Expand Up @@ -375,6 +380,7 @@ def __eq__(self, other: object) -> bool:
self.anti_replay == other.anti_replay,
self.handshake_timeout_min == other.handshake_timeout_min,
self.handshake_timeout_max == other.handshake_timeout_max,
self.read_timeout == other.read_timeout,
self.sni_callback == other.sni_callback,
self.pre_shared_key == other.pre_shared_key,
self.pre_shared_key_store == other.pre_shared_key_store,
Expand All @@ -394,6 +400,7 @@ def update(
anti_replay: _Wrap[bool] = _DEFAULT_VALUE,
handshake_timeout_min: _Wrap[float] = _DEFAULT_VALUE,
handshake_timeout_max: _Wrap[float] = _DEFAULT_VALUE,
read_timeout: _Wrap[float] = _DEFAULT_VALUE,
sni_callback: _Wrap[ServerNameCallback] = _DEFAULT_VALUE,
pre_shared_key: _Wrap[Tuple[str, bytes]] = _DEFAULT_VALUE,
pre_shared_key_store: _Wrap[Mapping[str, bytes]] = _DEFAULT_VALUE,
Expand Down Expand Up @@ -435,6 +442,10 @@ def update(
handshake_timeout_max,
self.handshake_timeout_max,
),
read_timeout=_unwrap(
read_timeout,
self.read_timeout,
),
sni_callback=_unwrap(sni_callback, self.sni_callback),
pre_shared_key=_unwrap(pre_shared_key, self.pre_shared_key),
pre_shared_key_store=_unwrap(
Expand Down
14 changes: 14 additions & 0 deletions tests/test_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,20 @@ def test_handshake_timeout_default(
assert_conf_invariant(conf, handshake_timeout_min=hs_min)
assert_conf_invariant(conf, handshake_timeout_max=hs_max)

@pytest.mark.parametrize("timeout", [1, 10, 5.3, 300])
def test_read_timeout_default(
self,
conf: Union[TLSConfiguration, DTLSConfiguration],
timeout: float,
) -> None:
assert conf.read_timeout == 0
conf_ = conf.update(
read_timeout=timeout,
)
assert conf_.read_timeout == timeout

assert_conf_invariant(conf, read_timeout=timeout)


class TestContext:
@pytest.fixture(params=[Purpose.SERVER_AUTH, Purpose.CLIENT_AUTH])
Expand Down