diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 89a889fa..36bb390d 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -17,6 +17,8 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + with: + submodules: recursive - name: setup python uses: actions/setup-python@v5 with: @@ -31,11 +33,11 @@ jobs: - name: dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt + pip install -r requirements-nocuda.txt pip install -e . - name: black run: black --check . - name: isort - run: isort --profile black --check . + run: isort --check . - name: pytest run: pytest diff --git a/.gitignore b/.gitignore index 68bc17f9..0bce421a 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,11 @@ __pycache__/ # C extensions *.so +bin +include +lib64 +pyvenv.cfg + # Distribution / packaging .Python build/ @@ -158,3 +163,15 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +# ignore wandb files +**/wandb/* +**/*.wandb +**/wandb-summary.json +**/wandb-metadata.json + +# scratch notebook +notebooks/scratch.ipynb + +# dsstore +.DS_Store \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a841252d..fbe11bd0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,5 +8,3 @@ repos: rev: 5.13.2 hooks: - id: isort - name: isort (python) - args: ["--profile", "black"] diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 00000000..cc1d9f99 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,34 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "run_training 256", + "type": "debugpy", + "request": "launch", + "program": "scripts/run_training.py", + "console": "integratedTerminal", + "args": "--debug --train_sample_limit=256" + //"args": "${command:pickArgs}" + }, + { + "name": "run_training --help", + "type": "debugpy", + "request": "launch", + "program": "scripts/run_training.py", + "console": "integratedTerminal", + "args": "--help" + //"args": "${command:pickArgs}" + }, + { + "name": "run training with debug plus custom args", + "type": "debugpy", + "request": "launch", + "program": "scripts/run_training.py", + "console": "integratedTerminal", + "args": "--debug ${command:pickArgs}" + } + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 8f9d001f..5a69a6b6 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -7,9 +7,5 @@ "source.organizeImports": "explicit" }, "python.analysis.typeCheckingMode": "basic", - "isort.args": [ - "--profile", - "black" - ], "black-formatter.importStrategy": "fromEnvironment", } \ No newline at end of file diff --git a/README.md b/README.md index 105d7ca7..d8aae486 100644 --- a/README.md +++ b/README.md @@ -49,3 +49,6 @@ When you save a file vscode should automatically format it. Otherwise, pre-commi - comment important sections of the code in _Files changed_ tab - when it's ready, add the relevant stakeholders as reviewers 4. after the comments are resolved and PR is approved, merge it using _Squash and merge_ + +# Incrementing Versions +When making a new release, increment the version in `delphi/__init__.py` \ No newline at end of file diff --git a/notebooks/training_demo.ipynb b/notebooks/training_demo.ipynb new file mode 100644 index 00000000..44b997c6 --- /dev/null +++ b/notebooks/training_demo.ipynb @@ -0,0 +1,45 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from delphi.train.config.utils import get_presets_by_name\n", + "from delphi.train.training import run_training\n", + "from delphi.train.utils import ModelTrainingState\n", + "from delphi.train.run_context import RunContext\n", + "\n", + "\n", + "def train() -> tuple[ModelTrainingState, RunContext]:\n", + " config = get_presets_by_name()[\"v0-llama2-100k\"]\n", + " config.wandb_config.entity = \"jaiwithani\"\n", + " return run_training(config)\n", + "\n", + "model_train_result = train()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tinyevals", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..15591ce4 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,13 @@ +[project] +name = "delphi" +dynamic = ["version"] + +[tool.setuptools.dynamic] +version = {attr = "delphi.__version__"} + +[tool.isort] +profile = 'black' +known_third_party = ['wandb'] + +[tool.pytest.ini_options] +testpaths = ["tests"] \ No newline at end of file diff --git a/requirements-nocuda.txt b/requirements-nocuda.txt new file mode 100644 index 00000000..95816526 --- /dev/null +++ b/requirements-nocuda.txt @@ -0,0 +1,30 @@ +# this is a separate requirements.txt file for use in github actions +# this omits packages that cannot be installed in github actions due +# to hardware limitations (e.g. no GPU). All packages here are automatically +# included when installing from requirements.txt +torch==2.1.2 +datasets==2.16.1 +tqdm==4.66.1 +ipywidgets==8.1.1 +nbformat==5.9.2 +pytest==7.4.4 +black==23.12.1 +jaxtyping==0.2.25 +beartype==0.16.4 +pre-commit==3.6.0 +isort==5.13.2 +chardet==5.2.0 +sentencepiece==0.1.99 +protobuf==4.25.2 +plotly==5.18.0 +wandb==0.16.3 +spacy==3.7.2 +pandas==1.3.4 +dacite==1.8.1 + +# temporarily installing transformers from main until 4.39.0 comes out (for mamba support) +transformers @ git+https://github.com/huggingface/transformers@main +# transformers==4.39.0 TODO: use this once 4.39.0 releases + +# spacy-transformers requires transformers <= 4.37.0, temporarily disabling +# spacy-transformers>=1.3.4 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 8130d532..9f95b321 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,19 +1,8 @@ -torch==2.1.2 -datasets==2.16.1 -transformers==4.36.2 -tqdm==4.66.1 -ipywidgets==8.1.1 -nbformat==5.9.2 -pytest==7.4.4 -black==23.12.1 -jaxtyping==0.2.25 -beartype==0.16.4 -pre-commit==3.6.0 -isort==5.13.2 -spacy==3.7.2 -chardet==5.2.0 -sentencepiece==0.1.99 -protobuf==4.25.2 -plotly==5.18.0 -spacy-transformers==1.3.4 -pandas==1.3.4 +# most packages are specified in requirements-gh.txt, and new packages should be placed +# there UNLESS they cannot be installed without CUDA support, in which case they should go here. +-r requirements-nocuda.txt + +# these libs support better mamba implementations in transformers, +# but require CUDA/nvcc, so they won't work on MacOS. +mamba_ssm==1.2.0.post1; sys_platform != 'darwin' +causal-conv1d==1.2.0.post2; sys_platform != 'darwin' \ No newline at end of file diff --git a/scripts/run_training.py b/scripts/run_training.py new file mode 100755 index 00000000..5586a2bb --- /dev/null +++ b/scripts/run_training.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 +import argparse +import logging +import os +from dataclasses import fields, is_dataclass +from datetime import datetime +from itertools import chain +from pathlib import Path +from typing import Any, Type, Union + +import platformdirs + +from delphi.constants import CONFIG_PRESETS_DIR +from delphi.train.config import ( + GigaConfig, + build_config_from_files_and_overrides, + get_preset_paths, + get_user_config_path, +) +from delphi.train.training import run_training +from delphi.train.utils import save_results + + +def _unoptionalize(t: Type) -> Type: + """unwrap `Optional[T]` to T""" + # Under the hood, `Optional` is really `Union[T, None]`. So we + # just check if this is a Union over two types including None, and + # return the other + if hasattr(t, "__origin__") and t.__origin__ is Union: + args = t.__args__ + # Check if one of the Union arguments is type None + if len(args) == 2 and type(None) in args: + return args[0] if args[1] is type(None) else args[1] + return t + + +def get_preset_args(args: argparse.Namespace) -> list[Path]: + cands = [] + for preset in get_preset_paths(): + if hasattr(args, preset.stem) and getattr(args, preset.stem): + cands.append(preset) + return cands + + +def get_config_files(args: argparse.Namespace) -> list[Path]: + user_config_path = get_user_config_path() + cands = [user_config_path] if user_config_path.exists() else [] + cands += get_preset_args(args) + config_files = list(chain(*args.config_file)) if args.config_file else [] + cands += map(Path, config_files) + configs = [] + for candpath in cands: + if candpath.exists(): + configs.append(candpath) + logging.info(f"Found config file {candpath}...") + else: + raise FileNotFoundError(candpath, f"Config file {candpath} does not exist.") + return configs + + +def add_preset_args(parser: argparse.ArgumentParser): + preset_arg_group = parser.add_argument_group("Preset configs") + for preset in sorted(get_preset_paths()): + preset_arg_group.add_argument( + f"--{preset.stem}", + help=f"Use {preset.stem} preset config {'***and set log level to DEBUG***' if preset.stem == 'debug' else ''}", + action="store_true", + ) + + +def add_dataclass_args_recursively( + parser: argparse.ArgumentParser, + dc: type[object], + group: argparse._ArgumentGroup, + help_parsers: dict[str, argparse.ArgumentParser], + prefix: str = "", + depth: int = 0, + max_help_depth=1, +): + """Recursively add arguments to an argparse parser from a dataclass + + + To keep --help sane, once we reach max_help_depth we start hiding options + from --help and instead add a --_help option to see config options + below that level (e.g. model_config.llama config) + """ + for field in fields(dc): # type: ignore + # if field is an Optional type, strip it to the actual underlying type + _type = _unoptionalize(field.type) + name = f"{prefix}{field.name}" + if is_dataclass(_type): + # at max-depth, + if depth == max_help_depth: + help_name = f"{name}_help" + group.add_argument( + f"--{help_name}", + help=f"***Print help for {name} options***", + default=False, + action="store_true", + ) + help_parser = argparse.ArgumentParser(help_name) + help_group = help_parser.add_argument_group(name) + help_parsers[help_name] = help_parser + add_dataclass_args_recursively( + help_parser, + _type, + help_group, + help_parsers, + prefix=f"{name}.", + depth=depth + 1, + max_help_depth=999, + ) + _group = parser.add_argument_group(f"{name}") + add_dataclass_args_recursively( + parser, + _type, + _group, + help_parsers, + prefix=f"{name}.", + depth=depth + 1, + ) + else: + if depth > max_help_depth: + help_str = argparse.SUPPRESS + elif field.default != field.default_factory: + help_str = f"Default: {field.default}" + else: + help_str = f"Must be specified as part of {group.title}" + group.add_argument( + f"--{name}", + type=_type, + required=False, + help=help_str, + ) + + +def add_logging_args(parser: argparse.ArgumentParser): + logging_group = parser.add_mutually_exclusive_group() + logging_group.add_argument( + "-v", + "--verbose", + action="count", + default=None, + help="Increase verbosity level, repeatable (e.g. -vvv). Mutually exclusive with --silent, --loglevel", + ) + logging_group.add_argument( + "-s", + "--silent", + action="store_true", + help="Silence all logging. Mutually exclusive with --verbose, --loglevel", + default=False, + ) + logging_group.add_argument( + "--loglevel", + type=int, + help="Logging level. 10=DEBUG, 50=CRITICAL. Mutually exclusive with --verbose, --silent", + default=None, + ) + + +def set_logging(args: argparse.Namespace): + logging.basicConfig(format="%(message)s") + logging.getLogger().setLevel(logging.INFO) + if args.debug: + logging.getLogger().setLevel(logging.DEBUG) + if args.verbose is not None: + if args.verbose == 1: + loglevel = logging.DEBUG + elif args.verbose >= 2: + loglevel = 0 + logging.getLogger().setLevel(loglevel) + if args.loglevel is not None: + logging.getLogger().setLevel(args.loglevel) + if args.silent: + logging.getLogger().setLevel(logging.CRITICAL) + else: + logging_level_str = logging.getLevelName( + logging.getLogger().getEffectiveLevel() + ) + print(f"set logging level to {logging_level_str}") + + +def setup_parser() -> ( + tuple[argparse.ArgumentParser, dict[str, argparse.ArgumentParser]] +): + # Setup argparse + parser = argparse.ArgumentParser(description="Train a delphi model") + parser.add_argument( + "--config_file", + help=( + "Path to json file(s) containing config values. Specific values can be overridden with --arguments. " + "e.g. `--config_file primary_config.json secondary_config.json --log_interval 42`. " + 'If passing multiple configs with overlapping args, use "priority" key to specify precedence, e.g. {"priority": 100} ' + f'overrides {{"priority": 99}} See preset configs in {CONFIG_PRESETS_DIR}' + ), + action="append", + nargs="*", + required=False, + type=str, + ) + config_arg_group = parser.add_argument_group("Config arguments") + help_parsers = dict() + add_dataclass_args_recursively(parser, GigaConfig, config_arg_group, help_parsers) + add_preset_args(parser) + add_logging_args(parser) + return parser, help_parsers + + +def var_args_to_dict(config_vars: dict[str, Any]) -> dict[str, Any]: + # {"a.b.c" = 4} to {"a": {"b": {"c": 4}}} + d = {} + for k, v in config_vars.items(): + if v is None: + continue + cur = d + subkeys = k.split(".") + for subkey in subkeys[:-1]: + if subkey not in cur: + cur[subkey] = {} + cur = cur[subkey] + cur[subkeys[-1]] = v + return d + + +def args_to_dict(args: argparse.Namespace) -> dict[str, Any]: + # at the toplevel, filter for args corresponding to field names in GigaConfig + field_names = set(field.name for field in fields(GigaConfig)) + config_vars = { + k: v for k, v in vars(args).items() if k.split(".")[0] in field_names + } + return var_args_to_dict(config_vars) + + +def print_subhelp_if_invoked(args: argparse.Namespace, help_parsers: dict[str, Any]): + for name, parser in help_parsers.items(): + if hasattr(args, name) and getattr(args, name): + parser.print_help() + exit(0) + + +def set_name_from_config_file(args: argparse.Namespace, config_files: list[Path]): + """if no run_name is specified + exactly one config file is, use the name of the config file""" + if args.run_name is None: + configs = [c for c in config_files if c != get_user_config_path()] + if len(configs) == 1: + run_time = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") + args.run_name = f"{configs[0].stem}__{run_time}" + + +def set_output_dir(args: argparse.Namespace): + """if output_dir not set, set based on run name""" + if args.output_dir is None: + args.output_dir = os.path.join( + platformdirs.user_data_dir(appname="delphi"), args.run_name + ) + + +def main(): + parser, help_parsers = setup_parser() + args = parser.parse_args() + print_subhelp_if_invoked(args, help_parsers) + set_logging(args) + + config_files = get_config_files(args) + set_name_from_config_file(args, config_files) + set_output_dir(args) + args_dict = args_to_dict(args) + config = build_config_from_files_and_overrides(config_files, args_dict) + # run training + results, run_context = run_training(config) + final_out_dir = os.path.join(config.output_dir, "final") + save_results(config, results, run_context, final_out_dir) + print(f"Saved results to {final_out_dir}") + + +if __name__ == "__main__": + main() diff --git a/scripts/sample_config.json b/scripts/sample_config.json new file mode 100644 index 00000000..55073fd7 --- /dev/null +++ b/scripts/sample_config.json @@ -0,0 +1,58 @@ +{ + "run_name": "2024_03_15_17_28_14", + "output_dir": "/Users/jaidhyani/Library/Application Support/delphi", + "device": "auto", + "eval_interval": 2000, + "log_interval": 1, + "eval_iters": 100, + "eval_only": false, + "always_save_checkpoint": false, + "init_from": "scratch", + "wandb_config": { + "log": false, + "project": "delphi", + "entity": "set_wandb.entity_to_your_wandb_username_to_make_wandb_logging_work" + }, + "batch_size": 64, + "max_seq_len": 512, + "model_config": { + "model_type": "llama", + "mamba": null, + "llama": { + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": -1, + "eos_token_id": -2, + "hidden_act": "silu", + "hidden_size": 288, + "initializer_range": 0.02, + "intermediate_size": 288, + "max_position_embeddings": 513, + "num_attention_heads": 6, + "num_hidden_layers": 6, + "num_key_value_heads": 6, + "pretraining_tp": 1, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 10000.0, + "tie_word_embeddings": false, + "use_cache": true, + "vocab_size": 4096 + } + }, + "max_epochs": 10, + "grad_clip": 1.0, + "optimizer": { + "gradient_accumulation_steps": 4, + "learning_rate": 0.0005, + "weight_decay": 0.1, + "beta1": 0.9, + "beta2": 0.95, + "grad_clip": 1.0, + "decay_lr": true, + "warmup_iters": 1000, + "min_lr": 0.0 + }, + "train_sample_limit": -1, + "val_sample_limit": -1 +} \ No newline at end of file diff --git a/scripts/sample_mamba.json b/scripts/sample_mamba.json new file mode 100644 index 00000000..a55dd1c0 --- /dev/null +++ b/scripts/sample_mamba.json @@ -0,0 +1,63 @@ +{ + "run_name": "2024_03_15_21_56_35", + "output_dir": "/Users/jaidhyani/Library/Application Support/delphi", + "device": "auto", + "eval_interval": 2000, + "log_interval": 1, + "eval_iters": 100, + "eval_only": false, + "always_save_checkpoint": false, + "init_from": "scratch", + "wandb_config": { + "log": false, + "project": "delphi", + "entity": "set_wandb.entity_to_your_wandb_username_to_make_wandb_logging_work" + }, + "batch_size": 64, + "max_seq_len": 512, + "model_config": { + "model_type": "mamba", + "mamba": { + "vocab_size": 4096, + "hidden_size": 768, + "state_size": 16, + "num_hidden_layers": 32, + "conv_kernel": 4, + "expand": 2, + "use_bias": false, + "use_conv_bias": true, + "bos_token_id": 0, + "eos_token_id": 0, + "pad_token_id": 0, + "time_step_rank": "auto", + "time_step_scale": 1.0, + "time_step_min": 0.001, + "time_step_max": 0.1, + "time_step_init_scheme": "random", + "time_step_floor": 0.0001, + "layer_norm_epsilon": 1e-05, + "hidden_act": "silu", + "initializer_range": 0.1, + "residual_in_fp32": true, + "rescale_prenorm_residual": false, + "use_cache": true, + "tie_word_embeddings": true + }, + "llama": null + }, + "max_epochs": 10, + "grad_clip": 1.0, + "optimizer": { + "gradient_accumulation_steps": 4, + "learning_rate": 0.0005, + "weight_decay": 0.1, + "beta1": 0.9, + "beta2": 0.95, + "grad_clip": 1.0, + "decay_lr": true, + "warmup_iters": 1000, + "min_lr": 0.0 + }, + "train_sample_limit": -1, + "val_sample_limit": -1 +} \ No newline at end of file diff --git a/setup.py b/setup.py index a4156702..fcfffc0d 100644 --- a/setup.py +++ b/setup.py @@ -2,8 +2,10 @@ setup( name="delphi", - version="0.1", packages=find_packages(where="src"), package_dir={"": "src"}, - package_data={"delphi.static": ["*"]}, + package_data={ + "delphi": ["static/**/*"], + }, + include_package_data=True, ) diff --git a/src/delphi/__init__.py b/src/delphi/__init__.py index b9b115cf..a0ea3bb4 100644 --- a/src/delphi/__init__.py +++ b/src/delphi/__init__.py @@ -1,3 +1,5 @@ from beartype.claw import beartype_this_package # <-- hype comes beartype_this_package() # <-- hype goes + +__version__ = "0.1.1" diff --git a/src/delphi/constants.py b/src/delphi/constants.py index 5216566c..4ede491e 100644 --- a/src/delphi/constants.py +++ b/src/delphi/constants.py @@ -1,3 +1,7 @@ from importlib.resources import files STATIC_ASSETS_DIR = files("delphi.static") +CONFIG_PRESETS_DIR = STATIC_ASSETS_DIR / "configs" + +CORPUS_DATASET = "delphi-suite/stories" +TOKENIZED_CORPUS_DATASET = "delphi-suite/v0-tinystories-v2-clean-tokenized" diff --git a/src/delphi/eval/spacy_token_labelling.py b/src/delphi/eval/spacy_token_labelling.py index 40cb9f4a..a9a82193 100644 --- a/src/delphi/eval/spacy_token_labelling.py +++ b/src/delphi/eval/spacy_token_labelling.py @@ -1,6 +1,6 @@ -import pickle +from collections.abc import Callable from pathlib import Path -from typing import Callable, Optional +from typing import Optional import pandas as pd import spacy diff --git a/src/delphi/eval/utils.py b/src/delphi/eval/utils.py index 0c4a8a6f..2d052974 100644 --- a/src/delphi/eval/utils.py +++ b/src/delphi/eval/utils.py @@ -66,21 +66,36 @@ def get_next_and_top_k_probs( return next_probs, top_k -def load_validation_dataset(dataset_name: str, split_slice: str = "") -> Dataset: +def load_delphi_dataset(dataset_name: str, split: str, slice: str = "") -> Dataset: + # check that split is either "train" or "validation" + if split not in ["train", "validation"]: + raise ValueError(f"Split must be either 'train' or 'validation', not {split}") if "/" not in dataset_name: dataset_name = f"delphi-suite/{dataset_name}" - data_str = f"data/validation-*.parquet" + data_files_str = f"data/{split}-*.parquet" dataset = load_dataset( dataset_name, - data_files=data_str, + data_files=data_files_str, verification_mode="no_checks", - # this seems to be the only split when using data_files - # regardless of the files we're actually loading - split=f"train{split_slice}", + # Currently, load_dataset returns a dataset dict *unless* a split is specified, + # EVEN IF NO SPLIT WITHIN THE DATA FILES SPECIFIED. If there's no split arg, + # huggingface just just says everything is in the "train" split and returns {"train": dataset}. + # In our case the data_files glob already specifies just the validation files, so we + # shouldn't need to specify a split. But we do need to specify a split to get a dataset object, + # or we'd get a Dataset dict. See https://github.com/huggingface/datasets/issues/5189 + split=f"train{slice}", ) return cast(Dataset, dataset) +def load_validation_dataset(dataset_name: str, slice: str = "") -> Dataset: + return load_delphi_dataset(dataset_name, "validation", slice) + + +def load_train_dataset(dataset_name: str, slice: str = "") -> Dataset: + return load_delphi_dataset(dataset_name, "train", slice) + + def tokenize( tokenizer: PreTrainedTokenizerBase, sample_txt: str ) -> Int[torch.Tensor, "seq"]: diff --git a/src/delphi/static/README.md b/src/delphi/static/README.md index 815b0c42..e08206a3 100644 --- a/src/delphi/static/README.md +++ b/src/delphi/static/README.md @@ -1,7 +1,7 @@ -# TODO: move this to delphi/static # Static Data Files + ## `token_map.pkl` pickle file: All locations of all tokens. dict of token to list of (doc, pos) pairs. diff --git a/src/delphi/static/configs/debug.json b/src/delphi/static/configs/debug.json new file mode 100644 index 00000000..5d40717a --- /dev/null +++ b/src/delphi/static/configs/debug.json @@ -0,0 +1,21 @@ +{ + "priority": -1, + "vocab_size": 4096, + "max_seq_len": 512, + "max_epochs": 2, + "eval_interval": 1, + "eval_iters": 1, + "train_sample_limit": 256, + "batch_size": 64, + "model_config": { + "model_type": "llama2", + "llama2": { + "hidden_size": 48, + "intermediate_size": 48, + "num_attention_heads": 2, + "num_hidden_layers": 2, + "num_key_value_heads": 2, + "vocab_size": 4096 + } + } +} \ No newline at end of file diff --git a/src/delphi/static/configs/debug_mamba.json b/src/delphi/static/configs/debug_mamba.json new file mode 100644 index 00000000..2a619e59 --- /dev/null +++ b/src/delphi/static/configs/debug_mamba.json @@ -0,0 +1,23 @@ +{ + "priority": -1, + "vocab_size": 4096, + "max_seq_len": 512, + "max_epochs": 2, + "eval_interval": 1, + "log_interval": 1, + "eval_iters": 10, + "train_sample_limit": 64, + "batch_size": 8, + "model_config": { + "model_type": "mamba", + "mamba": { + "vocab_size": 4096, + "hidden_size": 48, + "state_size": 16, + "num_hidden_layers": 2, + "conv_kernel": 2, + "expand": 2, + "time_step_rank": 2 + } + } +} \ No newline at end of file diff --git a/src/delphi/static/configs/v0-llama2-1.6m.json b/src/delphi/static/configs/v0-llama2-1.6m.json new file mode 100644 index 00000000..376c2876 --- /dev/null +++ b/src/delphi/static/configs/v0-llama2-1.6m.json @@ -0,0 +1,26 @@ +{ + "model_config": { + "model_type": "llama2", + "llama2": { + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 48, + "initializer_range": 0.02, + "intermediate_size": 128, + "max_position_embeddings": 512, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "num_key_value_heads": 4, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 10000.0, + "tie_word_embeddings": true, + "use_cache": true, + "vocab_size": 4096 + } + } +} \ No newline at end of file diff --git a/src/delphi/static/configs/v0-llama2-100k-quick.json b/src/delphi/static/configs/v0-llama2-100k-quick.json new file mode 100644 index 00000000..376c2876 --- /dev/null +++ b/src/delphi/static/configs/v0-llama2-100k-quick.json @@ -0,0 +1,26 @@ +{ + "model_config": { + "model_type": "llama2", + "llama2": { + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 48, + "initializer_range": 0.02, + "intermediate_size": 128, + "max_position_embeddings": 512, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "num_key_value_heads": 4, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 10000.0, + "tie_word_embeddings": true, + "use_cache": true, + "vocab_size": 4096 + } + } +} \ No newline at end of file diff --git a/src/delphi/static/configs/v0-llama2-100k.json b/src/delphi/static/configs/v0-llama2-100k.json new file mode 100644 index 00000000..376c2876 --- /dev/null +++ b/src/delphi/static/configs/v0-llama2-100k.json @@ -0,0 +1,26 @@ +{ + "model_config": { + "model_type": "llama2", + "llama2": { + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 48, + "initializer_range": 0.02, + "intermediate_size": 128, + "max_position_embeddings": 512, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "num_key_value_heads": 4, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 10000.0, + "tie_word_embeddings": true, + "use_cache": true, + "vocab_size": 4096 + } + } +} \ No newline at end of file diff --git a/src/delphi/static/configs/v0-llama2-12.8m.json b/src/delphi/static/configs/v0-llama2-12.8m.json new file mode 100644 index 00000000..376c2876 --- /dev/null +++ b/src/delphi/static/configs/v0-llama2-12.8m.json @@ -0,0 +1,26 @@ +{ + "model_config": { + "model_type": "llama2", + "llama2": { + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 48, + "initializer_range": 0.02, + "intermediate_size": 128, + "max_position_embeddings": 512, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "num_key_value_heads": 4, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 10000.0, + "tie_word_embeddings": true, + "use_cache": true, + "vocab_size": 4096 + } + } +} \ No newline at end of file diff --git a/src/delphi/static/configs/v0-llama2-200k.json b/src/delphi/static/configs/v0-llama2-200k.json new file mode 100644 index 00000000..376c2876 --- /dev/null +++ b/src/delphi/static/configs/v0-llama2-200k.json @@ -0,0 +1,26 @@ +{ + "model_config": { + "model_type": "llama2", + "llama2": { + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 48, + "initializer_range": 0.02, + "intermediate_size": 128, + "max_position_embeddings": 512, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "num_key_value_heads": 4, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 10000.0, + "tie_word_embeddings": true, + "use_cache": true, + "vocab_size": 4096 + } + } +} \ No newline at end of file diff --git a/src/delphi/static/configs/v0-llama2-25.6m.json b/src/delphi/static/configs/v0-llama2-25.6m.json new file mode 100644 index 00000000..376c2876 --- /dev/null +++ b/src/delphi/static/configs/v0-llama2-25.6m.json @@ -0,0 +1,26 @@ +{ + "model_config": { + "model_type": "llama2", + "llama2": { + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 48, + "initializer_range": 0.02, + "intermediate_size": 128, + "max_position_embeddings": 512, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "num_key_value_heads": 4, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 10000.0, + "tie_word_embeddings": true, + "use_cache": true, + "vocab_size": 4096 + } + } +} \ No newline at end of file diff --git a/src/delphi/static/configs/v0-llama2-3.2m.json b/src/delphi/static/configs/v0-llama2-3.2m.json new file mode 100644 index 00000000..909a898c --- /dev/null +++ b/src/delphi/static/configs/v0-llama2-3.2m.json @@ -0,0 +1,26 @@ +{ + "model_config": { + "model_type": "llama2", + "llama2": { + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 48, + "initializer_range": 0.02, + "intermediate_size": 128, + "max_position_embeddings": 512, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "num_key_value_heads": 4, + "pretraining_tp": 1, + "rms_norm_eps": 0.00001, + "rope_scaling": null, + "rope_theta": 10000.0, + "tie_word_embeddings": true, + "use_cache": true, + "vocab_size": 4096 + } + } +} \ No newline at end of file diff --git a/src/delphi/static/configs/v0-llama2-400k.json b/src/delphi/static/configs/v0-llama2-400k.json new file mode 100644 index 00000000..376c2876 --- /dev/null +++ b/src/delphi/static/configs/v0-llama2-400k.json @@ -0,0 +1,26 @@ +{ + "model_config": { + "model_type": "llama2", + "llama2": { + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 48, + "initializer_range": 0.02, + "intermediate_size": 128, + "max_position_embeddings": 512, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "num_key_value_heads": 4, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 10000.0, + "tie_word_embeddings": true, + "use_cache": true, + "vocab_size": 4096 + } + } +} \ No newline at end of file diff --git a/src/delphi/static/configs/v0-llama2-6.4m.json b/src/delphi/static/configs/v0-llama2-6.4m.json new file mode 100644 index 00000000..376c2876 --- /dev/null +++ b/src/delphi/static/configs/v0-llama2-6.4m.json @@ -0,0 +1,26 @@ +{ + "model_config": { + "model_type": "llama2", + "llama2": { + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 48, + "initializer_range": 0.02, + "intermediate_size": 128, + "max_position_embeddings": 512, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "num_key_value_heads": 4, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 10000.0, + "tie_word_embeddings": true, + "use_cache": true, + "vocab_size": 4096 + } + } +} \ No newline at end of file diff --git a/src/delphi/static/configs/v0-llama2-800k.json b/src/delphi/static/configs/v0-llama2-800k.json new file mode 100644 index 00000000..376c2876 --- /dev/null +++ b/src/delphi/static/configs/v0-llama2-800k.json @@ -0,0 +1,26 @@ +{ + "model_config": { + "model_type": "llama2", + "llama2": { + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 48, + "initializer_range": 0.02, + "intermediate_size": 128, + "max_position_embeddings": 512, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "num_key_value_heads": 4, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 10000.0, + "tie_word_embeddings": true, + "use_cache": true, + "vocab_size": 4096 + } + } +} \ No newline at end of file diff --git a/src/delphi/train/config/__init__.py b/src/delphi/train/config/__init__.py new file mode 100644 index 00000000..ce698e55 --- /dev/null +++ b/src/delphi/train/config/__init__.py @@ -0,0 +1,13 @@ +from .gigaconfig import GigaConfig +from .optimizer_config import OptimizerConfig +from .utils import ( + build_config_dict_from_files, + build_config_from_files, + build_config_from_files_and_overrides, + get_config_dicts_from_files, + get_preset_paths, + get_presets_by_name, + get_user_config_path, + load_preset, +) +from .wandb_config import WandbConfig diff --git a/src/delphi/train/config/gigaconfig.py b/src/delphi/train/config/gigaconfig.py new file mode 100644 index 00000000..91b4bb2d --- /dev/null +++ b/src/delphi/train/config/gigaconfig.py @@ -0,0 +1,54 @@ +import os +from dataclasses import dataclass, field +from datetime import datetime + +import platformdirs +from beartype import beartype + +from .huggingface_config import HuggingfaceConfig +from .models import ModelConfig +from .optimizer_config import OptimizerConfig +from .wandb_config import WandbConfig + + +@beartype +@dataclass(frozen=True) +class GigaConfig: + model_config: ModelConfig + # meta + run_name: str = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") + output_dir: str = os.path.join( + platformdirs.user_data_dir(appname="delphi"), run_name + ) + huggingface: HuggingfaceConfig = field(default_factory=HuggingfaceConfig) + + # device + device: str = "auto" + + # I/O + eval_interval: int = 2000 + log_interval: int = 1 + eval_iters: int = 100 + always_save_checkpoint: bool = ( + False # if True, always save a checkpoint after each eval + ) + init_from: str = "scratch" # 'scratch' or 'resume' + # wandb logging + wandb_config: WandbConfig = field(default_factory=WandbConfig) + # data + batch_size: int = ( + 64 # if gradient_accumulation_steps > 1, this is the micro-batch size + ) + # model config + max_seq_len: int = 512 + # training + max_epochs: int = 10 # total number of training epochs + grad_clip: float = 1.0 # clip gradients at this value, or disable if == 0.0 + # (adamw) optimizer + optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) + # reproducibility + batch_ordering_seed = 1337 + torch_seed = 42 + # debugging + train_sample_limit: int = -1 # -1 implies no limit + val_sample_limit: int = -1 diff --git a/src/delphi/train/config/huggingface_config.py b/src/delphi/train/config/huggingface_config.py new file mode 100644 index 00000000..a0164fd5 --- /dev/null +++ b/src/delphi/train/config/huggingface_config.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass +from typing import Optional + +from beartype import beartype + + +@beartype +@dataclass(frozen=True) +class HuggingfaceConfig: + repo_id: Optional[str] = None + push_checkpoints_to_hub: bool = False diff --git a/src/delphi/train/config/models/__init__.py b/src/delphi/train/config/models/__init__.py new file mode 100644 index 00000000..567055aa --- /dev/null +++ b/src/delphi/train/config/models/__init__.py @@ -0,0 +1,4 @@ +from .model_config import ModelConfig, get_delphi_config +from .model_types import ModelTypes +from .typed_llama_config import TypedLlamaConfig +from .typed_mamba_config import TypedMambaConfig diff --git a/src/delphi/train/config/models/model_config.py b/src/delphi/train/config/models/model_config.py new file mode 100644 index 00000000..de84de69 --- /dev/null +++ b/src/delphi/train/config/models/model_config.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass +from typing import Optional + +from beartype import beartype + +from .model_types import ModelTypes +from .typed_llama_config import TypedLlamaConfig +from .typed_mamba_config import TypedMambaConfig +from .typed_model_config import TypedModelConfig + + +@beartype +@dataclass(frozen=True) +class ModelConfig: + model_type: str + mamba: Optional[TypedMambaConfig] = None + llama2: Optional[TypedLlamaConfig] = None + + def __post_init__(self): + if get_delphi_config(self) is None: + raise ValueError( + f"Model config specifies model_type = {self.model_type} " + "but doesn't provide a corresponding config." + ) + + +def get_delphi_config(config: ModelConfig) -> TypedModelConfig: + # get delphi config corresponding to model_type in model config + # e.g. {model_type: "llama2", llama2: my_delphi_llama_config} -> + # my_delphi_llama_config + delphi_config = getattr(config, config.model_type) + return delphi_config diff --git a/src/delphi/train/config/models/model_types.py b/src/delphi/train/config/models/model_types.py new file mode 100644 index 00000000..e8f27010 --- /dev/null +++ b/src/delphi/train/config/models/model_types.py @@ -0,0 +1,79 @@ +""" +For any given model we use, there are three associated types: +- TypedModelConfig: a typed dataclass that defines the arguments to the model. + We use this to enforce some semblance of type safety in configs and code in general. +- PretrainedConfig: a transformers config that defines the model architecture. + The arguments for this are defined in TypedModelConfig. +- PreTrainedModel: a transformers model that implements the model architecture. + Configured by PretrainedConfig. + + This file defines a ModelType dataclass that associated these three types for a given model, + and a ModelTypes container class that defines all the models we use in Delphi along with a + helpful ModelTypes.get() method for getting ModelType from a string. +""" +from dataclasses import dataclass + +from beartype import beartype +from beartype.typing import Type +from transformers import ( + LlamaConfig, + LlamaForCausalLM, + MambaConfig, + MambaForCausalLM, + PretrainedConfig, + PreTrainedModel, +) + +from .typed_llama_config import TypedLlamaConfig +from .typed_mamba_config import TypedMambaConfig +from .typed_model_config import TypedModelConfig + + +@beartype +@dataclass(frozen=True) +class ModelType: + name: str + delphi_config: type[TypedModelConfig] + config: type[PretrainedConfig] + model: type[PreTrainedModel] + + # Allow for ModelType == 'llama2' + def __eq__(self, other): + if isinstance(other, str): + return self.name == other + else: + return super().__eq__(other) + + def __post_init__(self): + # register the ModelType so ModelTypes.get(model_type_name) works + _model_name_to_model_type[self.name.lower()] = self + + +_model_name_to_model_type: dict[str, ModelType] = {} + + +# define new model types here +class ModelTypes: + MAMBA = ModelType( + name="mamba", + delphi_config=TypedMambaConfig, + config=MambaConfig, + model=MambaForCausalLM, + ) + LLAMA2 = ModelType( + name="llama2", + delphi_config=TypedLlamaConfig, + config=LlamaConfig, + model=LlamaForCausalLM, + ) + + # NEWMODEL = ModelType( # var name should match name + # name="newmodel", # string that will be associated with model in configs, etc + # typed_config=TypedNewModelConfig, # typed dataclass for args to config + # config=NewModelConfig, # transformers config + # model=NewModelForCausalLM, # transformers model + # ) + + @classmethod + def get(cls: Type["ModelTypes"], name: str) -> ModelType: + return _model_name_to_model_type[name.lower()] diff --git a/src/delphi/train/config/models/typed_llama_config.py b/src/delphi/train/config/models/typed_llama_config.py new file mode 100644 index 00000000..6b551fc5 --- /dev/null +++ b/src/delphi/train/config/models/typed_llama_config.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass +from typing import Any, Optional + +from beartype import beartype + +from .typed_model_config import TypedModelConfig + + +@beartype +@dataclass(frozen=True) +class TypedLlamaConfig(TypedModelConfig): + attention_bias: bool = False + attention_dropout: float = 0.0 + bos_token_id: int = -1 + eos_token_id: int = -2 + hidden_act: str = "silu" + hidden_size: int = 288 + initializer_range: float = 0.02 + intermediate_size: int = 288 + max_position_embeddings: int = 513 + num_attention_heads: int = 6 + num_hidden_layers: int = 6 + num_key_value_heads: int = 6 + pretraining_tp: int = 1 + rms_norm_eps: float = 1e-06 + rope_scaling: Optional[dict[str, Any]] = None + rope_theta: float = 10000.0 + tie_word_embeddings: bool = False + use_cache: bool = True + vocab_size: int = 4096 diff --git a/src/delphi/train/config/models/typed_mamba_config.py b/src/delphi/train/config/models/typed_mamba_config.py new file mode 100644 index 00000000..8b01a932 --- /dev/null +++ b/src/delphi/train/config/models/typed_mamba_config.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass +from typing import Union + +from beartype import beartype + +from .typed_model_config import TypedModelConfig + + +@beartype +@dataclass(frozen=True) +class TypedMambaConfig(TypedModelConfig): + # model shape + vocab_size: int = 4096 + hidden_size: int = 768 + state_size: int = 16 + num_hidden_layers: int = 32 + conv_kernel: int = 4 + expand: int = 2 + use_bias: bool = False + use_conv_bias: bool = True + # tokens + bos_token_id: int = 0 + eos_token_id: int = 0 + pad_token_id: int = 0 + # time step + time_step_rank: Union[int, str] = "auto" + time_step_scale: float = 1.0 + time_step_min: float = 0.001 + time_step_max: float = 0.1 + time_step_init_scheme: str = "random" # "random" or "uniform" + time_step_floor: float = 0.0001 + # misc + layer_norm_epsilon: float = 1e-05 + hidden_act: str = "silu" + initializer_range: float = 0.1 + residual_in_fp32: bool = True + rescale_prenorm_residual: bool = False + use_cache: bool = True + tie_word_embeddings: bool = True diff --git a/src/delphi/train/config/models/typed_model_config.py b/src/delphi/train/config/models/typed_model_config.py new file mode 100644 index 00000000..8eae1de8 --- /dev/null +++ b/src/delphi/train/config/models/typed_model_config.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass + + +@dataclass(frozen=True) +class TypedModelConfig: + """ + This is a dummy class for typing purposes. We could make a Union class that we update + every time we add a TypedModelConfig class, but that would mean remembering to go update + another thing when adding a new TypedModelConfig. + """ + + def __init__(self): + raise NotImplementedError( + "TypedModelConfig is a dummy class to provide typing for actual ModelConfig classes. It shouldn't ever be instantiated." + ) diff --git a/src/delphi/train/config/optimizer_config.py b/src/delphi/train/config/optimizer_config.py new file mode 100644 index 00000000..619e9d3c --- /dev/null +++ b/src/delphi/train/config/optimizer_config.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass + + +@dataclass +class OptimizerConfig: + # adamw optimizer + gradient_accumulation_steps: int = 4 # used to simulate larger batch sizes + learning_rate: float = 5e-4 # max learning rate + weight_decay: float = 1e-1 + beta1: float = 0.9 + beta2: float = 0.95 + grad_clip: float = 1.0 # clip gradients at this value, or disable if == 0.0 + # learning rate decay settings + decay_lr: bool = True # whether to decay the learning rate + warmup_iters: int = 1000 # how many steps to warm up for + min_lr: float = 0.0 # should be ~learning_rate/10 per Chinchill diff --git a/src/delphi/train/config/utils.py b/src/delphi/train/config/utils.py new file mode 100644 index 00000000..3f3a7e6f --- /dev/null +++ b/src/delphi/train/config/utils.py @@ -0,0 +1,83 @@ +import json +import logging +from pathlib import Path + +from beartype.typing import Any, Iterable +from dacite import from_dict +from platformdirs import user_config_dir + +from delphi.constants import CONFIG_PRESETS_DIR + +from .gigaconfig import GigaConfig + + +def _merge_dicts(merge_into: dict[str, Any], merge_from: dict[str, Any]): + """recursively merge two dicts, with values in merge_from taking precedence""" + for key, val in merge_from.items(): + if ( + key in merge_into + and isinstance(merge_into[key], dict) + and isinstance(val, dict) + ): + _merge_dicts(merge_into[key], val) + else: + merge_into[key] = val + + +def get_preset_paths() -> Iterable[Path]: + return Path(CONFIG_PRESETS_DIR).glob("*.json") # type: ignore + + +def get_user_config_path() -> Path: + _user_config_dir = Path(user_config_dir(appname="delphi")) + _user_config_dir.mkdir(parents=True, exist_ok=True) + user_config_path = _user_config_dir / "config.json" + return user_config_path + + +def get_presets_by_name() -> dict[str, GigaConfig]: + return { + preset.stem: build_config_from_files([preset]) for preset in get_preset_paths() + } + + +def get_config_dicts_from_files(config_files: list[Path]) -> list[dict[str, Any]]: + """loads config files in ascending priority order""" + config_dicts = [] + for config_file in config_files: + logging.info(f"Loading {config_file}") + with open(config_file, "r") as f: + config_dicts.append(json.load(f)) + return config_dicts + + +def combine_configs(configs: list[dict[str, Any]]) -> dict[str, Any]: + # combine configs dicts, with key "priority" setting precendence (higher priority overrides lower priority) + sorted_configs = sorted(configs, key=lambda c: c.get("priority", -999)) + combined_config = dict() + for config in sorted_configs: + _merge_dicts(merge_into=combined_config, merge_from=config) + return combined_config + + +def build_config_dict_from_files(config_files: list[Path]) -> dict[str, Any]: + configs_in_order = get_config_dicts_from_files(config_files) + combined_config = combine_configs(configs_in_order) + return combined_config + + +def build_config_from_files_and_overrides( + config_files: list[Path], overrides: dict[str, Any] +) -> GigaConfig: + combined_config = build_config_dict_from_files(config_files) + _merge_dicts(merge_into=combined_config, merge_from=overrides) + return from_dict(GigaConfig, combined_config) + + +def build_config_from_files(config_files: list[Path]) -> GigaConfig: + return build_config_from_files_and_overrides(config_files, {}) + + +def load_preset(preset_name: str) -> GigaConfig: + preset_path = Path(CONFIG_PRESETS_DIR) / f"{preset_name}.json" # type: ignore + return build_config_from_files([preset_path]) diff --git a/src/delphi/train/config/wandb_config.py b/src/delphi/train/config/wandb_config.py new file mode 100644 index 00000000..716dec0a --- /dev/null +++ b/src/delphi/train/config/wandb_config.py @@ -0,0 +1,9 @@ +from dataclasses import dataclass +from datetime import datetime + + +@dataclass +class WandbConfig: + log: bool = False + project: str = "delphi" + entity: str = "set_wandb.entity_to_your_wandb_username_to_make_wandb_logging_work" diff --git a/src/delphi/train/iteration_params.py b/src/delphi/train/iteration_params.py new file mode 100644 index 00000000..8a7e8d88 --- /dev/null +++ b/src/delphi/train/iteration_params.py @@ -0,0 +1,39 @@ +import logging +from dataclasses import dataclass + +from datasets import Dataset + +from .config import GigaConfig + + +@dataclass +class IterationParams: + num_batches: int + num_steps: int + eval_iters: int + lr_decay_iters: int + tokens_per_iter: int + + +def set_iteration_params( + config: GigaConfig, train_ds: Dataset, validation_ds: Dataset +) -> IterationParams: + num_batches = len(train_ds) // config.batch_size + # we take gradient_accumulation_steps batches per step (one in each microstep) + num_steps = num_batches // config.optimizer.gradient_accumulation_steps + eval_iters = min(12, len(validation_ds) // config.batch_size) + lr_decay_iters = ( + config.max_epochs * num_batches + ) # should be ~=max_iters per Chinchilla + tokens_per_iter = ( + config.optimizer.gradient_accumulation_steps + * config.batch_size + * config.max_seq_len + ) + logging.debug(f"tokens per iteration will be: {tokens_per_iter:,}") + logging.debug( + f"breaks down as: {config.optimizer.gradient_accumulation_steps} grad accum steps * {config.batch_size} batch size * {config.max_seq_len} max seq len" + ) + return IterationParams( + num_batches, num_steps, eval_iters, lr_decay_iters, tokens_per_iter + ) diff --git a/src/delphi/train/run_context.py b/src/delphi/train/run_context.py new file mode 100644 index 00000000..ced25908 --- /dev/null +++ b/src/delphi/train/run_context.py @@ -0,0 +1,14 @@ +# get contextual information about a training run + +from dataclasses import dataclass + +import torch + + +@dataclass +class RunContext: + device: torch.device + torch_version: str + delphi_version: str + transformers_version: str + os: str diff --git a/src/delphi/train/train_step.py b/src/delphi/train/train_step.py new file mode 100644 index 00000000..7b4d8fbe --- /dev/null +++ b/src/delphi/train/train_step.py @@ -0,0 +1,148 @@ +import logging +import time +from collections.abc import Callable, Generator + +import torch +from datasets import Dataset + +from .config import GigaConfig +from .config.models import ModelTypes +from .iteration_params import IterationParams +from .run_context import RunContext +from .utils import EvalData, ModelTrainingState, estimate_loss, get_next_xy, set_lr + + +def train_step( + model_training_state: ModelTrainingState, + train_ds: Dataset, + validation_ds: Dataset, + iteration_params: IterationParams, + eval_callbacks: list[Callable], + config: GigaConfig, + train_batch_iter: Generator, + run_context: RunContext, +): + """ + Runs a training step, updating (mutating in place) model_training_state + returns true if training should break, false otherwise + """ + model = model_training_state.model + optimizer = model_training_state.optimizer + + # here's how each train step works: + # 1. Set learning rate + # 2. (every eval_interval steps) evaluate, log to wandb, save checkpoint + # 3. forward backward update + # 4. log timing + + # 1. determine and set the learning rate for this iteration + model_training_state.lr = set_lr( + iteration_params.lr_decay_iters, + config, + optimizer, + model_training_state.iter_num, + ) + + # 2. evaluate the loss on train/val sets and write checkpoints + if model_training_state.iter_num % config.eval_interval == 0: + losses = estimate_loss( + model=model, + eval_iters=iteration_params.eval_iters, + batch_size=config.batch_size, + split_to_ds={"train": train_ds, "val": validation_ds}, + device=run_context.device, + epoch=model_training_state.epoch, + ) + new_best_val_loss = False + if losses["val"] < model_training_state.best_val_loss: + model_training_state.best_val_loss = float(losses["val"]) + new_best_val_loss = True + eval_data = EvalData( + tokens_per_iter=iteration_params.tokens_per_iter, + losses=losses, + new_best_val_loss=new_best_val_loss, + config=config, + model_training_state=model_training_state, + run_context=run_context, + ) + logging.info( + f"step {model_training_state.iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}" + ) + for callback in eval_callbacks: + callback(eval_data) + + # 3. forward backward update, with optional gradient accumulation to simulate larger batch size + logging.info( + f"gradient accumulation steps: {config.optimizer.gradient_accumulation_steps}, " + f"num_steps: {iteration_params.num_steps}, iter_num: {model_training_state.iter_num}" + ) + for micro_step in range(config.optimizer.gradient_accumulation_steps): + X, Y = get_next_xy(train_batch_iter, run_context.device) + loss = ( + model(X, labels=Y, return_dict=True).loss + / config.optimizer.gradient_accumulation_steps + ) + loss.backward() + # clip the gradient + if config.grad_clip != 0.0: + torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) # type: ignore + optimizer.step() + + # flush the gradients as soon as we can, no need for this memory anymore + optimizer.zero_grad(set_to_none=True) + + # 4. log timing + t1 = time.time() + dt = t1 - model_training_state.t0 + model_training_state.t0 = t1 + if model_training_state.iter_num % config.log_interval == 0: + # get loss as float, scale up due to the divide above. note: this is a CPU-GPU sync point + lossf = loss.item() * config.optimizer.gradient_accumulation_steps + if ( + model_training_state.local_iter_num >= 5 + ): # let the training loop settle a bit + mfu = estimate_mfu( + config=config, model=model_training_state.model, timedelta=dt + ) + model_training_state.running_mfu = ( + mfu + if model_training_state.running_mfu == -1.0 + else 0.9 * model_training_state.running_mfu + 0.1 * mfu + ) + logging.debug( + ( + f"{model_training_state.iter_num} | loss {lossf:.4f} | lr {model_training_state.lr:e} | " + f"{dt*1000:.2f}ms | mfu {model_training_state.running_mfu*100:.2f}%" + ) + ) + model_training_state.iter_num += 1 + model_training_state.local_iter_num += 1 + + +def estimate_mfu(config: GigaConfig, model: torch.nn.Module, timedelta: float) -> float: + """estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS""" + # first estimate the number of flops we do per iteration. + # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 + N = sum(p.numel() for p in model.parameters()) + if config.model_config.model_type == ModelTypes.LLAMA2: + cfg = model.config + L, H, Q, T = ( + cfg.num_hidden_layers, + cfg.num_attention_heads, + cfg.hidden_size // cfg.num_attention_heads, + cfg.max_position_embeddings, + ) + else: + logging.warn( + f"estimate_mfu not implemented for {config.model_config.model_type}, setting MFU to -1" + ) + return -1.0 + flops_per_token = 6 * N + 12 * L * H * Q * T + flops_per_fwdbwd = flops_per_token * T + fwdbwd_per_iter = config.batch_size * config.optimizer.gradient_accumulation_steps + flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter + # express our flops throughput as ratio of A100 bfloat16 peak flops + flops_achieved = flops_per_iter * (1.0 / timedelta) # per second + flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS + mfu = flops_achieved / flops_promised + return mfu diff --git a/src/delphi/train/training.py b/src/delphi/train/training.py new file mode 100644 index 00000000..ab189f14 --- /dev/null +++ b/src/delphi/train/training.py @@ -0,0 +1,96 @@ +import logging +import os +from dataclasses import fields +from typing import cast + +import torch +from datasets import Dataset +from tqdm import tqdm +from transformers import __version__ as transformers_version + +from delphi import __version__ as delphi_version + +from . import wandb_utils +from .config import GigaConfig +from .iteration_params import set_iteration_params +from .run_context import RunContext +from .train_step import train_step +from .utils import ( + ModelTrainingState, + batch_generator, + get_device, + initialize_model_training_state, + load_delphi_training_dataset, + save_checkpoint_if_needed, +) + + +def run_training(config: GigaConfig) -> tuple[ModelTrainingState, RunContext]: + logging.info("Starting training...") + logging.debug("Setting torch.use_deterministic_algorithms(True)") + torch.use_deterministic_algorithms(True) + torch.backends.cudnn.benchmark = False + torch.manual_seed(config.torch_seed) + logging.info("Config:") + for field in fields(config): + logging.info(f" {field.name}: {getattr(config, field.name)}") + # system + run_context = RunContext( + device=get_device(config.device), + torch_version=torch.__version__, + delphi_version=delphi_version, + transformers_version=transformers_version, + os=os.uname().version, + ) + logging.debug(f"Run context: {run_context}") + + # load data + logging.debug("Loading data...") + train_ds = cast( + Dataset, load_delphi_training_dataset("train", limit=config.train_sample_limit) + ) + validation_ds = cast( + Dataset, + load_delphi_training_dataset("validation", limit=config.val_sample_limit), + ) + + # derive iteration params (num_batches, num_steps, etc) + iteration_params = set_iteration_params(config, train_ds, validation_ds) + + # setup + logging.info("Setting up...") + os.makedirs(config.output_dir, exist_ok=True) + torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul + torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn + + # model init + model_training_state = initialize_model_training_state(config, run_context.device) + + # setup eval callbacks + eval_callbacks = [save_checkpoint_if_needed] + if config.wandb_config.log: + wandb_utils.init_wandb(config) + eval_callbacks.append(wandb_utils.log_to_wandb) + + # training loop + logging.info("Starting training...") + for epoch in range(config.max_epochs): + train_batch_iter = iter( + batch_generator( + train_ds, config.batch_size, epoch, config.batch_ordering_seed + ) + ) + model_training_state.epoch = epoch + for step in tqdm(range(iteration_params.num_steps)): + model_training_state.step = step + train_step( + model_training_state, + train_ds, + validation_ds, + iteration_params, + eval_callbacks, + config, + train_batch_iter, + run_context, + ) + return model_training_state, run_context diff --git a/src/delphi/train/utils.py b/src/delphi/train/utils.py new file mode 100644 index 00000000..300576e9 --- /dev/null +++ b/src/delphi/train/utils.py @@ -0,0 +1,319 @@ +import json +import logging +import math +import os +import time +from collections.abc import Generator +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import cast + +import safetensors.torch as st +import torch +from datasets import Dataset +from huggingface_hub import HfApi +from torch.optim import AdamW +from transformers import PreTrainedModel + +from delphi import constants +from delphi.eval.utils import load_delphi_dataset + +from .config.gigaconfig import GigaConfig +from .config.models import ModelTypes, get_delphi_config +from .config.models.model_config import ModelConfig +from .run_context import RunContext +from .shuffle import shuffle_list + + +@dataclass +class ModelTrainingState: + model: torch.nn.Module + optimizer: torch.optim.Optimizer + iter_num: int + local_iter_num: int + best_val_loss: float + running_mfu: float + t0: float + epoch: int + step: int + lr: float = 1.0e-5 + + +@dataclass +class EvalData: + # values we expose to eval callback functions + tokens_per_iter: int + losses: dict[str, float] + new_best_val_loss: bool + config: GigaConfig + model_training_state: ModelTrainingState + run_context: RunContext + + +def get_device(device_str: str = "auto") -> torch.device: + """ + Get torch device specified by device_str. May pass "auto" to set torch device automatically. + """ + # cuda if available; else mps if apple silicon; else cpu + if device_str == "auto": + if torch.cuda.is_available(): + device_str = "cuda" + elif torch.backends.mps.is_available(): + device_str = "mps" + else: + device_str = "cpu" + return torch.device(device_str) + + +def get_optimizer( + model: torch.nn.Module, + config: GigaConfig, + output_dir=None, + device: torch.device = torch.device("cpu"), +) -> AdamW: + optimizer = AdamW( + lr=config.optimizer.learning_rate, + params=model.parameters(), + weight_decay=config.optimizer.weight_decay, + betas=(config.optimizer.beta1, config.optimizer.beta2), + ) + if output_dir is not None: + opt_path = os.path.join(output_dir, "opt.pt") + with open(opt_path, "rb") as f: + optimizer.load_state_dict(torch.load(f)) + return optimizer + + +def get_lr( + iter_num: int, + warmup_iters: int, + learning_rate: float, + lr_decay_iters: int, + min_lr: float, +): + # 1) linear warmup for warmup_iters steps + if iter_num < warmup_iters: + return learning_rate * iter_num / warmup_iters + # 2) if it > lr_decay_iters, return min learning rate + if iter_num > lr_decay_iters: + return min_lr + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (iter_num - warmup_iters) / (lr_decay_iters - warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return min_lr + coeff * (learning_rate - min_lr) + + +def set_lr( + lr_decay_iters: int, + config: GigaConfig, + optimizer: torch.optim.Optimizer, + iter_num: int, +): + lr = ( + get_lr( + iter_num=iter_num, + warmup_iters=config.optimizer.warmup_iters, + learning_rate=config.optimizer.learning_rate, + lr_decay_iters=lr_decay_iters, + min_lr=config.optimizer.min_lr, + ) + if config.optimizer.decay_lr + else config.optimizer.learning_rate + ) + for param_group in optimizer.param_groups: + param_group["lr"] = lr + return lr + + +def save_checkpoint_if_needed(eval_data: EvalData): + mts = eval_data.model_training_state + # we save if it's not the first iter AND at least one of: + # 1) we have a new best validation loss + # 2) always_save_checkpoint is set + if mts.iter_num == 0: + return + if (not eval_data.new_best_val_loss) and ( + not eval_data.config.always_save_checkpoint + ): + return + results_path = os.path.join(eval_data.config.output_dir, f"iter_{mts.iter_num:06d}") + logging.info(f"saving checkpoint to {results_path}") + save_results( + config=eval_data.config, + train_results=mts, + run_context=eval_data.run_context, + results_path=results_path, + ) + + +def load_model_from_checkpoint(config: GigaConfig, output_dir: str) -> torch.nn.Module: + model = config_to_model(config.model_config) + st.load_model(model, os.path.join(output_dir, "model", "model.safetensors")) + return model + + +def config_to_model(config: ModelConfig) -> PreTrainedModel: + # get ModelType object from name ('llama2' -> ModelType(...)) + delphi_config = get_delphi_config(config) + model_type = ModelTypes.get(config.model_type) + return model_type.model(model_type.config(**asdict(delphi_config))) + + +def initialize_model_training_state( + config: GigaConfig, device: torch.device +) -> ModelTrainingState: + t0 = time.time() + training_state = None + if config.init_from == "scratch": + # init a new model from scratch + logging.debug("Initializing a new model from scratch") + model = config_to_model(config.model_config) + checkpoint = None + # TODO: resume from huggingface model + elif config.init_from == "resume": + logging.info(f"Resuming training from {config.output_dir}") + checkpoint = config.output_dir + model = load_model_from_checkpoint(config, checkpoint) + with open(os.path.join(checkpoint, "training_state.json"), "r") as f: + training_state = json.load(f) + model.to(device) # type: ignore + # optimizer + optimizer = get_optimizer( + model=model, + config=config, + output_dir=config.output_dir + if (Path(config.output_dir) / "opt.safetensors").exists() + else None, + device=device, + ) + epoch = training_state.get("epoch", 0) if training_state is not None else 0 + step = training_state.get("step", 0) if training_state is not None else 0 + best_val_loss = training_state.get("best_val_loss", 1e9) if training_state else 1e9 + iter_num = training_state.get("iter_num", 0) if training_state else 0 + local_iter_num = training_state.get("local_iter_num", 0) if training_state else 0 + running_mfu = training_state.get("running_mfu", 0.0) if training_state else -1.0 + checkpoint = None # free up memory + return ModelTrainingState( + model=model, + optimizer=optimizer, + iter_num=iter_num, + local_iter_num=local_iter_num, + best_val_loss=best_val_loss, + running_mfu=running_mfu, + t0=t0, + epoch=epoch, + step=step, + ) + + +def load_delphi_training_dataset(split: str, limit: int = -1): + """For training, we want (X, Y) pairs, where X is a chunk of text and Y is the next token.) + To construct this, we take the original tokenized dataset, break it into max_seq_len+1 length chunks, + and then take [:-1] as X and [1:] as Y. + """ + if limit == -1: + ds = load_delphi_dataset(constants.TOKENIZED_CORPUS_DATASET, split) + else: + ds = load_delphi_dataset(constants.TOKENIZED_CORPUS_DATASET, split).select( + range(limit) + ) + ds.set_format("torch") + return ds + + +def get_next_xy( + train_batch_iter: Generator, + device: torch.device + # train_batch_iter: Generator[dict[str, list[int]], None, None], device: torch.device +) -> tuple[torch.Tensor, torch.Tensor]: + data = next(train_batch_iter).to(device) + X, Y = data[:, :-1], data[:, 1:] + return X, Y + + +def batch_generator( + dataset: Dataset, batch_size: int, epoch: int, ordering_seed: int +) -> Generator[torch.Tensor, None, None]: + sampler = list(range(len(dataset))) # type: ignore + shuffle_list(sampler, seed=ordering_seed + epoch) + sampler = torch.Tensor(sampler) + for samples in sampler.split(batch_size): + yield dataset[samples]["tokens"] + + +@torch.no_grad() +def estimate_loss( + model: torch.nn.Module, + eval_iters: int, + batch_size: int, + split_to_ds: dict[str, Dataset], + device: torch.device, + epoch: int, +) -> dict[str, float]: + """helps estimate an arbitrarily accurate loss over either split using many batches""" + out = {} + model.eval() + for split, ds in split_to_ds.items(): + batch_iter = iter(batch_generator(ds, batch_size, epoch, 1234)) + losses = torch.zeros(eval_iters) # keep on CPU + for k in range(min(eval_iters, len(ds) // batch_size)): # type: ignore + X, Y = get_next_xy(batch_iter, device) + loss = model(X, labels=Y, return_dict=True).loss + losses[k] = loss.item() + out[split] = losses.mean() + model.train() + return out + + +def upload_to_huggingface(eval_data: EvalData): + model = eval_data.model_training_state.model + if isinstance(model, PreTrainedModel): + model = cast(PreTrainedModel, model) + model.save_pretrained(eval_data.config.output_dir) + + +def save_results( + config: GigaConfig, + train_results: ModelTrainingState, + run_context: RunContext, + results_path: str, +): + os.makedirs(results_path, exist_ok=True) + with open(os.path.join(results_path, "config.json"), "w") as file: + json.dump(asdict(config), file, indent=2) + model = train_results.model + if isinstance(model, PreTrainedModel): + model = cast(PreTrainedModel, model) + model.save_pretrained( + save_directory=os.path.join(results_path, "model"), + ) + else: + st.save_model( + train_results.model, + os.path.join(results_path, "model", "model.safetensors"), + ) + with open(os.path.join(results_path, "opt.pt"), "wb") as f: + torch.save(train_results.optimizer.state_dict(), f) + with open(os.path.join(results_path, "training_state.json"), "w") as file: + training_state_dict = { + "iter_num": train_results.iter_num, + "local_iter_num": train_results.local_iter_num, + "best_val_loss": train_results.best_val_loss, + "running_mfu": train_results.running_mfu, + "lr": train_results.lr, + "epoch": train_results.epoch, + "step": train_results.step, + } + json.dump(training_state_dict, file, indent=2) + with open(os.path.join(results_path, "run_context.json"), "w") as file: + run_context_dict = asdict(run_context) + run_context_dict["device"] = str(run_context.device) + json.dump(run_context_dict, file, indent=2) + if config.huggingface.push_checkpoints_to_hub: + api = HfApi() + api.upload_folder( + folder_path=results_path, + repo_id=str(config.huggingface.repo_id), + path_in_repo=f"iter_{train_results.iter_num}/", + ) diff --git a/src/delphi/train/wandb_utils.py b/src/delphi/train/wandb_utils.py new file mode 100644 index 00000000..89ed5518 --- /dev/null +++ b/src/delphi/train/wandb_utils.py @@ -0,0 +1,43 @@ +import logging +import os +from dataclasses import asdict + +import wandb + +from .config import GigaConfig +from .utils import EvalData + + +def silence_wandb(): + # set env var WANDB_SILENT=true + os.environ["WANDB_SILENT"] = "true" + + +def init_wandb(config: GigaConfig): + # if log level < debug, silence wandb + if logging.getLogger().level > logging.INFO: + silence_wandb() + wandb.init( + entity=config.wandb_config.entity, + project=config.wandb_config.project, + name=config.run_name, + config=asdict(config), + ) + + +def log_to_wandb(eval_data: EvalData): + mts = eval_data.model_training_state + try: + wandb.log( + { + "iter": mts.iter_num, + "tokens": mts.iter_num * eval_data.tokens_per_iter, + "loss/train": eval_data.losses["train"], + "loss/val": eval_data.losses["val"], + "lr": mts.lr, + "mfu": mts.running_mfu * 100, # convert to percentage + }, + step=mts.iter_num, + ) + except Exception as e: + logging.error(f"logging to wandb failed: {e}") diff --git a/tests/train/test_wandb_utils.py b/tests/train/test_wandb_utils.py new file mode 100644 index 00000000..093fd0c5 --- /dev/null +++ b/tests/train/test_wandb_utils.py @@ -0,0 +1,99 @@ +import os +from dataclasses import asdict +from unittest.mock import MagicMock, patch + +import pytest +import torch +from dacite import from_dict + +from delphi.train.config import GigaConfig +from delphi.train.config.models import TypedLlamaConfig +from delphi.train.run_context import RunContext +from delphi.train.utils import EvalData, initialize_model_training_state +from delphi.train.wandb_utils import init_wandb, log_to_wandb, silence_wandb + + +@pytest.fixture +def mock_giga_config(): + config = from_dict( + GigaConfig, + { + "run_name": "test_run", + "device": "cpu", + "model_config": { + "model_type": "llama2", + "llama2": asdict(TypedLlamaConfig()), + }, + "wandb_config": { + "log": True, + "entity": "test_entity", + "project": "test_project", + }, + }, + ) + return config + + +@pytest.fixture +def mock_model_training_state(mock_giga_config): + device = torch.device(mock_giga_config.device) + # this is gross and horrible, sorry, I'm rushing + mts = initialize_model_training_state(config=mock_giga_config, device=device) + mts.step = 1 + mts.epoch = 1 + mts.iter_num = 1 + mts.lr = 0.001 + mts.running_mfu = 3.0 + return mts + + +@pytest.fixture +def mock_eval_data(mock_giga_config, mock_model_training_state): + eval_data = EvalData( + model_training_state=mock_model_training_state, + tokens_per_iter=1000, + losses={"train": 0.5, "val": 0.4}, + new_best_val_loss=False, + config=mock_giga_config, + run_context=RunContext( + device=torch.device("cpu"), + torch_version="-1", + delphi_version="-1", + os="test", + transformers_version="-1", + ), + ) + return eval_data + + +@patch.dict("os.environ", {}, clear=True) +def test_silence_wandb(): + silence_wandb() + assert os.environ["WANDB_SILENT"] == "true" + + +@patch("wandb.init") +def test_init_wandb(mock_wandb_init: MagicMock, mock_giga_config): + init_wandb(mock_giga_config) + mock_wandb_init.assert_called_once_with( + entity="test_entity", + project="test_project", + name="test_run", + config=asdict(mock_giga_config), + ) + + +@patch("wandb.log") +def test_log_to_wandb(mock_wandb_log, mock_eval_data): + log_to_wandb(mock_eval_data) + mock_wandb_log.assert_called_once_with( + { + "iter": 1, + "tokens": 1000, + "loss/train": 0.5, + "loss/val": 0.4, + "lr": 0.001, + "mfu": 300.0, + }, + step=1, + )