From f395e6dd7d2ee3e54bfffe6bb37438d2d52cfca8 Mon Sep 17 00:00:00 2001 From: Jett Date: Mon, 20 May 2024 16:59:35 +0200 Subject: [PATCH] HF upload: consistent iter name, catch Exception --- scripts/run_training.py | 7 ++++--- src/delphi/train/checkpoint_step.py | 9 +-------- src/delphi/train/utils.py | 22 ++++++++++------------ 3 files changed, 15 insertions(+), 23 deletions(-) diff --git a/scripts/run_training.py b/scripts/run_training.py index ff48b9eb..92fd68e9 100755 --- a/scripts/run_training.py +++ b/scripts/run_training.py @@ -99,9 +99,10 @@ def main(): config = build_config_from_files_and_overrides(config_files, args_dict) # run training results, run_context = run_training(config) - final_out_dir = os.path.join(config.output_dir, "final") - save_results(config, results, run_context, final_out_dir, final=True) - print(f"Saved results to {final_out_dir}") + # to save & upload to iterX folder/branch + save_results(config, results, run_context, final=False) + # to save & upload to main folder/branch + save_results(config, results, run_context, final=True) if __name__ == "__main__": diff --git a/src/delphi/train/checkpoint_step.py b/src/delphi/train/checkpoint_step.py index a86a8680..afd84702 100644 --- a/src/delphi/train/checkpoint_step.py +++ b/src/delphi/train/checkpoint_step.py @@ -44,14 +44,7 @@ def log_and_save_checkpoint( logging.info( f"step {mts.iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}" ) - results_path = os.path.join(config.output_dir, f"iter_{mts.iter_num:06d}") - logging.info(f"saving checkpoint to {results_path}") - save_results( - config=config, - train_results=mts, - run_context=run_context, - results_path=results_path, - ) + save_results(config=config, train_results=mts, run_context=run_context) if config.wandb: log_to_wandb( mts=mts, diff --git a/src/delphi/train/utils.py b/src/delphi/train/utils.py index 2201e759..4966aa8c 100644 --- a/src/delphi/train/utils.py +++ b/src/delphi/train/utils.py @@ -203,7 +203,6 @@ def save_results( config: TrainingConfig, train_results: ModelTrainingState, run_context: RunContext, - results_path: str, final: bool = False, ): """ @@ -212,6 +211,9 @@ def save_results( Saves everything required to replicate the current state of training, including optimizer state, config, context (e.g. hardware), training step, etc """ + iter_name = "main" if final else f"iter{train_results.iter_num}" + results_path = os.path.join(config.output_dir, iter_name) + logging.info(f"saving checkpoint to {results_path}") os.makedirs(results_path, exist_ok=True) with open(os.path.join(results_path, "training_config.json"), "w") as file: json.dump(asdict(config), file, indent=2) @@ -239,21 +241,17 @@ def save_results( with open(os.path.join(results_path, "run_context.json"), "w") as file: json.dump(run_context.asdict(), file, indent=2) if config.out_repo_id: - api = HfApi() - api.create_repo(config.out_repo_id, exist_ok=True) - branch_name = f"iter{train_results.iter_num}" - api.create_branch(config.out_repo_id, branch=branch_name) - api.upload_folder( - folder_path=results_path, - repo_id=config.out_repo_id, - revision=branch_name, - ) - if final: + try: + api = HfApi() + api.create_repo(config.out_repo_id, exist_ok=True) + api.create_branch(config.out_repo_id, branch=iter_name, exist_ok=True) api.upload_folder( folder_path=results_path, repo_id=config.out_repo_id, - revision="main", + revision=iter_name, ) + except Exception as e: + logging.error(f"Failed to upload to huggingface: {e}") def count_tokens_so_far(config: TrainingConfig, mts: ModelTrainingState) -> int: