Skip to content

Commit

Permalink
feat: swiglu mlp
Browse files Browse the repository at this point in the history
  • Loading branch information
zanussbaum committed Dec 20, 2024
1 parent 72b6259 commit c77f913
Show file tree
Hide file tree
Showing 3 changed files with 261 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ export * from "./layers/module.js";
export * from "./layers/embedding.js";
export * from "./layers/linear.js";
export * from "./layers/norm.js";
export * from "./layers/mlp.js";
export * from "./ops/div.js";
129 changes: 129 additions & 0 deletions src/layers/mlp.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import { Tensor } from "../tensor/tensor.js";
import { Module } from "./module.js";
import { Linear } from "./linear.js";

type ActivationType = "relu" | "silu" | "gelu" | "swiglu" | "none";

export class MLP extends Module {
up: Linear; // Project up to larger dimension
down: Linear; // Project back down
activation: ActivationType;

constructor(
dim: number, // input/output dimension
hiddenDim: number, // hidden dimension
activation: ActivationType = "relu",
) {
super("mlp");

// For SwiGLU, we need double the hidden dimension for gating
const actualHiddenDim = activation === "swiglu" ? hiddenDim * 2 : hiddenDim;

this.up = new Linear(dim, actualHiddenDim);
this.down = new Linear(hiddenDim, dim);
this.activation = activation;
}

private async gelu(x: Tensor): Promise<[Tensor, number]> {
// GELU(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
const sqrt2OverPi = Math.sqrt(2 / Math.PI);

// Calculate x^3
const [xSquared] = await x.mul(x);
const [xCubed] = await xSquared.mul(x);

// Calculate 0.044715 * x^3
const [scaledCube] = await xCubed.mul(
Tensor.full(x.shape, 0.044715, false),
);

// Add x to the scaled cube
const [innerSum] = await x.add(scaledCube);

// Multiply by sqrt(2/π)
const [scaled] = await innerSum.mul(
Tensor.full(x.shape, sqrt2OverPi, false),
);

// Calculate tanh using (e^x - e^-x)/(e^x + e^-x)
const [exp] = await scaled.exp();
const [negScaled] = await scaled.mul(Tensor.full(x.shape, -1, false));
const [negExp] = await negScaled.exp();

const [numerator] = await exp.sub(negExp);
const [denominator] = await exp.add(negExp);

const [tanh] = await numerator.div(denominator);

// Add 1 to tanh result
const [tanhPlusOne] = await tanh.add(Tensor.full(x.shape, 1, false));

// Multiply by x
const [xTimesSum] = await x.mul(tanhPlusOne);

// Multiply by 0.5 for final result
return xTimesSum.mul(Tensor.full(x.shape, 0.5, false));
}

private async silu(x: Tensor): Promise<[Tensor, number]> {
const [negX] = await x.mul(Tensor.full(x.shape, -1, false));
const [expNegX] = await negX.exp();
const [onePlusExpNegX] = await expNegX.add(Tensor.full(x.shape, 1, false));

const [sigmoid] = await Tensor.full(x.shape, 1, false).div(onePlusExpNegX);
return x.mul(sigmoid);
}

private async applyActivation(x: Tensor): Promise<[Tensor, number]> {
switch (this.activation) {
case "relu":
return x.relu();
case "silu":
return this.silu(x);
case "gelu":
return this.gelu(x);
case "swiglu": {
// Split the tensor in half for gate and value paths
const halfSize = Math.floor(x.shape[x.shape.length - 1] / 2);
const [gate, value] = await Promise.all([
x.slice(":", [0, halfSize]),
x.slice(":", [halfSize, x.shape[x.shape.length - 1]]),
]);
const [gateActivated] = await this.silu(gate);
return gateActivated.mul(value);
}
case "none":
return [x, -1];
default:
throw new Error(`Unknown activation type: ${this.activation}`);
}
}

async forward(...inputs: [Tensor]): Promise<[Tensor]> {
const [input] = inputs;

// Project up to hidden dimension
const [upProjected] = await this.up.forward(input);

// Apply activation
const [activated] = await this.applyActivation(upProjected);

// Project back down
return this.down.forward(activated);
}

// Helper method for creating standard configurations
static create(config: {
dim: number; // input/output dimension
hiddenMul?: number; // multiplier for hidden dimension (default 4)
activation?: ActivationType;
}): MLP {
const {
dim,
hiddenMul = 4, // typical transformer uses 4x dimension for FFN
activation = "relu",
} = config;

return new MLP(dim, dim * hiddenMul, activation);
}
}
131 changes: 131 additions & 0 deletions tests/integration/mlp.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import { test, expect } from "@playwright/test";

