Skip to content

Commit

Permalink
Adding support for config files
Browse files Browse the repository at this point in the history
  • Loading branch information
jaidhyani committed Mar 8, 2024
1 parent 83496e2 commit 30a1b30
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 3 deletions.
22 changes: 19 additions & 3 deletions scripts/run_training.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import argparse
import copy
import json
from dataclasses import fields
from typing import Any

from delphi.train.gigaconfig import GigaConfig, debug_config
from delphi.train.training import run_training


def update_config(config: GigaConfig, new_vals: dict[str, Any]):
for field in fields(config):
if new_vals.get(field.name) is not None:
setattr(config, field.name, new_vals[field.name])


def main():
# Setup argparse
parser = argparse.ArgumentParser(description="Train a delphi model")
Expand All @@ -17,6 +25,12 @@ def main():
required=False,
help=f"Default: {field.default}",
)
parser.add_argument(
"--config_file",
help="Path to a json file containing config values (see sample_config.json).",
required=False,
type=str,
)
parser.add_argument(
"--debug",
help="Use debug config values (can still override with other arguments)",
Expand All @@ -29,9 +43,11 @@ def main():
config = copy.copy(debug_config)
else:
config = GigaConfig()
for field in fields(GigaConfig):
if getattr(args, field.name) is not None:
setattr(config, field.name, getattr(args, field.name))
update_config(config, vars(args))
if args.config_file is not None:
with open(args.config_file, "r") as f:
config_dict = json.load(f)
update_config(config, config_dict)
run_training(config)


Expand Down
34 changes: 34 additions & 0 deletions scripts/sample_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"out_dir": "out",
"eval_interval": 500,
"log_interval": 1,
"eval_iters": 10,
"eval_only": false,
"always_save_checkpoint": false,
"init_from": "scratch",
"wandb_log": true,
"wandb_entity": "jaiwithani",
"wandb_project": "delphi",
"wandb_run_name": "2024_03_07_17_43_09",
"batch_size": 64,
"max_seq_len": 512,
"vocab_size": 4096,
"dim": 48,
"n_layers": 2,
"n_heads": 2,
"n_kv_heads": 2,
"multiple_of": 32,
"dropout": 0.0,
"gradient_accumulation_steps": 4,
"learning_rate": 0.0005,
"max_epochs": 2,
"weight_decay": 0.1,
"beta1": 0.9,
"beta2": 0.95,
"grad_clip": 1.0,
"decay_lr": true,
"warmup_iters": 1000,
"min_lr": 0.0,
"train_sample_limit": 256,
"val_sample_limit": -1
}
6 changes: 6 additions & 0 deletions src/delphi/train/training.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import time
from dataclasses import fields

import torch
from torch.utils.data import DataLoader
Expand All @@ -18,6 +19,11 @@


def run_training(config: GigaConfig):
print("Starting training...")
print()
print("Config:")
for field in fields(config):
print(f" {field.name}: {getattr(config, field.name)}")
# system
device = get_device()

Expand Down

0 comments on commit 30a1b30

Please sign in to comment.