From 262972bd6570faa605802d95a2b4f4099fe6989f Mon Sep 17 00:00:00 2001 From: Goncalo Paulo <30472805+SrGonao@users.noreply.github.com> Date: Wed, 24 Apr 2024 20:37:02 +0200 Subject: [PATCH] llama2 & mamba training configs (#113) * Llama 2 example scripts * Mamba example scripts * Added all training config sizes Added an example Llama2 base, Mamba base and all sizes in their respective folders. * Moving and reformating configs * Llama 2 example scripts * Mamba example scripts * Cleaning extra files * remove src/llama2c * moved/renamed configs * bos/eos token ids * updated configs following the meeting * static stuff * updated base configs and simplified config structure * grad_acc_steps & batch_size fix * add @beartype to config classes that don't have it yet * don't ignore incorrect config keys * gradient_accumulation_steps fix * fix broken test config * Updating test configs to work with recent changes * config testing and fixes * re-adding accidentally deleted test config * fix tests that broke after recent changes * remove minibatch divisibility requirement * estimate_loss returns float, not tensor * log train/validation dataset size when training starts * fix incorrect train split default * Updating log spaced checkpoints and checkpointing intervals * update llama2 configs * stories cfgs: checkpoints, evals, bos & eos --------- Co-authored-by: Jett Co-authored-by: JaiDhyani Co-authored-by: Jannik Brinkmann --- .../static/configs => configs}/debug.json | 10 +-- .../sample_config.json | 19 +++--- .../sample_mamba.json | 10 +-- .../sample_transformers_bloom.json | 17 +++-- configs/stories/llama2/100k.json | 9 +++ configs/stories/llama2/10m.json | 9 +++ configs/stories/llama2/1m.json | 9 +++ configs/stories/llama2/2.5m.json | 9 +++ configs/stories/llama2/250k.json | 9 +++ configs/stories/llama2/25m.json | 9 +++ configs/stories/llama2/500k.json | 9 +++ configs/stories/llama2/50k.json | 9 +++ configs/stories/llama2/50m.json | 9 +++ configs/stories/llama2/5m.json | 9 +++ configs/stories/llama2/README.md | 7 ++ configs/stories/llama2/base.json | 52 +++++++++++++++ configs/stories/mamba/100k.json | 6 ++ configs/stories/mamba/10m.json | 6 ++ configs/stories/mamba/1m.json | 6 ++ configs/stories/mamba/2.5m.json | 6 ++ configs/stories/mamba/250k.json | 6 ++ configs/stories/mamba/25m.json | 6 ++ configs/stories/mamba/500k.json | 6 ++ configs/stories/mamba/50k.json | 6 ++ configs/stories/mamba/50m.json | 6 ++ configs/stories/mamba/5m.json | 6 ++ configs/stories/mamba/README.md | 10 +++ configs/stories/mamba/base.json | 53 +++++++++++++++ scripts/map_tokens.py | 1 - scripts/spacy_label_all_tokens.py | 1 - .../sample_mamba.json | 58 ----------------- scripts/training_run.sh | 6 ++ scripts/validate_configs.py | 57 ++++++++++++++++ setup.py | 2 +- src/delphi/constants.py | 3 +- src/delphi/static/configs/v0-llama2-1.6m.json | 24 ------- .../static/configs/v0-llama2-100k-quick.json | 24 ------- .../static/configs/v0-llama2-12.8m.json | 24 ------- src/delphi/static/configs/v0-llama2-200k.json | 24 ------- .../static/configs/v0-llama2-25.6m.json | 24 ------- src/delphi/static/configs/v0-llama2-3.2m.json | 24 ------- src/delphi/static/configs/v0-llama2-400k.json | 24 ------- src/delphi/static/configs/v0-llama2-6.4m.json | 24 ------- src/delphi/static/configs/v0-llama2-800k.json | 24 ------- .../{static => test_configs}/__init__.py | 0 src/delphi/test_configs/debug.json | 20 ++++++ .../debug_transformers_bloom.json | 10 +-- .../v0-llama2-100k.json | 6 ++ src/delphi/train/checkpoint_step.py | 10 +-- src/delphi/train/config/__init__.py | 2 - src/delphi/train/config/adam_config.py | 3 + src/delphi/train/config/data_config.py | 58 ----------------- src/delphi/train/config/dataset_config.py | 46 +++++++++++++ src/delphi/train/config/huggingface_config.py | 10 ++- src/delphi/train/config/training_config.py | 27 ++++---- src/delphi/train/config/utils.py | 17 +---- src/delphi/train/config/wandb_config.py | 9 +-- src/delphi/train/train_step.py | 2 +- src/delphi/train/training.py | 23 ++----- src/delphi/train/utils.py | 61 +++--------------- src/delphi/train/wandb_utils.py | 7 +- {src/delphi/static => static}/README.md | 0 .../static => static}/all_tokens_list.txt | Bin .../labelled_token_ids_dict.pkl | Bin .../static => static}/model_group_stats.pkl | Bin {src/delphi/static => static}/token_map.pkl | Bin tests/train/config/test_config_utils.py | 12 +++- tests/train/test_train_step.py | 17 +++-- tests/train/test_wandb_utils.py | 35 ++++------ 69 files changed, 531 insertions(+), 516 deletions(-) rename {src/delphi/static/configs => configs}/debug.json (72%) rename {scripts/training_config_examples => configs}/sample_config.json (83%) rename src/delphi/static/configs/debug_mamba.json => configs/sample_mamba.json (72%) rename {scripts/training_config_examples => configs}/sample_transformers_bloom.json (65%) create mode 100644 configs/stories/llama2/100k.json create mode 100644 configs/stories/llama2/10m.json create mode 100644 configs/stories/llama2/1m.json create mode 100644 configs/stories/llama2/2.5m.json create mode 100644 configs/stories/llama2/250k.json create mode 100644 configs/stories/llama2/25m.json create mode 100644 configs/stories/llama2/500k.json create mode 100644 configs/stories/llama2/50k.json create mode 100644 configs/stories/llama2/50m.json create mode 100644 configs/stories/llama2/5m.json create mode 100644 configs/stories/llama2/README.md create mode 100644 configs/stories/llama2/base.json create mode 100644 configs/stories/mamba/100k.json create mode 100644 configs/stories/mamba/10m.json create mode 100644 configs/stories/mamba/1m.json create mode 100644 configs/stories/mamba/2.5m.json create mode 100644 configs/stories/mamba/250k.json create mode 100644 configs/stories/mamba/25m.json create mode 100644 configs/stories/mamba/500k.json create mode 100644 configs/stories/mamba/50k.json create mode 100644 configs/stories/mamba/50m.json create mode 100644 configs/stories/mamba/5m.json create mode 100644 configs/stories/mamba/README.md create mode 100644 configs/stories/mamba/base.json delete mode 100644 scripts/training_config_examples/sample_mamba.json create mode 100644 scripts/training_run.sh create mode 100755 scripts/validate_configs.py delete mode 100644 src/delphi/static/configs/v0-llama2-1.6m.json delete mode 100644 src/delphi/static/configs/v0-llama2-100k-quick.json delete mode 100644 src/delphi/static/configs/v0-llama2-12.8m.json delete mode 100644 src/delphi/static/configs/v0-llama2-200k.json delete mode 100644 src/delphi/static/configs/v0-llama2-25.6m.json delete mode 100644 src/delphi/static/configs/v0-llama2-3.2m.json delete mode 100644 src/delphi/static/configs/v0-llama2-400k.json delete mode 100644 src/delphi/static/configs/v0-llama2-6.4m.json delete mode 100644 src/delphi/static/configs/v0-llama2-800k.json rename src/delphi/{static => test_configs}/__init__.py (100%) create mode 100644 src/delphi/test_configs/debug.json rename src/delphi/{static/configs => test_configs}/debug_transformers_bloom.json (81%) rename src/delphi/{static/configs => test_configs}/v0-llama2-100k.json (79%) delete mode 100644 src/delphi/train/config/data_config.py create mode 100644 src/delphi/train/config/dataset_config.py rename {src/delphi/static => static}/README.md (100%) rename {src/delphi/static => static}/all_tokens_list.txt (100%) rename {src/delphi/static => static}/labelled_token_ids_dict.pkl (100%) rename {src/delphi/static => static}/model_group_stats.pkl (100%) rename {src/delphi/static => static}/token_map.pkl (100%) diff --git a/src/delphi/static/configs/debug.json b/configs/debug.json similarity index 72% rename from src/delphi/static/configs/debug.json rename to configs/debug.json index 0aa6abed..bdfd6308 100644 --- a/src/delphi/static/configs/debug.json +++ b/configs/debug.json @@ -1,12 +1,9 @@ { - "vocab_size": 4096, "max_seq_len": 512, "max_epochs": 2, - "eval_interval": 1, "eval_iters": 1, - "data_config": { - "train_sample_limit": 256 - }, + "batch_ordering_seed": 42, + "torch_seed": 1337, "batch_size": 64, "model_config": { "model_class": "LlamaForCausalLM", @@ -16,5 +13,8 @@ "num_hidden_layers": 2, "num_key_value_heads": 2, "vocab_size": 4096 + }, + "dataset": { + "name": "delphi-suite/v0-tinystories-v2-clean-tokenized" } } \ No newline at end of file diff --git a/scripts/training_config_examples/sample_config.json b/configs/sample_config.json similarity index 83% rename from scripts/training_config_examples/sample_config.json rename to configs/sample_config.json index 3f326687..ac538399 100644 --- a/scripts/training_config_examples/sample_config.json +++ b/configs/sample_config.json @@ -1,15 +1,13 @@ { "run_name": "2024_03_15_17_28_14", "output_dir": "/Users/jaidhyani/Library/Application Support/delphi", + "dataset": { + "name": "delphi-suite/v0-tinystories-v2-clean-tokenized" + }, "device": "auto", - "eval_interval": 2000, "log_interval": 1, "eval_iters": 100, - "eval_only": false, - "always_save_checkpoint": false, - "init_from": "scratch", - "wandb_config": { - "log": false, + "wandb": { "project": "delphi", "entity": "set_wandb.entity_to_your_wandb_username_to_make_wandb_logging_work" }, @@ -38,16 +36,17 @@ "vocab_size": 4096 }, "max_epochs": 10, + "gradient_accumulation_steps": 1, "grad_clip": 1.0, - "optimizer": { - "gradient_accumulation_steps": 4, + "adam": { "learning_rate": 0.0005, "weight_decay": 0.1, "beta1": 0.9, "beta2": 0.95, - "grad_clip": 1.0, "decay_lr": true, "warmup_iters": 1000, "min_lr": 0.0 - } + }, + "batch_ordering_seed": 42, + "torch_seed": 1337 } \ No newline at end of file diff --git a/src/delphi/static/configs/debug_mamba.json b/configs/sample_mamba.json similarity index 72% rename from src/delphi/static/configs/debug_mamba.json rename to configs/sample_mamba.json index 8f502135..7fddcb26 100644 --- a/src/delphi/static/configs/debug_mamba.json +++ b/configs/sample_mamba.json @@ -1,13 +1,8 @@ { - "vocab_size": 4096, "max_seq_len": 512, "max_epochs": 2, - "eval_interval": 1, "log_interval": 1, "eval_iters": 10, - "data_config": { - "train_sample_limit": 64 - }, "batch_size": 8, "model_config": { "model_class": "MambaForCausalLM", @@ -18,5 +13,10 @@ "conv_kernel": 2, "expand": 2, "time_step_rank": 2 + }, + "batch_ordering_seed": 42, + "torch_seed": 1337, + "dataset": { + "name": "delphi-suite/v0-tinystories-v2-clean-tokenized" } } \ No newline at end of file diff --git a/scripts/training_config_examples/sample_transformers_bloom.json b/configs/sample_transformers_bloom.json similarity index 65% rename from scripts/training_config_examples/sample_transformers_bloom.json rename to configs/sample_transformers_bloom.json index 9a81ac89..793f6a8a 100644 --- a/scripts/training_config_examples/sample_transformers_bloom.json +++ b/configs/sample_transformers_bloom.json @@ -1,9 +1,7 @@ { - "vocab_size": 4096, "max_seq_len": 512, - "max_epochs": 10, - "eval_interval": 10, - "eval_iters": 8, + "max_epochs": 2, + "eval_iters": 1, "batch_size": 64, "model_config": { "model_class": "BloomForCausalLM", @@ -12,14 +10,19 @@ "bos_token_id": 1, "eos_token_id": 2, "hidden_dropout": 0.0, - "hidden_size": 64, + "hidden_size": 8, "initializer_range": 0.02, "layer_norm_epsilon": 1e-05, - "n_head": 8, - "n_layer": 10, + "n_head": 2, + "n_layer": 2, "pretraining_tp": 1, "slow_but_exact": false, "use_cache": true, "vocab_size": 4096 + }, + "batch_ordering_seed": 42, + "torch_seed": 1337, + "dataset": { + "name": "delphi-suite/v0-tinystories-v2-clean-tokenized" } } \ No newline at end of file diff --git a/configs/stories/llama2/100k.json b/configs/stories/llama2/100k.json new file mode 100644 index 00000000..601c3809 --- /dev/null +++ b/configs/stories/llama2/100k.json @@ -0,0 +1,9 @@ +{ + "model_config": { + "hidden_size": 12, + "intermediate_size": 48, + "num_attention_heads": 2, + "num_hidden_layers": 1, + "num_key_value_heads": 1 + } +} \ No newline at end of file diff --git a/configs/stories/llama2/10m.json b/configs/stories/llama2/10m.json new file mode 100644 index 00000000..224b4674 --- /dev/null +++ b/configs/stories/llama2/10m.json @@ -0,0 +1,9 @@ +{ + "model_config": { + "hidden_size": 332, + "intermediate_size": 896, + "num_attention_heads": 12, + "num_hidden_layers": 6, + "num_key_value_heads": 6 + } +} \ No newline at end of file diff --git a/configs/stories/llama2/1m.json b/configs/stories/llama2/1m.json new file mode 100644 index 00000000..52f1c893 --- /dev/null +++ b/configs/stories/llama2/1m.json @@ -0,0 +1,9 @@ +{ + "model_config": { + "hidden_size": 84, + "intermediate_size": 256, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "num_key_value_heads": 4 + } +} \ No newline at end of file diff --git a/configs/stories/llama2/2.5m.json b/configs/stories/llama2/2.5m.json new file mode 100644 index 00000000..4d55904c --- /dev/null +++ b/configs/stories/llama2/2.5m.json @@ -0,0 +1,9 @@ +{ + "model_config": { + "hidden_size": 168, + "intermediate_size": 384, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "num_key_value_heads": 4 + } +} \ No newline at end of file diff --git a/configs/stories/llama2/250k.json b/configs/stories/llama2/250k.json new file mode 100644 index 00000000..7a4ed066 --- /dev/null +++ b/configs/stories/llama2/250k.json @@ -0,0 +1,9 @@ +{ + "model_config": { + "hidden_size": 28, + "intermediate_size": 96, + "num_attention_heads": 4, + "num_hidden_layers": 2, + "num_key_value_heads": 2 + } +} \ No newline at end of file diff --git a/configs/stories/llama2/25m.json b/configs/stories/llama2/25m.json new file mode 100644 index 00000000..813d2b63 --- /dev/null +++ b/configs/stories/llama2/25m.json @@ -0,0 +1,9 @@ +{ + "model_config": { + "hidden_size": 484, + "intermediate_size": 1332, + "num_attention_heads": 16, + "num_hidden_layers": 8, + "num_key_value_heads": 8 + } +} \ No newline at end of file diff --git a/configs/stories/llama2/500k.json b/configs/stories/llama2/500k.json new file mode 100644 index 00000000..c4e0ec8e --- /dev/null +++ b/configs/stories/llama2/500k.json @@ -0,0 +1,9 @@ +{ + "model_config": { + "hidden_size": 52, + "intermediate_size": 184, + "num_attention_heads": 4, + "num_hidden_layers": 2, + "num_key_value_heads": 2 + } +} \ No newline at end of file diff --git a/configs/stories/llama2/50k.json b/configs/stories/llama2/50k.json new file mode 100644 index 00000000..53afb500 --- /dev/null +++ b/configs/stories/llama2/50k.json @@ -0,0 +1,9 @@ +{ + "model_config": { + "hidden_size": 6, + "intermediate_size": 24, + "num_attention_heads": 2, + "num_hidden_layers": 1, + "num_key_value_heads": 1 + } +} \ No newline at end of file diff --git a/configs/stories/llama2/50m.json b/configs/stories/llama2/50m.json new file mode 100644 index 00000000..3fa95022 --- /dev/null +++ b/configs/stories/llama2/50m.json @@ -0,0 +1,9 @@ +{ + "model_config": { + "hidden_size": 708, + "intermediate_size": 1896, + "num_attention_heads": 16, + "num_hidden_layers": 8, + "num_key_value_heads": 8 + } +} \ No newline at end of file diff --git a/configs/stories/llama2/5m.json b/configs/stories/llama2/5m.json new file mode 100644 index 00000000..839221f6 --- /dev/null +++ b/configs/stories/llama2/5m.json @@ -0,0 +1,9 @@ +{ + "model_config": { + "hidden_size": 232, + "intermediate_size": 512, + "num_attention_heads": 12, + "num_hidden_layers": 6, + "num_key_value_heads": 6 + } +} \ No newline at end of file diff --git a/configs/stories/llama2/README.md b/configs/stories/llama2/README.md new file mode 100644 index 00000000..be1f976e --- /dev/null +++ b/configs/stories/llama2/README.md @@ -0,0 +1,7 @@ +not using padding, so pad_token_id not set +use_cache - using default +pretraining_tp - experimental parallelization we're not using, which is the default +tie_word_embeddings - llama2 used False and this is better for interpretability, note that llama2.c is using True by default, which is probably more efficient use of parameters for very small models +rope settings are widely used defaults +attention_bias - no biases on QKV and output projection is the default and that's what we're using +attention_dropout - this is the only dropout llama2 can use, it's set to prob=0 by default and that's what we're using \ No newline at end of file diff --git a/configs/stories/llama2/base.json b/configs/stories/llama2/base.json new file mode 100644 index 00000000..427e124f --- /dev/null +++ b/configs/stories/llama2/base.json @@ -0,0 +1,52 @@ +{ + "model_config": { + "model_class": "LlamaForCausalLM", + "vocab_size": 4096, + "hidden_act": "silu", + "max_position_embeddings": 512, + "initializer_range": 0.02, + "rms_norm_eps": 1e-06, + "bos_token_id": 0, + "eos_token_id": 1, + "tie_word_embeddings": false, + "rope_theta": 10000.0, + "rope_scaling": null, + "attention_bias": false, + "attention_dropout": 0.0 + }, + "max_seq_len": 512, + "device": "auto", + "checkpoint_interval": 400, + "extra_checkpoint_iters": [ + 1, + 2, + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 512 + ], + "log_interval": 40, + "eval_iters": 10, + "batch_size": 256, + "max_epochs": 10, + "grad_clip": 1.0, + "gradient_accumulation_steps": 1, + "adam": { + "learning_rate": 0.0005, + "weight_decay": 0.1, + "beta1": 0.9, + "beta2": 0.95, + "decay_lr": true, + "warmup_iters": 1000, + "min_lr": 0.0 + }, + "batch_ordering_seed": 1337, + "torch_seed": 42, + "dataset": { + "name": "delphi-suite/stories-tokenized" + } +} \ No newline at end of file diff --git a/configs/stories/mamba/100k.json b/configs/stories/mamba/100k.json new file mode 100644 index 00000000..56d3d232 --- /dev/null +++ b/configs/stories/mamba/100k.json @@ -0,0 +1,6 @@ +{ + "model_config": { + "hidden_size": 24, + "num_hidden_layers": 2 + } +} \ No newline at end of file diff --git a/configs/stories/mamba/10m.json b/configs/stories/mamba/10m.json new file mode 100644 index 00000000..3afe882d --- /dev/null +++ b/configs/stories/mamba/10m.json @@ -0,0 +1,6 @@ +{ + "model_config": { + "hidden_size": 400, + "num_hidden_layers": 8 + } +} \ No newline at end of file diff --git a/configs/stories/mamba/1m.json b/configs/stories/mamba/1m.json new file mode 100644 index 00000000..24e23c84 --- /dev/null +++ b/configs/stories/mamba/1m.json @@ -0,0 +1,6 @@ +{ + "model_config": { + "hidden_size": 112, + "num_hidden_layers": 6 + } +} \ No newline at end of file diff --git a/configs/stories/mamba/2.5m.json b/configs/stories/mamba/2.5m.json new file mode 100644 index 00000000..fcc76cde --- /dev/null +++ b/configs/stories/mamba/2.5m.json @@ -0,0 +1,6 @@ +{ + "model_config": { + "hidden_size": 204, + "num_hidden_layers": 6 + } +} \ No newline at end of file diff --git a/configs/stories/mamba/250k.json b/configs/stories/mamba/250k.json new file mode 100644 index 00000000..a9ef4141 --- /dev/null +++ b/configs/stories/mamba/250k.json @@ -0,0 +1,6 @@ +{ + "model_config": { + "hidden_size": 36, + "num_hidden_layers": 4 + } +} \ No newline at end of file diff --git a/configs/stories/mamba/25m.json b/configs/stories/mamba/25m.json new file mode 100644 index 00000000..50dd4029 --- /dev/null +++ b/configs/stories/mamba/25m.json @@ -0,0 +1,6 @@ +{ + "model_config": { + "hidden_size": 664, + "num_hidden_layers": 8 + } +} \ No newline at end of file diff --git a/configs/stories/mamba/500k.json b/configs/stories/mamba/500k.json new file mode 100644 index 00000000..5d16a05f --- /dev/null +++ b/configs/stories/mamba/500k.json @@ -0,0 +1,6 @@ +{ + "model_config": { + "hidden_size": 76, + "num_hidden_layers": 4 + } +} \ No newline at end of file diff --git a/configs/stories/mamba/50k.json b/configs/stories/mamba/50k.json new file mode 100644 index 00000000..6e146429 --- /dev/null +++ b/configs/stories/mamba/50k.json @@ -0,0 +1,6 @@ +{ + "model_config": { + "hidden_size": 12, + "num_hidden_layers": 2 + } +} \ No newline at end of file diff --git a/configs/stories/mamba/50m.json b/configs/stories/mamba/50m.json new file mode 100644 index 00000000..7e230f98 --- /dev/null +++ b/configs/stories/mamba/50m.json @@ -0,0 +1,6 @@ +{ + "model_config": { + "hidden_size": 952, + "num_hidden_layers": 8 + } +} \ No newline at end of file diff --git a/configs/stories/mamba/5m.json b/configs/stories/mamba/5m.json new file mode 100644 index 00000000..488d283f --- /dev/null +++ b/configs/stories/mamba/5m.json @@ -0,0 +1,6 @@ +{ + "model_config": { + "hidden_size": 308, + "num_hidden_layers": 6 + } +} \ No newline at end of file diff --git a/configs/stories/mamba/README.md b/configs/stories/mamba/README.md new file mode 100644 index 00000000..3e83bccc --- /dev/null +++ b/configs/stories/mamba/README.md @@ -0,0 +1,10 @@ +pad_token_id - we're not using pad tokens, do we don't set it +layer_norm_eps - different than rms norm eps in mamba +initializer_range - different in mamba & llama +residual_in_fp32 - mamba specific parameter +time_step_* - mamba specific, sane defaults +there is no way to untie embeddings and unembeddings in mamba, they're tied by default +https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/mamba/modeling_mamba.py#L602-L610 +rescale_prenorm_residual was True in original paper, so we set it to True, despite HF default being false +using default for use_cache +state_size is default \ No newline at end of file diff --git a/configs/stories/mamba/base.json b/configs/stories/mamba/base.json new file mode 100644 index 00000000..ede565ef --- /dev/null +++ b/configs/stories/mamba/base.json @@ -0,0 +1,53 @@ +{ + "model_config": { + "model_class": "MambaForCausalLM", + "vocab_size": 4096, + "state_size": 16, + "layer_norm_epsilon": 1e-5, + "bos_token_id": 0, + "eos_token_id": 1, + "expand": 2, + "conv_kernel": 4, + "use_bias": false, + "use_conv_bias": true, + "hidden_act": "silu", + "initializer_range": 0.1, + "residual_in_fp32": true, + "rescale_prenorm_residual": true + }, + "max_seq_len": 512, + "device": "auto", + "checkpoint_interval": 400, + "extra_checkpoint_iters": [ + 1, + 2, + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 512 + ], + "log_interval": 40, + "eval_iters": 10, + "batch_size": 256, + "max_epochs": 10, + "grad_clip": 1.0, + "gradient_accumulation_steps": 1, + "adam": { + "learning_rate": 0.0005, + "weight_decay": 0.1, + "beta1": 0.9, + "beta2": 0.95, + "decay_lr": true, + "warmup_iters": 1000, + "min_lr": 0.0 + }, + "batch_ordering_seed": 1337, + "torch_seed": 42, + "dataset": { + "name": "delphi-suite/stories-tokenized" + } +} \ No newline at end of file diff --git a/scripts/map_tokens.py b/scripts/map_tokens.py index 5bafbffe..b0393fa5 100755 --- a/scripts/map_tokens.py +++ b/scripts/map_tokens.py @@ -5,7 +5,6 @@ import pandas as pd from datasets import Dataset -from delphi.constants import STATIC_ASSETS_DIR from delphi.eval.token_map import token_map from delphi.eval.utils import load_validation_dataset diff --git a/scripts/spacy_label_all_tokens.py b/scripts/spacy_label_all_tokens.py index 22a3b7f4..3ff8ccda 100644 --- a/scripts/spacy_label_all_tokens.py +++ b/scripts/spacy_label_all_tokens.py @@ -6,7 +6,6 @@ from tqdm.auto import tqdm from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast -from delphi.constants import STATIC_ASSETS_DIR from delphi.eval import spacy_token_labelling diff --git a/scripts/training_config_examples/sample_mamba.json b/scripts/training_config_examples/sample_mamba.json deleted file mode 100644 index da089c4f..00000000 --- a/scripts/training_config_examples/sample_mamba.json +++ /dev/null @@ -1,58 +0,0 @@ -{ - "run_name": "2024_03_15_21_56_35", - "output_dir": "/Users/jaidhyani/Library/Application Support/delphi", - "device": "auto", - "eval_interval": 2000, - "log_interval": 1, - "eval_iters": 100, - "eval_only": false, - "always_save_checkpoint": false, - "init_from": "scratch", - "wandb_config": { - "log": false, - "project": "delphi", - "entity": "set_wandb.entity_to_your_wandb_username_to_make_wandb_logging_work" - }, - "batch_size": 64, - "max_seq_len": 512, - "model_config": { - "model_class": "MambaForCausalLM", - "vocab_size": 4096, - "hidden_size": 768, - "state_size": 16, - "num_hidden_layers": 32, - "conv_kernel": 4, - "expand": 2, - "use_bias": false, - "use_conv_bias": true, - "bos_token_id": 0, - "eos_token_id": 0, - "pad_token_id": 0, - "time_step_rank": "auto", - "time_step_scale": 1.0, - "time_step_min": 0.001, - "time_step_max": 0.1, - "time_step_init_scheme": "random", - "time_step_floor": 0.0001, - "layer_norm_epsilon": 1e-05, - "hidden_act": "silu", - "initializer_range": 0.1, - "residual_in_fp32": true, - "rescale_prenorm_residual": false, - "use_cache": true, - "tie_word_embeddings": true - }, - "max_epochs": 10, - "grad_clip": 1.0, - "optimizer": { - "gradient_accumulation_steps": 4, - "learning_rate": 0.0005, - "weight_decay": 0.1, - "beta1": 0.9, - "beta2": 0.95, - "grad_clip": 1.0, - "decay_lr": true, - "warmup_iters": 1000, - "min_lr": 0.0 - } -} \ No newline at end of file diff --git a/scripts/training_run.sh b/scripts/training_run.sh new file mode 100644 index 00000000..7d1b2fe8 --- /dev/null +++ b/scripts/training_run.sh @@ -0,0 +1,6 @@ +counter=1 +for config in 4-76.json 6-112 6-204 +do + CUDA_VISIBLE_DEVICES=$counter CUBLAS_WORKSPACE_CONFIG=:4096:8 python scripts/run_training.py --config scripts/$config & > $config.log + counter=$((counter+1)) +done diff --git a/scripts/validate_configs.py b/scripts/validate_configs.py new file mode 100755 index 00000000..46d3f5d1 --- /dev/null +++ b/scripts/validate_configs.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +import argparse +import pathlib + +from delphi.train.config import build_config_from_files_and_overrides + + +def get_config_path_with_base(config_path: pathlib.Path) -> list[pathlib.Path]: + """If config path is in directory which includes base.json, include that as the first config.""" + if (config_path.parent / "base.json").exists(): + return [config_path.parent / "base.json", config_path] + return [config_path] + + +def get_config_paths(config_path: str) -> list[list[pathlib.Path]]: + """If config path is a directory, recursively glob all json files in it. Otherwise, just use the path and create a list of 1.""" + paths = ( + list(pathlib.Path(config_path).rglob("*.json")) + if pathlib.Path(config_path).is_dir() + else [pathlib.Path(config_path)] + ) + # exclude base.json files + paths = [path for path in paths if not path.name.startswith("base")] + # supplement non-base configs with base.json if it exists in same dir + return [get_config_path_with_base(path) for path in paths] + + +def main(): + parser = argparse.ArgumentParser() + # we take one positional argument, a path to a directory or config + parser.add_argument( + "config_path", + type=str, + help="path to a training config json or directory of training config jsons", + ) + args = parser.parse_args() + config_paths = get_config_paths(args.config_path) + print( + f"validating configs: {' | '.join(str(config_path[-1]) for config_path in config_paths)}" + ) + errors = [] + for config_path in config_paths: + try: + build_config_from_files_and_overrides(config_path, {}) + except Exception as e: + errors.append((config_path, e)) + continue + if errors: + print("errors:") + for config_path, e in errors: + print(f" {config_path[-1]}: {e}") + else: + print("all configs loaded successfully") + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index fcfffc0d..4a92f04d 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ packages=find_packages(where="src"), package_dir={"": "src"}, package_data={ - "delphi": ["static/**/*"], + "delphi": ["test_configs/**/*"], }, include_package_data=True, ) diff --git a/src/delphi/constants.py b/src/delphi/constants.py index b18c50af..c86e97fb 100644 --- a/src/delphi/constants.py +++ b/src/delphi/constants.py @@ -2,8 +2,7 @@ from pathlib import Path from typing import cast -STATIC_ASSETS_DIR = files("delphi.static") -CONFIG_PRESETS_DIR = cast(Path, STATIC_ASSETS_DIR / "configs") +TEST_CONFIGS_DIR = cast(Path, files("delphi.test_configs")) CORPUS_DATASET = "delphi-suite/stories" TINYSTORIES_TOKENIZED_HF_DATASET = "delphi-suite/v0-tinystories-v2-clean-tokenized" diff --git a/src/delphi/static/configs/v0-llama2-1.6m.json b/src/delphi/static/configs/v0-llama2-1.6m.json deleted file mode 100644 index 95261935..00000000 --- a/src/delphi/static/configs/v0-llama2-1.6m.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "model_config": { - "model_class": "LlamaForCausalLM", - "attention_bias": false, - "attention_dropout": 0.0, - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 48, - "initializer_range": 0.02, - "intermediate_size": 128, - "max_position_embeddings": 512, - "num_attention_heads": 8, - "num_hidden_layers": 4, - "num_key_value_heads": 4, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": null, - "rope_theta": 10000.0, - "tie_word_embeddings": true, - "use_cache": true, - "vocab_size": 4096 - } -} \ No newline at end of file diff --git a/src/delphi/static/configs/v0-llama2-100k-quick.json b/src/delphi/static/configs/v0-llama2-100k-quick.json deleted file mode 100644 index 95261935..00000000 --- a/src/delphi/static/configs/v0-llama2-100k-quick.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "model_config": { - "model_class": "LlamaForCausalLM", - "attention_bias": false, - "attention_dropout": 0.0, - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 48, - "initializer_range": 0.02, - "intermediate_size": 128, - "max_position_embeddings": 512, - "num_attention_heads": 8, - "num_hidden_layers": 4, - "num_key_value_heads": 4, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": null, - "rope_theta": 10000.0, - "tie_word_embeddings": true, - "use_cache": true, - "vocab_size": 4096 - } -} \ No newline at end of file diff --git a/src/delphi/static/configs/v0-llama2-12.8m.json b/src/delphi/static/configs/v0-llama2-12.8m.json deleted file mode 100644 index 95261935..00000000 --- a/src/delphi/static/configs/v0-llama2-12.8m.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "model_config": { - "model_class": "LlamaForCausalLM", - "attention_bias": false, - "attention_dropout": 0.0, - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 48, - "initializer_range": 0.02, - "intermediate_size": 128, - "max_position_embeddings": 512, - "num_attention_heads": 8, - "num_hidden_layers": 4, - "num_key_value_heads": 4, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": null, - "rope_theta": 10000.0, - "tie_word_embeddings": true, - "use_cache": true, - "vocab_size": 4096 - } -} \ No newline at end of file diff --git a/src/delphi/static/configs/v0-llama2-200k.json b/src/delphi/static/configs/v0-llama2-200k.json deleted file mode 100644 index 95261935..00000000 --- a/src/delphi/static/configs/v0-llama2-200k.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "model_config": { - "model_class": "LlamaForCausalLM", - "attention_bias": false, - "attention_dropout": 0.0, - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 48, - "initializer_range": 0.02, - "intermediate_size": 128, - "max_position_embeddings": 512, - "num_attention_heads": 8, - "num_hidden_layers": 4, - "num_key_value_heads": 4, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": null, - "rope_theta": 10000.0, - "tie_word_embeddings": true, - "use_cache": true, - "vocab_size": 4096 - } -} \ No newline at end of file diff --git a/src/delphi/static/configs/v0-llama2-25.6m.json b/src/delphi/static/configs/v0-llama2-25.6m.json deleted file mode 100644 index 95261935..00000000 --- a/src/delphi/static/configs/v0-llama2-25.6m.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "model_config": { - "model_class": "LlamaForCausalLM", - "attention_bias": false, - "attention_dropout": 0.0, - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 48, - "initializer_range": 0.02, - "intermediate_size": 128, - "max_position_embeddings": 512, - "num_attention_heads": 8, - "num_hidden_layers": 4, - "num_key_value_heads": 4, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": null, - "rope_theta": 10000.0, - "tie_word_embeddings": true, - "use_cache": true, - "vocab_size": 4096 - } -} \ No newline at end of file diff --git a/src/delphi/static/configs/v0-llama2-3.2m.json b/src/delphi/static/configs/v0-llama2-3.2m.json deleted file mode 100644 index 7a2c9689..00000000 --- a/src/delphi/static/configs/v0-llama2-3.2m.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "model_config": { - "model_class": "LlamaForCausalLM", - "attention_bias": false, - "attention_dropout": 0.0, - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 48, - "initializer_range": 0.02, - "intermediate_size": 128, - "max_position_embeddings": 512, - "num_attention_heads": 8, - "num_hidden_layers": 4, - "num_key_value_heads": 4, - "pretraining_tp": 1, - "rms_norm_eps": 0.00001, - "rope_scaling": null, - "rope_theta": 10000.0, - "tie_word_embeddings": true, - "use_cache": true, - "vocab_size": 4096 - } -} \ No newline at end of file diff --git a/src/delphi/static/configs/v0-llama2-400k.json b/src/delphi/static/configs/v0-llama2-400k.json deleted file mode 100644 index 95261935..00000000 --- a/src/delphi/static/configs/v0-llama2-400k.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "model_config": { - "model_class": "LlamaForCausalLM", - "attention_bias": false, - "attention_dropout": 0.0, - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 48, - "initializer_range": 0.02, - "intermediate_size": 128, - "max_position_embeddings": 512, - "num_attention_heads": 8, - "num_hidden_layers": 4, - "num_key_value_heads": 4, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": null, - "rope_theta": 10000.0, - "tie_word_embeddings": true, - "use_cache": true, - "vocab_size": 4096 - } -} \ No newline at end of file diff --git a/src/delphi/static/configs/v0-llama2-6.4m.json b/src/delphi/static/configs/v0-llama2-6.4m.json deleted file mode 100644 index 95261935..00000000 --- a/src/delphi/static/configs/v0-llama2-6.4m.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "model_config": { - "model_class": "LlamaForCausalLM", - "attention_bias": false, - "attention_dropout": 0.0, - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 48, - "initializer_range": 0.02, - "intermediate_size": 128, - "max_position_embeddings": 512, - "num_attention_heads": 8, - "num_hidden_layers": 4, - "num_key_value_heads": 4, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": null, - "rope_theta": 10000.0, - "tie_word_embeddings": true, - "use_cache": true, - "vocab_size": 4096 - } -} \ No newline at end of file diff --git a/src/delphi/static/configs/v0-llama2-800k.json b/src/delphi/static/configs/v0-llama2-800k.json deleted file mode 100644 index 95261935..00000000 --- a/src/delphi/static/configs/v0-llama2-800k.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "model_config": { - "model_class": "LlamaForCausalLM", - "attention_bias": false, - "attention_dropout": 0.0, - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 48, - "initializer_range": 0.02, - "intermediate_size": 128, - "max_position_embeddings": 512, - "num_attention_heads": 8, - "num_hidden_layers": 4, - "num_key_value_heads": 4, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": null, - "rope_theta": 10000.0, - "tie_word_embeddings": true, - "use_cache": true, - "vocab_size": 4096 - } -} \ No newline at end of file diff --git a/src/delphi/static/__init__.py b/src/delphi/test_configs/__init__.py similarity index 100% rename from src/delphi/static/__init__.py rename to src/delphi/test_configs/__init__.py diff --git a/src/delphi/test_configs/debug.json b/src/delphi/test_configs/debug.json new file mode 100644 index 00000000..bdfd6308 --- /dev/null +++ b/src/delphi/test_configs/debug.json @@ -0,0 +1,20 @@ +{ + "max_seq_len": 512, + "max_epochs": 2, + "eval_iters": 1, + "batch_ordering_seed": 42, + "torch_seed": 1337, + "batch_size": 64, + "model_config": { + "model_class": "LlamaForCausalLM", + "hidden_size": 48, + "intermediate_size": 48, + "num_attention_heads": 2, + "num_hidden_layers": 2, + "num_key_value_heads": 2, + "vocab_size": 4096 + }, + "dataset": { + "name": "delphi-suite/v0-tinystories-v2-clean-tokenized" + } +} \ No newline at end of file diff --git a/src/delphi/static/configs/debug_transformers_bloom.json b/src/delphi/test_configs/debug_transformers_bloom.json similarity index 81% rename from src/delphi/static/configs/debug_transformers_bloom.json rename to src/delphi/test_configs/debug_transformers_bloom.json index 8ec0f5b0..793f6a8a 100644 --- a/src/delphi/static/configs/debug_transformers_bloom.json +++ b/src/delphi/test_configs/debug_transformers_bloom.json @@ -1,12 +1,7 @@ { - "vocab_size": 4096, "max_seq_len": 512, "max_epochs": 2, - "eval_interval": 1, "eval_iters": 1, - "data_config": { - "train_sample_limit": 256 - }, "batch_size": 64, "model_config": { "model_class": "BloomForCausalLM", @@ -24,5 +19,10 @@ "slow_but_exact": false, "use_cache": true, "vocab_size": 4096 + }, + "batch_ordering_seed": 42, + "torch_seed": 1337, + "dataset": { + "name": "delphi-suite/v0-tinystories-v2-clean-tokenized" } } \ No newline at end of file diff --git a/src/delphi/static/configs/v0-llama2-100k.json b/src/delphi/test_configs/v0-llama2-100k.json similarity index 79% rename from src/delphi/static/configs/v0-llama2-100k.json rename to src/delphi/test_configs/v0-llama2-100k.json index 95261935..b89872ce 100644 --- a/src/delphi/static/configs/v0-llama2-100k.json +++ b/src/delphi/test_configs/v0-llama2-100k.json @@ -1,4 +1,5 @@ { + "max_seq_len": 512, "model_config": { "model_class": "LlamaForCausalLM", "attention_bias": false, @@ -20,5 +21,10 @@ "tie_word_embeddings": true, "use_cache": true, "vocab_size": 4096 + }, + "batch_ordering_seed": 42, + "torch_seed": 1337, + "dataset": { + "name": "delphi-suite/v0-tinystories-v2-clean-tokenized" } } \ No newline at end of file diff --git a/src/delphi/train/checkpoint_step.py b/src/delphi/train/checkpoint_step.py index df0e5c1b..a86a8680 100644 --- a/src/delphi/train/checkpoint_step.py +++ b/src/delphi/train/checkpoint_step.py @@ -39,13 +39,7 @@ def log_and_save_checkpoint( split_to_ds={"train": train_ds, "val": validation_ds}, device=run_context.device, epoch=mts.epoch, - feature_names={ - "train": config.data_config.train_feature, - "val": ( - config.data_config.validation_feature - or config.data_config.train_feature - ), - }, + feature_name=config.dataset.feature, ) logging.info( f"step {mts.iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}" @@ -58,7 +52,7 @@ def log_and_save_checkpoint( run_context=run_context, results_path=results_path, ) - if config.wandb_config.log: + if config.wandb: log_to_wandb( mts=mts, losses=losses, diff --git a/src/delphi/train/config/__init__.py b/src/delphi/train/config/__init__.py index 65e65a9b..fe0e825b 100644 --- a/src/delphi/train/config/__init__.py +++ b/src/delphi/train/config/__init__.py @@ -4,8 +4,6 @@ build_config_dict_from_files, build_config_from_files_and_overrides, dot_notation_to_dict, - get_preset_paths, get_user_config_path, - load_preset, ) from .wandb_config import WandbConfig diff --git a/src/delphi/train/config/adam_config.py b/src/delphi/train/config/adam_config.py index b448b5ba..480f1165 100644 --- a/src/delphi/train/config/adam_config.py +++ b/src/delphi/train/config/adam_config.py @@ -1,6 +1,9 @@ from dataclasses import dataclass +from beartype import beartype + +@beartype @dataclass class AdamConfig: # adamw optimizer diff --git a/src/delphi/train/config/data_config.py b/src/delphi/train/config/data_config.py deleted file mode 100644 index 10fa303c..00000000 --- a/src/delphi/train/config/data_config.py +++ /dev/null @@ -1,58 +0,0 @@ -from dataclasses import dataclass, field -from typing import Optional - -from beartype import beartype - -from delphi import constants - - -@beartype -@dataclass(frozen=True) -class DataConfig: - train_dataset: str = field( - # TODO: remove default after updating configs to include this field - default=constants.TINYSTORIES_TOKENIZED_HF_DATASET, - metadata={"help": "tokenized dataset on huggingface to use for train"}, - ) - train_split: str = field( - default="train", - metadata={"help": "split of the train dataset to use for train"}, - ) - train_feature: str = field( - default="tokens", - metadata={ - "help": "feature in the train dataset to use for train; should be a list of max_seq_len+1 token ints" - }, - ) - train_sample_limit: Optional[int] = field( - default=None, - metadata={"help": "limit the number of train samples to use"}, - ) - - validation_dataset: Optional[str] = field( - default=None, - metadata={ - "help": ( - "tokenized dataset on huggingface to use for validation. " - "If not set, validation defaults to using train_dataset" - ) - }, - ) - validation_split: str = field( - default="validation", - metadata={"help": "split of the validation dataset to use for validation"}, - ) - validation_feature: Optional[str] = field( - default=None, - metadata={ - "help": ( - "feature in the validation dataset to use for validation; " - "should be a list of max_seq_len+1 token ints. " - "If not set, validation defaults to using train_feature." - ) - }, - ) - validation_sample_limit: Optional[int] = field( - default=None, - metadata={"help": "limit the number of validation samples to use"}, - ) diff --git a/src/delphi/train/config/dataset_config.py b/src/delphi/train/config/dataset_config.py new file mode 100644 index 00000000..bcc65a1a --- /dev/null +++ b/src/delphi/train/config/dataset_config.py @@ -0,0 +1,46 @@ +from dataclasses import dataclass, field +from typing import cast + +import datasets +from beartype import beartype +from datasets import Dataset, load_dataset + + +@beartype +@dataclass(frozen=True) +class DatasetConfig: + name: str = field( + metadata={"help": "tokenized dataset on huggingface to use for train"}, + ) + feature: str = field( + default="tokens", + metadata={ + "help": "feature in the train dataset to use for train; should be a list of max_seq_len+1 token ints" + }, + ) + train_split: str = field( + default="train", + metadata={"help": "split of the train dataset to use for validation"}, + ) + validation_split: str = field( + default="validation", + metadata={"help": "split of the validation dataset to use for validation"}, + ) + + def _load(self, split) -> Dataset: + ds = load_dataset( + self.name, + split=split, + features=datasets.Features( + {self.feature: datasets.Sequence(datasets.Value("int32"))} + ), + ) + ds = cast(Dataset, ds) + ds.set_format("torch") + return ds + + def load_train(self) -> Dataset: + return self._load(self.train_split) + + def load_validation(self) -> Dataset: + return self._load(self.validation_split) diff --git a/src/delphi/train/config/huggingface_config.py b/src/delphi/train/config/huggingface_config.py index a0164fd5..d51bd6a3 100644 --- a/src/delphi/train/config/huggingface_config.py +++ b/src/delphi/train/config/huggingface_config.py @@ -1,11 +1,19 @@ -from dataclasses import dataclass +import os +from dataclasses import dataclass, field from typing import Optional from beartype import beartype +def get_hf_token(): + token = os.getenv("HF_TOKEN", "") + assert token, "HF_TOKEN env variable must be set or specified manually" + return token + + @beartype @dataclass(frozen=True) class HuggingfaceConfig: repo_id: Optional[str] = None push_checkpoints_to_hub: bool = False + token: str = field(default_factory=get_hf_token) diff --git a/src/delphi/train/config/training_config.py b/src/delphi/train/config/training_config.py index ce59bffa..0dcf6714 100644 --- a/src/delphi/train/config/training_config.py +++ b/src/delphi/train/config/training_config.py @@ -7,20 +7,21 @@ from beartype import beartype from .adam_config import AdamConfig -from .data_config import DataConfig +from .dataset_config import DatasetConfig from .debug_config import DebugConfig from .huggingface_config import HuggingfaceConfig from .wandb_config import WandbConfig @beartype -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class TrainingConfig: model_config: dict[str, Any] = field( metadata={ "help": "model config; class_name=name of model class in transformers, everything else is kwargs for the corresponding model config" }, ) + max_seq_len: int = field(metadata={"help": "max sequence length"}) # meta run_name: str = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") output_dir: str = field( @@ -56,13 +57,10 @@ class TrainingConfig: batch_size: int = field( default=64, metadata={ - "help": "if gradient_accumulation_steps > 1, this is the micro-batch size" + "help": "number of samples used to compute the gradient for a single optimizer step" }, ) - # model config - max_seq_len: int = field(default=512, metadata={"help": "max sequence length"}) - # training max_epochs: int = field( default=10, metadata={"help": "total number of training epochs"} @@ -71,26 +69,29 @@ class TrainingConfig: default=1.0, metadata={"help": "clip gradients at this value, or disable if == 0.0"}, ) - gradient_accumulation_steps: int = 4 # used to simulate larger batch sizes + gradient_accumulation_steps: int = field( + default=1, + metadata={ + "help": "if > 1 reduces memory usage by computing gradient in microbatches" + }, + ) # (adamw) optimizer adam: AdamConfig = field(default_factory=AdamConfig) # reproducibility batch_ordering_seed: int = field( - default=1337, metadata={"help": "seed used for pseudorandomly sampling data during training"}, ) - torch_seed: int = field(default=42, metadata={"help": "seed used for torch"}) + torch_seed: int = field(metadata={"help": "seed used for torch"}) # data - data_config: DataConfig = field( - default_factory=DataConfig, + dataset: DatasetConfig = field( metadata={"help": "specify training and validation data"}, ) # third party - wandb_config: WandbConfig = field(default_factory=WandbConfig) - huggingface: HuggingfaceConfig = field(default_factory=HuggingfaceConfig) + wandb: Optional[WandbConfig] = None + hf: Optional[HuggingfaceConfig] = None # debug debug_config: DebugConfig = field(default_factory=DebugConfig) diff --git a/src/delphi/train/config/utils.py b/src/delphi/train/config/utils.py index c94875b2..b645cd5e 100644 --- a/src/delphi/train/config/utils.py +++ b/src/delphi/train/config/utils.py @@ -2,7 +2,6 @@ import json import logging import os -from collections.abc import Iterable from dataclasses import fields, is_dataclass from datetime import datetime from pathlib import Path @@ -10,10 +9,9 @@ from typing import Any, Type, TypeVar, Union import platformdirs +from dacite import Config as dacite_config from dacite import from_dict -from delphi.constants import CONFIG_PRESETS_DIR - from .training_config import TrainingConfig T = TypeVar("T") @@ -42,11 +40,6 @@ def merge_dicts(*dicts: dict[str, Any]) -> dict[str, Any]: return merged -def get_preset_paths() -> Iterable[Path]: - """This gets all the paths to the preset config files in the static preset config dir.""" - return CONFIG_PRESETS_DIR.glob("*.json") - - def get_user_config_path() -> Path: """ This enables a user-specific config to always be included in the training config. @@ -132,13 +125,7 @@ def build_config_from_files_and_overrides( cast_types(overrides, TrainingConfig) merge_two_dicts(merge_into=combined_config, merge_from=overrides) set_backup_vals(combined_config, config_files) - return from_dict(TrainingConfig, combined_config) - - -def load_preset(preset_name: str) -> TrainingConfig: - """Load a preset config by name, e.g. `load_preset("debug")`.""" - preset_path = CONFIG_PRESETS_DIR / f"{preset_name}.json" - return build_config_from_files_and_overrides([preset_path], {}) + return from_dict(TrainingConfig, combined_config, config=dacite_config(strict=True)) def dot_notation_to_dict(vars: dict[str, Any]) -> dict[str, Any]: diff --git a/src/delphi/train/config/wandb_config.py b/src/delphi/train/config/wandb_config.py index 23fbc3a0..9b4e3c55 100644 --- a/src/delphi/train/config/wandb_config.py +++ b/src/delphi/train/config/wandb_config.py @@ -1,10 +1,11 @@ from dataclasses import dataclass -from datetime import datetime +from beartype import beartype + +@beartype @dataclass class WandbConfig: - log: bool = False - project: str = "delphi" - entity: str = "set_wandb.entity_to_your_wandb_username_to_make_wandb_logging_work" + project: str + entity: str silence: bool = False diff --git a/src/delphi/train/train_step.py b/src/delphi/train/train_step.py index 49c4d1d9..d0fb50ee2 100644 --- a/src/delphi/train/train_step.py +++ b/src/delphi/train/train_step.py @@ -38,7 +38,7 @@ def train_step( num_minibatches=config.gradient_accumulation_steps, step=model_training_state.step, device=device, - feature_name=config.data_config.train_feature, + feature_name=config.dataset.feature, ) total_loss = accumulate_gradients( model=model, diff --git a/src/delphi/train/training.py b/src/delphi/train/training.py index 23a19515..a20132c3 100644 --- a/src/delphi/train/training.py +++ b/src/delphi/train/training.py @@ -18,7 +18,6 @@ get_device, get_indices_for_epoch, initialize_model_training_state, - load_tokens_dataset_from_huggingface, set_lr, setup_determinism, ) @@ -37,7 +36,7 @@ def setup_training(config: TrainingConfig): setup_determinism(config.torch_seed) # wandb setup - if config.wandb_config.log: + if config.wandb: init_wandb(config=config) @@ -59,22 +58,10 @@ def run_training(config: TrainingConfig) -> tuple[ModelTrainingState, RunContext # load data logging.info("Loading data...") - train_ds = load_tokens_dataset_from_huggingface( - hf_dataset_id=config.data_config.train_dataset, - split=config.data_config.train_split, - tokens_feature=config.data_config.train_feature, - limit=config.data_config.train_sample_limit, - ) - validation_ds = load_tokens_dataset_from_huggingface( - hf_dataset_id=( - config.data_config.validation_dataset or config.data_config.train_dataset - ), - split=config.data_config.validation_split, - tokens_feature=( - config.data_config.validation_feature or config.data_config.train_feature - ), - limit=config.data_config.validation_sample_limit, - ) + train_ds = config.dataset.load_train() + validation_ds = config.dataset.load_validation() + logging.info(f"Train dataset: {len(train_ds)} samples") + logging.info(f"Validation dataset: {len(validation_ds)} samples") # derive iteration params steps_per_epoch = len(train_ds) // config.batch_size diff --git a/src/delphi/train/utils.py b/src/delphi/train/utils.py index 7dd5956d..11ec0059 100644 --- a/src/delphi/train/utils.py +++ b/src/delphi/train/utils.py @@ -6,13 +6,12 @@ from collections.abc import Generator from dataclasses import asdict, dataclass, field from pathlib import Path -from typing import Any, Optional, Type, cast +from typing import Any, Type, cast -import datasets import safetensors.torch as st import torch import transformers -from datasets import Dataset, load_dataset +from datasets import Dataset from huggingface_hub import HfApi from torch.optim import AdamW from transformers import PreTrainedModel @@ -163,11 +162,7 @@ def get_xy_batch( feature_name: str, device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Get a batch of data from a dataset given a batch number and indices - - Args: - """ + """Get a batch of data from a dataset given a batch number and indices""" start = batch_num * batch_size end = (batch_num + 1) * batch_size batch_indices = indices[start:end] @@ -187,9 +182,6 @@ def gen_minibatches( """ Generate minibatches from a dataset given a step and indices """ - assert ( - batch_size % num_minibatches == 0 - ), "batch_size must be divisible by num_minibatches" minibatch_size = batch_size // num_minibatches first_minibatch_num = num_minibatches * step for i in range(num_minibatches): @@ -211,7 +203,7 @@ def estimate_loss( split_to_ds: dict[str, Dataset], device: torch.device, epoch: int, - feature_names: dict[str, str], + feature_name: str, ) -> dict[str, float]: """helps estimate an arbitrarily accurate loss over either split using many batches""" out = {} @@ -232,12 +224,12 @@ def estimate_loss( step=0, indices=indices, device=device, - feature_name=feature_names[split], + feature_name=feature_name, ) for k, (X, Y) in enumerate(minibatches): loss = model(X, labels=Y, return_dict=True).loss losses[k] = loss.item() - out[split] = losses.mean() + out[split] = losses.mean().item() model.train() return out @@ -282,50 +274,17 @@ def save_results( run_context_dict = asdict(run_context) run_context_dict["device"] = str(run_context.device) json.dump(run_context_dict, file, indent=2) - if config.huggingface.push_checkpoints_to_hub: - api = HfApi() + if config.hf and config.hf.push_checkpoints_to_hub: + api = HfApi(token=config.hf.token) api.upload_folder( folder_path=results_path, - repo_id=str(config.huggingface.repo_id), + repo_id=str(config.hf.repo_id), revision=f"iter_{train_results.iter_num}", ) -def load_tokens_dataset_from_huggingface( - hf_dataset_id: str, - split: str, - tokens_feature: str, - limit: Optional[int] = None, -) -> Dataset: - """Load a dataset from huggingface - - Args: - hf_dataset_id (str): huggingface dataset id e.g. "delphi-suite/v0-tinystories-v2-clean-tokenized" - split (str): split to load, e.g. "train" or "validation" - tokens_feature (str): feature name for tokens, e.g. "tokens" - limit (Optional[int], optional): limit the number of samples. None (default) means no limit (use full dataset split) - """ - ds = cast( - Dataset, - load_dataset( - hf_dataset_id, - split=split, - features=datasets.Features( - {tokens_feature: datasets.Sequence(datasets.Value("int32"))} - ), - ), - ) - if limit is not None and limit > 0: - ds = ds.select(range(limit)) - ds.set_format("torch") - return ds - - def count_tokens_so_far(config: TrainingConfig, mts: ModelTrainingState) -> int: - tokens_per_iter = ( - config.batch_size * config.gradient_accumulation_steps * config.max_seq_len - ) - + tokens_per_iter = config.batch_size * config.max_seq_len return mts.iter_num * tokens_per_iter diff --git a/src/delphi/train/wandb_utils.py b/src/delphi/train/wandb_utils.py index 805f817c..83c53912 100644 --- a/src/delphi/train/wandb_utils.py +++ b/src/delphi/train/wandb_utils.py @@ -15,11 +15,12 @@ def silence_wandb(): def init_wandb(config: TrainingConfig): # if log level < debug, silence wandb - if logging.getLogger().level > logging.INFO or config.wandb_config.silence: + assert config.wandb is not None + if logging.getLogger().level > logging.INFO or config.wandb.silence: silence_wandb() wandb.init( - entity=config.wandb_config.entity, - project=config.wandb_config.project, + entity=config.wandb.entity, + project=config.wandb.project, name=config.run_name, config=asdict(config), ) diff --git a/src/delphi/static/README.md b/static/README.md similarity index 100% rename from src/delphi/static/README.md rename to static/README.md diff --git a/src/delphi/static/all_tokens_list.txt b/static/all_tokens_list.txt similarity index 100% rename from src/delphi/static/all_tokens_list.txt rename to static/all_tokens_list.txt diff --git a/src/delphi/static/labelled_token_ids_dict.pkl b/static/labelled_token_ids_dict.pkl similarity index 100% rename from src/delphi/static/labelled_token_ids_dict.pkl rename to static/labelled_token_ids_dict.pkl diff --git a/src/delphi/static/model_group_stats.pkl b/static/model_group_stats.pkl similarity index 100% rename from src/delphi/static/model_group_stats.pkl rename to static/model_group_stats.pkl diff --git a/src/delphi/static/token_map.pkl b/static/token_map.pkl similarity index 100% rename from src/delphi/static/token_map.pkl rename to static/token_map.pkl diff --git a/tests/train/config/test_config_utils.py b/tests/train/config/test_config_utils.py index 452adef2..710aa404 100644 --- a/tests/train/config/test_config_utils.py +++ b/tests/train/config/test_config_utils.py @@ -2,7 +2,7 @@ import pytest -from delphi.constants import CONFIG_PRESETS_DIR +from delphi.constants import TEST_CONFIGS_DIR from delphi.train.config.utils import ( _unoptionalize, build_config_from_files_and_overrides, @@ -12,6 +12,12 @@ ) +def test_configs(): + test_configs = list(TEST_CONFIGS_DIR.glob("*.json")) + for config in test_configs: + build_config_from_files_and_overrides([config], {}) + + def test_merge_two_dicts(): dict1 = {"a": 1, "b": 2, "c": {"d": 3, "e": 4}} dict2 = {"a": 5, "c": {"d": 6}} @@ -34,7 +40,7 @@ def test_dot_notation_to_dict(): def test_build_config_from_files_and_overrides(): - config_files = [CONFIG_PRESETS_DIR / "debug.json"] + config_files = [TEST_CONFIGS_DIR / "debug.json"] overrides = {"model_config": {"hidden_size": 128}, "eval_iters": 5} config = build_config_from_files_and_overrides(config_files, overrides) # check overrides @@ -42,7 +48,7 @@ def test_build_config_from_files_and_overrides(): assert config.eval_iters == 5 # check base values assert config.max_epochs == 2 - assert config.data_config.train_sample_limit == 256 + assert config.dataset.name == "delphi-suite/v0-tinystories-v2-clean-tokenized" def test_unoptionalize(): diff --git a/tests/train/test_train_step.py b/tests/train/test_train_step.py index 178f47be..c66ccdc4 100644 --- a/tests/train/test_train_step.py +++ b/tests/train/test_train_step.py @@ -6,8 +6,9 @@ from datasets import Dataset from jaxtyping import Float +from delphi.constants import TEST_CONFIGS_DIR from delphi.train.config import TrainingConfig -from delphi.train.config.utils import load_preset +from delphi.train.config.utils import build_config_from_files_and_overrides from delphi.train.train_step import accumulate_gradients, train_step from delphi.train.utils import ( ModelTrainingState, @@ -17,6 +18,12 @@ ) +def load_test_config(preset_name: str) -> TrainingConfig: + """Load a test config by name, e.g. `load_preset("debug")`.""" + preset_path = TEST_CONFIGS_DIR / f"{preset_name}.json" + return build_config_from_files_and_overrides([preset_path], {}) + + @pytest.fixture def dataset(): ds = Dataset.from_dict( @@ -71,7 +78,9 @@ def test_basic_reproducibility(dataset, model): ) device = torch.device("cpu") indices = list(range(len(dataset))) - train_step(model_training_state, dataset, load_preset("debug"), device, indices) + train_step( + model_training_state, dataset, load_test_config("debug"), device, indices + ) params = get_params(model) @@ -230,7 +239,7 @@ def test_train_step_no_training(dataset, model): Test train_step when no_training is set to True """ # setup - config_dict = asdict(load_preset("debug")) + config_dict = asdict(load_test_config("debug")) config_dict["debug_config"] = {"no_training": True} config = dacite.from_dict(TrainingConfig, config_dict) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) @@ -252,7 +261,7 @@ def test_train_step_with_training(dataset, model): Test train_step when training is performed """ # setup - config_dict = asdict(load_preset("debug")) + config_dict = asdict(load_test_config("debug")) config_dict["debug_config"] = {"no_training": False} config_dict["batch_size"] = 16 config_dict["optimizer"] = {"gradient_accumulation_steps": 4} diff --git a/tests/train/test_wandb_utils.py b/tests/train/test_wandb_utils.py index c4d60725..4ca89670 100644 --- a/tests/train/test_wandb_utils.py +++ b/tests/train/test_wandb_utils.py @@ -7,37 +7,24 @@ import transformers from dacite import from_dict +from delphi.constants import TEST_CONFIGS_DIR from delphi.train.config import TrainingConfig +from delphi.train.config.utils import build_config_from_files_and_overrides from delphi.train.utils import ModelTrainingState, initialize_model_training_state from delphi.train.wandb_utils import init_wandb, log_to_wandb, silence_wandb @pytest.fixture -def mock_training_config(): - config = from_dict( - TrainingConfig, - { - "run_name": "test_run", - "device": "cpu", - "model_config": { - "model_type": "LlamaForCausalLM", - "model_params": { - "hidden_size": 48, - "intermediate_size": 48, - "num_attention_heads": 2, - "num_hidden_layers": 2, - "num_key_value_heads": 2, - "vocab_size": 4096, - }, - }, - "wandb_config": { - "log": True, - "entity": "test_entity", - "project": "test_project", - }, +def mock_training_config() -> TrainingConfig: + preset_path = TEST_CONFIGS_DIR / "debug.json" + overrides = { + "run_name": "test_run", + "wandb": { + "entity": "test_entity", + "project": "test_project", }, - ) - return config + } + return build_config_from_files_and_overrides([preset_path], overrides) @pytest.fixture