Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow feature_shape to be passed to fx plan initialize #983

Merged
merged 4 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions openfl/interface/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from click import Path as ClickPath

from openfl.utilities.path_check import is_directory_traversal
from openfl.utilities.click_types import ListOption

logger = getLogger(__name__)

Expand All @@ -36,8 +37,8 @@ def plan(context):
default='plan/data.yaml', type=ClickPath(exists=True))
@option('-a', '--aggregator_address', required=False,
help='The FQDN of the federation agregator')
@option('-f', '--feature_shape', required=False,
help='The input shape to the model')
@option('-f', '--feature_shape', cls=ListOption, required=False,
psfoley marked this conversation as resolved.
Show resolved Hide resolved
help='The input shape to the model (i.e. [1,28,28])')
@option('-g', '--gandlf_config', required=False,
help='GaNDLF Configuration File Path')
def initialize(context, plan_config, cols_config, data_config,
Expand All @@ -54,6 +55,7 @@ def initialize(context, plan_config, cols_config, data_config,
from openfl.protocols import utils
from openfl.utilities.split import split_tensor_dict_for_holdouts
from openfl.utilities.utils import getfqdn_env
from openfl.utilities.mocks import MockDataLoader
psfoley marked this conversation as resolved.
Show resolved Hide resolved

for p in [plan_config, cols_config, data_config]:
if is_directory_traversal(p):
Expand All @@ -73,18 +75,15 @@ def initialize(context, plan_config, cols_config, data_config,

init_state_path = plan.config['aggregator']['settings']['init_state_path']

# TODO: Is this part really needed? Why would we need to collaborator
# name to know the input shape to the model?

# if feature_shape is None:
# if cols_config is None:
# exit('You must specify either a feature
# shape or authorized collaborator
# list in order for the script to determine the input layer shape')

collaborator_cname = list(plan.cols_data_paths)[0]

data_loader = plan.get_data_loader(collaborator_cname)
# This is needed to bypass data being locally available
if feature_shape is not None:
logger.info('Attempting to generate initial model weights with' \
f' custom feature shape {feature_shape}')
data_loader = MockDataLoader(feature_shape)
else:
# If feature shape is not provided, data is assumed to be present
collaborator_cname = list(plan.cols_data_paths)[0]
data_loader = plan.get_data_loader(collaborator_cname)
task_runner = plan.get_task_runner(data_loader)
tensor_pipe = plan.get_tensor_pipe()

Expand Down
12 changes: 12 additions & 0 deletions openfl/utilities/click_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""Click types module."""
psfoley marked this conversation as resolved.
Show resolved Hide resolved

import click
import ast

from openfl.utilities import utils

Expand Down Expand Up @@ -30,6 +31,17 @@ def convert(self, value, param, ctx):
self.fail(f'{value} is not a valid ip adress name', param, ctx)
return value

class ListOption(click.Option):

psfoley marked this conversation as resolved.
Show resolved Hide resolved
def type_cast_value(self, ctx, value):
try:
if value is None:
return None
else:
return ast.literal_eval(value)
except:
raise click.BadParameter(value)


FQDN = FqdnParamType()
IP_ADDRESS = IpAddressParamType()
Loading