diff --git a/train/models/constants.py b/train/models/constants.py index 680432e..196aeb1 100644 --- a/train/models/constants.py +++ b/train/models/constants.py @@ -7,7 +7,7 @@ DEF_learning_rate = 1e-4 DEF_init_std = 1e-2 DEF_epochs = 100 -DEF_validate_step = 10 +DEF_validate_step = 100 DEF_chunk_size = 512 DEF_basis_size = 100 diff --git a/train/train.py b/train/train.py index dff4f50..2fabee1 100644 --- a/train/train.py +++ b/train/train.py @@ -171,14 +171,14 @@ def _process_lambdas(lambda_param, lambda_params, possible_models=possible_model model_formats = _match_model(opts.domains[0], possible_models) model_type = models.hsm_id.HSMIndependentDomainsModel input_directories = os.path.join(opts.input_directory, model_formats.directory) - elif len(opts.domains) == 1: - model_formats = _match_model(opts.domains[0], possible_models) - model_type = models.hsm_d_singledomain.HSMSingleDomainsModel - input_directories = os.path.join(opts.input_directory, model_formats.directory) elif opts.include_all_domains: model_formats = model_formats_ifile model_type = models.hsm_d.HSMDomainsModel input_directories = opts.input_directory + elif len(opts.domains) == 1: + model_formats = _match_model(opts.domains[0], possible_models) + model_type = models.hsm_d_singledomain.HSMSingleDomainsModel + input_directories = os.path.join(opts.input_directory, model_formats.directory) else: model_formats = [_match_model(d, possible_models) for d in opts.domains] model_type = models.hsm_d.HSMDomainsModel