From eeafaf5b008f704c34163d78c788ee4904fbbdf8 Mon Sep 17 00:00:00 2001 From: Patrick Foley Date: Mon, 24 Jun 2024 13:42:27 -0700 Subject: [PATCH] Allow `feature_shape` to be passed to `fx plan initialize` (#983) * Remove need for aggregator to have dataset for model weight initialization Signed-off-by: Patrick Foley * Remove extra print arguments Signed-off-by: Patrick Foley * Address review comments Signed-off-by: Patrick Foley * Remove extraneous print statement Signed-off-by: Patrick Foley --------- Signed-off-by: Patrick Foley --- openfl/interface/cli.py | 2 +- openfl/interface/plan.py | 32 +++++++++++++++++--------------- openfl/utilities/click_types.py | 17 +++++++++++++++-- openfl/utilities/mocks.py | 12 ++++++++++++ 4 files changed, 45 insertions(+), 18 deletions(-) create mode 100644 openfl/utilities/mocks.py diff --git a/openfl/interface/cli.py b/openfl/interface/cli.py index 3e4d0876d2..25a0eed1eb 100755 --- a/openfl/interface/cli.py +++ b/openfl/interface/cli.py @@ -254,7 +254,7 @@ def entry(): cli.add_command(command_group.__getattribute__(module)) try: - cli() + cli(max_content_width=120) except Exception as e: error_handler(e) diff --git a/openfl/interface/plan.py b/openfl/interface/plan.py index 6d207944d8..9e1618f742 100644 --- a/openfl/interface/plan.py +++ b/openfl/interface/plan.py @@ -12,6 +12,8 @@ from click import Path as ClickPath from openfl.utilities.path_check import is_directory_traversal +from openfl.utilities.click_types import InputSpec +from openfl.utilities.mocks import MockDataLoader logger = getLogger(__name__) @@ -36,12 +38,15 @@ def plan(context): default='plan/data.yaml', type=ClickPath(exists=True)) @option('-a', '--aggregator_address', required=False, help='The FQDN of the federation agregator') -@option('-f', '--feature_shape', required=False, - help='The input shape to the model') +@option('-f', '--input_shape', cls=InputSpec, required=False, + help="The input shape to the model. May be provided as a list:\n\n" + "--input_shape [1,28,28]\n\n" + "or as a dictionary for multihead models (must be passed in quotes):\n\n" + "--input_shape \"{'input_0': [1, 240, 240, 4],'output_1': [1, 240, 240, 1]}\"\n\n ") @option('-g', '--gandlf_config', required=False, help='GaNDLF Configuration File Path') def initialize(context, plan_config, cols_config, data_config, - aggregator_address, feature_shape, gandlf_config): + aggregator_address, input_shape, gandlf_config): """ Initialize Data Science plan. @@ -73,18 +78,15 @@ def initialize(context, plan_config, cols_config, data_config, init_state_path = plan.config['aggregator']['settings']['init_state_path'] - # TODO: Is this part really needed? Why would we need to collaborator - # name to know the input shape to the model? - - # if feature_shape is None: - # if cols_config is None: - # exit('You must specify either a feature - # shape or authorized collaborator - # list in order for the script to determine the input layer shape') - - collaborator_cname = list(plan.cols_data_paths)[0] - - data_loader = plan.get_data_loader(collaborator_cname) + # This is needed to bypass data being locally available + if input_shape is not None: + logger.info('Attempting to generate initial model weights with' + f' custom input shape {input_shape}') + data_loader = MockDataLoader(input_shape) + else: + # If feature shape is not provided, data is assumed to be present + collaborator_cname = list(plan.cols_data_paths)[0] + data_loader = plan.get_data_loader(collaborator_cname) task_runner = plan.get_task_runner(data_loader) tensor_pipe = plan.get_tensor_pipe() diff --git a/openfl/utilities/click_types.py b/openfl/utilities/click_types.py index d9b8a0535f..4847ff5ed1 100644 --- a/openfl/utilities/click_types.py +++ b/openfl/utilities/click_types.py @@ -1,8 +1,9 @@ -# Copyright (C) 2020-2023 Intel Corporation +# Copyright (C) 2020-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""Click types module.""" +"""Custom input types definition for Click""" import click +import ast from openfl.utilities import utils @@ -31,5 +32,17 @@ def convert(self, value, param, ctx): return value +class InputSpec(click.Option): + """List or dictionary that corresponds to the input shape for a model""" + def type_cast_value(self, ctx, value): + try: + if value is None: + return None + else: + return ast.literal_eval(value) + except Exception: + raise click.BadParameter(value) + + FQDN = FqdnParamType() IP_ADDRESS = IpAddressParamType() diff --git a/openfl/utilities/mocks.py b/openfl/utilities/mocks.py new file mode 100644 index 0000000000..a6b6206b71 --- /dev/null +++ b/openfl/utilities/mocks.py @@ -0,0 +1,12 @@ +# Copyright (C) 2020-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Mock objects to eliminate extraneous dependencies""" + + +class MockDataLoader: + """Placeholder dataloader for when data is not available""" + def __init__(self, feature_shape): + self.feature_shape = feature_shape + + def get_feature_shape(self): + return self.feature_shape