Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training script #31

Merged
merged 150 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from 118 commits
Commits
Show all changes
150 commits
Select commit Hold shift + click to select a range
5aa939a
add llama2.c submodule
jannik-brinkmann Feb 4, 2024
27e7df4
rename submodule to avoid import errors
jannik-brinkmann Feb 4, 2024
f9dbff2
add llama2.c wrapper
jannik-brinkmann Feb 4, 2024
1b394e0
draft training.py
jannik-brinkmann Feb 5, 2024
9fba402
updated draft
jannik-brinkmann Feb 11, 2024
e11de88
Adding a Mamba Class
SrGonao Feb 13, 2024
6287121
Moving stuff to the correct place
SrGonao Feb 13, 2024
e42a8db
Not ready but idea there
SrGonao Feb 13, 2024
6642256
updated training script
jannik-brinkmann Feb 24, 2024
b9cdc78
formatting
jettjaniak Feb 26, 2024
294e792
remove gitmodules
jettjaniak Feb 26, 2024
c95c28b
moved llama2c submodule
jettjaniak Feb 26, 2024
27891b5
llama2c update
jettjaniak Feb 26, 2024
cd2c5f7
fix import
jettjaniak Feb 26, 2024
145f8aa
remove unused files
jettjaniak Feb 26, 2024
fe834b0
rename training_old -> training
jettjaniak Feb 26, 2024
f7eacd7
Moved Mamba
SrGonao Mar 1, 2024
a88be1f
Added type hinting
SrGonao Mar 1, 2024
0d356af
Removed not needed file
SrGonao Mar 1, 2024
646b3e7
Removed compile, amp train fp32
SrGonao Mar 1, 2024
b04975a
fixing black and isort
SrGonao Mar 2, 2024
2b056ca
add submodules to checkout in CI
jettjaniak Mar 2, 2024
23f8a55
pyproject.toml, moved isort cfg, excl. llama2c
jettjaniak Mar 3, 2024
f7cc6b7
isort: llama2c known_third_party
jettjaniak Mar 3, 2024
e863604
limit pytest to tests/ directory
jettjaniak Mar 3, 2024
a00425b
Training_script_refactor (#54)
jaidhyani Mar 8, 2024
2aa2ea6
It's actually a script now
jaidhyani Mar 8, 2024
ce4a6ca
lol copypasting
jaidhyani Mar 8, 2024
83496e2
cleanup
jaidhyani Mar 8, 2024
30a1b30
Adding support for config files
jaidhyani Mar 8, 2024
4098036
comments
jaidhyani Mar 8, 2024
96f2361
flag arguments take priority over config file values
jaidhyani Mar 8, 2024
584c55d
comments
jaidhyani Mar 8, 2024
d79d50c
gitignore .DS_Store file on macos
jaidhyani Mar 8, 2024
9e1e9d8
remove training.sh
jaidhyani Mar 8, 2024
27f5d43
meeting notes and tweaks
jaidhyani Mar 8, 2024
e542237
configurable device
jaidhyani Mar 8, 2024
0524030
Adding mamba implementation
SrGonao Mar 8, 2024
29e986e
mamba hacks, please forgive me
jaidhyani Mar 8, 2024
d8831bd
experimenting with cuda support in gh actions
jaidhyani Mar 8, 2024
95d534d
welp, that didn't work
jaidhyani Mar 8, 2024
8716bf7
remove tokenized_chunks_dataset
jaidhyani Mar 8, 2024
7352eec
separate batch ordering and torch seeds
jaidhyani Mar 8, 2024
c3e5ef7
remove mamba.py
jaidhyani Mar 8, 2024
7cb4ca7
refactoring
jaidhyani Mar 8, 2024
7213aa4
rm TODO
jaidhyani Mar 8, 2024
a8f7143
refactoring
jaidhyani Mar 8, 2024
e038f31
bughunt
jaidhyani Mar 8, 2024
59fce94
debugger config
jaidhyani Mar 9, 2024
734b92e
typing improvements and bugfixes
jaidhyani Mar 9, 2024
2ed386c
add support for "x.y.z = val" style config
jaidhyani Mar 9, 2024
c928d3b
first steps towards Llama2HF support
jaidhyani Mar 9, 2024
54b095a
more debugging stuff
jaidhyani Mar 9, 2024
fee0497
initial HF llama2 support
jaidhyani Mar 9, 2024
2665633
debug more
jaidhyani Mar 10, 2024
2f65438
Add support for preset model configs in script, specifying multiple c…
jaidhyani Mar 10, 2024
4c64774
bughunt
jaidhyani Mar 10, 2024
a9ac3dd
fix beartype Callalble deprecation warning
jaidhyani Mar 10, 2024
3d7711a
rm llamaconfig json accidentally added before
jaidhyani Mar 10, 2024
c4e69d2
asdf
jaidhyani Mar 10, 2024
656228c
script tweaks
jaidhyani Mar 10, 2024
27fdc79
better gigaconfig defaults
jaidhyani Mar 10, 2024
398f1de
debug config is now just another preset; better documentation for sub…
jaidhyani Mar 10, 2024
366b4b5
fix imports
jaidhyani Mar 10, 2024
d4a81e8
remove upload_tokens
jaidhyani Mar 10, 2024
5dc23e6
Whoops. I should probably test things more before pushing them.
jaidhyani Mar 10, 2024
2f1a0a4
cleanup
jaidhyani Mar 10, 2024
1f37228
script tweaks
jaidhyani Mar 10, 2024
adfd4b4
added support for prioritizing configs
jaidhyani Mar 10, 2024
e3b326c
refactoring (config_utils) to support notebook use
jaidhyani Mar 10, 2024
551a8de
fix Llama2ConfigData bug in gigaconfig (use default_factory)
jaidhyani Mar 11, 2024
cd9a5b1
make run_training return ModelTrainingState
jaidhyani Mar 11, 2024
ab19879
more config_utils
jaidhyani Mar 11, 2024
bc8b43d
cleanup run_training script
jaidhyani Mar 11, 2024
859ae09
training_demo notebook (for colab)
jaidhyani Mar 11, 2024
fc3f021
static files tweak
jaidhyani Mar 11, 2024
70e82ee
estimate_mfu for llama2hf
jaidhyani Mar 11, 2024
697f729
Don't break if model export not available
jaidhyani Mar 11, 2024
a8f7a4f
100k quick config
jaidhyani Mar 12, 2024
997ec3a
torch.use_deterministic_algorithms for training
jaidhyani Mar 12, 2024
f9bd899
import Callable from collections.abc
jaidhyani Mar 12, 2024
89cee7c
Move up torch.manual_seed before calling anything in torch
jaidhyani Mar 12, 2024
698365f
add wandb to requirements
jaidhyani Mar 12, 2024
7f6c180
factor out training config package + wandb_config
jaidhyani Mar 12, 2024
662555d
unused import
jaidhyani Mar 12, 2024
0699026
isort
jaidhyani Mar 12, 2024
4007b6a
initial mamba support
jaidhyani Mar 12, 2024
594033e
pip install wheel
jaidhyani Mar 12, 2024
cc010da
pip install packaging
jaidhyani Mar 12, 2024
bfb28c1
come on, mamba_ssm, get it together
jaidhyani Mar 12, 2024
f85d015
requirements-nocuda.txt for gh actions
jaidhyani Mar 12, 2024
3c0010b
Merge branch 'main' into training-script
jaidhyani Mar 13, 2024
4870d92
mv ModelTypes to constants
jaidhyani Mar 14, 2024
9eeb960
deprecate llama2c support
jaidhyani Mar 14, 2024
02ec1a1
clear out more llama2c stuff
jaidhyani Mar 14, 2024
6cb9d52
we still need max_seq_len
jaidhyani Mar 14, 2024
35cb7c4
factoring out optimizer params from config
jaidhyani Mar 14, 2024
5e10db2
fix broken test
jaidhyani Mar 15, 2024
af6c0db
model_args overhaul
jaidhyani Mar 15, 2024
2a9d2c2
rm llama2c
jaidhyani Mar 15, 2024
1225d93
replace DataLoader
jaidhyani Mar 15, 2024
a9a791b
run_dir to gigaconfig; output_run_dir; fix Generator type warning
jaidhyani Mar 15, 2024
10b1a36
save results when training is done
jaidhyani Mar 15, 2024
fbfeaa5
save step in results
jaidhyani Mar 15, 2024
5a54078
include architecture and priority in llama preset configs
jaidhyani Mar 15, 2024
19e779f
Merge branch 'training-script' into mamba_dev
jaidhyani Mar 15, 2024
a98aa42
update training demo
jaidhyani Mar 15, 2024
53a6adf
mamba expectedly imports correctly
jaidhyani Mar 15, 2024
0d69ede
rm export_model
jaidhyani Mar 15, 2024
e24fec4
estimate_loss no longer depends on architecture
jaidhyani Mar 15, 2024
b9f682c
add combine_configs (working towards frozen config for type safety)
jaidhyani Mar 15, 2024
13db6eb
renaming/simplification
jaidhyani Mar 15, 2024
15983a3
model_config refactor to approach type safety + frozen dataclasses
jaidhyani Mar 15, 2024
ac8fb6a
rm architectures.py
jaidhyani Mar 15, 2024
43dd1b2
new config system with type safety!
jaidhyani Mar 16, 2024
6a21f2c
Support for optional config types (mamba and llama)
jaidhyani Mar 16, 2024
be2354f
fix sample configs
jaidhyani Mar 16, 2024
c128f96
remove some unused model config args
jaidhyani Mar 16, 2024
28076ed
remove unused mamba.py
jaidhyani Mar 16, 2024
384b140
I thought I already deleted this?
jaidhyani Mar 16, 2024
5f64802
rename to "initialize_model_training_state"
jaidhyani Mar 16, 2024
e4ecd21
Support for mandatory fields in run_training
jaidhyani Mar 17, 2024
a1a150d
ModelTypes
jaidhyani Mar 17, 2024
ec04579
output_dir is output_dir
jaidhyani Mar 17, 2024
fb0c509
cleaner imports
jaidhyani Mar 17, 2024
129587d
error if ModelConfig doesn't include config for chosen model type
jaidhyani Mar 17, 2024
ab208aa
no-op renames for clarity
jaidhyani Mar 17, 2024
076cd8b
log levels
jaidhyani Mar 17, 2024
be352fe
shebang & chmod +x on run_training.py
jettjaniak Mar 18, 2024
10d718e
renamed corpus dataset
jettjaniak Mar 18, 2024
77c15ef
removed llama2c references from pyproject.toml
jettjaniak Mar 18, 2024
a8d03a8
removed .gitmodules
jettjaniak Mar 18, 2024
2d86063
removed scripts/upload_stories.py
jettjaniak Mar 18, 2024
bed480c
test wandb_utils
jaidhyani Mar 17, 2024
f168310
no llama2c, no .view, no need for enforcing contigious tensors
jaidhyani Mar 18, 2024
6a67a02
Fix _unoptionalize
jaidhyani Mar 18, 2024
52745bc
run_training.py --help when no args
jaidhyani Mar 18, 2024
9e759a3
script improvements: model-specific args moved to their own --help; f…
jaidhyani Mar 19, 2024
5dfdb94
rename llama to llama2
jaidhyani Mar 19, 2024
12dc6c4
unused imports
jaidhyani Mar 19, 2024
a81f9e5
set run name from config file
jaidhyani Mar 19, 2024
63cb45b
set default output_dir based on run_name
jaidhyani Mar 19, 2024
be3069d
remove in-progress testing file added by mistake
jaidhyani Mar 19, 2024
d15f467
add huggingface config
jaidhyani Mar 19, 2024
9c27671
fix config json that got broken somehow
jaidhyani Mar 19, 2024
c259e97
save/load fix + huggingface uploading
jaidhyani Mar 19, 2024
0a04b99
fix test that broken when renaming llama to llama2
jaidhyani Mar 19, 2024
083cb1b
unused import
jaidhyani Mar 19, 2024
580d3c6
fix validation sampling
jaidhyani Mar 19, 2024
14dc55e
remove eval_only
jaidhyani Mar 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
submodules: recursive
Comment on lines +20 to +21
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We added this when we added the llama2c submodule. Ironically, we also removed llama2c in the course of developing this PR. Technically we don't need this anymore, but it's not a bad idea to have it around for any submodules we add in the future.

- name: setup python
uses: actions/setup-python@v5
with:
Expand All @@ -31,11 +33,11 @@ jobs:
- name: dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-nocuda.txt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mamba really wants to run on CUDA, and there are optional-but-preferred dependencies for doing so. We want to use that when we can, because otherwise training mamba models would take foreverrrrrrrrr - but only on platforms that support it. We split out requirements into those that don't require cuda (-nocuda) and those that do (still in requirements.txt, which automatically includes -nocuda requirements). Github CI doesn't run in a CUDA env because why would Microsoft give away that much GPU compute? So we use the non-cuda requirements.txt for CI.

pip install -e .
- name: black
run: black --check .
- name: isort
run: isort --profile black --check .
run: isort --check .
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--black config is now implicit in pyproject.toml

- name: pytest
run: pytest
17 changes: 17 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ __pycache__/
# C extensions
*.so

bin
include
lib64
pyvenv.cfg
Comment on lines +9 to +12
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these are llama2c artifacts? Also no technically longer needed, but on the other hand these are generally things we'd want to exclude from git if anything with these names ever showed up.


# Distribution / packaging
.Python
build/
Expand Down Expand Up @@ -158,3 +163,15 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# ignore wandb files
**/wandb/*
**/*.wandb
**/wandb-summary.json
**/wandb-metadata.json
Comment on lines +167 to +171
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debugging wandb integration involved a lot of wandb artifacts being created and I was too lazy to change directories.


# scratch notebook
notebooks/scratch.ipynb
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like having a notebook in my environment for messy testing, ad hoc experiments, and running snippets. It's (by design) horrible ugly nastiness that should never see the light of day.


# dsstore
.DS_Store
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MacOS seems to make this sometimes for reasons that are probably my fault but I haven't actually entirely figured out.

Empty file added .gitmodules
Empty file.
2 changes: 0 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,3 @@ repos:
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
args: ["--profile", "black"]
34 changes: 34 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "run_training 256",
"type": "debugpy",
"request": "launch",
"program": "scripts/run_training.py",
"console": "integratedTerminal",
"args": "--debug --train_sample_limit=256"
//"args": "${command:pickArgs}"
},
{
"name": "run_training --help",
"type": "debugpy",
"request": "launch",
"program": "scripts/run_training.py",
"console": "integratedTerminal",
"args": "--help"
//"args": "${command:pickArgs}"
},
{
"name": "run training with debug plus custom args",
"type": "debugpy",
"request": "launch",
"program": "scripts/run_training.py",
"console": "integratedTerminal",
"args": "--debug ${command:pickArgs}"
}
]
}
Comment on lines +1 to +34
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are configurations for vscode's integrated debugging tool

