From c77f91342fe7dc134f5eacee1cd7b2a487786c94 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Thu, 19 Dec 2024 19:08:44 -0500 Subject: [PATCH] feat: swiglu mlp --- src/index.ts | 1 + src/layers/mlp.ts | 129 +++++++++++++++++++++++++++++++++ tests/integration/mlp.test.ts | 131 ++++++++++++++++++++++++++++++++++ 3 files changed, 261 insertions(+) create mode 100644 src/layers/mlp.ts create mode 100644 tests/integration/mlp.test.ts diff --git a/src/index.ts b/src/index.ts index c0a965d..fdbf50f 100644 --- a/src/index.ts +++ b/src/index.ts @@ -12,4 +12,5 @@ export * from "./layers/module.js"; export * from "./layers/embedding.js"; export * from "./layers/linear.js"; export * from "./layers/norm.js"; +export * from "./layers/mlp.js"; export * from "./ops/div.js"; diff --git a/src/layers/mlp.ts b/src/layers/mlp.ts new file mode 100644 index 0000000..a338091 --- /dev/null +++ b/src/layers/mlp.ts @@ -0,0 +1,129 @@ +import { Tensor } from "../tensor/tensor.js"; +import { Module } from "./module.js"; +import { Linear } from "./linear.js"; + +type ActivationType = "relu" | "silu" | "gelu" | "swiglu" | "none"; + +export class MLP extends Module { + up: Linear; // Project up to larger dimension + down: Linear; // Project back down + activation: ActivationType; + + constructor( + dim: number, // input/output dimension + hiddenDim: number, // hidden dimension + activation: ActivationType = "relu", + ) { + super("mlp"); + + // For SwiGLU, we need double the hidden dimension for gating + const actualHiddenDim = activation === "swiglu" ? hiddenDim * 2 : hiddenDim; + + this.up = new Linear(dim, actualHiddenDim); + this.down = new Linear(hiddenDim, dim); + this.activation = activation; + } + + private async gelu(x: Tensor): Promise<[Tensor, number]> { + // GELU(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3))) + const sqrt2OverPi = Math.sqrt(2 / Math.PI); + + // Calculate x^3 + const [xSquared] = await x.mul(x); + const [xCubed] = await xSquared.mul(x); + + // Calculate 0.044715 * x^3 + const [scaledCube] = await xCubed.mul( + Tensor.full(x.shape, 0.044715, false), + ); + + // Add x to the scaled cube + const [innerSum] = await x.add(scaledCube); + + // Multiply by sqrt(2/π) + const [scaled] = await innerSum.mul( + Tensor.full(x.shape, sqrt2OverPi, false), + ); + + // Calculate tanh using (e^x - e^-x)/(e^x + e^-x) + const [exp] = await scaled.exp(); + const [negScaled] = await scaled.mul(Tensor.full(x.shape, -1, false)); + const [negExp] = await negScaled.exp(); + + const [numerator] = await exp.sub(negExp); + const [denominator] = await exp.add(negExp); + + const [tanh] = await numerator.div(denominator); + + // Add 1 to tanh result + const [tanhPlusOne] = await tanh.add(Tensor.full(x.shape, 1, false)); + + // Multiply by x + const [xTimesSum] = await x.mul(tanhPlusOne); + + // Multiply by 0.5 for final result + return xTimesSum.mul(Tensor.full(x.shape, 0.5, false)); + } + + private async silu(x: Tensor): Promise<[Tensor, number]> { + const [negX] = await x.mul(Tensor.full(x.shape, -1, false)); + const [expNegX] = await negX.exp(); + const [onePlusExpNegX] = await expNegX.add(Tensor.full(x.shape, 1, false)); + + const [sigmoid] = await Tensor.full(x.shape, 1, false).div(onePlusExpNegX); + return x.mul(sigmoid); + } + + private async applyActivation(x: Tensor): Promise<[Tensor, number]> { + switch (this.activation) { + case "relu": + return x.relu(); + case "silu": + return this.silu(x); + case "gelu": + return this.gelu(x); + case "swiglu": { + // Split the tensor in half for gate and value paths + const halfSize = Math.floor(x.shape[x.shape.length - 1] / 2); + const [gate, value] = await Promise.all([ + x.slice(":", [0, halfSize]), + x.slice(":", [halfSize, x.shape[x.shape.length - 1]]), + ]); + const [gateActivated] = await this.silu(gate); + return gateActivated.mul(value); + } + case "none": + return [x, -1]; + default: + throw new Error(`Unknown activation type: ${this.activation}`); + } + } + + async forward(...inputs: [Tensor]): Promise<[Tensor]> { + const [input] = inputs; + + // Project up to hidden dimension + const [upProjected] = await this.up.forward(input); + + // Apply activation + const [activated] = await this.applyActivation(upProjected); + + // Project back down + return this.down.forward(activated); + } + + // Helper method for creating standard configurations + static create(config: { + dim: number; // input/output dimension + hiddenMul?: number; // multiplier for hidden dimension (default 4) + activation?: ActivationType; + }): MLP { + const { + dim, + hiddenMul = 4, // typical transformer uses 4x dimension for FFN + activation = "relu", + } = config; + + return new MLP(dim, dim * hiddenMul, activation); + } +} diff --git a/tests/integration/mlp.test.ts b/tests/integration/mlp.test.ts new file mode 100644 index 0000000..e329c1b --- /dev/null +++ b/tests/integration/mlp.test.ts @@ -0,0 +1,131 @@ +import { test, expect } from "@playwright/test"; + +test("MLP with SwiGLU activation forward pass with known values", async ({ + page, +}) => { + await page.goto("http://localhost:8080"); + + page.on("console", (msg) => { + console.log(msg); + }); + + // Inject test function + await page.evaluate(() => { + return new Promise((resolve) => { + // @ts-expect-error ignore error for tests + import("/dist/bundle.js").then((module) => { + const { Tensor, MLP } = module; + + window.runSwiGLUTest = async function () { + // Create sample input tensor with known values + const inputDim = 4; + const hiddenDim = 8; // Will be doubled internally for SwiGLU + const seqLength = 2; + + const input = new Tensor( + new Float32Array([ + 0.1, + 0.2, + 0.3, + 0.4, // First sequence + 0.5, + 0.6, + 0.7, + 0.8, // Second sequence + ]), + [seqLength, inputDim], + false, + ); + + // Create MLP with SwiGLU activation + const mlp = new MLP(inputDim, hiddenDim, "swiglu"); + + // Set known weights and biases for reproducibility + mlp.up.weight = new Tensor( + new Float32Array([ + // First half for gate + 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.2, 0.3, 0.4, 0.5, 0.6, + 0.7, 0.8, 0.9, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 0.4, 0.5, + 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, + // Second half for value + 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.2, + 0.2, 0.2, 0.2, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.4, 0.4, + 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, + ]), + [inputDim, hiddenDim * 2], + true, + ); + + mlp.up.bias = new Tensor( + new Float32Array([ + // Gate bias + 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, + // Value bias + 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, + ]), + [hiddenDim * 2], + true, + ); + + mlp.down.weight = new Tensor( + new Float32Array([ + 0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.3, 0.4, 0.5, 0.6, 0.4, + 0.5, 0.6, 0.7, 0.5, 0.6, 0.7, 0.8, 0.6, 0.7, 0.8, 0.9, 0.7, 0.8, + 0.9, 1.0, 0.8, 0.9, 1.0, 1.1, + ]), + [hiddenDim, inputDim], + true, + ); + + mlp.down.bias = new Tensor( + new Float32Array([0.1, 0.1, 0.1, 0.1]), + [inputDim], + true, + ); + + // Forward pass + const [output] = await mlp.forward(input); + + return { + inputShape: input.shape, + inputData: Array.from(input.data), + outputShape: output.shape, + outputData: Array.from(output.data), + }; + }; + resolve(); + }); + }); + }); + + // Run the test function in the browser context + const result = await page.evaluate(() => window.runSwiGLUTest()); + + // Validate shapes + expect(result.inputShape).toEqual([2, 4]); // [batch_size, input_dim] + expect(result.outputShape).toEqual([2, 4]); // [batch_size, input_dim] + console.log("result.outputData:", result.outputData.toString()); + + // Expected values computed using the same architecture with PyTorch + const expectedOutput = [ + 0.7809, 0.9126, 1.0443, 1.176, 5.0712, 5.9646, 6.8581, 7.7515, + ]; + + // Validate output values + result.outputData.forEach((value, idx) => { + expect(value).toBeCloseTo(expectedOutput[idx], 4); + }); + + await page.close(); +}); + +declare global { + interface Window { + runSwiGLUTest: () => Promise<{ + inputShape: number[]; + inputData: number[]; + outputShape: number[]; + outputData: number[]; + }>; + } +}