diff --git a/std/algebra/emulated/sw_bls12381/bls_sig.go b/std/algebra/emulated/sw_bls12381/bls_sig.go new file mode 100644 index 0000000000..2969485391 --- /dev/null +++ b/std/algebra/emulated/sw_bls12381/bls_sig.go @@ -0,0 +1,38 @@ +package sw_bls12381 + +import ( + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/uints" +) + +const g2_dst = "BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_" + +func BlsAssertG2Verification(api frontend.API, pub G1Affine, sig G2Affine, msg []uints.U8) error { + pairing, e := NewPairing(api) + if e != nil { + return e + } + + // public key cannot be infinity + xtest := pairing.g1.curveF.IsZero(&pub.X) + ytest := pairing.g1.curveF.IsZero(&pub.Y) + pubTest := api.Or(xtest, ytest) + api.AssertIsEqual(pubTest, 0) + + // prime order subgroup checks + pairing.AssertIsOnG1(&pub) + pairing.AssertIsOnG2(&sig) + + var g1GNeg bls12381.G1Affine + _, _, g1Gen, _ := bls12381.Generators() + g1GNeg.Neg(&g1Gen) + g1GN := NewG1Affine(g1GNeg) + + h, e := HashToG2(api, msg, []byte(g2_dst)) + if e != nil { + return e + } + + return pairing.PairingCheck([]*G1Affine{&g1GN, &pub}, []*G2Affine{&sig, h}) +} diff --git a/std/algebra/emulated/sw_bls12381/bls_sig_test.go b/std/algebra/emulated/sw_bls12381/bls_sig_test.go new file mode 100644 index 0000000000..47827ab258 --- /dev/null +++ b/std/algebra/emulated/sw_bls12381/bls_sig_test.go @@ -0,0 +1,83 @@ +package sw_bls12381 + +import ( + "encoding/hex" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/uints" + "github.com/consensys/gnark/test" +) + +type blsG2SigCircuit struct { + Pub bls12381.G1Affine + msg []byte + Sig bls12381.G2Affine +} + +func (c *blsG2SigCircuit) Define(api frontend.API) error { + msg := uints.NewU8Array(c.msg) + return BlsAssertG2Verification(api, NewG1Affine(c.Pub), NewG2Affine(c.Sig), msg) +} + +// "pubkey": "0xa491d1b0ecd9bb917989f0e74f0dea0422eac4a873e5e2644f368dffb9a6e20fd6e10c1b77654d067c0618f6e5a7f79a", +// "message": "0x5656565656565656565656565656565656565656565656565656565656565656", +// "signature": "0x882730e5d03f6b42c3abc26d3372625034e1d871b65a8a6b900a56dae22da98abbe1b68f85e49fe7652a55ec3d0591c20767677e33e5cbb1207315c41a9ac03be39c2e7668edc043d6cb1d9fd93033caa8a1c5b0e84bedaeb6c64972503a43eb"}, +// "output": true} +func TestBlsSigTestSolve(t *testing.T) { + assert := test.NewAssert(t) + + msgHex := "5656565656565656565656565656565656565656565656565656565656565656" + pubHex := "a491d1b0ecd9bb917989f0e74f0dea0422eac4a873e5e2644f368dffb9a6e20fd6e10c1b77654d067c0618f6e5a7f79a" + sigHex := "882730e5d03f6b42c3abc26d3372625034e1d871b65a8a6b900a56dae22da98abbe1b68f85e49fe7652a55ec3d0591c20767677e33e5cbb1207315c41a9ac03be39c2e7668edc043d6cb1d9fd93033caa8a1c5b0e84bedaeb6c64972503a43eb" + + msgBytes := make([]byte, len(msgHex)>>1) + hex.Decode(msgBytes, []byte(msgHex)) + pubBytes := make([]byte, len(pubHex)>>1) + hex.Decode(pubBytes, []byte(pubHex)) + sigBytes := make([]byte, len(sigHex)>>1) + hex.Decode(sigBytes, []byte(sigHex)) + + var pub bls12381.G1Affine + _, e := pub.SetBytes(pubBytes) + if e != nil { + t.Fail() + } + var sig bls12381.G2Affine + _, e = sig.SetBytes(sigBytes) + if e != nil { + t.Fail() + } + + var g1GNeg bls12381.G1Affine + _, _, g1Gen, _ := bls12381.Generators() + g1GNeg.Neg(&g1Gen) + + h, e := bls12381.HashToG2(msgBytes, []byte(g2_dst)) + if e != nil { + t.Fail() + } + + b, e := bls12381.PairingCheck([]bls12381.G1Affine{g1GNeg, pub}, []bls12381.G2Affine{sig, h}) + if e != nil { + t.Fail() + } + if !b { + t.Fail() // invalid inputs, won't verify + } + + circuit := blsG2SigCircuit{ + Pub: pub, + msg: msgBytes, + Sig: sig, + } + witness := blsG2SigCircuit{ + Pub: pub, + msg: msgBytes, + Sig: sig, + } + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} diff --git a/std/algebra/emulated/sw_bls12381/g2.go b/std/algebra/emulated/sw_bls12381/g2.go index e6a275cfbb..7421bd6d8f 100644 --- a/std/algebra/emulated/sw_bls12381/g2.go +++ b/std/algebra/emulated/sw_bls12381/g2.go @@ -13,6 +13,7 @@ type G2 struct { *fields_bls12381.Ext2 u1, w *emulated.Element[BaseField] v *fields_bls12381.E2 + api frontend.API } type g2AffP struct { @@ -50,6 +51,7 @@ func NewG2(api frontend.API) *G2 { w: &w, u1: &u1, v: &v, + api: api, } } @@ -96,6 +98,18 @@ func (g2 *G2) psi(q *G2Affine) *G2Affine { } } +func (g2 *G2) psi2(q *G2Affine) *G2Affine { + x := g2.Ext2.MulByElement(&q.P.X, g2.w) + y := g2.Ext2.Neg(&q.P.Y) + + return &G2Affine{ + P: g2AffP{ + X: *x, + Y: *y, + }, + } +} + func (g2 *G2) scalarMulBySeed(q *G2Affine) *G2Affine { z := g2.triple(q) @@ -136,6 +150,60 @@ func (g2 G2) add(p, q *G2Affine) *G2Affine { } } +// Follow sw_emulated.Curve.AddUnified to implement the Brier and Joye algorithm +// to handle edge cases, i.e., p == q, p == 0 or/and q == 0 +func (g2 G2) addUnified(p, q *G2Affine) *G2Affine { + + // selector1 = 1 when p is (0,0) and 0 otherwise + selector1 := g2.api.And(g2.Ext2.IsZero(&p.P.X), g2.Ext2.IsZero(&p.P.Y)) + // selector2 = 1 when q is (0,0) and 0 otherwise + selector2 := g2.api.And(g2.Ext2.IsZero(&q.P.X), g2.Ext2.IsZero(&q.P.Y)) + + // λ = ((p.x+q.x)² - p.x*q.x + a)/(p.y + q.y) + pxqx := g2.Ext2.Mul(&p.P.X, &q.P.X) + pxplusqx := g2.Ext2.Add(&p.P.X, &q.P.X) + num := g2.Ext2.Mul(pxplusqx, pxplusqx) + num = g2.Ext2.Sub(num, pxqx) + denum := g2.Ext2.Add(&p.P.Y, &q.P.Y) + // if p.y + q.y = 0, assign dummy 1 to denum and continue + selector3 := g2.Ext2.IsZero(denum) + denum = g2.Ext2.Select(selector3, g2.Ext2.One(), denum) + λ := g2.Ext2.DivUnchecked(num, denum) // we already know that denum won't be zero + + // x = λ^2 - p.x - q.x + xr := g2.Ext2.Mul(λ, λ) + xr = g2.Ext2.Sub(xr, pxplusqx) + + // y = λ(p.x - xr) - p.y + yr := g2.Ext2.Sub(&p.P.X, xr) + yr = g2.Ext2.Mul(yr, λ) + yr = g2.Ext2.Sub(yr, &p.P.Y) + result := &G2Affine{ + P: g2AffP{ + X: *xr, + Y: *yr, + }, + } + + zero := g2.Ext2.Zero() + // if p=(0,0) return q + resultX := *g2.Select(selector1, &q.P.X, &result.P.X) + resultY := *g2.Select(selector1, &q.P.Y, &result.P.Y) + // if q=(0,0) return p + resultX = *g2.Select(selector2, &p.P.X, &resultX) + resultY = *g2.Select(selector2, &p.P.Y, &resultY) + // if p.y + q.y = 0, return (0, 0) + resultX = *g2.Select(selector3, zero, &resultX) + resultY = *g2.Select(selector3, zero, &resultY) + + return &G2Affine{ + P: g2AffP{ + X: resultX, + Y: resultY, + }, + } +} + func (g2 G2) neg(p *G2Affine) *G2Affine { xr := &p.P.X yr := g2.Ext2.Neg(&p.P.Y) diff --git a/std/algebra/emulated/sw_bls12381/g2_test.go b/std/algebra/emulated/sw_bls12381/g2_test.go index 9d4a90d0e4..534843ccd4 100644 --- a/std/algebra/emulated/sw_bls12381/g2_test.go +++ b/std/algebra/emulated/sw_bls12381/g2_test.go @@ -37,6 +37,116 @@ func TestAddG2TestSolve(t *testing.T) { assert.NoError(err) } +func TestAddG2FailureCaseTestSolve(t *testing.T) { + assert := test.NewAssert(t) + _, in1 := randomG1G2Affines() + var res bls12381.G2Affine + res.Double(&in1) + witness := addG2Circuit{ + In1: NewG2Affine(in1), + In2: NewG2Affine(in1), + Res: NewG2Affine(res), + } + err := test.IsSolved(&addG2Circuit{}, &witness, ecc.BN254.ScalarField()) + // the add() function cannot handle identical inputs + assert.Error(err) +} + +type addG2UnifiedCircuit struct { + In1, In2 G2Affine + Res G2Affine +} + +func (c *addG2UnifiedCircuit) Define(api frontend.API) error { + g2 := NewG2(api) + res := g2.addUnified(&c.In1, &c.In2) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +func TestAddG2UnifiedTestSolveAdd(t *testing.T) { + assert := test.NewAssert(t) + _, in1 := randomG1G2Affines() + _, in2 := randomG1G2Affines() + var res bls12381.G2Affine + res.Add(&in1, &in2) + witness := addG2UnifiedCircuit{ + In1: NewG2Affine(in1), + In2: NewG2Affine(in2), + Res: NewG2Affine(res), + } + err := test.IsSolved(&addG2UnifiedCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +func TestAddG2UnifiedTestSolveDbl(t *testing.T) { + assert := test.NewAssert(t) + _, in1 := randomG1G2Affines() + var res bls12381.G2Affine + res.Double(&in1) + witness := addG2UnifiedCircuit{ + In1: NewG2Affine(in1), + In2: NewG2Affine(in1), + Res: NewG2Affine(res), + } + err := test.IsSolved(&addG2UnifiedCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +func TestAddG2UnifiedTestSolveEdgeCases(t *testing.T) { + assert := test.NewAssert(t) + _, p := randomG1G2Affines() + var np, zero bls12381.G2Affine + np.Neg(&p) + zero.Sub(&p, &p) + + // p + (-p) == (0, 0) + witness := addG2UnifiedCircuit{ + In1: NewG2Affine(p), + In2: NewG2Affine(np), + Res: NewG2Affine(zero), + } + err := test.IsSolved(&addG2UnifiedCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + + // (-p) + p == (0, 0) + witness2 := addG2UnifiedCircuit{ + In1: NewG2Affine(np), + In2: NewG2Affine(p), + Res: NewG2Affine(zero), + } + err2 := test.IsSolved(&addG2UnifiedCircuit{}, &witness2, ecc.BN254.ScalarField()) + assert.NoError(err2) + + // p + (0, 0) == p + witness3 := addG2UnifiedCircuit{ + In1: NewG2Affine(p), + In2: NewG2Affine(zero), + Res: NewG2Affine(p), + } + err3 := test.IsSolved(&addG2UnifiedCircuit{}, &witness3, ecc.BN254.ScalarField()) + assert.NoError(err3) + + // (0, 0) + p == p + witness4 := addG2UnifiedCircuit{ + In1: NewG2Affine(zero), + In2: NewG2Affine(p), + Res: NewG2Affine(p), + } + err4 := test.IsSolved(&addG2UnifiedCircuit{}, &witness4, ecc.BN254.ScalarField()) + assert.NoError(err4) + + // (0, 0) + (0, 0) == (0, 0) + witness5 := addG2UnifiedCircuit{ + In1: NewG2Affine(zero), + In2: NewG2Affine(zero), + Res: NewG2Affine(zero), + } + err5 := test.IsSolved(&addG2UnifiedCircuit{}, &witness5, ecc.BN254.ScalarField()) + assert.NoError(err5) + +} + type doubleG2Circuit struct { In1 G2Affine Res G2Affine diff --git a/std/algebra/emulated/sw_bls12381/hash_to_g2.go b/std/algebra/emulated/sw_bls12381/hash_to_g2.go new file mode 100644 index 0000000000..55d8d9a52f --- /dev/null +++ b/std/algebra/emulated/sw_bls12381/hash_to_g2.go @@ -0,0 +1,373 @@ +package sw_bls12381 + +import ( + "math/big" + "slices" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/fields_bls12381" + "github.com/consensys/gnark/std/hash/tofield" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/math/emulated/emparams" + "github.com/consensys/gnark/std/math/uints" +) + +const ( + security_level = 128 + len_per_base_element = 64 +) + +func HashToG2(api frontend.API, msg []uints.U8, dst []byte) (*G2Affine, error) { + fp, e := emulated.NewField[emulated.BLS12381Fp](api) + if e != nil { + return &G2Affine{}, e + } + ext2 := fields_bls12381.NewExt2(api) + mapper := newMapper(api, ext2, fp) + g2 := NewG2(api) + + // Steps: + // 1. u = hash_to_field(msg, 2) + // 2. Q0 = map_to_curve(u[0]) + // 3. Q1 = map_to_curve(u[1]) + // 4. R = Q0 + Q1 # Point addition + // 5. P = clear_cofactor(R) + // 6. return P + lenPerBaseElement := len_per_base_element + lenInBytes := lenPerBaseElement * 4 + uniformBytes, e := tofield.ExpandMsgXmd(api, msg, dst, lenInBytes) + if e != nil { + return &G2Affine{}, e + } + + ele1 := bytesToElement(api, fp, uniformBytes[:lenPerBaseElement]) + ele2 := bytesToElement(api, fp, uniformBytes[lenPerBaseElement:lenPerBaseElement*2]) + ele3 := bytesToElement(api, fp, uniformBytes[lenPerBaseElement*2:lenPerBaseElement*3]) + ele4 := bytesToElement(api, fp, uniformBytes[lenPerBaseElement*3:]) + + // we will still do iso_map before point addition, as we do not have point addition in E' (yet) + Q0 := mapper.mapToCurve(fields_bls12381.E2{A0: *ele1, A1: *ele2}) + Q1 := mapper.mapToCurve(fields_bls12381.E2{A0: *ele3, A1: *ele4}) + Q0 = mapper.isogeny(&Q0.P.X, &Q0.P.Y) + Q1 = mapper.isogeny(&Q1.P.X, &Q1.P.Y) + + R := g2.addUnified(Q0, Q1) + + return clearCofactor(g2, fp, R), nil +} + +func bytesToElement(api frontend.API, fp *emulated.Field[emulated.BLS12381Fp], data []uints.U8) *emulated.Element[emulated.BLS12381Fp] { + // data in BE, need to convert to LE + slices.Reverse(data) + + bits := make([]frontend.Variable, len(data)*8) + for i := 0; i < len(data); i++ { + u8 := data[i] + u8Bits := api.ToBinary(u8.Val, 8) + for j := 0; j < 8; j++ { + bits[i*8+j] = u8Bits[j] + } + } + + cutoff := 17 + tailBits, headBits := bits[:cutoff*8], bits[cutoff*8:] + tail := fp.FromBits(tailBits...) + head := fp.FromBits(headBits...) + + byteMultiplier := big.NewInt(256) + headMultiplier := byteMultiplier.Exp(byteMultiplier, big.NewInt(int64(cutoff)), big.NewInt(0)) + head = fp.MulConst(head, headMultiplier) + + return fp.Add(head, tail) +} + +type sswuMapper struct { + A, B, Z fields_bls12381.E2 + ext2 *fields_bls12381.Ext2 + fp *emulated.Field[emulated.BLS12381Fp] + api frontend.API + iso *isogeny +} + +func newMapper(api frontend.API, ext2 *fields_bls12381.Ext2, fp *emulated.Field[emulated.BLS12381Fp]) *sswuMapper { + coeff_a := fields_bls12381.E2{ + A0: emulated.ValueOf[emparams.BLS12381Fp](0), + A1: emulated.ValueOf[emparams.BLS12381Fp](240), + } + coeff_b := fields_bls12381.E2{ + A0: emulated.ValueOf[emparams.BLS12381Fp](1012), + A1: emulated.ValueOf[emparams.BLS12381Fp](1012), + } + + one := emulated.ValueOf[emulated.BLS12381Fp](1) + two := emulated.ValueOf[emulated.BLS12381Fp](2) + zeta := fields_bls12381.E2{ + A0: *fp.Neg(&two), + A1: *fp.Neg(&one), + } + + return &sswuMapper{ + A: coeff_a, + B: coeff_b, + Z: zeta, + ext2: ext2, + fp: fp, + api: api, + iso: newIsogeny(), + } +} + +// Apply the Simplified SWU for the E' curve (RFC 9380 Section 6.6.3) +func (m sswuMapper) mapToCurve(u fields_bls12381.E2) *G2Affine { + // SSWU Steps: + // 1. tv1 = u^2 + tv1 := m.ext2.Square(&u) + // 2. tv1 = Z * tv1 + tv1 = m.ext2.Mul(&m.Z, tv1) + // 3. tv2 = tv1^2 + tv2 := m.ext2.Square(tv1) + // 4. tv2 = tv2 + tv1 + tv2 = m.ext2.Add(tv2, tv1) + // 5. tv3 = tv2 + 1 + tv3 := m.ext2.Add(tv2, m.ext2.One()) + // 6. tv3 = B * tv3 + tv3 = m.ext2.Mul(&m.B, tv3) + // 7. tv4 = CMOV(Z, -tv2, tv2 != 0) + s1 := m.ext2.IsZero(tv2) + tv4 := m.ext2.Select(s1, &m.Z, m.ext2.Neg(tv2)) + // 8. tv4 = A * tv4 + tv4 = m.ext2.Mul(&m.A, tv4) + // 9. tv2 = tv3^2 + tv2 = m.ext2.Square(tv3) + // 10. tv6 = tv4^2 + tv6 := m.ext2.Square(tv4) + // 11. tv5 = A * tv6 + tv5 := m.ext2.Mul(&m.A, tv6) + // 12. tv2 = tv2 + tv5 + tv2 = m.ext2.Add(tv2, tv5) + // 13. tv2 = tv2 * tv3 + tv2 = m.ext2.Mul(tv2, tv3) + // 14. tv6 = tv6 * tv4 + tv6 = m.ext2.Mul(tv6, tv4) + // 15. tv5 = B * tv6 + tv5 = m.ext2.Mul(&m.B, tv6) + // 16. tv2 = tv2 + tv5 + tv2 = m.ext2.Add(tv2, tv5) + // 17. x = tv1 * tv3 + x := m.ext2.Mul(tv1, tv3) + // 18. (is_gx1_square, y1) = sqrt_ratio(tv2, tv6) + isGx1Square, y1 := m.sqrtRatio(tv2, tv6) + // 19. y = tv1 * u + y := m.ext2.Mul(tv1, &u) + // 20. y = y * y1 + y = m.ext2.Mul(y, y1) + // 21. x = CMOV(x, tv3, is_gx1_square) + x = m.ext2.Select(isGx1Square, tv3, x) + // 22. y = CMOV(y, y1, is_gx1_square) + y = m.ext2.Select(isGx1Square, y1, y) + // 23. e1 = sgn0(u) == sgn0(y) + sgn0U := m.sgn0(&u) + sgn0Y := m.sgn0(y) + diff := m.api.Sub(sgn0U, sgn0Y) + e1 := m.api.IsZero(diff) + // 24. y = CMOV(-y, y, e1) + yNeg := m.ext2.Neg(y) + y = m.ext2.Select(e1, y, yNeg) + // 25. x = x / tv4 + x = m.ext2.DivUnchecked(x, tv4) + // 26. return (x, y) + return &G2Affine{ + P: g2AffP{X: *x, Y: *y}, + } +} + +func (m sswuMapper) sgn0(x *fields_bls12381.E2) frontend.Variable { + // Steps for sgn0_m_eq_2 + // 1. sign_0 = x_0 mod 2 + A0 := m.fp.Reduce(&x.A0) + x0 := m.fp.ToBits(A0) + sign0 := x0[0] + // 2. zero_0 = x_0 == 0 + zero0 := m.fp.IsZero(&x.A0) + // 3. sign_1 = x_1 mod 2 + A1 := m.fp.Reduce(&x.A1) + x1 := m.fp.ToBits(A1) + sign1 := x1[0] + // 4. s = sign_0 OR (zero_0 AND sign_1) # Avoid short-circuit logic ops + tv := m.api.And(zero0, sign1) + s := m.api.Or(sign0, tv) + // 5. return s + return s +} + +// Let's not mechanically translate the spec algorithm (Section F.2.1) into R1CS circuits. +// We could simply compute the result as a hint, then apply proper constraints, which is: +// for output of (b, y) +// +// b1 := {b = True AND y^2 * v = u} +// b2 := {b = False AND y^2 * v = Z * u} +// AssertTrue: {b1 OR b2} +func (m sswuMapper) sqrtRatio(u, v *fields_bls12381.E2) (frontend.Variable, *fields_bls12381.E2) { + // Steps + // 1. extract the base values of u, v, then compute G2SqrtRatio with gnark-crypto + x, err := m.fp.NewHint(GetHints()[0], 3, &u.A0, &u.A1, &v.A0, &v.A1) + if err != nil { + panic("failed to calculate sqrtRatio with gnark-crypto " + err.Error()) + } + + b := m.fp.IsZero(x[0]) + y := fields_bls12381.E2{A0: *x[1], A1: *x[2]} + + // 2. apply constraints + // b1 := {b = True AND y^2 * v = u} + m.api.AssertIsBoolean(b) + y2 := m.ext2.Square(&y) + y2v := m.ext2.Mul(y2, v) + bY2vu := m.ext2.IsZero(m.ext2.Sub(y2v, u)) + b1 := m.api.And(b, bY2vu) + + // b2 := {b = False AND y^2 * v = Z * u} + uZ := m.ext2.Mul(&m.Z, u) + bY2vZu := m.ext2.IsZero(m.ext2.Sub(y2v, uZ)) + nb := m.api.IsZero(b) + b2 := m.api.And(nb, bY2vZu) + + cmp := m.api.Or(b1, b2) + m.api.AssertIsEqual(cmp, 1) + + return b, &y +} + +type g2Polynomial []fields_bls12381.E2 + +func (p g2Polynomial) eval(m *sswuMapper, at fields_bls12381.E2) (pAt *fields_bls12381.E2) { + pAt = &p[len(p)-1] + + for i := len(p) - 2; i >= 0; i-- { + pAt = m.ext2.Mul(pAt, &at) + pAt = m.ext2.Add(pAt, &p[i]) + } + + return +} + +type isogeny struct { + x_numerator, x_denominator, y_numerator, y_denominator g2Polynomial +} + +func newIsogeny() *isogeny { + return &isogeny{ + x_numerator: g2Polynomial([]fields_bls12381.E2{ + *e2FromStrings( + "889424345604814976315064405719089812568196182208668418962679585805340366775741747653930584250892369786198727235542", + "889424345604814976315064405719089812568196182208668418962679585805340366775741747653930584250892369786198727235542"), + *e2FromStrings( + "0", + "2668273036814444928945193217157269437704588546626005256888038757416021100327225242961791752752677109358596181706522"), + *e2FromStrings( + "2668273036814444928945193217157269437704588546626005256888038757416021100327225242961791752752677109358596181706526", + "1334136518407222464472596608578634718852294273313002628444019378708010550163612621480895876376338554679298090853261"), + *e2FromStrings( + "3557697382419259905260257622876359250272784728834673675850718343221361467102966990615722337003569479144794908942033", + "0"), + }), + x_denominator: g2Polynomial([]fields_bls12381.E2{ + *e2FromStrings( + "0", + "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559715"), + *e2FromStrings( + "12", + "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559775"), + *e2FromStrings( + "1", + "0"), + }), + y_numerator: g2Polynomial([]fields_bls12381.E2{ + *e2FromStrings( + "3261222600550988246488569487636662646083386001431784202863158481286248011511053074731078808919938689216061999863558", + "3261222600550988246488569487636662646083386001431784202863158481286248011511053074731078808919938689216061999863558"), + *e2FromStrings( + "0", + "889424345604814976315064405719089812568196182208668418962679585805340366775741747653930584250892369786198727235518"), + *e2FromStrings( + "2668273036814444928945193217157269437704588546626005256888038757416021100327225242961791752752677109358596181706524", + "1334136518407222464472596608578634718852294273313002628444019378708010550163612621480895876376338554679298090853263"), + *e2FromStrings( + "2816510427748580758331037284777117739799287910327449993381818688383577828123182200904113516794492504322962636245776", + "0"), + }), + y_denominator: g2Polynomial([]fields_bls12381.E2{ + *e2FromStrings( + "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559355", + "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559355"), + *e2FromStrings( + "0", + "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559571"), + *e2FromStrings( + "18", + "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559769"), + *e2FromStrings( + "1", + "0"), + }), + } +} + +// Map the point from E' to E +func (m sswuMapper) isogeny(x, y *fields_bls12381.E2) *G2Affine { + xn := m.iso.x_numerator.eval(&m, *x) + + xd := m.iso.x_denominator.eval(&m, *x) + xdInv := m.ext2.Inverse(xd) + + yn := m.iso.y_numerator.eval(&m, *x) + yn = m.ext2.Mul(yn, y) + + yd := m.iso.y_denominator.eval(&m, *x) + ydInv := m.ext2.Inverse(yd) + + return &G2Affine{ + P: g2AffP{ + X: *m.ext2.Mul(xn, xdInv), + Y: *m.ext2.Mul(yn, ydInv), + }, + } +} + +func e2FromStrings(x, y string) *fields_bls12381.E2 { + A0, _ := new(big.Int).SetString(x, 10) + A1, _ := new(big.Int).SetString(y, 10) + + a0 := emulated.ValueOf[emulated.BLS12381Fp](A0) + a1 := emulated.ValueOf[emulated.BLS12381Fp](A1) + + return &fields_bls12381.E2{A0: a0, A1: a1} +} + +// Follow RFC 9380 Apendix G.3 to compute efficiently. +func clearCofactor(g2 *G2, fp *emulated.Field[emparams.BLS12381Fp], p *G2Affine) *G2Affine { + // Steps: + // 1. t1 = c1 * P + // c1 = -15132376222941642752 + t1 := g2.scalarMulBySeed(p) + // 2. t2 = psi(P) + t2 := g2.psi(p) + // 3. t3 = 2 * P + t3 := g2.double(p) + // 4. t3 = psi2(t3) + t3 = g2.psi2(t3) + // 5. t3 = t3 - t2 + t3 = g2.sub(t3, t2) + // 6. t2 = t1 + t2 + t2 = g2.addUnified(t1, t2) + // 7. t2 = c1 * t2 + t2 = g2.scalarMulBySeed(t2) + // 8. t3 = t3 + t2 + t3 = g2.addUnified(t3, t2) + // 9. t3 = t3 - t1 + t3 = g2.sub(t3, t1) + // 10. Q = t3 - P + Q := g2.sub(t3, p) + // 11. return Q + return Q +} diff --git a/std/algebra/emulated/sw_bls12381/hash_to_g2_test.go b/std/algebra/emulated/sw_bls12381/hash_to_g2_test.go new file mode 100644 index 0000000000..683106b5b6 --- /dev/null +++ b/std/algebra/emulated/sw_bls12381/hash_to_g2_test.go @@ -0,0 +1,266 @@ +package sw_bls12381 + +import ( + "bytes" + "encoding/hex" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + bls12381fp "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/gnark/std/algebra/emulated/fields_bls12381" + "github.com/consensys/gnark/std/hash/tofield" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/math/uints" + "github.com/consensys/gnark/test" +) + +func getMsgs() []string { + return []string{"", "a", "ab", "abc", "abcd", "abcde", "abcdef", "abcdefg", "1", "2", "3", "4", "5", "5656565656565656565656565656565656565656565656565656565656565656"} +} + +func getDst() []byte { + dstHex := "412717974da474d0f8c420f320ff81e8432adb7c927d9bd082b4fb4d16c0a236" + dst := make([]byte, len(dstHex)/2) + hex.Decode(dst, []byte(dstHex)) + return dst +} + +type hashToFieldCircuit struct { + Msg []byte + Dst []byte + Res bls12381fp.Element +} + +func (c *hashToFieldCircuit) Define(api frontend.API) error { + msg := uints.NewU8Array(c.Msg) + uniformBytes, _ := tofield.ExpandMsgXmd(api, msg, c.Dst, 64) + fp, _ := emulated.NewField[emulated.BLS12381Fp](api) + + ele := bytesToElement(api, fp, uniformBytes) + + fp.AssertIsEqual(ele, fp.NewElement(c.Res)) + + return nil +} + +func TestHashToFieldTestSolve(t *testing.T) { + assert := test.NewAssert(t) + dst := getDst() + + for _, msg := range getMsgs() { + + rawEles, _ := bls12381fp.Hash([]byte(msg), dst, 1) + + circuit := hashToFieldCircuit{ + Msg: []byte(msg), + Dst: dst, + Res: rawEles[0], + } + witness := hashToFieldCircuit{ + Msg: []byte(msg), + Dst: dst, + Res: rawEles[0], + } + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + } +} + +type mapToCurveCircuit struct { + Msg []byte + Dst []byte + Res G2Affine +} + +func (c *mapToCurveCircuit) Define(api frontend.API) error { + msg := uints.NewU8Array(c.Msg) + fp, _ := emulated.NewField[emulated.BLS12381Fp](api) + ext2 := fields_bls12381.NewExt2(api) + mapper := newMapper(api, ext2, fp) + + uniformBytes, _ := tofield.ExpandMsgXmd(api, msg, c.Dst, 128) + ele1 := bytesToElement(api, fp, uniformBytes[:64]) + ele2 := bytesToElement(api, fp, uniformBytes[64:]) + e := fields_bls12381.E2{A0: *ele1, A1: *ele2} + affine := mapper.mapToCurve(e) + + g2 := NewG2(api) + g2.AssertIsEqual(affine, &c.Res) + + return nil +} + +func TestMapToCurveTestSolve(t *testing.T) { + assert := test.NewAssert(t) + dst := getDst() + + for _, msg := range getMsgs() { + + rawEles, _ := bls12381fp.Hash([]byte(msg), dst, 2) + rawAffine := bls12381.MapToCurve2(&bls12381.E2{A0: rawEles[0], A1: rawEles[1]}) + wrappedRawAffine := NewG2Affine(rawAffine) + + circuit := mapToCurveCircuit{ + Msg: []byte(msg), + Dst: dst, + Res: wrappedRawAffine, + } + witness := mapToCurveCircuit{ + Msg: []byte(msg), + Dst: dst, + Res: wrappedRawAffine, + } + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + } +} + +type clearCofactorCircuit struct { + In G2Affine + Res G2Affine +} + +func (c *clearCofactorCircuit) Define(api frontend.API) error { + g2 := NewG2(api) + fp, _ := emulated.NewField[emulated.BLS12381Fp](api) + res := clearCofactor(g2, fp, &c.In) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +func TestClearCofactorTestSolve(t *testing.T) { + assert := test.NewAssert(t) + _, in := randomG1G2Affines() + + inAffine := NewG2Affine(in) + + in.ClearCofactor(&in) + circuit := clearCofactorCircuit{ + In: inAffine, + Res: NewG2Affine(in), + } + witness := clearCofactorCircuit{ + In: inAffine, + Res: NewG2Affine(in), + } + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type hashToG2Circuit struct { + Msg []byte + Dst []byte + Res G2Affine +} + +func (c *hashToG2Circuit) Define(api frontend.API) error { + res, e := HashToG2(api, uints.NewU8Array(c.Msg), c.Dst) + if e != nil { + return e + } + + g2 := NewG2(api) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +func TestHashToG2TestSolve(t *testing.T) { + assert := test.NewAssert(t) + dst := getDst() + + for _, msg := range getMsgs() { + + expected, _ := bls12381.HashToG2([]uint8(msg), dst) + wrappedRes := NewG2Affine(expected) + + circuit := hashToG2Circuit{ + Msg: []uint8(msg), + Dst: dst, + Res: wrappedRes, + } + witness := hashToG2Circuit{ + Msg: []uint8(msg), + Dst: dst, + Res: wrappedRes, + } + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + } +} + +type hashToG2BenchCircuit struct { + Msg []byte + Dst []byte +} + +func (c *hashToG2BenchCircuit) Define(api frontend.API) error { + _, e := HashToG2(api, uints.NewU8Array(c.Msg), c.Dst) + return e +} + +func BenchmarkHashToG2(b *testing.B) { + + dst := getDst() + + msg := "abcd" + witness := hashToG2BenchCircuit{ + Msg: []uint8(msg), + Dst: dst, + } + w, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + if err != nil { + b.Fatal(err) + } + var ccs constraint.ConstraintSystem + b.Run("compile scs", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if ccs, err = frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &hashToG2BenchCircuit{}); err != nil { + b.Fatal(err) + } + } + }) + var buf bytes.Buffer + _, err = ccs.WriteTo(&buf) + if err != nil { + b.Fatal(err) + } + b.Logf("scs size: %d (bytes), nb constraints %d, nbInstructions: %d", buf.Len(), ccs.GetNbConstraints(), ccs.GetNbInstructions()) + b.Run("solve scs", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := ccs.Solve(w); err != nil { + b.Fatal(err) + } + } + }) + b.Run("compile r1cs", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if ccs, err = frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &hashToG2BenchCircuit{}); err != nil { + b.Fatal(err) + } + } + }) + buf.Reset() + _, err = ccs.WriteTo(&buf) + if err != nil { + b.Fatal(err) + } + b.Logf("r1cs size: %d (bytes), nb constraints %d, nbInstructions: %d", buf.Len(), ccs.GetNbConstraints(), ccs.GetNbInstructions()) + + b.Run("solve r1cs", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := ccs.Solve(w); err != nil { + b.Fatal(err) + } + } + }) + +} diff --git a/std/algebra/emulated/sw_bls12381/hints.go b/std/algebra/emulated/sw_bls12381/hints.go new file mode 100644 index 0000000000..6d6bbc4431 --- /dev/null +++ b/std/algebra/emulated/sw_bls12381/hints.go @@ -0,0 +1,44 @@ +package sw_bls12381 + +import ( + "fmt" + "math/big" + + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/math/emulated" +) + +func GetHints() []solver.Hint { + return []solver.Hint{ + sqrtRatioHint, + } +} + +func init() { + solver.RegisterHint(GetHints()...) +} + +func sqrtRatioHint(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { + return emulated.UnwrapHint(inputs, outputs, func(field *big.Int, inputs, outputs []*big.Int) error { + if len(inputs) != 4 { + return fmt.Errorf("expecting 4 inputs") + } + if len(outputs) != 3 { + return fmt.Errorf("expecting 3 outputs") + } + + var z0, z1, u0, u1, v0, v1 fp.Element + u0.SetBigInt(inputs[0]) + u1.SetBigInt(inputs[1]) + v0.SetBigInt(inputs[2]) + v1.SetBigInt(inputs[3]) + + b := bls12381.G2SqrtRatio(&z0, &z1, &u0, &u1, &v0, &v1) + outputs[0].SetUint64(b) + z0.BigInt(outputs[1]) + z1.BigInt(outputs[2]) + return nil + }) +} diff --git a/std/hash/tofield/doc.go b/std/hash/tofield/doc.go new file mode 100644 index 0000000000..b1612201a7 --- /dev/null +++ b/std/hash/tofield/doc.go @@ -0,0 +1,3 @@ +// Package tofield provides ZKP circuits for expanding messages to field elements, according to +// RFC9380 (section 5.3.1). +package tofield diff --git a/std/hash/tofield/expand.go b/std/hash/tofield/expand.go new file mode 100644 index 0000000000..f6d3e07e1a --- /dev/null +++ b/std/hash/tofield/expand.go @@ -0,0 +1,102 @@ +package tofield + +import ( + "errors" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash/sha2" + "github.com/consensys/gnark/std/math/uints" +) + +const ( + block_size = 64 +) + +// ExpandMsgXmd expands msg to a slice of lenInBytes bytes according to RFC9380 (section 5.3.1) +// Spec: https://datatracker.ietf.org/doc/html/rfc9380#name-expand_message_xmd (hashutils.go) +// Implementation was adapted from gnark-crypto/field/hash.ExpandMsgXmd. +func ExpandMsgXmd(api frontend.API, msg []uints.U8, dst []byte, lenInBytes int) ([]uints.U8, error) { + h, e := sha2.New(api) + if e != nil { + return nil, e + } + + ell := (lenInBytes + h.Size() - 1) / h.Size() // ceil(len_in_bytes / b_in_bytes) + if ell > 255 { + return nil, errors.New("invalid lenInBytes") + } + if len(dst) > 255 { + return nil, errors.New("invalid domain size (>255 bytes)") + } + sizeDomain := uint8(len(dst)) + + dst_prime := make([]uints.U8, len(dst)+1) + copy(dst_prime, uints.NewU8Array(dst)) + dst_prime[len(dst)] = uints.NewU8(uint8(sizeDomain)) + + Z_pad_raw := make([]uint8, block_size) + Z_pad := uints.NewU8Array(Z_pad_raw) + h.Write(Z_pad) + h.Write(msg) + h.Write([]uints.U8{uints.NewU8(uint8(lenInBytes >> 8)), uints.NewU8(uint8(lenInBytes)), uints.NewU8(0)}) + h.Write(dst_prime) + b0 := h.Sum() + + h, e = sha2.New(api) + if e != nil { + return nil, e + } + h.Write(b0) + h.Write([]uints.U8{uints.NewU8(1)}) + h.Write(dst_prime) + b1 := h.Sum() + + res := make([]uints.U8, lenInBytes) + copy(res[:h.Size()], b1) + + for i := 2; i <= ell; i++ { + h, e = sha2.New(api) + if e != nil { + return nil, e + } + + // b_i = H(strxor(b₀, b_(i - 1)) ∥ I2OSP(i, 1) ∥ DST_prime) + strxor := make([]uints.U8, h.Size()) + for j := 0; j < h.Size(); j++ { + strxor[j], e = xor(api, b0[j], b1[j]) + if e != nil { + return res, e + } + } + h.Write(strxor) + h.Write([]uints.U8{uints.NewU8(uint8(i))}) + h.Write(dst_prime) + b1 = h.Sum() + copy(res[h.Size()*(i-1):min(h.Size()*i, len(res))], b1) + } + + return res, nil +} + +func xor(api frontend.API, a, b uints.U8) (uints.U8, error) { + aBits := api.ToBinary(a.Val, 8) + bBits := api.ToBinary(b.Val, 8) + cBits := make([]frontend.Variable, 8) + + for i := 0; i < 8; i++ { + cBits[i] = api.Xor(aBits[i], bBits[i]) + } + + uapi, err := uints.New[uints.U32](api) + if err != nil { + return uints.NewU8(255), err + } + return uapi.ByteValueOf(api.FromBinary(cBits...)), nil +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/std/hash/tofield/expand_test.go b/std/hash/tofield/expand_test.go new file mode 100644 index 0000000000..1df5271c59 --- /dev/null +++ b/std/hash/tofield/expand_test.go @@ -0,0 +1,158 @@ +package tofield + +import ( + "encoding/hex" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/uints" + "github.com/consensys/gnark/test" +) + +type expandMsgXmdCircuit struct { + Msg []uints.U8 + Dst []uint8 + Len int + Expected []uints.U8 +} + +type expandMsgXmdTestCase struct { + msg string + lenInBytes int + uniformBytesHex string +} + +func (c *expandMsgXmdCircuit) Define(api frontend.API) error { + uapi, err := uints.New[uints.U32](api) + if err != nil { + return err + } + expanded, err := ExpandMsgXmd(api, c.Msg, c.Dst, c.Len) + if err != nil { + return err + } + + for i := 0; i < c.Len; i++ { + uapi.ByteAssertEq(expanded[i], c.Expected[i]) + } + + return nil +} + +// adapted from gnark-crypto/field/hash/hashutils_test.go +func TestExpandMsgXmd(t *testing.T) { + //name := "expand_message_xmd" + dst := "QUUX-V01-CS02-with-expander-SHA256-128" + //hash := "SHA256" + //k := 128 + + testCases := []expandMsgXmdTestCase{ + { + "", + 0x20, + "68a985b87eb6b46952128911f2a4412bbc302a9d759667f87f7a21d803f07235", + }, + + { + "abc", + 0x20, + "d8ccab23b5985ccea865c6c97b6e5b8350e794e603b4b97902f53a8a0d605615", + }, + + { + "abcdef0123456789", + 0x20, + "eff31487c770a893cfb36f912fbfcbff40d5661771ca4b2cb4eafe524333f5c1", + }, + + { + "q128_qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqq", + 0x20, + "b23a1d2b4d97b2ef7785562a7e8bac7eed54ed6e97e29aa51bfe3f12ddad1ff9", + }, + + { + "a512_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + 0x20, + "4623227bcc01293b8c130bf771da8c298dede7383243dc0993d2d94823958c4c", + }, + { + "", + 0x80, + "af84c27ccfd45d41914fdff5df25293e221afc53d8ad2ac06d5e3e29485dadbee0d121587713a3e0dd4d5e69e93eb7cd4f5df4cd103e188cf60cb02edc3edf18eda8576c412b18ffb658e3dd6ec849469b979d444cf7b26911a08e63cf31f9dcc541708d3491184472c2c29bb749d4286b004ceb5ee6b9a7fa5b646c993f0ced", + }, + { + "", + 0x20, + "68a985b87eb6b46952128911f2a4412bbc302a9d759667f87f7a21d803f07235", + }, + { + "abc", + 0x80, + "abba86a6129e366fc877aab32fc4ffc70120d8996c88aee2fe4b32d6c7b6437a647e6c3163d40b76a73cf6a5674ef1d890f95b664ee0afa5359a5c4e07985635bbecbac65d747d3d2da7ec2b8221b17b0ca9dc8a1ac1c07ea6a1e60583e2cb00058e77b7b72a298425cd1b941ad4ec65e8afc50303a22c0f99b0509b4c895f40", + }, + { + "abcdef0123456789", + 0x80, + "ef904a29bffc4cf9ee82832451c946ac3c8f8058ae97d8d629831a74c6572bd9ebd0df635cd1f208e2038e760c4994984ce73f0d55ea9f22af83ba4734569d4bc95e18350f740c07eef653cbb9f87910d833751825f0ebefa1abe5420bb52be14cf489b37fe1a72f7de2d10be453b2c9d9eb20c7e3f6edc5a60629178d9478df", + }, + { + "q128_qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqq", + 0x80, + "80be107d0884f0d881bb460322f0443d38bd222db8bd0b0a5312a6fedb49c1bbd88fd75d8b9a09486c60123dfa1d73c1cc3169761b17476d3c6b7cbbd727acd0e2c942f4dd96ae3da5de368d26b32286e32de7e5a8cb2949f866a0b80c58116b29fa7fabb3ea7d520ee603e0c25bcaf0b9a5e92ec6a1fe4e0391d1cdbce8c68a", + }, + { + "a512_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + 0x80, + "546aff5444b5b79aa6148bd81728704c32decb73a3ba76e9e75885cad9def1d06d6792f8a7d12794e90efed817d96920d728896a4510864370c207f99bd4a608ea121700ef01ed879745ee3e4ceef777eda6d9e5e38b90c86ea6fb0b36504ba4a45d22e86f6db5dd43d98a294bebb9125d5b794e9d2a81181066eb954966a487", + }, + //test cases not in the standard + { + "", + 0x30, + "3808e9bb0ade2df3aa6f1b459eb5058a78142f439213ddac0c97dcab92ae5a8408d86b32bbcc87de686182cbdf65901f", + }, + { + "abc", + 0x30, + "2b877f5f0dfd881405426c6b87b39205ef53a548b0e4d567fc007cb37c6fa1f3b19f42871efefca518ac950c27ac4e28", + }, + { + "abcdef0123456789", + 0x30, + "226da1780b06e59723714f80da9a63648aebcfc1f08e0db87b5b4d16b108da118214c1450b0e86f9cefeb44903fd3aba", + }, + { + "q128_qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqq", + 0x30, + "12b23ae2e888f442fd6d0d85d90a0d7ed5337d38113e89cdc7c22db91bd0abaec1023e9a8f0ef583a111104e2f8a0637", + }, + { + "a512_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + 0x30, + "1aaee90016547a85ab4dc55e4f78a364c2e239c0e58b05753453c63e6e818334005e90d9ce8f047bddab9fbb315f8722", + }, + } + + for _, testCase := range testCases { + uniformBytes := make([]uint8, len(testCase.uniformBytesHex)>>1) + hex.Decode(uniformBytes, []uint8(testCase.uniformBytesHex)) + witness := expandMsgXmdCircuit{ + Msg: uints.NewU8Array([]uint8(testCase.msg)), + Dst: []uint8(dst), + Len: testCase.lenInBytes, + Expected: uints.NewU8Array(uniformBytes), + } + circuit := expandMsgXmdCircuit{ + Msg: uints.NewU8Array(make([]uint8, len(testCase.msg))), + Dst: []uint8(dst), + Len: testCase.lenInBytes, + Expected: uints.NewU8Array(uniformBytes), + } + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + if err != nil { + t.Fatal(err) + } + } +} diff --git a/std/hints.go b/std/hints.go index 709c5e9b01..33149e9ef2 100644 --- a/std/hints.go +++ b/std/hints.go @@ -7,6 +7,7 @@ import ( "github.com/consensys/gnark/std/algebra/emulated/fields_bls12381" "github.com/consensys/gnark/std/algebra/emulated/fields_bn254" "github.com/consensys/gnark/std/algebra/emulated/fields_bw6761" + "github.com/consensys/gnark/std/algebra/emulated/sw_bls12381" "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" "github.com/consensys/gnark/std/algebra/native/fields_bls12377" "github.com/consensys/gnark/std/algebra/native/fields_bls24315" @@ -51,6 +52,7 @@ func registerHints() { solver.RegisterHint(fields_bls24315.GetHints()...) // emulated curves solver.RegisterHint(sw_emulated.GetHints()...) + solver.RegisterHint(sw_bls12381.GetHints()...) // native curves solver.RegisterHint(sw_bls12377.GetHints()...) solver.RegisterHint(sw_bls24315.GetHints()...) diff --git a/std/math/emulated/field_assert.go b/std/math/emulated/field_assert.go index 3cac83ea2c..107f91265d 100644 --- a/std/math/emulated/field_assert.go +++ b/std/math/emulated/field_assert.go @@ -148,25 +148,25 @@ func (f *Field[T]) IsZero(a *Element[T]) frontend.Variable { return f.api.Or(res0, resP) } -// // Cmp returns: -// // - -1 if a < b -// // - 0 if a = b -// // - 1 if a > b -// // -// // The method internally reduces the element and asserts that the value is less -// // than the modulus. -// func (f *Field[T]) Cmp(a, b *Element[T]) frontend.Variable { -// ca := f.Reduce(a) -// f.AssertIsInRange(ca) -// cb := f.Reduce(b) -// f.AssertIsInRange(cb) -// var res frontend.Variable = 0 -// for i := int(f.fParams.NbLimbs() - 1); i >= 0; i-- { -// lmbCmp := f.api.Cmp(ca.Limbs[i], cb.Limbs[i]) -// res = f.api.Select(f.api.IsZero(res), lmbCmp, res) -// } -// return res -// } +// Cmp returns: +// - -1 if a < b +// - 0 if a = b +// - 1 if a > b +// +// The method internally reduces the element and asserts that the value is less +// than the modulus. +func (f *Field[T]) Cmp(a, b *Element[T]) frontend.Variable { + ca := f.Reduce(a) + f.AssertIsInRange(ca) + cb := f.Reduce(b) + f.AssertIsInRange(cb) + var res frontend.Variable = 0 + for i := int(f.fParams.NbLimbs() - 1); i >= 0; i-- { + lmbCmp := f.api.Cmp(ca.Limbs[i], cb.Limbs[i]) + res = f.api.Select(f.api.IsZero(res), lmbCmp, res) + } + return res +} // TODO(@ivokub) // func (f *Field[T]) AssertIsDifferent(a, b *Element[T]) { diff --git a/std/math/uints/uint8.go b/std/math/uints/uint8.go index 42b30856f6..461c22f0b9 100644 --- a/std/math/uints/uint8.go +++ b/std/math/uints/uint8.go @@ -24,10 +24,12 @@ package uints import ( "fmt" + "math" "math/bits" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/internal/logderivprecomp" + "github.com/consensys/gnark/std/math/bits" "github.com/consensys/gnark/std/math/bitslice" "github.com/consensys/gnark/std/rangecheck" ) @@ -168,6 +170,37 @@ func (bf *BinaryField[T]) ByteValueOf(a frontend.Variable) U8 { return U8{Val: a, internal: true} } +// Convert any varialbe to bits first then to U8 array +// Note that if expectedLen is shorter than actual value, the converted value is *not* +// equal to the original value! +// TODO optimization +func (bf *BinaryField[T]) ByteArrayValueOf(a frontend.Variable, expectedLen ...int) []U8 { + var opt bits.BaseConversionOption + var bs []frontend.Variable + if len(expectedLen) == 1 { + opt = bits.WithNbDigits(expectedLen[0] * 8) + bs = bits.ToBinary(bf.api, a, opt) + } else { + bs = bits.ToBinary(bf.api, a) + } + + lenBits := len(bs) + lenBytes := int(math.Ceil(float64(lenBits) / 8.0)) + + ret := make([]U8, lenBytes) + for i := 0; i < lenBytes; i++ { + b := bs[i*8] + for j := 1; j < 8 && i*8+j < lenBits; j++ { + v := bs[i*8+j] + v = bf.api.Mul(v, 1<