Skip to content

Commit

Permalink
Deploying to gh-pages from @ ed1ece2 🚀
Browse files Browse the repository at this point in the history
  • Loading branch information
Honry committed Jul 24, 2024
1 parent 43ac827 commit 76a2428
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 24 deletions.
63 changes: 40 additions & 23 deletions nnotepad/js/nnotepad.js
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ export class ComputeError extends Error {
// General WebNN Utilities
// ============================================================

const kArgTypeOperandList = 1;
const kArgTypeNonOperand = 2;
const kArgTypeOperand = 3;

class WebNNUtil {
static bufferForOperand(operand) {
const size = [...operand.shape()].reduce((a, b) => a * b, 1);
Expand Down Expand Up @@ -60,21 +64,22 @@ class WebNNUtil {
throw new Error(`Unsupported dataType ${type}`);
}

static isNonOperandArg(name, index) {
static argumentType(name, index) {
return ({
concat: [0, 1],
expand: [1],
gru: [3, 4],
gruCell: [4],
lstm: [3, 4],
lstmCell: [5],
pad: [1, 2],
reshape: [1],
slice: [1, 2],
softmax: [1], // TODO: Distinguish overloads
split: [1],
concat: {0: kArgTypeOperandList, 1: kArgTypeNonOperand},
expand: {1: kArgTypeNonOperand},
gru: {3: kArgTypeNonOperand, 4: kArgTypeNonOperand},
gruCell: {4: kArgTypeNonOperand},
lstm: {3: kArgTypeNonOperand, 4: kArgTypeNonOperand},
lstmCell: {5: kArgTypeNonOperand},
pad: {1: kArgTypeNonOperand, 2: kArgTypeNonOperand},
reshape: {1: kArgTypeNonOperand},
slice: {1: kArgTypeNonOperand, 2: kArgTypeNonOperand},
softmax: {1: kArgTypeNonOperand},
split: {1: kArgTypeNonOperand},
})[name]
?.includes(index);
?.[index] ||
kArgTypeOperand;
}
}

Expand Down Expand Up @@ -379,7 +384,7 @@ export class NNotepad {
}
throw new Error(`unexpected line type: ${line.type}`);
}
function serializeExpr(expr, nonOperand = false) {
function serializeExpr(expr, argumentType = kArgTypeOperand) {
if (expr.op) {
if (expr.lhs) {
return `_.${kBinaryOperators[expr.op]}(${serializeExpr(expr.lhs)}, ${
Expand All @@ -394,11 +399,21 @@ export class NNotepad {
case 'boolean':
return String(expr.value);
case 'number':
return nonOperand ? Util.stringify(expr.value) :
serializeScalar(expr.value, expr.dataType);
switch (argumentType) {
case kArgTypeNonOperand:
return Util.stringify(expr.value);
default:
return serializeScalar(expr.value, expr.dataType);
}
case 'array':
return nonOperand ? serializeArray(expr.value) :
serializeTensor(expr.value, expr.dataType);
switch (argumentType) {
case kArgTypeNonOperand:
return serializeArray(expr.value, kArgTypeNonOperand);
case kArgTypeOperandList:
return serializeArray(expr.value, kArgTypeOperand);
default:
return serializeTensor(expr.value, expr.dataType);
}
case 'dict':
return serializeDict(expr.dict);
case 'identifier':
Expand All @@ -414,7 +429,7 @@ export class NNotepad {
.map((k) => {
const v = dict[k];
k = Util.stringify(k);
return `${k}: ${serializeExpr(v, true)}`;
return `${k}: ${serializeExpr(v, kArgTypeNonOperand)}`;
})
.join(', ') +
'}';
Expand Down Expand Up @@ -465,8 +480,10 @@ export class NNotepad {
elements.map((n) => Util.stringifyNumber(n, dataType)).join(',')}]))`;
}

function serializeArray(array) {
return '[' + array.map((expr) => serializeExpr(expr)).join(', ') + ']';
function serializeArray(array, argumentType) {
return '[' +
array.map((expr) => serializeExpr(expr, argumentType)).join(', ') +
']';
}

function serializeCall(name, args) {
Expand Down Expand Up @@ -506,8 +523,8 @@ export class NNotepad {

return `_.${name}(${
args.map(
(arg, index) => serializeExpr(
arg, WebNNUtil.isNonOperandArg(name, index)))
(arg, index) =>
serializeExpr(arg, WebNNUtil.argumentType(name, index)))
.join(', ')})`;
}
}
Expand Down
15 changes: 14 additions & 1 deletion nnotepad/js/tests.js
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,27 @@ document.addEventListener('DOMContentLoaded', async (e) => {
{dataType: 'float32', shape: [2], buffer: [3, 4]},
]);

Harness.section('Multiple input tensors');
Harness.section('Non-operand arguments: array of operands');
await test(
`A = [1,2] B = [3,4] concat([A,B], 0)`,
{dataType: 'float32', shape: [4], buffer: [1, 2, 3, 4]});
await test(
`concat([identity([1,2]),identity([3,4])], 0)`,
{dataType: 'float32', shape: [4], buffer: [1, 2, 3, 4]});

Harness.section('Non-operand arguments: array of numbers');
await test(
`T = [[1,2,3],[4,5,6]] reshape(T, [1, 3, 2, 1])`,
{dataType: 'float32', shape: [1, 3, 2, 1], buffer: [1, 2, 3, 4, 5, 6]});
await test(
`expand([1], [2, 2])`,
{dataType: 'float32', shape: [2, 2], buffer: [1, 1, 1, 1]});

Harness.section('Non-operand arguments: simple numbers');
await test(
`softmax([1], 0)`,
{dataType: 'float32', shape: [1], buffer: [1]});

Harness.section('Regression tests');
await test(
`concat([[1,2],[3,4]], 0)`,
Expand Down

0 comments on commit 76a2428

Please sign in to comment.