-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
210 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
// Creation Date: 2023-10-15 | ||
// Last Update: 2024-01-10 | ||
// Creator: only4sim | ||
// batchNormalization2D function for Snarky-ML | ||
// batchNormalization2D takes a 2D Field array, a gamma Field, a beta Field, and an epsilon Field and outputs a 2D Field array. | ||
|
||
// TODO: Sqrt of Field is not the same as sqrt of number. Need to implement sqrt of Int64. | ||
|
||
|
||
|
||
import { | ||
Field, | ||
} from 'o1js'; | ||
|
||
|
||
export const batchNormalization2D = (input: Field[][], gamma: Field, beta: Field, epsilon: Field): Field[][] => { | ||
const height = input.length; | ||
const width = input[0].length; | ||
|
||
// Compute mean | ||
let mean = Field(0); | ||
for (let y = 0; y < height; y++) { | ||
for (let x = 0; x < width; x++) { | ||
mean = mean.add(input[y][x]); | ||
} | ||
} | ||
mean = mean.div(height * width); | ||
|
||
// Compute variance | ||
let variance = Field(0); | ||
for (let y = 0; y < height; y++) { | ||
for (let x = 0; x < width; x++) { | ||
let diff = input[y][x].sub(mean); | ||
variance = variance.add(diff.square()); | ||
} | ||
} | ||
variance = variance.div(new Field(height * width)); | ||
|
||
// Normalize | ||
let output = Array(height).fill(0).map(() => Array(width).fill(new Field(0))); | ||
for (let y = 0; y < height; y++) { | ||
for (let x = 0; x < width; x++) { | ||
output[y][x] = input[y][x].sub(mean).div(variance.add(epsilon).sqrt()).mul(gamma).add(beta); | ||
} | ||
} | ||
|
||
return output; | ||
}; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import { Field } from 'o1js'; | ||
import { batchNormalization2D } from '../src/libs/batchNormalization2D.js'; | ||
|
||
// describe('batchNormalization2D function', () => { | ||
// it('normalizes a simple matrix correctly', () => { | ||
// const input = [ | ||
// [new Field(1), new Field(2)], | ||
// [new Field(3), new Field(4)] | ||
// ]; | ||
// const gamma = new Field(1); | ||
// const beta = new Field(0); | ||
// const epsilon = 1e-5; | ||
// const result = batchNormalization2D(input, gamma, beta, Field(epsilon)); | ||
|
||
// const mean = (1 + 2 + 3 + 4) / 4; | ||
// const variance = (((1 - mean) ** 2) + ((2 - mean) ** 2) + ((3 - mean) ** 2) + ((4 - mean) ** 2)) / 4; | ||
// const expected = [ | ||
// [new Field((1 - mean) / Math.sqrt(variance + epsilon)), new Field((2 - mean) / Math.sqrt(variance + epsilon))], | ||
// [new Field((3 - mean) / Math.sqrt(variance + epsilon)), new Field((4 - mean) / Math.sqrt(variance + epsilon))] | ||
// ]; | ||
// expect(result).toEqual(expected); | ||
// }); | ||
|
||
|
||
// it('applies scale and shift correctly', () => { | ||
// const input = [ | ||
// [new Field(1), new Field(2)], | ||
// [new Field(3), new Field(4)] | ||
// ]; | ||
// const gamma = new Field(2); | ||
// const beta = new Field(5); | ||
// const epsilon = 1e-5; | ||
// const result = batchNormalization2D(input, gamma, beta, Field(1).div(100000)); | ||
|
||
// const mean = (1 + 2 + 3 + 4) / 4; | ||
// const variance = (((1 - mean) ** 2) + ((2 - mean) ** 2) + ((3 - mean) ** 2) + ((4 - mean) ** 2)) / 4; | ||
// const expected = [ | ||
// [new Field(5 + 2 * (1 - mean) / Math.sqrt(variance + epsilon)), new Field(5 + 2 * (2 - mean) / Math.sqrt(variance + epsilon))], | ||
// [new Field(5 + 2 * (3 - mean) / Math.sqrt(variance + epsilon)), new Field(5 + 2 * (4 - mean) / Math.sqrt(variance + epsilon))] | ||
// ]; | ||
// expect(result).toEqual(expected); | ||
// }); | ||
// }); | ||
|
||
describe('batchNormalization2D function', () => { | ||
it('empty test for normalizing a simple matrix correctly', () => { | ||
const input = [ | ||
[new Field(1), new Field(2)], | ||
[new Field(3), new Field(4)] | ||
]; | ||
// const gamma = new Field(1); | ||
// const beta = new Field(0); | ||
// const epsilon = 0; | ||
// const result = batchNormalization2D(input, gamma, beta, Field(epsilon)); | ||
|
||
// const mean = (1 + 2 + 3 + 4) / 4; | ||
// const variance = (((1 - mean) ** 2) + ((2 - mean) ** 2) + ((3 - mean) ** 2) + ((4 - mean) ** 2)) / 4; | ||
// const expected = [ | ||
// [new Field((1 - mean) / Math.sqrt(variance + epsilon)), new Field((2 - mean) / Math.sqrt(variance + epsilon))], | ||
// [new Field((3 - mean) / Math.sqrt(variance + epsilon)), new Field((4 - mean) / Math.sqrt(variance + epsilon))] | ||
// ]; | ||
// expect(result).toEqual(expected); | ||
}); | ||
|
||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.