Skip to content

Commit

Permalink
Add class names of nested config objects to JSON and create the objec…
Browse files Browse the repository at this point in the history
…ts when loaded
  • Loading branch information
mdw771 committed May 24, 2024
1 parent 54dcb8c commit dfee9cb
Showing 1 changed file with 24 additions and 26 deletions.
50 changes: 24 additions & 26 deletions generic_trainer/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,26 @@ def string_to_object(self, key, value):
:return: object.
"""
# Value is a class handle
if isinstance(value, (list, tuple)):
if isinstance(value, dict):
# Only convert the dict to an OptionContainer object if they are supposed to. Otherwise, leave it as a dict.
if 'model_params' in key or key in ['loss_tracker_params', 'parallelization_params']:
assert 'config_class' in value.keys(), ('The value of {} is supposed to be an object of a subclass '
'of OptionContainer, but I cannot find the '
'class name.'.format(key))
try:
config = globals()[value['config_class']]()
config.deserizalize_dict(value)
value = config
except KeyError as e:
raise ModuleNotFoundError(
"When loading {} from JSON, the following error occurred when attempting to create the "
"OptionContainer object for it:'\n{}\n"
"To create an OptionContainer object, its class name must be in the global namespace. You can "
"import the proper classes in your driver script using from ... import ..., and pass "
"globals() to load_from_json:\n"
" configs.load_from_json(filename, namespace=globals())\n".format(key, e)
)
elif isinstance(value, (list, tuple)):
value = [self.string_to_object(key, v) for v in value]
elif isinstance(value, str) and (res := re.match(r"<class '(.+)'>", value)):
class_import_path = res.groups()[0].split('.')
Expand All @@ -105,7 +124,9 @@ def object_to_string(self, key, value):
:return: str.
"""
if isinstance(value, OptionContainer):
config_class_name = value.__class__.__name__
value = value.get_serializable_dict()
value['config_class'] = config_class_name
elif isinstance(value, (dict, int, float, bool)):
value = value
elif isinstance(value, (tuple, list)):
Expand All @@ -123,14 +144,7 @@ def object_to_string(self, key, value):

@dataclasses.dataclass
class ModelParameters(OptionContainer):

def string_to_object(self, key, value):
value = super().string_to_object(key, value)
if 'model_params' in key and isinstance(value, dict):
config = ModelParameters()
config.deserizalize_dict(value)
value = config
return value
pass


# =============================
Expand Down Expand Up @@ -201,18 +215,6 @@ class Config(OptionContainer):
Task type. Can be 'classification', 'regression'. Currently this only affects the logging of the loss tracker.
"""

def string_to_object(self, key, value):
value = super().string_to_object(key, value)
if 'model_params' in key and isinstance(value, dict):
config = ModelParameters()
config.deserizalize_dict(value)
value = config
elif key == 'parallelization_params' and isinstance(value, dict):
config = ParallelizationConfig()
config.deserizalize_dict(value)
value = config
return value


@dataclasses.dataclass
class InferenceConfig(Config):
Expand Down Expand Up @@ -355,17 +357,13 @@ def string_to_object(self, key, value):
try:
value = eval(value)
except Exception as e:
print(
raise ModuleNotFoundError(
"When loading loss_function from JSON, the following error occurred:'\n{}\n"
"To create a loss function object, its class name must be in the global namespace. You can "
"import the proper classes in your driver script using from ... import ..., and pass "
"globals() to load_from_json:\n"
" configs.load_from_json(filename, namespace=globals())\n".format(e)
)
elif key == 'loss_tracker_params':
configs = LossTrackerParameters()
configs.deserizalize_dict(value)
value = configs
return value

@dataclasses.dataclass
Expand Down

0 comments on commit dfee9cb

Please sign in to comment.