Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(TEST) Apply Florian's patch to 2.30.x #4204

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion continuous_integration/environment-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies:
- prometheus_client
- psutil
- pytest
- pytest-asyncio
- pytest-asyncio<0.14.0
- pytest-repeat
- pytest-timeout
- pytest-faulthandler
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/travis/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ conda create -n dask-distributed -c conda-forge -c defaults \
prometheus_client \
psutil \
'pytest>=4' \
pytest-asyncio \
'pytest-asyncio<0.14.0' \
pytest-faulthandler \
pytest-repeat \
pytest-timeout \
Expand Down
136 changes: 67 additions & 69 deletions distributed/comm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,15 @@ async def _():

async def on_connection(self, comm: Comm, handshake_overrides=None):
local_info = {**comm.handshake_info(), **(handshake_overrides or {})}

timeout = dask.config.get("distributed.comm.timeouts.connect")
timeout = parse_timedelta(timeout, default="seconds")
try:
write = await asyncio.wait_for(comm.write(local_info), 1)
handshake = await asyncio.wait_for(comm.read(), 1)
# Timeout is to ensure that we'll terminate connections eventually.
# Connector side will employ smaller timeouts and we should only
# reach this if the comm is dead anyhow.
write = await asyncio.wait_for(comm.write(local_info), timeout=timeout)
handshake = await asyncio.wait_for(comm.read(), timeout=timeout)
# This would be better, but connections leak if worker is closed quickly
# write, handshake = await asyncio.gather(comm.write(local_info), comm.read())
except Exception as e:
Expand Down Expand Up @@ -262,79 +268,71 @@ async def connect(
comm = None

start = time()
deadline = start + timeout
error = None

def _raise(error):
error = error or "connect() didn't finish in time"
msg = "Timed out trying to connect to %r after %s s: %s" % (
addr,
timeout,
error,
)
raise IOError(msg)

backoff = 0.01
if timeout and timeout / 20 < backoff:
backoff = timeout / 20
def time_left():
deadline = start + timeout
return max(0, deadline - time())

retry_timeout_backoff = random.randrange(140, 160) / 100
backoff_base = 0.01
attempt = 0

# This starts a thread
while True:
# Prefer multiple small attempts than one long attempt. This should protect
# primarily from DNS race conditions
# gh3104, gh4176, gh4167
intermediate_cap = timeout / 5
active_exception = None
while time_left() > 0:
try:
while deadline - time() > 0:

async def _():
comm = await connector.connect(
loc, deserialize=deserialize, **connection_args
)
local_info = {
**comm.handshake_info(),
**(handshake_overrides or {}),
}
try:
handshake = await asyncio.wait_for(comm.read(), 1)
write = await asyncio.wait_for(comm.write(local_info), 1)
# This would be better, but connections leak if worker is closed quickly
# write, handshake = await asyncio.gather(comm.write(local_info), comm.read())
except Exception as e:
with suppress(Exception):
await comm.close()
raise CommClosedError() from e

comm.remote_info = handshake
comm.remote_info["address"] = comm._peer_addr
comm.local_info = local_info
comm.local_info["address"] = comm._local_addr

comm.handshake_options = comm.handshake_configuration(
comm.local_info, comm.remote_info
)
return comm

with suppress(TimeoutError):
comm = await asyncio.wait_for(
_(), timeout=min(deadline - time(), retry_timeout_backoff)
)
break
if not comm:
_raise(error)
comm = await asyncio.wait_for(
connector.connect(loc, deserialize=deserialize, **connection_args),
timeout=min(intermediate_cap, time_left()),
)
break
except FatalCommClosedError:
raise
except EnvironmentError as e:
error = str(e)
if time() < deadline:
logger.debug("Could not connect, waiting before retrying")
await asyncio.sleep(backoff)
backoff *= random.randrange(140, 160) / 100
retry_timeout_backoff *= random.randrange(140, 160) / 100
backoff = min(backoff, 1) # wait at most one second
else:
_raise(error)
else:
break

# CommClosed, EnvironmentError inherit from OSError
except (TimeoutError, OSError) as exc:
active_exception = exc

# The intermediate capping is mostly relevant for the initial
# connect. Afterwards we should be more forgiving
intermediate_cap = intermediate_cap * 1.5
# FullJitter see https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/

upper_cap = min(time_left(), backoff_base * (2 ** attempt))
backoff = random.uniform(0, upper_cap)
attempt += 1
logger.debug("Could not connect, waiting for %s before retrying", backoff)
await asyncio.sleep(backoff)
else:
raise IOError(
f"Timed out trying to connect to {addr} after {timeout} s"
) from active_exception

local_info = {
**comm.handshake_info(),
**(handshake_overrides or {}),
}
try:
# This would be better, but connections leak if worker is closed quickly
# write, handshake = await asyncio.gather(comm.write(local_info), comm.read())
handshake = await asyncio.wait_for(comm.read(), time_left())
await asyncio.wait_for(comm.write(local_info), time_left())
except Exception as exc:
with suppress(Exception):
await comm.close()
raise IOError(
f"Timed out during handshake while connecting to {addr} after {timeout} s"
) from exc

comm.remote_info = handshake
comm.remote_info["address"] = comm._peer_addr
comm.local_info = local_info
comm.local_info["address"] = comm._local_addr

comm.handshake_options = comm.handshake_configuration(
comm.local_info, comm.remote_info
)
return comm


Expand Down
4 changes: 3 additions & 1 deletion distributed/comm/tcp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import errno
import logging
import socket
from ssl import SSLError
import struct
import sys
from tornado import gen
Expand Down Expand Up @@ -349,7 +350,6 @@ async def connect(self, address, deserialize=True, **connection_args):
stream = await self.client.connect(
ip, port, max_buffer_size=MAX_BUFFER_SIZE, **kwargs
)

# Under certain circumstances tornado will have a closed connnection with an error and not raise
# a StreamClosedError.
#
Expand All @@ -360,6 +360,8 @@ async def connect(self, address, deserialize=True, **connection_args):
except StreamClosedError as e:
# The socket connect() call failed
convert_stream_closed_error(self, e)
except SSLError as err:
raise FatalCommClosedError() from err

local_address = self.prefix + get_stream_address(stream)
comm = self.comm_class(
Expand Down
140 changes: 110 additions & 30 deletions distributed/comm/tests/test_comms.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,42 @@
import asyncio
import types
from functools import partial
import os
import sys
import threading
import types
import warnings
from functools import partial

import distributed
import pkg_resources
import pytest

from tornado import ioloop
from tornado.concurrent import Future

import distributed
from distributed.metrics import time
from distributed.utils import get_ip, get_ipv6
from distributed.utils_test import (
requires_ipv6,
has_ipv6,
get_cert,
get_server_ssl_context,
get_client_ssl_context,
)
from distributed.utils_test import loop # noqa: F401

from distributed.protocol import to_serialize, Serialized, serialize, deserialize

from distributed.comm.registry import backends, get_backend
from distributed.comm import (
tcp,
inproc,
CommClosedError,
connect,
get_address_host,
get_local_address_for,
inproc,
listen,
CommClosedError,
parse_address,
parse_host_port,
unparse_host_port,
resolve_address,
get_address_host,
get_local_address_for,
tcp,
unparse_host_port,
)
from distributed.comm.registry import backends, get_backend
from distributed.comm.tcp import TCP, TCPBackend, TCPConnector
from distributed.metrics import time
from distributed.protocol import Serialized, deserialize, serialize, to_serialize
from distributed.utils import get_ip, get_ipv6
from distributed.utils_test import loop # noqa: F401
from distributed.utils_test import (
get_cert,
get_client_ssl_context,
get_server_ssl_context,
has_ipv6,
requires_ipv6,
)
from tornado import ioloop
from tornado.concurrent import Future

EXTERNAL_IP4 = get_ip()
if has_ipv6():
Expand Down Expand Up @@ -218,7 +215,7 @@ async def handle_comm(comm):
await comm.write(msg)
await comm.close()

listener = await tcp.TCPListener("localhost", handle_comm)
listener = await tcp.TCPListener("127.0.0.1", handle_comm)
host, port = listener.get_host_port()
assert host in ("localhost", "127.0.0.1", "::1")
assert port > 0
Expand Down Expand Up @@ -264,7 +261,7 @@ async def handle_comm(comm):
server_ctx = get_server_ssl_context()
client_ctx = get_client_ssl_context()

listener = await tcp.TLSListener("localhost", handle_comm, ssl_context=server_ctx)
listener = await tcp.TLSListener("127.0.0.1", handle_comm, ssl_context=server_ctx)
host, port = listener.get_host_port()
assert host in ("localhost", "127.0.0.1", "::1")
assert port > 0
Expand Down Expand Up @@ -665,7 +662,8 @@ async def handle_comm(comm):

with pytest.raises(EnvironmentError) as excinfo:
await connect(listener.contact_address, timeout=2, ssl_context=cli_ctx)
assert "certificate verify failed" in str(excinfo.value)

assert "certificate verify failed" in str(excinfo.value.__cause__)


#
Expand Down Expand Up @@ -797,6 +795,88 @@ async def handle_comm(comm):
#


async def echo(comm):
message = await comm.read()
await comm.write(message)


@pytest.mark.asyncio
async def test_retry_connect(monkeypatch):
async def echo(comm):
message = await comm.read()
await comm.write(message)

class UnreliableConnector(TCPConnector):
def __init__(self):

self.num_failures = 2
self.failures = 0
super().__init__()

async def connect(self, address, deserialize=True, **connection_args):
if self.failures > self.num_failures:
return await super().connect(address, deserialize, **connection_args)
else:
self.failures += 1
raise IOError()

class UnreliableBackend(TCPBackend):
_connector_class = UnreliableConnector

monkeypatch.setitem(backends, "tcp", UnreliableBackend())

listener = await listen("tcp://127.0.0.1:1234", echo)
try:
comm = await connect(listener.contact_address)
await comm.write(b"test")
msg = await comm.read()
assert msg == b"test"
finally:
listener.stop()


@pytest.mark.asyncio
async def test_handshake_slow_comm(monkeypatch):
class SlowComm(TCP):
def __init__(self, *args, delay_in_comm=0.5, **kwargs):
super().__init__(*args, **kwargs)
self.delay_in_comm = delay_in_comm

async def read(self, *args, **kwargs):
await asyncio.sleep(self.delay_in_comm)
return await super().read(*args, **kwargs)

async def write(self, *args, **kwargs):
await asyncio.sleep(self.delay_in_comm)
res = await super(type(self), self).write(*args, **kwargs)
return res

class SlowConnector(TCPConnector):
comm_class = SlowComm

class SlowBackend(TCPBackend):
_connector_class = SlowConnector

monkeypatch.setitem(backends, "tcp", SlowBackend())

listener = await listen("tcp://127.0.0.1:1234", echo)
try:
comm = await connect(listener.contact_address)
await comm.write(b"test")
msg = await comm.read()
assert msg == b"test"

import dask

with dask.config.set({"distributed.comm.timeouts.connect": "100ms"}):
with pytest.raises(
IOError, match="Timed out during handshake while connecting to"
):
await connect(listener.contact_address)
finally:
listener.stop()


async def check_connect_timeout(addr):
t1 = time()
with pytest.raises(IOError):
Expand Down