diff --git a/fedn/network/combiner/hooks/hook_client.py b/fedn/network/combiner/hooks/hook_client.py index 340305881..3219e5a17 100644 --- a/fedn/network/combiner/hooks/hook_client.py +++ b/fedn/network/combiner/hooks/hook_client.py @@ -17,19 +17,22 @@ class CombinerHookInterface: def __init__(self): """Initialize CombinerHookInterface client.""" - self.hook_service_host = os.getenv("HOOK_SERVICE_HOST", "hook:12081") - self.channel = grpc.insecure_channel( - self.hook_service_host, - options=[ - ("grpc.keepalive_time_ms", 30000), # 30 seconds ping interval - ("grpc.keepalive_timeout_ms", 5000), # 5 seconds timeout for a response - ("grpc.keepalive_permit_without_calls", 1), # allow keepalives even with no active calls - ("grpc.enable_retries", 1), # automatic retries - ("grpc.initial_reconnect_backoff_ms", 1000), # initial delay before retrying - ("grpc.max_reconnect_backoff_ms", 5000), # maximum delay before retrying - ], - ) - self.stub = rpc.FunctionServiceStub(self.channel) + try: + self.hook_service_host = os.getenv("HOOK_SERVICE_HOST", "hook:12081") + self.channel = grpc.insecure_channel( + self.hook_service_host, + options=[ + ("grpc.keepalive_time_ms", 30000), # 30 seconds ping interval + ("grpc.keepalive_timeout_ms", 5000), # 5 seconds timeout for a response + ("grpc.keepalive_permit_without_calls", 1), # allow keepalives even with no active calls + ("grpc.enable_retries", 1), # automatic retries + ("grpc.initial_reconnect_backoff_ms", 1000), # initial delay before retrying + ("grpc.max_reconnect_backoff_ms", 5000), # maximum delay before retrying + ], + ) + self.stub = rpc.FunctionServiceStub(self.channel) + except Exception as e: + logger.warning(f"Failed to initialize connection to hooks container with error {e}") def provided_functions(self, server_functions: str): """Communicates to hook container and asks which functions are available. @@ -39,10 +42,14 @@ def provided_functions(self, server_functions: str): :return: dictionary specifing which functions are implemented. :rtype: dict """ - request = fedn.ProvidedFunctionsRequest(function_code=server_functions) + try: + request = fedn.ProvidedFunctionsRequest(function_code=server_functions) - response = self.stub.HandleProvidedFunctions(request) - return response.available_functions + response = self.stub.HandleProvidedFunctions(request) + return response.available_functions + except Exception as e: + logger.warning(f"Was not able to communicate to hooks container due to: {e}") + return {} def client_settings(self, global_model) -> dict: """Communicates to hook container to get a client config. diff --git a/fedn/network/combiner/roundhandler.py b/fedn/network/combiner/roundhandler.py index 5eb5387d8..fa3d83e8f 100644 --- a/fedn/network/combiner/roundhandler.py +++ b/fedn/network/combiner/roundhandler.py @@ -121,7 +121,7 @@ def push_round_config(self, round_config: RoundConfig) -> str: raise return round_config["_job_id"] - def _training_round(self, config, clients, provided_functions): + def _training_round(self, config: dict, clients: list, provided_functions: dict): """Send model update requests to clients and aggregate results. :param config: The round config object (passed to the client). @@ -141,7 +141,7 @@ def _training_round(self, config, clients, provided_functions): session_id = config["session_id"] model_id = config["model_id"] - if provided_functions["client_settings"]: + if provided_functions.get("client_settings", False): global_model_bytes = self.modelservice.temp_model_storage.get(model_id) client_settings = self.hook_interface.client_settings(global_model_bytes) config["client_settings"] = client_settings @@ -172,7 +172,7 @@ def _training_round(self, config, clients, provided_functions): parameters = Parameters(dict_parameters) else: parameters = None - if provided_functions["aggregate"]: + if provided_functions.get("aggregate", False): previous_model_bytes = self.modelservice.temp_model_storage.get(model_id) model, data = self.hook_interface.aggregate(previous_model_bytes, self.update_handler, helper, delete_models=delete_models) else: @@ -326,7 +326,7 @@ def execute_training_round(self, config): provided_functions = self.hook_interface.provided_functions(self.server_functions) - if provided_functions["client_selection"]: + if provided_functions.get("client_selection", False): clients = self.hook_interface.client_selection(clients=self.server.get_active_trainers()) else: clients = self._assign_round_clients(self.server.max_clients)