Skip to content

Commit

Permalink
Merge pull request #807 from epfml/654-improve-gpt-julien
Browse files Browse the repository at this point in the history
Fix and rework GPT-TF.js
  • Loading branch information
JulienVig authored Dec 6, 2024
2 parents 8176df3 + 30de4fb commit 7844f97
Show file tree
Hide file tree
Showing 28 changed files with 835 additions and 353 deletions.
76 changes: 76 additions & 0 deletions .github/workflows/record-cypress.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
name: record-cypress
on:
workflow_dispatch:

permissions:
contents: read

jobs:
download-datasets:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
lfs: true
submodules: true
- uses: actions/cache@v4
with:
path: datasets
key: datasets-${{ hashFiles('datasets/**') }}
- run: datasets/populate

build-lib:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version-file: .nvmrc
cache: npm
- run: npm ci
- run: npm --workspace=discojs run build

build-lib-web:
needs: build-lib
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version-file: .nvmrc
cache: npm
- run: npm ci
- run: npm run --workspace=discojs build
- run: npm run --workspace=discojs-web build

record-test-webapp:
needs: [build-lib, build-lib-web, download-datasets]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
lfs: true
submodules: true
- uses: actions/cache@v4
with:
path: datasets
key: datasets-${{ hashFiles('datasets/**') }}
- uses: actions/setup-node@v4
with:
node-version-file: .nvmrc
cache: npm
- run: npm ci
- run: npm --workspace={discojs,discojs-web} run build
- run: npm --workspace=webapp run test:unit
- uses: cypress-io/github-action@v6
with:
working-directory: webapp
install: false
start: npm start
wait-on: 'http://localhost:8081' # Waits for above
# Records to Cypress Cloud
# https://docs.cypress.io/guides/cloud/projects#Set-up-a-project-to-record
record: true
env:
VITE_SERVER_URL: http://server
CYPRESS_RECORD_KEY: ${{ secrets.CYPRESS_RECORD_KEY }}
1 change: 1 addition & 0 deletions cli/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"watch": "nodemon --ext ts --ignore dist --watch ../discojs-node/dist --watch ../server/dist --watch . --exec npm run",
"start": "npm run build && node dist/cli.js",
"benchmark_gpt": "npm run build && node dist/benchmark_gpt.js",
"train_gpt": "npm run build && node dist/train_gpt.js",
"build": "tsc",
"lint": "npx eslint .",
"test": ": nothing"
Expand Down
28 changes: 13 additions & 15 deletions cli/src/benchmark_gpt.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import '@tensorflow/tfjs-node';
import { List } from "immutable";
import { parse } from "ts-command-line-args";
import { AutoTokenizer } from "@xenova/transformers";
Expand Down Expand Up @@ -41,7 +42,7 @@ const args = { ...defaultArgs, ...parsedArgs }
* Benchmark results are reported in https://github.com/epfml/disco/pull/659
*/