4 changes: 0 additions & 4 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,5 @@
"source.organizeImports": "explicit"
},
"python.analysis.typeCheckingMode": "basic",
"isort.args": [
"--profile",
"black"
],
Comment on lines 9 to -13
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

config consolidated and moved to pyproject.toml

"black-formatter.importStrategy": "fromEnvironment",
}
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,6 @@ When you save a file vscode should automatically format it. Otherwise, pre-commi
- comment important sections of the code in _Files changed_ tab
- when it's ready, add the relevant stakeholders as reviewers
4. after the comments are resolved and PR is approved, merge it using _Squash and merge_

# Incrementing Versions
When making a new release, increment the version in `delphi/__init__.py`
Comment on lines +52 to +54
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The canonical package version is now stored in delphi.__init__.py, and automatically propagated to setup.py. Version tracking is important for reproducibility, so we should try to actually keep track of this.

45 changes: 45 additions & 0 deletions notebooks/training_demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a notebook to demo a basic training run. I'm also noticing that it's actually broken at the moment, I should fix this.

"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from delphi.train.config.utils import get_presets_by_name\n",
"from delphi.train.training import run_training\n",
"from delphi.train.utils import ModelTrainingState\n",
"from delphi.train.run_context import RunContext\n",
"\n",
"\n",
"def train() -> tuple[ModelTrainingState, RunContext]:\n",
" config = get_presets_by_name()[\"v0-llama2-100k\"]\n",
" config.wandb_config.entity = \"jaiwithani\"\n",
" return run_training(config)\n",
"\n",
"model_train_result = train()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "tinyevals",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
18 changes: 18 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[project]
name = "delphi"
dynamic = ["version"]

