From 2a6e749ecee92c6f5d3711e8c678f280a9744edf Mon Sep 17 00:00:00 2001 From: Nikita Masych <92444221+NikitaMasych@users.noreply.github.com> Date: Tue, 29 Aug 2023 19:35:06 +0300 Subject: [PATCH] feat: add sha3 primitive (#817) * refactor: keccakf1600 accepts and returns [25]uints.U64 instead of [25]frontend.Variable * feat: add sha3 primitive * fix: sha3 test current hash global elimination => moved to struct --- std/hash/sha3/doc.go | 5 ++ std/hash/sha3/hashes.go | 94 ++++++++++++++++++++++++++ std/hash/sha3/sha3.go | 91 +++++++++++++++++++++++++ std/hash/sha3/sha3_test.go | 90 ++++++++++++++++++++++++ std/permutation/keccakf/keccak_test.go | 19 ++++-- std/permutation/keccakf/keccakf.go | 25 ++----- 6 files changed, 298 insertions(+), 26 deletions(-) create mode 100644 std/hash/sha3/doc.go create mode 100644 std/hash/sha3/hashes.go create mode 100644 std/hash/sha3/sha3.go create mode 100644 std/hash/sha3/sha3_test.go diff --git a/std/hash/sha3/doc.go b/std/hash/sha3/doc.go new file mode 100644 index 0000000000..84a8d0f340 --- /dev/null +++ b/std/hash/sha3/doc.go @@ -0,0 +1,5 @@ +// Package sha3 provides ZKP circuits for SHA3 hash algorithms applying sponge construction over +// Keccak f-[1600] permutation function. +// +// Instances correspond golang.org/x/crypto/sha3, except SHA224, which is not x64 compatible. +package sha3 diff --git a/std/hash/sha3/hashes.go b/std/hash/sha3/hashes.go new file mode 100644 index 0000000000..b2282f9108 --- /dev/null +++ b/std/hash/sha3/hashes.go @@ -0,0 +1,94 @@ +package sha3 + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash" + "github.com/consensys/gnark/std/math/uints" +) + +// New256 creates a new SHA3-256 hash. +// Its generic security strength is 256 bits against preimage attacks, +// and 128 bits against collision attacks. +func New256(api frontend.API) (hash.BinaryHasher, error) { + uapi, err := uints.New[uints.U64](api) + if err != nil { + return nil, err + } + return &digest{ + uapi: uapi, + state: newState(), + dsbyte: 0x06, + rate: 136, + outputLen: 32, + }, nil +} + +// New384 creates a new SHA3-384 hash. +// Its generic security strength is 384 bits against preimage attacks, +// and 192 bits against collision attacks. +func New384(api frontend.API) (hash.BinaryHasher, error) { + uapi, err := uints.New[uints.U64](api) + if err != nil { + return nil, err + } + return &digest{ + uapi: uapi, + state: newState(), + dsbyte: 0x06, + rate: 104, + outputLen: 48, + }, nil +} + +// New512 creates a new SHA3-512 hash. +// Its generic security strength is 512 bits against preimage attacks, +// and 256 bits against collision attacks. +func New512(api frontend.API) (hash.BinaryHasher, error) { + uapi, err := uints.New[uints.U64](api) + if err != nil { + return nil, err + } + return &digest{ + uapi: uapi, + state: newState(), + dsbyte: 0x06, + rate: 72, + outputLen: 64, + }, nil +} + +// NewLegacyKeccak256 creates a new Keccak-256 hash. +// +// Only use this function if you require compatibility with an existing cryptosystem +// that uses non-standard padding. All other users should use New256 instead. +func NewLegacyKeccak256(api frontend.API) (hash.BinaryHasher, error) { + uapi, err := uints.New[uints.U64](api) + if err != nil { + return nil, err + } + return &digest{ + uapi: uapi, + state: newState(), + dsbyte: 0x01, + rate: 136, + outputLen: 32, + }, nil +} + +// NewLegacyKeccak512 creates a new Keccak-512 hash. +// +// Only use this function if you require compatibility with an existing cryptosystem +// that uses non-standard padding. All other users should use New512 instead. +func NewLegacyKeccak512(api frontend.API) (hash.BinaryHasher, error) { + uapi, err := uints.New[uints.U64](api) + if err != nil { + return nil, err + } + return &digest{ + uapi: uapi, + state: newState(), + dsbyte: 0x01, + rate: 72, + outputLen: 64, + }, nil +} diff --git a/std/hash/sha3/sha3.go b/std/hash/sha3/sha3.go new file mode 100644 index 0000000000..76cd2c8a0a --- /dev/null +++ b/std/hash/sha3/sha3.go @@ -0,0 +1,91 @@ +package sha3 + +import ( + "github.com/consensys/gnark/std/math/uints" + "github.com/consensys/gnark/std/permutation/keccakf" +) + +type digest struct { + uapi *uints.BinaryField[uints.U64] + state [25]uints.U64 // 1600 bits state: 25 x 64 + in []uints.U8 // input to be digested + dsbyte byte // dsbyte contains the "domain separation" bits and the first bit of the padding + rate int // the number of bytes of state to use + outputLen int // the default output size in bytes +} + +func (d *digest) Write(in []uints.U8) { + d.in = append(d.in, in...) +} + +func (d *digest) Size() int { return d.outputLen } + +func (d *digest) Reset() { + d.in = nil + d.state = newState() +} + +func (d *digest) Sum() []uints.U8 { + padded := d.padding() + blocks := d.composeBlocks(padded) + d.absorbing(blocks) + return d.squeezeBlocks() +} + +func (d *digest) padding() []uints.U8 { + padded := make([]uints.U8, len(d.in)) + copy(padded[:], d.in[:]) + + switch q := d.rate - (len(padded) % d.rate); q { + case 1: + padded = append(padded, uints.NewU8(d.dsbyte^0x80)) + case 2: + padded = append(padded, uints.NewU8(d.dsbyte)) + padded = append(padded, uints.NewU8(0x80)) + default: + padded = append(padded, uints.NewU8(d.dsbyte)) + padded = append(padded, uints.NewU8Array(make([]uint8, q-2))...) + padded = append(padded, uints.NewU8(0x80)) + } + + return padded +} + +func (d *digest) composeBlocks(padded []uints.U8) [][]uints.U64 { + blocks := make([][]uints.U64, len(padded)/d.rate) + + for i := range blocks { + block := make([]uints.U64, d.rate/8) + for j := range block { + u64 := padded[j*8 : j*8+8] + block[j] = d.uapi.PackLSB(u64...) + } + blocks[i] = block + padded = padded[d.rate:] + } + + return blocks +} + +func (d *digest) absorbing(blocks [][]uints.U64) { + for _, block := range blocks { + for i := range block { + d.state[i] = d.uapi.Xor(d.state[i], block[i]) + } + d.state = keccakf.Permute(d.uapi, d.state) + } +} + +func (d *digest) squeezeBlocks() (result []uints.U8) { + for i := 0; i < d.outputLen/8; i++ { + result = append(result, d.uapi.UnpackLSB(d.state[i])...) + } + return +} + +func newState() (state [25]uints.U64) { + for i := range state { + state[i] = uints.NewU64(0) + } + return +} diff --git a/std/hash/sha3/sha3_test.go b/std/hash/sha3/sha3_test.go new file mode 100644 index 0000000000..0336746519 --- /dev/null +++ b/std/hash/sha3/sha3_test.go @@ -0,0 +1,90 @@ +package sha3 + +import ( + "crypto/rand" + "fmt" + "hash" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + zkhash "github.com/consensys/gnark/std/hash" + "github.com/consensys/gnark/std/math/uints" + "github.com/consensys/gnark/test" + "golang.org/x/crypto/sha3" +) + +type testCase struct { + zk func(api frontend.API) (zkhash.BinaryHasher, error) + native func() hash.Hash +} + +var testCases = map[string]testCase{ + "SHA3-256": {New256, sha3.New256}, + "SHA3-384": {New384, sha3.New384}, + "SHA3-512": {New512, sha3.New512}, + "Keccak-256": {NewLegacyKeccak256, sha3.NewLegacyKeccak256}, + "Keccak-512": {NewLegacyKeccak512, sha3.NewLegacyKeccak512}, +} + +type sha3Circuit struct { + In []uints.U8 + Expected []uints.U8 + + hasher string +} + +func (c *sha3Circuit) Define(api frontend.API) error { + newHasher, ok := testCases[c.hasher] + if !ok { + return fmt.Errorf("hash function unknown: %s", c.hasher) + } + h, err := newHasher.zk(api) + if err != nil { + return err + } + uapi, err := uints.New[uints.U64](api) + if err != nil { + return err + } + + h.Write(c.In) + res := h.Sum() + + for i := range c.Expected { + uapi.ByteAssertEq(c.Expected[i], res[i]) + } + return nil +} + +func TestSHA3(t *testing.T) { + assert := test.NewAssert(t) + in := make([]byte, 310) + _, err := rand.Reader.Read(in) + assert.NoError(err) + + for name := range testCases { + assert.Run(func(assert *test.Assert) { + name := name + strategy := testCases[name] + h := strategy.native() + h.Write(in) + expected := h.Sum(nil) + + circuit := &sha3Circuit{ + In: make([]uints.U8, len(in)), + Expected: make([]uints.U8, len(expected)), + hasher: name, + } + + witness := &sha3Circuit{ + In: uints.NewU8Array(in), + Expected: uints.NewU8Array(expected), + } + + if err := test.IsSolved(circuit, witness, ecc.BN254.ScalarField()); err != nil { + t.Fatalf("%s: %s", name, err) + } + }, name) + } +} diff --git a/std/permutation/keccakf/keccak_test.go b/std/permutation/keccakf/keccak_test.go index b7151a893c..ec08257069 100644 --- a/std/permutation/keccakf/keccak_test.go +++ b/std/permutation/keccakf/keccak_test.go @@ -1,6 +1,7 @@ package keccakf_test import ( + "github.com/consensys/gnark/std/math/uints" "testing" "github.com/consensys/gnark-crypto/ecc" @@ -11,20 +12,24 @@ import ( ) type keccakfCircuit struct { - In [25]frontend.Variable - Expected [25]frontend.Variable `gnark:",public"` + In [25]uints.U64 + Expected [25]uints.U64 `gnark:",public"` } func (c *keccakfCircuit) Define(api frontend.API) error { - var res [25]frontend.Variable + var res [25]uints.U64 for i := range res { res[i] = c.In[i] } + uapi, err := uints.New[uints.U64](api) + if err != nil { + return err + } for i := 0; i < 2; i++ { - res = keccakf.Permute(api, res) + res = keccakf.Permute(uapi, res) } for i := range res { - api.AssertIsEqual(res[i], c.Expected[i]) + uapi.AssertEq(res[i], c.Expected[i]) } return nil } @@ -41,8 +46,8 @@ func TestKeccakf(t *testing.T) { } witness := keccakfCircuit{} for i := range nativeIn { - witness.In[i] = nativeIn[i] - witness.Expected[i] = res[i] + witness.In[i] = uints.NewU64(nativeIn[i]) + witness.Expected[i] = uints.NewU64(res[i]) } assert := test.NewAssert(t) assert.ProverSucceeded(&keccakfCircuit{}, &witness, diff --git a/std/permutation/keccakf/keccakf.go b/std/permutation/keccakf/keccakf.go index 8f5e3ae346..48aded9131 100644 --- a/std/permutation/keccakf/keccakf.go +++ b/std/permutation/keccakf/keccakf.go @@ -11,7 +11,6 @@ package keccakf import ( - "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/math/uints" ) @@ -50,24 +49,12 @@ var piln = [24]int{ 15, 23, 19, 13, 12, 2, 20, 14, 22, 9, 6, 1, } -// Permute applies Keccak-F permutation on the input a and returns the permuted -// vector. The input array must consist of 64-bit (unsigned) integers. The -// returned array also contains 64-bit unsigned integers. -func Permute(api frontend.API, a [25]frontend.Variable) [25]frontend.Variable { - var in [25]uints.U64 - uapi, err := uints.New[uints.U64](api) - if err != nil { - panic(err) // TODO: return error instead - } - for i := range a { - in[i] = uapi.ValueOf(a[i]) - } - res := permute(uapi, in) - var out [25]frontend.Variable - for i := range out { - out[i] = uapi.ToValue(res[i]) - } - return out +// Permute applies Keccak-F permutation on the input and returns the permuted vector. +// Original input is not modified. +func Permute(uapi *uints.BinaryField[uints.U64], input [25]uints.U64) [25]uints.U64 { + var state [25]uints.U64 + copy(state[:], input[:]) + return permute(uapi, state) } func permute(uapi *uints.BinaryField[uints.U64], st [25]uints.U64) [25]uints.U64 {