Skip to content

Commit

Permalink
save/push tokenizer when training
Browse files Browse the repository at this point in the history
  • Loading branch information
jettjaniak committed May 21, 2024
1 parent 67df789 commit 3cb823f
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 20 deletions.
3 changes: 2 additions & 1 deletion configs/stories/llama2/base.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,6 @@
"torch_seed": 42,
"dataset": {
"name": "delphi-suite/stories-tokenized"
}
},
"tokenizer": "delphi-suite/stories-tokenizer"
}
3 changes: 2 additions & 1 deletion configs/stories/mamba/base.json
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,6 @@
"torch_seed": 42,
"dataset": {
"name": "delphi-suite/stories-tokenized"
}
},
"tokenizer": "delphi-suite/stories-tokenizer"
}
1 change: 0 additions & 1 deletion scripts/run_training.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/env python3
import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Any
Expand Down
11 changes: 10 additions & 1 deletion src/delphi/train/config/training_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,18 @@ class TrainingConfig:
metadata={"help": "specify training and validation data"},
)

tokenizer: str = field(
default="",
metadata={
"help": "HF repo id or local directory containing the tokenizer. Used only to upload it to HF with the model, not for training"
},
)

# third party
wandb: Optional[WandbConfig] = None
out_repo_id: str
out_repo_id: str = field(
metadata={"help": "set to empty string to not push to repo"},
)

# debug
debug_config: DebugConfig = field(default_factory=DebugConfig)
6 changes: 6 additions & 0 deletions src/delphi/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import os
import time
from dataclasses import fields
from pathlib import Path

import torch
from tqdm import tqdm
from transformers import AutoTokenizer

from .checkpoint_step import log_and_save_checkpoint, should_save_checkpoint
from .config import TrainingConfig
Expand Down Expand Up @@ -35,6 +37,10 @@ def setup_training(config: TrainingConfig):
if config.wandb:
init_wandb(config=config)

if config.tokenizer:
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer)
tokenizer.save_pretrained(Path(config.output_dir) / "tokenizer")


def run_training(config: TrainingConfig) -> tuple[ModelTrainingState, RunContext]:
setup_training(config)
Expand Down
32 changes: 16 additions & 16 deletions src/delphi/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import math
import os
import shutil
import time
from collections.abc import Iterator
from dataclasses import asdict, dataclass, field
Expand Down Expand Up @@ -212,34 +213,33 @@ def save_results(
config, context (e.g. hardware), training step, etc
"""
iter_name = "main" if final else f"iter{train_results.iter_num}"
results_path = os.path.join(config.output_dir, iter_name)
output_dir = Path(config.output_dir)
results_path = output_dir / iter_name
logging.info(f"saving checkpoint to {results_path}")
os.makedirs(results_path, exist_ok=True)
with open(os.path.join(results_path, "training_config.json"), "w") as file:
results_path.mkdir(parents=True, exist_ok=True)
with open(results_path / "training_config.json", "w") as file:
json.dump(asdict(config), file, indent=2)
model = train_results.model
if isinstance(model, PreTrainedModel):
model.save_pretrained(
save_directory=results_path,
)
else:
st.save_model(
model,
os.path.join(results_path, "model.safetensors"),
)
train_results.model.save_pretrained(
save_directory=results_path,
)
if config.save_optimizer:
with open(os.path.join(results_path, "optimizer.pt"), "wb") as f:
with open(results_path / "optimizer.pt", "wb") as f:
torch.save(train_results.optimizer.state_dict(), f)
with open(os.path.join(results_path, "training_state.json"), "w") as file:
with open(results_path / "training_state.json", "w") as file:
training_state_dict = {
"iter_num": train_results.iter_num,
"lr": train_results.lr,
"epoch": train_results.epoch,
"step": train_results.step,
}
json.dump(training_state_dict, file, indent=2)
with open(os.path.join(results_path, "run_context.json"), "w") as file:
with open(results_path / "run_context.json", "w") as file:
json.dump(run_context.asdict(), file, indent=2)
if (tokenizer_dir := output_dir / "tokenizer").exists():
for src_file in tokenizer_dir.iterdir():
if src_file.is_file():
dest_file = results_path / src_file.name
shutil.copy2(src_file, dest_file)
if config.out_repo_id:
try:
api = HfApi()
Expand Down

0 comments on commit 3cb823f

Please sign in to comment.