diff --git a/.gotestfmt/downloads.gotpl b/.gotestfmt/downloads.gotpl deleted file mode 100644 index ca1cf92f55..0000000000 --- a/.gotestfmt/downloads.gotpl +++ /dev/null @@ -1,36 +0,0 @@ -{{- /*gotype: github.com/gotesttools/gotestfmt/v2/parser.Downloads*/ -}} -{{- /* -This template contains the format for a package download. -*/ -}} -{{- $settings := .Settings -}} -{{- if or .Packages .Reason -}} - {{- if or (not .Settings.HideSuccessfulDownloads) .Failed -}} - {{- if .Failed -}} - ❌ - {{- else -}} - 📥 - {{- end -}} - {{ " " }} Dependency downloads - {{ "\n" -}} - - {{- range .Packages -}} - {{- if or (not $settings.HideSuccessfulDownloads) .Failed -}} - {{- " " -}} - {{- if .Failed -}} - ❌ - {{- else -}} - 📦 - {{- end -}} - {{- " " -}} - {{- .Package }} {{ .Version -}} - {{- "\n" -}} - {{ with .Reason -}} - {{- " " -}}{{ . -}}{{ "\n" -}} - {{- end -}} - {{- end -}} - {{- end -}} - {{- with .Reason -}} - {{- " " -}}🛑 {{ . }}{{ "\n" -}} - {{- end -}} - {{- end -}} -{{- end -}} diff --git a/.gotestfmt/package.gotpl b/.gotestfmt/package.gotpl deleted file mode 100644 index 504949a86b..0000000000 --- a/.gotestfmt/package.gotpl +++ /dev/null @@ -1,42 +0,0 @@ -{{- /*gotype: github.com/gotesttools/gotestfmt/v2/parser.Package*/ -}} - -{{- $settings := .Settings -}} -{{- if and (or (not $settings.HideSuccessfulPackages) (ne .Result "PASS")) (or (not $settings.HideEmptyPackages) (ne .Result "SKIP") (ne (len .TestCases) 0)) -}} - 📦 `{{ .Name }}` - {{- with .Coverage -}} - ({{ . }}% coverage) - {{- end -}} - {{- "\n" -}} - {{- with .Reason -}} - {{- " " -}}🛑 {{ . -}}{{- "\n" -}} - {{- end -}} - {{- with .Output -}} - ```{{- "\n" -}} - {{- . -}}{{- "\n" -}} - ```{{- "\n" -}} - {{- end -}} - {{- with .TestCases -}} - {{- range . -}} - {{- if or (not $settings.HideSuccessfulTests) (ne .Result "PASS") -}} - {{- if eq .Result "PASS" -}} - ✅ - {{- else if eq .Result "SKIP" -}} - 🚧 - {{- else -}} - ❌ - {{- end -}} - {{ " " }}`{{- .Name -}}` {{ .Duration -}} - {{- "\n" -}} - - {{- with .Output -}} - ```{{- "\n" -}} - {{- formatTestOutput . $settings -}}{{- "\n" -}} - ```{{- "\n" -}} - {{- end -}} - - {{- "\n" -}} - {{- end -}} - {{- end -}} - {{- end -}} - {{- "\n" -}} -{{- end -}} diff --git a/internal/parallel/execute.go b/internal/parallel/execute.go new file mode 100644 index 0000000000..05f9a8f666 --- /dev/null +++ b/internal/parallel/execute.go @@ -0,0 +1,56 @@ +package parallel + +import ( + "runtime" + "sync" +) + +// Execute process in parallel the work function +func Execute(nbIterations int, work func(int, int), maxCpus ...int) { + + nbTasks := runtime.NumCPU() + if len(maxCpus) == 1 { + nbTasks = maxCpus[0] + if nbTasks < 1 { + nbTasks = 1 + } else if nbTasks > 512 { + nbTasks = 512 + } + } + + if nbTasks == 1 { + // no go routines + work(0, nbIterations) + return + } + + nbIterationsPerCpus := nbIterations / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = nbIterations + } + + var wg sync.WaitGroup + + extraTasks := nbIterations - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + go func() { + work(_start, _end) + wg.Done() + }() + } + + wg.Wait() +} diff --git a/std/fiat-shamir/settings.go b/std/fiat-shamir/settings.go index 146a64355e..2c475e83e4 100644 --- a/std/fiat-shamir/settings.go +++ b/std/fiat-shamir/settings.go @@ -3,6 +3,9 @@ package fiatshamir import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/hash" + "github.com/consensys/gnark/std/math/emulated" + gohash "hash" + "math/big" ) type Settings struct { @@ -12,6 +15,20 @@ type Settings struct { Hash hash.FieldHasher } +type SettingsBigInt struct { + Transcript *Transcript + Prefix string + BaseChallenges []big.Int + Hash gohash.Hash +} + +type SettingsEmulated[FR emulated.FieldParams] struct { + Transcript *Transcript + Prefix string + BaseChallenges []emulated.Element[FR] + Hash hash.FieldHasher +} + func WithTranscript(transcript *Transcript, prefix string, baseChallenges ...frontend.Variable) Settings { return Settings{ Transcript: transcript, @@ -20,9 +37,39 @@ func WithTranscript(transcript *Transcript, prefix string, baseChallenges ...fro } } +func WithTranscriptBigInt(transcript *Transcript, prefix string, baseChallenges ...big.Int) SettingsBigInt { + return SettingsBigInt{ + Transcript: transcript, + Prefix: prefix, + BaseChallenges: baseChallenges, + } +} + +func WithTranscriptFr[FR emulated.FieldParams](transcript *Transcript, prefix string, baseChallenges ...emulated.Element[FR]) SettingsEmulated[FR] { + return SettingsEmulated[FR]{ + Transcript: transcript, + Prefix: prefix, + BaseChallenges: baseChallenges, + } +} + func WithHash(hash hash.FieldHasher, baseChallenges ...frontend.Variable) Settings { return Settings{ BaseChallenges: baseChallenges, Hash: hash, } } + +func WithHashBigInt(hash gohash.Hash, baseChallenges ...big.Int) SettingsBigInt { + return SettingsBigInt{ + BaseChallenges: baseChallenges, + Hash: hash, + } +} + +func WithHashFr[FR emulated.FieldParams](hash hash.FieldHasher, baseChallenges ...emulated.Element[FR]) SettingsEmulated[FR] { + return SettingsEmulated[FR]{ + BaseChallenges: baseChallenges, + Hash: hash, + } +} diff --git a/std/gkr/gkr.go b/std/gkr/gkr.go index a715a9d98e..4da2629934 100644 --- a/std/gkr/gkr.go +++ b/std/gkr/gkr.go @@ -308,6 +308,7 @@ func Verify(api frontend.API, c Circuit, assignment WireAssignment, proof Proof, claims := newClaimsManager(c, assignment) var firstChallenge []frontend.Variable + // why no bind values here? firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) if err != nil { return err @@ -327,7 +328,7 @@ func Verify(api frontend.API, c Circuit, assignment WireAssignment, proof Proof, claim := claims.getLazyClaim(wire) if wire.noProof() { // input wires with one claim only // make sure the proof is empty - if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + if len(finalEvalProof) != 0 || len(proofW.RoundPolyEvaluations) != 0 { return fmt.Errorf("no proof allowed for input wire with a single claim") } @@ -470,16 +471,16 @@ func (a WireAssignment) NumVars() int { func (p Proof) Serialize() []frontend.Variable { size := 0 for i := range p { - for j := range p[i].PartialSumPolys { - size += len(p[i].PartialSumPolys[j]) + for j := range p[i].RoundPolyEvaluations { + size += len(p[i].RoundPolyEvaluations[j]) } size += len(p[i].FinalEvalProof.([]frontend.Variable)) } res := make([]frontend.Variable, 0, size) for i := range p { - for j := range p[i].PartialSumPolys { - res = append(res, p[i].PartialSumPolys[j]...) + for j := range p[i].RoundPolyEvaluations { + res = append(res, p[i].RoundPolyEvaluations[j]...) } res = append(res, p[i].FinalEvalProof.([]frontend.Variable)...) } @@ -519,9 +520,9 @@ func DeserializeProof(sorted []*Wire, serializedProof []frontend.Variable) (Proo reader := variablesReader(serializedProof) for i, wI := range sorted { if !wI.noProof() { - proof[i].PartialSumPolys = make([]polynomial.Polynomial, logNbInstances) - for j := range proof[i].PartialSumPolys { - proof[i].PartialSumPolys[j] = reader.nextN(wI.Gate.Degree() + 1) + proof[i].RoundPolyEvaluations = make([]polynomial.Polynomial, logNbInstances) + for j := range proof[i].RoundPolyEvaluations { + proof[i].RoundPolyEvaluations[j] = reader.nextN(wI.Gate.Degree() + 1) } } proof[i].FinalEvalProof = reader.nextN(wI.nbUniqueInputs()) diff --git a/std/gkr/gkr_test.go b/std/gkr/gkr_test.go index d24b25a95c..31ebd5c112 100644 --- a/std/gkr/gkr_test.go +++ b/std/gkr/gkr_test.go @@ -3,6 +3,7 @@ package gkr import ( "encoding/json" "fmt" + "math/big" "os" "path/filepath" "reflect" @@ -165,8 +166,8 @@ type TestCase struct { type TestCaseInfo struct { Hash HashDescription `json:"hash"` Circuit string `json:"circuit"` - Input [][]interface{} `json:"input"` - Output [][]interface{} `json:"output"` + Input [][]big.Int `json:"input"` + Output [][]big.Int `json:"output"` Proof PrintableProof `json:"proof"` } @@ -275,8 +276,8 @@ func (g _select) Degree() int { type PrintableProof []PrintableSumcheckProof type PrintableSumcheckProof struct { - FinalEvalProof interface{} `json:"finalEvalProof"` - PartialSumPolys [][]interface{} `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` + RoundPolyEvaluations [][]interface{} `json:"roundPolyEvaluations"` } func unmarshalProof(printable PrintableProof) (proof Proof) { @@ -294,9 +295,9 @@ func unmarshalProof(printable PrintableProof) (proof Proof) { proof[i].FinalEvalProof = nil } - proof[i].PartialSumPolys = make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)) - for k := range printable[i].PartialSumPolys { - proof[i].PartialSumPolys[k] = ToVariableSlice(printable[i].PartialSumPolys[k]) + proof[i].RoundPolyEvaluations = make([]polynomial.Polynomial, len(printable[i].RoundPolyEvaluations)) + for k := range printable[i].RoundPolyEvaluations { + proof[i].RoundPolyEvaluations[k] = ToVariableSlice(printable[i].RoundPolyEvaluations[k]) } } return @@ -327,7 +328,6 @@ func TestLoadCircuit(t *testing.T) { assert.Equal(t, []*Wire{}, c[0].Inputs) assert.Equal(t, []*Wire{&c[0]}, c[1].Inputs) assert.Equal(t, []*Wire{&c[1]}, c[2].Inputs) - } func TestTopSortTrivial(t *testing.T) { diff --git a/std/gkr/test_vectors/single_identity_gate_two_instances.json b/std/gkr/test_vectors/single_identity_gate_two_instances.json index ce326d0a63..fa38a03cb6 100644 --- a/std/gkr/test_vectors/single_identity_gate_two_instances.json +++ b/std/gkr/test_vectors/single_identity_gate_two_instances.json @@ -19,13 +19,13 @@ "proof": [ { "finalEvalProof": [], - "partialSumPolys": [] + "roundPolyEvaluations": [] }, { "finalEvalProof": [ 5 ], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ -3, -8 diff --git a/std/gkr/test_vectors/single_input_two_identity_gates_two_instances.json b/std/gkr/test_vectors/single_input_two_identity_gates_two_instances.json index 2c95f044f2..a995f7197a 100644 --- a/std/gkr/test_vectors/single_input_two_identity_gates_two_instances.json +++ b/std/gkr/test_vectors/single_input_two_identity_gates_two_instances.json @@ -23,7 +23,7 @@ "proof": [ { "finalEvalProof": [], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ 0, 0 @@ -34,7 +34,7 @@ "finalEvalProof": [ 1 ], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ -3, -16 @@ -45,7 +45,7 @@ "finalEvalProof": [ 1 ], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ -3, -16 diff --git a/std/gkr/test_vectors/single_input_two_outs_two_instances.json b/std/gkr/test_vectors/single_input_two_outs_two_instances.json index d348303d0e..6dace72193 100644 --- a/std/gkr/test_vectors/single_input_two_outs_two_instances.json +++ b/std/gkr/test_vectors/single_input_two_outs_two_instances.json @@ -23,7 +23,7 @@ "proof": [ { "finalEvalProof": [], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ 0, 0 @@ -34,7 +34,7 @@ "finalEvalProof": [ 0 ], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ -4, -36, @@ -46,7 +46,7 @@ "finalEvalProof": [ 0 ], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ -2, -12 diff --git a/std/gkr/test_vectors/single_mimc_gate_four_instances.json b/std/gkr/test_vectors/single_mimc_gate_four_instances.json index 525459ecb1..1162e56f36 100644 --- a/std/gkr/test_vectors/single_mimc_gate_four_instances.json +++ b/std/gkr/test_vectors/single_mimc_gate_four_instances.json @@ -29,18 +29,18 @@ "proof": [ { "finalEvalProof": [], - "partialSumPolys": [] + "roundPolyEvaluations": [] }, { "finalEvalProof": [], - "partialSumPolys": [] + "roundPolyEvaluations": [] }, { "finalEvalProof": [ -1, -3 ], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ -32640, -2239484, diff --git a/std/gkr/test_vectors/single_mimc_gate_two_instances.json b/std/gkr/test_vectors/single_mimc_gate_two_instances.json index 7fa23ce4b1..12d7755dd5 100644 --- a/std/gkr/test_vectors/single_mimc_gate_two_instances.json +++ b/std/gkr/test_vectors/single_mimc_gate_two_instances.json @@ -23,18 +23,18 @@ "proof": [ { "finalEvalProof": [], - "partialSumPolys": [] + "roundPolyEvaluations": [] }, { "finalEvalProof": [], - "partialSumPolys": [] + "roundPolyEvaluations": [] }, { "finalEvalProof": [ 1, 0 ], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ -2187, -65536, diff --git a/std/gkr/test_vectors/single_mul_gate_two_instances.json b/std/gkr/test_vectors/single_mul_gate_two_instances.json index 75c1d59c3d..ba854e37f5 100644 --- a/std/gkr/test_vectors/single_mul_gate_two_instances.json +++ b/std/gkr/test_vectors/single_mul_gate_two_instances.json @@ -23,18 +23,18 @@ "proof": [ { "finalEvalProof": [], - "partialSumPolys": [] + "roundPolyEvaluations": [] }, { "finalEvalProof": [], - "partialSumPolys": [] + "roundPolyEvaluations": [] }, { "finalEvalProof": [ 5, 1 ], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ -9, -32, diff --git a/std/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json b/std/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json index 10e5f1ff3c..e145c7d18d 100644 --- a/std/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json +++ b/std/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json @@ -19,13 +19,13 @@ "proof": [ { "finalEvalProof": [], - "partialSumPolys": [] + "roundPolyEvaluations": [] }, { "finalEvalProof": [ 3 ], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ -1, 0 @@ -36,7 +36,7 @@ "finalEvalProof": [ 3 ], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ -1, 0 diff --git a/std/gkr/test_vectors/two_inputs_select-input-3_gate_two_instances.json b/std/gkr/test_vectors/two_inputs_select-input-3_gate_two_instances.json index 19e127df71..e972222802 100644 --- a/std/gkr/test_vectors/two_inputs_select-input-3_gate_two_instances.json +++ b/std/gkr/test_vectors/two_inputs_select-input-3_gate_two_instances.json @@ -23,18 +23,18 @@ "proof": [ { "finalEvalProof": [], - "partialSumPolys": [] + "roundPolyEvaluations": [] }, { "finalEvalProof": [], - "partialSumPolys": [] + "roundPolyEvaluations": [] }, { "finalEvalProof": [ -1, 1 ], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ -3, -16 diff --git a/std/math/emulated/field.go b/std/math/emulated/field.go index 6c1f19b04d..9dbc471e25 100644 --- a/std/math/emulated/field.go +++ b/std/math/emulated/field.go @@ -245,7 +245,7 @@ func (f *Field[T]) constantValue(v *Element[T]) (*big.Int, bool) { constLimbs := make([]*big.Int, len(v.Limbs)) for i, l := range v.Limbs { // for each limb we get it's constant value if we can, or fail. - if constLimbs[i], ok = f.api.ConstantValue(l); !ok { + if constLimbs[i], ok = f.api.Compiler().ConstantValue(l); !ok { return nil, false } } diff --git a/std/math/emulated/field_assert.go b/std/math/emulated/field_assert.go index 5c2c700663..ac20b22b0b 100644 --- a/std/math/emulated/field_assert.go +++ b/std/math/emulated/field_assert.go @@ -34,6 +34,7 @@ func (f *Field[T]) enforceWidth(a *Element[T], modWidth bool) { func (f *Field[T]) AssertIsEqual(a, b *Element[T]) { f.enforceWidthConditional(a) f.enforceWidthConditional(b) + ba, aConst := f.constantValue(a) bb, bConst := f.constantValue(b) if aConst && bConst { diff --git a/std/math/emulated/field_mul.go b/std/math/emulated/field_mul.go index 278b9a5024..5177873adf 100644 --- a/std/math/emulated/field_mul.go +++ b/std/math/emulated/field_mul.go @@ -115,47 +115,98 @@ func (mc *mulCheck[T]) cleanEvaluations() { // mulMod returns a*b mod r. In practice it computes the result using a hint and // defers the actual multiplication check. func (f *Field[T]) mulMod(a, b *Element[T], _ uint, p *Element[T]) *Element[T] { - f.enforceWidthConditional(a) - f.enforceWidthConditional(b) - f.enforceWidthConditional(p) - k, r, c, err := f.callMulHint(a, b, true, p) - if err != nil { - panic(err) - } - mc := mulCheck[T]{ - f: f, - a: a, - b: b, - c: c, - k: k, - r: r, - p: p, - } - f.mulChecks = append(f.mulChecks, mc) - return r + return f.mulModProfiling(a, b, p, true) + // f.enforceWidthConditional(a) + // f.enforceWidthConditional(b) + // f.enforceWidthConditional(p) + // k, r, c, err := f.callMulHint(a, b, true, p) + // if err != nil { + // panic(err) + // } + // mc := mulCheck[T]{ + // f: f, + // a: a, + // b: b, + // c: c, + // k: k, + // r: r, + // p: p, + // } + // f.mulChecks = append(f.mulChecks, mc) + // return r } // checkZero creates multiplication check a * 1 = 0 + k*p. func (f *Field[T]) checkZero(a *Element[T], p *Element[T]) { + f.mulModProfiling(a, f.shortOne(), p, false) // the method works similarly to mulMod, but we know that we are multiplying // by one and expected result should be zero. + // f.enforceWidthConditional(a) + // f.enforceWidthConditional(p) + // b := f.shortOne() + // k, r, c, err := f.callMulHint(a, b, false, p) + // if err != nil { + // panic(err) + // } + // mc := mulCheck[T]{ + // f: f, + // a: a, + // b: b, // one on single limb to speed up the polynomial evaluation + // c: c, + // k: k, + // r: r, // expected to be zero on zero limbs. + // p: p, + // } + // f.mulChecks = append(f.mulChecks, mc) +} + +func (f *Field[T]) mulModProfiling(a, b *Element[T], p *Element[T], isMulMod bool) *Element[T] { f.enforceWidthConditional(a) - f.enforceWidthConditional(p) - b := f.shortOne() - k, r, c, err := f.callMulHint(a, b, false, p) + f.enforceWidthConditional(b) + k, r, c, err := f.callMulHint(a, b, isMulMod, p) if err != nil { panic(err) } mc := mulCheck[T]{ f: f, a: a, - b: b, // one on single limb to speed up the polynomial evaluation + b: b, c: c, k: k, - r: r, // expected to be zero on zero limbs. - p: p, + r: r, } - f.mulChecks = append(f.mulChecks, mc) + var toCommit []frontend.Variable + toCommit = append(toCommit, mc.a.Limbs...) + toCommit = append(toCommit, mc.b.Limbs...) + toCommit = append(toCommit, mc.r.Limbs...) + toCommit = append(toCommit, mc.k.Limbs...) + toCommit = append(toCommit, mc.c.Limbs...) + multicommit.WithCommitment(f.api, func(api frontend.API, commitment frontend.Variable) error { + // we do nothing. We just want to ensure that we count the commitments + return nil + }, toCommit...) + // XXX: or use something variable to count the commitments and constraints properly. Maybe can use 123 from hint? + commitment := 123 + + // for efficiency, we compute all powers of the challenge as slice at. + coefsLen := max(len(mc.a.Limbs), len(mc.b.Limbs), + len(mc.c.Limbs), len(mc.k.Limbs)) + at := make([]frontend.Variable, coefsLen) + at[0] = commitment + for i := 1; i < len(at); i++ { + at[i] = f.api.Mul(at[i-1], commitment) + } + mc.evalRound1(at) + mc.evalRound2(at) + // evaluate p(X) at challenge + pval := f.evalWithChallenge(f.Modulus(), at) + // compute (2^t-X) at challenge + coef := big.NewInt(1) + coef.Lsh(coef, f.fParams.BitsPerLimb()) + ccoef := f.api.Sub(coef, commitment) + // verify all mulchecks + mc.check(f.api, pval.evaluation, ccoef) + return r } // evalWithChallenge represents element a as a polynomial a(X) and evaluates at diff --git a/std/math/emulated/field_ops.go b/std/math/emulated/field_ops.go index a9f0d9cda3..2ef1f26889 100644 --- a/std/math/emulated/field_ops.go +++ b/std/math/emulated/field_ops.go @@ -3,10 +3,10 @@ package emulated import ( "errors" "fmt" - "math/bits" - "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/selector" + "math/big" + "math/bits" ) // Div computes a/b and returns it. It uses [DivHint] as a hint function. @@ -368,3 +368,37 @@ type overflowError struct { func (e overflowError) Error() string { return fmt.Sprintf("op %s overflow %d exceeds max %d", e.op, e.nextOverflow, e.maxOverflow) } + +func (f *Field[T]) String(a *Element[T]) string { + // for debug only, if is not test engine then no-op + var fp T + blimbs := make([]*big.Int, len(a.Limbs)) + for i, v := range a.Limbs { + switch vv := v.(type) { + case *big.Int: + blimbs[i] = vv + case big.Int: + blimbs[i] = &vv + case int: + blimbs[i] = new(big.Int) + blimbs[i].SetInt64(int64(vv)) + case uint: + blimbs[i] = new(big.Int) + blimbs[i].SetUint64(uint64(vv)) + default: + return "???" + } + } + res := new(big.Int) + err := recompose(blimbs, fp.BitsPerLimb(), res) + if err != nil { + return "!!!" + } + reduced := new(big.Int).Mod(res, fp.Modulus()) + return reduced.String() +} + +func (f *Field[T]) Println(a *Element[T]) { + res := f.String(a) + fmt.Println(res) +} diff --git a/std/math/polynomial/polynomial.go b/std/math/polynomial/polynomial.go index e09ef69ef1..511dc4107d 100644 --- a/std/math/polynomial/polynomial.go +++ b/std/math/polynomial/polynomial.go @@ -22,6 +22,10 @@ type Univariate[FR emulated.FieldParams] []emulated.Element[FR] // coefficients. type Multilinear[FR emulated.FieldParams] []emulated.Element[FR] +func (ml *Multilinear[FR]) NumVars() int { + return bits.Len(uint(len(*ml) - 1)) +} + func valueOf[FR emulated.FieldParams](univ []*big.Int) []emulated.Element[FR] { ret := make([]emulated.Element[FR], len(univ)) for i := range univ { @@ -61,6 +65,9 @@ type Polynomial[FR emulated.FieldParams] struct { // FromSlice maps slice of emulated element values to their references. func FromSlice[FR emulated.FieldParams](in []emulated.Element[FR]) []*emulated.Element[FR] { + if len(in) == 0 { + return []*emulated.Element[FR]{} + } r := make([]*emulated.Element[FR], len(in)) for i := range in { r[i] = &in[i] diff --git a/std/polynomial/polynomial.go b/std/polynomial/polynomial.go index 0953cb3ac7..4bd0940023 100644 --- a/std/polynomial/polynomial.go +++ b/std/polynomial/polynomial.go @@ -3,6 +3,7 @@ package polynomial import ( "math/bits" + "github.com/consensys/gnark-crypto/utils" "github.com/consensys/gnark/frontend" ) @@ -11,6 +12,53 @@ type MultiLin []frontend.Variable var minFoldScaledLogSize = 16 +func _clone(m MultiLin, p *Pool) MultiLin { + if p == nil { + return m.Clone() + } else { + return p.Clone(m) + } +} + +func _dump(m MultiLin, p *Pool) { + if p != nil { + p.Dump(m) + } +} + +// Evaluate assumes len(m) = 1 << len(at) +// it doesn't modify m +func (m MultiLin) EvaluatePool(api frontend.API, at []frontend.Variable, pool *Pool) frontend.Variable { + _m := _clone(m, pool) + + /*minFoldScaledLogSize := 16 + if api is r1cs { + minFoldScaledLogSize = math.MaxInt64 // no scaling for r1cs + }*/ + + scaleCorrectionFactor := frontend.Variable(1) + // at each iteration fold by at[i] + for len(_m) > 1 { + if len(_m) >= minFoldScaledLogSize { + scaleCorrectionFactor = api.Mul(scaleCorrectionFactor, _m.foldScaled(api, at[0])) + } else { + _m.Fold(api, at[0]) + } + _m = _m[:len(_m)/2] + at = at[1:] + } + + if len(at) != 0 { + panic("incompatible evaluation vector size") + } + + result := _m[0] + + _dump(_m, pool) + + return api.Mul(result, scaleCorrectionFactor) +} + // Evaluate assumes len(m) = 1 << len(at) // it doesn't modify m func (m MultiLin) Evaluate(api frontend.API, at []frontend.Variable) frontend.Variable { @@ -27,7 +75,7 @@ func (m MultiLin) Evaluate(api frontend.API, at []frontend.Variable) frontend.Va if len(_m) >= minFoldScaledLogSize { scaleCorrectionFactor = api.Mul(scaleCorrectionFactor, _m.foldScaled(api, at[0])) } else { - _m.fold(api, at[0]) + _m.Fold(api, at[0]) } _m = _m[:len(_m)/2] at = at[1:] @@ -42,7 +90,7 @@ func (m MultiLin) Evaluate(api frontend.API, at []frontend.Variable) frontend.Va // fold fixes the value of m's first variable to at, thus halving m's required bookkeeping table size // WARNING: The user should halve m themselves after the call -func (m MultiLin) fold(api frontend.API, at frontend.Variable) { +func (m MultiLin) Fold(api frontend.API, at frontend.Variable) { zero := m[:len(m)/2] one := m[len(m)/2:] for j := range zero { @@ -51,6 +99,43 @@ func (m MultiLin) fold(api frontend.API, at frontend.Variable) { } } +func (m *MultiLin) FoldParallel(api frontend.API, r frontend.Variable) utils.Task { + mid := len(*m) / 2 + bottom, top := (*m)[:mid], (*m)[mid:] + + *m = bottom + + return func(start, end int) { + var t frontend.Variable // no need to update the top part + for i := start; i < end; i++ { + // table[i] ← table[i] + r (table[i + mid] - table[i]) + t = api.Sub(&top[i], &bottom[i]) + t = api.Mul(&t, &r) + bottom[i] = api.Add(&bottom[i], &t) + } + } +} + +// Eq sets m to the representation of the polynomial Eq(q₁, ..., qₙ, *, ..., *) × m[0] +func (m *MultiLin) Eq(api frontend.API, q []frontend.Variable) { + n := len(q) + + if len(*m) != 1< p.subPools[poolI].maxN { + poolI++ + } + return &p.subPools[poolI] // out of bounds error here would mean that n is too large +} + +func (p *Pool) Make(n int) []frontend.Variable { + pool := p.findCorrespondingPool(n) + ptr := pool.get(n) + p.addInUse(ptr, pool) + return unsafe.Slice(ptr, n) +} + +// Dump dumps a set of polynomials into the pool +func (p *Pool) Dump(slices ...[]frontend.Variable) { + for _, slice := range slices { + ptr := getDataPointer(slice) + if metadata, ok := p.inUse.Load(ptr); ok { + p.inUse.Delete(ptr) + metadata.(inUseData).pool.put(ptr) + } else { + panic("attempting to dump a slice not created by the pool") + } + } +} + +func (p *Pool) addInUse(ptr *frontend.Variable, pool *sizedPool) { + pcs := make([]uintptr, 2) + n := runtime.Callers(3, pcs) + + if prevPcs, ok := p.inUse.Load(ptr); ok { // TODO: remove if unnecessary for security + panic(fmt.Errorf("re-allocated non-dumped slice, previously allocated at %v", runtime.CallersFrames(prevPcs.(inUseData).allocatedFor))) + } + p.inUse.Store(ptr, inUseData{ + allocatedFor: pcs[:n], + pool: pool, + }) +} + +func printFrame(frame runtime.Frame) { + fmt.Printf("\t%s line %d, function %s\n", frame.File, frame.Line, frame.Function) +} + +func (p *Pool) printInUse() { + fmt.Println("slices never dumped allocated at:") + p.inUse.Range(func(_, pcs any) bool { + fmt.Println("-------------------------") + + var frame runtime.Frame + frames := runtime.CallersFrames(pcs.(inUseData).allocatedFor) + more := true + for more { + frame, more = frames.Next() + printFrame(frame) + } + return true + }) +} + +type poolStats struct { + Used int + Allocated int + ReuseRate float64 + InUse int + GreatestNUsed int + SmallestNUsed int +} + +type poolsStats struct { + SubPools []poolStats + InUse int +} + +func (s *poolStats) make(n int) { + s.Used++ + s.InUse++ + if n > s.GreatestNUsed { + s.GreatestNUsed = n + } + if s.SmallestNUsed == 0 || s.SmallestNUsed > n { + s.SmallestNUsed = n + } +} + +func (s *poolStats) dump() { + s.InUse-- +} + +func (s *poolStats) finalize() { + s.ReuseRate = float64(s.Used) / float64(s.Allocated) +} + +func getDataPointer(slice []frontend.Variable) *frontend.Variable { + header := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) + return (*frontend.Variable)(unsafe.Pointer(header.Data)) +} + +func (p *Pool) PrintPoolStats() { + InUse := 0 + subStats := make([]poolStats, len(p.subPools)) + for i := range p.subPools { + subPool := &p.subPools[i] + subPool.stats.finalize() + subStats[i] = subPool.stats + InUse += subPool.stats.InUse + } + + stats := poolsStats{ + SubPools: subStats, + InUse: InUse, + } + serialized, _ := json.MarshalIndent(stats, "", " ") + fmt.Println(string(serialized)) + p.printInUse() +} + +func (p *Pool) Clone(slice []frontend.Variable) []frontend.Variable { + res := p.Make(len(slice)) + copy(res, slice) + return res +} diff --git a/std/recursion/gkr/gkr_nonnative.go b/std/recursion/gkr/gkr_nonnative.go new file mode 100644 index 0000000000..0f3de809e5 --- /dev/null +++ b/std/recursion/gkr/gkr_nonnative.go @@ -0,0 +1,2101 @@ +package gkrnonative + +import ( + "fmt" + cryptofiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/parallel" + fiatshamir "github.com/consensys/gnark/std/fiat-shamir" + "github.com/consensys/gnark/std/math/bits" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/math/polynomial" + "github.com/consensys/gnark/std/recursion" + "github.com/consensys/gnark/std/recursion/sumcheck" + "math/big" + "slices" + "strconv" +) + +// Gate must be a low-degree polynomial +type Gate interface { + Evaluate(*sumcheck.BigIntEngine, ...*big.Int) []*big.Int + Degree() int + NbInputs() int + NbOutputs() int + GetName() string +} + +type WireBundle struct { + Gate Gate + Layer int + Depth int + Inputs []*Wires // if there are no Inputs, the wire is assumed an input wire + Outputs []*Wires `SameBundle:"true"` // if there are no Outputs, the wire is assumed an output wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +func bundleKey(wireBundle *WireBundle) string { + return fmt.Sprintf("%d-%s", wireBundle.Layer, wireBundle.Gate.GetName()) +} + +func bundleKeyEmulated[FR emulated.FieldParams](wireBundle *WireBundleEmulated[FR]) string { + return fmt.Sprintf("%d-%s", wireBundle.Layer, wireBundle.Gate.GetName()) +} + +// InitFirstWireBundle initializes the first WireBundle for Layer 0 padded with IdentityGate as relayer +func InitFirstWireBundle(inputsLen int, numLayers int) WireBundle { + gate := IdentityGate[*sumcheck.BigIntEngine, *big.Int]{Arity: inputsLen} + inputs := make([]*Wires, inputsLen) + for i := 0; i < inputsLen; i++ { + inputs[i] = &Wires{ + SameBundle: true, + BundleIndex: -1, + BundleLength: inputsLen, + WireIndex: i, + nbUniqueOutputs: 0, + } + } + + outputs := make([]*Wires, gate.NbOutputs()) + for i := 0; i < len(outputs); i++ { + outputs[i] = &Wires{ + SameBundle: true, + BundleIndex: 0, + BundleLength: len(outputs), + WireIndex: i, + nbUniqueOutputs: 0, + } + } + + return WireBundle{ + Gate: gate, + Layer: 0, + Depth: numLayers, + Inputs: inputs, + Outputs: outputs, + nbUniqueOutputs: 0, + } +} + +// NewWireBundle connects previous output wires to current input wires and initializes the current output wires +func NewWireBundle(gate Gate, inputWires []*Wires, layer int, numLayers int) WireBundle { + inputs := make([]*Wires, len(inputWires)) + for i := 0; i < len(inputWires); i++ { + inputs[i] = &Wires{ + SameBundle: true, + BundleIndex: layer - 1, //takes inputs from previous layer + BundleLength: len(inputs), + WireIndex: i, + nbUniqueOutputs: inputWires[i].nbUniqueOutputs, + } + } + + outputs := make([]*Wires, gate.NbOutputs()) + for i := 0; i < len(outputs); i++ { + outputs[i] = &Wires{ + SameBundle: true, + BundleIndex: layer, + BundleLength: len(outputs), + WireIndex: i, + nbUniqueOutputs: 0, + } + } + + return WireBundle{ + Gate: gate, + Layer: layer, + Depth: numLayers, + Inputs: inputs, + Outputs: outputs, + nbUniqueOutputs: 0, + } +} + +type Wires struct { + SameBundle bool + BundleIndex int + BundleLength int + WireIndex int + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Wire struct { + Gate Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +// Gate must be a low-degree polynomial +type GateEmulated[FR emulated.FieldParams] interface { + Evaluate(*sumcheck.EmuEngine[FR], ...*emulated.Element[FR]) []*emulated.Element[FR] + NbInputs() int + NbOutputs() int + Degree() int + GetName() string +} + +type WireEmulated[FR emulated.FieldParams] struct { + Gate GateEmulated[FR] + Inputs []*WireEmulated[FR] // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type WireBundleEmulated[FR emulated.FieldParams] struct { + Gate GateEmulated[FR] + Layer int + Depth int + Inputs []*Wires // if there are no Inputs, the wire is assumed an input wire + Outputs []*Wires `SameBundle:"true"` // if there are no Outputs, the wire is assumed an output wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +// InitFirstWireBundle initializes the first WireBundle for Layer 0 padded with IdentityGate as relayer +func InitFirstWireBundleEmulated[FR emulated.FieldParams](inputsLen int, numLayers int) WireBundleEmulated[FR] { + gate := IdentityGate[*sumcheck.EmuEngine[FR], *emulated.Element[FR]]{Arity: inputsLen} + inputs := make([]*Wires, inputsLen) + for i := 0; i < inputsLen; i++ { + inputs[i] = &Wires{ + SameBundle: true, + BundleIndex: -1, + BundleLength: inputsLen, + WireIndex: i, + nbUniqueOutputs: 0, + } + } + + outputs := make([]*Wires, gate.NbOutputs()) + for i := 0; i < len(outputs); i++ { + outputs[i] = &Wires{ + SameBundle: true, + BundleIndex: 0, + BundleLength: len(outputs), + WireIndex: i, + nbUniqueOutputs: 0, + } + } + + return WireBundleEmulated[FR]{ + Gate: gate, + Layer: 0, + Depth: numLayers, + Inputs: inputs, + Outputs: outputs, + nbUniqueOutputs: 0, + } +} + +// NewWireBundle connects previous output wires to current input wires and initializes the current output wires +func NewWireBundleEmulated[FR emulated.FieldParams](gate GateEmulated[FR], inputWires []*Wires, layer int, numLayers int) WireBundleEmulated[FR] { + inputs := make([]*Wires, len(inputWires)) + for i := 0; i < len(inputWires); i++ { + inputs[i] = &Wires{ + SameBundle: true, + BundleIndex: layer - 1, + BundleLength: len(inputs), + WireIndex: i, + nbUniqueOutputs: inputWires[i].nbUniqueOutputs, + } + } + + outputs := make([]*Wires, gate.NbOutputs()) + for i := 0; i < len(outputs); i++ { + outputs[i] = &Wires{ + SameBundle: true, + BundleIndex: layer, + BundleLength: len(outputs), + WireIndex: i, + nbUniqueOutputs: 0, + } + } + + return WireBundleEmulated[FR]{ + Gate: gate, + Layer: layer, + Depth: numLayers, + Inputs: inputs, + Outputs: outputs, + nbUniqueOutputs: 0, + } +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) nbUniqueInputs() int { + set := make(map[*Wire]struct{}, len(w.Inputs)) + for _, in := range w.Inputs { + set[in] = struct{}{} + } + return len(set) +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +type CircuitBundle []WireBundle + +func (w WireBundle) IsInput() bool { + return w.Layer == 0 +} + +func (w WireBundle) IsOutput() bool { + return w.Layer == w.Depth - 1 + //return w.nbUniqueOutputs == 0 && w.Layer != 0 +} + +func (w WireBundle) NbClaims() int { + //todo check this + if w.IsOutput() { + return w.Gate.NbOutputs() + } + return w.nbUniqueOutputs +} + +func (w WireBundle) nbUniqueInputs() int { + set := make(map[*Wires]struct{}, len(w.Inputs)) + for _, in := range w.Inputs { + set[in] = struct{}{} + } + return len(set) +} + +func (w WireBundle) noProof() bool { + return w.IsInput() // && w.NbClaims() == 1 +} + +func (c Circuit) maxGateDegree() int { + res := 1 + for i := range c { + if !c[i].IsInput() { + res = utils.Max(res, c[i].Gate.Degree()) + } + } + return res +} + +type CircuitBundleEmulated[FR emulated.FieldParams] []WireBundleEmulated[FR] +//todo change these methods +func (w WireBundleEmulated[FR]) IsInput() bool { + return w.Layer == 0 +} + +func (w WireBundleEmulated[FR]) IsOutput() bool { + return w.Layer == w.Depth - 1 + //return w.nbUniqueOutputs == 0 +} + +//todo check this - assuming single claim per individual wire +func (w WireBundleEmulated[FR]) NbClaims() int { + return w.Gate.NbOutputs() + // if w.IsOutput() { + // return 1 + // } + //return w.nbUniqueOutputs +} + +func (w WireBundleEmulated[FR]) nbUniqueInputs() int { + set := make(map[*Wires]struct{}, len(w.Inputs)) + for _, in := range w.Inputs { + set[in] = struct{}{} + } + + return len(set) +} + +func (w WireBundleEmulated[FR]) noProof() bool { + return w.IsInput() // && w.NbClaims() == 1 +} + +type CircuitEmulated[FR emulated.FieldParams] []WireEmulated[FR] + +func (w WireEmulated[FR]) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w WireEmulated[FR]) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w WireEmulated[FR]) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w WireEmulated[FR]) nbUniqueInputs() int { + set := make(map[*WireEmulated[FR]]struct{}, len(w.Inputs)) + for _, in := range w.Inputs { + set[in] = struct{}{} + } + return len(set) +} + +func (w WireEmulated[FR]) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[string]sumcheck.NativeMultilinear + +type WireAssignmentBundle map[*WireBundle]WireAssignment + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignmentEmulated[FR emulated.FieldParams] map[string]polynomial.Multilinear[FR] + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignmentBundleEmulated[FR emulated.FieldParams] map[*WireBundleEmulated[FR]]WireAssignmentEmulated[FR] + +type Proofs[FR emulated.FieldParams] []sumcheck.Proof[FR] // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +type eqTimesGateEvalSumcheckLazyClaimsEmulated[FR emulated.FieldParams] struct { + wire *Wires + commonGate GateEmulated[FR] + evaluationPoints [][]emulated.Element[FR] + claimedEvaluations []emulated.Element[FR] + manager *claimsManagerEmulated[FR] // WARNING: Circular references + verifier *GKRVerifier[FR] + engine *sumcheck.EmuEngine[FR] +} + +func (e *eqTimesGateEvalSumcheckLazyClaimsEmulated[FR]) NbClaims() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaimsEmulated[FR]) NbVars() int { + return len(e.evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaimsEmulated[FR]) CombinedSum(a *emulated.Element[FR]) *emulated.Element[FR] { + evalsAsPoly := polynomial.Univariate[FR](e.claimedEvaluations) + return e.verifier.p.EvalUnivariate(evalsAsPoly, a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaimsEmulated[FR]) Degree(int) int { + return 1 + e.commonGate.Degree() +} + +type eqTimesGateEvalSumcheckLazyClaimsBundleEmulated[FR emulated.FieldParams] struct { + wireBundle *WireBundleEmulated[FR] + claimsMapOutputsLazy map[string]*eqTimesGateEvalSumcheckLazyClaimsEmulated[FR] + claimsMapInputsLazy map[string]*eqTimesGateEvalSumcheckLazyClaimsEmulated[FR] + verifier *GKRVerifier[FR] + engine *sumcheck.EmuEngine[FR] +} + +func (e *eqTimesGateEvalSumcheckLazyClaimsBundleEmulated[FR]) addOutput(wire *Wires, evaluationPoint []emulated.Element[FR], evaluation emulated.Element[FR]) { + claim := e.claimsMapOutputsLazy[wireKey(wire)] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +// todo assuming single claim per wire +func (e *eqTimesGateEvalSumcheckLazyClaimsBundleEmulated[FR]) NbClaims() int { + return len(e.claimsMapOutputsLazy) +} + +// to batch sumchecks in the bundle all claims should have the same number of variables - taking first outputwire +func (e *eqTimesGateEvalSumcheckLazyClaimsBundleEmulated[FR]) NbVars() int { + return len(e.claimsMapOutputsLazy[wireKey(e.wireBundle.Outputs[0])].evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaimsBundleEmulated[FR]) CombinedSum(a *emulated.Element[FR]) *emulated.Element[FR] { + //dummy challenges only for testing + challengesRLC := make([]*emulated.Element[FR], len(e.claimsMapOutputsLazy)) + for i := range challengesRLC { + challengesRLC[i] = e.engine.Const(big.NewInt(int64(i+1))) // todo check this + } + acc := e.engine.Const(big.NewInt(0)) + for i, claim := range e.claimsMapOutputsLazy { + _, wireIndex := parseWireKey(i) + sum := claim.CombinedSum(a) + sumRLC := e.engine.Mul(sum, challengesRLC[wireIndex]) + acc = e.engine.Add(acc, sumRLC) + } + return acc +} + +func (e *eqTimesGateEvalSumcheckLazyClaimsBundleEmulated[FR]) Degree(int) int { + return 1 + e.wireBundle.Gate.Degree() +} + +func (e *eqTimesGateEvalSumcheckLazyClaimsBundleEmulated[FR]) AssertEvaluation(r []*emulated.Element[FR], combinationCoeff, expectedValue *emulated.Element[FR], proof sumcheck.EvaluationProof) error { + inputEvaluationsNoRedundancy := proof.([]emulated.Element[FR]) + + field, err := emulated.NewField[FR](e.verifier.api) + if err != nil { + return fmt.Errorf("failed to create field: %w", err) + } + p, err := polynomial.New[FR](e.verifier.api) + if err != nil { + return err + } + + // dummy challenges for testing, get from transcript + challengesRLC := make([]*emulated.Element[FR], len(e.wireBundle.Outputs)) + for i := range challengesRLC { + challengesRLC[i] = e.engine.Const(big.NewInt(int64(i+1))) + } + + var evaluationFinal emulated.Element[FR] + // the eq terms + evaluationEq := make([]*emulated.Element[FR], len(e.claimsMapOutputsLazy)) + for k, claims := range e.claimsMapOutputsLazy { + _, wireIndex := parseWireKey(k) + numClaims := len(claims.evaluationPoints) + eval := p.EvalEqual(polynomial.FromSlice(claims.evaluationPoints[numClaims - 1]), r) // assuming single claim per wire + // for i := numClaims - 2; i >= 0; i-- { // todo change this to handle multiple claims per wire - assuming single claim per wire so don't need to combine + // eval = field.Mul(eval, combinationCoeff) + // eq := p.EvalEqual(polynomial.FromSlice(claims.evaluationPoints[i]), r) + // eval = field.Add(eval, eq) + // } + evaluationEq[wireIndex] = eval + } + + // the g(...) term + if e.wireBundle.IsInput() { // From previous impl - was not needed as this is already handled with noproof before initiating sumcheck verify + // for _, output := range e.wireBundle.Outputs { // doing on output as first layer is dummy layer with identity gate + // gateEvaluationsPtr, err := p.EvalMultilinear(r, e.claimsMapOutputsLazy[wireKey(output)].manager.assignment[wireKey(output)]) + // if err != nil { + // return err + // } + // gateEvaluations = append(gateEvaluations, *gateEvaluationsPtr) + // for i, s := range gateEvaluations { + // gateEvaluationRLC := e.engine.Mul(&s, challengesRLC[i]) + // gateEvaluation = *e.engine.Add(&gateEvaluation, gateEvaluationRLC) + // } + // } + } else { + inputEvaluations := make([]emulated.Element[FR], len(e.wireBundle.Inputs)) + indexesInProof := make(map[*Wires]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wireBundle.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.claimsMapInputsLazy[wireKey(in)].manager.add(in, polynomial.FromSliceReferences(r), inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluationOutputs := e.wireBundle.Gate.Evaluate(e.engine, polynomial.FromSlice(inputEvaluations)...) + + for i , s := range gateEvaluationOutputs { + evaluationRLC := e.engine.Mul(s, challengesRLC[i]) + evaluationFinal = *e.engine.Add(&evaluationFinal, evaluationRLC) + } + } + + evaluationFinal = *e.engine.Mul(&evaluationFinal, evaluationEq[0]) + + field.AssertIsEqual(&evaluationFinal, expectedValue) + return nil +} + +type claimsManagerEmulated[FR emulated.FieldParams] struct { + claimsMap map[string]*eqTimesGateEvalSumcheckLazyClaimsEmulated[FR] + assignment WireAssignmentEmulated[FR] +} + +func (m *claimsManagerEmulated[FR]) add(wire *Wires, evaluationPoint []emulated.Element[FR], evaluation emulated.Element[FR]) { + claim := m.claimsMap[wireKey(wire)] + i := len(claim.evaluationPoints) //todo check this + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +type claimsManagerBundleEmulated[FR emulated.FieldParams] struct { + claimsMap map[string]*eqTimesGateEvalSumcheckLazyClaimsBundleEmulated[FR] + assignment WireAssignmentBundleEmulated[FR] +} + +func newClaimsManagerBundleEmulated[FR emulated.FieldParams](c CircuitBundleEmulated[FR], assignment WireAssignmentBundleEmulated[FR], verifier GKRVerifier[FR]) (claims claimsManagerBundleEmulated[FR]) { + claims.assignment = assignment + claims.claimsMap = make(map[string]*eqTimesGateEvalSumcheckLazyClaimsBundleEmulated[FR], len(c)) + engine, err := sumcheck.NewEmulatedEngine[FR](verifier.api) + if err != nil { + panic(err) + } + + for i := range c { + wireBundle := &c[i] + claimsMapOutputs := make(map[string]*eqTimesGateEvalSumcheckLazyClaimsEmulated[FR], len(wireBundle.Outputs)) + claimsMapInputs := make(map[string]*eqTimesGateEvalSumcheckLazyClaimsEmulated[FR], len(wireBundle.Inputs)) + + for _, wire := range wireBundle.Outputs { + inputClaimsManager := &claimsManagerEmulated[FR]{} + inputClaimsManager.assignment = assignment[wireBundle] + // todo we assume each individual wire has only one claim + inputClaimsManager.claimsMap = make(map[string]*eqTimesGateEvalSumcheckLazyClaimsEmulated[FR], 1) + new_claim := &eqTimesGateEvalSumcheckLazyClaimsEmulated[FR]{ + wire: wire, + commonGate: wireBundle.Gate, + evaluationPoints: make([][]emulated.Element[FR], 0, 1), // assuming single claim per wire + claimedEvaluations: make([]emulated.Element[FR], 1), + manager: inputClaimsManager, + verifier: &verifier, + engine: engine, + } + inputClaimsManager.claimsMap[wireKey(wire)] = new_claim + claimsMapOutputs[wireKey(wire)] = new_claim + } + for _, wire := range wireBundle.Inputs { + inputClaimsManager := &claimsManagerEmulated[FR]{} + inputClaimsManager.assignment = assignment[wireBundle] + // todo we assume each individual wire has only one claim + inputClaimsManager.claimsMap = make(map[string]*eqTimesGateEvalSumcheckLazyClaimsEmulated[FR], 1) + new_claim := &eqTimesGateEvalSumcheckLazyClaimsEmulated[FR]{ + wire: wire, + commonGate: wireBundle.Gate, + evaluationPoints: make([][]emulated.Element[FR], 0, 1), // assuming single claim per wire + claimedEvaluations: make([]emulated.Element[FR], 1), + manager: inputClaimsManager, + verifier: &verifier, + engine: engine, + } + inputClaimsManager.claimsMap[wireKey(wire)] = new_claim + claimsMapInputs[wireKey(wire)] = new_claim + } + claims.claimsMap[bundleKeyEmulated(wireBundle)] = &eqTimesGateEvalSumcheckLazyClaimsBundleEmulated[FR]{ + wireBundle: wireBundle, + claimsMapOutputsLazy: claimsMapOutputs, + claimsMapInputsLazy: claimsMapInputs, + verifier: &verifier, + engine: engine, + } + } + return +} + +func (m *claimsManagerBundleEmulated[FR]) getLazyClaim(wire *WireBundleEmulated[FR]) *eqTimesGateEvalSumcheckLazyClaimsBundleEmulated[FR] { + return m.claimsMap[bundleKeyEmulated(wire)] +} + +func (m *claimsManagerBundleEmulated[FR]) deleteClaim(wireBundle *WireBundleEmulated[FR], previousWireBundle *WireBundleEmulated[FR]) { + if !wireBundle.IsInput() { + sewnClaimsMapOutputs := m.claimsMap[bundleKeyEmulated(wireBundle)].claimsMapInputsLazy + m.claimsMap[bundleKeyEmulated(previousWireBundle)].claimsMapOutputsLazy = sewnClaimsMapOutputs + } + delete(m.claimsMap, bundleKeyEmulated(wireBundle)) +} + +type claimsManager struct { + claimsMap map[string]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment +} + +func wireKey(w *Wires) string { + return fmt.Sprintf("%d-%d", w.BundleIndex, w.WireIndex) +} + +func getOuputWireKey(w *Wires) string { + return fmt.Sprintf("%d-%d", w.BundleIndex + 1, w.WireIndex) +} + +func getInputWireKey(w *Wires) string { + return fmt.Sprintf("%d-%d", w.BundleIndex - 1, w.WireIndex) +} + +func parseWireKey(key string) (int, int) { + var bundleIndex, wireIndex int + _, err := fmt.Sscanf(key, "%d-%d", &bundleIndex, &wireIndex) + if err != nil { + panic(err) + } + return bundleIndex, wireIndex +} + +type claimsManagerBundle struct { + claimsMap map[string]*eqTimesGateEvalSumcheckLazyClaimsBundle // bundleKey(wireBundle) + assignment WireAssignmentBundle +} + +func newClaimsManagerBundle(c CircuitBundle, assignment WireAssignmentBundle) (claims claimsManagerBundle) { + claims.assignment = assignment + claims.claimsMap = make(map[string]*eqTimesGateEvalSumcheckLazyClaimsBundle, len(c)) + + for i := range c { + wireBundle := &c[i] + claimsMapOutputs := make(map[string]*eqTimesGateEvalSumcheckLazyClaims, len(wireBundle.Outputs)) + claimsMapInputs := make(map[string]*eqTimesGateEvalSumcheckLazyClaims, len(wireBundle.Inputs)) + for _, wire := range wireBundle.Outputs { + inputClaimsManager := &claimsManager{} + inputClaimsManager.assignment = assignment[wireBundle] + // todo we assume each individual wire has only one claim + inputClaimsManager.claimsMap = make(map[string]*eqTimesGateEvalSumcheckLazyClaims, len(wireBundle.Inputs)) + new_claim := &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]big.Int, 0, 1), //assuming single claim per wire + claimedEvaluations: make([]big.Int, 1), + manager: inputClaimsManager, + } + inputClaimsManager.claimsMap[wireKey(wire)] = new_claim + claimsMapOutputs[wireKey(wire)] = new_claim + } + for _, wire := range wireBundle.Inputs { + inputClaimsManager := &claimsManager{} + inputClaimsManager.assignment = assignment[wireBundle] + // todo we assume each individual wire has only one claim + inputClaimsManager.claimsMap = make(map[string]*eqTimesGateEvalSumcheckLazyClaims, len(wireBundle.Inputs)) + new_claim := &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]big.Int, 0, 1), //assuming single claim per wire + claimedEvaluations: make([]big.Int, 1), + manager: inputClaimsManager, + } + inputClaimsManager.claimsMap[wireKey(wire)] = new_claim + claimsMapInputs[wireKey(wire)] = new_claim + } + claims.claimsMap[bundleKey(wireBundle)] = &eqTimesGateEvalSumcheckLazyClaimsBundle{ + wireBundle: wireBundle, + claimsMapOutputsLazy: claimsMapOutputs, + claimsMapInputsLazy: claimsMapInputs, + } + } + return +} + +func (m *claimsManagerBundle) getClaim(engine *sumcheck.BigIntEngine, wireBundle *WireBundle) *eqTimesGateEvalSumcheckClaimsBundle { + lazyClaimsOutputs := m.claimsMap[bundleKey(wireBundle)].claimsMapOutputsLazy + lazyClaimsInputs := m.claimsMap[bundleKey(wireBundle)].claimsMapInputsLazy + claimsMapOutputs := make(map[string]*eqTimesGateEvalSumcheckClaims, len(lazyClaimsOutputs)) + claimsMapInputs := make(map[string]*eqTimesGateEvalSumcheckClaims, len(lazyClaimsInputs)) + + for _, lazyClaim := range lazyClaimsOutputs { + output_claim := &eqTimesGateEvalSumcheckClaims{ + wire: lazyClaim.wire, + evaluationPoints: lazyClaim.evaluationPoints, + claimedEvaluations: lazyClaim.claimedEvaluations, + manager: lazyClaim.manager, + engine: engine, + } + + claimsMapOutputs[wireKey(lazyClaim.wire)] = output_claim + + if wireBundle.IsInput() { + output_claim.inputPreprocessors = []sumcheck.NativeMultilinear{m.assignment[wireBundle][getInputWireKey(lazyClaim.wire)]} + } else { + output_claim.inputPreprocessors = make([]sumcheck.NativeMultilinear, 1) //change this + output_claim.inputPreprocessors[0] = m.assignment[wireBundle][getInputWireKey(lazyClaim.wire)].Clone() + + } + } + + for _, lazyClaim := range lazyClaimsInputs { + + input_claim := &eqTimesGateEvalSumcheckClaims{ + wire: lazyClaim.wire, + evaluationPoints: make([][]big.Int, 0, 1), + claimedEvaluations: make([]big.Int, 1), + manager: lazyClaim.manager, + engine: engine, + } + + if !wireBundle.IsOutput() { + input_claim.claimedEvaluations = lazyClaim.claimedEvaluations + input_claim.evaluationPoints = lazyClaim.evaluationPoints + } + + claimsMapInputs[wireKey(lazyClaim.wire)] = input_claim + } + + res := &eqTimesGateEvalSumcheckClaimsBundle{ + wireBundle: wireBundle, + claimsMapOutputs: claimsMapOutputs, + claimsMapInputs: claimsMapInputs, + claimsManagerBundle: m, + } + + return res +} + +// sews claimsInput to claimsOutput and deletes the claimsInput +func (m *claimsManagerBundle) deleteClaim(wireBundle *WireBundle, previousWireBundle *WireBundle) { + if !wireBundle.IsInput() { + sewnClaimsMapOutputs := m.claimsMap[bundleKey(wireBundle)].claimsMapInputsLazy + m.claimsMap[bundleKey(previousWireBundle)].claimsMapOutputsLazy = sewnClaimsMapOutputs + } + delete(m.claimsMap, bundleKey(wireBundle)) +} + +func (e *claimsManagerBundle) addInput(wireBundle *WireBundle, wire *Wires, evaluationPoint []big.Int, evaluation big.Int) { + claim := e.claimsMap[bundleKey(wireBundle)].claimsMapInputsLazy[wireKey(wire)] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wires + evaluationPoints [][]big.Int // x in the paper + claimedEvaluations []big.Int // y in the paper + manager *claimsManager +} + +type eqTimesGateEvalSumcheckClaims struct { + wire *Wires + evaluationPoints [][]big.Int // x in the paper + claimedEvaluations []big.Int // y in the paper + manager *claimsManager + engine *sumcheck.BigIntEngine + inputPreprocessors []sumcheck.NativeMultilinear // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + + eq sumcheck.NativeMultilinear // ∑_i τ_i eq(x_i, -) +} + +func (e *eqTimesGateEvalSumcheckClaims) NbClaims() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckClaims) NbVars() int { + return len(e.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) CombineWithoutComputeGJ(combinationCoeff *big.Int) { + varsNum := c.NbVars() + eqLength := 1 << varsNum + claimsNum := c.NbClaims() + + // initialize the eq tables + c.eq = make(sumcheck.NativeMultilinear, eqLength) + for i := 0; i < eqLength; i++ { + c.eq[i] = new(big.Int) + } + c.eq[0] = c.engine.One() + sumcheck.Eq(c.engine, c.eq, sumcheck.ReferenceBigIntSlice(c.evaluationPoints[0])) + + newEq := make(sumcheck.NativeMultilinear, eqLength) + for i := 0; i < eqLength; i++ { + newEq[i] = new(big.Int) + } + aI := new(big.Int).Set(combinationCoeff) + + for k := 1; k < claimsNum; k++ { // TODO: parallelizable? + // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + newEq[0].Set(aI) + sumcheck.EqAcc(c.engine, c.eq, newEq, sumcheck.ReferenceBigIntSlice(c.evaluationPoints[k])) + if k+1 < claimsNum { + aI.Mul(aI, combinationCoeff) + } + } +} + +type eqTimesGateEvalSumcheckLazyClaimsBundle struct { + wireBundle *WireBundle + claimsMapOutputsLazy map[string]*eqTimesGateEvalSumcheckLazyClaims + claimsMapInputsLazy map[string]*eqTimesGateEvalSumcheckLazyClaims +} + +func (e *eqTimesGateEvalSumcheckLazyClaimsBundle) addOutput(wire *Wires, evaluationPoint []big.Int, evaluation big.Int) { + claim := e.claimsMapOutputsLazy[wireKey(wire)] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +type eqTimesGateEvalSumcheckClaimsBundle struct { + wireBundle *WireBundle + claimsMapOutputs map[string]*eqTimesGateEvalSumcheckClaims + claimsMapInputs map[string]*eqTimesGateEvalSumcheckClaims + claimsManagerBundle *claimsManagerBundle +} + +// assuming each individual wire has a single claim +func (e *eqTimesGateEvalSumcheckClaimsBundle) NbClaims() int { + return len(e.claimsMapOutputs) +} +// to batch sumchecks in the bundle all claims should have the same number of variables +func (e *eqTimesGateEvalSumcheckClaimsBundle) NbVars() int { + return len(e.claimsMapOutputs[wireKey(e.wireBundle.Outputs[0])].evaluationPoints[0]) +} + +func (cB *eqTimesGateEvalSumcheckClaimsBundle) Combine(combinationCoeff *big.Int) sumcheck.NativePolynomial { + for _, claim := range cB.claimsMapOutputs { + claim.CombineWithoutComputeGJ(combinationCoeff) + } + + // from this point on the claims are rather simple : g_i = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree + // we batch sumchecks for g_i using RLC + return cB.bundleComputeGJFull() +} + +//todo optimise loops +// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k +// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). +// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +func (cB *eqTimesGateEvalSumcheckClaimsBundle) bundleComputeGJFull() sumcheck.NativePolynomial { + degGJ := 1 + cB.wireBundle.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + batch := len(cB.claimsMapOutputs) + s := make([][]sumcheck.NativeMultilinear, batch) + // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + for i, c := range cB.claimsMapOutputs { + _, wireIndex := parseWireKey(i) + s[wireIndex] = make([]sumcheck.NativeMultilinear, len(c.inputPreprocessors)+1) + s[wireIndex][0] = c.eq + s[wireIndex][1] = c.inputPreprocessors[0].Clone() + } + + // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called + //nbInner := len(s[0]) // wrt output, which has high nbOuter and low nbInner + nbOuter := len(s[0][0]) / 2 + + challengesRLC := make([]*big.Int, batch) + for i := range challengesRLC { + challengesRLC[i] = big.NewInt(int64(i+1)) + } + + // Contains the output of the algo + evals := make([]*big.Int, degGJ) + for i := range evals { + evals[i] = new(big.Int) + } + evaluationBuffer := make([][]*big.Int, batch) + tmpEvals := make([][]*big.Int, nbOuter) + eqChunk := make([][]*big.Int, nbOuter) + tmpEqs := make([]*big.Int, nbOuter) + dEqs := make([]*big.Int, nbOuter) + for i := range dEqs { + dEqs[i] = new(big.Int) + } + tmpXs := make([][]*big.Int, batch) + for i := range tmpXs { + tmpXs[i] = make([]*big.Int, 2*nbOuter) + for j := range tmpXs[i] { + tmpXs[i][j] = new(big.Int) + } + } + dXs := make([][]*big.Int, nbOuter) + for i := range dXs { + dXs[i] = make([]*big.Int, batch) + for j := range dXs[i] { + dXs[i][j] = new(big.Int) + } + } + + engine := cB.claimsMapOutputs[wireKey(cB.wireBundle.Outputs[0])].engine + evalsVec := make([]*big.Int, nbOuter) + for i := range evalsVec { + evalsVec[i] = big.NewInt(0) + } + evalPtr := big.NewInt(0) + v := big.NewInt(0) + + // for g(0) -- for debuggin + // for i, _ := range cB.claimsMapOutputs { + // _, wireIndex := parseWireKey(i) + // // Redirect the evaluation table directly to inst + // // So we don't copy into tmpXs + // evaluationBuffer[wireIndex] = s[wireIndex][1][0:nbOuter] + // for i, q := range evaluationBuffer[wireIndex] { + // fmt.Println("evaluationBuffer0[", wireIndex, "][", i, "]", q.String()) + // } + // } + + // // evaluate the gate with inputs pointed to by the evaluation buffer + // for i := 0; i < nbOuter; i++ { + // inputs := make([]*big.Int, batch) + // tmpEvals[i] = make([]*big.Int, batch) + // for j := 0; j < batch; j++ { + // inputs[j] = evaluationBuffer[j][i] + // } + // tmpEvals[i] = cB.wireBundle.Gate.Evaluate(engine, inputs...) + // //fmt.Println("tmpEvals[", i, "]", tmpEvals[i]) + // } + + // for x := 0; x < nbOuter; x++ { + // eqChunk[x] = make([]*big.Int, batch) + // for i, _ := range cB.claimsMapOutputs { + // _, wireIndex := parseWireKey(i) + // eqChunk[x][wireIndex] = s[wireIndex][0][0:nbOuter][x] + // v = engine.Mul(eqChunk[x][wireIndex], tmpEvals[x][wireIndex]) + // v = engine.Mul(v, challengesRLC[wireIndex]) + // evalPtr = engine.Add(evalPtr, v) + // } + // } + // //fmt.Println("evalPtr", evalPtr) + + // // Then update the evalsValue + // evals[0] = evalPtr// 0 because t = 0 + + // Second special case : evaluation at t = 1 + evalPtr = big.NewInt(0) + for i, _ := range cB.claimsMapOutputs { + _, wireIndex := parseWireKey(i) + // Redirect the evaluation table directly to inst + // So we don't copy into tmpXs + evaluationBuffer[wireIndex] = s[wireIndex][1][nbOuter:nbOuter*2] + } + + for x := 0; x < nbOuter; x++ { + inputs := make([]*big.Int, batch) + tmpEvals[x] = make([]*big.Int, batch) + for j := 0; j < batch; j++ { + inputs[j] = evaluationBuffer[j][x] + } + tmpEvals[x] = cB.wireBundle.Gate.Evaluate(engine, inputs...) + + eqChunk[x] = make([]*big.Int, batch) + for i, _ := range cB.claimsMapOutputs { + _, wireIndex := parseWireKey(i) + v = engine.Mul(tmpEvals[x][wireIndex], challengesRLC[wireIndex]) + evalsVec[x] = engine.Add(evalsVec[x], v) + } + eqChunk[x][0] = s[0][0][nbOuter:nbOuter*2][x] + evalsVec[x] = engine.Mul(evalsVec[x], eqChunk[x][0]) + evalPtr = engine.Add(evalPtr, evalsVec[x]) + } + + // Then update the evalsValue + evals[0] = evalPtr // 1 because t = 1 + + // Then regular case t >= 2 + + // Initialize the eq and dEq table, at the value for t = 1 + // (We get the next values for t by adding dEqs) + // Initializes the dXs as P(t=1, x) - P(t=0, x) + // As for eq, we initialize each input table `X` with the value for t = 1 + // (We get the next values for t by adding dXs) + for x := 0; x < nbOuter; x++ { + tmpEqs[x] = s[0][0][nbOuter:nbOuter*2][x] + dEqs[x] = engine.Sub(s[0][0][nbOuter+x], s[0][0][x]) + for i, _ := range cB.claimsMapOutputs { + _, wireIndex := parseWireKey(i) + dXs[x][wireIndex] = engine.Sub(s[wireIndex][1][nbOuter+x], s[wireIndex][1][x]) + tmpXs[wireIndex][0:nbOuter][x] = s[wireIndex][1][nbOuter:nbOuter*2][x] + evaluationBuffer[wireIndex] = tmpXs[wireIndex][0:nbOuter] + } + } + + for t := 1; t < degGJ; t++ { + evalPtr = big.NewInt(0) + nInputsSubChunkLen := 1 * nbOuter // assuming single input per claim + // Update the value of tmpXs : as dXs and tmpXs have the same layout, + // no need to make a double loop on k : the index of the separate inputs + // We can do this, because P is multilinear so P(t+1,x) = P(t, x) + dX(x) + for i, _ := range cB.claimsMapOutputs { + _, wireIndex := parseWireKey(i) + for kx := 0; kx < nInputsSubChunkLen; kx++ { + tmpXs[wireIndex][kx] = engine.Add(tmpXs[wireIndex][kx], dXs[kx][wireIndex]) + } + } + + for x := 0; x < nbOuter; x++ { + evalsVec[x] = big.NewInt(0) + tmpEqs[x] = engine.Add(tmpEqs[x], dEqs[x]) + + inputs := make([]*big.Int, batch) + tmpEvals[x] = make([]*big.Int, batch) + for j := 0; j < batch; j++ { + inputs[j] = evaluationBuffer[j][x] + } + tmpEvals[x] = cB.wireBundle.Gate.Evaluate(engine, inputs...) + + for i, _ := range cB.claimsMapOutputs { + _, wireIndex := parseWireKey(i) + v = engine.Mul(tmpEvals[x][wireIndex], challengesRLC[wireIndex]) + evalsVec[x] = engine.Add(evalsVec[x], v) + } + evalsVec[x] = engine.Mul(evalsVec[x], tmpEqs[x]) + evalPtr = engine.Add(evalPtr, evalsVec[x]) + } + + evals[t] = evalPtr + + } + + // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though + // for _, eval := range evals { + // fmt.Println("evals", eval.String()) + // } + return evals +} + +// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +func (c *eqTimesGateEvalSumcheckClaimsBundle) Next(element *big.Int) sumcheck.NativePolynomial { + eq := []*big.Int{} + for j, claim := range c.claimsMapOutputs { + _, wireIndex := parseWireKey(j) + for i := 0; i < len(claim.inputPreprocessors); i++ { + claim.inputPreprocessors[i] = sumcheck.Fold(claim.engine, claim.inputPreprocessors[i], element).Clone() + } + if wireIndex == 0 { + eq = sumcheck.Fold(claim.engine, claim.eq, element).Clone() + } + claim.eq = eq + } + + return c.bundleComputeGJFull() +} + +func (c *eqTimesGateEvalSumcheckClaimsBundle) ProverFinalEval(r []*big.Int) sumcheck.NativeEvaluationProof { + engine := c.claimsMapOutputs[wireKey(c.wireBundle.Outputs[0])].engine + //defer the proof, return list of claims + evaluations := make([]*big.Int, 0, len(c.wireBundle.Outputs)) + noMoreClaimsAllowed := make(map[*Wires]struct{}, len(c.claimsMapOutputs)) + for _, claim := range c.claimsMapOutputs { + noMoreClaimsAllowed[claim.wire] = struct{}{} + } + // each claim corresponds to a wireBundle, P_u is folded and added to corresponding claimBundle + for _, in := range c.wireBundle.Inputs { + puI := c.claimsMapOutputs[getOuputWireKey(in)].inputPreprocessors[0] //todo change this - maybe not required + if _, found := noMoreClaimsAllowed[in]; !found { + noMoreClaimsAllowed[in] = struct{}{} + puI = sumcheck.Fold(engine, puI, r[len(r)-1]) + puI0 := new(big.Int).Set(puI[0]) + c.claimsManagerBundle.addInput(c.wireBundle, in, sumcheck.DereferenceBigIntSlice(r), *puI0) + evaluations = append(evaluations, puI0) + } + } + + return evaluations +} + +func (e *eqTimesGateEvalSumcheckClaimsBundle) Degree(int) int { + return 1 + e.wireBundle.Gate.Degree() +} + +func setup(current *big.Int, target *big.Int, c CircuitBundle, assignment WireAssignmentBundle, options ...OptionGkr) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< b { + return a + } + return b +} + +func ChallengeNames(sorted []*Wire, logNbInstances int, prefix string) []string { + + // Pre-compute the size TODO: Consider not doing this and just grow the list by appending + size := logNbInstances // first challenge + + for _, w := range sorted { + if w.noProof() { // no proof, no challenge + continue + } + if w.NbClaims() > 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func ChallengeNamesBundle(sorted []*WireBundle, logNbInstances int, prefix string) []string { + + // Pre-compute the size TODO: Consider not doing this and just grow the list by appending + size := logNbInstances // first challenge + + for _, w := range sorted { + if w.noProof() { // no proof, no challenge + continue + } + if w.NbClaims() > 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func ChallengeNamesEmulated[FR emulated.FieldParams](sorted []*WireBundleEmulated[FR], logNbInstances int, prefix string) []string { + + // Pre-compute the size TODO: Consider not doing this and just grow the list by appending + size := logNbInstances // first challenge + + for _, w := range sorted { + if w.noProof() { // no proof, no challenge + continue + } + if w.NbClaims() > 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func (v *GKRVerifier[FR]) getChallengesFr(transcript *fiatshamir.Transcript, names []string) (challenges []emulated.Element[FR], err error) { + challenges = make([]emulated.Element[FR], len(names)) + var challenge emulated.Element[FR] + var fr FR + for i, name := range names { + nativeChallenge, err := transcript.ComputeChallenge(name) + if err != nil { + return nil, fmt.Errorf("compute challenge %s: %w", names, err) + } + // TODO: when implementing better way (construct from limbs instead of bits) then change + chBts := bits.ToBinary(v.api, nativeChallenge, bits.WithNbDigits(fr.Modulus().BitLen())) + challenge = *v.f.FromBits(chBts...) + challenges[i] = challenge + + } + return challenges, nil +} + +// Prove consistency of the claimed assignment +func Prove(current *big.Int, target *big.Int, c CircuitBundle, assignment WireAssignmentBundle, transcriptSettings fiatshamir.SettingsBigInt, options ...OptionGkr) (NativeProofs, error) { + be := sumcheck.NewBigIntEngine(target) + o, err := setup(current, target, c, assignment, options...) + if err != nil { + return nil, err + } + + claimBundle := newClaimsManagerBundle(c, assignment) + proof := make(NativeProofs, len(c)) + challengeNames := getFirstChallengeNames(o.nbVars, o.transcriptPrefix) + // firstChallenge called rho in the paper + firstChallenge := make([]*big.Int, len(challengeNames)) + for i := 0; i < len(challengeNames); i++ { + firstChallenge[i], _, err = sumcheck.DeriveChallengeProver(o.transcript, challengeNames[i:], nil) + if err != nil { + return nil, err + } + } + + var baseChallenge []*big.Int + for i := len(c) - 1; i >= 0; i-- { + wireBundle := o.sorted[i] + var previousWireBundle *WireBundle + if !wireBundle.IsInput() { + previousWireBundle = o.sorted[i-1] + } + claimBundleMap := claimBundle.claimsMap[bundleKey(wireBundle)] + + if wireBundle.IsOutput() { + for _ , outputs := range wireBundle.Outputs { + evaluation := sumcheck.Eval(be, assignment[wireBundle][wireKey(outputs)], firstChallenge) + claimBundleMap.addOutput(outputs, sumcheck.DereferenceBigIntSlice(firstChallenge), *evaluation) + } + } + + claimBundleSumcheck := claimBundle.getClaim(be, wireBundle) + var finalEvalProofLen int + + if wireBundle.noProof() { // input wires with one claim only + proof[i] = sumcheck.NativeProof{ + RoundPolyEvaluations: []sumcheck.NativePolynomial{}, + FinalEvalProof: sumcheck.NativeDeferredEvalProof([]big.Int{}), + } + } else { + proof[i], err = sumcheck.Prove( + current, target, claimBundleSumcheck, + ) + if err != nil { + return proof, err + } + + finalEvalProof := proof[i].FinalEvalProof + switch finalEvalProof := finalEvalProof.(type) { + case nil: + finalEvalProofCasted := sumcheck.NativeDeferredEvalProof([]big.Int{}) + proof[i].FinalEvalProof = finalEvalProofCasted + case []*big.Int: + finalEvalProofLen = len(finalEvalProof) + finalEvalProofCasted := sumcheck.NativeDeferredEvalProof(sumcheck.DereferenceBigIntSlice(finalEvalProof)) + proof[i].FinalEvalProof = finalEvalProofCasted + default: + return nil, fmt.Errorf("finalEvalProof is not of type DeferredEvalProof") + } + + baseChallenge = make([]*big.Int, finalEvalProofLen) + for i := 0; i < finalEvalProofLen; i++ { + baseChallenge[i] = finalEvalProof.([]*big.Int)[i] + } + } + // the verifier checks a single claim about input wires itself + claimBundle.deleteClaim(wireBundle, previousWireBundle) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete, +// Use valueOfProof[FR](proof) to convert nativeproof by prover into nonnativeproof used by in-circuit verifier +func (v *GKRVerifier[FR]) Verify(api frontend.API, c CircuitBundleEmulated[FR], assignment WireAssignmentBundleEmulated[FR], proof Proofs[FR], transcriptSettings fiatshamir.SettingsEmulated[FR], options ...OptionEmulated[FR]) error { + o, err := v.setup(api, c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + sumcheck_verifier, err := sumcheck.NewVerifier[FR](api) + if err != nil { + return err + } + + claimBundle := newClaimsManagerBundleEmulated[FR](c, assignment, *v) + var firstChallenge []emulated.Element[FR] + firstChallenge, err = v.getChallengesFr(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + var baseChallenge []emulated.Element[FR] + for i := len(c) - 1; i >= 0; i-- { + wireBundle := o.sorted[i] + var previousWireBundle *WireBundleEmulated[FR] + if !wireBundle.IsInput() { + previousWireBundle = o.sorted[i-1] + } + claimBundleMap := claimBundle.claimsMap[bundleKeyEmulated(wireBundle)] + if wireBundle.IsOutput() { + for _, outputs := range wireBundle.Outputs { + var evaluation emulated.Element[FR] + evaluationPtr, err := v.p.EvalMultilinear(polynomial.FromSlice(firstChallenge), assignment[wireBundle][wireKey(outputs)]) + if err != nil { + return err + } + evaluation = *evaluationPtr + claimBundleMap.addOutput(outputs, firstChallenge, evaluation) + } + } + + proofW := proof[i] + finalEvalProof := proofW.FinalEvalProof + claim := claimBundle.getLazyClaim(wireBundle) + + if wireBundle.noProof() { // input wires with one claim only + // make sure the proof is empty + // make sure finalevalproof is of type deferred for gkr + var proofLen int + switch proof := finalEvalProof.(type) { + case nil: //todo check this + proofLen = 0 + case []emulated.Element[FR]: + proofLen = len(sumcheck.DeferredEvalProof[FR](proof)) + default: + return fmt.Errorf("finalEvalProof is not of type DeferredEvalProof") + } + + if (finalEvalProof != nil && proofLen != 0) || len(proofW.RoundPolyEvaluations) != 0 { + return fmt.Errorf("no proof allowed for input wire with a single claim") + } + + if wireBundle.NbClaims() == len(wireBundle.Inputs) { // input wire // todo fix this + // simply evaluate and see if it matches + for _, output := range wireBundle.Outputs { + var evaluation emulated.Element[FR] + evaluationPtr, err := v.p.EvalMultilinear(polynomial.FromSlice(claim.claimsMapOutputsLazy[wireKey(output)].evaluationPoints[0]), assignment[wireBundle][getInputWireKey(output)]) + if err != nil { + return err + } + evaluation = *evaluationPtr + v.f.AssertIsEqual(&claim.claimsMapOutputsLazy[wireKey(output)].claimedEvaluations[0], &evaluation) + } + //todo input actual scalrbits from input testing only + scalarbits := v.f.ToBits(v.f.Modulus()) + nBInstances := 1 << o.nbVars + scalarbitsEmulatedAssignement := make([]emulated.Element[FR], nBInstances) + for i := range scalarbitsEmulatedAssignement { + scalarbitsEmulatedAssignement[i] = *v.f.NewElement(scalarbits[0]) + } + + challengesEval := make([]emulated.Element[FR], o.nbVars) + for i := 0; i < o.nbVars; i++ { + challengesEval[i] = *v.f.NewElement(uint64(i)) + } + for range scalarbits{ + _, err := v.p.EvalMultilinear(polynomial.FromSlice(challengesEval), polynomial.Multilinear[FR](scalarbitsEmulatedAssignement)) + if err != nil { + return err + } + } + + } + } else if err = sumcheck_verifier.Verify( + claim, proof[i], + ); err == nil { + switch proof := finalEvalProof.(type) { + case []emulated.Element[FR]: + baseChallenge = sumcheck.DeferredEvalProof[FR](proof) + default: + return fmt.Errorf("finalEvalProof is not of type DeferredEvalProof") + } + _ = baseChallenge + } else { + return err + } + claimBundle.deleteClaim(wireBundle, previousWireBundle) + } + return nil +} + +//todo reimplement for wireBundle - outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c CircuitBundle, indexes map[*WireBundle]map[*Wires]int) [][][]int { + res := make([][][]int, len(c)) + for i := range c { + res[i] = make([][]int, len(c[i].Inputs)) + c[i].nbUniqueOutputs = 0 + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[&c[i]][in] + res[i][inI] = append(res[i][inI], len(c[i].Inputs)) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][][]int + status [][]int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*WireBundle]map[*Wires]int + leastReady int +} + +func (d *topSortData) markDone(i int, j int) { + d.status[i][j] = -1 + for _, outI := range d.outputs[i][j] { + d.status[j][outI]-- + if d.status[j][outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[i][d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c CircuitBundle) map[*WireBundle]map[*Wires]int { + res := make(map[*WireBundle]map[*Wires]int, len(c)) + for i := range c { + res[&c[i]] = make(map[*Wires]int, len(c[i].Inputs)) + for j := range c[i].Inputs { + res[&c[i]][c[i].Inputs[j]] = j + } + } + return res +} + +func statusList(c CircuitBundle) [][]int { + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, len(c[i].Inputs)) + for j := range c[i].Inputs { + if c[i].IsInput() { + res[i][j] = 0 + } else { + res[i][j] = len(c[i].Inputs) + } + } + + for range c[i].Outputs { + res[i] = append(res[i], len(c[i].Outputs)) + } + } + return res +} + +type IdentityGate[AE sumcheck.ArithEngine[E], E element] struct{ + Arity int +} + +func (gate IdentityGate[AE, E]) NbOutputs() int { + return gate.Arity +} + +func (IdentityGate[AE, E]) Evaluate(api AE, input ...E) []E { + return input +} + +func (IdentityGate[AE, E]) Degree() int { + return 1 +} + +func (gate IdentityGate[AE, E]) NbInputs() int { + return gate.Arity +} + +func (gate IdentityGate[AE, E]) GetName() string { + return "identity" +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsListEmulated[FR emulated.FieldParams](c CircuitEmulated[FR], indexes map[*WireEmulated[FR]]int) [][]int { + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = IdentityGate[*sumcheck.EmuEngine[FR], *emulated.Element[FR]]{} + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortDataEmulated[FR emulated.FieldParams] struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*WireEmulated[FR]]int + leastReady int +} + +func (d *topSortDataEmulated[FR]) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMapEmulated[FR emulated.FieldParams](c CircuitEmulated[FR]) map[*WireEmulated[FR]]int { + res := make(map[*WireEmulated[FR]]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusListEmulated[FR emulated.FieldParams](c CircuitEmulated[FR]) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// TODO: reimplement this for wirebundle, Have this use algo_utils.TopologicalSort underneath + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSortBundle(c CircuitBundle) []*WireBundle { + // var data topSortDataBundle + // data.index = indexMapBundle(c) + // data.outputs = outputsListBundle(c, data.index) + // data.status = statusListBundle(c) + // fmt.Println("data.status", data.status) + // sorted := make([]*WireBundle, len(c)) + + // data.leastReady = 0 + // for i := range c { + // fmt.Println("data.status[", i, "][", data.leastReady, "]", data.status[i][data.leastReady]) + // for data.leastReady < len(data.status[i]) - 1 && data.status[i][data.leastReady] != 0 { + // data.leastReady++ + // } + // fmt.Println("data.leastReady", data.leastReady) + // } + // // if data.leastReady < len(data.status[i]) - 1 && data.status[i][data.leastReady] != 0 { + // // break + // // } + + // for i := range c { + // fmt.Println("data.leastReady", data.leastReady) + // fmt.Println("i", i) + // sorted[i] = &c[i] // .wires[data.leastReady] + // data.markDone(i, data.leastReady) + // } + + //return sorted + + sorted := make([]*WireBundle, len(c)) + for i := range c { + sorted[i] = &c[i] + } + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignmentBundle) Complete(c CircuitBundle, target *big.Int) WireAssignmentBundle { + + engine := sumcheck.NewBigIntEngine(target) + sortedWires := topologicalSortBundle(c) + nbInstances := a.NumInstances() + maxNbIns := 0 + + for _, w := range sortedWires { + maxNbIns = utils.Max(maxNbIns, len(w.Inputs)) + for _, output := range w.Outputs { + if a[w][wireKey(output)] == nil { + a[w][wireKey(output)] = make(sumcheck.NativeMultilinear, nbInstances) + } + } + for _, input := range w.Inputs { + if a[w][wireKey(input)] == nil { + a[w][wireKey(input)] = make(sumcheck.NativeMultilinear, nbInstances) + } + } + } + + parallel.Execute(nbInstances, func(start, end int) { + ins := make([]*big.Int, maxNbIns) + sewWireOutputs := make([][]*big.Int, nbInstances) // assuming inputs outputs same + for i := start; i < end; i++ { + sewWireOutputs[i] = make([]*big.Int, len(sortedWires[0].Inputs)) + for _, w := range sortedWires { + if !w.IsInput() { + for inI, in := range w.Inputs { + a[w][wireKey(in)][i] = sewWireOutputs[i][inI] + } + } + for inI, in := range w.Inputs { + ins[inI] = a[w][wireKey(in)][i] + } + if !w.IsOutput() { + res := w.Gate.Evaluate(engine, ins[:len(w.Inputs)]...) + for outputI, output := range w.Outputs { + a[w][wireKey(output)][i] = res[outputI] + sewWireOutputs[i][outputI] = a[w][wireKey(output)][i] + } + } + } + } + }) + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + if aW != nil { + return len(aW) + } + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + if aW != nil { + return aW.NumVars() + } + } + panic("empty assignment") +} + +func (a WireAssignmentBundle) NumInstances() int { + for _, aWBundle := range a { + for _, aW := range aWBundle { + if aW != nil { + return len(aW) + } + } + } + panic("empty assignment") +} + +func (a WireAssignmentBundle) NumVars() int { + for _, aW := range a { + if aW != nil { + return aW.NumVars() + } + } + panic("empty assignment") +} + +//todo complete this for wirebundle +func topologicalSortBundleEmulated[FR emulated.FieldParams](c CircuitBundleEmulated[FR]) []*WireBundleEmulated[FR] { + // var data topSortDataEmulated[FR] + // data.index = indexMapEmulated(c) + // data.outputs = outputsListEmulated(c, data.index) + // data.status = statusListEmulated(c) + // sorted := make([]*WireBundleEmulated[FR], len(c)) + + // for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + // } + + // for i := range c { + // sorted[i] = &c[data.leastReady] + // data.markDone(data.leastReady) + // } + + sorted := make([]*WireBundleEmulated[FR], len(c)) + for i := range c { + sorted[i] = &c[i] + } + return sorted +} + +func (a WireAssignmentEmulated[FR]) NumInstances() int { + for _, aW := range a { + if aW != nil { + return len(aW) + } + } + panic("empty assignment") +} + +func (a WireAssignmentEmulated[FR]) NumVars() int { + for _, aW := range a { + if aW != nil { + return aW.NumVars() + } + } + panic("empty assignment") +} + +func (a WireAssignmentBundleEmulated[FR]) NumInstances() int { + for _, aWBundle := range a { + for _, aW := range aWBundle { + if aW != nil { + return len(aW) + } + } + } + panic("empty assignment") +} + +func (a WireAssignmentBundleEmulated[FR]) NumVars() int { + for _, aW := range a { + if aW != nil { + return aW.NumVars() + } + } + panic("empty assignment") +} + +func (p Proofs[FR]) Serialize() []emulated.Element[FR] { + size := 0 + for i := range p { + for j := range p[i].RoundPolyEvaluations { + size += len(p[i].RoundPolyEvaluations[j]) + } + switch v := p[i].FinalEvalProof.(type) { + case sumcheck.DeferredEvalProof[FR]: + size += len(v) + } + } + + res := make([]emulated.Element[FR], 0, size) + for i := range p { + for j := range p[i].RoundPolyEvaluations { + res = append(res, p[i].RoundPolyEvaluations[j]...) + } + switch v := p[i].FinalEvalProof.(type) { + case sumcheck.DeferredEvalProof[FR]: + res = append(res, v...) + } + } + if len(res) != size { + panic("bug") // TODO: Remove + } + return res +} + +func computeLogNbInstancesBundle[FR emulated.FieldParams](wires []*WireBundleEmulated[FR], serializedProofLen int) int { + partialEvalElemsPerVar := 0 + for _, w := range wires { + if !w.noProof() { + partialEvalElemsPerVar += w.Gate.Degree() + 1 + serializedProofLen -= 1 //w.nbUniqueOutputs + } + } + return serializedProofLen / partialEvalElemsPerVar +} + +type variablesReader[FR emulated.FieldParams] []emulated.Element[FR] + +func (r *variablesReader[FR]) nextN(n int) []emulated.Element[FR] { + res := (*r)[:n] + *r = (*r)[n:] + return res +} + +func (r *variablesReader[FR]) hasNextN(n int) bool { + return len(*r) >= n +} + +func DeserializeProofBundle[FR emulated.FieldParams](sorted []*WireBundleEmulated[FR], serializedProof []emulated.Element[FR]) (Proofs[FR], error) { + proof := make(Proofs[FR], len(sorted)) + logNbInstances := computeLogNbInstancesBundle(sorted, len(serializedProof)) + + reader := variablesReader[FR](serializedProof) + for i, wI := range sorted { + if !wI.noProof() { + proof[i].RoundPolyEvaluations = make([]polynomial.Univariate[FR], logNbInstances) + for j := range proof[i].RoundPolyEvaluations { + proof[i].RoundPolyEvaluations[j] = reader.nextN(wI.Gate.Degree() + 1) + } + proof[i].FinalEvalProof = reader.nextN(wI.nbUniqueInputs()) + } + + } + if reader.hasNextN(1) { + return nil, fmt.Errorf("proof too long: expected %d encountered %d", len(serializedProof)-len(reader), len(serializedProof)) + } + return proof, nil +} + +type element any + +type MulGate[AE sumcheck.ArithEngine[E], E element] struct{} + +func (g MulGate[AE, E]) NbOutputs() int { + return 1 +} + +func (g MulGate[AE, E]) Evaluate(api AE, x ...E) []E { + if len(x) != 2 { + panic("mul has fan-in 2") + } + return []E{api.Mul(x[0], x[1])} +} + +// TODO: Degree must take nbInputs as an argument and return degree = nbInputs +func (g MulGate[AE, E]) Degree() int { + return 2 +} + +func (g MulGate[AE, E]) NbInputs() int { + return 2 +} + +func (g MulGate[AE, E]) GetName() string { + return "mul" +} + +type AddGate[AE sumcheck.ArithEngine[E], E element] struct{} + +func (a AddGate[AE, E]) Evaluate(api AE, v ...E) []E { + switch len(v) { + case 0: + return []E{api.Const(big.NewInt(0))} + case 1: + return []E{v[0]} + } + rest := v[2:] + res := api.Add(v[0], v[1]) + for _, e := range rest { + res = api.Add(res, e) + } + return []E{res} +} + +func (a AddGate[AE, E]) Degree() int { + return 1 +} + +func (a AddGate[AE, E]) NbInputs() int { + return 2 +} + +func (a AddGate[AE, E]) NbOutputs() int { + return 1 +} + +func (a AddGate[AE, E]) GetName() string { + return "add" +} \ No newline at end of file diff --git a/std/recursion/gkr/gkr_nonnative_test.go b/std/recursion/gkr/gkr_nonnative_test.go new file mode 100644 index 0000000000..b861790f19 --- /dev/null +++ b/std/recursion/gkr/gkr_nonnative_test.go @@ -0,0 +1,324 @@ +package gkrnonative + +import ( + "encoding/json" + "fmt" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bn254" + fpbn254 "github.com/consensys/gnark-crypto/ecc/bn254/fp" + + "github.com/consensys/gnark/frontend" + // "github.com/consensys/gnark/frontend/cs/scs" + // "github.com/consensys/gnark/profile" + fiatshamir "github.com/consensys/gnark/std/fiat-shamir" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/math/emulated/emparams" + "github.com/consensys/gnark/std/recursion" + "github.com/consensys/gnark/std/recursion/gkr/utils" + "github.com/consensys/gnark/std/recursion/sumcheck" + "github.com/consensys/gnark/test" +) + +type GkrVerifierCircuitEmulated[FR emulated.FieldParams] struct { + Input [][]emulated.Element[FR] + Output [][]emulated.Element[FR] `gnark:",public"` + SerializedProof []emulated.Element[FR] + ToFail bool + TestCaseName string +} + +func makeInOutAssignmentBundle[FR emulated.FieldParams](c CircuitBundleEmulated[FR], inputValues [][]emulated.Element[FR], outputValues [][]emulated.Element[FR]) WireAssignmentBundleEmulated[FR] { + sorted := topologicalSortBundleEmulated(c) + res := make(WireAssignmentBundleEmulated[FR], len(sorted)) + for _, w := range sorted { + if w.IsInput() { + res[w] = make(WireAssignmentEmulated[FR], len(w.Inputs)) + for _, wire := range w.Inputs { + res[w][wireKey(wire)] = inputValues[wire.WireIndex] + } + } else if w.IsOutput() { + res[w] = make(WireAssignmentEmulated[FR], len(w.Outputs)) + for _, wire := range w.Outputs { + res[w][wireKey(wire)] = outputValues[wire.WireIndex] + } + } + } + return res +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof [][]uint64 `json:"finalEvalProof"` + RoundPolyEvaluations [][][]uint64 `json:"roundPolyEvaluations"` +} + +func unmarshalProof(printable []PrintableSumcheckProof) (proof NativeProofs) { + proof = make(NativeProofs, len(printable)) + + for i := range printable { + if printable[i].FinalEvalProof != nil { + finalEvalProof := make(sumcheck.NativeDeferredEvalProof, len(printable[i].FinalEvalProof)) + for k, val := range printable[i].FinalEvalProof { + var temp big.Int + temp.SetUint64(val[0]) + for _, v := range val[1:] { + temp.Lsh(&temp, 64).Add(&temp, new(big.Int).SetUint64(v)) + } + finalEvalProof[k] = temp + } + proof[i].FinalEvalProof = finalEvalProof + } else { + proof[i].FinalEvalProof = nil + } + + proof[i].RoundPolyEvaluations = make([]sumcheck.NativePolynomial, len(printable[i].RoundPolyEvaluations)) + for k, evals := range printable[i].RoundPolyEvaluations { + proof[i].RoundPolyEvaluations[k] = make(sumcheck.NativePolynomial, len(evals)) + for j, eval := range evals { + var temp big.Int + temp.SetUint64(eval[0]) + for _, v := range eval[1:] { + temp.Lsh(&temp, 64).Add(&temp, new(big.Int).SetUint64(v)) + } + proof[i].RoundPolyEvaluations[k][j] = &temp + } + } + } + return proof +} + +func (p *PrintableSumcheckProof) UnmarshalJSON(data []byte) error { + var temp struct { + FinalEvalProof [][]uint64 `json:"finalEvalProof"` + RoundPolyEvaluations [][][]uint64 `json:"roundPolyEvaluations"` + } + + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + p.FinalEvalProof = temp.FinalEvalProof + + p.RoundPolyEvaluations = make([][][]uint64, len(temp.RoundPolyEvaluations)) + for i, arr2D := range temp.RoundPolyEvaluations { + p.RoundPolyEvaluations[i] = make([][]uint64, len(arr2D)) + for j, arr1D := range arr2D { + p.RoundPolyEvaluations[i][j] = make([]uint64, len(arr1D)) + for k, v := range arr1D { + p.RoundPolyEvaluations[i][j][k] = uint64(v) + } + } + } + return nil +} + +type ProjAddGkrVerifierCircuit[FR emulated.FieldParams] struct { + Circuit CircuitBundleEmulated[FR] + Input [][]emulated.Element[FR] + Output [][]emulated.Element[FR] `gnark:",public"` + SerializedProof []emulated.Element[FR] +} + +func (c *ProjAddGkrVerifierCircuit[FR]) Define(api frontend.API) error { + var fr FR + var proof Proofs[FR] + var err error + + v, err := NewGKRVerifier[FR](api) + if err != nil { + return fmt.Errorf("new verifier: %w", err) + } + + sorted := topologicalSortBundleEmulated(c.Circuit) + + if proof, err = DeserializeProofBundle(sorted, c.SerializedProof); err != nil { + return err + } + assignment := makeInOutAssignmentBundle(c.Circuit, c.Input, c.Output) + // initiating hash in bitmode, since bn254 basefield is bigger than scalarfield + hsh, err := recursion.NewHash(api, fr.Modulus(), true) + if err != nil { + return err + } + + return v.Verify(api, c.Circuit, assignment, proof, fiatshamir.WithHashFr[FR](hsh)) +} + +func ElementToBigInt(element fpbn254.Element) *big.Int { + var temp big.Int + return element.BigInt(&temp) +} + +func testMultipleDblAddSelectGKRInstance[FR emulated.FieldParams](t *testing.T, current *big.Int, target *big.Int, inputs [][]*big.Int, outputs [][]*big.Int, depth int) { + selector := []*big.Int{big.NewInt(1)} + c := make(CircuitBundle, depth + 1) + c[0] = InitFirstWireBundle(len(inputs), len(c)) + for i := 1; i < depth + 1; i++ { + c[i] = NewWireBundle( + sumcheck.DblAddSelectGateFullOutput[*sumcheck.BigIntEngine, *big.Int]{Selector: selector[0]}, + c[i-1].Outputs, + i, + len(c), + ) + } + + selectorEmulated := make([]emulated.Element[FR], len(selector)) + for i, f := range selector { + selectorEmulated[i] = emulated.ValueOf[FR](f) + } + + cEmulated := make(CircuitBundleEmulated[FR], len(c)) + cEmulated[0] = InitFirstWireBundleEmulated[FR](len(inputs), len(c)) + for i := 1; i < depth + 1; i++ { + cEmulated[i] = NewWireBundleEmulated( + sumcheck.DblAddSelectGateFullOutput[*sumcheck.EmuEngine[FR], *emulated.Element[FR]]{Selector: &selectorEmulated[0]}, + c[i-1].Outputs, + i, + len(c), + ) + } + + assert := test.NewAssert(t) + hash, err := recursion.NewShort(current, target) + if err != nil { + t.Errorf("new short hash: %v", err) + return + } + t.Log("Evaluating all circuit wires") + + fullAssignment := make(WireAssignmentBundle) + inOutAssignment := make(WireAssignmentBundle) + + sorted := topologicalSortBundle(c) + + inI, outI := 0, 0 + for _, w := range sorted { + assignmentRaw := make([][]*big.Int, len(w.Inputs)) + fullAssignment[w] = make(WireAssignment, len(w.Inputs)) + inOutAssignment[w] = make(WireAssignment, len(w.Inputs)) + + if w.IsInput() { + if inI == len(inputs) { + t.Errorf("fewer input in vector than in circuit") + return + } + copy(assignmentRaw, inputs) + for i, assignment := range assignmentRaw { + wireAssignment, err := utils.SliceToBigIntSlice(assignment) + assert.NoError(err) + fullAssignment[w][wireKey(w.Inputs[i])] = sumcheck.NativeMultilinear(utils.ConvertToBigIntSlice(wireAssignment)) + inOutAssignment[w][wireKey(w.Inputs[i])] = sumcheck.NativeMultilinear(utils.ConvertToBigIntSlice(wireAssignment)) + } + } else if w.IsOutput() { + if outI == len(outputs) { + t.Errorf("fewer output in vector than in circuit") + return + } + copy(assignmentRaw, outputs) + for i, assignment := range assignmentRaw { + wireAssignment, err := utils.SliceToBigIntSlice(assignment) + assert.NoError(err) + fullAssignment[w][wireKey(w.Outputs[i])] = sumcheck.NativeMultilinear(utils.ConvertToBigIntSlice(wireAssignment)) + inOutAssignment[w][wireKey(w.Outputs[i])] = sumcheck.NativeMultilinear(utils.ConvertToBigIntSlice(wireAssignment)) + } + } + } + + fullAssignment.Complete(c, target) + + t.Log("Circuit evaluation complete") + proof, err := Prove(current, target, c, fullAssignment, fiatshamir.WithHashBigInt(hash)) + assert.NoError(err) + t.Log("Proof complete") + + proofEmulated := make(Proofs[FR], len(proof)) + for i, proof := range proof { + proofEmulated[i] = sumcheck.ValueOfProof[FR](proof) + } + + validCircuit := &ProjAddGkrVerifierCircuit[FR]{ + Circuit: cEmulated, + Input: make([][]emulated.Element[FR], len(inputs)), + Output: make([][]emulated.Element[FR], len(outputs)), + SerializedProof: proofEmulated.Serialize(), + } + + validAssignment := &ProjAddGkrVerifierCircuit[FR]{ + Circuit: cEmulated, + Input: make([][]emulated.Element[FR], len(inputs)), + Output: make([][]emulated.Element[FR], len(outputs)), + SerializedProof: proofEmulated.Serialize(), + } + + for i := range inputs { + validCircuit.Input[i] = make([]emulated.Element[FR], len(inputs[i])) + validAssignment.Input[i] = make([]emulated.Element[FR], len(inputs[i])) + for j := range inputs[i] { + validAssignment.Input[i][j] = emulated.ValueOf[FR](inputs[i][j]) + } + } + + for i := range outputs { + validCircuit.Output[i] = make([]emulated.Element[FR], len(outputs[i])) + validAssignment.Output[i] = make([]emulated.Element[FR], len(outputs[i])) + for j := range outputs[i] { + validAssignment.Output[i][j] = emulated.ValueOf[FR](outputs[i][j]) + } + } + + err = test.IsSolved(validCircuit, validAssignment, current) + assert.NoError(err) + + // p := profile.Start() + // _, _ = frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, validCircuit) + // p.Stop() + + // fmt.Println(p.NbConstraints()) +} + +func TestMultipleDblAddSelectGKR(t *testing.T) { + var P1 bn254.G1Affine + var one fpbn254.Element + one.SetOne() + var zero fpbn254.Element + zero.SetZero() + var random fpbn254.Element + + depth := 64 + arity := 6 + nBInstances := 2048 + var fp emparams.BN254Fp + be := sumcheck.NewBigIntEngine(fp.Modulus()) + gate := sumcheck.DblAddSelectGateFullOutput[*sumcheck.BigIntEngine, *big.Int]{Selector: big.NewInt(1)} + + res := make([][]*big.Int, nBInstances) + gateInputs := make([][]*big.Int, nBInstances) + for i := 0; i < nBInstances; i++ { + random.SetRandom() + element := P1.ScalarMultiplicationBase(random.BigInt(new(big.Int))) + gateInputs[i] = []*big.Int{ElementToBigInt(element.X), ElementToBigInt(element.Y), ElementToBigInt(one), ElementToBigInt(zero), ElementToBigInt(one), ElementToBigInt(zero)} + inputLayer := gateInputs[i] + for j := 0; j < depth; j++ { + res[i] = gate.Evaluate(be, inputLayer...) + inputLayer = res[i] + } + } + + inputs := make([][]*big.Int, arity) + outputs := make([][]*big.Int, arity) + for i := 0; i < arity; i++ { + inputs[i] = make([]*big.Int, nBInstances) + outputs[i] = make([]*big.Int, nBInstances) + for j := 0; j < nBInstances; j++ { + inputs[i][j] = gateInputs[j][i] + outputs[i][j] = res[j][i] + } + } + + testMultipleDblAddSelectGKRInstance[emparams.BN254Fp](t, ecc.BN254.ScalarField(), fp.Modulus(), inputs, outputs, depth) + +} diff --git a/std/recursion/gkr/test_vectors/resources/mimc_five_levels.json b/std/recursion/gkr/test_vectors/resources/mimc_five_levels.json new file mode 100644 index 0000000000..3dd74f42b5 --- /dev/null +++ b/std/recursion/gkr/test_vectors/resources/mimc_five_levels.json @@ -0,0 +1,36 @@ +[ + [ + { + "gate": "mimc", + "inputs": [[1,0], [5,5]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[2,0], [5,4]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[3,0], [5,3]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[4,0], [5,2]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[5,0], [5,1]] + } + ], + [ + {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, + {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, {"gate": null, "inputs": []} + ] +] \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/resources/single_identity_gate.json b/std/recursion/gkr/test_vectors/resources/single_identity_gate.json new file mode 100644 index 0000000000..a44066c7b4 --- /dev/null +++ b/std/recursion/gkr/test_vectors/resources/single_identity_gate.json @@ -0,0 +1,10 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "identity", + "inputs": [0] + } +] \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/resources/single_input_two_identity_gates.json b/std/recursion/gkr/test_vectors/resources/single_input_two_identity_gates.json new file mode 100644 index 0000000000..6181784fa8 --- /dev/null +++ b/std/recursion/gkr/test_vectors/resources/single_input_two_identity_gates.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "identity", + "inputs": [0] + }, + { + "gate": "identity", + "inputs": [0] + } +] \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/resources/single_input_two_outs.json b/std/recursion/gkr/test_vectors/resources/single_input_two_outs.json new file mode 100644 index 0000000000..c577c1cace --- /dev/null +++ b/std/recursion/gkr/test_vectors/resources/single_input_two_outs.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "mul", + "inputs": [0, 0] + }, + { + "gate": "identity", + "inputs": [0] + } +] \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/resources/single_mimc_gate.json b/std/recursion/gkr/test_vectors/resources/single_mimc_gate.json new file mode 100644 index 0000000000..c89e7d52ae --- /dev/null +++ b/std/recursion/gkr/test_vectors/resources/single_mimc_gate.json @@ -0,0 +1,7 @@ +[ + {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, + { + "gate": "mimc", + "inputs": [0, 1] + } +] \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/resources/single_mimc_gate_two_instances.json b/std/recursion/gkr/test_vectors/resources/single_mimc_gate_two_instances.json new file mode 100644 index 0000000000..a75ccccfef --- /dev/null +++ b/std/recursion/gkr/test_vectors/resources/single_mimc_gate_two_instances.json @@ -0,0 +1,89 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_mimc_gate.json", + "input": [ + [ + 1, + 1 + ], + [ + 1, + 2 + ] + ], + "output": [ + [ + 128, + 2187 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "roundPolyEvaluations": [ + [ + [ + 2241063740747277757, + 6741806107110303462, + 10584378630379443447, + 11431840297086248935 + ], + [ + 1284308761996894303, + 17461779615671711157, + 12779565606756425632, + 11770813148743782171 + ] + ] + ] + }, + { + "finalEvalProof": [], + "roundPolyEvaluations": [ + [ + [ + 2241063740747277757, + 6741806107110303462, + 10584378630379443447, + 11431840297086248935 + ], + [ + 1284308761996894303, + 17461779615671711157, + 12779565606756425632, + 11770813148743782171 + ] + ] + ] + }, + { + "finalEvalProof": [ + [ + 3445061460418080392, + 1582772968760438233, + 15430626802533927355, + 10677110232782539588 + ] + ], + "roundPolyEvaluations": [ + [ + [ + 2241063740747277757, + 6741806107110303462, + 10584378630379443447, + 11431840297086248935 + ], + [ + 1284308761996894303, + 17461779615671711157, + 12779565606756425632, + 11770813148743782171 + ] + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/resources/single_mul_gate.json b/std/recursion/gkr/test_vectors/resources/single_mul_gate.json new file mode 100644 index 0000000000..0f65a07edf --- /dev/null +++ b/std/recursion/gkr/test_vectors/resources/single_mul_gate.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": null, + "inputs": [] + }, + { + "gate": "mul", + "inputs": [0, 1] + } +] \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/resources/two_identity_gates_composed_single_input.json b/std/recursion/gkr/test_vectors/resources/two_identity_gates_composed_single_input.json new file mode 100644 index 0000000000..26681c2f89 --- /dev/null +++ b/std/recursion/gkr/test_vectors/resources/two_identity_gates_composed_single_input.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "identity", + "inputs": [0] + }, + { + "gate": "identity", + "inputs": [1] + } +] \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/resources/two_inputs_select-input-3_gate.json b/std/recursion/gkr/test_vectors/resources/two_inputs_select-input-3_gate.json new file mode 100644 index 0000000000..cdbdb3b471 --- /dev/null +++ b/std/recursion/gkr/test_vectors/resources/two_inputs_select-input-3_gate.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": null, + "inputs": [] + }, + { + "gate": "select-input-3", + "inputs": [0,0,1] + } +] \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/resources/two_inputs_select-input-3_gate_two_instances.json b/std/recursion/gkr/test_vectors/resources/two_inputs_select-input-3_gate_two_instances.json new file mode 100644 index 0000000000..05a2a421e4 --- /dev/null +++ b/std/recursion/gkr/test_vectors/resources/two_inputs_select-input-3_gate_two_instances.json @@ -0,0 +1,65 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/two_inputs_select-input-3_gate.json", + "input": [ + [ + 0, + 1 + ], + [ + 2, + 3 + ] + ], + "output": [ + [ + 2, + 3 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "roundPolyEvaluations": [] + }, + { + "finalEvalProof": [], + "roundPolyEvaluations": [] + }, + { + "finalEvalProof": [ + [ + 34793283102800716, + 14623004755582362860, + 7566020917664053271, + 804411355194692424 + ], + [ + 34793283102800716, + 14623004755582362860, + 7566020917664053271, + 804411355194692426 + ] + ], + "roundPolyEvaluations": [ + [ + [ + 202884085477257258, + 10103834348047568829, + 18355757093406830428, + 3239398344287232773 + ], + [ + 811536341909029034, + 3521849244771172087, + 18082796152498666864, + 12957593377148931088 + ] + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/single_identity_gate_two_instances.json b/std/recursion/gkr/test_vectors/single_identity_gate_two_instances.json new file mode 100644 index 0000000000..420584f6fa --- /dev/null +++ b/std/recursion/gkr/test_vectors/single_identity_gate_two_instances.json @@ -0,0 +1,51 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_identity_gate.json", + "input": [ + [ + 4, + 3 + ] + ], + "output": [ + [ + 4, + 3 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "roundPolyEvaluations": [] + }, + { + "finalEvalProof": [ + [ + 3445061460418080392, + 1582772968760438233, + 15430626802533927355, + 10677110232782539588 + ] + ], + "roundPolyEvaluations": [ + [ + [ + 202884085477257258, + 10103834348047568829, + 18355757093406830428, + 3239398344287232773 + ], + [ + 405768170954514517, + 1760924622385586043, + 18264770113104109240, + 6478796688574465544 + ] + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/single_input_two_identity_gates_two_instances.json b/std/recursion/gkr/test_vectors/single_input_two_identity_gates_two_instances.json new file mode 100644 index 0000000000..1cf156c016 --- /dev/null +++ b/std/recursion/gkr/test_vectors/single_input_two_identity_gates_two_instances.json @@ -0,0 +1,96 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_input_two_identity_gates.json", + "input": [ + [ + 2, + 3 + ] + ], + "output": [ + [ + 2, + 3 + ], + [ + 2, + 3 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "roundPolyEvaluations": [ + [ + [ + 1309801114600745759, + 3758563846819454073, + 10262009230221415359, + 16005847429194593330 + ], + [ + 1641562985788773784, + 10408495378109679862, + 1607731544356410364, + 2789460758528902269 + ] + ] + ] + }, + { + "finalEvalProof": [ + [ + 34793283102800716, + 14623004755582362860, + 7566020917664053271, + 804411355194692426 + ] + ], + "roundPolyEvaluations": [ + [ + [ + 202884085477257258, + 10103834348047568829, + 18355757093406830428, + 3239398344287232773 + ], + [ + 811536341909029034, + 3521849244771172087, + 18082796152498666864, + 12957593377148931088 + ] + ] + ] + }, + { + "finalEvalProof": [ + [ + 34793283102800716, + 14623004755582362860, + 7566020917664053271, + 804411355194692426 + ] + ], + "roundPolyEvaluations": [ + [ + [ + 202884085477257258, + 10103834348047568829, + 18355757093406830428, + 3239398344287232773 + ], + [ + 811536341909029034, + 3521849244771172087, + 18082796152498666864, + 12957593377148931088 + ] + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/single_input_two_outs_two_instances.json b/std/recursion/gkr/test_vectors/single_input_two_outs_two_instances.json new file mode 100644 index 0000000000..9f9bb7b4e4 --- /dev/null +++ b/std/recursion/gkr/test_vectors/single_input_two_outs_two_instances.json @@ -0,0 +1,102 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_input_two_outs.json", + "input": [ + [ + 1, + 2 + ] + ], + "output": [ + [ + 1, + 4 + ], + [ + 1, + 2 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "roundPolyEvaluations": [ + [ + [ + 2241063740747277757, + 6741806107110303462, + 10584378630379443447, + 11431840297086248935 + ], + [ + 1284308761996894303, + 17461779615671711157, + 12779565606756425632, + 11770813148743782171 + ] + ] + ] + }, + { + "finalEvalProof": [ + [ + 11552014468118848, + 6459316162880666778, + 5573794085540653091, + 12018926454163338051 + ] + ], + "roundPolyEvaluations": [ + [ + [ + 270512113969676344, + 13471779130730091773, + 6027598717499555621, + 10468112483619494236 + ], + [ + 1825956769295315326, + 17147532837589913005, + 17627861250985060925, + 10707841024875543332 + ], + [ + 1923244012590556228, + 16346717705102969708, + 17401129836965471330, + 2115447990305160649 + ] + ] + ] + }, + { + "finalEvalProof": [ + [ + 38772122160298693, + 2654177376158557373, + 666365361690475594, + 9065178994946100760 + ] + ], + "roundPolyEvaluations": [ + [ + [ + 135256056984838172, + 6735889565365045886, + 12237171395604553618, + 14457428278664522926 + ], + [ + 608652256431771775, + 11864758970433154873, + 18173783132801388052, + 9718195032861698316 + ] + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/single_mul_gate_two_instances.json b/std/recursion/gkr/test_vectors/single_mul_gate_two_instances.json new file mode 100644 index 0000000000..128b57f3e1 --- /dev/null +++ b/std/recursion/gkr/test_vectors/single_mul_gate_two_instances.json @@ -0,0 +1,71 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_mul_gate.json", + "input": [ + [ + 4, + 3 + ], + [ + 2, + 3 + ] + ], + "output": [ + [ + 8, + 9 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "roundPolyEvaluations": [] + }, + { + "finalEvalProof": [], + "roundPolyEvaluations": [] + }, + { + "finalEvalProof": [ + [ + 3474249952014841962, + 12028090948092229382, + 15144988130097378949, + 4865233403516270609 + ], + [ + 12748314788128703, + 1253101003182465366, + 14218880088090055687, + 17914127541472937276 + ] + ], + "roundPolyEvaluations": [ + [ + [ + 608652256431771775, + 11864758970433154873, + 18173783132801388052, + 9718195032861698319 + ], + [ + 1623072683818058068, + 7043698489542344175, + 17718848231287782113, + 7468442680588310560 + ], + [ + 1690700712310477154, + 10411643272224867119, + 5390689855380507306, + 14697156819920572021 + ] + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json b/std/recursion/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json new file mode 100644 index 0000000000..376025e4e9 --- /dev/null +++ b/std/recursion/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json @@ -0,0 +1,77 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/two_identity_gates_composed_single_input.json", + "input": [ + [ + 2, + 1 + ] + ], + "output": [ + [ + 2, + 1 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "roundPolyEvaluations": [] + }, + { + "finalEvalProof": [ + [ + 3479106886554451955, + 541048341316977072, + 10578437981560588015, + 16173759560562137918 + ] + ], + "roundPolyEvaluations": [ + [ + [ + 21302570947489481, + 677004288128798096, + 11618204988248521184, + 10639673014910314290 + ], + [ + 0, + 0, + 0, + 0 + ] + ] + ] + }, + { + "finalEvalProof": [ + [ + 3465695695855481184, + 12604187663145896652, + 17745663229938913452, + 12139687930078893591 + ] + ], + "roundPolyEvaluations": [ + [ + [ + 67628028492419086, + 3367944782682522943, + 6118585697802276809, + 7228714139332261463 + ], + [ + 0, + 0, + 0, + 0 + ] + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/gkr/utils/util.go b/std/recursion/gkr/utils/util.go new file mode 100644 index 0000000000..c7a9399d1d --- /dev/null +++ b/std/recursion/gkr/utils/util.go @@ -0,0 +1,195 @@ +package utils + +import ( + "fmt" + gohash "hash" + "math/big" + "testing" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash" + "github.com/consensys/gnark/std/math/emulated" + "github.com/stretchr/testify/assert" +) + +func SliceToBigIntSlice[T any](slice []T) ([]big.Int, error) { + elementSlice := make([]big.Int, len(slice)) + for i, v := range slice { + switch v := any(v).(type) { + case *big.Int: + elementSlice[i] = *v + case float64: + elementSlice[i] = *big.NewInt(int64(v)) + default: + return nil, fmt.Errorf("unsupported type: %T", v) + } + } + return elementSlice, nil +} + +func ConvertToBigIntSlice(input []big.Int) []*big.Int { + output := make([]*big.Int, len(input)) + for i := range input { + output[i] = &input[i] + } + return output +} + +func SliceEqualsBigInt(a []big.Int, b []big.Int) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if a[i].Cmp(&b[i]) != 0 { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func ToVariableFr[FR emulated.FieldParams](v interface{}) emulated.Element[FR] { + switch vT := v.(type) { + case float64: + return *new(emulated.Field[FR]).NewElement(int(vT)) + default: + return *new(emulated.Field[FR]).NewElement(v) + } +} + +func ToVariableSliceFr[FR emulated.FieldParams, V any](slice []V) (variableSlice []emulated.Element[FR]) { + variableSlice = make([]emulated.Element[FR], len(slice)) + for i := range slice { + variableSlice[i] = ToVariableFr[FR](slice[i]) + } + return +} + +func ToVariableSliceSliceFr[FR emulated.FieldParams, V any](sliceSlice [][]V) (variableSliceSlice [][]emulated.Element[FR]) { + variableSliceSlice = make([][]emulated.Element[FR], len(sliceSlice)) + for i := range sliceSlice { + variableSliceSlice[i] = ToVariableSliceFr[FR](sliceSlice[i]) + } + return +} + +func AssertSliceEqual[T comparable](t *testing.T, expected, seen []T) { + assert.Equal(t, len(expected), len(seen)) + for i := range seen { + assert.True(t, expected[i] == seen[i], "@%d: %v != %v", i, expected[i], seen[i]) // assert.Equal is not strict enough when comparing pointers, i.e. it compares what they refer to + } +} + +func SliceEqual[T comparable](expected, seen []T) bool { + if len(expected) != len(seen) { + return false + } + for i := range seen { + if expected[i] != seen[i] { + return false + } + } + return true +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (gohash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter{startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + var temp big.Int + inputBlockSize := (len(p)-1)/len(temp.Bytes()) + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + var temp big.Int + inputBlockSize := (len(b)-1)/len(temp.Bytes()) + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res big.Int + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + var temp big.Int + return len(temp.Bytes()) +} + +func (m *MessageCounter) BlockSize() int { + var temp big.Int + return len(temp.Bytes()) +} + +func NewMessageCounter(startState, step int) gohash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() gohash.Hash { + return func() gohash.Hash { + return NewMessageCounter(startState, step) + } +} + +type MessageCounterEmulated struct { + startState int64 + state int64 + step int64 + + // cheap trick to avoid unconstrained input errors + api frontend.API + zero frontend.Variable +} + +func (m *MessageCounterEmulated) Write(data ...frontend.Variable) { + + for i := range data { + sq1, sq2 := m.api.Mul(data[i], data[i]), m.api.Mul(data[i], data[i]) + m.zero = m.api.Sub(sq1, sq2, m.zero) + } + + m.state += int64(len(data)) * m.step +} + +func (m *MessageCounterEmulated) Sum() frontend.Variable { + return m.api.Add(m.state, m.zero) +} + +func (m *MessageCounterEmulated) Reset() { + m.zero = 0 + m.state = m.startState +} + +func NewMessageCounterEmulated(api frontend.API, startState, step int) hash.FieldHasher { + transcript := &MessageCounterEmulated{startState: int64(startState), state: int64(startState), step: int64(step), api: api} + return transcript +} + +func NewMessageCounterGeneratorEmulated(startState, step int) func(frontend.API) hash.FieldHasher { + return func(api frontend.API) hash.FieldHasher { + return NewMessageCounterEmulated(api, startState, step) + } +} diff --git a/std/recursion/sumcheck/arithengine.go b/std/recursion/sumcheck/arithengine.go index e4de69ba0a..1ba3df3732 100644 --- a/std/recursion/sumcheck/arithengine.go +++ b/std/recursion/sumcheck/arithengine.go @@ -15,7 +15,7 @@ type element any // case of prover, it is initialized with a finite field arithmetic engine // defined over [*big.Int] or field arithmetic packages. In case of verifier, is // initialized with non-native arithmetic. -type arithEngine[E element] interface { +type ArithEngine[E element] interface { Add(a, b E) E Mul(a, b E) E Sub(a, b E) E @@ -24,74 +24,84 @@ type arithEngine[E element] interface { Const(i *big.Int) E } -// bigIntEngine performs computation reducing with given modulus. -type bigIntEngine struct { +// BigIntEngine performs computation reducing with given modulus. +type BigIntEngine struct { mod *big.Int // TODO: we should also add pools for more efficient memory management. } -func (be *bigIntEngine) Add(a, b *big.Int) *big.Int { +func (be *BigIntEngine) Add(a, b *big.Int) *big.Int { dst := new(big.Int) dst.Add(a, b) dst.Mod(dst, be.mod) return dst } -func (be *bigIntEngine) Mul(a, b *big.Int) *big.Int { +func (be *BigIntEngine) Mul(a, b *big.Int) *big.Int { dst := new(big.Int) dst.Mul(a, b) dst.Mod(dst, be.mod) return dst } -func (be *bigIntEngine) Sub(a, b *big.Int) *big.Int { +func (be *BigIntEngine) Sub(a, b *big.Int) *big.Int { dst := new(big.Int) dst.Sub(a, b) dst.Mod(dst, be.mod) return dst } -func (be *bigIntEngine) One() *big.Int { +func (be *BigIntEngine) One() *big.Int { return big.NewInt(1) } -func (be *bigIntEngine) Const(i *big.Int) *big.Int { +func (be *BigIntEngine) Const(i *big.Int) *big.Int { return new(big.Int).Set(i) } -func newBigIntEngine(mod *big.Int) *bigIntEngine { - return &bigIntEngine{mod: new(big.Int).Set(mod)} +func NewBigIntEngine(mod *big.Int) *BigIntEngine { + return &BigIntEngine{mod: new(big.Int).Set(mod)} } -// emuEngine uses non-native arithmetic for operations. -type emuEngine[FR emulated.FieldParams] struct { +// EmuEngine uses non-native arithmetic for operations. +type EmuEngine[FR emulated.FieldParams] struct { f *emulated.Field[FR] } -func (ee *emuEngine[FR]) Add(a, b *emulated.Element[FR]) *emulated.Element[FR] { +func (ee *EmuEngine[FR]) Add(a, b *emulated.Element[FR]) *emulated.Element[FR] { return ee.f.Add(a, b) } -func (ee *emuEngine[FR]) Mul(a, b *emulated.Element[FR]) *emulated.Element[FR] { +func (ee *EmuEngine[FR]) Mul(a, b *emulated.Element[FR]) *emulated.Element[FR] { return ee.f.Mul(a, b) } -func (ee *emuEngine[FR]) Sub(a, b *emulated.Element[FR]) *emulated.Element[FR] { +func (ee *EmuEngine[FR]) Sub(a, b *emulated.Element[FR]) *emulated.Element[FR] { return ee.f.Sub(a, b) } -func (ee *emuEngine[FR]) One() *emulated.Element[FR] { +func (ee *EmuEngine[FR]) One() *emulated.Element[FR] { return ee.f.One() } -func (ee *emuEngine[FR]) Const(i *big.Int) *emulated.Element[FR] { +func (ee *EmuEngine[FR]) Const(i *big.Int) *emulated.Element[FR] { return ee.f.NewElement(i) } -func newEmulatedEngine[FR emulated.FieldParams](api frontend.API) (*emuEngine[FR], error) { +func NewEmulatedEngine[FR emulated.FieldParams](api frontend.API) (*EmuEngine[FR], error) { f, err := emulated.NewField[FR](api) if err != nil { return nil, fmt.Errorf("new field: %w", err) } - return &emuEngine[FR]{f: f}, nil + return &EmuEngine[FR]{f: f}, nil } + + +// noopEngine is a no-operation arithmetic engine. Can be used to access methods of the gates without performing any computation. +type noopEngine struct{} + +func (ne *noopEngine) Add(a, b element) element { panic("noop engine: Add called") } +func (ne *noopEngine) Mul(a, b element) element { panic("noop engine: Mul called") } +func (ne *noopEngine) Sub(a, b element) element { panic("noop engine: Sub called") } +func (ne *noopEngine) One() element { panic("noop engine: One called") } +func (ne *noopEngine) Const(i *big.Int) element { panic("noop engine: Const called") } \ No newline at end of file diff --git a/std/recursion/sumcheck/challenge.go b/std/recursion/sumcheck/challenge.go index fb9e87ee4c..3a8759e346 100644 --- a/std/recursion/sumcheck/challenge.go +++ b/std/recursion/sumcheck/challenge.go @@ -25,7 +25,7 @@ func getChallengeNames(prefix string, nbClaims int, nbVars int) []string { } // bindChallengeProver binds the values for challengeName using native Fiat-Shamir transcript. -func bindChallengeProver(fs *cryptofiatshamir.Transcript, challengeName string, values []*big.Int) error { +func BindChallengeProver(fs *cryptofiatshamir.Transcript, challengeName string, values []*big.Int) error { for i := range values { buf := make([]byte, 32) values[i].FillBytes(buf) @@ -39,8 +39,8 @@ func bindChallengeProver(fs *cryptofiatshamir.Transcript, challengeName string, // deriveChallengeProver binds the values for challengeName and then returns the // challenge using native Fiat-Shamir transcript. It also returns the rest of // the challenge names for used in the protocol. -func deriveChallengeProver(fs *cryptofiatshamir.Transcript, challengeNames []string, values []*big.Int) (challenge *big.Int, restChallengeNames []string, err error) { - if err = bindChallengeProver(fs, challengeNames[0], values); err != nil { +func DeriveChallengeProver(fs *cryptofiatshamir.Transcript, challengeNames []string, values []*big.Int) (challenge *big.Int, restChallengeNames []string, err error) { + if err = BindChallengeProver(fs, challengeNames[0], values); err != nil { return nil, nil, fmt.Errorf("bind: %w", err) } nativeChallenge, err := fs.ComputeChallenge(challengeNames[0]) @@ -51,6 +51,7 @@ func deriveChallengeProver(fs *cryptofiatshamir.Transcript, challengeNames []str return challenge, challengeNames[1:], nil } +// todo change this bind as limbs instead of bits, ask @arya if necessary // bindChallenge binds the values for challengeName using in-circuit Fiat-Shamir transcript. func (v *Verifier[FR]) bindChallenge(fs *fiatshamir.Transcript, challengeName string, values []emulated.Element[FR]) error { for i := range values { diff --git a/std/recursion/sumcheck/claim_intf.go b/std/recursion/sumcheck/claim_intf.go index d2df83aea6..731234debd 100644 --- a/std/recursion/sumcheck/claim_intf.go +++ b/std/recursion/sumcheck/claim_intf.go @@ -28,12 +28,12 @@ type claims interface { NbVars() int // Combine combines separate claims into a single sumcheckable claim using // the coefficient coeff. - Combine(coeff *big.Int) nativePolynomial + Combine(coeff *big.Int) NativePolynomial // Next fixes the next free variable to r, keeps the next variable free and // sums over a hypercube for the last variables. Instead of returning the // polynomial in coefficient form, it returns the evaluations at degree // different points. - Next(r *big.Int) nativePolynomial + Next(r *big.Int) NativePolynomial // ProverFinalEval returns the (lazy) evaluation proof. - ProverFinalEval(r []*big.Int) nativeEvaluationProof + ProverFinalEval(r []*big.Int) NativeEvaluationProof } diff --git a/std/recursion/sumcheck/claimable_gate.go b/std/recursion/sumcheck/claimable_gate.go index 04884388ee..b6c8d86ff3 100644 --- a/std/recursion/sumcheck/claimable_gate.go +++ b/std/recursion/sumcheck/claimable_gate.go @@ -11,11 +11,11 @@ import ( ) // gate defines a multivariate polynomial which can be sumchecked. -type gate[AE arithEngine[E], E element] interface { +type gate[AE ArithEngine[E], E element] interface { // NbInputs is the number of inputs the gate takes. NbInputs() int // Evaluate evaluates the gate at inputs vars. - Evaluate(api AE, vars ...E) E + Evaluate(api AE, vars ...E) []E // Degree returns the maximum degree of the variables. Degree() int // TODO: return degree of variable for optimized verification } @@ -27,9 +27,9 @@ type gate[AE arithEngine[E], E element] interface { type gateClaim[FR emulated.FieldParams] struct { f *emulated.Field[FR] p *polynomial.Polynomial[FR] - engine *emuEngine[FR] + engine *EmuEngine[FR] - gate gate[*emuEngine[FR], *emulated.Element[FR]] + gate gate[*EmuEngine[FR], *emulated.Element[FR]] evaluationPoints [][]*emulated.Element[FR] claimedEvaluations []*emulated.Element[FR] @@ -48,7 +48,7 @@ type gateClaim[FR emulated.FieldParams] struct { // evaluationPoints is the random coefficients for ensuring the consistency of // the inputs during the final round and claimedEvals is the claimed evaluation // values with the inputs combined at the evaluationPoints. -func newGate[FR emulated.FieldParams](api frontend.API, gate gate[*emuEngine[FR], *emulated.Element[FR]], +func newGate[FR emulated.FieldParams](api frontend.API, gate gate[*EmuEngine[FR], *emulated.Element[FR]], inputs [][]*emulated.Element[FR], evaluationPoints [][]*emulated.Element[FR], claimedEvals []*emulated.Element[FR]) (LazyClaims[FR], error) { nbInputs := gate.NbInputs() @@ -71,7 +71,7 @@ func newGate[FR emulated.FieldParams](api frontend.API, gate gate[*emuEngine[FR] if err != nil { return nil, fmt.Errorf("new polynomial: %w", err) } - engine, err := newEmulatedEngine[FR](api) + engine, err := NewEmulatedEngine[FR](api) if err != nil { return nil, fmt.Errorf("new emulated engine: %w", err) } @@ -146,15 +146,15 @@ func (g *gateClaim[FR]) AssertEvaluation(r []*emulated.Element[FR], combinationC // now, we can evaluate the gate at the random input. gateEval := g.gate.Evaluate(g.engine, inputEvals...) - res := g.f.Mul(eqEval, gateEval) + res := g.f.Mul(eqEval, gateEval[0]) g.f.AssertIsEqual(res, expectedValue) return nil } type nativeGateClaim struct { - engine *bigIntEngine + engine *BigIntEngine - gate gate[*bigIntEngine, *big.Int] + gate gate[*BigIntEngine, *big.Int] evaluationPoints [][]*big.Int claimedEvaluations []*big.Int @@ -163,13 +163,13 @@ type nativeGateClaim struct { // multi-instance input id to the instance value. This allows running // sumcheck over the hypercube. Every element in the slice represents the // input. - inputPreprocessors []nativeMultilinear + inputPreprocessors []NativeMultilinear - eq nativeMultilinear + eq NativeMultilinear } -func newNativeGate(target *big.Int, gate gate[*bigIntEngine, *big.Int], inputs [][]*big.Int, evaluationPoints [][]*big.Int) (claim claims, evaluations []*big.Int, err error) { - be := newBigIntEngine(target) +func newNativeGate(target *big.Int, gate gate[*BigIntEngine, *big.Int], inputs [][]*big.Int, evaluationPoints [][]*big.Int) (claim claims, evaluations []*big.Int, err error) { + be := &BigIntEngine{mod: new(big.Int).Set(target)} nbInputs := gate.NbInputs() if len(inputs) != nbInputs { return nil, nil, fmt.Errorf("expected %d inputs got %d", nbInputs, len(inputs)) @@ -184,7 +184,7 @@ func newNativeGate(target *big.Int, gate gate[*bigIntEngine, *big.Int], inputs [ evalInput := make([][]*big.Int, nbInstances) // TODO: pad input to power of two for i := range evalInput { - evalInput[i] = make(nativeMultilinear, nbInputs) + evalInput[i] = make(NativeMultilinear, nbInputs) for j := range evalInput[i] { evalInput[i][j] = new(big.Int).Set(inputs[j][i]) } @@ -193,12 +193,12 @@ func newNativeGate(target *big.Int, gate gate[*bigIntEngine, *big.Int], inputs [ evaluations = make([]*big.Int, nbInstances) for i := range evaluations { evaluations[i] = new(big.Int) - evaluations[i] = gate.Evaluate(be, evalInput[i]...) + evaluations[i] = gate.Evaluate(be, evalInput[i]...)[0] } // construct the mapping (inputIdx, instanceIdx) -> inputVal - inputPreprocessors := make([]nativeMultilinear, nbInputs) + inputPreprocessors := make([]NativeMultilinear, nbInputs) for i := range inputs { - inputPreprocessors[i] = make(nativeMultilinear, nbInstances) + inputPreprocessors[i] = make(NativeMultilinear, nbInstances) for j := range inputs[i] { inputPreprocessors[i][j] = new(big.Int).Set(inputs[i][j]) } @@ -211,7 +211,7 @@ func newNativeGate(target *big.Int, gate gate[*bigIntEngine, *big.Int], inputs [ // compute the random linear combinations of the evaluation values of the gate claimedEvaluations := make([]*big.Int, len(evaluationPoints)) for i := range claimedEvaluations { - claimedEvaluations[i] = eval(be, evaluations, evaluationPoints[i]) + claimedEvaluations[i] = Eval(be, evaluations, evaluationPoints[i]) } return &nativeGateClaim{ engine: be, @@ -231,19 +231,19 @@ func (g *nativeGateClaim) NbVars() int { return len(g.evaluationPoints[0]) } -func (g *nativeGateClaim) Combine(coeff *big.Int) nativePolynomial { +func (g *nativeGateClaim) Combine(coeff *big.Int) NativePolynomial { nbVars := g.NbVars() eqLength := 1 << nbVars nbClaims := g.NbClaims() - g.eq = make(nativeMultilinear, eqLength) + g.eq = make(NativeMultilinear, eqLength) g.eq[0] = g.engine.One() for i := 1; i < eqLength; i++ { g.eq[i] = new(big.Int) } - g.eq = eq(g.engine, g.eq, g.evaluationPoints[0]) + g.eq = Eq(g.engine, g.eq, g.evaluationPoints[0]) - newEq := make(nativeMultilinear, eqLength) + newEq := make(NativeMultilinear, eqLength) for i := 1; i < eqLength; i++ { newEq[i] = new(big.Int) } @@ -251,7 +251,7 @@ func (g *nativeGateClaim) Combine(coeff *big.Int) nativePolynomial { for k := 1; k < nbClaims; k++ { newEq[0] = g.engine.One() - g.eq = eqAcc(g.engine, g.eq, newEq, g.evaluationPoints[k]) + g.eq = EqAcc(g.engine, g.eq, newEq, g.evaluationPoints[k]) if k+1 < nbClaims { aI = g.engine.Mul(aI, coeff) } @@ -260,32 +260,32 @@ func (g *nativeGateClaim) Combine(coeff *big.Int) nativePolynomial { return g.computeGJ() } -func (g *nativeGateClaim) Next(r *big.Int) nativePolynomial { +func (g *nativeGateClaim) Next(r *big.Int) NativePolynomial { for i := range g.inputPreprocessors { - g.inputPreprocessors[i] = fold(g.engine, g.inputPreprocessors[i], r) + g.inputPreprocessors[i] = Fold(g.engine, g.inputPreprocessors[i], r) } - g.eq = fold(g.engine, g.eq, r) + g.eq = Fold(g.engine, g.eq, r) return g.computeGJ() } -func (g *nativeGateClaim) ProverFinalEval(r []*big.Int) nativeEvaluationProof { +func (g *nativeGateClaim) ProverFinalEval(r []*big.Int) NativeEvaluationProof { // verifier computes the value of the gate (times the eq) itself return nil } -func (g *nativeGateClaim) computeGJ() nativePolynomial { +func (g *nativeGateClaim) computeGJ() NativePolynomial { // returns the polynomial GJ through its evaluations degGJ := 1 + g.gate.Degree() nbGateIn := len(g.inputPreprocessors) - s := make([]nativeMultilinear, nbGateIn+1) + s := make([]NativeMultilinear, nbGateIn+1) s[0] = g.eq copy(s[1:], g.inputPreprocessors) nbInner := len(s) nbOuter := len(s[0]) / 2 - gJ := make(nativePolynomial, degGJ) + gJ := make(NativePolynomial, degGJ) for i := range gJ { gJ[i] = new(big.Int) } @@ -314,7 +314,7 @@ func (g *nativeGateClaim) computeGJ() nativePolynomial { _s := 0 _e := nbInner for d := 0; d < degGJ; d++ { - summand := g.gate.Evaluate(g.engine, operands[_s+1:_e]...) + summand := g.gate.Evaluate(g.engine, operands[_s+1:_e]...)[0] summand = g.engine.Mul(summand, operands[_s]) res[d] = g.engine.Add(res[d], summand) _s, _e = _e, _e+nbInner diff --git a/std/recursion/sumcheck/claimable_multilinear.go b/std/recursion/sumcheck/claimable_multilinear.go index c73395514f..7bb4b43918 100644 --- a/std/recursion/sumcheck/claimable_multilinear.go +++ b/std/recursion/sumcheck/claimable_multilinear.go @@ -62,7 +62,7 @@ func (fn *multilinearClaim[FR]) AssertEvaluation(r []*emulated.Element[FR], comb } type nativeMultilinearClaim struct { - be *bigIntEngine + be *BigIntEngine ml []*big.Int } @@ -71,7 +71,7 @@ func newNativeMultilinearClaim(target *big.Int, ml []*big.Int) (claim claims, hy if bits.OnesCount(uint(len(ml))) != 1 { return nil, nil, fmt.Errorf("expecting power of two coeffs") } - be := newBigIntEngine(target) + be := NewBigIntEngine(target) hypersum = new(big.Int) for i := range ml { hypersum = be.Add(hypersum, ml[i]) @@ -91,16 +91,16 @@ func (fn *nativeMultilinearClaim) NbVars() int { return bits.Len(uint(len(fn.ml))) - 1 } -func (fn *nativeMultilinearClaim) Combine(coeff *big.Int) nativePolynomial { +func (fn *nativeMultilinearClaim) Combine(coeff *big.Int) NativePolynomial { return []*big.Int{hypersumX1One(fn.be, fn.ml)} } -func (fn *nativeMultilinearClaim) Next(r *big.Int) nativePolynomial { - fn.ml = fold(fn.be, fn.ml, r) +func (fn *nativeMultilinearClaim) Next(r *big.Int) NativePolynomial { + fn.ml = Fold(fn.be, fn.ml, r) return []*big.Int{hypersumX1One(fn.be, fn.ml)} } -func (fn *nativeMultilinearClaim) ProverFinalEval(r []*big.Int) nativeEvaluationProof { +func (fn *nativeMultilinearClaim) ProverFinalEval(r []*big.Int) NativeEvaluationProof { // verifier computes the value of the multilinear function itself return nil } diff --git a/std/recursion/sumcheck/fullscalarmul_test.go b/std/recursion/sumcheck/fullscalarmul_test.go new file mode 100644 index 0000000000..b7eecd7180 --- /dev/null +++ b/std/recursion/sumcheck/fullscalarmul_test.go @@ -0,0 +1,471 @@ +package sumcheck + +import ( + "crypto/rand" + "fmt" + "math/big" + stdbits "math/bits" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/secp256k1" + fr_secp256k1 "github.com/consensys/gnark-crypto/ecc/secp256k1/fr" + cryptofs "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/profile" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/gnark/std/algebra" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" + "github.com/consensys/gnark/std/math/bits" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/math/emulated/emparams" + "github.com/consensys/gnark/std/math/polynomial" + "github.com/consensys/gnark/std/recursion" + "github.com/consensys/gnark/test" +) + +type ProjectivePoint[Base emulated.FieldParams] struct { + X, Y, Z emulated.Element[Base] +} + +type ScalarMulCircuit[Base, Scalars emulated.FieldParams] struct { + Points []sw_emulated.AffinePoint[Base] + Scalars []emulated.Element[Scalars] + + nbScalarBits int +} + +func (c *ScalarMulCircuit[B, S]) Define(api frontend.API) error { + var fp B + nbInputs := len(c.Points) + if len(c.Points) != len(c.Scalars) { + return fmt.Errorf("len(inputs) != len(scalars)") + } + baseApi, err := emulated.NewField[B](api) + if err != nil { + return fmt.Errorf("new base field: %w", err) + } + scalarApi, err := emulated.NewField[S](api) + if err != nil { + return fmt.Errorf("new scalar field: %w", err) + } + poly, err := polynomial.New[B](api) + if err != nil { + return fmt.Errorf("new polynomial: %w", err) + } + // we use curve for marshaling points and scalars + curve, err := algebra.GetCurve[S, sw_emulated.AffinePoint[B]](api) + if err != nil { + return fmt.Errorf("get curve: %w", err) + } + fs, err := recursion.NewTranscript(api, fp.Modulus(), []string{"alpha", "beta"}) + if err != nil { + return fmt.Errorf("new transcript: %w", err) + } + // compute the all double-and-add steps for each scalar multiplication + // var results, accs []ProjectivePoint[B] + for i := range c.Points { + if err := fs.Bind("alpha", curve.MarshalG1(c.Points[i])); err != nil { + return fmt.Errorf("bind point %d alpha: %w", i, err) + } + if err := fs.Bind("alpha", curve.MarshalScalar(c.Scalars[i])); err != nil { + return fmt.Errorf("bind scalar %d alpha: %w", i, err) + } + } + result, acc, proof, err := callHintScalarMulSteps[B, S](api, baseApi, scalarApi, c.nbScalarBits, c.Points, c.Scalars) + if err != nil { + return fmt.Errorf("hint scalar mul steps: %w", err) + } + + // derive the randomness for random linear combination + alphaNative, err := fs.ComputeChallenge("alpha") + if err != nil { + return fmt.Errorf("compute challenge alpha: %w", err) + } + alphaBts := bits.ToBinary(api, alphaNative, bits.WithNbDigits(fp.Modulus().BitLen())) + alphas := make([]*emulated.Element[B], 6) + alphas[0] = baseApi.One() + alphas[1] = baseApi.FromBits(alphaBts...) + for i := 2; i < len(alphas); i++ { + alphas[i] = baseApi.Mul(alphas[i-1], alphas[1]) + } + claimed := make([]*emulated.Element[B], nbInputs*c.nbScalarBits) + // compute the random linear combinations of the intermediate results provided by the hint + for i := 0; i < nbInputs; i++ { + for j := 0; j < c.nbScalarBits; j++ { + claimed[i*c.nbScalarBits+j] = baseApi.Sum( + &acc[i][j].X, + baseApi.MulNoReduce(alphas[1], &acc[i][j].Y), + baseApi.MulNoReduce(alphas[2], &acc[i][j].Z), + baseApi.MulNoReduce(alphas[3], &result[i][j].X), + baseApi.MulNoReduce(alphas[4], &result[i][j].Y), + baseApi.MulNoReduce(alphas[5], &result[i][j].Z), + ) + } + } + // derive the randomness for folding + betaNative, err := fs.ComputeChallenge("beta") + if err != nil { + return fmt.Errorf("compute challenge alpha: %w", err) + } + betaBts := bits.ToBinary(api, betaNative, bits.WithNbDigits(fp.Modulus().BitLen())) + evalPoints := make([]*emulated.Element[B], stdbits.Len(uint(len(claimed)))-1) + evalPoints[0] = baseApi.FromBits(betaBts...) + for i := 1; i < len(evalPoints); i++ { + evalPoints[i] = baseApi.Mul(evalPoints[i-1], evalPoints[0]) + } + // compute the polynomial evaluation + claimedPoly := polynomial.FromSliceReferences(claimed) + evaluation, err := poly.EvalMultilinear(evalPoints, claimedPoly) + if err != nil { + return fmt.Errorf("eval multilinear: %w", err) + } + + inputs := make([][]*emulated.Element[B], 7) + for i := range inputs { + inputs[i] = make([]*emulated.Element[B], nbInputs*c.nbScalarBits) + } + for i := 0; i < nbInputs; i++ { + scalarBts := scalarApi.ToBits(&c.Scalars[i]) + inputs[0][i*c.nbScalarBits] = &c.Points[i].X + inputs[1][i*c.nbScalarBits] = &c.Points[i].Y + inputs[2][i*c.nbScalarBits] = baseApi.One() + inputs[3][i*c.nbScalarBits] = baseApi.Zero() + inputs[4][i*c.nbScalarBits] = baseApi.One() + inputs[5][i*c.nbScalarBits] = baseApi.Zero() + inputs[6][i*c.nbScalarBits] = baseApi.NewElement(scalarBts[0]) + for j := 1; j < c.nbScalarBits; j++ { + inputs[0][i*c.nbScalarBits+j] = &acc[i][j-1].X + inputs[1][i*c.nbScalarBits+j] = &acc[i][j-1].Y + inputs[2][i*c.nbScalarBits+j] = &acc[i][j-1].Z + inputs[3][i*c.nbScalarBits+j] = &result[i][j-1].X + inputs[4][i*c.nbScalarBits+j] = &result[i][j-1].Y + inputs[5][i*c.nbScalarBits+j] = &result[i][j-1].Z + inputs[6][i*c.nbScalarBits+j] = baseApi.NewElement(scalarBts[j]) + } + } + gate := DblAddSelectGate[*EmuEngine[B], *emulated.Element[B]]{Folding: alphas} + claim, err := newGate[B](api, gate, inputs, [][]*emulated.Element[B]{evalPoints}, []*emulated.Element[B]{evaluation}) + v, err := NewVerifier[B](api) + if err != nil { + return fmt.Errorf("new sumcheck verifier: %w", err) + } + if err = v.Verify(claim, proof); err != nil { + return fmt.Errorf("verify sumcheck: %w", err) + } + + return nil +} + +func callHintScalarMulSteps[B, S emulated.FieldParams](api frontend.API, + baseApi *emulated.Field[B], scalarApi *emulated.Field[S], + nbScalarBits int, + points []sw_emulated.AffinePoint[B], scalars []emulated.Element[S]) (results [][]ProjectivePoint[B], accumulators [][]ProjectivePoint[B], proof Proof[B], err error) { + var fp B + var fr S + nbInputs := len(points) + inputs := []frontend.Variable{nbInputs, nbScalarBits, fp.BitsPerLimb(), fp.NbLimbs(), fr.BitsPerLimb(), fr.NbLimbs()} + inputs = append(inputs, baseApi.Modulus().Limbs...) + inputs = append(inputs, scalarApi.Modulus().Limbs...) + for i := range points { + inputs = append(inputs, points[i].X.Limbs...) + inputs = append(inputs, points[i].Y.Limbs...) + inputs = append(inputs, scalars[i].Limbs...) + } + // steps part + nbRes := nbScalarBits * int(fp.NbLimbs()) * 6 * nbInputs + // proof part + nbRes += int(fp.NbLimbs()) * (stdbits.Len(uint(nbInputs*nbScalarBits)) - 1) * (DblAddSelectGate[*noopEngine, element]{}.Degree() + 1) + hintRes, err := api.Compiler().NewHint(hintScalarMulSteps, nbRes, inputs...) + if err != nil { + return nil, nil, proof, fmt.Errorf("new hint: %w", err) + } + res := make([][]ProjectivePoint[B], nbInputs) + acc := make([][]ProjectivePoint[B], nbInputs) + for i := 0; i < nbInputs; i++ { + res[i] = make([]ProjectivePoint[B], nbScalarBits) + acc[i] = make([]ProjectivePoint[B], nbScalarBits) + } + for i := 0; i < nbInputs; i++ { + inputRes := hintRes[i*(6*int(fp.NbLimbs())*nbScalarBits) : (i+1)*(6*int(fp.NbLimbs())*nbScalarBits)] + for j := 0; j < nbScalarBits; j++ { + coords := make([]*emulated.Element[B], 6) + for k := range coords { + limbs := inputRes[j*(6*int(fp.NbLimbs()))+k*int(fp.NbLimbs()) : j*(6*int(fp.NbLimbs()))+(k+1)*int(fp.NbLimbs())] + coords[k] = baseApi.NewElement(limbs) + } + res[i][j] = ProjectivePoint[B]{ + X: *coords[0], + Y: *coords[1], + Z: *coords[2], + } + acc[i][j] = ProjectivePoint[B]{ + X: *coords[3], + Y: *coords[4], + Z: *coords[5], + } + } + } + proof.RoundPolyEvaluations = make([]polynomial.Univariate[B], stdbits.Len(uint(nbInputs*nbScalarBits))-1) + ptr := nbInputs * 6 * int(fp.NbLimbs()) * nbScalarBits + for i := range proof.RoundPolyEvaluations { + proof.RoundPolyEvaluations[i] = make(polynomial.Univariate[B], DblAddSelectGate[*noopEngine, element]{}.Degree()+1) + for j := range proof.RoundPolyEvaluations[i] { + limbs := hintRes[ptr : ptr+int(fp.NbLimbs())] + el := baseApi.NewElement(limbs) + proof.RoundPolyEvaluations[i][j] = *el + ptr += int(fp.NbLimbs()) + } + } + return res, acc, proof, nil +} + +func hintScalarMulSteps(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { + nbInputs := int(inputs[0].Int64()) + scalarLength := int(inputs[1].Int64()) + nbBits := int(inputs[2].Int64()) + nbLimbs := int(inputs[3].Int64()) + nbScalarBits := int(inputs[4].Int64()) + nbScalarLimbs := int(inputs[5].Int64()) + fpLimbs := inputs[6 : 6+nbLimbs] + frLimbs := inputs[6+nbLimbs : 6+nbLimbs+nbScalarLimbs] + fp := new(big.Int) + fr := new(big.Int) + if err := recompose(fpLimbs, uint(nbBits), fp); err != nil { + return fmt.Errorf("recompose fp: %w", err) + } + if err := recompose(frLimbs, uint(nbScalarBits), fr); err != nil { + return fmt.Errorf("recompose fr: %w", err) + } + ptr := 6 + nbLimbs + nbScalarLimbs + xs := make([]*big.Int, nbInputs) + ys := make([]*big.Int, nbInputs) + scalars := make([]*big.Int, nbInputs) + for i := 0; i < nbInputs; i++ { + xLimbs := inputs[ptr : ptr+nbLimbs] + ptr += nbLimbs + yLimbs := inputs[ptr : ptr+nbLimbs] + ptr += nbLimbs + scalarLimbs := inputs[ptr : ptr+nbScalarLimbs] + ptr += nbScalarLimbs + xs[i] = new(big.Int) + ys[i] = new(big.Int) + scalars[i] = new(big.Int) + if err := recompose(xLimbs, uint(nbBits), xs[i]); err != nil { + return fmt.Errorf("recompose x: %w", err) + } + if err := recompose(yLimbs, uint(nbBits), ys[i]); err != nil { + return fmt.Errorf("recompose y: %w", err) + } + if err := recompose(scalarLimbs, uint(nbScalarBits), scalars[i]); err != nil { + return fmt.Errorf("recompose scalar: %w", err) + } + } + + // first, we need to provide the steps of the scalar multiplication to the + // verifier. As the output of one step is an input of the next step, we need + // to provide the results and the accumulators. By checking the consistency + // of the inputs related to the outputs (inputs using multilinear evaluation + // in the final round of the sumcheck and outputs by requiring the verifier + // to construct the claim itself), we can ensure that the final step is the + // actual scalar multiplication result. + api := NewBigIntEngine(fp) + selector := new(big.Int) + outPtr := 0 + proofInput := make([][]*big.Int, 7) + for i := range proofInput { + proofInput[i] = make([]*big.Int, nbInputs*scalarLength) + } + for i := 0; i < nbInputs; i++ { + scalar := new(big.Int).Set(scalars[i]) + x := xs[i] + y := ys[i] + accX := new(big.Int).Set(x) + accY := new(big.Int).Set(y) + accZ := big.NewInt(1) + resultX := big.NewInt(0) + resultY := big.NewInt(1) + resultZ := big.NewInt(0) + for j := 0; j < scalarLength; j++ { + selector.And(scalar, big.NewInt(1)) + scalar.Rsh(scalar, 1) + proofInput[0][i*scalarLength+j] = new(big.Int).Set(accX) + proofInput[1][i*scalarLength+j] = new(big.Int).Set(accY) + proofInput[2][i*scalarLength+j] = new(big.Int).Set(accZ) + proofInput[3][i*scalarLength+j] = new(big.Int).Set(resultX) + proofInput[4][i*scalarLength+j] = new(big.Int).Set(resultY) + proofInput[5][i*scalarLength+j] = new(big.Int).Set(resultZ) + proofInput[6][i*scalarLength+j] = new(big.Int).Set(selector) + tmpX, tmpY, tmpZ := ProjAdd(api, accX, accY, accZ, resultX, resultY, resultZ) + resultX, resultY, resultZ = ProjSelect(api, selector, tmpX, tmpY, tmpZ, resultX, resultY, resultZ) + accX, accY, accZ = ProjDbl(api, accX, accY, accZ) + if err := decompose(resultX, uint(nbBits), outputs[outPtr:outPtr+nbLimbs]); err != nil { + return fmt.Errorf("decompose resultX: %w", err) + } + outPtr += nbLimbs + if err := decompose(resultY, uint(nbBits), outputs[outPtr:outPtr+nbLimbs]); err != nil { + return fmt.Errorf("decompose resultY: %w", err) + } + outPtr += nbLimbs + if err := decompose(resultZ, uint(nbBits), outputs[outPtr:outPtr+nbLimbs]); err != nil { + return fmt.Errorf("decompose resultZ: %w", err) + } + outPtr += nbLimbs + if err := decompose(accX, uint(nbBits), outputs[outPtr:outPtr+nbLimbs]); err != nil { + return fmt.Errorf("decompose accX: %w", err) + } + outPtr += nbLimbs + if err := decompose(accY, uint(nbBits), outputs[outPtr:outPtr+nbLimbs]); err != nil { + return fmt.Errorf("decompose accY: %w", err) + } + outPtr += nbLimbs + if err := decompose(accZ, uint(nbBits), outputs[outPtr:outPtr+nbLimbs]); err != nil { + return fmt.Errorf("decompose accZ: %w", err) + } + outPtr += nbLimbs + } + } + + // now, we construct the sumcheck proof. For that we first need to compute + // the challenges for computing the random linear combination of the + // double-and-add outputs and for the claim polynomial evaluation. + h, err := recursion.NewShort(mod, fp) + if err != nil { + return fmt.Errorf("new short hash: %w", err) + } + fs := cryptofs.NewTranscript(h, "alpha", "beta") + for i := range xs { + var P secp256k1.G1Affine + var s fr_secp256k1.Element + P.X.SetBigInt(xs[i]) + P.Y.SetBigInt(ys[i]) + raw := P.RawBytes() + if err := fs.Bind("alpha", raw[:]); err != nil { + return fmt.Errorf("bind alpha point: %w", err) + } + s.SetBigInt(scalars[i]) + if err := fs.Bind("alpha", s.Marshal()); err != nil { + return fmt.Errorf("bind alpha scalar: %w", err) + } + } + // challenges. + // alpha is used for the random linear combination of the double-and-add + alpha, err := fs.ComputeChallenge("alpha") + if err != nil { + return fmt.Errorf("compute challenge alpha: %w", err) + } + alphas := make([]*big.Int, 6) + alphas[0] = big.NewInt(1) + alphas[1] = new(big.Int).SetBytes(alpha) + for i := 2; i < len(alphas); i++ { + alphas[i] = new(big.Int).Mul(alphas[i-1], alphas[1]) + } + + // beta is used for the claim polynomial evaluation + beta, err := fs.ComputeChallenge("beta") + if err != nil { + return fmt.Errorf("compute challenge beta: %w", err) + } + betas := make([]*big.Int, stdbits.Len(uint(nbInputs*scalarLength))-1) + betas[0] = new(big.Int).SetBytes(beta) + for i := 1; i < len(betas); i++ { + betas[i] = new(big.Int).Mul(betas[i-1], betas[0]) + } + + nativeGate := DblAddSelectGate[*BigIntEngine, *big.Int]{Folding: alphas} + claim, evals, err := newNativeGate(fp, nativeGate, proofInput, [][]*big.Int{betas}) + if err != nil { + return fmt.Errorf("new native gate: %w", err) + } + proof, err := Prove(mod, fp, claim) + if err != nil { + return fmt.Errorf("prove: %w", err) + } + for _, pl := range proof.RoundPolyEvaluations { + for j := range pl { + if err := decompose(pl[j], uint(nbBits), outputs[outPtr:outPtr+nbLimbs]); err != nil { + return fmt.Errorf("decompose claim: %w", err) + } + outPtr += nbLimbs + } + } + // verifier computes the evaluation itself for consistency. We do not pass + // it through the hint. Explicitly ignore. + _ = evals + return nil +} + +func recompose(inputs []*big.Int, nbBits uint, res *big.Int) error { + if len(inputs) == 0 { + return fmt.Errorf("zero length slice input") + } + if res == nil { + return fmt.Errorf("result not initialized") + } + res.SetUint64(0) + for i := range inputs { + res.Lsh(res, nbBits) + res.Add(res, inputs[len(inputs)-i-1]) + } + // TODO @gbotrel mod reduce ? + return nil +} + +func decompose(input *big.Int, nbBits uint, res []*big.Int) error { + // limb modulus + if input.BitLen() > len(res)*int(nbBits) { + return fmt.Errorf("decomposed integer does not fit into res") + } + for _, r := range res { + if r == nil { + return fmt.Errorf("result slice element uninitalized") + } + } + base := new(big.Int).Lsh(big.NewInt(1), nbBits) + tmp := new(big.Int).Set(input) + for i := 0; i < len(res); i++ { + res[i].Mod(tmp, base) + tmp.Rsh(tmp, nbBits) + } + return nil +} + +func TestScalarMul(t *testing.T) { + assert := test.NewAssert(t) + type B = emparams.Secp256k1Fp + type S = emparams.Secp256k1Fr + var P secp256k1.G1Affine + var s fr_secp256k1.Element + nbInputs := 1 << 2 + nbScalarBits := 256 + scalarBound := new(big.Int).Lsh(big.NewInt(1), uint(nbScalarBits)) + points := make([]sw_emulated.AffinePoint[B], nbInputs) + scalars := make([]emulated.Element[S], nbInputs) + for i := range points { + P.ScalarMultiplicationBase(big.NewInt(1)) + s.SetRandom() + P.ScalarMultiplicationBase(s.BigInt(new(big.Int))) + sc, _ := rand.Int(rand.Reader, scalarBound) + points[i] = sw_emulated.AffinePoint[B]{ + X: emulated.ValueOf[B](P.X), + Y: emulated.ValueOf[B](P.Y), + } + scalars[i] = emulated.ValueOf[S](sc) + } + circuit := ScalarMulCircuit[B, S]{ + Points: make([]sw_emulated.AffinePoint[B], nbInputs), + Scalars: make([]emulated.Element[S], nbInputs), + nbScalarBits: nbScalarBits, + } + witness := ScalarMulCircuit[B, S]{ + Points: points, + Scalars: scalars, + } + err := test.IsSolved(&circuit, &witness, ecc.BLS12_377.ScalarField()) + assert.NoError(err) + p := profile.Start() + _, _ = frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit) + p.Stop() + fmt.Println(p.NbConstraints()) + +} \ No newline at end of file diff --git a/std/recursion/sumcheck/polynomial.go b/std/recursion/sumcheck/polynomial.go index aaeb318fe4..dae8488ca4 100644 --- a/std/recursion/sumcheck/polynomial.go +++ b/std/recursion/sumcheck/polynomial.go @@ -2,15 +2,58 @@ package sumcheck import ( "math/big" + "math/bits" ) -type nativePolynomial []*big.Int -type nativeMultilinear []*big.Int +type NativePolynomial []*big.Int +type NativeMultilinear []*big.Int // helper functions for multilinear polynomial evaluations -func fold(api *bigIntEngine, ml nativeMultilinear, r *big.Int) nativeMultilinear { - // NB! it modifies ml in-place and also returns +// Clone returns a deep copy of p. +// If capacity is provided, the new coefficient slice capacity will be set accordingly. +func (p NativeMultilinear) Clone(capacity ...int) NativeMultilinear { + var newCapacity int + if len(capacity) > 0 { + newCapacity = capacity[0] + } else { + newCapacity = len(p) + } + + res := make(NativeMultilinear, len(p), newCapacity) + for i, v := range p { + res[i] = new(big.Int).Set(v) + } + return res +} + +func DereferenceBigIntSlice(ptrs []*big.Int) []big.Int { + vals := make([]big.Int, len(ptrs)) + for i, ptr := range ptrs { + vals[i] = *ptr + } + return vals +} + +func ReferenceBigIntSlice(vals []big.Int) []*big.Int { + ptrs := make([]*big.Int, len(vals)) + for i := range ptrs { + ptrs[i] = &vals[i] + } + return ptrs +} + +func BatchRLC(api *BigIntEngine, mlpolys []NativeMultilinear, r []*big.Int) NativeMultilinear { + res := make(NativeMultilinear, len(mlpolys[0])) + for j := 0; j < len(mlpolys[0]); j++ { + for i := 0; i < len(mlpolys); i++ { + res[j] = api.Add(res[j], api.Mul(mlpolys[i][j], r[i])) + } + } + return res +} + +func Fold(api *BigIntEngine, ml NativeMultilinear, r *big.Int) NativeMultilinear { mid := len(ml) / 2 bottom, top := ml[:mid], ml[mid:] var t *big.Int @@ -22,7 +65,7 @@ func fold(api *bigIntEngine, ml nativeMultilinear, r *big.Int) nativeMultilinear return ml[:mid] } -func hypersumX1One(api *bigIntEngine, ml nativeMultilinear) *big.Int { +func hypersumX1One(api *BigIntEngine, ml NativeMultilinear) *big.Int { sum := ml[len(ml)/2] for i := len(ml)/2 + 1; i < len(ml); i++ { sum = api.Add(sum, ml[i]) @@ -30,7 +73,7 @@ func hypersumX1One(api *bigIntEngine, ml nativeMultilinear) *big.Int { return sum } -func eq(api *bigIntEngine, ml nativeMultilinear, q []*big.Int) nativeMultilinear { +func Eq(api *BigIntEngine, ml NativeMultilinear, q []*big.Int) NativeMultilinear { if (1 << len(q)) != len(ml) { panic("scalar length mismatch") } @@ -46,20 +89,20 @@ func eq(api *bigIntEngine, ml nativeMultilinear, q []*big.Int) nativeMultilinear return ml } -func eval(api *bigIntEngine, ml nativeMultilinear, r []*big.Int) *big.Int { - mlCopy := make(nativeMultilinear, len(ml)) +func Eval(api *BigIntEngine, ml NativeMultilinear, r []*big.Int) *big.Int { + mlCopy := make(NativeMultilinear, len(ml)) for i := range mlCopy { mlCopy[i] = new(big.Int).Set(ml[i]) } for _, ri := range r { - mlCopy = fold(api, mlCopy, ri) + mlCopy = Fold(api, mlCopy, ri) } return mlCopy[0] } -func eqAcc(api *bigIntEngine, e nativeMultilinear, m nativeMultilinear, q []*big.Int) nativeMultilinear { +func EqAcc(api *BigIntEngine, e NativeMultilinear, m NativeMultilinear, q []*big.Int) NativeMultilinear { if len(e) != len(m) { panic("length mismatch") } @@ -83,3 +126,7 @@ func eqAcc(api *bigIntEngine, e nativeMultilinear, m nativeMultilinear, q []*big } return e } + +func (m NativeMultilinear) NumVars() int { + return bits.TrailingZeros(uint(len(m))) +} diff --git a/std/recursion/sumcheck/proof.go b/std/recursion/sumcheck/proof.go index cdba88cc7d..9738cbb161 100644 --- a/std/recursion/sumcheck/proof.go +++ b/std/recursion/sumcheck/proof.go @@ -1,6 +1,8 @@ package sumcheck import ( + "math/big" + "github.com/consensys/gnark/std/math/emulated" "github.com/consensys/gnark/std/math/polynomial" ) @@ -14,9 +16,9 @@ type Proof[FR emulated.FieldParams] struct { FinalEvalProof EvaluationProof } -type nativeProof struct { - RoundPolyEvaluations []nativePolynomial - FinalEvalProof nativeEvaluationProof +type NativeProof struct { + RoundPolyEvaluations []NativePolynomial + FinalEvalProof NativeEvaluationProof } // EvaluationProof is proof for allowing the sumcheck verifier to perform the @@ -27,16 +29,32 @@ type nativeProof struct { // - if it is deferred, then it is a slice. type EvaluationProof any -type nativeEvaluationProof any +// evaluationProof for gkr +type DeferredEvalProof[FR emulated.FieldParams] []emulated.Element[FR] +type NativeDeferredEvalProof []big.Int + +type NativeEvaluationProof any -func valueOfProof[FR emulated.FieldParams](nproof nativeProof) Proof[FR] { +func ValueOfProof[FR emulated.FieldParams](nproof NativeProof) Proof[FR] { rps := make([]polynomial.Univariate[FR], len(nproof.RoundPolyEvaluations)) + finaleval := nproof.FinalEvalProof + if finaleval != nil { + switch v := finaleval.(type) { + case NativeDeferredEvalProof: + deferredEval := make(DeferredEvalProof[FR], len(v)) + for i := range v { + deferredEval[i] = emulated.ValueOf[FR](v[i]) + } + finaleval = deferredEval + } + } for i := range nproof.RoundPolyEvaluations { rps[i] = polynomial.ValueOfUnivariate[FR](nproof.RoundPolyEvaluations[i]) } - // TODO: type switch FinalEvalProof when it is not-nil + return Proof[FR]{ RoundPolyEvaluations: rps, + FinalEvalProof: finaleval, } } diff --git a/std/recursion/sumcheck/prover.go b/std/recursion/sumcheck/prover.go index c075cf1530..a1e154c5d7 100644 --- a/std/recursion/sumcheck/prover.go +++ b/std/recursion/sumcheck/prover.go @@ -15,13 +15,6 @@ type proverConfig struct { type proverOption func(*proverConfig) error -func withProverPrefix(prefix string) proverOption { - return func(pc *proverConfig) error { - pc.prefix = prefix - return nil - } -} - func newProverConfig(opts ...proverOption) (*proverConfig, error) { ret := new(proverConfig) for i := range opts { @@ -32,58 +25,57 @@ func newProverConfig(opts ...proverOption) (*proverConfig, error) { return ret, nil } -func prove(current *big.Int, target *big.Int, claims claims, opts ...proverOption) (nativeProof, error) { - var proof nativeProof +func Prove(current *big.Int, target *big.Int, claims claims, opts ...proverOption) (NativeProof, error) { + var proof NativeProof cfg, err := newProverConfig(opts...) if err != nil { return proof, fmt.Errorf("parse options: %w", err) } - challengeNames := getChallengeNames(cfg.prefix, claims.NbClaims(), claims.NbVars()) + challengeNames := getChallengeNames(cfg.prefix, 1, claims.NbVars()) // claims.NbClaims() fshash, err := recursion.NewShort(current, target) if err != nil { return proof, fmt.Errorf("new short hash: %w", err) } fs := fiatshamir.NewTranscript(fshash, challengeNames...) - if err != nil { - return proof, fmt.Errorf("new transcript: %w", err) - } // bind challenge from previous round if it is a continuation - if err = bindChallengeProver(fs, challengeNames[0], cfg.baseChallenges); err != nil { + if err = BindChallengeProver(fs, challengeNames[0], cfg.baseChallenges); err != nil { return proof, fmt.Errorf("base: %w", err) } combinationCoef := big.NewInt(0) - if claims.NbClaims() >= 2 { - if combinationCoef, challengeNames, err = deriveChallengeProver(fs, challengeNames, nil); err != nil { - return proof, fmt.Errorf("derive combination coef: %w", err) - } - } + // if claims.NbClaims() >= 2 { // todo change this to handle multiple claims per wire - assuming single claim per wire so don't need to combine + // if combinationCoef, challengeNames, err = DeriveChallengeProver(fs, challengeNames, nil); err != nil { + // return proof, fmt.Errorf("derive combination coef: %w", err) + // } // todo change this nbclaims give 6 results in combination coeff + // } + // in sumcheck we run a round for every variable. So the number of variables // defines the number of rounds. nbVars := claims.NbVars() - proof.RoundPolyEvaluations = make([]nativePolynomial, nbVars) + proof.RoundPolyEvaluations = make([]NativePolynomial, nbVars) // the first round in the sumcheck is without verifier challenge. Combine challenges and provers sends the first polynomial proof.RoundPolyEvaluations[0] = claims.Combine(combinationCoef) - challenges := make([]*big.Int, nbVars) // we iterate over all variables. However, we omit the last round as the // final evaluation is possibly deferred. for j := 0; j < nbVars-1; j++ { // compute challenge for the next round - if challenges[j], challengeNames, err = deriveChallengeProver(fs, challengeNames, proof.RoundPolyEvaluations[j]); err != nil { + if challenges[j], challengeNames, err = DeriveChallengeProver(fs, challengeNames, proof.RoundPolyEvaluations[j]); err != nil { return proof, fmt.Errorf("derive challenge: %w", err) } // compute the univariate polynomial with first j variables fixed. proof.RoundPolyEvaluations[j+1] = claims.Next(challenges[j]) + } - if challenges[nbVars-1], challengeNames, err = deriveChallengeProver(fs, challengeNames, proof.RoundPolyEvaluations[nbVars-1]); err != nil { + if challenges[nbVars-1], challengeNames, err = DeriveChallengeProver(fs, challengeNames, proof.RoundPolyEvaluations[nbVars-1]); err != nil { return proof, fmt.Errorf("derive challenge: %w", err) } if len(challengeNames) > 0 { return proof, fmt.Errorf("excessive challenges") } + proof.FinalEvalProof = claims.ProverFinalEval(challenges) return proof, nil diff --git a/std/recursion/sumcheck/scalarmul_gates_test.go b/std/recursion/sumcheck/scalarmul_gates.go similarity index 74% rename from std/recursion/sumcheck/scalarmul_gates_test.go rename to std/recursion/sumcheck/scalarmul_gates.go index 30ff77e1ad..f652253e1b 100644 --- a/std/recursion/sumcheck/scalarmul_gates_test.go +++ b/std/recursion/sumcheck/scalarmul_gates.go @@ -8,19 +8,21 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/profile" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/std/math/emulated" "github.com/consensys/gnark/std/math/emulated/emparams" "github.com/consensys/gnark/std/math/polynomial" "github.com/consensys/gnark/test" ) -type projAddGate[AE arithEngine[E], E element] struct { - folding E +type ProjAddGate[AE ArithEngine[E], E element] struct { + Folding E } -func (m projAddGate[AE, E]) NbInputs() int { return 6 } -func (m projAddGate[AE, E]) Degree() int { return 4 } -func (m projAddGate[AE, E]) Evaluate(api AE, vars ...E) E { +func (m ProjAddGate[AE, E]) NbInputs() int { return 6 } +func (m ProjAddGate[AE, E]) Degree() int { return 4 } +func (m ProjAddGate[AE, E]) Evaluate(api AE, vars ...E) []E { if len(vars) != m.NbInputs() { panic("incorrect nb of inputs") } @@ -61,11 +63,11 @@ func (m projAddGate[AE, E]) Evaluate(api AE, vars ...E) E { Z3 = api.Mul(Z3, t4) Z3 = api.Add(Z3, t0) - res := api.Mul(m.folding, Z3) + res := api.Mul(m.Folding, Z3) res = api.Add(res, Y3) - res = api.Mul(m.folding, res) + res = api.Mul(m.Folding, res) res = api.Add(res, X3) - return res + return []E{res} } type ProjAddSumcheckCircuit[FR emulated.FieldParams] struct { @@ -102,7 +104,7 @@ func (c *ProjAddSumcheckCircuit[FR]) Define(api frontend.API) error { for i := range c.EvaluationPoints { evalPoints[i] = polynomial.FromSlice[FR](c.EvaluationPoints[i]) } - claim, err := newGate[FR](api, projAddGate[*emuEngine[FR], *emulated.Element[FR]]{f.NewElement(123)}, inputs, evalPoints, claimedEvals) + claim, err := newGate[FR](api, ProjAddGate[*EmuEngine[FR], *emulated.Element[FR]]{f.NewElement(123)}, inputs, evalPoints, claimedEvals) if err != nil { return fmt.Errorf("new gate claim: %w", err) } @@ -114,7 +116,7 @@ func (c *ProjAddSumcheckCircuit[FR]) Define(api frontend.API) error { func testProjAddSumCheckInstance[FR emulated.FieldParams](t *testing.T, current *big.Int, inputs [][]int) { var fr FR - nativeGate := projAddGate[*bigIntEngine, *big.Int]{folding: big.NewInt(123)} + nativeGate := ProjAddGate[*BigIntEngine, *big.Int]{Folding: big.NewInt(123)} assert := test.NewAssert(t) inputB := make([][]*big.Int, len(inputs)) for i := range inputB { @@ -126,7 +128,7 @@ func testProjAddSumCheckInstance[FR emulated.FieldParams](t *testing.T, current evalPointsB, evalPointsPH, evalPointsC := getChallengeEvaluationPoints[FR](inputB) claim, evals, err := newNativeGate(fr.Modulus(), nativeGate, inputB, evalPointsB) assert.NoError(err) - proof, err := prove(current, fr.Modulus(), claim) + proof, err := Prove(current, fr.Modulus(), claim) assert.NoError(err) nbVars := bits.Len(uint(len(inputs[0]))) - 1 circuit := &ProjAddSumcheckCircuit[FR]{ @@ -137,7 +139,7 @@ func testProjAddSumCheckInstance[FR emulated.FieldParams](t *testing.T, current } assignment := &ProjAddSumcheckCircuit[FR]{ Inputs: make([][]emulated.Element[FR], len(inputs)), - Proof: valueOfProof[FR](proof), + Proof: ValueOfProof[FR](proof), EvaluationPoints: evalPointsC, Claimed: []emulated.Element[FR]{emulated.ValueOf[FR](evals[0])}, } @@ -168,12 +170,12 @@ func TestProjAddSumCheckSumcheck(t *testing.T) { testProjAddSumCheckInstance[emparams.BN254Fr](t, ecc.BN254.ScalarField(), inputs) } -type dblAddSelectGate[AE arithEngine[E], E element] struct { - folding []E +type DblAddSelectGate[AE ArithEngine[E], E element] struct { + Folding []E } -func projAdd[AE arithEngine[E], E element](api AE, X1, Y1, Z1, X2, Y2, Z2 E) (X3, Y3, Z3 E) { - b3 := api.Const(big.NewInt(21)) +func ProjAdd[AE ArithEngine[E], E element](api AE, X1, Y1, Z1, X2, Y2, Z2 E) (X3, Y3, Z3 E) { + b3 := api.Const(big.NewInt(9)) //todo hardcoded for bn254, b3 = 3*b t0 := api.Mul(X1, X2) t1 := api.Mul(Y1, Y2) t2 := api.Mul(Z1, Z2) @@ -210,7 +212,7 @@ func projAdd[AE arithEngine[E], E element](api AE, X1, Y1, Z1, X2, Y2, Z2 E) (X3 return } -func projSelect[AE arithEngine[E], E element](api AE, selector, X1, Y1, Z1, X2, Y2, Z2 E) (X3, Y3, Z3 E) { +func ProjSelect[AE ArithEngine[E], E element](api AE, selector, X1, Y1, Z1, X2, Y2, Z2 E) (X3, Y3, Z3 E) { X3 = api.Sub(X1, X2) X3 = api.Mul(selector, X3) X3 = api.Add(X3, X2) @@ -225,8 +227,8 @@ func projSelect[AE arithEngine[E], E element](api AE, selector, X1, Y1, Z1, X2, return } -func projDbl[AE arithEngine[E], E element](api AE, X, Y, Z E) (X3, Y3, Z3 E) { - b3 := api.Const(big.NewInt(21)) +func ProjDbl[AE ArithEngine[E], E element](api AE, X, Y, Z E) (X3, Y3, Z3 E) { + b3 := api.Const(big.NewInt(9)) //todo hardcoded for bn254, b3 = 3*b t0 := api.Mul(Y, Y) Z3 = api.Add(t0, t0) Z3 = api.Add(Z3, Z3) @@ -248,14 +250,15 @@ func projDbl[AE arithEngine[E], E element](api AE, X, Y, Z E) (X3, Y3, Z3 E) { return } -func (m dblAddSelectGate[AE, E]) NbInputs() int { return 7 } -func (m dblAddSelectGate[AE, E]) Degree() int { return 5 } -func (m dblAddSelectGate[AE, E]) Evaluate(api AE, vars ...E) E { +func (m DblAddSelectGate[AE, E]) NbInputs() int { return 7 } +func (m DblAddSelectGate[AE, E]) NbOutputs() int { return 1 } +func (m DblAddSelectGate[AE, E]) Degree() int { return 5 } +func (m DblAddSelectGate[AE, E]) Evaluate(api AE, vars ...E) []E { if len(vars) != m.NbInputs() { panic("incorrect nb of inputs") } - if len(m.folding) != m.NbInputs()-1 { - panic("incorrect nb of folding vars") + if len(m.Folding) != m.NbInputs()-1 { + panic("incorrect nb of Folding vars") } // X1, Y1, Z1 == accumulator X1, Y1, Z1 := vars[0], vars[1], vars[2] @@ -263,29 +266,79 @@ func (m dblAddSelectGate[AE, E]) Evaluate(api AE, vars ...E) E { X2, Y2, Z2 := vars[3], vars[4], vars[5] selector := vars[6] - tmpX, tmpY, tmpZ := projAdd(api, X1, Y1, Z1, X2, Y2, Z2) - ResX, ResY, ResZ := projSelect(api, selector, tmpX, tmpY, tmpZ, X2, Y2, Z2) - AccX, AccY, AccZ := projDbl(api, X1, Y1, Z1) - - // folding part - f0 := api.Mul(m.folding[0], AccX) - f1 := api.Mul(m.folding[1], AccY) - f2 := api.Mul(m.folding[2], AccZ) - f3 := api.Mul(m.folding[3], ResX) - f4 := api.Mul(m.folding[4], ResY) - f5 := api.Mul(m.folding[5], ResZ) + tmpX, tmpY, tmpZ := ProjAdd(api, X1, Y1, Z1, X2, Y2, Z2) + ResX, ResY, ResZ := ProjSelect(api, selector, tmpX, tmpY, tmpZ, X2, Y2, Z2) + AccX, AccY, AccZ := ProjDbl(api, X1, Y1, Z1) + + // Folding part + f0 := api.Mul(m.Folding[0], AccX) + f1 := api.Mul(m.Folding[1], AccY) + f2 := api.Mul(m.Folding[2], AccZ) + f3 := api.Mul(m.Folding[3], ResX) + f4 := api.Mul(m.Folding[4], ResY) + f5 := api.Mul(m.Folding[5], ResZ) res := api.Add(f0, f1) res = api.Add(res, f2) res = api.Add(res, f3) res = api.Add(res, f4) res = api.Add(res, f5) - return res + return []E{res} +} + +type MultipleDblAddSelectGate[AE ArithEngine[E], E any] struct { + selector []E +} + +func (m MultipleDblAddSelectGate[AE, E]) NbInputs() int { return 7 } +func (m MultipleDblAddSelectGate[AE, E]) Degree() int { return 5 } +func (m MultipleDblAddSelectGate[AE, E]) Evaluate(api AE, vars ...E) []E { + if len(vars) != m.NbInputs() { + panic("incorrect nb of inputs") + } + // X1, Y1, Z1 == accumulator + X1, Y1, Z1 := vars[0], vars[1], vars[2] + // X2, Y2, Z2 == result + X2, Y2, Z2 := vars[3], vars[4], vars[5] + selector := vars[6] + + tmpX, tmpY, tmpZ := ProjAdd(api, X1, Y1, Z1, X2, Y2, Z2) + ResX, ResY, ResZ := ProjSelect(api, selector, tmpX, tmpY, tmpZ, X2, Y2, Z2) + AccX, AccY, AccZ := ProjDbl(api, X1, Y1, Z1) + + return []E{AccX, AccY, AccZ, ResX, ResY, ResZ} +} + +type DblAddSelectGateFullOutput[AE ArithEngine[E], E any] struct { + Selector E +} + +func (m DblAddSelectGateFullOutput[AE, E]) NbInputs() int { return 6 } +func (m DblAddSelectGateFullOutput[AE, E]) NbOutputs() int { return 6 } +func (m DblAddSelectGateFullOutput[AE, E]) Degree() int { return 5 } +func (m DblAddSelectGateFullOutput[AE, E]) GetName() string { + return "dbl_add_select_full_output" +} +func (m DblAddSelectGateFullOutput[AE, E]) Evaluate(api AE, vars ...E) []E { + if len(vars) != m.NbInputs() { + panic("incorrect nb of inputs") + } + // X1, Y1, Z1 == accumulator + X1, Y1, Z1 := vars[0], vars[1], vars[2] + // X2, Y2, Z2 == result + X2, Y2, Z2 := vars[3], vars[4], vars[5] + selector := m.Selector //vars[6] + + tmpX, tmpY, tmpZ := ProjAdd(api, X1, Y1, Z1, X2, Y2, Z2) + ResX, ResY, ResZ := ProjSelect(api, selector, tmpX, tmpY, tmpZ, X2, Y2, Z2) + AccX, AccY, AccZ := ProjDbl(api, X1, Y1, Z1) + + return []E{AccX, AccY, AccZ, ResX, ResY, ResZ} } func TestDblAndAddGate(t *testing.T) { assert := test.NewAssert(t) - nativeGate := dblAddSelectGate[*bigIntEngine, *big.Int]{folding: []*big.Int{ + nativeGate := DblAddSelectGate[*BigIntEngine, *big.Int]{Folding: []*big.Int{ big.NewInt(1), big.NewInt(2), big.NewInt(3), @@ -299,7 +352,7 @@ func TestDblAndAddGate(t *testing.T) { assert.True(ok) secpfp, ok := new(big.Int).SetString("fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f", 16) assert.True(ok) - eng := newBigIntEngine(secpfp) + eng := NewBigIntEngine(secpfp) res := nativeGate.Evaluate(eng, px, py, big.NewInt(1), big.NewInt(0), big.NewInt(1), big.NewInt(0), big.NewInt(1)) t.Log(res) _ = res @@ -339,9 +392,9 @@ func (c *ProjDblAddSelectSumcheckCircuit[FR]) Define(api frontend.API) error { for i := range c.EvaluationPoints { evalPoints[i] = polynomial.FromSlice[FR](c.EvaluationPoints[i]) } - claim, err := newGate[FR](api, dblAddSelectGate[*emuEngine[FR], + claim, err := newGate[FR](api, DblAddSelectGate[*EmuEngine[FR], *emulated.Element[FR]]{ - folding: []*emulated.Element[FR]{ + Folding: []*emulated.Element[FR]{ f.NewElement(1), f.NewElement(2), f.NewElement(3), @@ -361,7 +414,7 @@ func (c *ProjDblAddSelectSumcheckCircuit[FR]) Define(api frontend.API) error { func testProjDblAddSelectSumCheckInstance[FR emulated.FieldParams](t *testing.T, current *big.Int, inputs [][]int) { var fr FR - nativeGate := dblAddSelectGate[*bigIntEngine, *big.Int]{folding: []*big.Int{ + nativeGate := DblAddSelectGate[*BigIntEngine, *big.Int]{Folding: []*big.Int{ big.NewInt(1), big.NewInt(2), big.NewInt(3), @@ -377,10 +430,11 @@ func testProjDblAddSelectSumCheckInstance[FR emulated.FieldParams](t *testing.T, inputB[i][j] = big.NewInt(int64(inputs[i][j])) } } + evalPointsB, evalPointsPH, evalPointsC := getChallengeEvaluationPoints[FR](inputB) claim, evals, err := newNativeGate(fr.Modulus(), nativeGate, inputB, evalPointsB) assert.NoError(err) - proof, err := prove(current, fr.Modulus(), claim) + proof, err := Prove(current, fr.Modulus(), claim) assert.NoError(err) nbVars := bits.Len(uint(len(inputs[0]))) - 1 circuit := &ProjDblAddSelectSumcheckCircuit[FR]{ @@ -391,7 +445,7 @@ func testProjDblAddSelectSumCheckInstance[FR emulated.FieldParams](t *testing.T, } assignment := &ProjDblAddSelectSumcheckCircuit[FR]{ Inputs: make([][]emulated.Element[FR], len(inputs)), - Proof: valueOfProof[FR](proof), + Proof: ValueOfProof[FR](proof), EvaluationPoints: evalPointsC, Claimed: []emulated.Element[FR]{emulated.ValueOf[FR](evals[0])}, } @@ -404,9 +458,14 @@ func testProjDblAddSelectSumCheckInstance[FR emulated.FieldParams](t *testing.T, } err = test.IsSolved(circuit, assignment, current) assert.NoError(err) + p := profile.Start() + _, _ = frontend.Compile(current, scs.NewBuilder, circuit) + p.Stop() + fmt.Println(p.NbConstraints()) } -func TestProjDblAddSelectSumCheckSumcheck(t *testing.T) { +//todo used this as Flattened SC benchmarks +func TestProjDblAddSelectSumCheck(t *testing.T) { // testProjDblAddSelectSumCheckInstance[emparams.BN254Fr](t, ecc.BN254.ScalarField(), [][]int{{4, 3}, {2, 3}, {3, 6}, {4, 9}, {13, 3}, {31, 9}}) // testProjDblAddSelectSumCheckInstance[emparams.BN254Fr](t, ecc.BN254.ScalarField(), [][]int{{1, 2, 3, 4}, {5, 6, 7, 8}}) // testProjDblAddSelectSumCheckInstance[emparams.BN254Fr](t, ecc.BN254.ScalarField(), [][]int{{1, 2, 3, 4, 5, 6, 7, 8}, {11, 12, 13, 14, 15, 16, 17, 18}}) @@ -420,5 +479,5 @@ func TestProjDblAddSelectSumCheckSumcheck(t *testing.T) { inputs[5] = append(inputs[5], (inputs[4][i-1]+5)*4) inputs[6] = append(inputs[6], (inputs[5][i-1]+6)*3) } - testProjDblAddSelectSumCheckInstance[emparams.BN254Fr](t, ecc.BN254.ScalarField(), inputs) + testProjDblAddSelectSumCheckInstance[emparams.BN254Fp](t, ecc.BN254.ScalarField(), inputs) } diff --git a/std/recursion/sumcheck/sumcheck_test.go b/std/recursion/sumcheck/sumcheck.go similarity index 93% rename from std/recursion/sumcheck/sumcheck_test.go rename to std/recursion/sumcheck/sumcheck.go index 1127e46e88..6dd611c02f 100644 --- a/std/recursion/sumcheck/sumcheck_test.go +++ b/std/recursion/sumcheck/sumcheck.go @@ -46,7 +46,7 @@ func testMultilinearSumcheckInstance[FR emulated.FieldParams](t *testing.T, curr claim, value, err := newNativeMultilinearClaim(fr.Modulus(), mleB) assert.NoError(err) - proof, err := prove(current, fr.Modulus(), claim) + proof, err := Prove(current, fr.Modulus(), claim) assert.NoError(err) nbVars := bits.Len(uint(len(mle))) - 1 circuit := &MultilinearSumcheckCircuit[FR]{ @@ -56,7 +56,7 @@ func testMultilinearSumcheckInstance[FR emulated.FieldParams](t *testing.T, curr assignment := &MultilinearSumcheckCircuit[FR]{ Function: polynomial.ValueOfMultilinear[FR](mleB), Claim: emulated.ValueOf[FR](value), - Proof: valueOfProof[FR](proof), + Proof: ValueOfProof[FR](proof), } err = test.IsSolved(circuit, assignment, current) assert.NoError(err) @@ -92,15 +92,15 @@ func getChallengeEvaluationPoints[FR emulated.FieldParams](inputs [][]*big.Int) return } -type mulGate1[AE arithEngine[E], E element] struct{} +type mulGate1[AE ArithEngine[E], E element] struct{} func (m mulGate1[AE, E]) NbInputs() int { return 2 } func (m mulGate1[AE, E]) Degree() int { return 2 } -func (m mulGate1[AE, E]) Evaluate(api AE, vars ...E) E { +func (m mulGate1[AE, E]) Evaluate(api AE, vars ...E) []E { if len(vars) != m.NbInputs() { panic("incorrect nb of inputs") } - return api.Mul(vars[0], vars[1]) + return []E{api.Mul(vars[0], vars[1])} } type MulGateSumcheck[FR emulated.FieldParams] struct { @@ -133,7 +133,7 @@ func (c *MulGateSumcheck[FR]) Define(api frontend.API) error { for i := range c.EvaluationPoints { evalPoints[i] = polynomial.FromSlice[FR](c.EvaluationPoints[i]) } - claim, err := newGate[FR](api, mulGate1[*emuEngine[FR], *emulated.Element[FR]]{}, inputs, evalPoints, claimedEvals) + claim, err := newGate[FR](api, mulGate1[*EmuEngine[FR], *emulated.Element[FR]]{}, inputs, evalPoints, claimedEvals) if err != nil { return fmt.Errorf("new gate claim: %w", err) } @@ -145,7 +145,7 @@ func (c *MulGateSumcheck[FR]) Define(api frontend.API) error { func testMulGate1SumcheckInstance[FR emulated.FieldParams](t *testing.T, current *big.Int, inputs [][]int) { var fr FR - var nativeGate mulGate1[*bigIntEngine, *big.Int] + var nativeGate mulGate1[*BigIntEngine, *big.Int] assert := test.NewAssert(t) inputB := make([][]*big.Int, len(inputs)) for i := range inputB { @@ -157,7 +157,7 @@ func testMulGate1SumcheckInstance[FR emulated.FieldParams](t *testing.T, current evalPointsB, evalPointsPH, evalPointsC := getChallengeEvaluationPoints[FR](inputB) claim, evals, err := newNativeGate(fr.Modulus(), nativeGate, inputB, evalPointsB) assert.NoError(err) - proof, err := prove(current, fr.Modulus(), claim) + proof, err := Prove(current, fr.Modulus(), claim) assert.NoError(err) nbVars := bits.Len(uint(len(inputs[0]))) - 1 circuit := &MulGateSumcheck[FR]{ @@ -168,7 +168,7 @@ func testMulGate1SumcheckInstance[FR emulated.FieldParams](t *testing.T, current } assignment := &MulGateSumcheck[FR]{ Inputs: make([][]emulated.Element[FR], len(inputs)), - Proof: valueOfProof[FR](proof), + Proof: ValueOfProof[FR](proof), EvaluationPoints: evalPointsC, Claimed: []emulated.Element[FR]{emulated.ValueOf[FR](evals[0])}, } diff --git a/std/recursion/sumcheck/test_vectors/mimc_five_levels_two_instances._json b/std/recursion/sumcheck/test_vectors/mimc_five_levels_two_instances._json new file mode 100644 index 0000000000..446d23fdb2 --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/mimc_five_levels_two_instances._json @@ -0,0 +1,7 @@ +{ + "hash": {"type": "const", "val": -1}, + "circuit": "resources/mimc_five_levels.json", + "input": [[1, 3], [1, 3], [1, 3], [1, 3], [1, 3], [1, 3]], + "output": [[4, 3]], + "proof": [[{"partialSumPolys":[[3,4]],"finalEvalProof":[3]}],[{"partialSumPolys":null,"finalEvalProof":null}]] +} \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/resources/mimc_five_levels.json b/std/recursion/sumcheck/test_vectors/resources/mimc_five_levels.json new file mode 100644 index 0000000000..3dd74f42b5 --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/resources/mimc_five_levels.json @@ -0,0 +1,36 @@ +[ + [ + { + "gate": "mimc", + "inputs": [[1,0], [5,5]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[2,0], [5,4]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[3,0], [5,3]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[4,0], [5,2]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[5,0], [5,1]] + } + ], + [ + {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, + {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, {"gate": null, "inputs": []} + ] +] \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/resources/single_identity_gate.json b/std/recursion/sumcheck/test_vectors/resources/single_identity_gate.json new file mode 100644 index 0000000000..a44066c7b4 --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/resources/single_identity_gate.json @@ -0,0 +1,10 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "identity", + "inputs": [0] + } +] \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/resources/single_input_two_identity_gates.json b/std/recursion/sumcheck/test_vectors/resources/single_input_two_identity_gates.json new file mode 100644 index 0000000000..6181784fa8 --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/resources/single_input_two_identity_gates.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "identity", + "inputs": [0] + }, + { + "gate": "identity", + "inputs": [0] + } +] \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/resources/single_input_two_outs.json b/std/recursion/sumcheck/test_vectors/resources/single_input_two_outs.json new file mode 100644 index 0000000000..c577c1cace --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/resources/single_input_two_outs.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "mul", + "inputs": [0, 0] + }, + { + "gate": "identity", + "inputs": [0] + } +] \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/resources/single_mimc_gate.json b/std/recursion/sumcheck/test_vectors/resources/single_mimc_gate.json new file mode 100644 index 0000000000..c89e7d52ae --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/resources/single_mimc_gate.json @@ -0,0 +1,7 @@ +[ + {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, + { + "gate": "mimc", + "inputs": [0, 1] + } +] \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/resources/single_mul_gate.json b/std/recursion/sumcheck/test_vectors/resources/single_mul_gate.json new file mode 100644 index 0000000000..0f65a07edf --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/resources/single_mul_gate.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": null, + "inputs": [] + }, + { + "gate": "mul", + "inputs": [0, 1] + } +] \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/resources/two_identity_gates_composed_single_input.json b/std/recursion/sumcheck/test_vectors/resources/two_identity_gates_composed_single_input.json new file mode 100644 index 0000000000..26681c2f89 --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/resources/two_identity_gates_composed_single_input.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "identity", + "inputs": [0] + }, + { + "gate": "identity", + "inputs": [1] + } +] \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/resources/two_inputs_select-input-3_gate.json b/std/recursion/sumcheck/test_vectors/resources/two_inputs_select-input-3_gate.json new file mode 100644 index 0000000000..cdbdb3b471 --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/resources/two_inputs_select-input-3_gate.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": null, + "inputs": [] + }, + { + "gate": "select-input-3", + "inputs": [0,0,1] + } +] \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/single_identity_gate_two_instances.json b/std/recursion/sumcheck/test_vectors/single_identity_gate_two_instances.json new file mode 100644 index 0000000000..ce326d0a63 --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/single_identity_gate_two_instances.json @@ -0,0 +1,36 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_identity_gate.json", + "input": [ + [ + 4, + 3 + ] + ], + "output": [ + [ + 4, + 3 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + 5 + ], + "partialSumPolys": [ + [ + -3, + -8 + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/single_input_two_identity_gates_two_instances.json b/std/recursion/sumcheck/test_vectors/single_input_two_identity_gates_two_instances.json new file mode 100644 index 0000000000..2c95f044f2 --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/single_input_two_identity_gates_two_instances.json @@ -0,0 +1,56 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_input_two_identity_gates.json", + "input": [ + [ + 2, + 3 + ] + ], + "output": [ + [ + 2, + 3 + ], + [ + 2, + 3 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [ + [ + 0, + 0 + ] + ] + }, + { + "finalEvalProof": [ + 1 + ], + "partialSumPolys": [ + [ + -3, + -16 + ] + ] + }, + { + "finalEvalProof": [ + 1 + ], + "partialSumPolys": [ + [ + -3, + -16 + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/single_input_two_outs_two_instances.json b/std/recursion/sumcheck/test_vectors/single_input_two_outs_two_instances.json new file mode 100644 index 0000000000..d348303d0e --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/single_input_two_outs_two_instances.json @@ -0,0 +1,57 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_input_two_outs.json", + "input": [ + [ + 1, + 2 + ] + ], + "output": [ + [ + 1, + 4 + ], + [ + 1, + 2 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [ + [ + 0, + 0 + ] + ] + }, + { + "finalEvalProof": [ + 0 + ], + "partialSumPolys": [ + [ + -4, + -36, + -112 + ] + ] + }, + { + "finalEvalProof": [ + 0 + ], + "partialSumPolys": [ + [ + -2, + -12 + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/single_mimc_gate_four_instances.json b/std/recursion/sumcheck/test_vectors/single_mimc_gate_four_instances.json new file mode 100644 index 0000000000..525459ecb1 --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/single_mimc_gate_four_instances.json @@ -0,0 +1,67 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_mimc_gate.json", + "input": [ + [ + 1, + 1, + 2, + 1 + ], + [ + 1, + 2, + 2, + 1 + ] + ], + "output": [ + [ + 128, + 2187, + 16384, + 128 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + -1, + -3 + ], + "partialSumPolys": [ + [ + -32640, + -2239484, + -29360128, + "-200000010", + "-931628672", + "-3373267120", + "-10200858624", + "-26939400158" + ], + [ + -81920, + -41943040, + "-1254113280", + "-13421772800", + "-83200000000", + "-366917713920", + "-1281828208640", + "-3779571220480" + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/single_mimc_gate_two_instances.json b/std/recursion/sumcheck/test_vectors/single_mimc_gate_two_instances.json new file mode 100644 index 0000000000..7fa23ce4b1 --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/single_mimc_gate_two_instances.json @@ -0,0 +1,51 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_mimc_gate.json", + "input": [ + [ + 1, + 1 + ], + [ + 1, + 2 + ] + ], + "output": [ + [ + 128, + 2187 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + 1, + 0 + ], + "partialSumPolys": [ + [ + -2187, + -65536, + -546875, + -2799360, + -10706059, + -33554432, + -90876411, + "-220000000" + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/single_mul_gate_two_instances.json b/std/recursion/sumcheck/test_vectors/single_mul_gate_two_instances.json new file mode 100644 index 0000000000..75c1d59c3d --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/single_mul_gate_two_instances.json @@ -0,0 +1,46 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_mul_gate.json", + "input": [ + [ + 4, + 3 + ], + [ + 2, + 3 + ] + ], + "output": [ + [ + 8, + 9 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + 5, + 1 + ], + "partialSumPolys": [ + [ + -9, + -32, + -35 + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/two_identity_gates_composed_single_input_two_instances.json b/std/recursion/sumcheck/test_vectors/two_identity_gates_composed_single_input_two_instances.json new file mode 100644 index 0000000000..10e5f1ff3c --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/two_identity_gates_composed_single_input_two_instances.json @@ -0,0 +1,47 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/two_identity_gates_composed_single_input.json", + "input": [ + [ + 2, + 1 + ] + ], + "output": [ + [ + 2, + 1 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + 3 + ], + "partialSumPolys": [ + [ + -1, + 0 + ] + ] + }, + { + "finalEvalProof": [ + 3 + ], + "partialSumPolys": [ + [ + -1, + 0 + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/two_inputs_select-input-3_gate_two_instances.json b/std/recursion/sumcheck/test_vectors/two_inputs_select-input-3_gate_two_instances.json new file mode 100644 index 0000000000..19e127df71 --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/two_inputs_select-input-3_gate_two_instances.json @@ -0,0 +1,45 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/two_inputs_select-input-3_gate.json", + "input": [ + [ + 0, + 1 + ], + [ + 2, + 3 + ] + ], + "output": [ + [ + 2, + 3 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + -1, + 1 + ], + "partialSumPolys": [ + [ + -3, + -16 + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/sumcheck/verifier.go b/std/recursion/sumcheck/verifier.go index 6674453ea8..fb523e4f65 100644 --- a/std/recursion/sumcheck/verifier.go +++ b/std/recursion/sumcheck/verifier.go @@ -9,25 +9,25 @@ import ( "github.com/consensys/gnark/std/recursion" ) -type config struct { +type Config struct { prefix string } // Option allows to alter the sumcheck verifier behaviour. -type Option func(c *config) error +type Option func(c *Config) error // WithClaimPrefix prepends the given string to the challenge names when // computing the challenges inside the sumcheck verifier. The option is used in // a higher level protocols to ensure that sumcheck claims are not interchanged. func WithClaimPrefix(prefix string) Option { - return func(c *config) error { + return func(c *Config) error { c.prefix = prefix return nil } } -func newConfig(opts ...Option) (*config, error) { - cfg := new(config) +func NewConfig(opts ...Option) (*Config, error) { + cfg := new(Config) for i := range opts { if err := opts[i](cfg); err != nil { return nil, fmt.Errorf("apply option %d: %w", i, err) @@ -37,7 +37,7 @@ func newConfig(opts ...Option) (*config, error) { } type verifyCfg[FR emulated.FieldParams] struct { - baseChallenges []emulated.Element[FR] + BaseChallenges []emulated.Element[FR] } // VerifyOption allows to alter the behaviour of the single sumcheck proof verification. @@ -48,13 +48,13 @@ type VerifyOption[FR emulated.FieldParams] func(c *verifyCfg[FR]) error func WithBaseChallenges[FR emulated.FieldParams](baseChallenges []*emulated.Element[FR]) VerifyOption[FR] { return func(c *verifyCfg[FR]) error { for i := range baseChallenges { - c.baseChallenges = append(c.baseChallenges, *baseChallenges[i]) + c.BaseChallenges = append(c.BaseChallenges, *baseChallenges[i]) } return nil } } -func newVerificationConfig[FR emulated.FieldParams](opts ...VerifyOption[FR]) (*verifyCfg[FR], error) { +func NewVerificationConfig[FR emulated.FieldParams](opts ...VerifyOption[FR]) (*verifyCfg[FR], error) { cfg := new(verifyCfg[FR]) for i := range opts { if err := opts[i](cfg); err != nil { @@ -69,14 +69,14 @@ type Verifier[FR emulated.FieldParams] struct { api frontend.API f *emulated.Field[FR] p *polynomial.Polynomial[FR] - *config + *Config } // NewVerifier initializes a new sumcheck verifier for the parametric emulated // field FR. It returns an error if the given options are invalid or when // initializing emulated arithmetic fails. func NewVerifier[FR emulated.FieldParams](api frontend.API, opts ...Option) (*Verifier[FR], error) { - cfg, err := newConfig(opts...) + cfg, err := NewConfig(opts...) if err != nil { return nil, fmt.Errorf("new configuration: %w", err) } @@ -92,41 +92,41 @@ func NewVerifier[FR emulated.FieldParams](api frontend.API, opts ...Option) (*Ve api: api, f: f, p: p, - config: cfg, + Config: cfg, }, nil } // Verify verifies the sumcheck proof for the given (lazy) claims. func (v *Verifier[FR]) Verify(claims LazyClaims[FR], proof Proof[FR], opts ...VerifyOption[FR]) error { var fr FR - cfg, err := newVerificationConfig(opts...) + cfg, err := NewVerificationConfig(opts...) if err != nil { return fmt.Errorf("verification opts: %w", err) } - challengeNames := getChallengeNames(v.prefix, claims.NbClaims(), claims.NbVars()) + challengeNames := getChallengeNames(v.prefix, 1, claims.NbVars()) //claims.NbClaims() todo change this fs, err := recursion.NewTranscript(v.api, fr.Modulus(), challengeNames) if err != nil { return fmt.Errorf("new transcript: %w", err) } // bind challenge from previous round if it is a continuation - if err = v.bindChallenge(fs, challengeNames[0], cfg.baseChallenges); err != nil { + if err = v.bindChallenge(fs, challengeNames[0], cfg.BaseChallenges); err != nil { return fmt.Errorf("base: %w", err) } - + nbVars := claims.NbVars() combinationCoef := v.f.Zero() - if claims.NbClaims() >= 2 { - if combinationCoef, challengeNames, err = v.deriveChallenge(fs, challengeNames, nil); err != nil { - return fmt.Errorf("derive combination coef: %w", err) - } - } - challenges := make([]*emulated.Element[FR], claims.NbVars()) + // if claims.NbClaims() >= 2 { // todo change this to handle multiple claims per wire - assuming single claim per wire so don't need to combine + // if combinationCoef, challengeNames, err = v.deriveChallenge(fs, challengeNames, nil); err != nil { + // return fmt.Errorf("derive combination coef: %w", err) + // } + // } + challenges := make([]*emulated.Element[FR], nbVars) //claims.NbVars() // gJR is the claimed value. In case of multiple claims it is combined // claimed value we're going to check against. gJR := claims.CombinedSum(combinationCoef) // sumcheck rounds - for j := 0; j < claims.NbVars(); j++ { + for j := 0; j < nbVars; j++ { // instead of sending the polynomials themselves, the provers sends n evaluations of the round polynomial: // // g_j(X_j) = \sum_{x_{j+1},...\x_k \in {0,1}} g(r_1, ..., r_{j-1}, X_j, x_{j+1}, ...) @@ -141,12 +141,14 @@ func (v *Verifier[FR]) Verify(claims LazyClaims[FR], proof Proof[FR], opts ...Ve if len(evals) != degree { return fmt.Errorf("expected len %d, got %d", degree, len(evals)) } - // computes g_{j-1}(r) - g_j(1) as missing evaluation + gj0 := v.f.Sub(gJR, &evals[0]) + // construct the n+1 evaluations for interpolation gJ := []*emulated.Element[FR]{gj0} for i := range evals { gJ = append(gJ, &evals[i]) + } // we derive the challenge from prover message. @@ -159,6 +161,7 @@ func (v *Verifier[FR]) Verify(claims LazyClaims[FR], proof Proof[FR], opts ...Ve // directly. gJR = v.p.InterpolateLDE(challenges[j], gJ) + // we do not directly need to check gJR now - as in the next round we // compute new evaluation point from gJR then the check is performed // implicitly. diff --git a/std/sumcheck/sumcheck.go b/std/sumcheck/sumcheck.go index de3689cbb8..9a278bf7e7 100644 --- a/std/sumcheck/sumcheck.go +++ b/std/sumcheck/sumcheck.go @@ -20,8 +20,8 @@ type LazyClaims interface { // Proof of a multi-sumcheck statement. type Proof struct { - PartialSumPolys []polynomial.Polynomial - FinalEvalProof interface{} + RoundPolyEvaluations []polynomial.Polynomial + FinalEvalProof interface{} } func setupTranscript(api frontend.API, claimsNum int, varsNum int, settings *fiatshamir.Settings) ([]string, error) { @@ -83,18 +83,17 @@ func Verify(api frontend.API, claims LazyClaims, proof Proof, transcriptSettings gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() gJR := claims.CombinedSum(api, combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) - for j := 0; j < claims.VarsNum(); j++ { - partialSumPoly := proof.PartialSumPolys[j] //proof.PartialSumPolys(j) - if len(partialSumPoly) != claims.Degree(j) { + roundPolyEvaluation := proof.RoundPolyEvaluations[j] //proof.RoundPolyEvaluations(j) + if len(roundPolyEvaluation) != claims.Degree(j) { return fmt.Errorf("malformed proof") //Malformed proof } - copy(gJ[1:], partialSumPoly) - gJ[0] = api.Sub(gJR, partialSumPoly[0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + copy(gJ[1:], roundPolyEvaluation) + gJ[0] = api.Sub(gJR, roundPolyEvaluation[0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) // gJ is ready //Prepare for the next iteration - if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + if r[j], err = next(transcript, proof.RoundPolyEvaluations[j], &remainingChallengeNames); err != nil { return err }