From bc2c8a233a93ba355cb19867206c5bb6b326d1be Mon Sep 17 00:00:00 2001 From: "Shah, Karan" Date: Thu, 28 Nov 2024 22:12:37 +0530 Subject: [PATCH] Code adaptations for new rules Signed-off-by: Shah, Karan --- openfl/component/director/director.py | 2 +- .../component/aggregator/aggregator.py | 6 +- .../component/collaborator/collaborator.py | 4 +- .../workflow/interface/cli/cli_helper.py | 98 +-------------- .../workflow/interface/cli/collaborator.py | 2 +- .../workflow/interface/cli/workspace.py | 3 +- .../workflow/placement/placement.py | 4 +- .../workflow/runtime/federated_runtime.py | 3 +- .../workflow/runtime/local_runtime.py | 2 +- .../workflow/utilities/runtime_utils.py | 2 +- openfl/experimental/workflow/utilities/ui.py | 4 +- .../workflow/workspace_export/export.py | 9 +- openfl/federated/task/runner_xgb.py | 32 +++-- .../aggregation_functions/fed_bagging.py | 3 +- openfl/interface/cli_helper.py | 117 +----------------- openfl/interface/collaborator.py | 34 +++-- .../interface/interactive_api/experiment.py | 8 +- openfl/interface/workspace.py | 11 +- openfl/native/fastestimator.py | 4 +- openfl/native/native.py | 2 +- openfl/pipelines/eden_pipeline.py | 2 +- .../frameworks_adapters/flax_adapter.py | 2 +- openfl/transport/grpc/aggregator_client.py | 9 +- openfl/transport/grpc/aggregator_server.py | 9 +- openfl/utilities/fed_timer.py | 2 +- 25 files changed, 82 insertions(+), 292 deletions(-) diff --git a/openfl/component/director/director.py b/openfl/component/director/director.py index f4ca3cc731..79ac85d961 100644 --- a/openfl/component/director/director.py +++ b/openfl/component/director/director.py @@ -368,7 +368,7 @@ def update_envoy_status( if not shard_info: raise ShardNotFoundError(f"Unknown shard {envoy_name}") - shard_info["is_online"]: True + shard_info["is_online"] = True shard_info["is_experiment_running"] = is_experiment_running shard_info["valid_duration"] = 2 * self.envoy_health_check_period shard_info["last_updated"] = time.time() diff --git a/openfl/experimental/workflow/component/aggregator/aggregator.py b/openfl/experimental/workflow/component/aggregator/aggregator.py index 1e4d7b48e4..cd15b51570 100644 --- a/openfl/experimental/workflow/component/aggregator/aggregator.py +++ b/openfl/experimental/workflow/component/aggregator/aggregator.py @@ -47,8 +47,8 @@ def __init__( rounds_to_train: int = 1, checkpoint: bool = False, private_attributes_callable: Callable = None, - private_attributes_kwargs: Dict = {}, - private_attributes: Dict = {}, + private_attributes_kwargs: Dict = None, + private_attributes: Dict = None, single_col_cert_common_name: str = None, log_metric_callback: Callable = None, **kwargs, @@ -232,7 +232,7 @@ def call_checkpoint(self, ctx: Any, f: Callable, stream_buffer: bytes = None) -> f = pickle.loads(f) if isinstance(stream_buffer, bytes): # Set stream buffer as function parameter - setattr(f.__func__, "_stream_buffer", pickle.loads(stream_buffer)) + f.__func__._stream_buffer = pickle.loads(stream_buffer) checkpoint(ctx, f) diff --git a/openfl/experimental/workflow/component/collaborator/collaborator.py b/openfl/experimental/workflow/component/collaborator/collaborator.py index d8e8a7fcd7..e753180158 100644 --- a/openfl/experimental/workflow/component/collaborator/collaborator.py +++ b/openfl/experimental/workflow/component/collaborator/collaborator.py @@ -35,8 +35,8 @@ def __init__( federation_uuid: str, client: Any, private_attributes_callable: Any = None, - private_attributes_kwargs: Dict = {}, - private_attributes: Dict = {}, + private_attributes_kwargs: Dict = None, + private_attributes: Dict = None, **kwargs, ) -> None: self.name = collaborator_name diff --git a/openfl/experimental/workflow/interface/cli/cli_helper.py b/openfl/experimental/workflow/interface/cli/cli_helper.py index 65692d6c22..d782c3b7e6 100644 --- a/openfl/experimental/workflow/interface/cli/cli_helper.py +++ b/openfl/experimental/workflow/interface/cli/cli_helper.py @@ -5,9 +5,8 @@ """Module with auxiliary CLI helper functions.""" import os import re -import shutil from itertools import islice -from os import environ, stat +from os import environ from pathlib import Path from sys import argv @@ -31,20 +30,6 @@ def pretty(o): echo(style(f"{k:<{m}} : ", fg="blue") + style(f"{v}", fg="cyan")) -def tree(path): - """Print current directory file tree.""" - echo(f"+ {path}") - - for path in sorted(path.rglob("*")): - depth = len(path.relative_to(path).parts) - space = " " * depth - - if path.is_file(): - echo(f"{space}f {path.name}") - else: - echo(f"{space}d {path.name}") - - def print_tree( dir_path: Path, level: int = -1, @@ -91,87 +76,6 @@ def inner(dir_path: Path, prefix: str = "", level=-1): echo(f"\n{directories} directories" + (f", {files} files" if files else "")) -def copytree( - src, - dst, - symlinks=False, - ignore=None, - ignore_dangling_symlinks=False, - dirs_exist_ok=False, -): - """From Python 3.8 'shutil' which include 'dirs_exist_ok' option.""" - - with os.scandir(src) as itr: - entries = list(itr) - - copy_function = shutil.copy2 - - def _copytree(): - if ignore is not None: - ignored_names = ignore(os.fspath(src), [x.name for x in entries]) - else: - ignored_names = set() - - os.makedirs(dst, exist_ok=dirs_exist_ok) - errors = [] - use_srcentry = copy_function is shutil.copy2 or copy_function is shutil.copy - - for srcentry in entries: - if srcentry.name in ignored_names: - continue - srcname = os.path.join(src, srcentry.name) - dstname = os.path.join(dst, srcentry.name) - srcobj = srcentry if use_srcentry else srcname - try: - is_symlink = srcentry.is_symlink() - if is_symlink and os.name == "nt": - lstat = srcentry.stat(follow_symlinks=False) - if lstat.st_reparse_tag == stat.IO_REPARSE_TAG_MOUNT_POINT: - is_symlink = False - if is_symlink: - linkto = os.readlink(srcname) - if symlinks: - os.symlink(linkto, dstname) - shutil.copystat(srcobj, dstname, follow_symlinks=not symlinks) - else: - if not os.path.exists(linkto) and ignore_dangling_symlinks: - continue - if srcentry.is_dir(): - copytree( - srcobj, - dstname, - symlinks, - ignore, - dirs_exist_ok=dirs_exist_ok, - ) - else: - copy_function(srcobj, dstname) - elif srcentry.is_dir(): - copytree( - srcobj, - dstname, - symlinks, - ignore, - dirs_exist_ok=dirs_exist_ok, - ) - else: - copy_function(srcobj, dstname) - except OSError as why: - errors.append((srcname, dstname, str(why))) - except Exception as err: - errors.extend(err.args[0]) - try: - shutil.copystat(src, dst) - except OSError as why: - if getattr(why, "winerror", None) is None: - errors.append((src, dst, str(why))) - if errors: - raise Exception(errors) - return dst - - return _copytree() - - def get_workspace_parameter(name): """Get a parameter from the workspace config file (.workspace).""" # Update the .workspace file to show the current workspace plan diff --git a/openfl/experimental/workflow/interface/cli/collaborator.py b/openfl/experimental/workflow/interface/cli/collaborator.py index 7398f01536..52f679fa23 100644 --- a/openfl/experimental/workflow/interface/cli/collaborator.py +++ b/openfl/experimental/workflow/interface/cli/collaborator.py @@ -246,7 +246,7 @@ def certify_(collaborator_name, silent, request_pkg, import_): certify(collaborator_name, silent, request_pkg, import_) -def certify(collaborator_name, silent, request_pkg=None, import_=False): +def certify(collaborator_name, silent, request_pkg=None, import_=False): # noqa C901 """Sign/certify collaborator certificate key pair.""" common_name = f"{collaborator_name}" diff --git a/openfl/experimental/workflow/interface/cli/workspace.py b/openfl/experimental/workflow/interface/cli/workspace.py index b2f7f54a73..d89d4b239d 100644 --- a/openfl/experimental/workflow/interface/cli/workspace.py +++ b/openfl/experimental/workflow/interface/cli/workspace.py @@ -16,9 +16,8 @@ from tempfile import mkdtemp from typing import Tuple -from click import Choice +from click import Choice, confirm, echo, group, option, pass_context, style from click import Path as ClickPath -from click import confirm, echo, group, option, pass_context, style from cryptography.hazmat.primitives import serialization from openfl.cryptography.ca import generate_root_cert, generate_signing_csr, sign_certificate diff --git a/openfl/experimental/workflow/placement/placement.py b/openfl/experimental/workflow/placement/placement.py index 67458a0fca..348716e24f 100644 --- a/openfl/experimental/workflow/placement/placement.py +++ b/openfl/experimental/workflow/placement/placement.py @@ -41,7 +41,7 @@ def wrapper(*args, **kwargs): print(f"\nCalling {f.__name__}") with RedirectStdStreamContext() as context_stream: # context_stream capture stdout and stderr for the function f.__name__ - setattr(wrapper, "_stream_buffer", context_stream) + wrapper._stream_buffer = context_stream f(*args, **kwargs) return wrapper @@ -92,7 +92,7 @@ def wrapper(*args, **kwargs): print(f"\nCalling {f.__name__}") with RedirectStdStreamContext() as context_stream: # context_stream capture stdout and stderr for the function f.__name__ - setattr(wrapper, "_stream_buffer", context_stream) + wrapper._stream_buffer = context_stream f(*args, **kwargs) return wrapper diff --git a/openfl/experimental/workflow/runtime/federated_runtime.py b/openfl/experimental/workflow/runtime/federated_runtime.py index d07d477580..885dbe261c 100644 --- a/openfl/experimental/workflow/runtime/federated_runtime.py +++ b/openfl/experimental/workflow/runtime/federated_runtime.py @@ -11,8 +11,7 @@ from openfl.experimental.workflow.runtime.runtime import Runtime if TYPE_CHECKING: - from openfl.experimental.workflow.interface import Aggregator - from openfl.experimental.workflow.interface import Collaborator + from openfl.experimental.workflow.interface import Aggregator, Collaborator from typing import List, Type diff --git a/openfl/experimental/workflow/runtime/local_runtime.py b/openfl/experimental/workflow/runtime/local_runtime.py index 978a684d65..beb2818e7c 100644 --- a/openfl/experimental/workflow/runtime/local_runtime.py +++ b/openfl/experimental/workflow/runtime/local_runtime.py @@ -420,7 +420,7 @@ def __get_aggregator_object(self, aggregator: Type[Aggregator]) -> Any: ) interface_module = importlib.import_module("openfl.experimental.workflow.interface") - aggregator_class = getattr(interface_module, "Aggregator") + aggregator_class = interface_module.Aggregator aggregator_actor = ray.remote(aggregator_class).options( num_cpus=agg_cpus, num_gpus=agg_gpus diff --git a/openfl/experimental/workflow/utilities/runtime_utils.py b/openfl/experimental/workflow/utilities/runtime_utils.py index bd4ce92cb0..c4c33acdf4 100644 --- a/openfl/experimental/workflow/utilities/runtime_utils.py +++ b/openfl/experimental/workflow/utilities/runtime_utils.py @@ -176,7 +176,7 @@ def check_resource_allocation(num_gpus, each_participant_gpu_usage): # buffer to cycle though since need_assigned will change sizes as we # assign participants current_dict = need_assigned.copy() - for i, (participant_name, participant_gpu_usage) in enumerate(current_dict.items()): + for (participant_name, participant_gpu_usage) in current_dict.items(): if gpu == 0: break if gpu < participant_gpu_usage: diff --git a/openfl/experimental/workflow/utilities/ui.py b/openfl/experimental/workflow/utilities/ui.py index 1ac59d1f93..d9b0c19207 100644 --- a/openfl/experimental/workflow/utilities/ui.py +++ b/openfl/experimental/workflow/utilities/ui.py @@ -27,7 +27,7 @@ def __init__( flow_obj, run_id, show_html=False, - ds_root=f"{Path.home()}/.metaflow", + ds_root=None, ): """Initializes the InspectFlow with a flow object, run ID, an optional flag to show the UI in a web browser, and an optional root directory @@ -41,7 +41,7 @@ def __init__( ds_root (str, optional): The root directory for the data store. Defaults to "~/.metaflow". """ - self.ds_root = ds_root + self.ds_root = ds_root or f"{Path.home()}/.metaflow" self.show_html = show_html self.run_id = run_id self.flow_name = flow_obj.__class__.__name__ diff --git a/openfl/experimental/workflow/workspace_export/export.py b/openfl/experimental/workflow/workspace_export/export.py index e0cc893f17..576a6c9204 100644 --- a/openfl/experimental/workflow/workspace_export/export.py +++ b/openfl/experimental/workflow/workspace_export/export.py @@ -294,9 +294,7 @@ def generate_plan_yaml(self): """ Generates plan.yaml """ - flspec = getattr( - importlib.import_module("openfl.experimental.workflow.interface"), "FLSpec" - ) + flspec = importlib.import_module("openfl.experimental.workflow.interface").FLSpec # Get flow classname _, self.flow_class_name = self.__get_class_name_and_sourcecode_from_parent_class(flspec) # Get expected arguments of flow class @@ -343,10 +341,7 @@ def generate_data_yaml(self): # If flow classname is not yet found if not hasattr(self, "flow_class_name"): - flspec = getattr( - importlib.import_module("openfl.experimental.workflow.interface"), - "FLSpec", - ) + flspec = importlib.import_module("openfl.experimental.workflow.interface").FLSpec _, self.flow_class_name = self.__get_class_name_and_sourcecode_from_parent_class(flspec) # Import flow class diff --git a/openfl/federated/task/runner_xgb.py b/openfl/federated/task/runner_xgb.py index ae44210ce2..a5f5101b2e 100644 --- a/openfl/federated/task/runner_xgb.py +++ b/openfl/federated/task/runner_xgb.py @@ -48,7 +48,8 @@ def __init__(self, **kwargs): Attributes: global_model (xgb.Booster): The global XGBoost model. - required_tensorkeys_for_function (dict): A dictionary to store required tensor keys for each function. + required_tensorkeys_for_function (dict): A dictionary to store required tensor keys + for each function. """ super().__init__(**kwargs) self.global_model = None @@ -58,11 +59,13 @@ def rebuild_model(self, input_tensor_dict): """ Rebuilds the model using the provided input tensor dictionary. - This method checks if the 'local_tree' key in the input tensor dictionary is either a non-empty numpy array - If this condition is met, it updates the internal tensor dictionary with the provided input. + This method checks if the 'local_tree' key in the input tensor dictionary is either a + non-empty numpy array. If this condition is met, it updates the internal tensor dictionary + with the provided input. Parameters: - input_tensor_dict (dict): A dictionary containing tensor data. It must include the key 'local_tree' + input_tensor_dict (dict): A dictionary containing tensor data. + It must include the key 'local_tree' Returns: None @@ -90,11 +93,13 @@ def validate_task(self, col_name, round_num, input_tensor_dict, **kwargs): """ data = self.data_loader.get_valid_dmatrix() - # during agg validation, self.bst will still be None. during local validation, it will have a value - no need to rebuild + # during agg validation, self.bst will still be None. during local validation, + # it will have a value - no need to rebuild if self.bst is None: self.rebuild_model(input_tensor_dict) - # if self.bst is still None after rebuilding, then there was no initial global model, so set metric to 0 + # if self.bst is still None after rebuilding, then there was no initial global model, so + # set metric to 0 if self.bst is None: # for first round agg validation, there is no model so set metric to 0 # TODO: this is not robust, especially if using a loss metric @@ -188,16 +193,18 @@ def get_tensor_dict(self, with_opt_vars=False): """ Retrieves the tensor dictionary containing the model's tree structure. - This method returns a dictionary with the key 'local_tree', which contains the model's tree structure as a numpy array. - If the model has not been initialized (`self.bst` is None), it returns an empty numpy array. - If the global model is not set or is empty, it returns the entire model as a numpy array. - Otherwise, it returns only the trees added in the latest training session. + This method returns a dictionary with the key 'local_tree', which contains the model's tree + structure as a numpy array. If the model has not been initialized (`self.bst` is None), it + returns an empty numpy array. If the global model is not set or is empty, it returns the + entire model as a numpy array. Otherwise, it returns only the trees added in the latest + training session. Parameters: with_opt_vars (bool): N/A for XGBoost (Default=False). Returns: - dict: A dictionary with the key 'local_tree' containing the model's tree structure as a numpy array. + dict: A dictionary with the key 'local_tree' containing the model's tree structure as a + numpy array. """ if self.bst is None: @@ -377,7 +384,8 @@ def validate_(self, data) -> Metric: Validate the XGBoost model. Args: - validation_dataloader (dict): A dictionary containing the validation data with keys 'dmatrix' and 'labels'. + validation_dataloader (dict): A dictionary containing the validation data with keys + 'dmatrix' and 'labels'. Returns: Metric: A Metric object containing the validation accuracy. diff --git a/openfl/interface/aggregation_functions/fed_bagging.py b/openfl/interface/aggregation_functions/fed_bagging.py index 2e42072c66..2f09481179 100644 --- a/openfl/interface/aggregation_functions/fed_bagging.py +++ b/openfl/interface/aggregation_functions/fed_bagging.py @@ -35,7 +35,8 @@ def append_trees(global_model, local_trees): Parameters: global_model (dict): A dictionary representing the global model. - local_trees (list): A list of dictionaries representing the local trees to be appended to the global model. + local_trees (list): A list of dictionaries representing the local trees to be appended to the + global model. Returns: dict: The updated global model with the local trees appended. diff --git a/openfl/interface/cli_helper.py b/openfl/interface/cli_helper.py index a326527158..cfabe77a0b 100644 --- a/openfl/interface/cli_helper.py +++ b/openfl/interface/cli_helper.py @@ -5,9 +5,8 @@ """Module with auxiliary CLI helper functions.""" import os import re -import shutil from itertools import islice -from os import environ, stat +from os import environ from pathlib import Path from sys import argv @@ -35,25 +34,6 @@ def pretty(o): echo(style(f"{k:<{m}} : ", fg="blue") + style(f"{v}", fg="cyan")) -def tree(path): - """ - Print current directory file tree. - - Args: - path (str): The path of the directory. - """ - echo(f"+ {path}") - - for path in sorted(path.rglob("*")): - depth = len(path.relative_to(path).parts) - space = " " * depth - - if path.is_file(): - echo(f"{space}f {path.name}") - else: - echo(f"{space}d {path.name}") - - def print_tree( dir_path: Path, level: int = -1, @@ -108,101 +88,6 @@ def inner(dir_path: Path, prefix: str = "", level=-1): echo(f"\n{directories} directories" + (f", {files} files" if files else "")) -def copytree( - src, - dst, - symlinks=False, - ignore=None, - ignore_dangling_symlinks=False, - dirs_exist_ok=False, -): - """From Python 3.8 'shutil' which include 'dirs_exist_ok' option. - - Args: - src (str): The source directory. - dst (str): The destination directory. - symlinks (bool, optional): Whether to copy symlinks. Defaults to False. - ignore (callable, optional): A function that takes a directory name - and filenames as input parameters and returns a list of names to - ignore. Defaults to None. - ignore_dangling_symlinks (bool, optional): Whether to ignore dangling - symlinks. Defaults to False. - dirs_exist_ok (bool, optional): Whether to raise an exception in case - dst or any missing parent directory already exists. Defaults to - False. - """ - - with os.scandir(src) as itr: - entries = list(itr) - - copy_function = shutil.copy2 - - def _copytree(): - if ignore is not None: - ignored_names = ignore(os.fspath(src), [x.name for x in entries]) - else: - ignored_names = set() - - os.makedirs(dst, exist_ok=dirs_exist_ok) - errors = [] - use_srcentry = copy_function is shutil.copy2 or copy_function is shutil.copy - - for srcentry in entries: - if srcentry.name in ignored_names: - continue - srcname = os.path.join(src, srcentry.name) - dstname = os.path.join(dst, srcentry.name) - srcobj = srcentry if use_srcentry else srcname - try: - is_symlink = srcentry.is_symlink() - if is_symlink and os.name == "nt": - lstat = srcentry.stat(follow_symlinks=False) - if lstat.st_reparse_tag == stat.IO_REPARSE_TAG_MOUNT_POINT: - is_symlink = False - if is_symlink: - linkto = os.readlink(srcname) - if symlinks: - os.symlink(linkto, dstname) - shutil.copystat(srcobj, dstname, follow_symlinks=not symlinks) - else: - if not os.path.exists(linkto) and ignore_dangling_symlinks: - continue - if srcentry.is_dir(): - copytree( - srcobj, - dstname, - symlinks, - ignore, - dirs_exist_ok=dirs_exist_ok, - ) - else: - copy_function(srcobj, dstname) - elif srcentry.is_dir(): - copytree( - srcobj, - dstname, - symlinks, - ignore, - dirs_exist_ok=dirs_exist_ok, - ) - else: - copy_function(srcobj, dstname) - except OSError as why: - errors.append((srcname, dstname, str(why))) - except Exception as err: - errors.extend(err.args[0]) - try: - shutil.copystat(src, dst) - except OSError as why: - if getattr(why, "winerror", None) is None: - errors.append((src, dst, str(why))) - if errors: - raise Exception(errors) - return dst - - return _copytree() - - def get_workspace_parameter(name): """Get a parameter from the workspace config file (.workspace). diff --git a/openfl/interface/collaborator.py b/openfl/interface/collaborator.py index 9bad1e9716..6716fe338c 100644 --- a/openfl/interface/collaborator.py +++ b/openfl/interface/collaborator.py @@ -403,16 +403,11 @@ def certify(collaborator_name, silent, request_pkg=None, import_=False): signing_crt = read_crt(CERT_DIR / signing_crt_path) - echo( - "The CSR Hash for file " - + style(f"{file_name}.csr", fg="green") - + " = " - + style(f"{csr_hash}", fg="red") - ) + echo(f"The CSR Hash for file {file_name}.csr is {csr_hash}") if silent: - echo(" Signing COLLABORATOR certificate") - echo(" Warning: manual check of certificate hashes is bypassed in silent mode.") + echo("Signing COLLABORATOR certificate, " + "Warning: manual check of certificate hashes is bypassed in silent mode.") signed_col_cert = sign_certificate(csr, signing_key, signing_crt.subject) write_crt(signed_col_cert, f"{cert_name}.crt") register_collaborator(CERT_DIR / "client" / f"{file_name}.crt") @@ -458,13 +453,16 @@ def certify(collaborator_name, silent, request_pkg=None, import_=False): rmtree(tmp_dir) else: - # Copy the signed certificate and cert chain into PKI_DIR - previous_crts = glob(f"{CERT_DIR}/client/*.crt") - unpack_archive(import_, extract_dir=CERT_DIR) - updated_crts = glob(f"{CERT_DIR}/client/*.crt") - cert_difference = list(set(updated_crts) - set(previous_crts)) - if len(cert_difference) != 0: - crt = basename(cert_difference[0]) - echo(f"Certificate {crt} installed to PKI directory") - else: - echo("Certificate updated in the PKI directory") + _import_certificates(import_) + +def _import_certificates(archive: str): + # Copy the signed certificate and cert chain into PKI_DIR + previous_crts = glob(f"{CERT_DIR}/client/*.crt") + unpack_archive(archive, extract_dir=CERT_DIR) + updated_crts = glob(f"{CERT_DIR}/client/*.crt") + cert_difference = list(set(updated_crts) - set(previous_crts)) + if len(cert_difference) != 0: + crt = basename(cert_difference[0]) + echo(f"Certificate {crt} installed to PKI directory") + else: + echo("Certificate updated in the PKI directory") diff --git a/openfl/interface/interactive_api/experiment.py b/openfl/interface/interactive_api/experiment.py index ce970abaaf..afaad32938 100644 --- a/openfl/interface/interactive_api/experiment.py +++ b/openfl/interface/interactive_api/experiment.py @@ -355,7 +355,7 @@ def start( else: self.logger.info("Experiment could not be submitted to the director.") - def define_task_assigner(self, task_keeper, rounds_to_train): + def define_task_assigner(self, task_keeper, rounds_to_train): # noqa: C901 """Define task assigner by registered tasks. This method defines a task assigner based on the registered tasks. @@ -396,7 +396,6 @@ def define_task_assigner(self, task_keeper, rounds_to_train): "because only validation tasks were given" ) if is_train_task_exist and self.is_validate_task_exist: - def assigner(collaborators, round_number, **kwargs): tasks_by_collaborator = {} for collaborator in collaborators: @@ -406,10 +405,9 @@ def assigner(collaborators, round_number, **kwargs): tasks["aggregated_model_validate"], ] return tasks_by_collaborator - return assigner - elif not is_train_task_exist and self.is_validate_task_exist: + elif not is_train_task_exist and self.is_validate_task_exist: def assigner(collaborators, round_number, **kwargs): tasks_by_collaborator = {} for collaborator in collaborators: @@ -417,8 +415,8 @@ def assigner(collaborators, round_number, **kwargs): tasks["aggregated_model_validate"], ] return tasks_by_collaborator - return assigner + elif is_train_task_exist and not self.is_validate_task_exist: raise Exception("You should define validate task!") else: diff --git a/openfl/interface/workspace.py b/openfl/interface/workspace.py index b138ad67db..567a58f37d 100644 --- a/openfl/interface/workspace.py +++ b/openfl/interface/workspace.py @@ -15,9 +15,8 @@ from sys import executable from typing import Union -from click import Choice +from click import Choice, echo, group, option, pass_context from click import Path as ClickPath -from click import echo, group, option, pass_context from cryptography.hazmat.primitives import serialization from openfl.cryptography.ca import generate_root_cert, generate_signing_csr, sign_certificate @@ -395,9 +394,9 @@ def export_() -> str: type=str, required=False, help=( - "Path to an enclave signing key. If not provided, a key will be auto-generated in the workspace. " - "Note that this command builds a TEE-ready image, key is NOT packaged along with the image. " - "You have the flexibility to not run inside a TEE later." + "Path to an enclave signing key. If not provided, a key will be auto-generated in the " + "workspace. Note that this command builds a TEE-ready image, key is NOT packaged along " + "with the image. You have the flexibility to not run inside a TEE later." ), ) @option( @@ -510,7 +509,7 @@ def _execute(cmd: str, verbose=True) -> None: """ logging.info(f"Executing: {cmd}") process = subprocess.Popen(cmd, shell=True, stderr=subprocess.STDOUT, stdout=subprocess.PIPE) - stdout_log = list() + stdout_log = [] for line in process.stdout: msg = line.rstrip().decode("utf-8") stdout_log.append(msg) diff --git a/openfl/native/fastestimator.py b/openfl/native/fastestimator.py index b794d30387..b99765fabd 100644 --- a/openfl/native/fastestimator.py +++ b/openfl/native/fastestimator.py @@ -44,14 +44,12 @@ def __init__(self, estimator, override_config: dict = None, **kwargs): if override_config: fx.update_plan(override_config) - def fit(self): + def fit(self): # noqa: C901 """Runs the estimator in federated mode.""" - file = Path(__file__).resolve() # interface root, containing command modules root = file.parent.resolve() work = Path.cwd().resolve() - path.append(str(root)) path.insert(0, str(work)) diff --git a/openfl/native/native.py b/openfl/native/native.py index 88b8c5f426..14c9cda9c9 100644 --- a/openfl/native/native.py +++ b/openfl/native/native.py @@ -85,7 +85,7 @@ def flatten(config, return_complete=False): return flattened_config -def update_plan(override_config, plan=None, resolve=True): +def update_plan(override_config, plan=None, resolve=True): # noqa: C901 """Updates the plan with the provided override and saves it to disk. For a list of available override options, call `fx.get_plan()` diff --git a/openfl/pipelines/eden_pipeline.py b/openfl/pipelines/eden_pipeline.py index 61c51cd8d9..fe05bd3af5 100644 --- a/openfl/pipelines/eden_pipeline.py +++ b/openfl/pipelines/eden_pipeline.py @@ -438,7 +438,7 @@ def rand_diag(self, size, seed): res = torch.zeros(size_scaled * bools_in_float32, device=self.device) s = 0 - for i in range(bools_in_float32): + for _ in range(bools_in_float32): res[s : s + size_scaled] = r & mask s += size_scaled r >>= shift diff --git a/openfl/plugins/frameworks_adapters/flax_adapter.py b/openfl/plugins/frameworks_adapters/flax_adapter.py index 1e6fd1e203..332b73a1fc 100644 --- a/openfl/plugins/frameworks_adapters/flax_adapter.py +++ b/openfl/plugins/frameworks_adapters/flax_adapter.py @@ -131,7 +131,7 @@ def _update_weights(state_dict, tensor_dict, prefix, suffix=None): """ dict_prefix = f"{prefix}_{suffix}" if suffix is not None else f"{prefix}" for layer_name, param_obj in state_dict.items(): - for param_name, value in param_obj.items(): + for param_name, _ in param_obj.items(): key = "*".join([dict_prefix, layer_name, param_name]) if key in tensor_dict: state_dict[layer_name][param_name] = tensor_dict[key] diff --git a/openfl/transport/grpc/aggregator_client.py b/openfl/transport/grpc/aggregator_client.py index 6713107c2b..4b61cb888b 100644 --- a/openfl/transport/grpc/aggregator_client.py +++ b/openfl/transport/grpc/aggregator_client.py @@ -173,9 +173,12 @@ class AggregatorGRPCClient: use_tls (bool): Whether to use TLS for the connection. require_client_auth (bool): Whether to enable client-side authentication, i.e. mTLS. Ignored if `use_tls=False`. - root_certificate (str): The path to the root certificate for the TLS connection, ignored if `use_tls=False`. - certificate (str): The path to the client's certificate for the TLS connection, ignored if `use_tls=False`. - private_key (str): The path to the client's private key for the TLS connection, ignored if `use_tls=False`. + root_certificate (str): The path to the root certificate for the TLS connection, ignored if + `use_tls=False`. + certificate (str): The path to the client's certificate for the TLS connection, ignored if + `use_tls=False`. + private_key (str): The path to the client's private key for the TLS connection, ignored if + `use_tls=False`. aggregator_uuid (str): The UUID of the aggregator. federation_uuid (str): The UUID of the federation. single_col_cert_common_name (str): The common name on the diff --git a/openfl/transport/grpc/aggregator_server.py b/openfl/transport/grpc/aggregator_server.py index 19d156338f..bfae10351b 100644 --- a/openfl/transport/grpc/aggregator_server.py +++ b/openfl/transport/grpc/aggregator_server.py @@ -31,9 +31,12 @@ class AggregatorGRPCServer(aggregator_pb2_grpc.AggregatorServicer): use_tls (bool): Whether to use TLS for the connection. require_client_auth (bool): Whether to enable client-side authentication, i.e. mTLS. Ignored if `use_tls=False`. - root_certificate (str): The path to the root certificate for the TLS connection, ignored if `use_tls=False`. - certificate (str): The path to the client's certificate for the TLS connection, ignored if `use_tls=False`. - private_key (str): The path to the client's private key for the TLS connection, ignored if `use_tls=False`. + root_certificate (str): The path to the root certificate for the TLS connection, ignored if + `use_tls=False`. + certificate (str): The path to the client's certificate for the TLS connection, ignored if + `use_tls=False`. + private_key (str): The path to the client's private key for the TLS connection, ignored if + `use_tls=False`. server (grpc.Server): The gRPC server. server_credentials (grpc.ServerCredentials): The server's credentials. """ diff --git a/openfl/utilities/fed_timer.py b/openfl/utilities/fed_timer.py index a8ecb9e45f..3d8770ec24 100644 --- a/openfl/utilities/fed_timer.py +++ b/openfl/utilities/fed_timer.py @@ -29,7 +29,7 @@ class CustomThread(Thread): kwargs (dict): The keyword arguments to pass to the target function. """ - def __init__(self, group=None, target=None, name=None, args=(), kwargs={}): + def __init__(self, group=None, target=None, name=None, args=None, kwargs=None): """Initialize a CustomThread object. Args: