Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save/push tokenizer when training #148

Merged
merged 1 commit into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion configs/stories/llama2/base.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,6 @@
"torch_seed": 42,
"dataset": {
"name": "delphi-suite/stories-tokenized"
}
},
"tokenizer": "delphi-suite/stories-tokenizer"
}
3 changes: 2 additions & 1 deletion configs/stories/mamba/base.json
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,6 @@
"torch_seed": 42,
"dataset": {
"name": "delphi-suite/stories-tokenized"
}
},
"tokenizer": "delphi-suite/stories-tokenizer"
}
1 change: 0 additions & 1 deletion scripts/run_training.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/env python3
import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Any
Expand Down
11 changes: 10 additions & 1 deletion src/delphi/train/config/training_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,18 @@ class TrainingConfig:
metadata={"help": "specify training and validation data"},
)

tokenizer: str = field(
default="",
metadata={
"help": "HF repo id or local directory containing the tokenizer. Used only to upload it to HF with the model, not for training"
},
)

# third party
wandb: Optional[WandbConfig] = None
out_repo_id: str
out_repo_id: str = field(
metadata={"help": "set to empty string to not push to repo"},
)

# debug
debug_config: DebugConfig = field(default_factory=DebugConfig)
6 changes: 6 additions & 0 deletions src/delphi/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import os
import time
from dataclasses import fields
from pathlib import Path

import torch
from tqdm import tqdm
from transformers import AutoTokenizer

from .checkpoint_step import log_and_save_checkpoint, should_save_checkpoint
from .config import TrainingConfig
Expand Down Expand Up @@ -35,6 +37,10 @@ def setup_training(config: TrainingConfig):
if config.wandb:
init_wandb(config=config)

if config.tokenizer:
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer)
tokenizer.save_pretrained(Path(config.output_dir) / "tokenizer")


def run_training(config: TrainingConfig) -> tuple[ModelTrainingState, RunContext]:
setup_training(config)
Expand Down
32 changes: 16 additions & 16 deletions src/delphi/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import math
import os
import shutil
import time
from collections.abc import Iterator
from dataclasses import asdict, dataclass, field
Expand Down Expand Up @@ -212,34 +213,33 @@ def save_results(
config, context (e.g. hardware), training step, etc
"""
iter_name = "main" if final else f"iter{train_results.iter_num}"
results_path = os.path.join(config.output_dir, iter_name)
output_dir = Path(config.output_dir)
results_path = output_dir / iter_name
logging.info(f"saving checkpoint to {results_path}")
os.makedirs(results_path, exist_ok=True)
with open(os.path.join(results_path, "training_config.json"), "w") as file:
results_path.mkdir(parents=True, exist_ok=True)
with open(results_path / "training_config.json", "w") as file:
json.dump(asdict(config), file, indent=2)
model = train_results.model
if isinstance(model, PreTrainedModel):
model.save_pretrained(
save_directory=results_path,
)
else:
st.save_model(
model,
os.path.join(results_path, "model.safetensors"),
)
train_results.model.save_pretrained(
save_directory=results_path,
)
if config.save_optimizer:
with open(os.path.join(results_path, "optimizer.pt"), "wb") as f:
with open(results_path / "optimizer.pt", "wb") as f:
torch.save(train_results.optimizer.state_dict(), f)
with open(os.path.join(results_path, "training_state.json"), "w") as file:
with open(results_path / "training_state.json", "w") as file:
training_state_dict = {
"iter_num": train_results.iter_num,
"lr": train_results.lr,
"epoch": train_results.epoch,
"step": train_results.step,
}
json.dump(training_state_dict, file, indent=2)
with open(os.path.join(results_path, "run_context.json"), "w") as file:
with open(results_path / "run_context.json", "w") as file:
json.dump(run_context.asdict(), file, indent=2)
if (tokenizer_dir := output_dir / "tokenizer").exists():
for src_file in tokenizer_dir.iterdir():
if src_file.is_file():
dest_file = results_path / src_file.name
shutil.copy2(src_file, dest_file)
if config.out_repo_id:
try:
api = HfApi()
Expand Down
Loading