Skip to content

Commit

Permalink
Allow feature_shape to be passed to fx plan initialize (#983)
Browse files Browse the repository at this point in the history
* Remove need for aggregator to have dataset for model weight initialization

Signed-off-by: Patrick Foley <[email protected]>

* Remove extra print arguments

Signed-off-by: Patrick Foley <[email protected]>

* Address review comments

Signed-off-by: Patrick Foley <[email protected]>

* Remove extraneous print statement

Signed-off-by: Patrick Foley <[email protected]>

---------

Signed-off-by: Patrick Foley <[email protected]>
  • Loading branch information
psfoley authored and manuelhsantana committed Jul 10, 2024
1 parent c72bcf3 commit 4826951
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 22 deletions.
2 changes: 1 addition & 1 deletion openfl/interface/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def entry():
cli.add_command(command_group.__getattribute__(module))

try:
cli()
cli(max_content_width=120)
except Exception as e:
error_handler(e)

Expand Down
41 changes: 22 additions & 19 deletions openfl/interface/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from click import Path as ClickPath

from openfl.utilities.path_check import is_directory_traversal
from openfl.utilities.click_types import InputSpec
from openfl.utilities.mocks import MockDataLoader

logger = getLogger(__name__)

Expand Down Expand Up @@ -40,15 +42,19 @@ def plan(context):
default='plan/data.yaml', type=ClickPath(exists=True))
@option('-a', '--aggregator_address', required=False,
help='The FQDN of the federation agregator')
@option('-f', '--feature_shape', required=False,
help='The input shape to the model')
@option('-f', '--input_shape', cls=InputSpec, required=False,
help="The input shape to the model. May be provided as a list:\n\n"
"--input_shape [1,28,28]\n\n"
"or as a dictionary for multihead models (must be passed in quotes):\n\n"
"--input_shape \"{'input_0': [1, 240, 240, 4],'output_1': [1, 240, 240, 1]}\"\n\n ")
@option('-g', '--gandlf_config', required=False,
help='GaNDLF Configuration File Path')
def initialize(context, plan_config, cols_config, data_config,
aggregator_address, feature_shape, gandlf_config):
"""Initialize Data Science plan.
Create a protocol buffer file of the initial model weights for the federation.
Create a protocol buffer file of the initial model weights for the
federation.
Args:
context (click.core.Context): Click context.
Expand Down Expand Up @@ -84,18 +90,15 @@ def initialize(context, plan_config, cols_config, data_config,

init_state_path = plan.config['aggregator']['settings']['init_state_path']

# TODO: Is this part really needed? Why would we need to collaborator
# name to know the input shape to the model?

# if feature_shape is None:
# if cols_config is None:
# exit('You must specify either a feature
# shape or authorized collaborator
# list in order for the script to determine the input layer shape')

collaborator_cname = list(plan.cols_data_paths)[0]

data_loader = plan.get_data_loader(collaborator_cname)
# This is needed to bypass data being locally available
if input_shape is not None:
logger.info('Attempting to generate initial model weights with'
f' custom input shape {input_shape}')
data_loader = MockDataLoader(input_shape)
else:
# If feature shape is not provided, data is assumed to be present
collaborator_cname = list(plan.cols_data_paths)[0]
data_loader = plan.get_data_loader(collaborator_cname)
task_runner = plan.get_task_runner(data_loader)
tensor_pipe = plan.get_tensor_pipe()

Expand Down Expand Up @@ -186,7 +189,7 @@ def freeze(plan_config):

def switch_plan(name):
"""Switch the FL plan to this one.
Args:
name (str): Name of the Federated learning plan.
Expand Down Expand Up @@ -231,7 +234,7 @@ def switch_plan(name):
default='default', type=str)
def switch_(name):
"""Switch the current plan to this plan.
Args:
name (str): Name of the Federated learning plan.
"""
Expand All @@ -244,7 +247,7 @@ def switch_(name):
default='default', type=str)
def save_(name):
"""Save the current plan to this plan and switch.
Args:
name (str): Name of the Federated learning plan.
"""
Expand All @@ -267,7 +270,7 @@ def save_(name):
default='default', type=str)
def remove_(name):
"""Remove this plan.
Args:
name (str): Name of the Federated learning plan.
"""
Expand Down
17 changes: 15 additions & 2 deletions openfl/utilities/click_types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (C) 2020-2023 Intel Corporation
# Copyright (C) 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Click types module."""
"""Custom input types definition for Click"""

import click
import ast

from openfl.utilities import utils

Expand Down Expand Up @@ -67,5 +68,17 @@ def convert(self, value, param, ctx):
return value


class InputSpec(click.Option):
"""List or dictionary that corresponds to the input shape for a model"""
def type_cast_value(self, ctx, value):
try:
if value is None:
return None
else:
return ast.literal_eval(value)
except Exception:
raise click.BadParameter(value)


FQDN = FqdnParamType()
IP_ADDRESS = IpAddressParamType()
12 changes: 12 additions & 0 deletions openfl/utilities/mocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (C) 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Mock objects to eliminate extraneous dependencies"""


class MockDataLoader:
"""Placeholder dataloader for when data is not available"""
def __init__(self, feature_shape):
self.feature_shape = feature_shape

def get_feature_shape(self):
return self.feature_shape

0 comments on commit 4826951

Please sign in to comment.