From 0f0ae6acc48918c9e5fe1e8ae0d508f790e1f96a Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Sun, 15 Dec 2024 21:20:54 -0500 Subject: [PATCH 01/23] feat: basic working gather implementation --- src/tensor/tensor.ts | 13 ++++++ tests/integration/gather.test.ts | 77 ++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 tests/integration/gather.test.ts diff --git a/src/tensor/tensor.ts b/src/tensor/tensor.ts index 759b9b7..ff04606 100644 --- a/src/tensor/tensor.ts +++ b/src/tensor/tensor.ts @@ -100,6 +100,19 @@ export class Tensor { return reluOp.forward(this); } + // In Tensor class + async gather(indices: Tensor): Promise<[Tensor, number]> { + // Convert indices to one-hot + const oneHot = new Float32Array(indices.shape[0] * this.shape[0]).fill(0); + for (let i = 0; i < indices.shape[0]; i++) { + oneHot[i * this.shape[0] + indices.data[i]] = 1; + } + const oneHotTensor = new Tensor(oneHot, [indices.shape[0], this.shape[0]], indices.requires_grad); + + // Use existing matmul + return oneHotTensor.matmul(this); + } + transpose() { const [rows, cols] = this.shape; const transposedData = new Float32Array(this.data.length); diff --git a/tests/integration/gather.test.ts b/tests/integration/gather.test.ts new file mode 100644 index 0000000..f27ad5a --- /dev/null +++ b/tests/integration/gather.test.ts @@ -0,0 +1,77 @@ +import { test, expect } from "@playwright/test"; + +test("Gather forward and backward pass", async ({ page }) => { + await page.goto("http://localhost:8080"); + + page.on("console", (msg) => { + console.log(msg); + }); + + await page.evaluate(() => { + return new Promise((resolve) => { + // @ts-expect-error ignore error for tests + import("/dist/bundle.js").then((module) => { + const { Tensor } = module; + + // @ts-expect-error ignore error for tests + window.runGatherTest = async function () { + // Create a simple embedding matrix with 3 embeddings of dimension 2 + const embeddings = new Tensor( + new Float32Array([ + 1.0, 2.0, // embedding 0 + 3.0, 4.0, // embedding 1 + 5.0, 6.0 // embedding 2 + ]), + [3, 2], + true + ); + + // Look up embeddings at indices 1, 0 (second embedding, then first) + const indices = new Tensor( + new Float32Array([1, 0]), + [2, 1], + false + ); + + // Forward pass - gather embeddings + const [output] = await embeddings.gather(indices); + + // Backward pass + await output.backward(); + + return { + embeddings: embeddings, + indices: indices, + output: output, + grad_embeddings: embeddings.grad, + }; + }; + resolve(); + }); + }); + }); + + // Run the test function in the browser context + // @ts-expect-error ignore error for tests + const result = await page.evaluate(() => window.runGatherTest()); + + expect(result.output.shape).toEqual([2, 2]); // 2 selected embeddings of dimension 2 + expect(result.grad_embeddings.shape).toEqual([3, 2]); // Same shape as input embeddings + + // Forward pass assertions - should get embeddings at indices 1 and 0 + const outputData = new Float32Array(Object.values(result.output.data)); + expect(outputData).toEqual(new Float32Array([ + 3.0, 4.0, // embedding at index 1 + 1.0, 2.0 // embedding at index 0 + ])); + + // Backward pass assertions - gradient should accumulate at the selected indices + const gradData = new Float32Array(Object.values(result.grad_embeddings.data)); + expect(gradData).toEqual(new Float32Array([ + 1.0, 1.0, // gradient for embedding 0 (selected second) + 1.0, 1.0, // gradient for embedding 1 (selected first) + 0.0, 0.0 // gradient for embedding 2 (not selected) + ])); + + await page.close(); +}); \ No newline at end of file From c36766de4cc5740824ddcb397c1c54ada2f13b20 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Mon, 16 Dec 2024 00:50:19 -0500 Subject: [PATCH 02/23] feat: embedding layer with matmul --- src/index.ts | 2 + src/layers/embedding.ts | 20 ++++++++ src/tensor/tensor.ts | 22 +++++++-- tests/integration/embedding.test.ts | 75 +++++++++++++++++++++++++++++ 4 files changed, 115 insertions(+), 4 deletions(-) create mode 100644 src/layers/embedding.ts create mode 100644 tests/integration/embedding.test.ts diff --git a/src/index.ts b/src/index.ts index ce384a6..004f0d5 100644 --- a/src/index.ts +++ b/src/index.ts @@ -8,3 +8,5 @@ export * from "./ops/log2.js"; export * from "./ops/ln.js"; export * from "./ops/relu.js"; export * from "./autograd/function.js"; +export * from "./layers/module.js"; +export * from "./layers/embedding.js"; diff --git a/src/layers/embedding.ts b/src/layers/embedding.ts new file mode 100644 index 0000000..36604ba --- /dev/null +++ b/src/layers/embedding.ts @@ -0,0 +1,20 @@ +import { Tensor } from "../tensor/tensor.js"; +import { Module } from "./module.js"; + +export class Embedding extends Module { + vocab_size: number; + emb_dim: number; + embedding: Tensor; + constructor(vocab_size: number, emb_dim: number) { + super("embedding"); + + this.vocab_size = vocab_size; + this.emb_dim = emb_dim; + this.embedding = Tensor.randn([vocab_size, emb_dim], true); + } + + async forward(...inputs: [Tensor]): Promise<[Tensor]> { + const [embeddings] = await this.embedding.gather(inputs[0]); + return [embeddings]; + } +} diff --git a/src/tensor/tensor.ts b/src/tensor/tensor.ts index ff04606..cab301d 100644 --- a/src/tensor/tensor.ts +++ b/src/tensor/tensor.ts @@ -52,6 +52,16 @@ export class Tensor { return Tensor.full(tensor.shape, 0, tensor.requires_grad); } + static randn(shape: number[], requires_grad = false) { + const data = new Float32Array(shape.reduce((a, b) => a * b)); + + for (let i = 0; i < data.length; i++) { + data[i] = Math.random() * 2 - 1; + } + + return new Tensor(data, shape, requires_grad); + } + async add(tensor: Tensor) { const addOp = await Add.create(); @@ -100,16 +110,20 @@ export class Tensor { return reluOp.forward(this); } - // In Tensor class async gather(indices: Tensor): Promise<[Tensor, number]> { // Convert indices to one-hot const oneHot = new Float32Array(indices.shape[0] * this.shape[0]).fill(0); for (let i = 0; i < indices.shape[0]; i++) { - oneHot[i * this.shape[0] + indices.data[i]] = 1; + let index = indices.data[i] + i * this.shape[0]; + // set one hot value for the whole vector + console.log("before setting one hot", oneHot.toString(), "at index", index); + oneHot.fill(1, index, index + 1); + console.log("after setting one hot", oneHot.toString(), "at index", index); } + const oneHotTensor = new Tensor(oneHot, [indices.shape[0], this.shape[0]], indices.requires_grad); - - // Use existing matmul + console.log("oneHotTensor", oneHotTensor.data.toString()); + return oneHotTensor.matmul(this); } diff --git a/tests/integration/embedding.test.ts b/tests/integration/embedding.test.ts new file mode 100644 index 0000000..c01ea76 --- /dev/null +++ b/tests/integration/embedding.test.ts @@ -0,0 +1,75 @@ +import { test, expect } from "@playwright/test"; + +test("Embedding forward pass with known values", async ({ page }) => { + await page.goto("http://localhost:8080"); + + page.on("console", (msg) => { + console.log(msg); + }); + + // Inject test function + await page.evaluate(() => { + return new Promise((resolve) => { + // @ts-expect-error ignore error for tests + import("/dist/bundle.js").then((module) => { + const { Tensor, Embedding } = module; + + window.runEmbeddingTest = async function () { + const vocabSize = 128; + const embeddingDim = 2; // Using small dim for easy verification + + // Create embedding layer + const embedding = new Embedding(vocabSize, embeddingDim); + console.log(embedding); + + // Create input tensor with indices + const inputIndices = new Tensor( + new Float32Array([1, 5, 10]), // Sample indices + [3], // Sequence length of 3 + false + ); + + // Forward pass + const [embeddings] = await embedding.forward(inputIndices); + + return { + inputIndices: Array.from(inputIndices.data), + embedding: embedding.embedding, + outputShape: embeddings.shape, + outputData: Array.from(embeddings.data) + }; + }; + resolve(); + }); + }); + }); + + // Run the test function in the browser context + const result = await page.evaluate(() => window.runEmbeddingTest()); + + // Validate shapes + expect(result.outputShape).toEqual([3, 2]); // Sequence length x Embedding dim + + const expectedOutput = [ + result.embedding.data[2], + result.embedding.data[3], + result.embedding.data[10], + result.embedding.data[11] , + result.embedding.data[20], + result.embedding.data[21] + ]; + + expect(result.outputData).toEqual(expectedOutput); + + await page.close(); +}); + +declare global { + interface Window { + runEmbeddingTest: () => Promise<{ + inputIndices: number[]; + outputShape: number[]; + outputData: number[]; + }>; + } +} \ No newline at end of file From b2765961afaefd2d685bc6a5e1cc63ba59e0bdbc Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Mon, 16 Dec 2024 00:51:56 -0500 Subject: [PATCH 03/23] fix: type error --- tests/integration/embedding.test.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration/embedding.test.ts b/tests/integration/embedding.test.ts index c01ea76..d1be2ce 100644 --- a/tests/integration/embedding.test.ts +++ b/tests/integration/embedding.test.ts @@ -68,6 +68,7 @@ declare global { interface Window { runEmbeddingTest: () => Promise<{ inputIndices: number[]; + embedding: { data: number[] }; outputShape: number[]; outputData: number[]; }>; From 54de6b4191cb15d19cd78fa86dc819042a6d6e2c Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Mon, 16 Dec 2024 00:52:56 -0500 Subject: [PATCH 04/23] style: let vs const --- src/layers/module.ts | 20 ++++++++++++++++++++ src/tensor/tensor.ts | 4 +--- 2 files changed, 21 insertions(+), 3 deletions(-) create mode 100644 src/layers/module.ts diff --git a/src/layers/module.ts b/src/layers/module.ts new file mode 100644 index 0000000..5312934 --- /dev/null +++ b/src/layers/module.ts @@ -0,0 +1,20 @@ +import { Tensor } from "../tensor/tensor"; +export abstract class Module { + protected name: string; + constructor(name: string) { + if (name === null || name === undefined) { + throw Error("Name cannot be null or undefined"); + } + + this.name = name; + } + + /** + * Abstract method that must be implemented by all layer subclasses + * Defines the forward pass computation of the layer + * @param inputs - Input tensor(s) to the layer + * @returns Output tensor(s) from the layer + */ + abstract forward(...inputs: [Tensor]): Promise<[Tensor]> + +} \ No newline at end of file diff --git a/src/tensor/tensor.ts b/src/tensor/tensor.ts index cab301d..4f82a78 100644 --- a/src/tensor/tensor.ts +++ b/src/tensor/tensor.ts @@ -114,11 +114,9 @@ export class Tensor { // Convert indices to one-hot const oneHot = new Float32Array(indices.shape[0] * this.shape[0]).fill(0); for (let i = 0; i < indices.shape[0]; i++) { - let index = indices.data[i] + i * this.shape[0]; + const index = indices.data[i] + i * this.shape[0]; // set one hot value for the whole vector - console.log("before setting one hot", oneHot.toString(), "at index", index); oneHot.fill(1, index, index + 1); - console.log("after setting one hot", oneHot.toString(), "at index", index); } const oneHotTensor = new Tensor(oneHot, [indices.shape[0], this.shape[0]], indices.requires_grad); From b080f0b39139b83c039950630b35ca4b590cdbe0 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Mon, 16 Dec 2024 01:05:29 -0500 Subject: [PATCH 05/23] fix: broadcast for when adding value that has same shape as last dim of a --- src/ops/add.ts | 7 ++++++- src/tensor/tensor.ts | 11 +++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/ops/add.ts b/src/ops/add.ts index a8de232..e902384 100644 --- a/src/ops/add.ts +++ b/src/ops/add.ts @@ -10,7 +10,12 @@ export class Add extends BinaryOp { if (b.shape.length === 1 && b.shape[0] === 1) { // Broadcast scalar b = Tensor.full(a.shape, b.data[0], b.requires_grad); - } else { + } + else if (b.shape.length == 1 && b.shape[0] == a.shape[1]) { + b = Tensor.broadcast(b, a.shape[0], b.requires_grad); + console.log("broadcasted", b.data.toString()); + } + else { throw new Error( `Incompatible shapes for Add: ${a.shape} and ${b.shape}`, ); diff --git a/src/tensor/tensor.ts b/src/tensor/tensor.ts index 4f82a78..8b68365 100644 --- a/src/tensor/tensor.ts +++ b/src/tensor/tensor.ts @@ -62,6 +62,17 @@ export class Tensor { return new Tensor(data, shape, requires_grad); } + static broadcast(tensor: Tensor, size: number, requires_grad = false) { + const shape = [size, ...tensor.shape]; + const data = new Float32Array(shape.reduce((a, b) => a * b)); + + for (let i = 0; i < data.length; i++) { + data[i] = tensor.data[i % tensor.shape.reduce((a, b) => a * b)]; + } + + return new Tensor(data, shape, requires_grad); + } + async add(tensor: Tensor) { const addOp = await Add.create(); From b7a02a2f5cf7aa65a7a1a609703002643e242a25 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Mon, 16 Dec 2024 01:05:40 -0500 Subject: [PATCH 06/23] feat: linear layer --- src/index.ts | 1 + src/layers/linear.ts | 20 ++++++++ tests/integration/linear.test.ts | 84 ++++++++++++++++++++++++++++++++ 3 files changed, 105 insertions(+) create mode 100644 src/layers/linear.ts create mode 100644 tests/integration/linear.test.ts diff --git a/src/index.ts b/src/index.ts index 004f0d5..33e77a3 100644 --- a/src/index.ts +++ b/src/index.ts @@ -10,3 +10,4 @@ export * from "./ops/relu.js"; export * from "./autograd/function.js"; export * from "./layers/module.js"; export * from "./layers/embedding.js"; +export * from "./layers/linear.js"; diff --git a/src/layers/linear.ts b/src/layers/linear.ts new file mode 100644 index 0000000..196c4bd --- /dev/null +++ b/src/layers/linear.ts @@ -0,0 +1,20 @@ +import { Tensor } from "../tensor/tensor.js"; +import { Module } from "./module.js"; + +export class Linear extends Module { + weight: Tensor; + bias: Tensor; + + constructor(inputSize: number, outputSize: number) { + super("linear"); + this.weight = Tensor.randn([inputSize, outputSize], true); + this.bias = Tensor.randn([outputSize], true); + } + + async forward(...inputs: [Tensor]): Promise<[Tensor]> { + const [input] = inputs; + const [output] = await input.matmul(this.weight); + const [outputBias] = await output.add(this.bias); + return [outputBias]; + } +} \ No newline at end of file diff --git a/tests/integration/linear.test.ts b/tests/integration/linear.test.ts new file mode 100644 index 0000000..3e8a0c6 --- /dev/null +++ b/tests/integration/linear.test.ts @@ -0,0 +1,84 @@ +import { test, expect } from "@playwright/test"; + +test("Linear forward pass with known values", async ({ page }) => { + await page.goto("http://localhost:8080"); + + page.on("console", (msg) => { + console.log(msg); + }); + + // Inject test function + await page.evaluate(() => { + return new Promise((resolve) => { + // @ts-expect-error ignore error for tests + import("/dist/bundle.js").then((module) => { + const { Tensor, Linear } = module; + + window.runLinearTest = async function () { + const inputSize = 3; + const outputSize = 2; + + // Create linear layer + const linear = new Linear(inputSize, outputSize); + + // Set known weights and biases for deterministic testing + linear.weight.data = new Float32Array([ + 0.1, 0.2, // First row + 0.3, 0.4, // Second row + 0.5, 0.6 // Third row + ]); + + linear.bias.data = new Float32Array([0.1, 0.2]); + + // Create input tensor + const input = new Tensor( + new Float32Array([1.0, 2.0, 3.0]), // Sample input + [1, 3], // Batch size 1, input size 3 + false + ); + + // Forward pass + const [output] = await linear.forward(input); + + return { + inputData: Array.from(input.data), + weights: Array.from(linear.weight.data), + biases: Array.from(linear.bias.data), + outputShape: output.shape, + outputData: Array.from(output.data) + }; + }; + resolve(); + }); + }); + }); + + // Run the test function in the browser context + const result = await page.evaluate(() => window.runLinearTest()); + + // Validate shapes + expect(result.outputShape).toEqual([1, 2]); // Batch size x Output size + + // Calculate expected output manually: + // output[0] = (1.0 * 0.1 + 2.0 * 0.3 + 3.0 * 0.5) + 0.1 = 2.0 + // output[1] = (1.0 * 0.2 + 2.0 * 0.4 + 3.0 * 0.6) + 0.2 = 2.8 + const expectedOutput = [2.3, 3.0]; + + // Check if outputs match expected values within a small tolerance + expect(result.outputData[0]).toBeCloseTo(expectedOutput[0], 5); + expect(result.outputData[1]).toBeCloseTo(expectedOutput[1], 5); + + await page.close(); +}); + +declare global { + interface Window { + runLinearTest: () => Promise<{ + inputData: number[]; + weights: number[]; + biases: number[]; + outputShape: number[]; + outputData: number[]; + }>; + } +} \ No newline at end of file From d1781871e240116098dda1dbf4e312069197a736 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Mon, 16 Dec 2024 01:50:23 -0500 Subject: [PATCH 07/23] feat: div op (prob unneccessary) --- src/index.ts | 1 + src/ops/div.ts | 42 ++++++++++++++++++++++ src/shaders/div.ts | 22 ++++++++++++ src/tensor/tensor.ts | 6 ++++ tests/integration/div.test.ts | 68 +++++++++++++++++++++++++++++++++++ 5 files changed, 139 insertions(+) create mode 100644 src/ops/div.ts create mode 100644 src/shaders/div.ts create mode 100644 tests/integration/div.test.ts diff --git a/src/index.ts b/src/index.ts index 33e77a3..83d8ea8 100644 --- a/src/index.ts +++ b/src/index.ts @@ -11,3 +11,4 @@ export * from "./autograd/function.js"; export * from "./layers/module.js"; export * from "./layers/embedding.js"; export * from "./layers/linear.js"; +export * from "./ops/div.js"; diff --git a/src/ops/div.ts b/src/ops/div.ts new file mode 100644 index 0000000..afb0173 --- /dev/null +++ b/src/ops/div.ts @@ -0,0 +1,42 @@ +import { BinaryOp } from "../autograd/function.js"; +import { Tensor } from "../tensor/tensor.js"; +import { divShader } from "../shaders/div.js"; + +export class Div extends BinaryOp { + protected readonly shader: string = divShader; + + validateShapes(a: Tensor, b: Tensor): Tensor { + if (!a.shape.every((value, index) => value === b.shape[index])) { + if (b.shape.length === 1 && b.shape[0] === 1) { + // Broadcast scalar + b = Tensor.full(a.shape, b.data[0], b.requires_grad); + } else { + throw new Error( + `Incompatible shapes for Div: ${a.shape} and ${b.shape}`, + ); + } + } + return b; + } + + async backward(grad_output: Tensor): Promise { + const [a, b] = this.inputs; + const [aRequiresGrad, bRequiresGrad] = this.requiresGrad; + + const grad_a_result = await this.forward(grad_output, b); + const grad_a = aRequiresGrad ? grad_a_result[0] : null; + if (grad_a !== null) { + await a.setGrad(grad_a); + } + + const grad_b_result = await this.forward(a, grad_output); + const grad_b = bRequiresGrad ? grad_b_result[0] : null; + if (grad_b !== null) { + await b.setGrad(grad_b); + } + + return [grad_a, grad_b].filter( + (tensor): tensor is Tensor => tensor !== null, + ); + } +} \ No newline at end of file diff --git a/src/shaders/div.ts b/src/shaders/div.ts new file mode 100644 index 0000000..bf62dcb --- /dev/null +++ b/src/shaders/div.ts @@ -0,0 +1,22 @@ +export const divShader = ` +struct Dimensions { + M: u32, + N: u32, +} + +@group(0) @binding(0) var dimensions: Dimensions; +@group(0) @binding(1) var a: array; +@group(0) @binding(2) var scalar: array; +@group(0) @binding(3) var result: array; + +@compute @workgroup_size(64) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let global_idx = global_id.x; + let row = global_idx / dimensions.N; + let col = global_idx % dimensions.N; + + if (global_idx < dimensions.M * dimensions.N) { + result[row * dimensions.N + col] = a[row * dimensions.N + col] / scalar[row * dimensions.N + col]; + } +} +`; \ No newline at end of file diff --git a/src/tensor/tensor.ts b/src/tensor/tensor.ts index 8b68365..ad5047d 100644 --- a/src/tensor/tensor.ts +++ b/src/tensor/tensor.ts @@ -6,6 +6,7 @@ import { Log2 } from "../ops/log2.js"; import { ReLU } from "../ops/relu.js"; import { Exp2 } from "../ops/exp2.js"; import { Ln } from "../ops/ln.js"; +import { Div } from "../ops/div.js"; import { AutogradFunction } from "../autograd/function.js"; @@ -85,6 +86,11 @@ export class Tensor { return mulOp.forward(this, tensor); } + async div(tensor: Tensor): Promise<[Tensor, number]> { + const divOp = await Div.create(); + return divOp.forward(this, tensor); + } + async matmul(tensor: Tensor) { const matmulOp = await MatMul.create(); diff --git a/tests/integration/div.test.ts b/tests/integration/div.test.ts new file mode 100644 index 0000000..d4aa49a --- /dev/null +++ b/tests/integration/div.test.ts @@ -0,0 +1,68 @@ +import { test, expect } from "@playwright/test"; + + +test("Elementwise scalar/broadcasted division forward and backward pass", async ({ + page, +}) => { + await page.goto("http://localhost:8080"); + + page.on("console", (msg) => { + console.log(msg); + }); + + // Inject your test function + await page.evaluate(() => { + return new Promise((resolve) => { + // @ts-expect-error ignore error for tests + import("/dist/bundle.js").then((module) => { + const { Tensor } = module; + + // @ts-expect-error ignore error for tests + window.runDivTest = async function () { + const x = new Tensor( + new Float32Array([2.0, 4.0, 6.0, 8.0, 10.0, 12.0]), + [2, 3], + true, + ); + const y = new Tensor(new Float32Array([2.0]), [1], false); + + // Forward pass + const [z] = await x.div(y); + + await z.backward(); + + return { + x: x, + y: y, + z: z, + grad_x: x.grad, + grad_y: y.grad, + }; + }; + resolve(); + }); + }); + }); + + // Run the test function in the browser context + // @ts-expect-error ignore error for tests + const result = await page.evaluate(() => window.runDivTest()); + + // Perform assertions + expect(result.x.shape).toEqual([2, 3]); + expect(result.y.shape).toEqual([1]); + expect(result.z.shape).toEqual([2, 3]); + expect(result.grad_x.shape).toEqual([2, 3]); + // check that grad_y is undefined + expect(result.grad_y).toBeNull(); + + const zData = new Float32Array(Object.values(result.z.data)); + const gradXData = new Float32Array(Object.values(result.grad_x.data)); + + expect(zData).toEqual(new Float32Array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])); + + // For division, gradient with respect to x is 1/y + expect(gradXData).toEqual(new Float32Array([0.5, 0.5, 0.5, 0.5, 0.5, 0.5])); + + await page.close(); +}); \ No newline at end of file From e1830a0eab4b755d2125d8e6979aa90db2a1215c Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Mon, 16 Dec 2024 01:57:46 -0500 Subject: [PATCH 08/23] feat: mean, std, sqrt by cursor. are they right? i'm too tired to check rn --- src/tensor/tensor.ts | 87 ++++++++++++++++++++++++++ tests/integration/tensor_stats.test.ts | 78 +++++++++++++++++++++++ 2 files changed, 165 insertions(+) create mode 100644 tests/integration/tensor_stats.test.ts diff --git a/src/tensor/tensor.ts b/src/tensor/tensor.ts index ad5047d..e09c513 100644 --- a/src/tensor/tensor.ts +++ b/src/tensor/tensor.ts @@ -86,6 +86,92 @@ export class Tensor { return mulOp.forward(this, tensor); } + async sub(tensor: Tensor) { + const negOne = Tensor.full(tensor.shape, -1, false); + const [negTensor] = await tensor.mul(negOne); + return this.add(negTensor); + } + + async mean(dims: number[]): Promise { + // Calculate new shape after reduction + const shape = this.shape.slice(); + const size = dims.reduce((acc, dim) => acc * shape[dim], 1); + + dims.sort((a, b) => b - a); // Sort in descending order to remove correctly + dims.forEach(dim => shape.splice(dim, 1)); + if (shape.length === 0) shape.push(1); + + const result = new Float32Array(shape.reduce((a, b) => a * b, 1)); + + // For 1D case + if (this.shape.length === 1 && dims.includes(0)) { + let sum = 0; + for (let i = 0; i < this.data.length; i++) { + sum += this.data[i]; + } + result[0] = sum / size; + return new Tensor(result, shape, this.requires_grad); + } + + // For higher dimensions (keeping existing logic for 2D) + const stride = this.shape[1]; + for (let i = 0; i < this.shape[0]; i++) { + let sum = 0; + for (let j = 0; j < stride; j++) { + sum += this.data[i * stride + j]; + } + result[i] = sum / size; + } + + return new Tensor(result, shape, this.requires_grad); + } + + async variance(dims: number[]): Promise { + const mean = await this.mean(dims); + const shape = this.shape.slice(); + const size = dims.reduce((acc, dim) => acc * shape[dim], 1); + + dims.sort((a, b) => b - a); + dims.forEach(dim => shape.splice(dim, 1)); + if (shape.length === 0) shape.push(1); + + const result = new Float32Array(shape.reduce((a, b) => a * b, 1)); + + // For 1D case + if (this.shape.length === 1 && dims.includes(0)) { + let sumSquaredDiff = 0; + const meanValue = mean.data[0]; + for (let i = 0; i < this.data.length; i++) { + const diff = this.data[i] - meanValue; + sumSquaredDiff += diff * diff; + } + result[0] = sumSquaredDiff / size; + return new Tensor(result, shape, this.requires_grad); + } + + // For higher dimensions + const stride = this.shape[1]; + for (let i = 0; i < this.shape[0]; i++) { + let sumSquaredDiff = 0; + const meanValue = mean.data[i]; + for (let j = 0; j < stride; j++) { + const diff = this.data[i * stride + j] - meanValue; + sumSquaredDiff += diff * diff; + } + result[i] = sumSquaredDiff / size; + } + + return new Tensor(result, shape, this.requires_grad); + } + + async sqrt(): Promise { + const result = new Float32Array(this.data.length); + for (let i = 0; i < this.data.length; i++) { + result[i] = Math.sqrt(this.data[i]); + } + return new Tensor(result, this.shape.slice(), this.requires_grad); + } + async div(tensor: Tensor): Promise<[Tensor, number]> { const divOp = await Div.create(); return divOp.forward(this, tensor); @@ -205,4 +291,5 @@ export class Tensor { return topo_order; } + } diff --git a/tests/integration/tensor_stats.test.ts b/tests/integration/tensor_stats.test.ts new file mode 100644 index 0000000..2fc4cfc --- /dev/null +++ b/tests/integration/tensor_stats.test.ts @@ -0,0 +1,78 @@ +import { Tensor } from "../../src/tensor/tensor.js"; +import { test, expect } from "@playwright/test"; + +test.describe("Tensor Statistics Operations", () => { + test.describe("mean", () => { + test("should calculate mean along specified dimensions", async () => { + const data = new Float32Array([1, 2, 3, 4, 5, 6]); + const tensor = new Tensor(data, [2, 3], false); + + const mean = await tensor.mean([1]); + expect(mean.shape).toEqual([2]); + expect(Array.from(mean.data)).toEqual([2, 5]); // [mean(1,2,3), mean(4,5,6)] + }); + + test("should handle single dimension tensors", async () => { + const data = new Float32Array([1, 2, 3, 4]); + const tensor = new Tensor(data, [4], false); + + const mean = await tensor.mean([0]); + expect(mean.shape).toEqual([1]); + expect(mean.data[0]).toBeCloseTo(2.5); // mean(1,2,3,4) + }); + }); + + test.describe("variance", () => { + test("should calculate variance along specified dimensions", async () => { + const data = new Float32Array([1, 2, 3, 4, 5, 6]); + const tensor = new Tensor(data, [2, 3], false); + + const variance = await tensor.variance([1]); + expect(variance.shape).toEqual([2]); + // Variance of [1,2,3] and [4,5,6] + expect(Array.from(variance.data).map(x => Number(x.toFixed(2)))).toEqual([0.67, 0.67]); + }); + + test("should handle single dimension tensors", async () => { + const data = new Float32Array([2, 4, 4, 6]); + const tensor = new Tensor(data, [4], false); + + const variance = await tensor.variance([0]); + expect(variance.shape).toEqual([1]); + expect(variance.data[0]).toBeCloseTo(2); // variance of [2,4,4,6] + }); + }); + + test.describe("sqrt", () => { + test("should calculate element-wise square root", async () => { + const data = new Float32Array([1, 4, 9, 16]); + const tensor = new Tensor(data, [4], false); + + const sqrt = await tensor.sqrt(); + expect(sqrt.shape).toEqual([4]); + expect(Array.from(sqrt.data)).toEqual([1, 2, 3, 4]); + }); + + test("should handle multi-dimensional tensors", async () => { + const data = new Float32Array([1, 4, 9, 16, 25, 36]); + const tensor = new Tensor(data, [2, 3], false); + + const sqrt = await tensor.sqrt(); + expect(sqrt.shape).toEqual([2, 3]); + expect(Array.from(sqrt.data)).toEqual([1, 2, 3, 4, 5, 6]); + }); + }); + + test.describe("combined operations", () => { + test("should correctly compute standard deviation using sqrt(variance)", async () => { + const data = new Float32Array([2, 4, 4, 6]); + const tensor = new Tensor(data, [4], false); + + const variance = await tensor.variance([0]); + const stdDev = await variance.sqrt(); + + expect(stdDev.shape).toEqual([1]); // The shape should be [1] for a scalar result + expect(stdDev.data[0]).toBeCloseTo(Math.sqrt(2)); // The actual value check + }); + }); +}); \ No newline at end of file From 9cf4b635316b9aee5f89be00dfa60fc961d4153b Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Tue, 17 Dec 2024 01:03:34 -0500 Subject: [PATCH 09/23] feat: layernorm + small tests --- src/index.ts | 1 + src/layers/norm.ts | 44 ++++++++++++++++++ src/ops/add.ts | 28 +++++++++++- src/ops/div.ts | 22 ++++++++- src/ops/mul.ts | 34 ++++++++++---- src/tensor/tensor.ts | 6 +++ tests/integration/norm.test.ts | 82 ++++++++++++++++++++++++++++++++++ 7 files changed, 205 insertions(+), 12 deletions(-) create mode 100644 src/layers/norm.ts create mode 100644 tests/integration/norm.test.ts diff --git a/src/index.ts b/src/index.ts index 83d8ea8..c0a965d 100644 --- a/src/index.ts +++ b/src/index.ts @@ -11,4 +11,5 @@ export * from "./autograd/function.js"; export * from "./layers/module.js"; export * from "./layers/embedding.js"; export * from "./layers/linear.js"; +export * from "./layers/norm.js"; export * from "./ops/div.js"; diff --git a/src/layers/norm.ts b/src/layers/norm.ts new file mode 100644 index 0000000..8491d30 --- /dev/null +++ b/src/layers/norm.ts @@ -0,0 +1,44 @@ +import { Tensor } from "../tensor/tensor.js"; +import { Module } from "./module.js"; + +export class LayerNorm extends Module { + normalized_shape: number[]; + eps: Tensor; + gamma: Tensor; + beta: Tensor; + constructor(normalized_shape: number[], eps: number) { + super("layer_norm"); + this.normalized_shape = normalized_shape; + // eps should be [2, 1] for broadcasting + this.eps = Tensor.full([1], eps); // Make eps a scalar tensor + // gamma and beta should match the feature dimension + this.gamma = Tensor.full([1, normalized_shape[0]], 1); // [1, 3] for broadcasting + this.beta = Tensor.full([1, normalized_shape[0]], 0); // [1, 3] for broadcasting + } + + async forward(x: Tensor): Promise<[Tensor]> { + const reduction_dims = [1]; // Reduce over the feature dimension + + // Calculate mean and reshape for broadcasting + const mean = await x.mean(reduction_dims); + mean.shape = [mean.shape[0], 1]; // [2, 1] + + const variance = await x.variance(reduction_dims); + variance.shape = [variance.shape[0], 1]; // [2, 1] + + console.log("x shape:", x.shape); // [2, 3] + console.log("mean shape:", mean.shape); // [2, 1] + console.log("variance shape:", variance.shape); // [2, 1] + console.log("gamma shape:", this.gamma.shape); // [1, 3] + console.log("beta shape:", this.beta.shape); // [1, 3] + + const [numerator] = await x.sub(mean); // [2, 3] + const [denominator] = await variance.add(this.eps); + const sqrtDenom = await denominator.sqrt(); + const [normalized] = await numerator.div(sqrtDenom); + + const [gamma] = await normalized.mul(this.gamma); // [2, 3] * [1, 3] -> [2, 3] + const [beta] = await gamma.add(this.beta); // [2, 3] + [1, 3] -> [2, 3] + return [beta]; + } +} \ No newline at end of file diff --git a/src/ops/add.ts b/src/ops/add.ts index e902384..72d3103 100644 --- a/src/ops/add.ts +++ b/src/ops/add.ts @@ -11,9 +11,33 @@ export class Add extends BinaryOp { // Broadcast scalar b = Tensor.full(a.shape, b.data[0], b.requires_grad); } - else if (b.shape.length == 1 && b.shape[0] == a.shape[1]) { + else if (b.shape.length === 1 && b.shape[0] === a.shape[1]) { + // Broadcast [m] to [n, m] b = Tensor.broadcast(b, a.shape[0], b.requires_grad); - console.log("broadcasted", b.data.toString()); + } + else if (b.shape.length === 2 && b.shape[1] === 1) { + // Broadcast [n, 1] to [n, m] + const newShape = [b.shape[0], a.shape[1]]; + console.log("Broadcasting [n,1] to shape:", newShape); + const newData = new Float32Array(newShape[0] * newShape[1]); + for (let i = 0; i < b.shape[0]; i++) { + for (let j = 0; j < a.shape[1]; j++) { + newData[i * a.shape[1] + j] = b.data[i]; + } + } + b = new Tensor(newData, newShape, b.requires_grad); + } + else if (b.shape.length === 2 && b.shape[0] === 1 && b.shape[1] === a.shape[1]) { + // Broadcast [1, m] to [n, m] + const newShape = [a.shape[0], b.shape[1]]; + console.log("Broadcasting [1,m] to shape:", newShape); + const newData = new Float32Array(newShape[0] * newShape[1]); + for (let i = 0; i < a.shape[0]; i++) { + for (let j = 0; j < b.shape[1]; j++) { + newData[i * b.shape[1] + j] = b.data[j]; + } + } + b = new Tensor(newData, newShape, b.requires_grad); } else { throw new Error( diff --git a/src/ops/div.ts b/src/ops/div.ts index afb0173..6b4a704 100644 --- a/src/ops/div.ts +++ b/src/ops/div.ts @@ -10,14 +10,32 @@ export class Div extends BinaryOp { if (b.shape.length === 1 && b.shape[0] === 1) { // Broadcast scalar b = Tensor.full(a.shape, b.data[0], b.requires_grad); - } else { + } + else if (b.shape.length === 1 && b.shape[0] === a.shape[1]) { + // Broadcast [m] to [n, m] + b = Tensor.broadcast(b, a.shape[0], b.requires_grad); + } + else if (b.shape.length === 2 && b.shape[1] === 1) { + // Broadcast [n, 1] to [n, m] + const newShape = [b.shape[0], a.shape[1]]; + console.log("Broadcasting [n,1] to shape:", newShape); + const newData = new Float32Array(newShape[0] * newShape[1]); + // Repeat the values across the second dimension + for (let i = 0; i < b.shape[0]; i++) { + for (let j = 0; j < a.shape[1]; j++) { + newData[i * a.shape[1] + j] = b.data[i]; + } + } + b = new Tensor(newData, newShape, b.requires_grad); + } + else { throw new Error( `Incompatible shapes for Div: ${a.shape} and ${b.shape}`, ); } } return b; - } +} async backward(grad_output: Tensor): Promise { const [a, b] = this.inputs; diff --git a/src/ops/mul.ts b/src/ops/mul.ts index 8f8a3f9..d76a3d8 100644 --- a/src/ops/mul.ts +++ b/src/ops/mul.ts @@ -6,16 +6,34 @@ export class Mul extends BinaryOp { protected readonly shader: string = mulShader; validateShapes(a: Tensor, b: Tensor): Tensor { - if (!a.shape.every((value, index) => value === b.shape[index])) { - if (b.shape.length === 1 && b.shape[0] === 1) { - // Broadcast scalar - b = Tensor.full(a.shape, b.data[0], b.requires_grad); - } else { - throw new Error( - `Incompatible shapes for Mul: ${a.shape} and ${b.shape}`, - ); + // Handle broadcasting for 2D tensors + if (a.shape.length === 2 && b.shape.length === 1) { + // Broadcasting b [n] to [m, n] + const newShape = [a.shape[0], b.shape[0]]; + b = Tensor.full(newShape, b.data[0], b.requires_grad); + } else if (a.shape.length === 2 && b.shape.length === 2) { + // Handle [m, 1] broadcasting to [m, n] + if (b.shape[1] === 1) { + const newShape = [b.shape[0], a.shape[1]]; + b = Tensor.full(newShape, b.data[0], b.requires_grad); } + } else if (b.shape.length === 1 && b.shape[0] === 1) { + // Scalar broadcasting + b = Tensor.full(a.shape, b.data[0], b.requires_grad); + } else if (!a.shape.every((value, index) => value === b.shape[index])) { + throw new Error( + `Incompatible shapes for Mul: ${a.shape} and ${b.shape}`, + ); } + + // Ensure 2D shapes for WebGPU operations + if (a.shape.length === 1) { + a.shape = [a.shape[0], 1]; + } + if (b.shape.length === 1) { + b.shape = [b.shape[0], 1]; + } + return b; } diff --git a/src/tensor/tensor.ts b/src/tensor/tensor.ts index e09c513..163a643 100644 --- a/src/tensor/tensor.ts +++ b/src/tensor/tensor.ts @@ -87,6 +87,12 @@ export class Tensor { } async sub(tensor: Tensor) { + if (tensor.shape.length === 1 && this.shape.length === 2) { + // Broadcasting [n] to [m, n] + const newShape = [this.shape[0], tensor.shape[0]]; + tensor = Tensor.full(newShape, tensor.data[0], tensor.requires_grad); + } + const negOne = Tensor.full(tensor.shape, -1, false); const [negTensor] = await tensor.mul(negOne); return this.add(negTensor); diff --git a/tests/integration/norm.test.ts b/tests/integration/norm.test.ts new file mode 100644 index 0000000..540a404 --- /dev/null +++ b/tests/integration/norm.test.ts @@ -0,0 +1,82 @@ +import { test, expect } from "@playwright/test"; + +test("LayerNorm forward pass with known values", async ({ page }) => { + await page.goto("http://localhost:8080"); + + page.on("console", (msg) => { + console.log(msg); + }); + + // Inject test function + await page.evaluate(() => { + return new Promise((resolve) => { + // @ts-expect-error ignore error for tests + import("/dist/bundle.js").then((module) => { + const { Tensor, LayerNorm } = module; + + window.runLayerNormTest = async function () { + // Create a simple input tensor with known values + const input = new Tensor( + new Float32Array([1, 2, 3, 4, 5, 6]), // Sample values + [2, 3], // 2 sequences, 3 features each + false + ); + + // Create LayerNorm with normalized_shape [3] + const layerNorm = new LayerNorm([3], 1e-5); + + // Set known values for gamma and beta + layerNorm.gamma.data.set([1.0, 1.0, 1.0]); + layerNorm.beta.data.set([0.0, 0.0, 0.0]); + + // Forward pass + const [output] = await layerNorm.forward(input); + + return { + inputData: Array.from(input.data), + inputShape: input.shape, + outputShape: output.shape, + outputData: Array.from(output.data), + gamma: Array.from(layerNorm.gamma.data), + beta: Array.from(layerNorm.beta.data) + }; + }; + resolve(); + }); + }); + }); + + // Run the test function in the browser context + const result = await page.evaluate(() => window.runLayerNormTest()); + + // Validate shapes + expect(result.inputShape).toEqual([2, 3]); + expect(result.outputShape).toEqual([2, 3]); + + // For the input [1,2,3] and [4,5,6], with gamma=1 and beta=0, + // we can pre-calculate the expected normalized values + const expectedOutput = [ + -1.224744871391589, 0, 1.224744871391589, // First sequence normalized + -1.224744871391589, 0, 1.224744871391589 // Second sequence normalized + ]; + + // Check if output matches expected values (using approximate equality) + result.outputData.forEach((val, idx) => { + expect(val).toBeCloseTo(expectedOutput[idx], 4); + }); + + await page.close(); +}); + +declare global { + interface Window { + runLayerNormTest: () => Promise<{ + inputData: number[]; + inputShape: number[]; + outputShape: number[]; + outputData: number[]; + gamma: number[]; + beta: number[]; + }>; + } +} \ No newline at end of file From 1a3a4f224c2c05637da780f97a119450fe7f4f5f Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Tue, 17 Dec 2024 01:33:21 -0500 Subject: [PATCH 10/23] fix: mul tests --- src/ops/mul.ts | 37 +++++++++++++------------------------ src/tensor/tensor.ts | 1 - 2 files changed, 13 insertions(+), 25 deletions(-) diff --git a/src/ops/mul.ts b/src/ops/mul.ts index d76a3d8..7c1ec86 100644 --- a/src/ops/mul.ts +++ b/src/ops/mul.ts @@ -6,34 +6,23 @@ export class Mul extends BinaryOp { protected readonly shader: string = mulShader; validateShapes(a: Tensor, b: Tensor): Tensor { - // Handle broadcasting for 2D tensors - if (a.shape.length === 2 && b.shape.length === 1) { - // Broadcasting b [n] to [m, n] - const newShape = [a.shape[0], b.shape[0]]; - b = Tensor.full(newShape, b.data[0], b.requires_grad); - } else if (a.shape.length === 2 && b.shape.length === 2) { - // Handle [m, 1] broadcasting to [m, n] - if (b.shape[1] === 1) { - const newShape = [b.shape[0], a.shape[1]]; - b = Tensor.full(newShape, b.data[0], b.requires_grad); + if (!a.shape.every((value, index) => value === b.shape[index])) { + if (b.shape.length === 1 && b.shape[0] === 1) { + // Broadcast scalar + b = Tensor.full(a.shape, b.data[0], b.requires_grad); + } else if (b.shape[0] === 1 && b.shape[1] === a.shape[1]) { + // broadcast [1, n] to [m, n] + b = Tensor.full(a.shape, b.data[0], b.requires_grad); + } + else { + throw new Error( + `Incompatible shapes for Mul: ${a.shape} and ${b.shape}`, + ); } - } else if (b.shape.length === 1 && b.shape[0] === 1) { - // Scalar broadcasting - b = Tensor.full(a.shape, b.data[0], b.requires_grad); - } else if (!a.shape.every((value, index) => value === b.shape[index])) { - throw new Error( - `Incompatible shapes for Mul: ${a.shape} and ${b.shape}`, - ); } - - // Ensure 2D shapes for WebGPU operations - if (a.shape.length === 1) { + if (a.shape.length === 1){ a.shape = [a.shape[0], 1]; } - if (b.shape.length === 1) { - b.shape = [b.shape[0], 1]; - } - return b; } diff --git a/src/tensor/tensor.ts b/src/tensor/tensor.ts index 163a643..a09c38f 100644 --- a/src/tensor/tensor.ts +++ b/src/tensor/tensor.ts @@ -229,7 +229,6 @@ export class Tensor { } const oneHotTensor = new Tensor(oneHot, [indices.shape[0], this.shape[0]], indices.requires_grad); - console.log("oneHotTensor", oneHotTensor.data.toString()); return oneHotTensor.matmul(this); } From 210a6ee03324e5dd7582f9afe3da8bbe8ca052d7 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Wed, 18 Dec 2024 01:21:39 -0500 Subject: [PATCH 11/23] feat: Tensor.concat --- src/tensor/tensor.ts | 54 ++++++++++++++++++++++++++++++++++++++ tests/unit/tensor.test.ts | 55 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+) diff --git a/src/tensor/tensor.ts b/src/tensor/tensor.ts index a09c38f..aed7cfc 100644 --- a/src/tensor/tensor.ts +++ b/src/tensor/tensor.ts @@ -297,4 +297,58 @@ export class Tensor { return topo_order; } + async concat(tensor: Tensor, axis: number): Promise { + // Validate axis + if (axis < 0 || axis >= this.shape.length) { + throw new Error(`Invalid axis ${axis}. Must be between 0 and ${this.shape.length - 1}`); + } + + // For axis 0 concatenation, all other dimensions must match exactly + if (axis === 0) { + // For 1D tensors, they must have the same shape + if (this.shape.length === 1 && this.shape[0] !== tensor.shape[0]) { + throw new Error(`Shape mismatch: tensors have different shapes at non-concatenating dimensions`); + } + } + + // For other axes, validate shapes - all dimensions except concat axis must match + for (let i = 0; i < this.shape.length; i++) { + if (i !== axis && this.shape[i] !== tensor.shape[i]) { + throw new Error(`Shape mismatch: tensors have different shapes at non-concatenating dimensions`); + } + } + + // Calculate new shape + const newShape = [...this.shape]; + newShape[axis] += tensor.shape[axis]; + + // Create new data array + const newData = new Float32Array(newShape.reduce((a, b) => a * b)); + + // Calculate strides for both tensors + const stride = this.shape[axis]; + const preAxisSize = this.shape.slice(0, axis).reduce((a, b) => a * b, 1); + const postAxisSize = this.shape.slice(axis + 1).reduce((a, b) => a * b, 1); + + // Copy data from both tensors + for (let i = 0; i < preAxisSize; i++) { + for (let j = 0; j < postAxisSize; j++) { + // Copy from first tensor + for (let k = 0; k < this.shape[axis]; k++) { + const srcIdx = i * stride * postAxisSize + k * postAxisSize + j; + const dstIdx = i * (stride + tensor.shape[axis]) * postAxisSize + k * postAxisSize + j; + newData[dstIdx] = this.data[srcIdx]; + } + // Copy from second tensor + for (let k = 0; k < tensor.shape[axis]; k++) { + const srcIdx = i * tensor.shape[axis] * postAxisSize + k * postAxisSize + j; + const dstIdx = i * (stride + tensor.shape[axis]) * postAxisSize + (k + stride) * postAxisSize + j; + newData[dstIdx] = tensor.data[srcIdx]; + } + } + } + + return new Tensor(newData, newShape, this.requires_grad || tensor.requires_grad); + } + } diff --git a/tests/unit/tensor.test.ts b/tests/unit/tensor.test.ts index d04f9a8..f7a2918 100644 --- a/tests/unit/tensor.test.ts +++ b/tests/unit/tensor.test.ts @@ -102,4 +102,59 @@ describe("Tensor", () => { ); }); }); + + describe("concat", () => { + it("should concatenate 1D tensors along axis 0", async () => { + const t1 = new Tensor(new Float32Array([1, 2, 3]), [3]); + const t2 = new Tensor(new Float32Array([4, 5, 6]), [3]); + + const result = await t1.concat(t2, 0); + + expect(result.shape).toEqual([6]); + expect(Array.from(result.data)).toEqual([1, 2, 3, 4, 5, 6]); + }); + + it("should concatenate 2D tensors along axis 0", async () => { + const t1 = new Tensor(new Float32Array([1, 2, 3, 4]), [2, 2]); + const t2 = new Tensor(new Float32Array([5, 6, 7, 8]), [2, 2]); + + const result = await t1.concat(t2, 0); + + expect(result.shape).toEqual([4, 2]); + expect(Array.from(result.data)).toEqual([1, 2, 3, 4, 5, 6, 7, 8]); + }); + + it("should concatenate 2D tensors along axis 1", async () => { + const t1 = new Tensor(new Float32Array([1, 2, 3, 4]), [2, 2]); + const t2 = new Tensor(new Float32Array([5, 6, 7, 8]), [2, 2]); + + const result = await t1.concat(t2, 1); + + expect(result.shape).toEqual([2, 4]); + expect(Array.from(result.data)).toEqual([1, 2, 5, 6, 3, 4, 7, 8]); + }); + + it("should throw error for invalid axis", async () => { + const t1 = new Tensor(new Float32Array([1, 2]), [2]); + const t2 = new Tensor(new Float32Array([3, 4]), [2]); + + await expect(t1.concat(t2, 1)).rejects.toThrow("Invalid axis"); + }); + + it("should throw error for shape mismatch", async () => { + const t1 = new Tensor(new Float32Array([1, 2]), [2]); + const t2 = new Tensor(new Float32Array([3, 4, 5]), [3]); + + await expect(t1.concat(t2, 0)).rejects.toThrow("Shape mismatch"); + }); + + it("should preserve requires_grad", async () => { + const t1 = new Tensor(new Float32Array([1, 2]), [2], true); + const t2 = new Tensor(new Float32Array([3, 4]), [2], false); + + const result = await t1.concat(t2, 0); + + expect(result.requires_grad).toBe(true); + }); + }); }); From df97d2bac948b39a8b22bcc4c663e4cba0d21868 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Wed, 18 Dec 2024 09:24:23 -0500 Subject: [PATCH 12/23] style: prettier --- package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/package.json b/package.json index d1d4e92..411a295 100644 --- a/package.json +++ b/package.json @@ -10,7 +10,7 @@ "start": "http-server -c-1", "build-clean": "npm run clean && npm run build && npm run build-bundle", "dev": "npm run build-clean && npm run start", - "prettier": "prettier --write .", + "prettier": "prettier --write tests/**/* src/**/*.ts", "build-bundle": "esbuild src/index.ts --bundle --outfile=dist/bundle.js --format=esm --target=es2020", "unit": "npm run build-clean && node --experimental-vm-modules node_modules/jest/bin/jest.js", "integration": "npm run build-clean && npx playwright test", From b1b402da874592eb8abb7e1b23ebd17da8dac155 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Wed, 18 Dec 2024 09:24:39 -0500 Subject: [PATCH 13/23] feat: slice --- src/tensor/tensor.ts | 295 +++++++++++++++++++++++++++++++++++--- tests/unit/tensor.test.ts | 126 ++++++++++++++-- 2 files changed, 388 insertions(+), 33 deletions(-) diff --git a/src/tensor/tensor.ts b/src/tensor/tensor.ts index aed7cfc..8cf2f67 100644 --- a/src/tensor/tensor.ts +++ b/src/tensor/tensor.ts @@ -10,6 +10,13 @@ import { Div } from "../ops/div.js"; import { AutogradFunction } from "../autograd/function.js"; +type SliceArg = + | number + | [number | null] + | [number | null, number | null] + | [number | null, number | null, number | null] + | ":"; + export class Tensor { data: Float32Array; shape: number[]; @@ -92,7 +99,7 @@ export class Tensor { const newShape = [this.shape[0], tensor.shape[0]]; tensor = Tensor.full(newShape, tensor.data[0], tensor.requires_grad); } - + const negOne = Tensor.full(tensor.shape, -1, false); const [negTensor] = await tensor.mul(negOne); return this.add(negTensor); @@ -102,13 +109,13 @@ export class Tensor { // Calculate new shape after reduction const shape = this.shape.slice(); const size = dims.reduce((acc, dim) => acc * shape[dim], 1); - + dims.sort((a, b) => b - a); // Sort in descending order to remove correctly - dims.forEach(dim => shape.splice(dim, 1)); + dims.forEach((dim) => shape.splice(dim, 1)); if (shape.length === 0) shape.push(1); - + const result = new Float32Array(shape.reduce((a, b) => a * b, 1)); - + // For 1D case if (this.shape.length === 1 && dims.includes(0)) { let sum = 0; @@ -118,7 +125,7 @@ export class Tensor { result[0] = sum / size; return new Tensor(result, shape, this.requires_grad); } - + // For higher dimensions (keeping existing logic for 2D) const stride = this.shape[1]; for (let i = 0; i < this.shape[0]; i++) { @@ -128,21 +135,21 @@ export class Tensor { } result[i] = sum / size; } - + return new Tensor(result, shape, this.requires_grad); } - + async variance(dims: number[]): Promise { const mean = await this.mean(dims); const shape = this.shape.slice(); const size = dims.reduce((acc, dim) => acc * shape[dim], 1); - + dims.sort((a, b) => b - a); - dims.forEach(dim => shape.splice(dim, 1)); + dims.forEach((dim) => shape.splice(dim, 1)); if (shape.length === 0) shape.push(1); - + const result = new Float32Array(shape.reduce((a, b) => a * b, 1)); - + // For 1D case if (this.shape.length === 1 && dims.includes(0)) { let sumSquaredDiff = 0; @@ -154,7 +161,7 @@ export class Tensor { result[0] = sumSquaredDiff / size; return new Tensor(result, shape, this.requires_grad); } - + // For higher dimensions const stride = this.shape[1]; for (let i = 0; i < this.shape[0]; i++) { @@ -166,10 +173,10 @@ export class Tensor { } result[i] = sumSquaredDiff / size; } - + return new Tensor(result, shape, this.requires_grad); } - + async sqrt(): Promise { const result = new Float32Array(this.data.length); for (let i = 0; i < this.data.length; i++) { @@ -228,7 +235,11 @@ export class Tensor { oneHot.fill(1, index, index + 1); } - const oneHotTensor = new Tensor(oneHot, [indices.shape[0], this.shape[0]], indices.requires_grad); + const oneHotTensor = new Tensor( + oneHot, + [indices.shape[0], this.shape[0]], + indices.requires_grad, + ); return oneHotTensor.matmul(this); } @@ -300,21 +311,27 @@ export class Tensor { async concat(tensor: Tensor, axis: number): Promise { // Validate axis if (axis < 0 || axis >= this.shape.length) { - throw new Error(`Invalid axis ${axis}. Must be between 0 and ${this.shape.length - 1}`); + throw new Error( + `Invalid axis ${axis}. Must be between 0 and ${this.shape.length - 1}`, + ); } // For axis 0 concatenation, all other dimensions must match exactly if (axis === 0) { // For 1D tensors, they must have the same shape if (this.shape.length === 1 && this.shape[0] !== tensor.shape[0]) { - throw new Error(`Shape mismatch: tensors have different shapes at non-concatenating dimensions`); + throw new Error( + `Shape mismatch: tensors have different shapes at non-concatenating dimensions`, + ); } } // For other axes, validate shapes - all dimensions except concat axis must match for (let i = 0; i < this.shape.length; i++) { if (i !== axis && this.shape[i] !== tensor.shape[i]) { - throw new Error(`Shape mismatch: tensors have different shapes at non-concatenating dimensions`); + throw new Error( + `Shape mismatch: tensors have different shapes at non-concatenating dimensions`, + ); } } @@ -336,19 +353,251 @@ export class Tensor { // Copy from first tensor for (let k = 0; k < this.shape[axis]; k++) { const srcIdx = i * stride * postAxisSize + k * postAxisSize + j; - const dstIdx = i * (stride + tensor.shape[axis]) * postAxisSize + k * postAxisSize + j; + const dstIdx = + i * (stride + tensor.shape[axis]) * postAxisSize + + k * postAxisSize + + j; newData[dstIdx] = this.data[srcIdx]; } // Copy from second tensor for (let k = 0; k < tensor.shape[axis]; k++) { - const srcIdx = i * tensor.shape[axis] * postAxisSize + k * postAxisSize + j; - const dstIdx = i * (stride + tensor.shape[axis]) * postAxisSize + (k + stride) * postAxisSize + j; + const srcIdx = + i * tensor.shape[axis] * postAxisSize + k * postAxisSize + j; + const dstIdx = + i * (stride + tensor.shape[axis]) * postAxisSize + + (k + stride) * postAxisSize + + j; newData[dstIdx] = tensor.data[srcIdx]; } } } - return new Tensor(newData, newShape, this.requires_grad || tensor.requires_grad); + return new Tensor( + newData, + newShape, + this.requires_grad || tensor.requires_grad, + ); + } + async slice(...args: SliceArg[]): Promise { + if (args.length > this.shape.length) { + throw new Error( + `Too many indices for tensor of dimension ${this.shape.length}`, + ); + } + + // Convert all arguments to normalized slice specs + const slices = args.map((arg, dim) => + this.normalizeSlice(arg, this.shape[dim]), + ); + console.log("slices:", slices); + + // Calculate output shape and stride info + const { outputShape, isReducedDim } = this.calculateOutputShape( + slices, + this.shape, + ); + + // Handle empty result case + if (outputShape.length === 0 || outputShape.some((dim) => dim === 0)) { + return new Tensor(new Float32Array(0), outputShape, this.requires_grad); + } + + // Create output tensor + const outputSize = outputShape.reduce((a, b) => a * b, 1); + const result = new Float32Array(outputSize); + + // For each output position, calculate corresponding input position + await this.populateSlicedData( + result, + outputSize, + outputShape, + slices, + isReducedDim, + ); + + return new Tensor(result, outputShape, this.requires_grad); } + private async populateSlicedData( + result: Float32Array, + outputSize: number, + outputShape: number[], + slices: [number, number, number][], + isReducedDim: boolean[], + ): Promise { + // Process in chunks to avoid blocking the main thread + const CHUNK_SIZE = 1000; + + for (let i = 0; i < outputSize; i += CHUNK_SIZE) { + const end = Math.min(i + CHUNK_SIZE, outputSize); + + for (let j = i; j < end; j++) { + const outputCoords = this.indexToCoords(j, outputShape); + const inputCoords = this.mapToInputCoords( + outputCoords, + slices, + isReducedDim, + ); + const inputIndex = this.coordsToIndex(inputCoords, this.shape); + result[j] = this.data[inputIndex]; + } + + // Yield to event loop periodically + if (end < outputSize) { + await new Promise((resolve) => setTimeout(resolve, 0)); + } + } + } + + private calculateOutputShape( + slices: [number, number, number][], + inputShape: number[], + ) { + // Pad slices to match input dimensions + const fullSlices = [...slices]; + while (fullSlices.length < inputShape.length) { + fullSlices.push([0, inputShape[fullSlices.length], 1]); + } + + // Track which dimensions are being reduced (single number index) + const isReducedDim = fullSlices.map( + ([start, end, step]) => end - start === 1 && step === 1, + ); + + // Calculate output shape, handling both positive and negative steps + const outputShape = fullSlices + .map(([start, end, step], i) => { + if (isReducedDim[i]) return 0; + + if (step > 0) { + return Math.max(0, Math.ceil((end - start) / step)); + } else { + // For negative steps, we need to handle the range differently + // When going backwards, we need to include the start position + const numElements = Math.max( + 0, + Math.ceil((start - end + 1) / Math.abs(step)), + ); + return numElements; + } + }) + .filter((size) => size !== 0); + + return { outputShape, isReducedDim }; + } + + private normalizeSlice( + arg: SliceArg, + dimSize: number, + ): [number, number, number] { + // Handle single number index + if (typeof arg === "number") { + const idx = arg < 0 ? dimSize + arg : arg; + if (idx < 0 || idx >= dimSize) { + throw new Error( + `Index ${arg} is out of bounds for dimension ${dimSize}`, + ); + } + return [idx, idx + 1, 1]; + } + + // Handle full slice + if (arg === ":") { + return [0, dimSize, 1]; + } + + // Handle array spec [start, end, step] + let [start, end, step] = arg as [ + number | null, + number | null, + number | null, + ]; + step = step ?? 1; + + if (step === 0) { + throw new Error("Slice step cannot be zero"); + } + + // Handle negative step + if (step < 0) { + // Default start is end of dimension for negative step + start = start ?? dimSize - 1; + // Default end is before beginning of dimension + end = end ?? -1; + + // Convert negative indices to positive + start = start < 0 ? dimSize + start : start; + // For negative step, don't convert negative end index if it's the default -1 + end = end < 0 && end !== -1 ? dimSize + end : end; + + // Clamp to valid range for negative step + start = Math.min(dimSize - 1, Math.max(0, start)); + end = Math.min(dimSize - 1, Math.max(0, end)); + } else { + // Default start is beginning of dimension for positive step + start = start ?? 0; + // Default end is end of dimension + end = end ?? dimSize; + + // Convert negative indices to positive + start = start < 0 ? dimSize + start : start; + end = end < 0 ? dimSize + end : end; + + // Clamp to valid range + start = Math.min(dimSize - 1, Math.max(0, start)); + end = Math.min(dimSize, Math.max(0, end)); + } + + return [start, end, step]; + } + + private indexToCoords(index: number, shape: number[]): number[] { + const coords = []; + let remaining = index; + let stride = shape.reduce((a, b) => a * b, 1); + + for (const dimSize of shape) { + stride = stride / dimSize; + const coord = Math.floor(remaining / stride); + remaining = remaining % stride; + coords.push(coord); + } + + return coords; + } + + private mapToInputCoords( + outputCoords: number[], + slices: [number, number, number][], + isReducedDim: boolean[], + ): number[] { + const inputCoords: number[] = []; + let outputIdx = 0; + + for (let i = 0; i < isReducedDim.length; i++) { + if (isReducedDim[i]) { + // For reduced dimensions, use the start index + inputCoords.push(slices[i][0]); + } else { + // For slice dimensions, calculate the actual position + const [start, , step] = slices[i]; + inputCoords.push(start + outputCoords[outputIdx] * step); + outputIdx++; + } + } + + return inputCoords; + } + + private coordsToIndex(coords: number[], shape: number[]): number { + let index = 0; + let stride = 1; + + for (let i = coords.length - 1; i >= 0; i--) { + index += coords[i] * stride; + stride *= shape[i]; + } + + return index; + } } diff --git a/tests/unit/tensor.test.ts b/tests/unit/tensor.test.ts index f7a2918..d6e19be 100644 --- a/tests/unit/tensor.test.ts +++ b/tests/unit/tensor.test.ts @@ -107,9 +107,9 @@ describe("Tensor", () => { it("should concatenate 1D tensors along axis 0", async () => { const t1 = new Tensor(new Float32Array([1, 2, 3]), [3]); const t2 = new Tensor(new Float32Array([4, 5, 6]), [3]); - + const result = await t1.concat(t2, 0); - + expect(result.shape).toEqual([6]); expect(Array.from(result.data)).toEqual([1, 2, 3, 4, 5, 6]); }); @@ -117,9 +117,9 @@ describe("Tensor", () => { it("should concatenate 2D tensors along axis 0", async () => { const t1 = new Tensor(new Float32Array([1, 2, 3, 4]), [2, 2]); const t2 = new Tensor(new Float32Array([5, 6, 7, 8]), [2, 2]); - + const result = await t1.concat(t2, 0); - + expect(result.shape).toEqual([4, 2]); expect(Array.from(result.data)).toEqual([1, 2, 3, 4, 5, 6, 7, 8]); }); @@ -127,9 +127,9 @@ describe("Tensor", () => { it("should concatenate 2D tensors along axis 1", async () => { const t1 = new Tensor(new Float32Array([1, 2, 3, 4]), [2, 2]); const t2 = new Tensor(new Float32Array([5, 6, 7, 8]), [2, 2]); - + const result = await t1.concat(t2, 1); - + expect(result.shape).toEqual([2, 4]); expect(Array.from(result.data)).toEqual([1, 2, 5, 6, 3, 4, 7, 8]); }); @@ -137,24 +137,130 @@ describe("Tensor", () => { it("should throw error for invalid axis", async () => { const t1 = new Tensor(new Float32Array([1, 2]), [2]); const t2 = new Tensor(new Float32Array([3, 4]), [2]); - + await expect(t1.concat(t2, 1)).rejects.toThrow("Invalid axis"); }); it("should throw error for shape mismatch", async () => { const t1 = new Tensor(new Float32Array([1, 2]), [2]); const t2 = new Tensor(new Float32Array([3, 4, 5]), [3]); - + await expect(t1.concat(t2, 0)).rejects.toThrow("Shape mismatch"); }); it("should preserve requires_grad", async () => { const t1 = new Tensor(new Float32Array([1, 2]), [2], true); const t2 = new Tensor(new Float32Array([3, 4]), [2], false); - + const result = await t1.concat(t2, 0); - + expect(result.requires_grad).toBe(true); }); }); + describe("slice", () => { + it("should slice a 1D tensor with basic indexing", async () => { + const tensor = new Tensor(new Float32Array([1, 2, 3, 4, 5]), [5]); + const result = await tensor.slice([1, 4]); + + expect(result.shape).toEqual([3]); + expect(Array.from(result.data)).toEqual([2, 3, 4]); + }); + + it("should handle full slice with ':'", async () => { + const tensor = new Tensor(new Float32Array([1, 2, 3, 4]), [4]); + const result = await tensor.slice(":"); + + expect(result.shape).toEqual([4]); + expect(Array.from(result.data)).toEqual([1, 2, 3, 4]); + }); + + it("should slice with step size", async () => { + const tensor = new Tensor(new Float32Array([1, 2, 3, 4, 5, 6]), [6]); + const result = await tensor.slice([null, null, 2]); + + expect(result.shape).toEqual([3]); + expect(Array.from(result.data)).toEqual([1, 3, 5]); + }); + + it("should handle negative indices", async () => { + const tensor = new Tensor(new Float32Array([1, 2, 3, 4, 5]), [5]); + const result = await tensor.slice([-3, -1]); + + expect(result.shape).toEqual([2]); + expect(Array.from(result.data)).toEqual([3, 4]); + }); + + it("should slice a 2D tensor along both dimensions", async () => { + const tensor = new Tensor( + new Float32Array([1, 2, 3, 4, 5, 6, 7, 8, 9]), + [3, 3], + ); + const result = await tensor.slice([0, 2], [1, 3]); + + expect(result.shape).toEqual([2, 2]); + expect(Array.from(result.data)).toEqual([2, 3, 5, 6]); + }); + + it("should handle reverse slicing with negative step", async () => { + const tensor = new Tensor(new Float32Array([1, 2, 3, 4, 5]), [5]); + const result = await tensor.slice([null, null, -1]); + + expect(result.shape).toEqual([5]); + expect(Array.from(result.data)).toEqual([5, 4, 3, 2, 1]); + }); + + it("should preserve requires_grad", async () => { + const tensor = new Tensor(new Float32Array([1, 2, 3, 4]), [4], true); + const result = await tensor.slice([1, 3]); + + expect(result.requires_grad).toBe(true); + }); + + it("should handle mixed slicing with numbers and slices", async () => { + const tensor = new Tensor( + new Float32Array([1, 2, 3, 4, 5, 6, 7, 8, 9]), + [3, 3], + ); + const result = await tensor.slice(1, ":"); + + expect(result.shape).toEqual([3]); + expect(Array.from(result.data)).toEqual([4, 5, 6]); + }); + + it("should slice the first half of a dimension", async () => { + const tensor = new Tensor(new Float32Array([1, 2, 3, 4, 5, 6]), [6]); + const result = await tensor.slice([0, 3]); + + expect(result.shape).toEqual([3]); + expect(Array.from(result.data)).toEqual([1, 2, 3]); + }); + + it("should handle overlapping step slices", async () => { + const tensor = new Tensor(new Float32Array([1, 2, 3, 4, 5]), [5]); + const result = await tensor.slice([0, 4, 2]); + + expect(result.shape).toEqual([2]); + expect(Array.from(result.data)).toEqual([1, 3]); + }); + + it("should throw error for invalid dimensions", async () => { + const tensor = new Tensor(new Float32Array([1, 2, 3]), [3]); + await expect(tensor.slice(":", ":")).rejects.toThrow( + "Too many indices for tensor", + ); + }); + + it("should handle 3D tensor slicing", async () => { + const tensor = new Tensor( + new Float32Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), + [3, 2, 2], + ); + const result = await tensor.slice(":", [0, 2], ":"); + + expect(result.shape).toEqual([3, 2, 2]); + expect(Array.from(result.data)).toEqual([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + ]); + }); + }); }); From 8fa9e8e7019d5d96c08de05ca0bdb80d2af710fd Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Wed, 18 Dec 2024 09:24:54 -0500 Subject: [PATCH 14/23] style: prettier format --- src/layers/linear.ts | 28 +++++----- src/layers/module.ts | 21 ++++---- src/layers/norm.ts | 74 +++++++++++++------------- src/ops/add.ts | 16 +++--- src/ops/div.ts | 13 ++--- src/ops/mul.ts | 5 +- src/shaders/add.ts | 2 +- src/shaders/div.ts | 2 +- src/shaders/exp.ts | 2 +- src/shaders/exp2.ts | 2 +- src/shaders/ln.ts | 2 +- src/shaders/log2.ts | 2 +- src/shaders/matmul.ts | 2 +- src/shaders/mul.ts | 2 +- src/shaders/relu.ts | 2 +- tests/integration/div.test.ts | 3 +- tests/integration/embedding.test.ts | 18 +++---- tests/integration/gather.test.ts | 50 +++++++++-------- tests/integration/linear.test.ts | 27 +++++----- tests/integration/norm.test.ts | 20 ++++--- tests/integration/tensor_stats.test.ts | 22 ++++---- 21 files changed, 163 insertions(+), 152 deletions(-) diff --git a/src/layers/linear.ts b/src/layers/linear.ts index 196c4bd..dcd0157 100644 --- a/src/layers/linear.ts +++ b/src/layers/linear.ts @@ -2,19 +2,19 @@ import { Tensor } from "../tensor/tensor.js"; import { Module } from "./module.js"; export class Linear extends Module { - weight: Tensor; - bias: Tensor; + weight: Tensor; + bias: Tensor; - constructor(inputSize: number, outputSize: number) { - super("linear"); - this.weight = Tensor.randn([inputSize, outputSize], true); - this.bias = Tensor.randn([outputSize], true); - } + constructor(inputSize: number, outputSize: number) { + super("linear"); + this.weight = Tensor.randn([inputSize, outputSize], true); + this.bias = Tensor.randn([outputSize], true); + } - async forward(...inputs: [Tensor]): Promise<[Tensor]> { - const [input] = inputs; - const [output] = await input.matmul(this.weight); - const [outputBias] = await output.add(this.bias); - return [outputBias]; - } -} \ No newline at end of file + async forward(...inputs: [Tensor]): Promise<[Tensor]> { + const [input] = inputs; + const [output] = await input.matmul(this.weight); + const [outputBias] = await output.add(this.bias); + return [outputBias]; + } +} diff --git a/src/layers/module.ts b/src/layers/module.ts index 5312934..a7306d8 100644 --- a/src/layers/module.ts +++ b/src/layers/module.ts @@ -1,20 +1,19 @@ import { Tensor } from "../tensor/tensor"; export abstract class Module { - protected name: string; - constructor(name: string) { - if (name === null || name === undefined) { - throw Error("Name cannot be null or undefined"); - } - - this.name = name; + protected name: string; + constructor(name: string) { + if (name === null || name === undefined) { + throw Error("Name cannot be null or undefined"); } - /** + this.name = name; + } + + /** * Abstract method that must be implemented by all layer subclasses * Defines the forward pass computation of the layer * @param inputs - Input tensor(s) to the layer * @returns Output tensor(s) from the layer */ - abstract forward(...inputs: [Tensor]): Promise<[Tensor]> - -} \ No newline at end of file + abstract forward(...inputs: [Tensor]): Promise<[Tensor]>; +} diff --git a/src/layers/norm.ts b/src/layers/norm.ts index 8491d30..ec1be60 100644 --- a/src/layers/norm.ts +++ b/src/layers/norm.ts @@ -2,43 +2,43 @@ import { Tensor } from "../tensor/tensor.js"; import { Module } from "./module.js"; export class LayerNorm extends Module { - normalized_shape: number[]; - eps: Tensor; - gamma: Tensor; - beta: Tensor; - constructor(normalized_shape: number[], eps: number) { - super("layer_norm"); - this.normalized_shape = normalized_shape; - // eps should be [2, 1] for broadcasting - this.eps = Tensor.full([1], eps); // Make eps a scalar tensor - // gamma and beta should match the feature dimension - this.gamma = Tensor.full([1, normalized_shape[0]], 1); // [1, 3] for broadcasting - this.beta = Tensor.full([1, normalized_shape[0]], 0); // [1, 3] for broadcasting - } + normalized_shape: number[]; + eps: Tensor; + gamma: Tensor; + beta: Tensor; + constructor(normalized_shape: number[], eps: number) { + super("layer_norm"); + this.normalized_shape = normalized_shape; + // eps should be [2, 1] for broadcasting + this.eps = Tensor.full([1], eps); // Make eps a scalar tensor + // gamma and beta should match the feature dimension + this.gamma = Tensor.full([1, normalized_shape[0]], 1); // [1, 3] for broadcasting + this.beta = Tensor.full([1, normalized_shape[0]], 0); // [1, 3] for broadcasting + } - async forward(x: Tensor): Promise<[Tensor]> { - const reduction_dims = [1]; // Reduce over the feature dimension + async forward(x: Tensor): Promise<[Tensor]> { + const reduction_dims = [1]; // Reduce over the feature dimension - // Calculate mean and reshape for broadcasting - const mean = await x.mean(reduction_dims); - mean.shape = [mean.shape[0], 1]; // [2, 1] - - const variance = await x.variance(reduction_dims); - variance.shape = [variance.shape[0], 1]; // [2, 1] - - console.log("x shape:", x.shape); // [2, 3] - console.log("mean shape:", mean.shape); // [2, 1] - console.log("variance shape:", variance.shape); // [2, 1] - console.log("gamma shape:", this.gamma.shape); // [1, 3] - console.log("beta shape:", this.beta.shape); // [1, 3] + // Calculate mean and reshape for broadcasting + const mean = await x.mean(reduction_dims); + mean.shape = [mean.shape[0], 1]; // [2, 1] - const [numerator] = await x.sub(mean); // [2, 3] - const [denominator] = await variance.add(this.eps); - const sqrtDenom = await denominator.sqrt(); - const [normalized] = await numerator.div(sqrtDenom); - - const [gamma] = await normalized.mul(this.gamma); // [2, 3] * [1, 3] -> [2, 3] - const [beta] = await gamma.add(this.beta); // [2, 3] + [1, 3] -> [2, 3] - return [beta]; - } -} \ No newline at end of file + const variance = await x.variance(reduction_dims); + variance.shape = [variance.shape[0], 1]; // [2, 1] + + console.log("x shape:", x.shape); // [2, 3] + console.log("mean shape:", mean.shape); // [2, 1] + console.log("variance shape:", variance.shape); // [2, 1] + console.log("gamma shape:", this.gamma.shape); // [1, 3] + console.log("beta shape:", this.beta.shape); // [1, 3] + + const [numerator] = await x.sub(mean); // [2, 3] + const [denominator] = await variance.add(this.eps); + const sqrtDenom = await denominator.sqrt(); + const [normalized] = await numerator.div(sqrtDenom); + + const [gamma] = await normalized.mul(this.gamma); // [2, 3] * [1, 3] -> [2, 3] + const [beta] = await gamma.add(this.beta); // [2, 3] + [1, 3] -> [2, 3] + return [beta]; + } +} diff --git a/src/ops/add.ts b/src/ops/add.ts index 72d3103..382601e 100644 --- a/src/ops/add.ts +++ b/src/ops/add.ts @@ -10,12 +10,10 @@ export class Add extends BinaryOp { if (b.shape.length === 1 && b.shape[0] === 1) { // Broadcast scalar b = Tensor.full(a.shape, b.data[0], b.requires_grad); - } - else if (b.shape.length === 1 && b.shape[0] === a.shape[1]) { + } else if (b.shape.length === 1 && b.shape[0] === a.shape[1]) { // Broadcast [m] to [n, m] b = Tensor.broadcast(b, a.shape[0], b.requires_grad); - } - else if (b.shape.length === 2 && b.shape[1] === 1) { + } else if (b.shape.length === 2 && b.shape[1] === 1) { // Broadcast [n, 1] to [n, m] const newShape = [b.shape[0], a.shape[1]]; console.log("Broadcasting [n,1] to shape:", newShape); @@ -26,8 +24,11 @@ export class Add extends BinaryOp { } } b = new Tensor(newData, newShape, b.requires_grad); - } - else if (b.shape.length === 2 && b.shape[0] === 1 && b.shape[1] === a.shape[1]) { + } else if ( + b.shape.length === 2 && + b.shape[0] === 1 && + b.shape[1] === a.shape[1] + ) { // Broadcast [1, m] to [n, m] const newShape = [a.shape[0], b.shape[1]]; console.log("Broadcasting [1,m] to shape:", newShape); @@ -38,8 +39,7 @@ export class Add extends BinaryOp { } } b = new Tensor(newData, newShape, b.requires_grad); - } - else { + } else { throw new Error( `Incompatible shapes for Add: ${a.shape} and ${b.shape}`, ); diff --git a/src/ops/div.ts b/src/ops/div.ts index 6b4a704..3f03ab2 100644 --- a/src/ops/div.ts +++ b/src/ops/div.ts @@ -10,12 +10,10 @@ export class Div extends BinaryOp { if (b.shape.length === 1 && b.shape[0] === 1) { // Broadcast scalar b = Tensor.full(a.shape, b.data[0], b.requires_grad); - } - else if (b.shape.length === 1 && b.shape[0] === a.shape[1]) { + } else if (b.shape.length === 1 && b.shape[0] === a.shape[1]) { // Broadcast [m] to [n, m] b = Tensor.broadcast(b, a.shape[0], b.requires_grad); - } - else if (b.shape.length === 2 && b.shape[1] === 1) { + } else if (b.shape.length === 2 && b.shape[1] === 1) { // Broadcast [n, 1] to [n, m] const newShape = [b.shape[0], a.shape[1]]; console.log("Broadcasting [n,1] to shape:", newShape); @@ -27,15 +25,14 @@ export class Div extends BinaryOp { } } b = new Tensor(newData, newShape, b.requires_grad); - } - else { + } else { throw new Error( `Incompatible shapes for Div: ${a.shape} and ${b.shape}`, ); } } return b; -} + } async backward(grad_output: Tensor): Promise { const [a, b] = this.inputs; @@ -57,4 +54,4 @@ export class Div extends BinaryOp { (tensor): tensor is Tensor => tensor !== null, ); } -} \ No newline at end of file +} diff --git a/src/ops/mul.ts b/src/ops/mul.ts index 7c1ec86..45291f6 100644 --- a/src/ops/mul.ts +++ b/src/ops/mul.ts @@ -13,14 +13,13 @@ export class Mul extends BinaryOp { } else if (b.shape[0] === 1 && b.shape[1] === a.shape[1]) { // broadcast [1, n] to [m, n] b = Tensor.full(a.shape, b.data[0], b.requires_grad); - } - else { + } else { throw new Error( `Incompatible shapes for Mul: ${a.shape} and ${b.shape}`, ); } } - if (a.shape.length === 1){ + if (a.shape.length === 1) { a.shape = [a.shape[0], 1]; } return b; diff --git a/src/shaders/add.ts b/src/shaders/add.ts index 599d62f..8e8f565 100644 --- a/src/shaders/add.ts +++ b/src/shaders/add.ts @@ -19,4 +19,4 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { result[row * dimensions.N + col] = a[row * dimensions.N + col] + scalar[row * dimensions.N + col]; } } -`; \ No newline at end of file +`; diff --git a/src/shaders/div.ts b/src/shaders/div.ts index bf62dcb..c1b72da 100644 --- a/src/shaders/div.ts +++ b/src/shaders/div.ts @@ -19,4 +19,4 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { result[row * dimensions.N + col] = a[row * dimensions.N + col] / scalar[row * dimensions.N + col]; } } -`; \ No newline at end of file +`; diff --git a/src/shaders/exp.ts b/src/shaders/exp.ts index 420e566..b9b62fc 100644 --- a/src/shaders/exp.ts +++ b/src/shaders/exp.ts @@ -18,4 +18,4 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { result[row * dimensions.N + col] = exp(a[row * dimensions.N + col]); } } -` \ No newline at end of file +`; diff --git a/src/shaders/exp2.ts b/src/shaders/exp2.ts index f1b74c9..cf5c2d1 100644 --- a/src/shaders/exp2.ts +++ b/src/shaders/exp2.ts @@ -18,4 +18,4 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { result[row * dimensions.N + col] = exp2(a[row * dimensions.N + col]); } } -`; \ No newline at end of file +`; diff --git a/src/shaders/ln.ts b/src/shaders/ln.ts index 21822fe..efb75e3 100644 --- a/src/shaders/ln.ts +++ b/src/shaders/ln.ts @@ -18,4 +18,4 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { result[row * dimensions.N + col] = log(a[row * dimensions.N + col]); } } -`; \ No newline at end of file +`; diff --git a/src/shaders/log2.ts b/src/shaders/log2.ts index 2114af9..e85097e 100644 --- a/src/shaders/log2.ts +++ b/src/shaders/log2.ts @@ -18,4 +18,4 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { result[row * dimensions.N + col] = log2(a[row * dimensions.N + col]); } } -`; \ No newline at end of file +`; diff --git a/src/shaders/matmul.ts b/src/shaders/matmul.ts index 926804a..97515c5 100644 --- a/src/shaders/matmul.ts +++ b/src/shaders/matmul.ts @@ -326,4 +326,4 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { } } } -`; \ No newline at end of file +`; diff --git a/src/shaders/mul.ts b/src/shaders/mul.ts index 642bd51..6bb0053 100644 --- a/src/shaders/mul.ts +++ b/src/shaders/mul.ts @@ -19,4 +19,4 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { result[row * dimensions.N + col] = a[row * dimensions.N + col] * scalar[row * dimensions.N + col]; } } -`; \ No newline at end of file +`; diff --git a/src/shaders/relu.ts b/src/shaders/relu.ts index 93a5443..0b6e600 100644 --- a/src/shaders/relu.ts +++ b/src/shaders/relu.ts @@ -18,4 +18,4 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { result[row * dimensions.N + col] = max(a[row * dimensions.N + col], 0); } } -`; \ No newline at end of file +`; diff --git a/tests/integration/div.test.ts b/tests/integration/div.test.ts index d4aa49a..277dc17 100644 --- a/tests/integration/div.test.ts +++ b/tests/integration/div.test.ts @@ -1,6 +1,5 @@ import { test, expect } from "@playwright/test"; - test("Elementwise scalar/broadcasted division forward and backward pass", async ({ page, }) => { @@ -65,4 +64,4 @@ test("Elementwise scalar/broadcasted division forward and backward pass", async expect(gradXData).toEqual(new Float32Array([0.5, 0.5, 0.5, 0.5, 0.5, 0.5])); await page.close(); -}); \ No newline at end of file +}); diff --git a/tests/integration/embedding.test.ts b/tests/integration/embedding.test.ts index d1be2ce..8eae819 100644 --- a/tests/integration/embedding.test.ts +++ b/tests/integration/embedding.test.ts @@ -13,11 +13,11 @@ test("Embedding forward pass with known values", async ({ page }) => { // @ts-expect-error ignore error for tests import("/dist/bundle.js").then((module) => { const { Tensor, Embedding } = module; - + window.runEmbeddingTest = async function () { const vocabSize = 128; const embeddingDim = 2; // Using small dim for easy verification - + // Create embedding layer const embedding = new Embedding(vocabSize, embeddingDim); console.log(embedding); @@ -26,7 +26,7 @@ test("Embedding forward pass with known values", async ({ page }) => { const inputIndices = new Tensor( new Float32Array([1, 5, 10]), // Sample indices [3], // Sequence length of 3 - false + false, ); // Forward pass @@ -36,7 +36,7 @@ test("Embedding forward pass with known values", async ({ page }) => { inputIndices: Array.from(inputIndices.data), embedding: embedding.embedding, outputShape: embeddings.shape, - outputData: Array.from(embeddings.data) + outputData: Array.from(embeddings.data), }; }; resolve(); @@ -48,17 +48,17 @@ test("Embedding forward pass with known values", async ({ page }) => { const result = await page.evaluate(() => window.runEmbeddingTest()); // Validate shapes - expect(result.outputShape).toEqual([3, 2]); // Sequence length x Embedding dim + expect(result.outputShape).toEqual([3, 2]); // Sequence length x Embedding dim const expectedOutput = [ result.embedding.data[2], result.embedding.data[3], result.embedding.data[10], - result.embedding.data[11] , + result.embedding.data[11], result.embedding.data[20], - result.embedding.data[21] + result.embedding.data[21], ]; - + expect(result.outputData).toEqual(expectedOutput); await page.close(); @@ -73,4 +73,4 @@ declare global { outputData: number[]; }>; } -} \ No newline at end of file +} diff --git a/tests/integration/gather.test.ts b/tests/integration/gather.test.ts index f27ad5a..58f9470 100644 --- a/tests/integration/gather.test.ts +++ b/tests/integration/gather.test.ts @@ -18,20 +18,19 @@ test("Gather forward and backward pass", async ({ page }) => { // Create a simple embedding matrix with 3 embeddings of dimension 2 const embeddings = new Tensor( new Float32Array([ - 1.0, 2.0, // embedding 0 - 3.0, 4.0, // embedding 1 - 5.0, 6.0 // embedding 2 + 1.0, + 2.0, // embedding 0 + 3.0, + 4.0, // embedding 1 + 5.0, + 6.0, // embedding 2 ]), [3, 2], - true + true, ); // Look up embeddings at indices 1, 0 (second embedding, then first) - const indices = new Tensor( - new Float32Array([1, 0]), - [2, 1], - false - ); + const indices = new Tensor(new Float32Array([1, 0]), [2, 1], false); // Forward pass - gather embeddings const [output] = await embeddings.gather(indices); @@ -55,23 +54,32 @@ test("Gather forward and backward pass", async ({ page }) => { // @ts-expect-error ignore error for tests const result = await page.evaluate(() => window.runGatherTest()); - expect(result.output.shape).toEqual([2, 2]); // 2 selected embeddings of dimension 2 - expect(result.grad_embeddings.shape).toEqual([3, 2]); // Same shape as input embeddings + expect(result.output.shape).toEqual([2, 2]); // 2 selected embeddings of dimension 2 + expect(result.grad_embeddings.shape).toEqual([3, 2]); // Same shape as input embeddings // Forward pass assertions - should get embeddings at indices 1 and 0 const outputData = new Float32Array(Object.values(result.output.data)); - expect(outputData).toEqual(new Float32Array([ - 3.0, 4.0, // embedding at index 1 - 1.0, 2.0 // embedding at index 0 - ])); + expect(outputData).toEqual( + new Float32Array([ + 3.0, + 4.0, // embedding at index 1 + 1.0, + 2.0, // embedding at index 0 + ]), + ); // Backward pass assertions - gradient should accumulate at the selected indices const gradData = new Float32Array(Object.values(result.grad_embeddings.data)); - expect(gradData).toEqual(new Float32Array([ - 1.0, 1.0, // gradient for embedding 0 (selected second) - 1.0, 1.0, // gradient for embedding 1 (selected first) - 0.0, 0.0 // gradient for embedding 2 (not selected) - ])); + expect(gradData).toEqual( + new Float32Array([ + 1.0, + 1.0, // gradient for embedding 0 (selected second) + 1.0, + 1.0, // gradient for embedding 1 (selected first) + 0.0, + 0.0, // gradient for embedding 2 (not selected) + ]), + ); await page.close(); -}); \ No newline at end of file +}); diff --git a/tests/integration/linear.test.ts b/tests/integration/linear.test.ts index 3e8a0c6..c21cfc4 100644 --- a/tests/integration/linear.test.ts +++ b/tests/integration/linear.test.ts @@ -13,28 +13,31 @@ test("Linear forward pass with known values", async ({ page }) => { // @ts-expect-error ignore error for tests import("/dist/bundle.js").then((module) => { const { Tensor, Linear } = module; - + window.runLinearTest = async function () { const inputSize = 3; const outputSize = 2; - + // Create linear layer const linear = new Linear(inputSize, outputSize); - + // Set known weights and biases for deterministic testing linear.weight.data = new Float32Array([ - 0.1, 0.2, // First row - 0.3, 0.4, // Second row - 0.5, 0.6 // Third row + 0.1, + 0.2, // First row + 0.3, + 0.4, // Second row + 0.5, + 0.6, // Third row ]); - + linear.bias.data = new Float32Array([0.1, 0.2]); // Create input tensor const input = new Tensor( new Float32Array([1.0, 2.0, 3.0]), // Sample input [1, 3], // Batch size 1, input size 3 - false + false, ); // Forward pass @@ -45,7 +48,7 @@ test("Linear forward pass with known values", async ({ page }) => { weights: Array.from(linear.weight.data), biases: Array.from(linear.bias.data), outputShape: output.shape, - outputData: Array.from(output.data) + outputData: Array.from(output.data), }; }; resolve(); @@ -57,13 +60,13 @@ test("Linear forward pass with known values", async ({ page }) => { const result = await page.evaluate(() => window.runLinearTest()); // Validate shapes - expect(result.outputShape).toEqual([1, 2]); // Batch size x Output size + expect(result.outputShape).toEqual([1, 2]); // Batch size x Output size // Calculate expected output manually: // output[0] = (1.0 * 0.1 + 2.0 * 0.3 + 3.0 * 0.5) + 0.1 = 2.0 // output[1] = (1.0 * 0.2 + 2.0 * 0.4 + 3.0 * 0.6) + 0.2 = 2.8 const expectedOutput = [2.3, 3.0]; - + // Check if outputs match expected values within a small tolerance expect(result.outputData[0]).toBeCloseTo(expectedOutput[0], 5); expect(result.outputData[1]).toBeCloseTo(expectedOutput[1], 5); @@ -81,4 +84,4 @@ declare global { outputData: number[]; }>; } -} \ No newline at end of file +} diff --git a/tests/integration/norm.test.ts b/tests/integration/norm.test.ts index 540a404..580ef7b 100644 --- a/tests/integration/norm.test.ts +++ b/tests/integration/norm.test.ts @@ -13,18 +13,18 @@ test("LayerNorm forward pass with known values", async ({ page }) => { // @ts-expect-error ignore error for tests import("/dist/bundle.js").then((module) => { const { Tensor, LayerNorm } = module; - + window.runLayerNormTest = async function () { // Create a simple input tensor with known values const input = new Tensor( new Float32Array([1, 2, 3, 4, 5, 6]), // Sample values [2, 3], // 2 sequences, 3 features each - false + false, ); - + // Create LayerNorm with normalized_shape [3] const layerNorm = new LayerNorm([3], 1e-5); - + // Set known values for gamma and beta layerNorm.gamma.data.set([1.0, 1.0, 1.0]); layerNorm.beta.data.set([0.0, 0.0, 0.0]); @@ -38,7 +38,7 @@ test("LayerNorm forward pass with known values", async ({ page }) => { outputShape: output.shape, outputData: Array.from(output.data), gamma: Array.from(layerNorm.gamma.data), - beta: Array.from(layerNorm.beta.data) + beta: Array.from(layerNorm.beta.data), }; }; resolve(); @@ -56,8 +56,12 @@ test("LayerNorm forward pass with known values", async ({ page }) => { // For the input [1,2,3] and [4,5,6], with gamma=1 and beta=0, // we can pre-calculate the expected normalized values const expectedOutput = [ - -1.224744871391589, 0, 1.224744871391589, // First sequence normalized - -1.224744871391589, 0, 1.224744871391589 // Second sequence normalized + -1.224744871391589, + 0, + 1.224744871391589, // First sequence normalized + -1.224744871391589, + 0, + 1.224744871391589, // Second sequence normalized ]; // Check if output matches expected values (using approximate equality) @@ -79,4 +83,4 @@ declare global { beta: number[]; }>; } -} \ No newline at end of file +} diff --git a/tests/integration/tensor_stats.test.ts b/tests/integration/tensor_stats.test.ts index 2fc4cfc..a26fb2a 100644 --- a/tests/integration/tensor_stats.test.ts +++ b/tests/integration/tensor_stats.test.ts @@ -6,7 +6,7 @@ test.describe("Tensor Statistics Operations", () => { test("should calculate mean along specified dimensions", async () => { const data = new Float32Array([1, 2, 3, 4, 5, 6]); const tensor = new Tensor(data, [2, 3], false); - + const mean = await tensor.mean([1]); expect(mean.shape).toEqual([2]); expect(Array.from(mean.data)).toEqual([2, 5]); // [mean(1,2,3), mean(4,5,6)] @@ -15,7 +15,7 @@ test.describe("Tensor Statistics Operations", () => { test("should handle single dimension tensors", async () => { const data = new Float32Array([1, 2, 3, 4]); const tensor = new Tensor(data, [4], false); - + const mean = await tensor.mean([0]); expect(mean.shape).toEqual([1]); expect(mean.data[0]).toBeCloseTo(2.5); // mean(1,2,3,4) @@ -26,17 +26,19 @@ test.describe("Tensor Statistics Operations", () => { test("should calculate variance along specified dimensions", async () => { const data = new Float32Array([1, 2, 3, 4, 5, 6]); const tensor = new Tensor(data, [2, 3], false); - + const variance = await tensor.variance([1]); expect(variance.shape).toEqual([2]); // Variance of [1,2,3] and [4,5,6] - expect(Array.from(variance.data).map(x => Number(x.toFixed(2)))).toEqual([0.67, 0.67]); + expect( + Array.from(variance.data).map((x) => Number(x.toFixed(2))), + ).toEqual([0.67, 0.67]); }); test("should handle single dimension tensors", async () => { const data = new Float32Array([2, 4, 4, 6]); const tensor = new Tensor(data, [4], false); - + const variance = await tensor.variance([0]); expect(variance.shape).toEqual([1]); expect(variance.data[0]).toBeCloseTo(2); // variance of [2,4,4,6] @@ -47,7 +49,7 @@ test.describe("Tensor Statistics Operations", () => { test("should calculate element-wise square root", async () => { const data = new Float32Array([1, 4, 9, 16]); const tensor = new Tensor(data, [4], false); - + const sqrt = await tensor.sqrt(); expect(sqrt.shape).toEqual([4]); expect(Array.from(sqrt.data)).toEqual([1, 2, 3, 4]); @@ -56,7 +58,7 @@ test.describe("Tensor Statistics Operations", () => { test("should handle multi-dimensional tensors", async () => { const data = new Float32Array([1, 4, 9, 16, 25, 36]); const tensor = new Tensor(data, [2, 3], false); - + const sqrt = await tensor.sqrt(); expect(sqrt.shape).toEqual([2, 3]); expect(Array.from(sqrt.data)).toEqual([1, 2, 3, 4, 5, 6]); @@ -67,12 +69,12 @@ test.describe("Tensor Statistics Operations", () => { test("should correctly compute standard deviation using sqrt(variance)", async () => { const data = new Float32Array([2, 4, 4, 6]); const tensor = new Tensor(data, [4], false); - + const variance = await tensor.variance([0]); const stdDev = await variance.sqrt(); - + expect(stdDev.shape).toEqual([1]); // The shape should be [1] for a scalar result expect(stdDev.data[0]).toBeCloseTo(Math.sqrt(2)); // The actual value check }); }); -}); \ No newline at end of file +}); From 72b625991376f106d23455f22b7ed27434f5b979 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Thu, 19 Dec 2024 11:03:37 -0500 Subject: [PATCH 15/23] feat: rotary pos embed --- src/layers/embedding.ts | 125 ++++++++++++++++++++++++++++--- src/tensor/tensor.ts | 1 - tests/integration/rotary.test.ts | 121 ++++++++++++++++++++++++++++++ 3 files changed, 234 insertions(+), 13 deletions(-) create mode 100644 tests/integration/rotary.test.ts diff --git a/src/layers/embedding.ts b/src/layers/embedding.ts index 36604ba..54ae9ce 100644 --- a/src/layers/embedding.ts +++ b/src/layers/embedding.ts @@ -2,19 +2,120 @@ import { Tensor } from "../tensor/tensor.js"; import { Module } from "./module.js"; export class Embedding extends Module { - vocab_size: number; - emb_dim: number; - embedding: Tensor; - constructor(vocab_size: number, emb_dim: number) { - super("embedding"); - - this.vocab_size = vocab_size; - this.emb_dim = emb_dim; - this.embedding = Tensor.randn([vocab_size, emb_dim], true); + vocab_size: number; + emb_dim: number; + embedding: Tensor; + constructor(vocab_size: number, emb_dim: number) { + super("embedding"); + + this.vocab_size = vocab_size; + this.emb_dim = emb_dim; + this.embedding = Tensor.randn([vocab_size, emb_dim], true); + } + + async forward(...inputs: [Tensor]): Promise<[Tensor]> { + const [embeddings] = await this.embedding.gather(inputs[0]); + return [embeddings]; + } +} + +export class RotaryEmbedding extends Module { + base: number; + dimension: number; + theta: Tensor; + sequenceLength: number; + idxTheta: Tensor | null = null; + constructor(base: number, dimension: number) { + super("rope_embedding"); + this.base = base; + this.dimension = dimension; + + const theta = this.createTheta(dimension, base); + this.theta = new Tensor(theta, [1, dimension / 2], true); + + this.sequenceLength = 0; + this.idxTheta = null; + } + + createTheta(dimension: number, base: number = 10000): Float32Array { + // Create a new Float32Array of the specified size + const result = new Float32Array(dimension / 2); + + // Calculate values for each position + for (let i = 0; i < dimension; i += 2) { + const value = 1.0 / Math.pow(base, i / dimension); + result[i / 2] = value; + } + + return result; + } + + async buildCache(sequenceLength: number) { + const posIdx = new Float32Array(sequenceLength); + for (let i = 0; i < sequenceLength; i++) { + posIdx[i] = i; + } + + const posTensor = new Tensor(posIdx, [sequenceLength, 1], true); + let [idxTheta] = await posTensor.matmul(this.theta); + + idxTheta = await idxTheta.concat(idxTheta, 1); + + return [idxTheta]; + } + + async forward(...inputs: [Tensor]): Promise<[Tensor]> { + const [x] = inputs; + + const currSeqLen = x.shape[0]; + const d2 = Math.floor(this.dimension / 2); + + if (currSeqLen > this.sequenceLength || this.idxTheta === null) { + const [cache] = await this.buildCache(currSeqLen); + this.sequenceLength = currSeqLen; + this.idxTheta = cache; + } + + const idxTheta = this.idxTheta; + + const idxThetaLength = idxTheta.data.length; + const cosIdxThetaArr = new Float32Array(idxThetaLength); + const sinIdxThetaArr = new Float32Array(idxThetaLength); + + for (let i = 0; i < idxThetaLength; i++) { + cosIdxThetaArr[i] = Math.cos(idxTheta.data[i]); + sinIdxThetaArr[i] = Math.sin(idxTheta.data[i]); } - async forward(...inputs: [Tensor]): Promise<[Tensor]> { - const [embeddings] = await this.embedding.gather(inputs[0]); - return [embeddings]; + const cosIdxTheta = new Tensor( + cosIdxThetaArr, + [currSeqLen, this.dimension], + x.requires_grad, + ); + const sinIdxTheta = new Tensor( + sinIdxThetaArr, + [currSeqLen, this.dimension], + x.requires_grad, + ); + + // Rewrite using tensor operations and select + const leftHalf = await x.slice(":", [null, d2]); + const rightHalf = await x.slice(":", [d2, this.dimension]); + const [negHalf] = await rightHalf.mul(Tensor.full([1], -1)); + + const half = await negHalf.concat(leftHalf, 1); + const xRope = await x.slice(":", [null, this.dimension]); + + const [xRopePos] = await xRope.mul(cosIdxTheta); + const [xRopeNeg] = await half.mul(sinIdxTheta); + + let [rope] = await xRopePos.add(xRopeNeg); + if (this.dimension < x.shape[1]) { + const xPass = await x.slice(":", [null, null, d2]); + + rope = await rope.concat(xPass, 1); } + + return [rope]; + } } diff --git a/src/tensor/tensor.ts b/src/tensor/tensor.ts index 8cf2f67..f8039e8 100644 --- a/src/tensor/tensor.ts +++ b/src/tensor/tensor.ts @@ -389,7 +389,6 @@ export class Tensor { const slices = args.map((arg, dim) => this.normalizeSlice(arg, this.shape[dim]), ); - console.log("slices:", slices); // Calculate output shape and stride info const { outputShape, isReducedDim } = this.calculateOutputShape( diff --git a/tests/integration/rotary.test.ts b/tests/integration/rotary.test.ts new file mode 100644 index 0000000..e382422 --- /dev/null +++ b/tests/integration/rotary.test.ts @@ -0,0 +1,121 @@ +import { test, expect } from "@playwright/test"; + +test("Rotary positional embedding forward pass with known values", async ({ + page, +}) => { + await page.goto("http://localhost:8080"); + + page.on("console", (msg) => { + console.log(msg); + }); + + // Inject test function + await page.evaluate(() => { + return new Promise((resolve) => { + // @ts-expect-error ignore error for tests + import("/dist/bundle.js").then((module) => { + const { Tensor, RotaryEmbedding } = module; + + window.runRotaryEmbeddingTest = async function () { + const seqLength = 4; + const dimension = 8; // Must be divisible by 2 for rotary embeddings + const base = 10000.0; + + // Create rotary embedding layer + const rotaryEmbed = new RotaryEmbedding(base, dimension); + + // Create sample input tensor + const input = new Tensor( + new Float32Array([ + 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.2, 0.3, 0.4, 0.5, 0.6, + 0.7, 0.8, 0.9, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 0.4, 0.5, + 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, + ]), + [seqLength, dimension], + false, + ); + + // Forward pass + const [rotatedOutput] = await rotaryEmbed.forward(input); + console.log("rotatedOutput", rotatedOutput.data.toString()); + + // Get the theta values and position encodings for verification + const theta = rotaryEmbed.createTheta(dimension, base); + const [cache] = await rotaryEmbed.buildCache(seqLength); + + return { + inputShape: input.shape, + inputData: Array.from(input.data), + outputShape: rotatedOutput.shape, + outputData: Array.from(rotatedOutput.data), + theta: Array.from(theta), + cache: Array.from(cache.data), + idxTheta: Array.from(rotaryEmbed.idxTheta), + }; + }; + resolve(); + }); + }); + }); + + // Run the test function in the browser context + const result = await page.evaluate(() => window.runRotaryEmbeddingTest()); + + // Validate shapes + expect(result.inputShape).toEqual([4, 8]); + expect(result.outputShape).toEqual([4, 8]); + + // Validate theta calculation + const expectedTheta = Array.from([1.0, 0.1, 0.01, 0.001]); + + result.theta.forEach((value, idx) => { + expect(value).toBeCloseTo(expectedTheta[idx], 4); + }); + + const expectedIdxTheta = Array.from([ + 0.0, 0.0, 0.0, 0.0, 1.0, 1.0e-1, 1.0e-2, 1.0e-3, 2.0, 2.0e-1, 2.0e-2, + 2.0e-3, 3.0, 3.0e-1, 3.0e-2, 3.0e-3, + ]); + + result.idxTheta.forEach((value, idx) => { + expect(value).toBeCloseTo(expectedIdxTheta[idx], 4); + }); + + // Validate cache calculation + const expectedCacheOutput = Array.from([ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0e-1, 1.0e-2, 1.0e-3, 1.0, + 1.0e-1, 1.0e-2, 1.0e-3, 2.0, 2.0e-1, 2.0e-2, 2.0e-3, 2.0, 2.0e-1, 2.0e-2, + 2.0e-3, 3.0, 3.0e-1, 3.0e-2, 3.0e-3, 3.0, 3.0e-1, 3.0e-2, 3.0e-3, + ]); + + result.cache.forEach((value, idx) => { + expect(value).toBeCloseTo(expectedCacheOutput[idx], 4); + }); + + const expectedRotatedOutput = Array.from([ + 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, -0.3968, 0.2286, 0.392, 0.4991, + 0.4925, 0.7265, 0.804, 0.9005, -0.7614, 0.2331, 0.4819, 0.598, -0.0185, + 0.8635, 0.9098, 1.0012, -0.5089, 0.2117, 0.5697, 0.6967, -0.7355, 1.0076, + 1.0175, 1.1021, + ]); + + result.outputData.forEach((value, idx) => { + expect(value).toBeCloseTo(expectedRotatedOutput[idx], 4); + }); + + await page.close(); +}); + +declare global { + interface Window { + runRotaryEmbeddingTest: () => Promise<{ + inputShape: number[]; + inputData: number[]; + outputShape: number[]; + outputData: number[]; + theta: number[]; + cache: number[]; + idxTheta: number[]; + }>; + } +} From c77f91342fe7dc134f5eacee1cd7b2a487786c94 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Thu, 19 Dec 2024 19:08:44 -0500 Subject: [PATCH 16/23] feat: swiglu mlp --- src/index.ts | 1 + src/layers/mlp.ts | 129 +++++++++++++++++++++++++++++++++ tests/integration/mlp.test.ts | 131 ++++++++++++++++++++++++++++++++++ 3 files changed, 261 insertions(+) create mode 100644 src/layers/mlp.ts create mode 100644 tests/integration/mlp.test.ts diff --git a/src/index.ts b/src/index.ts index c0a965d..fdbf50f 100644 --- a/src/index.ts +++ b/src/index.ts @@ -12,4 +12,5 @@ export * from "./layers/module.js"; export * from "./layers/embedding.js"; export * from "./layers/linear.js"; export * from "./layers/norm.js"; +export * from "./layers/mlp.js"; export * from "./ops/div.js"; diff --git a/src/layers/mlp.ts b/src/layers/mlp.ts new file mode 100644 index 0000000..a338091 --- /dev/null +++ b/src/layers/mlp.ts @@ -0,0 +1,129 @@ +import { Tensor } from "../tensor/tensor.js"; +import { Module } from "./module.js"; +import { Linear } from "./linear.js"; + +type ActivationType = "relu" | "silu" | "gelu" | "swiglu" | "none"; + +export class MLP extends Module { + up: Linear; // Project up to larger dimension + down: Linear; // Project back down + activation: ActivationType; + + constructor( + dim: number, // input/output dimension + hiddenDim: number, // hidden dimension + activation: ActivationType = "relu", + ) { + super("mlp"); + + // For SwiGLU, we need double the hidden dimension for gating + const actualHiddenDim = activation === "swiglu" ? hiddenDim * 2 : hiddenDim; + + this.up = new Linear(dim, actualHiddenDim); + this.down = new Linear(hiddenDim, dim); + this.activation = activation; + } + + private async gelu(x: Tensor): Promise<[Tensor, number]> { + // GELU(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3))) + const sqrt2OverPi = Math.sqrt(2 / Math.PI); + + // Calculate x^3 + const [xSquared] = await x.mul(x); + const [xCubed] = await xSquared.mul(x); + + // Calculate 0.044715 * x^3 + const [scaledCube] = await xCubed.mul( + Tensor.full(x.shape, 0.044715, false), + ); + + // Add x to the scaled cube + const [innerSum] = await x.add(scaledCube); + + // Multiply by sqrt(2/π) + const [scaled] = await innerSum.mul( + Tensor.full(x.shape, sqrt2OverPi, false), + ); + + // Calculate tanh using (e^x - e^-x)/(e^x + e^-x) + const [exp] = await scaled.exp(); + const [negScaled] = await scaled.mul(Tensor.full(x.shape, -1, false)); + const [negExp] = await negScaled.exp(); + + const [numerator] = await exp.sub(negExp); + const [denominator] = await exp.add(negExp); + + const [tanh] = await numerator.div(denominator); + + // Add 1 to tanh result + const [tanhPlusOne] = await tanh.add(Tensor.full(x.shape, 1, false)); + + // Multiply by x + const [xTimesSum] = await x.mul(tanhPlusOne); + + // Multiply by 0.5 for final result + return xTimesSum.mul(Tensor.full(x.shape, 0.5, false)); + } + + private async silu(x: Tensor): Promise<[Tensor, number]> { + const [negX] = await x.mul(Tensor.full(x.shape, -1, false)); + const [expNegX] = await negX.exp(); + const [onePlusExpNegX] = await expNegX.add(Tensor.full(x.shape, 1, false)); + + const [sigmoid] = await Tensor.full(x.shape, 1, false).div(onePlusExpNegX); + return x.mul(sigmoid); + } + + private async applyActivation(x: Tensor): Promise<[Tensor, number]> { + switch (this.activation) { + case "relu": + return x.relu(); + case "silu": + return this.silu(x); + case "gelu": + return this.gelu(x); + case "swiglu": { + // Split the tensor in half for gate and value paths + const halfSize = Math.floor(x.shape[x.shape.length - 1] / 2); + const [gate, value] = await Promise.all([ + x.slice(":", [0, halfSize]), + x.slice(":", [halfSize, x.shape[x.shape.length - 1]]), + ]); + const [gateActivated] = await this.silu(gate); + return gateActivated.mul(value); + } + case "none": + return [x, -1]; + default: + throw new Error(`Unknown activation type: ${this.activation}`); + } + } + + async forward(...inputs: [Tensor]): Promise<[Tensor]> { + const [input] = inputs; + + // Project up to hidden dimension + const [upProjected] = await this.up.forward(input); + + // Apply activation + const [activated] = await this.applyActivation(upProjected); + + // Project back down + return this.down.forward(activated); + } + + // Helper method for creating standard configurations + static create(config: { + dim: number; // input/output dimension + hiddenMul?: number; // multiplier for hidden dimension (default 4) + activation?: ActivationType; + }): MLP { + const { + dim, + hiddenMul = 4, // typical transformer uses 4x dimension for FFN + activation = "relu", + } = config; + + return new MLP(dim, dim * hiddenMul, activation); + } +} diff --git a/tests/integration/mlp.test.ts b/tests/integration/mlp.test.ts new file mode 100644 index 0000000..e329c1b --- /dev/null +++ b/tests/integration/mlp.test.ts @@ -0,0 +1,131 @@ +import { test, expect } from "@playwright/test"; + +test("MLP with SwiGLU activation forward pass with known values", async ({ + page, +}) => { + await page.goto("http://localhost:8080"); + + page.on("console", (msg) => { + console.log(msg); + }); + + // Inject test function + await page.evaluate(() => { + return new Promise((resolve) => { + // @ts-expect-error ignore error for tests + import("/dist/bundle.js").then((module) => { + const { Tensor, MLP } = module; + + window.runSwiGLUTest = async function () { + // Create sample input tensor with known values + const inputDim = 4; + const hiddenDim = 8; // Will be doubled internally for SwiGLU + const seqLength = 2; + + const input = new Tensor( + new Float32Array([ + 0.1, + 0.2, + 0.3, + 0.4, // First sequence + 0.5, + 0.6, + 0.7, + 0.8, // Second sequence + ]), + [seqLength, inputDim], + false, + ); + + // Create MLP with SwiGLU activation + const mlp = new MLP(inputDim, hiddenDim, "swiglu"); + + // Set known weights and biases for reproducibility + mlp.up.weight = new Tensor( + new Float32Array([ + // First half for gate + 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.2, 0.3, 0.4, 0.5, 0.6, + 0.7, 0.8, 0.9, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 0.4, 0.5, + 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, + // Second half for value + 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.2, + 0.2, 0.2, 0.2, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.4, 0.4, + 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, + ]), + [inputDim, hiddenDim * 2], + true, + ); + + mlp.up.bias = new Tensor( + new Float32Array([ + // Gate bias + 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, + // Value bias + 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, + ]), + [hiddenDim * 2], + true, + ); + + mlp.down.weight = new Tensor( + new Float32Array([ + 0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.3, 0.4, 0.5, 0.6, 0.4, + 0.5, 0.6, 0.7, 0.5, 0.6, 0.7, 0.8, 0.6, 0.7, 0.8, 0.9, 0.7, 0.8, + 0.9, 1.0, 0.8, 0.9, 1.0, 1.1, + ]), + [hiddenDim, inputDim], + true, + ); + + mlp.down.bias = new Tensor( + new Float32Array([0.1, 0.1, 0.1, 0.1]), + [inputDim], + true, + ); + + // Forward pass + const [output] = await mlp.forward(input); + + return { + inputShape: input.shape, + inputData: Array.from(input.data), + outputShape: output.shape, + outputData: Array.from(output.data), + }; + }; + resolve(); + }); + }); + }); + + // Run the test function in the browser context + const result = await page.evaluate(() => window.runSwiGLUTest()); + + // Validate shapes + expect(result.inputShape).toEqual([2, 4]); // [batch_size, input_dim] + expect(result.outputShape).toEqual([2, 4]); // [batch_size, input_dim] + console.log("result.outputData:", result.outputData.toString()); + + // Expected values computed using the same architecture with PyTorch + const expectedOutput = [ + 0.7809, 0.9126, 1.0443, 1.176, 5.0712, 5.9646, 6.8581, 7.7515, + ]; + + // Validate output values + result.outputData.forEach((value, idx) => { + expect(value).toBeCloseTo(expectedOutput[idx], 4); + }); + + await page.close(); +}); + +declare global { + interface Window { + runSwiGLUTest: () => Promise<{ + inputShape: number[]; + inputData: number[]; + outputShape: number[]; + outputData: number[]; + }>; + } +} From 6a4cef34a3074f03fd507998bd93faf3978d4573 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Fri, 20 Dec 2024 01:04:24 -0500 Subject: [PATCH 17/23] feat: mha attention!! --- src/index.ts | 3 +- src/layers/attention.ts | 107 ++++++++++++++++++++++ src/ops/div.ts | 99 ++++++++++++++++----- src/tensor/tensor.ts | 32 +++++++ tests/integration/attention.test.ts | 132 ++++++++++++++++++++++++++++ 5 files changed, 351 insertions(+), 22 deletions(-) create mode 100644 src/layers/attention.ts create mode 100644 tests/integration/attention.test.ts diff --git a/src/index.ts b/src/index.ts index fdbf50f..db66762 100644 --- a/src/index.ts +++ b/src/index.ts @@ -7,10 +7,11 @@ export * from "./ops/exp2.js"; export * from "./ops/log2.js"; export * from "./ops/ln.js"; export * from "./ops/relu.js"; +export * from "./ops/div.js"; export * from "./autograd/function.js"; export * from "./layers/module.js"; export * from "./layers/embedding.js"; export * from "./layers/linear.js"; export * from "./layers/norm.js"; export * from "./layers/mlp.js"; -export * from "./ops/div.js"; +export * from "./layers/attention.js"; diff --git a/src/layers/attention.ts b/src/layers/attention.ts new file mode 100644 index 0000000..a88670f --- /dev/null +++ b/src/layers/attention.ts @@ -0,0 +1,107 @@ +import { Tensor } from "../tensor/tensor.js"; +import { Module } from "./module.js"; +import { Linear } from "./linear.js"; + +export class MultiHeadAttention extends Module { + qkv: Linear; // Combined projection for Query, Key, Value + output: Linear; // Output projection + num_heads: number; + head_dim: number; + hidden_dim: number; + + constructor(hidden_dim: number, num_heads: number) { + super("multihead_attention"); + + this.num_heads = num_heads; + this.head_dim = Math.floor(hidden_dim / num_heads); + this.hidden_dim = hidden_dim; + + if (this.head_dim * num_heads !== hidden_dim) { + throw new Error( + `Hidden dimension ${hidden_dim} must be divisible by number of heads ${num_heads}`, + ); + } + + // Combined QKV projection + // Projects to 3x hidden_dim for Q, K, V + this.qkv = new Linear(hidden_dim, hidden_dim * 3); + + // Output projection + this.output = new Linear(hidden_dim, hidden_dim); + } + + private async reshapeToHeads(tensor: Tensor): Promise { + const heads: Tensor[] = []; + + // Each head will be (seqlen, head_dim) + for (let i = 0; i < this.num_heads; i++) { + const start = i * this.head_dim; + const end = start + this.head_dim; + const head = await tensor.slice(":", [start, end]); + heads.push(head); + } + + return heads; + } + + private async scaledDotProductAttention( + query: Tensor, + key: Tensor, + value: Tensor, + ): Promise<[Tensor, number]> { + // Scale factor is 1/sqrt(head_dim) + const scale = 1 / Math.sqrt(this.head_dim); + const scaleTensor = Tensor.full(query.shape, scale, false); + + // Compute attention scores + const [scores] = await query.matmul(key.transpose()); + const [scaledScores] = await scores.mul(scaleTensor); + + // Softmax implementation + const [expScores] = await scaledScores.exp(); + const sumExp = await expScores.sum([1]); + + const [attention] = await expScores.div(sumExp); + + // Apply attention to values + return attention.matmul(value); + } + + async forward(input: Tensor): Promise<[Tensor]> { + // Project input to Q, K, V + const [qkv] = await this.qkv.forward(input); + + // Split into Q, K, V + const query = await qkv.slice(":", [0, this.hidden_dim]); + const key = await qkv.slice(":", [this.hidden_dim, this.hidden_dim * 2]); + const value = await qkv.slice(":", [ + this.hidden_dim * 2, + this.hidden_dim * 3, + ]); + + // Split each of Q, K, V into heads + const queryHeads = await this.reshapeToHeads(query); + const keyHeads = await this.reshapeToHeads(key); + const valueHeads = await this.reshapeToHeads(value); + + // Compute attention for each head + const headOutputs: Tensor[] = []; + for (let i = 0; i < this.num_heads; i++) { + const [headOutput] = await this.scaledDotProductAttention( + queryHeads[i], + keyHeads[i], + valueHeads[i], + ); + headOutputs.push(headOutput); + } + + // Concatenate heads + let concatOutput = headOutputs[0]; + for (let i = 1; i < headOutputs.length; i++) { + concatOutput = await concatOutput.concat(headOutputs[i], 1); + } + + // Final output projection + return this.output.forward(concatOutput); + } +} diff --git a/src/ops/div.ts b/src/ops/div.ts index 3f03ab2..828d84d 100644 --- a/src/ops/div.ts +++ b/src/ops/div.ts @@ -6,32 +6,89 @@ export class Div extends BinaryOp { protected readonly shader: string = divShader; validateShapes(a: Tensor, b: Tensor): Tensor { - if (!a.shape.every((value, index) => value === b.shape[index])) { - if (b.shape.length === 1 && b.shape[0] === 1) { - // Broadcast scalar - b = Tensor.full(a.shape, b.data[0], b.requires_grad); - } else if (b.shape.length === 1 && b.shape[0] === a.shape[1]) { - // Broadcast [m] to [n, m] - b = Tensor.broadcast(b, a.shape[0], b.requires_grad); - } else if (b.shape.length === 2 && b.shape[1] === 1) { - // Broadcast [n, 1] to [n, m] - const newShape = [b.shape[0], a.shape[1]]; - console.log("Broadcasting [n,1] to shape:", newShape); - const newData = new Float32Array(newShape[0] * newShape[1]); - // Repeat the values across the second dimension - for (let i = 0; i < b.shape[0]; i++) { - for (let j = 0; j < a.shape[1]; j++) { - newData[i * a.shape[1] + j] = b.data[i]; - } - } - b = new Tensor(newData, newShape, b.requires_grad); + // Handle scalar case first + if (b.shape.length === 1 && b.shape[0] === 1) { + return Tensor.full(a.shape, b.data[0], b.requires_grad); + } + + // Get dimensions of both tensors + const dimA = a.shape.length; + const dimB = b.shape.length; + + // Calculate the number of dimensions in the output + const maxDim = Math.max(dimA, dimB); + + // Pad shapes with 1s from the left to match max dimensions + const paddedA = Array(maxDim - dimA) + .fill(1) + .concat(a.shape); + const paddedB = Array(maxDim - dimB) + .fill(1) + .concat(b.shape); + + // Check if shapes can be broadcast + const outputShape = []; + for (let i = 0; i < maxDim; i++) { + if (paddedA[i] === paddedB[i]) { + outputShape.push(paddedA[i]); + } else if (paddedA[i] === 1) { + outputShape.push(paddedB[i]); + } else if (paddedB[i] === 1) { + outputShape.push(paddedA[i]); } else { throw new Error( - `Incompatible shapes for Div: ${a.shape} and ${b.shape}`, + `Incompatible shapes for broadcasting: ${a.shape} and ${b.shape}`, ); } } - return b; + + // If shapes are already compatible, return original tensor + if (outputShape.every((dim, i) => dim === b.shape[i])) { + return b; + } + + // Create new broadcasted tensor + const newSize = outputShape.reduce((acc, dim) => acc * dim, 1); + const newData = new Float32Array(newSize); + + // For a tensor of shape [n] being broadcast to [n, m], + // we want to repeat each element m times consecutively + if (b.shape.length === 1 && outputShape.length === 2) { + const n = b.shape[0]; + const m = outputShape[1]; + + for (let i = 0; i < n; i++) { + for (let j = 0; j < m; j++) { + newData[i * m + j] = b.data[i]; + } + } + } else { + // General case for broadcasting across multiple dimensions + for (let i = 0; i < newSize; i++) { + // Convert flat index to coordinates + let remaining = i; + const coords = []; + for (const dim of outputShape) { + coords.push(remaining % dim); + remaining = Math.floor(remaining / dim); + } + coords.reverse(); + + // Map to input tensor coordinates + let inputIdx = 0; + let stride = 1; + for (let dim = b.shape.length - 1; dim >= 0; dim--) { + const outputDim = dim + (outputShape.length - b.shape.length); + const coord = coords[outputDim] % b.shape[dim]; + inputIdx += coord * stride; + stride *= b.shape[dim]; + } + + newData[i] = b.data[inputIdx]; + } + } + + return new Tensor(newData, outputShape, b.requires_grad); } async backward(grad_output: Tensor): Promise { diff --git a/src/tensor/tensor.ts b/src/tensor/tensor.ts index f8039e8..2910ae5 100644 --- a/src/tensor/tensor.ts +++ b/src/tensor/tensor.ts @@ -139,6 +139,38 @@ export class Tensor { return new Tensor(result, shape, this.requires_grad); } + async sum(dims: number[]): Promise { + const shape = this.shape.slice(); + + dims.sort((a, b) => b - a); // Sort in descending order to remove correctly + dims.forEach((dim) => shape.splice(dim, 1)); + if (shape.length === 0) shape.push(1); + + const result = new Float32Array(shape.reduce((a, b) => a * b, 1)); + + // For 1D case + if (this.shape.length === 1 && dims.includes(0)) { + let sum = 0; + for (let i = 0; i < this.data.length; i++) { + sum += this.data[i]; + } + result[0] = sum; + return new Tensor(result, shape, this.requires_grad); + } + + // For higher dimensions (keeping existing logic for 2D) + const stride = this.shape[1]; + for (let i = 0; i < this.shape[0]; i++) { + let sum = 0; + for (let j = 0; j < stride; j++) { + sum += this.data[i * stride + j]; + } + result[i] = sum; + } + + return new Tensor(result, shape, this.requires_grad); + } + async variance(dims: number[]): Promise { const mean = await this.mean(dims); const shape = this.shape.slice(); diff --git a/tests/integration/attention.test.ts b/tests/integration/attention.test.ts new file mode 100644 index 0000000..c63f7b3 --- /dev/null +++ b/tests/integration/attention.test.ts @@ -0,0 +1,132 @@ +import { test, expect } from "@playwright/test"; + +test("MultiHeadAttention forward pass with known values", async ({ page }) => { + await page.goto("http://localhost:8080"); + + page.on("console", (msg) => { + console.log(msg); + }); + + // Inject test function + await page.evaluate(() => { + return new Promise((resolve) => { + // @ts-expect-error ignore error for tests + import("/dist/bundle.js").then((module) => { + const { Tensor, MultiHeadAttention } = module; + + window.runAttentionTest = async function () { + // Create sample input tensor with known values + const seqLength = 2; + const hiddenDim = 4; + const numHeads = 2; + const headDim = hiddenDim / numHeads; + + const input = new Tensor( + new Float32Array([ + 0.1, + 0.2, + 0.3, + 0.4, // First sequence + 0.5, + 0.6, + 0.7, + 0.8, // Second sequence + ]), + [seqLength, hiddenDim], + false, + ); + + // Create MultiHeadAttention + const attention = new MultiHeadAttention(hiddenDim, numHeads); + + // Set known weights and biases for reproducibility + attention.qkv.weight = new Tensor( + new Float32Array([ + // Q weights + 0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.3, 0.4, 0.5, 0.6, 0.4, + 0.5, 0.6, 0.7, + // K weights + 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3, 0.3, 0.3, 0.4, + 0.4, 0.4, 0.4, + // V weights + 0.5, 0.5, 0.5, 0.5, 0.6, 0.6, 0.6, 0.6, 0.7, 0.7, 0.7, 0.7, 0.8, + 0.8, 0.8, 0.8, + ]), + [hiddenDim, hiddenDim * 3], + true, + ); + + attention.qkv.bias = new Tensor( + new Float32Array([ + // Q bias + 0.1, 0.1, 0.1, 0.1, + // K bias + 0.2, 0.2, 0.2, 0.2, + // V bias + 0.3, 0.3, 0.3, 0.3, + ]), + [hiddenDim * 3], + true, + ); + + attention.output.weight = new Tensor( + new Float32Array([ + 0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.3, 0.4, 0.5, 0.6, 0.4, + 0.5, 0.6, 0.7, + ]), + [hiddenDim, hiddenDim], + true, + ); + + attention.output.bias = new Tensor( + new Float32Array([0.1, 0.1, 0.1, 0.1]), + [hiddenDim], + true, + ); + + // Forward pass + const [output] = await attention.forward(input); + + return { + inputShape: input.shape, + inputData: Array.from(input.data), + outputShape: output.shape, + outputData: Array.from(output.data), + }; + }; + resolve(); + }); + }); + }); + + // Run the test function in the browser context + const result = await page.evaluate(() => window.runAttentionTest()); + + // Validate shapes + expect(result.inputShape).toEqual([2, 4]); // [seq_len, hidden_dim] + expect(result.outputShape).toEqual([2, 4]); // [seq_len, hidden_dim] + console.log("result.outputData:", result.outputData.toString()); + + // Expected values computed using the same architecture with PyTorch + const expectedOutput = [ + 1.4622, 1.9985, 2.5347, 3.0709, 1.5701, 2.1462, 2.7224, 3.2985, + ]; + + // Validate output values + result.outputData.forEach((value, idx) => { + expect(value).toBeCloseTo(expectedOutput[idx], 4); + }); + + await page.close(); +}); + +declare global { + interface Window { + runAttentionTest: () => Promise<{ + inputShape: number[]; + inputData: number[]; + outputShape: number[]; + outputData: number[]; + }>; + } +} From fad15ff1d81a576eb8b30d871987863596fea0da Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Fri, 20 Dec 2024 01:05:01 -0500 Subject: [PATCH 18/23] style: comment --- src/layers/attention.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/src/layers/attention.ts b/src/layers/attention.ts index a88670f..488e9e2 100644 --- a/src/layers/attention.ts +++ b/src/layers/attention.ts @@ -85,6 +85,7 @@ export class MultiHeadAttention extends Module { const valueHeads = await this.reshapeToHeads(value); // Compute attention for each head + // this will be slow, we should create bmm const headOutputs: Tensor[] = []; for (let i = 0; i < this.num_heads; i++) { const [headOutput] = await this.scaledDotProductAttention( From 8a2be80f0ece47d1cdcbaff91f2b4408ac8b5539 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Fri, 20 Dec 2024 09:45:04 -0500 Subject: [PATCH 19/23] fix: remove unused --- tests/integration/attention.test.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integration/attention.test.ts b/tests/integration/attention.test.ts index c63f7b3..b8ca9bf 100644 --- a/tests/integration/attention.test.ts +++ b/tests/integration/attention.test.ts @@ -19,7 +19,6 @@ test("MultiHeadAttention forward pass with known values", async ({ page }) => { const seqLength = 2; const hiddenDim = 4; const numHeads = 2; - const headDim = hiddenDim / numHeads; const input = new Tensor( new Float32Array([ From 2312317dea72525bf1c23fb37f4e7c94fa9c949a Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Fri, 20 Dec 2024 09:47:48 -0500 Subject: [PATCH 20/23] feat: pow --- src/tensor/tensor.ts | 8 ++++++ tests/unit/tensor.test.ts | 54 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/src/tensor/tensor.ts b/src/tensor/tensor.ts index 2910ae5..315794e 100644 --- a/src/tensor/tensor.ts +++ b/src/tensor/tensor.ts @@ -171,6 +171,14 @@ export class Tensor { return new Tensor(result, shape, this.requires_grad); } + async pow(p: number): Promise<[Tensor]> { + const result = new Float32Array(this.data.length); + for (let i = 0; i < this.data.length; i++) { + result[i] = this.data[i] ** p; + } + return [new Tensor(result, this.shape.slice(), this.requires_grad)]; + } + async variance(dims: number[]): Promise { const mean = await this.mean(dims); const shape = this.shape.slice(); diff --git a/tests/unit/tensor.test.ts b/tests/unit/tensor.test.ts index d6e19be..f57e01e 100644 --- a/tests/unit/tensor.test.ts +++ b/tests/unit/tensor.test.ts @@ -263,4 +263,58 @@ describe("Tensor", () => { ]); }); }); + describe("pow", () => { + it("should correctly square a tensor (power of 2)", async () => { + const tensor = new Tensor(new Float32Array([1, 2, 3, 4]), [2, 2]); + const [result] = await tensor.pow(2); + + expect(result.shape).toEqual([2, 2]); + expect(Array.from(result.data)).toEqual([1, 4, 9, 16]); + }); + + it("should correctly calculate square root (power of 0.5)", async () => { + const tensor = new Tensor(new Float32Array([1, 4, 9, 16]), [2, 2]); + const [result] = await tensor.pow(0.5); + + expect(result.shape).toEqual([2, 2]); + expect(Array.from(result.data)).toEqual([1, 2, 3, 4]); + }); + + it("should handle negative numbers with even powers", async () => { + const tensor = new Tensor(new Float32Array([-2, -3, 2, 3]), [2, 2]); + const [result] = await tensor.pow(2); + + expect(result.shape).toEqual([2, 2]); + expect(Array.from(result.data)).toEqual([4, 9, 4, 9]); + }); + + it("should preserve requires_grad", async () => { + const tensor = new Tensor(new Float32Array([1, 2, 3, 4]), [2, 2], true); + const [result] = await tensor.pow(2); + + expect(result.requires_grad).toBe(true); + }); + + it("should handle power of 1 (identity)", async () => { + const tensor = new Tensor(new Float32Array([1, 2, 3, 4]), [2, 2]); + const [result] = await tensor.pow(1); + + expect(Array.from(result.data)).toEqual([1, 2, 3, 4]); + }); + + it("should handle power of 0 (all ones)", async () => { + const tensor = new Tensor(new Float32Array([1, 2, 3, 4]), [2, 2]); + const [result] = await tensor.pow(0); + + expect(Array.from(result.data)).toEqual([1, 1, 1, 1]); + }); + + it("should maintain shape for 1D tensors", async () => { + const tensor = new Tensor(new Float32Array([1, 2, 3, 4]), [4]); + const [result] = await tensor.pow(2); + + expect(result.shape).toEqual([4]); + expect(Array.from(result.data)).toEqual([1, 4, 9, 16]); + }); + }); }); From 77ec7ba66b62aa4e018681abfc11b54a89845651 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Fri, 20 Dec 2024 10:03:10 -0500 Subject: [PATCH 21/23] feat: norm and fix/test sum --- src/tensor/tensor.ts | 54 +++++++++++++--- tests/unit/tensor.test.ts | 127 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 173 insertions(+), 8 deletions(-) diff --git a/src/tensor/tensor.ts b/src/tensor/tensor.ts index 315794e..bf4c68b 100644 --- a/src/tensor/tensor.ts +++ b/src/tensor/tensor.ts @@ -142,12 +142,26 @@ export class Tensor { async sum(dims: number[]): Promise { const shape = this.shape.slice(); - dims.sort((a, b) => b - a); // Sort in descending order to remove correctly + // Sort dimensions in descending order for correct removal + dims.sort((a, b) => b - a); dims.forEach((dim) => shape.splice(dim, 1)); if (shape.length === 0) shape.push(1); const result = new Float32Array(shape.reduce((a, b) => a * b, 1)); + // Special case: if we're summing all dimensions, just sum everything + if ( + dims.length === this.shape.length || + (this.shape.length === 2 && dims.includes(0) && dims.includes(1)) + ) { + let sum = 0; + for (let i = 0; i < this.data.length; i++) { + sum += this.data[i]; + } + result[0] = sum; + return new Tensor(result, [1], this.requires_grad); + } + // For 1D case if (this.shape.length === 1 && dims.includes(0)) { let sum = 0; @@ -158,14 +172,31 @@ export class Tensor { return new Tensor(result, shape, this.requires_grad); } - // For higher dimensions (keeping existing logic for 2D) - const stride = this.shape[1]; - for (let i = 0; i < this.shape[0]; i++) { - let sum = 0; - for (let j = 0; j < stride; j++) { - sum += this.data[i * stride + j]; + // For 2D case + if (this.shape.length === 2) { + if (dims.includes(0)) { + // Sum along first dimension (vertically) + const cols = this.shape[1]; + const rows = this.shape[0]; + for (let j = 0; j < cols; j++) { + let sum = 0; + for (let i = 0; i < rows; i++) { + sum += this.data[i * cols + j]; + } + result[j] = sum; + } + } else if (dims.includes(1)) { + // Sum along second dimension (horizontally) + const cols = this.shape[1]; + const rows = this.shape[0]; + for (let i = 0; i < rows; i++) { + let sum = 0; + for (let j = 0; j < cols; j++) { + sum += this.data[i * cols + j]; + } + result[i] = sum; + } } - result[i] = sum; } return new Tensor(result, shape, this.requires_grad); @@ -179,6 +210,13 @@ export class Tensor { return [new Tensor(result, this.shape.slice(), this.requires_grad)]; } + async norm(p: number = 2, dim: number = 0): Promise<[Tensor]> { + const [norm] = await this.pow(p); + const sumNorm = await norm.sum([dim]); + const [rootNorm] = await sumNorm.pow(1 / p); + return [rootNorm]; + } + async variance(dims: number[]): Promise { const mean = await this.mean(dims); const shape = this.shape.slice(); diff --git a/tests/unit/tensor.test.ts b/tests/unit/tensor.test.ts index f57e01e..775a856 100644 --- a/tests/unit/tensor.test.ts +++ b/tests/unit/tensor.test.ts @@ -317,4 +317,131 @@ describe("Tensor", () => { expect(Array.from(result.data)).toEqual([1, 4, 9, 16]); }); }); + describe("sum", () => { + it("should sum 1D tensor correctly", async () => { + const tensor = new Tensor(new Float32Array([1, 2, 3, 4]), [4]); + const result = await tensor.sum([0]); + + expect(result.shape).toEqual([1]); + expect(Array.from(result.data)[0]).toBe(10); // 1 + 2 + 3 + 4 + }); + + it("should sum 2D tensor along first dimension", async () => { + const tensor = new Tensor(new Float32Array([1, 2, 3, 4, 5, 6]), [2, 3]); + const result = await tensor.sum([0]); + + expect(result.shape).toEqual([3]); + expect(Array.from(result.data)).toEqual([5, 7, 9]); // [1+4, 2+5, 3+6] + }); + + it("should sum 2D tensor along second dimension", async () => { + const tensor = new Tensor(new Float32Array([1, 2, 3, 4, 5, 6]), [2, 3]); + const result = await tensor.sum([1]); + + expect(result.shape).toEqual([2]); + expect(Array.from(result.data)).toEqual([6, 15]); // [1+2+3, 4+5+6] + }); + + it("should preserve requires_grad", async () => { + const tensor = new Tensor(new Float32Array([1, 2, 3, 4]), [4], true); + const result = await tensor.sum([0]); + + expect(result.requires_grad).toBe(true); + }); + + it("should handle tensor with all zeros", async () => { + const tensor = new Tensor(new Float32Array([0, 0, 0, 0]), [2, 2]); + const result = await tensor.sum([1]); + + expect(result.shape).toEqual([2]); + expect(Array.from(result.data)).toEqual([0, 0]); + }); + + it("should handle single element tensor", async () => { + const tensor = new Tensor(new Float32Array([5]), [1]); + const result = await tensor.sum([0]); + + expect(result.shape).toEqual([1]); + expect(Array.from(result.data)[0]).toBe(5); + }); + + it("should handle summing along multiple dimensions", async () => { + const tensor = new Tensor(new Float32Array([1, 2, 3, 4, 5, 6]), [2, 3]); + const result = await tensor.sum([0, 1]); + + expect(result.shape).toEqual([1]); + expect(Array.from(result.data)[0]).toBe(21); // sum of all elements + }); + + it("should maintain correct shape after summing", async () => { + const tensor = new Tensor( + new Float32Array([1, 2, 3, 4, 5, 6, 7, 8, 9]), + [3, 3], + ); + const result = await tensor.sum([0]); + + expect(result.shape).toEqual([3]); + expect(Array.from(result.data)).toEqual([12, 15, 18]); // column sums + }); + }); + + it("should default requires_grad to false", () => { + const data = new Float32Array([1, 2, 3, 4]); + const shape = [2, 2]; + const tensor = new Tensor(data, shape); + + expect(tensor.requires_grad).toBe(false); + }); + + it("should throw an error if the number of elements in data and shape are different", () => { + const data = new Float32Array([1, 2, 3]); + const shape = [2, 2]; + + expect(() => new Tensor(data, shape)).toThrow("Incompatible shapes"); + }); + + describe("norm", () => { + it("should calculate Euclidean norm along default dimension", async () => { + const tensor = new Tensor(new Float32Array([3, 4]), [2]); + const [result] = await tensor.norm(); + + // Should be sqrt(3^2 + 4^2) = 5 + expect(result.shape).toEqual([1]); + expect(Array.from(result.data)[0]).toBeCloseTo(5); + }); + + it("should handle 2D tensor Euclidean norm", async () => { + const tensor = new Tensor(new Float32Array([3, 4, 6, 8]), [2, 2]); + const [result] = await tensor.norm(); + + console.log("result:", result.data.toString()); + + // For dim=0: sqrt(3^2 + 6^2) and sqrt(4^2 + 8^2) + expect(result.shape).toEqual([2]); + expect(Array.from(result.data)[0]).toBeCloseTo(6.708203932499369); // sqrt(45) + expect(Array.from(result.data)[1]).toBeCloseTo(8.94427190999916); // sqrt(80) + }); + + it("should preserve requires_grad", async () => { + const tensor = new Tensor(new Float32Array([3, 4]), [2], true); + const [result] = await tensor.norm(); + + expect(result.requires_grad).toBe(true); + }); + + it("should handle tensor with all zeros", async () => { + const tensor = new Tensor(new Float32Array([0, 0, 0, 0]), [2, 2]); + const [result] = await tensor.norm(); + + expect(Array.from(result.data)).toEqual([0, 0]); + }); + + it("should handle 1D tensor with single element", async () => { + const tensor = new Tensor(new Float32Array([5]), [1]); + const [result] = await tensor.norm(); + + expect(result.shape).toEqual([1]); + expect(Array.from(result.data)[0]).toBe(5); + }); + }); }); From d5d579656e4afb7879c8493bd766be0818c32e6e Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Fri, 20 Dec 2024 10:03:27 -0500 Subject: [PATCH 22/23] fix: use aregs for variable number of inputs --- src/layers/module.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/module.ts b/src/layers/module.ts index a7306d8..5a4a88b 100644 --- a/src/layers/module.ts +++ b/src/layers/module.ts @@ -15,5 +15,5 @@ export abstract class Module { * @param inputs - Input tensor(s) to the layer * @returns Output tensor(s) from the layer */ - abstract forward(...inputs: [Tensor]): Promise<[Tensor]>; + abstract forward(...args: Tensor[]): Promise<[Tensor]>; } From 9de39ad86f8b855166c2d3849a86a57e6b311c18 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Fri, 20 Dec 2024 11:46:41 -0500 Subject: [PATCH 23/23] feat: (almost) working nomic embed architecture --- src/autograd/function.ts | 11 +- src/index.ts | 1 + src/layers/attention.ts | 7 +- src/layers/embedding.ts | 2 +- src/layers/linear.ts | 4 +- src/layers/norm.ts | 8 +- src/model/nomic_embed.ts | 195 ++++++++++++++++++++++++++ src/ops/add.ts | 5 + src/tensor/tensor.ts | 49 +++++-- tests/integration/nomic_embed.test.ts | 127 +++++++++++++++++ 10 files changed, 379 insertions(+), 30 deletions(-) create mode 100644 src/model/nomic_embed.ts create mode 100644 tests/integration/nomic_embed.test.ts diff --git a/src/autograd/function.ts b/src/autograd/function.ts index b460d0c..5771c33 100644 --- a/src/autograd/function.ts +++ b/src/autograd/function.ts @@ -155,21 +155,22 @@ export abstract class BinaryOp extends AutogradFunction { pass.setPipeline(this.pipeline); pass.setBindGroup(0, bindGroup); + // TODO: set these as overrides in the layers/ops level since the kernels are different const WORKGROUP_SIZE = 16; const TILE_SIZE = 8; - const workgropuA = Math.ceil(a.shape[0] / (TILE_SIZE * WORKGROUP_SIZE)); - const workgropuB = Math.ceil(b.shape[1] / (TILE_SIZE * WORKGROUP_SIZE)); + const workgroupA = Math.ceil(a.shape[0] / (TILE_SIZE * WORKGROUP_SIZE)); + const workgroupB = Math.ceil(b.shape[1] / (TILE_SIZE * WORKGROUP_SIZE)); console.log( "a.shape[0]:", a.shape[0], "b.shape[1]:", b.shape[1], "launching workgroups", - workgropuA, + workgroupA, ",", - workgropuB, + workgroupB, ); - pass.dispatchWorkgroups(workgropuA, workgropuB); + pass.dispatchWorkgroups(workgroupA, workgroupB); pass.end(); const stagingBuffer = this.device.createBuffer({ diff --git a/src/index.ts b/src/index.ts index db66762..4eb468d 100644 --- a/src/index.ts +++ b/src/index.ts @@ -15,3 +15,4 @@ export * from "./layers/linear.js"; export * from "./layers/norm.js"; export * from "./layers/mlp.js"; export * from "./layers/attention.js"; +export * from "./model/nomic_embed.js" \ No newline at end of file diff --git a/src/layers/attention.ts b/src/layers/attention.ts index 488e9e2..236f727 100644 --- a/src/layers/attention.ts +++ b/src/layers/attention.ts @@ -51,8 +51,11 @@ export class MultiHeadAttention extends Module { ): Promise<[Tensor, number]> { // Scale factor is 1/sqrt(head_dim) const scale = 1 / Math.sqrt(this.head_dim); - const scaleTensor = Tensor.full(query.shape, scale, false); - + const scaleTensor = Tensor.full( + [query.shape[0], key.shape[0]], + scale, + false, + ); // Compute attention scores const [scores] = await query.matmul(key.transpose()); const [scaledScores] = await scores.mul(scaleTensor); diff --git a/src/layers/embedding.ts b/src/layers/embedding.ts index 54ae9ce..7608a5b 100644 --- a/src/layers/embedding.ts +++ b/src/layers/embedding.ts @@ -10,7 +10,7 @@ export class Embedding extends Module { this.vocab_size = vocab_size; this.emb_dim = emb_dim; - this.embedding = Tensor.randn([vocab_size, emb_dim], true); + this.embedding = Tensor.normal([vocab_size, emb_dim], true, 0.02); } async forward(...inputs: [Tensor]): Promise<[Tensor]> { diff --git a/src/layers/linear.ts b/src/layers/linear.ts index dcd0157..5aea6af 100644 --- a/src/layers/linear.ts +++ b/src/layers/linear.ts @@ -7,8 +7,8 @@ export class Linear extends Module { constructor(inputSize: number, outputSize: number) { super("linear"); - this.weight = Tensor.randn([inputSize, outputSize], true); - this.bias = Tensor.randn([outputSize], true); + this.weight = Tensor.normal([inputSize, outputSize], true, 0.02); + this.bias = Tensor.full([outputSize], 0, true); } async forward(...inputs: [Tensor]): Promise<[Tensor]> { diff --git a/src/layers/norm.ts b/src/layers/norm.ts index ec1be60..8adda7d 100644 --- a/src/layers/norm.ts +++ b/src/layers/norm.ts @@ -21,18 +21,14 @@ export class LayerNorm extends Module { // Calculate mean and reshape for broadcasting const mean = await x.mean(reduction_dims); + console.log("mean.data", mean.data.toString()); mean.shape = [mean.shape[0], 1]; // [2, 1] const variance = await x.variance(reduction_dims); variance.shape = [variance.shape[0], 1]; // [2, 1] - console.log("x shape:", x.shape); // [2, 3] - console.log("mean shape:", mean.shape); // [2, 1] - console.log("variance shape:", variance.shape); // [2, 1] - console.log("gamma shape:", this.gamma.shape); // [1, 3] - console.log("beta shape:", this.beta.shape); // [1, 3] - const [numerator] = await x.sub(mean); // [2, 3] + console.log("numerator.data", numerator.data.toString()); const [denominator] = await variance.add(this.eps); const sqrtDenom = await denominator.sqrt(); const [normalized] = await numerator.div(sqrtDenom); diff --git a/src/model/nomic_embed.ts b/src/model/nomic_embed.ts new file mode 100644 index 0000000..7a64acc --- /dev/null +++ b/src/model/nomic_embed.ts @@ -0,0 +1,195 @@ +import { Tensor } from "../tensor/tensor.js"; +import { Module } from "../layers/module.js"; +import { LayerNorm } from "../layers/norm.js"; +import { MultiHeadAttention } from "../layers/attention.js"; +import { MLP } from "../layers/mlp.js"; +import { Embedding } from "../layers/embedding.js"; + +export interface NomicEmbedConfig { + vocab_size: number; + hidden_size: number; + num_hidden_layers: number; + num_attention_heads: number; + intermediate_size: number; + hidden_act: string; + hidden_dropout_prob: number; + attention_probs_dropout_prob: number; + max_position_embeddings: number; + type_vocab_size: number; + initializer_range: number; + layer_norm_eps: number; + pad_token_id: number; + position_embedding_type: string; + use_cache: boolean; + classifier_dropout: number | null; + rotary_emb_fraction: number; + use_flash_attn: boolean; + qkv_proj_bias: boolean; + mlp_fc1_bias: boolean; + mlp_fc2_bias: boolean; + causal: boolean; +} + +class NomicBertEmbeddings extends Module { + private wordEmbeddings: Embedding; + private positionEmbeddings: Embedding | null; + private typeEmbeddings: Embedding | null; + private maxPositionEmbeddings: number; + private typeVocabSize: number; + + constructor(config: NomicEmbedConfig) { + super("bert_embeddings"); + + // Word embeddings + this.wordEmbeddings = new Embedding(config.vocab_size, config.hidden_size); + + // Position embeddings if using absolute positions + this.maxPositionEmbeddings = config.max_position_embeddings; + this.positionEmbeddings = + this.maxPositionEmbeddings > 0 && config.rotary_emb_fraction <= 0 + ? new Embedding(config.max_position_embeddings, config.hidden_size) + : null; + + // Token type embeddings if used + this.typeVocabSize = config.type_vocab_size; + this.typeEmbeddings = + this.typeVocabSize > 0 + ? new Embedding(config.type_vocab_size, config.hidden_size) + : null; + } + + async forward( + inputIds: Tensor, + positionIds?: Tensor, + tokenTypeIds?: Tensor, + inputsEmbeds?: Tensor, + ): Promise<[Tensor]> { + // Get word embeddings + let [embeddings] = inputsEmbeds + ? [inputsEmbeds] + : await this.wordEmbeddings.forward(inputIds); + + // Add token type embeddings if used + // if (this.typeEmbeddings && this.typeVocabSize > 0 && tokenTypeIds) { + // const [typeEmbeddings] = await this.typeEmbeddings.forward(tokenTypeIds); + // console.log("typeEmbeddings.data", typeEmbeddings.data.toString()); + // console.log("typeEmbeddings.shape", typeEmbeddings.shape); + // [embeddings] = await embeddings.add(typeEmbeddings); + // } + + return [embeddings]; + } +} + +class NomicBertLayer extends Module { + private attention: MultiHeadAttention; + private mlp: MLP; + private layerNorm1: LayerNorm; + private layerNorm2: LayerNorm; + + constructor(config: NomicEmbedConfig) { + super("bert_layer"); + this.attention = new MultiHeadAttention( + config.hidden_size, + config.num_attention_heads, + ); + this.mlp = new MLP(config.hidden_size, config.intermediate_size); + this.layerNorm1 = new LayerNorm( + [config.hidden_size], + config.layer_norm_eps, + ); + this.layerNorm2 = new LayerNorm( + [config.hidden_size], + config.layer_norm_eps, + ); + } + + async forward(...inputs: [Tensor]): Promise<[Tensor]> { + // Self-attention + const [hiddenStates] = inputs; + const [normed1] = await this.layerNorm1.forward(hiddenStates); + const [attnOutput] = await this.attention.forward(normed1); + const [residual1] = await hiddenStates.add(attnOutput); + + // MLP + const [normed2] = await this.layerNorm2.forward(residual1); + const [mlpOutput] = await this.mlp.forward(normed2); + const [residual2] = await residual1.add(mlpOutput); + return [residual2]; + } +} + +class NomicBertEncoder extends Module { + private layers: NomicBertLayer[]; + + constructor(config: NomicEmbedConfig) { + super("bert_encoder"); + this.layers = Array(config.num_hidden_layers) + .fill(null) + .map(() => new NomicBertLayer(config)); + } + + async forward(...args: Tensor[]): Promise<[Tensor]> { + let [hiddenStates, attentionMask] = args; + let currentOutput = hiddenStates; + + // Pass through each layer + for (const layer of this.layers) { + [currentOutput] = await layer.forward(currentOutput); + } + + return [currentOutput]; + } +} + +export class NomicEmbed extends Module { + private embeddings: NomicBertEmbeddings; + private encoder: NomicBertEncoder; + private emb_ln: LayerNorm; + + constructor(config: NomicEmbedConfig) { + super("nomic_embed"); + + // Initialize components + this.embeddings = new NomicBertEmbeddings(config); + this.encoder = new NomicBertEncoder(config); + this.emb_ln = new LayerNorm([config.hidden_size], config.layer_norm_eps); + } + + private async meanPooling( + modelOutput: Tensor, + attentionMask: Tensor, + ): Promise<[Tensor]> { + return [await modelOutput.mean([0])]; + } + + async forward(...args: Tensor[]): Promise<[Tensor]> { + // Get embeddings + const [inputIds, attentionMask, positionIds, tokenTypeIds] = args; + const [hidden] = await this.embeddings.forward( + inputIds, + positionIds, + tokenTypeIds, + ); + console.log("hidden.data", hidden.data.toString()); + + // Apply layer norm + const [normed] = await this.emb_ln.forward(hidden); + console.log("normed.data", normed.data.toString()); + + // Pass through encoder + const [encoded] = await this.encoder.forward(normed, attentionMask); + // Mean pooling + console.log("encoded.data", encoded.data.toString()); + const [pooled] = await this.meanPooling(encoded, attentionMask); + console.log("pooled.shape", pooled.shape); + + const [norm] = await pooled.norm(2, 0); + console.log("norm.shape", norm.shape); + console.log("norm", norm.data.toString()); + + const [pooledNormed] = await pooled.div(norm); + // Normalize embeddings + return [pooledNormed]; + } +} diff --git a/src/ops/add.ts b/src/ops/add.ts index 382601e..a0476fa 100644 --- a/src/ops/add.ts +++ b/src/ops/add.ts @@ -45,6 +45,11 @@ export class Add extends BinaryOp { ); } } + console.log("add a.shape:", a.shape); + console.log("a.data:", a.data.toString()); + console.log("add broadcasted b.shape:", b.shape); + console.log("b.data:", b.data.toString()); + return b; } diff --git a/src/tensor/tensor.ts b/src/tensor/tensor.ts index bf4c68b..ac90168 100644 --- a/src/tensor/tensor.ts +++ b/src/tensor/tensor.ts @@ -70,6 +70,20 @@ export class Tensor { return new Tensor(data, shape, requires_grad); } + static normal( + shape: number[], + requires_grad = false, + initializer_range = 0.01, + ) { + const data = new Float32Array(shape.reduce((a, b) => a * b)); + + for (let i = 0; i < data.length; i++) { + data[i] = Math.random() * 2 * initializer_range - initializer_range; + } + + return new Tensor(data, shape, requires_grad); + } + static broadcast(tensor: Tensor, size: number, requires_grad = false) { const shape = [size, ...tensor.shape]; const data = new Float32Array(shape.reduce((a, b) => a * b)); @@ -102,6 +116,7 @@ export class Tensor { const negOne = Tensor.full(tensor.shape, -1, false); const [negTensor] = await tensor.mul(negOne); + console.log("this.shape", this.shape); return this.add(negTensor); } @@ -305,21 +320,27 @@ export class Tensor { } async gather(indices: Tensor): Promise<[Tensor, number]> { - // Convert indices to one-hot - const oneHot = new Float32Array(indices.shape[0] * this.shape[0]).fill(0); - for (let i = 0; i < indices.shape[0]; i++) { - const index = indices.data[i] + i * this.shape[0]; - // set one hot value for the whole vector - oneHot.fill(1, index, index + 1); - } - - const oneHotTensor = new Tensor( - oneHot, - [indices.shape[0], this.shape[0]], - indices.requires_grad, - ); + // For input shape [batch_size] and embedding matrix [vocab_size, embedding_dim] + // We want output shape [batch_size, embedding_dim] + const batchSize = indices.shape[0]; + const embeddingDim = this.shape[1]; + const result = new Float32Array(batchSize * embeddingDim); + + // For each item in the batch + for (let i = 0; i < batchSize; i++) { + const tokenId = indices.data[i]; + // Copy the entire embedding vector for this token + const sourceOffset = tokenId * embeddingDim; + const targetOffset = i * embeddingDim; + for (let j = 0; j < embeddingDim; j++) { + result[targetOffset + j] = this.data[sourceOffset + j]; + } + } - return oneHotTensor.matmul(this); + return [ + new Tensor(result, [batchSize, embeddingDim], indices.requires_grad), + -1, + ]; } transpose() { diff --git a/tests/integration/nomic_embed.test.ts b/tests/integration/nomic_embed.test.ts new file mode 100644 index 0000000..432ea9c --- /dev/null +++ b/tests/integration/nomic_embed.test.ts @@ -0,0 +1,127 @@ +import { test, expect } from "@playwright/test"; + +test("NomicEmbed forward pass with known values", async ({ page }) => { + await page.goto("http://localhost:8080"); + + page.on("console", (msg) => { + console.log(msg); + }); + + // Inject test function + await page.evaluate(() => { + return new Promise((resolve) => { + // @ts-expect-error ignore error for tests + import("/dist/bundle.js").then((module) => { + const { Tensor, NomicEmbed } = module; + + window.runNomicEmbedTest = async function () { + // Create configuration matching the HF config + const config = { + vocab_size: 30528, + hidden_size: 768, + num_hidden_layers: 2, + num_attention_heads: 2, + intermediate_size: 3072, + hidden_act: "swiglu", + hidden_dropout_prob: 0.0, + attention_probs_dropout_prob: 0.0, + max_position_embeddings: 8192, + type_vocab_size: 2, + initializer_range: 0.02, + layer_norm_eps: 1e-12, + pad_token_id: 0, + position_embedding_type: "rotary", + use_cache: true, + classifier_dropout: null, + rotary_emb_fraction: 1.0, + qkv_proj_bias: false, + mlp_fc1_bias: false, + mlp_fc2_bias: false, + causal: false, + }; + + // Create sample input tensors + const seqLength = 1; // Small sequence for testing + + // Create input IDs tensor with some token IDs + const inputIds = new Tensor( + new Float32Array([1]), + [seqLength], + false, + ); + + // Create attention mask (all 1s for no masking) + const attentionMask = new Tensor( + new Float32Array([1]), + [seqLength], + false, + ); + + // Create position IDs (optional) + const positionIds = new Tensor( + new Float32Array([0]), + [seqLength], + false, + ); + + // Create token type IDs (optional) + const tokenTypeIds = new Tensor( + new Float32Array([0]), + [seqLength], + false, + ); + + // Initialize model + const model = new NomicEmbed(config); + + // Forward pass + const [output] = await model.forward( + inputIds, + attentionMask, + positionIds, + tokenTypeIds, + ); + + return { + inputShape: inputIds.shape, + outputShape: output.shape, + outputData: Array.from(output.data), + }; + }; + resolve(); + }); + }); + }); + + // Run the test function in the browser context + const result = await page.evaluate(() => window.runNomicEmbedTest()); + + // Test input shape + expect(result.inputShape).toEqual([1]); // [sequence_length] + + // Test output shape - should be [hidden_size] after pooling and normalization + expect(result.outputShape).toEqual([768]); // [hidden_size] + + // Verify output is normalized (L2 norm should be close to 1) + const l2Norm = Math.sqrt( + result.outputData.reduce((sum, val) => sum + val * val, 0), + ); + expect(l2Norm).toBeCloseTo(1, 6); + + // Verify output values are within reasonable range + result.outputData.forEach((value) => { + expect(Math.abs(value)).toBeLessThan(1); // Normalized values should be < 1 + }); + + await page.close(); +}); + +declare global { + interface Window { + runNomicEmbedTest: () => Promise<{ + inputShape: number[]; + outputShape: number[]; + outputData: number[]; + }>; + } +}