-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
72b6259
commit c77f913
Showing
3 changed files
with
261 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import { Tensor } from "../tensor/tensor.js"; | ||
import { Module } from "./module.js"; | ||
import { Linear } from "./linear.js"; | ||
|
||
type ActivationType = "relu" | "silu" | "gelu" | "swiglu" | "none"; | ||
|
||
export class MLP extends Module { | ||
up: Linear; // Project up to larger dimension | ||
down: Linear; // Project back down | ||
activation: ActivationType; | ||
|
||
constructor( | ||
dim: number, // input/output dimension | ||
hiddenDim: number, // hidden dimension | ||
activation: ActivationType = "relu", | ||
) { | ||
super("mlp"); | ||
|
||
// For SwiGLU, we need double the hidden dimension for gating | ||
const actualHiddenDim = activation === "swiglu" ? hiddenDim * 2 : hiddenDim; | ||
|
||
this.up = new Linear(dim, actualHiddenDim); | ||
this.down = new Linear(hiddenDim, dim); | ||
this.activation = activation; | ||
} | ||
|
||
private async gelu(x: Tensor): Promise<[Tensor, number]> { | ||
// GELU(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3))) | ||
const sqrt2OverPi = Math.sqrt(2 / Math.PI); | ||
|
||
// Calculate x^3 | ||
const [xSquared] = await x.mul(x); | ||
const [xCubed] = await xSquared.mul(x); | ||
|
||
// Calculate 0.044715 * x^3 | ||
const [scaledCube] = await xCubed.mul( | ||
Tensor.full(x.shape, 0.044715, false), | ||
); | ||
|
||
// Add x to the scaled cube | ||
const [innerSum] = await x.add(scaledCube); | ||
|
||
// Multiply by sqrt(2/π) | ||
const [scaled] = await innerSum.mul( | ||
Tensor.full(x.shape, sqrt2OverPi, false), | ||
); | ||
|
||
// Calculate tanh using (e^x - e^-x)/(e^x + e^-x) | ||
const [exp] = await scaled.exp(); | ||
const [negScaled] = await scaled.mul(Tensor.full(x.shape, -1, false)); | ||
const [negExp] = await negScaled.exp(); | ||
|
||
const [numerator] = await exp.sub(negExp); | ||
const [denominator] = await exp.add(negExp); | ||
|
||
const [tanh] = await numerator.div(denominator); | ||
|
||
// Add 1 to tanh result | ||
const [tanhPlusOne] = await tanh.add(Tensor.full(x.shape, 1, false)); | ||
|
||
// Multiply by x | ||
const [xTimesSum] = await x.mul(tanhPlusOne); | ||
|
||
// Multiply by 0.5 for final result | ||
return xTimesSum.mul(Tensor.full(x.shape, 0.5, false)); | ||
} | ||
|
||
private async silu(x: Tensor): Promise<[Tensor, number]> { | ||
const [negX] = await x.mul(Tensor.full(x.shape, -1, false)); | ||
const [expNegX] = await negX.exp(); | ||
const [onePlusExpNegX] = await expNegX.add(Tensor.full(x.shape, 1, false)); | ||
|
||
const [sigmoid] = await Tensor.full(x.shape, 1, false).div(onePlusExpNegX); | ||
return x.mul(sigmoid); | ||
} | ||
|
||
private async applyActivation(x: Tensor): Promise<[Tensor, number]> { | ||
switch (this.activation) { | ||
case "relu": | ||
return x.relu(); | ||
case "silu": | ||
return this.silu(x); | ||
case "gelu": | ||
return this.gelu(x); | ||
case "swiglu": { | ||
// Split the tensor in half for gate and value paths | ||
const halfSize = Math.floor(x.shape[x.shape.length - 1] / 2); | ||
const [gate, value] = await Promise.all([ | ||
x.slice(":", [0, halfSize]), | ||
x.slice(":", [halfSize, x.shape[x.shape.length - 1]]), | ||
]); | ||
const [gateActivated] = await this.silu(gate); | ||
return gateActivated.mul(value); | ||
} | ||
case "none": | ||
return [x, -1]; | ||
default: | ||
throw new Error(`Unknown activation type: ${this.activation}`); | ||
} | ||
} | ||
|
||
async forward(...inputs: [Tensor]): Promise<[Tensor]> { | ||
const [input] = inputs; | ||
|
||
// Project up to hidden dimension | ||
const [upProjected] = await this.up.forward(input); | ||
|
||
// Apply activation | ||
const [activated] = await this.applyActivation(upProjected); | ||
|
||
// Project back down | ||
return this.down.forward(activated); | ||
} | ||
|
||
// Helper method for creating standard configurations | ||
static create(config: { | ||
dim: number; // input/output dimension | ||
hiddenMul?: number; // multiplier for hidden dimension (default 4) | ||
activation?: ActivationType; | ||
}): MLP { | ||
const { | ||
dim, | ||
hiddenMul = 4, // typical transformer uses 4x dimension for FFN | ||
activation = "relu", | ||
} = config; | ||
|
||
return new MLP(dim, dim * hiddenMul, activation); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import { test, expect } from "@playwright/test"; | ||
|
||
test("MLP with SwiGLU activation forward pass with known values", async ({ | ||
page, | ||
}) => { | ||
await page.goto("http://localhost:8080"); | ||
|
||
page.on("console", (msg) => { | ||
console.log(msg); | ||
}); | ||
|
||
// Inject test function | ||
await page.evaluate(() => { | ||
return new Promise<void>((resolve) => { | ||
// @ts-expect-error ignore error for tests | ||
import("/dist/bundle.js").then((module) => { | ||
const { Tensor, MLP } = module; | ||
|
||
window.runSwiGLUTest = async function () { | ||
// Create sample input tensor with known values | ||
const inputDim = 4; | ||
const hiddenDim = 8; // Will be doubled internally for SwiGLU | ||
const seqLength = 2; | ||
|
||
const input = new Tensor( | ||
new Float32Array([ | ||
0.1, | ||
0.2, | ||
0.3, | ||
0.4, // First sequence | ||
0.5, | ||
0.6, | ||
0.7, | ||
0.8, // Second sequence | ||
]), | ||
[seqLength, inputDim], | ||
false, | ||
); | ||
|
||
// Create MLP with SwiGLU activation | ||
const mlp = new MLP(inputDim, hiddenDim, "swiglu"); | ||
|
||
// Set known weights and biases for reproducibility | ||
mlp.up.weight = new Tensor( | ||
new Float32Array([ | ||
// First half for gate | ||
0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.2, 0.3, 0.4, 0.5, 0.6, | ||
0.7, 0.8, 0.9, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 0.4, 0.5, | ||
0.6, 0.7, 0.8, 0.9, 1.0, 1.1, | ||
// Second half for value | ||
0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.2, | ||
0.2, 0.2, 0.2, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.4, 0.4, | ||
0.4, 0.4, 0.4, 0.4, 0.4, 0.4, | ||
]), | ||
[inputDim, hiddenDim * 2], | ||
true, | ||
); | ||
|
||
mlp.up.bias = new Tensor( | ||
new Float32Array([ | ||
// Gate bias | ||
0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, | ||
// Value bias | ||
0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, | ||
]), | ||
[hiddenDim * 2], | ||
true, | ||
); | ||
|
||
mlp.down.weight = new Tensor( | ||
new Float32Array([ | ||
0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.3, 0.4, 0.5, 0.6, 0.4, | ||
0.5, 0.6, 0.7, 0.5, 0.6, 0.7, 0.8, 0.6, 0.7, 0.8, 0.9, 0.7, 0.8, | ||
0.9, 1.0, 0.8, 0.9, 1.0, 1.1, | ||
]), | ||
[hiddenDim, inputDim], | ||
true, | ||
); | ||
|
||
mlp.down.bias = new Tensor( | ||
new Float32Array([0.1, 0.1, 0.1, 0.1]), | ||
[inputDim], | ||
true, | ||
); | ||
|
||
// Forward pass | ||
const [output] = await mlp.forward(input); | ||
|
||
return { | ||
inputShape: input.shape, | ||
inputData: Array.from(input.data), | ||
outputShape: output.shape, | ||
outputData: Array.from(output.data), | ||
}; | ||
}; | ||
resolve(); | ||
}); | ||
}); | ||
}); | ||
|
||
// Run the test function in the browser context | ||
const result = await page.evaluate(() => window.runSwiGLUTest()); | ||
|
||
// Validate shapes | ||
expect(result.inputShape).toEqual([2, 4]); // [batch_size, input_dim] | ||
expect(result.outputShape).toEqual([2, 4]); // [batch_size, input_dim] | ||
console.log("result.outputData:", result.outputData.toString()); | ||
|
||
// Expected values computed using the same architecture with PyTorch | ||
const expectedOutput = [ | ||
0.7809, 0.9126, 1.0443, 1.176, 5.0712, 5.9646, 6.8581, 7.7515, | ||
]; | ||
|
||
// Validate output values | ||
result.outputData.forEach((value, idx) => { | ||
expect(value).toBeCloseTo(expectedOutput[idx], 4); | ||
}); | ||
|
||
await page.close(); | ||
}); | ||
|
||
declare global { | ||
interface Window { | ||
runSwiGLUTest: () => Promise<{ | ||
inputShape: number[]; | ||
inputData: number[]; | ||
outputShape: number[]; | ||
outputData: number[]; | ||
}>; | ||
} | ||
} |