From c97576b7d6e63bfb136a9ce597e711a58e420c73 Mon Sep 17 00:00:00 2001 From: Robin Date: Fri, 24 May 2024 14:37:58 +0200 Subject: [PATCH] Update bn256 from cloudflare's changes (#518) * Moved scheme.go and threshold.go into internals and uncommented bls_test.go * Removed nerr++ in favor of len(errors) * Changed the path for test * Sorting imports * Update bn256 from cloudflare's changes * Added the bn256/hash.go and its tests --- pairing/bn256/README.md | 2 +- pairing/bn256/constants.go | 15 ++++ pairing/bn256/gfp.go | 66 +++++++++++++- pairing/bn256/gfp_decl.go | 1 - pairing/bn256/gfp_generic.go | 1 - pairing/bn256/gfp_test.go | 166 +++++++++++++++++++++++++++++++++++ pairing/bn256/hash.go | 107 ++++++++++++++++++++++ pairing/bn256/hash_test.go | 55 ++++++++++++ pairing/bn256/mul_arm64.h | 1 + pairing/bn256/optate.go | 3 +- 10 files changed, 409 insertions(+), 8 deletions(-) create mode 100644 pairing/bn256/gfp_test.go create mode 100644 pairing/bn256/hash.go create mode 100644 pairing/bn256/hash_test.go diff --git a/pairing/bn256/README.md b/pairing/bn256/README.md index 23050e963..01ad08b4a 100644 --- a/pairing/bn256/README.md +++ b/pairing/bn256/README.md @@ -50,4 +50,4 @@ The basis for this package is [Cloudflare's bn256 implementation](https://github which itself is an improved version of the [official bn256 package](https://golang.org/x/crypto/bn256). The package at hand maintains compatibility to Cloudflare's library. The biggest difference is the replacement of their [public API](https://github.com/cloudflare/bn256/blob/master/bn256.go) by a new -one that is compatible to Kyber's scalar, point, group, and suite interfaces. +one that is compatible to Kyber's scalar, point, group, and suite interfaces. *Last update 05.2024* diff --git a/pairing/bn256/constants.go b/pairing/bn256/constants.go index 943751a07..d31c5f7f6 100644 --- a/pairing/bn256/constants.go +++ b/pairing/bn256/constants.go @@ -54,3 +54,18 @@ var r2 = &gfP{0x9c21c3ff7e444f56, 0x409ed151b2efb0c2, 0xc6dc37b80fb1651, 0x7c36e // r3 is R^3 where R = 2^256 mod p. var r3 = &gfP{0x2af2dfb9324a5bb8, 0x388f899054f538a4, 0xdf2ff66396b107a7, 0x24ebbbb3a2529292} + +// pPlus1Over4 is (p+1)/4. +var pPlus1Over4 = [4]uint64{0x86172b1b1782259a, 0x7b96e234482d6d67, 0x6a9bfb2e18613708, 0x23ed4078d2a8e1fe} + +// pMinus2 is p-2. +var pMinus2 = [4]uint64{0x185cac6c5e089665, 0xee5b88d120b5b59e, 0xaa6fecb86184dc21, 0x8fb501e34aa387f9} + +// pMinus1Over2 is (p-1)/2. +var pMinus1Over2 = [4]uint64{0x0c2e56362f044b33, 0xf72dc468905adacf, 0xd537f65c30c26e10, 0x47da80f1a551c3fc} + +// s is the Montgomery encoding of the square root of -3. Then, s = sqrt(-3) * 2^256 mod p. +var s = &gfP{0x236e675956be783b, 0x053957e6f379ab64, 0xe60789a768f4a5c4, 0x04f8979dd8bad754} + +// sMinus1Over2 is the Montgomery encoding of (s-1)/2. Then, sMinus1Over2 = ( (s-1) / 2) * 2^256 mod p. +var sMinus1Over2 = &gfP{0x3642364f386c1db8, 0xe825f92d2acd661f, 0xf2aba7e846c19d14, 0x5a0bcea3dc52b7a0} diff --git a/pairing/bn256/gfp.go b/pairing/bn256/gfp.go index d6df830a7..6e9e9006a 100644 --- a/pairing/bn256/gfp.go +++ b/pairing/bn256/gfp.go @@ -1,7 +1,10 @@ package bn256 import ( + "crypto/sha256" + "encoding/binary" "fmt" + "golang.org/x/crypto/hkdf" "math/big" ) @@ -36,6 +39,29 @@ func newGFpFromBigInt(bigInt *big.Int) *gfP { return out } +func hashToBase(msg, dst []byte) *gfP { + var t [48]byte + info := []byte{'H', '2', 'C', byte(0), byte(1)} + r := hkdf.New(sha256.New, msg, dst, info) + if _, err := r.Read(t[:]); err != nil { + panic(err) + } + var x big.Int + v := x.SetBytes(t[:]).Mod(&x, p).Bytes() + v32 := [32]byte{} + for i := len(v) - 1; i >= 0; i-- { + v32[len(v)-1-i] = v[i] + } + u := &gfP{ + binary.LittleEndian.Uint64(v32[0*8 : 1*8]), + binary.LittleEndian.Uint64(v32[1*8 : 2*8]), + binary.LittleEndian.Uint64(v32[2*8 : 3*8]), + binary.LittleEndian.Uint64(v32[3*8 : 4*8]), + } + montEncode(u, u) + return u +} + func (e *gfP) String() string { return fmt.Sprintf("%16.16x%16.16x%16.16x%16.16x", e[3], e[2], e[1], e[0]) } @@ -47,9 +73,7 @@ func (e *gfP) Set(f *gfP) { e[3] = f[3] } -func (e *gfP) Invert(f *gfP) { - bits := [4]uint64{0x185cac6c5e089665, 0xee5b88d120b5b59e, 0xaa6fecb86184dc21, 0x8fb501e34aa387f9} - +func (e *gfP) exp(f *gfP, bits [4]uint64) { sum, power := &gfP{}, &gfP{} sum.Set(rN1) power.Set(f) @@ -67,6 +91,15 @@ func (e *gfP) Invert(f *gfP) { e.Set(sum) } +func (e *gfP) Invert(f *gfP) { + e.exp(f, pMinus2) +} + +func (e *gfP) Sqrt(f *gfP) { + // Since p = 4k+3, then e = f^(k+1) is a root of f. + e.exp(f, pPlus1Over4) +} + func (e *gfP) Marshal(out []byte) { for w := uint(0); w < 4; w++ { for b := uint(0); b < 8; b++ { @@ -96,3 +129,30 @@ func (e *gfP) BigInt() *big.Int { func montEncode(c, a *gfP) { gfpMul(c, a, r2) } func montDecode(c, a *gfP) { gfpMul(c, a, &gfP{1}) } + +func sign0(e *gfP) int { + x := &gfP{} + montDecode(x, e) + for w := 3; w >= 0; w-- { + if x[w] > pMinus1Over2[w] { + return 1 + } else if x[w] < pMinus1Over2[w] { + return -1 + } + } + return 1 +} + +func legendre(e *gfP) int { + f := &gfP{} + // Since p = 4k+3, then e^(2k+1) is the Legendre symbol of e. + f.exp(e, pMinus1Over2) + + montDecode(f, f) + + if *f != [4]uint64{} { + return 2*int(f[0]&1) - 1 + } + + return 0 +} diff --git a/pairing/bn256/gfp_decl.go b/pairing/bn256/gfp_decl.go index bdb6a8915..23df6f186 100644 --- a/pairing/bn256/gfp_decl.go +++ b/pairing/bn256/gfp_decl.go @@ -1,5 +1,4 @@ //go:build (amd64 && !generic) || (arm64 && !generic) -// +build amd64,!generic arm64,!generic package bn256 diff --git a/pairing/bn256/gfp_generic.go b/pairing/bn256/gfp_generic.go index 7742dda4c..944208c67 100644 --- a/pairing/bn256/gfp_generic.go +++ b/pairing/bn256/gfp_generic.go @@ -1,5 +1,4 @@ //go:build (!amd64 && !arm64) || generic -// +build !amd64,!arm64 generic package bn256 diff --git a/pairing/bn256/gfp_test.go b/pairing/bn256/gfp_test.go new file mode 100644 index 000000000..54570b4f2 --- /dev/null +++ b/pairing/bn256/gfp_test.go @@ -0,0 +1,166 @@ +package bn256 + +import ( + "crypto/rand" + "encoding/binary" + "io" + "math/big" + "testing" +) + +// randomGF returns a random integer between 0 and p-1. +func randomGF(r io.Reader) *big.Int { + k, err := rand.Int(r, p) + if err != nil { + panic(err) + } + return k +} + +// toBigInt converts a field element into its reduced (mod p) +// integer representation. +func toBigInt(a *gfP) *big.Int { + v := &gfP{} + montDecode(v, a) + c := new(big.Int) + for i := len(v) - 1; i >= 0; i-- { + c.Lsh(c, 64) + c.Add(c, new(big.Int).SetUint64(v[i])) + } + return c +} + +// togfP converts an integer into a field element (in +// Montgomery representation). This function assumes the +// input is between 0 and p-1; otherwise it panics. +func togfP(k *big.Int) *gfP { + if k.Cmp(p) >= 0 { + panic("not in the range 0 to p-1") + } + v := k.Bytes() + v32 := [32]byte{} + for i := len(v) - 1; i >= 0; i-- { + v32[len(v)-1-i] = v[i] + } + u := &gfP{ + binary.LittleEndian.Uint64(v32[0*8 : 1*8]), + binary.LittleEndian.Uint64(v32[1*8 : 2*8]), + binary.LittleEndian.Uint64(v32[2*8 : 3*8]), + binary.LittleEndian.Uint64(v32[3*8 : 4*8]), + } + montEncode(u, u) + return u +} + +func TestGFp(t *testing.T) { + const testTimes = 1 << 8 + + t.Run("add", func(t *testing.T) { + c := &gfP{} + bigC := new(big.Int) + for i := 0; i < testTimes; i++ { + bigA := randomGF(rand.Reader) + bigB := randomGF(rand.Reader) + want := bigC.Add(bigA, bigB).Mod(bigC, p) + + a := togfP(bigA) + b := togfP(bigB) + gfpAdd(c, a, b) + got := toBigInt(c) + + if got.Cmp(want) != 0 { + t.Errorf("got: %v want:%v", got, want) + } + } + }) + + t.Run("sub", func(t *testing.T) { + c := &gfP{} + bigC := new(big.Int) + for i := 0; i < testTimes; i++ { + bigA := randomGF(rand.Reader) + bigB := randomGF(rand.Reader) + want := bigC.Sub(bigA, bigB).Mod(bigC, p) + + a := togfP(bigA) + b := togfP(bigB) + gfpSub(c, a, b) + got := toBigInt(c) + + if got.Cmp(want) != 0 { + t.Errorf("got: %v want:%v", got, want) + } + } + }) + + t.Run("mul", func(t *testing.T) { + c := &gfP{} + bigC := new(big.Int) + for i := 0; i < testTimes; i++ { + bigA := randomGF(rand.Reader) + bigB := randomGF(rand.Reader) + want := bigC.Mul(bigA, bigB).Mod(bigC, p) + + a := togfP(bigA) + b := togfP(bigB) + gfpMul(c, a, b) + got := toBigInt(c) + + if got.Cmp(want) != 0 { + t.Errorf("got: %v want:%v", got, want) + } + } + }) + + t.Run("neg", func(t *testing.T) { + c := &gfP{} + bigC := new(big.Int) + for i := 0; i < testTimes; i++ { + bigA := randomGF(rand.Reader) + want := bigC.Neg(bigA).Mod(bigC, p) + + a := togfP(bigA) + gfpNeg(c, a) + got := toBigInt(c) + + if got.Cmp(want) != 0 { + t.Errorf("got: %v want:%v", got, want) + } + } + }) + + t.Run("inv", func(t *testing.T) { + c := &gfP{} + bigC := new(big.Int) + for i := 0; i < testTimes; i++ { + bigA := randomGF(rand.Reader) + want := bigC.ModInverse(bigA, p) + + a := togfP(bigA) + c.Invert(a) + got := toBigInt(c) + + if got.Cmp(want) != 0 { + t.Errorf("got: %v want:%v", got, want) + } + } + }) + + t.Run("sqrt", func(t *testing.T) { + c := &gfP{} + bigC := new(big.Int) + for i := 0; i < testTimes; i++ { + bigA := randomGF(rand.Reader) + bigA.Mul(bigA, bigA).Mod(bigA, p) + want := bigC.ModSqrt(bigA, p) + + a := togfP(bigA) + c.Sqrt(a) + got := toBigInt(c) + + if got.Cmp(want) != 0 { + t.Errorf("got: %v want:%v", got, want) + } + } + }) +} diff --git a/pairing/bn256/hash.go b/pairing/bn256/hash.go new file mode 100644 index 000000000..696fb8de3 --- /dev/null +++ b/pairing/bn256/hash.go @@ -0,0 +1,107 @@ +package bn256 + +import "go.dedis.ch/kyber/v3" + +// HashG1 implements a hashing function into the G1 group. +// +// dst represents domain separation tag, similar to salt, for the hash. +func HashG1(msg, dst []byte) kyber.Point { + return mapToCurve(hashToBase(msg, dst)) +} + +func mapToCurve(t *gfP) kyber.Point { + one := *newGFp(1) + + // calculate w = (s * t)/(1 + B + t^2) + // we calculate w0 = s * t * (1 + B + t^2) and inverse of it, so that w = (st)^2/w0 + // and then later x3 = 1 + (1 + B + t^2)^4/w0^2 + w := &gfP{} + + // a = (1 + B + t^2) + a := &gfP{} + t2 := &gfP{} + gfpMul(t2, t, t) + gfpAdd(a, curveB, t2) + gfpAdd(a, a, &one) + + st := &gfP{} + gfpMul(st, s, t) + + w0 := &gfP{} + gfpMul(w0, st, a) + w0.Invert(w0) + + gfpMul(w, st, st) + gfpMul(w, w, w0) + + e := sign0(t) + cp := &curvePoint{z: one, t: one} + + // calculate x1 = ((-1 + s) / 2) - t * w + tw := &gfP{} + gfpMul(tw, t, w) + x1 := &gfP{} + gfpSub(x1, sMinus1Over2, tw) + + // check if y=x1^3+3 is a square + y := &gfP{} + y.Set(x1) + gfpMul(y, x1, x1) + gfpMul(y, y, x1) + gfpAdd(y, y, curveB) + if legendre(y) == 1 { + cp.x = *x1 + y.Sqrt(y) + if e != sign0(y) { + gfpNeg(y, y) + } + cp.y = *y + + pg1 := pointG1{cp} + return pg1.Clone() + } + + // calculate x2 = -1 - x1 + x2 := newGFp(-1) + gfpSub(x2, x2, x1) + + // check if y=x2^3+3 is a square + y.Set(x2) + gfpMul(y, x2, x2) + gfpMul(y, y, x2) + gfpAdd(y, y, curveB) + if legendre(y) == 1 { + cp.x = *x2 + y.Sqrt(y) + if e != sign0(y) { + gfpNeg(y, y) + } + cp.y = *y + + pg1 := pointG1{cp} + return pg1.Clone() + } + + // calculate x3 = 1 + (1/ww) = 1 + a^4 * w0^2 + x3 := &gfP{} + gfpMul(x3, a, a) + gfpMul(x3, x3, x3) + gfpMul(x3, x3, w0) + gfpMul(x3, x3, w0) + gfpAdd(x3, x3, &one) + + y.Set(x3) + gfpMul(y, x3, x3) + gfpMul(y, y, x3) + gfpAdd(y, y, curveB) + + cp.x = *x3 + y.Sqrt(y) + if e != sign0(y) { + gfpNeg(y, y) + } + cp.y = *y + + pg1 := pointG1{cp} + return pg1.Clone() +} diff --git a/pairing/bn256/hash_test.go b/pairing/bn256/hash_test.go new file mode 100644 index 000000000..5c5353921 --- /dev/null +++ b/pairing/bn256/hash_test.go @@ -0,0 +1,55 @@ +package bn256 + +import ( + "testing" + + "bytes" +) + +func TestKnownHashes(t *testing.T) { + for i, mh := range marshaledHashes { + g := HashG1([]byte{byte(i)}, nil) + b, _ := g.MarshalBinary() + if !bytes.Equal(mh[:], b) { + t.Fatal("hash doesn't match a known value") + } + } +} + +var buf = make([]byte, 8192) + +func benchmarkSize(b *testing.B, size int) { + b.SetBytes(int64(size)) + for i := 0; i < b.N; i++ { + HashG1(buf[:size], nil) + } +} + +func BenchmarkHashG1Size8bytes(b *testing.B) { + b.ResetTimer() + benchmarkSize(b, 8) +} + +func BenchmarkHashG1Size1k(b *testing.B) { + b.ResetTimer() + benchmarkSize(b, 1024) +} + +func BenchmarkHashG1Size8k(b *testing.B) { + b.ResetTimer() + benchmarkSize(b, 8192) +} + +var marshaledHashes = [11][64]byte{ + [64]byte{80, 233, 64, 52, 60, 233, 95, 49, 57, 115, 89, 101, 189, 182, 251, 43, 158, 186, 22, 10, 130, 128, 127, 143, 10, 158, 148, 102, 148, 86, 194, 111, 98, 232, 82, 178, 190, 193, 65, 1, 58, 126, 154, 37, 11, 185, 207, 250, 219, 202, 140, 196, 2, 35, 223, 87, 13, 60, 204, 201, 34, 231, 118, 206}, + [64]byte{123, 250, 39, 222, 32, 210, 254, 221, 94, 5, 32, 6, 19, 120, 252, 162, 110, 53, 149, 185, 209, 83, 189, 194, 77, 40, 160, 168, 17, 143, 13, 121, 72, 31, 247, 190, 150, 8, 159, 57, 145, 45, 129, 145, 164, 29, 156, 159, 182, 177, 142, 145, 38, 236, 98, 84, 157, 8, 164, 38, 123, 73, 215, 23}, + [64]byte{129, 78, 244, 19, 205, 198, 70, 100, 63, 152, 218, 52, 132, 20, 180, 241, 223, 109, 93, 80, 59, 6, 16, 183, 99, 5, 202, 77, 136, 165, 254, 32, 124, 242, 44, 52, 28, 76, 54, 116, 113, 243, 51, 101, 114, 70, 190, 124, 81, 194, 77, 8, 163, 135, 148, 175, 224, 248, 184, 44, 167, 124, 10, 30}, + [64]byte{113, 140, 119, 103, 41, 163, 49, 69, 93, 208, 11, 126, 85, 100, 1, 11, 151, 207, 202, 144, 7, 154, 203, 84, 123, 255, 67, 107, 189, 188, 93, 14, 131, 167, 214, 27, 85, 82, 122, 220, 131, 237, 192, 206, 159, 132, 216, 254, 227, 52, 232, 216, 182, 154, 170, 46, 99, 78, 137, 79, 90, 30, 236, 16}, + [64]byte{128, 205, 34, 132, 54, 241, 30, 185, 253, 248, 45, 227, 78, 202, 148, 137, 224, 86, 199, 253, 98, 156, 169, 132, 129, 141, 118, 247, 102, 200, 47, 231, 62, 4, 169, 180, 190, 184, 212, 40, 88, 118, 134, 129, 149, 108, 105, 153, 54, 153, 40, 159, 189, 245, 63, 172, 43, 49, 22, 246, 154, 57, 63, 57}, + [64]byte{44, 243, 231, 191, 3, 107, 182, 73, 39, 43, 51, 20, 25, 235, 151, 112, 207, 24, 28, 96, 201, 60, 175, 210, 179, 42, 117, 101, 16, 196, 82, 238, 126, 198, 61, 68, 228, 96, 166, 130, 139, 167, 181, 195, 46, 10, 51, 83, 59, 165, 249, 111, 205, 113, 80, 43, 240, 194, 72, 240, 64, 235, 120, 34}, + [64]byte{125, 159, 122, 73, 206, 48, 230, 111, 229, 18, 224, 100, 101, 149, 116, 190, 47, 116, 78, 156, 94, 87, 164, 157, 156, 211, 110, 229, 191, 250, 213, 83, 139, 111, 120, 241, 26, 131, 125, 200, 87, 166, 76, 136, 241, 37, 113, 44, 200, 158, 236, 122, 0, 33, 172, 198, 242, 255, 33, 101, 142, 245, 180, 243}, + [64]byte{130, 99, 125, 203, 106, 197, 191, 151, 248, 98, 27, 76, 200, 122, 173, 139, 129, 31, 54, 51, 206, 49, 122, 51, 57, 88, 139, 191, 42, 22, 158, 125, 100, 87, 23, 89, 148, 160, 5, 224, 46, 35, 217, 254, 28, 247, 86, 227, 186, 200, 3, 206, 50, 134, 14, 193, 23, 58, 2, 161, 52, 1, 201, 136}, + [64]byte{76, 28, 164, 59, 70, 75, 165, 57, 131, 109, 238, 103, 17, 89, 191, 194, 78, 248, 115, 8, 108, 206, 46, 235, 52, 219, 98, 231, 194, 252, 229, 98, 55, 45, 194, 177, 115, 176, 207, 167, 174, 12, 94, 199, 63, 175, 214, 137, 190, 168, 67, 247, 107, 64, 169, 74, 250, 174, 177, 141, 93, 207, 71, 147}, + [64]byte{45, 115, 123, 118, 162, 144, 82, 134, 198, 17, 162, 200, 91, 168, 191, 115, 31, 66, 81, 201, 111, 250, 133, 16, 247, 62, 92, 251, 227, 234, 116, 183, 16, 117, 103, 177, 94, 201, 169, 155, 59, 218, 174, 242, 28, 66, 171, 113, 245, 247, 98, 236, 193, 26, 85, 62, 215, 101, 229, 214, 191, 153, 176, 168}, + [64]byte{143, 123, 127, 149, 167, 27, 159, 25, 254, 211, 196, 88, 17, 185, 138, 237, 62, 140, 84, 177, 134, 58, 193, 141, 25, 152, 79, 6, 41, 39, 248, 117, 52, 208, 167, 215, 212, 60, 250, 228, 1, 232, 111, 254, 154, 18, 209, 55, 207, 200, 68, 60, 163, 106, 59, 27, 12, 72, 130, 141, 182, 103, 16, 80}, +} diff --git a/pairing/bn256/mul_arm64.h b/pairing/bn256/mul_arm64.h index d405eb8f7..b43404bb9 100644 --- a/pairing/bn256/mul_arm64.h +++ b/pairing/bn256/mul_arm64.h @@ -126,6 +126,7 @@ SBCS R6, R22, R11 \ SBCS R7, R23, R12 \ SBCS R8, R24, R13 \ + SBCS $0, R0, R0 \ \ CSEL CS, R10, R21, R1 \ CSEL CS, R11, R22, R2 \ diff --git a/pairing/bn256/optate.go b/pairing/bn256/optate.go index 126c64ca6..a235a45be 100644 --- a/pairing/bn256/optate.go +++ b/pairing/bn256/optate.go @@ -196,9 +196,8 @@ func miller(q *twistPoint, p *curvePoint) *gfP12 { r = newR r2.Square(&minusQ2.y) - a, b, c, newR = lineFunctionAdd(r, minusQ2, bAffine, r2) + a, b, c, _ = lineFunctionAdd(r, minusQ2, bAffine, r2) mulLine(ret, a, b, c) - r = newR return ret }