From 30a1b30bb870acfce77a6427c67559e5d1ccbe03 Mon Sep 17 00:00:00 2001 From: JaiDhyani Date: Thu, 7 Mar 2024 17:48:21 -0800 Subject: [PATCH] Adding support for config files --- scripts/run_training.py | 22 +++++++++++++++++++--- scripts/sample_config.json | 34 ++++++++++++++++++++++++++++++++++ src/delphi/train/training.py | 6 ++++++ 3 files changed, 59 insertions(+), 3 deletions(-) create mode 100644 scripts/sample_config.json diff --git a/scripts/run_training.py b/scripts/run_training.py index 819a67e4..53de4ed2 100644 --- a/scripts/run_training.py +++ b/scripts/run_training.py @@ -1,11 +1,19 @@ import argparse import copy +import json from dataclasses import fields +from typing import Any from delphi.train.gigaconfig import GigaConfig, debug_config from delphi.train.training import run_training +def update_config(config: GigaConfig, new_vals: dict[str, Any]): + for field in fields(config): + if new_vals.get(field.name) is not None: + setattr(config, field.name, new_vals[field.name]) + + def main(): # Setup argparse parser = argparse.ArgumentParser(description="Train a delphi model") @@ -17,6 +25,12 @@ def main(): required=False, help=f"Default: {field.default}", ) + parser.add_argument( + "--config_file", + help="Path to a json file containing config values (see sample_config.json).", + required=False, + type=str, + ) parser.add_argument( "--debug", help="Use debug config values (can still override with other arguments)", @@ -29,9 +43,11 @@ def main(): config = copy.copy(debug_config) else: config = GigaConfig() - for field in fields(GigaConfig): - if getattr(args, field.name) is not None: - setattr(config, field.name, getattr(args, field.name)) + update_config(config, vars(args)) + if args.config_file is not None: + with open(args.config_file, "r") as f: + config_dict = json.load(f) + update_config(config, config_dict) run_training(config) diff --git a/scripts/sample_config.json b/scripts/sample_config.json new file mode 100644 index 00000000..067c4da8 --- /dev/null +++ b/scripts/sample_config.json @@ -0,0 +1,34 @@ +{ + "out_dir": "out", + "eval_interval": 500, + "log_interval": 1, + "eval_iters": 10, + "eval_only": false, + "always_save_checkpoint": false, + "init_from": "scratch", + "wandb_log": true, + "wandb_entity": "jaiwithani", + "wandb_project": "delphi", + "wandb_run_name": "2024_03_07_17_43_09", + "batch_size": 64, + "max_seq_len": 512, + "vocab_size": 4096, + "dim": 48, + "n_layers": 2, + "n_heads": 2, + "n_kv_heads": 2, + "multiple_of": 32, + "dropout": 0.0, + "gradient_accumulation_steps": 4, + "learning_rate": 0.0005, + "max_epochs": 2, + "weight_decay": 0.1, + "beta1": 0.9, + "beta2": 0.95, + "grad_clip": 1.0, + "decay_lr": true, + "warmup_iters": 1000, + "min_lr": 0.0, + "train_sample_limit": 256, + "val_sample_limit": -1 +} \ No newline at end of file diff --git a/src/delphi/train/training.py b/src/delphi/train/training.py index d0a236ad..2d6b6198 100644 --- a/src/delphi/train/training.py +++ b/src/delphi/train/training.py @@ -1,5 +1,6 @@ import os import time +from dataclasses import fields import torch from torch.utils.data import DataLoader @@ -18,6 +19,11 @@ def run_training(config: GigaConfig): + print("Starting training...") + print() + print("Config:") + for field in fields(config): + print(f" {field.name}: {getattr(config, field.name)}") # system device = get_device()