Skip to content

Commit

Permalink
Use exponential backoff for connection retries
Browse files Browse the repository at this point in the history
Calls to socket.connect() are non-blocking, hence all subsequent calls
to socket.sendall() will fail if the target KDC service is temporarily
or indefinitely unreachable. Since the kdcproxy task uses busy-looping,
it results in the journal to be flooded with warning logs.

This commit introduces a per-socket reactivation delay which increases
exponentially as the number of reties is incremented, until timeout is
reached (i.e. 100ms, 200ms, 400ms, 800ms, 1.6s, 3.2s, ...).

Signed-off-by: Julien Rische <[email protected]>
  • Loading branch information
jrisc committed Nov 21, 2024
1 parent f61979e commit bac3c99
Showing 1 changed file with 60 additions and 3 deletions.
63 changes: 60 additions & 3 deletions kdcproxy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,37 @@ def __str__(self):
return "%d %s" % (self.code, httplib.responses[self.code])


class SocketException(Exception):

def __init__(self, message, sock):
super(Exception, self).__init__(message)
self.sockfno = sock.fileno()


class Application:
MAX_LENGTH = 128 * 1024
SOCKTYPES = {
"tcp": socket.SOCK_STREAM,
"udp": socket.SOCK_DGRAM,
}

def addr2socktypename(self, addr):
ret = None
for name in self.SOCKTYPES:
if self.SOCKTYPES[name] == addr[1]:
ret = name
break
return ret

def __init__(self):
self.__resolver = MetaResolver()

def __await_reply(self, pr, rsocks, wsocks, timeout):
starting_time = time.time()
send_error = None
recv_error = None
failing_sock = None
reactivations = {}
extra = 0
read_buffers = {}
while (timeout + extra) > time.time():
Expand All @@ -92,6 +112,12 @@ def __await_reply(self, pr, rsocks, wsocks, timeout):
pass

for sock in w:
# Fetch reactivation tuple:
# 1st element: reactivation index (-1 = first activation)
# 2nd element: planned reactivation time (0.0 = now)
(rn, rt) = reactivations.get(sock, (-1, 0.0))
if rt > time.time():
continue
try:
if self.sock_type(sock) == socket.SOCK_DGRAM:
# If we proxy over UDP, remove the 4-byte length
Expand All @@ -101,23 +127,44 @@ def __await_reply(self, pr, rsocks, wsocks, timeout):
sock.sendall(pr.request)
extra = 10 # New connections get 10 extra seconds
except Exception as e:
logging.warning("Conection broken while writing (%s)", e)
send_error = e
failing_sock = sock
reactivations[sock] = (rn + 1,
time.time() + 2.0**(rn + 1) / 10)
continue
if sock in reactivations:
del reactivations[sock]
rsocks.append(sock)
wsocks.remove(sock)

for sock in r:
try:
reply = self.__handle_recv(sock, read_buffers)
except Exception as e:
logging.warning("Connection broken while reading (%s)", e)
recv_error = e
failing_sock = sock
if self.sock_type(sock) == socket.SOCK_STREAM:
# Remove broken TCP socket from readers
rsocks.remove(sock)
else:
if reply is not None:
return reply

if reactivations:
raise SocketException("Timeout while sending packets after %.2fs "
"and %d tries: %s" % (
(timeout + extra) - starting_time,
sum(map(lambda r: r[0],
reactivations.values())),
send_error),
failing_sock)
elif recv_error is not None:
raise SocketException("Timeout while receiving packets after "
"%.2fs: %s" % (
(timeout + extra) - starting_time,
recv_error),
failing_sock)

return None

def __handle_recv(self, sock, read_buffers):
Expand Down Expand Up @@ -215,6 +262,7 @@ def __call__(self, env, start_response):
reply = None
wsocks = []
rsocks = []
sockfno2addr = {}
for server in map(urlparse.urlparse, servers):
# Enforce valid, supported URIs
scheme = server.scheme.lower().split("+", 1)
Expand Down Expand Up @@ -261,6 +309,7 @@ def __call__(self, env, start_response):
continue
except io.BlockingIOError:
pass
sockfno2addr[sock.fileno()] = addr
wsocks.append(sock)

# Resend packets to UDP servers
Expand All @@ -271,7 +320,15 @@ def __call__(self, env, start_response):

# Call select()
timeout = time.time() + (15 if addr is None else 2)
reply = self.__await_reply(pr, rsocks, wsocks, timeout)
try:
reply = self.__await_reply(pr, rsocks, wsocks, timeout)
except SocketException as e:
fail_addr = sockfno2addr[e.sockfno]
fail_socktype = self.addr2socktypename(fail_addr)
fail_ip = fail_addr[4][0]
fail_port = fail_addr[4][1]
logging.warning("Exchange with %s:[%s]:%d failed: %s",
fail_socktype, fail_ip, fail_port, e)
if reply is not None:
break

Expand Down

0 comments on commit bac3c99

Please sign in to comment.