-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Training script #31
Training script #31
Changes from 125 commits
5aa939a
27e7df4
f9dbff2
1b394e0
9fba402
e11de88
6287121
e42a8db
6642256
b9cdc78
294e792
c95c28b
27891b5
cd2c5f7
145f8aa
fe834b0
f7eacd7
a88be1f
0d356af
646b3e7
b04975a
2b056ca
23f8a55
f7cc6b7
e863604
a00425b
2aa2ea6
ce4a6ca
83496e2
30a1b30
4098036
96f2361
584c55d
d79d50c
9e1e9d8
27f5d43
e542237
0524030
29e986e
d8831bd
95d534d
8716bf7
7352eec
c3e5ef7
7cb4ca7
7213aa4
a8f7143
e038f31
59fce94
734b92e
2ed386c
c928d3b
54b095a
fee0497
2665633
2f65438
4c64774
a9ac3dd
3d7711a
c4e69d2
656228c
27fdc79
398f1de
366b4b5
d4a81e8
5dc23e6
2f1a0a4
1f37228
adfd4b4
e3b326c
551a8de
cd9a5b1
ab19879
bc8b43d
859ae09
fc3f021
70e82ee
697f729
a8f7a4f
997ec3a
f9bd899
89cee7c
698365f
7f6c180
662555d
0699026
4007b6a
594033e
cc010da
bfb28c1
f85d015
3c0010b
4870d92
9eeb960
02ec1a1
6cb9d52
35cb7c4
5e10db2
af6c0db
2a9d2c2
1225d93
a9a791b
10b1a36
fbfeaa5
5a54078
19e779f
a98aa42
53a6adf
0d69ede
e24fec4
b9f682c
13db6eb
15983a3
ac8fb6a
43dd1b2
6a21f2c
be2354f
c128f96
28076ed
384b140
5f64802
e4ecd21
a1a150d
ec04579
fb0c509
129587d
ab208aa
076cd8b
be352fe
10d718e
77c15ef
a8d03a8
2d86063
bed480c
f168310
6a67a02
52745bc
9e759a3
5dfdb94
12dc6c4
a81f9e5
63cb45b
be3069d
d15f467
9c27671
c259e97
0a04b99
083cb1b
580d3c6
14dc55e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,8 @@ jobs: | |
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
with: | ||
submodules: recursive | ||
- name: setup python | ||
uses: actions/setup-python@v5 | ||
with: | ||
|
@@ -31,11 +33,11 @@ jobs: | |
- name: dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install -r requirements.txt | ||
pip install -r requirements-nocuda.txt | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mamba really wants to run on CUDA, and there are optional-but-preferred dependencies for doing so. We want to use that when we can, because otherwise training mamba models would take foreverrrrrrrrr - but only on platforms that support it. We split out requirements into those that don't require cuda (-nocuda) and those that do (still in requirements.txt, which automatically includes -nocuda requirements). Github CI doesn't run in a CUDA env because why would Microsoft give away that much GPU compute? So we use the non-cuda requirements.txt for CI. |
||
pip install -e . | ||
- name: black | ||
run: black --check . | ||
- name: isort | ||
run: isort --profile black --check . | ||
run: isort --check . | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. --black config is now implicit in pyproject.toml |
||
- name: pytest | ||
run: pytest |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,11 @@ __pycache__/ | |
# C extensions | ||
*.so | ||
|
||
bin | ||
include | ||
lib64 | ||
pyvenv.cfg | ||
Comment on lines
+9
to
+12
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think these are llama2c artifacts? Also no technically longer needed, but on the other hand these are generally things we'd want to exclude from git if anything with these names ever showed up. |
||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
|
@@ -158,3 +163,15 @@ cython_debug/ | |
# and can be added to the global gitignore or merged into this file. For a more nuclear | ||
# option (not recommended) you can uncomment the following to ignore the entire idea folder. | ||
#.idea/ | ||
|
||
# ignore wandb files | ||
**/wandb/* | ||
**/*.wandb | ||
**/wandb-summary.json | ||
**/wandb-metadata.json | ||
Comment on lines
+167
to
+171
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Debugging wandb integration involved a lot of wandb artifacts being created and I was too lazy to change directories. |
||
|
||
# scratch notebook | ||
notebooks/scratch.ipynb | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like having a notebook in my environment for messy testing, ad hoc experiments, and running snippets. It's (by design) horrible ugly nastiness that should never see the light of day. |
||
|
||
# dsstore | ||
.DS_Store | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. MacOS seems to make this sometimes for reasons that are probably my fault but I haven't actually entirely figured out. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,5 +8,3 @@ repos: | |
rev: 5.13.2 | ||
hooks: | ||
- id: isort | ||
name: isort (python) | ||
args: ["--profile", "black"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
{ | ||
// Use IntelliSense to learn about possible attributes. | ||
// Hover to view descriptions of existing attributes. | ||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 | ||
"version": "0.2.0", | ||
"configurations": [ | ||
{ | ||
"name": "run_training 256", | ||
"type": "debugpy", | ||
"request": "launch", | ||
"program": "scripts/run_training.py", | ||
"console": "integratedTerminal", | ||
"args": "--debug --train_sample_limit=256" | ||
//"args": "${command:pickArgs}" | ||
}, | ||
{ | ||
"name": "run_training --help", | ||
"type": "debugpy", | ||
"request": "launch", | ||
"program": "scripts/run_training.py", | ||
"console": "integratedTerminal", | ||
"args": "--help" | ||
//"args": "${command:pickArgs}" | ||
}, | ||
{ | ||
"name": "run training with debug plus custom args", | ||
"type": "debugpy", | ||
"request": "launch", | ||
"program": "scripts/run_training.py", | ||
"console": "integratedTerminal", | ||
"args": "--debug ${command:pickArgs}" | ||
} | ||
] | ||
} | ||
Comment on lines
+1
to
+34
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are configurations for vscode's integrated debugging tool |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,9 +7,5 @@ | |
"source.organizeImports": "explicit" | ||
}, | ||
"python.analysis.typeCheckingMode": "basic", | ||
"isort.args": [ | ||
"--profile", | ||
"black" | ||
], | ||
Comment on lines
9
to
-13
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. config consolidated and moved to pyproject.toml |
||
"black-formatter.importStrategy": "fromEnvironment", | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,3 +49,6 @@ When you save a file vscode should automatically format it. Otherwise, pre-commi | |
- comment important sections of the code in _Files changed_ tab | ||
- when it's ready, add the relevant stakeholders as reviewers | ||
4. after the comments are resolved and PR is approved, merge it using _Squash and merge_ | ||
|
||
# Incrementing Versions | ||
When making a new release, increment the version in `delphi/__init__.py` | ||
Comment on lines
+52
to
+54
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The canonical package version is now stored in |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a notebook to demo a basic training run. I'm also noticing that it's actually broken at the moment, I should fix this. |
||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from delphi.train.config.utils import get_presets_by_name\n", | ||
"from delphi.train.training import run_training\n", | ||
"from delphi.train.utils import ModelTrainingState\n", | ||
"from delphi.train.run_context import RunContext\n", | ||
"\n", | ||
"\n", | ||
"def train() -> tuple[ModelTrainingState, RunContext]:\n", | ||
" config = get_presets_by_name()[\"v0-llama2-100k\"]\n", | ||
" config.wandb_config.entity = \"jaiwithani\"\n", | ||
" return run_training(config)\n", | ||
"\n", | ||
"model_train_result = train()" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "tinyevals", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.13" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
[project] | ||
name = "delphi" | ||
dynamic = ["version"] | ||
|
||
[tool.setuptools.dynamic] | ||
version = {attr = "delphi.__version__"} | ||
Comment on lines
+1
to
+6
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This handles automatically propagating the version from init.py to setup.py |
||
|
||
|
||
[tool.black] | ||
extend-exclude = 'src/llama2c' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When we were using llama2c we wanted to avoid tocuhing it as much as possible, including any formatting changes. We can probably remove this? |
||
|
||
[tool.isort] | ||
profile = 'black' | ||
known_third_party = ['llama2c', 'wandb'] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isort will assume that llama2c and wandb are first-party packages if directories with those names are present at the top level (alongside delphi). This tells isort that, no, llama2c and wandb should be treated as third-party packages. This otherwise causes a conflict during CI, when llama2c and wandb dirs are not present, which makes the githubCI isort complain that third-party imports are being formatted as first-party imports. |
||
extend_skip = ['src/llama2c'] | ||
|
||
[tool.pytest.ini_options] | ||
testpaths = ["tests"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Woo consolidated configuration courtesy of @jettjaniak |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# this is a separate requirements.txt file for use in github actions | ||
# this omits packages that cannot be installed in github actions due | ||
# to hardware limitations (e.g. no GPU). All packages here are automatically | ||
# included when installing from requirements.txt | ||
torch==2.1.2 | ||
datasets==2.16.1 | ||
tqdm==4.66.1 | ||
ipywidgets==8.1.1 | ||
nbformat==5.9.2 | ||
pytest==7.4.4 | ||
black==23.12.1 | ||
jaxtyping==0.2.25 | ||
beartype==0.16.4 | ||
pre-commit==3.6.0 | ||
isort==5.13.2 | ||
chardet==5.2.0 | ||
sentencepiece==0.1.99 | ||
protobuf==4.25.2 | ||
plotly==5.18.0 | ||
wandb==0.16.3 | ||
spacy==3.7.2 | ||
pandas==1.3.4 | ||
dacite==1.8.1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dacite adds a |
||
|
||
# temporarily installing transformers from main until 4.39.0 comes out (for mamba support) | ||
transformers @ git+https://github.com/huggingface/transformers@main | ||
# transformers==4.39.0 TODO: use this once 4.39.0 releases | ||
|
||
# spacy-transformers requires transformers <= 4.37.0, temporarily disabling | ||
# spacy-transformers>=1.3.4 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,8 @@ | ||
torch==2.1.2 | ||
datasets==2.16.1 | ||
transformers==4.36.2 | ||
tqdm==4.66.1 | ||
ipywidgets==8.1.1 | ||
nbformat==5.9.2 | ||
pytest==7.4.4 | ||
black==23.12.1 | ||
jaxtyping==0.2.25 | ||
beartype==0.16.4 | ||
pre-commit==3.6.0 | ||
isort==5.13.2 | ||
spacy==3.7.2 | ||
chardet==5.2.0 | ||
sentencepiece==0.1.99 | ||
protobuf==4.25.2 | ||
plotly==5.18.0 | ||
spacy-transformers==1.3.4 | ||
pandas==1.3.4 | ||
# most packages are specified in requirements-gh.txt, and new packages should be placed | ||
# there UNLESS they cannot be installed without CUDA support, in which case they should go here. | ||
-r requirements-nocuda.txt | ||
|
||
# these libs support better mamba implementations in transformers, | ||
# but require CUDA/nvcc, so they won't work on MacOS. | ||
mamba_ssm==1.2.0.post1; sys_platform != 'darwin' | ||
causal-conv1d==1.2.0.post2; sys_platform != 'darwin' |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import argparse | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the script for training delphi models from the command-line. |
||
import logging | ||
import os | ||
from dataclasses import fields, is_dataclass | ||
from itertools import chain | ||
from pathlib import Path | ||
from typing import Any, Optional, _GenericAlias # type: ignore | ||
|
||
from delphi.constants import CONFIG_PRESETS_DIR | ||
from delphi.train.config import ( | ||
GigaConfig, | ||
build_config_from_files_and_overrides, | ||
get_preset_paths, | ||
get_user_config_path, | ||
) | ||
from delphi.train.training import run_training | ||
from delphi.train.utils import save_results | ||
|
||
|
||
def _unoptionalize(t: type) -> type: | ||
if isinstance(t, _GenericAlias): | ||
return t.__args__[0] | ||
else: | ||
return t | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do a lot of programmatic argument construction by inspecting config dataclasses. We want to expose config parameters for fields that are |
||
|
||
|
||
def get_preset_args(args: argparse.Namespace) -> list[Path]: | ||
cands = [] | ||
for preset in get_preset_paths(): | ||
if hasattr(args, preset.stem) and getattr(args, preset.stem): | ||
cands.append(preset) | ||
return cands | ||
Comment on lines
+37
to
+42
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This gets all user arguments that correspond to config presets (e.g. |
||
|
||
|
||
def get_config_files(args: argparse.Namespace) -> list[Path]: | ||
user_config_path = get_user_config_path() | ||
cands = [user_config_path] if user_config_path.exists() else [] | ||
cands += get_preset_args(args) | ||
config_files = list(chain(*args.config_file)) if args.config_file else [] | ||
cands += map(Path, config_files) | ||
configs = [] | ||
for candpath in cands: | ||
if candpath.exists(): | ||
configs.append(candpath) | ||
logging.info(f"Found config file {candpath}...") | ||
else: | ||
raise FileNotFoundError(candpath, f"Config file {candpath} does not exist.") | ||
return configs | ||
Comment on lines
+45
to
+58
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks for three different kind of preset jsons:
|
||
|
||
|
||
def add_dataclass_args_recursively( | ||
parser: argparse.ArgumentParser, | ||
dc: type[object], | ||
default_group: argparse._ArgumentGroup, | ||
group: Optional[argparse._ArgumentGroup] = None, | ||
prefix: str = "", | ||
): | ||
for field in fields(dc): # type: ignore | ||
# if field is an Optional type, strip it to the actual underlying type | ||
_type = _unoptionalize(field.type) | ||
if is_dataclass(_type): | ||
_group = group or parser.add_argument_group(f"{field.name}") | ||
add_dataclass_args_recursively( | ||
parser, | ||
_type, | ||
default_group, | ||
_group, | ||
prefix=f"{prefix}{field.name}.", | ||
) | ||
else: | ||
_group = group or default_group | ||
_group.add_argument( | ||
f"--{prefix}{field.name}", | ||
type=_type, | ||
required=False, | ||
help=f"Default: {field.default}" | ||
if field.default != field.default_factory | ||
else f"Must be specified as part of {_group.title}", | ||
) | ||
|
||
|
||
def setup_parser() -> argparse.ArgumentParser: | ||
# Setup argparse | ||
parser = argparse.ArgumentParser(description="Train a delphi model") | ||
parser.add_argument( | ||
"--config_file", | ||
help=( | ||
"Path to json file(s) containing config values. Specific values can be overridden with --arguments. " | ||
"e.g. `--config_file primary_config.json secondary_config.json --log_interval 42`. " | ||
'If passing multiple configs with overlapping args, use "priority" key to specify precedence, e.g. {"priority": 100} ' | ||
f'overrides {{"priority": 99}} See preset configs in {CONFIG_PRESETS_DIR}' | ||
), | ||
action="append", | ||
nargs="*", | ||
required=False, | ||
type=str, | ||
) | ||
config_arg_group = parser.add_argument_group("Config arguments") | ||
add_dataclass_args_recursively(parser, GigaConfig, config_arg_group) | ||
preset_arg_group = parser.add_argument_group("Preset configs") | ||
for preset in sorted(get_preset_paths()): | ||
preset_arg_group.add_argument( | ||
f"--{preset.stem}", | ||
help=f"Use {preset.stem} preset config", | ||
action="store_true", | ||
) | ||
return parser | ||
|
||
|
||
def var_args_to_dict(config_vars: dict[str, Any]) -> dict[str, Any]: | ||
# {"a.b.c" = 4} to {"a": {"b": {"c": 4}}} | ||
d = {} | ||
for k, v in config_vars.items(): | ||
if v is None: | ||
continue | ||
cur = d | ||
subkeys = k.split(".") | ||
for subkey in subkeys[:-1]: | ||
if subkey not in cur: | ||
cur[subkey] = {} | ||
cur = cur[subkey] | ||
cur[subkeys[-1]] = v | ||
return d | ||
Comment on lines
+209
to
+222
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We let users specified nested arguments on the commandline, this converts them a nested dictionary that will be used to build the actual config object. |
||
|
||
|
||
def args_to_dict(args: argparse.Namespace) -> dict[str, Any]: | ||
# at the toplevel, filter for args corresponding to field names in GigaConfig | ||
field_names = set(field.name for field in fields(GigaConfig)) | ||
config_vars = { | ||
k: v for k, v in vars(args).items() if k.split(".")[0] in field_names | ||
} | ||
return var_args_to_dict(config_vars) | ||
Comment on lines
+225
to
+231
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This just wraps |
||
|
||
|
||
def main(): | ||
parser = setup_parser() | ||
args = parser.parse_args() | ||
|
||
config_files = get_config_files(args) | ||
args_dict = args_to_dict(args) | ||
config = build_config_from_files_and_overrides(config_files, args_dict) | ||
|
||
# run training | ||
results, run_context = run_training(config) | ||
final_out_dir = os.path.join(config.output_dir, "final") | ||
save_results(config, results, run_context, final_out_dir) | ||
print(f"Saved results to {final_out_dir}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We added this when we added the llama2c submodule. Ironically, we also removed llama2c in the course of developing this PR. Technically we don't need this anymore, but it's not a bad idea to have it around for any submodules we add in the future.