Skip to content

Commit

Permalink
add keepalive to client
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrede committed Oct 18, 2024
1 parent c45477b commit 2bd2cac
Showing 1 changed file with 25 additions and 33 deletions.
58 changes: 25 additions & 33 deletions fedn/network/clients/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self, key):
def __call__(self, context, callback):
callback((("authorization", f"{FEDN_AUTH_SCHEME} {self._key}"),), None)


def _get_ssl_certificate(domain, port=443):
context = SSL.Context(SSL.TLSv1_2_METHOD)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
Expand All @@ -39,6 +40,7 @@ def _get_ssl_certificate(domain, port=443):
cert = cert.to_cryptography().public_bytes(Encoding.PEM).decode()
return cert


class GrpcHandler:
def __init__(self, host: str, port: int, name: str, token: str, combiner_name: str):
self.metadata = [
Expand All @@ -59,6 +61,11 @@ def _init_secure_channel(self, host: str, port: int, token: str):
url = f"{host}:{port}"
logger.info(f"Connecting (GRPC) to {url}")

# Keepalive settings: these help keep the connection open for long-lived clients
KEEPALIVE_TIME_MS = 60 * 1000 # send keepalive ping every 60 seconds
KEEPALIVE_TIMEOUT_MS = 20 * 1000 # wait 20 seconds for keepalive ping ack before considering connection dead
KEEPALIVE_PERMIT_WITHOUT_CALLS = True # allow keepalive pings even when there are no RPCs

if os.getenv("FEDN_GRPC_ROOT_CERT_PATH"):
logger.info("Using root certificate from environment variable for GRPC channel.")
with open(os.environ["FEDN_GRPC_ROOT_CERT_PATH"], "rb") as f:
Expand All @@ -70,7 +77,16 @@ def _init_secure_channel(self, host: str, port: int, token: str):
cert = _get_ssl_certificate(host, port)
credentials = grpc.ssl_channel_credentials(cert.encode("utf-8"))
auth_creds = grpc.metadata_call_credentials(GrpcAuth(token))
self.channel = grpc.secure_channel("{}:{}".format(host, str(port)), grpc.composite_channel_credentials(credentials, auth_creds))
self.channel = grpc.secure_channel(
"{}:{}".format(host, str(port)),
grpc.composite_channel_credentials(credentials, auth_creds),
options=[
("grpc.keepalive_time_ms", KEEPALIVE_TIME_MS),
("grpc.keepalive_timeout_ms", KEEPALIVE_TIMEOUT_MS),
("grpc.keepalive_permit_without_calls", KEEPALIVE_PERMIT_WITHOUT_CALLS),
("grpc.http2.max_pings_without_data", 0), # unlimited pings without data
],
)

def _init_insecure_channel(self, host: str, port: int):
url = f"{host}:{port}"
Expand Down Expand Up @@ -115,7 +131,7 @@ def listen_to_task_stream(self, client_name: str, client_id: str, callback: Call
type=fedn.StatusType.MODEL_UPDATE_REQUEST,
request=request,
sesssion_id=request.session_id,
sender_name=client_name
sender_name=client_name,
)

logger.info(f"Received task request of type {request.type} for model_id {request.model_id}")
Expand Down Expand Up @@ -234,15 +250,16 @@ def send_model_to_combiner(self, model: BytesIO, id: str):

return result

def send_model_update(self,
def send_model_update(
self,
sender_name: str,
sender_role: fedn.Role,
client_id: str,
model_id: str,
model_update_id: str,
receiver_name: str,
receiver_role: fedn.Role,
meta: dict
meta: dict,
):
update = fedn.ModelUpdate()
update.sender.name = sender_name
Expand All @@ -260,32 +277,16 @@ def send_model_update(self,
_ = self.combinerStub.SendModelUpdate(update, metadata=self.metadata)
except grpc.RpcError as e:
return self._handle_grpc_error(
e,
"SendModelUpdate",
lambda: self.send_model_update(
sender_name,
sender_role,
model_id,
model_update_id,
receiver_name,
receiver_role,
meta
)
e, "SendModelUpdate", lambda: self.send_model_update(sender_name, sender_role, model_id, model_update_id, receiver_name, receiver_role, meta)
)
except Exception as e:
logger.error(f"GRPC (SendModelUpdate): An error occurred: {e}")
self._disconnect()

return True

def send_model_validation(self,
sender_name: str,
receiver_name: str,
receiver_role: fedn.Role,
model_id: str,
metrics: str,
correlation_id: str,
session_id: str
def send_model_validation(
self, sender_name: str, receiver_name: str, receiver_role: fedn.Role, model_id: str, metrics: str, correlation_id: str, session_id: str
) -> bool:
validation = fedn.ModelValidation()
validation.sender.name = sender_name
Expand All @@ -298,23 +299,14 @@ def send_model_validation(self,
validation.correlation_id = correlation_id
validation.session_id = session_id


try:
logger.info("Sending model validation to combiner.")
_ = self.combinerStub.SendModelValidation(validation, metadata=self.metadata)
except grpc.RpcError as e:
return self._handle_grpc_error(
e,
"SendModelValidation",
lambda: self.send_model_validation(
sender_name,
receiver_name,
receiver_role,
model_id,
metrics,
correlation_id,
session_id
)
lambda: self.send_model_validation(sender_name, receiver_name, receiver_role, model_id, metrics, correlation_id, session_id),
)
except Exception as e:
logger.error(f"GRPC (SendModelValidation): An error occurred: {e}")
Expand Down

0 comments on commit 2bd2cac

Please sign in to comment.