Skip to content

Commit

Permalink
feat: add sha3 primitive (#817)
Browse files Browse the repository at this point in the history
* refactor: keccakf1600 accepts and returns [25]uints.U64 instead of [25]frontend.Variable

* feat: add sha3 primitive

* fix: sha3 test current hash global elimination => moved to struct
  • Loading branch information
NikitaMasych authored Aug 29, 2023
1 parent 0896fe1 commit 2a6e749
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 26 deletions.
5 changes: 5 additions & 0 deletions std/hash/sha3/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// Package sha3 provides ZKP circuits for SHA3 hash algorithms applying sponge construction over
// Keccak f-[1600] permutation function.
//
// Instances correspond golang.org/x/crypto/sha3, except SHA224, which is not x64 compatible.
package sha3
94 changes: 94 additions & 0 deletions std/hash/sha3/hashes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package sha3

import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/hash"
"github.com/consensys/gnark/std/math/uints"
)

// New256 creates a new SHA3-256 hash.
// Its generic security strength is 256 bits against preimage attacks,
// and 128 bits against collision attacks.
func New256(api frontend.API) (hash.BinaryHasher, error) {
uapi, err := uints.New[uints.U64](api)
if err != nil {
return nil, err
}
return &digest{
uapi: uapi,
state: newState(),
dsbyte: 0x06,
rate: 136,
outputLen: 32,
}, nil
}

// New384 creates a new SHA3-384 hash.
// Its generic security strength is 384 bits against preimage attacks,
// and 192 bits against collision attacks.
func New384(api frontend.API) (hash.BinaryHasher, error) {
uapi, err := uints.New[uints.U64](api)
if err != nil {
return nil, err
}
return &digest{
uapi: uapi,
state: newState(),
dsbyte: 0x06,
rate: 104,
outputLen: 48,
}, nil
}

// New512 creates a new SHA3-512 hash.
// Its generic security strength is 512 bits against preimage attacks,
// and 256 bits against collision attacks.
func New512(api frontend.API) (hash.BinaryHasher, error) {
uapi, err := uints.New[uints.U64](api)
if err != nil {
return nil, err
}
return &digest{
uapi: uapi,
state: newState(),
dsbyte: 0x06,
rate: 72,
outputLen: 64,
}, nil
}

// NewLegacyKeccak256 creates a new Keccak-256 hash.
//
// Only use this function if you require compatibility with an existing cryptosystem
// that uses non-standard padding. All other users should use New256 instead.
func NewLegacyKeccak256(api frontend.API) (hash.BinaryHasher, error) {
uapi, err := uints.New[uints.U64](api)
if err != nil {
return nil, err
}
return &digest{
uapi: uapi,
state: newState(),
dsbyte: 0x01,
rate: 136,
outputLen: 32,
}, nil
}

// NewLegacyKeccak512 creates a new Keccak-512 hash.
//
// Only use this function if you require compatibility with an existing cryptosystem
// that uses non-standard padding. All other users should use New512 instead.
func NewLegacyKeccak512(api frontend.API) (hash.BinaryHasher, error) {
uapi, err := uints.New[uints.U64](api)
if err != nil {
return nil, err
}
return &digest{
uapi: uapi,
state: newState(),
dsbyte: 0x01,
rate: 72,
outputLen: 64,
}, nil
}
91 changes: 91 additions & 0 deletions std/hash/sha3/sha3.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package sha3

import (
"github.com/consensys/gnark/std/math/uints"
"github.com/consensys/gnark/std/permutation/keccakf"
)

type digest struct {
uapi *uints.BinaryField[uints.U64]
state [25]uints.U64 // 1600 bits state: 25 x 64
in []uints.U8 // input to be digested
dsbyte byte // dsbyte contains the "domain separation" bits and the first bit of the padding
rate int // the number of bytes of state to use
outputLen int // the default output size in bytes
}

func (d *digest) Write(in []uints.U8) {
d.in = append(d.in, in...)
}

func (d *digest) Size() int { return d.outputLen }

func (d *digest) Reset() {
d.in = nil
d.state = newState()
}

func (d *digest) Sum() []uints.U8 {
padded := d.padding()
blocks := d.composeBlocks(padded)
d.absorbing(blocks)
return d.squeezeBlocks()
}

func (d *digest) padding() []uints.U8 {
padded := make([]uints.U8, len(d.in))
copy(padded[:], d.in[:])

switch q := d.rate - (len(padded) % d.rate); q {
case 1:
padded = append(padded, uints.NewU8(d.dsbyte^0x80))
case 2:
padded = append(padded, uints.NewU8(d.dsbyte))
padded = append(padded, uints.NewU8(0x80))
default:
padded = append(padded, uints.NewU8(d.dsbyte))
padded = append(padded, uints.NewU8Array(make([]uint8, q-2))...)
padded = append(padded, uints.NewU8(0x80))
}

return padded
}

func (d *digest) composeBlocks(padded []uints.U8) [][]uints.U64 {
blocks := make([][]uints.U64, len(padded)/d.rate)

for i := range blocks {
block := make([]uints.U64, d.rate/8)
for j := range block {
u64 := padded[j*8 : j*8+8]
block[j] = d.uapi.PackLSB(u64...)
}
blocks[i] = block
padded = padded[d.rate:]
}

return blocks
}

func (d *digest) absorbing(blocks [][]uints.U64) {
for _, block := range blocks {
for i := range block {
d.state[i] = d.uapi.Xor(d.state[i], block[i])
}
d.state = keccakf.Permute(d.uapi, d.state)
}
}

