From 79a371e0a3cfb8053a07f6893a268659e96f2078 Mon Sep 17 00:00:00 2001 From: Soleimani193 Date: Wed, 11 Dec 2024 11:08:19 +0100 Subject: [PATCH 01/12] initial framework for lookup --- prover/protocol/compiler/lookup/compiler.go | 72 +++---- prover/protocol/compiler/lookup/context.go | 33 +-- prover/protocol/compiler/lookup/prover.go | 4 +- prover/protocol/compiler/lookup/utils.go | 28 +-- prover/protocol/compiler/lookup/z_packing.go | 12 +- .../compiler/inclusion/inclusion.go | 160 +++++++++++++- .../compiler/inclusion/log_deriv_sum.go | 204 ++++++++++++++++++ .../inclusion/log_derivative_sum_verif.go | 59 +++++ .../compiler/inclusion/lookup_discoverer.go | 117 ++++++++++ .../compiler/inclusion/preparation.go | 9 + .../distributed/compiler/inclusion/prover.go | 125 +++++++++++ 11 files changed, 742 insertions(+), 81 deletions(-) create mode 100644 prover/protocol/distributed/compiler/inclusion/log_deriv_sum.go create mode 100644 prover/protocol/distributed/compiler/inclusion/log_derivative_sum_verif.go create mode 100644 prover/protocol/distributed/compiler/inclusion/lookup_discoverer.go create mode 100644 prover/protocol/distributed/compiler/inclusion/preparation.go create mode 100644 prover/protocol/distributed/compiler/inclusion/prover.go diff --git a/prover/protocol/compiler/lookup/compiler.go b/prover/protocol/compiler/lookup/compiler.go index 6b4f8fea6..3bc9c4366 100644 --- a/prover/protocol/compiler/lookup/compiler.go +++ b/prover/protocol/compiler/lookup/compiler.go @@ -33,7 +33,7 @@ func CompileLogDerivative(comp *wizard.CompiledIOP) { // zCatalog stores a mapping (round, size) into ZCtx and helps finding // which Z context should be used to handle a part of a given permutation // query. - zCatalog = map[[2]int]*zCtx{} + zCatalog = map[[2]int]*ZCtx{} zEntries = [][2]int{} // verifier actions va = finalEvaluationCheck{} @@ -41,25 +41,25 @@ func CompileLogDerivative(comp *wizard.CompiledIOP) { // Skip the compilation phase if no lookup constraint is being used. Otherwise // it will register a verifier action that is not required and will be bugged. - if len(mainLookupCtx.lookupTables) == 0 { + if len(mainLookupCtx.LookupTables) == 0 { return } // Step 1. construct the "per table" contexts and pack the Sigma's into // zCatalog. - for _, lookupTable := range mainLookupCtx.lookupTables { + for _, lookupTable := range mainLookupCtx.LookupTables { var ( // get checkedTables, rounds, Filters by lookupTableName - lookupTableName = nameTable(lookupTable) - checkTable = mainLookupCtx.checkedTables[lookupTableName] - round = mainLookupCtx.rounds[lookupTableName] - includedFilters = mainLookupCtx.includedFilters[lookupTableName] + lookupTableName = NameTable(lookupTable) + checkTable = mainLookupCtx.CheckedTables[lookupTableName] + round = mainLookupCtx.Rounds[lookupTableName] + includedFilters = mainLookupCtx.IncludedFilters[lookupTableName] tableCtx = compileLookupTable(comp, round, lookupTable, checkTable, includedFilters) ) // push to zCatalog - tableCtx.pushToZCatalog(zCatalog) + tableCtx.PushToZCatalog(zCatalog) proverActions[round].pushMAssignment( mAssignmentTask{ @@ -131,13 +131,13 @@ func CompileLogDerivative(comp *wizard.CompiledIOP) { // The function also implictly reduces the conditionals over the Including table // be appending a "one" column on the included side and the filter on the // including side. -func captureLookupTables(comp *wizard.CompiledIOP) mainLookupCtx { +func captureLookupTables(comp *wizard.CompiledIOP) MainLookupCtx { - ctx := mainLookupCtx{ - lookupTables: [][]table{}, - checkedTables: map[string][]table{}, - includedFilters: map[string][]ifaces.Column{}, - rounds: map[string]int{}, + ctx := MainLookupCtx{ + LookupTables: [][]table{}, + CheckedTables: map[string][]table{}, + IncludedFilters: map[string][]ifaces.Column{}, + Rounds: map[string]int{}, } // Collect all the lookup queries into "lookups" @@ -157,8 +157,8 @@ func captureLookupTables(comp *wizard.CompiledIOP) mainLookupCtx { var ( // checkedTable corresponds to the "included" table and lookupTable // corresponds to the including table. - checkedTable, lookupTable = getTableCanonicalOrder(lookup) - tableName = nameTable(lookupTable) + checkedTable, lookupTable = GetTableCanonicalOrder(lookup) + tableName = NameTable(lookupTable) // includedFilters stores the query.IncludedFilter parameter. If the // query has no includedFilters on the Included side. Then this is // left as nil. @@ -176,7 +176,7 @@ func captureLookupTables(comp *wizard.CompiledIOP) mainLookupCtx { lookupTable[frag] = append([]ifaces.Column{lookup.IncludingFilter[frag]}, lookupTable[frag]...) } - tableName = nameTable(lookupTable) + tableName = NameTable(lookupTable) } if lookup.IsFilteredOnIncluded() { @@ -185,16 +185,16 @@ func captureLookupTables(comp *wizard.CompiledIOP) mainLookupCtx { // In case this is the first iteration where we encounter the lookupTable // we need to add entries in the registering maps. - if _, ok := ctx.checkedTables[tableName]; !ok { - ctx.includedFilters[tableName] = []ifaces.Column{} - ctx.checkedTables[tableName] = []table{} - ctx.lookupTables = append(ctx.lookupTables, lookupTable) - ctx.rounds[tableName] = 0 + if _, ok := ctx.CheckedTables[tableName]; !ok { + ctx.IncludedFilters[tableName] = []ifaces.Column{} + ctx.CheckedTables[tableName] = []table{} + ctx.LookupTables = append(ctx.LookupTables, lookupTable) + ctx.Rounds[tableName] = 0 } - ctx.includedFilters[tableName] = append(ctx.includedFilters[tableName], includedFilter) - ctx.checkedTables[tableName] = append(ctx.checkedTables[tableName], checkedTable) - ctx.rounds[tableName] = max(ctx.rounds[tableName], comp.QueriesNoParams.Round(lookup.ID)) + ctx.IncludedFilters[tableName] = append(ctx.IncludedFilters[tableName], includedFilter) + ctx.CheckedTables[tableName] = append(ctx.CheckedTables[tableName], checkedTable) + ctx.Rounds[tableName] = max(ctx.Rounds[tableName], comp.QueriesNoParams.Round(lookup.ID)) } @@ -219,10 +219,10 @@ func compileLookupTable( lookupTable []table, checkedTables []table, includedFilters []ifaces.Column, -) (ctx singleTableCtx) { +) (ctx SingleTableCtx) { - ctx = singleTableCtx{ - TableName: nameTable(lookupTable), + ctx = SingleTableCtx{ + TableName: NameTable(lookupTable), S: make([]*symbolic.Expression, len(checkedTables)), SFilters: includedFilters, T: make([]*symbolic.Expression, len(lookupTable)), @@ -240,7 +240,7 @@ func compileLookupTable( ctx.T[frag] = symbolic.NewVariable(lookupTable[frag][0]) ctx.M[frag] = comp.InsertCommit( round, - deriveTableNameWithIndex[ifaces.ColID](logDerivativePrefix, lookupTable, frag, "M"), + DeriveTableNameWithIndex[ifaces.ColID](LogDerivativePrefix, lookupTable, frag, "M"), lookupTable[frag][0].Size(), ) @@ -257,7 +257,7 @@ func compileLookupTable( // columns of T and S when they are (both) multi-columns. alpha := comp.InsertCoin( round+1, - deriveTableName[coin.Name](logDerivativePrefix, lookupTable, "ALPHA"), + DeriveTableName[coin.Name](LogDerivativePrefix, lookupTable, "ALPHA"), coin.Field, ) @@ -265,7 +265,7 @@ func compileLookupTable( ctx.T[frag] = wizardutils.RandLinCombColSymbolic(alpha, lookupTable[frag]) ctx.M[frag] = comp.InsertCommit( round, - deriveTableNameWithIndex[ifaces.ColID](logDerivativePrefix, lookupTable, frag, "M"), + DeriveTableNameWithIndex[ifaces.ColID](LogDerivativePrefix, lookupTable, frag, "M"), lookupTable[frag][0].Size(), ) } @@ -277,16 +277,16 @@ func compileLookupTable( ctx.Gamma = comp.InsertCoin( round+1, - deriveTableName[coin.Name](logDerivativePrefix, lookupTable, "GAMMA"), + DeriveTableName[coin.Name](LogDerivativePrefix, lookupTable, "GAMMA"), coin.Field, ) return ctx } -// pushToZCatalog constructs the numerators and denominators for S and T of the +// PushToZCatalog constructs the numerators and denominators for S and T of the // stc into zCatalog for their corresponding rounds and size. -func (stc *singleTableCtx) pushToZCatalog(zCatalog map[[2]int]*zCtx) { +func (stc *SingleTableCtx) PushToZCatalog(zCatalog map[[2]int]*ZCtx) { var ( round = stc.Gamma.Round @@ -299,7 +299,7 @@ func (stc *singleTableCtx) pushToZCatalog(zCatalog map[[2]int]*zCtx) { key := [2]int{round, size} if zCatalog[key] == nil { - zCatalog[key] = &zCtx{ + zCatalog[key] = &ZCtx{ Size: size, Round: round, Name: stc.TableName, @@ -324,7 +324,7 @@ func (stc *singleTableCtx) pushToZCatalog(zCatalog map[[2]int]*zCtx) { key := [2]int{round, size} if zCatalog[key] == nil { - zCatalog[key] = &zCtx{ + zCatalog[key] = &ZCtx{ Size: size, Round: round, Name: stc.TableName, diff --git a/prover/protocol/compiler/lookup/context.go b/prover/protocol/compiler/lookup/context.go index 4ec961da1..cfeb5d7c2 100644 --- a/prover/protocol/compiler/lookup/context.go +++ b/prover/protocol/compiler/lookup/context.go @@ -6,13 +6,13 @@ import ( "github.com/consensys/linea-monorepo/prover/symbolic" ) -// mainLookupCtx stores the compilation context of all the lookup queries +// MainLookupCtx stores the compilation context of all the lookup queries // altogether. -type mainLookupCtx struct { +type MainLookupCtx struct { - // lookupTables stores all the lookup table the compiler encounters. They are + // LookupTables stores all the lookup table the compiler encounters. They are // sorted in canonical order. This used to derive a determistic ordering - // of the lookup lookupTables. (We want to ensure the compiler yields always + // of the lookup LookupTables. (We want to ensure the compiler yields always // exactly the same result for replicability). // // To illustrates its structure, the following sub-statement @@ -20,31 +20,34 @@ type mainLookupCtx struct { // table[numTable][frag] // // refers to to the fragment #frag of the the table #numTable. - lookupTables [][]table + LookupTables [][]table - // checkedTables stores all the checked column by lookup table. The key is + // CheckedTables stores all the checked column by lookup table. The key is // obtained as nameTable(lookupTable) where lookup is sorted in // canonical order. - checkedTables map[string][]table + CheckedTables map[string][]table - // includedFilters stores all the filters for the checked columns and `nil` + // IncludedFilters stores all the filters for the checked columns and `nil` // if no filter is applied. As for [checkedTables] they are stored by // lookup table name and in the same order for each key. - includedFilters map[string][]ifaces.Column + IncludedFilters map[string][]ifaces.Column - // rounds stores the interaction round assigned to each lookupTable. The - // round is obtained by taking the max of the declaration rounds of the + // Rounds stores the interaction round assigned to each lookupTable. The + // round is obtained by taking the max of the declaration Rounds of the // Inclusion queries using the corresponding lookup table. - rounds map[string]int + Rounds map[string]int + + // the name of the tables + AllTableNames []string } -// singleTableCtx stores the compilation context for a single lookup query +// SingleTableCtx stores the compilation context for a single lookup query // when it is compiled using the log-derivative lookup technique. // -// A singleTableCtx relates to a lookup table rather than a lookup query. This +// A SingleTableCtx relates to a lookup table rather than a lookup query. This // means that multiple lookup queries that are related to the same table will be // grouped into the same context. This allows optimizing the -type singleTableCtx struct { +type SingleTableCtx struct { // TableName reflects the name of the lookup table being compiled. TableName string diff --git a/prover/protocol/compiler/lookup/prover.go b/prover/protocol/compiler/lookup/prover.go index 063fa5a6b..12a73f94f 100644 --- a/prover/protocol/compiler/lookup/prover.go +++ b/prover/protocol/compiler/lookup/prover.go @@ -281,7 +281,7 @@ func (a mAssignmentTask) run(run *wizard.ProverRuntime) { } utils.Panic( "entry %v of the table %v is not included in the table. tableRow=%v", - k, nameTable([][]ifaces.Column{a.S[i]}), vector.Prettify(tableRow), + k, NameTable([][]ifaces.Column{a.S[i]}), vector.Prettify(tableRow), ) } @@ -300,7 +300,7 @@ func (a mAssignmentTask) run(run *wizard.ProverRuntime) { // zAssignmentTask represents a prover task of assignming the columns // SigmaS and SigmaT for a specific lookup table. // sigmaAssignment -type zAssignmentTask zCtx +type zAssignmentTask ZCtx func (z zAssignmentTask) run(run *wizard.ProverRuntime) { parallel.Execute(len(z.ZDenominatorBoarded), func(start, stop int) { diff --git a/prover/protocol/compiler/lookup/utils.go b/prover/protocol/compiler/lookup/utils.go index a6d4eeb27..17cbe8858 100644 --- a/prover/protocol/compiler/lookup/utils.go +++ b/prover/protocol/compiler/lookup/utils.go @@ -10,12 +10,12 @@ import ( ) const ( - // logDerivativePrefix is a prefix that we commonly use to derive query, + // LogDerivativePrefix is a prefix that we commonly use to derive query, // coin or column names that are introduced by the compiler. - logDerivativePrefix = "LOGDERIVATIVE" + LogDerivativePrefix = "LOGDERIVATIVE" ) -// getTableCanonicalOrder extracts the lookup table and the queried tables +// GetTableCanonicalOrder extracts the lookup table and the queried tables // from `q` and rearrange them conjointly so that the names of T are returned // in alphabetical order. // @@ -28,7 +28,7 @@ const ( // can always make sure to specify the table in the same order all the time. // // Importantly, the function allocates its own result. -func getTableCanonicalOrder(q query.Inclusion) ([]ifaces.Column, [][]ifaces.Column) { +func GetTableCanonicalOrder(q query.Inclusion) ([]ifaces.Column, [][]ifaces.Column) { if len(q.Including) > 1 { // The append here are performing a deep-copy of the slice within the @@ -67,8 +67,8 @@ func getTableCanonicalOrder(q query.Inclusion) ([]ifaces.Column, [][]ifaces.Colu return checked, [][]ifaces.Column{table} } -// deriveName constructs a generic name -func deriveName[R ~string](args ...any) R { +// DeriveName constructs a generic name +func DeriveName[R ~string](args ...any) R { argStr := []string{"LOOKUP_LOGDERIVATIVE"} for _, arg := range args { argStr = append(argStr, fmt.Sprintf("%v", arg)) @@ -76,25 +76,25 @@ func deriveName[R ~string](args ...any) R { return R(strings.Join(argStr, "_")) } -// deriveTableName constructs a name for the table `t`. The caller may provide +// DeriveTableName constructs a name for the table `t`. The caller may provide // a context and a suffix to the name. If `t` is empty, the name is the // concatenation of `context` and `name` separated by an underscore. -func deriveTableName[R ~string](context string, t [][]ifaces.Column, name string) R { - res := fmt.Sprintf("%v_%v_%v", nameTable(t), context, name) +func DeriveTableName[R ~string](context string, t [][]ifaces.Column, name string) R { + res := fmt.Sprintf("%v_%v_%v", NameTable(t), context, name) return R(res) } -// deriveTableNameWithIndex is as [deriveTableName] but additionally allows +// DeriveTableNameWithIndex is as [deriveTableName] but additionally allows // appending an integer index in the name. -func deriveTableNameWithIndex[R ~string](context string, t [][]ifaces.Column, index int, name string) R { - res := fmt.Sprintf("%v_%v_%v_%v", nameTable(t), index, context, name) +func DeriveTableNameWithIndex[R ~string](context string, t [][]ifaces.Column, index int, name string) R { + res := fmt.Sprintf("%v_%v_%v_%v", NameTable(t), index, context, name) return R(res) } -// nameTable returns a unique name corresponding to the provided +// NameTable returns a unique name corresponding to the provided // sequence of columns `t`. The unique name is constructed by appending the // name of all the column separated by an underscore. -func nameTable(t []table) string { +func NameTable(t []table) string { // This single fragment case is managed as a special case although it is // not really one. This is for backwards compatibility. if len(t) == 1 { diff --git a/prover/protocol/compiler/lookup/z_packing.go b/prover/protocol/compiler/lookup/z_packing.go index d6d12ffbe..2838f4834 100644 --- a/prover/protocol/compiler/lookup/z_packing.go +++ b/prover/protocol/compiler/lookup/z_packing.go @@ -26,7 +26,7 @@ const ( // T: lookupTable, // SFilter: includedFilters, -type zCtx struct { +type ZCtx struct { Round, Size int SigmaNumerator []*sym.Expression // T -> -M, S -> +Filter SigmaDenominator []*sym.Expression // S or T -> ({S,T} + X) @@ -44,7 +44,7 @@ type zCtx struct { // and change T -> -M, S -> +Filter // S or T -> ({S,T} + X) // compile should be called inside CompileGrandSum -func (z *zCtx) compile(comp *wizard.CompiledIOP) { +func (z *ZCtx) compile(comp *wizard.CompiledIOP) { var ( numZs = utils.DivCeil( @@ -82,14 +82,14 @@ func (z *zCtx) compile(comp *wizard.CompiledIOP) { z.Zs[i] = comp.InsertCommit( z.Round, - deriveName[ifaces.ColID]("Z", comp.SelfRecursionCount, z.Round, z.Size, i), + DeriveName[ifaces.ColID]("Z", comp.SelfRecursionCount, z.Round, z.Size, i), z.Size, ) // initial condition comp.InsertLocal( z.Round, - deriveName[ifaces.QueryID]("Z_CONSISTENCY_START", comp.SelfRecursionCount, z.Round, z.Size, i), + DeriveName[ifaces.QueryID]("Z_CONSISTENCY_START", comp.SelfRecursionCount, z.Round, z.Size, i), sym.Sub( zNumerator, sym.Mul( @@ -102,7 +102,7 @@ func (z *zCtx) compile(comp *wizard.CompiledIOP) { // consistency check comp.InsertGlobal( z.Round, - deriveName[ifaces.QueryID]("Z_CONSISTENCY", comp.SelfRecursionCount, z.Round, z.Size, i), + DeriveName[ifaces.QueryID]("Z_CONSISTENCY", comp.SelfRecursionCount, z.Round, z.Size, i), sym.Sub( zNumerator, sym.Mul( @@ -115,7 +115,7 @@ func (z *zCtx) compile(comp *wizard.CompiledIOP) { // local opening of the final value of the Z polynomial z.ZOpenings[i] = comp.InsertLocalOpening( z.Round, - deriveName[ifaces.QueryID]("Z_FINAL", comp.SelfRecursionCount, z.Round, z.Size, i), + DeriveName[ifaces.QueryID]("Z_FINAL", comp.SelfRecursionCount, z.Round, z.Size, i), column.Shift(z.Zs[i], -1), ) } diff --git a/prover/protocol/distributed/compiler/inclusion/inclusion.go b/prover/protocol/distributed/compiler/inclusion/inclusion.go index 43af9174f..6d18aea3e 100644 --- a/prover/protocol/distributed/compiler/inclusion/inclusion.go +++ b/prover/protocol/distributed/compiler/inclusion/inclusion.go @@ -2,16 +2,160 @@ package inclusion import ( "github.com/consensys/linea-monorepo/prover/protocol/coin" + lookUp "github.com/consensys/linea-monorepo/prover/protocol/compiler/lookup" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/protocol/wizardutils" + "github.com/consensys/linea-monorepo/prover/symbolic" ) -// CompileDist compiles inclusion queries distributedly. -// It receives a compiledIOP object relevant to a segment. -// The seed is a random coin from randomness beacon (FS of all LPP commitments). -// All the compilation steps are similar to the permutation compilation apart from: -// - random coins \alpha and \gamma are generated from the seed (and the tableName). -// - no verifierAction is needed over the ZOpening. -// - ZOpenings are declared as public input. -func CompileDist(comp *wizard.CompiledIOP, seed coin.Info) { +// CompileLogDerivative scans `comp`, looking for Inclusion queries and compiles +// them using the LogDerivativeLookup technique. The compiler attempts to group +// queries relating to the same table. This allows saving in commitment because +// when grouping is possible, then we only need to commit to a single +// extract the table and the checked from the lookup query and ensures that the +// table are in canonical order. That is because we want to group lookups into +// the same columns but not in the same order. +func CompileLogDerivative(comp *wizard.CompiledIOP) { + var ( + mainLookupCtx = captureModuleLookupTables(comp) + lastRound = comp.NumRounds() - 1 + + // zCatalog stores a mapping (round, size) into ZCtx and helps finding + // which Z context should be used to handle a part of a given permutation + // query. + zCatalog = map[[2]int]*lookUp.ZCtx{} + ) + + // Skip the compilation phase if no lookup constraint is being used. Otherwise + // it will register a verifier action that is not required and will be bugged. + if len(mainLookupCtx.LookupTables) == 0 && len(mainLookupCtx.CheckedTables) == 0 { + return + } + + // Step 1. construct the "per table" contexts and pack the Sigma's into + // zCatalog. + for tableName, checkedTables := range mainLookupCtx.CheckedTables { + lookupTable := findLookupTableByName(mainLookupCtx.LookupTables, tableName) + var ( + round = mainLookupCtx.Rounds[tableName] + includedFilters = mainLookupCtx.IncludedFilters[tableName] + tableCtx = compileLookupTable(comp, round, lookupTable, checkedTables, includedFilters) + ) + + // push to zCatalog + tableCtx.PushToZCatalog(zCatalog) + } + + // Handle cases where only T part is present + for _, lookupTable := range mainLookupCtx.LookupTables { + tableName := lookUp.NameTable(lookupTable) + if _, ok := mainLookupCtx.CheckedTables[tableName]; ok { + continue + } + + var ( + round = mainLookupCtx.Rounds[tableName] + tableCtx = compileLookupTable(comp, round, lookupTable, nil, nil) + ) + + // push to zCatalog + tableCtx.PushToZCatalog(zCatalog) + } + + InsertLogDerivativeSum(comp, lastRound, zCatalog) +} + +// findLookupTableByName searches for a lookup table by its name in the list of lookup tables. +func findLookupTableByName(lookupTables [][]table, name string) []table { + for _, lookupTable := range lookupTables { + if lookUp.NameTable(lookupTable) == name { + return lookupTable + } + } + return nil +} + +// compileLookupTable applies the log-derivative lookup compilation context to +// the supplied table. round denotes the interaction round in which to start the +// compilation. +// +// It registers the following queries +// - (1) The verifier queries that $\sum_{k=0\ldots n-1} (\Sigma_{S,k})[|S_k| - 1] == (\Sigma_T)[|T| - 1]$. Namely, the sum of the last entry of all $\Sigma_{S,k}$ equals the last entry of $\Sigma_T$ +// - (2) **(For all k)** the verifier makes a `Local` query : $(\Sigma_{S,k})[0] = \frac{1}{S_{k,0} + \gamma}$ +// - (3) The verifier makes a `Local` query : $(\Sigma_T)[0] = \frac{M_0}{T_0 + \gamma}$ +// - (4) **(For all k)** The verifier makes a `Global` query : $\left((\Sigma_{S,k})[i] - (\Sigma_{S,k})[i-1]\right)(S_{k,i} + \gamma) = 1$ +// - (5) The verifier makes a `Global` query : $\left((\Sigma_T)[i] - (\Sigma_T)[i-1]\right)(T_i + \gamma) = M_i$ + +// here we are looking up set of columns S in a single column T +func compileLookupTable( + comp *wizard.CompiledIOP, + round int, + lookupTable []table, + checkedTables []table, + includedFilters []ifaces.Column, +) (ctx lookUp.SingleTableCtx) { + + ctx = lookUp.SingleTableCtx{ + TableName: lookUp.NameTable(lookupTable), + S: make([]*symbolic.Expression, len(checkedTables)), + SFilters: includedFilters, + T: make([]*symbolic.Expression, len(lookupTable)), + M: make([]ifaces.Column, len(lookupTable)), + } + + var ( + // isMultiColumn indicates whether the lookup table (and thus the + // checked tables) have the same number of + isMultiColumn = (len(lookupTable) > 0 && len(lookupTable[0]) > 1) || (len(checkedTables) > 0 && len(checkedTables[0]) > 1) + ) + + if !isMultiColumn { + for frag := range ctx.T { + ctx.T[frag] = symbolic.NewVariable(lookupTable[frag][0]) + ctx.M[frag] = comp.InsertCommit( + round, + lookUp.DeriveTableNameWithIndex[ifaces.ColID](lookUp.LogDerivativePrefix, lookupTable, frag, "M"), + lookupTable[frag][0].Size(), + ) + + } + + for i := range ctx.S { + ctx.S[i] = symbolic.NewVariable(checkedTables[i][0]) + } + } + + if isMultiColumn { + + // alpha is the coin used to compute the linear combination of the + // columns of T and S when they are (both) multi-columns. + alpha := comp.InsertCoin( + round+1, + lookUp.DeriveTableName[coin.Name](lookUp.LogDerivativePrefix, lookupTable, "ALPHA"), + coin.Field, + ) + + for frag := range ctx.T { + ctx.T[frag] = wizardutils.RandLinCombColSymbolic(alpha, lookupTable[frag]) + ctx.M[frag] = comp.InsertCommit( + round, + lookUp.DeriveTableNameWithIndex[ifaces.ColID](lookUp.LogDerivativePrefix, lookupTable, frag, "M"), + lookupTable[frag][0].Size(), + ) + } + + for i := range ctx.S { + ctx.S[i] = wizardutils.RandLinCombColSymbolic(alpha, checkedTables[i]) + } + } + + ctx.Gamma = comp.InsertCoin( + round+1, + lookUp.DeriveTableName[coin.Name](lookUp.LogDerivativePrefix, lookupTable, "GAMMA"), + coin.Field, + ) + + return ctx } diff --git a/prover/protocol/distributed/compiler/inclusion/log_deriv_sum.go b/prover/protocol/distributed/compiler/inclusion/log_deriv_sum.go new file mode 100644 index 000000000..094c9a55e --- /dev/null +++ b/prover/protocol/distributed/compiler/inclusion/log_deriv_sum.go @@ -0,0 +1,204 @@ +package inclusion + +import ( + "slices" + + "github.com/consensys/linea-monorepo/prover/protocol/column" + lookUp "github.com/consensys/linea-monorepo/prover/protocol/compiler/lookup" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + sym "github.com/consensys/linea-monorepo/prover/symbolic" + "github.com/consensys/linea-monorepo/prover/utils" +) + +const ( + packingArity = 3 +) + +// zCtx is an internal compilation structure responsible for grouping the +// "Sigma" columns together so that we can trade-off commitment complexity with +// global constraint degree. +// +// All the grouped "Sigma" columns are of the same size and same round for the +// same context. +// M is computing the appearance of the rows of T in the rows of S. + +// M: tableCtx.M, +// S: checkTable, +// T: lookupTable, +// SFilter: includedFilters, + +type logDerivativeSumCtx struct { + zCtx *lookUp.ZCtx + // the global sum in LogDerivative + PI ifaces.Column +} + +// check permutation and see how/where compile is called (see how to constracut z there) +// when constructing z, check if z is T or S +// and change T -> -M, S -> +Filter +// S or T -> ({S,T} + X) +// compile should be called inside CompileGrandSum +func (lgSum *logDerivativeSumCtx) compileLogDerivativeSum(comp *wizard.CompiledIOP) { + + var ( + z = lgSum.zCtx + numZs = utils.DivCeil( + len(z.SigmaDenominator), + packingArity, + ) + ) + z.Zs = make([]ifaces.Column, numZs) + z.ZOpenings = make([]query.LocalOpening, numZs) + z.ZNumeratorBoarded = make([]sym.ExpressionBoard, numZs) + z.ZDenominatorBoarded = make([]sym.ExpressionBoard, numZs) + + lgSum.PI = comp.InsertColumn( + z.Round, + lookUp.DeriveName[ifaces.ColID]("PI", comp.SelfRecursionCount, z.Round, z.Size), + 1, + column.PublicInput, + ) + + for i := range z.Zs { + + var ( + packedNum = safeAnySubSlice(z.SigmaNumerator, i*packingArity, (i+1)*packingArity) + packedDen = safeAnySubSlice(z.SigmaDenominator, i*packingArity, (i+1)*packingArity) + + zNumerator = sym.NewConstant(0) + zDenominator = sym.Mul(packedDen...) + ) + + for j := range packedNum { + term := packedNum[j] + for k := range packedDen { + if k != j { + term = sym.Mul(term, packedDen[k]) + } + } + zNumerator = sym.Add(zNumerator, term) + } + + z.ZNumeratorBoarded[i] = zNumerator.Board() + z.ZDenominatorBoarded[i] = zDenominator.Board() + + z.Zs[i] = comp.InsertCommit( + z.Round, + lookUp.DeriveName[ifaces.ColID]("Z", comp.SelfRecursionCount, z.Round, z.Size, i), + z.Size, + ) + + // initial condition + comp.InsertLocal( + z.Round, + lookUp.DeriveName[ifaces.QueryID]("Z_CONSISTENCY_START", comp.SelfRecursionCount, z.Round, z.Size, i), + sym.Sub( + zNumerator, + sym.Mul( + z.Zs[i], + zDenominator, + ), + ), + ) + + // consistency check + comp.InsertGlobal( + z.Round, + lookUp.DeriveName[ifaces.QueryID]("Z_CONSISTENCY", comp.SelfRecursionCount, z.Round, z.Size, i), + sym.Sub( + zNumerator, + sym.Mul( + sym.Sub(z.Zs[i], column.Shift(z.Zs[i], -1)), + zDenominator, + ), + ), + ) + + // local opening of the final value of the Z polynomial + z.ZOpenings[i] = comp.InsertLocalOpening( + z.Round, + lookUp.DeriveName[ifaces.QueryID]("Z_FINAL", comp.SelfRecursionCount, z.Round, z.Size, i), + column.Shift(z.Zs[i], -1), + ) + } +} + +// attempt to take the subslice of a slice, and truncates or returns an empty +// slice if the parameters are out of bounds. +func safeAnySubSlice[T any](t []T, start, stop int) []any { + + if stop < start { + panic("invalid argument") + } + + var tmp []T + + switch { + case start >= len(t): + return []any{} + case stop >= len(t): + tmp = t[start:] + default: + tmp = t[start:stop] + } + + res := make([]any, len(tmp)) + for i := range res { + res[i] = tmp[i] + } + + return res +} + +// IntoLogDerivativeSum handles the remaining process after PushToZCatalog. +func InsertLogDerivativeSum(comp *wizard.CompiledIOP, lastRound int, zCatalog map[[2]int]*lookUp.ZCtx) { + var ( + zEntries = [][2]int{} + va = finalEvaluationCheck{} + ) + + // This loop is necessary to build a sorted list of the entries of zCatalog. + // Without it, if we tried to loop over zCatalog directly, the entries would + // be processed in a non-deterministic order. The sorting order itself is + // without importance, what matters is that zEntries is in deterministic + // order. + for entry := range zCatalog { + zEntries = append(zEntries, entry) + } + + slices.SortFunc(zEntries, func(a, b [2]int) int { + switch { + case a[0] < b[0]: + return -1 + case a[0] > b[0]: + return 1 + case a[1] < b[1]: + return -1 + case a[1] > b[1]: + return 1 + default: + return 0 + } + }) + + // compile zCatalog + for _, entry := range zEntries { + zC := zCatalog[entry] + logDerivSumCtx := + logDerivativeSumCtx{ + zCtx: zC, + } + // z-packing compile + logDerivSumCtx.compileLogDerivativeSum(comp) + // entry[0]:round, entry[1]: size + // the round that Gamma was registered. + + // pushZAssignment(zAssignmentTask(*zC)) + va.ZOpenings = append(va.ZOpenings, zC.ZOpenings...) + va.Name = zC.Name + } + + comp.RegisterVerifierAction(lastRound, &va) +} diff --git a/prover/protocol/distributed/compiler/inclusion/log_derivative_sum_verif.go b/prover/protocol/distributed/compiler/inclusion/log_derivative_sum_verif.go new file mode 100644 index 000000000..cccb5a3ac --- /dev/null +++ b/prover/protocol/distributed/compiler/inclusion/log_derivative_sum_verif.go @@ -0,0 +1,59 @@ +package inclusion + +import ( + "fmt" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" +) + +// finalEvaluationCheck implements the [wizard.VerifierAction] interface. It +// represents the consistency check between SigmaT and the SigmaS[i]. This +// corresponds in the check #1 in the doc, where the verifier ensures that +// summing the ending values of the SigmaS[i]s yields the same result as the +// ending value of SigmaT. +// +// The current implementation is for packed Zs +type finalEvaluationCheck struct { + // the name of a lookupTable in the pack, this can help for debugging. + Name string + // ZOpenings lists all the openings of all the zCtx + ZOpenings []query.LocalOpening + // the global sum in LogDerivative + PI query.LocalOpening +} + +// Run implements the [wizard.VerifierAction] +func (f *finalEvaluationCheck) Run(run *wizard.VerifierRuntime) error { + + // zSum stores the sum of the ending values of the zs as queried + // in the protocol via the local opening queries. + zSum := field.Zero() + for k := range f.ZOpenings { + temp := run.GetLocalPointEvalParams(f.ZOpenings[k].ID).Y + zSum.Add(&zSum, &temp) + } + + PI := run.GetLocalPointEvalParams(f.PI.ID).Y + if zSum != PI { + return fmt.Errorf("log-derivate lookup, the final evaluation check failed for %v,", f.Name) + } + + return nil +} + +// RunGnark implements the [wizard.VerifierAction] +func (f *finalEvaluationCheck) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) { + + // SigmaSKSum stores the sum of the ending values of the SigmaSs as queried + // in the protocol via the + zSum := frontend.Variable(field.Zero()) + for k := range f.ZOpenings { + temp := run.GetLocalPointEvalParams(f.ZOpenings[k].ID).Y + zSum = api.Add(zSum, temp) + } + + api.AssertIsEqual(zSum, 0) +} diff --git a/prover/protocol/distributed/compiler/inclusion/lookup_discoverer.go b/prover/protocol/distributed/compiler/inclusion/lookup_discoverer.go new file mode 100644 index 000000000..4859d8a40 --- /dev/null +++ b/prover/protocol/distributed/compiler/inclusion/lookup_discoverer.go @@ -0,0 +1,117 @@ +package inclusion + +import ( + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/column" + "github.com/consensys/linea-monorepo/prover/protocol/column/verifiercol" + lookUp "github.com/consensys/linea-monorepo/prover/protocol/compiler/lookup" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" +) + +// table is an alias for a list of column. We use it in the scope of the lookup +// compiler as a shorthand to make the code more eye-parseable. +type table = []ifaces.Column + +// captureModuleLookupTables inspects comp and looks for Inclusion queries that are relevant to the module. +// It groups the matched queries by lookup table and marks them as ignored. +// The input is a compiledIOP object that stores the columns relevant to the module (in its Column field) +// Note that for a lookup query the module may contain only S or T table (and not necessarily both). +func captureModuleLookupTables(comp *wizard.CompiledIOP) lookUp.MainLookupCtx { + var ( + ctx = lookUp.MainLookupCtx{ + LookupTables: [][]table{}, + CheckedTables: map[string][]table{}, + IncludedFilters: map[string][]ifaces.Column{}, + Rounds: map[string]int{}, + } + ) + + // Collect all the lookup queries into "lookups" + for _, qName := range comp.QueriesNoParams.AllUnignoredKeys() { + + // Filter out non-lookup queries + lookup, ok := comp.QueriesNoParams.Data(qName).(query.Inclusion) + if !ok { + continue + } + + // Determine if the query is relevant to the module + relevantPart := determineRelevantPart(lookup, comp.Columns) + if relevantPart == "" { + continue + } + + // This ensures that the lookup query is not used again in the + // compilation process. We know that the query was already ignored at + // the beginning because we are iterating over the unignored keys. + comp.QueriesNoParams.MarkAsIgnored(qName) + + var ( + // checkedTable corresponds to the "included" table and lookupTable + // corresponds to the including table. + checkedTable, lookupTable = lookUp.GetTableCanonicalOrder(lookup) + tableName = lookUp.NameTable(lookupTable) + // includedFilters stores the query.IncludedFilter parameter. If the + // query has no includedFilters on the Included side. Then this is + // left as nil. + includedFilter ifaces.Column + ) + + if lookup.IsFilteredOnIncluding() { + var ( + checkedLen = checkedTable[0].Size() + ones = verifiercol.NewConstantCol(field.One(), checkedLen) + ) + + checkedTable = append([]ifaces.Column{ones}, checkedTable...) + for frag := range lookupTable { + lookupTable[frag] = append([]ifaces.Column{lookup.IncludingFilter[frag]}, lookupTable[frag]...) + } + + tableName = lookUp.NameTable(lookupTable) + } + + if lookup.IsFilteredOnIncluded() { + includedFilter = lookup.IncludedFilter + } + + // In case this is the first iteration where we encounter the lookupTable + // we need to add entries in the registering maps. + if _, ok := ctx.CheckedTables[tableName]; !ok { + ctx.IncludedFilters[tableName] = []ifaces.Column{} + ctx.CheckedTables[tableName] = []table{} + ctx.LookupTables = [][]table{} + ctx.Rounds[tableName] = 0 + } + + // Add only the relevant part to the context + if relevantPart == "S" { + ctx.IncludedFilters[tableName] = append(ctx.IncludedFilters[tableName], includedFilter) + ctx.CheckedTables[tableName] = append(ctx.CheckedTables[tableName], checkedTable) + } else if relevantPart == "T" { + ctx.LookupTables = append(ctx.LookupTables, lookupTable) + } + + ctx.Rounds[tableName] = max(ctx.Rounds[tableName], comp.QueriesNoParams.Round(lookup.ID)) + + } + + return ctx +} + +// determineRelevantPart checks if the lookup query involves columns from the module and returns the relevant part (S or T). +func determineRelevantPart(lookup query.Inclusion, moduleColumns column.Store) string { + // Check if any column in S part is in the module + if moduleColumns.Exists(lookup.Included[0].GetColID()) { + return "S" + } + + // Check if any column in T part is in the module + if moduleColumns.Exists(lookup.Including[0][0].GetColID()) { + return "T" + } + + return "" +} diff --git a/prover/protocol/distributed/compiler/inclusion/preparation.go b/prover/protocol/distributed/compiler/inclusion/preparation.go new file mode 100644 index 000000000..0fb53adf8 --- /dev/null +++ b/prover/protocol/distributed/compiler/inclusion/preparation.go @@ -0,0 +1,9 @@ +package inclusion + +import "github.com/consensys/linea-monorepo/prover/protocol/wizard" + +// It scans the compiledIOP and group all the checkedTables related to the same lookupTable, +// It creates a frequency column M for all such pairs (checkedTables, lookupTable). +func prepareLookup(comp *wizard.CompiledIOP) { + +} diff --git a/prover/protocol/distributed/compiler/inclusion/prover.go b/prover/protocol/distributed/compiler/inclusion/prover.go new file mode 100644 index 000000000..2cc2a8886 --- /dev/null +++ b/prover/protocol/distributed/compiler/inclusion/prover.go @@ -0,0 +1,125 @@ +package inclusion + +import ( + "runtime/debug" + "sync" + + sv "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/common/vector" + "github.com/consensys/linea-monorepo/prover/maths/field" + lookUp "github.com/consensys/linea-monorepo/prover/protocol/compiler/lookup" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/protocol/wizardutils" + "github.com/consensys/linea-monorepo/prover/utils" + "github.com/consensys/linea-monorepo/prover/utils/parallel" +) + +// proverTaskAtRound implements the [wizard.ProverAction] interface. It gathers +// all the operations related to all compiled tables altogether that have to be +// done at a particular round. +// +// Namely, if applied to the round N. The action will be responsible for +// assigning the M column for tables compiled on round N and the SigmaS/SigmaT +// and their respective LocalOpening for the tables compiled at round N-1. +// +// All these actions are performed in parallel. +type proverTaskAtRound struct { + + // ZAssignmentTasks lists all the tasks consisting of assigning the + // columns SigmaS and SigmaT for the given round. + ZAssignmentTasks []zAssignmentTask +} + +// Run implements the [wizard.ProverAction interface]. The tasks will spawn +// a goroutine for each tasks and wait for all of them to finish. The approach +// for parallelization can be justified if the number of go-routines stays low +// (e.g. less than 1000s). +func (p proverTaskAtRound) Run(run *wizard.ProverRuntime) { + + wg := &sync.WaitGroup{} + wg.Add(p.numTasks()) + + var ( + panicTrace []byte + panicMsg any + panicOnce = &sync.Once{} + ) + + for i := range p.ZAssignmentTasks { + // the passing of the index `i` is there to ensure that the go-routine + // is running over a local copy of `i` which is not incremented every + // time the loop goes to the next iteration. + go func(i int) { + + // In case the subtask panics, we recover so that we can repanic in + // the main goroutine. Simplifying the process of tracing back the + // error and allowing to test the panics. + defer func() { + if r := recover(); r != nil { + panicOnce.Do(func() { + panicMsg = r + panicTrace = debug.Stack() + }) + } + + wg.Done() + }() + + p.ZAssignmentTasks[i].run(run) + }(i) + } + + wg.Wait() + + if len(panicTrace) > 0 { + utils.Panic("Had a panic: %v\nStack: %v\n", panicMsg, string(panicTrace)) + } +} + +// pushZAssignment appends an [sigmaAssignmentTask] to the list of tasks +func (p *proverTaskAtRound) pushZAssignment(s zAssignmentTask) { + p.ZAssignmentTasks = append(p.ZAssignmentTasks, s) +} + +// numTasks returns the total number of tasks that are scheduled in the +// [proverTaskAtRound]. +func (p *proverTaskAtRound) numTasks() int { + return len(p.ZAssignmentTasks) +} + +// zAssignmentTask represents a prover task of assignming the columns +// SigmaS and SigmaT for a specific lookup table. +// sigmaAssignment +type zAssignmentTask lookUp.ZCtx + +func (z zAssignmentTask) run(run *wizard.ProverRuntime) { + parallel.Execute(len(z.ZDenominatorBoarded), func(start, stop int) { + for frag := start; frag < stop; frag++ { + + var ( + numeratorMetadata = z.ZNumeratorBoarded[frag].ListVariableMetadata() + denominator = wizardutils.EvalExprColumn(run, z.ZDenominatorBoarded[frag]).IntoRegVecSaveAlloc() + numerator []field.Element + packedZ = field.BatchInvert(denominator) + ) + + if len(numeratorMetadata) == 0 { + numerator = vector.Repeat(field.One(), z.Size) + } + + if len(numeratorMetadata) > 0 { + numerator = wizardutils.EvalExprColumn(run, z.ZNumeratorBoarded[frag]).IntoRegVecSaveAlloc() + } + + for k := range packedZ { + packedZ[k].Mul(&numerator[k], &packedZ[k]) + if k > 0 { + packedZ[k].Add(&packedZ[k], &packedZ[k-1]) + } + } + + run.AssignColumn(z.Zs[frag].GetColID(), sv.NewRegular(packedZ)) + run.AssignLocalPoint(z.ZOpenings[frag].ID, packedZ[len(packedZ)-1]) + } + }) +} From 4ab05032d0e7486016c094ca2fc51e1cd6918f60 Mon Sep 17 00:00:00 2001 From: Soleimani193 Date: Wed, 11 Dec 2024 16:43:40 +0100 Subject: [PATCH 02/12] added logderivative_sum query --- .../compiler/inclusion/inclusion.go | 61 +++++- .../compiler/inclusion/log_deriv_sum.go | 204 ------------------ .../inclusion/log_derivative_sum_verif.go | 59 ----- prover/protocol/query/logderiv_sum.go | 66 ++++++ prover/protocol/wizard/compiled.go | 11 + 5 files changed, 134 insertions(+), 267 deletions(-) delete mode 100644 prover/protocol/distributed/compiler/inclusion/log_deriv_sum.go delete mode 100644 prover/protocol/distributed/compiler/inclusion/log_derivative_sum_verif.go create mode 100644 prover/protocol/query/logderiv_sum.go diff --git a/prover/protocol/distributed/compiler/inclusion/inclusion.go b/prover/protocol/distributed/compiler/inclusion/inclusion.go index 6d18aea3e..d0ff89911 100644 --- a/prover/protocol/distributed/compiler/inclusion/inclusion.go +++ b/prover/protocol/distributed/compiler/inclusion/inclusion.go @@ -4,6 +4,7 @@ import ( "github.com/consensys/linea-monorepo/prover/protocol/coin" lookUp "github.com/consensys/linea-monorepo/prover/protocol/compiler/lookup" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/protocol/wizard" "github.com/consensys/linea-monorepo/prover/protocol/wizardutils" "github.com/consensys/linea-monorepo/prover/symbolic" @@ -25,7 +26,7 @@ func CompileLogDerivative(comp *wizard.CompiledIOP) { // zCatalog stores a mapping (round, size) into ZCtx and helps finding // which Z context should be used to handle a part of a given permutation // query. - zCatalog = map[[2]int]*lookUp.ZCtx{} + zCatalog = map[[2]int]*query.LogDerivativeSumInput{} ) // Skip the compilation phase if no lookup constraint is being used. Otherwise @@ -45,7 +46,7 @@ func CompileLogDerivative(comp *wizard.CompiledIOP) { ) // push to zCatalog - tableCtx.PushToZCatalog(zCatalog) + PushToZCatalog(tableCtx, zCatalog) } // Handle cases where only T part is present @@ -61,10 +62,11 @@ func CompileLogDerivative(comp *wizard.CompiledIOP) { ) // push to zCatalog - tableCtx.PushToZCatalog(zCatalog) + PushToZCatalog(tableCtx, zCatalog) } - InsertLogDerivativeSum(comp, lastRound, zCatalog) + // insert a LogDerivativeSum for all the Sigma Columns . + comp.InsertLogDerivativeSum(lastRound, "LogDerivativeSum", zCatalog) } // findLookupTableByName searches for a lookup table by its name in the list of lookup tables. @@ -159,3 +161,54 @@ func compileLookupTable( return ctx } + +// PushToZCatalog constructs the numerators and denominators for S and T of the +// stc into zCatalog for their corresponding rounds and size. +func PushToZCatalog(stc lookUp.SingleTableCtx, zCatalog map[[2]int]*query.LogDerivativeSumInput) { + + var ( + round = stc.Gamma.Round + ) + + // tableCtx push to -> zCtx + // Process the T columns + for frag := range stc.T { + size := stc.M[frag].Size() + + key := [2]int{round, size} + if zCatalog[key] == nil { + zCatalog[key] = &query.LogDerivativeSumInput{ + Size: size, + Round: round, + } + } + + zCtxEntry := zCatalog[key] + zCtxEntry.Numerator = append(zCtxEntry.Numerator, symbolic.Neg(stc.M[frag])) // no functions for num, denom here + zCtxEntry.Denominator = append(zCtxEntry.Denominator, symbolic.Add(stc.Gamma, stc.T[frag])) + } + + // Process the S columns + for table := range stc.S { + var ( + _, _, size = wizardutils.AsExpr(stc.S[table]) + sFilter = symbolic.NewConstant(1) + ) + + if stc.SFilters[table] != nil { + sFilter = symbolic.NewVariable(stc.SFilters[table]) + } + + key := [2]int{round, size} + if zCatalog[key] == nil { + zCatalog[key] = &query.LogDerivativeSumInput{ + Size: size, + Round: round, + } + } + + zCtxEntry := zCatalog[key] + zCtxEntry.Numerator = append(zCtxEntry.Numerator, sFilter) + zCtxEntry.Denominator = append(zCtxEntry.Denominator, symbolic.Add(stc.Gamma, stc.S[table])) + } +} diff --git a/prover/protocol/distributed/compiler/inclusion/log_deriv_sum.go b/prover/protocol/distributed/compiler/inclusion/log_deriv_sum.go deleted file mode 100644 index 094c9a55e..000000000 --- a/prover/protocol/distributed/compiler/inclusion/log_deriv_sum.go +++ /dev/null @@ -1,204 +0,0 @@ -package inclusion - -import ( - "slices" - - "github.com/consensys/linea-monorepo/prover/protocol/column" - lookUp "github.com/consensys/linea-monorepo/prover/protocol/compiler/lookup" - "github.com/consensys/linea-monorepo/prover/protocol/ifaces" - "github.com/consensys/linea-monorepo/prover/protocol/query" - "github.com/consensys/linea-monorepo/prover/protocol/wizard" - sym "github.com/consensys/linea-monorepo/prover/symbolic" - "github.com/consensys/linea-monorepo/prover/utils" -) - -const ( - packingArity = 3 -) - -// zCtx is an internal compilation structure responsible for grouping the -// "Sigma" columns together so that we can trade-off commitment complexity with -// global constraint degree. -// -// All the grouped "Sigma" columns are of the same size and same round for the -// same context. -// M is computing the appearance of the rows of T in the rows of S. - -// M: tableCtx.M, -// S: checkTable, -// T: lookupTable, -// SFilter: includedFilters, - -type logDerivativeSumCtx struct { - zCtx *lookUp.ZCtx - // the global sum in LogDerivative - PI ifaces.Column -} - -// check permutation and see how/where compile is called (see how to constracut z there) -// when constructing z, check if z is T or S -// and change T -> -M, S -> +Filter -// S or T -> ({S,T} + X) -// compile should be called inside CompileGrandSum -func (lgSum *logDerivativeSumCtx) compileLogDerivativeSum(comp *wizard.CompiledIOP) { - - var ( - z = lgSum.zCtx - numZs = utils.DivCeil( - len(z.SigmaDenominator), - packingArity, - ) - ) - z.Zs = make([]ifaces.Column, numZs) - z.ZOpenings = make([]query.LocalOpening, numZs) - z.ZNumeratorBoarded = make([]sym.ExpressionBoard, numZs) - z.ZDenominatorBoarded = make([]sym.ExpressionBoard, numZs) - - lgSum.PI = comp.InsertColumn( - z.Round, - lookUp.DeriveName[ifaces.ColID]("PI", comp.SelfRecursionCount, z.Round, z.Size), - 1, - column.PublicInput, - ) - - for i := range z.Zs { - - var ( - packedNum = safeAnySubSlice(z.SigmaNumerator, i*packingArity, (i+1)*packingArity) - packedDen = safeAnySubSlice(z.SigmaDenominator, i*packingArity, (i+1)*packingArity) - - zNumerator = sym.NewConstant(0) - zDenominator = sym.Mul(packedDen...) - ) - - for j := range packedNum { - term := packedNum[j] - for k := range packedDen { - if k != j { - term = sym.Mul(term, packedDen[k]) - } - } - zNumerator = sym.Add(zNumerator, term) - } - - z.ZNumeratorBoarded[i] = zNumerator.Board() - z.ZDenominatorBoarded[i] = zDenominator.Board() - - z.Zs[i] = comp.InsertCommit( - z.Round, - lookUp.DeriveName[ifaces.ColID]("Z", comp.SelfRecursionCount, z.Round, z.Size, i), - z.Size, - ) - - // initial condition - comp.InsertLocal( - z.Round, - lookUp.DeriveName[ifaces.QueryID]("Z_CONSISTENCY_START", comp.SelfRecursionCount, z.Round, z.Size, i), - sym.Sub( - zNumerator, - sym.Mul( - z.Zs[i], - zDenominator, - ), - ), - ) - - // consistency check - comp.InsertGlobal( - z.Round, - lookUp.DeriveName[ifaces.QueryID]("Z_CONSISTENCY", comp.SelfRecursionCount, z.Round, z.Size, i), - sym.Sub( - zNumerator, - sym.Mul( - sym.Sub(z.Zs[i], column.Shift(z.Zs[i], -1)), - zDenominator, - ), - ), - ) - - // local opening of the final value of the Z polynomial - z.ZOpenings[i] = comp.InsertLocalOpening( - z.Round, - lookUp.DeriveName[ifaces.QueryID]("Z_FINAL", comp.SelfRecursionCount, z.Round, z.Size, i), - column.Shift(z.Zs[i], -1), - ) - } -} - -// attempt to take the subslice of a slice, and truncates or returns an empty -// slice if the parameters are out of bounds. -func safeAnySubSlice[T any](t []T, start, stop int) []any { - - if stop < start { - panic("invalid argument") - } - - var tmp []T - - switch { - case start >= len(t): - return []any{} - case stop >= len(t): - tmp = t[start:] - default: - tmp = t[start:stop] - } - - res := make([]any, len(tmp)) - for i := range res { - res[i] = tmp[i] - } - - return res -} - -// IntoLogDerivativeSum handles the remaining process after PushToZCatalog. -func InsertLogDerivativeSum(comp *wizard.CompiledIOP, lastRound int, zCatalog map[[2]int]*lookUp.ZCtx) { - var ( - zEntries = [][2]int{} - va = finalEvaluationCheck{} - ) - - // This loop is necessary to build a sorted list of the entries of zCatalog. - // Without it, if we tried to loop over zCatalog directly, the entries would - // be processed in a non-deterministic order. The sorting order itself is - // without importance, what matters is that zEntries is in deterministic - // order. - for entry := range zCatalog { - zEntries = append(zEntries, entry) - } - - slices.SortFunc(zEntries, func(a, b [2]int) int { - switch { - case a[0] < b[0]: - return -1 - case a[0] > b[0]: - return 1 - case a[1] < b[1]: - return -1 - case a[1] > b[1]: - return 1 - default: - return 0 - } - }) - - // compile zCatalog - for _, entry := range zEntries { - zC := zCatalog[entry] - logDerivSumCtx := - logDerivativeSumCtx{ - zCtx: zC, - } - // z-packing compile - logDerivSumCtx.compileLogDerivativeSum(comp) - // entry[0]:round, entry[1]: size - // the round that Gamma was registered. - - // pushZAssignment(zAssignmentTask(*zC)) - va.ZOpenings = append(va.ZOpenings, zC.ZOpenings...) - va.Name = zC.Name - } - - comp.RegisterVerifierAction(lastRound, &va) -} diff --git a/prover/protocol/distributed/compiler/inclusion/log_derivative_sum_verif.go b/prover/protocol/distributed/compiler/inclusion/log_derivative_sum_verif.go deleted file mode 100644 index cccb5a3ac..000000000 --- a/prover/protocol/distributed/compiler/inclusion/log_derivative_sum_verif.go +++ /dev/null @@ -1,59 +0,0 @@ -package inclusion - -import ( - "fmt" - - "github.com/consensys/gnark/frontend" - "github.com/consensys/linea-monorepo/prover/maths/field" - "github.com/consensys/linea-monorepo/prover/protocol/query" - "github.com/consensys/linea-monorepo/prover/protocol/wizard" -) - -// finalEvaluationCheck implements the [wizard.VerifierAction] interface. It -// represents the consistency check between SigmaT and the SigmaS[i]. This -// corresponds in the check #1 in the doc, where the verifier ensures that -// summing the ending values of the SigmaS[i]s yields the same result as the -// ending value of SigmaT. -// -// The current implementation is for packed Zs -type finalEvaluationCheck struct { - // the name of a lookupTable in the pack, this can help for debugging. - Name string - // ZOpenings lists all the openings of all the zCtx - ZOpenings []query.LocalOpening - // the global sum in LogDerivative - PI query.LocalOpening -} - -// Run implements the [wizard.VerifierAction] -func (f *finalEvaluationCheck) Run(run *wizard.VerifierRuntime) error { - - // zSum stores the sum of the ending values of the zs as queried - // in the protocol via the local opening queries. - zSum := field.Zero() - for k := range f.ZOpenings { - temp := run.GetLocalPointEvalParams(f.ZOpenings[k].ID).Y - zSum.Add(&zSum, &temp) - } - - PI := run.GetLocalPointEvalParams(f.PI.ID).Y - if zSum != PI { - return fmt.Errorf("log-derivate lookup, the final evaluation check failed for %v,", f.Name) - } - - return nil -} - -// RunGnark implements the [wizard.VerifierAction] -func (f *finalEvaluationCheck) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) { - - // SigmaSKSum stores the sum of the ending values of the SigmaSs as queried - // in the protocol via the - zSum := frontend.Variable(field.Zero()) - for k := range f.ZOpenings { - temp := run.GetLocalPointEvalParams(f.ZOpenings[k].ID).Y - zSum = api.Add(zSum, temp) - } - - api.AssertIsEqual(zSum, 0) -} diff --git a/prover/protocol/query/logderiv_sum.go b/prover/protocol/query/logderiv_sum.go new file mode 100644 index 000000000..bbdacbb2d --- /dev/null +++ b/prover/protocol/query/logderiv_sum.go @@ -0,0 +1,66 @@ +package query + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + sym "github.com/consensys/linea-monorepo/prover/symbolic" +) + +// LogDerivativeSumInput stores the input to the query +type LogDerivativeSumInput struct { + Round, Size int + Numerator []*sym.Expression // T -> -M, S -> +Filter + Denominator []*sym.Expression // S or T -> ({S,T} + X) +} + +// LogDerivativeSum is the context of LogDerivativeSum. +// The fields are maps from [round, size]. +type LogDerivativeSum struct { + Inputs map[[2]int]*LogDerivativeSumInput + + ZNumeratorBoarded, ZDenominatorBoarded map[[2]int][]sym.ExpressionBoard + + Zs map[[2]int][]ifaces.Column + // ZOpenings are the opening queries to the end of each Z. + ZOpenings map[[2]int][]LocalOpening + + ID ifaces.QueryID +} + +// the result of the global Sum +type LogDerivSumParams struct { + Sum field.Element // the sum of all the ZOpenings from different [round,size]. +} + +// NewLogDerivativeSum creates the new context LogDerivativeSum. +func NewLogDerivativeSum(inp map[[2]int]*LogDerivativeSumInput) LogDerivativeSum { + + // add some sanity checks here + + return LogDerivativeSum{ + Inputs: inp, + } + +} + +// Name implements the [ifaces.Query] interface +func (r LogDerivativeSum) Name() ifaces.QueryID { + return r.ID +} + +// Constructor for the query parameters/result +func NewLogDeriveSumParams(sum field.Element) LogDerivSumParams { + return LogDerivSumParams{Sum: sum} +} + +// Test that global sum is correct +func (r LogDerivativeSum) Check(run ifaces.Runtime) error { + + return nil +} + +// Test that global sum is correct +func (r LogDerivativeSum) CheckGnark(api frontend.API, run ifaces.GnarkRuntime) { + +} diff --git a/prover/protocol/wizard/compiled.go b/prover/protocol/wizard/compiled.go index 6cf5c37a3..f0b0efd87 100644 --- a/prover/protocol/wizard/compiled.go +++ b/prover/protocol/wizard/compiled.go @@ -579,6 +579,17 @@ func (c *CompiledIOP) InsertLocalOpening(round int, name ifaces.QueryID, pol ifa return q } +// InsertLogDerivativeSum registers a new LogDerivativeSum query [query.LogDerivativeSum]. +// It generates a single global summation for many Sigma Columns from Lookup compilation. +// The sigma columns are categorized by [round,size]. +func (c *CompiledIOP) InsertLogDerivativeSum(lastRound int, id ifaces.QueryID, in map[[2]int]*query.LogDerivativeSumInput) query.LogDerivativeSum { + c.assertConsistentRound(lastRound) + q := query.NewLogDerivativeSum(in) + // Finally registers the query + c.QueriesParams.AddToRound(lastRound, id, q) + return q +} + // assertConsistentRound compares the round passed as an argument and panic if it greater than // coin.Round. This helps ensuring that we do not have "useless" rounds. func (c *CompiledIOP) assertConsistentRound(round int) { From d39fb716bbcb51e21d4c384fae636d64d284cb3b Mon Sep 17 00:00:00 2001 From: Soleimani193 Date: Wed, 11 Dec 2024 17:34:49 +0100 Subject: [PATCH 03/12] added documentation --- .../compiler/inclusion/inclusion.go | 55 ++++++------------- .../compiler/inclusion/lookup_discoverer.go | 36 +++++++++--- .../compiler/inclusion/preparation.go | 6 +- 3 files changed, 48 insertions(+), 49 deletions(-) diff --git a/prover/protocol/distributed/compiler/inclusion/inclusion.go b/prover/protocol/distributed/compiler/inclusion/inclusion.go index d0ff89911..f21737c94 100644 --- a/prover/protocol/distributed/compiler/inclusion/inclusion.go +++ b/prover/protocol/distributed/compiler/inclusion/inclusion.go @@ -12,20 +12,22 @@ import ( // CompileLogDerivative scans `comp`, looking for Inclusion queries and compiles // them using the LogDerivativeLookup technique. The compiler attempts to group -// queries relating to the same table. This allows saving in commitment because -// when grouping is possible, then we only need to commit to a single -// extract the table and the checked from the lookup query and ensures that the -// table are in canonical order. That is because we want to group lookups into -// the same columns but not in the same order. -func CompileLogDerivative(comp *wizard.CompiledIOP) { +// queries relating to the same table (as such groups needs the same randomness). +// +// The input is a wizard.CompiledIOP object relevant to the module. +// It contains a list of the columns relevant to the module (inside its Columns field). +// +// For each T column inside the module, it also contains the M column. +// +// Note that for a lookup query the module may contain only the S or T columns (and not both). +func CompileLogDerivative(moduleComp *wizard.CompiledIOP) { var ( - mainLookupCtx = captureModuleLookupTables(comp) - lastRound = comp.NumRounds() - 1 + mainLookupCtx = captureModuleLookupTables(moduleComp) + lastRound = moduleComp.NumRounds() - 1 - // zCatalog stores a mapping (round, size) into ZCtx and helps finding - // which Z context should be used to handle a part of a given permutation - // query. + // zCatalog stores a mapping (round, size) into [query.LogDerivativeSumInput]. + // it packs the sigma columns from the same (round,size) together. zCatalog = map[[2]int]*query.LogDerivativeSumInput{} ) @@ -42,7 +44,7 @@ func CompileLogDerivative(comp *wizard.CompiledIOP) { var ( round = mainLookupCtx.Rounds[tableName] includedFilters = mainLookupCtx.IncludedFilters[tableName] - tableCtx = compileLookupTable(comp, round, lookupTable, checkedTables, includedFilters) + tableCtx = collapsMultiColsToSingleCol(moduleComp, round, lookupTable, checkedTables, includedFilters) ) // push to zCatalog @@ -58,7 +60,7 @@ func CompileLogDerivative(comp *wizard.CompiledIOP) { var ( round = mainLookupCtx.Rounds[tableName] - tableCtx = compileLookupTable(comp, round, lookupTable, nil, nil) + tableCtx = collapsMultiColsToSingleCol(moduleComp, round, lookupTable, nil, nil) ) // push to zCatalog @@ -66,7 +68,7 @@ func CompileLogDerivative(comp *wizard.CompiledIOP) { } // insert a LogDerivativeSum for all the Sigma Columns . - comp.InsertLogDerivativeSum(lastRound, "LogDerivativeSum", zCatalog) + moduleComp.InsertLogDerivativeSum(lastRound, "LogDerivativeSum", zCatalog) } // findLookupTableByName searches for a lookup table by its name in the list of lookup tables. @@ -79,19 +81,8 @@ func findLookupTableByName(lookupTables [][]table, name string) []table { return nil } -// compileLookupTable applies the log-derivative lookup compilation context to -// the supplied table. round denotes the interaction round in which to start the -// compilation. -// -// It registers the following queries -// - (1) The verifier queries that $\sum_{k=0\ldots n-1} (\Sigma_{S,k})[|S_k| - 1] == (\Sigma_T)[|T| - 1]$. Namely, the sum of the last entry of all $\Sigma_{S,k}$ equals the last entry of $\Sigma_T$ -// - (2) **(For all k)** the verifier makes a `Local` query : $(\Sigma_{S,k})[0] = \frac{1}{S_{k,0} + \gamma}$ -// - (3) The verifier makes a `Local` query : $(\Sigma_T)[0] = \frac{M_0}{T_0 + \gamma}$ -// - (4) **(For all k)** The verifier makes a `Global` query : $\left((\Sigma_{S,k})[i] - (\Sigma_{S,k})[i-1]\right)(S_{k,i} + \gamma) = 1$ -// - (5) The verifier makes a `Global` query : $\left((\Sigma_T)[i] - (\Sigma_T)[i-1]\right)(T_i + \gamma) = M_i$ - -// here we are looking up set of columns S in a single column T -func compileLookupTable( +// It collapses the tables of MultiColumns to single columns. It also sample the Gamma coin for the rest of the compilation. +func collapsMultiColsToSingleCol( comp *wizard.CompiledIOP, round int, lookupTable []table, @@ -116,11 +107,6 @@ func compileLookupTable( if !isMultiColumn { for frag := range ctx.T { ctx.T[frag] = symbolic.NewVariable(lookupTable[frag][0]) - ctx.M[frag] = comp.InsertCommit( - round, - lookUp.DeriveTableNameWithIndex[ifaces.ColID](lookUp.LogDerivativePrefix, lookupTable, frag, "M"), - lookupTable[frag][0].Size(), - ) } @@ -141,11 +127,6 @@ func compileLookupTable( for frag := range ctx.T { ctx.T[frag] = wizardutils.RandLinCombColSymbolic(alpha, lookupTable[frag]) - ctx.M[frag] = comp.InsertCommit( - round, - lookUp.DeriveTableNameWithIndex[ifaces.ColID](lookUp.LogDerivativePrefix, lookupTable, frag, "M"), - lookupTable[frag][0].Size(), - ) } for i := range ctx.S { diff --git a/prover/protocol/distributed/compiler/inclusion/lookup_discoverer.go b/prover/protocol/distributed/compiler/inclusion/lookup_discoverer.go index 4859d8a40..2ccdba64c 100644 --- a/prover/protocol/distributed/compiler/inclusion/lookup_discoverer.go +++ b/prover/protocol/distributed/compiler/inclusion/lookup_discoverer.go @@ -14,11 +14,13 @@ import ( // compiler as a shorthand to make the code more eye-parseable. type table = []ifaces.Column -// captureModuleLookupTables inspects comp and looks for Inclusion queries that are relevant to the module. +// captureModuleLookupTables inspects comp and looks for Inclusion queries. // It groups the matched queries by lookup table and marks them as ignored. +// It creates a list of lookupTables and checkedTables present in the module. +// // The input is a compiledIOP object that stores the columns relevant to the module (in its Column field) // Note that for a lookup query the module may contain only S or T table (and not necessarily both). -func captureModuleLookupTables(comp *wizard.CompiledIOP) lookUp.MainLookupCtx { +func captureModuleLookupTables(moduleComp *wizard.CompiledIOP) lookUp.MainLookupCtx { var ( ctx = lookUp.MainLookupCtx{ LookupTables: [][]table{}, @@ -29,16 +31,16 @@ func captureModuleLookupTables(comp *wizard.CompiledIOP) lookUp.MainLookupCtx { ) // Collect all the lookup queries into "lookups" - for _, qName := range comp.QueriesNoParams.AllUnignoredKeys() { + for _, qName := range moduleComp.QueriesNoParams.AllUnignoredKeys() { // Filter out non-lookup queries - lookup, ok := comp.QueriesNoParams.Data(qName).(query.Inclusion) + lookup, ok := moduleComp.QueriesNoParams.Data(qName).(query.Inclusion) if !ok { continue } // Determine if the query is relevant to the module - relevantPart := determineRelevantPart(lookup, comp.Columns) + relevantPart := determineRelevantPart(lookup, moduleComp.Columns) if relevantPart == "" { continue } @@ -46,7 +48,7 @@ func captureModuleLookupTables(comp *wizard.CompiledIOP) lookUp.MainLookupCtx { // This ensures that the lookup query is not used again in the // compilation process. We know that the query was already ignored at // the beginning because we are iterating over the unignored keys. - comp.QueriesNoParams.MarkAsIgnored(qName) + moduleComp.QueriesNoParams.MarkAsIgnored(qName) var ( // checkedTable corresponds to the "included" table and lookupTable @@ -87,6 +89,11 @@ func captureModuleLookupTables(comp *wizard.CompiledIOP) lookUp.MainLookupCtx { } // Add only the relevant part to the context + if relevantPart == "Both" { + ctx.IncludedFilters[tableName] = append(ctx.IncludedFilters[tableName], includedFilter) + ctx.CheckedTables[tableName] = append(ctx.CheckedTables[tableName], checkedTable) + ctx.LookupTables = append(ctx.LookupTables, lookupTable) + } if relevantPart == "S" { ctx.IncludedFilters[tableName] = append(ctx.IncludedFilters[tableName], includedFilter) ctx.CheckedTables[tableName] = append(ctx.CheckedTables[tableName], checkedTable) @@ -94,7 +101,7 @@ func captureModuleLookupTables(comp *wizard.CompiledIOP) lookUp.MainLookupCtx { ctx.LookupTables = append(ctx.LookupTables, lookupTable) } - ctx.Rounds[tableName] = max(ctx.Rounds[tableName], comp.QueriesNoParams.Round(lookup.ID)) + ctx.Rounds[tableName] = max(ctx.Rounds[tableName], moduleComp.QueriesNoParams.Round(lookup.ID)) } @@ -103,15 +110,26 @@ func captureModuleLookupTables(comp *wizard.CompiledIOP) lookUp.MainLookupCtx { // determineRelevantPart checks if the lookup query involves columns from the module and returns the relevant part (S or T). func determineRelevantPart(lookup query.Inclusion, moduleColumns column.Store) string { + var hasT, hasS bool // Check if any column in S part is in the module if moduleColumns.Exists(lookup.Included[0].GetColID()) { - return "S" + hasS = true } // Check if any column in T part is in the module if moduleColumns.Exists(lookup.Including[0][0].GetColID()) { + hasT = true + } + if hasT && hasS { + return "Both" + } + if hasT && !hasS { return "T" } + if hasS && !hasT { + return "S" + } else { + panic("the module contains a query that is not relevant to any column") + } - return "" } diff --git a/prover/protocol/distributed/compiler/inclusion/preparation.go b/prover/protocol/distributed/compiler/inclusion/preparation.go index 0fb53adf8..3b4bdfff6 100644 --- a/prover/protocol/distributed/compiler/inclusion/preparation.go +++ b/prover/protocol/distributed/compiler/inclusion/preparation.go @@ -2,8 +2,8 @@ package inclusion import "github.com/consensys/linea-monorepo/prover/protocol/wizard" -// It scans the compiledIOP and group all the checkedTables related to the same lookupTable, +// It scans the initial compiledIOP and group all the checkedTables related to the same lookupTable, // It creates a frequency column M for all such pairs (checkedTables, lookupTable). -func prepareLookup(comp *wizard.CompiledIOP) { - +func prepareLookup(initalComp *wizard.CompiledIOP) { + panic("unimplemented") } From d60cc9f3323fc0d9b79c4a53739da04479904291 Mon Sep 17 00:00:00 2001 From: Soleimani193 Date: Thu, 12 Dec 2024 10:57:31 +0100 Subject: [PATCH 04/12] getting the multiplicity table from compild iop --- prover/protocol/compiler/lookup/compiler.go | 4 +- prover/protocol/compiler/lookup/context.go | 7 +--- .../distributed/compiler/inclusion/context.go | 41 +++++++++++++++++++ .../compiler/inclusion/inclusion.go | 26 ++++++++---- .../compiler/inclusion/lookup_discoverer.go | 30 ++++++++++---- 5 files changed, 83 insertions(+), 25 deletions(-) create mode 100644 prover/protocol/distributed/compiler/inclusion/context.go diff --git a/prover/protocol/compiler/lookup/compiler.go b/prover/protocol/compiler/lookup/compiler.go index 3bc9c4366..56cb6b343 100644 --- a/prover/protocol/compiler/lookup/compiler.go +++ b/prover/protocol/compiler/lookup/compiler.go @@ -131,9 +131,9 @@ func CompileLogDerivative(comp *wizard.CompiledIOP) { // The function also implictly reduces the conditionals over the Including table // be appending a "one" column on the included side and the filter on the // including side. -func captureLookupTables(comp *wizard.CompiledIOP) MainLookupCtx { +func captureLookupTables(comp *wizard.CompiledIOP) mainLookupCtx { - ctx := MainLookupCtx{ + ctx := mainLookupCtx{ LookupTables: [][]table{}, CheckedTables: map[string][]table{}, IncludedFilters: map[string][]ifaces.Column{}, diff --git a/prover/protocol/compiler/lookup/context.go b/prover/protocol/compiler/lookup/context.go index cfeb5d7c2..3f30808ad 100644 --- a/prover/protocol/compiler/lookup/context.go +++ b/prover/protocol/compiler/lookup/context.go @@ -6,9 +6,9 @@ import ( "github.com/consensys/linea-monorepo/prover/symbolic" ) -// MainLookupCtx stores the compilation context of all the lookup queries +// mainLookupCtx stores the compilation context of all the lookup queries // altogether. -type MainLookupCtx struct { +type mainLookupCtx struct { // LookupTables stores all the lookup table the compiler encounters. They are // sorted in canonical order. This used to derive a determistic ordering @@ -36,9 +36,6 @@ type MainLookupCtx struct { // round is obtained by taking the max of the declaration Rounds of the // Inclusion queries using the corresponding lookup table. Rounds map[string]int - - // the name of the tables - AllTableNames []string } // SingleTableCtx stores the compilation context for a single lookup query diff --git a/prover/protocol/distributed/compiler/inclusion/context.go b/prover/protocol/distributed/compiler/inclusion/context.go new file mode 100644 index 000000000..ceb39a79a --- /dev/null +++ b/prover/protocol/distributed/compiler/inclusion/context.go @@ -0,0 +1,41 @@ +package inclusion + +import "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + +// MainLookupCtx stores the compilation context of all the lookup queries +// altogether. +type mainLookupCtx struct { + + // LookupTables stores all the lookup table the compiler encounters. They are + // sorted in canonical order. This used to derive a determistic ordering + // of the lookup LookupTables. (We want to ensure the compiler yields always + // exactly the same result for replicability). + // + // To illustrates its structure, the following sub-statement + // + // table[numTable][frag] + // + // refers to to the fragment #frag of the the table #numTable. + LookupTables [][]table + + // CheckedTables stores all the checked column by lookup table. The key is + // obtained as nameTable(lookupTable) where lookup is sorted in + // canonical order. + CheckedTables map[string][]table + + // IncludedFilters stores all the filters for the checked columns and `nil` + // if no filter is applied. As for [checkedTables] they are stored by + // lookup table name and in the same order for each key. + IncludedFilters map[string][]ifaces.Column + + // Rounds stores the interaction round assigned to each lookupTable. The + // round is obtained by taking the max of the declaration Rounds of the + // Inclusion queries using the corresponding lookup table. + Rounds map[string]int + + // it stores the multiplicity of T is S. + // For the multiColum case it collapse T and S and then counts the multiplicity. + // note that here the collapsing does not need the same randomness as the compilation. + // since the multiplicity is the same w.r.t any randomness. + mTables map[string]table +} diff --git a/prover/protocol/distributed/compiler/inclusion/inclusion.go b/prover/protocol/distributed/compiler/inclusion/inclusion.go index f21737c94..513f926f9 100644 --- a/prover/protocol/distributed/compiler/inclusion/inclusion.go +++ b/prover/protocol/distributed/compiler/inclusion/inclusion.go @@ -23,10 +23,11 @@ import ( func CompileLogDerivative(moduleComp *wizard.CompiledIOP) { var ( + // capture the S and T columns relevant to the module. mainLookupCtx = captureModuleLookupTables(moduleComp) lastRound = moduleComp.NumRounds() - 1 - // zCatalog stores a mapping (round, size) into [query.LogDerivativeSumInput]. + // zCatalog stores a mapping from (round, size) into [query.LogDerivativeSumInput]. // it packs the sigma columns from the same (round,size) together. zCatalog = map[[2]int]*query.LogDerivativeSumInput{} ) @@ -44,14 +45,16 @@ func CompileLogDerivative(moduleComp *wizard.CompiledIOP) { var ( round = mainLookupCtx.Rounds[tableName] includedFilters = mainLookupCtx.IncludedFilters[tableName] - tableCtx = collapsMultiColsToSingleCol(moduleComp, round, lookupTable, checkedTables, includedFilters) + mTable = mainLookupCtx.mTables[tableName] + // it collapses multiColumn tables to single columns. + tableCtx = collapsMultiColToSingleCol(moduleComp, round, lookupTable, checkedTables, includedFilters, mTable) ) // push to zCatalog PushToZCatalog(tableCtx, zCatalog) } - // Handle cases where only T part is present + // Handle cases where only T part is present in the module for _, lookupTable := range mainLookupCtx.LookupTables { tableName := lookUp.NameTable(lookupTable) if _, ok := mainLookupCtx.CheckedTables[tableName]; ok { @@ -60,14 +63,15 @@ func CompileLogDerivative(moduleComp *wizard.CompiledIOP) { var ( round = mainLookupCtx.Rounds[tableName] - tableCtx = collapsMultiColsToSingleCol(moduleComp, round, lookupTable, nil, nil) + mTable = mainLookupCtx.mTables[tableName] + tableCtx = collapsMultiColToSingleCol(moduleComp, round, lookupTable, nil, nil, mTable) ) // push to zCatalog PushToZCatalog(tableCtx, zCatalog) } - // insert a LogDerivativeSum for all the Sigma Columns . + // insert a LogDerivativeSum for all the Sigma Columns in the module. moduleComp.InsertLogDerivativeSum(lastRound, "LogDerivativeSum", zCatalog) } @@ -81,13 +85,15 @@ func findLookupTableByName(lookupTables [][]table, name string) []table { return nil } -// It collapses the tables of MultiColumns to single columns. It also sample the Gamma coin for the rest of the compilation. -func collapsMultiColsToSingleCol( +// It collapses the tables of MultiColumns to single columns. +// It also sample the Gamma coin for the rest of the compilation. +func collapsMultiColToSingleCol( comp *wizard.CompiledIOP, round int, lookupTable []table, checkedTables []table, includedFilters []ifaces.Column, + mTable table, ) (ctx lookUp.SingleTableCtx) { ctx = lookUp.SingleTableCtx{ @@ -134,6 +140,8 @@ func collapsMultiColsToSingleCol( } } + ctx.M = mTable + ctx.Gamma = comp.InsertCoin( round+1, lookUp.DeriveTableName[coin.Name](lookUp.LogDerivativePrefix, lookupTable, "GAMMA"), @@ -143,8 +151,8 @@ func collapsMultiColsToSingleCol( return ctx } -// PushToZCatalog constructs the numerators and denominators for S and T of the -// stc into zCatalog for their corresponding rounds and size. +// PushToZCatalog constructs the numerators and denominators for the collapsed S and T +// into zCatalog, for their corresponding rounds and size. func PushToZCatalog(stc lookUp.SingleTableCtx, zCatalog map[[2]int]*query.LogDerivativeSumInput) { var ( diff --git a/prover/protocol/distributed/compiler/inclusion/lookup_discoverer.go b/prover/protocol/distributed/compiler/inclusion/lookup_discoverer.go index 2ccdba64c..52114278e 100644 --- a/prover/protocol/distributed/compiler/inclusion/lookup_discoverer.go +++ b/prover/protocol/distributed/compiler/inclusion/lookup_discoverer.go @@ -14,16 +14,17 @@ import ( // compiler as a shorthand to make the code more eye-parseable. type table = []ifaces.Column -// captureModuleLookupTables inspects comp and looks for Inclusion queries. +// captureModuleLookupTables inspects moduleComp and looks for Inclusion queries. // It groups the matched queries by lookup table and marks them as ignored. -// It creates a list of lookupTables and checkedTables present in the module. +// It creates a list of lookupTables, mTables and checkedTables present in the module. // // The input is a compiledIOP object that stores the columns relevant to the module (in its Column field) // Note that for a lookup query the module may contain only S or T table (and not necessarily both). -func captureModuleLookupTables(moduleComp *wizard.CompiledIOP) lookUp.MainLookupCtx { +func captureModuleLookupTables(moduleComp *wizard.CompiledIOP) mainLookupCtx { var ( - ctx = lookUp.MainLookupCtx{ + ctx = mainLookupCtx{ LookupTables: [][]table{}, + mTables: map[string]table{}, CheckedTables: map[string][]table{}, IncludedFilters: map[string][]ifaces.Column{}, Rounds: map[string]int{}, @@ -39,11 +40,8 @@ func captureModuleLookupTables(moduleComp *wizard.CompiledIOP) lookUp.MainLookup continue } - // Determine if the query is relevant to the module + // Determine the part of the query that is relevant to the module relevantPart := determineRelevantPart(lookup, moduleComp.Columns) - if relevantPart == "" { - continue - } // This ensures that the lookup query is not used again in the // compilation process. We know that the query was already ignored at @@ -85,6 +83,7 @@ func captureModuleLookupTables(moduleComp *wizard.CompiledIOP) lookUp.MainLookup ctx.IncludedFilters[tableName] = []ifaces.Column{} ctx.CheckedTables[tableName] = []table{} ctx.LookupTables = [][]table{} + ctx.mTables[tableName] = table{} ctx.Rounds[tableName] = 0 } @@ -93,12 +92,17 @@ func captureModuleLookupTables(moduleComp *wizard.CompiledIOP) lookUp.MainLookup ctx.IncludedFilters[tableName] = append(ctx.IncludedFilters[tableName], includedFilter) ctx.CheckedTables[tableName] = append(ctx.CheckedTables[tableName], checkedTable) ctx.LookupTables = append(ctx.LookupTables, lookupTable) + // get the M table from the moduleComp and add it to the mTables. + ctx.mTables[tableName] = mTable(moduleComp, lookupTable) + } if relevantPart == "S" { ctx.IncludedFilters[tableName] = append(ctx.IncludedFilters[tableName], includedFilter) ctx.CheckedTables[tableName] = append(ctx.CheckedTables[tableName], checkedTable) } else if relevantPart == "T" { ctx.LookupTables = append(ctx.LookupTables, lookupTable) + // get the M table from the moduleComp and add it to the mTables. + ctx.mTables[tableName] = mTable(moduleComp, lookupTable) } ctx.Rounds[tableName] = max(ctx.Rounds[tableName], moduleComp.QueriesNoParams.Round(lookup.ID)) @@ -129,7 +133,15 @@ func determineRelevantPart(lookup query.Inclusion, moduleColumns column.Store) s if hasS && !hasT { return "S" } else { - panic("the module contains a query that is not relevant to any column") + panic("the module contains a query that is not relevant to any module-column") } } + +func mTable(comp *wizard.CompiledIOP, t [][]ifaces.Column) (m []ifaces.Column) { + for frag := range t[0] { + id := ifaces.ColIDf("%v_%v_%v", lookUp.NameTable(t), "M", frag) + m = append(m, comp.Columns.GetHandle(id)) + } + return m +} From dceec4fbf22946a1ac973907af869d23dc90e89e Mon Sep 17 00:00:00 2001 From: Soleimani193 Date: Thu, 12 Dec 2024 15:31:51 +0100 Subject: [PATCH 05/12] added preparation for lookup --- prover/protocol/compiler/lookup/compiler.go | 6 +-- prover/protocol/distributed/distributed.go | 14 ------- prover/protocol/distributed/preparation.go | 43 +++++++++++++++++++++ 3 files changed, 46 insertions(+), 17 deletions(-) create mode 100644 prover/protocol/distributed/preparation.go diff --git a/prover/protocol/compiler/lookup/compiler.go b/prover/protocol/compiler/lookup/compiler.go index 56cb6b343..40dcb3735 100644 --- a/prover/protocol/compiler/lookup/compiler.go +++ b/prover/protocol/compiler/lookup/compiler.go @@ -27,7 +27,7 @@ type table = []ifaces.Column func CompileLogDerivative(comp *wizard.CompiledIOP) { var ( - mainLookupCtx = captureLookupTables(comp) + mainLookupCtx = CaptureLookupTables(comp) lastRound = comp.NumRounds() - 1 proverActions = make([]proverTaskAtRound, comp.NumRounds()+1) // zCatalog stores a mapping (round, size) into ZCtx and helps finding @@ -119,7 +119,7 @@ func CompileLogDerivative(comp *wizard.CompiledIOP) { comp.RegisterVerifierAction(lastRound, &va) } -// captureLookupTables inspects comp and look for Inclusion queries that are not +// CaptureLookupTables inspects comp and look for Inclusion queries that are not // marked as ignored yet. All the queries matched queries are grouped by look-up // table (e.g. all the queries that use the same lookup table). All the matched // queries are marked as ignored. The function returns the thereby-initialized @@ -131,7 +131,7 @@ func CompileLogDerivative(comp *wizard.CompiledIOP) { // The function also implictly reduces the conditionals over the Including table // be appending a "one" column on the included side and the filter on the // including side. -func captureLookupTables(comp *wizard.CompiledIOP) mainLookupCtx { +func CaptureLookupTables(comp *wizard.CompiledIOP) mainLookupCtx { ctx := mainLookupCtx{ LookupTables: [][]table{}, diff --git a/prover/protocol/distributed/distributed.go b/prover/protocol/distributed/distributed.go index dbe01f6f2..58b3ce257 100644 --- a/prover/protocol/distributed/distributed.go +++ b/prover/protocol/distributed/distributed.go @@ -1,9 +1,6 @@ package distributed import ( - "github.com/consensys/linea-monorepo/prover/protocol/compiler/innerproduct" - "github.com/consensys/linea-monorepo/prover/protocol/compiler/mimc" - "github.com/consensys/linea-monorepo/prover/protocol/compiler/specialqueries" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/protocol/wizard" @@ -118,14 +115,3 @@ func addToGlobalLocal(comp *wizard.CompiledIOP, q ifaces.Query) { func aggregator(distModules []DistributedModule, maxNumSegments int) *wizard.CompiledIOP { panic("unimplemented") } - -// prepare reduces any query to LPP or GL. -// it prepares the columns that depends on whole the witness,e.g., M column for lookups. -func prepare(comp *wizard.CompiledIOP) { - mimc.CompileMiMC(comp) - specialqueries.RangeProof(comp) - specialqueries.CompileFixedPermutations(comp) - innerproduct.Compile(comp) - - // prepareLookup(comp) -} diff --git a/prover/protocol/distributed/preparation.go b/prover/protocol/distributed/preparation.go new file mode 100644 index 000000000..15510c6f4 --- /dev/null +++ b/prover/protocol/distributed/preparation.go @@ -0,0 +1,43 @@ +package distributed + +import ( + "github.com/consensys/linea-monorepo/prover/protocol/compiler/innerproduct" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/lookup" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/mimc" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/specialqueries" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" +) + +// prepare reduces any query to LPP or GL. +// it prepares the columns that depends on whole the witness,e.g., M column for lookups. +func prepare(comp *wizard.CompiledIOP) { + mimc.CompileMiMC(comp) + specialqueries.RangeProof(comp) + specialqueries.CompileFixedPermutations(comp) + innerproduct.Compile(comp) + + prepareLookup(comp) +} + +// It scans the initial compiledIOP and group all the checkedTables related to the same lookupTable, +// It creates a multiplicity column M for all such pairs (checkedTables, lookupTable). +func prepareLookup(comp *wizard.CompiledIOP) { + mainLookupCtx := lookup.CaptureLookupTables(comp) + + for _, lookupTable := range mainLookupCtx.LookupTables { + var ( + tableName = lookup.NameTable(lookupTable) + round = mainLookupCtx.Rounds[tableName] + size = lookupTable[0][0].Size() + fragNum = len(lookupTable[0]) + mTable = make([]ifaces.Column, fragNum) + ) + for frag := range lookupTable[0] { + + mTableID := ifaces.ColIDf("%v_%v_%v", tableName, "M", frag) + mTable[frag] = comp.InsertCommit(round, mTableID, size) + } + + } +} From 17dd068bd13ed455217214c21280aad22d0c852e Mon Sep 17 00:00:00 2001 From: Soleimani193 Date: Thu, 12 Dec 2024 15:35:19 +0100 Subject: [PATCH 06/12] minor --- .../distributed/compiler/inclusion/preparation.go | 9 --------- 1 file changed, 9 deletions(-) delete mode 100644 prover/protocol/distributed/compiler/inclusion/preparation.go diff --git a/prover/protocol/distributed/compiler/inclusion/preparation.go b/prover/protocol/distributed/compiler/inclusion/preparation.go deleted file mode 100644 index 3b4bdfff6..000000000 --- a/prover/protocol/distributed/compiler/inclusion/preparation.go +++ /dev/null @@ -1,9 +0,0 @@ -package inclusion - -import "github.com/consensys/linea-monorepo/prover/protocol/wizard" - -// It scans the initial compiledIOP and group all the checkedTables related to the same lookupTable, -// It creates a frequency column M for all such pairs (checkedTables, lookupTable). -func prepareLookup(initalComp *wizard.CompiledIOP) { - panic("unimplemented") -} From 08c76d802e22f4ac6b149167ee892db6744dd118 Mon Sep 17 00:00:00 2001 From: Soleimani193 Date: Fri, 13 Dec 2024 16:17:11 +0100 Subject: [PATCH 07/12] added IntoLogDerivationSum during the preparation --- prover/protocol/compiler/lookup/compiler.go | 4 +- .../distributed/compiler/inclusion/context.go | 41 ---- .../compiler/inclusion/inclusion.go | 229 ++++-------------- .../compiler/inclusion/lookup_discoverer.go | 147 ----------- .../distributed/compiler/inclusion/prover.go | 125 ---------- prover/protocol/distributed/distributed.go | 2 + prover/protocol/distributed/preparation.go | 105 ++++++-- 7 files changed, 144 insertions(+), 509 deletions(-) delete mode 100644 prover/protocol/distributed/compiler/inclusion/context.go delete mode 100644 prover/protocol/distributed/compiler/inclusion/lookup_discoverer.go delete mode 100644 prover/protocol/distributed/compiler/inclusion/prover.go diff --git a/prover/protocol/compiler/lookup/compiler.go b/prover/protocol/compiler/lookup/compiler.go index 40dcb3735..02123da11 100644 --- a/prover/protocol/compiler/lookup/compiler.go +++ b/prover/protocol/compiler/lookup/compiler.go @@ -55,7 +55,7 @@ func CompileLogDerivative(comp *wizard.CompiledIOP) { checkTable = mainLookupCtx.CheckedTables[lookupTableName] round = mainLookupCtx.Rounds[lookupTableName] includedFilters = mainLookupCtx.IncludedFilters[lookupTableName] - tableCtx = compileLookupTable(comp, round, lookupTable, checkTable, includedFilters) + tableCtx = CompileLookupTable(comp, round, lookupTable, checkTable, includedFilters) ) // push to zCatalog @@ -213,7 +213,7 @@ func CaptureLookupTables(comp *wizard.CompiledIOP) mainLookupCtx { // - (5) The verier makes a `Global` query : $\left((\Sigma_T)[i] - (\Sigma_T)[i-1]\right)(T_i + \gamma) = M_i$ // here we are looking up set of columns S in a single column T -func compileLookupTable( +func CompileLookupTable( comp *wizard.CompiledIOP, round int, lookupTable []table, diff --git a/prover/protocol/distributed/compiler/inclusion/context.go b/prover/protocol/distributed/compiler/inclusion/context.go deleted file mode 100644 index ceb39a79a..000000000 --- a/prover/protocol/distributed/compiler/inclusion/context.go +++ /dev/null @@ -1,41 +0,0 @@ -package inclusion - -import "github.com/consensys/linea-monorepo/prover/protocol/ifaces" - -// MainLookupCtx stores the compilation context of all the lookup queries -// altogether. -type mainLookupCtx struct { - - // LookupTables stores all the lookup table the compiler encounters. They are - // sorted in canonical order. This used to derive a determistic ordering - // of the lookup LookupTables. (We want to ensure the compiler yields always - // exactly the same result for replicability). - // - // To illustrates its structure, the following sub-statement - // - // table[numTable][frag] - // - // refers to to the fragment #frag of the the table #numTable. - LookupTables [][]table - - // CheckedTables stores all the checked column by lookup table. The key is - // obtained as nameTable(lookupTable) where lookup is sorted in - // canonical order. - CheckedTables map[string][]table - - // IncludedFilters stores all the filters for the checked columns and `nil` - // if no filter is applied. As for [checkedTables] they are stored by - // lookup table name and in the same order for each key. - IncludedFilters map[string][]ifaces.Column - - // Rounds stores the interaction round assigned to each lookupTable. The - // round is obtained by taking the max of the declaration Rounds of the - // Inclusion queries using the corresponding lookup table. - Rounds map[string]int - - // it stores the multiplicity of T is S. - // For the multiColum case it collapse T and S and then counts the multiplicity. - // note that here the collapsing does not need the same randomness as the compilation. - // since the multiplicity is the same w.r.t any randomness. - mTables map[string]table -} diff --git a/prover/protocol/distributed/compiler/inclusion/inclusion.go b/prover/protocol/distributed/compiler/inclusion/inclusion.go index 513f926f9..f30d503e8 100644 --- a/prover/protocol/distributed/compiler/inclusion/inclusion.go +++ b/prover/protocol/distributed/compiler/inclusion/inclusion.go @@ -1,203 +1,74 @@ package inclusion import ( - "github.com/consensys/linea-monorepo/prover/protocol/coin" - lookUp "github.com/consensys/linea-monorepo/prover/protocol/compiler/lookup" + "github.com/consensys/linea-monorepo/prover/protocol/distributed" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/protocol/wizard" - "github.com/consensys/linea-monorepo/prover/protocol/wizardutils" "github.com/consensys/linea-monorepo/prover/symbolic" ) -// CompileLogDerivative scans `comp`, looking for Inclusion queries and compiles -// them using the LogDerivativeLookup technique. The compiler attempts to group -// queries relating to the same table (as such groups needs the same randomness). -// -// The input is a wizard.CompiledIOP object relevant to the module. -// It contains a list of the columns relevant to the module (inside its Columns field). -// -// For each T column inside the module, it also contains the M column. -// -// Note that for a lookup query the module may contain only the S or T columns (and not both). -func CompileLogDerivative(moduleComp *wizard.CompiledIOP) { - - var ( - // capture the S and T columns relevant to the module. - mainLookupCtx = captureModuleLookupTables(moduleComp) - lastRound = moduleComp.NumRounds() - 1 - - // zCatalog stores a mapping from (round, size) into [query.LogDerivativeSumInput]. - // it packs the sigma columns from the same (round,size) together. - zCatalog = map[[2]int]*query.LogDerivativeSumInput{} - ) - - // Skip the compilation phase if no lookup constraint is being used. Otherwise - // it will register a verifier action that is not required and will be bugged. - if len(mainLookupCtx.LookupTables) == 0 && len(mainLookupCtx.CheckedTables) == 0 { - return - } - - // Step 1. construct the "per table" contexts and pack the Sigma's into - // zCatalog. - for tableName, checkedTables := range mainLookupCtx.CheckedTables { - lookupTable := findLookupTableByName(mainLookupCtx.LookupTables, tableName) - var ( - round = mainLookupCtx.Rounds[tableName] - includedFilters = mainLookupCtx.IncludedFilters[tableName] - mTable = mainLookupCtx.mTables[tableName] - // it collapses multiColumn tables to single columns. - tableCtx = collapsMultiColToSingleCol(moduleComp, round, lookupTable, checkedTables, includedFilters, mTable) - ) - - // push to zCatalog - PushToZCatalog(tableCtx, zCatalog) - } - - // Handle cases where only T part is present in the module - for _, lookupTable := range mainLookupCtx.LookupTables { - tableName := lookUp.NameTable(lookupTable) - if _, ok := mainLookupCtx.CheckedTables[tableName]; ok { - continue - } - - var ( - round = mainLookupCtx.Rounds[tableName] - mTable = mainLookupCtx.mTables[tableName] - tableCtx = collapsMultiColToSingleCol(moduleComp, round, lookupTable, nil, nil, mTable) - ) - - // push to zCatalog - PushToZCatalog(tableCtx, zCatalog) - } +const ( + LogDerivativeSum = "LOGDERIVATIVE_SUM" +) - // insert a LogDerivativeSum for all the Sigma Columns in the module. - moduleComp.InsertLogDerivativeSum(lastRound, "LogDerivativeSum", zCatalog) +// distributionInputs stores the input required for the distribution of a LogDerivativeSum query. +type distributionInputs struct { + comp *wizard.CompiledIOP + // module Discoverer used to detect the relevant part of the query to the module + disc distributed.ModuleDiscoverer + // Name of the module + moduleName string + // query is supposed to be the global LogDerivativeSum. + queryID ifaces.QueryID } -// findLookupTableByName searches for a lookup table by its name in the list of lookup tables. -func findLookupTableByName(lookupTables [][]table, name string) []table { - for _, lookupTable := range lookupTables { - if lookUp.NameTable(lookupTable) == name { - return lookupTable - } - } - return nil -} - -// It collapses the tables of MultiColumns to single columns. -// It also sample the Gamma coin for the rest of the compilation. -func collapsMultiColToSingleCol( - comp *wizard.CompiledIOP, - round int, - lookupTable []table, - checkedTables []table, - includedFilters []ifaces.Column, - mTable table, -) (ctx lookUp.SingleTableCtx) { - - ctx = lookUp.SingleTableCtx{ - TableName: lookUp.NameTable(lookupTable), - S: make([]*symbolic.Expression, len(checkedTables)), - SFilters: includedFilters, - T: make([]*symbolic.Expression, len(lookupTable)), - M: make([]ifaces.Column, len(lookupTable)), - } - +// DistributeLogDerivativeSum distribute the LogDerivativeSum among the modules. +// It detect the relevant share of the module from the global LogDerivativeSum. +// It generates a new LogDerivateSum query relevant to the module. +func DistributeLogDerivativeSum(in distributionInputs) { var ( - // isMultiColumn indicates whether the lookup table (and thus the - // checked tables) have the same number of - isMultiColumn = (len(lookupTable) > 0 && len(lookupTable[0]) > 1) || (len(checkedTables) > 0 && len(checkedTables[0]) > 1) + comp = in.comp + numerator []*symbolic.Expression + denominator []*symbolic.Expression + zCatalog map[[2]int]*query.LogDerivativeSumInput + lastRound = in.comp.NumRounds() - 1 ) - - if !isMultiColumn { - for frag := range ctx.T { - ctx.T[frag] = symbolic.NewVariable(lookupTable[frag][0]) - - } - - for i := range ctx.S { - ctx.S[i] = symbolic.NewVariable(checkedTables[i][0]) - } + // check that the given query is a valid LogDerivateSum query in the CompiledIOP. + logDeriv, ok := comp.QueriesParams.Data(in.queryID).(query.LogDerivativeSum) + if !ok { + panic("the given query is not a valid LogDerivativeSum from the compiledIOP") } - if isMultiColumn { - - // alpha is the coin used to compute the linear combination of the - // columns of T and S when they are (both) multi-columns. - alpha := comp.InsertCoin( - round+1, - lookUp.DeriveTableName[coin.Name](lookUp.LogDerivativePrefix, lookupTable, "ALPHA"), - coin.Field, - ) - - for frag := range ctx.T { - ctx.T[frag] = wizardutils.RandLinCombColSymbolic(alpha, lookupTable[frag]) - } - - for i := range ctx.S { - ctx.S[i] = wizardutils.RandLinCombColSymbolic(alpha, checkedTables[i]) - } - } - - ctx.M = mTable - - ctx.Gamma = comp.InsertCoin( - round+1, - lookUp.DeriveTableName[coin.Name](lookUp.LogDerivativePrefix, lookupTable, "GAMMA"), - coin.Field, - ) - - return ctx -} - -// PushToZCatalog constructs the numerators and denominators for the collapsed S and T -// into zCatalog, for their corresponding rounds and size. -func PushToZCatalog(stc lookUp.SingleTableCtx, zCatalog map[[2]int]*query.LogDerivativeSumInput) { + // This ensures that the logDerivative query is not used again in the + // compilation process. + comp.QueriesNoParams.MarkAsIgnored(in.queryID) - var ( - round = stc.Gamma.Round - ) - - // tableCtx push to -> zCtx - // Process the T columns - for frag := range stc.T { - size := stc.M[frag].Size() - - key := [2]int{round, size} - if zCatalog[key] == nil { - zCatalog[key] = &query.LogDerivativeSumInput{ - Size: size, - Round: round, + // extract the share of the module from the global sum. + for key := range logDeriv.Inputs { + for i := range logDeriv.Inputs[key].Numerator { + if in.disc.ExpressionIsInModule(logDeriv.Inputs[key].Numerator[i], in.moduleName) { + numerator = append(numerator, logDeriv.Inputs[key].Numerator[i]) + } + if in.disc.ExpressionIsInModule(logDeriv.Inputs[key].Denominator[i], in.moduleName) { + denominator = append(denominator, logDeriv.Inputs[key].Denominator[i]) } } - zCtxEntry := zCatalog[key] - zCtxEntry.Numerator = append(zCtxEntry.Numerator, symbolic.Neg(stc.M[frag])) // no functions for num, denom here - zCtxEntry.Denominator = append(zCtxEntry.Denominator, symbolic.Add(stc.Gamma, stc.T[frag])) - } - - // Process the S columns - for table := range stc.S { - var ( - _, _, size = wizardutils.AsExpr(stc.S[table]) - sFilter = symbolic.NewConstant(1) - ) - - if stc.SFilters[table] != nil { - sFilter = symbolic.NewVariable(stc.SFilters[table]) - } - - key := [2]int{round, size} - if zCatalog[key] == nil { - zCatalog[key] = &query.LogDerivativeSumInput{ - Size: size, - Round: round, - } + // zCatalog specific to the module + zCatalog[key] = &query.LogDerivativeSumInput{ + Round: key[0], + Size: key[1], + Numerator: numerator, + Denominator: denominator, } - zCtxEntry := zCatalog[key] - zCtxEntry.Numerator = append(zCtxEntry.Numerator, sFilter) - zCtxEntry.Denominator = append(zCtxEntry.Denominator, symbolic.Add(stc.Gamma, stc.S[table])) } + + // insert a LogDerivativeSum specific to the module. + comp.InsertLogDerivativeSum( + lastRound, + ifaces.QueryIDf("%v_%v", LogDerivativeSum, in.moduleName), + zCatalog, + ) } diff --git a/prover/protocol/distributed/compiler/inclusion/lookup_discoverer.go b/prover/protocol/distributed/compiler/inclusion/lookup_discoverer.go deleted file mode 100644 index 52114278e..000000000 --- a/prover/protocol/distributed/compiler/inclusion/lookup_discoverer.go +++ /dev/null @@ -1,147 +0,0 @@ -package inclusion - -import ( - "github.com/consensys/linea-monorepo/prover/maths/field" - "github.com/consensys/linea-monorepo/prover/protocol/column" - "github.com/consensys/linea-monorepo/prover/protocol/column/verifiercol" - lookUp "github.com/consensys/linea-monorepo/prover/protocol/compiler/lookup" - "github.com/consensys/linea-monorepo/prover/protocol/ifaces" - "github.com/consensys/linea-monorepo/prover/protocol/query" - "github.com/consensys/linea-monorepo/prover/protocol/wizard" -) - -// table is an alias for a list of column. We use it in the scope of the lookup -// compiler as a shorthand to make the code more eye-parseable. -type table = []ifaces.Column - -// captureModuleLookupTables inspects moduleComp and looks for Inclusion queries. -// It groups the matched queries by lookup table and marks them as ignored. -// It creates a list of lookupTables, mTables and checkedTables present in the module. -// -// The input is a compiledIOP object that stores the columns relevant to the module (in its Column field) -// Note that for a lookup query the module may contain only S or T table (and not necessarily both). -func captureModuleLookupTables(moduleComp *wizard.CompiledIOP) mainLookupCtx { - var ( - ctx = mainLookupCtx{ - LookupTables: [][]table{}, - mTables: map[string]table{}, - CheckedTables: map[string][]table{}, - IncludedFilters: map[string][]ifaces.Column{}, - Rounds: map[string]int{}, - } - ) - - // Collect all the lookup queries into "lookups" - for _, qName := range moduleComp.QueriesNoParams.AllUnignoredKeys() { - - // Filter out non-lookup queries - lookup, ok := moduleComp.QueriesNoParams.Data(qName).(query.Inclusion) - if !ok { - continue - } - - // Determine the part of the query that is relevant to the module - relevantPart := determineRelevantPart(lookup, moduleComp.Columns) - - // This ensures that the lookup query is not used again in the - // compilation process. We know that the query was already ignored at - // the beginning because we are iterating over the unignored keys. - moduleComp.QueriesNoParams.MarkAsIgnored(qName) - - var ( - // checkedTable corresponds to the "included" table and lookupTable - // corresponds to the including table. - checkedTable, lookupTable = lookUp.GetTableCanonicalOrder(lookup) - tableName = lookUp.NameTable(lookupTable) - // includedFilters stores the query.IncludedFilter parameter. If the - // query has no includedFilters on the Included side. Then this is - // left as nil. - includedFilter ifaces.Column - ) - - if lookup.IsFilteredOnIncluding() { - var ( - checkedLen = checkedTable[0].Size() - ones = verifiercol.NewConstantCol(field.One(), checkedLen) - ) - - checkedTable = append([]ifaces.Column{ones}, checkedTable...) - for frag := range lookupTable { - lookupTable[frag] = append([]ifaces.Column{lookup.IncludingFilter[frag]}, lookupTable[frag]...) - } - - tableName = lookUp.NameTable(lookupTable) - } - - if lookup.IsFilteredOnIncluded() { - includedFilter = lookup.IncludedFilter - } - - // In case this is the first iteration where we encounter the lookupTable - // we need to add entries in the registering maps. - if _, ok := ctx.CheckedTables[tableName]; !ok { - ctx.IncludedFilters[tableName] = []ifaces.Column{} - ctx.CheckedTables[tableName] = []table{} - ctx.LookupTables = [][]table{} - ctx.mTables[tableName] = table{} - ctx.Rounds[tableName] = 0 - } - - // Add only the relevant part to the context - if relevantPart == "Both" { - ctx.IncludedFilters[tableName] = append(ctx.IncludedFilters[tableName], includedFilter) - ctx.CheckedTables[tableName] = append(ctx.CheckedTables[tableName], checkedTable) - ctx.LookupTables = append(ctx.LookupTables, lookupTable) - // get the M table from the moduleComp and add it to the mTables. - ctx.mTables[tableName] = mTable(moduleComp, lookupTable) - - } - if relevantPart == "S" { - ctx.IncludedFilters[tableName] = append(ctx.IncludedFilters[tableName], includedFilter) - ctx.CheckedTables[tableName] = append(ctx.CheckedTables[tableName], checkedTable) - } else if relevantPart == "T" { - ctx.LookupTables = append(ctx.LookupTables, lookupTable) - // get the M table from the moduleComp and add it to the mTables. - ctx.mTables[tableName] = mTable(moduleComp, lookupTable) - } - - ctx.Rounds[tableName] = max(ctx.Rounds[tableName], moduleComp.QueriesNoParams.Round(lookup.ID)) - - } - - return ctx -} - -// determineRelevantPart checks if the lookup query involves columns from the module and returns the relevant part (S or T). -func determineRelevantPart(lookup query.Inclusion, moduleColumns column.Store) string { - var hasT, hasS bool - // Check if any column in S part is in the module - if moduleColumns.Exists(lookup.Included[0].GetColID()) { - hasS = true - } - - // Check if any column in T part is in the module - if moduleColumns.Exists(lookup.Including[0][0].GetColID()) { - hasT = true - } - if hasT && hasS { - return "Both" - } - if hasT && !hasS { - return "T" - } - if hasS && !hasT { - return "S" - } else { - panic("the module contains a query that is not relevant to any module-column") - } - -} - -func mTable(comp *wizard.CompiledIOP, t [][]ifaces.Column) (m []ifaces.Column) { - for frag := range t[0] { - id := ifaces.ColIDf("%v_%v_%v", lookUp.NameTable(t), "M", frag) - m = append(m, comp.Columns.GetHandle(id)) - } - return m -} diff --git a/prover/protocol/distributed/compiler/inclusion/prover.go b/prover/protocol/distributed/compiler/inclusion/prover.go deleted file mode 100644 index 2cc2a8886..000000000 --- a/prover/protocol/distributed/compiler/inclusion/prover.go +++ /dev/null @@ -1,125 +0,0 @@ -package inclusion - -import ( - "runtime/debug" - "sync" - - sv "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" - "github.com/consensys/linea-monorepo/prover/maths/common/vector" - "github.com/consensys/linea-monorepo/prover/maths/field" - lookUp "github.com/consensys/linea-monorepo/prover/protocol/compiler/lookup" - "github.com/consensys/linea-monorepo/prover/protocol/wizard" - "github.com/consensys/linea-monorepo/prover/protocol/wizardutils" - "github.com/consensys/linea-monorepo/prover/utils" - "github.com/consensys/linea-monorepo/prover/utils/parallel" -) - -// proverTaskAtRound implements the [wizard.ProverAction] interface. It gathers -// all the operations related to all compiled tables altogether that have to be -// done at a particular round. -// -// Namely, if applied to the round N. The action will be responsible for -// assigning the M column for tables compiled on round N and the SigmaS/SigmaT -// and their respective LocalOpening for the tables compiled at round N-1. -// -// All these actions are performed in parallel. -type proverTaskAtRound struct { - - // ZAssignmentTasks lists all the tasks consisting of assigning the - // columns SigmaS and SigmaT for the given round. - ZAssignmentTasks []zAssignmentTask -} - -// Run implements the [wizard.ProverAction interface]. The tasks will spawn -// a goroutine for each tasks and wait for all of them to finish. The approach -// for parallelization can be justified if the number of go-routines stays low -// (e.g. less than 1000s). -func (p proverTaskAtRound) Run(run *wizard.ProverRuntime) { - - wg := &sync.WaitGroup{} - wg.Add(p.numTasks()) - - var ( - panicTrace []byte - panicMsg any - panicOnce = &sync.Once{} - ) - - for i := range p.ZAssignmentTasks { - // the passing of the index `i` is there to ensure that the go-routine - // is running over a local copy of `i` which is not incremented every - // time the loop goes to the next iteration. - go func(i int) { - - // In case the subtask panics, we recover so that we can repanic in - // the main goroutine. Simplifying the process of tracing back the - // error and allowing to test the panics. - defer func() { - if r := recover(); r != nil { - panicOnce.Do(func() { - panicMsg = r - panicTrace = debug.Stack() - }) - } - - wg.Done() - }() - - p.ZAssignmentTasks[i].run(run) - }(i) - } - - wg.Wait() - - if len(panicTrace) > 0 { - utils.Panic("Had a panic: %v\nStack: %v\n", panicMsg, string(panicTrace)) - } -} - -// pushZAssignment appends an [sigmaAssignmentTask] to the list of tasks -func (p *proverTaskAtRound) pushZAssignment(s zAssignmentTask) { - p.ZAssignmentTasks = append(p.ZAssignmentTasks, s) -} - -// numTasks returns the total number of tasks that are scheduled in the -// [proverTaskAtRound]. -func (p *proverTaskAtRound) numTasks() int { - return len(p.ZAssignmentTasks) -} - -// zAssignmentTask represents a prover task of assignming the columns -// SigmaS and SigmaT for a specific lookup table. -// sigmaAssignment -type zAssignmentTask lookUp.ZCtx - -func (z zAssignmentTask) run(run *wizard.ProverRuntime) { - parallel.Execute(len(z.ZDenominatorBoarded), func(start, stop int) { - for frag := start; frag < stop; frag++ { - - var ( - numeratorMetadata = z.ZNumeratorBoarded[frag].ListVariableMetadata() - denominator = wizardutils.EvalExprColumn(run, z.ZDenominatorBoarded[frag]).IntoRegVecSaveAlloc() - numerator []field.Element - packedZ = field.BatchInvert(denominator) - ) - - if len(numeratorMetadata) == 0 { - numerator = vector.Repeat(field.One(), z.Size) - } - - if len(numeratorMetadata) > 0 { - numerator = wizardutils.EvalExprColumn(run, z.ZNumeratorBoarded[frag]).IntoRegVecSaveAlloc() - } - - for k := range packedZ { - packedZ[k].Mul(&numerator[k], &packedZ[k]) - if k > 0 { - packedZ[k].Add(&packedZ[k], &packedZ[k-1]) - } - } - - run.AssignColumn(z.Zs[frag].GetColID(), sv.NewRegular(packedZ)) - run.AssignLocalPoint(z.ZOpenings[frag].ID, packedZ[len(packedZ)-1]) - } - }) -} diff --git a/prover/protocol/distributed/distributed.go b/prover/protocol/distributed/distributed.go index 58b3ce257..63097eabc 100644 --- a/prover/protocol/distributed/distributed.go +++ b/prover/protocol/distributed/distributed.go @@ -4,6 +4,7 @@ import ( "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/symbolic" ) type moduleName = string @@ -33,6 +34,7 @@ type ModuleDiscoverer interface { FindModule(col ifaces.Column) moduleName // given a query and a module name it checks if the query is inside the module QueryIsInModule(ifaces.Query, moduleName) bool + ExpressionIsInModule(*symbolic.Expression, moduleName) bool } // This transforms the initial wizard. So it is not really the initial diff --git a/prover/protocol/distributed/preparation.go b/prover/protocol/distributed/preparation.go index 15510c6f4..9784129b4 100644 --- a/prover/protocol/distributed/preparation.go +++ b/prover/protocol/distributed/preparation.go @@ -5,39 +5,114 @@ import ( "github.com/consensys/linea-monorepo/prover/protocol/compiler/lookup" "github.com/consensys/linea-monorepo/prover/protocol/compiler/mimc" "github.com/consensys/linea-monorepo/prover/protocol/compiler/specialqueries" - "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/protocol/wizardutils" + "github.com/consensys/linea-monorepo/prover/symbolic" ) // prepare reduces any query to LPP or GL. -// it prepares the columns that depends on whole the witness,e.g., M column for lookups. +// For Inclusion query, it push the compilation one step further: +// all the inclusion queries are compiled into a LogDarivativeSum query, +// This is required due to the challenge that table M depends on whole the witness. func prepare(comp *wizard.CompiledIOP) { mimc.CompileMiMC(comp) specialqueries.RangeProof(comp) specialqueries.CompileFixedPermutations(comp) innerproduct.Compile(comp) - prepareLookup(comp) + IntoLogDerivativeSum(comp) } -// It scans the initial compiledIOP and group all the checkedTables related to the same lookupTable, -// It creates a multiplicity column M for all such pairs (checkedTables, lookupTable). -func prepareLookup(comp *wizard.CompiledIOP) { - mainLookupCtx := lookup.CaptureLookupTables(comp) +// IIntoLogDerivativeSum compiles all the inclusion queries to a single LogDerivativeSum query that is ready for the split. +// This step is necessary for inclusion, +// as the M table depends on the whole witness and so can not be handled modules-wise without changing the API of WizardIOP. +func IntoLogDerivativeSum(comp *wizard.CompiledIOP) { + var ( + mainLookupCtx = lookup.CaptureLookupTables(comp) + lastRound = comp.NumRounds() - 1 + // zCatalog stores a mapping (round, size) into query.LogDerivativeSumInput and helps finding + // which Z context should be used to handle a part of a given inclusion + // query. + zCatalog = map[[2]int]*query.LogDerivativeSumInput{} + ) + + // Skip the compilation phase if no lookup constraint is being used. Otherwise + // it will register a verifier action that is not required and will be bugged. + if len(mainLookupCtx.LookupTables) == 0 { + return + } + + // Step 1. construct the "per table" contexts and pack the Sigma's into + // zCatalog. for _, lookupTable := range mainLookupCtx.LookupTables { + + var ( + // get checkedTables, rounds, Filters by lookupTableName + lookupTableName = lookup.NameTable(lookupTable) + checkTable = mainLookupCtx.CheckedTables[lookupTableName] + round = mainLookupCtx.Rounds[lookupTableName] + includedFilters = mainLookupCtx.IncludedFilters[lookupTableName] + // collapse multiColumns to single Columns + tableCtx = lookup.CompileLookupTable(comp, round, lookupTable, checkTable, includedFilters) + ) + + // push single-columns into zCatalog + PushToZCatalog(tableCtx, zCatalog) + + // insert a single LogDerivativeSum query for the global zCatalog. + comp.InsertLogDerivativeSum(lastRound, "GlobalLogDerivativeSum", zCatalog) + } +} + +// PushToZCatalog constructs the numerators and denominators for the collapsed S and T +// into zCatalog, for their corresponding rounds and size. +func PushToZCatalog(stc lookup.SingleTableCtx, zCatalog map[[2]int]*query.LogDerivativeSumInput) { + + var ( + round = stc.Gamma.Round + ) + + // tableCtx push to -> zCtx + // Process the T columns + for frag := range stc.T { + size := stc.M[frag].Size() + + key := [2]int{round, size} + if zCatalog[key] == nil { + zCatalog[key] = &query.LogDerivativeSumInput{ + Size: size, + Round: round, + } + } + + zCtxEntry := zCatalog[key] + zCtxEntry.Numerator = append(zCtxEntry.Numerator, symbolic.Neg(stc.M[frag])) // no functions for num, denom here + zCtxEntry.Denominator = append(zCtxEntry.Denominator, symbolic.Add(stc.Gamma, stc.T[frag])) + } + + // Process the S columns + for table := range stc.S { var ( - tableName = lookup.NameTable(lookupTable) - round = mainLookupCtx.Rounds[tableName] - size = lookupTable[0][0].Size() - fragNum = len(lookupTable[0]) - mTable = make([]ifaces.Column, fragNum) + _, _, size = wizardutils.AsExpr(stc.S[table]) + sFilter = symbolic.NewConstant(1) ) - for frag := range lookupTable[0] { - mTableID := ifaces.ColIDf("%v_%v_%v", tableName, "M", frag) - mTable[frag] = comp.InsertCommit(round, mTableID, size) + if stc.SFilters[table] != nil { + sFilter = symbolic.NewVariable(stc.SFilters[table]) + } + + key := [2]int{round, size} + if zCatalog[key] == nil { + zCatalog[key] = &query.LogDerivativeSumInput{ + Size: size, + Round: round, + } } + zCtxEntry := zCatalog[key] + zCtxEntry.Numerator = append(zCtxEntry.Numerator, sFilter) + zCtxEntry.Denominator = append(zCtxEntry.Denominator, symbolic.Add(stc.Gamma, stc.S[table])) } } From 02d74d8c519bda762b044fa1ccb710cd1bc371d1 Mon Sep 17 00:00:00 2001 From: Soleimani193 Date: Fri, 13 Dec 2024 17:50:45 +0100 Subject: [PATCH 08/12] added the check for LogDerivativeSum query --- prover/protocol/query/logderiv_sum.go | 141 +++++++++++++++++++++++++- prover/protocol/wizard/compiled.go | 2 +- 2 files changed, 138 insertions(+), 5 deletions(-) diff --git a/prover/protocol/query/logderiv_sum.go b/prover/protocol/query/logderiv_sum.go index bbdacbb2d..2ed367b1d 100644 --- a/prover/protocol/query/logderiv_sum.go +++ b/prover/protocol/query/logderiv_sum.go @@ -1,10 +1,19 @@ package query import ( + "fmt" + "github.com/consensys/gnark/frontend" + "github.com/consensys/linea-monorepo/prover/crypto/fiatshamir" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/common/vector" "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/coin" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/variables" + "github.com/consensys/linea-monorepo/prover/symbolic" sym "github.com/consensys/linea-monorepo/prover/symbolic" + "github.com/consensys/linea-monorepo/prover/utils" ) // LogDerivativeSumInput stores the input to the query @@ -14,7 +23,7 @@ type LogDerivativeSumInput struct { Denominator []*sym.Expression // S or T -> ({S,T} + X) } -// LogDerivativeSum is the context of LogDerivativeSum. +// LogDerivativeSum is the context of LogDerivativeSum query. // The fields are maps from [round, size]. type LogDerivativeSum struct { Inputs map[[2]int]*LogDerivativeSumInput @@ -33,13 +42,32 @@ type LogDerivSumParams struct { Sum field.Element // the sum of all the ZOpenings from different [round,size]. } -// NewLogDerivativeSum creates the new context LogDerivativeSum. -func NewLogDerivativeSum(inp map[[2]int]*LogDerivativeSumInput) LogDerivativeSum { +// Updates a Fiat-Shamir state +func (l LogDerivSumParams) UpdateFS(fs *fiatshamir.State) { + fs.Update(l.Sum) +} - // add some sanity checks here +// NewLogDerivativeSum creates the new context LogDerivativeSum. +func NewLogDerivativeSum(inp map[[2]int]*LogDerivativeSumInput, id ifaces.QueryID) LogDerivativeSum { + + // check the length consistency + for key := range inp { + if len(inp[key].Numerator) != len(inp[key].Denominator) || len(inp[key].Numerator) == 0 { + panic("Numerator and Denominator should have the same (no-zero) length") + } + for i := range inp[key].Numerator { + if err := inp[key].Numerator[i].Validate(); err != nil { + utils.Panic(" Numerator[%v] is not a valid expression", i) + } + if err := inp[key].Denominator[i].Validate(); err != nil { + utils.Panic(" Denominator[%v] is not a valid expression", i) + } + } + } return LogDerivativeSum{ Inputs: inp, + ID: id, } } @@ -56,6 +84,42 @@ func NewLogDeriveSumParams(sum field.Element) LogDerivSumParams { // Test that global sum is correct func (r LogDerivativeSum) Check(run ifaces.Runtime) error { + params := run.GetParams(r.ID).(LogDerivSumParams) + // compute the actual sum from the Numerator and Denominator + actualSum := field.Zero() + for key := range r.Inputs { + for i, num := range r.Inputs[key].Numerator { + + var ( + numBoard = num.Board() + denBoard = r.Inputs[key].Denominator[i].Board() + numeratorMetadata = numBoard.ListVariableMetadata() + denominator = EvalExprColumn(run, denBoard).IntoRegVecSaveAlloc() + numerator []field.Element + packedZ = field.BatchInvert(denominator) + ) + + if len(numeratorMetadata) == 0 { + numerator = vector.Repeat(field.One(), r.Inputs[key].Size) + } + + if len(numeratorMetadata) > 0 { + numerator = EvalExprColumn(run, numBoard).IntoRegVecSaveAlloc() + } + + for k := range packedZ { + packedZ[k].Mul(&numerator[k], &packedZ[k]) + if k > 0 { + packedZ[k].Add(&packedZ[k], &packedZ[k-1]) + } + } + actualSum.Add(&actualSum, &packedZ[len(packedZ)-1]) + } + } + + if actualSum != params.Sum { + return fmt.Errorf("expected LogDerivativeSum = %s but got %s for the query %v", params.Sum.String(), actualSum.String(), r.ID) + } return nil } @@ -64,3 +128,72 @@ func (r LogDerivativeSum) Check(run ifaces.Runtime) error { func (r LogDerivativeSum) CheckGnark(api frontend.API, run ifaces.GnarkRuntime) { } + +func EvalExprColumn(run ifaces.Runtime, board symbolic.ExpressionBoard) smartvectors.SmartVector { + + var ( + metadata = board.ListVariableMetadata() + inputs = make([]smartvectors.SmartVector, len(metadata)) + length = ExprIsOnSameLengthHandles(&board) + ) + + // Attempt to recover the size of the + for i := range inputs { + switch m := metadata[i].(type) { + case ifaces.Column: + inputs[i] = m.GetColAssignment(run) + case coin.Info: + v := run.GetRandomCoinField(m.Name) + inputs[i] = smartvectors.NewConstant(v, length) + case ifaces.Accessor: + v := m.GetVal(run) + inputs[i] = smartvectors.NewConstant(v, length) + case variables.PeriodicSample: + v := m.EvalCoset(length, 0, 1, false) + inputs[i] = v + case variables.X: + v := m.EvalCoset(length, 0, 1, false) + inputs[i] = v + } + } + + return board.Evaluate(inputs) +} + +// ExprIsOnSameLengthHandles checks that all the variables of the expression +// that are [ifaces.Column] have the same size (and panics if it does not), then +// returns the match. +func ExprIsOnSameLengthHandles(board *symbolic.ExpressionBoard) int { + + var ( + metadatas = board.ListVariableMetadata() + length = 0 + ) + + for _, m := range metadatas { + switch metadata := m.(type) { + case ifaces.Column: + // Initialize the length with the first commitment + if length == 0 { + length = metadata.Size() + } + + // Sanity-check the vector should all have the same length + if length != metadata.Size() { + utils.Panic("Inconsistent length for %v (has size %v, but expected %v)", metadata.GetColID(), metadata.Size(), length) + } + // The expression can involve random coins + case coin.Info, variables.X, variables.PeriodicSample, ifaces.Accessor: + // Do nothing + default: + utils.Panic("unknown type %T", metadata) + } + } + + // No commitment were found in the metadata, thus this call is broken + if length == 0 { + utils.Panic("declared a handle from an expression which does not contains any handle") + } + + return length +} diff --git a/prover/protocol/wizard/compiled.go b/prover/protocol/wizard/compiled.go index f0b0efd87..2ccc7f554 100644 --- a/prover/protocol/wizard/compiled.go +++ b/prover/protocol/wizard/compiled.go @@ -584,7 +584,7 @@ func (c *CompiledIOP) InsertLocalOpening(round int, name ifaces.QueryID, pol ifa // The sigma columns are categorized by [round,size]. func (c *CompiledIOP) InsertLogDerivativeSum(lastRound int, id ifaces.QueryID, in map[[2]int]*query.LogDerivativeSumInput) query.LogDerivativeSum { c.assertConsistentRound(lastRound) - q := query.NewLogDerivativeSum(in) + q := query.NewLogDerivativeSum(in, id) // Finally registers the query c.QueriesParams.AddToRound(lastRound, id, q) return q From 157f641a0bd05c8dafc451c735f7f6fccbb6646c Mon Sep 17 00:00:00 2001 From: Soleimani193 Date: Mon, 16 Dec 2024 17:43:03 +0100 Subject: [PATCH 09/12] added the test for LogDerivativeSum --- prover/protocol/query/gnark_params.go | 13 ++++ prover/protocol/query/logderiv_sum.go | 11 +++- prover/protocol/query/logderiv_sum_test.go | 71 ++++++++++++++++++++++ prover/protocol/wizard/gnark_verifier.go | 20 ++++++ prover/protocol/wizard/prover.go | 37 +++++++++++ prover/protocol/wizard/verifier.go | 5 ++ 6 files changed, 155 insertions(+), 2 deletions(-) create mode 100644 prover/protocol/query/logderiv_sum_test.go diff --git a/prover/protocol/query/gnark_params.go b/prover/protocol/query/gnark_params.go index 76c60b3b0..fc2ddabbf 100644 --- a/prover/protocol/query/gnark_params.go +++ b/prover/protocol/query/gnark_params.go @@ -15,6 +15,14 @@ func (p LocalOpeningParams) GnarkAssign() GnarkLocalOpeningParams { return GnarkLocalOpeningParams{Y: p.Y} } +type GnarkLogDerivSumParams struct { + Y frontend.Variable +} + +func (p LogDerivSumParams) GnarkAssign() GnarkLogDerivSumParams { + return GnarkLogDerivSumParams{Y: p.Sum} +} + // A gnark circuit version of InnerProductParams type GnarkInnerProductParams struct { Ys []frontend.Variable @@ -54,6 +62,11 @@ func (p GnarkLocalOpeningParams) UpdateFS(fs *fiatshamir.GnarkFiatShamir) { fs.Update(p.Y) } +// Update the fiat-shamir state with the the present parameters +func (p GnarkLogDerivSumParams) UpdateFS(fs *fiatshamir.GnarkFiatShamir) { + fs.Update(p.Y) +} + // Update the fiat-shamir state with the the present parameters func (p GnarkUnivariateEvalParams) UpdateFS(fs *fiatshamir.GnarkFiatShamir) { fs.Update(p.Ys...) diff --git a/prover/protocol/query/logderiv_sum.go b/prover/protocol/query/logderiv_sum.go index 2ed367b1d..774ddb088 100644 --- a/prover/protocol/query/logderiv_sum.go +++ b/prover/protocol/query/logderiv_sum.go @@ -25,6 +25,10 @@ type LogDerivativeSumInput struct { // LogDerivativeSum is the context of LogDerivativeSum query. // The fields are maps from [round, size]. +// the aim of the query is to compute: +// \sum_{i,j} N_{i,j}/D_{i,j} where +// N_{i,j} is the i-th element of the underlying column of j-th Numerator +// D_{i,j} is the i-th element of the underlying column of j-th Denominator type LogDerivativeSum struct { Inputs map[[2]int]*LogDerivativeSumInput @@ -78,7 +82,7 @@ func (r LogDerivativeSum) Name() ifaces.QueryID { } // Constructor for the query parameters/result -func NewLogDeriveSumParams(sum field.Element) LogDerivSumParams { +func NewLogDerivSumParams(sum field.Element) LogDerivSumParams { return LogDerivSumParams{Sum: sum} } @@ -126,7 +130,10 @@ func (r LogDerivativeSum) Check(run ifaces.Runtime) error { // Test that global sum is correct func (r LogDerivativeSum) CheckGnark(api frontend.API, run ifaces.GnarkRuntime) { - + /*params := run.GetParams(r.ID).(GnarkLogDerivSumParams) + actualY := TBD + api.AssertIsEqual(params.Y, actualY) + */ } func EvalExprColumn(run ifaces.Runtime, board symbolic.ExpressionBoard) smartvectors.SmartVector { diff --git a/prover/protocol/query/logderiv_sum_test.go b/prover/protocol/query/logderiv_sum_test.go new file mode 100644 index 000000000..1a6283f4f --- /dev/null +++ b/prover/protocol/query/logderiv_sum_test.go @@ -0,0 +1,71 @@ +package query_test + +import ( + "testing" + + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/dummy" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/symbolic" + "github.com/stretchr/testify/require" +) + +func TestLogDerivSum(t *testing.T) { + + define := func(b *wizard.Builder) { + var ( + comp = b.CompiledIOP + ) + + p0 := b.RegisterCommit("Num_0", 4) + p1 := b.RegisterCommit("Num_1", 4) + p2 := b.RegisterCommit("Num_2", 4) + + q0 := b.RegisterCommit("Den_0", 4) + q1 := b.RegisterCommit("Den_1", 4) + q2 := b.RegisterCommit("Den_2", 4) + + numerators := []*symbolic.Expression{ + symbolic.Mul(p0, -1), + ifaces.ColumnAsVariable(p1), + symbolic.Mul(p2, p0, 2), + } + + denominators := []*symbolic.Expression{ + ifaces.ColumnAsVariable(q0), + ifaces.ColumnAsVariable(q1), + ifaces.ColumnAsVariable(q2), + } + + key := [2]int{0, 0} + zCat1 := map[[2]int]*query.LogDerivativeSumInput{} + zCat1[key] = &query.LogDerivativeSumInput{ + Numerator: numerators, + Denominator: denominators, + } + comp.InsertLogDerivativeSum(0, "LogDerivSum_Test", zCat1) + + } + + prover := func(run *wizard.ProverRuntime) { + + run.AssignColumn("Num_0", smartvectors.ForTest(1, 1, 1, 1)) + run.AssignColumn("Num_1", smartvectors.ForTest(2, 3, 7, 9)) + run.AssignColumn("Num_2", smartvectors.ForTest(5, 6, 1, 1)) + + run.AssignColumn("Den_0", smartvectors.ForTest(1, 1, 1, 1)) + run.AssignColumn("Den_1", smartvectors.ForTest(2, 3, 7, 9)) + run.AssignColumn("Den_2", smartvectors.ForTest(5, 6, 1, 1)) + + run.AssignLogDerivSum("LogDerivSum_Test", field.NewElement(8)) + + } + + compiled := wizard.Compile(define, dummy.Compile) + proof := wizard.Prove(compiled, prover) + valid := wizard.Verify(compiled, proof) + require.NoError(t, valid) +} diff --git a/prover/protocol/wizard/gnark_verifier.go b/prover/protocol/wizard/gnark_verifier.go index 841c3755c..efe06be28 100644 --- a/prover/protocol/wizard/gnark_verifier.go +++ b/prover/protocol/wizard/gnark_verifier.go @@ -47,6 +47,8 @@ type WizardVerifierCircuit struct { innerProductIDs collection.Mapping[ifaces.QueryID, int] `gnark:"-"` // Same for local-opening query localOpeningIDs collection.Mapping[ifaces.QueryID, int] `gnark:"-"` + // Same for logDerivativeSum query + logDerivSumIDs collection.Mapping[ifaces.QueryID, int] `gnark:"-"` // Columns stores the gnark witness part corresponding to the columns // provided in the proof and in the VerifyingKey. @@ -63,6 +65,9 @@ type WizardVerifierCircuit struct { // LocalOpeningParams stores an assignment for each [query.LocalOpeningParams] // from the proof. It is part of the witness of the gnark circuit. LocalOpeningParams []query.GnarkLocalOpeningParams `gnark:",secret"` + // LogDerivSumParams stores an assignment for each [query.LogDerivSumParams] + // from the proof. It is part of the witness of the gnark circuit. + LogDerivSumParams []query.GnarkLogDerivSumParams `gnark:",secret"` // FS is the Fiat-Shamir state, mirroring [VerifierRuntime.FS]. The same // cautionnary rules apply to it; e.g. don't use it externally when @@ -311,6 +316,14 @@ func (c *WizardVerifierCircuit) GetLocalPointEvalParams(name ifaces.QueryID) que return c.LocalOpeningParams[qID] } +// GetLogDerivSumParams returns the parameters for the requested +// [query.LogDerivativeSum] query. Its work mirrors the function +// [VerifierRuntime.GetLogDerivSumParams] +func (c *WizardVerifierCircuit) GetLogDerivSumParams(name ifaces.QueryID) query.GnarkLocalOpeningParams { + qID := c.localOpeningIDs.MustGet(name) + return c.LocalOpeningParams[qID] +} + // GetColumns returns the gnark assignment of a column in a gnark circuit. It // mirrors the function [VerifierRuntime.GetColumn] func (c *WizardVerifierCircuit) GetColumn(name ifaces.ColID) []frontend.Variable { @@ -350,11 +363,13 @@ func newWizardVerifierCircuit() *WizardVerifierCircuit { res.columnsIDs = collection.NewMapping[ifaces.ColID, int]() res.univariateParamsIDs = collection.NewMapping[ifaces.QueryID, int]() res.localOpeningIDs = collection.NewMapping[ifaces.QueryID, int]() + res.logDerivSumIDs = collection.NewMapping[ifaces.QueryID, int]() res.innerProductIDs = collection.NewMapping[ifaces.QueryID, int]() res.Columns = [][]frontend.Variable{} res.UnivariateParams = make([]query.GnarkUnivariateEvalParams, 0) res.InnerProductParams = make([]query.GnarkInnerProductParams, 0) res.LocalOpeningParams = make([]query.GnarkLocalOpeningParams, 0) + res.LogDerivSumParams = make([]query.GnarkLogDerivSumParams, 0) res.Coins = collection.NewMapping[coin.Name, interface{}]() return res } @@ -418,6 +433,9 @@ func GetWizardVerifierCircuitAssignment(comp *CompiledIOP, proof Proof) *WizardV case query.LocalOpeningParams: res.localOpeningIDs.InsertNew(qName, len(res.LocalOpeningParams)) res.LocalOpeningParams = append(res.LocalOpeningParams, params.GnarkAssign()) + case query.LogDerivSumParams: + res.logDerivSumIDs.InsertNew(qName, len(res.LogDerivSumParams)) + res.LogDerivSumParams = append(res.LogDerivSumParams, params.GnarkAssign()) default: utils.Panic("unknow type %T", params) @@ -435,6 +453,8 @@ func (c *WizardVerifierCircuit) GetParams(id ifaces.QueryID) ifaces.GnarkQueryPa return c.GetUnivariateParams(id) case query.LocalOpening: return c.GetLocalPointEvalParams(id) + case query.LogDerivativeSum: + return c.GetLogDerivSumParams(id) case query.InnerProduct: return c.GetInnerProductParams(id) default: diff --git a/prover/protocol/wizard/prover.go b/prover/protocol/wizard/prover.go index caa4df4aa..7d42cdb61 100644 --- a/prover/protocol/wizard/prover.go +++ b/prover/protocol/wizard/prover.go @@ -691,6 +691,43 @@ func (run *ProverRuntime) GetLocalPointEvalParams(name ifaces.QueryID) query.Loc return run.QueriesParams.MustGet(name).(query.LocalOpeningParams) } +// AssignLogDerivSum assign the claimed values for a logDeriveSum +// The function will panic if: +// - the parameters were already assigned +// - the specified query is not registered +// - the assignment round is incorrect +func (run *ProverRuntime) AssignLogDerivSum(name ifaces.QueryID, y field.Element) { + + // Global prover locks for accessing the maps + run.lock.Lock() + defer run.lock.Unlock() + + // Make sure, it is done at the right round + run.Spec.QueriesParams.MustBeInRound(run.currRound, name) + + // Adds it to the assignments + params := query.NewLogDerivSumParams(y) + run.QueriesParams.InsertNew(name, params) +} + +// GetLogDeriveSum gets the metadata of a [query.LogDerivativeSum] query. Panic if not found. +func (run *ProverRuntime) GetLogDeriveSum(name ifaces.QueryID) query.LogDerivativeSum { + // Global prover locks for accessing the maps + run.lock.Lock() + defer run.lock.Unlock() + return run.Spec.QueriesParams.Data(name).(query.LogDerivativeSum) +} + +// GetLogDerivSumParams returns the parameters of [query.LogDerivativeSum] +func (run *ProverRuntime) GetLogDerivSumParams(name ifaces.QueryID) query.LogDerivSumParams { + + // Global prover's lock for accessing params + run.lock.Lock() + defer run.lock.Unlock() + + return run.QueriesParams.MustGet(name).(query.LogDerivSumParams) +} + // GetParams generically extracts the parameters of a query. Will panic if no // parameters are found func (run *ProverRuntime) GetParams(name ifaces.QueryID) ifaces.QueryParams { diff --git a/prover/protocol/wizard/verifier.go b/prover/protocol/wizard/verifier.go index 1a355796e..2c5b8da5f 100644 --- a/prover/protocol/wizard/verifier.go +++ b/prover/protocol/wizard/verifier.go @@ -308,6 +308,11 @@ func (run *VerifierRuntime) GetLocalPointEvalParams(name ifaces.QueryID) query.L return run.QueriesParams.MustGet(name).(query.LocalOpeningParams) } +// GetLogDerivSumParams returns the parameters of a [query.LogDerivativeSum] +func (run *VerifierRuntime) GetLogDerivSumParams(name ifaces.QueryID) query.LogDerivSumParams { + return run.QueriesParams.MustGet(name).(query.LogDerivSumParams) +} + /* CopyColumnInto implements `column.GetWitness` Copies the witness into a slice From eb9fa7b3a1057fa73ad04ea14f41e71798c8eab3 Mon Sep 17 00:00:00 2001 From: Soleimani193 Date: Tue, 17 Dec 2024 08:58:24 +0100 Subject: [PATCH 10/12] moving two functions from wizardutils to the column package --- prover/protocol/column/column.go | 73 +++++++++++++++++ prover/protocol/compiler/globalcs/merging.go | 4 +- .../protocol/compiler/innerproduct/prover.go | 4 +- prover/protocol/compiler/lookup/prover.go | 5 +- prover/protocol/compiler/mimc/manual.go | 3 +- .../protocol/compiler/permutation/prover.go | 6 +- .../protocol/dedicated/bigrange/bigrange.go | 3 +- .../protocol/dedicated/byte32cmp/decompose.go | 2 +- .../dedicated/byte32cmp/multi_limb_cmp.go | 3 +- .../dedicated/expr_handle/expr_handle.go | 3 +- prover/protocol/dedicated/is_zero.go | 7 +- .../dedicated/projection/projection.go | 4 +- prover/protocol/query/logderiv_sum.go | 78 +------------------ prover/protocol/wizardutils/evaluation.go | 40 ---------- prover/protocol/wizardutils/utils.go | 41 +--------- .../protocol/wizardutils/wizardutils_test.go | 9 ++- 16 files changed, 108 insertions(+), 177 deletions(-) diff --git a/prover/protocol/column/column.go b/prover/protocol/column/column.go index 3df39cfee..d44d4cc16 100644 --- a/prover/protocol/column/column.go +++ b/prover/protocol/column/column.go @@ -3,7 +3,11 @@ package column import ( "reflect" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/protocol/coin" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/variables" + "github.com/consensys/linea-monorepo/prover/symbolic" "github.com/consensys/linea-monorepo/prover/utils" ) @@ -91,3 +95,72 @@ func NbLeaves(h ifaces.Column) int { } panic("unreachable") } + +func EvalExprColumn(run ifaces.Runtime, board symbolic.ExpressionBoard) smartvectors.SmartVector { + + var ( + metadata = board.ListVariableMetadata() + inputs = make([]smartvectors.SmartVector, len(metadata)) + length = ExprIsOnSameLengthHandles(&board) + ) + + // Attempt to recover the size of the + for i := range inputs { + switch m := metadata[i].(type) { + case ifaces.Column: + inputs[i] = m.GetColAssignment(run) + case coin.Info: + v := run.GetRandomCoinField(m.Name) + inputs[i] = smartvectors.NewConstant(v, length) + case ifaces.Accessor: + v := m.GetVal(run) + inputs[i] = smartvectors.NewConstant(v, length) + case variables.PeriodicSample: + v := m.EvalCoset(length, 0, 1, false) + inputs[i] = v + case variables.X: + v := m.EvalCoset(length, 0, 1, false) + inputs[i] = v + } + } + + return board.Evaluate(inputs) +} + +// ExprIsOnSameLengthHandles checks that all the variables of the expression +// that are [ifaces.Column] have the same size (and panics if it does not), then +// returns the match. +func ExprIsOnSameLengthHandles(board *symbolic.ExpressionBoard) int { + + var ( + metadatas = board.ListVariableMetadata() + length = 0 + ) + + for _, m := range metadatas { + switch metadata := m.(type) { + case ifaces.Column: + // Initialize the length with the first commitment + if length == 0 { + length = metadata.Size() + } + + // Sanity-check the vector should all have the same length + if length != metadata.Size() { + utils.Panic("Inconsistent length for %v (has size %v, but expected %v)", metadata.GetColID(), metadata.Size(), length) + } + // The expression can involve random coins + case coin.Info, variables.X, variables.PeriodicSample, ifaces.Accessor: + // Do nothing + default: + utils.Panic("unknown type %T", metadata) + } + } + + // No commitment were found in the metadata, thus this call is broken + if length == 0 { + utils.Panic("declared a handle from an expression which does not contains any handle") + } + + return length +} diff --git a/prover/protocol/compiler/globalcs/merging.go b/prover/protocol/compiler/globalcs/merging.go index f7b04c453..8d6e22b0b 100644 --- a/prover/protocol/compiler/globalcs/merging.go +++ b/prover/protocol/compiler/globalcs/merging.go @@ -7,11 +7,11 @@ import ( "github.com/consensys/linea-monorepo/prover/maths/fft" "github.com/consensys/linea-monorepo/prover/maths/field" "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/protocol/variables" "github.com/consensys/linea-monorepo/prover/protocol/wizard" - "github.com/consensys/linea-monorepo/prover/protocol/wizardutils" "github.com/consensys/linea-monorepo/prover/symbolic" "github.com/consensys/linea-monorepo/prover/utils" ) @@ -178,7 +178,7 @@ func getBoundCancelledExpression(cs query.GlobalConstraint) *symbolic.Expression func getExprRatio(expr *symbolic.Expression) int { var ( board = expr.Board() - domainSize = wizardutils.ExprIsOnSameLengthHandles(&board) + domainSize = column.ExprIsOnSameLengthHandles(&board) exprDegree = board.Degree(GetDegree(domainSize)) quotientSize = exprDegree - domainSize + 1 ratio = utils.DivCeil(quotientSize, domainSize) diff --git a/prover/protocol/compiler/innerproduct/prover.go b/prover/protocol/compiler/innerproduct/prover.go index 6cc508b23..41ed30887 100644 --- a/prover/protocol/compiler/innerproduct/prover.go +++ b/prover/protocol/compiler/innerproduct/prover.go @@ -5,8 +5,8 @@ import ( "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/wizard" - "github.com/consensys/linea-monorepo/prover/protocol/wizardutils" ) // proverTask implements the [wizard.ProverAction] interface and as such @@ -39,7 +39,7 @@ func (ctx *contextForSize) run(run *wizard.ProverRuntime) { var ( size = ctx.Summation.Size() - collapsed = wizardutils.EvalExprColumn(run, ctx.CollapsedBoard).IntoRegVecSaveAlloc() + collapsed = column.EvalExprColumn(run, ctx.CollapsedBoard).IntoRegVecSaveAlloc() summation = make([]field.Element, size) ) diff --git a/prover/protocol/compiler/lookup/prover.go b/prover/protocol/compiler/lookup/prover.go index 12a73f94f..72f7d44bc 100644 --- a/prover/protocol/compiler/lookup/prover.go +++ b/prover/protocol/compiler/lookup/prover.go @@ -7,6 +7,7 @@ import ( sv "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" "github.com/consensys/linea-monorepo/prover/maths/common/vector" "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/wizard" "github.com/consensys/linea-monorepo/prover/protocol/wizardutils" @@ -308,7 +309,7 @@ func (z zAssignmentTask) run(run *wizard.ProverRuntime) { var ( numeratorMetadata = z.ZNumeratorBoarded[frag].ListVariableMetadata() - denominator = wizardutils.EvalExprColumn(run, z.ZDenominatorBoarded[frag]).IntoRegVecSaveAlloc() + denominator = column.EvalExprColumn(run, z.ZDenominatorBoarded[frag]).IntoRegVecSaveAlloc() numerator []field.Element packedZ = field.BatchInvert(denominator) ) @@ -318,7 +319,7 @@ func (z zAssignmentTask) run(run *wizard.ProverRuntime) { } if len(numeratorMetadata) > 0 { - numerator = wizardutils.EvalExprColumn(run, z.ZNumeratorBoarded[frag]).IntoRegVecSaveAlloc() + numerator = column.EvalExprColumn(run, z.ZNumeratorBoarded[frag]).IntoRegVecSaveAlloc() } for k := range packedZ { diff --git a/prover/protocol/compiler/mimc/manual.go b/prover/protocol/compiler/mimc/manual.go index a204392ab..50885feb6 100644 --- a/prover/protocol/compiler/mimc/manual.go +++ b/prover/protocol/compiler/mimc/manual.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/consensys/linea-monorepo/prover/crypto/mimc" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/wizard" "github.com/consensys/linea-monorepo/prover/protocol/wizardutils" @@ -155,7 +156,7 @@ func mimcExprHandle(comp *wizard.CompiledIOP, expr *symbolic.Expression, name .. maxRound := wizardutils.LastRoundToEval(expr) board := expr.Board() - length := wizardutils.ExprIsOnSameLengthHandles(&board) + length := column.ExprIsOnSameLengthHandles(&board) handleName := fmt.Sprintf("SYMBOLIC_%v", expr.ESHash.String()) if len(name) > 0 { diff --git a/prover/protocol/compiler/permutation/prover.go b/prover/protocol/compiler/permutation/prover.go index ab5d0e79a..190da0ce2 100644 --- a/prover/protocol/compiler/permutation/prover.go +++ b/prover/protocol/compiler/permutation/prover.go @@ -6,8 +6,8 @@ import ( "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" "github.com/consensys/linea-monorepo/prover/maths/common/vector" "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/wizard" - "github.com/consensys/linea-monorepo/prover/protocol/wizardutils" ) // proverTaskAtRound implements the [wizard.ProverAction] interface and is @@ -48,13 +48,13 @@ func (z *ZCtx) run(run *wizard.ProverRuntime) { ) if packingArity*i < len(z.NumeratorFactors) { - numerator = wizardutils.EvalExprColumn(run, z.NumeratorFactorsBoarded[i]).IntoRegVecSaveAlloc() + numerator = column.EvalExprColumn(run, z.NumeratorFactorsBoarded[i]).IntoRegVecSaveAlloc() } else { numerator = vector.Repeat(field.One(), z.Size) } if packingArity*i < len(z.DenominatorFactors) { - denominator = wizardutils.EvalExprColumn(run, z.DenominatorFactorsBoarded[i]).IntoRegVecSaveAlloc() + denominator = column.EvalExprColumn(run, z.DenominatorFactorsBoarded[i]).IntoRegVecSaveAlloc() } else { denominator = vector.Repeat(field.One(), z.Size) } diff --git a/prover/protocol/dedicated/bigrange/bigrange.go b/prover/protocol/dedicated/bigrange/bigrange.go index 60244ca90..c5192a819 100644 --- a/prover/protocol/dedicated/bigrange/bigrange.go +++ b/prover/protocol/dedicated/bigrange/bigrange.go @@ -11,6 +11,7 @@ import ( "github.com/consensys/linea-monorepo/prover/maths/fft" "github.com/consensys/linea-monorepo/prover/maths/field" "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/variables" "github.com/consensys/linea-monorepo/prover/protocol/wizard" @@ -41,7 +42,7 @@ func BigRange(comp *wizard.CompiledIOP, expr *symbolic.Expression, numLimbs, bit limbs = make([]ifaces.Column, numLimbs) round = wizardutils.LastRoundToEval(expr) boarded = expr.Board() - size = wizardutils.ExprIsOnSameLengthHandles(&boarded) + size = column.ExprIsOnSameLengthHandles(&boarded) totalNumBits = numLimbs * bitPerLimbs ) diff --git a/prover/protocol/dedicated/byte32cmp/decompose.go b/prover/protocol/dedicated/byte32cmp/decompose.go index 6058cca73..538412887 100644 --- a/prover/protocol/dedicated/byte32cmp/decompose.go +++ b/prover/protocol/dedicated/byte32cmp/decompose.go @@ -170,7 +170,7 @@ func (d *decompositionCtx) Run(run *wizard.ProverRuntime) { numLimbs = len(d.decomposed.Limbs) bitPerLimbs = d.decomposed.LimbBitSize totalNumBits = numLimbs * bitPerLimbs - original = wizardutils.EvalExprColumn(run, d.original) + original = column.EvalExprColumn(run, d.original) limbsWitness = make([][]field.Element, numLimbs) size = original.Len() ) diff --git a/prover/protocol/dedicated/byte32cmp/multi_limb_cmp.go b/prover/protocol/dedicated/byte32cmp/multi_limb_cmp.go index 8ff17c5a6..35d96e11d 100644 --- a/prover/protocol/dedicated/byte32cmp/multi_limb_cmp.go +++ b/prover/protocol/dedicated/byte32cmp/multi_limb_cmp.go @@ -6,6 +6,7 @@ import ( "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/dedicated" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/wizard" @@ -192,7 +193,7 @@ func (mCmp *multiLimbCmp) Run(run *wizard.ProverRuntime) { }() var ( - syndrom = wizardutils.EvalExprColumn(run, mCmp.syndromBoard) + syndrom = column.EvalExprColumn(run, mCmp.syndromBoard) isGreater = make([]field.Element, mCmp.isGreater.Size()) isLower = make([]field.Element, mCmp.isLower.Size()) nnSyndrom = make([]field.Element, mCmp.isLower.Size()) diff --git a/prover/protocol/dedicated/expr_handle/expr_handle.go b/prover/protocol/dedicated/expr_handle/expr_handle.go index 4e72c2259..7e72eae06 100644 --- a/prover/protocol/dedicated/expr_handle/expr_handle.go +++ b/prover/protocol/dedicated/expr_handle/expr_handle.go @@ -8,6 +8,7 @@ import ( "github.com/consensys/linea-monorepo/prover/maths/fft" "github.com/consensys/linea-monorepo/prover/maths/field" "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/variables" "github.com/consensys/linea-monorepo/prover/protocol/wizard" @@ -25,7 +26,7 @@ func ExprHandle(comp *wizard.CompiledIOP, expr *symbolic.Expression, name ...str var ( boarded = expr.Board() maxRound = wizardutils.LastRoundToEval(expr) - length = wizardutils.ExprIsOnSameLengthHandles(&boarded) + length = column.ExprIsOnSameLengthHandles(&boarded) handleName = fmt.Sprintf("SYMBOLIC_%v", expr.ESHash.String()) ) diff --git a/prover/protocol/dedicated/is_zero.go b/prover/protocol/dedicated/is_zero.go index 208886c7f..68df50089 100644 --- a/prover/protocol/dedicated/is_zero.go +++ b/prover/protocol/dedicated/is_zero.go @@ -2,6 +2,7 @@ package dedicated import ( "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/wizard" "github.com/consensys/linea-monorepo/prover/protocol/wizardutils" @@ -44,7 +45,7 @@ func IsZero(comp *wizard.CompiledIOP, c any) (ifaces.Column, wizard.ProverAction case *sym.Expression: board := c1.Board() ctx.c = c1 - ctx.size = wizardutils.ExprIsOnSameLengthHandles(&board) + ctx.size = column.ExprIsOnSameLengthHandles(&board) ctx.round = wizardutils.LastRoundToEval(c1) } @@ -123,13 +124,13 @@ func compileIsZeroWithSize(comp *wizard.CompiledIOP, ctx *isZeroCtx) { // Run implements the [wizard.ProverAction] interface func (ctx *isZeroCtx) Run(run *wizard.ProverRuntime) { var ( - c = wizardutils.EvalExprColumn(run, ctx.c.Board()) + c = column.EvalExprColumn(run, ctx.c.Board()) invOrZero = smartvectors.BatchInvert(c) isZero = smartvectors.IsZero(c) ) if ctx.mask != nil { - mask := wizardutils.EvalExprColumn(run, ctx.mask.Board()) + mask := column.EvalExprColumn(run, ctx.mask.Board()) invOrZero = smartvectors.Mul(invOrZero, mask) isZero = smartvectors.Mul(isZero, mask) } diff --git a/prover/protocol/dedicated/projection/projection.go b/prover/protocol/dedicated/projection/projection.go index efa137fb4..8141e01da 100644 --- a/prover/protocol/dedicated/projection/projection.go +++ b/prover/protocol/dedicated/projection/projection.go @@ -219,8 +219,8 @@ func InsertProjection( func (pa projectionProverAction) Run(run *wizard.ProverRuntime) { var ( - a = wizardutils.EvalExprColumn(run, pa.ABoard).IntoRegVecSaveAlloc() - b = wizardutils.EvalExprColumn(run, pa.BBoard).IntoRegVecSaveAlloc() + a = column.EvalExprColumn(run, pa.ABoard).IntoRegVecSaveAlloc() + b = column.EvalExprColumn(run, pa.BBoard).IntoRegVecSaveAlloc() fA = pa.FilterA.GetColAssignment(run).IntoRegVecSaveAlloc() fB = pa.FilterB.GetColAssignment(run).IntoRegVecSaveAlloc() x = run.GetRandomCoinField(pa.EvalCoin.Name) diff --git a/prover/protocol/query/logderiv_sum.go b/prover/protocol/query/logderiv_sum.go index 774ddb088..8617809ab 100644 --- a/prover/protocol/query/logderiv_sum.go +++ b/prover/protocol/query/logderiv_sum.go @@ -5,13 +5,10 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/linea-monorepo/prover/crypto/fiatshamir" - "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" "github.com/consensys/linea-monorepo/prover/maths/common/vector" "github.com/consensys/linea-monorepo/prover/maths/field" - "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" - "github.com/consensys/linea-monorepo/prover/protocol/variables" - "github.com/consensys/linea-monorepo/prover/symbolic" sym "github.com/consensys/linea-monorepo/prover/symbolic" "github.com/consensys/linea-monorepo/prover/utils" ) @@ -98,7 +95,7 @@ func (r LogDerivativeSum) Check(run ifaces.Runtime) error { numBoard = num.Board() denBoard = r.Inputs[key].Denominator[i].Board() numeratorMetadata = numBoard.ListVariableMetadata() - denominator = EvalExprColumn(run, denBoard).IntoRegVecSaveAlloc() + denominator = column.EvalExprColumn(run, denBoard).IntoRegVecSaveAlloc() numerator []field.Element packedZ = field.BatchInvert(denominator) ) @@ -108,7 +105,7 @@ func (r LogDerivativeSum) Check(run ifaces.Runtime) error { } if len(numeratorMetadata) > 0 { - numerator = EvalExprColumn(run, numBoard).IntoRegVecSaveAlloc() + numerator = column.EvalExprColumn(run, numBoard).IntoRegVecSaveAlloc() } for k := range packedZ { @@ -135,72 +132,3 @@ func (r LogDerivativeSum) CheckGnark(api frontend.API, run ifaces.GnarkRuntime) api.AssertIsEqual(params.Y, actualY) */ } - -func EvalExprColumn(run ifaces.Runtime, board symbolic.ExpressionBoard) smartvectors.SmartVector { - - var ( - metadata = board.ListVariableMetadata() - inputs = make([]smartvectors.SmartVector, len(metadata)) - length = ExprIsOnSameLengthHandles(&board) - ) - - // Attempt to recover the size of the - for i := range inputs { - switch m := metadata[i].(type) { - case ifaces.Column: - inputs[i] = m.GetColAssignment(run) - case coin.Info: - v := run.GetRandomCoinField(m.Name) - inputs[i] = smartvectors.NewConstant(v, length) - case ifaces.Accessor: - v := m.GetVal(run) - inputs[i] = smartvectors.NewConstant(v, length) - case variables.PeriodicSample: - v := m.EvalCoset(length, 0, 1, false) - inputs[i] = v - case variables.X: - v := m.EvalCoset(length, 0, 1, false) - inputs[i] = v - } - } - - return board.Evaluate(inputs) -} - -// ExprIsOnSameLengthHandles checks that all the variables of the expression -// that are [ifaces.Column] have the same size (and panics if it does not), then -// returns the match. -func ExprIsOnSameLengthHandles(board *symbolic.ExpressionBoard) int { - - var ( - metadatas = board.ListVariableMetadata() - length = 0 - ) - - for _, m := range metadatas { - switch metadata := m.(type) { - case ifaces.Column: - // Initialize the length with the first commitment - if length == 0 { - length = metadata.Size() - } - - // Sanity-check the vector should all have the same length - if length != metadata.Size() { - utils.Panic("Inconsistent length for %v (has size %v, but expected %v)", metadata.GetColID(), metadata.Size(), length) - } - // The expression can involve random coins - case coin.Info, variables.X, variables.PeriodicSample, ifaces.Accessor: - // Do nothing - default: - utils.Panic("unknown type %T", metadata) - } - } - - // No commitment were found in the metadata, thus this call is broken - if length == 0 { - utils.Panic("declared a handle from an expression which does not contains any handle") - } - - return length -} diff --git a/prover/protocol/wizardutils/evaluation.go b/prover/protocol/wizardutils/evaluation.go index 7381060ac..a3b93acc8 100644 --- a/prover/protocol/wizardutils/evaluation.go +++ b/prover/protocol/wizardutils/evaluation.go @@ -5,50 +5,10 @@ import ( "github.com/consensys/linea-monorepo/prover/maths/field" "github.com/consensys/linea-monorepo/prover/protocol/coin" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" - "github.com/consensys/linea-monorepo/prover/protocol/variables" "github.com/consensys/linea-monorepo/prover/protocol/wizard" "github.com/consensys/linea-monorepo/prover/symbolic" ) -// EvalExprColumn resolves an expression to a column assignment. The expression -// must be converted to a board prior to evaluating the expression. -// -// - If the expression does not uses ifaces.Column as metadata, the function -// will panic. -// -// - If the expression contains several columns and they don't contain all -// have the same size. -func EvalExprColumn(run *wizard.ProverRuntime, board symbolic.ExpressionBoard) smartvectors.SmartVector { - - var ( - metadata = board.ListVariableMetadata() - inputs = make([]smartvectors.SmartVector, len(metadata)) - length = ExprIsOnSameLengthHandles(&board) - ) - - // Attempt to recover the size of the - for i := range inputs { - switch m := metadata[i].(type) { - case ifaces.Column: - inputs[i] = m.GetColAssignment(run) - case coin.Info: - v := run.GetRandomCoinField(m.Name) - inputs[i] = smartvectors.NewConstant(v, length) - case ifaces.Accessor: - v := m.GetVal(run) - inputs[i] = smartvectors.NewConstant(v, length) - case variables.PeriodicSample: - v := m.EvalCoset(length, 0, 1, false) - inputs[i] = v - case variables.X: - v := m.EvalCoset(length, 0, 1, false) - inputs[i] = v - } - } - - return board.Evaluate(inputs) -} - // returns the symbolic expression of a column obtained as a random linear combinations of differents handles // without committing to the column itself func RandLinCombColSymbolic(x coin.Info, hs []ifaces.Column) *symbolic.Expression { diff --git a/prover/protocol/wizardutils/utils.go b/prover/protocol/wizardutils/utils.go index ba58ac339..764092a1f 100644 --- a/prover/protocol/wizardutils/utils.go +++ b/prover/protocol/wizardutils/utils.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/variables" "github.com/consensys/linea-monorepo/prover/symbolic" @@ -41,44 +42,6 @@ func LastRoundToEval(expr *symbolic.Expression) int { return maxRound } -// ExprIsOnSameLengthHandles checks that all the variables of the expression -// that are [ifaces.Column] have the same size (and panics if it does not), then -// returns the match. -func ExprIsOnSameLengthHandles(board *symbolic.ExpressionBoard) int { - - var ( - metadatas = board.ListVariableMetadata() - length = 0 - ) - - for _, m := range metadatas { - switch metadata := m.(type) { - case ifaces.Column: - // Initialize the length with the first commitment - if length == 0 { - length = metadata.Size() - } - - // Sanity-check the vector should all have the same length - if length != metadata.Size() { - utils.Panic("Inconsistent length for %v (has size %v, but expected %v)", metadata.GetColID(), metadata.Size(), length) - } - // The expression can involve random coins - case coin.Info, variables.X, variables.PeriodicSample, ifaces.Accessor: - // Do nothing - default: - utils.Panic("unknown type %T", metadata) - } - } - - // No commitment were found in the metadata, thus this call is broken - if length == 0 { - utils.Panic("declared a handle from an expression which does not contains any handle") - } - - return length -} - // maximal round of declaration for a list of commitment func MaxRound(handles ...ifaces.Column) int { res := 0 @@ -128,7 +91,7 @@ func AsExpr(x any) (e *symbolic.Expression, round, size int) { case *symbolic.Expression: board := c1.Board() e = c1 - size = ExprIsOnSameLengthHandles(&board) + size = column.ExprIsOnSameLengthHandles(&board) round = LastRoundToEval(c1) return e, round, size } diff --git a/prover/protocol/wizardutils/wizardutils_test.go b/prover/protocol/wizardutils/wizardutils_test.go index 417a0a85a..fd2ac86bc 100644 --- a/prover/protocol/wizardutils/wizardutils_test.go +++ b/prover/protocol/wizardutils/wizardutils_test.go @@ -6,6 +6,7 @@ import ( "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" "github.com/consensys/linea-monorepo/prover/maths/field" "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/column/verifiercol" "github.com/consensys/linea-monorepo/prover/protocol/compiler/dummy" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" @@ -48,16 +49,16 @@ func TestWizarldutils(t *testing.T) { run.GetRandomCoinField(coin.Namef("Coin")) - res1Wit := EvalExprColumn(run, res1.Board()).IntoRegVecSaveAlloc() - res11Wit := EvalExprColumn(run, res11.Board()).IntoRegVecSaveAlloc() + res1Wit := column.EvalExprColumn(run, res1.Board()).IntoRegVecSaveAlloc() + res11Wit := column.EvalExprColumn(run, res11.Board()).IntoRegVecSaveAlloc() for i := range res11Wit { if res1Wit[i].Cmp(&res11Wit[i]) != 0 { panic("err") } } - res2Wit := EvalExprColumn(run, res2.Board()).IntoRegVecSaveAlloc() - res22Wit := EvalExprColumn(run, res22.Board()).IntoRegVecSaveAlloc() + res2Wit := column.EvalExprColumn(run, res2.Board()).IntoRegVecSaveAlloc() + res22Wit := column.EvalExprColumn(run, res22.Board()).IntoRegVecSaveAlloc() for i := range res11Wit { if res2Wit[i].Cmp(&res22Wit[i]) != 0 { panic("err") From 523767491dde2f3c4375b5025dc582276373d191 Mon Sep 17 00:00:00 2001 From: arijitdutta67 Date: Thu, 12 Dec 2024 17:08:13 +0530 Subject: [PATCH 11/12] adding module discoverer --- .../protocol/distributed/module_discover.go | 25 ------- .../module_discoverer/module_discoverer.go | 69 +++++++++++++++++++ 2 files changed, 69 insertions(+), 25 deletions(-) delete mode 100644 prover/protocol/distributed/module_discover.go create mode 100644 prover/protocol/distributed/module_discoverer/module_discoverer.go diff --git a/prover/protocol/distributed/module_discover.go b/prover/protocol/distributed/module_discover.go deleted file mode 100644 index 41393e0d2..000000000 --- a/prover/protocol/distributed/module_discover.go +++ /dev/null @@ -1,25 +0,0 @@ -package distributed - -import ( - "github.com/consensys/linea-monorepo/prover/protocol/ifaces" - "github.com/consensys/linea-monorepo/prover/protocol/wizard" -) - -// it implement [ModuleDiscoverer], it splits the compiler horizontally. -type HorizontalSplitting struct { - modules []string -} - -func (split HorizontalSplitting) Analyze(comp *wizard.CompiledIOP) { - -} - -func (split HorizontalSplitting) Split(comp *wizard.CompiledIOP) { - -} - -// QueryIsInModule checks if the given query is inside the given module -func (split HorizontalSplitting) QueryIsInModule(ifaces.Query, moduleName) bool { - panic("unimplemented") - -} diff --git a/prover/protocol/distributed/module_discoverer/module_discoverer.go b/prover/protocol/distributed/module_discoverer/module_discoverer.go new file mode 100644 index 000000000..88f343107 --- /dev/null +++ b/prover/protocol/distributed/module_discoverer/module_discoverer.go @@ -0,0 +1,69 @@ +package distributed + +import ( + "strings" + + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" +) + +type ModuleName string + +// Example struct implementing ModuleDiscoverer +type PeriodSeperatingModuleDiscoverer struct { + modules map[ModuleName][]ifaces.Column +} + +// Analyze groups columns into modules +func (p *PeriodSeperatingModuleDiscoverer) Analyze(comp *wizard.CompiledIOP) { + p.modules = make(map[ModuleName][]ifaces.Column) + numRounds := comp.NumRounds() + for i := range numRounds { + for _, col := range comp.Columns.AllHandlesAtRound(i) { // Assume comp.Columns exists + module := periodLogicToDetermineModule(col) + p.modules[module] = append(p.modules[module], col) + } + } +} + +func periodLogicToDetermineModule(col ifaces.Column) ModuleName { + colName := col.GetColID() + return ModuleName(periodSeparator(string(colName))) +} + +func periodSeparator(name string) string { + // Find the index of the first occurrence of a period + index := strings.Index(name, ".") + if index == -1 { + // If no period is found, return the original string + return name + } + // Return the substring before the first period + return name[:index] +} + +// NbModules returns the number of modules +func (p *PeriodSeperatingModuleDiscoverer) NbModules() int { + return len(p.modules) +} + +// ModuleList returns the list of module names +func (p *PeriodSeperatingModuleDiscoverer) ModuleList(comp *wizard.CompiledIOP) []ModuleName { + moduleNames := make([]ModuleName, 0, len(p.modules)) + for moduleName := range p.modules { + moduleNames = append(moduleNames, moduleName) + } + return moduleNames +} + +// FindModule finds the module name for a given column +func (p *PeriodSeperatingModuleDiscoverer) FindModule(col ifaces.Column) ModuleName { + for moduleName, columns := range p.modules { + for _, c := range columns { + if c == col { + return moduleName + } + } + } + return "no column found" // Return a default or error value +} From ab6b7fbed46cb3f8aa93d2cea93e98139bd191e4 Mon Sep 17 00:00:00 2001 From: Soleimani193 Date: Tue, 17 Dec 2024 09:57:47 +0100 Subject: [PATCH 12/12] added more methodes to module discoverer --- .../module_discoverer/module_discoverer.go | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/prover/protocol/distributed/module_discoverer/module_discoverer.go b/prover/protocol/distributed/module_discoverer/module_discoverer.go index 88f343107..7d7a911ae 100644 --- a/prover/protocol/distributed/module_discoverer/module_discoverer.go +++ b/prover/protocol/distributed/module_discoverer/module_discoverer.go @@ -5,6 +5,7 @@ import ( "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/symbolic" ) type ModuleName string @@ -67,3 +68,40 @@ func (p *PeriodSeperatingModuleDiscoverer) FindModule(col ifaces.Column) ModuleN } return "no column found" // Return a default or error value } + +// ColumnIsInModule checks that the given column is inside the given module. +func (p *PeriodSeperatingModuleDiscoverer) ColumnIsInModule(col ifaces.Column, name ModuleName) bool { + for _, c := range p.modules[name] { + if c.GetColID() == col.GetColID() { + return true + } + } + return false +} + +// ExpressionIsInModule checks that all the columns in the expression are from the given module. +// +// It does not check the presence of the coins and other metadata. +func (p *PeriodSeperatingModuleDiscoverer) ExpressionIsInModule(expr symbolic.Expression, name ModuleName) bool { + var ( + board = expr.Board() + metadata = board.ListVariableMetadata() + b = true + cols []ifaces.Column + ) + + for _, m := range metadata { + switch v := m.(type) { + case ifaces.Column: + if !p.ColumnIsInModule(v, name) { + b = b && false + cols = append(cols, v) + } + } + } + if len(cols) == 0 { + panic("could not find any column in the expression") + } else { + return b + } +}