Skip to content

Commit

Permalink
Allow feature_shape to be passed to fx plan initialize (#983)
Browse files Browse the repository at this point in the history
* Remove need for aggregator to have dataset for model weight initialization

Signed-off-by: Patrick Foley <[email protected]>

* Remove extra print arguments

Signed-off-by: Patrick Foley <[email protected]>

* Address review comments

Signed-off-by: Patrick Foley <[email protected]>

* Remove extraneous print statement

Signed-off-by: Patrick Foley <[email protected]>

---------

Signed-off-by: Patrick Foley <[email protected]>
  • Loading branch information
psfoley authored Jun 24, 2024
1 parent 41b175e commit eeafaf5
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 18 deletions.
2 changes: 1 addition & 1 deletion openfl/interface/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def entry():
cli.add_command(command_group.__getattribute__(module))

try:
cli()
cli(max_content_width=120)
except Exception as e:
error_handler(e)

Expand Down
32 changes: 17 additions & 15 deletions openfl/interface/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from click import Path as ClickPath

from openfl.utilities.path_check import is_directory_traversal
from openfl.utilities.click_types import InputSpec
from openfl.utilities.mocks import MockDataLoader

logger = getLogger(__name__)

Expand All @@ -36,12 +38,15 @@ def plan(context):
default='plan/data.yaml', type=ClickPath(exists=True))
@option('-a', '--aggregator_address', required=False,
help='The FQDN of the federation agregator')
@option('-f', '--feature_shape', required=False,
help='The input shape to the model')
@option('-f', '--input_shape', cls=InputSpec, required=False,
help="The input shape to the model. May be provided as a list:\n\n"
"--input_shape [1,28,28]\n\n"
"or as a dictionary for multihead models (must be passed in quotes):\n\n"
"--input_shape \"{'input_0': [1, 240, 240, 4],'output_1': [1, 240, 240, 1]}\"\n\n ")
@option('-g', '--gandlf_config', required=False,
help='GaNDLF Configuration File Path')
def initialize(context, plan_config, cols_config, data_config,
aggregator_address, feature_shape, gandlf_config):
aggregator_address, input_shape, gandlf_config):
"""
Initialize Data Science plan.
Expand Down Expand Up @@ -73,18 +78,15 @@ def initialize(context, plan_config, cols_config, data_config,

init_state_path = plan.config['aggregator']['settings']['init_state_path']

# TODO: Is this part really needed? Why would we need to collaborator
# name to know the input shape to the model?

# if feature_shape is None:
# if cols_config is None:
# exit('You must specify either a feature
# shape or authorized collaborator
# list in order for the script to determine the input layer shape')

collaborator_cname = list(plan.cols_data_paths)[0]

data_loader = plan.get_data_loader(collaborator_cname)
# This is needed to bypass data being locally available
if input_shape is not None:
logger.info('Attempting to generate initial model weights with'
f' custom input shape {input_shape}')
data_loader = MockDataLoader(input_shape)
else:
# If feature shape is not provided, data is assumed to be present
collaborator_cname = list(plan.cols_data_paths)[0]
data_loader = plan.get_data_loader(collaborator_cname)
task_runner = plan.get_task_runner(data_loader)
tensor_pipe = plan.get_tensor_pipe()

Expand Down
17 changes: 15 additions & 2 deletions openfl/utilities/click_types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (C) 2020-2023 Intel Corporation
# Copyright (C) 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Click types module."""
"""Custom input types definition for Click"""

import click
import ast

from openfl.utilities import utils

Expand Down Expand Up @@ -31,5 +32,17 @@ def convert(self, value, param, ctx):
return value


class InputSpec(click.Option):
"""List or dictionary that corresponds to the input shape for a model"""
def type_cast_value(self, ctx, value):
try:
if value is None:
return None
else:
return ast.literal_eval(value)
except Exception:
raise click.BadParameter(value)


FQDN = FqdnParamType()
IP_ADDRESS = IpAddressParamType()
12 changes: 12 additions & 0 deletions openfl/utilities/mocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (C) 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Mock objects to eliminate extraneous dependencies"""


class MockDataLoader:
"""Placeholder dataloader for when data is not available"""
def __init__(self, feature_shape):
self.feature_shape = feature_shape

def get_feature_shape(self):
return self.feature_shape

0 comments on commit eeafaf5

Please sign in to comment.