From 581c332cffa71c4e6cb7de22416f592a3839c27b Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Wed, 20 Nov 2024 12:47:10 +0000 Subject: [PATCH 01/21] perf: add fast path for zero const multiplication --- std/math/emulated/field_mul.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/std/math/emulated/field_mul.go b/std/math/emulated/field_mul.go index a1947cb7c..4617db966 100644 --- a/std/math/emulated/field_mul.go +++ b/std/math/emulated/field_mul.go @@ -192,6 +192,10 @@ func (mc *mulCheck[T]) cleanEvaluations() { // mulMod returns a*b mod r. In practice it computes the result using a hint and // defers the actual multiplication check. func (f *Field[T]) mulMod(a, b *Element[T], _ uint, p *Element[T]) *Element[T] { + // fast path - if one of the inputs is on zero limbs (it is zero), then the result is also zero + if len(a.Limbs) == 0 || len(b.Limbs) == 0 { + return f.Zero() + } f.enforceWidthConditional(a) f.enforceWidthConditional(b) f.enforceWidthConditional(p) From a36ee9717a675b40a9ec8f36c48103e2c60421e9 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Sun, 24 Nov 2024 22:50:37 +0000 Subject: [PATCH 02/21] chore: move tinyfield to smallfield package --- backend/witness/vector.go | 2 +- backend/witness/witness.go | 2 +- constraint/core.go | 2 +- constraint/tinyfield/coeff.go | 2 +- constraint/tinyfield/r1cs_test.go | 2 +- constraint/tinyfield/solver.go | 2 +- constraint/tinyfield/system.go | 2 +- frontend/cs/r1cs/builder.go | 2 +- frontend/cs/scs/builder.go | 2 +- internal/generator/backend/main.go | 2 +- .../backend/template/imports.go.tmpl | 6 +- internal/smallfields/circuits_test.go | 66 +++++++++++++++++++ internal/{ => smallfields}/tinyfield/arith.go | 0 internal/{ => smallfields}/tinyfield/doc.go | 0 .../{ => smallfields}/tinyfield/element.go | 0 .../tinyfield/element_ops_purego.go | 0 .../tinyfield/element_test.go | 0 .../{ => smallfields}/tinyfield/vector.go | 0 .../tinyfield/vector_test.go | 0 test/solver_test.go | 2 +- 20 files changed, 80 insertions(+), 14 deletions(-) create mode 100644 internal/smallfields/circuits_test.go rename internal/{ => smallfields}/tinyfield/arith.go (100%) rename internal/{ => smallfields}/tinyfield/doc.go (100%) rename internal/{ => smallfields}/tinyfield/element.go (100%) rename internal/{ => smallfields}/tinyfield/element_ops_purego.go (100%) rename internal/{ => smallfields}/tinyfield/element_test.go (100%) rename internal/{ => smallfields}/tinyfield/vector.go (100%) rename internal/{ => smallfields}/tinyfield/vector_test.go (100%) diff --git a/backend/witness/vector.go b/backend/witness/vector.go index 248e293ab..94f309a80 100644 --- a/backend/witness/vector.go +++ b/backend/witness/vector.go @@ -13,7 +13,7 @@ import ( fr_bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr" fr_bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" fr_bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - "github.com/consensys/gnark/internal/tinyfield" + "github.com/consensys/gnark/internal/smallfields/tinyfield" "github.com/consensys/gnark/internal/utils" ) diff --git a/backend/witness/witness.go b/backend/witness/witness.go index 5436fc75d..210a6b3bb 100644 --- a/backend/witness/witness.go +++ b/backend/witness/witness.go @@ -56,7 +56,7 @@ import ( fr_bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" "github.com/consensys/gnark/debug" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/tinyfield" + "github.com/consensys/gnark/internal/smallfields/tinyfield" ) var ErrInvalidWitness = errors.New("invalid witness") diff --git a/constraint/core.go b/constraint/core.go index 5d5ff2de4..a5dd9ed26 100644 --- a/constraint/core.go +++ b/constraint/core.go @@ -11,7 +11,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/debug" - "github.com/consensys/gnark/internal/tinyfield" + "github.com/consensys/gnark/internal/smallfields/tinyfield" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" "github.com/consensys/gnark/profile" diff --git a/constraint/tinyfield/coeff.go b/constraint/tinyfield/coeff.go index 73f8c0c5b..48a3af341 100644 --- a/constraint/tinyfield/coeff.go +++ b/constraint/tinyfield/coeff.go @@ -23,7 +23,7 @@ import ( "github.com/consensys/gnark/internal/utils" "math/big" - fr "github.com/consensys/gnark/internal/tinyfield" + fr "github.com/consensys/gnark/internal/smallfields/tinyfield" ) // CoeffTable ensure we store unique coefficients in the constraint system diff --git a/constraint/tinyfield/r1cs_test.go b/constraint/tinyfield/r1cs_test.go index 425f6e3f2..7b678816d 100644 --- a/constraint/tinyfield/r1cs_test.go +++ b/constraint/tinyfield/r1cs_test.go @@ -30,7 +30,7 @@ import ( "github.com/consensys/gnark/constraint/tinyfield" - fr "github.com/consensys/gnark/internal/tinyfield" + fr "github.com/consensys/gnark/internal/smallfields/tinyfield" ) func TestSerialization(t *testing.T) { diff --git a/constraint/tinyfield/solver.go b/constraint/tinyfield/solver.go index 84e920572..a77dbc8d2 100644 --- a/constraint/tinyfield/solver.go +++ b/constraint/tinyfield/solver.go @@ -31,7 +31,7 @@ import ( "sync" "sync/atomic" - fr "github.com/consensys/gnark/internal/tinyfield" + fr "github.com/consensys/gnark/internal/smallfields/tinyfield" ) // solver represent the state of the solver during a call to System.Solve(...) diff --git a/constraint/tinyfield/system.go b/constraint/tinyfield/system.go index e8671cea2..a736422d3 100644 --- a/constraint/tinyfield/system.go +++ b/constraint/tinyfield/system.go @@ -27,7 +27,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" - fr "github.com/consensys/gnark/internal/tinyfield" + fr "github.com/consensys/gnark/internal/smallfields/tinyfield" ) type R1CS = system diff --git a/frontend/cs/r1cs/builder.go b/frontend/cs/r1cs/builder.go index bb0ac3c50..e4ca2b71f 100644 --- a/frontend/cs/r1cs/builder.go +++ b/frontend/cs/r1cs/builder.go @@ -31,7 +31,7 @@ import ( "github.com/consensys/gnark/internal/circuitdefer" "github.com/consensys/gnark/internal/frontendtype" "github.com/consensys/gnark/internal/kvstore" - "github.com/consensys/gnark/internal/tinyfield" + "github.com/consensys/gnark/internal/smallfields/tinyfield" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" diff --git a/frontend/cs/scs/builder.go b/frontend/cs/scs/builder.go index db029b225..eaeaf544c 100644 --- a/frontend/cs/scs/builder.go +++ b/frontend/cs/scs/builder.go @@ -30,7 +30,7 @@ import ( "github.com/consensys/gnark/frontend/schema" "github.com/consensys/gnark/internal/circuitdefer" "github.com/consensys/gnark/internal/kvstore" - "github.com/consensys/gnark/internal/tinyfield" + "github.com/consensys/gnark/internal/smallfields/tinyfield" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" diff --git a/internal/generator/backend/main.go b/internal/generator/backend/main.go index 38ed29bc5..391f8bf72 100644 --- a/internal/generator/backend/main.go +++ b/internal/generator/backend/main.go @@ -62,7 +62,7 @@ func main() { CurveID: "BW6_633", } tiny_field := templateData{ - RootPath: "../../../internal/tinyfield/", + RootPath: "../../../internal/smallfields/tinyfield/", CSPath: "../../../constraint/tinyfield", Curve: "tinyfield", CurveID: "UNKNOWN", diff --git a/internal/generator/backend/template/imports.go.tmpl b/internal/generator/backend/template/imports.go.tmpl index 8cab3b4c3..aa9c942dd 100644 --- a/internal/generator/backend/template/imports.go.tmpl +++ b/internal/generator/backend/template/imports.go.tmpl @@ -1,6 +1,6 @@ {{- define "import_fr" }} {{- if eq .Curve "tinyfield"}} - fr "github.com/consensys/gnark/internal/tinyfield" + fr "github.com/consensys/gnark/internal/smallfields/tinyfield" {{- else}} "github.com/consensys/gnark-crypto/ecc/{{toLower .Curve}}/fr" {{- end}} @@ -8,7 +8,7 @@ {{- define "import_fp" }} {{- if eq .Curve "tinyfield"}} - fr "github.com/consensys/gnark/internal/tinyfield" + fr "github.com/consensys/gnark/internal/smallfields/tinyfield" {{- else}} "github.com/consensys/gnark-crypto/ecc/{{toLower .Curve}}/fp" {{- end}} @@ -41,7 +41,7 @@ {{- define "import_witness" }} {{- if eq .Curve "tinyfield"}} - {{toLower .CurveID}}witness "github.com/consensys/gnark/internal/tinyfield/witness" + {{toLower .CurveID}}witness "github.com/consensys/gnark/internal/smallfields/tinyfield/witness" {{- else}} {{toLower .CurveID}}witness "github.com/consensys/gnark/internal/backend/{{toLower .Curve}}/witness" {{- end}} diff --git a/internal/smallfields/circuits_test.go b/internal/smallfields/circuits_test.go new file mode 100644 index 000000000..0a4cfc721 --- /dev/null +++ b/internal/smallfields/circuits_test.go @@ -0,0 +1,66 @@ +package smallfields + +import ( + "testing" + + "github.com/consensys/gnark-crypto/field/goldilocks" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/math/emulated/emparams" + "github.com/consensys/gnark/test" +) + +type NativeCircuit struct { + A frontend.Variable `gnark:",public"` + B frontend.Variable `gnark:",secret"` +} + +func (circuit *NativeCircuit) Define(api frontend.API) error { + res := api.Mul(circuit.A, circuit.A) + api.AssertIsEqual(res, circuit.B) + return nil +} + +type EmulatedCircuit[T emulated.FieldParams] struct { + A emulated.Element[T] `gnark:",public"` + B emulated.Element[T] `gnark:",secret"` +} + +func (circuit *EmulatedCircuit[T]) Define(api frontend.API) error { + f, err := emulated.NewField[T](api) + if err != nil { + return err + } + res := f.Mul(&circuit.A, &circuit.A) + f.AssertIsEqual(res, &circuit.B) + return nil +} + +func TestNativeCircuit(t *testing.T) { + assert := test.NewAssert(t) + + err := test.IsSolved(&NativeCircuit{}, &NativeCircuit{A: 2, B: 4}, goldilocks.Modulus()) + assert.NoError(err) +} + +type smallBN struct { + emparams.BN254Fp +} + +func (smallBN) BitsPerLimb() uint { + return 16 +} + +func (smallBN) NbLimbs() uint { + return 16 +} + +func TestEmulatedCircuit(t *testing.T) { + assert := test.NewAssert(t) + + err := test.IsSolved(&EmulatedCircuit[emparams.BN254Fp]{}, &EmulatedCircuit[emparams.BN254Fp]{A: emulated.ValueOf[emparams.BN254Fp](2), B: emulated.ValueOf[emparams.BN254Fp](4)}, goldilocks.Modulus()) + assert.Error(err) + + err = test.IsSolved(&EmulatedCircuit[smallBN]{}, &EmulatedCircuit[smallBN]{A: emulated.ValueOf[smallBN](2), B: emulated.ValueOf[smallBN](4)}, goldilocks.Modulus()) + assert.NoError(err) +} diff --git a/internal/tinyfield/arith.go b/internal/smallfields/tinyfield/arith.go similarity index 100% rename from internal/tinyfield/arith.go rename to internal/smallfields/tinyfield/arith.go diff --git a/internal/tinyfield/doc.go b/internal/smallfields/tinyfield/doc.go similarity index 100% rename from internal/tinyfield/doc.go rename to internal/smallfields/tinyfield/doc.go diff --git a/internal/tinyfield/element.go b/internal/smallfields/tinyfield/element.go similarity index 100% rename from internal/tinyfield/element.go rename to internal/smallfields/tinyfield/element.go diff --git a/internal/tinyfield/element_ops_purego.go b/internal/smallfields/tinyfield/element_ops_purego.go similarity index 100% rename from internal/tinyfield/element_ops_purego.go rename to internal/smallfields/tinyfield/element_ops_purego.go diff --git a/internal/tinyfield/element_test.go b/internal/smallfields/tinyfield/element_test.go similarity index 100% rename from internal/tinyfield/element_test.go rename to internal/smallfields/tinyfield/element_test.go diff --git a/internal/tinyfield/vector.go b/internal/smallfields/tinyfield/vector.go similarity index 100% rename from internal/tinyfield/vector.go rename to internal/smallfields/tinyfield/vector.go diff --git a/internal/tinyfield/vector_test.go b/internal/smallfields/tinyfield/vector_test.go similarity index 100% rename from internal/tinyfield/vector_test.go rename to internal/smallfields/tinyfield/vector_test.go diff --git a/test/solver_test.go b/test/solver_test.go index 88e0d39b9..152e8039a 100644 --- a/test/solver_test.go +++ b/test/solver_test.go @@ -19,7 +19,7 @@ import ( "github.com/consensys/gnark/frontend/schema" "github.com/consensys/gnark/internal/backend/circuits" "github.com/consensys/gnark/internal/kvstore" - "github.com/consensys/gnark/internal/tinyfield" + "github.com/consensys/gnark/internal/smallfields/tinyfield" "github.com/consensys/gnark/internal/utils" ) From e5a6f6196e559afbe2b0ab3e0e7e5d4c04f8b8d6 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Sun, 24 Nov 2024 23:08:30 +0000 Subject: [PATCH 03/21] feat: auto-generate fields when requested --- internal/generator/backend/main.go | 56 +++++++++++++++++++----------- 1 file changed, 35 insertions(+), 21 deletions(-) diff --git a/internal/generator/backend/main.go b/internal/generator/backend/main.go index 391f8bf72..53eb3e33c 100644 --- a/internal/generator/backend/main.go +++ b/internal/generator/backend/main.go @@ -62,21 +62,22 @@ func main() { CurveID: "BW6_633", } tiny_field := templateData{ - RootPath: "../../../internal/smallfields/tinyfield/", - CSPath: "../../../constraint/tinyfield", - Curve: "tinyfield", - CurveID: "UNKNOWN", - noBackend: true, - NoGKR: true, + RootPath: "../../../internal/smallfields/tinyfield/", + CSPath: "../../../constraint/tinyfield", + Curve: "tinyfield", + CurveID: "UNKNOWN", + noBackend: true, + NoGKR: true, + autoGenerateField: "0x2f", } - - // autogenerate tinyfield - tinyfieldConf, err := config.NewFieldConfig("tinyfield", "Element", "0x2f", false) - if err != nil { - panic(err) - } - if err := generator.GenerateFF(tinyfieldConf, tiny_field.RootPath, "", ""); err != nil { - panic(err) + baby_bear_field := templateData{ + RootPath: "../../../internal/smallfields/babybear/", + CSPath: "../../../constraint/babybear", + Curve: "babybear", + CurveID: "UNKNOWN", + noBackend: true, + NoGKR: true, + autoGenerateField: "0x78000001", } datas := []templateData{ @@ -88,6 +89,7 @@ func main() { bls24_317, bw6_633, tiny_field, + baby_bear_field, } const importCurve = "../imports.go.tmpl" @@ -100,6 +102,16 @@ func main() { go func(d templateData) { defer wg.Done() + // auto-generate small fields + if d.autoGenerateField != "" { + conf, err := config.NewFieldConfig(d.Curve, "Element", d.autoGenerateField, false) + if err != nil { + panic(err) + } + if err := generator.GenerateFF(conf, d.RootPath, "", ""); err != nil { + panic(err) + } + } var ( groth16Dir = strings.Replace(d.RootPath, "{?}", "groth16", 1) @@ -128,7 +140,7 @@ func main() { } // gkr backend - if d.Curve != "tinyfield" { + if !d.NoGKR { entries = []bavard.Entry{{File: filepath.Join(csDir, "gkr.go"), Templates: []string{"gkr.go.tmpl", importCurve}}} if err := bgen.Generate(d, "cs", "./template/representations/", entries...); err != nil { panic(err) @@ -218,10 +230,12 @@ func main() { } type templateData struct { - RootPath string - CSPath string - Curve string - CurveID string - noBackend bool - NoGKR bool + RootPath string + CSPath string + Curve string + CurveID string + + autoGenerateField string + noBackend bool + NoGKR bool } From 721b89f977048f022c537aa90de20b2dd0c59ac5 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 25 Nov 2024 11:28:06 +0000 Subject: [PATCH 04/21] feat: generated babybear field --- constraint/babybear/coeff.go | 230 ++ constraint/babybear/marshal.go | 101 + constraint/babybear/r1cs_test.go | 186 ++ constraint/babybear/solver.go | 648 +++++ constraint/babybear/system.go | 304 +++ internal/smallfields/babybear/arith.go | 60 + internal/smallfields/babybear/doc.go | 53 + internal/smallfields/babybear/element.go | 1059 ++++++++ .../babybear/element_ops_purego.go | 127 + internal/smallfields/babybear/element_test.go | 2208 +++++++++++++++++ internal/smallfields/babybear/vector.go | 340 +++ internal/smallfields/babybear/vector_test.go | 365 +++ 12 files changed, 5681 insertions(+) create mode 100644 constraint/babybear/coeff.go create mode 100644 constraint/babybear/marshal.go create mode 100644 constraint/babybear/r1cs_test.go create mode 100644 constraint/babybear/solver.go create mode 100644 constraint/babybear/system.go create mode 100644 internal/smallfields/babybear/arith.go create mode 100644 internal/smallfields/babybear/doc.go create mode 100644 internal/smallfields/babybear/element.go create mode 100644 internal/smallfields/babybear/element_ops_purego.go create mode 100644 internal/smallfields/babybear/element_test.go create mode 100644 internal/smallfields/babybear/vector.go create mode 100644 internal/smallfields/babybear/vector_test.go diff --git a/constraint/babybear/coeff.go b/constraint/babybear/coeff.go new file mode 100644 index 000000000..79a51b786 --- /dev/null +++ b/constraint/babybear/coeff.go @@ -0,0 +1,230 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "encoding/binary" + "errors" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/internal/utils" + "math/big" + + "github.com/consensys/gnark-crypto/ecc/babybear/fr" +) + +// CoeffTable ensure we store unique coefficients in the constraint system +type CoeffTable struct { + Coefficients []fr.Element + mCoeffs map[fr.Element]uint32 // maps coefficient to coeffID +} + +func newCoeffTable(capacity int) CoeffTable { + r := CoeffTable{ + Coefficients: make([]fr.Element, 5, 5+capacity), + mCoeffs: make(map[fr.Element]uint32, capacity), + } + + r.Coefficients[constraint.CoeffIdZero].SetUint64(0) + r.Coefficients[constraint.CoeffIdOne].SetOne() + r.Coefficients[constraint.CoeffIdTwo].SetUint64(2) + r.Coefficients[constraint.CoeffIdMinusOne].SetInt64(-1) + r.Coefficients[constraint.CoeffIdMinusTwo].SetInt64(-2) + + return r + +} + +func (ct *CoeffTable) toBytes() []byte { + buf := make([]byte, 0, 8+len(ct.Coefficients)*fr.Bytes) + ctLen := uint64(len(ct.Coefficients)) + + buf = binary.LittleEndian.AppendUint64(buf, ctLen) + for _, c := range ct.Coefficients { + for _, w := range c { + buf = binary.LittleEndian.AppendUint64(buf, w) + } + } + + return buf +} + +func (ct *CoeffTable) fromBytes(buf []byte) error { + if len(buf) < 8 { + return errors.New("invalid buffer size") + } + ctLen := binary.LittleEndian.Uint64(buf[:8]) + buf = buf[8:] + + if uint64(len(buf)) < ctLen*fr.Bytes { + return errors.New("invalid buffer size") + } + ct.Coefficients = make([]fr.Element, ctLen) + for i := uint64(0); i < ctLen; i++ { + var c fr.Element + k := int(i) * fr.Bytes + for j := 0; j < fr.Limbs; j++ { + c[j] = binary.LittleEndian.Uint64(buf[k+j*8 : k+(j+1)*8]) + } + ct.Coefficients[i] = c + } + return nil +} + +func (ct *CoeffTable) AddCoeff(coeff constraint.Element) uint32 { + c := (*fr.Element)(coeff[:]) + var cID uint32 + if c.IsZero() { + cID = constraint.CoeffIdZero + } else if c.IsOne() { + cID = constraint.CoeffIdOne + } else if c.Equal(&two) { + cID = constraint.CoeffIdTwo + } else if c.Equal(&minusOne) { + cID = constraint.CoeffIdMinusOne + } else if c.Equal(&minusTwo) { + cID = constraint.CoeffIdMinusTwo + } else { + cc := *c + if id, ok := ct.mCoeffs[cc]; ok { + cID = id + } else { + cID = uint32(len(ct.Coefficients)) + ct.Coefficients = append(ct.Coefficients, cc) + ct.mCoeffs[cc] = cID + } + } + return cID +} + +func (ct *CoeffTable) MakeTerm(coeff constraint.Element, variableID int) constraint.Term { + cID := ct.AddCoeff(coeff) + return constraint.Term{VID: uint32(variableID), CID: cID} +} + +// CoeffToString implements constraint.Resolver +func (ct *CoeffTable) CoeffToString(cID int) string { + return ct.Coefficients[cID].String() +} + +// implements constraint.Field +type field struct{} + +var _ constraint.Field = &field{} + +var ( + two fr.Element + minusOne fr.Element + minusTwo fr.Element +) + +func init() { + minusOne.SetOne() + minusOne.Neg(&minusOne) + two.SetOne() + two.Double(&two) + minusTwo.Neg(&two) +} + +func (engine *field) FromInterface(i interface{}) constraint.Element { + var e fr.Element + if _, err := e.SetInterface(i); err != nil { + // need to clean that --> some code path are dissimilar + // for example setting a fr.Element from an fp.Element + // fails with the above but succeeds through big int... (2-chains) + b := utils.FromInterface(i) + e.SetBigInt(&b) + } + var r constraint.Element + copy(r[:], e[:]) + return r +} +func (engine *field) ToBigInt(c constraint.Element) *big.Int { + e := (*fr.Element)(c[:]) + r := new(big.Int) + e.BigInt(r) + return r + +} +func (engine *field) Mul(a, b constraint.Element) constraint.Element { + _a := (*fr.Element)(a[:]) + _b := (*fr.Element)(b[:]) + _a.Mul(_a, _b) + return a +} + +func (engine *field) Add(a, b constraint.Element) constraint.Element { + _a := (*fr.Element)(a[:]) + _b := (*fr.Element)(b[:]) + _a.Add(_a, _b) + return a +} +func (engine *field) Sub(a, b constraint.Element) constraint.Element { + _a := (*fr.Element)(a[:]) + _b := (*fr.Element)(b[:]) + _a.Sub(_a, _b) + return a +} +func (engine *field) Neg(a constraint.Element) constraint.Element { + e := (*fr.Element)(a[:]) + e.Neg(e) + return a + +} +func (engine *field) Inverse(a constraint.Element) (constraint.Element, bool) { + if a.IsZero() { + return a, false + } + e := (*fr.Element)(a[:]) + if e.IsZero() { + return a, false + } else if e.IsOne() { + return a, true + } + var t fr.Element + t.Neg(e) + if t.IsOne() { + return a, true + } + + e.Inverse(e) + return a, true +} + +func (engine *field) IsOne(a constraint.Element) bool { + e := (*fr.Element)(a[:]) + return e.IsOne() +} + +func (engine *field) One() constraint.Element { + e := fr.One() + var r constraint.Element + copy(r[:], e[:]) + return r +} + +func (engine *field) String(a constraint.Element) string { + e := (*fr.Element)(a[:]) + return e.String() +} + +func (engine *field) Uint64(a constraint.Element) (uint64, bool) { + e := (*fr.Element)(a[:]) + if !e.IsUint64() { + return 0, false + } + return e.Uint64(), true +} diff --git a/constraint/babybear/marshal.go b/constraint/babybear/marshal.go new file mode 100644 index 000000000..ea13e57da --- /dev/null +++ b/constraint/babybear/marshal.go @@ -0,0 +1,101 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "encoding/binary" + "fmt" + "io" + + "github.com/blang/semver/v4" +) + +// WriteTo encodes R1CS into provided io.Writer using cbor +func (cs *system) WriteTo(w io.Writer) (int64, error) { + b, err := cs.System.ToBytes() + if err != nil { + return 0, err + } + + c := cs.CoeffTable.toBytes() + + totalLen := uint64(len(b) + len(c)) + gnarkVersion := semver.MustParse(cs.GnarkVersion) + // write totalLen, gnarkVersion.Major, gnarkVersion.Minor, gnarkVersion.Patch using + // binary.LittleEndian + if err := binary.Write(w, binary.LittleEndian, totalLen); err != nil { + return 0, err + } + if err := binary.Write(w, binary.LittleEndian, gnarkVersion.Major); err != nil { + return 0, err + } + if err := binary.Write(w, binary.LittleEndian, gnarkVersion.Minor); err != nil { + return 0, err + } + if err := binary.Write(w, binary.LittleEndian, gnarkVersion.Patch); err != nil { + return 0, err + } + + // write the system + n, err := w.Write(b) + if err != nil { + return int64(n), err + } + + // write the coeff table + m, err := w.Write(c) + return int64(n+m) + 4*8, err +} + +// ReadFrom attempts to decode R1CS from io.Reader using cbor +func (cs *system) ReadFrom(r io.Reader) (int64, error) { + var totalLen uint64 + if err := binary.Read(r, binary.LittleEndian, &totalLen); err != nil { + return 0, err + } + + var major, minor, patch uint64 + if err := binary.Read(r, binary.LittleEndian, &major); err != nil { + return 0, err + } + if err := binary.Read(r, binary.LittleEndian, &minor); err != nil { + return 0, err + } + if err := binary.Read(r, binary.LittleEndian, &patch); err != nil { + return 0, err + } + // TODO @gbotrel validate version, duplicate logic with core.go CheckSerializationHeader + if major != 0 || minor < 10 { + return 0, fmt.Errorf("unsupported gnark version %d.%d.%d", major, minor, patch) + } + + data := make([]byte, totalLen) + if _, err := io.ReadFull(r, data); err != nil { + return 0, err + } + n, err := cs.System.FromBytes(data) + if err != nil { + return 0, err + } + data = data[n:] + + if err := cs.CoeffTable.fromBytes(data); err != nil { + return 0, err + } + + return int64(totalLen) + 4*8, nil +} diff --git a/constraint/babybear/r1cs_test.go b/constraint/babybear/r1cs_test.go new file mode 100644 index 000000000..a99e33b36 --- /dev/null +++ b/constraint/babybear/r1cs_test.go @@ -0,0 +1,186 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs_test + +import ( + "bytes" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/gnark/internal/backend/circuits" + "reflect" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + + cs "github.com/consensys/gnark/constraint/babybear" + + "github.com/consensys/gnark-crypto/ecc/babybear/fr" +) + +func TestSerialization(t *testing.T) { + + var buffer, buffer2 bytes.Buffer + + for name := range circuits.Circuits { + t.Run(name, func(t *testing.T) { + tc := circuits.Circuits[name] + + r1cs1, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, tc.Circuit) + if err != nil { + t.Fatal(err) + } + if testing.Short() && r1cs1.GetNbConstraints() > 50 { + return + } + + // compile a second time to ensure determinism + r1cs2, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, tc.Circuit) + if err != nil { + t.Fatal(err) + } + + { + buffer.Reset() + t.Log(name) + var err error + var written, read int64 + written, err = r1cs1.WriteTo(&buffer) + if err != nil { + t.Fatal(err) + } + var reconstructed cs.R1CS + read, err = reconstructed.ReadFrom(&buffer) + if err != nil { + t.Fatal(err) + } + if written != read { + t.Fatal("didn't read same number of bytes we wrote") + } + + // compare original and reconstructed + if diff := cmp.Diff(r1cs1, &reconstructed, + cmpopts.IgnoreFields(cs.R1CS{}, + "System.q", + "field", + "CoeffTable.mCoeffs", + "System.lbWireLevel", + "System.genericHint", + "System.SymbolTable", + "System.bitLen")); diff != "" { + t.Fatalf("round trip mismatch (-want +got):\n%s", diff) + } + } + + // ensure determinism in compilation / serialization / reconstruction + { + buffer.Reset() + n, err := r1cs1.WriteTo(&buffer) + if err != nil { + t.Fatal(err) + } + if n == 0 { + t.Fatal("No bytes are written") + } + + buffer2.Reset() + _, err = r1cs2.WriteTo(&buffer2) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(buffer.Bytes(), buffer2.Bytes()) { + t.Fatal("compilation of R1CS is not deterministic") + } + + var r, r2 cs.R1CS + n, err = r.ReadFrom(&buffer) + if err != nil { + t.Fatal(nil) + } + if n == 0 { + t.Fatal("No bytes are read") + } + _, err = r2.ReadFrom(&buffer2) + if err != nil { + t.Fatal(nil) + } + + if !reflect.DeepEqual(r, r2) { + t.Fatal("compilation of R1CS is not deterministic (reconstruction)") + } + } + }) + + } +} + +const n = 10000 + +type circuit struct { + X frontend.Variable + Y frontend.Variable `gnark:",public"` +} + +func (circuit *circuit) Define(api frontend.API) error { + for i := 0; i < n; i++ { + circuit.X = api.Add(api.Mul(circuit.X, circuit.X), circuit.X, 42) + } + api.AssertIsEqual(circuit.X, circuit.Y) + return nil +} + +func BenchmarkSolve(b *testing.B) { + + var w circuit + w.X = 1 + w.Y = 1 + witness, err := frontend.NewWitness(&w, fr.Modulus()) + if err != nil { + b.Fatal(err) + } + + b.Run("scs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), scs.NewBuilder, &c) + if err != nil { + b.Fatal(err) + } + b.Log("scs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + + b.Run("r1cs", func(b *testing.B) { + var c circuit + ccs, err := frontend.Compile(fr.Modulus(), r1cs.NewBuilder, &c, frontend.WithCompressThreshold(10)) + if err != nil { + b.Fatal(err) + } + b.Log("r1cs nbConstraints", ccs.GetNbConstraints()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ccs.IsSolved(witness) + } + }) + +} diff --git a/constraint/babybear/solver.go b/constraint/babybear/solver.go new file mode 100644 index 000000000..33ce08640 --- /dev/null +++ b/constraint/babybear/solver.go @@ -0,0 +1,648 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/field/pool" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/rs/zerolog" + "math" + "math/big" + "strconv" + "strings" + "sync" + "sync/atomic" + + "github.com/consensys/gnark-crypto/ecc/babybear/fr" +) + +// solver represent the state of the solver during a call to System.Solve(...) +type solver struct { + *system + + // values and solved are index by the wire (variable) id + values []fr.Element + solved []bool + nbSolved uint64 + + // maps hintID to hint function + mHintsFunctions map[csolver.HintID]csolver.Hint + + // used to out api.Println + logger zerolog.Logger + nbTasks int + + a, b, c fr.Vector // R1CS solver will compute the a,b,c matrices + + q *big.Int +} + +func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { + // parse options + opt, err := csolver.NewConfig(opts...) + if err != nil { + return nil, err + } + + // check witness size + witnessOffset := 0 + if cs.Type == constraint.SystemR1CS { + witnessOffset++ + } + + nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables + expectedWitnessSize := len(cs.Public) - witnessOffset + len(cs.Secret) + + if len(witness) != expectedWitnessSize { + return nil, fmt.Errorf("invalid witness size, got %d, expected %d", len(witness), expectedWitnessSize) + } + + // check all hints are there + hintFunctions := opt.HintFunctions + + // hintsDependencies is from compile time; it contains the list of hints the solver **needs** + var missing []string + for hintUUID, hintID := range cs.MHintsDependencies { + if _, ok := hintFunctions[hintUUID]; !ok { + missing = append(missing, hintID) + } + } + + if len(missing) > 0 { + return nil, fmt.Errorf("solver missing hint(s): %v", missing) + } + + s := solver{ + system: cs, + values: make([]fr.Element, nbWires), + solved: make([]bool, nbWires), + mHintsFunctions: hintFunctions, + logger: opt.Logger, + nbTasks: opt.NbTasks, + q: cs.Field(), + } + + // set the witness indexes as solved + if witnessOffset == 1 { + s.solved[0] = true // ONE_WIRE + s.values[0].SetOne() + } + copy(s.values[witnessOffset:], witness) + for i := range witness { + s.solved[i+witnessOffset] = true + } + + // keep track of the number of wire instantiations we do, for a post solve sanity check + // to ensure we instantiated all wires + s.nbSolved += uint64(len(witness) + witnessOffset) + + if s.Type == constraint.SystemR1CS { + n := ecc.NextPowerOfTwo(uint64(cs.GetNbConstraints())) + s.a = make(fr.Vector, cs.GetNbConstraints(), n) + s.b = make(fr.Vector, cs.GetNbConstraints(), n) + s.c = make(fr.Vector, cs.GetNbConstraints(), n) + } + + return &s, nil +} + +func (s *solver) set(id int, value fr.Element) { + if s.solved[id] { + panic("solving the same wire twice should never happen.") + } + s.values[id] = value + s.solved[id] = true + atomic.AddUint64(&s.nbSolved, 1) +} + +// computeTerm computes coeff*variable +func (s *solver) computeTerm(t constraint.Term) fr.Element { + cID, vID := t.CoeffID(), t.WireID() + + if t.IsConstant() { + return s.Coefficients[cID] + } + + if cID != 0 && !s.solved[vID] { + panic("computing a term with an unsolved wire") + } + + switch cID { + case constraint.CoeffIdZero: + return fr.Element{} + case constraint.CoeffIdOne: + return s.values[vID] + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + return res + case constraint.CoeffIdMinusOne: + var res fr.Element + res.Neg(&s.values[vID]) + return res + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + return res + } +} + +// r += (t.coeff*t.value) +// TODO @gbotrel check t.IsConstant on the caller side when necessary +func (s *solver) accumulateInto(t constraint.Term, r *fr.Element) { + cID := t.CoeffID() + vID := t.WireID() + + if t.IsConstant() { + r.Add(r, &s.Coefficients[cID]) + return + } + + switch cID { + case constraint.CoeffIdZero: + return + case constraint.CoeffIdOne: + r.Add(r, &s.values[vID]) + case constraint.CoeffIdTwo: + var res fr.Element + res.Double(&s.values[vID]) + r.Add(r, &res) + case constraint.CoeffIdMinusOne: + r.Sub(r, &s.values[vID]) + default: + var res fr.Element + res.Mul(&s.Coefficients[cID], &s.values[vID]) + r.Add(r, &res) + } +} + +// solveWithHint executes a hint and assign the result to its defined outputs. +func (s *solver) solveWithHint(h *constraint.HintMapping) error { + // ensure hint function was provided + f, ok := s.mHintsFunctions[h.HintID] + if !ok { + return errors.New("missing hint function") + } + + // tmp IO big int memory + nbInputs := len(h.Inputs) + nbOutputs := int(h.OutputRange.End - h.OutputRange.Start) + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbOutputs; i++ { + outputs[i] = pool.BigInt.Get() + outputs[i].SetUint64(0) + } + + q := pool.BigInt.Get() + q.Set(s.q) + + for i := 0; i < nbInputs; i++ { + var v fr.Element + for _, term := range h.Inputs[i] { + if term.IsConstant() { + v.Add(&v, &s.Coefficients[term.CoeffID()]) + continue + } + s.accumulateInto(term, &v) + } + inputs[i] = pool.BigInt.Get() + v.BigInt(inputs[i]) + } + + err := f(q, inputs, outputs) + + var v fr.Element + for i := range outputs { + v.SetBigInt(outputs[i]) + s.set(int(h.OutputRange.Start)+i, v) + pool.BigInt.Put(outputs[i]) + } + + for i := range inputs { + pool.BigInt.Put(inputs[i]) + } + + pool.BigInt.Put(q) + + return err +} + +func (s *solver) printLogs(logs []constraint.LogEntry) { + if s.logger.GetLevel() == zerolog.Disabled { + return + } + + for i := 0; i < len(logs); i++ { + logLine := s.logValue(logs[i]) + s.logger.Debug().Str(zerolog.CallerFieldName, logs[i].Caller).Msg(logLine) + } +} + +const unsolvedVariable = "" + +func (s *solver) logValue(log constraint.LogEntry) string { + var toResolve []interface{} + var ( + eval fr.Element + missingValue bool + ) + for j := 0; j < len(log.ToResolve); j++ { + // before eval le + + missingValue = false + eval.SetZero() + + for _, t := range log.ToResolve[j] { + // for each term in the linear expression + + cID, vID := t.CoeffID(), t.WireID() + if t.IsConstant() { + // just add the constant + eval.Add(&eval, &s.Coefficients[cID]) + continue + } + + if !s.solved[vID] { + missingValue = true + break // stop the loop we can't evaluate. + } + + tv := s.computeTerm(t) + eval.Add(&eval, &tv) + } + + // after + if missingValue { + toResolve = append(toResolve, unsolvedVariable) + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + + } + if len(log.Stack) > 0 { + var sbb strings.Builder + for _, lID := range log.Stack { + location := s.SymbolTable.Locations[lID] + function := s.SymbolTable.Functions[location.FunctionID] + + sbb.WriteString(function.Name) + sbb.WriteByte('\n') + sbb.WriteByte('\t') + sbb.WriteString(function.Filename) + sbb.WriteByte(':') + sbb.WriteString(strconv.Itoa(int(location.Line))) + sbb.WriteByte('\n') + } + toResolve = append(toResolve, sbb.String()) + } + return fmt.Sprintf(log.Format, toResolve...) +} + +// divByCoeff sets res = res / t.Coeff +func (solver *solver) divByCoeff(res *fr.Element, cID uint32) { + switch cID { + case constraint.CoeffIdOne: + return + case constraint.CoeffIdMinusOne: + res.Neg(res) + case constraint.CoeffIdZero: + panic("division by 0") + default: + // this is slow, but shouldn't happen as divByCoeff is called to + // remove the coeff of an unsolved wire + // but unsolved wires are (in gnark frontend) systematically set with a coeff == 1 or -1 + res.Div(res, &solver.Coefficients[cID]) + } +} + +// Implement constraint.Solver +func (s *solver) GetValue(cID, vID uint32) constraint.Element { + var r constraint.Element + e := s.computeTerm(constraint.Term{CID: cID, VID: vID}) + copy(r[:], e[:]) + return r +} +func (s *solver) GetCoeff(cID uint32) constraint.Element { + var r constraint.Element + copy(r[:], s.Coefficients[cID][:]) + return r +} +func (s *solver) SetValue(vID uint32, f constraint.Element) { + s.set(int(vID), *(*fr.Element)(f[:])) +} + +func (s *solver) IsSolved(vID uint32) bool { + return s.solved[vID] +} + +// Read interprets input calldata as either a LinearExpression (if R1CS) or a Term (if Plonkish), +// evaluates it and return the result and the number of uint32 word read. +func (s *solver) Read(calldata []uint32) (constraint.Element, int) { + if s.Type == constraint.SystemSparseR1CS { + if calldata[0] != 1 { + panic("invalid calldata") + } + return s.GetValue(calldata[1], calldata[2]), 3 + } + var r fr.Element + n := int(calldata[0]) + j := 1 + for k := 0; k < n; k++ { + // we read k Terms + s.accumulateInto(constraint.Term{CID: calldata[j], VID: calldata[j+1]}, &r) + j += 2 + } + + var ret constraint.Element + copy(ret[:], r[:]) + return ret, j +} + +// processInstruction decodes the instruction and execute blueprint-defined logic. +// an instruction can encode a hint, a custom constraint or a generic constraint. +func (solver *solver) processInstruction(pi constraint.PackedInstruction, scratch *scratch) error { + // fetch the blueprint + blueprint := solver.Blueprints[pi.BlueprintID] + inst := pi.Unpack(&solver.System) + cID := inst.ConstraintOffset // here we have 1 constraint in the instruction only + + if solver.Type == constraint.SystemR1CS { + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + // TODO @gbotrel we use the solveR1C method for now, having user-defined + // blueprint for R1CS would require constraint.Solver interface to add methods + // to set a,b,c since it's more efficient to compute these while we solve. + bc.DecompressR1C(&scratch.tR1C, inst) + return solver.solveR1C(cID, &scratch.tR1C) + } + } + + // blueprint declared "I know how to solve this." + if bc, ok := blueprint.(constraint.BlueprintSolvable); ok { + if err := bc.Solve(solver, inst); err != nil { + return solver.wrapErrWithDebugInfo(cID, err) + } + return nil + } + + // blueprint encodes a hint, we execute. + // TODO @gbotrel may be worth it to move hint logic in blueprint "solve" + if bc, ok := blueprint.(constraint.BlueprintHint); ok { + bc.DecompressHint(&scratch.tHint, inst) + return solver.solveWithHint(&scratch.tHint) + } + + return nil +} + +// run runs the solver. it return an error if a constraint is not satisfied or if not all wires +// were instantiated. +func (solver *solver) run() error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 // TODO @gbotrel revisit that with blocks. + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + var wg sync.WaitGroup + chTasks := make(chan []uint32, solver.nbTasks) + chError := make(chan error, solver.nbTasks) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < solver.nbTasks; i++ { + go func() { + var scratch scratch + for t := range chTasks { + for _, i := range t { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + var scratch scratch + + // for each level, we push the tasks + for _, level := range solver.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 || solver.nbTasks == 1 { + // we do it sequentially + for _, i := range level { + if err := solver.processInstruction(solver.Instructions[i], &scratch); err != nil { + return err + } + } + continue + } + + // number of tasks for this level is set to number of CPU + // but if we don't have enough work for all our CPU, it can be lower. + nbTasks := solver.nbTasks + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + if int(solver.nbSolved) != len(solver.values) { + return errors.New("solver didn't assign a value to all wires") + } + + return nil +} + +// solveR1C compute unsolved wires in the constraint, if any and set the solver accordingly +// +// returns an error if the solver called a hint function that errored +// returns false, nil if there was no wire to solve +// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that +// the constraint is satisfied later. +func (solver *solver) solveR1C(cID uint32, r *constraint.R1C) error { + a, b, c := &solver.a[cID], &solver.b[cID], &solver.c[cID] + + // the index of the non-zero entry shows if L, R or O has an uninstantiated wire + // the content is the ID of the wire non instantiated + var loc uint8 + + var termToCompute constraint.Term + + processLExp := func(l constraint.LinearExpression, val *fr.Element, locValue uint8) { + for _, t := range l { + vID := t.WireID() + + // wire is already computed, we just accumulate in val + if solver.solved[vID] { + solver.accumulateInto(t, val) + continue + } + + if loc != 0 { + panic("found more than one wire to instantiate") + } + termToCompute = t + loc = locValue + } + } + + processLExp(r.L, a, 1) + processLExp(r.R, b, 2) + processLExp(r.O, c, 3) + + if loc == 0 { + // there is nothing to solve, may happen if we have an assertion + // (ie a constraints that doesn't yield any output) + // or if we solved the unsolved wires with hint functions + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + return nil + } + + // we compute the wire value and instantiate it + wID := termToCompute.WireID() + + // solver result + var wire fr.Element + + switch loc { + case 1: + if !b.IsZero() { + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) + } else { + // we didn't actually ensure that a * b == c + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 2: + if !a.IsZero() { + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) + } else { + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return solver.wrapErrWithDebugInfo(cID, fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String())) + } + } + case 3: + wire.Mul(a, b). + Sub(&wire, c) + + c.Add(c, &wire) + } + + // wire is the term (coeff * value) + // but in the solver we want to store the value only + // note that in gnark frontend, coeff here is always 1 or -1 + solver.divByCoeff(&wire, termToCompute.CID) + solver.set(wID, wire) + + return nil +} + +// UnsatisfiedConstraintError wraps an error with useful metadata on the unsatisfied constraint +type UnsatisfiedConstraintError struct { + Err error + CID int // constraint ID + DebugInfo *string // optional debug info +} + +func (r *UnsatisfiedConstraintError) Error() string { + if r.DebugInfo != nil { + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, *r.DebugInfo) + } + return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) +} + +func (solver *solver) wrapErrWithDebugInfo(cID uint32, err error) *UnsatisfiedConstraintError { + var debugInfo *string + if dID, ok := solver.MDebug[int(cID)]; ok { + debugInfo = new(string) + *debugInfo = solver.logValue(solver.DebugInfo[dID]) + } + return &UnsatisfiedConstraintError{CID: int(cID), Err: err, DebugInfo: debugInfo} +} + +// temporary variables to avoid memallocs in hotloop +type scratch struct { + tR1C constraint.R1C + tHint constraint.HintMapping +} diff --git a/constraint/babybear/system.go b/constraint/babybear/system.go new file mode 100644 index 000000000..6e98dfbbd --- /dev/null +++ b/constraint/babybear/system.go @@ -0,0 +1,304 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "io" + "time" + + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + csolver "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/logger" + + "github.com/consensys/gnark-crypto/ecc" + + "github.com/consensys/gnark-crypto/ecc/babybear/fr" +) + +type R1CS = system +type SparseR1CS = system + +// system is a curved-typed constraint.System with a concrete coefficient table (fr.Element) +type system struct { + constraint.System + CoeffTable + field +} + +// NewR1CS is a constructor for R1CS. It is meant to be use by gnark frontend only, +// and should not be used by gnark users. See groth16.NewCS(...) instead. +func NewR1CS(capacity int) *R1CS { + return newSystem(capacity, constraint.SystemR1CS) +} + +// NewSparseR1CS is a constructor for SparseR1CS. It is meant to be use by gnark frontend only, +// and should not be used by gnark users. See plonk.NewCS(...) instead. +func NewSparseR1CS(capacity int) *SparseR1CS { + return newSystem(capacity, constraint.SystemSparseR1CS) +} + +func newSystem(capacity int, t constraint.SystemType) *system { + return &system{ + System: constraint.NewSystem(fr.Modulus(), capacity, t), + CoeffTable: newCoeffTable(capacity / 10), + } +} + +// Solve solves the constraint system with provided witness. +// If it's a R1CS returns R1CSSolution +// If it's a SparseR1CS returns SparseR1CSSolution +func (cs *system) Solve(witness witness.Witness, opts ...csolver.Option) (any, error) { + log := logger.Logger().With().Int("nbConstraints", cs.GetNbConstraints()).Logger() + start := time.Now() + + v := witness.Vector().(fr.Vector) + + // init the solver + solver, err := newSolver(cs, v, opts...) + if err != nil { + log.Err(err).Send() + return nil, err + } + + // reset the stateful blueprints + for i := range cs.Blueprints { + if b, ok := cs.Blueprints[i].(constraint.BlueprintStateful); ok { + b.Reset() + } + } + + // defer log printing once all solver.values are computed + // (or sooner, if a constraint is not satisfied) + defer solver.printLogs(cs.Logs) + + // run it. + if err := solver.run(); err != nil { + log.Err(err).Send() + return nil, err + } + + log.Debug().Dur("took", time.Since(start)).Msg("constraint system solver done") + + // format the solution + // TODO @gbotrel revisit post-refactor + if cs.Type == constraint.SystemR1CS { + var res R1CSSolution + res.W = solver.values + res.A = solver.a + res.B = solver.b + res.C = solver.c + return &res, nil + } else { + // sparse R1CS + var res SparseR1CSSolution + // query l, r, o in Lagrange basis, not blinded + res.L, res.R, res.O = evaluateLROSmallDomain(cs, solver.values) + + return &res, nil + } + +} + +// IsSolved +// Deprecated: use _, err := Solve(...) instead +func (cs *system) IsSolved(witness witness.Witness, opts ...csolver.Option) error { + _, err := cs.Solve(witness, opts...) + return err +} + +// GetR1Cs return the list of R1C +func (cs *system) GetR1Cs() []constraint.R1C { + toReturn := make([]constraint.R1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintR1C); ok { + var r1c constraint.R1C + bc.DecompressR1C(&r1c, inst.Unpack(&cs.System)) + toReturn = append(toReturn, r1c) + } + } + return toReturn +} + +// GetNbCoefficients return the number of unique coefficients needed in the R1CS +func (cs *system) GetNbCoefficients() int { + return len(cs.Coefficients) +} + +// CurveID returns curve ID as defined in gnark-crypto +func (cs *system) CurveID() ecc.ID { + return ecc.UNKNOWN +} + +func (cs *system) GetCoefficient(i int) (r constraint.Element) { + copy(r[:], cs.Coefficients[i][:]) + return +} + +// GetSparseR1Cs return the list of SparseR1C +func (cs *system) GetSparseR1Cs() []constraint.SparseR1C { + + toReturn := make([]constraint.SparseR1C, 0, cs.GetNbConstraints()) + + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + var sparseR1C constraint.SparseR1C + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + toReturn = append(toReturn, sparseR1C) + } + } + return toReturn +} + +// evaluateLROSmallDomain extracts the solver l, r, o, and returns it in lagrange form. +// solver = [ public | secret | internal ] +// TODO @gbotrel refactor; this seems to be a small util function for plonk +func evaluateLROSmallDomain(cs *system, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { + + //s := int(pk.Domain[0].Cardinality) + s := cs.GetNbConstraints() + len(cs.Public) // len(spr.Public) is for the placeholder constraints + s = int(ecc.NextPowerOfTwo(uint64(s))) + + var l, r, o []fr.Element + l = make([]fr.Element, s, s+4) // +4 to leave room for the blinding in plonk + r = make([]fr.Element, s, s+4) + o = make([]fr.Element, s, s+4) + s0 := solution[0] + + for i := 0; i < len(cs.Public); i++ { // placeholders + l[i] = solution[i] + r[i] = s0 + o[i] = s0 + } + offset := len(cs.Public) + nbConstraints := cs.GetNbConstraints() + + var sparseR1C constraint.SparseR1C + j := 0 + for _, inst := range cs.Instructions { + blueprint := cs.Blueprints[inst.BlueprintID] + if bc, ok := blueprint.(constraint.BlueprintSparseR1C); ok { + bc.DecompressSparseR1C(&sparseR1C, inst.Unpack(&cs.System)) + + l[offset+j] = solution[sparseR1C.XA] + r[offset+j] = solution[sparseR1C.XB] + o[offset+j] = solution[sparseR1C.XC] + j++ + } + } + + offset += nbConstraints + + for i := 0; i < s-offset; i++ { // offset to reach 2**n constraints (where the id of l,r,o is 0, so we assign solver[0]) + l[offset+i] = s0 + r[offset+i] = s0 + o[offset+i] = s0 + } + + return l, r, o + +} + +// R1CSSolution represent a valid assignment to all the variables in the constraint system. +// The vector W such that Aw o Bw - Cw = 0 +type R1CSSolution struct { + W fr.Vector + A, B, C fr.Vector +} + +func (t *R1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.W.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.A.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.B.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.C.WriteTo(w) + n += a + return n, err +} + +func (t *R1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.W.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.A.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.B.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.C.ReadFrom(r) + n += a + return n, err +} + +// SparseR1CSSolution represent a valid assignment to all the variables in the constraint system. +type SparseR1CSSolution struct { + L, R, O fr.Vector +} + +func (t *SparseR1CSSolution) WriteTo(w io.Writer) (int64, error) { + n, err := t.L.WriteTo(w) + if err != nil { + return n, err + } + a, err := t.R.WriteTo(w) + n += a + if err != nil { + return n, err + } + a, err = t.O.WriteTo(w) + n += a + return n, err + +} + +func (t *SparseR1CSSolution) ReadFrom(r io.Reader) (int64, error) { + n, err := t.L.ReadFrom(r) + if err != nil { + return n, err + } + a, err := t.R.ReadFrom(r) + n += a + if err != nil { + return n, err + } + a, err = t.O.ReadFrom(r) + n += a + return n, err +} + +func (s *system) AddGkr(gkr constraint.GkrInfo) error { + return s.System.AddGkr(gkr) +} diff --git a/internal/smallfields/babybear/arith.go b/internal/smallfields/babybear/arith.go new file mode 100644 index 000000000..3dfd7e5ff --- /dev/null +++ b/internal/smallfields/babybear/arith.go @@ -0,0 +1,60 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package babybear + +import ( + "math/bits" +) + +// madd0 hi = a*b + c (discards lo bits) +func madd0(a, b, c uint64) (hi uint64) { + var carry, lo uint64 + hi, lo = bits.Mul64(a, b) + _, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +// madd1 hi, lo = a*b + c +func madd1(a, b, c uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +// madd2 hi, lo = a*b + c + d +func madd2(a, b, c, d uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + c, carry = bits.Add64(c, d, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + c, carry = bits.Add64(c, d, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, e, carry) + return +} diff --git a/internal/smallfields/babybear/doc.go b/internal/smallfields/babybear/doc.go new file mode 100644 index 000000000..65011bef3 --- /dev/null +++ b/internal/smallfields/babybear/doc.go @@ -0,0 +1,53 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// Package babybear contains field arithmetic operations for modulus = 0x78000001. +// +// The API is similar to math/big (big.Int), but the operations are significantly faster (up to 20x for the modular multiplication on amd64, see also https://hackmd.io/@gnark/modular_multiplication) +// +// The modulus is hardcoded in all the operations. +// +// Field elements are represented as an array, and assumed to be in Montgomery form in all methods: +// +// type Element [1]uint64 +// +// # Usage +// +// Example API signature: +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element +// +// and can be used like so: +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) +// +// Modulus q = +// +// q[base10] = 2013265921 +// q[base16] = 0x78000001 +// +// # Warning +// +// This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. +package babybear diff --git a/internal/smallfields/babybear/element.go b/internal/smallfields/babybear/element.go new file mode 100644 index 000000000..ec8175f7c --- /dev/null +++ b/internal/smallfields/babybear/element.go @@ -0,0 +1,1059 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package babybear + +import ( + "crypto/rand" + "encoding/binary" + "errors" + "io" + "math/big" + "math/bits" + "reflect" + "strconv" + "strings" + + "github.com/bits-and-blooms/bitset" + "github.com/consensys/gnark-crypto/field/hash" + "github.com/consensys/gnark-crypto/field/pool" +) + +// Element represents a field element stored on 1 words (uint64) +// +// Element are assumed to be in Montgomery form in all methods. +// +// Modulus q = +// +// q[base10] = 2013265921 +// q[base16] = 0x78000001 +// +// # Warning +// +// This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. +type Element [1]uint64 + +const ( + Limbs = 1 // number of 64 bits words needed to represent a Element + Bits = 31 // number of bits needed to represent a Element + Bytes = 8 // number of bytes needed to represent a Element +) + +// Field modulus q +const ( + q0 uint64 = 2013265921 + q uint64 = q0 +) + +var qElement = Element{ + q0, +} + +var _modulus big.Int // q stored as big.Int + +// Modulus returns q as a big.Int +// +// q[base10] = 2013265921 +// q[base16] = 0x78000001 +func Modulus() *big.Int { + return new(big.Int).Set(&_modulus) +} + +// q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r +// used for Montgomery reduction +const qInvNeg uint64 = 14393504411089371135 + +func init() { + _modulus.SetString("78000001", 16) +} + +// NewElement returns a new Element from a uint64 value +// +// it is equivalent to +// +// var v Element +// v.SetUint64(...) +func NewElement(v uint64) Element { + z := Element{v} + z.Mul(&z, &rSquare) + return z +} + +// SetUint64 sets z to v and returns z +func (z *Element) SetUint64(v uint64) *Element { + // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form + *z = Element{v} + return z.Mul(z, &rSquare) // z.toMont() +} + +// SetInt64 sets z to v and returns z +func (z *Element) SetInt64(v int64) *Element { + + // absolute value of v + m := v >> 63 + z.SetUint64(uint64((v ^ m) - m)) + + if m != 0 { + // v is negative + z.Neg(z) + } + + return z +} + +// Set z = x and returns z +func (z *Element) Set(x *Element) *Element { + z[0] = x[0] + return z +} + +// SetInterface converts provided interface into Element +// returns an error if provided type is not supported +// supported types: +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte +func (z *Element) SetInterface(i1 interface{}) (*Element, error) { + if i1 == nil { + return nil, errors.New("can't set babybear.Element with ") + } + + switch c1 := i1.(type) { + case Element: + return z.Set(&c1), nil + case *Element: + if c1 == nil { + return nil, errors.New("can't set babybear.Element with ") + } + return z.Set(c1), nil + case uint8: + return z.SetUint64(uint64(c1)), nil + case uint16: + return z.SetUint64(uint64(c1)), nil + case uint32: + return z.SetUint64(uint64(c1)), nil + case uint: + return z.SetUint64(uint64(c1)), nil + case uint64: + return z.SetUint64(c1), nil + case int8: + return z.SetInt64(int64(c1)), nil + case int16: + return z.SetInt64(int64(c1)), nil + case int32: + return z.SetInt64(int64(c1)), nil + case int64: + return z.SetInt64(c1), nil + case int: + return z.SetInt64(int64(c1)), nil + case string: + return z.SetString(c1) + case *big.Int: + if c1 == nil { + return nil, errors.New("can't set babybear.Element with ") + } + return z.SetBigInt(c1), nil + case big.Int: + return z.SetBigInt(&c1), nil + case []byte: + return z.SetBytes(c1), nil + default: + return nil, errors.New("can't set babybear.Element from type " + reflect.TypeOf(i1).String()) + } +} + +// SetZero z = 0 +func (z *Element) SetZero() *Element { + z[0] = 0 + return z +} + +// SetOne z = 1 (in Montgomery form) +func (z *Element) SetOne() *Element { + z[0] = 1172168163 + return z +} + +// Div z = x*y⁻¹ (mod q) +func (z *Element) Div(x, y *Element) *Element { + var yInv Element + yInv.Inverse(y) + z.Mul(x, &yInv) + return z +} + +// Equal returns z == x; constant-time +func (z *Element) Equal(x *Element) bool { + return z.NotEqual(x) == 0 +} + +// NotEqual returns 0 if and only if z == x; constant-time +func (z *Element) NotEqual(x *Element) uint64 { + return (z[0] ^ x[0]) +} + +// IsZero returns z == 0 +func (z *Element) IsZero() bool { + return (z[0]) == 0 +} + +// IsOne returns z == 1 +func (z *Element) IsOne() bool { + return z[0] == 1172168163 +} + +// IsUint64 reports whether z can be represented as an uint64. +func (z *Element) IsUint64() bool { + return true +} + +// Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. +func (z *Element) Uint64() uint64 { + return z.Bits()[0] +} + +// FitsOnOneWord reports whether z words (except the least significant word) are 0 +// +// It is the responsibility of the caller to convert from Montgomery to Regular form if needed. +func (z *Element) FitsOnOneWord() bool { + return true +} + +// Cmp compares (lexicographic order) z and x and returns: +// +// -1 if z < x +// 0 if z == x +// +1 if z > x +func (z *Element) Cmp(x *Element) int { + _z := z.Bits() + _x := x.Bits() + if _z[0] > _x[0] { + return 1 + } else if _z[0] < _x[0] { + return -1 + } + return 0 +} + +// LexicographicallyLargest returns true if this element is strictly lexicographically +// larger than its negation, false otherwise +func (z *Element) LexicographicallyLargest() bool { + // adapted from github.com/zkcrypto/bls12_381 + // we check if the element is larger than (q-1) / 2 + // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 + + _z := z.Bits() + + var b uint64 + _, b = bits.Sub64(_z[0], 1006632961, 0) + + return b == 0 +} + +// SetRandom sets z to a uniform random value in [0, q). +// +// This might error only if reading from crypto/rand.Reader errors, +// in which case, value of z is undefined. +func (z *Element) SetRandom() (*Element, error) { + // this code is generated for all modulus + // and derived from go/src/crypto/rand/util.go + + // l is number of limbs * 8; the number of bytes needed to reconstruct 1 uint64 + const l = 8 + + // bitLen is the maximum bit length needed to encode a value < q. + const bitLen = 31 + + // k is the maximum byte length needed to encode a value < q. + const k = (bitLen + 7) / 8 + + // b is the number of bits in the most significant byte of q-1. + b := uint(bitLen % 8) + if b == 0 { + b = 8 + } + + var bytes [l]byte + + for { + // note that bytes[k:l] is always 0 + if _, err := io.ReadFull(rand.Reader, bytes[:k]); err != nil { + return nil, err + } + + // Clear unused bits in in the most significant byte to increase probability + // that the candidate is < q. + bytes[k-1] &= uint8(int(1<> 1 + z[0] >>= 1 + +} + +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// sets and returns z = z * 1 +func (z *Element) fromMont() *Element { + fromMont(z) + return z +} + +// Add z = x + y (mod q) +func (z *Element) Add(x, y *Element) *Element { + + z[0], _ = bits.Add64(x[0], y[0], 0) + if z[0] >= q { + z[0] -= q + } + return z +} + +// Double z = x + x (mod q), aka Lsh 1 +func (z *Element) Double(x *Element) *Element { + if x[0]&(1<<63) == (1 << 63) { + // if highest bit is set, then we have a carry to x + x, we shift and subtract q + z[0] = (x[0] << 1) - q + } else { + // highest bit is not set, but x + x can still be >= q + z[0] = (x[0] << 1) + if z[0] >= q { + z[0] -= q + } + } + return z +} + +// Sub z = x - y (mod q) +func (z *Element) Sub(x, y *Element) *Element { + var b uint64 + z[0], b = bits.Sub64(x[0], y[0], 0) + if b != 0 { + z[0] += q + } + return z +} + +// Neg z = q - x +func (z *Element) Neg(x *Element) *Element { + if x.IsZero() { + z.SetZero() + return z + } + z[0] = q - x[0] + return z +} + +// Select is a constant-time conditional move. +// If c=0, z = x0. Else z = x1 +func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { + cC := uint64((int64(c) | -int64(c)) >> 63) // "canonicized" into: 0 if c=0, -1 otherwise + z[0] = x0[0] ^ cC&(x0[0]^x1[0]) + return z +} + +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. +func _mulGeneric(z, x, y *Element) { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + var t [2]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + + t[1], D = bits.Add64(t[1], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + + t[0], C = bits.Add64(t[1], C, 0) + t[1], _ = bits.Add64(0, D, C) + + if t[1] != 0 { + // we need to reduce, we have a result on 2 words + z[0], _ = bits.Sub64(t[0], q0, 0) + return + } + + // copy t into z + z[0] = t[0] + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + z[0] -= q + } +} + +func _fromMontGeneric(z *Element) { + // the following lines implement z = z * 1 + // with a modified CIOS montgomery multiplication + // see Mul for algorithm documentation + { + // m = z[0]n'[0] mod W + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + z[0] = C + } + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + z[0] -= q + } +} + +func _reduceGeneric(z *Element) { + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + z[0] -= q + } +} + +// BatchInvert returns a new slice with every element inverted. +// Uses Montgomery batch inversion trick +func BatchInvert(a []Element) []Element { + res := make([]Element, len(a)) + if len(a) == 0 { + return res + } + + zeroes := bitset.New(uint(len(a))) + accumulator := One() + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + zeroes.Set(uint(i)) + continue + } + res[i] = accumulator + accumulator.Mul(&accumulator, &a[i]) + } + + accumulator.Inverse(&accumulator) + + for i := len(a) - 1; i >= 0; i-- { + if zeroes.Test(uint(i)) { + continue + } + res[i].Mul(&res[i], &accumulator) + accumulator.Mul(&accumulator, &a[i]) + } + + return res +} + +func _butterflyGeneric(a, b *Element) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +// BitLen returns the minimum number of bits needed to represent z +// returns 0 if z == 0 +func (z *Element) BitLen() int { + return bits.Len64(z[0]) +} + +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := hash.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := pool.BigInt.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + pool.BigInt.Put(vv) + + return res, nil +} + +// Exp z = xᵏ (mod q) +func (z *Element) Exp(x Element, k *big.Int) *Element { + if k.IsUint64() && k.Uint64() == 0 { + return z.SetOne() + } + + e := k + if k.Sign() == -1 { + // negative k, we invert + // if k < 0: xᵏ (mod q) == (x⁻¹)ᵏ (mod q) + x.Inverse(&x) + + // we negate k in a temp big.Int since + // Int.Bit(_) of k and -k is different + e = pool.BigInt.Get() + defer pool.BigInt.Put(e) + e.Neg(k) + } + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z +} + +// rSquare where r is the Montgommery constant +// see section 2.3.2 of Tolga Acar's thesis +// https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf +var rSquare = Element{ + 663890614, +} + +// toMont converts z to Montgomery form +// sets and returns z = z * r² +func (z *Element) toMont() *Element { + return z.Mul(z, &rSquare) +} + +// String returns the decimal representation of z as generated by +// z.Text(10). +func (z *Element) String() string { + return z.Text(10) +} + +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint64(b[0:8], z[0]) + + return res.SetBytes(b[:]) +} + +// Text returns the string representation of z in the given base. +// Base must be between 2 and 36, inclusive. The result uses the +// lower-case letters 'a' to 'z' for digit values 10 to 35. +// No prefix (such as "0x") is added to the string. If z is a nil +// pointer it returns "". +// If base == 10 and -z fits in a uint16 prefix "-" is added to the string. +func (z *Element) Text(base int) string { + if base < 2 || base > 36 { + panic("invalid base") + } + if z == nil { + return "" + } + + const maxUint16 = 65535 + if base == 10 { + var zzNeg Element + zzNeg.Neg(z) + zzNeg.fromMont() + if zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { + return "-" + strconv.FormatUint(zzNeg[0], base) + } + } + zz := z.Bits() + return strconv.FormatUint(zz[0], base) +} + +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) +} + +// ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead +func (z Element) ToBigIntRegular(res *big.Int) *big.Int { + z.fromMont() + return z.toBigInt(res) +} + +// Bits provides access to z by returning its value as a little-endian [1]uint64 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [1]uint64 { + _z := *z + fromMont(&_z) + return _z +} + +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) + return +} + +// Marshal returns the value of z as a big-endian byte slice +func (z *Element) Marshal() []byte { + b := z.Bytes() + return b[:] +} + +// Unmarshal is an alias for SetBytes, it sets z to the value of e. +func (z *Element) Unmarshal(e []byte) { + z.SetBytes(e) +} + +// SetBytes interprets e as the bytes of a big-endian unsigned integer, +// sets z to that value, and returns z. +func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. + // get a big int from our pool + vv := pool.BigInt.Get() + vv.SetBytes(e) + + // set big int + z.SetBigInt(vv) + + // put temporary object back in pool + pool.BigInt.Put(vv) + + return z +} + +// SetBytesCanonical interprets e as the bytes of a big-endian 8-byte integer. +// If e is not a 8-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid babybear.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + +// SetBigInt sets z to v and returns z +func (z *Element) SetBigInt(v *big.Int) *Element { + z.SetZero() + + var zero big.Int + + // fast path + c := v.Cmp(&_modulus) + if c == 0 { + // v == 0 + return z + } else if c != 1 && v.Cmp(&zero) != -1 { + // 0 < v < q + return z.setBigInt(v) + } + + // get temporary big int from the pool + vv := pool.BigInt.Get() + + // copy input + modular reduction + vv.Mod(v, &_modulus) + + // set big int byte value + z.setBigInt(vv) + + // release object into pool + pool.BigInt.Put(vv) + return z +} + +// setBigInt assumes 0 ⩽ v < q +func (z *Element) setBigInt(v *big.Int) *Element { + vBits := v.Bits() + + if bits.UintSize == 64 { + for i := 0; i < len(vBits); i++ { + z[i] = uint64(vBits[i]) + } + } else { + for i := 0; i < len(vBits); i++ { + if i%2 == 0 { + z[i/2] = uint64(vBits[i]) + } else { + z[i/2] |= uint64(vBits[i]) << 32 + } + } + } + + return z.toMont() +} + +// SetString creates a big.Int with number and calls SetBigInt on z +// +// The number prefix determines the actual base: A prefix of +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 +// and no prefix is accepted. +// +// For base 16, lower and upper case letters are considered the same: +// The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. +// +// An underscore character ”_” may appear between a base +// prefix and an adjacent digit, and between successive digits; such +// underscores do not change the value of the number. +// Incorrect placement of underscores is reported as a panic if there +// are no other errors. +// +// If the number is invalid this method leaves z unchanged and returns nil, error. +func (z *Element) SetString(number string) (*Element, error) { + // get temporary big int from the pool + vv := pool.BigInt.Get() + + if _, ok := vv.SetString(number, 0); !ok { + return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) + } + + z.SetBigInt(vv) + + // release object into pool + pool.BigInt.Put(vv) + + return z, nil +} + +// MarshalJSON returns json encoding of z (z.Text(10)) +// If z == nil, returns null +func (z *Element) MarshalJSON() ([]byte, error) { + if z == nil { + return []byte("null"), nil + } + const maxSafeBound = 15 // we encode it as number if it's small + s := z.Text(10) + if len(s) <= maxSafeBound { + return []byte(s), nil + } + var sbb strings.Builder + sbb.WriteByte('"') + sbb.WriteString(s) + sbb.WriteByte('"') + return []byte(sbb.String()), nil +} + +// UnmarshalJSON accepts numbers and strings as input +// See Element.SetString for valid prefixes (0x, 0b, ...) +func (z *Element) UnmarshalJSON(data []byte) error { + s := string(data) + if len(s) > Bits*3 { + return errors.New("value too large (max = Element.Bits * 3)") + } + + // we accept numbers and strings, remove leading and trailing quotes if any + if len(s) > 0 && s[0] == '"' { + s = s[1:] + } + if len(s) > 0 && s[len(s)-1] == '"' { + s = s[:len(s)-1] + } + + // get temporary big int from the pool + vv := pool.BigInt.Get() + + if _, ok := vv.SetString(s, 0); !ok { + return errors.New("can't parse into a big.Int: " + s) + } + + z.SetBigInt(vv) + + // release object into pool + pool.BigInt.Put(vv) + return nil +} + +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 8-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid babybear.Element encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint64((*b)[0:8], e[0]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid babybear.Element encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) +} + +func (littleEndian) String() string { return "LittleEndian" } + +var ( + _bLegendreExponentElement *big.Int + _bSqrtExponentElement *big.Int +) + +func init() { + _bLegendreExponentElement, _ = new(big.Int).SetString("3c000000", 16) + const sqrtExponentElement = "7" + _bSqrtExponentElement, _ = new(big.Int).SetString(sqrtExponentElement, 16) +} + +// Legendre returns the Legendre symbol of z (either +1, -1, or 0.) +func (z *Element) Legendre() int { + var l Element + // z^((q-1)/2) + l.Exp(*z, _bLegendreExponentElement) + + if l.IsZero() { + return 0 + } + + // if l == 1 + if l.IsOne() { + return 1 + } + return -1 +} + +// Sqrt z = √x (mod q) +// if the square root doesn't exist (x is not a square mod q) +// Sqrt leaves z unchanged and returns nil +func (z *Element) Sqrt(x *Element) *Element { + // q ≡ 1 (mod 4) + // see modSqrtTonelliShanks in math/big/int.go + // using https://www.maa.org/sites/default/files/pdf/upload_library/22/Polya/07468342.di020786.02p0470a.pdf + + var y, b, t, w Element + // w = x^((s-1)/2)) + w.Exp(*x, _bSqrtExponentElement) + + // y = x^((s+1)/2)) = w * x + y.Mul(x, &w) + + // b = xˢ = w * w * x = y * x + b.Mul(&w, &y) + + // g = nonResidue ^ s + var g = Element{ + 1738020498, + } + r := uint64(27) + + // compute legendre symbol + // t = x^((q-1)/2) = r-1 squaring of xˢ + t = b + for i := uint64(0); i < r-1; i++ { + t.Square(&t) + } + if t.IsZero() { + return z.SetZero() + } + if !t.IsOne() { + // t != 1, we don't have a square root + return nil + } + for { + var m uint64 + t = b + + // for t != 1 + for !t.IsOne() { + t.Square(&t) + m++ + } + + if m == 0 { + return z.Set(&y) + } + // t = g^(2^(r-m-1)) (mod q) + ge := int(r - m - 1) + t = g + for ge > 0 { + t.Square(&t) + ge-- + } + + g.Square(&t) + y.Mul(&y, &t) + b.Mul(&b, &g) + r = m + } +} + +// Inverse z = x⁻¹ (mod q) +// +// if x == 0, sets and returns z = x +func (z *Element) Inverse(x *Element) *Element { + // Algorithm 16 in "Efficient Software-Implementation of Finite Fields with Applications to Cryptography" + const q uint64 = q0 + if x.IsZero() { + z.SetZero() + return z + } + + var r, s, u, v uint64 + u = q + s = 663890614 // s = r² + r = 0 + v = x[0] + + var carry, borrow uint64 + + for (u != 1) && (v != 1) { + for v&1 == 0 { + v >>= 1 + if s&1 == 0 { + s >>= 1 + } else { + s, carry = bits.Add64(s, q, 0) + s >>= 1 + if carry != 0 { + s |= (1 << 63) + } + } + } + for u&1 == 0 { + u >>= 1 + if r&1 == 0 { + r >>= 1 + } else { + r, carry = bits.Add64(r, q, 0) + r >>= 1 + if carry != 0 { + r |= (1 << 63) + } + } + } + if v >= u { + v -= u + s, borrow = bits.Sub64(s, r, 0) + if borrow == 1 { + s += q + } + } else { + u -= v + r, borrow = bits.Sub64(r, s, 0) + if borrow == 1 { + r += q + } + } + } + + if u == 1 { + z[0] = r + } else { + z[0] = s + } + + return z +} diff --git a/internal/smallfields/babybear/element_ops_purego.go b/internal/smallfields/babybear/element_ops_purego.go new file mode 100644 index 000000000..0db661075 --- /dev/null +++ b/internal/smallfields/babybear/element_ops_purego.go @@ -0,0 +1,127 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package babybear + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + var y Element + y.SetUint64(3) + x.Mul(x, &y) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + var y Element + y.SetUint64(5) + x.Mul(x, &y) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y Element + y.SetUint64(13) + x.Mul(x, &y) +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // In fact, since the modulus R fits on one register, the CIOS algorithm gets reduced to standard REDC (textbook Montgomery reduction): + // hi, lo := x * y + // m := (lo * qInvNeg) mod R + // (*) r := (hi * R + lo + m * q) / R + // reduce r if necessary + + // On the emphasized line, we get r = hi + (lo + m * q) / R + // If we write hi2, lo2 = m * q then R | m * q - lo2 ⇒ R | (lo * qInvNeg) q - lo2 = -lo - lo2 + // This shows lo + lo2 = 0 mod R. i.e. lo + lo2 = 0 if lo = 0 and R otherwise. + // Which finally gives (lo + m * q) / R = (lo + lo2 + R hi2) / R = hi2 + (lo+lo2) / R = hi2 + (lo != 0) + // This "optimization" lets us do away with one MUL instruction on ARM architectures and is available for all q < R. + + var r uint64 + hi, lo := bits.Mul64(x[0], y[0]) + if lo != 0 { + hi++ // x[0] * y[0] ≤ 2¹²⁸ - 2⁶⁵ + 1, meaning hi ≤ 2⁶⁴ - 2 so no need to worry about overflow + } + m := lo * qInvNeg + hi2, _ := bits.Mul64(m, q) + r, carry := bits.Add64(hi2, hi, 0) + + if carry != 0 || r >= q { + // we need to reduce + r -= q + } + z[0] = r + + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + // In fact, since the modulus R fits on one register, the CIOS algorithm gets reduced to standard REDC (textbook Montgomery reduction): + // hi, lo := x * y + // m := (lo * qInvNeg) mod R + // (*) r := (hi * R + lo + m * q) / R + // reduce r if necessary + + // On the emphasized line, we get r = hi + (lo + m * q) / R + // If we write hi2, lo2 = m * q then R | m * q - lo2 ⇒ R | (lo * qInvNeg) q - lo2 = -lo - lo2 + // This shows lo + lo2 = 0 mod R. i.e. lo + lo2 = 0 if lo = 0 and R otherwise. + // Which finally gives (lo + m * q) / R = (lo + lo2 + R hi2) / R = hi2 + (lo+lo2) / R = hi2 + (lo != 0) + // This "optimization" lets us do away with one MUL instruction on ARM architectures and is available for all q < R. + + var r uint64 + hi, lo := bits.Mul64(x[0], x[0]) + if lo != 0 { + hi++ // x[0] * y[0] ≤ 2¹²⁸ - 2⁶⁵ + 1, meaning hi ≤ 2⁶⁴ - 2 so no need to worry about overflow + } + m := lo * qInvNeg + hi2, _ := bits.Mul64(m, q) + r, carry := bits.Add64(hi2, hi, 0) + + if carry != 0 || r >= q { + // we need to reduce + r -= q + } + z[0] = r + + return z +} diff --git a/internal/smallfields/babybear/element_test.go b/internal/smallfields/babybear/element_test.go new file mode 100644 index 000000000..cb30fd82d --- /dev/null +++ b/internal/smallfields/babybear/element_test.go @@ -0,0 +1,2208 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package babybear + +import ( + "crypto/rand" + "encoding/json" + "fmt" + "math/big" + "math/bits" + + "testing" + + "github.com/leanovate/gopter" + ggen "github.com/leanovate/gopter/gen" + "github.com/leanovate/gopter/prop" + + "github.com/stretchr/testify/require" +) + +// ------------------------------------------------------------------------------------------------- +// benchmarks +// most benchmarks are rudimentary and should sample a large number of random inputs +// or be run multiple times to ensure it didn't measure the fastest path of the function + +var benchResElement Element + +func BenchmarkElementSelect(b *testing.B) { + var x, y Element + x.SetRandom() + y.SetRandom() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Select(i%3, &x, &y) + } +} + +func BenchmarkElementSetRandom(b *testing.B) { + var x Element + x.SetRandom() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = x.SetRandom() + } +} + +func BenchmarkElementSetBytes(b *testing.B) { + var x Element + x.SetRandom() + bb := x.Bytes() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchResElement.SetBytes(bb[:]) + } + +} + +func BenchmarkElementMulByConstants(b *testing.B) { + b.Run("mulBy3", func(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + MulBy3(&benchResElement) + } + }) + b.Run("mulBy5", func(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + MulBy5(&benchResElement) + } + }) + b.Run("mulBy13", func(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + MulBy13(&benchResElement) + } + }) +} + +func BenchmarkElementInverse(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchResElement.Inverse(&x) + } + +} + +func BenchmarkElementButterfly(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + Butterfly(&x, &benchResElement) + } +} + +func BenchmarkElementExp(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b1, _ := rand.Int(rand.Reader, Modulus()) + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Exp(x, b1) + } +} + +func BenchmarkElementDouble(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Double(&benchResElement) + } +} + +func BenchmarkElementAdd(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Add(&x, &benchResElement) + } +} + +func BenchmarkElementSub(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Sub(&x, &benchResElement) + } +} + +func BenchmarkElementNeg(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Neg(&benchResElement) + } +} + +func BenchmarkElementDiv(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Div(&x, &benchResElement) + } +} + +func BenchmarkElementFromMont(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.fromMont() + } +} + +func BenchmarkElementSquare(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Square(&benchResElement) + } +} + +func BenchmarkElementSqrt(b *testing.B) { + var a Element + a.SetUint64(4) + a.Neg(&a) + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Sqrt(&a) + } +} + +func BenchmarkElementMul(b *testing.B) { + x := Element{ + 663890614, + } + benchResElement.SetOne() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Mul(&benchResElement, &x) + } +} + +func BenchmarkElementCmp(b *testing.B) { + x := Element{ + 663890614, + } + benchResElement = x + benchResElement[0] = 0 + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Cmp(&x) + } +} + +func TestElementCmp(t *testing.T) { + var x, y Element + + if x.Cmp(&y) != 0 { + t.Fatal("x == y") + } + + one := One() + y.Sub(&y, &one) + + if x.Cmp(&y) != -1 { + t.Fatal("x < y") + } + if y.Cmp(&x) != 1 { + t.Fatal("x < y") + } + + x = y + if x.Cmp(&y) != 0 { + t.Fatal("x == y") + } + + x.Sub(&x, &one) + if x.Cmp(&y) != -1 { + t.Fatal("x < y") + } + if y.Cmp(&x) != 1 { + t.Fatal("x < y") + } +} + +func TestElementNegZero(t *testing.T) { + var a, b Element + b.SetZero() + for a.IsZero() { + a.SetRandom() + } + a.Neg(&b) + if !a.IsZero() { + t.Fatal("neg(0) != 0") + } +} + +// ------------------------------------------------------------------------------------------------- +// Gopter tests +// most of them are generated with a template + +const ( + nbFuzzShort = 200 + nbFuzz = 1000 +) + +// special values to be used in tests +var staticTestValues []Element + +func init() { + staticTestValues = append(staticTestValues, Element{}) // zero + staticTestValues = append(staticTestValues, One()) // one + staticTestValues = append(staticTestValues, rSquare) // r² + var e, one Element + one.SetOne() + e.Sub(&qElement, &one) + staticTestValues = append(staticTestValues, e) // q - 1 + e.Double(&one) + staticTestValues = append(staticTestValues, e) // 2 + + { + a := qElement + a[0]-- + staticTestValues = append(staticTestValues, a) + } + staticTestValues = append(staticTestValues, Element{0}) + staticTestValues = append(staticTestValues, Element{1}) + staticTestValues = append(staticTestValues, Element{2}) + + { + a := qElement + a[0]-- + staticTestValues = append(staticTestValues, a) + } + + { + a := qElement + a[0] = 0 + staticTestValues = append(staticTestValues, a) + } + +} + +func TestElementReduce(t *testing.T) { + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + s := testValues[i] + expected := s + reduce(&s) + _reduceGeneric(&expected) + if !s.Equal(&expected) { + t.Fatal("reduce failed: asm and generic impl don't match") + } + } + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := genFull() + + properties.Property("reduce should output a result smaller than modulus", prop.ForAll( + func(a Element) bool { + b := a + reduce(&a) + _reduceGeneric(&b) + return a.smallerThanModulus() && a.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementEqual(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("x.Equal(&y) iff x == y; likely false for random pairs", prop.ForAll( + func(a testPairElement, b testPairElement) bool { + return a.element.Equal(&b.element) == (a.element == b.element) + }, + genA, + genB, + )) + + properties.Property("x.Equal(&y) if x == y", prop.ForAll( + func(a testPairElement) bool { + b := a.element + return a.element.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementBytes(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("SetBytes(Bytes()) should stay constant", prop.ForAll( + func(a testPairElement) bool { + var b Element + bytes := a.element.Bytes() + b.SetBytes(bytes[:]) + return a.element.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementInverseExp(t *testing.T) { + // inverse must be equal to exp^-2 + exp := Modulus() + exp.Sub(exp, new(big.Int).SetUint64(2)) + + invMatchExp := func(a testPairElement) bool { + var b Element + b.Set(&a.element) + a.element.Inverse(&a.element) + b.Exp(b, exp) + + return a.element.Equal(&b) + } + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + properties := gopter.NewProperties(parameters) + genA := gen() + properties.Property("inv == exp^-2", prop.ForAll(invMatchExp, genA)) + properties.TestingRun(t, gopter.ConsoleReporter(false)) + + parameters.MinSuccessfulTests = 1 + properties = gopter.NewProperties(parameters) + properties.Property("inv(0) == 0", prop.ForAll(invMatchExp, ggen.OneConstOf(testPairElement{}))) + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func mulByConstant(z *Element, c uint8) { + var y Element + y.SetUint64(uint64(c)) + z.Mul(z, &y) +} + +func TestElementMulByConstants(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + implemented := []uint8{0, 1, 2, 3, 5, 13} + properties.Property("mulByConstant", prop.ForAll( + func(a testPairElement) bool { + for _, c := range implemented { + var constant Element + constant.SetUint64(uint64(c)) + + b := a.element + b.Mul(&b, &constant) + + aa := a.element + mulByConstant(&aa, c) + + if !aa.Equal(&b) { + return false + } + } + + return true + }, + genA, + )) + + properties.Property("MulBy3(x) == Mul(x, 3)", prop.ForAll( + func(a testPairElement) bool { + var constant Element + constant.SetUint64(3) + + b := a.element + b.Mul(&b, &constant) + + MulBy3(&a.element) + + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("MulBy5(x) == Mul(x, 5)", prop.ForAll( + func(a testPairElement) bool { + var constant Element + constant.SetUint64(5) + + b := a.element + b.Mul(&b, &constant) + + MulBy5(&a.element) + + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("MulBy13(x) == Mul(x, 13)", prop.ForAll( + func(a testPairElement) bool { + var constant Element + constant.SetUint64(13) + + b := a.element + b.Mul(&b, &constant) + + MulBy13(&a.element) + + return a.element.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementLegendre(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("legendre should output same result than big.Int.Jacobi", prop.ForAll( + func(a testPairElement) bool { + return a.element.Legendre() == big.Jacobi(&a.bigint, Modulus()) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementBitLen(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( + func(a testPairElement) bool { + return a.element.fromMont().BitLen() == a.bigint.BitLen() + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementButterflies(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("butterfly0 == a -b; a +b", prop.ForAll( + func(a, b testPairElement) bool { + a0, b0 := a.element, b.element + + _butterflyGeneric(&a.element, &b.element) + Butterfly(&a0, &b0) + + return a.element.Equal(&a0) && b.element.Equal(&b0) + }, + genA, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementLexicographicallyLargest(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("element.Cmp should match LexicographicallyLargest output", prop.ForAll( + func(a testPairElement) bool { + var negA Element + negA.Neg(&a.element) + + cmpResult := a.element.Cmp(&negA) + lResult := a.element.LexicographicallyLargest() + + if lResult && cmpResult == 1 { + return true + } + if !lResult && cmpResult != 1 { + return true + } + return false + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementAdd(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Add: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Add(&a.element, &b.element) + a.element.Add(&a.element, &b.element) + b.element.Add(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Add: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Add(&a.element, &b.element) + + var d, e big.Int + d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Add(&a.element, &r) + d.Add(&a.bigint, &rb).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Add: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Add(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Add(&a, &b) + d.Add(&aBig, &bBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Add failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementSub(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Sub: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Sub(&a.element, &b.element) + a.element.Sub(&a.element, &b.element) + b.element.Sub(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Sub: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Sub(&a.element, &b.element) + + var d, e big.Int + d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Sub(&a.element, &r) + d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Sub: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Sub(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Sub(&a, &b) + d.Sub(&aBig, &bBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Sub failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementMul(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Mul: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Mul(&a.element, &b.element) + a.element.Mul(&a.element, &b.element) + b.element.Mul(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Mul: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Mul(&a.element, &b.element) + + var d, e big.Int + d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Mul(&a.element, &r) + d.Mul(&a.bigint, &rb).Mod(&d, Modulus()) + + // checking generic impl against asm path + var cGeneric Element + _mulGeneric(&cGeneric, &a.element, &r) + if !cGeneric.Equal(&c) { + // need to give context to failing error. + return false + } + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Mul: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Mul(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + properties.Property("Mul: assembly implementation must be consistent with generic one", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + c.Mul(&a.element, &b.element) + _mulGeneric(&d, &a.element, &b.element) + return c.Equal(&d) + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Mul(&a, &b) + d.Mul(&aBig, &bBig).Mod(&d, Modulus()) + + // checking asm against generic impl + var cGeneric Element + _mulGeneric(&cGeneric, &a, &b) + if !cGeneric.Equal(&c) { + t.Fatal("Mul failed special test values: asm and generic impl don't match") + } + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Mul failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementDiv(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Div: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Div(&a.element, &b.element) + a.element.Div(&a.element, &b.element) + b.element.Div(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Div: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Div(&a.element, &b.element) + + var d, e big.Int + d.ModInverse(&b.bigint, Modulus()) + d.Mul(&d, &a.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Div(&a.element, &r) + d.ModInverse(&rb, Modulus()) + d.Mul(&d, &a.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Div: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Div(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Div(&a, &b) + d.ModInverse(&bBig, Modulus()) + d.Mul(&d, &aBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Div failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementExp(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Exp: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Exp(a.element, &b.bigint) + a.element.Exp(a.element, &b.bigint) + b.element.Exp(d, &b.bigint) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Exp: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Exp(a.element, &b.bigint) + + var d, e big.Int + d.Exp(&a.bigint, &b.bigint, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Exp(a.element, &rb) + d.Exp(&a.bigint, &rb, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Exp: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Exp(a.element, &b.bigint) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Exp(a, &bBig) + d.Exp(&aBig, &bBig, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Exp failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementSquare(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Square: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Square(&a.element) + a.element.Square(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Square: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Square(&a.element) + + var d, e big.Int + d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Square: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Square(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Square(&a) + + var d, e big.Int + d.Mul(&aBig, &aBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Square failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementInverse(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Inverse: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Inverse(&a.element) + a.element.Inverse(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Inverse: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Inverse(&a.element) + + var d, e big.Int + d.ModInverse(&a.bigint, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Inverse: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Inverse(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Inverse(&a) + + var d, e big.Int + d.ModInverse(&aBig, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Inverse failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementSqrt(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Sqrt: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + b := a.element + + b.Sqrt(&a.element) + a.element.Sqrt(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Sqrt: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Sqrt(&a.element) + + var d, e big.Int + d.ModSqrt(&a.bigint, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Sqrt: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Sqrt(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Sqrt(&a) + + var d, e big.Int + d.ModSqrt(&aBig, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Sqrt failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementDouble(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Double: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Double(&a.element) + a.element.Double(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Double: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Double(&a.element) + + var d, e big.Int + d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Double: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Double(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Double(&a) + + var d, e big.Int + d.Lsh(&aBig, 1).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Double failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementNeg(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Neg: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Neg(&a.element) + a.element.Neg(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Neg: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Neg(&a.element) + + var d, e big.Int + d.Neg(&a.bigint).Mod(&d, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Neg: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Neg(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Neg(&a) + + var d, e big.Int + d.Neg(&aBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Neg failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementHalve(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + var twoInv Element + twoInv.SetUint64(2) + twoInv.Inverse(&twoInv) + + properties.Property("z.Halve must match z / 2", prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.Halve() + d.Mul(&d, &twoInv) + return c.Equal(&d) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func combineSelectionArguments(c int64, z int8) int { + if z%3 == 0 { + return 0 + } + return int(c) +} + +func TestElementSelect(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := genFull() + genB := genFull() + genC := ggen.Int64() //the condition + genZ := ggen.Int8() //to make zeros artificially more likely + + properties.Property("Select: must select correctly", prop.ForAll( + func(a, b Element, cond int64, z int8) bool { + condC := combineSelectionArguments(cond, z) + + var c Element + c.Select(condC, &a, &b) + + if condC == 0 { + return c.Equal(&a) + } + return c.Equal(&b) + }, + genA, + genB, + genC, + genZ, + )) + + properties.Property("Select: having the receiver as operand should output the same result", prop.ForAll( + func(a, b Element, cond int64, z int8) bool { + condC := combineSelectionArguments(cond, z) + + var c, d Element + d.Set(&a) + c.Select(condC, &a, &b) + a.Select(condC, &a, &b) + b.Select(condC, &d, &b) + return a.Equal(&b) && a.Equal(&c) && b.Equal(&c) + }, + genA, + genB, + genC, + genZ, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementSetInt64(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("z.SetInt64 must match z.SetString", prop.ForAll( + func(a testPairElement, v int64) bool { + c := a.element + d := a.element + + c.SetInt64(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, ggen.Int64(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementSetInterface(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genInt := ggen.Int + genInt8 := ggen.Int8 + genInt16 := ggen.Int16 + genInt32 := ggen.Int32 + genInt64 := ggen.Int64 + + genUint := ggen.UInt + genUint8 := ggen.UInt8 + genUint16 := ggen.UInt16 + genUint32 := ggen.UInt32 + genUint64 := ggen.UInt64 + + properties.Property("z.SetInterface must match z.SetString with int8", prop.ForAll( + func(a testPairElement, v int8) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt8(), + )) + + properties.Property("z.SetInterface must match z.SetString with int16", prop.ForAll( + func(a testPairElement, v int16) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt16(), + )) + + properties.Property("z.SetInterface must match z.SetString with int32", prop.ForAll( + func(a testPairElement, v int32) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt32(), + )) + + properties.Property("z.SetInterface must match z.SetString with int64", prop.ForAll( + func(a testPairElement, v int64) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt64(), + )) + + properties.Property("z.SetInterface must match z.SetString with int", prop.ForAll( + func(a testPairElement, v int) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint8", prop.ForAll( + func(a testPairElement, v uint8) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint8(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint16", prop.ForAll( + func(a testPairElement, v uint16) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint16(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint32", prop.ForAll( + func(a testPairElement, v uint32) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint32(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint64", prop.ForAll( + func(a testPairElement, v uint64) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint64(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint", prop.ForAll( + func(a testPairElement, v uint) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + + { + assert := require.New(t) + var e Element + r, err := e.SetInterface(nil) + assert.Nil(r) + assert.Error(err) + + var ptE *Element + var ptB *big.Int + + r, err = e.SetInterface(ptE) + assert.Nil(r) + assert.Error(err) + ptE = new(Element).SetOne() + r, err = e.SetInterface(ptE) + assert.NoError(err) + assert.True(r.IsOne()) + + r, err = e.SetInterface(ptB) + assert.Nil(r) + assert.Error(err) + + } +} + +func TestElementNegativeExp(t *testing.T) { + t.Parallel() + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("x⁻ᵏ == 1/xᵏ", prop.ForAll( + func(a, b testPairElement) bool { + + var nb, d, e big.Int + nb.Neg(&b.bigint) + + var c Element + c.Exp(a.element, &nb) + + d.Exp(&a.bigint, &nb, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementNewElement(t *testing.T) { + assert := require.New(t) + + t.Parallel() + + e := NewElement(1) + assert.True(e.IsOne()) + + e = NewElement(0) + assert.True(e.IsZero()) +} + +func TestElementBatchInvert(t *testing.T) { + assert := require.New(t) + + t.Parallel() + + // ensure batchInvert([x]) == invert(x) + for i := int64(-1); i <= 2; i++ { + var e, eInv Element + e.SetInt64(i) + eInv.Inverse(&e) + + a := []Element{e} + aInv := BatchInvert(a) + + assert.True(aInv[0].Equal(&eInv), "batchInvert != invert") + + } + + // test x * x⁻¹ == 1 + tData := [][]int64{ + {-1, 1, 2, 3}, + {0, -1, 1, 2, 3, 0}, + {0, -1, 1, 0, 2, 3, 0}, + {-1, 1, 0, 2, 3}, + {0, 0, 1}, + {1, 0, 0}, + {0, 0, 0}, + } + + for _, t := range tData { + a := make([]Element, len(t)) + for i := 0; i < len(a); i++ { + a[i].SetInt64(t[i]) + } + + aInv := BatchInvert(a) + + assert.True(len(aInv) == len(a)) + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + assert.True(aInv[i].IsZero(), "0⁻¹ != 0") + } else { + assert.True(a[i].Mul(&a[i], &aInv[i]).IsOne(), "x * x⁻¹ != 1") + } + } + } + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("batchInvert --> x * x⁻¹ == 1", prop.ForAll( + func(tp testPairElement, r uint8) bool { + + a := make([]Element, r) + if r != 0 { + a[0] = tp.element + + } + one := One() + for i := 1; i < len(a); i++ { + a[i].Add(&a[i-1], &one) + } + + aInv := BatchInvert(a) + + assert.True(len(aInv) == len(a)) + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + if !aInv[i].IsZero() { + return false + } + } else { + if !a[i].Mul(&a[i], &aInv[i]).IsOne() { + return false + } + } + } + return true + }, + genA, ggen.UInt8(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementFromMont(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Assembly implementation must be consistent with generic one", prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.fromMont() + _fromMontGeneric(&d) + return c.Equal(&d) + }, + genA, + )) + + properties.Property("x.fromMont().toMont() == x", prop.ForAll( + func(a testPairElement) bool { + c := a.element + c.fromMont().toMont() + return c.Equal(&a.element) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementJSON(t *testing.T) { + assert := require.New(t) + + type S struct { + A Element + B [3]Element + C *Element + D *Element + } + + // encode to JSON + var s S + s.A.SetString("-1") + s.B[2].SetUint64(42) + s.D = new(Element).SetUint64(8000) + + encoded, err := json.Marshal(&s) + assert.NoError(err) + // we may need to adjust "42" and "8000" values for some moduli; see Text() method for more details. + formatValue := func(v int64) string { + var a big.Int + a.SetInt64(v) + a.Mod(&a, Modulus()) + const maxUint16 = 65535 + var aNeg big.Int + aNeg.Neg(&a).Mod(&aNeg, Modulus()) + if aNeg.Uint64() != 0 && aNeg.Uint64() <= maxUint16 { + return "-" + aNeg.Text(10) + } + return a.Text(10) + } + expected := fmt.Sprintf("{\"A\":%s,\"B\":[0,0,%s],\"C\":null,\"D\":%s}", formatValue(-1), formatValue(42), formatValue(8000)) + assert.Equal(expected, string(encoded)) + + // decode valid + var decoded S + err = json.Unmarshal([]byte(expected), &decoded) + assert.NoError(err) + + assert.Equal(s, decoded, "element -> json -> element round trip failed") + + // decode hex and string values + withHexValues := "{\"A\":\"-1\",\"B\":[0,\"0x00000\",\"0x2A\"],\"C\":null,\"D\":\"8000\"}" + + var decodedS S + err = json.Unmarshal([]byte(withHexValues), &decodedS) + assert.NoError(err) + + assert.Equal(s, decodedS, " json with strings -> element failed") + +} + +type testPairElement struct { + element Element + bigint big.Int +} + +func gen() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var g testPairElement + + g.element = Element{ + genParams.NextUint64(), + } + if qElement[0] != ^uint64(0) { + g.element[0] %= (qElement[0] + 1) + } + + for !g.element.smallerThanModulus() { + g.element = Element{ + genParams.NextUint64(), + } + if qElement[0] != ^uint64(0) { + g.element[0] %= (qElement[0] + 1) + } + } + + g.element.BigInt(&g.bigint) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genRandomFq(genParams *gopter.GenParameters) Element { + var g Element + + g = Element{ + genParams.NextUint64(), + } + + if qElement[0] != ^uint64(0) { + g[0] %= (qElement[0] + 1) + } + + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + } + if qElement[0] != ^uint64(0) { + g[0] %= (qElement[0] + 1) + } + } + + return g +} + +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + + var carry uint64 + a[0], _ = bits.Add64(a[0], qElement[0], carry) + + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} + +func genElement() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} diff --git a/internal/smallfields/babybear/vector.go b/internal/smallfields/babybear/vector.go new file mode 100644 index 000000000..a404f074f --- /dev/null +++ b/internal/smallfields/babybear/vector.go @@ -0,0 +1,340 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package babybear + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "runtime" + "strings" + "sync" + "sync/atomic" + "unsafe" +) + +// Vector represents a slice of Element. +// +// It implements the following interfaces: +// - Stringer +// - io.WriterTo +// - io.ReaderFrom +// - encoding.BinaryMarshaler +// - encoding.BinaryUnmarshaler +// - sort.Interface +type Vector []Element + +// MarshalBinary implements encoding.BinaryMarshaler +func (vector *Vector) MarshalBinary() (data []byte, err error) { + var buf bytes.Buffer + + if _, err = vector.WriteTo(&buf); err != nil { + return + } + return buf.Bytes(), nil +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler +func (vector *Vector) UnmarshalBinary(data []byte) error { + r := bytes.NewReader(data) + _, err := vector.ReadFrom(r) + return err +} + +// WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. +// Length of the vector is encoded as a uint32 on the first 4 bytes. +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { + // encode slice length + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { + return 0, err + } + + n := int64(4) + + var buf [Bytes]byte + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) + m, err := w.Write(buf[:]) + n += int64(m) + if err != nil { + return n, err + } + } + return n, nil +} + +// AsyncReadFrom reads a vector of big endian encoded Element. +// Length of the vector must be encoded as a uint32 on the first 4 bytes. +// It consumes the needed bytes from the reader and returns the number of bytes read and an error if any. +// It also returns a channel that will be closed when the validation is done. +// The validation consist of checking that the elements are smaller than the modulus, and +// converting them to montgomery form. +func (vector *Vector) AsyncReadFrom(r io.Reader) (int64, error, chan error) { + chErr := make(chan error, 1) + var buf [Bytes]byte + if read, err := io.ReadFull(r, buf[:4]); err != nil { + close(chErr) + return int64(read), err, chErr + } + sliceLen := binary.BigEndian.Uint32(buf[:4]) + + n := int64(4) + (*vector) = make(Vector, sliceLen) + if sliceLen == 0 { + close(chErr) + return n, nil, chErr + } + + bSlice := unsafe.Slice((*byte)(unsafe.Pointer(&(*vector)[0])), sliceLen*Bytes) + read, err := io.ReadFull(r, bSlice) + n += int64(read) + if err != nil { + close(chErr) + return n, err, chErr + } + + go func() { + var cptErrors uint64 + // process the elements in parallel + execute(int(sliceLen), func(start, end int) { + + var z Element + for i := start; i < end; i++ { + // we have to set vector[i] + bstart := i * Bytes + bend := bstart + Bytes + b := bSlice[bstart:bend] + z[0] = binary.BigEndian.Uint64(b[0:8]) + + if !z.smallerThanModulus() { + atomic.AddUint64(&cptErrors, 1) + return + } + z.toMont() + (*vector)[i] = z + } + }) + + if cptErrors > 0 { + chErr <- fmt.Errorf("async read: %d elements failed validation", cptErrors) + } + close(chErr) + }() + return n, nil, chErr +} + +// ReadFrom implements io.ReaderFrom and reads a vector of big endian encoded Element. +// Length of the vector must be encoded as a uint32 on the first 4 bytes. +func (vector *Vector) ReadFrom(r io.Reader) (int64, error) { + + var buf [Bytes]byte + if read, err := io.ReadFull(r, buf[:4]); err != nil { + return int64(read), err + } + sliceLen := binary.BigEndian.Uint32(buf[:4]) + + n := int64(4) + (*vector) = make(Vector, sliceLen) + + for i := 0; i < int(sliceLen); i++ { + read, err := io.ReadFull(r, buf[:]) + n += int64(read) + if err != nil { + return n, err + } + (*vector)[i], err = BigEndian.Element(&buf) + if err != nil { + return n, err + } + } + + return n, nil +} + +// String implements fmt.Stringer interface +func (vector Vector) String() string { + var sbb strings.Builder + sbb.WriteByte('[') + for i := 0; i < len(vector); i++ { + sbb.WriteString(vector[i].String()) + if i != len(vector)-1 { + sbb.WriteByte(',') + } + } + sbb.WriteByte(']') + return sbb.String() +} + +// Len is the number of elements in the collection. +func (vector Vector) Len() int { + return len(vector) +} + +// Less reports whether the element with +// index i should sort before the element with index j. +func (vector Vector) Less(i, j int) bool { + return vector[i].Cmp(&vector[j]) == -1 +} + +// Swap swaps the elements with indexes i and j. +func (vector Vector) Swap(i, j int) { + vector[i], vector[j] = vector[j], vector[i] +} + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} + +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + +// TODO @gbotrel make a public package out of that. +// execute executes the work function in parallel. +// this is copy paste from internal/parallel/parallel.go +// as we don't want to generate code importing internal/ +func execute(nbIterations int, work func(int, int), maxCpus ...int) { + + nbTasks := runtime.NumCPU() + if len(maxCpus) == 1 { + nbTasks = maxCpus[0] + if nbTasks < 1 { + nbTasks = 1 + } else if nbTasks > 512 { + nbTasks = 512 + } + } + + if nbTasks == 1 { + // no go routines + work(0, nbIterations) + return + } + + nbIterationsPerCpus := nbIterations / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = nbIterations + } + + var wg sync.WaitGroup + + extraTasks := nbIterations - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + go func() { + work(_start, _end) + wg.Done() + }() + } + + wg.Wait() +} diff --git a/internal/smallfields/babybear/vector_test.go b/internal/smallfields/babybear/vector_test.go new file mode 100644 index 000000000..abc2da2b0 --- /dev/null +++ b/internal/smallfields/babybear/vector_test.go @@ -0,0 +1,365 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package babybear + +import ( + "bytes" + "fmt" + "github.com/stretchr/testify/require" + "os" + "reflect" + "sort" + "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" +) + +func TestVectorSort(t *testing.T) { + assert := require.New(t) + + v := make(Vector, 3) + v[0].SetUint64(2) + v[1].SetUint64(3) + v[2].SetUint64(1) + + sort.Sort(v) + + assert.Equal("[1,2,3]", v.String()) +} + +func TestVectorRoundTrip(t *testing.T) { + assert := require.New(t) + + v1 := make(Vector, 3) + v1[0].SetUint64(2) + v1[1].SetUint64(3) + v1[2].SetUint64(1) + + b, err := v1.MarshalBinary() + assert.NoError(err) + + var v2, v3 Vector + + err = v2.UnmarshalBinary(b) + assert.NoError(err) + + err = v3.unmarshalBinaryAsync(b) + assert.NoError(err) + + assert.True(reflect.DeepEqual(v1, v2)) + assert.True(reflect.DeepEqual(v3, v2)) +} + +func TestVectorEmptyRoundTrip(t *testing.T) { + assert := require.New(t) + + v1 := make(Vector, 0) + + b, err := v1.MarshalBinary() + assert.NoError(err) + + var v2, v3 Vector + + err = v2.UnmarshalBinary(b) + assert.NoError(err) + + err = v3.unmarshalBinaryAsync(b) + assert.NoError(err) + + assert.True(reflect.DeepEqual(v1, v2)) + assert.True(reflect.DeepEqual(v3, v2)) +} + +func (vector *Vector) unmarshalBinaryAsync(data []byte) error { + r := bytes.NewReader(data) + _, err, chErr := vector.AsyncReadFrom(r) + if err != nil { + return err + } + return <-chErr +} + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b Element) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum Element + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct Element + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4, 8, 9, 15, 16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + genElement(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1 << 24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer Element + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n := 1 << 4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := Element{ + genParams.NextUint64(), + } + if qElement[0] != ^uint64(0) { + mixer[0] %= (qElement[0] + 1) + } + + for !mixer.smallerThanModulus() { + mixer = Element{ + genParams.NextUint64(), + } + if qElement[0] != ^uint64(0) { + mixer[0] %= (qElement[0] + 1) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} From 1d1fbcbee284deb5f7f22f5cf441c86a157aafa3 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 25 Nov 2024 12:27:17 +0000 Subject: [PATCH 05/21] feat: automate small field imports --- internal/generator/backend/main.go | 10 ++++----- .../backend/template/imports.go.tmpl | 22 +++++++++---------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/internal/generator/backend/main.go b/internal/generator/backend/main.go index 53eb3e33c..8daf2c1db 100644 --- a/internal/generator/backend/main.go +++ b/internal/generator/backend/main.go @@ -68,7 +68,7 @@ func main() { CurveID: "UNKNOWN", noBackend: true, NoGKR: true, - autoGenerateField: "0x2f", + AutoGenerateField: "0x2f", } baby_bear_field := templateData{ RootPath: "../../../internal/smallfields/babybear/", @@ -77,7 +77,7 @@ func main() { CurveID: "UNKNOWN", noBackend: true, NoGKR: true, - autoGenerateField: "0x78000001", + AutoGenerateField: "0x78000001", } datas := []templateData{ @@ -103,8 +103,8 @@ func main() { go func(d templateData) { defer wg.Done() // auto-generate small fields - if d.autoGenerateField != "" { - conf, err := config.NewFieldConfig(d.Curve, "Element", d.autoGenerateField, false) + if d.AutoGenerateField != "" { + conf, err := config.NewFieldConfig(d.Curve, "Element", d.AutoGenerateField, false) if err != nil { panic(err) } @@ -235,7 +235,7 @@ type templateData struct { Curve string CurveID string - autoGenerateField string + AutoGenerateField string noBackend bool NoGKR bool } diff --git a/internal/generator/backend/template/imports.go.tmpl b/internal/generator/backend/template/imports.go.tmpl index aa9c942dd..fc5b188bc 100644 --- a/internal/generator/backend/template/imports.go.tmpl +++ b/internal/generator/backend/template/imports.go.tmpl @@ -1,14 +1,14 @@ {{- define "import_fr" }} - {{- if eq .Curve "tinyfield"}} - fr "github.com/consensys/gnark/internal/smallfields/tinyfield" + {{- if ne .AutoGenerateField "" }} + fr "github.com/consensys/gnark/internal/smallfields/{{ .Curve }}" {{- else}} "github.com/consensys/gnark-crypto/ecc/{{toLower .Curve}}/fr" {{- end}} {{- end }} {{- define "import_fp" }} - {{- if eq .Curve "tinyfield"}} - fr "github.com/consensys/gnark/internal/smallfields/tinyfield" + {{- if ne .AutoGenerateField "" }} + fp "github.com/consensys/gnark/internal/smallfields/{{ .Curve }}" {{- else}} "github.com/consensys/gnark-crypto/ecc/{{toLower .Curve}}/fp" {{- end}} @@ -20,16 +20,16 @@ {{- end}} {{- define "import_curve" }} - {{- if ne .Curve "tinyfield"}} - curve "github.com/consensys/gnark-crypto/ecc/{{toLower .Curve}}" - {{- else }} + {{- if ne .AutoGenerateField "" }} "github.com/consensys/gnark-crypto/ecc" + {{- else}} + curve "github.com/consensys/gnark-crypto/ecc/{{toLower .Curve}}" {{- end}} {{- end }} {{- define "import_backend_cs" }} - {{- if eq .Curve "tinyfield"}} - "github.com/consensys/gnark/constraint/tinyfield" + {{- if ne .AutoGenerateField "" }} + cs "github.com/consensys/gnark/constraint/{{ .Curve }}" {{- else}} cs "github.com/consensys/gnark/constraint/{{toLower .Curve}}" {{- end}} @@ -40,8 +40,8 @@ {{- end }} {{- define "import_witness" }} - {{- if eq .Curve "tinyfield"}} - {{toLower .CurveID}}witness "github.com/consensys/gnark/internal/smallfields/tinyfield/witness" + {{- if ne .AutoGenerateField "" }} + {{toLower .CurveID}}witness "github.com/consensys/gnark/internal/smallfields/{{ .Curve }}/witness" {{- else}} {{toLower .CurveID}}witness "github.com/consensys/gnark/internal/backend/{{toLower .Curve}}/witness" {{- end}} From 2e1453f4c3f8c4c262706dead390c2fb6a1ab404 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 25 Nov 2024 12:30:39 +0000 Subject: [PATCH 06/21] chore: generate --- constraint/babybear/coeff.go | 2 +- constraint/babybear/r1cs_test.go | 2 +- constraint/babybear/solver.go | 2 +- constraint/babybear/system.go | 2 +- constraint/tinyfield/r1cs_test.go | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/constraint/babybear/coeff.go b/constraint/babybear/coeff.go index 79a51b786..528df7679 100644 --- a/constraint/babybear/coeff.go +++ b/constraint/babybear/coeff.go @@ -23,7 +23,7 @@ import ( "github.com/consensys/gnark/internal/utils" "math/big" - "github.com/consensys/gnark-crypto/ecc/babybear/fr" + fr "github.com/consensys/gnark/internal/smallfields/babybear" ) // CoeffTable ensure we store unique coefficients in the constraint system diff --git a/constraint/babybear/r1cs_test.go b/constraint/babybear/r1cs_test.go index a99e33b36..980647958 100644 --- a/constraint/babybear/r1cs_test.go +++ b/constraint/babybear/r1cs_test.go @@ -30,7 +30,7 @@ import ( cs "github.com/consensys/gnark/constraint/babybear" - "github.com/consensys/gnark-crypto/ecc/babybear/fr" + fr "github.com/consensys/gnark/internal/smallfields/babybear" ) func TestSerialization(t *testing.T) { diff --git a/constraint/babybear/solver.go b/constraint/babybear/solver.go index 33ce08640..f77278c85 100644 --- a/constraint/babybear/solver.go +++ b/constraint/babybear/solver.go @@ -31,7 +31,7 @@ import ( "sync" "sync/atomic" - "github.com/consensys/gnark-crypto/ecc/babybear/fr" + fr "github.com/consensys/gnark/internal/smallfields/babybear" ) // solver represent the state of the solver during a call to System.Solve(...) diff --git a/constraint/babybear/system.go b/constraint/babybear/system.go index 6e98dfbbd..1e6e65f54 100644 --- a/constraint/babybear/system.go +++ b/constraint/babybear/system.go @@ -27,7 +27,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark-crypto/ecc/babybear/fr" + fr "github.com/consensys/gnark/internal/smallfields/babybear" ) type R1CS = system diff --git a/constraint/tinyfield/r1cs_test.go b/constraint/tinyfield/r1cs_test.go index 7b678816d..3196b20c5 100644 --- a/constraint/tinyfield/r1cs_test.go +++ b/constraint/tinyfield/r1cs_test.go @@ -28,7 +28,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/consensys/gnark/constraint/tinyfield" + cs "github.com/consensys/gnark/constraint/tinyfield" fr "github.com/consensys/gnark/internal/smallfields/tinyfield" ) From 73ce430f6d8f1acf46803953d5dd1480babbd758 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 25 Nov 2024 14:28:46 +0000 Subject: [PATCH 07/21] feat: support babybear in builders --- backend/witness/vector.go | 19 +++++++++++++++++-- frontend/cs/r1cs/builder.go | 6 ++++++ frontend/cs/scs/builder.go | 6 ++++++ 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/backend/witness/vector.go b/backend/witness/vector.go index 94f309a80..3bbe17516 100644 --- a/backend/witness/vector.go +++ b/backend/witness/vector.go @@ -13,6 +13,7 @@ import ( fr_bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr" fr_bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" fr_bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark/internal/smallfields/babybear" "github.com/consensys/gnark/internal/smallfields/tinyfield" "github.com/consensys/gnark/internal/utils" ) @@ -37,9 +38,11 @@ func newVector(field *big.Int, size int) (any, error) { default: if field.Cmp(tinyfield.Modulus()) == 0 { return make(tinyfield.Vector, size), nil - } else { - return nil, errors.New("unsupported modulus") } + if field.Cmp(babybear.Modulus()) == 0 { + return make(babybear.Vector, size), nil + } + return nil, errors.New("unsupported modulus") } } @@ -77,6 +80,10 @@ func newFrom(from any, n int) (any, error) { a := make(tinyfield.Vector, n) copy(a, wt) return a, nil + case babybear.Vector: + a := make(babybear.Vector, n) + copy(a, wt) + return a, nil default: return nil, errors.New("unsupported modulus") } @@ -155,6 +162,12 @@ func set(v any, index int, value any) error { } _, err := pv[index].SetInterface(value) return err + case babybear.Vector: + if index >= len(pv) { + return errors.New("out of bounds") + } + _, err := pv[index].SetInterface(value) + return err default: panic("invalid input") } @@ -243,6 +256,8 @@ func resize(v any, n int) any { return make(fr_bw6633.Vector, n) case tinyfield.Vector: return make(tinyfield.Vector, n) + case babybear.Vector: + return make(babybear.Vector, n) default: panic("invalid input") } diff --git a/frontend/cs/r1cs/builder.go b/frontend/cs/r1cs/builder.go index e4ca2b71f..0a7d63efc 100644 --- a/frontend/cs/r1cs/builder.go +++ b/frontend/cs/r1cs/builder.go @@ -31,10 +31,12 @@ import ( "github.com/consensys/gnark/internal/circuitdefer" "github.com/consensys/gnark/internal/frontendtype" "github.com/consensys/gnark/internal/kvstore" + "github.com/consensys/gnark/internal/smallfields/babybear" "github.com/consensys/gnark/internal/smallfields/tinyfield" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" + babybearr1cs "github.com/consensys/gnark/constraint/babybear" bls12377r1cs "github.com/consensys/gnark/constraint/bls12-377" bls12381r1cs "github.com/consensys/gnark/constraint/bls12-381" bls24315r1cs "github.com/consensys/gnark/constraint/bls24-315" @@ -114,6 +116,10 @@ func newBuilder(field *big.Int, config frontend.CompileConfig) *builder { builder.cs = tinyfieldr1cs.NewR1CS(config.Capacity) break } + if field.Cmp(babybear.Modulus()) == 0 { + builder.cs = babybearr1cs.NewR1CS(config.Capacity) + break + } panic("not implemented") } diff --git a/frontend/cs/scs/builder.go b/frontend/cs/scs/builder.go index eaeaf544c..a5420acce 100644 --- a/frontend/cs/scs/builder.go +++ b/frontend/cs/scs/builder.go @@ -30,10 +30,12 @@ import ( "github.com/consensys/gnark/frontend/schema" "github.com/consensys/gnark/internal/circuitdefer" "github.com/consensys/gnark/internal/kvstore" + "github.com/consensys/gnark/internal/smallfields/babybear" "github.com/consensys/gnark/internal/smallfields/tinyfield" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" + babybearr1cs "github.com/consensys/gnark/constraint/babybear" bls12377r1cs "github.com/consensys/gnark/constraint/bls12-377" bls12381r1cs "github.com/consensys/gnark/constraint/bls12-381" bls24315r1cs "github.com/consensys/gnark/constraint/bls24-315" @@ -112,6 +114,10 @@ func newBuilder(field *big.Int, config frontend.CompileConfig) *builder { b.cs = tinyfieldr1cs.NewSparseR1CS(config.Capacity) break } + if field.Cmp(babybear.Modulus()) == 0 { + b.cs = babybearr1cs.NewSparseR1CS(config.Capacity) + break + } panic("not implemented") } From a3b3814c16bcaf89480544ab96e3d35d9ba8af1a Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Tue, 3 Dec 2024 00:13:31 +0000 Subject: [PATCH 08/21] feat: add field-extension package --- std/math/fieldextension/default_extensions.go | 15 ++ std/math/fieldextension/fieldextension.go | 174 ++++++++++++++ .../fieldextension/fieldextension_test.go | 212 ++++++++++++++++++ std/math/fieldextension/option.go | 36 +++ 4 files changed, 437 insertions(+) create mode 100644 std/math/fieldextension/default_extensions.go create mode 100644 std/math/fieldextension/fieldextension.go create mode 100644 std/math/fieldextension/fieldextension_test.go create mode 100644 std/math/fieldextension/option.go diff --git a/std/math/fieldextension/default_extensions.go b/std/math/fieldextension/default_extensions.go new file mode 100644 index 000000000..de257316e --- /dev/null +++ b/std/math/fieldextension/default_extensions.go @@ -0,0 +1,15 @@ +package fieldextension + +var defaultExtensions = map[string][]int{ + "2013265921-default": {11, 0, 0, 0, 0, 0, 0, 0, 1}, // x^8 - 11 -- BabyBear field + "2013265921-8": {11, 0, 0, 0, 0, 0, 0, 0, 1}, // x^8 - 11 -- BabyBear field + "2013265921-4": {11, 0, 0, 0, 1}, // x^4 - 11 -- BabyBear field + + "2130706433-default": {3, 0, 0, 0, 0, 0, 0, 0, 1}, // x^8 - 3 -- KoalaBear field + "2130706433-8": {3, 0, 0, 0, 0, 0, 0, 0, 1}, // x^8 - 3 -- KoalaBear field + "2130706433-4": {3, 0, 0, 0, 1}, // x^4 - 3 -- KoalaBear field + + "18446744069414584321-default": {7, 0, 0, 0, 1}, // x^4 - 7 -- Goldilocks field + "18446744069414584321-4": {7, 0, 0, 0, 1}, // x^4 - 7 -- Goldilocks field + "18446744069414584321-2": {7, 0, 1}, // x^2 - 7 -- Goldilocks field +} diff --git a/std/math/fieldextension/fieldextension.go b/std/math/fieldextension/fieldextension.go new file mode 100644 index 000000000..3504384a9 --- /dev/null +++ b/std/math/fieldextension/fieldextension.go @@ -0,0 +1,174 @@ +// Package fieldextension provides operations over an extension field of the native field. +// +// The operations inside the circuit are performed in the native field. In case +// of small fields, we need to perform some operations over an extension field +// to achieve the required soundness level. This package provides some +// primitives to perform such operations. +package fieldextension + +import ( + "fmt" + "strconv" + + "github.com/consensys/gnark/frontend" +) + +type extensionType int + +const ( + minimal extensionType = iota // x^n + 1 + simple // x^n + d + generic // everything else +) + +type Extension struct { + api frontend.API + + extension []int // we expect the extension defining modulus to have small small coefficients + extensionType +} + +// NewExtension returns a new extension field object. +func NewExtension(api frontend.API, opts ...Option) (*Extension, error) { + cfg, err := newConfig(opts...) + if err != nil { + return nil, fmt.Errorf("apply options: %w", err) + } + // extension is provided + if cfg.extension != nil { + if cfg.extension[len(cfg.extension)-1] != 1 { + return nil, fmt.Errorf("last coefficient of the extension must be 1") + } + et := simple + if cfg.extension[0] == 1 { + et = minimal + } + for i := 1; i < len(cfg.extension)-1; i++ { + if cfg.extension[i] != 0 { + et = generic + break + } + } + return &Extension{api: api, extension: cfg.extension, extensionType: et}, nil + } + degree := "default" + if cfg.degree != -1 { + degree = strconv.Itoa(cfg.degree) + } + + extension, ok := defaultExtensions[fmt.Sprintf("%s-%s", api.Compiler().Field(), degree)] + if !ok { + return nil, fmt.Errorf("no default extension for native modulus and not explicit extension provided") + } + return &Extension{api: api, extension: extension, extensionType: simple}, nil +} + +type ExtensionVariable []frontend.Variable + +func (e *Extension) Reduce(a ExtensionVariable) []frontend.Variable { + if e.extensionType == generic { + // TODO: implement later + panic("not implemented") + } + if len(a) < len(e.extension) { + // no reduction needed + return a + } + // we don't want to change a in place + ret := make([]frontend.Variable, len(a)) + copy(ret, a) + for len(ret) >= len(e.extension) { + q := ret[len(e.extension)-1:] + if e.extensionType == simple { + // in case we have minimal extension, we don't need to multiply q by + // the extension + q = e.MulByElement(q, e.extension[0]) + } + commonLen := min(len(q), len(e.extension)-1) + for i := 0; i < commonLen; i++ { + ret[i] = e.api.Add(ret[i], q[i]) + } + for i := commonLen; i < len(q); i++ { + ret[i] = q[i] + } + ret = ret[:max(len(q), len(e.extension)-1)] + } + return ret +} + +func (e *Extension) Mul(a, b ExtensionVariable) ExtensionVariable { + ret := e.MulNoReduce(a, b) + return e.Reduce(ret) +} + +func (e *Extension) MulNoReduce(a, b ExtensionVariable) ExtensionVariable { + ret := make([]frontend.Variable, len(a)+len(b)-1) + for i := range ret { + ret[i] = 0 + } + for i := range a { + for j := range b { + ret[i+j] = e.api.Add(ret[i+j], e.api.Mul(a[i], b[j])) + } + } + return ret +} + +func (e *Extension) Add(a, b ExtensionVariable) ExtensionVariable { + commonLen := min(len(a), len(b)) + ret := make([]frontend.Variable, max(len(a), len(b))) + for i := 0; i < commonLen; i++ { + ret[i] = e.api.Add(a[i], b[i]) + } + for i := commonLen; i < len(a); i++ { + ret[i] = a[i] + } + for i := commonLen; i < len(b); i++ { + ret[i] = b[i] + } + return ret +} + +func (e *Extension) Sub(a, b ExtensionVariable) ExtensionVariable { + commonLen := min(len(a), len(b)) + ret := make([]frontend.Variable, max(len(a), len(b))) + for i := 0; i < commonLen; i++ { + ret[i] = e.api.Sub(a[i], b[i]) + } + for i := commonLen; i < len(a); i++ { + ret[i] = a[i] + } + for i := commonLen; i < len(b); i++ { + ret[i] = e.api.Neg(b[i]) + } + return ret +} + +func (e *Extension) Div(a, b ExtensionVariable) ExtensionVariable { + panic("not implemented") +} + +func (e *Extension) Inverse(a ExtensionVariable) ExtensionVariable { + panic("not implemented") +} + +func (e *Extension) MulByElement(a ExtensionVariable, b frontend.Variable) ExtensionVariable { + ret := make([]frontend.Variable, len(a)) + for i := range a { + ret[i] = e.api.Mul(a[i], b) + } + return ret +} + +func (e *Extension) AssertIsEqual(a, b ExtensionVariable) { + commonLen := min(len(a), len(b)) + for i := 0; i < commonLen; i++ { + e.api.AssertIsEqual(a[i], b[i]) + } + for i := commonLen; i < len(a); i++ { + e.api.AssertIsEqual(a[i], 0) + } + for i := commonLen; i < len(b); i++ { + e.api.AssertIsEqual(b[i], 0) + } +} diff --git a/std/math/fieldextension/fieldextension_test.go b/std/math/fieldextension/fieldextension_test.go new file mode 100644 index 000000000..6099fdf48 --- /dev/null +++ b/std/math/fieldextension/fieldextension_test.go @@ -0,0 +1,212 @@ +package fieldextension + +import ( + "testing" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/smallfields/babybear" + "github.com/consensys/gnark/test" +) + +type ReduceCircut struct { + Input []frontend.Variable + Reduced []frontend.Variable +} + +func (c *ReduceCircut) Define(api frontend.API) error { + e, err := NewExtension(api) + if err != nil { + return err + } + res := e.Reduce(c.Input) + e.AssertIsEqual(c.Reduced, res) + return nil +} + +func TestReduce(t *testing.T) { + assert := test.NewAssert(t) + for _, tc := range []struct { + input, reduced []int + }{ + {[]int{1467980320, 1137445292}, []int{1467980320, 1137445292, 0, 0, 0, 0, 0, 0}}, + {[]int{1906025257, 900972831, 355994451}, []int{1906025257, 900972831, 355994451, 0, 0, 0, 0, 0}}, + {[]int{1315269736, 1305411155, 1484949641, 1487157818}, []int{1315269736, 1305411155, 1484949641, 1487157818, 0, 0, 0, 0}}, + {[]int{930655562, 191916507, 245232235, 249903878, 1688769114}, []int{930655562, 191916507, 245232235, 249903878, 1688769114, 0, 0, 0}}, + {[]int{1900558240, 1034669852, 62012066, 1636768938, 1951223124, 800157949}, []int{1900558240, 1034669852, 62012066, 1636768938, 1951223124, 800157949, 0, 0}}, + {[]int{1506768621, 1188015241, 521233244, 464809937, 288133325, 339109914, 1107846641}, []int{1506768621, 1188015241, 521233244, 464809937, 288133325, 339109914, 1107846641, 0}}, + + {[]int{1467980320, 1137445292}, []int{1467980320, 1137445292}}, + {[]int{1906025257, 900972831, 355994451}, []int{1906025257, 900972831, 355994451}}, + {[]int{1315269736, 1305411155, 1484949641, 1487157818}, []int{1315269736, 1305411155, 1484949641, 1487157818}}, + {[]int{930655562, 191916507, 245232235, 249903878, 1688769114}, []int{930655562, 191916507, 245232235, 249903878, 1688769114}}, + {[]int{1900558240, 1034669852, 62012066, 1636768938, 1951223124, 800157949}, []int{1900558240, 1034669852, 62012066, 1636768938, 1951223124, 800157949}}, + {[]int{1506768621, 1188015241, 521233244, 464809937, 288133325, 339109914, 1107846641}, []int{1506768621, 1188015241, 521233244, 464809937, 288133325, 339109914, 1107846641}}, + + {[]int{1200147517, 527805146, 1459729161, 298883860, 1301476489, 186161068, 997795829, 257063407}, []int{1200147517, 527805146, 1459729161, 298883860, 1301476489, 186161068, 997795829, 257063407}}, + {[]int{1353990425, 388912686, 1299455585, 514865345, 286702144, 1363798779, 1209821622, 492855042, 1874476453}, []int{1840572198, 388912686, 1299455585, 514865345, 286702144, 1363798779, 1209821622, 492855042}}, + {[]int{953116276, 1677525413, 1330847726, 935325903, 367765685, 666819005, 1259969643, 141562180, 860612033, 1047773391}, []int{353519034, 1123437188, 1330847726, 935325903, 367765685, 666819005, 1259969643, 141562180}}, + {[]int{752482022, 867506333, 1723423219, 361328340, 1241112226, 476919145, 182725336, 1468842972, 551661607, 617211228, 590726493}, []int{780961936, 1617032078, 168350958, 361328340, 1241112226, 476919145, 182725336, 1468842972}}, + {[]int{1013674053, 1587348044, 1207155881, 1116555932, 478056632, 1288268012, 1451373934, 1796131301, 1248869310, 1814778483, 275039764, 1209127427}, []int{658375016, 1417252147, 206061443, 324096182, 478056632, 1288268012, 1451373934, 1796131301}}, + {[]int{1498948171, 1003564318, 56559900, 1147491866, 1124826785, 1729654757, 819256679, 1503546020, 1907968762, 427148195, 1043561808, 666209549, 911979384}, []int{340679422, 1675662621, 1469410183, 422733221, 1090270404, 1729654757, 819256679, 1503546020}}, + {[]int{681684962, 586522244, 2004199348, 221839431, 1345587360, 872049662, 1613021061, 1424383966, 558639729, 930888084, 1820774825, 932126772, 391304517, 456497888}, []int{786924218, 759961563, 1900063213, 408904318, 1623405205, 1866994588, 1613021061, 1424383966}}, + {[]int{1727523785, 1360879540, 1735135715, 148864715, 863920986, 1616761360, 984599128, 940289252, 1438501308, 1482499975, 491492725, 1154531434, 1826009304, 581938247, 1777373046}, []int{1444910805, 1562251897, 1101757927, 769114963, 817364120, 1978284314, 403043424, 940289252}}, + {[]int{877033825, 65655943, 178692522, 262720537, 1210970053, 1422087058, 1186945353, 1031788918, 1005627592, 1704405255, 1402192569, 1123357903, 932685461, 1556434039, 1727137317, 1394736528}, []int{1872607732, 694720459, 1509949334, 540061944, 1404180519, 423468198, 52796630, 267763358}}, + {[]int{1134977302, 1648475213, 477296831, 768465845, 11704608, 542135288, 1210705842, 1659979380, 1241159085, 1275966185, 1014704952, 526651747, 1562882049, 1877460597, 1053879427, 1280810580, 1520192977}, []int{1431017196, 1591241801, 1572721698, 521837299, 1097279779, 1061542645, 723784013, 1656034313}}, + {[]int{1555011839, 787815769, 182103774, 781857749, 471054316, 962296659, 1274815353, 1018400075, 1523982888, 1230566410, 1442405474, 639520646, 1147075418, 1454417950, 191443111, 424358887, 598662772, 1646007563}, []int{160052574, 84773776, 1955702541, 1776787092, 1009288388, 854766741, 1367423653, 1659815990}}, + {[]int{1779267173, 1543263614, 1661719435, 1900359777, 893592452, 1006662846, 1442338151, 1005387230, 1302908503, 1159753495, 1351425996, 1816874489, 1790280276, 1100455927, 955172652, 726156767, 566008880, 282490572, 1152386475}, []int{41166504, 163529167, 944692949, 1753319946, 454016278, 1032082517, 1882907718, 940047983}}, + {[]int{1002186586, 435254065, 1396145151, 1622293879, 174136851, 1037681320, 768511646, 675947902, 188213580, 1414683255, 1986089618, 1203380915, 1567490308, 36246439, 1651769094, 105294241, 1140394104, 1184467882, 986378799, 1295577854}, []int{131608080, 269375833, 1666351158, 496800993, 1310402871, 1436392149, 818578391, 1834184553}}, + } { + bb8 := make([]frontend.Variable, len(tc.input)) + for i := range tc.input { + bb8[i] = frontend.Variable(tc.input[i]) + } + bb8red := make([]frontend.Variable, len(tc.reduced)) + for i := range tc.reduced { + bb8red[i] = frontend.Variable(tc.reduced[i]) + } + err := test.IsSolved(&ReduceCircut{Input: make([]frontend.Variable, len(bb8)), Reduced: make([]frontend.Variable, len(bb8red))}, + &ReduceCircut{Input: bb8, Reduced: bb8red}, babybear.Modulus()) + assert.NoError(err) + } +} + +type AddCircuit struct { + A, B, C []frontend.Variable +} + +func (c *AddCircuit) Define(api frontend.API) error { + e, err := NewExtension(api) + if err != nil { + return err + } + res := e.Add(c.A, c.B) + e.AssertIsEqual(c.C, res) + return nil +} + +func TestAdd(t *testing.T) { + assert := test.NewAssert(t) + for _, tc := range []struct { + a, b, c []int + }{ + {[]int{1504941483, 528713979, 1590716977, 1030723568, 691448958, 45161890, 558331570, 1584182780, 884750304, 1178012232, 1236551897, 1743822194, 1102524691, 949136580, 968686988, 1807636110, 1419005839}, []int{1451698632, 267757499, 1153206782, 291258043, 1014114345, 588561574, 161218185, 1655775873, 115681370, 24609626, 1495418674}, []int{943374194, 796471478, 730657838, 1321981611, 1705563303, 633723464, 719549755, 1226692732, 1000431674, 1202621858, 718704650, 1743822194, 1102524691, 949136580, 968686988, 1807636110, 1419005839}}, + {[]int{1874099641, 1982653637, 1187310579, 475561226, 1692092055}, []int{1230758452, 1959413289, 1645368110, 432360750, 1418838351, 687464797, 1234750833, 1203209996, 888838337, 852882006, 356386082, 916503764, 1792596122, 1102186785, 1444663299}, []int{1091592172, 1928801005, 819412768, 907921976, 1097664485, 687464797, 1234750833, 1203209996, 888838337, 852882006, 356386082, 916503764, 1792596122, 1102186785, 1444663299}}, + {[]int{1795185782, 1766445854, 504379178, 1820376092, 137151794, 1064960087, 1759175291, 585123542, 1604030370, 1511659175, 916198528, 1166864589, 1699685308}, []int{805729528, 430124370, 1260617837, 297604025, 613457793, 20971739, 105513811}, []int{587649389, 183304303, 1764997015, 104714196, 750609587, 1085931826, 1864689102, 585123542, 1604030370, 1511659175, 916198528, 1166864589, 1699685308}}, + {[]int{370766974, 1556330301, 1310468525, 1225434398, 928378540, 435540789, 361405873, 1035503425, 545600368, 120758801, 1022518983, 1758884239, 1312473265, 1134254141}, []int{1297416854, 637666735, 325808988, 824671410}, []int{1668183828, 180731115, 1636277513, 36839887, 928378540, 435540789, 361405873, 1035503425, 545600368, 120758801, 1022518983, 1758884239, 1312473265, 1134254141}}, + {[]int{1891025422, 2009813156, 1706954798, 1389626918, 1029725850, 1487402244, 717521687, 10632936, 73787955, 744460996, 1457784272, 1484874357, 811933684, 1652886077, 1531756184, 753745186, 1714652400}, []int{457901568, 944951453, 169650164, 315210583, 1580068898, 1204321039, 648114211, 1202582296, 197451510, 734577008, 1745397260, 1991793135, 1515312634, 736548227, 1072265360, 1703339801, 5096947, 881796514}, []int{335661069, 941498688, 1876604962, 1704837501, 596528827, 678457362, 1365635898, 1213215232, 271239465, 1479038004, 1189915611, 1463401571, 313980397, 376168383, 590755623, 443819066, 1719749347, 881796514}}, + {[]int{1621710603, 226525544, 1202575715}, []int{1412522468, 178072249, 1954193329, 164698463, 2004081065, 1337457847, 1308872918}, []int{1020967150, 404597793, 1143503123, 164698463, 2004081065, 1337457847, 1308872918}}, + {[]int{409456482, 1543428783, 135589462, 1688687654, 1313059883, 348554791, 299198720, 1323721072, 1389838688, 822515643, 927970864, 1040608757, 1776611271, 1797713807, 712571504, 775475735, 1363147356, 787062335, 734743186, 334849816}, []int{1779280006, 450050841, 889363814, 1440765181, 1194153487, 1482798286, 28525033, 743091086, 1967868359, 423958824, 259288007, 640076739, 873173657, 1402881862, 627946497, 315209236, 676276018, 1482056562, 107131096, 295273407}, []int{175470567, 1993479624, 1024953276, 1116186914, 493947449, 1831353077, 327723753, 53546237, 1344441126, 1246474467, 1187258871, 1680685496, 636519007, 1187329748, 1340518001, 1090684971, 26157453, 255852976, 841874282, 630123223}}, + {[]int{439132273, 1435362348, 652986404, 595027578, 50394610, 1163471868, 1350110751, 1387888121, 1541711601, 1311011531, 629723242, 332422020, 1846595946, 1630183415, 892729502, 29895452, 1044010203}, []int{1813396452, 114068876, 1327268679, 1868447085, 184894747, 1182003852}, []int{239262804, 1549431224, 1980255083, 450208742, 235289357, 332209799, 1350110751, 1387888121, 1541711601, 1311011531, 629723242, 332422020, 1846595946, 1630183415, 892729502, 29895452, 1044010203}}, + {[]int{166180982, 764991101, 689087390, 429838129, 645158827, 1453030567, 1933567468, 1814820989, 457070860, 1832972348, 222162489, 312570738, 1353658637, 97753143, 1606729033, 596918423, 1097411730}, []int{682845415, 1075084129, 166827081, 1149467700, 750197496, 1980081828, 137604657, 584718339, 1309044764, 1639753374, 1544780495, 1889342289}, []int{849026397, 1840075230, 855914471, 1579305829, 1395356323, 1419846474, 57906204, 386273407, 1766115624, 1459459801, 1766942984, 188647106, 1353658637, 97753143, 1606729033, 596918423, 1097411730}}, + } { + bb8a := make([]frontend.Variable, len(tc.a)) + for i := range tc.a { + bb8a[i] = frontend.Variable(tc.a[i]) + } + bb8b := make([]frontend.Variable, len(tc.b)) + for i := range tc.b { + bb8b[i] = frontend.Variable(tc.b[i]) + } + bb8c := make([]frontend.Variable, len(tc.c)) + for i := range tc.c { + bb8c[i] = frontend.Variable(tc.c[i]) + } + err := test.IsSolved(&AddCircuit{A: make([]frontend.Variable, len(bb8a)), B: make([]frontend.Variable, len(bb8b)), C: make([]frontend.Variable, len(bb8c))}, + &AddCircuit{A: bb8a, B: bb8b, C: bb8c}, babybear.Modulus()) + assert.NoError(err) + } +} + +type SubCircuit struct { + A, B, C []frontend.Variable +} + +func (c *SubCircuit) Define(api frontend.API) error { + e, err := NewExtension(api) + if err != nil { + return err + } + res := e.Sub(c.A, c.B) + e.AssertIsEqual(c.C, res) + return nil +} + +func TestSub(t *testing.T) { + assert := test.NewAssert(t) + for _, tc := range []struct { + a, b, c []int + }{ + {[]int{1146194893, 161636653, 1838869339, 53943494, 240077858, 1545249092, 1809326915, 1715283441, 1371628, 294589792, 350818866, 391858895, 1629176799, 601342455, 1570046548, 1407018614, 116964098}, []int{1358850047, 1241999865, 899127662}, []int{1800610767, 932902709, 939741677, 53943494, 240077858, 1545249092, 1809326915, 1715283441, 1371628, 294589792, 350818866, 391858895, 1629176799, 601342455, 1570046548, 1407018614, 116964098}}, + {[]int{1930372684, 1864892085, 1136595379, 1655262918, 778003842, 1395703951, 674238279, 303428310, 1869785911, 1465648550, 1654265669, 601993522, 1573728473, 678122861}, []int{758288169, 655811754, 1808890303}, []int{1172084515, 1209080331, 1340970997, 1655262918, 778003842, 1395703951, 674238279, 303428310, 1869785911, 1465648550, 1654265669, 601993522, 1573728473, 678122861}}, + {[]int{198803940, 683262254, 1171724940, 220582, 1436309010, 1011767254, 1619789563, 984205254, 1230618647, 661342751, 1574746193, 850095862, 1888386567}, []int{29202675, 1965459445, 1226138134, 614755, 823163111, 1965257586, 570492890, 714310672, 1863719043, 316112110, 751275028, 1305876957, 76087403, 289554855, 543603956, 1343584811}, []int{169601265, 731068730, 1958852727, 2012871748, 613145899, 1059775589, 1049296673, 269894582, 1380165525, 345230641, 823471165, 1557484826, 1812299164, 1723711066, 1469661965, 669681110}}, + {[]int{867771133, 1674871834, 173849765, 1667039402, 1926702105, 192555144}, []int{793120622, 876063077, 577433800, 1846006825, 1905707677, 1851151225}, []int{74650511, 798808757, 1609681886, 1834298498, 20994428, 354669840}}, + {[]int{56091872, 813716739, 362113363, 1053599731, 178619716, 1801257436, 864815551, 1305284265, 340955220, 1066690326, 674386095, 370881527, 1974134341, 167570042, 1480417387, 190897437}, []int{414748620, 1946157966, 678505871, 1157487387, 1854184016, 438292057, 1226900614, 2009898878, 557555644, 1058000961, 951280428, 1740323340, 1389148174, 315149809, 1822366716, 1274014418, 1803141600, 27865225}, []int{1654609173, 880824694, 1696873413, 1909378265, 337701621, 1362965379, 1651180858, 1308651308, 1796665497, 8689365, 1736371588, 643824108, 584986167, 1865686154, 1671316592, 930148940, 210124321, 1985400696}}, + {[]int{894252018, 1208416601, 802813920, 406175937, 1248756763, 2010718340, 132883210, 808520913}, []int{521827370, 428787881, 1443028395, 248442971, 1526599792, 1784112161, 1259960262, 1432566078, 234210554, 377567478, 1616559930, 1457879671, 1783692545, 1166700134, 63192557, 238060092, 1077493263}, []int{372424648, 779628720, 1373051446, 157732966, 1735422892, 226606179, 886188869, 1389220756, 1779055367, 1635698443, 396705991, 555386250, 229573376, 846565787, 1950073364, 1775205829, 935772658}}, + {[]int{1222756305, 1532801094, 1965391915, 1635685881, 1432129702, 1842258559, 818133559, 126161692, 1872764052, 1885587202, 388899896, 1271969485, 1753820414, 551808295, 272431669, 879739774, 672550552}, []int{296928275, 1436034937, 1801721783, 1498823779, 841763593, 248672479, 124418116, 1495721918, 555622041, 962101046, 1267239367, 1607045139, 1006652808, 369825252, 1129445804}, []int{925828030, 96766157, 163670132, 136862102, 590366109, 1593586080, 693715443, 643705695, 1317142011, 923486156, 1134926450, 1678190267, 747167606, 181983043, 1156251786, 879739774, 672550552}}, + {[]int{1731622955, 615410865, 1558496679, 195832953, 78170750, 61301540, 424972314, 1058412714}, []int{1077167652, 248376566, 1905047628, 1483682839, 135881338, 1082317338, 975917104, 914666340}, []int{654455303, 367034299, 1666714972, 725416035, 1955555333, 992250123, 1462321131, 143746374}}, + {[]int{1362372527, 1429758972, 1923199203, 808799268, 908434557, 22885471, 289022981, 655969201, 944182779, 947702885}, []int{1640391773, 1285351917, 1033611649, 157640943, 584694384}, []int{1735246675, 144407055, 889587554, 651158325, 323740173, 22885471, 289022981, 655969201, 944182779, 947702885}}, + } { + bb8a := make([]frontend.Variable, len(tc.a)) + for i := range tc.a { + bb8a[i] = frontend.Variable(tc.a[i]) + } + bb8b := make([]frontend.Variable, len(tc.b)) + for i := range tc.b { + bb8b[i] = frontend.Variable(tc.b[i]) + } + bb8c := make([]frontend.Variable, len(tc.c)) + for i := range tc.c { + bb8c[i] = frontend.Variable(tc.c[i]) + } + err := test.IsSolved(&SubCircuit{A: make([]frontend.Variable, len(bb8a)), B: make([]frontend.Variable, len(bb8b)), C: make([]frontend.Variable, len(bb8c))}, + &SubCircuit{A: bb8a, B: bb8b, C: bb8c}, babybear.Modulus()) + assert.NoError(err) + } +} + +type MulCircuit struct { + A, B, C []frontend.Variable +} + +func (c *MulCircuit) Define(api frontend.API) error { + e, err := NewExtension(api) + if err != nil { + return err + } + res := e.Mul(c.A, c.B) + e.AssertIsEqual(c.C, res) + return nil +} + +func TestMul(t *testing.T) { + assert := test.NewAssert(t) + for _, tc := range []struct { + a, b, c []int + }{ + {[]int{234968604, 1416371157, 1226800682, 893689929, 1778035510, 146580532, 280014629, 1865717137, 982812264, 531104756, 624717176}, []int{1870372928, 89929324, 1716676259}, []int{1906632739, 115903316, 672362298, 305415989, 834985591, 1605817228, 1210941820, 985790928}}, + {[]int{1087077945, 320581995, 1629282702, 1741108544, 1040857706, 1916768501, 1565495085, 823889356, 1417428004, 1583630854, 1114754081, 1910869750, 187917565, 1438312600}, []int{1563899835, 168797949, 1371079710, 2987340, 1026622935, 1246885219, 506032556, 1788593166, 237013976, 1824355399, 625048497, 68448670, 1607339381, 951954832, 885388282, 683432779, 1575631187}, []int{720142330, 1056607600, 416577423, 1478261035, 1220299325, 903507263, 959938193, 355726286}}, + {[]int{1270288480, 1120584133, 721331187, 1421659182, 1094444484, 359616929, 969570910, 1882596876, 1297123805, 1881461151, 97448081}, []int{2003818742, 1858628164, 1023684969, 1085350554, 781453742, 1116677995, 1468065106, 1335317024, 1486544729, 1673869660, 144423861}, []int{426591007, 224682575, 778802683, 1271911177, 251644533, 207528538, 1964476679, 1876339154}}, + {[]int{1853783118, 1380960591, 964095257, 695279244, 315564693, 1867490771, 53851649, 1343775624, 653780889, 1583674803}, []int{1969070769, 1769394471, 414599120, 647597532, 1788546055, 224442741, 1412932412, 680401167, 298718932, 1146328071, 1478899454, 1909103677, 1428990649, 1439633502, 54662272, 596249162, 461878709, 563248862, 1000459500, 1645847614}, []int{504670415, 227198315, 1349561269, 501560516, 894895922, 1993202942, 1850127592, 1108750151}}, + {[]int{414037409, 953085481, 1924772772, 1517340116, 1237653110, 133837088, 1315588440, 238864701}, []int{1304959022, 1925100119, 978981709, 1918377397, 1207231558, 281134995, 502889770}, []int{142108887, 1889729586, 1250323514, 594024056, 1230999660, 1787836861, 1534177366, 801580624}}, + {[]int{947344343, 423823149, 344707902, 700248832, 566581327, 1849547514, 399209144, 1091850846, 1364174972, 1803614392, 1634840199, 1026184357, 1838704001, 203731055, 743992513, 251080705, 1036651012, 759652320, 577883317, 1716209722, 1529813228}, []int{886233417, 796245045, 174590451, 434936528, 626331990}, []int{1530952022, 551906495, 1612686281, 1586062699, 971563806, 686771089, 907205414, 1181573107}}, + {[]int{7027166, 463591444, 1803846561, 1505438619, 2012281334, 1039204555, 1439978503, 1620975569, 870977727, 746630744, 1686836478, 1924796057}, []int{819644881, 463949651, 522020012, 1377665054, 831007978, 1954765014, 1976440214, 1392258642, 52259004, 1536634317, 1591661847, 628460335, 150161825, 1915169606, 671751539, 196434398, 1160799204, 1385730435, 583362563}, []int{128480043, 456509411, 299825145, 1189434742, 1790453058, 918340297, 1075473370, 992322875}}, + {[]int{1711268919, 1677353510}, []int{315387358, 877475853, 1779977986, 1816170934, 1740575889, 1377265373, 2007938566, 486612909, 953317838, 150150087, 1034065308, 130344828, 1720755480, 1194766973, 74573519, 511551933, 1766944307, 214027799, 226716130, 1055958601, 1902536491}, []int{1633902657, 998090610, 1658849111, 175196657, 889509699, 1868699941, 1847876682, 1348099822}}, + {[]int{1926641336, 233517016, 1382898361, 516240145, 730324703, 196139649, 1751487986, 1718388392, 93866265, 234489342, 1447664327, 978489786, 629636261}, []int{1812584753, 1146117727, 185071390}, []int{1032160623, 1983522902, 1373330365, 1300425158, 649650457, 1711009205, 567980350, 546199098}}, + } { + bb8a := make([]frontend.Variable, len(tc.a)) + for i := range tc.a { + bb8a[i] = frontend.Variable(tc.a[i]) + } + bb8b := make([]frontend.Variable, len(tc.b)) + for i := range tc.b { + bb8b[i] = frontend.Variable(tc.b[i]) + } + bb8c := make([]frontend.Variable, len(tc.c)) + for i := range tc.c { + bb8c[i] = frontend.Variable(tc.c[i]) + } + err := test.IsSolved(&MulCircuit{A: make([]frontend.Variable, len(bb8a)), B: make([]frontend.Variable, len(bb8b)), C: make([]frontend.Variable, len(bb8c))}, + &MulCircuit{A: bb8a, B: bb8b, C: bb8c}, babybear.Modulus()) + assert.NoError(err) + } +} diff --git a/std/math/fieldextension/option.go b/std/math/fieldextension/option.go new file mode 100644 index 000000000..256405a00 --- /dev/null +++ b/std/math/fieldextension/option.go @@ -0,0 +1,36 @@ +package fieldextension + +type config struct { + extension []int + degree int +} + +type Option func(*config) error + +// WithDegree forces the degree of the extension field. If not set then we +// choose the degree which provides soundness over the native field. +func WithDegree(degree int) Option { + return func(c *config) error { + c.degree = degree + return nil + } +} + +func WithExtension(extension []int) Option { + return func(c *config) error { + c.extension = extension + return nil + } +} + +func newConfig(opts ...Option) (*config, error) { + c := &config{ + degree: -1, + } + for _, opt := range opts { + if err := opt(c); err != nil { + return nil, err + } + } + return c, nil +} From 3a2f75b1454232be5127bebba3553fc08732c2cd Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Tue, 3 Dec 2024 12:56:11 +0000 Subject: [PATCH 09/21] feat: add experimental WideCommitter interface --- frontend/builder.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/frontend/builder.go b/frontend/builder.go index f39262be1..b06dfdc38 100644 --- a/frontend/builder.go +++ b/frontend/builder.go @@ -91,6 +91,26 @@ type Committer interface { Commit(toCommit ...Variable) (commitment Variable, err error) } +// WideCommitter allows to commit to the variables and returns the commitment as +// an extension field element. The commitment can be used as a challenge using +// Fiat-Shamir heuristic. This method is required when the circuit is defined +// over a small field where the individual commitment would be too small to +// achieve desired soundness level. +// +// This is experimental API and may be subject to change. It is not relevant for +// pairing-based backends where the commitment is in a large field and is not +// defined for such cases. Thus, the caller should check if this or [Committer] +// interfaces is implemented and use the appropriate method. +type WideCommitter interface { + // WideCommit commits to the variables and returns the commitments. + // This method is required when the circuit is defined over a small field + // where the individual commitment would be too small to achieve desired + // soundness level. + // + // The width parameter defines the number of elements in the commitment. + WideCommit(width int, toCommit ...Variable) (commitment []Variable, err error) +} + // Rangechecker allows to externally range-check the variables to be of // specified width. Not all compilers implement this interface. Users should // instead use [github.com/consensys/gnark/std/rangecheck] package which From d01d7cd22a948c786b7e3592781f8fd6610a2514 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Tue, 3 Dec 2024 13:23:21 +0000 Subject: [PATCH 10/21] feat: implement wide commitment for multicommit --- std/multicommit/nativecommit.go | 64 ++++++++++++++++++++++------ std/multicommit/nativecommit_test.go | 34 ++++++++++++++- 2 files changed, 85 insertions(+), 13 deletions(-) diff --git a/std/multicommit/nativecommit.go b/std/multicommit/nativecommit.go index 60f10f76b..467aa49a7 100644 --- a/std/multicommit/nativecommit.go +++ b/std/multicommit/nativecommit.go @@ -23,9 +23,16 @@ import ( ) type multicommitter struct { - closed bool - vars []frontend.Variable - cbs []WithCommitmentFn + closed bool + vars []frontend.Variable + cbs []WithCommitmentFn + wcbs []wcbInfo + maxWidth int +} + +type wcbInfo struct { + cb WithWideCommitmentFn + width int } type ctxMulticommitterKey struct{} @@ -80,20 +87,34 @@ func getCached(api frontend.API) *multicommitter { func (mct *multicommitter) commitAndCall(api frontend.API) error { // close collecting input in case anyone wants to check more variables to commit to. mct.closed = true - if len(mct.cbs) == 0 { + if len(mct.cbs) == 0 && len(mct.wcbs) == 0 { // shouldn't happen. we defer this function on creating multicommitter // instance. It is probably some race. panic("calling committer with zero callbacks") } - committer, ok := api.Compiler().(frontend.Committer) - if !ok { - panic("compiler doesn't implement frontend.Committer") - } - rootCmt, err := committer.Commit(mct.vars...) - if err != nil { - return fmt.Errorf("commit: %w", err) + var rootCmt []frontend.Variable + var err error + if len(mct.wcbs) > 0 && mct.maxWidth > 1 { + committer, ok := api.Compiler().(frontend.WideCommitter) + if !ok { + panic("compiler doesn't implement frontend.WideCommitter") + } + rootCmt, err = committer.WideCommit(mct.maxWidth, mct.vars...) + if err != nil { + return fmt.Errorf("wide commit: %w", err) + } + } else { + committer, ok := api.Compiler().(frontend.Committer) + if !ok { + panic("compiler doesn't implement frontend.Committer") + } + rootCmt = make([]frontend.Variable, 1) + rootCmt[0], err = committer.Commit(mct.vars...) + if err != nil { + return fmt.Errorf("commit: %w", err) + } } - cmt := rootCmt + cmt := rootCmt[0] if err = mct.cbs[0](api, cmt); err != nil { return fmt.Errorf("callback 0: %w", err) } @@ -103,6 +124,11 @@ func (mct *multicommitter) commitAndCall(api frontend.API) error { return fmt.Errorf("callback %d: %w", i, err) } } + for i := 0; i < len(mct.wcbs); i++ { + if err := mct.wcbs[i].cb(api, rootCmt[:mct.wcbs[i].width]); err != nil { + return fmt.Errorf("wide callback %d: %w", i, err) + } + } return nil } @@ -115,6 +141,10 @@ func (mct *multicommitter) commitAndCall(api frontend.API) error { // leads to panic. However, the method can call defer for other callbacks. type WithCommitmentFn func(api frontend.API, commitment frontend.Variable) error +// WithWideCommitmentFn is as [WidthCommitmentFn], but instead receives a slice +// of commitments. The commitments is generated in the extension field. +type WithWideCommitmentFn func(api frontend.API, commitment []frontend.Variable) error + // WithCommitment schedules the function cb to be called with a unique // commitment. We append the variables committedVariables to be committed to // with the native [frontend.Committer] interface. @@ -126,3 +156,13 @@ func WithCommitment(api frontend.API, cb WithCommitmentFn, committedVariables .. mct.vars = append(mct.vars, committedVariables...) mct.cbs = append(mct.cbs, cb) } + +func WithWideCommitment(api frontend.API, cb WithWideCommitmentFn, width int, committedVariable ...frontend.Variable) { + mct := getCached(api) + if mct.closed { + panic("called WithCommitment recursively") + } + mct.maxWidth = max(mct.maxWidth, width) + mct.vars = append(mct.vars, committedVariable...) + mct.wcbs = append(mct.wcbs, wcbInfo{cb: cb, width: width}) +} diff --git a/std/multicommit/nativecommit_test.go b/std/multicommit/nativecommit_test.go index b78f51865..44b143b6e 100644 --- a/std/multicommit/nativecommit_test.go +++ b/std/multicommit/nativecommit_test.go @@ -6,6 +6,8 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/internal/smallfields/babybear" + "github.com/consensys/gnark/std/math/fieldextension" "github.com/consensys/gnark/test" ) @@ -51,7 +53,7 @@ func TestMultipleCommitments(t *testing.T) { circuit := multipleCommitmentCircuit{} assignment := multipleCommitmentCircuit{X: 10} assert := test.NewAssert(t) - assert.ProverSucceeded(&circuit, &assignment, test.WithCurves(ecc.BN254)) // right now PLONK doesn't implement commitment + assert.ProverSucceeded(&circuit, &assignment, test.WithCurves(ecc.BN254)) } type noCommitVariable struct { @@ -71,3 +73,33 @@ func TestNoCommitVariable(t *testing.T) { assert := test.NewAssert(t) assert.ProverSucceeded(&circuit, &assignment, test.WithCurves(ecc.BN254)) } + +type wideCommitment struct { + X frontend.Variable +} + +func (c *wideCommitment) Define(api frontend.API) error { + WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { + api.AssertIsDifferent(commitment, 0) + return nil + }, c.X) + WithWideCommitment(api, func(api frontend.API, commitment []frontend.Variable) error { + fe, err := fieldextension.NewExtension(api, fieldextension.WithDegree(8)) + if err != nil { + return err + } + res := fe.Mul(commitment, commitment) + for i := range res { + api.AssertIsDifferent(res[i], 0) + } + return nil + }, 8, c.X) + return nil +} + +func TestWideCommitment(t *testing.T) { + assert := test.NewAssert(t) + err := test.IsSolved(&wideCommitment{}, &wideCommitment{X: 10}, babybear.Modulus()) + // TODO: when we have implemented for PLONK, then also check for that. We're never going to implement for Groth16 + assert.NoError(err) +} From 40a8c2fbf2fbf5f9f4053f3c67e64328f8b036fa Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Tue, 3 Dec 2024 13:23:39 +0000 Subject: [PATCH 11/21] test: add WideCommit to test engine --- test/engine.go | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/engine.go b/test/engine.go index 38182d881..f18a84d38 100644 --- a/test/engine.go +++ b/test/engine.go @@ -692,6 +692,25 @@ func (e *engine) Commit(v ...frontend.Variable) (frontend.Variable, error) { return res, nil } +func (e *engine) WideCommit(width int, v ...frontend.Variable) ([]frontend.Variable, error) { + nb := (e.FieldBitLen() + 7) / 8 + buf := make([]byte, nb) + hasher := sha3.NewCShake128(nil, []byte("gnark test engine")) + for i := range v { + vs := e.toBigInt(v[i]) + bs := vs.FillBytes(buf) + hasher.Write(bs) + } + res := make([]frontend.Variable, width) + for i := 0; i < width; i++ { + hasher.Read(buf) + resi := new(big.Int).SetBytes(buf) + resi.Mod(resi, e.modulus()) + res[i] = new(big.Int).Set(resi) + } + return res, nil +} + func (e *engine) Defer(cb func(frontend.API) error) { circuitdefer.Put(e, cb) } From 86076893b9eeefd6944b71b28b054e38c6088a3b Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 9 Dec 2024 17:02:43 +0000 Subject: [PATCH 12/21] fix: multiplication var when in native commit mode --- std/multicommit/nativecommit.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/std/multicommit/nativecommit.go b/std/multicommit/nativecommit.go index 467aa49a7..44a0fa95f 100644 --- a/std/multicommit/nativecommit.go +++ b/std/multicommit/nativecommit.go @@ -119,7 +119,7 @@ func (mct *multicommitter) commitAndCall(api frontend.API) error { return fmt.Errorf("callback 0: %w", err) } for i := 1; i < len(mct.cbs); i++ { - cmt = api.Mul(rootCmt, cmt) + cmt = api.Mul(rootCmt[0], cmt) if err := mct.cbs[i](api, cmt); err != nil { return fmt.Errorf("callback %d: %w", i, err) } From 8621dd34858dff017a6eda3c3c84eba4c6ccdae1 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 9 Dec 2024 17:03:13 +0000 Subject: [PATCH 13/21] feat: add utility methods to field extension --- std/math/fieldextension/fieldextension.go | 30 +++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/std/math/fieldextension/fieldextension.go b/std/math/fieldextension/fieldextension.go index 3504384a9..f52054174 100644 --- a/std/math/fieldextension/fieldextension.go +++ b/std/math/fieldextension/fieldextension.go @@ -172,3 +172,33 @@ func (e *Extension) AssertIsEqual(a, b ExtensionVariable) { e.api.AssertIsEqual(b[i], 0) } } + +func (e *Extension) Zero() ExtensionVariable { + ret := make(ExtensionVariable, len(e.extension)) + for i := range ret { + ret[i] = frontend.Variable(0) + } + return ret +} + +func (e *Extension) One() ExtensionVariable { + ret := make(ExtensionVariable, len(e.extension)) + ret[0] = frontend.Variable(1) + for i := 1; i < len(ret); i++ { + ret[i] = frontend.Variable(0) + } + return ret +} + +func (e *Extension) AsExtensionVariable(a frontend.Variable) ExtensionVariable { + ret := make(ExtensionVariable, len(e.extension)) + ret[0] = a + for i := 1; i < len(ret); i++ { + ret[i] = frontend.Variable(0) + } + return ret +} + +func (e *Extension) Degree() int { + return len(e.extension) - 1 +} From 0751e9975ec0c46bcbfde4f2b63afe2f113eca9b Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 9 Dec 2024 17:04:48 +0000 Subject: [PATCH 14/21] feat: initialize mulchecks in field extensioifen small native field --- std/math/emulated/field.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/std/math/emulated/field.go b/std/math/emulated/field.go index d7087974c..81f9cdf7e 100644 --- a/std/math/emulated/field.go +++ b/std/math/emulated/field.go @@ -10,6 +10,7 @@ import ( "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" limbs "github.com/consensys/gnark/std/internal/limbcomposition" + "github.com/consensys/gnark/std/math/fieldextension" "github.com/consensys/gnark/std/rangecheck" "github.com/rs/zerolog" "golang.org/x/exp/constraints" @@ -22,6 +23,8 @@ import ( type Field[T FieldParams] struct { // api is the native API api frontend.API + // extensionApi is the extension API when we need to perform multiplication checks over the extension field + extensionApi *fieldextension.Extension // f carries the ring parameters fParams T @@ -73,6 +76,14 @@ func NewField[T FieldParams](native frontend.API) (*Field[T], error) { constrainedLimbs: make(map[[16]byte]struct{}), checker: rangecheck.New(native), } + if native.Compiler().Field().BitLen() < 128 { + f.log.Debug().Msg("using small native field, multiplication checks will be performed in extension field") + extapi, err := fieldextension.NewExtension(native) + if err != nil { + return nil, fmt.Errorf("extension field: %w", err) + } + f.extensionApi = extapi + } // ensure prime is correctly set if f.fParams.IsPrime() { From 983ec62e3a20ed185470ab347e8fc44d7ea9ba97 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 9 Dec 2024 17:05:24 +0000 Subject: [PATCH 15/21] feat: implement mulchecks in field extension --- std/math/emulated/field_mul.go | 259 ++++++++++++++++++++++++++------- 1 file changed, 207 insertions(+), 52 deletions(-) diff --git a/std/math/emulated/field_mul.go b/std/math/emulated/field_mul.go index 4617db966..5da7cc059 100644 --- a/std/math/emulated/field_mul.go +++ b/std/math/emulated/field_mul.go @@ -8,6 +8,7 @@ import ( "github.com/consensys/gnark/frontend" limbs "github.com/consensys/gnark/std/internal/limbcomposition" + "github.com/consensys/gnark/std/math/fieldextension" "github.com/consensys/gnark/std/multicommit" ) @@ -18,6 +19,11 @@ import ( // checks. // // Currently used for multiplication and multivariate evaluation checks. +// +// The methods [evalRound1], [evalRound2] and [check] may receive as inputs +// either [frontend.Variable] or [fieldextension.ExtensionVariable]. The +// implementation should differentiate on the different input types and use the +// appropriate API (native or extension). type deferredChecker interface { // toCommit outputs the variable which should be committed to. The checker // then uses the commitment to obtain the verifier challenge for the @@ -165,9 +171,37 @@ func (mc *mulCheck[T]) check(api frontend.API, peval, coef frontend.Variable) { if mc.p != nil { peval = mc.p.evaluation } - ls := api.Mul(mc.a.evaluation, mc.b.evaluation) - rs := api.Add(mc.r.evaluation, api.Mul(peval, mc.k.evaluation), api.Mul(mc.c.evaluation, coef)) - api.AssertIsEqual(ls, rs) + // we either have to perform the equality check in the native field or in + // the extension field. It was already determined at the [Field] + // initialization time which kind of check needs to be done. + if mc.f.extensionApi == nil { + ls := api.Mul(mc.a.evaluation, mc.b.evaluation) + rs := api.Add(mc.r.evaluation, api.Mul(peval, mc.k.evaluation), api.Mul(mc.c.evaluation, coef)) + api.AssertIsEqual(ls, rs) + } else { + // here we use the fact that [frontend.Variable] is defined as any, but + // we have actually provided [ExtensionVariable]. We type assert to be + // able to use the fieldextension API. + // + // the computations are same as in the previous conditional block, but + // only in the extension. + aext := mc.a.evaluation.(fieldextension.ExtensionVariable) + bext := mc.b.evaluation.(fieldextension.ExtensionVariable) + ls := mc.f.extensionApi.Mul(aext, bext) + + rext := mc.r.evaluation.(fieldextension.ExtensionVariable) + pevalext := peval.(fieldextension.ExtensionVariable) + cext := mc.c.evaluation.(fieldextension.ExtensionVariable) + kext := mc.k.evaluation.(fieldextension.ExtensionVariable) + coefext := coef.(fieldextension.ExtensionVariable) + pkext := mc.f.extensionApi.Mul(pevalext, kext) + ccoefext := mc.f.extensionApi.Mul(coefext, cext) + + rs := mc.f.extensionApi.Add(rext, pkext) + rs = mc.f.extensionApi.Add(rs, ccoefext) + + mc.f.extensionApi.AssertIsEqual(ls, rs) + } } // cleanEvaluations cleans the cached evaluation values. This is necessary for @@ -254,6 +288,18 @@ func (f *Field[T]) evalWithChallenge(a *Element[T], at []frontend.Variable) *Ele if len(at) < len(a.Limbs)-1 { panic("evaluation powers less than limbs") } + var sum frontend.Variable + if f.extensionApi != nil { + sum = f.evalWithChallengeExtension(a, at) + } else { + sum = f.evalWithChallengeNative(a, at) + } + a.isEvaluated = true + a.evaluation = sum + return a +} + +func (f *Field[T]) evalWithChallengeNative(a *Element[T], at []frontend.Variable) frontend.Variable { var sum frontend.Variable = 0 if len(a.Limbs) > 0 { sum = f.api.Mul(a.Limbs[0], 1) // copy because we use MulAcc @@ -261,9 +307,30 @@ func (f *Field[T]) evalWithChallenge(a *Element[T], at []frontend.Variable) *Ele for i := 1; i < len(a.Limbs); i++ { sum = f.api.MulAcc(sum, a.Limbs[i], at[i-1]) } - a.isEvaluated = true - a.evaluation = sum - return a + return sum +} + +func (f *Field[T]) evalWithChallengeExtension(a *Element[T], at []frontend.Variable) frontend.Variable { + // even though at is []frontend.Variable, then we abuse the fact that + // frontend.Variable is defined as any and at is []ExtensionVariable. We + // type assert it. + atext := make([]fieldextension.ExtensionVariable, len(at)) + for i := 0; i < len(at); i++ { + atexti, ok := at[i].(fieldextension.ExtensionVariable) + if !ok { + panic("not an extension variable") + } + atext[i] = atexti + } + sum := f.extensionApi.Zero() + if len(a.Limbs) > 0 { + sum = f.extensionApi.AsExtensionVariable(a.Limbs[0]) + } + for i := 1; i < len(a.Limbs); i++ { + toAdd := f.extensionApi.MulByElement(atext[i-1], a.Limbs[i]) + sum = f.extensionApi.Add(sum, toAdd) + } + return sum } // performMulChecks should be deferred to actually perform all the @@ -288,43 +355,92 @@ func (f *Field[T]) performDeferredChecks(api frontend.API) error { for i := range f.deferredChecks { toCommit = append(toCommit, f.deferredChecks[i].toCommit()...) } - // we give all the inputs as inputs to obtain random verifier challenge. - multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { - // for efficiency, we compute all powers of the challenge as slice at. - coefsLen := int(f.fParams.NbLimbs()) - for i := range f.deferredChecks { - coefsLen = max(coefsLen, f.deferredChecks[i].maxLen()) - } - at := make([]frontend.Variable, coefsLen) - at[0] = commitment - for i := 1; i < len(at); i++ { - at[i] = api.Mul(at[i-1], commitment) - } - // evaluate all r, k, c - for i := range f.deferredChecks { - f.deferredChecks[i].evalRound1(at) - } - // assuming r is input to some other multiplication, then is already evaluated - for i := range f.deferredChecks { - f.deferredChecks[i].evalRound2(at) - } - // evaluate p(X) at challenge - pval := f.evalWithChallenge(f.Modulus(), at) - // compute (2^t-X) at challenge - coef := big.NewInt(1) - coef.Lsh(coef, f.fParams.BitsPerLimb()) - ccoef := api.Sub(coef, commitment) - // verify all mulchecks - for i := range f.deferredChecks { - f.deferredChecks[i].check(api, pval.evaluation, ccoef) - } - // clean cached evaluation. Helps in case we compile the same circuit - // multiple times. - for i := range f.deferredChecks { - f.deferredChecks[i].cleanEvaluations() - } - return nil - }, toCommit...) + if f.extensionApi == nil { + // we give all the inputs as inputs to obtain random verifier challenge. + multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { + // for efficiency, we compute all powers of the challenge as slice at. + coefsLen := int(f.fParams.NbLimbs()) + for i := range f.deferredChecks { + coefsLen = max(coefsLen, f.deferredChecks[i].maxLen()) + } + at := make([]frontend.Variable, coefsLen) + at[0] = commitment + for i := 1; i < len(at); i++ { + at[i] = api.Mul(at[i-1], commitment) + } + // evaluate all r, k, c + for i := range f.deferredChecks { + f.deferredChecks[i].evalRound1(at) + } + // assuming r is input to some other multiplication, then is already evaluated + for i := range f.deferredChecks { + f.deferredChecks[i].evalRound2(at) + } + // evaluate p(X) at challenge + pval := f.evalWithChallenge(f.Modulus(), at) + // compute (2^t-X) at challenge + coef := big.NewInt(1) + coef.Lsh(coef, f.fParams.BitsPerLimb()) + ccoef := api.Sub(coef, commitment) + // verify all mulchecks + for i := range f.deferredChecks { + f.deferredChecks[i].check(api, pval.evaluation, ccoef) + } + // clean cached evaluation. Helps in case we compile the same circuit + // multiple times. + for i := range f.deferredChecks { + f.deferredChecks[i].cleanEvaluations() + } + return nil + }, toCommit...) + } else { + // this is the same as above, but we have challenges in the extension + // field. The commitment argument below is actually extension field + // element, but we give it as []frontend.Variable for interface + // compatibility. + multicommit.WithWideCommitment(api, func(api frontend.API, commitment []frontend.Variable) error { + // for efficiency, we compute all powers of the challenge as slice at. + coefsLen := int(f.fParams.NbLimbs()) + for i := range f.deferredChecks { + coefsLen = max(coefsLen, f.deferredChecks[i].maxLen()) + } + at := make([]fieldextension.ExtensionVariable, coefsLen) + at[0] = commitment + for i := 1; i < len(at); i++ { + atexti := at[i-1] + at[i] = f.extensionApi.Mul((fieldextension.ExtensionVariable)(atexti), commitment) + } + atv := make([]frontend.Variable, len(at)) + for i := range at { + atv[i] = at[i] + } + // evaluate all r, k, c + for i := range f.deferredChecks { + f.deferredChecks[i].evalRound1(atv) + } + // assuming r is input to some other multiplication, then is already evaluated + for i := range f.deferredChecks { + f.deferredChecks[i].evalRound2(atv) + } + // evaluate p(X) at challenge + pval := f.evalWithChallenge(f.Modulus(), atv) + // compute (2^t-X) at challenge + coef := big.NewInt(1) + coef.Lsh(coef, f.fParams.BitsPerLimb()) + coefext := f.extensionApi.AsExtensionVariable(coef) + ccoef := f.extensionApi.Sub(coefext, commitment) + // verify all mulchecks + for i := range f.deferredChecks { + f.deferredChecks[i].check(api, pval.evaluation, ccoef) + } + // clean cached evaluation. Helps in case we compile the same circuit + // multiple times. + for i := range f.deferredChecks { + f.deferredChecks[i].cleanEvaluations() + } + return nil + }, f.extensionApi.Degree(), toCommit...) + } return nil } @@ -807,19 +923,58 @@ func (mc *mvCheck[T]) evalRound2(at []frontend.Variable) { } } +// check checks that the multivariate polynomial f(x1(ch), x2(ch), ...) = r(ch) +// + k(ch)*p(ch) + (2^t-ch) c(ch) holds. As p and (2^t-ch) are same over all +// checks then we get them as arguments to this method. func (mc *mvCheck[T]) check(api frontend.API, peval, coef frontend.Variable) { - ls := frontend.Variable(0) - for i, term := range mc.mv.Terms { - termProd := frontend.Variable(mc.mv.Coefficients[i]) - for i, pow := range term { - for j := 0; j < pow; j++ { - termProd = api.Mul(termProd, mc.vals[i].evaluation) + // we either have to perform the equality check in the native field or in + // the extension field. It was already determined at the [Field] + // initialization time which kind of check needs to be done. + if mc.f.extensionApi == nil { + ls := frontend.Variable(0) + for i, term := range mc.mv.Terms { + termProd := frontend.Variable(mc.mv.Coefficients[i]) + for i, pow := range term { + for j := 0; j < pow; j++ { + termProd = api.Mul(termProd, mc.vals[i].evaluation) + } } + ls = api.Add(ls, termProd) } - ls = api.Add(ls, termProd) + rs := api.Add(mc.r.evaluation, api.Mul(peval, mc.k.evaluation), api.Mul(mc.c.evaluation, coef)) + api.AssertIsEqual(ls, rs) + } else { + // here we use the fact that [frontend.Variable] is defined as any, but + // we have actually provided [ExtensionVariable]. We type assert to be + // able to use the fieldextension API. + // + // the computations are same as in the previous conditional block, but + // only in the extension. + ls := mc.f.extensionApi.Zero() + for i, term := range mc.mv.Terms { + termProd := mc.f.extensionApi.AsExtensionVariable(mc.mv.Coefficients[i]) + for i, pow := range term { + for j := 0; j < pow; j++ { + valsexti := mc.vals[i].evaluation.(fieldextension.ExtensionVariable) + termProd = mc.f.extensionApi.Mul(termProd, valsexti) + } + } + ls = mc.f.extensionApi.Add(ls, termProd) + } + rext := mc.r.evaluation.(fieldextension.ExtensionVariable) + pevalext := peval.(fieldextension.ExtensionVariable) + kext := mc.k.evaluation.(fieldextension.ExtensionVariable) + cext := mc.c.evaluation.(fieldextension.ExtensionVariable) + coefext := coef.(fieldextension.ExtensionVariable) + + pkext := mc.f.extensionApi.Mul(pevalext, kext) + ccoefext := mc.f.extensionApi.Mul(coefext, cext) + + rs := mc.f.extensionApi.Add(rext, pkext) + rs = mc.f.extensionApi.Add(rs, ccoefext) + + mc.f.extensionApi.AssertIsEqual(ls, rs) } - rs := api.Add(mc.r.evaluation, api.Mul(peval, mc.k.evaluation), api.Mul(mc.c.evaluation, coef)) - api.AssertIsEqual(ls, rs) } func (mc *mvCheck[T]) cleanEvaluations() { From 364746468ecd7bc1c3ef19f33140cff02365c85b Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 9 Dec 2024 17:05:43 +0000 Subject: [PATCH 16/21] chore: remove unused field --- std/math/emulated/field.go | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/std/math/emulated/field.go b/std/math/emulated/field.go index 81f9cdf7e..1c7ff8d89 100644 --- a/std/math/emulated/field.go +++ b/std/math/emulated/field.go @@ -34,16 +34,14 @@ type Field[T FieldParams] struct { maxOfOnce sync.Once // constants for often used elements n, 0 and 1. Allocated only once - nConstOnce sync.Once - nConst *Element[T] - nprevConstOnce sync.Once - nprevConst *Element[T] - zeroConstOnce sync.Once - zeroConst *Element[T] - oneConstOnce sync.Once - oneConst *Element[T] - shortOneConstOnce sync.Once - shortOneConst *Element[T] + nConstOnce sync.Once + nConst *Element[T] + nprevConstOnce sync.Once + nprevConst *Element[T] + zeroConstOnce sync.Once + zeroConst *Element[T] + oneConstOnce sync.Once + oneConst *Element[T] log zerolog.Logger From c36c92eb5ef62ea1cdd9d8ed1d88499db361dad5 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Tue, 10 Dec 2024 00:25:47 +0000 Subject: [PATCH 17/21] chore: avoid initializing hasher when 1-column logderivarg --- std/internal/logderivarg/logderivarg.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/std/internal/logderivarg/logderivarg.go b/std/internal/logderivarg/logderivarg.go index 4dcc55933..f7adac69a 100644 --- a/std/internal/logderivarg/logderivarg.go +++ b/std/internal/logderivarg/logderivarg.go @@ -147,6 +147,10 @@ func Build(api frontend.API, table Table, queries Table) error { } func randLinearCoefficients(api frontend.API, nbRow int, commitment frontend.Variable) (rowCoeffs []frontend.Variable, challenge frontend.Variable) { + if nbRow == 1 { + // to avoid initializing the hasher. + return []frontend.Variable{1}, commitment + } hasher, err := mimc.NewMiMC(api) if err != nil { panic(err) From e905c21b935ef182d579968e50e232d9d48c3ebf Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 25 Nov 2024 11:36:29 +0000 Subject: [PATCH 18/21] XXX babybear tests --- internal/smallfields/circuits_test.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/internal/smallfields/circuits_test.go b/internal/smallfields/circuits_test.go index 0a4cfc721..d40d1ae88 100644 --- a/internal/smallfields/circuits_test.go +++ b/internal/smallfields/circuits_test.go @@ -5,6 +5,7 @@ import ( "github.com/consensys/gnark-crypto/field/goldilocks" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/smallfields/babybear" "github.com/consensys/gnark/std/math/emulated" "github.com/consensys/gnark/std/math/emulated/emparams" "github.com/consensys/gnark/test" @@ -48,11 +49,11 @@ type smallBN struct { } func (smallBN) BitsPerLimb() uint { - return 16 + return 11 } func (smallBN) NbLimbs() uint { - return 16 + return 24 } func TestEmulatedCircuit(t *testing.T) { @@ -63,4 +64,10 @@ func TestEmulatedCircuit(t *testing.T) { err = test.IsSolved(&EmulatedCircuit[smallBN]{}, &EmulatedCircuit[smallBN]{A: emulated.ValueOf[smallBN](2), B: emulated.ValueOf[smallBN](4)}, goldilocks.Modulus()) assert.NoError(err) + + err = test.IsSolved(&EmulatedCircuit[emparams.BN254Fp]{}, &EmulatedCircuit[emparams.BN254Fp]{A: emulated.ValueOf[emparams.BN254Fp](2), B: emulated.ValueOf[emparams.BN254Fp](4)}, babybear.Modulus()) + assert.Error(err) + + err = test.IsSolved(&EmulatedCircuit[smallBN]{}, &EmulatedCircuit[smallBN]{A: emulated.ValueOf[smallBN](2), B: emulated.ValueOf[smallBN](4)}, babybear.Modulus()) + assert.NoError(err) } From 046701bceb194bf15fa58c3e585ed0427e6059df Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Tue, 10 Dec 2024 00:25:59 +0000 Subject: [PATCH 19/21] XXX tests --- internal/smallfields/circuits_test.go | 151 ++++++++++++++++++++++++-- 1 file changed, 143 insertions(+), 8 deletions(-) diff --git a/internal/smallfields/circuits_test.go b/internal/smallfields/circuits_test.go index d40d1ae88..d0aea3262 100644 --- a/internal/smallfields/circuits_test.go +++ b/internal/smallfields/circuits_test.go @@ -1,14 +1,25 @@ package smallfields import ( + "crypto/rand" + "fmt" + "math/big" "testing" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bn254" "github.com/consensys/gnark-crypto/field/goldilocks" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/internal/smallfields/babybear" + "github.com/consensys/gnark/std/algebra/emulated/sw_bn254" "github.com/consensys/gnark/std/math/emulated" "github.com/consensys/gnark/std/math/emulated/emparams" "github.com/consensys/gnark/test" + + babybearcs "github.com/consensys/gnark/constraint/babybear" + bls12377cs "github.com/consensys/gnark/constraint/bls12-377" ) type NativeCircuit struct { @@ -49,25 +60,149 @@ type smallBN struct { } func (smallBN) BitsPerLimb() uint { - return 11 + return 10 } func (smallBN) NbLimbs() uint { - return 24 + return 26 } func TestEmulatedCircuit(t *testing.T) { assert := test.NewAssert(t) - err := test.IsSolved(&EmulatedCircuit[emparams.BN254Fp]{}, &EmulatedCircuit[emparams.BN254Fp]{A: emulated.ValueOf[emparams.BN254Fp](2), B: emulated.ValueOf[emparams.BN254Fp](4)}, goldilocks.Modulus()) - assert.Error(err) + a, err := rand.Int(rand.Reader, emparams.BN254Fp{}.Modulus()) + assert.NoError(err) + b := new(big.Int).Mul(a, a) + b.Mod(b, emparams.BN254Fp{}.Modulus()) + + err = test.IsSolved(&EmulatedCircuit[emparams.BN254Fp]{}, &EmulatedCircuit[emparams.BN254Fp]{A: emulated.ValueOf[emparams.BN254Fp](a), B: emulated.ValueOf[emparams.BN254Fp](b)}, ecc.BN254.ScalarField()) + assert.NoError(err) + + err = test.IsSolved(&EmulatedCircuit[emparams.BN254Fp]{}, &EmulatedCircuit[emparams.BN254Fp]{A: emulated.ValueOf[emparams.BN254Fp](a), B: emulated.ValueOf[emparams.BN254Fp](b)}, goldilocks.Modulus()) + assert.NoError(err) + + err = test.IsSolved(&EmulatedCircuit[smallBN]{}, &EmulatedCircuit[smallBN]{A: emulated.ValueOf[smallBN](a), B: emulated.ValueOf[smallBN](b)}, goldilocks.Modulus()) + assert.NoError(err) + + err = test.IsSolved(&EmulatedCircuit[emparams.BN254Fp]{}, &EmulatedCircuit[emparams.BN254Fp]{A: emulated.ValueOf[emparams.BN254Fp](a), B: emulated.ValueOf[emparams.BN254Fp](b)}, babybear.Modulus()) + assert.NoError(err) + + err = test.IsSolved(&EmulatedCircuit[smallBN]{}, &EmulatedCircuit[smallBN]{A: emulated.ValueOf[smallBN](a), B: emulated.ValueOf[smallBN](b)}, babybear.Modulus()) + assert.NoError(err) +} + +func TestCompileEmulatedCircuit(t *testing.T) { + assert := test.NewAssert(t) + f := babybear.Modulus() + + circuit := &EmulatedCircuit[smallBN]{} + assignemnt := &EmulatedCircuit[smallBN]{A: emulated.ValueOf[smallBN](2), B: emulated.ValueOf[smallBN](4)} + + ccs, err := frontend.Compile(f, scs.NewBuilder, circuit) + assert.NoError(err) + + w, err := frontend.NewWitness(assignemnt, f) + assert.NoError(err) + + res, err := ccs.Solve(w) + assert.NoError(err) + + tres, ok := res.(*babybearcs.SparseR1CSSolution) + assert.True(ok) + + fmt.Println(tres.L.String()) + fmt.Println(tres.R.String()) + fmt.Println(tres.O.String()) + + ccs2, err := frontend.Compile(f, r1cs.NewBuilder, circuit) + assert.NoError(err) + + res2, err := ccs2.Solve(w) + assert.NoError(err) + + tres2, ok := res2.(*babybearcs.R1CSSolution) + assert.True(ok) + + fmt.Println(tres2.W.String()) + fmt.Println(tres2.A.String()) + fmt.Println(tres2.B.String()) + fmt.Println(tres2.C.String()) +} + +type PairCircuit struct { + InG1 sw_bn254.G1Affine + InG2 sw_bn254.G2Affine + Res sw_bn254.GTEl +} + +func (c *PairCircuit) Define(api frontend.API) error { + pairing, err := sw_bn254.NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + pairing.AssertIsOnG1(&c.InG1) + pairing.AssertIsOnG2(&c.InG2) + res, err := pairing.Pair([]*sw_bn254.G1Affine{&c.InG1}, []*sw_bn254.G2Affine{&c.InG2}) + if err != nil { + return fmt.Errorf("pair: %w", err) + } + pairing.AssertIsEqual(res, &c.Res) + return nil +} - err = test.IsSolved(&EmulatedCircuit[smallBN]{}, &EmulatedCircuit[smallBN]{A: emulated.ValueOf[smallBN](2), B: emulated.ValueOf[smallBN](4)}, goldilocks.Modulus()) +func TestPairTestSolve(t *testing.T) { + assert := test.NewAssert(t) + p, q := randomG1G2Affines() + res, err := bn254.Pair([]bn254.G1Affine{p}, []bn254.G2Affine{q}) assert.NoError(err) + witness := PairCircuit{ + InG1: sw_bn254.NewG1Affine(p), + InG2: sw_bn254.NewG2Affine(q), + Res: sw_bn254.NewGTEl(res), + } + err = test.IsSolved(&PairCircuit{}, &witness, babybear.Modulus()) + assert.NoError(err) + + // ccs, err := frontend.Compile(babybear.Modulus(), scs.NewBuilder, &PairCircuit{}) + // assert.NoError(err) + // _ = ccs + + // w, err := frontend.NewWitness(&witness, babybear.Modulus()) + // assert.NoError(err) - err = test.IsSolved(&EmulatedCircuit[emparams.BN254Fp]{}, &EmulatedCircuit[emparams.BN254Fp]{A: emulated.ValueOf[emparams.BN254Fp](2), B: emulated.ValueOf[emparams.BN254Fp](4)}, babybear.Modulus()) - assert.Error(err) + // sol, err := ccs.Solve(w) + // assert.NoError(err) - err = test.IsSolved(&EmulatedCircuit[smallBN]{}, &EmulatedCircuit[smallBN]{A: emulated.ValueOf[smallBN](2), B: emulated.ValueOf[smallBN](4)}, babybear.Modulus()) + // tres, ok := sol.(*babybearcs.SparseR1CSSolution) + // assert.True(ok) + // fmt.Println(tres.L.Len()) + + ccs2, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &PairCircuit{}) + _ = ccs2 + assert.NoError(err) + w2, err := frontend.NewWitness(&witness, ecc.BLS12_377.ScalarField()) + assert.NoError(err) + sol2, err := ccs2.Solve(w2) assert.NoError(err) + tres, ok := sol2.(*bls12377cs.SparseR1CSSolution) + assert.True(ok) + fmt.Println(tres.L.Len()) +} + +func randomG1G2Affines() (bn254.G1Affine, bn254.G2Affine) { + _, _, G1AffGen, G2AffGen := bn254.Generators() + mod := bn254.ID.ScalarField() + s1, err := rand.Int(rand.Reader, mod) + if err != nil { + panic(err) + } + s2, err := rand.Int(rand.Reader, mod) + if err != nil { + panic(err) + } + var p bn254.G1Affine + p.ScalarMultiplication(&G1AffGen, s1) + var q bn254.G2Affine + q.ScalarMultiplication(&G2AffGen, s2) + return p, q } From e0377bb01677b2e158322a46d0f46901c7ee95b9 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Tue, 10 Dec 2024 00:26:20 +0000 Subject: [PATCH 20/21] XXX emulated parameter change for small fields --- std/math/emulated/emparams/emparams.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/std/math/emulated/emparams/emparams.go b/std/math/emulated/emparams/emparams.go index b6ac1ce9a..9ebff80eb 100644 --- a/std/math/emulated/emparams/emparams.go +++ b/std/math/emulated/emparams/emparams.go @@ -103,6 +103,14 @@ type BN254Fp struct{ fourLimbPrimeField } func (fp BN254Fp) Modulus() *big.Int { return ecc.BN254.BaseField() } +func (BN254Fp) BitsPerLimb() uint { + return 10 +} + +func (BN254Fp) NbLimbs() uint { + return 26 +} + // BN254Fr provides type parametrization for field emulation: // - limbs: 4 // - limb width: 64 bits @@ -117,6 +125,14 @@ type BN254Fr struct{ fourLimbPrimeField } func (fp BN254Fr) Modulus() *big.Int { return ecc.BN254.ScalarField() } +func (BN254Fr) BitsPerLimb() uint { + return 10 +} + +func (BN254Fr) NbLimbs() uint { + return 26 +} + // BLS12377Fp provides type parametrization for field emulation: // - limbs: 6 // - limb width: 64 bits From d1ec461fa757dedef6021780815941f3fcb2c2db Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Tue, 10 Dec 2024 00:35:42 +0000 Subject: [PATCH 21/21] XXX tests --- internal/smallfields/circuits_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/smallfields/circuits_test.go b/internal/smallfields/circuits_test.go index d0aea3262..c609e5d4f 100644 --- a/internal/smallfields/circuits_test.go +++ b/internal/smallfields/circuits_test.go @@ -96,12 +96,12 @@ func TestCompileEmulatedCircuit(t *testing.T) { f := babybear.Modulus() circuit := &EmulatedCircuit[smallBN]{} - assignemnt := &EmulatedCircuit[smallBN]{A: emulated.ValueOf[smallBN](2), B: emulated.ValueOf[smallBN](4)} + assignment := &EmulatedCircuit[smallBN]{A: emulated.ValueOf[smallBN](2), B: emulated.ValueOf[smallBN](4)} ccs, err := frontend.Compile(f, scs.NewBuilder, circuit) assert.NoError(err) - w, err := frontend.NewWitness(assignemnt, f) + w, err := frontend.NewWitness(assignment, f) assert.NoError(err) res, err := ccs.Solve(w)