diff --git a/src/index.ts b/src/index.ts index 004f0d5..33e77a3 100644 --- a/src/index.ts +++ b/src/index.ts @@ -10,3 +10,4 @@ export * from "./ops/relu.js"; export * from "./autograd/function.js"; export * from "./layers/module.js"; export * from "./layers/embedding.js"; +export * from "./layers/linear.js"; diff --git a/src/layers/linear.ts b/src/layers/linear.ts new file mode 100644 index 0000000..196c4bd --- /dev/null +++ b/src/layers/linear.ts @@ -0,0 +1,20 @@ +import { Tensor } from "../tensor/tensor.js"; +import { Module } from "./module.js"; + +export class Linear extends Module { + weight: Tensor; + bias: Tensor; + + constructor(inputSize: number, outputSize: number) { + super("linear"); + this.weight = Tensor.randn([inputSize, outputSize], true); + this.bias = Tensor.randn([outputSize], true); + } + + async forward(...inputs: [Tensor]): Promise<[Tensor]> { + const [input] = inputs; + const [output] = await input.matmul(this.weight); + const [outputBias] = await output.add(this.bias); + return [outputBias]; + } +} \ No newline at end of file diff --git a/tests/integration/linear.test.ts b/tests/integration/linear.test.ts new file mode 100644 index 0000000..3e8a0c6 --- /dev/null +++ b/tests/integration/linear.test.ts @@ -0,0 +1,84 @@ +import { test, expect } from "@playwright/test"; + +test("Linear forward pass with known values", async ({ page }) => { + await page.goto("http://localhost:8080"); + + page.on("console", (msg) => { + console.log(msg); + }); + + // Inject test function + await page.evaluate(() => { + return new Promise((resolve) => { + // @ts-expect-error ignore error for tests + import("/dist/bundle.js").then((module) => { + const { Tensor, Linear } = module; + + window.runLinearTest = async function () { + const inputSize = 3; + const outputSize = 2; + + // Create linear layer + const linear = new Linear(inputSize, outputSize); + + // Set known weights and biases for deterministic testing + linear.weight.data = new Float32Array([ + 0.1, 0.2, // First row + 0.3, 0.4, // Second row + 0.5, 0.6 // Third row + ]); + + linear.bias.data = new Float32Array([0.1, 0.2]); + + // Create input tensor + const input = new Tensor( + new Float32Array([1.0, 2.0, 3.0]), // Sample input + [1, 3], // Batch size 1, input size 3 + false + ); + + // Forward pass + const [output] = await linear.forward(input); + + return { + inputData: Array.from(input.data), + weights: Array.from(linear.weight.data), + biases: Array.from(linear.bias.data), + outputShape: output.shape, + outputData: Array.from(output.data) + }; + }; + resolve(); + }); + }); + }); + + // Run the test function in the browser context + const result = await page.evaluate(() => window.runLinearTest()); + + // Validate shapes + expect(result.outputShape).toEqual([1, 2]); // Batch size x Output size + + // Calculate expected output manually: + // output[0] = (1.0 * 0.1 + 2.0 * 0.3 + 3.0 * 0.5) + 0.1 = 2.0 + // output[1] = (1.0 * 0.2 + 2.0 * 0.4 + 3.0 * 0.6) + 0.2 = 2.8 + const expectedOutput = [2.3, 3.0]; + + // Check if outputs match expected values within a small tolerance + expect(result.outputData[0]).toBeCloseTo(expectedOutput[0], 5); + expect(result.outputData[1]).toBeCloseTo(expectedOutput[1], 5); + + await page.close(); +}); + +declare global { + interface Window { + runLinearTest: () => Promise<{ + inputData: number[]; + weights: number[]; + biases: number[]; + outputShape: number[]; + outputData: number[]; + }>; + } +} \ No newline at end of file