Skip to content

Commit

Permalink
added the test for LogDerivativeSum
Browse files Browse the repository at this point in the history
  • Loading branch information
Soleimani193 committed Dec 16, 2024
1 parent 02d74d8 commit 157f641
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 2 deletions.
13 changes: 13 additions & 0 deletions prover/protocol/query/gnark_params.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ func (p LocalOpeningParams) GnarkAssign() GnarkLocalOpeningParams {
return GnarkLocalOpeningParams{Y: p.Y}
}

type GnarkLogDerivSumParams struct {
Y frontend.Variable
}

func (p LogDerivSumParams) GnarkAssign() GnarkLogDerivSumParams {
return GnarkLogDerivSumParams{Y: p.Sum}
}

// A gnark circuit version of InnerProductParams
type GnarkInnerProductParams struct {
Ys []frontend.Variable
Expand Down Expand Up @@ -54,6 +62,11 @@ func (p GnarkLocalOpeningParams) UpdateFS(fs *fiatshamir.GnarkFiatShamir) {
fs.Update(p.Y)
}

// Update the fiat-shamir state with the the present parameters
func (p GnarkLogDerivSumParams) UpdateFS(fs *fiatshamir.GnarkFiatShamir) {
fs.Update(p.Y)
}

// Update the fiat-shamir state with the the present parameters
func (p GnarkUnivariateEvalParams) UpdateFS(fs *fiatshamir.GnarkFiatShamir) {
fs.Update(p.Ys...)
Expand Down
11 changes: 9 additions & 2 deletions prover/protocol/query/logderiv_sum.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ type LogDerivativeSumInput struct {

// LogDerivativeSum is the context of LogDerivativeSum query.
// The fields are maps from [round, size].
// the aim of the query is to compute:
// \sum_{i,j} N_{i,j}/D_{i,j} where
// N_{i,j} is the i-th element of the underlying column of j-th Numerator
// D_{i,j} is the i-th element of the underlying column of j-th Denominator
type LogDerivativeSum struct {
Inputs map[[2]int]*LogDerivativeSumInput

Expand Down Expand Up @@ -78,7 +82,7 @@ func (r LogDerivativeSum) Name() ifaces.QueryID {
}

// Constructor for the query parameters/result
func NewLogDeriveSumParams(sum field.Element) LogDerivSumParams {
func NewLogDerivSumParams(sum field.Element) LogDerivSumParams {
return LogDerivSumParams{Sum: sum}
}

Expand Down Expand Up @@ -126,7 +130,10 @@ func (r LogDerivativeSum) Check(run ifaces.Runtime) error {

// Test that global sum is correct
func (r LogDerivativeSum) CheckGnark(api frontend.API, run ifaces.GnarkRuntime) {

/*params := run.GetParams(r.ID).(GnarkLogDerivSumParams)
actualY := TBD
api.AssertIsEqual(params.Y, actualY)
*/
}

func EvalExprColumn(run ifaces.Runtime, board symbolic.ExpressionBoard) smartvectors.SmartVector {
Expand Down
71 changes: 71 additions & 0 deletions prover/protocol/query/logderiv_sum_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package query_test

import (
"testing"

"github.com/consensys/linea-monorepo/prover/maths/common/smartvectors"
"github.com/consensys/linea-monorepo/prover/maths/field"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/dummy"
"github.com/consensys/linea-monorepo/prover/protocol/ifaces"
"github.com/consensys/linea-monorepo/prover/protocol/query"
"github.com/consensys/linea-monorepo/prover/protocol/wizard"
"github.com/consensys/linea-monorepo/prover/symbolic"
"github.com/stretchr/testify/require"
)

func TestLogDerivSum(t *testing.T) {

define := func(b *wizard.Builder) {
var (
comp = b.CompiledIOP
)

p0 := b.RegisterCommit("Num_0", 4)
p1 := b.RegisterCommit("Num_1", 4)
p2 := b.RegisterCommit("Num_2", 4)

q0 := b.RegisterCommit("Den_0", 4)
q1 := b.RegisterCommit("Den_1", 4)
q2 := b.RegisterCommit("Den_2", 4)

numerators := []*symbolic.Expression{
symbolic.Mul(p0, -1),
ifaces.ColumnAsVariable(p1),
symbolic.Mul(p2, p0, 2),
}

denominators := []*symbolic.Expression{
ifaces.ColumnAsVariable(q0),
ifaces.ColumnAsVariable(q1),
ifaces.ColumnAsVariable(q2),
}

key := [2]int{0, 0}
zCat1 := map[[2]int]*query.LogDerivativeSumInput{}
zCat1[key] = &query.LogDerivativeSumInput{
Numerator: numerators,
Denominator: denominators,
}
comp.InsertLogDerivativeSum(0, "LogDerivSum_Test", zCat1)

}

prover := func(run *wizard.ProverRuntime) {

run.AssignColumn("Num_0", smartvectors.ForTest(1, 1, 1, 1))
run.AssignColumn("Num_1", smartvectors.ForTest(2, 3, 7, 9))
run.AssignColumn("Num_2", smartvectors.ForTest(5, 6, 1, 1))

run.AssignColumn("Den_0", smartvectors.ForTest(1, 1, 1, 1))
run.AssignColumn("Den_1", smartvectors.ForTest(2, 3, 7, 9))
run.AssignColumn("Den_2", smartvectors.ForTest(5, 6, 1, 1))

run.AssignLogDerivSum("LogDerivSum_Test", field.NewElement(8))

}

compiled := wizard.Compile(define, dummy.Compile)
proof := wizard.Prove(compiled, prover)
valid := wizard.Verify(compiled, proof)
require.NoError(t, valid)
}
20 changes: 20 additions & 0 deletions prover/protocol/wizard/gnark_verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ type WizardVerifierCircuit struct {
innerProductIDs collection.Mapping[ifaces.QueryID, int] `gnark:"-"`
// Same for local-opening query
localOpeningIDs collection.Mapping[ifaces.QueryID, int] `gnark:"-"`
// Same for logDerivativeSum query
logDerivSumIDs collection.Mapping[ifaces.QueryID, int] `gnark:"-"`

// Columns stores the gnark witness part corresponding to the columns
// provided in the proof and in the VerifyingKey.
Expand All @@ -63,6 +65,9 @@ type WizardVerifierCircuit struct {
// LocalOpeningParams stores an assignment for each [query.LocalOpeningParams]
// from the proof. It is part of the witness of the gnark circuit.
LocalOpeningParams []query.GnarkLocalOpeningParams `gnark:",secret"`
// LogDerivSumParams stores an assignment for each [query.LogDerivSumParams]
// from the proof. It is part of the witness of the gnark circuit.
LogDerivSumParams []query.GnarkLogDerivSumParams `gnark:",secret"`

// FS is the Fiat-Shamir state, mirroring [VerifierRuntime.FS]. The same
// cautionnary rules apply to it; e.g. don't use it externally when
Expand Down Expand Up @@ -311,6 +316,14 @@ func (c *WizardVerifierCircuit) GetLocalPointEvalParams(name ifaces.QueryID) que
return c.LocalOpeningParams[qID]
}

// GetLogDerivSumParams returns the parameters for the requested
// [query.LogDerivativeSum] query. Its work mirrors the function
// [VerifierRuntime.GetLogDerivSumParams]
func (c *WizardVerifierCircuit) GetLogDerivSumParams(name ifaces.QueryID) query.GnarkLocalOpeningParams {
qID := c.localOpeningIDs.MustGet(name)
return c.LocalOpeningParams[qID]
}

// GetColumns returns the gnark assignment of a column in a gnark circuit. It
// mirrors the function [VerifierRuntime.GetColumn]
func (c *WizardVerifierCircuit) GetColumn(name ifaces.ColID) []frontend.Variable {
Expand Down Expand Up @@ -350,11 +363,13 @@ func newWizardVerifierCircuit() *WizardVerifierCircuit {
res.columnsIDs = collection.NewMapping[ifaces.ColID, int]()
res.univariateParamsIDs = collection.NewMapping[ifaces.QueryID, int]()
res.localOpeningIDs = collection.NewMapping[ifaces.QueryID, int]()
res.logDerivSumIDs = collection.NewMapping[ifaces.QueryID, int]()
res.innerProductIDs = collection.NewMapping[ifaces.QueryID, int]()
res.Columns = [][]frontend.Variable{}
res.UnivariateParams = make([]query.GnarkUnivariateEvalParams, 0)
res.InnerProductParams = make([]query.GnarkInnerProductParams, 0)
res.LocalOpeningParams = make([]query.GnarkLocalOpeningParams, 0)
res.LogDerivSumParams = make([]query.GnarkLogDerivSumParams, 0)
res.Coins = collection.NewMapping[coin.Name, interface{}]()
return res
}
Expand Down Expand Up @@ -418,6 +433,9 @@ func GetWizardVerifierCircuitAssignment(comp *CompiledIOP, proof Proof) *WizardV
case query.LocalOpeningParams:
res.localOpeningIDs.InsertNew(qName, len(res.LocalOpeningParams))
res.LocalOpeningParams = append(res.LocalOpeningParams, params.GnarkAssign())
case query.LogDerivSumParams:
res.logDerivSumIDs.InsertNew(qName, len(res.LogDerivSumParams))
res.LogDerivSumParams = append(res.LogDerivSumParams, params.GnarkAssign())

default:
utils.Panic("unknow type %T", params)
Expand All @@ -435,6 +453,8 @@ func (c *WizardVerifierCircuit) GetParams(id ifaces.QueryID) ifaces.GnarkQueryPa
return c.GetUnivariateParams(id)
case query.LocalOpening:
return c.GetLocalPointEvalParams(id)
case query.LogDerivativeSum:
return c.GetLogDerivSumParams(id)
case query.InnerProduct:
return c.GetInnerProductParams(id)
default:
Expand Down
37 changes: 37 additions & 0 deletions prover/protocol/wizard/prover.go
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,43 @@ func (run *ProverRuntime) GetLocalPointEvalParams(name ifaces.QueryID) query.Loc
return run.QueriesParams.MustGet(name).(query.LocalOpeningParams)
}

// AssignLogDerivSum assign the claimed values for a logDeriveSum
// The function will panic if:
// - the parameters were already assigned
// - the specified query is not registered
// - the assignment round is incorrect
func (run *ProverRuntime) AssignLogDerivSum(name ifaces.QueryID, y field.Element) {

// Global prover locks for accessing the maps
run.lock.Lock()
defer run.lock.Unlock()

// Make sure, it is done at the right round
run.Spec.QueriesParams.MustBeInRound(run.currRound, name)

// Adds it to the assignments
params := query.NewLogDerivSumParams(y)
run.QueriesParams.InsertNew(name, params)
}

// GetLogDeriveSum gets the metadata of a [query.LogDerivativeSum] query. Panic if not found.
func (run *ProverRuntime) GetLogDeriveSum(name ifaces.QueryID) query.LogDerivativeSum {
// Global prover locks for accessing the maps
run.lock.Lock()
defer run.lock.Unlock()
return run.Spec.QueriesParams.Data(name).(query.LogDerivativeSum)
}

// GetLogDerivSumParams returns the parameters of [query.LogDerivativeSum]
func (run *ProverRuntime) GetLogDerivSumParams(name ifaces.QueryID) query.LogDerivSumParams {

// Global prover's lock for accessing params
run.lock.Lock()
defer run.lock.Unlock()

return run.QueriesParams.MustGet(name).(query.LogDerivSumParams)
}

// GetParams generically extracts the parameters of a query. Will panic if no
// parameters are found
func (run *ProverRuntime) GetParams(name ifaces.QueryID) ifaces.QueryParams {
Expand Down
5 changes: 5 additions & 0 deletions prover/protocol/wizard/verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,11 @@ func (run *VerifierRuntime) GetLocalPointEvalParams(name ifaces.QueryID) query.L
return run.QueriesParams.MustGet(name).(query.LocalOpeningParams)
}

// GetLogDerivSumParams returns the parameters of a [query.LogDerivativeSum]
func (run *VerifierRuntime) GetLogDerivSumParams(name ifaces.QueryID) query.LogDerivSumParams {
return run.QueriesParams.MustGet(name).(query.LogDerivSumParams)
}

/*
CopyColumnInto implements `column.GetWitness`
Copies the witness into a slice
Expand Down

0 comments on commit 157f641

Please sign in to comment.