Skip to content

Commit

Permalink
Working on #18 : more configurability for autoencoders
Browse files Browse the repository at this point in the history
  • Loading branch information
sam-may committed Feb 3, 2022
1 parent de2cf28 commit 854f5dd
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
8 changes: 6 additions & 2 deletions autodqm_ml/algorithms/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from autodqm_ml import utils

DEFAULT_OPT = {
"batch_size" : 16,
"batch_size" : 16,
"val_batch_size" : 1024,
"n_epochs" : 1000,
"early_stopping" : True,
Expand Down Expand Up @@ -51,10 +51,14 @@ def __init__(self, **kwargs):

self.mode = kwargs.get('autoencoder_mode', 'individual')
if not self.mode in ["individual", "simultaneous"]:
logger.exception("AutoEncoder : __init__] mode '%s' is not a recognized option for AutoEncoder. Currently available modes are 'individual' (default) and 'simultaneous'." % (self.mode))
logger.exception("[AutoEncoder : __init__] mode '%s' is not a recognized option for AutoEncoder. Currently available modes are 'individual' (default) and 'simultaneous'." % (self.mode))
raise ValueError()
self.models = {}

logger.debug("[AutoEncoder : __init__] Constructing AutoEncoder with the following training options and hyperparameters:")
for param, value in self.config.items():
logger.debug("\t %s : %s" % (param, str(value)))


def load_model(self, model_file):
"""
Expand Down
8 changes: 7 additions & 1 deletion scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,16 @@
if not os.path.exists(args.algorithm):
algorithm_config_file = expand_path(args.algorithm)
else:
algorithm_config_file = algo
algorithm_config_file = args.algorithm

with open(algorithm_config_file, "r") as f_in:
config = json.load(f_in)

# Add command line arguments to config
for k,v in vars(args).items():
if v is not None:
config[k] = v # note: if you specify an argument both through command line argument and json, we give precedence to the version from command line arguments

else:
config = vars(args)
config["name"] = args.algorithm.lower()
Expand Down

0 comments on commit 854f5dd

Please sign in to comment.