Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training script #31

Merged
merged 150 commits into from
Mar 19, 2024
Merged
Changes from 1 commit
Commits
Show all changes
150 commits
Select commit Hold shift + click to select a range
5aa939a
add llama2.c submodule
jannik-brinkmann Feb 4, 2024
27e7df4
rename submodule to avoid import errors
jannik-brinkmann Feb 4, 2024
f9dbff2
add llama2.c wrapper
jannik-brinkmann Feb 4, 2024
1b394e0
draft training.py
jannik-brinkmann Feb 5, 2024
9fba402
updated draft
jannik-brinkmann Feb 11, 2024
e11de88
Adding a Mamba Class
SrGonao Feb 13, 2024
6287121
Moving stuff to the correct place
SrGonao Feb 13, 2024
e42a8db
Not ready but idea there
SrGonao Feb 13, 2024
6642256
updated training script
jannik-brinkmann Feb 24, 2024
b9cdc78
formatting
jettjaniak Feb 26, 2024
294e792
remove gitmodules
jettjaniak Feb 26, 2024
c95c28b
moved llama2c submodule
jettjaniak Feb 26, 2024
27891b5
llama2c update
jettjaniak Feb 26, 2024
cd2c5f7
fix import
jettjaniak Feb 26, 2024
145f8aa
remove unused files
jettjaniak Feb 26, 2024
fe834b0
rename training_old -> training
jettjaniak Feb 26, 2024
f7eacd7
Moved Mamba
SrGonao Mar 1, 2024
a88be1f
Added type hinting
SrGonao Mar 1, 2024
0d356af
Removed not needed file
SrGonao Mar 1, 2024
646b3e7
Removed compile, amp train fp32
SrGonao Mar 1, 2024
b04975a
fixing black and isort
SrGonao Mar 2, 2024
2b056ca
add submodules to checkout in CI
jettjaniak Mar 2, 2024
23f8a55
pyproject.toml, moved isort cfg, excl. llama2c
jettjaniak Mar 3, 2024
f7cc6b7
isort: llama2c known_third_party
jettjaniak Mar 3, 2024
e863604
limit pytest to tests/ directory
jettjaniak Mar 3, 2024
a00425b
Training_script_refactor (#54)
jaidhyani Mar 8, 2024
2aa2ea6
It's actually a script now
jaidhyani Mar 8, 2024
ce4a6ca
lol copypasting
jaidhyani Mar 8, 2024
83496e2
cleanup
jaidhyani Mar 8, 2024
30a1b30
Adding support for config files
jaidhyani Mar 8, 2024
4098036
comments
jaidhyani Mar 8, 2024
96f2361
flag arguments take priority over config file values
jaidhyani Mar 8, 2024
584c55d
comments
jaidhyani Mar 8, 2024
d79d50c
gitignore .DS_Store file on macos
jaidhyani Mar 8, 2024
9e1e9d8
remove training.sh
jaidhyani Mar 8, 2024
27f5d43
meeting notes and tweaks
jaidhyani Mar 8, 2024
e542237
configurable device
jaidhyani Mar 8, 2024
0524030
Adding mamba implementation
SrGonao Mar 8, 2024
29e986e
mamba hacks, please forgive me
jaidhyani Mar 8, 2024
d8831bd
experimenting with cuda support in gh actions
jaidhyani Mar 8, 2024
95d534d
welp, that didn't work
jaidhyani Mar 8, 2024
8716bf7
remove tokenized_chunks_dataset
jaidhyani Mar 8, 2024
7352eec
separate batch ordering and torch seeds
jaidhyani Mar 8, 2024
c3e5ef7
remove mamba.py
jaidhyani Mar 8, 2024
7cb4ca7
refactoring
jaidhyani Mar 8, 2024
7213aa4
rm TODO
jaidhyani Mar 8, 2024
a8f7143
refactoring
jaidhyani Mar 8, 2024
e038f31
bughunt
jaidhyani Mar 8, 2024
59fce94
debugger config
jaidhyani Mar 9, 2024
734b92e
typing improvements and bugfixes
jaidhyani Mar 9, 2024
2ed386c
add support for "x.y.z = val" style config
jaidhyani Mar 9, 2024
c928d3b
first steps towards Llama2HF support
jaidhyani Mar 9, 2024
54b095a
more debugging stuff
jaidhyani Mar 9, 2024
fee0497
initial HF llama2 support
jaidhyani Mar 9, 2024
2665633
debug more
jaidhyani Mar 10, 2024
2f65438
Add support for preset model configs in script, specifying multiple c…
jaidhyani Mar 10, 2024
4c64774
bughunt
jaidhyani Mar 10, 2024
a9ac3dd
fix beartype Callalble deprecation warning
jaidhyani Mar 10, 2024
3d7711a
rm llamaconfig json accidentally added before
jaidhyani Mar 10, 2024
c4e69d2
asdf
jaidhyani Mar 10, 2024
656228c
script tweaks
jaidhyani Mar 10, 2024
27fdc79
better gigaconfig defaults
jaidhyani Mar 10, 2024
398f1de
debug config is now just another preset; better documentation for sub…
jaidhyani Mar 10, 2024
366b4b5
fix imports
jaidhyani Mar 10, 2024
d4a81e8
remove upload_tokens
jaidhyani Mar 10, 2024
5dc23e6
Whoops. I should probably test things more before pushing them.
jaidhyani Mar 10, 2024
2f1a0a4
cleanup
jaidhyani Mar 10, 2024
1f37228
script tweaks
jaidhyani Mar 10, 2024
adfd4b4
added support for prioritizing configs
jaidhyani Mar 10, 2024
e3b326c
refactoring (config_utils) to support notebook use
jaidhyani Mar 10, 2024
551a8de
fix Llama2ConfigData bug in gigaconfig (use default_factory)
jaidhyani Mar 11, 2024
cd9a5b1
make run_training return ModelTrainingState
jaidhyani Mar 11, 2024
ab19879
more config_utils
jaidhyani Mar 11, 2024
bc8b43d
cleanup run_training script
jaidhyani Mar 11, 2024
859ae09
training_demo notebook (for colab)
jaidhyani Mar 11, 2024
fc3f021
static files tweak
jaidhyani Mar 11, 2024
70e82ee
estimate_mfu for llama2hf
jaidhyani Mar 11, 2024
697f729
Don't break if model export not available
jaidhyani Mar 11, 2024
a8f7a4f
100k quick config
jaidhyani Mar 12, 2024
997ec3a
torch.use_deterministic_algorithms for training
jaidhyani Mar 12, 2024
f9bd899
import Callable from collections.abc
jaidhyani Mar 12, 2024
89cee7c
Move up torch.manual_seed before calling anything in torch
jaidhyani Mar 12, 2024
698365f
add wandb to requirements
jaidhyani Mar 12, 2024
7f6c180
factor out training config package + wandb_config
jaidhyani Mar 12, 2024
662555d
unused import
jaidhyani Mar 12, 2024
0699026
isort
jaidhyani Mar 12, 2024
4007b6a
initial mamba support
jaidhyani Mar 12, 2024
594033e
pip install wheel
jaidhyani Mar 12, 2024
cc010da
pip install packaging
jaidhyani Mar 12, 2024
bfb28c1
come on, mamba_ssm, get it together
jaidhyani Mar 12, 2024
f85d015
requirements-nocuda.txt for gh actions
jaidhyani Mar 12, 2024
3c0010b
Merge branch 'main' into training-script
jaidhyani Mar 13, 2024
4870d92
mv ModelTypes to constants
jaidhyani Mar 14, 2024
9eeb960
deprecate llama2c support
jaidhyani Mar 14, 2024
02ec1a1
clear out more llama2c stuff
jaidhyani Mar 14, 2024
6cb9d52
we still need max_seq_len
jaidhyani Mar 14, 2024
35cb7c4
factoring out optimizer params from config
jaidhyani Mar 14, 2024
5e10db2
fix broken test
jaidhyani Mar 15, 2024
af6c0db
model_args overhaul
jaidhyani Mar 15, 2024
2a9d2c2
rm llama2c
jaidhyani Mar 15, 2024
1225d93
replace DataLoader
jaidhyani Mar 15, 2024
a9a791b
run_dir to gigaconfig; output_run_dir; fix Generator type warning
jaidhyani Mar 15, 2024
10b1a36
save results when training is done
jaidhyani Mar 15, 2024
fbfeaa5
save step in results
jaidhyani Mar 15, 2024
5a54078
include architecture and priority in llama preset configs
jaidhyani Mar 15, 2024
19e779f
Merge branch 'training-script' into mamba_dev
jaidhyani Mar 15, 2024
a98aa42
update training demo
jaidhyani Mar 15, 2024
53a6adf
mamba expectedly imports correctly
jaidhyani Mar 15, 2024
0d69ede
rm export_model
jaidhyani Mar 15, 2024
e24fec4
estimate_loss no longer depends on architecture
jaidhyani Mar 15, 2024
b9f682c
add combine_configs (working towards frozen config for type safety)
jaidhyani Mar 15, 2024
13db6eb
renaming/simplification
jaidhyani Mar 15, 2024
15983a3
model_config refactor to approach type safety + frozen dataclasses
jaidhyani Mar 15, 2024
ac8fb6a
rm architectures.py
jaidhyani Mar 15, 2024
43dd1b2
new config system with type safety!
jaidhyani Mar 16, 2024
6a21f2c
Support for optional config types (mamba and llama)
jaidhyani Mar 16, 2024
be2354f
fix sample configs
jaidhyani Mar 16, 2024
c128f96
remove some unused model config args
jaidhyani Mar 16, 2024
28076ed
remove unused mamba.py
jaidhyani Mar 16, 2024
384b140
I thought I already deleted this?
jaidhyani Mar 16, 2024
5f64802
rename to "initialize_model_training_state"
jaidhyani Mar 16, 2024
e4ecd21
Support for mandatory fields in run_training
jaidhyani Mar 17, 2024
a1a150d
ModelTypes
jaidhyani Mar 17, 2024
ec04579
output_dir is output_dir
jaidhyani Mar 17, 2024
fb0c509
cleaner imports
jaidhyani Mar 17, 2024
129587d
error if ModelConfig doesn't include config for chosen model type
jaidhyani Mar 17, 2024
ab208aa
no-op renames for clarity
jaidhyani Mar 17, 2024
076cd8b
log levels
jaidhyani Mar 17, 2024
be352fe
shebang & chmod +x on run_training.py
jettjaniak Mar 18, 2024
10d718e
renamed corpus dataset
jettjaniak Mar 18, 2024
77c15ef
removed llama2c references from pyproject.toml
jettjaniak Mar 18, 2024
a8d03a8
removed .gitmodules
jettjaniak Mar 18, 2024
2d86063
removed scripts/upload_stories.py
jettjaniak Mar 18, 2024
bed480c
test wandb_utils
jaidhyani Mar 17, 2024
f168310
no llama2c, no .view, no need for enforcing contigious tensors
jaidhyani Mar 18, 2024
6a67a02
Fix _unoptionalize
jaidhyani Mar 18, 2024
52745bc
run_training.py --help when no args
jaidhyani Mar 18, 2024
9e759a3
script improvements: model-specific args moved to their own --help; f…
jaidhyani Mar 19, 2024
5dfdb94
rename llama to llama2
jaidhyani Mar 19, 2024
12dc6c4
unused imports
jaidhyani Mar 19, 2024
a81f9e5
set run name from config file
jaidhyani Mar 19, 2024
63cb45b
set default output_dir based on run_name
jaidhyani Mar 19, 2024
be3069d
remove in-progress testing file added by mistake
jaidhyani Mar 19, 2024
d15f467
add huggingface config
jaidhyani Mar 19, 2024
9c27671
fix config json that got broken somehow
jaidhyani Mar 19, 2024
c259e97
save/load fix + huggingface uploading
jaidhyani Mar 19, 2024
0a04b99
fix test that broken when renaming llama to llama2
jaidhyani Mar 19, 2024
083cb1b
unused import
jaidhyani Mar 19, 2024
580d3c6
fix validation sampling
jaidhyani Mar 19, 2024
14dc55e
remove eval_only
jaidhyani Mar 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix validation sampling
jaidhyani committed Mar 19, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 580d3c60b18a133bc2d99fe5e5e83448b8ea7092
1 change: 1 addition & 0 deletions src/delphi/train/train_step.py
Original file line number Diff line number Diff line change
@@ -51,6 +51,7 @@ def train_step(
batch_size=config.batch_size,
split_to_ds={"train": train_ds, "val": validation_ds},
device=run_context.device,
epoch=model_training_state.epoch,
)
new_best_val_loss = False
if losses["val"] < model_training_state.best_val_loss:
4 changes: 2 additions & 2 deletions src/delphi/train/utils.py
Original file line number Diff line number Diff line change
@@ -249,13 +249,13 @@ def estimate_loss(
batch_size: int,
split_to_ds: dict[str, Dataset],
device: torch.device,
epoch: int,
) -> dict[str, float]:
"""helps estimate an arbitrarily accurate loss over either split using many batches"""
out = {}
model.eval()
for split, ds in split_to_ds.items():
# TODO: actually sample from val!!!!!
batch_iter = iter(batch_generator(ds, batch_size, 0, 0))
batch_iter = iter(batch_generator(ds, batch_size, epoch, 1234))
losses = torch.zeros(eval_iters) # keep on CPU
for k in range(min(eval_iters, len(ds) // batch_size)): # type: ignore
X, Y = get_next_xy(batch_iter, device)