Skip to content

Commit

Permalink
fix a runaway weights folder
Browse files Browse the repository at this point in the history
  • Loading branch information
ivolazy committed Sep 16, 2023
1 parent 3cf3af2 commit c8f44fb
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
3 changes: 1 addition & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def get_weights_file_path(config, epoch: str):
# Find the latest weights file in the weights folder
def latest_weights_file_path(config):
model_folder = f"{config['datasource']}_{config['model_folder']}"
model_basename = config["model_basename"]
model_filename = f"{config['model_basename']}*.pt"
model_filename = f"{config['model_basename']}*"
weights_files = list(Path(model_folder).glob(model_filename))
if len(weights_files) == 0:
return None
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def train_model(config):
print("Using device:", device)

# Make sure the weights folder exists
Path(config['model_folder']).mkdir(parents=True, exist_ok=True)
Path(f"{config['datasource']}_{config['model_folder']}").mkdir(parents=True, exist_ok=True)

train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
Expand Down

0 comments on commit c8f44fb

Please sign in to comment.