Skip to content

Commit

Permalink
feat: added more options to the trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Nov 29, 2024
1 parent 14d1728 commit c9e2a63
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
4 changes: 4 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
### Added

- `edsnlp.data.read_parquet` now accept a `work_unit="fragment"` option to split tasks between workers by parquet fragment instead of row. When this is enabled, workers do not read every fragment while skipping 1 in n rows, but read all rows of 1/n fragments, which should be faster.
- Accept no validation data in `edsnlp.train` script
- Log the training config at the beginning of the trainings
- Support a specific model output dir path for trainings (`output_model_dir`), and whether to save the model or not (`save_model`)
- Specify whether to log the validation results or not (`logger=False`)

### Fixed

Expand Down
35 changes: 27 additions & 8 deletions edsnlp/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def train(
*,
nlp: Pipeline,
train_data: AsList[TrainingData],
val_data: AsList[Stream],
val_data: AsList[Stream] = [],
seed: int = 42,
max_steps: int = 1000,
optimizer: Union[ScheduledOptimizer, torch.optim.Optimizer] = None,
Expand All @@ -313,6 +313,10 @@ def train(
cpu: bool = False,
mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no",
output_dir: Union[Path, str] = Path("artifacts"),
output_model_dir: Optional[Union[Path, str]] = None,
save_model: bool = True,
logger: bool = True,
config_meta: Dict,
**kwargs,
):
"""
Expand Down Expand Up @@ -385,6 +389,15 @@ def train(
The output directory, which will contain a `model-last` directory
with the last model, and a `train_metrics.json` file with the
training metrics and stats.
output_model_dir: Optional[Union[Path, str]]
The directory where to save the model. If None, defaults to
`output_dir / "model-last"`.
save_model: bool
Whether to save the model or not. This can be useful if you are only
interested in the metrics, but no the model, and want to avoid
spending time dumping the model weights to the disk.
logger: bool
Whether to log the validation metrics in a rich table.
kwargs: Dict
Additional keyword arguments.
Expand All @@ -398,13 +411,14 @@ def train(
# accelerator.register_for_checkpointing(dataset)
is_main_process = accelerator.is_main_process
device = accelerator.device
print("Starting training on device:", device)
accelerator.print(config_meta["unresolved_config"].to_yaml_str())

output_dir = Path(output_dir or Path.cwd() / "artifacts")
model_path = output_dir / "model-last"
output_model_dir = output_model_dir or output_dir / "model-last"
train_metrics_path = output_dir / "train_metrics.json"
if is_main_process:
os.makedirs(output_dir, exist_ok=True)
config_meta["unresolved_config"].to_disk(output_dir / "training_config.yml")

validation_interval = validation_interval or max_steps // 10
checkpoint_interval = checkpoint_interval or validation_interval
Expand Down Expand Up @@ -501,7 +515,7 @@ def train(
set_seed(seed)
with (
RichTablePrinter(LOGGER_FIELDS, auto_refresh=False)
if is_main_process
if is_main_process and logger
else nullcontext()
) as logger:
# Training loop
Expand All @@ -526,10 +540,15 @@ def train(
)
cumulated_data.clear()
train_metrics_path.write_text(json.dumps(all_metrics, indent=2))
logger.log_metrics(flatten_dict(all_metrics[-1]))
if logger:
logger.log_metrics(flatten_dict(all_metrics[-1]))

if is_main_process and (step % checkpoint_interval) == 0:
nlp.to_disk(model_path)
if (
save_model
and is_main_process
and (step % checkpoint_interval) == 0
):
nlp.to_disk(output_model_dir)

if step == max_steps:
break
Expand Down Expand Up @@ -572,7 +591,7 @@ def train(
res[f"{name}_loss"] = res["loss"]
for k, v in res.items():
if (
isinstance(v, float)
isinstance(v, (float, int))
or isinstance(v, torch.Tensor)
and v.ndim == 0
):
Expand Down

0 comments on commit c9e2a63

Please sign in to comment.