From c8f44fb34e5b34891d49a978ea204e524a6f5df6 Mon Sep 17 00:00:00 2001 From: Ivan Lazarov Date: Sat, 16 Sep 2023 00:06:40 -0700 Subject: [PATCH] fix a runaway weights folder --- config.py | 3 +-- train.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/config.py b/config.py index e7df4e5..fac2178 100644 --- a/config.py +++ b/config.py @@ -25,8 +25,7 @@ def get_weights_file_path(config, epoch: str): # Find the latest weights file in the weights folder def latest_weights_file_path(config): model_folder = f"{config['datasource']}_{config['model_folder']}" - model_basename = config["model_basename"] - model_filename = f"{config['model_basename']}*.pt" + model_filename = f"{config['model_basename']}*" weights_files = list(Path(model_folder).glob(model_filename)) if len(weights_files) == 0: return None diff --git a/train.py b/train.py index 3d01143..f9988b1 100644 --- a/train.py +++ b/train.py @@ -183,7 +183,7 @@ def train_model(config): print("Using device:", device) # Make sure the weights folder exists - Path(config['model_folder']).mkdir(parents=True, exist_ok=True) + Path(f"{config['datasource']}_{config['model_folder']}").mkdir(parents=True, exist_ok=True) train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config) model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)