async function main(args: Required<CLIArguments>): Promise<void> {
async function main(args: Required<CLIArguments>): Promise<void> {
const { inference: benchmarkInference, modelType,
contextLength, batchSize, modelPath } = args

Expand All @@ -68,20 +69,20 @@ async function main(args: Required<CLIArguments>): Promise<void> {
const config: models.GPTConfig = {
modelType: modelType as models.GPTConfig['modelType'],
maxIter: iterationsPerEpoch,
blockSize: contextLength,
lr: 0.0001,
vocabSize: 50258 // default wikitext task uses the gpt2 tokenizer with vocabSize 50258
contextLength,
}

// Load the dataset after setting the Task batch size and max sequence length
// to make sure the dataset is batched and tokenized correctly
task.trainingInformation.batchSize = batchSize
task.trainingInformation.maxSequenceLength = contextLength
task.trainingInformation.contextLength = contextLength
const dataset = loadText('../datasets/wikitext/wiki.train.tokens')
.map(text => processing.tokenize(tokenizer, text))
.flatten()
.batch(config.contextLength + 1, 1)

const maxLength = task.trainingInformation.maxSequenceLength ?? (tokenizer.model_max_length as number) + 1
const preprocessedDataset = dataset
.map((line) => processing.tokenizeAndLeftPad(line, tokenizer, maxLength))
.map((tokens) => [tokens.pop(), tokens.last()] as [List<number>, number])
.batch(batchSize);

Expand All @@ -108,25 +109,22 @@ async function main(args: Required<CLIArguments>): Promise<void> {

// Benchmark parameters
const prompt = 'The game began development in 2010 , carrying over a large portion, The game began development in 2010 , carrying over a large portion, The game began development in 2010 , carrying over a large portion,'
const nbNewTokens = 200
const maxNewTokens = 200
const iterations = 10
console.log("Generating", nbNewTokens, "new tokens")
console.log("Generating", maxNewTokens, "new tokens")

let tokens = List(
(tokenizer(prompt, { return_tensor: false }) as { input_ids: number[] })
.input_ids,
);
let tokens = processing.tokenize(tokenizer, prompt);

let inferenceTime = 0
for (let i = 0; i < iterations; i++) {
const timeStart = performance.now()
for (let n = 0; n < nbNewTokens; n++) {
for (let n = 0; n < maxNewTokens; n++) {
const next: number = (await model.predict(List.of(tokens))).first();
tokens = tokens.push(next)
tokens = tokens.push(next)
}
inferenceTime += performance.now() - timeStart
}
console.log(`Inference time: ${(inferenceTime/ nbNewTokens / iterations).toFixed(2)} ms/token`)
console.log(`Inference time: ${(inferenceTime/ maxNewTokens / iterations).toFixed(2)} ms/token`)
}
await new Promise((resolve, reject) => {
server.once('close', resolve)
Expand Down
49 changes: 49 additions & 0 deletions cli/src/train_gpt.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import "@tensorflow/tfjs-node"
import { AutoTokenizer } from "@xenova/transformers";
import { models, processing, Dataset } from "@epfml/discojs";
import { List } from "immutable";

async function main(): Promise<void> {
const data = "Lorem ipsum dolor sit amet, consectetur adipis"
const seed = 42

const config: models.GPTConfig = {
modelType: 'gpt-nano',
lr: 0.01,
maxIter: 50,
evaluateEvery:50,
maxEvalBatches: 10,
contextLength: 16,
seed
}

const tokenizer = await AutoTokenizer.from_pretrained('Xenova/gpt2')

const tokenDataset = new Dataset([data])
.map((text: string) => processing.tokenize(tokenizer, text))
.flatten()
.batch(config.contextLength + 1, 1)
.map((tokens) => [tokens.pop(), tokens.last()] as [List<number>, number])
.repeat()
.batch(8);

const model = new models.GPT(config)
for await (const logs of model.train(tokenDataset, undefined)) {
console.log(logs)
}

let tokens = processing.tokenize(tokenizer, "Lorem");

const maxNewTokens = 14
for (let n = 0; n < maxNewTokens; n++) {
const next: number = (await model.predict(
List.of(tokens), { seed })
).first();
tokens = tokens.push(next)
}
const generation = tokenizer.decode(tokens.toArray(), { skip_special_tokens: true })
console.log(generation)
}

// You can run this example with "npm run run_gpt" from this folder
main().catch(console.error)
11 changes: 6 additions & 5 deletions discojs-node/src/loaders.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,13 @@ describe("image directory parser", () => {

describe("text parser", () => {
it("parses basic file", async () => {
const text = ["a", "b", "c"].join("\n")
await withFile(async ({ path }) => {
await fs.writeFile(path, ["a", "b", "c"].join("\n"));

const parsed = loadText(path);

expect(await parsed.size()).to.equal(3);
await fs.writeFile(path, text);
const sequences = await arrayFromAsync(loadText(path))
expect(sequences.length).to.equal(1);
expect(sequences[0]).to.equal(text);
});
});
});
26 changes: 19 additions & 7 deletions discojs-node/src/loaders/text.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
import * as fs from "node:fs/promises";
import * as readline from "node:readline/promises";

import createDebug from "debug";
import { createReadStream } from 'node:fs';
import { Dataset, Text } from "@epfml/discojs";

const debug = createDebug("discojs-node:loaders:text");

/**
* Returns chunks of text. Use `minChunkSize` to ensure that
* each chunk is bigger than the expected sequence length.
*
* @param path path to the text file to read
* @returns a dataset of tokenized input and label sequences
*/
export function load(path: string): Dataset<Text> {
return new Dataset(async function* () {
const input = (await fs.open(path)).createReadStream({ encoding: "utf8" });
// Create a stream to read the text file chunk by chunk
const stream = createReadStream(path, { encoding: "utf8" });
for await (const chunk of stream) {
if (typeof chunk !== 'string')
throw new Error('Expected file stream to yield string')

// `readline` is a bit overkill but seems standard
// https://nodejs.org/api/readline.html#example-read-file-stream-line-by-line
yield* readline.createInterface({ input, crlfDelay: Infinity });
debug("yield chunk of length: %o", chunk.length);
yield chunk
}
});
}
23 changes: 8 additions & 15 deletions discojs-web/src/loaders.spec.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import { describe, it, expect } from "vitest";

import { loadCSV, loadText } from "./loaders/index.js";

async function arrayFromAsync<T>(iter: AsyncIterable<T>): Promise<T[]> {
Expand All @@ -22,22 +21,16 @@ describe("csv parser", () => {
});

describe("text parser", () => {
it("loads", async () => {
it("loads a simple sequence", async () => {
const text = ["first", "second", "third"].join("\n")

// jsdom doesn't implement .text on File/Blob
// trick from https://github.com/jsdom/jsdom/issues/2555
const text = await (
await fetch(
// data URL content need to be url-encoded
["data:,first", "second", "third"].join("%0A"),
)
const file = await (
await fetch( "data:," + encodeURIComponent(text))
).blob();

const parsed = loadText(text);

expect(await arrayFromAsync(parsed)).to.have.ordered.members([
"first",
"second",
"third",
]);
const parsed = loadText(file)
expect(await parsed.size()).to.equal(1);
expect((await arrayFromAsync(parsed))[0]).to.equal(text);
});
});
25 changes: 0 additions & 25 deletions discojs-web/src/loaders/text.ts
Original file line number Diff line number Diff line change
@@ -1,35 +1,10 @@
import { Dataset, Text } from "@epfml/discojs";

class LineStream extends TransformStream<string, string> {
constructor() {
let current_line = "";

super({
transform: (chunk, controller) => {
const [head, ...lines] = chunk.split(/\r\n|\r|\n/);
const first_line = current_line + head;

if (lines.length === 0) {
current_line = first_line;
return;
}

controller.enqueue(first_line);
for (const line of lines.slice(0, -1)) controller.enqueue(line);

current_line = lines[lines.length - 1];
},
flush: (controller) => controller.enqueue(current_line),
});
}
}

export function load(file: Blob): Dataset<Text> {
return new Dataset(async function* () {
const reader = file
.stream()
.pipeThrough(new TextDecoderStream())
.pipeThrough(new LineStream())
.getReader();

while (true) {
Expand Down
Loading

0 comments on commit 7844f97

Please sign in to comment.