[tool.setuptools.dynamic]
version = {attr = "delphi.__version__"}
Comment on lines +1 to +6
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This handles automatically propagating the version from init.py to setup.py



[tool.black]
extend-exclude = 'src/llama2c'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we were using llama2c we wanted to avoid tocuhing it as much as possible, including any formatting changes. We can probably remove this?


[tool.isort]
profile = 'black'
known_third_party = ['llama2c', 'wandb']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isort will assume that llama2c and wandb are first-party packages if directories with those names are present at the top level (alongside delphi). This tells isort that, no, llama2c and wandb should be treated as third-party packages. This otherwise causes a conflict during CI, when llama2c and wandb dirs are not present, which makes the githubCI isort complain that third-party imports are being formatted as first-party imports.

extend_skip = ['src/llama2c']

[tool.pytest.ini_options]
testpaths = ["tests"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woo consolidated configuration courtesy of @jettjaniak

30 changes: 30 additions & 0 deletions requirements-nocuda.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# this is a separate requirements.txt file for use in github actions
# this omits packages that cannot be installed in github actions due
# to hardware limitations (e.g. no GPU). All packages here are automatically
# included when installing from requirements.txt
torch==2.1.2
datasets==2.16.1
tqdm==4.66.1
ipywidgets==8.1.1
nbformat==5.9.2
pytest==7.4.4
black==23.12.1
jaxtyping==0.2.25
beartype==0.16.4
pre-commit==3.6.0
isort==5.13.2
chardet==5.2.0
sentencepiece==0.1.99
protobuf==4.25.2
plotly==5.18.0
wandb==0.16.3
spacy==3.7.2
pandas==1.3.4
dacite==1.8.1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dacite adds a from_dict function for deserializing dataclasses (the inverse of dataclasses's asdict). This is extremely useful for deserializing json-desrived config dicts into config objects.


# temporarily installing transformers from main until 4.39.0 comes out (for mamba support)
transformers @ git+https://github.com/huggingface/transformers@main
# transformers==4.39.0 TODO: use this once 4.39.0 releases

# spacy-transformers requires transformers <= 4.37.0, temporarily disabling
# spacy-transformers>=1.3.4
27 changes: 8 additions & 19 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,19 +1,8 @@
torch==2.1.2
datasets==2.16.1
transformers==4.36.2
tqdm==4.66.1
ipywidgets==8.1.1
nbformat==5.9.2
pytest==7.4.4
black==23.12.1
jaxtyping==0.2.25
beartype==0.16.4
pre-commit==3.6.0
isort==5.13.2
spacy==3.7.2
chardet==5.2.0
sentencepiece==0.1.99
protobuf==4.25.2
plotly==5.18.0
spacy-transformers==1.3.4
pandas==1.3.4
# most packages are specified in requirements-gh.txt, and new packages should be placed
# there UNLESS they cannot be installed without CUDA support, in which case they should go here.
-r requirements-nocuda.txt

# these libs support better mamba implementations in transformers,
# but require CUDA/nvcc, so they won't work on MacOS.
mamba_ssm==1.2.0.post1; sys_platform != 'darwin'
causal-conv1d==1.2.0.post2; sys_platform != 'darwin'
151 changes: 151 additions & 0 deletions scripts/run_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import argparse
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the script for training delphi models from the command-line.

import logging
import os
from dataclasses import fields, is_dataclass
from itertools import chain
from pathlib import Path
from typing import Any, Optional, _GenericAlias # type: ignore

from delphi.constants import CONFIG_PRESETS_DIR
from delphi.train.config.gigaconfig import GigaConfig
from delphi.train.config.utils import (
build_config_from_files_and_overrides,
get_preset_paths,
get_user_config_path,
)
from delphi.train.training import run_training
from delphi.train.utils import get_run_output_dir, save_results


def _unoptionalize(t: type) -> type:
if isinstance(t, _GenericAlias):
return t.__args__[0]
else:
return t
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do a lot of programmatic argument construction by inspecting config dataclasses. We want to expose config parameters for fields that are Optional, and doing this requires accessing the underlying type that Optional wraps. This arcane python invocation takes Optional[T] and returns T. (or just T if it wasn't Optional to begin with).



def get_preset_args(args: argparse.Namespace) -> list[Path]:
cands = []
for preset in get_preset_paths():
if hasattr(args, preset.stem) and getattr(args, preset.stem):
cands.append(preset)
return cands
Comment on lines +37 to +42
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gets all user arguments that correspond to config presets (e.g. --some_preset for $PRESETS/some_preset.json). This is technically inefficient since we've already checked the presets dir to build the arguments in the first place, but whatever.



def get_config_files(args: argparse.Namespace) -> list[Path]:
user_config_path = get_user_config_path()
cands = [user_config_path] if user_config_path.exists() else []
cands += get_preset_args(args)
config_files = list(chain(*args.config_file)) if args.config_file else []
cands += map(Path, config_files)
configs = []
for candpath in cands:
if candpath.exists():
configs.append(candpath)
logging.info(f"Found config file {candpath}...")
else:
raise FileNotFoundError(candpath, f"Config file {candpath} does not exist.")
return configs
Comment on lines +45 to +58
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks for three different kind of preset jsons:

  1. The user-specific config (at platformdirst.user_config_dir(appname="delphi") / config.json)
  2. Any user-specified presets
  3. Any general user-specified config json paths



def add_dataclass_args_recursively(
parser: argparse.ArgumentParser,
dc: type[object],
default_group: argparse._ArgumentGroup,
group: Optional[argparse._ArgumentGroup] = None,
prefix: str = "",
):
for field in fields(dc): # type: ignore
# if field is an Optional type, strip it to the actual underlying type
_type = _unoptionalize(field.type)
if is_dataclass(_type):
_group = group or parser.add_argument_group(
f"{field.name.capitalize()} arguments"
)
add_dataclass_args_recursively(
parser,
_type,
default_group,
_group,
prefix=f"{prefix}{field.name}.",
)
else:
_group = group or default_group
_group.add_argument(
f"--{prefix}{field.name}",
type=_type,
required=False,
help=f"Default: {field.default}",
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how we build argparse arguments from a dataclass (specifically, GigaConfig). There are a two basic rules for how we want to do this:

  1. Any top-level fields (e.g. not dataclasses) should be added to a pre-specified default argument group.
  2. Any dataclass fields should create their own argument group for their subfields (but these subfields should not create further argument groups).



def setup_parser() -> argparse.ArgumentParser:
# Setup argparse
parser = argparse.ArgumentParser(description="Train a delphi model")
parser.add_argument(
"--config_file",
help=(
"Path to json file(s) containing config values. Specific values can be overridden with --arguments. "
"e.g. `--config_file primary_config.json secondary_config.json --log_interval 42`. "
'If passing multiple configs with overlapping args, use "priority" key to specify precedence, e.g. {"priority": 100} '
f'overrides {{"priority": 99}} See preset configs in {CONFIG_PRESETS_DIR}'
),
action="append",
nargs="*",
required=False,
type=str,
)
config_arg_group = parser.add_argument_group("Config arguments")
add_dataclass_args_recursively(parser, GigaConfig, config_arg_group)
preset_arg_group = parser.add_argument_group("Preset configs")
for preset in sorted(get_preset_paths()):
preset_arg_group.add_argument(
f"--{preset.stem}",
help=f"Use {preset.stem} preset config",
action="store_true",
)
return parser


def var_args_to_dict(config_vars: dict[str, Any]) -> dict[str, Any]:
# {"a.b.c" = 4} to {"a": {"b": {"c": 4}}}
d = {}
for k, v in config_vars.items():
if v is None:
continue
cur = d
subkeys = k.split(".")
for subkey in subkeys[:-1]:
if subkey not in cur:
cur[subkey] = {}
cur = cur[subkey]
cur[subkeys[-1]] = v
return d
Comment on lines +209 to +222
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We let users specified nested arguments on the commandline, this converts them a nested dictionary that will be used to build the actual config object.



def args_to_dict(args: argparse.Namespace) -> dict[str, Any]:
# at the toplevel, filter for args corresponding to field names in GigaConfig
field_names = set(field.name for field in fields(GigaConfig))
config_vars = {
k: v for k, v in vars(args).items() if k.split(".")[0] in field_names
}
return var_args_to_dict(config_vars)
Comment on lines +225 to +231
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just wraps var_args_to_dict and restricts it to arguments corresponding to actual config fields (excluding non-config params like priority).



def main():
parser = setup_parser()
args = parser.parse_args()

config_files = get_config_files(args)
args_dict = args_to_dict(args)
config = build_config_from_files_and_overrides(config_files, args_dict)

# run training
results, run_context = run_training(config)
final_out_dir = os.path.join(get_run_output_dir(config), "final")
save_results(config, results, run_context, final_out_dir)
print(f"Saved results to {final_out_dir}")


if __name__ == "__main__":
main()
Loading
Loading