test("MLP with SwiGLU activation forward pass with known values", async ({
page,
}) => {
await page.goto("http://localhost:8080");

page.on("console", (msg) => {
console.log(msg);
});

// Inject test function
await page.evaluate(() => {
return new Promise<void>((resolve) => {
// @ts-expect-error ignore error for tests
import("/dist/bundle.js").then((module) => {
const { Tensor, MLP } = module;

window.runSwiGLUTest = async function () {
// Create sample input tensor with known values
const inputDim = 4;
const hiddenDim = 8; // Will be doubled internally for SwiGLU
const seqLength = 2;

const input = new Tensor(
new Float32Array([
0.1,
0.2,
0.3,
0.4, // First sequence
0.5,
0.6,
0.7,
0.8, // Second sequence
]),
[seqLength, inputDim],
false,
);

// Create MLP with SwiGLU activation
const mlp = new MLP(inputDim, hiddenDim, "swiglu");

// Set known weights and biases for reproducibility
mlp.up.weight = new Tensor(
new Float32Array([
// First half for gate
0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.2, 0.3, 0.4, 0.5, 0.6,
0.7, 0.8, 0.9, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 0.4, 0.5,
0.6, 0.7, 0.8, 0.9, 1.0, 1.1,
// Second half for value
0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.2,
0.2, 0.2, 0.2, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.4, 0.4,
0.4, 0.4, 0.4, 0.4, 0.4, 0.4,
]),
[inputDim, hiddenDim * 2],
true,
);

mlp.up.bias = new Tensor(
new Float32Array([
// Gate bias
0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
// Value bias
0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2,
]),
[hiddenDim * 2],
true,
);

mlp.down.weight = new Tensor(
new Float32Array([
0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.3, 0.4, 0.5, 0.6, 0.4,
0.5, 0.6, 0.7, 0.5, 0.6, 0.7, 0.8, 0.6, 0.7, 0.8, 0.9, 0.7, 0.8,
0.9, 1.0, 0.8, 0.9, 1.0, 1.1,
]),
[hiddenDim, inputDim],
true,
);

mlp.down.bias = new Tensor(
new Float32Array([0.1, 0.1, 0.1, 0.1]),
[inputDim],
true,
);

// Forward pass
const [output] = await mlp.forward(input);

return {
inputShape: input.shape,
inputData: Array.from(input.data),
outputShape: output.shape,
outputData: Array.from(output.data),
};
};
resolve();
});
});
});

// Run the test function in the browser context
const result = await page.evaluate(() => window.runSwiGLUTest());

// Validate shapes
expect(result.inputShape).toEqual([2, 4]); // [batch_size, input_dim]
expect(result.outputShape).toEqual([2, 4]); // [batch_size, input_dim]
console.log("result.outputData:", result.outputData.toString());

// Expected values computed using the same architecture with PyTorch
const expectedOutput = [
0.7809, 0.9126, 1.0443, 1.176, 5.0712, 5.9646, 6.8581, 7.7515,
];

// Validate output values
result.outputData.forEach((value, idx) => {
expect(value).toBeCloseTo(expectedOutput[idx], 4);
});

await page.close();
});

declare global {
interface Window {
runSwiGLUTest: () => Promise<{
inputShape: number[];
inputData: number[];
outputShape: number[];
outputData: number[];
}>;
}
}

0 comments on commit c77f913

Please sign in to comment.