func (d *digest) squeezeBlocks() (result []uints.U8) {
for i := 0; i < d.outputLen/8; i++ {
result = append(result, d.uapi.UnpackLSB(d.state[i])...)
}
return
}

func newState() (state [25]uints.U64) {
for i := range state {
state[i] = uints.NewU64(0)
}
return
}
90 changes: 90 additions & 0 deletions std/hash/sha3/sha3_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package sha3

import (
"crypto/rand"
"fmt"
"hash"
"testing"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/frontend"
zkhash "github.com/consensys/gnark/std/hash"
"github.com/consensys/gnark/std/math/uints"
"github.com/consensys/gnark/test"
"golang.org/x/crypto/sha3"
)

type testCase struct {
zk func(api frontend.API) (zkhash.BinaryHasher, error)
native func() hash.Hash
}

var testCases = map[string]testCase{
"SHA3-256": {New256, sha3.New256},
"SHA3-384": {New384, sha3.New384},
"SHA3-512": {New512, sha3.New512},
"Keccak-256": {NewLegacyKeccak256, sha3.NewLegacyKeccak256},
"Keccak-512": {NewLegacyKeccak512, sha3.NewLegacyKeccak512},
}

type sha3Circuit struct {
In []uints.U8
Expected []uints.U8

hasher string
}

func (c *sha3Circuit) Define(api frontend.API) error {
newHasher, ok := testCases[c.hasher]
if !ok {
return fmt.Errorf("hash function unknown: %s", c.hasher)
}
h, err := newHasher.zk(api)
if err != nil {
return err
}
uapi, err := uints.New[uints.U64](api)
if err != nil {
return err
}

h.Write(c.In)
res := h.Sum()

for i := range c.Expected {
uapi.ByteAssertEq(c.Expected[i], res[i])
}
return nil
}

func TestSHA3(t *testing.T) {
assert := test.NewAssert(t)
in := make([]byte, 310)
_, err := rand.Reader.Read(in)
assert.NoError(err)

for name := range testCases {
assert.Run(func(assert *test.Assert) {
name := name
strategy := testCases[name]
h := strategy.native()
h.Write(in)
expected := h.Sum(nil)

circuit := &sha3Circuit{
In: make([]uints.U8, len(in)),
Expected: make([]uints.U8, len(expected)),
hasher: name,
}

witness := &sha3Circuit{
In: uints.NewU8Array(in),
Expected: uints.NewU8Array(expected),
}

if err := test.IsSolved(circuit, witness, ecc.BN254.ScalarField()); err != nil {
t.Fatalf("%s: %s", name, err)
}
}, name)
}
}
19 changes: 12 additions & 7 deletions std/permutation/keccakf/keccak_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package keccakf_test

import (
"github.com/consensys/gnark/std/math/uints"
"testing"

"github.com/consensys/gnark-crypto/ecc"
Expand All @@ -11,20 +12,24 @@ import (
)

type keccakfCircuit struct {
In [25]frontend.Variable
Expected [25]frontend.Variable `gnark:",public"`
In [25]uints.U64
Expected [25]uints.U64 `gnark:",public"`
}

func (c *keccakfCircuit) Define(api frontend.API) error {
var res [25]frontend.Variable
var res [25]uints.U64
for i := range res {
res[i] = c.In[i]
}
uapi, err := uints.New[uints.U64](api)
if err != nil {
return err
}
for i := 0; i < 2; i++ {
res = keccakf.Permute(api, res)
res = keccakf.Permute(uapi, res)
}
for i := range res {
api.AssertIsEqual(res[i], c.Expected[i])
uapi.AssertEq(res[i], c.Expected[i])
}
return nil
}
Expand All @@ -41,8 +46,8 @@ func TestKeccakf(t *testing.T) {
}
witness := keccakfCircuit{}
for i := range nativeIn {
witness.In[i] = nativeIn[i]
witness.Expected[i] = res[i]
witness.In[i] = uints.NewU64(nativeIn[i])
witness.Expected[i] = uints.NewU64(res[i])
}
assert := test.NewAssert(t)
assert.ProverSucceeded(&keccakfCircuit{}, &witness,
Expand Down
25 changes: 6 additions & 19 deletions std/permutation/keccakf/keccakf.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
package keccakf

import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/uints"
)

Expand Down Expand Up @@ -50,24 +49,12 @@ var piln = [24]int{
15, 23, 19, 13, 12, 2, 20, 14, 22, 9, 6, 1,
}

// Permute applies Keccak-F permutation on the input a and returns the permuted
// vector. The input array must consist of 64-bit (unsigned) integers. The
// returned array also contains 64-bit unsigned integers.
func Permute(api frontend.API, a [25]frontend.Variable) [25]frontend.Variable {
var in [25]uints.U64
uapi, err := uints.New[uints.U64](api)
if err != nil {
panic(err) // TODO: return error instead
}
for i := range a {
in[i] = uapi.ValueOf(a[i])
}
res := permute(uapi, in)
var out [25]frontend.Variable
for i := range out {
out[i] = uapi.ToValue(res[i])
}
return out
// Permute applies Keccak-F permutation on the input and returns the permuted vector.
// Original input is not modified.
func Permute(uapi *uints.BinaryField[uints.U64], input [25]uints.U64) [25]uints.U64 {
var state [25]uints.U64
copy(state[:], input[:])
return permute(uapi, state)
}

func permute(uapi *uints.BinaryField[uints.U64], st [25]uints.U64) [25]uints.U64 {
Expand Down

0 comments on commit 2a6e749

Please sign in to comment.