Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: nomic embed webgpu architecture #37

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"start": "http-server -c-1",
"build-clean": "npm run clean && npm run build && npm run build-bundle",
"dev": "npm run build-clean && npm run start",
"prettier": "prettier --write .",
"prettier": "prettier --write tests/**/* src/**/*.ts",
"build-bundle": "esbuild src/index.ts --bundle --outfile=dist/bundle.js --format=esm --target=es2020",
"unit": "npm run build-clean && node --experimental-vm-modules node_modules/jest/bin/jest.js",
"integration": "npm run build-clean && npx playwright test",
Expand Down
11 changes: 6 additions & 5 deletions src/autograd/function.ts
Original file line number Diff line number Diff line change
Expand Up @@ -155,21 +155,22 @@ export abstract class BinaryOp extends AutogradFunction {
pass.setPipeline(this.pipeline);
pass.setBindGroup(0, bindGroup);

// TODO: set these as overrides in the layers/ops level since the kernels are different
const WORKGROUP_SIZE = 16;
const TILE_SIZE = 8;
const workgropuA = Math.ceil(a.shape[0] / (TILE_SIZE * WORKGROUP_SIZE));
const workgropuB = Math.ceil(b.shape[1] / (TILE_SIZE * WORKGROUP_SIZE));
const workgroupA = Math.ceil(a.shape[0] / (TILE_SIZE * WORKGROUP_SIZE));
const workgroupB = Math.ceil(b.shape[1] / (TILE_SIZE * WORKGROUP_SIZE));
console.log(
"a.shape[0]:",
a.shape[0],
"b.shape[1]:",
b.shape[1],
"launching workgroups",
workgropuA,
workgroupA,
",",
workgropuB,
workgroupB,
);
pass.dispatchWorkgroups(workgropuA, workgropuB);
pass.dispatchWorkgroups(workgroupA, workgroupB);
pass.end();

const stagingBuffer = this.device.createBuffer({
Expand Down
8 changes: 8 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,12 @@ export * from "./ops/exp2.js";
export * from "./ops/log2.js";
export * from "./ops/ln.js";
export * from "./ops/relu.js";
export * from "./ops/div.js";
export * from "./autograd/function.js";
export * from "./layers/module.js";
export * from "./layers/embedding.js";
export * from "./layers/linear.js";
export * from "./layers/norm.js";
export * from "./layers/mlp.js";
export * from "./layers/attention.js";
export * from "./model/nomic_embed.js"
111 changes: 111 additions & 0 deletions src/layers/attention.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import { Tensor } from "../tensor/tensor.js";
import { Module } from "./module.js";
import { Linear } from "./linear.js";

export class MultiHeadAttention extends Module {
qkv: Linear; // Combined projection for Query, Key, Value
output: Linear; // Output projection
num_heads: number;
head_dim: number;
hidden_dim: number;

constructor(hidden_dim: number, num_heads: number) {
super("multihead_attention");

this.num_heads = num_heads;
this.head_dim = Math.floor(hidden_dim / num_heads);
this.hidden_dim = hidden_dim;

if (this.head_dim * num_heads !== hidden_dim) {
throw new Error(
`Hidden dimension ${hidden_dim} must be divisible by number of heads ${num_heads}`,
);
}

// Combined QKV projection
// Projects to 3x hidden_dim for Q, K, V
this.qkv = new Linear(hidden_dim, hidden_dim * 3);

// Output projection
this.output = new Linear(hidden_dim, hidden_dim);
}

private async reshapeToHeads(tensor: Tensor): Promise<Tensor[]> {
const heads: Tensor[] = [];

// Each head will be (seqlen, head_dim)
for (let i = 0; i < this.num_heads; i++) {
const start = i * this.head_dim;
const end = start + this.head_dim;
const head = await tensor.slice(":", [start, end]);
heads.push(head);
}

return heads;
}

private async scaledDotProductAttention(
query: Tensor,
key: Tensor,
value: Tensor,
): Promise<[Tensor, number]> {
// Scale factor is 1/sqrt(head_dim)
const scale = 1 / Math.sqrt(this.head_dim);
const scaleTensor = Tensor.full(
[query.shape[0], key.shape[0]],
scale,
false,
);
// Compute attention scores
const [scores] = await query.matmul(key.transpose());
const [scaledScores] = await scores.mul(scaleTensor);

// Softmax implementation
const [expScores] = await scaledScores.exp();
const sumExp = await expScores.sum([1]);

const [attention] = await expScores.div(sumExp);

// Apply attention to values
return attention.matmul(value);
}

async forward(input: Tensor): Promise<[Tensor]> {
// Project input to Q, K, V
const [qkv] = await this.qkv.forward(input);

// Split into Q, K, V
const query = await qkv.slice(":", [0, this.hidden_dim]);
const key = await qkv.slice(":", [this.hidden_dim, this.hidden_dim * 2]);
const value = await qkv.slice(":", [
this.hidden_dim * 2,
this.hidden_dim * 3,
]);

// Split each of Q, K, V into heads
const queryHeads = await this.reshapeToHeads(query);
const keyHeads = await this.reshapeToHeads(key);
const valueHeads = await this.reshapeToHeads(value);

// Compute attention for each head
// this will be slow, we should create bmm
const headOutputs: Tensor[] = [];
for (let i = 0; i < this.num_heads; i++) {
const [headOutput] = await this.scaledDotProductAttention(
queryHeads[i],
keyHeads[i],
valueHeads[i],
);
headOutputs.push(headOutput);
}

// Concatenate heads
let concatOutput = headOutputs[0];
for (let i = 1; i < headOutputs.length; i++) {
concatOutput = await concatOutput.concat(headOutputs[i], 1);
}

// Final output projection
return this.output.forward(concatOutput);
}
}
121 changes: 121 additions & 0 deletions src/layers/embedding.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import { Tensor } from "../tensor/tensor.js";
import { Module } from "./module.js";

export class Embedding extends Module {
vocab_size: number;
emb_dim: number;
embedding: Tensor;
constructor(vocab_size: number, emb_dim: number) {
super("embedding");

this.vocab_size = vocab_size;
this.emb_dim = emb_dim;
this.embedding = Tensor.normal([vocab_size, emb_dim], true, 0.02);
}

async forward(...inputs: [Tensor]): Promise<[Tensor]> {
const [embeddings] = await this.embedding.gather(inputs[0]);
return [embeddings];
}
}

export class RotaryEmbedding extends Module {
base: number;
dimension: number;
theta: Tensor;
sequenceLength: number;
idxTheta: Tensor | null = null;
constructor(base: number, dimension: number) {
super("rope_embedding");
this.base = base;
this.dimension = dimension;

const theta = this.createTheta(dimension, base);
this.theta = new Tensor(theta, [1, dimension / 2], true);

this.sequenceLength = 0;
this.idxTheta = null;
}

createTheta(dimension: number, base: number = 10000): Float32Array {
// Create a new Float32Array of the specified size
const result = new Float32Array(dimension / 2);

// Calculate values for each position
for (let i = 0; i < dimension; i += 2) {
const value = 1.0 / Math.pow(base, i / dimension);
result[i / 2] = value;
}

return result;
}

async buildCache(sequenceLength: number) {
const posIdx = new Float32Array(sequenceLength);
for (let i = 0; i < sequenceLength; i++) {
posIdx[i] = i;
}

const posTensor = new Tensor(posIdx, [sequenceLength, 1], true);
let [idxTheta] = await posTensor.matmul(this.theta);

idxTheta = await idxTheta.concat(idxTheta, 1);

return [idxTheta];
}

async forward(...inputs: [Tensor]): Promise<[Tensor]> {
const [x] = inputs;

const currSeqLen = x.shape[0];
const d2 = Math.floor(this.dimension / 2);

if (currSeqLen > this.sequenceLength || this.idxTheta === null) {
const [cache] = await this.buildCache(currSeqLen);
this.sequenceLength = currSeqLen;
this.idxTheta = cache;
}

const idxTheta = this.idxTheta;

const idxThetaLength = idxTheta.data.length;
const cosIdxThetaArr = new Float32Array(idxThetaLength);
const sinIdxThetaArr = new Float32Array(idxThetaLength);

for (let i = 0; i < idxThetaLength; i++) {
cosIdxThetaArr[i] = Math.cos(idxTheta.data[i]);
sinIdxThetaArr[i] = Math.sin(idxTheta.data[i]);
}

const cosIdxTheta = new Tensor(
cosIdxThetaArr,
[currSeqLen, this.dimension],
x.requires_grad,
);
const sinIdxTheta = new Tensor(
sinIdxThetaArr,
[currSeqLen, this.dimension],
x.requires_grad,
);

// Rewrite using tensor operations and select
const leftHalf = await x.slice(":", [null, d2]);
const rightHalf = await x.slice(":", [d2, this.dimension]);
const [negHalf] = await rightHalf.mul(Tensor.full([1], -1));

const half = await negHalf.concat(leftHalf, 1);
const xRope = await x.slice(":", [null, this.dimension]);

const [xRopePos] = await xRope.mul(cosIdxTheta);
const [xRopeNeg] = await half.mul(sinIdxTheta);

let [rope] = await xRopePos.add(xRopeNeg);
if (this.dimension < x.shape[1]) {
const xPass = await x.slice(":", [null, null, d2]);

rope = await rope.concat(xPass, 1);
}

return [rope];
}
}
20 changes: 20 additions & 0 deletions src/layers/linear.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import { Tensor } from "../tensor/tensor.js";
import { Module } from "./module.js";

export class Linear extends Module {
weight: Tensor;
bias: Tensor;

constructor(inputSize: number, outputSize: number) {
super("linear");
this.weight = Tensor.normal([inputSize, outputSize], true, 0.02);
this.bias = Tensor.full([outputSize], 0, true);
}

async forward(...inputs: [Tensor]): Promise<[Tensor]> {
const [input] = inputs;
const [output] = await input.matmul(this.weight);
const [outputBias] = await output.add(this.bias);
return [outputBias];
}
}
Loading
Loading