Skip to content

Commit

Permalink
Fix/SK-1178 | Set client_id in ModelRequest + single hearbeat function (
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrede authored Nov 7, 2024
1 parent 70a6486 commit 60926be
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 56 deletions.
22 changes: 8 additions & 14 deletions fedn/network/clients/client_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ def init_grpchandler(self, config: GrpcConnectionOptions, client_name: str, toke
logger.error("Error: Could not initialize GRPC connection")
return False


def send_heartbeats(self, client_name: str, client_id: str, update_frequency: float = 2.0):
self.grpc_handler.send_heartbeats(client_name=client_name, client_id=client_id, update_frequency=update_frequency)

Expand All @@ -220,24 +219,25 @@ def _task_stream_callback(self, request):
elif request.type == fedn.StatusType.MODEL_VALIDATION:
self.validate(request)

def get_model_from_combiner(self, id: str, client_name: str, timeout: int = 20) -> BytesIO:
return self.grpc_handler.get_model_from_combiner(id=id, client_name=client_name, timeout=timeout)
def get_model_from_combiner(self, id: str, client_id: str, timeout: int = 20) -> BytesIO:
return self.grpc_handler.get_model_from_combiner(id=id, client_name=client_id, timeout=timeout)

def send_model_to_combiner(self, model: BytesIO, id: str):
return self.grpc_handler.send_model_to_combiner(model, id)

def send_status(self, msg: str, log_level=fedn.Status.INFO, type=None, request=None, sesssion_id: str = None, sender_name: str = None):
return self.grpc_handler.send_status(msg, log_level, type, request, sesssion_id, sender_name)

def send_model_update(self,
def send_model_update(
self,
sender_name: str,
sender_role: fedn.Role,
client_id: str,
model_id: str,
model_update_id: str,
receiver_name: str,
receiver_role: fedn.Role,
meta: dict
meta: dict,
) -> bool:
return self.grpc_handler.send_model_update(
sender_name=sender_name,
Expand All @@ -247,17 +247,11 @@ def send_model_update(self,
model_update_id=model_update_id,
receiver_name=receiver_name,
receiver_role=receiver_role,
meta=meta
meta=meta,
)

def send_model_validation(self,
sender_name: str,
receiver_name: str,
receiver_role: fedn.Role,
model_id: str,
metrics: dict,
correlation_id: str,
session_id: str
def send_model_validation(
self, sender_name: str, receiver_name: str, receiver_role: fedn.Role, model_id: str, metrics: dict, correlation_id: str, session_id: str
) -> bool:
return self.grpc_handler.send_model_validation(sender_name, receiver_name, receiver_role, model_id, metrics, correlation_id, session_id)

Expand Down
36 changes: 17 additions & 19 deletions fedn/network/clients/client_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,17 @@ def to_json(self):


class Client:
def __init__(self,
api_url: str,
api_port: int,
client_obj: ClientOptions,
combiner_host: str = None,
combiner_port: int = None,
token: str = None,
package_checksum: str = None,
helper_type: str = None
):
def __init__(
self,
api_url: str,
api_port: int,
client_obj: ClientOptions,
combiner_host: str = None,
combiner_port: int = None,
token: str = None,
package_checksum: str = None,
helper_type: str = None,
):
self.api_url = api_url
self.api_port = api_port
self.combiner_host = combiner_host
Expand Down Expand Up @@ -149,7 +150,6 @@ def on_validation(self, request):
logger.info("Received validation request")
self._process_validation_request(request)


def _process_training_request(self, request) -> Tuple[str, dict]:
"""Process a training (model update) request.
Expand All @@ -164,16 +164,14 @@ def _process_training_request(self, request) -> Tuple[str, dict]:
session_id: str = request.session_id

self.client_api.send_status(
f"\t Starting processing of training request for model_id {model_id}",
sesssion_id=session_id,
sender_name=self.client_obj.name
f"\t Starting processing of training request for model_id {model_id}", sesssion_id=session_id, sender_name=self.client_obj.name
)

try:
meta = {}
tic = time.time()

model = self.client_api.get_model_from_combiner(id=str(model_id), client_name=self.client_obj.client_id)
model = self.client_api.get_model_from_combiner(id=str(model_id), client_id=self.client_obj.client_id)

if model is None:
logger.error("Could not retrieve model from combiner. Aborting training request.")
Expand Down Expand Up @@ -246,7 +244,7 @@ def _process_training_request(self, request) -> Tuple[str, dict]:
type=fedn.StatusType.MODEL_UPDATE,
request=request,
sesssion_id=session_id,
sender_name=self.client_obj.name
sender_name=self.client_obj.name,
)

def _process_validation_request(self, request):
Expand All @@ -266,7 +264,7 @@ def _process_validation_request(self, request):
self.client_api.send_status(f"Processing {cmd} request for model_id {model_id}", sesssion_id=session_id, sender_name=self.client_obj.name)

try:
model = self.client_api.get_model_from_combiner(id=str(model_id), client_name=self.client_obj.client_id)
model = self.client_api.get_model_from_combiner(id=str(model_id), client_id=self.client_obj.client_id)
if model is None:
logger.error("Could not retrieve model from combiner. Aborting validation request.")
return
Expand Down Expand Up @@ -318,13 +316,13 @@ def _process_validation_request(self, request):
type=fedn.StatusType.MODEL_VALIDATION,
request=validation,
sesssion_id=request.session_id,
sender_name=self.client_obj.name
sender_name=self.client_obj.name,
)
else:
self.client_api.send_status(
"Client {} failed to complete model validation.".format(self.name),
log_level=fedn.Status.WARNING,
request=request,
sesssion_id=request.session_id,
sender_name=self.client_obj.name
sender_name=self.client_obj.name,
)
71 changes: 49 additions & 22 deletions fedn/network/clients/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,24 @@
from fedn.common.log_config import logger
from fedn.network.combiner.modelservice import upload_request_generator

# Keepalive settings: these help keep the connection open for long-lived clients
KEEPALIVE_TIME_MS = 1 * 1000 # send keepalive ping every 60 seconds
KEEPALIVE_TIMEOUT_MS = 30 * 1000 # wait 20 seconds for keepalive ping ack before considering connection dead
KEEPALIVE_PERMIT_WITHOUT_CALLS = True # allow keepalive pings even when there are no RPCs
MAX_CONNECTION_IDLE_MS = 30000
MAX_CONNECTION_AGE_GRACE_MS = "INT_MAX" # keep connection open indefinitely
CLIENT_IDLE_TIMEOUT_MS = 30000

GRPC_OPTIONS = [
("grpc.keepalive_time_ms", KEEPALIVE_TIME_MS),
("grpc.keepalive_timeout_ms", KEEPALIVE_TIMEOUT_MS),
("grpc.keepalive_permit_without_calls", KEEPALIVE_PERMIT_WITHOUT_CALLS),
("grpc.http2.max_pings_without_data", 0), # unlimited pings without data
("grpc.max_connection_idle_ms", MAX_CONNECTION_IDLE_MS),
("grpc.max_connection_age_grace_ms", MAX_CONNECTION_AGE_GRACE_MS),
("grpc.client_idle_timeout_ms", CLIENT_IDLE_TIMEOUT_MS),
]


class GrpcAuth(grpc.AuthMetadataPlugin):
def __init__(self, key):
Expand Down Expand Up @@ -61,11 +79,6 @@ def _init_secure_channel(self, host: str, port: int, token: str):
url = f"{host}:{port}"
logger.info(f"Connecting (GRPC) to {url}")

# Keepalive settings: these help keep the connection open for long-lived clients
KEEPALIVE_TIME_MS = 60 * 1000 # send keepalive ping every 60 seconds
KEEPALIVE_TIMEOUT_MS = 20 * 1000 # wait 20 seconds for keepalive ping ack before considering connection dead
KEEPALIVE_PERMIT_WITHOUT_CALLS = True # allow keepalive pings even when there are no RPCs

if os.getenv("FEDN_GRPC_ROOT_CERT_PATH"):
logger.info("Using root certificate from environment variable for GRPC channel.")
with open(os.environ["FEDN_GRPC_ROOT_CERT_PATH"], "rb") as f:
Expand All @@ -80,34 +93,48 @@ def _init_secure_channel(self, host: str, port: int, token: str):
self.channel = grpc.secure_channel(
"{}:{}".format(host, str(port)),
grpc.composite_channel_credentials(credentials, auth_creds),
options=[
("grpc.keepalive_time_ms", KEEPALIVE_TIME_MS),
("grpc.keepalive_timeout_ms", KEEPALIVE_TIMEOUT_MS),
("grpc.keepalive_permit_without_calls", KEEPALIVE_PERMIT_WITHOUT_CALLS),
("grpc.http2.max_pings_without_data", 0), # unlimited pings without data
],
options=GRPC_OPTIONS,
)

def _init_insecure_channel(self, host: str, port: int):
url = f"{host}:{port}"
logger.info(f"Connecting (GRPC) to {url}")
self.channel = grpc.insecure_channel(url)
self.channel = grpc.insecure_channel(
url,
options=GRPC_OPTIONS,
)

def send_heartbeats(self, client_name: str, client_id: str, update_frequency: float = 2.0):
def heartbeat(self, client_name: str, client_id: str):
"""Send a heartbeat to the combiner.
:return: Response from the combiner.
:rtype: fedn.Response
"""
heartbeat = fedn.Heartbeat(sender=fedn.Client(name=client_name, role=fedn.WORKER, client_id=client_id))

try:
logger.info("Sending heartbeat to combiner")
response = self.connectorStub.SendHeartbeat(heartbeat, metadata=self.metadata)
except grpc.RpcError as e:
raise e
except Exception as e:
logger.error(f"GRPC (SendHeartbeat): An error occurred: {e}")
self._disconnect()
raise e
return response

def send_heartbeats(self, client_name: str, client_id: str, update_frequency: float = 2.0):
send_hearbeat = True
while send_hearbeat:
try:
logger.info("Sending heartbeat to combiner")
self.connectorStub.SendHeartbeat(heartbeat, metadata=self.metadata)
response = self.heartbeat(client_name, client_id)
except grpc.RpcError as e:
return self._handle_grpc_error(e, "SendHeartbeat", lambda: self.send_heartbeats(client_name, client_id, update_frequency))
except Exception as e:
logger.error(f"GRPC (SendHeartbeat): An error occurred: {e}")
self._disconnect()
if isinstance(response, fedn.Response):
logger.info("Heartbeat successful.")
else:
logger.error("Heartbeat failed.")
send_hearbeat = False

time.sleep(update_frequency)

def listen_to_task_stream(self, client_name: str, client_id: str, callback: Callable[[Any], None]):
Expand Down Expand Up @@ -179,7 +206,7 @@ def send_status(self, msg: str, log_level=fedn.Status.INFO, type=None, request=N
logger.error(f"GRPC (SendStatus): An error occurred: {e}")
self._disconnect()

def get_model_from_combiner(self, id: str, client_name: str, timeout: int = 20) -> BytesIO:
def get_model_from_combiner(self, id: str, client_id: str, timeout: int = 20) -> BytesIO:
"""Fetch a model from the assigned combiner.
Downloads the model update object via a gRPC streaming channel.
Expand All @@ -191,7 +218,7 @@ def get_model_from_combiner(self, id: str, client_name: str, timeout: int = 20)
data = BytesIO()
time_start = time.time()
request = fedn.ModelRequest(id=id)
request.sender.name = client_name
request.sender.client_id = client_id
request.sender.role = fedn.WORKER

try:
Expand All @@ -211,7 +238,7 @@ def get_model_from_combiner(self, id: str, client_name: str, timeout: int = 20)
return None
continue
except grpc.RpcError as e:
return self._handle_grpc_error(e, "Download", lambda: self.get_model_from_combiner(id, client_name, timeout))
return self._handle_grpc_error(e, "Download", lambda: self.get_model_from_combiner(id, client_id, timeout))
except Exception as e:
logger.error(f"GRPC (Download): An error occurred: {e}")
self._disconnect()
Expand Down
2 changes: 1 addition & 1 deletion fedn/network/combiner/modelservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def Download(self, request, context):
:return: A model response iterator.
:rtype: :class:`fedn.network.grpc.fedn_pb2.ModelResponse`
"""
logger.info(f"grpc.ModelService.Download: {request.sender.role}:{request.sender.name} requested model {request.id}")
logger.info(f"grpc.ModelService.Download: {request.sender.role}:{request.sender.client_id} requested model {request.id}")
try:
status = self.temp_model_storage.get_model_metadata(request.id)
if status != fedn.ModelStatus.OK:
Expand Down

0 comments on commit 60926be

Please sign in to comment.