diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml
index 397ae1dd..34a73eea 100644
--- a/.github/workflows/checks.yml
+++ b/.github/workflows/checks.yml
@@ -6,7 +6,7 @@ on:
- main
pull_request:
branches:
- - '*'
+ - "*"
permissions:
actions: write
@@ -17,10 +17,19 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
+ with:
+ submodules: recursive
- name: setup python
uses: actions/setup-python@v5
with:
python-version: "3.10"
+ cache: "pip"
+ - name: cache models and datasets
+ uses: actions/cache@v3
+ with:
+ path: |
+ ~/.cache/huggingface
+ key: ${{ runner.os }}-huggingface-cache-v1 # increment this key to invalidate the cache when new models/datasets are added
- name: dependencies
run: |
python -m pip install --upgrade pip
@@ -31,4 +40,4 @@ jobs:
- name: isort
run: isort --check .
- name: pytest
- run: pytest
\ No newline at end of file
+ run: pytest
diff --git a/.gitignore b/.gitignore
index 68bc17f9..0bce421a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,6 +6,11 @@ __pycache__/
# C extensions
*.so
+bin
+include
+lib64
+pyvenv.cfg
+
# Distribution / packaging
.Python
build/
@@ -158,3 +163,15 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
+
+# ignore wandb files
+**/wandb/*
+**/*.wandb
+**/wandb-summary.json
+**/wandb-metadata.json
+
+# scratch notebook
+notebooks/scratch.ipynb
+
+# dsstore
+.DS_Store
\ No newline at end of file
diff --git a/.gitmodules b/.gitmodules
index 1bdff169..009900cb 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -1,3 +1,3 @@
-[submodule "src/delphi/train/llama2c"]
- path = src/delphi/train/llama2c
+[submodule "src/llama2c"]
+ path = src/llama2c
url = https://github.com/delphi-suite/llama2.c.git
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 84fc89ad..fbe11bd0 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -8,4 +8,3 @@ repos:
rev: 5.13.2
hooks:
- id: isort
- name: isort (python)
\ No newline at end of file
diff --git a/.vscode/settings.json b/.vscode/settings.json
index 934a143e..5a69a6b6 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -7,4 +7,5 @@
"source.organizeImports": "explicit"
},
"python.analysis.typeCheckingMode": "basic",
+ "black-formatter.importStrategy": "fromEnvironment",
}
\ No newline at end of file
diff --git a/README.md b/README.md
index c796115a..105d7ca7 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,24 @@
# Delphi
+
Interpreting Small Language Models Across Time and Scale
-# setup
-1. make python 3.10 virtual env in `.venv`
-2. install dependencies `pip install -r requirements.txt`
-3. install the project in editable state `pip install -e .`
-4. run tests `pytest`
+# Setup
+
+1. Clone this repo and submodules: `git clone https://github.com/delphi-suite/delphi.git --recurse-submodules`
+2. make python 3.10 virtual env in `.venv`
+3. install dependencies `pip install -r requirements.txt`
+4. install the project in editable state `pip install -e .`
+5. run tests `pytest`
+
+## Submodule Setup
+If you cloned without `--recurse-submodules`, you can still install the submodules later with:
+```bash
+git submodule init
+git submodule update
+```
+
+# Formatting
-# formatting
We're using black & isort to format the code. To make sure your changes adhere to the rules:
1. follow setup instructions above
@@ -16,24 +27,25 @@ We're using black & isort to format the code. To make sure your changes adhere t
When you save a file vscode should automatically format it. Otherwise, pre-commit will do that, but you will need to add the changes and commit again.
-# pull requests
+# Pull Requests
1. make a branch
- - if it relates to an existing issue
- - go to the issue page and click *Create a branch* under *Development*
- - if the default name is not very long, keep it; otherwise, make it shorter, but keep the issue number in the front
- - otherwise pick a short but descriptive name, a few hyphen-separated-words
+ - if it relates to an existing issue
+ - go to the issue page and click _Create a branch_ under _Development_
+ - if the default name is not very long, keep it; otherwise, make it shorter, but keep the issue number in the front
+ - otherwise pick a short but descriptive name, a few hyphen-separated-words
2. make your changes
- - include unit tests
- - update README if needed
+ - include unit tests
+ - update README if needed
+ - if new huggingface datasets/models are added to testing, increment the cache number in `.github/workflows/checks.yml`
3. make a pull request
- - if it isn't ready for review yet, mark it as draft
- - check if CI is passing
- - if the change is big, try to keep the commit history clean using interactive rebase
- - don't push more often than it's needed, we're running github actions on a free tier
- - if there were any changes to the main branch, rebase on top of it
- - explain the change
- - provide short description; focus on things that were not mentioned in the relevant issue
- - comment important sections of the code in *Files changed* tab
- - when it's ready, add the relevant stakeholders as reviewers
-4. after the comments are resolved and PR is approved, merge it using *Squash and merge*
\ No newline at end of file
+ - if it isn't ready for review yet, mark it as draft
+ - check if CI is passing
+ - if the change is big, try to keep the commit history clean using interactive rebase
+ - don't push more often than it's needed, we're running github actions on a free tier
+ - if there were any changes to the main branch, rebase on top of it
+ - explain the change
+ - provide short description; focus on things that were not mentioned in the relevant issue
+ - comment important sections of the code in _Files changed_ tab
+ - when it's ready, add the relevant stakeholders as reviewers
+4. after the comments are resolved and PR is approved, merge it using _Squash and merge_
diff --git a/notebooks/end2end_demo.ipynb b/notebooks/end2end_demo.ipynb
new file mode 100644
index 00000000..f08aba38
--- /dev/null
+++ b/notebooks/end2end_demo.ipynb
@@ -0,0 +1,133 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "from typing import cast\n",
+ "import pickle\n",
+ "from collections import defaultdict\n",
+ "\n",
+ "from datasets import load_dataset, Dataset\n",
+ "\n",
+ "from delphi.constants import STATIC_ASSETS_DIR\n",
+ "from delphi.eval import utils\n",
+ "from delphi.eval import constants\n",
+ "from delphi.eval.vis_per_token_model import visualize_per_token_category\n",
+ "\n",
+ "# from delphi.eval.calc_model_group_stats import calc_model_group_stats\n",
+ "from delphi.eval.token_labelling import TOKEN_LABELS"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# load data\n",
+ "tokenized_corpus_dataset = cast(Dataset, load_dataset(constants.tokenized_corpus_dataset))[\"validation\"]\n",
+ "\n",
+ "# TODO: convert to use static paths\n",
+ "# with open(\"../src/delphi/eval/labelled_token_ids_dict.pkl\", \"rb\") as f:\n",
+ "# token_groups = pickle.load(f)\n",
+ "# model_group_stats = calc_model_group_stats(\n",
+ "# tokenized_corpus_dataset, logprob_datasets, token_groups, token_groups[0].keys()\n",
+ "# )\n",
+ "with open(f\"{STATIC_ASSETS_DIR}/model_group_stats.pkl\", \"rb\") as f:\n",
+ " model_group_stats = pickle.load(f)\n",
+ "\n",
+ "logprob_datasets = utils.load_logprob_datasets(\"validation\")\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Visualization"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "0f8846898fbb4a1b9e872ff6511acd3d",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "VBox(children=(Dropdown(description='Token Category:', options=('Capitalized', 'Is Determiner', 'Is Interjunct…"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "performance_data = defaultdict(dict)\n",
+ "for model in constants.LLAMA2_MODELS:\n",
+ " for token_group_desc in TOKEN_LABELS:\n",
+ " if (model, token_group_desc) not in model_group_stats:\n",
+ " continue\n",
+ " stats = model_group_stats[(model, token_group_desc)]\n",
+ " performance_data[model][token_group_desc] = (\n",
+ " -stats[\"median\"],\n",
+ " -stats[\"75th\"],\n",
+ " -stats[\"25th\"],\n",
+ " )\n",
+ "\n",
+ "visualize_per_token_category(\n",
+ " performance_data,\n",
+ " log_scale=True,\n",
+ " bg_color=\"LightGrey\",\n",
+ " line_color=\"Red\",\n",
+ " marker_color=\"Orange\",\n",
+ " bar_color=\"Green\",\n",
+ ")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "tinyevals",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/notebooks/per_token_plot.ipynb b/notebooks/per_token_plot.ipynb
new file mode 100644
index 00000000..198057c7
--- /dev/null
+++ b/notebooks/per_token_plot.ipynb
@@ -0,0 +1,100 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "696575431f65420e9dc22c3b3476bfbb",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "VBox(children=(Dropdown(description='Token Category:', options=('nouns', 'verbs', 'prepositions', 'adjectives'…"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from collections import defaultdict\n",
+ "import math\n",
+ "import random\n",
+ "import numpy as np\n",
+ "\n",
+ "from delphi.eval.vis_per_token_model import visualize_per_token_category\n",
+ "\n",
+ "\n",
+ "random.seed(0)\n",
+ "\n",
+ "# generate mock data\n",
+ "model_names = ['llama2-100k', 'llama2-200k', 'llama2-1m', 'llama2-10m']\n",
+ "categories = ['nouns', 'verbs', 'prepositions', 'adjectives']\n",
+ "entries = [200, 100, 150, 300]\n",
+ "performance_data = defaultdict()\n",
+ "for i, model in enumerate(model_names):\n",
+ " performance_data[model] = defaultdict()\n",
+ " for cat in categories:\n",
+ " x = [math.log2(random.random()) for _ in range(entries[i])]\n",
+ " means = np.mean(x)\n",
+ " err_low = means - np.percentile(x, 25)\n",
+ " err_hi = np.percentile(x, 75) - means\n",
+ " performance_data[model][cat] = (-means, err_low, err_hi)\n",
+ "\n",
+ "\n",
+ "visualize_per_token_category(performance_data, log_scale=True, bg_color='LightGrey', line_color=\"Red\", marker_color='Orange', bar_color='Green')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "cb3af5248a4a40118c36a527c927289d",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "VBox(children=(Dropdown(description='Token Category:', options=('nouns', 'verbs', 'prepositions', 'adjectives'…"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "visualize_per_token_category(performance_data, log_scale=False)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/notebooks/token_labelling.ipynb b/notebooks/token_labelling.ipynb
new file mode 100644
index 00000000..45423d8c
--- /dev/null
+++ b/notebooks/token_labelling.ipynb
@@ -0,0 +1,435 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Giving tokens a label - How to categorize tokens\n",
+ "\n",
+ "\n",
+ "The first part of this Notebook contains elements that explain how to label tokens and how the functions work.\n",
+ "\n",
+ "The second part shows how all tokens are labelled that are used for our delphi language models.3\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The autoreload extension is already loaded. To reload it, use:\n",
+ " %reload_ext autoreload\n"
+ ]
+ }
+ ],
+ "source": [
+ "# autoreload\n",
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
+ "from pprint import pprint \n",
+ "\n",
+ "import spacy\n",
+ "from tqdm.auto import tqdm\n",
+ "\n",
+ "import delphi\n",
+ "\n",
+ "from delphi.eval import token_labelling"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "# 1) How to use the token labelling functions"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We analyze a simple sentence and receive the respective tokens with their analyzed attributes. \n",
+ "The grammatical/linguistic analysis is done by a model provided by spaCy for the English language."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Peter \t PROPN \t nsubj \t PERSON\n",
+ "is \t AUX \t ROOT \t \n",
+ "a \t DET \t det \t \n",
+ "person \t NOUN \t attr \t \n"
+ ]
+ }
+ ],
+ "source": [
+ "# Load the english model\n",
+ "nlp = spacy.load(\"en_core_web_sm\")\n",
+ "\n",
+ "# Create a Doc object from a given text\n",
+ "doc = nlp(\"Peter is a person\")\n",
+ "\n",
+ "token = doc[0]\n",
+ "for tok in doc:\n",
+ " print(tok,\"\\t\", tok.pos_, \"\\t\", tok.dep_, \"\\t\", tok.ent_type_)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's get the label for our custom token that we just printed."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'Capitalized': True,\n",
+ " 'Is Adjective': False,\n",
+ " 'Is Adposition': False,\n",
+ " 'Is Adverb': False,\n",
+ " 'Is Auxiliary': False,\n",
+ " 'Is Coordinating conjuction': False,\n",
+ " 'Is Determiner': False,\n",
+ " 'Is Interjunction': False,\n",
+ " 'Is Named Entity': True,\n",
+ " 'Is Noun': False,\n",
+ " 'Is Numeral': False,\n",
+ " 'Is Other': False,\n",
+ " 'Is Particle': False,\n",
+ " 'Is Pronoun': False,\n",
+ " 'Is Proper Noun': True,\n",
+ " 'Is Punctuation': False,\n",
+ " 'Is Subordinating conjuction': False,\n",
+ " 'Is Symbol': False,\n",
+ " 'Is Verb': False,\n",
+ " 'Starts with space': False}\n"
+ ]
+ }
+ ],
+ "source": [
+ "from delphi.eval import token_labelling\n",
+ "\n",
+ "label = token_labelling.label_single_token(token)\n",
+ "pprint(label)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's get an understanding of what the labels acutally mean.\n",
+ "Use this function to receive an explanation for a single token."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------- Explanation of token labels --------\n",
+ "Token text: Peter\n",
+ "Token dependency: nominal subject\n",
+ "Token POS: proper noun\n",
+ "---------------- Token labels ---------------\n",
+ " 0 Starts with space False\n",
+ " 1 Capitalized True\n",
+ " 2 Is Adjective False\n",
+ " 3 Is Adposition False\n",
+ " 4 Is Adverb False\n",
+ " 5 Is Auxiliary False\n",
+ " 6 Is Coordinating conjuction False\n",
+ " 7 Is Determiner False\n",
+ " 8 Is Interjunction False\n",
+ " 9 Is Noun False\n",
+ " 10 Is Numeral False\n",
+ " 11 Is Particle False\n",
+ " 12 Is Pronoun False\n",
+ " 13 Is Proper Noun True\n",
+ " 14 Is Punctuation False\n",
+ " 15 Is Subordinating conjuction False\n",
+ " 16 Is Symbol False\n",
+ " 17 Is Verb False\n",
+ " 18 Is Other False\n",
+ " 19 Is Named Entity True\n"
+ ]
+ }
+ ],
+ "source": [
+ "token_labelling.explain_token_labels(token)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "If you are interested in all the possible labels a token can have, that spaCy is capable of assigning, then call the same function but without any argument:\n",
+ "```Python\n",
+ ">>> token_labelling.explain_token_labels()\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Batched token labelling\n",
+ "Next, let us analyze a batch of sentences and have them labelled.\n",
+ "> In the example below the input sentences are not yet tokenized, so spaCy uses its internal tokenizer."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Token: Peter\n",
+ "Starts with space | Capitalized | Is Adjective | Is Adposition | Is Adverb | Is Auxiliary | Is Coordinating conjuction | Is Determiner | Is Interjunction | Is Noun | Is Numeral | Is Particle | Is Pronoun | Is Proper Noun | Is Punctuation | Is Subordinating conjuction | Is Symbol | Is Verb | Is Other | Is Named Entity\n",
+ "False | True | False | False | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | True \n",
+ "---\n",
+ "Token: is\n",
+ "Starts with space | Capitalized | Is Adjective | Is Adposition | Is Adverb | Is Auxiliary | Is Coordinating conjuction | Is Determiner | Is Interjunction | Is Noun | Is Numeral | Is Particle | Is Pronoun | Is Proper Noun | Is Punctuation | Is Subordinating conjuction | Is Symbol | Is Verb | Is Other | Is Named Entity\n",
+ "False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False \n",
+ "---\n",
+ "Token: a\n",
+ "Starts with space | Capitalized | Is Adjective | Is Adposition | Is Adverb | Is Auxiliary | Is Coordinating conjuction | Is Determiner | Is Interjunction | Is Noun | Is Numeral | Is Particle | Is Pronoun | Is Proper Noun | Is Punctuation | Is Subordinating conjuction | Is Symbol | Is Verb | Is Other | Is Named Entity\n",
+ "False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False \n",
+ "---\n",
+ "Token: person\n",
+ "Starts with space | Capitalized | Is Adjective | Is Adposition | Is Adverb | Is Auxiliary | Is Coordinating conjuction | Is Determiner | Is Interjunction | Is Noun | Is Numeral | Is Particle | Is Pronoun | Is Proper Noun | Is Punctuation | Is Subordinating conjuction | Is Symbol | Is Verb | Is Other | Is Named Entity\n",
+ "False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | False \n",
+ "---\n",
+ "Token: .\n",
+ "Starts with space | Capitalized | Is Adjective | Is Adposition | Is Adverb | Is Auxiliary | Is Coordinating conjuction | Is Determiner | Is Interjunction | Is Noun | Is Numeral | Is Particle | Is Pronoun | Is Proper Noun | Is Punctuation | Is Subordinating conjuction | Is Symbol | Is Verb | Is Other | Is Named Entity\n",
+ "False | False | False | False | False | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False \n",
+ "---\n",
+ "\n",
+ "\n",
+ "5\n",
+ "[{'Starts with space': False, 'Capitalized': True, 'Is Adjective': False, 'Is Adposition': False, 'Is Adverb': False, 'Is Auxiliary': False, 'Is Coordinating conjuction': False, 'Is Determiner': False, 'Is Interjunction': False, 'Is Noun': False, 'Is Numeral': False, 'Is Particle': False, 'Is Pronoun': False, 'Is Proper Noun': True, 'Is Punctuation': False, 'Is Subordinating conjuction': False, 'Is Symbol': False, 'Is Verb': False, 'Is Other': False, 'Is Named Entity': True}, {'Starts with space': False, 'Capitalized': False, 'Is Adjective': False, 'Is Adposition': False, 'Is Adverb': False, 'Is Auxiliary': True, 'Is Coordinating conjuction': False, 'Is Determiner': False, 'Is Interjunction': False, 'Is Noun': False, 'Is Numeral': False, 'Is Particle': False, 'Is Pronoun': False, 'Is Proper Noun': False, 'Is Punctuation': False, 'Is Subordinating conjuction': False, 'Is Symbol': False, 'Is Verb': False, 'Is Other': False, 'Is Named Entity': False}, {'Starts with space': False, 'Capitalized': False, 'Is Adjective': False, 'Is Adposition': False, 'Is Adverb': False, 'Is Auxiliary': False, 'Is Coordinating conjuction': False, 'Is Determiner': True, 'Is Interjunction': False, 'Is Noun': False, 'Is Numeral': False, 'Is Particle': False, 'Is Pronoun': False, 'Is Proper Noun': False, 'Is Punctuation': False, 'Is Subordinating conjuction': False, 'Is Symbol': False, 'Is Verb': False, 'Is Other': False, 'Is Named Entity': False}, {'Starts with space': False, 'Capitalized': False, 'Is Adjective': False, 'Is Adposition': False, 'Is Adverb': False, 'Is Auxiliary': False, 'Is Coordinating conjuction': False, 'Is Determiner': False, 'Is Interjunction': False, 'Is Noun': True, 'Is Numeral': False, 'Is Particle': False, 'Is Pronoun': False, 'Is Proper Noun': False, 'Is Punctuation': False, 'Is Subordinating conjuction': False, 'Is Symbol': False, 'Is Verb': False, 'Is Other': False, 'Is Named Entity': False}, {'Starts with space': False, 'Capitalized': False, 'Is Adjective': False, 'Is Adposition': False, 'Is Adverb': False, 'Is Auxiliary': False, 'Is Coordinating conjuction': False, 'Is Determiner': False, 'Is Interjunction': False, 'Is Noun': False, 'Is Numeral': False, 'Is Particle': False, 'Is Pronoun': False, 'Is Proper Noun': False, 'Is Punctuation': True, 'Is Subordinating conjuction': False, 'Is Symbol': False, 'Is Verb': False, 'Is Other': False, 'Is Named Entity': False}]\n"
+ ]
+ }
+ ],
+ "source": [
+ "sentences = [\n",
+ " \"Peter is a person.\"\n",
+ "]\n",
+ "labels = token_labelling.label_batch_sentences(sentences, tokenized=False, verbose=True)\n",
+ "\n",
+ "print(len(labels[0]))\n",
+ "print(labels[0])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now with our own tokenization. E.g. the one from our TinyStories models."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "5\n",
+ "[{'Starts with space': False, 'Capitalized': True, 'Is Noun': True, 'Is Pronoun': False, 'Is Adjective': False, 'Is Verb': False, 'Is Adverb': False, 'Is Preposition': False, 'Is Conjunction': False, 'Is Interjunction': False, 'Is Named Entity': False}, {'Starts with space': False, 'Capitalized': False, 'Is Noun': False, 'Is Pronoun': False, 'Is Adjective': False, 'Is Verb': False, 'Is Adverb': True, 'Is Preposition': False, 'Is Conjunction': False, 'Is Interjunction': False, 'Is Named Entity': False}, {'Starts with space': False, 'Capitalized': False, 'Is Noun': False, 'Is Pronoun': False, 'Is Adjective': True, 'Is Verb': False, 'Is Adverb': False, 'Is Preposition': False, 'Is Conjunction': False, 'Is Interjunction': False, 'Is Named Entity': False}, {'Starts with space': False, 'Capitalized': False, 'Is Noun': True, 'Is Pronoun': False, 'Is Adjective': False, 'Is Verb': False, 'Is Adverb': False, 'Is Preposition': False, 'Is Conjunction': False, 'Is Interjunction': False, 'Is Named Entity': False}, {'Starts with space': False, 'Capitalized': False, 'Is Noun': False, 'Is Pronoun': False, 'Is Adjective': False, 'Is Verb': False, 'Is Adverb': False, 'Is Preposition': False, 'Is Conjunction': False, 'Is Interjunction': False, 'Is Named Entity': False}]\n"
+ ]
+ }
+ ],
+ "source": [
+ "sentences = [\n",
+ " [\"This \", \"is \", \"a \", \"sentence\", \".\"]\n",
+ "]\n",
+ "labelled_sentences = token_labelling.label_batch_sentences(sentences, tokenized=True, verbose=False)\n",
+ "\n",
+ "print(len(labelled_sentences[0]))\n",
+ "print(labelled_sentences[0])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# 2) Labelling all tokens in the dataset\n",
+ "\n",
+ "Now we want to label all the tokens that our tokenizer knows - its entire vocabulary.\n",
+ "\n",
+ "Using thy script in `scripts/label_all_tokens.py` we get the files:\n",
+ "- `src\\delphi\\eval\\all_tokens_list.txt`\n",
+ "- `src\\delphi\\eval\\labelled_token_ids_dict.pkl`\n",
+ "\n",
+ "Let's load the tokenizer so that we can look at the labelled tokens.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "c:\\Users\\joshu\\anaconda3\\envs\\delphi2\\lib\\site-packages\\transformers\\utils\\generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
+ " _torch_pytree._register_pytree_node(\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The vocab size is: 4096\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Get all the tokens of the tokenizer\n",
+ "from transformers import AutoTokenizer, PreTrainedTokenizer\n",
+ "\n",
+ "\n",
+ "# Decode a sentence\n",
+ "def decode(tokenizer: PreTrainedTokenizer, token_ids: list[int]) -> str:\n",
+ " return tokenizer.decode(token_ids, skip_special_tokens=True)\n",
+ "\n",
+ "model = \"delphi-suite/delphi-llama2-100k\"\n",
+ "tokenizer = AutoTokenizer.from_pretrained(model)\n",
+ "vocab_size = tokenizer.vocab_size\n",
+ "print(\"The vocab size is:\", vocab_size)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Load the pickle."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pickle\n",
+ "path = \"../src/delphi/eval/labelled_token_ids_dict.pkl\"\n",
+ "# load \n",
+ "with open(path, \"rb\") as f:\n",
+ " labelled_token_ids_dict = pickle.load(f)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Look at some random tokens and their labels"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The token id is: 1143\n",
+ "The decoded token is: has\n",
+ "The label is:\n",
+ "{'Capitalized': False,\n",
+ " 'Is Adjective': False,\n",
+ " 'Is Adposition': False,\n",
+ " 'Is Adverb': False,\n",
+ " 'Is Auxiliary': False,\n",
+ " 'Is Coordinating conjuction': False,\n",
+ " 'Is Determiner': False,\n",
+ " 'Is Interjunction': True,\n",
+ " 'Is Named Entity': False,\n",
+ " 'Is Noun': False,\n",
+ " 'Is Numeral': False,\n",
+ " 'Is Other': False,\n",
+ " 'Is Particle': False,\n",
+ " 'Is Pronoun': False,\n",
+ " 'Is Proper Noun': False,\n",
+ " 'Is Punctuation': False,\n",
+ " 'Is Subordinating conjuction': False,\n",
+ " 'Is Symbol': False,\n",
+ " 'Is Verb': False,\n",
+ " 'Starts with space': False}\n"
+ ]
+ }
+ ],
+ "source": [
+ "import random\n",
+ "from pprint import pprint\n",
+ "# Get a random token id between 0 and 4000\n",
+ "token_id = random.randint(0, 4000)\n",
+ "# decode the token id\n",
+ "decoded_token = decode(tokenizer, [token_id])\n",
+ "# get the corresponding label\n",
+ "label = labelled_token_ids_dict[token_id]\n",
+ "# print the results\n",
+ "print(\"The token id is:\", token_id)\n",
+ "print(\"The decoded token is:\", decoded_token)\n",
+ "print(\"The label is:\")\n",
+ "pprint(label)\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "venv_tinyevals",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/notebooks/vis_demo.ipynb b/notebooks/vis_demo.ipynb
new file mode 100644
index 00000000..842804d0
--- /dev/null
+++ b/notebooks/vis_demo.ipynb
@@ -0,0 +1,148 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch; torch.set_grad_enabled(False)\n",
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
+ "\n",
+ "from delphi.eval.utils import tokenize, get_next_and_top_k_probs, load_validation_dataset\n",
+ "from delphi.eval.vis import vis_sample_prediction_probs\n",
+ "\n",
+ "model_name = \"roneneldan/TinyStories-1M\"\n",
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
+ "model = AutoModelForCausalLM.from_pretrained(model_name)\n",
+ "ds = load_validation_dataset(\"tinystories-v2-clean\")\n",
+ "ds_txt = ds[\"story\"][:100]\n",
+ "ds_tok = [tokenize(tokenizer, txt) for txt in ds_txt]\n",
+ "sample_tok = ds_tok[0]\n",
+ "\n",
+ "correct_probs, top_3_probs = get_next_and_top_k_probs(model, sample_tok, k=3)\n",
+ "_, top_5_probs = get_next_and_top_k_probs(model, sample_tok, k=5)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### collect top k predictions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ "
<|endoftext|>
Once
upon
a
time
,
there
was
a
kind
girl
named
Lily
.
Lily
loved
to
mix
things
.
One
day
,
she
found
a
big
box
full
of
colors
.
Lily
was
very
happy
.
\\n
L
ily
took
out
a
strip
of
red
and
a
strip
of
blue
.
She
mixed
them
together
and
made
a
new
color
,
purple
!
Lily
was
so
excited
.
She
wanted
to
mix
more
colors
.
\\n
Next
,
Lily
took
a
strip
of
yellow
and
a
strip
of
green
.
She
mixed
them
together
and
made
a
new
color
,
orange
!
Lily
was
very
proud
of
herself
.
She
showed
her
new
colors
to
her
mom
and
dad
,
and
they
were
proud
of
her
too
.
They
all
lived
happily
ever
after
.
\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "_ = vis_sample_prediction_probs(sample_tok, correct_probs, top_3_probs, tokenizer)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " <|endoftext|>
Once
upon
a
time
,
there
was
a
kind
girl
named
Lily
.
Lily
loved
to
mix
things
.
One
day
,
she
found
a
big
box
full
of
colors
.
Lily
was
very
happy
.
\\n
L
ily
took
out
a
strip
of
red
and
a
strip
of
blue
.
She
mixed
them
together
and
made
a
new
color
,
purple
!
Lily
was
so
excited
.
She
wanted
to
mix
more
colors
.
\\n
Next
,
Lily
took
a
strip
of
yellow
and
a
strip
of
green
.
She
mixed
them
together
and
made
a
new
color
,
orange
!
Lily
was
very
proud
of
herself
.
She
showed
her
new
colors
to
her
mom
and
dad
,
and
they
were
proud
of
her
too
.
They
all
lived
happily
ever
after
.
\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "_ = vis_sample_prediction_probs(sample_tok, correct_probs, top_5_probs, tokenizer)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 00000000..c31b13f5
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,10 @@
+[tool.black]
+extend-exclude = 'src/llama2c'
+
+[tool.isort]
+profile = 'black'
+known_third_party = ['llama2c', 'wandb']
+extend_skip = ['src/llama2c']
+
+[tool.pytest.ini_options]
+testpaths = ["tests"]
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 65b457a4..e14ef757 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -9,4 +9,10 @@ black==23.12.1
jaxtyping==0.2.25
beartype==0.16.4
pre-commit==3.6.0
-isort==5.13.2
\ No newline at end of file
+isort==5.13.2
+spacy==3.7.2
+chardet==5.2.0
+sentencepiece==0.1.99
+protobuf==4.25.2
+plotly==5.18.0
+spacy-transformers==1.3.4
\ No newline at end of file
diff --git a/scripts/generate_logprobs.sh b/scripts/generate_logprobs.sh
new file mode 100644
index 00000000..fc1f836a
--- /dev/null
+++ b/scripts/generate_logprobs.sh
@@ -0,0 +1,28 @@
+#!/bin/bash
+
+# Define the batch size
+BATCH_SIZE=80 # This worked well in my CPU, but 200 was too much
+DATASET_NAME="delphi-suite/tinystories-v2-clean-tokenized"
+USERNAME="transcendingvictor" # your Hugging Face username
+TOKEN="hf_aaaaaaaaaaaaaaaaaaaaaaaaaa" # your Hugging Face API token
+
+
+# List of models
+declare -a MODEL_NAMES=("delphi-suite/delphi-llama2-100k"
+ "delphi-suite/delphi-llama2-200k"
+ "delphi-suite/delphi-llama2-400k"
+ "delphi-suite/delphi-llama2-800k"
+ "delphi-suite/delphi-llama2-1.6m"
+ "delphi-suite/delphi-llama2-3.2m"
+ "delphi-suite/delphi-llama2-6.4m"
+ "delphi-suite/delphi-llama2-12.8m"
+ "delphi-suite/delphi-llama2-25.6m")
+
+# Loop through each model and generate log probabilities
+for MODEL_NAME in "${MODEL_NAMES[@]}"
+do
+ echo "Processing $MODEL_NAME"
+ python scripts/inference.py "$MODEL_NAME" --batch-size "$BATCH_SIZE" --dataset-name "$DATASET_NAME" --username "$USERNAME" --token "$TOKEN"
+done
+
+echo "All models processed."
diff --git a/scripts/inference.py b/scripts/inference.py
new file mode 100644
index 00000000..52076e3b
--- /dev/null
+++ b/scripts/inference.py
@@ -0,0 +1,114 @@
+import argparse
+import os
+
+import numpy as np
+import pandas as pd
+import torch
+from datasets import Dataset, load_dataset
+from jaxtyping import Int
+from tqdm.auto import tqdm
+from transformers import AutoModelForCausalLM
+
+from delphi.eval.utils import get_all_and_next_logprobs, load_validation_dataset
+
+torch.set_grad_enabled(False)
+
+
+def main(
+ model_name: str,
+ batch_size: Int,
+ dataset_name: str,
+ username: str,
+ token: str,
+ funct_test: bool = False,
+):
+ """
+ Outputs the log probabilities of the next token for each token in the validation dataset.
+ And uploads the resulting dataset to huggingface.
+ Args:
+ - model_name: The name of the model to use for inference
+ - batch_size: The batch size for processing. 80 worked well in CPU.
+ - dataset_name: The name of the dataset from which validation set will be loaded
+ - username: Hugging Face API username
+ - token: Hugging Face API token
+ """
+ val_ds = load_validation_dataset(dataset_name)
+
+ model = AutoModelForCausalLM.from_pretrained(model_name)
+
+ total_sequences = (
+ len(val_ds) if not funct_test else 320
+ ) # Use only 320 sequences if funct_test is True
+
+ logprobs = np.empty((total_sequences, 513))
+ logprobs[:, 0] = float("nan")
+ for i in tqdm(range(0, total_sequences, batch_size)):
+ batch_end = min(i + batch_size, total_sequences)
+ batch_sequences = [val_ds[j]["tokens"] for j in range(i, batch_end)]
+ batch_sequences_tensor = torch.tensor(batch_sequences)
+
+ logprobs_tensor = get_all_and_next_logprobs(model, batch_sequences_tensor)[1]
+ logprobs[i:batch_end, 1:] = logprobs_tensor.cpu().numpy()
+
+ df_dataset = pd.DataFrame({"logprobs": [row for row in logprobs]})
+ hf_dataset = Dataset.from_pandas(df_dataset)
+
+ # change the repo_id to your hf username in generate_logprobs.sh
+ # change the yout hf token in generate_logprobs.sh
+
+ repo_id = f"{username}/{model_name.rsplit('/', 1)[-1]}-validation-logprobs"
+ if funct_test:
+ repo_id += "-funct-test"
+ hf_dataset.push_to_hub(
+ repo_id=repo_id,
+ split="validation",
+ private=False,
+ token=token,
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Run inference and generate log probabilities."
+ )
+ parser.add_argument(
+ "model_name", type=str, help="Model name with or without delphi-suite/ prefix"
+ )
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=80,
+ help="Batch size for processing (default: 80)",
+ )
+ parser.add_argument(
+ "--dataset-name",
+ type=str,
+ help="Dataset name with or without delphi-suite/ prefix",
+ )
+ parser.add_argument(
+ "--username",
+ type=str,
+ help="Hugging Face API username",
+ )
+ parser.add_argument(
+ "--token",
+ type=str,
+ help="Hugging Face API token",
+ )
+ parser.add_argument(
+ "--test-funct", action="store_true", help="Enable test function mode"
+ )
+
+ args = parser.parse_args()
+
+ if "/" not in args.model_name:
+ args.model_name = "delphi-suite/" + args.model_name
+
+ main(
+ args.model_name,
+ args.batch_size,
+ args.dataset_name,
+ args.username,
+ args.token,
+ args.test_funct,
+ )
diff --git a/scripts/label_all_tokens.py b/scripts/label_all_tokens.py
new file mode 100644
index 00000000..6519eaca
--- /dev/null
+++ b/scripts/label_all_tokens.py
@@ -0,0 +1,109 @@
+import argparse
+import pickle
+
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
+
+from delphi.constants import STATIC_ASSETS_DIR
+from delphi.eval import token_labelling
+
+
+def tokenize(
+ tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, sample_txt: str
+) -> int:
+ # supposedly this can be different than prepending the bos token id
+ return tokenizer.encode(tokenizer.bos_token + sample_txt, return_tensors="pt")[0]
+
+
+# Decode a sentence
+def decode(
+ tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, token_ids: int | list[int]
+) -> str:
+ return tokenizer.decode(token_ids, skip_special_tokens=True)
+
+
+def main():
+ # Setup argparse
+ parser = argparse.ArgumentParser(description="Tokenization and labeling utility.")
+ parser.add_argument(
+ "--model_name",
+ type=str,
+ help="Name of the model to use for tokenization and labeling.",
+ default="delphi-suite/delphi-llama2-100k",
+ required=False,
+ )
+ args = parser.parse_args()
+
+ # Access command-line arguments
+
+ model_name = args.model_name
+
+ print("\n", " LABEL ALL TOKENS ".center(50, "="), "\n")
+ print(f"You chose the model: {model_name}\n")
+ print(
+ f"The language model will be loaded from Huggingface and its tokenizer used to do two things:\n\t1) Create a list of all tokens in the tokenizer's vocabulary.\n\t2) Label each token with its part of speech, dependency, and named entity recognition tags.\nThe respective results will be saved to files located at: '{STATIC_ASSETS_DIR}'\n"
+ )
+
+ # ================ (1) =================
+ print("(1) Create a list of all tokens in the tokenizer's vocabulary ...")
+
+ # Load the tokenizer from Huggingface
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
+ vocab_size = tokenizer.vocab_size
+ print("Loaded the tokenizer.\nThe vocab size is:", vocab_size)
+
+ # Create a list of all tokens in the tokenizer's vocabulary
+ tokens_str = "" # will hold all tokens and their ids
+ for i in range(tokenizer.vocab_size):
+ tokens_str += f"{i},{decode(tokenizer, i)}\n"
+
+ # Save the list of all tokens to a file
+ filename = "all_tokens_list.txt"
+ filepath = STATIC_ASSETS_DIR.joinpath(filename)
+ with open(f"{filepath}", "w", encoding="utf-8") as f:
+ f.write(tokens_str)
+
+ print(f"Saved the list of all tokens to:\n\t{filepath}\n")
+
+ # ================ (2) =================
+ print("(2) Label each token ...")
+
+ # let's label each token
+ labelled_token_ids_dict: dict[int, dict[str, bool]] = {} # token_id: labels
+ max_token_id = tokenizer.vocab_size # stop at which token id, vocab size
+ # we iterate over all token_ids individually
+ for token_id in tqdm(range(0, max_token_id), desc="Labelling tokens"):
+ # decode the token_ids to get a list of tokens, a 'sentence'
+ tokens = decode(tokenizer, token_id) # list of tokens == sentence
+ # put the sentence into a list, to make it a batch of sentences
+ sentences = [tokens]
+ # label the batch of sentences
+ labels = token_labelling.label_batch_sentences(
+ sentences, tokenized=True, verbose=False
+ )
+ # create a dict with the token_ids and their labels
+ # update the labelled_token_ids_dict with the new dict
+ labelled_token_ids_dict[token_id] = labels[0][0]
+
+ # Save the labelled tokens to a file
+ filename = "labelled_token_ids_dict.pkl"
+ filepath = STATIC_ASSETS_DIR.joinpath(filename)
+ with open(f"{filepath}", "wb") as f:
+ pickle.dump(labelled_token_ids_dict, f)
+
+ print(f"Saved the labelled tokens to:\n\t{filepath}\n")
+
+ # sanity check that The pickled and the original dict are the same
+ print("Sanity check ...", end="")
+ # load pickle
+ with open(f"{filepath}", "rb") as f:
+ pickled = pickle.load(f)
+ # compare
+ assert labelled_token_ids_dict == pickled
+ print(" completed.")
+
+ print(" END ".center(50, "="))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/map_tokens.py b/scripts/map_tokens.py
new file mode 100755
index 00000000..5bafbffe
--- /dev/null
+++ b/scripts/map_tokens.py
@@ -0,0 +1,51 @@
+#!/usr/bin/env python3
+
+import argparse
+
+import pandas as pd
+from datasets import Dataset
+
+from delphi.constants import STATIC_ASSETS_DIR
+from delphi.eval.token_map import token_map
+from delphi.eval.utils import load_validation_dataset
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="")
+
+ parser.add_argument(
+ "dataset_name",
+ type=str,
+ help="Dataset from huggingface to run token_map on",
+ )
+ parser.add_argument(
+ "--username",
+ type=str,
+ help="Hugging Face API username",
+ )
+ parser.add_argument(
+ "--token",
+ type=str,
+ help="Hugging Face API token",
+ )
+ parser.add_argument(
+ "--tokenizer-size",
+ type=int,
+ default=4096,
+ help="Size of the tokenizer",
+ )
+ args = parser.parse_args()
+
+ dataset = load_validation_dataset(args.dataset_name)
+
+ hf_dataset = Dataset.from_dict(
+ {"prompt_pos_idx": token_map(dataset, args.tokenizer_size)}
+ )
+
+ repo_id = f"{args.username}/v0-token-map" # location in to hf
+
+ hf_dataset.push_to_hub(
+ repo_id=repo_id,
+ split="validation",
+ private=False,
+ token=args.token,
+ )
diff --git a/scripts/run_training.py b/scripts/run_training.py
new file mode 100644
index 00000000..e244b8bc
--- /dev/null
+++ b/scripts/run_training.py
@@ -0,0 +1,63 @@
+import argparse
+import copy
+import json
+from dataclasses import fields
+from typing import Any
+
+from delphi.train.gigaconfig import GigaConfig, debug_config
+from delphi.train.training import run_training
+
+
+def update_config(config: GigaConfig, new_vals: dict[str, Any]):
+ for field in fields(config):
+ if new_vals.get(field.name) is not None:
+ setattr(config, field.name, new_vals[field.name])
+
+
+def main():
+ # Setup argparse
+ parser = argparse.ArgumentParser(description="Train a delphi model")
+ config_arg_group = parser.add_argument_group("Config arguments")
+ for field in fields(GigaConfig):
+ config_arg_group.add_argument(
+ f"--{field.name}",
+ type=field.type,
+ required=False,
+ help=f"Default: {field.default}",
+ )
+ parser.add_argument(
+ "--config_file",
+ help=(
+ "Path to a json file containing config values (see sample_config.json). "
+ "Specific values can be overridden with --arguments."
+ ),
+ required=False,
+ type=str,
+ )
+ parser.add_argument(
+ "--debug",
+ help="Use debug config values. Overridden by config file values and --arguments.",
+ required=False,
+ action="store_true",
+ )
+ args = parser.parse_args()
+
+ # setup config
+ if args.debug:
+ config = copy.copy(debug_config)
+ else:
+ config = GigaConfig()
+ # config file overrides default values
+ if args.config_file is not None:
+ with open(args.config_file, "r") as f:
+ config_dict = json.load(f)
+ update_config(config, config_dict)
+ # specific arguments override everything else
+ update_config(config, vars(args))
+
+ # run training
+ run_training(config)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/sample_config.json b/scripts/sample_config.json
new file mode 100644
index 00000000..067c4da8
--- /dev/null
+++ b/scripts/sample_config.json
@@ -0,0 +1,34 @@
+{
+ "out_dir": "out",
+ "eval_interval": 500,
+ "log_interval": 1,
+ "eval_iters": 10,
+ "eval_only": false,
+ "always_save_checkpoint": false,
+ "init_from": "scratch",
+ "wandb_log": true,
+ "wandb_entity": "jaiwithani",
+ "wandb_project": "delphi",
+ "wandb_run_name": "2024_03_07_17_43_09",
+ "batch_size": 64,
+ "max_seq_len": 512,
+ "vocab_size": 4096,
+ "dim": 48,
+ "n_layers": 2,
+ "n_heads": 2,
+ "n_kv_heads": 2,
+ "multiple_of": 32,
+ "dropout": 0.0,
+ "gradient_accumulation_steps": 4,
+ "learning_rate": 0.0005,
+ "max_epochs": 2,
+ "weight_decay": 0.1,
+ "beta1": 0.9,
+ "beta2": 0.95,
+ "grad_clip": 1.0,
+ "decay_lr": true,
+ "warmup_iters": 1000,
+ "min_lr": 0.0,
+ "train_sample_limit": 256,
+ "val_sample_limit": -1
+}
\ No newline at end of file
diff --git a/scripts/upload_stories.py b/scripts/upload_stories.py
index 6a420153..e1afc042 100644
--- a/scripts/upload_stories.py
+++ b/scripts/upload_stories.py
@@ -1,24 +1,20 @@
import json
-import pandas as pd
+import pandas as pd
from datasets import Dataset
-
splits = [
("../train/llama2c/data/TinyStoriesV2-GPT4-train-clean.json", "train"),
- ("../train/llama2c/data/TinyStoriesV2-GPT4-valid-clean.json", "validation")
+ ("../train/llama2c/data/TinyStoriesV2-GPT4-valid-clean.json", "validation"),
]
+
def load_dataset(filepath):
- with open(filepath, 'r', encoding='utf-8') as file:
+ with open(filepath, "r", encoding="utf-8") as file:
return json.load(file)
-
-
-for (filename, split) in splits:
+
+
+for filename, split in splits:
stories = load_dataset(filename)
dataset = Dataset.from_pandas(pd.DataFrame(stories))
- dataset.push_to_hub(
- repo_id="",
- split=split,
- token=""
- )
\ No newline at end of file
+ dataset.push_to_hub(repo_id="", split=split, token="")
diff --git a/scripts/upload_tokens.py b/scripts/upload_tokens.py
index 33cd937b..d83f00e5 100644
--- a/scripts/upload_tokens.py
+++ b/scripts/upload_tokens.py
@@ -1,10 +1,10 @@
+from functools import partial
+
import pandas as pd
from datasets import Dataset
-from functools import partial
from delphi import PretokDataset
-
batch_size = 1
max_seq_len = 512
vocab_size = 4096
@@ -12,7 +12,6 @@
device = "cuda"
for split in ["train", "validation"]:
-
ds = PretokDataset(
split=split,
batch_size=batch_size,
@@ -20,18 +19,14 @@
vocab_size=vocab_size,
vocab_source=vocab_source,
)
-
+
num_batches = len(PretokDataset)
-
+
tokens = []
for idx, (chunk) in enumerate(ds):
- if idx >= num_batches:
+ if idx >= num_batches:
break
- tokens.append({'tokens': chunk.numpy()})
-
+ tokens.append({"tokens": chunk.numpy()})
+
dataset = Dataset.from_pandas(pd.DataFrame(tokens))
- dataset.push_to_hub(
- repo_id="",
- split=split,
- token=""
- )
+ dataset.push_to_hub(repo_id="", split=split, token="")
diff --git a/setup.py b/setup.py
index 80059fc2..a4156702 100644
--- a/setup.py
+++ b/setup.py
@@ -5,4 +5,5 @@
version="0.1",
packages=find_packages(where="src"),
package_dir={"": "src"},
+ package_data={"delphi.static": ["*"]},
)
diff --git a/src/delphi/constants.py b/src/delphi/constants.py
new file mode 100644
index 00000000..6e26b4fa
--- /dev/null
+++ b/src/delphi/constants.py
@@ -0,0 +1,6 @@
+from importlib.resources import files
+
+STATIC_ASSETS_DIR = files("delphi.static")
+
+CORPUS_DATASET = "delphi-suite/tinystories-v2-clean"
+TOKENIZED_CORPUS_DATASET = "delphi-suite/v0-tinystories-v2-clean-tokenized"
diff --git a/src/delphi/eval/calc_model_group_stats.py b/src/delphi/eval/calc_model_group_stats.py
new file mode 100644
index 00000000..d9c5d4c1
--- /dev/null
+++ b/src/delphi/eval/calc_model_group_stats.py
@@ -0,0 +1,54 @@
+import numpy as np
+
+
+def calc_model_group_stats(
+ tokenized_corpus_dataset: list,
+ logprobs_by_dataset: dict[str, list[list[float]]],
+ token_labels_by_token: dict[int, dict[str, bool]],
+ token_labels: list[str],
+) -> dict[tuple[str, str], dict[str, float]]:
+ """
+ For each (model, token group) pair, calculate useful stats (for visualization)
+
+ args:
+ - tokenized_corpus_dataset: the tokenized corpus dataset, e.g. load_dataset(constants.tokenized_corpus_dataset))["validation"]
+ - logprob_datasets: a dict of lists of logprobs, e.g. {"llama2": load_dataset("transcendingvictor/llama2-validation-logprobs")["validation"]["logprobs"]}
+ - token_groups: a dict of token groups, e.g. {0: {"Is Noun": True, "Is Verb": False, ...}, 1: {...}, ...}
+ - models: a list of model names, e.g. constants.LLAMA2_MODELS
+ - token_labels: a list of token group descriptions, e.g. ["Is Noun", "Is Verb", ...]
+
+ returns: a dict of (model, token group) pairs to a dict of stats,
+ e.g. {("llama2", "Is Noun"): {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...}
+
+ Technically `models` and `token_labels` are redundant, as they are also keys in `logprob_datasets` and `token_groups`,
+ but it's better to be explicit
+
+ stats calculated: mean, median, min, max, 25th percentile, 75th percentile
+ """
+ model_group_stats = {}
+ for model in logprobs_by_dataset:
+ group_logprobs = {}
+ print(f"Processing model {model}")
+ dataset = logprobs_by_dataset[model]
+ for ix_doc_lp, document_lps in enumerate(dataset):
+ tokens = tokenized_corpus_dataset[ix_doc_lp]["tokens"]
+ for ix_token, token in enumerate(tokens):
+ if ix_token == 0: # skip the first token, which isn't predicted
+ continue
+ logprob = document_lps[ix_token]
+ for token_group_desc in token_labels:
+ if token_labels_by_token[token][token_group_desc]:
+ if token_group_desc not in group_logprobs:
+ group_logprobs[token_group_desc] = []
+ group_logprobs[token_group_desc].append(logprob)
+ for token_group_desc in token_labels:
+ if token_group_desc in group_logprobs:
+ model_group_stats[(model, token_group_desc)] = {
+ "mean": np.mean(group_logprobs[token_group_desc]),
+ "median": np.median(group_logprobs[token_group_desc]),
+ "min": np.min(group_logprobs[token_group_desc]),
+ "max": np.max(group_logprobs[token_group_desc]),
+ "25th": np.percentile(group_logprobs[token_group_desc], 25),
+ "75th": np.percentile(group_logprobs[token_group_desc], 75),
+ }
+ return model_group_stats
diff --git a/src/delphi/eval/compare_models.py b/src/delphi/eval/compare_models.py
new file mode 100644
index 00000000..e03b300c
--- /dev/null
+++ b/src/delphi/eval/compare_models.py
@@ -0,0 +1,91 @@
+from dataclasses import dataclass
+
+import torch
+from jaxtyping import Int
+from transformers import PreTrainedModel
+
+from delphi.eval.utils import get_all_and_next_logprobs_single
+
+
+def identify_model(model: PreTrainedModel) -> str:
+ return model.config.name_or_path
+
+
+@dataclass
+class TokenPrediction:
+ token: int
+ base_model_prob: float
+ lift_model_prob: float
+
+
+@dataclass
+class NextTokenStats:
+ base_model: str
+ lift_model: str
+ next_prediction: TokenPrediction
+ topk: list[TokenPrediction]
+
+
+def compare_models(
+ model_a: PreTrainedModel,
+ model_b: PreTrainedModel,
+ sample_tok: Int[torch.Tensor, "seq"],
+ top_k: int = 3,
+) -> list[NextTokenStats | None]:
+ """
+ Compare the probabilities of the next token for two models and get the top k token predictions according to model B.
+ Args:
+ - model_a: The first model (assumed to be the base model)
+ - model_b: The second model (assumed to be the improved model)
+ - sample_tok: The tokenized prompt
+ - top_k: The number of top token predictions to retrieve (default is 5)
+ Returns:
+ A list of NextTokenStats objects, one for each token in the prompt.
+ Tensors are aligned to the token they are predicting (by prepending a -1 to the start of the tensor)
+ """
+ assert (
+ model_a.device == model_b.device
+ ), "Both models must be on the same device for comparison."
+
+ device = model_a.device
+ sample_tok = sample_tok.to(device)
+
+ logprobs_a, next_probs_a = get_all_and_next_logprobs_single(model_a, sample_tok)
+ logprobs_b, next_probs_b = get_all_and_next_logprobs_single(model_b, sample_tok)
+
+ probs_a = torch.exp(logprobs_a)
+ probs_b = torch.exp(logprobs_b)
+
+ top_k_b = torch.topk(probs_b, top_k, dim=-1)
+ top_k_a_probs = torch.gather(probs_a, 1, top_k_b.indices)
+
+ top_k_b_tokens = top_k_b.indices
+ top_k_b_probs = top_k_b.values
+
+ comparisons = []
+ # ignore first token when evaluating predictions
+ comparisons.append(None)
+
+ for next_p_a, next_p_b, top_toks_b, top_probs_a, top_probs_b in zip(
+ next_probs_a, next_probs_b, top_k_b_tokens, top_k_a_probs, top_k_b_probs
+ ):
+ nts = NextTokenStats(
+ base_model=identify_model(model_a),
+ lift_model=identify_model(model_b),
+ next_prediction=TokenPrediction(
+ token=int(next_p_a.item()),
+ base_model_prob=next_p_a.item(),
+ lift_model_prob=next_p_b.item(),
+ ),
+ topk=[
+ TokenPrediction(
+ token=int(top_toks_b[i].item()),
+ base_model_prob=top_probs_a[i].item(),
+ lift_model_prob=top_probs_b[i].item(),
+ )
+ for i in range(top_k)
+ ],
+ )
+ comparisons.append(nts)
+
+ return comparisons
diff --git a/src/delphi/eval/constants.py b/src/delphi/eval/constants.py
new file mode 100644
index 00000000..30b3e36b
--- /dev/null
+++ b/src/delphi/eval/constants.py
@@ -0,0 +1,14 @@
+corpus_dataset = "delphi-suite/tinystories-v2-clean"
+tokenized_corpus_dataset = "delphi-suite/tinystories-v2-clean-tokenized-v0"
+
+LLAMA2_MODELS = [
+ "delphi-llama2-100k",
+ "delphi-llama2-200k",
+ "delphi-llama2-400k",
+ "delphi-llama2-800k",
+ "delphi-llama2-1.6m",
+ "delphi-llama2-3.2m",
+ "delphi-llama2-6.4m",
+ "delphi-llama2-12.8m",
+ "delphi-llama2-25.6m",
+]
diff --git a/src/delphi/eval/token_labelling.py b/src/delphi/eval/token_labelling.py
new file mode 100644
index 00000000..80673e03
--- /dev/null
+++ b/src/delphi/eval/token_labelling.py
@@ -0,0 +1,210 @@
+from typing import Callable, Optional
+
+import spacy
+from spacy.tokens import Doc, Token
+from spacy.util import is_package
+
+# make sure the english language model capabilities are installed by the equivalent of:
+# python -m spacy download en_core_web_sm
+# Should be run once, initially. Download only starts if not already installed.
+SPACY_MODEL = "en_core_web_sm" # small: "en_core_web_sm", large: "en_core_web_trf"
+NLP = None # global var to hold the language model
+if not is_package(SPACY_MODEL):
+ spacy.cli.download(SPACY_MODEL, False, False)
+
+
+TOKEN_LABELS: dict[str, Callable] = {
+ # --- custom categories ---
+ "Starts with space": (lambda token: token.text.startswith(" ")), # bool
+ "Capitalized": (lambda token: token.text[0].isupper()), # bool
+ # --- POS (part-of-speech) categories ---
+ # They include the Universal POS tags (https://universaldependencies.org/u/pos/)
+ # -> "POS Tag": (lambda token: token.pos_), # 'NOUN', 'VB', ..
+ "Is Adjective": (lambda token: token.pos_ == "ADJ"),
+ "Is Adposition": (lambda token: token.pos_ == "ADP"),
+ "Is Adverb": (lambda token: token.pos_ == "ADV"),
+ "Is Auxiliary": (lambda token: token.pos_ == "AUX"),
+ "Is Coordinating conjuction": (lambda token: token.pos_ == "CCONJ"),
+ "Is Determiner": (lambda token: token.pos_ == "DET"),
+ "Is Interjunction": (lambda token: token.pos_ == "INTJ"),
+ "Is Noun": (lambda token: token.pos_ == "NOUN"),
+ "Is Numeral": (lambda token: token.pos_ == "NUM"),
+ "Is Particle": (lambda token: token.pos_ == "PART"),
+ "Is Pronoun": (lambda token: token.pos_ == "PRON"),
+ "Is Proper Noun": (lambda token: token.pos_ == "PROPN"),
+ "Is Punctuation": (lambda token: token.pos_ == "PUNCT"),
+ "Is Subordinating conjuction": (lambda token: token.pos_ == "SCONJ"),
+ "Is Symbol": (lambda token: token.pos_ == "SYM"),
+ "Is Verb": (lambda token: token.pos_ == "VERB"),
+ "Is Other": (lambda token: token.pos_ == "X"),
+ # --- dependency categories ---
+ # -> "Dependency": (lambda token: token.dep_), # 'nsubj', 'ROOT', 'dobj', ..
+ # "Is Subject": (lambda token: token.dep_ == "nsubj"),
+ # "Is Object": (lambda token: token.dep_ == "dobj"),
+ # "Is Root": (
+ # lambda token: token.dep_ == "ROOT"
+ # ), # root of the sentence (often a verb)
+ # "Is auxiliary": (lambda token: token.dep_ == "aux"),
+ # --- Named entity recognition (NER) categories ---
+ # "Named Entity Type": (lambda token: token.ent_type_), # '', 'PERSON', 'ORG', 'GPE', ..
+ "Is Named Entity": (lambda token: token.ent_type_ != ""),
+}
+
+
+def explain_token_labels(token: Optional[Token] = None) -> None:
+ """
+ Prints the explanation of a specific token's labels or of ALL
+ possible labels (POS, dependency, NER, ...), if no token is provided.
+
+ Parameters
+ ----------
+ token : Optional[Token], optional
+ The token, whose labels should be explained. If None, all labels
+ possible labels are explained, by default None.
+ """
+ if token is not None:
+ # get token labels
+ labels = label_single_token(token)
+ print(" Explanation of token labels ".center(45, "-"))
+ print("Token text:".ljust(20), token.text)
+ print("Token dependency:".ljust(20), spacy.glossary.explain(token.dep_))
+ print("Token POS:".ljust(20), spacy.glossary.explain(token.pos_))
+ print(" Token labels ".center(45, "-"))
+ for i, (label_name, value) in enumerate(labels.items()):
+ print(f" {i:2} ", label_name.ljust(20), value)
+
+ else:
+ glossary = spacy.glossary.GLOSSARY
+ print(
+ f"Explanation of all {len(glossary.keys())} token labels (POS, dependency, NER, ...):"
+ )
+ for label, key in glossary.items():
+ print(" ", label.ljust(10), key)
+
+
+def label_single_token(token: Token | None) -> dict[str, bool]:
+ """
+ Labels a single token. A token, that has been analyzed by the spaCy
+ library.
+
+ Parameters
+ ----------
+ token : Token | None
+ The token to be labelled.
+
+ Returns
+ -------
+ dict[str, bool]
+ Returns a dictionary with the token's labels as keys and their
+ corresponding boolean values.
+ """
+ labels = dict() # The dict holding labels of a single token
+ # if token is None, then it is a '' empty strong token or similar
+ if token is None:
+ for label_name, category_check in TOKEN_LABELS.items():
+ labels[label_name] = False
+ labels["Is Other"] = True
+ return labels
+ # all other cases / normal tokens
+ for label_name, category_check in TOKEN_LABELS.items():
+ labels[label_name] = category_check(token)
+ return labels
+
+
+def label_sentence(tokens: Doc | list[Token]) -> list[dict[str, bool]]:
+ """
+ Labels spaCy Tokens in a sentence. Takes the context of the token into account
+ for dependency labels (e.g. subject, object, ...), IF dependency labels are turned on.
+
+ Parameters
+ ----------
+ tokens : list[Token]
+ A list of tokens.
+
+ Returns
+ -------
+ list[dict[str, bool]]
+ Returns a list of the tokens' labels.
+ """
+ labelled_tokens = list() # list holding labels for all tokens of sentence
+ # if the list is empty it is because token is '' empty string or similar
+ if len(tokens) == 0:
+ labels = label_single_token(None)
+ labelled_tokens.append(labels)
+ return labelled_tokens
+ # in all other cases
+ for token in tokens:
+ labels = label_single_token(token)
+ labelled_tokens.append(labels)
+ return labelled_tokens
+
+
+def label_batch_sentences(
+ sentences: list[str] | list[list[str]],
+ tokenized: bool = True,
+ verbose: bool = False,
+) -> list[list[dict[str, bool]]]:
+ """
+ Labels tokens in a sentence batchwise. Takes the context of the token into
+ account for dependency labels (e.g. subject, object, ...).
+
+ Parameters
+ ----------
+ sentences : list
+ A batch/list of sentences, each being a list of tokens.
+ tokenized : bool, optional
+ Whether the sentences are already tokenized, by default True. If the sentences
+ are full strings and not lists of tokens, then set to False. If true then `sentences` must be list[list[str]].
+ verbose : bool, optional
+ Whether to print the tokens and their labels to the console, by default False.
+
+ Returns
+ -------
+ list[list[dict[str, bool]]
+ Returns a list of sentences. Each sentence contains a list of its
+ corresponding token length where each entry provides the labels/categories
+ for the token. Sentence -> Token -> Labels
+ """
+ global NLP, SPACY_MODEL
+
+ if NLP is None:
+ # Load english language model
+ NLP = spacy.load(SPACY_MODEL)
+ # labelled tokens, list holding sentences holding tokens holding corresponding token labels
+ labelled_sentences: list[list[dict[str, bool]]] = list()
+
+ # go through each sentence in the batch
+ for sentence in sentences:
+ if tokenized:
+ # sentence is a list of tokens
+ doc = Doc(NLP.vocab, words=sentence) # type: ignore
+ # Apply the spaCy pipeline, except for the tokenizer
+ for name, proc in NLP.pipeline:
+ if name != "tokenizer":
+ doc = proc(doc)
+ else:
+ # sentence is a single string
+ doc = NLP(sentence) # type: ignore
+
+ labelled_tokens = list() # list holding labels for all tokens of sentence
+ labelled_tokens = label_sentence(doc)
+
+ # print the token and its labels to console
+ if verbose is True:
+ # go through each token in the sentence
+ for token, labelled_token in zip(doc, labelled_tokens):
+ print(f"Token: {token}")
+ print(" | ".join(list(TOKEN_LABELS.keys())))
+ printable = [
+ str(l).ljust(len(name)) for name, l in labelled_token.items()
+ ]
+ printable = " | ".join(printable)
+ print(printable)
+ print("---")
+ # add current sentence's tokens' labels to the list
+ labelled_sentences.append(labelled_tokens)
+
+ if verbose is True:
+ print("\n")
+
+ return labelled_sentences
diff --git a/src/delphi/eval/token_map.py b/src/delphi/eval/token_map.py
new file mode 100644
index 00000000..4ac7b0df
--- /dev/null
+++ b/src/delphi/eval/token_map.py
@@ -0,0 +1,18 @@
+import os
+from typing import cast
+
+from datasets import Dataset
+
+
+def token_map(
+ tokenized_dataset: Dataset,
+ tokenizer_size: int,
+) -> list[list[tuple[int, int]]]:
+ """Return a mapping of tokens to their (prompt_idx, token_idx) locations in the tokenized_dataset."""
+
+ mapping = [[] for _ in range(tokenizer_size)]
+ for prompt_idx, prompt in enumerate(tokenized_dataset):
+ prompt = cast(dict, prompt)
+ for position_idx, token in enumerate(prompt["tokens"]):
+ mapping[token].append((prompt_idx, position_idx))
+ return mapping
diff --git a/src/delphi/eval/utils.py b/src/delphi/eval/utils.py
index ee9893d9..2d052974 100644
--- a/src/delphi/eval/utils.py
+++ b/src/delphi/eval/utils.py
@@ -4,6 +4,9 @@
import torch
from datasets import Dataset, load_dataset
from jaxtyping import Float, Int
+from transformers import PreTrainedModel, PreTrainedTokenizerBase
+
+from delphi.eval import constants
def get_all_logprobs(
@@ -14,6 +17,13 @@ def get_all_logprobs(
return torch.log_softmax(logits, dim=-1)
+# convenience wrapper for calling on a single sample
+def get_single_logprobs(
+ model: Callable, input_ids: Int[torch.Tensor, "seq"]
+) -> Float[torch.Tensor, "seq vocab"]:
+ return get_all_logprobs(model, input_ids.unsqueeze(0))[0]
+
+
def gather_logprobs(
logprobs: Float[torch.Tensor, "batch seq vocab"],
tokens: Int[torch.Tensor, "batch seq"],
@@ -21,24 +31,87 @@ def gather_logprobs(
return torch.gather(logprobs, -1, tokens.unsqueeze(-1)).squeeze(-1)
-def get_next_logprobs(
- model: Callable, input_ids: Int[torch.Tensor, "batch seq"]
-) -> Float[torch.Tensor, "batch shorter_seq"]:
+def get_all_and_next_logprobs(
+ model: Callable,
+ input_ids: Int[torch.Tensor, "batch seq"],
+) -> tuple[
+ Float[torch.Tensor, "batch shorter_seq vocab"],
+ Float[torch.Tensor, "batch shorter_seq"],
+]:
logprobs = get_all_logprobs(model, input_ids[:, :-1])
next_tokens = input_ids[:, 1:]
- return gather_logprobs(logprobs, next_tokens)
+ return logprobs, gather_logprobs(logprobs, next_tokens)
+
+def get_all_and_next_logprobs_single(
+ model: Callable,
+ input_ids: Int[torch.Tensor, "seq"],
+) -> tuple[
+ Float[torch.Tensor, "shorter_seq vocab"],
+ Float[torch.Tensor, "shorter_seq"],
+]:
+ all_logprobs, next_logprobs = get_all_and_next_logprobs(
+ model, input_ids.unsqueeze(0)
+ )
+ return all_logprobs[0], next_logprobs[0]
+
+
+def get_next_and_top_k_probs(
+ model: PreTrainedModel, input_ids: Int[torch.Tensor, "seq"], k: int = 3
+) -> tuple[Float[torch.Tensor, "shorter_seq"], torch.return_types.topk,]:
+ all_logprobs, next_logprobs = get_all_and_next_logprobs_single(model, input_ids)
+ all_probs = torch.exp(all_logprobs)
+ next_probs = torch.exp(next_logprobs)
+ top_k = torch.topk(all_probs, k, dim=-1)
+ return next_probs, top_k
-def load_validation_dataset(dataset_name: str) -> Dataset:
+
+def load_delphi_dataset(dataset_name: str, split: str, slice: str = "") -> Dataset:
+ # check that split is either "train" or "validation"
+ if split not in ["train", "validation"]:
+ raise ValueError(f"Split must be either 'train' or 'validation', not {split}")
if "/" not in dataset_name:
dataset_name = f"delphi-suite/{dataset_name}"
- data_str = f"data/validation-*.parquet"
+ data_files_str = f"data/{split}-*.parquet"
dataset = load_dataset(
dataset_name,
- data_files=data_str,
+ data_files=data_files_str,
verification_mode="no_checks",
- # this seems to be the only split when using data_files
- # regardless of the files we're actually loading
- split="train",
+ # Currently, load_dataset returns a dataset dict *unless* a split is specified,
+ # EVEN IF NO SPLIT WITHIN THE DATA FILES SPECIFIED. If there's no split arg,
+ # huggingface just just says everything is in the "train" split and returns {"train": dataset}.
+ # In our case the data_files glob already specifies just the validation files, so we
+ # shouldn't need to specify a split. But we do need to specify a split to get a dataset object,
+ # or we'd get a Dataset dict. See https://github.com/huggingface/datasets/issues/5189
+ split=f"train{slice}",
)
return cast(Dataset, dataset)
+
+
+def load_validation_dataset(dataset_name: str, slice: str = "") -> Dataset:
+ return load_delphi_dataset(dataset_name, "validation", slice)
+
+
+def load_train_dataset(dataset_name: str, slice: str = "") -> Dataset:
+ return load_delphi_dataset(dataset_name, "train", slice)
+
+
+def tokenize(
+ tokenizer: PreTrainedTokenizerBase, sample_txt: str
+) -> Int[torch.Tensor, "seq"]:
+ # supposedly this can be different than prepending the bos token id
+ return cast(
+ Int[torch.Tensor, "seq"],
+ tokenizer.encode(tokenizer.bos_token + sample_txt, return_tensors="pt")[0],
+ )
+
+
+def load_logprob_dataset(model: str) -> Dataset:
+ return load_dataset(f"transcendingvictor/{model}-validation-logprobs") # type: ignore
+
+
+def load_logprob_datasets(split: str = "validation") -> dict[str, list[list[float]]]:
+ return {
+ model: cast(dict, load_logprob_dataset(model)[split])["logprobs"]
+ for model in constants.LLAMA2_MODELS
+ }
diff --git a/src/delphi/eval/vis.py b/src/delphi/eval/vis.py
new file mode 100644
index 00000000..5dd4fdb2
--- /dev/null
+++ b/src/delphi/eval/vis.py
@@ -0,0 +1,140 @@
+import uuid
+from typing import cast
+
+import torch
+from IPython.core.display import HTML
+from IPython.core.display_functions import display
+from jaxtyping import Float, Int
+from transformers import PreTrainedTokenizerBase
+
+
+def probs_to_colors(probs: Float[torch.Tensor, "next_pos"]) -> list[str]:
+ # for the endoftext token
+ # no prediction, no color
+ colors = ["white"]
+ for p in probs.tolist():
+ red_gap = 150 # the higher it is, the less red the tokens will be
+ green_blue_val = red_gap + int((255 - red_gap) * (1 - p))
+ colors.append(f"rgb(255, {green_blue_val}, {green_blue_val})")
+ return colors
+
+
+def to_tok_prob_str(tok: int, prob: float, tokenizer: PreTrainedTokenizerBase) -> str:
+ tok_str = tokenizer.decode(tok).replace(" ", " ").replace("\n", r"\n")
+ prob_str = f"{prob:.2%}"
+ return f"{prob_str:>6} |{tok_str}|"
+
+
+def token_to_html(
+ token: int,
+ tokenizer: PreTrainedTokenizerBase,
+ bg_color: str,
+ data: dict,
+) -> str:
+ data = data or {} # equivalent to if not data: data = {}
+ # non-breakable space, w/o it leading spaces wouldn't be displayed
+ str_token = tokenizer.decode(token).replace(" ", " ")
+
+ # background or user-select (for \n) goes here
+ specific_styles = {}
+ # for now just adds line break or doesn't
+ br = ""
+
+ if bg_color:
+ specific_styles["background-color"] = bg_color
+ if str_token == "\n":
+ # replace new line character with two characters: \ and n
+ str_token = r"\n"
+ # add line break in html
+ br += "
"
+ # this is so we can copy the prompt without "\n"s
+ specific_styles["user-select"] = "none"
+
+ style_str = data_str = ""
+ # converting style dict into the style attribute
+ if specific_styles:
+ inside_style_str = "; ".join(f"{k}: {v}" for k, v in specific_styles.items())
+ style_str = f" style='{inside_style_str}'"
+ if data:
+ data_str = "".join(
+ f" data-{k}='{v.replace(' ', ' ')}'" for k, v in data.items()
+ )
+ return f"{str_token}
{br}"
+
+
+_token_style = {
+ "border": "1px solid #888",
+ "display": "inline-block",
+ # each character of the same width, so we can easily spot a space
+ "font-family": "monospace",
+ "font-size": "14px",
+ "color": "black",
+ "background-color": "white",
+ "margin": "1px 0px 1px 1px",
+ "padding": "0px 1px 1px 1px",
+}
+_token_style_str = " ".join([f"{k}: {v};" for k, v in _token_style.items()])
+
+
+def vis_sample_prediction_probs(
+ sample_tok: Int[torch.Tensor, "pos"],
+ correct_probs: Float[torch.Tensor, "pos"],
+ top_k_probs: torch.return_types.topk,
+ tokenizer: PreTrainedTokenizerBase,
+) -> str:
+ colors = probs_to_colors(correct_probs)
+ token_htmls = []
+
+ # Generate a unique ID for this instance (so we can have multiple instances on the same page)
+ unique_id = str(uuid.uuid4())
+
+ token_class = f"token_{unique_id}"
+ hover_div_id = f"hover_info_{unique_id}"
+
+ for i in range(sample_tok.shape[0]):
+ tok = cast(int, sample_tok[i].item())
+ data = {}
+ if i > 0:
+ correct_prob = correct_probs[i - 1].item()
+ data["next"] = to_tok_prob_str(tok, correct_prob, tokenizer)
+ top_k_probs_tokens = top_k_probs.indices[i - 1]
+ top_k_probs_values = top_k_probs.values[i - 1]
+ for j in range(top_k_probs_tokens.shape[0]):
+ top_tok = top_k_probs_tokens[j].item()
+ top_tok = cast(int, top_tok)
+ top_prob = top_k_probs_values[j].item()
+ data[f"top{j}"] = to_tok_prob_str(top_tok, top_prob, tokenizer)
+
+ token_htmls.append(
+ token_to_html(tok, tokenizer, bg_color=colors[i], data=data).replace(
+ "class='token'", f"class='{token_class}'"
+ )
+ )
+
+ html_str = f"""
+
+ {"".join(token_htmls)}
+
+ """
+ display(HTML(html_str))
+ return html_str
diff --git a/src/delphi/eval/vis_per_token_model.py b/src/delphi/eval/vis_per_token_model.py
new file mode 100644
index 00000000..618840b0
--- /dev/null
+++ b/src/delphi/eval/vis_per_token_model.py
@@ -0,0 +1,67 @@
+import ipywidgets
+import numpy as np
+import plotly.graph_objects as go
+
+
+def visualize_per_token_category(
+ input: dict[str, dict[str, tuple]], log_scale=False, **kwargs: str
+) -> ipywidgets.VBox:
+ model_names = list(input.keys())
+ categories = list(input[model_names[0]].keys())
+ category = categories[0]
+
+ def get_hovertexts(mid: np.ndarray, lo: np.ndarray, hi: np.ndarray) -> list[str]:
+ return [f"Loss: {m:.3f} ({l:.3f}, {h:.3f})" for m, l, h in zip(mid, lo, hi)]
+
+ def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
+ x = np.array([input[name][category] for name in model_names]).T
+ means, err_lo, err_hi = x[0], x[1], x[2]
+ return means, err_lo, err_hi
+
+ means, err_low, err_hi = get_plot_values(category)
+ g = go.FigureWidget(
+ data=go.Scatter(
+ x=model_names,
+ y=means,
+ error_y=dict(
+ type="data",
+ symmetric=False,
+ array=err_hi,
+ arrayminus=err_low,
+ color=kwargs.get("bar_color", "purple"),
+ ),
+ marker=dict(
+ color=kwargs.get("marker_color", "SkyBlue"),
+ size=15,
+ line=dict(color=kwargs.get("line_color", "MediumPurple"), width=2),
+ ),
+ hovertext=get_hovertexts(means, err_low, err_hi),
+ hoverinfo="text+x",
+ ),
+ layout=go.Layout(
+ yaxis=dict(
+ title="Loss",
+ type="log" if log_scale else "linear",
+ ),
+ plot_bgcolor=kwargs.get("bg_color", "AliceBlue"),
+ ),
+ )
+
+ selected_category = ipywidgets.Dropdown(
+ options=categories,
+ placeholder="",
+ description="Token Category:",
+ disabled=False,
+ )
+
+ def response(change):
+ means, err_lo, err_hi = get_plot_values(selected_category.value)
+ with g.batch_update():
+ g.data[0].y = means
+ g.data[0].error_y["array"] = err_hi
+ g.data[0].error_y["arrayminus"] = err_lo
+ g.data[0].hovertext = get_hovertexts(means, err_lo, err_hi)
+
+ selected_category.observe(response, names="value")
+
+ return ipywidgets.VBox([selected_category, g])
diff --git a/src/delphi/static/README.md b/src/delphi/static/README.md
new file mode 100644
index 00000000..815b0c42
--- /dev/null
+++ b/src/delphi/static/README.md
@@ -0,0 +1,10 @@
+# TODO: move this to delphi/static
+# Static Data Files
+
+
+## `token_map.pkl`
+pickle file: All locations of all tokens. dict of token to list of (doc, pos) pairs.
+
+## `model_group_stats.pkl`
+useful statistics for data visualization of (model, tokengroup) pairs; dict of (model, tokengroup) to dict of (str, float):
+e.g. {("llama2", "Is Noun"): {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...}
\ No newline at end of file
diff --git a/src/delphi/static/__init__.py b/src/delphi/static/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/delphi/static/all_tokens_list.txt b/src/delphi/static/all_tokens_list.txt
new file mode 100644
index 00000000..438dddae
Binary files /dev/null and b/src/delphi/static/all_tokens_list.txt differ
diff --git a/src/delphi/static/labelled_token_ids_dict.pkl b/src/delphi/static/labelled_token_ids_dict.pkl
new file mode 100644
index 00000000..e8442a01
Binary files /dev/null and b/src/delphi/static/labelled_token_ids_dict.pkl differ
diff --git a/src/delphi/static/model_group_stats.pkl b/src/delphi/static/model_group_stats.pkl
new file mode 100644
index 00000000..7f5297c4
Binary files /dev/null and b/src/delphi/static/model_group_stats.pkl differ
diff --git a/src/delphi/static/token_map.pkl b/src/delphi/static/token_map.pkl
new file mode 100644
index 00000000..dec1a6a7
Binary files /dev/null and b/src/delphi/static/token_map.pkl differ
diff --git a/src/delphi/train/architectures.py b/src/delphi/train/architectures.py
new file mode 100644
index 00000000..134bfde3
--- /dev/null
+++ b/src/delphi/train/architectures.py
@@ -0,0 +1,74 @@
+from dataclasses import fields
+
+import torch
+from llama2c import model_export
+from llama2c.model import ModelArgs as Llama2ModelArgs
+from llama2c.model import Transformer as Llama2Model
+
+
+class ModelTypes:
+ LLAMA2C = "llama2c"
+ MAMBA = "mamba"
+
+
+args_to_load_from_checkpoint = {
+ ModelTypes.LLAMA2C: [
+ "dim",
+ "n_layers",
+ "n_heads",
+ "n_kv_heads",
+ "vocab_size",
+ "multiple_of",
+ "max_seq_len",
+ ],
+ ModelTypes.MAMBA: [
+ "n_layers",
+ "model_dim",
+ "vocab_size",
+ ],
+}
+
+
+def initialize_model(**model_args) -> torch.nn.Module:
+ if model_args["architecture"] == ModelTypes.LLAMA2C:
+ # filter model_args for fields in Llama2ModelArgs
+ llama2_arg_names = {f.name for f in fields(Llama2ModelArgs)}
+ llama2_args = {k: v for k, v in model_args.items() if k in llama2_arg_names}
+ return Llama2Model(Llama2ModelArgs(**llama2_args))
+ else:
+ raise NotImplementedError(
+ f"Architecture {model_args['architecture']} not yet implemented"
+ )
+
+
+def load_model(model_args, checkpoint) -> torch.nn.Module:
+ arch = model_args["architecture"]
+ checkpoint_model_args = checkpoint["model_args"]
+ for k in args_to_load_from_checkpoint[arch]:
+ model_args[k] = checkpoint_model_args[k]
+ if arch == ModelTypes.LLAMA2C:
+ # create the model
+ gptconf = Llama2ModelArgs(**model_args)
+ model = Llama2Model(gptconf)
+ state_dict = checkpoint["model"]
+ # fix the keys of the state dictionary :(
+ # honestly no idea how checkpoints sometimes get this prefix, have to debug more
+ unwanted_prefix = "_orig_mod."
+ for k, v in list(state_dict.items()):
+ if k.startswith(unwanted_prefix):
+ state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
+ model.load_state_dict(state_dict)
+ return model
+ else:
+ raise NotImplementedError(f"Architecture {arch} not yet implemented")
+
+
+def export_model(model, model_architecture, output_path):
+ if model_architecture == ModelTypes.LLAMA2C:
+ model_export(
+ model,
+ output_path,
+ version=0,
+ )
+ else:
+ raise NotImplementedError("only llama2c model export is supported for now")
diff --git a/src/delphi/train/gigaconfig.py b/src/delphi/train/gigaconfig.py
new file mode 100644
index 00000000..e3567519
--- /dev/null
+++ b/src/delphi/train/gigaconfig.py
@@ -0,0 +1,85 @@
+from dataclasses import dataclass
+from datetime import datetime
+
+from beartype import beartype
+
+from delphi.train.architectures import ModelTypes
+
+
+@beartype
+@dataclass
+class GigaConfig:
+ """This is a terrible hack to get usable config objects to pass around
+ It's way too big and ties way too many things together. This should be broken
+ into several smaller configs.
+ """
+
+ # device
+ device = "auto"
+
+ # model architecture
+ architecture = ModelTypes.LLAMA2C
+
+ # I/O
+ out_dir: str = "out"
+ eval_interval: int = 2000
+ log_interval: int = 1
+ eval_iters: int = 100
+ eval_only: bool = False # if True, script exits right after the first eval
+ always_save_checkpoint: bool = (
+ False # if True, always save a checkpoint after each eval
+ )
+ init_from: str = "scratch" # 'scratch' or 'resume'
+ # wandb logging
+ wandb_log: bool = True # disabled by default
+ wandb_entity: str = "jannik-brinkmann"
+ wandb_project: str = "delphi"
+ wandb_run_name: str = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
+ # data
+ batch_size: int = (
+ 64 # if gradient_accumulation_steps > 1, this is the micro-batch size
+ )
+ # TODO: delete this, use doc size always
+ max_seq_len: int = 256
+ vocab_size: int = 32000 # the Llama 2 tokenizer has 32K tokens
+ # model
+ dim: int = 288
+ n_layers: int = 6
+ n_heads: int = 6
+ n_kv_heads: int = 6
+ multiple_of: int = 32
+ dropout: float = 0.0
+ # adamw optimizer
+ gradient_accumulation_steps: int = 4 # used to simulate larger batch sizes
+ learning_rate: float = 5e-4 # max learning rate
+ max_epochs: int = 10 # total number of training epochs
+ weight_decay: float = 1e-1
+ beta1: float = 0.9
+ beta2: float = 0.95
+ grad_clip: float = 1.0 # clip gradients at this value, or disable if == 0.0
+ # learning rate decay settings
+ decay_lr: bool = True # whether to decay the learning rate
+ warmup_iters: int = 1000 # how many steps to warm up for
+ min_lr: float = 0.0 # should be ~learning_rate/10 per Chinchill
+ # reproducibility
+ seed = 1337
+ # TODO: seeds for batch ordering and weight initialization
+ # debugging
+ train_sample_limit: int = -1 # -1 implies no limit
+ val_sample_limit: int = -1
+
+
+# Jai Overrides TODO: remove these
+debug_config = GigaConfig(
+ wandb_entity="jaiwithani",
+ vocab_size=4096,
+ max_seq_len=512,
+ dim=48,
+ n_layers=2,
+ n_heads=2,
+ n_kv_heads=2,
+ max_epochs=2,
+ eval_interval=500,
+ eval_iters=10,
+ train_sample_limit=256,
+)
diff --git a/src/delphi/train/iteration_params.py b/src/delphi/train/iteration_params.py
new file mode 100644
index 00000000..2fa3c63f
--- /dev/null
+++ b/src/delphi/train/iteration_params.py
@@ -0,0 +1,29 @@
+from dataclasses import dataclass
+
+
+@dataclass
+class IterationParams:
+ num_batches: int
+ num_steps: int
+ eval_iters: int
+ lr_decay_iters: int
+ tokens_per_iter: int
+
+
+def set_iteration_params(config, train_ds, validation_ds) -> IterationParams:
+ num_batches = len(train_ds) // config.batch_size
+ num_steps = num_batches // config.gradient_accumulation_steps
+ eval_iters = min(12, len(validation_ds) // config.batch_size)
+ lr_decay_iters = (
+ config.max_epochs * num_batches
+ ) # should be ~=max_iters per Chinchilla
+ tokens_per_iter = (
+ config.gradient_accumulation_steps * config.batch_size * config.max_seq_len
+ )
+ print(f"tokens per iteration will be: {tokens_per_iter:,}")
+ print(
+ f"breaks down as: {config.gradient_accumulation_steps} grad accum steps * {config.batch_size} batch size * {config.max_seq_len} max seq len"
+ )
+ return IterationParams(
+ num_batches, num_steps, eval_iters, lr_decay_iters, tokens_per_iter
+ )
diff --git a/src/delphi/train/mamba.py b/src/delphi/train/mamba.py
new file mode 100644
index 00000000..04e8154e
--- /dev/null
+++ b/src/delphi/train/mamba.py
@@ -0,0 +1,32 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from mamba_ssm.models.config_mamba import MambaConfig
+from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
+
+
+@dataclass
+class MambaArgs(MambaConfig):
+ pass
+
+
+class Mamba(MambaLMHeadModel):
+ def __init__(self, params: MambaArgs) -> None:
+ super().__init__(params)
+
+ def forward(
+ self, input_ids: torch.Tensor, target_ids: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """
+ "position_ids" is just to be compatible with Transformer generation. We don't use it.
+ num_last_tokens: if > 0, only return the logits for the last n tokens
+ """
+ hidden_states = self.backbone(input_ids)
+ logits = self.lm_head(hidden_states)
+ self.last_loss = F.cross_entropy(
+ logits.view(-1, logits.size(-1)), target_ids.view(-1), ignore_index=-1
+ )
+
+ return logits
diff --git a/src/delphi/train/tokenized_chunks_dataset.py b/src/delphi/train/tokenized_chunks_dataset.py
new file mode 100644
index 00000000..92bf67aa
--- /dev/null
+++ b/src/delphi/train/tokenized_chunks_dataset.py
@@ -0,0 +1,55 @@
+import torch
+from torch.utils.data import Dataset
+
+from delphi.train.shuffle import shuffle_list
+
+
+class TokenizedChunksDataset(Dataset):
+ def __init__(self, tokenized_docs, max_seq_len, device):
+ self.device = device
+ self.tokenized_docs = tokenized_docs
+ self.max_len = max_seq_len
+ self.batched_tokens = (
+ torch.Tensor()
+ ) # will be initialized in initialize_samples
+
+ def initialize_samples(self):
+ # self.tokenized_docs is an (X, 1) tensor of dicts. Each entry is just {"tokens": [int]}
+ # where [int] is doc_len long
+ # we want to turn this into a (num_batches, max_len + 1) tensor of ints
+ # the +1 is for the last Y token prediction, and implies an overlap of 1 token between batches
+ # this is because each batch will be broken into X [:-1] and Y [1:]
+ tensor_tokens = torch.stack(
+ [torch.tensor(doc["tokens"]) for doc in self.tokenized_docs]
+ ).to(self.device)
+ self.batched_tokens = tensor_tokens.flatten().unfold(
+ 0, self.max_len + 1, self.max_len
+ )
+ self.indices = self._default_indices()
+
+ def _default_indices(self):
+ return list(range(len(self.batched_tokens)))
+
+ def shuffle(self, epoch: int):
+ """this is inefficient, but tinyevals are tiny, so nbd probably"""
+ # reset for idempotent determinism
+ self.indices = self._default_indices()
+ shuffle_list(self.indices, seed=epoch)
+
+ def __len__(self):
+ return len(self.batched_tokens)
+
+ def get_sample_window(self, idx):
+ return self.batched_tokens[idx % len(self.batched_tokens), :]
+
+ def __getitem__(self, idx):
+ sample = self.get_sample_window(idx)
+ X = sample[:-1]
+ Y = sample[1:]
+ return X, Y
+
+ def __iter__(self):
+ while True:
+ for idx in self.indices:
+ X, Y = self[idx]
+ yield X, Y
diff --git a/src/delphi/train/train_step.py b/src/delphi/train/train_step.py
new file mode 100644
index 00000000..21a994c1
--- /dev/null
+++ b/src/delphi/train/train_step.py
@@ -0,0 +1,125 @@
+import time
+
+import torch
+
+from delphi.train.utils import EvalData, ModelTrainingState, estimate_loss, set_lr
+
+
+def train_step(
+ model_training_state: ModelTrainingState,
+ train_ds,
+ validation_ds,
+ iteration_params,
+ eval_callbacks,
+ config,
+ train_batch_iter,
+) -> bool:
+ """
+ Runs a training step, updating (mutating in place) model_training_state
+ returns true if training should break, false otherwise
+ """
+ model = model_training_state.model
+ optimizer = model_training_state.optimizer
+
+ # here's how each train step works:
+ # 1. Set learning rate
+ # 2. (every eval_interval steps) evaluate, log to wandb, save checkpoint
+ # 3. forward backward update
+ # 4. log timing
+
+ # 1. determine and set the learning rate for this iteration
+ # TODO: move lr to ModelTrainingState
+ lr = set_lr(
+ iteration_params.lr_decay_iters,
+ config,
+ optimizer,
+ model_training_state.iter_num,
+ )
+
+ # 2. evaluate the loss on train/val sets and write checkpoints
+ if model_training_state.iter_num % config.eval_interval == 0:
+ losses = estimate_loss(
+ model=model,
+ eval_iters=iteration_params.eval_iters,
+ batch_size=config.batch_size,
+ split_to_ds={"train": train_ds, "val": validation_ds},
+ )
+ new_best_val_loss = False
+ if losses["val"] < model_training_state.best_val_loss:
+ model_training_state.best_val_loss = float(losses["val"])
+ new_best_val_loss = True
+ # TODO: refactor EvalData to use ModelTrainingState
+ eval_data = EvalData(
+ iter_num=model_training_state.iter_num,
+ tokens_per_iter=iteration_params.tokens_per_iter,
+ running_mfu=model_training_state.running_mfu,
+ lr=lr,
+ losses=losses,
+ best_val_loss=model_training_state.best_val_loss,
+ new_best_val_loss=new_best_val_loss,
+ model=model,
+ model_args=model_training_state.model_args,
+ optimizer=optimizer,
+ config=config,
+ )
+ print(
+ f"step {model_training_state.iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
+ )
+ for callback in eval_callbacks:
+ callback(eval_data)
+
+ if model_training_state.iter_num == 0 and config.eval_only:
+ return True
+
+ # 3. forward backward update, with optional gradient accumulation to simulate larger batch size
+ X, Y = next(train_batch_iter)
+ print(
+ f"gradient accumulation steps: {config.gradient_accumulation_steps}, "
+ f"num_steps: {iteration_params.num_steps}, iter_num: {model_training_state.iter_num}"
+ )
+ for micro_step in range(
+ min(
+ config.gradient_accumulation_steps,
+ iteration_params.num_steps - model_training_state.iter_num + 1,
+ )
+ ):
+ logits = model(X, Y)
+ loss = model.last_loss / config.gradient_accumulation_steps
+ # immediately async prefetch next batch while model is doing the forward pass on the GPU
+ X, Y = next(train_batch_iter)
+ loss.backward()
+ # clip the gradient
+ if config.grad_clip != 0.0:
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) # type: ignore
+ optimizer.step()
+
+ # flush the gradients as soon as we can, no need for this memory anymore
+ optimizer.zero_grad(set_to_none=True)
+
+ # 4. log timing
+ t1 = time.time()
+ dt = t1 - model_training_state.t0
+ model_training_state.t0 = t1
+ if model_training_state.iter_num % config.log_interval == 0:
+ # get loss as float, scale up due to the divide above. note: this is a CPU-GPU sync point
+ lossf = loss.item() * config.gradient_accumulation_steps
+ if (
+ model_training_state.local_iter_num >= 5
+ ): # let the training loop settle a bit
+ mfu = model.estimate_mfu(
+ config.batch_size * config.gradient_accumulation_steps, dt
+ )
+ model_training_state.running_mfu = (
+ mfu
+ if model_training_state.running_mfu == -1.0
+ else 0.9 * model_training_state.running_mfu + 0.1 * mfu
+ )
+ print(
+ (
+ f"{model_training_state.iter_num} | loss {lossf:.4f} | lr {lr:e} | "
+ f"{dt*1000:.2f}ms | mfu {model_training_state.running_mfu*100:.2f}%"
+ )
+ )
+ model_training_state.iter_num += 1
+ model_training_state.local_iter_num += 1
+ return False
diff --git a/src/delphi/train/training.py b/src/delphi/train/training.py
index ecbd034e..99374b02 100644
--- a/src/delphi/train/training.py
+++ b/src/delphi/train/training.py
@@ -1,276 +1,74 @@
+import os
+import time
+from dataclasses import fields
-from datetime import datetime
-from dataclasses import dataclass
-import os
import torch
-import time
-from contextlib import nullcontext
-import math
-@dataclass
-def TrainingConfig(config):
- # -----------------------------------------------------------------------------
- # I/O
- out_dir: str = "out"
- eval_interval: int = 2000
- log_interval:int = 1
- eval_iters:int = 100
- eval_only:bool = False # if True, script exits right after the first eval
- always_save_checkpoint:bool = False # if True, always save a checkpoint after each eval
- init_from:bool = "scratch" # 'scratch' or 'resume'
- # wandb logging
- wandb_log:bool = False # disabled by default
- wandb_project:str = "llamac"
- wandb_run_name:str = "run" + datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
- # data
- batch_size:int = 128 # if gradient_accumulation_steps > 1, this is the micro-batch size
- max_seq_len:int = 256
- vocab_source:str = "llama2" # llama2|custom; use Lllama 2 vocab from Meta, or custom trained
- vocab_size:str = 32000 # the Llama 2 tokenizer has 32K tokens
- # model
- dim:int = 288
- n_layers:int = 6
- n_heads:int = 6
- n_kv_heads:int = 6
- multiple_of:int = 32
- dropout:int = 0.0
- # adamw optimizer
- gradient_accumulation_steps:int = 4 # used to simulate larger batch sizes
- learning_rate:float = 5e-4 # max learning rate
- max_iters:int = 100000 # total number of training iterations
- weight_decay:float = 1e-1
- beta1:float = 0.9
- beta2:float = 0.95
- grad_clip:float = 1.0 # clip gradients at this value, or disable if == 0.0
- # learning rate decay settings
- decay_lr:bool = True # whether to decay the learning rate
- warmup_iters:int = 1000 # how many steps to warm up for
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from delphi.train import wandb_utils
+from delphi.train.gigaconfig import GigaConfig, debug_config
+from delphi.train.iteration_params import set_iteration_params
+from delphi.train.train_step import train_step
+from delphi.train.utils import (
+ get_device,
+ load_delphi_training_dataset,
+ load_model_training_state,
+ save_checkpoint_if_needed,
+)
+
+
+def run_training(config: GigaConfig):
+ print("Starting training...")
+ print()
+ print("Config:")
+ for field in fields(config):
+ print(f" {field.name}: {getattr(config, field.name)}")
# system
- device:str = "cuda" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
- dtype:str = "bfloat16" # float32|bfloat16|float16
- compile:bool = True # use PyTorch 2.0 to compile the model to be faster
- # -----------------------------------------------------------------------------
- config_keys = [
- k
- for k, v in globals().items()
- if not k.startswith("_") and isinstance(v, (int, float, bool, str))
- ]
- exec(open("configurator.py").read()) # overrides from command line or config file
- config = {k: globals()[k] for k in config_keys} # will be useful for logging
-
-
- # -----------------------------------------------------------------------------
-
- # fixing some hyperparams to sensible defaults
- lr_decay_iters = max_iters # should be ~= max_iters per Chinchilla
- min_lr = 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
-
- # validating checks
- assert vocab_source in ["llama2", "custom"]
- assert vocab_source == "custom" or vocab_size == 32000, "The vocab from Meta has 32K tokens"
+ device = get_device(config.device)
- # various inits, derived attributes, I/O setup
- seed = 1337
- os.makedirs(out_dir, exist_ok=True)
-
-
-def model_initialization(config):
+ # load data
+ print("Loading data...")
+ train_ds = load_delphi_training_dataset(
+ "train", config.max_seq_len, device, limit=config.train_sample_limit
+ )
+ validation_ds = load_delphi_training_dataset(
+ "validation", config.max_seq_len, device, limit=config.val_sample_limit
+ )
- #model
- if config["model"] == "llama2":
- from delphi.models.llama2 import LLaMA2, LLaMA2Args
- model_args = dict(
- dim=config["dim"],
- n_layers=config["n_layers"],
- n_heads=config["n_heads"],
- n_kv_heads=config["n_kv_heads"],
- vocab_size=config["vocab_size"],
- multiple_of=config["multiple_of"],
- max_seq_len=config["max_seq_len"],
- dropout=config["dropout"],
- )
- gptconf = LLaMA2Args(**model_args)
- model = LLaMA2(gptconf)
- elif config["model"] == "mamba":
- from delphi.models.mamba import Mamba, MambaArgs
- model_args = dict(
- dim=config["dim"],
- n_layers=config["n_layers"],
- vocab_size=config["vocab_size"],
- )
- mambaconf = MambaArgs(**model_args)
- model = Mamba(mambaconf)
-
- if config["init_from"] == "resume":
- print(f"Resuming training from {config['out_dir']}")
- # resume training from a checkpoint.
- ckpt_path = os.path.join(config['out_dir'], "ckpt.pt")
- checkpoint = torch.load(ckpt_path, map_location=config['device'])
- checkpoint_model_args = checkpoint["model_args"]
- # force these config attributes to be equal otherwise we can't even resume training
- # the rest of the attributes (e.g. dropout) can stay as desired from command line
- for k in ["dim", "n_layers", "n_heads", "n_kv_heads", "vocab_size", "multiple_of", "max_seq_len"]:
- model_args[k] = checkpoint_model_args[k]
- # create the model
- state_dict = checkpoint["model"]
- # fix the keys of the state dictionary :(
- # honestly no idea how checkpoints sometimes get this prefix, have to debug more
- unwanted_prefix = "_orig_mod."
- for k, v in list(state_dict.items()):
- if k.startswith(unwanted_prefix):
- state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
- model.load_state_dict(state_dict)
- config.iter_num = checkpoint["iter_num"]
- config.best_val_loss = checkpoint["best_val_loss"]
-
- model.to(config["device"])
- # compile the model
- if config["compile"]:
- print("compiling the model... (takes a ~minute)")
- unoptimized_model = model
- model = torch.compile(model) # requires PyTorch 2.0
- return model,model_args
+ # derive iteration params (num_batches, num_steps, etc)
+ iteration_params = set_iteration_params(config, train_ds, validation_ds)
-def train_loop(model, TrainConf):
- torch.manual_seed(TrainConf.seed)
+ # setup
+ print("Setting up...")
+ os.makedirs(config.out_dir, exist_ok=True)
+ torch.manual_seed(config.seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
- device_type = "cuda" if "cuda" in TrainConf.device else "cpu" # for later use in torch.autocast
- # note: float16 data type will automatically use a GradScaler
- ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[TrainConf.dtype]
- ctx = (
- nullcontext()
- if device_type == "cpu"
- else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
- )
- scaler = torch.cuda.amp.GradScaler(enabled=(TrainConf.dtype == "float16"))
- # optimizer
- optimizer = model.configure_optimizers(TrainConf.weight_decay, TrainConf.learning_rate, (TrainConf.beta1, TrainConf.beta2), device_type)
- if TrainConf.init_from == "resume" and "optimizer" in checkpoint:
- optimizer.load_state_dict(checkpoint["optimizer"])
- checkpoint = None # free up memory
-
- if TrainConf.wandb_log:
- import wandb
- wandb.init(project=TrainConf.wandb_project, name=TrainConf.wandb_run_name, config=TrainConf.config)
-
- train_batch_iter = TrainConf.iter_batches(split="train")
- X, Y = next(train_batch_iter) # fetch the very first batch
- t0 = time.time()
- local_iter_num = 0 # number of iterations in the lifetime of this process
- raw_model = model # unwrap DDP container if needed
- running_mfu = -1.0
- while True:
- # determine and set the learning rate for this iteration
- lr = get_lr(iter_num,TrainConf) if TrainConf.decay_lr else TrainConf.learning_rate
- for param_group in optimizer.param_groups:
- param_group["lr"] = lr
-
- # evaluate the loss on train/val sets and write checkpoints
- if iter_num % TrainConf.eval_interval == 0:
- losses = estimate_loss()
- print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
- if TrainConf.wandb_log:
- try:
- wandb.log(
- {
- "iter": iter_num,
- "tokens": iter_num * TrainConf.tokens_per_iter,
- "loss/train": losses["train"],
- "loss/val": losses["val"],
- "lr": lr,
- "mfu": running_mfu * 100, # convert to percentage
- }, step = iter_num
- )
- except Exception as e:
- print(f"logging to wandb failed: {e}")
- if losses["val"] < best_val_loss or TrainConf.always_save_checkpoint:
- best_val_loss = losses["val"]
- if iter_num > 0:
- checkpoint = {
- "model": raw_model.state_dict(),
- "optimizer": optimizer.state_dict(),
- "model_args": model_args,
- "iter_num": iter_num,
- "best_val_loss": best_val_loss,
- "config": config,
- }
- print(f"saving checkpoint to {out_dir}")
- torch.save(checkpoint, os.path.join(out_dir, "ckpt.pt"))
- model_export(raw_model, os.path.join(out_dir, "model.bin"), version=0)
- if iter_num == 0 and eval_only:
- break
-
- # forward backward update, with optional gradient accumulation to simulate larger batch size
- # and using the GradScaler if data type is float16
- for micro_step in range(gradient_accumulation_steps):
- with ctx:
- logits = model(X, Y)
- loss = raw_model.last_loss
- loss = loss / gradient_accumulation_steps
- # immediately async prefetch next batch while model is doing the forward pass on the GPU
- X, Y = next(train_batch_iter)
- # backward pass, with gradient scaling if training in fp16
- scaler.scale(loss).backward()
- # clip the gradient
- if grad_clip != 0.0:
- scaler.unscale_(optimizer)
- torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
- # step the optimizer and scaler if training in fp16
- scaler.step(optimizer)
- scaler.update()
- # flush the gradients as soon as we can, no need for this memory anymore
- optimizer.zero_grad(set_to_none=True)
-
- # timing and logging
- t1 = time.time()
- dt = t1 - t0
- t0 = t1
- if iter_num % log_interval == 0:
- # get loss as float, scale up due to the divide above. note: this is a CPU-GPU sync point
- lossf = loss.item() * gradient_accumulation_steps
- if local_iter_num >= 5: # let the training loop settle a bit
- mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
- running_mfu = mfu if running_mfu == -1.0 else 0.9 * running_mfu + 0.1 * mfu
- print(
- f"{iter_num} | loss {lossf:.4f} | lr {lr:e} | {dt*1000:.2f}ms | mfu {running_mfu*100:.2f}%"
+ # model init
+ model_training_state = load_model_training_state(config, device)
+
+ # setup eval callbacks
+ eval_callbacks = [save_checkpoint_if_needed]
+ if config.wandb_log:
+ wandb_utils.init_wandb(config)
+ eval_callbacks.append(wandb_utils.log_to_wandb)
+
+ # training loop
+ print("Starting training...")
+ for epoch in range(config.max_epochs):
+ train_ds.shuffle(epoch)
+ train_batch_iter = iter(DataLoader(train_ds, batch_size=config.batch_size)) # type: ignore
+ for _ in tqdm(range(iteration_params.num_steps)):
+ breaknow = train_step(
+ model_training_state,
+ train_ds,
+ validation_ds,
+ iteration_params,
+ eval_callbacks,
+ config,
+ train_batch_iter,
)
- iter_num += 1
- local_iter_num += 1
-
- # termination conditions
- if iter_num > max_iters:
- break
-
-@torch.no_grad()
-def estimate_loss():
- out = {}
- model.eval()
- for split in ["train", "val"]:
- batch_iter = iter_batches(split=split)
- losses = torch.zeros(eval_iters) # keep on CPU
- for k in range(eval_iters):
- X, Y = next(batch_iter)
- with ctx:
- logits = model(X, Y)
- loss = raw_model.last_loss
- losses[k] = loss.item()
- out[split] = losses.mean()
- model.train()
- return out
-
-# learning rate decay scheduler (cosine with warmup)
-def get_lr(it,TrainConf):
- # 1) linear warmup for warmup_iters steps
- if it < TrainConf.warmup_iters:
- return TrainConf.learning_rate * it / TrainConf.warmup_iters
- # 2) if it > lr_decay_iters, return min learning rate
- if it > TrainConf.lr_decay_iters:
- return TrainConf.min_lr
- # 3) in between, use cosine decay down to min learning rate
- decay_ratio = (it - TrainConf.warmup_iters) / (TrainConf.lr_decay_iters - TrainConf.warmup_iters)
- assert 0 <= decay_ratio <= 1
- coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
- return TrainConf.min_lr + coeff * (TrainConf.learning_rate - TrainConf.min_lr)
-
-
+ if breaknow:
+ break
diff --git a/src/delphi/train/utils.py b/src/delphi/train/utils.py
index 73c3e9e6..619e0c14 100644
--- a/src/delphi/train/utils.py
+++ b/src/delphi/train/utils.py
@@ -1,6 +1,269 @@
import json
+import math
+import os
+import time
+from dataclasses import asdict, dataclass
+from pathlib import Path
+from typing import Any, cast
+
+import torch
+from torch import Tensor
+from torch.optim import AdamW
+from torch.utils.data import DataLoader, Dataset
+
+from delphi import constants
+from delphi.eval.utils import load_delphi_dataset
+from delphi.train.architectures import export_model, initialize_model, load_model
+from delphi.train.gigaconfig import GigaConfig
+from delphi.train.tokenized_chunks_dataset import TokenizedChunksDataset
+
def load_config(config_path):
- with open(config_path, 'r') as file:
+ with open(config_path, "r") as file:
return json.load(file)
-
\ No newline at end of file
+
+
+def get_device(device_str: str = "auto") -> torch.device:
+ """
+ Get torch device specified by device_str. May pass "auto" to set torch device automatically.
+ """
+ # cuda if available; else mps if apple silicon; else cpu
+ if device_str == "auto":
+ if torch.cuda.is_available():
+ device_str = "cuda"
+ elif torch.backends.mps.is_available():
+ device_str = "mps"
+ else:
+ device_str = "cpu"
+ return torch.device(device_str)
+
+
+@dataclass
+class ModelMidTrain:
+ # hack for packing the values touched by resume_model in a single object
+ model: torch.nn.Module
+ iter_num: int
+ best_val_loss: float
+ checkpoint: Any
+
+
+def resume_model(
+ resume_from_path: Path, device: torch.device, **model_args
+) -> ModelMidTrain:
+ ckpt_path = resume_from_path / "ckpt.pt"
+ checkpoint = torch.load(ckpt_path, map_location=device)
+ model = load_model(model_args, checkpoint)
+ iter_num = checkpoint["iter_num"]
+ best_val_loss = checkpoint["best_val_loss"]
+ return ModelMidTrain(
+ model=model,
+ iter_num=iter_num,
+ best_val_loss=best_val_loss,
+ checkpoint=checkpoint,
+ )
+
+
+def get_optimizer(
+ model: torch.nn.Module,
+ config: GigaConfig,
+ device: torch.device,
+ checkpoint=None,
+) -> AdamW:
+ device_type = device.type
+ optimizer = model.configure_optimizers(
+ config.weight_decay,
+ config.learning_rate,
+ (config.beta1, config.beta2),
+ device_type,
+ )
+ if checkpoint is not None:
+ optimizer.load_state_dict(checkpoint["optimizer"])
+ return optimizer
+
+
+@torch.no_grad()
+def estimate_loss(
+ model: torch.nn.Module,
+ eval_iters: int,
+ batch_size: int,
+ split_to_ds: dict[str, Dataset],
+) -> dict[str, float]:
+ """helps estimate an arbitrarily accurate loss over either split using many batches"""
+ out = {}
+ model.eval()
+ for split, ds in split_to_ds.items():
+ batch_iter = iter(DataLoader(ds, batch_size=batch_size)) # type: ignore
+ losses = torch.zeros(eval_iters) # keep on CPU
+ for k in range(min(eval_iters, len(ds) // batch_size)): # type: ignore
+ X, Y = next(batch_iter)
+ # forward pass, which will also compute the loss
+ _logits = model(X, Y)
+ loss = cast(Tensor, model.last_loss)
+ losses[k] = loss.item()
+ out[split] = losses.mean()
+ model.train()
+ return out
+
+
+def get_lr(it, warmup_iters, learning_rate, lr_decay_iters, min_lr):
+ # 1) linear warmup for warmup_iters steps
+ if it < warmup_iters:
+ return learning_rate * it / warmup_iters
+ # 2) if it > lr_decay_iters, return min learning rate
+ if it > lr_decay_iters:
+ return min_lr
+ # 3) in between, use cosine decay down to min learning rate
+ decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
+ assert 0 <= decay_ratio <= 1
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
+ return min_lr + coeff * (learning_rate - min_lr)
+
+
+def set_lr(
+ lr_decay_iters: int,
+ config: GigaConfig,
+ optimizer: torch.optim.Optimizer,
+ iter_num: int,
+):
+ lr = (
+ get_lr(
+ iter_num,
+ config.warmup_iters,
+ config.learning_rate,
+ lr_decay_iters,
+ config.min_lr,
+ )
+ if config.decay_lr
+ else config.learning_rate
+ )
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = lr
+ return lr
+
+
+@dataclass
+class EvalData:
+ # values we expose to eval callback functions
+ iter_num: int
+ tokens_per_iter: int
+ running_mfu: float
+ lr: float
+ losses: dict[str, float]
+ best_val_loss: float
+ new_best_val_loss: bool
+ model: torch.nn.Module
+ model_args: Any
+ optimizer: torch.optim.Optimizer
+ config: GigaConfig
+
+
+def save_checkpoint_if_needed(eval_data: EvalData):
+ # we save if it's not the first iter AND at least one of:
+ # 1) we have a new best validation loss
+ # 2) always_save_checkpoint is set
+ if eval_data.iter_num == 0:
+ return
+ if (not eval_data.new_best_val_loss) and (
+ not eval_data.config.always_save_checkpoint
+ ):
+ return
+ checkpoint = {
+ "model": eval_data.model.state_dict(),
+ "optimizer": eval_data.optimizer.state_dict(),
+ "model_args": eval_data.model_args,
+ "iter_num": eval_data.iter_num,
+ "best_val_loss": eval_data.best_val_loss,
+ "config": asdict(eval_data.config),
+ }
+ print(f"saving checkpoint to {eval_data.config.out_dir}")
+ torch.save(checkpoint, os.path.join(eval_data.config.out_dir, "ckpt.pt"))
+ export_model(
+ eval_data.model,
+ eval_data.model_args["architecture"],
+ os.path.join(eval_data.config.out_dir, "model.bin"),
+ )
+
+
+@dataclass
+class ModelTrainingState:
+ model: torch.nn.Module
+ optimizer: torch.optim.Optimizer
+ model_args: Any
+ iter_num: int
+ local_iter_num: int
+ best_val_loss: float
+ running_mfu: float
+ t0: float
+
+
+def load_model_training_state(
+ config: GigaConfig, device: torch.device
+) -> ModelTrainingState:
+ iter_num = 0
+ local_iter_num = 0
+ best_val_loss = 1e9
+ running_mfu = -1.0
+ t0 = time.time()
+ model_args = dict(
+ architecture=config.architecture,
+ dim=config.dim,
+ n_layers=config.n_layers,
+ n_heads=config.n_heads,
+ n_kv_heads=config.n_kv_heads,
+ vocab_size=config.vocab_size,
+ multiple_of=config.multiple_of,
+ max_seq_len=config.max_seq_len,
+ dropout=config.dropout,
+ ) # start with model_args from command line
+ if config.init_from == "scratch":
+ # init a new model from scratch
+ print("Initializing a new model from scratch")
+ model = initialize_model(**model_args)
+ checkpoint = None
+ # TODO: resume from huggingface model
+ elif config.init_from == "resume":
+ print(f"Resuming training from {config.out_dir}")
+ model_mid_train = resume_model(Path(config.out_dir), device, **model_args)
+ model = model_mid_train.model
+ iter_num = model_mid_train.iter_num
+ best_val_loss = model_mid_train.best_val_loss
+ checkpoint = model_mid_train.checkpoint
+ model.to(device)
+ # optimizer
+ optimizer = get_optimizer(
+ model=model,
+ config=config,
+ device=device,
+ checkpoint=checkpoint
+ if checkpoint is not None and "optimizer" in checkpoint
+ else None,
+ )
+ checkpoint = None # free up memory
+ return ModelTrainingState(
+ model=model,
+ optimizer=optimizer,
+ model_args=model_args,
+ iter_num=iter_num,
+ local_iter_num=local_iter_num,
+ best_val_loss=best_val_loss,
+ running_mfu=running_mfu,
+ t0=t0,
+ )
+
+
+def load_delphi_training_dataset(
+ split: str, max_seq_len: int, device: torch.device, limit: int = -1
+):
+ """For training, we want (X, Y) pairs, where X is a chunk of text and Y is the next token.)
+ To construct this, we take the original tokenized dataset, break it into max_seq_len+1 length chunks,
+ and then take [:-1] as X and [1:] as Y.
+ """
+ if limit == -1:
+ ds = load_delphi_dataset(constants.TOKENIZED_CORPUS_DATASET, split)
+ else:
+ ds = load_delphi_dataset(constants.TOKENIZED_CORPUS_DATASET, split).select(
+ range(limit)
+ )
+ token_ds = TokenizedChunksDataset(ds, max_seq_len, device)
+ token_ds.initialize_samples()
+ return token_ds
diff --git a/src/delphi/train/wandb_utils.py b/src/delphi/train/wandb_utils.py
new file mode 100644
index 00000000..33d37667
--- /dev/null
+++ b/src/delphi/train/wandb_utils.py
@@ -0,0 +1,32 @@
+from dataclasses import asdict
+
+import wandb
+
+from delphi.train.gigaconfig import GigaConfig
+from delphi.train.utils import EvalData
+
+
+def init_wandb(config: GigaConfig):
+ wandb.init(
+ entity=config.wandb_entity,
+ project=config.wandb_project,
+ name=config.wandb_run_name,
+ config=asdict(config),
+ )
+
+
+def log_to_wandb(eval_data: EvalData):
+ try:
+ wandb.log(
+ {
+ "iter": eval_data.iter_num,
+ "tokens": eval_data.iter_num * eval_data.tokens_per_iter,
+ "loss/train": eval_data.losses["train"],
+ "loss/val": eval_data.losses["val"],
+ "lr": eval_data.lr,
+ "mfu": eval_data.running_mfu * 100, # convert to percentage
+ },
+ step=eval_data.iter_num,
+ )
+ except Exception as e:
+ print(f"logging to wandb failed: {e}")
diff --git a/src/llama2c b/src/llama2c
new file mode 160000
index 00000000..07947ec9
--- /dev/null
+++ b/src/llama2c
@@ -0,0 +1 @@
+Subproject commit 07947ec91095980c60f7c39bf42d2105a0d81bb1
diff --git a/tests/eval/test_compare_models.py b/tests/eval/test_compare_models.py
new file mode 100644
index 00000000..0521b0cb
--- /dev/null
+++ b/tests/eval/test_compare_models.py
@@ -0,0 +1,23 @@
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from delphi.eval.compare_models import NextTokenStats, compare_models
+from delphi.eval.utils import load_validation_dataset, tokenize
+
+
+def test_compare_models():
+ with torch.set_grad_enabled(False):
+ model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-1M")
+ model_instruct = AutoModelForCausalLM.from_pretrained(
+ "roneneldan/TinyStories-Instruct-1M"
+ )
+ ds_txt = load_validation_dataset("tinystories-v2-clean")["story"]
+ tokenizer = AutoTokenizer.from_pretrained("roneneldan/TinyStories-1M")
+ sample_tok = tokenize(tokenizer, ds_txt[0])
+ K = 3
+ model_comparison = compare_models(model, model_instruct, sample_tok, top_k=K)
+ # ignore the first element comparison
+ assert model_comparison[0] is None
+ assert isinstance(model_comparison[1], NextTokenStats)
+ assert len(model_comparison) == sample_tok.shape[0]
+ assert len(model_comparison[1].topk) == K
diff --git a/tests/eval/test_token_labelling.py b/tests/eval/test_token_labelling.py
new file mode 100644
index 00000000..a727ddc0
--- /dev/null
+++ b/tests/eval/test_token_labelling.py
@@ -0,0 +1,114 @@
+import pytest
+import spacy
+from spacy.language import Language
+from spacy.tokens import Doc
+
+import delphi.eval.token_labelling as tl
+
+
+@pytest.fixture
+def dummy_doc() -> tuple[str, Doc, dict[str, bool]]:
+ """
+ Create a dummy Doc (list of Tokens) with specific attributes for testing purposes.
+ """
+ nlp_dummy = Language()
+
+ # Assume we're creating a dummy token with specific attributes
+ words = ["Peter", "is", "a", "person"]
+ spaces = [True, True, True, True] # No space after "dummy_token"
+ pos_tags = ["PROPN", "AUX", "DET", "NOUN"] # Part-of-speech tag
+ dep_tags = ["nsubj", "ROOT", "det", "attr"] # Dependency tag
+ ner_tags = ["PERSON", "", "", ""] # Named entity tag
+
+ # Ensure the length of pos_tags and dep_tags matches the length of words
+ assert len(words) == len(pos_tags) == len(dep_tags) == len(ner_tags)
+
+ # Create a Doc with one dummy token
+ doc = Doc(nlp_dummy.vocab, words=words, spaces=spaces)
+
+ # Manually set POS, dependency and NER tags
+ for token, pos, dep, ner_tag in zip(doc, pos_tags, dep_tags, ner_tags):
+ token.pos_, token.dep_, token.ent_type_ = pos, dep, ner_tag
+
+ # Token labels for "Peter" in the dummy doc
+ PETER_TOKEN_LABEL = {
+ "Starts with space": False,
+ "Capitalized": True,
+ "Is Adjective": False,
+ "Is Adposition": False,
+ "Is Adverb": False,
+ "Is Auxiliary": False,
+ "Is Coordinating conjuction": False,
+ "Is Determiner": False,
+ "Is Interjunction": False,
+ "Is Noun": False,
+ "Is Numeral": False,
+ "Is Particle": False,
+ "Is Pronoun": False,
+ "Is Proper Noun": True,
+ "Is Punctuation": False,
+ "Is Subordinating conjuction": False,
+ "Is Symbol": False,
+ "Is Verb": False,
+ "Is Other": False,
+ "Is Named Entity": True,
+ }
+ text = " ".join(words)
+ return text, doc, PETER_TOKEN_LABEL
+
+
+def test_explain_token_labels(dummy_doc):
+ """
+ Test the explain_token_labels function.
+ """
+ # explain all labels
+ tl.explain_token_labels()
+ # print explanations for the first token in doc
+ text, doc, PETER_TOKEN_LABEL = dummy_doc
+ tl.explain_token_labels(doc[0])
+
+
+def test_label_single_token(dummy_doc):
+ """
+ Test the label_single_token function.
+ """
+ # create a dummy token
+ text, doc, PETER_TOKEN_LABEL = dummy_doc
+ token = doc[0]
+ # label the token
+ labels = tl.label_single_token(token)
+ # check if the labels are correct
+ assert labels == PETER_TOKEN_LABEL
+
+
+def test_label_sentence(dummy_doc):
+ """
+ Test the label_sentence function.
+ """
+ text, doc, PETER_TOKEN_LABEL = dummy_doc
+ # label the sentence
+ labels = tl.label_sentence(doc)
+ # assert the first token is labeled correctly
+ assert labels[0] == PETER_TOKEN_LABEL
+ # iterate through tokens in doc
+ for token, label in zip(doc, labels):
+ assert label == tl.label_single_token(token)
+
+
+def test_label_batch_sentences(dummy_doc):
+ """
+ Test the label_batch_sentences function.
+ """
+ # create a batch of sentences
+ text, doc, PETER_TOKEN_LABEL = dummy_doc
+ text = text.split(" ")
+ batch = [text, text, text]
+ # label the batch
+ labels = tl.label_batch_sentences(batch, tokenized=True)
+ # assert the first token is labeled correctly
+ assert labels[0][0] == PETER_TOKEN_LABEL
+ assert labels[1][0] == PETER_TOKEN_LABEL
+ assert labels[2][0] == PETER_TOKEN_LABEL
+ # iterate through tokens in doc
+ for token, label in zip(doc, labels[0]):
+ assert label == tl.label_single_token(token)
diff --git a/tests/eval/test_token_map.py b/tests/eval/test_token_map.py
new file mode 100644
index 00000000..2f896326
--- /dev/null
+++ b/tests/eval/test_token_map.py
@@ -0,0 +1,47 @@
+import pytest
+from datasets import Dataset
+
+from delphi.eval.token_map import token_map
+
+
+def test_token_map():
+ tokenized_dataset = Dataset.from_dict(
+ {
+ "tokens": [
+ [0, 1, 2, 3, 4, 5, 0, 6, 7],
+ [0, 1, 2, 3, 4, 5, 0, 6, 7],
+ [0, 1, 2, 3, 4, 5, 0, 6, 7],
+ ]
+ }
+ )
+ assert token_map(tokenized_dataset, tokenizer_size=9) == [
+ [(0, 0), (0, 6), (1, 0), (1, 6), (2, 0), (2, 6)],
+ [(0, 1), (1, 1), (2, 1)],
+ [(0, 2), (1, 2), (2, 2)],
+ [(0, 3), (1, 3), (2, 3)],
+ [(0, 4), (1, 4), (2, 4)],
+ [(0, 5), (1, 5), (2, 5)],
+ [(0, 7), (1, 7), (2, 7)],
+ [(0, 8), (1, 8), (2, 8)],
+ [], # token 8 is not present in the dataset
+ ]
+
+ # fmt: off
+ tokenized_dataset = Dataset.from_dict(
+ { # one really long prompt
+ "tokens": [
+ [0, 1, 2, 3, 4, 5, 0, 6, 7, 0, 1, 2, 3, 4, 5, 0, 6, 7, 0, 1, 2, 3, 4, 5, 0, 6, 7]
+ ]
+ }
+ )
+ # fmt: on
+ assert token_map(tokenized_dataset, tokenizer_size=8) == [
+ [(0, 0), (0, 6), (0, 9), (0, 15), (0, 18), (0, 24)],
+ [(0, 1), (0, 10), (0, 19)],
+ [(0, 2), (0, 11), (0, 20)],
+ [(0, 3), (0, 12), (0, 21)],
+ [(0, 4), (0, 13), (0, 22)],
+ [(0, 5), (0, 14), (0, 23)],
+ [(0, 7), (0, 16), (0, 25)],
+ [(0, 8), (0, 17), (0, 26)],
+ ]
diff --git a/tests/scripts/functional_test_generate_logprobs.sh b/tests/scripts/functional_test_generate_logprobs.sh
new file mode 100644
index 00000000..95085645
--- /dev/null
+++ b/tests/scripts/functional_test_generate_logprobs.sh
@@ -0,0 +1,22 @@
+#!/bin/bash
+#test to check if whether inference.py uploads log probabilities to Hugging Face.
+#similar to generate_logprobs.sh, much smaller.
+
+BATCH_SIZE=80
+DATASET_NAME="delphi-suite/tinystories-v2-clean-tokenized"
+USERNAME="transcendingvictor" # Your Hugging Face username
+TOKEN="hf_aaaaaaaaaaaaaaaaaaaaaaaaa" # Your Hugging Face API token
+
+# List of models
+declare -a MODEL_NAMES=("delphi-suite/delphi-llama2-100k"
+ "delphi-suite/delphi-llama2-200k"
+ )
+
+# Loop through each model and generate log probabilities
+for MODEL_NAME in "${MODEL_NAMES[@]}"
+do
+ echo "Processing $MODEL_NAME"
+ python scripts/inference.py "$MODEL_NAME" --batch-size "$BATCH_SIZE" --dataset-name "$DATASET_NAME" --username "$USERNAME" --token "$TOKEN" --test-funct
+done
+
+echo "All models processed."
\ No newline at end of file