diff --git a/scripts/run_training.py b/scripts/run_training.py index 92fd68e9..98e126ed 100755 --- a/scripts/run_training.py +++ b/scripts/run_training.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 import argparse import logging -import os import sys from pathlib import Path from typing import Any diff --git a/src/delphi/train/config/training_config.py b/src/delphi/train/config/training_config.py index 2de91cce..c68896d4 100644 --- a/src/delphi/train/config/training_config.py +++ b/src/delphi/train/config/training_config.py @@ -89,9 +89,18 @@ class TrainingConfig: metadata={"help": "specify training and validation data"}, ) + tokenizer: str = field( + default="", + metadata={ + "help": "HF repo id or local directory containing the tokenizer. Used only to upload it to HF with the model, not for training" + }, + ) + # third party wandb: Optional[WandbConfig] = None - out_repo_id: str + out_repo_id: str = field( + metadata={"help": "set to empty string to not push to repo"}, + ) # debug debug_config: DebugConfig = field(default_factory=DebugConfig) diff --git a/src/delphi/train/training.py b/src/delphi/train/training.py index 85fda188..7db8dfef 100644 --- a/src/delphi/train/training.py +++ b/src/delphi/train/training.py @@ -2,9 +2,11 @@ import os import time from dataclasses import fields +from pathlib import Path import torch from tqdm import tqdm +from transformers import AutoTokenizer from .checkpoint_step import log_and_save_checkpoint, should_save_checkpoint from .config import TrainingConfig @@ -35,6 +37,10 @@ def setup_training(config: TrainingConfig): if config.wandb: init_wandb(config=config) + if config.tokenizer: + tokenizer = AutoTokenizer.from_pretrained(config.tokenizer) + tokenizer.save_pretrained(Path(config.output_dir) / "tokenizer") + def run_training(config: TrainingConfig) -> tuple[ModelTrainingState, RunContext]: setup_training(config) diff --git a/src/delphi/train/utils.py b/src/delphi/train/utils.py index 4966aa8c..e0d5777d 100644 --- a/src/delphi/train/utils.py +++ b/src/delphi/train/utils.py @@ -2,6 +2,7 @@ import logging import math import os +import shutil import time from collections.abc import Iterator from dataclasses import asdict, dataclass, field @@ -212,25 +213,19 @@ def save_results( config, context (e.g. hardware), training step, etc """ iter_name = "main" if final else f"iter{train_results.iter_num}" - results_path = os.path.join(config.output_dir, iter_name) + output_dir = Path(config.output_dir) + results_path = output_dir / iter_name logging.info(f"saving checkpoint to {results_path}") - os.makedirs(results_path, exist_ok=True) - with open(os.path.join(results_path, "training_config.json"), "w") as file: + results_path.mkdir(parents=True, exist_ok=True) + with open(results_path / "training_config.json", "w") as file: json.dump(asdict(config), file, indent=2) - model = train_results.model - if isinstance(model, PreTrainedModel): - model.save_pretrained( - save_directory=results_path, - ) - else: - st.save_model( - model, - os.path.join(results_path, "model.safetensors"), - ) + train_results.model.save_pretrained( + save_directory=results_path, + ) if config.save_optimizer: - with open(os.path.join(results_path, "optimizer.pt"), "wb") as f: + with open(results_path / "optimizer.pt", "wb") as f: torch.save(train_results.optimizer.state_dict(), f) - with open(os.path.join(results_path, "training_state.json"), "w") as file: + with open(results_path / "training_state.json", "w") as file: training_state_dict = { "iter_num": train_results.iter_num, "lr": train_results.lr, @@ -238,8 +233,13 @@ def save_results( "step": train_results.step, } json.dump(training_state_dict, file, indent=2) - with open(os.path.join(results_path, "run_context.json"), "w") as file: + with open(results_path / "run_context.json", "w") as file: json.dump(run_context.asdict(), file, indent=2) + if (tokenizer_dir := output_dir / "tokenizer").exists(): + for src_file in tokenizer_dir.iterdir(): + if src_file.is_file(): + dest_file = results_path / src_file.name + shutil.copy2(src_file, dest_file) if config.out_repo_id: try: api = HfApi()