diff --git a/delphi/train/training.py b/delphi/train/training.py index b56dea3b..575f9844 100644 --- a/delphi/train/training.py +++ b/delphi/train/training.py @@ -5,6 +5,7 @@ from pathlib import Path import torch +from huggingface_hub import HfApi from tqdm import tqdm from transformers import AutoTokenizer @@ -27,14 +28,15 @@ def setup_training(config: TrainingConfig): logging.info("Setting up training...") os.makedirs(config.out_dir, exist_ok=True) - # torch misc - TODO: check if this is actually needed torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn - # determinism setup_determinism(config.torch_seed) - # wandb setup + if config.out_repo: + api = HfApi() + api.create_repo(config.out_repo, exist_ok=True) + if config.wandb: init_wandb(config) diff --git a/delphi/train/utils.py b/delphi/train/utils.py index b76d2058..240d2dcd 100644 --- a/delphi/train/utils.py +++ b/delphi/train/utils.py @@ -19,7 +19,6 @@ from .config import TrainingConfig from .run_context import RunContext -from .shuffle import shuffle_list @dataclass @@ -228,7 +227,6 @@ def save_results( if config.out_repo: try: api = HfApi() - api.create_repo(config.out_repo, exist_ok=True) api.create_branch(config.out_repo, branch=iter_name, exist_ok=True) api.upload_folder( folder_path=results_path,