Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow feature_shape to be passed to fx plan initialize #983

Merged
merged 4 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion openfl/interface/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def entry():
cli.add_command(command_group.__getattribute__(module))

try:
cli()
cli(max_content_width=120)
except Exception as e:
error_handler(e)

Expand Down
32 changes: 17 additions & 15 deletions openfl/interface/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from click import Path as ClickPath

from openfl.utilities.path_check import is_directory_traversal
from openfl.utilities.click_types import InputSpec
from openfl.utilities.mocks import MockDataLoader

logger = getLogger(__name__)

Expand All @@ -36,12 +38,15 @@ def plan(context):
default='plan/data.yaml', type=ClickPath(exists=True))
@option('-a', '--aggregator_address', required=False,
help='The FQDN of the federation agregator')
@option('-f', '--feature_shape', required=False,
help='The input shape to the model')
@option('-f', '--input_shape', cls=InputSpec, required=False,
help="The input shape to the model. May be provided as a list:\n\n"
"--input_shape [1,28,28]\n\n"
"or as a dictionary for multihead models (must be passed in quotes):\n\n"
"--input_shape \"{'input_0': [1, 240, 240, 4],'output_1': [1, 240, 240, 1]}\"\n\n ")
@option('-g', '--gandlf_config', required=False,
help='GaNDLF Configuration File Path')
def initialize(context, plan_config, cols_config, data_config,
aggregator_address, feature_shape, gandlf_config):
aggregator_address, input_shape, gandlf_config):
"""
Initialize Data Science plan.

Expand Down Expand Up @@ -73,18 +78,15 @@ def initialize(context, plan_config, cols_config, data_config,

init_state_path = plan.config['aggregator']['settings']['init_state_path']

# TODO: Is this part really needed? Why would we need to collaborator
# name to know the input shape to the model?

# if feature_shape is None:
# if cols_config is None:
# exit('You must specify either a feature
# shape or authorized collaborator
# list in order for the script to determine the input layer shape')

collaborator_cname = list(plan.cols_data_paths)[0]

data_loader = plan.get_data_loader(collaborator_cname)
# This is needed to bypass data being locally available
if input_shape is not None:
logger.info('Attempting to generate initial model weights with'
f' custom input shape {input_shape}')
data_loader = MockDataLoader(input_shape)
else:
# If feature shape is not provided, data is assumed to be present
collaborator_cname = list(plan.cols_data_paths)[0]
data_loader = plan.get_data_loader(collaborator_cname)
task_runner = plan.get_task_runner(data_loader)
tensor_pipe = plan.get_tensor_pipe()

Expand Down
17 changes: 15 additions & 2 deletions openfl/utilities/click_types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (C) 2020-2023 Intel Corporation
# Copyright (C) 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Click types module."""
"""Custom input types definition for Click"""

import click
import ast

from openfl.utilities import utils

Expand Down Expand Up @@ -31,5 +32,17 @@ def convert(self, value, param, ctx):
return value


class InputSpec(click.Option):
psfoley marked this conversation as resolved.
Show resolved Hide resolved
"""List or dictionary that corresponds to the input shape for a model"""
def type_cast_value(self, ctx, value):
try:
if value is None:
return None
else:
return ast.literal_eval(value)
except Exception:
raise click.BadParameter(value)


FQDN = FqdnParamType()
IP_ADDRESS = IpAddressParamType()
12 changes: 12 additions & 0 deletions openfl/utilities/mocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (C) 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Mock objects to eliminate extraneous dependencies"""


class MockDataLoader:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose you don't extend DataLoader, because only the get_feature_shape() method is required during plan initialization? In this case, would it make sense to call this class FeatureShapeLoader?

"""Placeholder dataloader for when data is not available"""
def __init__(self, feature_shape):
self.feature_shape = feature_shape

def get_feature_shape(self):
return self.feature_shape
Loading