Skip to content

Commit

Permalink
Deserialize JSON and create class objects when loading
Browse files Browse the repository at this point in the history
  • Loading branch information
mdw771 committed May 22, 2024
1 parent 31147b2 commit 384e2b4
Showing 1 changed file with 76 additions and 4 deletions.
80 changes: 76 additions & 4 deletions generic_trainer/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Any, Callable, Optional, Union
import json
import os
import re
import importlib

import torch
from torch.utils.data import Dataset
Expand All @@ -15,6 +17,10 @@

@dataclasses.dataclass
class OptionContainer:
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.globals = {}

def __str__(self):
s = ''
for key in self.__dict__.keys():
Expand All @@ -40,6 +46,10 @@ def get_serializable_dict(self):
d[key] = v
return d

def deserizalize_dict(self, d):
for key in d.keys():
self.__dict__[key] = self.string_to_object(key, d[key])

def dump_to_json(self, filename):
try:
f = open(filename, 'w')
Expand All @@ -49,14 +59,16 @@ def dump_to_json(self, filename):
except:
print('Failed to dump json.')

def load_from_json(self, filename):
def load_from_json(self, filename, namespace=None):
"""
This function only overwrites entries contained in the JSON file. Unspecified entries are unaffected.
"""
if namespace is not None:
for key in namespace.keys():
globals()[key] = namespace[key]
f = open(filename, 'r')
d = json.load(f)
for key in d.keys():
self.__dict__[key] = self.string_to_object(key, d[key])
self.deserizalize_dict(d)
f.close()

def string_to_object(self, key, value):
Expand All @@ -66,6 +78,21 @@ def string_to_object(self, key, value):
:param value: str.
:return: object.
"""
# Value is a class handle
if isinstance(value, (list, tuple)):
value = [self.string_to_object(key, v) for v in value]
elif isinstance(value, str) and (res := re.match(r"<class '(.+)'>", value)):
class_import_path = res.groups()[0].split('.')
value = getattr(importlib.import_module('.'.join(class_import_path[:-1])), class_import_path[-1])
elif value in ['True', 'False']:
value = True if value == 'True' else False
else:
for caster in (int, float):
try:
value = caster(value)
break
except (ValueError, TypeError):
pass
return value

def object_to_string(self, key, value):
Expand All @@ -77,8 +104,12 @@ def object_to_string(self, key, value):
"""
if isinstance(value, OptionContainer):
value = value.get_serializable_dict()
elif isinstance(value, dict):
value = value
elif isinstance(value, (tuple, list)):
value = [self.object_to_string(key, x) for x in value]
elif value is None:
value = None
else:
value = str(value)
return value
Expand All @@ -90,7 +121,14 @@ def object_to_string(self, key, value):

@dataclasses.dataclass
class ModelParameters(OptionContainer):
pass

def string_to_object(self, key, value):
value = super().string_to_object(key, value)
if 'model_params' in key and isinstance(value, dict):
config = ModelParameters()
config.deserizalize_dict(value)
value = config
return value


# =============================
Expand Down Expand Up @@ -161,6 +199,18 @@ class Config(OptionContainer):
Task type. Can be 'classification', 'regression'. Currently this only affects the logging of the loss tracker.
"""

def string_to_object(self, key, value):
value = super().string_to_object(key, value)
if 'model_params' in key and isinstance(value, dict):
config = ModelParameters()
config.deserizalize_dict(value)
value = config
if key == 'parallelization_params' and isinstance(value, dict):
config = ParallelizationConfig()
config.deserizalize_dict(value)
value = config
return value


@dataclasses.dataclass
class InferenceConfig(Config):
Expand All @@ -185,6 +235,10 @@ class InferenceConfig(Config):
processed variables.
"""

def string_to_object(self, key, value):
if key == 'model_save_dir':
self.pretrained_model_path = os.path.join(value, 'best_model.pth')


@dataclasses.dataclass
class TrainingConfig(Config):
Expand Down Expand Up @@ -289,6 +343,24 @@ class TrainingConfig(Config):
save_onnx: bool = False
"""If True, ONNX models are saved along with state dicts."""

def string_to_object(self, key, value):
value = super().string_to_object(key, value)
if key == 'loss_function' and not isinstance(value, (list, tuple)):
try:
value = eval(value)
except Exception as e:
print(
"When loading loss_function from JSON, the following error occurred:'\n{}\n"
"To create a loss function object, its class name must be in the global namespace. You can "
"import the proper classes in your driver script using from ... import ..., and pass "
"globals() to load_from_json:\n"
" configs.load_from_json(filename, namespace=globals())\n".format(e)
)
elif key == 'loss_tracker_params':
configs = LossTrackerParameters()
configs.deserizalize_dict(value)
value = configs
return value

@dataclasses.dataclass
class PretrainingConfig(TrainingConfig):
Expand Down

0 comments on commit 384e2b4

Please sign in to comment.