Skip to content

Commit

Permalink
HF upload: consistent iter name, catch Exception
Browse files Browse the repository at this point in the history
  • Loading branch information
jettjaniak committed May 20, 2024
1 parent 898df99 commit f395e6d
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 23 deletions.
7 changes: 4 additions & 3 deletions scripts/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,10 @@ def main():
config = build_config_from_files_and_overrides(config_files, args_dict)
# run training
results, run_context = run_training(config)
final_out_dir = os.path.join(config.output_dir, "final")
save_results(config, results, run_context, final_out_dir, final=True)
print(f"Saved results to {final_out_dir}")
# to save & upload to iterX folder/branch
save_results(config, results, run_context, final=False)
# to save & upload to main folder/branch
save_results(config, results, run_context, final=True)


if __name__ == "__main__":
Expand Down
9 changes: 1 addition & 8 deletions src/delphi/train/checkpoint_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,7 @@ def log_and_save_checkpoint(
logging.info(
f"step {mts.iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
)
results_path = os.path.join(config.output_dir, f"iter_{mts.iter_num:06d}")
logging.info(f"saving checkpoint to {results_path}")
save_results(
config=config,
train_results=mts,
run_context=run_context,
results_path=results_path,
)
save_results(config=config, train_results=mts, run_context=run_context)
if config.wandb:
log_to_wandb(
mts=mts,
Expand Down
22 changes: 10 additions & 12 deletions src/delphi/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ def save_results(
config: TrainingConfig,
train_results: ModelTrainingState,
run_context: RunContext,
results_path: str,
final: bool = False,
):
"""
Expand All @@ -212,6 +211,9 @@ def save_results(
Saves everything required to replicate the current state of training, including optimizer state,
config, context (e.g. hardware), training step, etc
"""
iter_name = "main" if final else f"iter{train_results.iter_num}"
results_path = os.path.join(config.output_dir, iter_name)
logging.info(f"saving checkpoint to {results_path}")
os.makedirs(results_path, exist_ok=True)
with open(os.path.join(results_path, "training_config.json"), "w") as file:
json.dump(asdict(config), file, indent=2)
Expand Down Expand Up @@ -239,21 +241,17 @@ def save_results(
with open(os.path.join(results_path, "run_context.json"), "w") as file:
json.dump(run_context.asdict(), file, indent=2)
if config.out_repo_id:
api = HfApi()
api.create_repo(config.out_repo_id, exist_ok=True)
branch_name = f"iter{train_results.iter_num}"
api.create_branch(config.out_repo_id, branch=branch_name)
api.upload_folder(
folder_path=results_path,
repo_id=config.out_repo_id,
revision=branch_name,
)
if final:
try:
api = HfApi()
api.create_repo(config.out_repo_id, exist_ok=True)
api.create_branch(config.out_repo_id, branch=iter_name, exist_ok=True)
api.upload_folder(
folder_path=results_path,
repo_id=config.out_repo_id,
revision="main",
revision=iter_name,
)
except Exception as e:
logging.error(f"Failed to upload to huggingface: {e}")


def count_tokens_so_far(config: TrainingConfig, mts: ModelTrainingState) -> int:
Expand Down

0 comments on commit f395e6d

Please sign in to comment.