diff --git a/prover/protocol/column/column.go b/prover/protocol/column/column.go index 3df39cfee..d44d4cc16 100644 --- a/prover/protocol/column/column.go +++ b/prover/protocol/column/column.go @@ -3,7 +3,11 @@ package column import ( "reflect" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/protocol/coin" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/variables" + "github.com/consensys/linea-monorepo/prover/symbolic" "github.com/consensys/linea-monorepo/prover/utils" ) @@ -91,3 +95,72 @@ func NbLeaves(h ifaces.Column) int { } panic("unreachable") } + +func EvalExprColumn(run ifaces.Runtime, board symbolic.ExpressionBoard) smartvectors.SmartVector { + + var ( + metadata = board.ListVariableMetadata() + inputs = make([]smartvectors.SmartVector, len(metadata)) + length = ExprIsOnSameLengthHandles(&board) + ) + + // Attempt to recover the size of the + for i := range inputs { + switch m := metadata[i].(type) { + case ifaces.Column: + inputs[i] = m.GetColAssignment(run) + case coin.Info: + v := run.GetRandomCoinField(m.Name) + inputs[i] = smartvectors.NewConstant(v, length) + case ifaces.Accessor: + v := m.GetVal(run) + inputs[i] = smartvectors.NewConstant(v, length) + case variables.PeriodicSample: + v := m.EvalCoset(length, 0, 1, false) + inputs[i] = v + case variables.X: + v := m.EvalCoset(length, 0, 1, false) + inputs[i] = v + } + } + + return board.Evaluate(inputs) +} + +// ExprIsOnSameLengthHandles checks that all the variables of the expression +// that are [ifaces.Column] have the same size (and panics if it does not), then +// returns the match. +func ExprIsOnSameLengthHandles(board *symbolic.ExpressionBoard) int { + + var ( + metadatas = board.ListVariableMetadata() + length = 0 + ) + + for _, m := range metadatas { + switch metadata := m.(type) { + case ifaces.Column: + // Initialize the length with the first commitment + if length == 0 { + length = metadata.Size() + } + + // Sanity-check the vector should all have the same length + if length != metadata.Size() { + utils.Panic("Inconsistent length for %v (has size %v, but expected %v)", metadata.GetColID(), metadata.Size(), length) + } + // The expression can involve random coins + case coin.Info, variables.X, variables.PeriodicSample, ifaces.Accessor: + // Do nothing + default: + utils.Panic("unknown type %T", metadata) + } + } + + // No commitment were found in the metadata, thus this call is broken + if length == 0 { + utils.Panic("declared a handle from an expression which does not contains any handle") + } + + return length +} diff --git a/prover/protocol/compiler/globalcs/merging.go b/prover/protocol/compiler/globalcs/merging.go index f7b04c453..8d6e22b0b 100644 --- a/prover/protocol/compiler/globalcs/merging.go +++ b/prover/protocol/compiler/globalcs/merging.go @@ -7,11 +7,11 @@ import ( "github.com/consensys/linea-monorepo/prover/maths/fft" "github.com/consensys/linea-monorepo/prover/maths/field" "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/protocol/variables" "github.com/consensys/linea-monorepo/prover/protocol/wizard" - "github.com/consensys/linea-monorepo/prover/protocol/wizardutils" "github.com/consensys/linea-monorepo/prover/symbolic" "github.com/consensys/linea-monorepo/prover/utils" ) @@ -178,7 +178,7 @@ func getBoundCancelledExpression(cs query.GlobalConstraint) *symbolic.Expression func getExprRatio(expr *symbolic.Expression) int { var ( board = expr.Board() - domainSize = wizardutils.ExprIsOnSameLengthHandles(&board) + domainSize = column.ExprIsOnSameLengthHandles(&board) exprDegree = board.Degree(GetDegree(domainSize)) quotientSize = exprDegree - domainSize + 1 ratio = utils.DivCeil(quotientSize, domainSize) diff --git a/prover/protocol/compiler/innerproduct/prover.go b/prover/protocol/compiler/innerproduct/prover.go index 6cc508b23..41ed30887 100644 --- a/prover/protocol/compiler/innerproduct/prover.go +++ b/prover/protocol/compiler/innerproduct/prover.go @@ -5,8 +5,8 @@ import ( "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/wizard" - "github.com/consensys/linea-monorepo/prover/protocol/wizardutils" ) // proverTask implements the [wizard.ProverAction] interface and as such @@ -39,7 +39,7 @@ func (ctx *contextForSize) run(run *wizard.ProverRuntime) { var ( size = ctx.Summation.Size() - collapsed = wizardutils.EvalExprColumn(run, ctx.CollapsedBoard).IntoRegVecSaveAlloc() + collapsed = column.EvalExprColumn(run, ctx.CollapsedBoard).IntoRegVecSaveAlloc() summation = make([]field.Element, size) ) diff --git a/prover/protocol/compiler/lookup/compiler.go b/prover/protocol/compiler/lookup/compiler.go index 6b4f8fea6..02123da11 100644 --- a/prover/protocol/compiler/lookup/compiler.go +++ b/prover/protocol/compiler/lookup/compiler.go @@ -27,13 +27,13 @@ type table = []ifaces.Column func CompileLogDerivative(comp *wizard.CompiledIOP) { var ( - mainLookupCtx = captureLookupTables(comp) + mainLookupCtx = CaptureLookupTables(comp) lastRound = comp.NumRounds() - 1 proverActions = make([]proverTaskAtRound, comp.NumRounds()+1) // zCatalog stores a mapping (round, size) into ZCtx and helps finding // which Z context should be used to handle a part of a given permutation // query. - zCatalog = map[[2]int]*zCtx{} + zCatalog = map[[2]int]*ZCtx{} zEntries = [][2]int{} // verifier actions va = finalEvaluationCheck{} @@ -41,25 +41,25 @@ func CompileLogDerivative(comp *wizard.CompiledIOP) { // Skip the compilation phase if no lookup constraint is being used. Otherwise // it will register a verifier action that is not required and will be bugged. - if len(mainLookupCtx.lookupTables) == 0 { + if len(mainLookupCtx.LookupTables) == 0 { return } // Step 1. construct the "per table" contexts and pack the Sigma's into // zCatalog. - for _, lookupTable := range mainLookupCtx.lookupTables { + for _, lookupTable := range mainLookupCtx.LookupTables { var ( // get checkedTables, rounds, Filters by lookupTableName - lookupTableName = nameTable(lookupTable) - checkTable = mainLookupCtx.checkedTables[lookupTableName] - round = mainLookupCtx.rounds[lookupTableName] - includedFilters = mainLookupCtx.includedFilters[lookupTableName] - tableCtx = compileLookupTable(comp, round, lookupTable, checkTable, includedFilters) + lookupTableName = NameTable(lookupTable) + checkTable = mainLookupCtx.CheckedTables[lookupTableName] + round = mainLookupCtx.Rounds[lookupTableName] + includedFilters = mainLookupCtx.IncludedFilters[lookupTableName] + tableCtx = CompileLookupTable(comp, round, lookupTable, checkTable, includedFilters) ) // push to zCatalog - tableCtx.pushToZCatalog(zCatalog) + tableCtx.PushToZCatalog(zCatalog) proverActions[round].pushMAssignment( mAssignmentTask{ @@ -119,7 +119,7 @@ func CompileLogDerivative(comp *wizard.CompiledIOP) { comp.RegisterVerifierAction(lastRound, &va) } -// captureLookupTables inspects comp and look for Inclusion queries that are not +// CaptureLookupTables inspects comp and look for Inclusion queries that are not // marked as ignored yet. All the queries matched queries are grouped by look-up // table (e.g. all the queries that use the same lookup table). All the matched // queries are marked as ignored. The function returns the thereby-initialized @@ -131,13 +131,13 @@ func CompileLogDerivative(comp *wizard.CompiledIOP) { // The function also implictly reduces the conditionals over the Including table // be appending a "one" column on the included side and the filter on the // including side. -func captureLookupTables(comp *wizard.CompiledIOP) mainLookupCtx { +func CaptureLookupTables(comp *wizard.CompiledIOP) mainLookupCtx { ctx := mainLookupCtx{ - lookupTables: [][]table{}, - checkedTables: map[string][]table{}, - includedFilters: map[string][]ifaces.Column{}, - rounds: map[string]int{}, + LookupTables: [][]table{}, + CheckedTables: map[string][]table{}, + IncludedFilters: map[string][]ifaces.Column{}, + Rounds: map[string]int{}, } // Collect all the lookup queries into "lookups" @@ -157,8 +157,8 @@ func captureLookupTables(comp *wizard.CompiledIOP) mainLookupCtx { var ( // checkedTable corresponds to the "included" table and lookupTable // corresponds to the including table. - checkedTable, lookupTable = getTableCanonicalOrder(lookup) - tableName = nameTable(lookupTable) + checkedTable, lookupTable = GetTableCanonicalOrder(lookup) + tableName = NameTable(lookupTable) // includedFilters stores the query.IncludedFilter parameter. If the // query has no includedFilters on the Included side. Then this is // left as nil. @@ -176,7 +176,7 @@ func captureLookupTables(comp *wizard.CompiledIOP) mainLookupCtx { lookupTable[frag] = append([]ifaces.Column{lookup.IncludingFilter[frag]}, lookupTable[frag]...) } - tableName = nameTable(lookupTable) + tableName = NameTable(lookupTable) } if lookup.IsFilteredOnIncluded() { @@ -185,16 +185,16 @@ func captureLookupTables(comp *wizard.CompiledIOP) mainLookupCtx { // In case this is the first iteration where we encounter the lookupTable // we need to add entries in the registering maps. - if _, ok := ctx.checkedTables[tableName]; !ok { - ctx.includedFilters[tableName] = []ifaces.Column{} - ctx.checkedTables[tableName] = []table{} - ctx.lookupTables = append(ctx.lookupTables, lookupTable) - ctx.rounds[tableName] = 0 + if _, ok := ctx.CheckedTables[tableName]; !ok { + ctx.IncludedFilters[tableName] = []ifaces.Column{} + ctx.CheckedTables[tableName] = []table{} + ctx.LookupTables = append(ctx.LookupTables, lookupTable) + ctx.Rounds[tableName] = 0 } - ctx.includedFilters[tableName] = append(ctx.includedFilters[tableName], includedFilter) - ctx.checkedTables[tableName] = append(ctx.checkedTables[tableName], checkedTable) - ctx.rounds[tableName] = max(ctx.rounds[tableName], comp.QueriesNoParams.Round(lookup.ID)) + ctx.IncludedFilters[tableName] = append(ctx.IncludedFilters[tableName], includedFilter) + ctx.CheckedTables[tableName] = append(ctx.CheckedTables[tableName], checkedTable) + ctx.Rounds[tableName] = max(ctx.Rounds[tableName], comp.QueriesNoParams.Round(lookup.ID)) } @@ -213,16 +213,16 @@ func captureLookupTables(comp *wizard.CompiledIOP) mainLookupCtx { // - (5) The verier makes a `Global` query : $\left((\Sigma_T)[i] - (\Sigma_T)[i-1]\right)(T_i + \gamma) = M_i$ // here we are looking up set of columns S in a single column T -func compileLookupTable( +func CompileLookupTable( comp *wizard.CompiledIOP, round int, lookupTable []table, checkedTables []table, includedFilters []ifaces.Column, -) (ctx singleTableCtx) { +) (ctx SingleTableCtx) { - ctx = singleTableCtx{ - TableName: nameTable(lookupTable), + ctx = SingleTableCtx{ + TableName: NameTable(lookupTable), S: make([]*symbolic.Expression, len(checkedTables)), SFilters: includedFilters, T: make([]*symbolic.Expression, len(lookupTable)), @@ -240,7 +240,7 @@ func compileLookupTable( ctx.T[frag] = symbolic.NewVariable(lookupTable[frag][0]) ctx.M[frag] = comp.InsertCommit( round, - deriveTableNameWithIndex[ifaces.ColID](logDerivativePrefix, lookupTable, frag, "M"), + DeriveTableNameWithIndex[ifaces.ColID](LogDerivativePrefix, lookupTable, frag, "M"), lookupTable[frag][0].Size(), ) @@ -257,7 +257,7 @@ func compileLookupTable( // columns of T and S when they are (both) multi-columns. alpha := comp.InsertCoin( round+1, - deriveTableName[coin.Name](logDerivativePrefix, lookupTable, "ALPHA"), + DeriveTableName[coin.Name](LogDerivativePrefix, lookupTable, "ALPHA"), coin.Field, ) @@ -265,7 +265,7 @@ func compileLookupTable( ctx.T[frag] = wizardutils.RandLinCombColSymbolic(alpha, lookupTable[frag]) ctx.M[frag] = comp.InsertCommit( round, - deriveTableNameWithIndex[ifaces.ColID](logDerivativePrefix, lookupTable, frag, "M"), + DeriveTableNameWithIndex[ifaces.ColID](LogDerivativePrefix, lookupTable, frag, "M"), lookupTable[frag][0].Size(), ) } @@ -277,16 +277,16 @@ func compileLookupTable( ctx.Gamma = comp.InsertCoin( round+1, - deriveTableName[coin.Name](logDerivativePrefix, lookupTable, "GAMMA"), + DeriveTableName[coin.Name](LogDerivativePrefix, lookupTable, "GAMMA"), coin.Field, ) return ctx } -// pushToZCatalog constructs the numerators and denominators for S and T of the +// PushToZCatalog constructs the numerators and denominators for S and T of the // stc into zCatalog for their corresponding rounds and size. -func (stc *singleTableCtx) pushToZCatalog(zCatalog map[[2]int]*zCtx) { +func (stc *SingleTableCtx) PushToZCatalog(zCatalog map[[2]int]*ZCtx) { var ( round = stc.Gamma.Round @@ -299,7 +299,7 @@ func (stc *singleTableCtx) pushToZCatalog(zCatalog map[[2]int]*zCtx) { key := [2]int{round, size} if zCatalog[key] == nil { - zCatalog[key] = &zCtx{ + zCatalog[key] = &ZCtx{ Size: size, Round: round, Name: stc.TableName, @@ -324,7 +324,7 @@ func (stc *singleTableCtx) pushToZCatalog(zCatalog map[[2]int]*zCtx) { key := [2]int{round, size} if zCatalog[key] == nil { - zCatalog[key] = &zCtx{ + zCatalog[key] = &ZCtx{ Size: size, Round: round, Name: stc.TableName, diff --git a/prover/protocol/compiler/lookup/context.go b/prover/protocol/compiler/lookup/context.go index 4ec961da1..3f30808ad 100644 --- a/prover/protocol/compiler/lookup/context.go +++ b/prover/protocol/compiler/lookup/context.go @@ -10,9 +10,9 @@ import ( // altogether. type mainLookupCtx struct { - // lookupTables stores all the lookup table the compiler encounters. They are + // LookupTables stores all the lookup table the compiler encounters. They are // sorted in canonical order. This used to derive a determistic ordering - // of the lookup lookupTables. (We want to ensure the compiler yields always + // of the lookup LookupTables. (We want to ensure the compiler yields always // exactly the same result for replicability). // // To illustrates its structure, the following sub-statement @@ -20,31 +20,31 @@ type mainLookupCtx struct { // table[numTable][frag] // // refers to to the fragment #frag of the the table #numTable. - lookupTables [][]table + LookupTables [][]table - // checkedTables stores all the checked column by lookup table. The key is + // CheckedTables stores all the checked column by lookup table. The key is // obtained as nameTable(lookupTable) where lookup is sorted in // canonical order. - checkedTables map[string][]table + CheckedTables map[string][]table - // includedFilters stores all the filters for the checked columns and `nil` + // IncludedFilters stores all the filters for the checked columns and `nil` // if no filter is applied. As for [checkedTables] they are stored by // lookup table name and in the same order for each key. - includedFilters map[string][]ifaces.Column + IncludedFilters map[string][]ifaces.Column - // rounds stores the interaction round assigned to each lookupTable. The - // round is obtained by taking the max of the declaration rounds of the + // Rounds stores the interaction round assigned to each lookupTable. The + // round is obtained by taking the max of the declaration Rounds of the // Inclusion queries using the corresponding lookup table. - rounds map[string]int + Rounds map[string]int } -// singleTableCtx stores the compilation context for a single lookup query +// SingleTableCtx stores the compilation context for a single lookup query // when it is compiled using the log-derivative lookup technique. // -// A singleTableCtx relates to a lookup table rather than a lookup query. This +// A SingleTableCtx relates to a lookup table rather than a lookup query. This // means that multiple lookup queries that are related to the same table will be // grouped into the same context. This allows optimizing the -type singleTableCtx struct { +type SingleTableCtx struct { // TableName reflects the name of the lookup table being compiled. TableName string diff --git a/prover/protocol/compiler/lookup/prover.go b/prover/protocol/compiler/lookup/prover.go index 063fa5a6b..72f7d44bc 100644 --- a/prover/protocol/compiler/lookup/prover.go +++ b/prover/protocol/compiler/lookup/prover.go @@ -7,6 +7,7 @@ import ( sv "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" "github.com/consensys/linea-monorepo/prover/maths/common/vector" "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/wizard" "github.com/consensys/linea-monorepo/prover/protocol/wizardutils" @@ -281,7 +282,7 @@ func (a mAssignmentTask) run(run *wizard.ProverRuntime) { } utils.Panic( "entry %v of the table %v is not included in the table. tableRow=%v", - k, nameTable([][]ifaces.Column{a.S[i]}), vector.Prettify(tableRow), + k, NameTable([][]ifaces.Column{a.S[i]}), vector.Prettify(tableRow), ) } @@ -300,7 +301,7 @@ func (a mAssignmentTask) run(run *wizard.ProverRuntime) { // zAssignmentTask represents a prover task of assignming the columns // SigmaS and SigmaT for a specific lookup table. // sigmaAssignment -type zAssignmentTask zCtx +type zAssignmentTask ZCtx func (z zAssignmentTask) run(run *wizard.ProverRuntime) { parallel.Execute(len(z.ZDenominatorBoarded), func(start, stop int) { @@ -308,7 +309,7 @@ func (z zAssignmentTask) run(run *wizard.ProverRuntime) { var ( numeratorMetadata = z.ZNumeratorBoarded[frag].ListVariableMetadata() - denominator = wizardutils.EvalExprColumn(run, z.ZDenominatorBoarded[frag]).IntoRegVecSaveAlloc() + denominator = column.EvalExprColumn(run, z.ZDenominatorBoarded[frag]).IntoRegVecSaveAlloc() numerator []field.Element packedZ = field.BatchInvert(denominator) ) @@ -318,7 +319,7 @@ func (z zAssignmentTask) run(run *wizard.ProverRuntime) { } if len(numeratorMetadata) > 0 { - numerator = wizardutils.EvalExprColumn(run, z.ZNumeratorBoarded[frag]).IntoRegVecSaveAlloc() + numerator = column.EvalExprColumn(run, z.ZNumeratorBoarded[frag]).IntoRegVecSaveAlloc() } for k := range packedZ { diff --git a/prover/protocol/compiler/lookup/utils.go b/prover/protocol/compiler/lookup/utils.go index a6d4eeb27..17cbe8858 100644 --- a/prover/protocol/compiler/lookup/utils.go +++ b/prover/protocol/compiler/lookup/utils.go @@ -10,12 +10,12 @@ import ( ) const ( - // logDerivativePrefix is a prefix that we commonly use to derive query, + // LogDerivativePrefix is a prefix that we commonly use to derive query, // coin or column names that are introduced by the compiler. - logDerivativePrefix = "LOGDERIVATIVE" + LogDerivativePrefix = "LOGDERIVATIVE" ) -// getTableCanonicalOrder extracts the lookup table and the queried tables +// GetTableCanonicalOrder extracts the lookup table and the queried tables // from `q` and rearrange them conjointly so that the names of T are returned // in alphabetical order. // @@ -28,7 +28,7 @@ const ( // can always make sure to specify the table in the same order all the time. // // Importantly, the function allocates its own result. -func getTableCanonicalOrder(q query.Inclusion) ([]ifaces.Column, [][]ifaces.Column) { +func GetTableCanonicalOrder(q query.Inclusion) ([]ifaces.Column, [][]ifaces.Column) { if len(q.Including) > 1 { // The append here are performing a deep-copy of the slice within the @@ -67,8 +67,8 @@ func getTableCanonicalOrder(q query.Inclusion) ([]ifaces.Column, [][]ifaces.Colu return checked, [][]ifaces.Column{table} } -// deriveName constructs a generic name -func deriveName[R ~string](args ...any) R { +// DeriveName constructs a generic name +func DeriveName[R ~string](args ...any) R { argStr := []string{"LOOKUP_LOGDERIVATIVE"} for _, arg := range args { argStr = append(argStr, fmt.Sprintf("%v", arg)) @@ -76,25 +76,25 @@ func deriveName[R ~string](args ...any) R { return R(strings.Join(argStr, "_")) } -// deriveTableName constructs a name for the table `t`. The caller may provide +// DeriveTableName constructs a name for the table `t`. The caller may provide // a context and a suffix to the name. If `t` is empty, the name is the // concatenation of `context` and `name` separated by an underscore. -func deriveTableName[R ~string](context string, t [][]ifaces.Column, name string) R { - res := fmt.Sprintf("%v_%v_%v", nameTable(t), context, name) +func DeriveTableName[R ~string](context string, t [][]ifaces.Column, name string) R { + res := fmt.Sprintf("%v_%v_%v", NameTable(t), context, name) return R(res) } -// deriveTableNameWithIndex is as [deriveTableName] but additionally allows +// DeriveTableNameWithIndex is as [deriveTableName] but additionally allows // appending an integer index in the name. -func deriveTableNameWithIndex[R ~string](context string, t [][]ifaces.Column, index int, name string) R { - res := fmt.Sprintf("%v_%v_%v_%v", nameTable(t), index, context, name) +func DeriveTableNameWithIndex[R ~string](context string, t [][]ifaces.Column, index int, name string) R { + res := fmt.Sprintf("%v_%v_%v_%v", NameTable(t), index, context, name) return R(res) } -// nameTable returns a unique name corresponding to the provided +// NameTable returns a unique name corresponding to the provided // sequence of columns `t`. The unique name is constructed by appending the // name of all the column separated by an underscore. -func nameTable(t []table) string { +func NameTable(t []table) string { // This single fragment case is managed as a special case although it is // not really one. This is for backwards compatibility. if len(t) == 1 { diff --git a/prover/protocol/compiler/lookup/z_packing.go b/prover/protocol/compiler/lookup/z_packing.go index d6d12ffbe..2838f4834 100644 --- a/prover/protocol/compiler/lookup/z_packing.go +++ b/prover/protocol/compiler/lookup/z_packing.go @@ -26,7 +26,7 @@ const ( // T: lookupTable, // SFilter: includedFilters, -type zCtx struct { +type ZCtx struct { Round, Size int SigmaNumerator []*sym.Expression // T -> -M, S -> +Filter SigmaDenominator []*sym.Expression // S or T -> ({S,T} + X) @@ -44,7 +44,7 @@ type zCtx struct { // and change T -> -M, S -> +Filter // S or T -> ({S,T} + X) // compile should be called inside CompileGrandSum -func (z *zCtx) compile(comp *wizard.CompiledIOP) { +func (z *ZCtx) compile(comp *wizard.CompiledIOP) { var ( numZs = utils.DivCeil( @@ -82,14 +82,14 @@ func (z *zCtx) compile(comp *wizard.CompiledIOP) { z.Zs[i] = comp.InsertCommit( z.Round, - deriveName[ifaces.ColID]("Z", comp.SelfRecursionCount, z.Round, z.Size, i), + DeriveName[ifaces.ColID]("Z", comp.SelfRecursionCount, z.Round, z.Size, i), z.Size, ) // initial condition comp.InsertLocal( z.Round, - deriveName[ifaces.QueryID]("Z_CONSISTENCY_START", comp.SelfRecursionCount, z.Round, z.Size, i), + DeriveName[ifaces.QueryID]("Z_CONSISTENCY_START", comp.SelfRecursionCount, z.Round, z.Size, i), sym.Sub( zNumerator, sym.Mul( @@ -102,7 +102,7 @@ func (z *zCtx) compile(comp *wizard.CompiledIOP) { // consistency check comp.InsertGlobal( z.Round, - deriveName[ifaces.QueryID]("Z_CONSISTENCY", comp.SelfRecursionCount, z.Round, z.Size, i), + DeriveName[ifaces.QueryID]("Z_CONSISTENCY", comp.SelfRecursionCount, z.Round, z.Size, i), sym.Sub( zNumerator, sym.Mul( @@ -115,7 +115,7 @@ func (z *zCtx) compile(comp *wizard.CompiledIOP) { // local opening of the final value of the Z polynomial z.ZOpenings[i] = comp.InsertLocalOpening( z.Round, - deriveName[ifaces.QueryID]("Z_FINAL", comp.SelfRecursionCount, z.Round, z.Size, i), + DeriveName[ifaces.QueryID]("Z_FINAL", comp.SelfRecursionCount, z.Round, z.Size, i), column.Shift(z.Zs[i], -1), ) } diff --git a/prover/protocol/compiler/mimc/manual.go b/prover/protocol/compiler/mimc/manual.go index a204392ab..50885feb6 100644 --- a/prover/protocol/compiler/mimc/manual.go +++ b/prover/protocol/compiler/mimc/manual.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/consensys/linea-monorepo/prover/crypto/mimc" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/wizard" "github.com/consensys/linea-monorepo/prover/protocol/wizardutils" @@ -155,7 +156,7 @@ func mimcExprHandle(comp *wizard.CompiledIOP, expr *symbolic.Expression, name .. maxRound := wizardutils.LastRoundToEval(expr) board := expr.Board() - length := wizardutils.ExprIsOnSameLengthHandles(&board) + length := column.ExprIsOnSameLengthHandles(&board) handleName := fmt.Sprintf("SYMBOLIC_%v", expr.ESHash.String()) if len(name) > 0 { diff --git a/prover/protocol/compiler/permutation/prover.go b/prover/protocol/compiler/permutation/prover.go index ab5d0e79a..190da0ce2 100644 --- a/prover/protocol/compiler/permutation/prover.go +++ b/prover/protocol/compiler/permutation/prover.go @@ -6,8 +6,8 @@ import ( "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" "github.com/consensys/linea-monorepo/prover/maths/common/vector" "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/wizard" - "github.com/consensys/linea-monorepo/prover/protocol/wizardutils" ) // proverTaskAtRound implements the [wizard.ProverAction] interface and is @@ -48,13 +48,13 @@ func (z *ZCtx) run(run *wizard.ProverRuntime) { ) if packingArity*i < len(z.NumeratorFactors) { - numerator = wizardutils.EvalExprColumn(run, z.NumeratorFactorsBoarded[i]).IntoRegVecSaveAlloc() + numerator = column.EvalExprColumn(run, z.NumeratorFactorsBoarded[i]).IntoRegVecSaveAlloc() } else { numerator = vector.Repeat(field.One(), z.Size) } if packingArity*i < len(z.DenominatorFactors) { - denominator = wizardutils.EvalExprColumn(run, z.DenominatorFactorsBoarded[i]).IntoRegVecSaveAlloc() + denominator = column.EvalExprColumn(run, z.DenominatorFactorsBoarded[i]).IntoRegVecSaveAlloc() } else { denominator = vector.Repeat(field.One(), z.Size) } diff --git a/prover/protocol/dedicated/bigrange/bigrange.go b/prover/protocol/dedicated/bigrange/bigrange.go index 60244ca90..c5192a819 100644 --- a/prover/protocol/dedicated/bigrange/bigrange.go +++ b/prover/protocol/dedicated/bigrange/bigrange.go @@ -11,6 +11,7 @@ import ( "github.com/consensys/linea-monorepo/prover/maths/fft" "github.com/consensys/linea-monorepo/prover/maths/field" "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/variables" "github.com/consensys/linea-monorepo/prover/protocol/wizard" @@ -41,7 +42,7 @@ func BigRange(comp *wizard.CompiledIOP, expr *symbolic.Expression, numLimbs, bit limbs = make([]ifaces.Column, numLimbs) round = wizardutils.LastRoundToEval(expr) boarded = expr.Board() - size = wizardutils.ExprIsOnSameLengthHandles(&boarded) + size = column.ExprIsOnSameLengthHandles(&boarded) totalNumBits = numLimbs * bitPerLimbs ) diff --git a/prover/protocol/dedicated/byte32cmp/decompose.go b/prover/protocol/dedicated/byte32cmp/decompose.go index 6058cca73..538412887 100644 --- a/prover/protocol/dedicated/byte32cmp/decompose.go +++ b/prover/protocol/dedicated/byte32cmp/decompose.go @@ -170,7 +170,7 @@ func (d *decompositionCtx) Run(run *wizard.ProverRuntime) { numLimbs = len(d.decomposed.Limbs) bitPerLimbs = d.decomposed.LimbBitSize totalNumBits = numLimbs * bitPerLimbs - original = wizardutils.EvalExprColumn(run, d.original) + original = column.EvalExprColumn(run, d.original) limbsWitness = make([][]field.Element, numLimbs) size = original.Len() ) diff --git a/prover/protocol/dedicated/byte32cmp/multi_limb_cmp.go b/prover/protocol/dedicated/byte32cmp/multi_limb_cmp.go index 8ff17c5a6..35d96e11d 100644 --- a/prover/protocol/dedicated/byte32cmp/multi_limb_cmp.go +++ b/prover/protocol/dedicated/byte32cmp/multi_limb_cmp.go @@ -6,6 +6,7 @@ import ( "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/dedicated" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/wizard" @@ -192,7 +193,7 @@ func (mCmp *multiLimbCmp) Run(run *wizard.ProverRuntime) { }() var ( - syndrom = wizardutils.EvalExprColumn(run, mCmp.syndromBoard) + syndrom = column.EvalExprColumn(run, mCmp.syndromBoard) isGreater = make([]field.Element, mCmp.isGreater.Size()) isLower = make([]field.Element, mCmp.isLower.Size()) nnSyndrom = make([]field.Element, mCmp.isLower.Size()) diff --git a/prover/protocol/dedicated/expr_handle/expr_handle.go b/prover/protocol/dedicated/expr_handle/expr_handle.go index 4e72c2259..7e72eae06 100644 --- a/prover/protocol/dedicated/expr_handle/expr_handle.go +++ b/prover/protocol/dedicated/expr_handle/expr_handle.go @@ -8,6 +8,7 @@ import ( "github.com/consensys/linea-monorepo/prover/maths/fft" "github.com/consensys/linea-monorepo/prover/maths/field" "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/variables" "github.com/consensys/linea-monorepo/prover/protocol/wizard" @@ -25,7 +26,7 @@ func ExprHandle(comp *wizard.CompiledIOP, expr *symbolic.Expression, name ...str var ( boarded = expr.Board() maxRound = wizardutils.LastRoundToEval(expr) - length = wizardutils.ExprIsOnSameLengthHandles(&boarded) + length = column.ExprIsOnSameLengthHandles(&boarded) handleName = fmt.Sprintf("SYMBOLIC_%v", expr.ESHash.String()) ) diff --git a/prover/protocol/dedicated/is_zero.go b/prover/protocol/dedicated/is_zero.go index 208886c7f..68df50089 100644 --- a/prover/protocol/dedicated/is_zero.go +++ b/prover/protocol/dedicated/is_zero.go @@ -2,6 +2,7 @@ package dedicated import ( "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/wizard" "github.com/consensys/linea-monorepo/prover/protocol/wizardutils" @@ -44,7 +45,7 @@ func IsZero(comp *wizard.CompiledIOP, c any) (ifaces.Column, wizard.ProverAction case *sym.Expression: board := c1.Board() ctx.c = c1 - ctx.size = wizardutils.ExprIsOnSameLengthHandles(&board) + ctx.size = column.ExprIsOnSameLengthHandles(&board) ctx.round = wizardutils.LastRoundToEval(c1) } @@ -123,13 +124,13 @@ func compileIsZeroWithSize(comp *wizard.CompiledIOP, ctx *isZeroCtx) { // Run implements the [wizard.ProverAction] interface func (ctx *isZeroCtx) Run(run *wizard.ProverRuntime) { var ( - c = wizardutils.EvalExprColumn(run, ctx.c.Board()) + c = column.EvalExprColumn(run, ctx.c.Board()) invOrZero = smartvectors.BatchInvert(c) isZero = smartvectors.IsZero(c) ) if ctx.mask != nil { - mask := wizardutils.EvalExprColumn(run, ctx.mask.Board()) + mask := column.EvalExprColumn(run, ctx.mask.Board()) invOrZero = smartvectors.Mul(invOrZero, mask) isZero = smartvectors.Mul(isZero, mask) } diff --git a/prover/protocol/dedicated/projection/projection.go b/prover/protocol/dedicated/projection/projection.go index efa137fb4..8141e01da 100644 --- a/prover/protocol/dedicated/projection/projection.go +++ b/prover/protocol/dedicated/projection/projection.go @@ -219,8 +219,8 @@ func InsertProjection( func (pa projectionProverAction) Run(run *wizard.ProverRuntime) { var ( - a = wizardutils.EvalExprColumn(run, pa.ABoard).IntoRegVecSaveAlloc() - b = wizardutils.EvalExprColumn(run, pa.BBoard).IntoRegVecSaveAlloc() + a = column.EvalExprColumn(run, pa.ABoard).IntoRegVecSaveAlloc() + b = column.EvalExprColumn(run, pa.BBoard).IntoRegVecSaveAlloc() fA = pa.FilterA.GetColAssignment(run).IntoRegVecSaveAlloc() fB = pa.FilterB.GetColAssignment(run).IntoRegVecSaveAlloc() x = run.GetRandomCoinField(pa.EvalCoin.Name) diff --git a/prover/protocol/distributed/compiler/inclusion/inclusion.go b/prover/protocol/distributed/compiler/inclusion/inclusion.go index 43af9174f..f30d503e8 100644 --- a/prover/protocol/distributed/compiler/inclusion/inclusion.go +++ b/prover/protocol/distributed/compiler/inclusion/inclusion.go @@ -1,17 +1,74 @@ package inclusion import ( - "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/distributed" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/symbolic" ) -// CompileDist compiles inclusion queries distributedly. -// It receives a compiledIOP object relevant to a segment. -// The seed is a random coin from randomness beacon (FS of all LPP commitments). -// All the compilation steps are similar to the permutation compilation apart from: -// - random coins \alpha and \gamma are generated from the seed (and the tableName). -// - no verifierAction is needed over the ZOpening. -// - ZOpenings are declared as public input. -func CompileDist(comp *wizard.CompiledIOP, seed coin.Info) { +const ( + LogDerivativeSum = "LOGDERIVATIVE_SUM" +) + +// distributionInputs stores the input required for the distribution of a LogDerivativeSum query. +type distributionInputs struct { + comp *wizard.CompiledIOP + // module Discoverer used to detect the relevant part of the query to the module + disc distributed.ModuleDiscoverer + // Name of the module + moduleName string + // query is supposed to be the global LogDerivativeSum. + queryID ifaces.QueryID +} + +// DistributeLogDerivativeSum distribute the LogDerivativeSum among the modules. +// It detect the relevant share of the module from the global LogDerivativeSum. +// It generates a new LogDerivateSum query relevant to the module. +func DistributeLogDerivativeSum(in distributionInputs) { + var ( + comp = in.comp + numerator []*symbolic.Expression + denominator []*symbolic.Expression + zCatalog map[[2]int]*query.LogDerivativeSumInput + lastRound = in.comp.NumRounds() - 1 + ) + // check that the given query is a valid LogDerivateSum query in the CompiledIOP. + logDeriv, ok := comp.QueriesParams.Data(in.queryID).(query.LogDerivativeSum) + if !ok { + panic("the given query is not a valid LogDerivativeSum from the compiledIOP") + } + + // This ensures that the logDerivative query is not used again in the + // compilation process. + comp.QueriesNoParams.MarkAsIgnored(in.queryID) + + // extract the share of the module from the global sum. + for key := range logDeriv.Inputs { + for i := range logDeriv.Inputs[key].Numerator { + if in.disc.ExpressionIsInModule(logDeriv.Inputs[key].Numerator[i], in.moduleName) { + numerator = append(numerator, logDeriv.Inputs[key].Numerator[i]) + } + if in.disc.ExpressionIsInModule(logDeriv.Inputs[key].Denominator[i], in.moduleName) { + denominator = append(denominator, logDeriv.Inputs[key].Denominator[i]) + } + } + + // zCatalog specific to the module + zCatalog[key] = &query.LogDerivativeSumInput{ + Round: key[0], + Size: key[1], + Numerator: numerator, + Denominator: denominator, + } + + } + // insert a LogDerivativeSum specific to the module. + comp.InsertLogDerivativeSum( + lastRound, + ifaces.QueryIDf("%v_%v", LogDerivativeSum, in.moduleName), + zCatalog, + ) } diff --git a/prover/protocol/distributed/distributed.go b/prover/protocol/distributed/distributed.go index dbe01f6f2..63097eabc 100644 --- a/prover/protocol/distributed/distributed.go +++ b/prover/protocol/distributed/distributed.go @@ -1,12 +1,10 @@ package distributed import ( - "github.com/consensys/linea-monorepo/prover/protocol/compiler/innerproduct" - "github.com/consensys/linea-monorepo/prover/protocol/compiler/mimc" - "github.com/consensys/linea-monorepo/prover/protocol/compiler/specialqueries" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/symbolic" ) type moduleName = string @@ -36,6 +34,7 @@ type ModuleDiscoverer interface { FindModule(col ifaces.Column) moduleName // given a query and a module name it checks if the query is inside the module QueryIsInModule(ifaces.Query, moduleName) bool + ExpressionIsInModule(*symbolic.Expression, moduleName) bool } // This transforms the initial wizard. So it is not really the initial @@ -118,14 +117,3 @@ func addToGlobalLocal(comp *wizard.CompiledIOP, q ifaces.Query) { func aggregator(distModules []DistributedModule, maxNumSegments int) *wizard.CompiledIOP { panic("unimplemented") } - -// prepare reduces any query to LPP or GL. -// it prepares the columns that depends on whole the witness,e.g., M column for lookups. -func prepare(comp *wizard.CompiledIOP) { - mimc.CompileMiMC(comp) - specialqueries.RangeProof(comp) - specialqueries.CompileFixedPermutations(comp) - innerproduct.Compile(comp) - - // prepareLookup(comp) -} diff --git a/prover/protocol/distributed/module_discover.go b/prover/protocol/distributed/module_discover.go deleted file mode 100644 index 41393e0d2..000000000 --- a/prover/protocol/distributed/module_discover.go +++ /dev/null @@ -1,25 +0,0 @@ -package distributed - -import ( - "github.com/consensys/linea-monorepo/prover/protocol/ifaces" - "github.com/consensys/linea-monorepo/prover/protocol/wizard" -) - -// it implement [ModuleDiscoverer], it splits the compiler horizontally. -type HorizontalSplitting struct { - modules []string -} - -func (split HorizontalSplitting) Analyze(comp *wizard.CompiledIOP) { - -} - -func (split HorizontalSplitting) Split(comp *wizard.CompiledIOP) { - -} - -// QueryIsInModule checks if the given query is inside the given module -func (split HorizontalSplitting) QueryIsInModule(ifaces.Query, moduleName) bool { - panic("unimplemented") - -} diff --git a/prover/protocol/distributed/module_discoverer/module_discoverer.go b/prover/protocol/distributed/module_discoverer/module_discoverer.go new file mode 100644 index 000000000..7d7a911ae --- /dev/null +++ b/prover/protocol/distributed/module_discoverer/module_discoverer.go @@ -0,0 +1,107 @@ +package distributed + +import ( + "strings" + + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/symbolic" +) + +type ModuleName string + +// Example struct implementing ModuleDiscoverer +type PeriodSeperatingModuleDiscoverer struct { + modules map[ModuleName][]ifaces.Column +} + +// Analyze groups columns into modules +func (p *PeriodSeperatingModuleDiscoverer) Analyze(comp *wizard.CompiledIOP) { + p.modules = make(map[ModuleName][]ifaces.Column) + numRounds := comp.NumRounds() + for i := range numRounds { + for _, col := range comp.Columns.AllHandlesAtRound(i) { // Assume comp.Columns exists + module := periodLogicToDetermineModule(col) + p.modules[module] = append(p.modules[module], col) + } + } +} + +func periodLogicToDetermineModule(col ifaces.Column) ModuleName { + colName := col.GetColID() + return ModuleName(periodSeparator(string(colName))) +} + +func periodSeparator(name string) string { + // Find the index of the first occurrence of a period + index := strings.Index(name, ".") + if index == -1 { + // If no period is found, return the original string + return name + } + // Return the substring before the first period + return name[:index] +} + +// NbModules returns the number of modules +func (p *PeriodSeperatingModuleDiscoverer) NbModules() int { + return len(p.modules) +} + +// ModuleList returns the list of module names +func (p *PeriodSeperatingModuleDiscoverer) ModuleList(comp *wizard.CompiledIOP) []ModuleName { + moduleNames := make([]ModuleName, 0, len(p.modules)) + for moduleName := range p.modules { + moduleNames = append(moduleNames, moduleName) + } + return moduleNames +} + +// FindModule finds the module name for a given column +func (p *PeriodSeperatingModuleDiscoverer) FindModule(col ifaces.Column) ModuleName { + for moduleName, columns := range p.modules { + for _, c := range columns { + if c == col { + return moduleName + } + } + } + return "no column found" // Return a default or error value +} + +// ColumnIsInModule checks that the given column is inside the given module. +func (p *PeriodSeperatingModuleDiscoverer) ColumnIsInModule(col ifaces.Column, name ModuleName) bool { + for _, c := range p.modules[name] { + if c.GetColID() == col.GetColID() { + return true + } + } + return false +} + +// ExpressionIsInModule checks that all the columns in the expression are from the given module. +// +// It does not check the presence of the coins and other metadata. +func (p *PeriodSeperatingModuleDiscoverer) ExpressionIsInModule(expr symbolic.Expression, name ModuleName) bool { + var ( + board = expr.Board() + metadata = board.ListVariableMetadata() + b = true + cols []ifaces.Column + ) + + for _, m := range metadata { + switch v := m.(type) { + case ifaces.Column: + if !p.ColumnIsInModule(v, name) { + b = b && false + cols = append(cols, v) + } + } + } + if len(cols) == 0 { + panic("could not find any column in the expression") + } else { + return b + } +} diff --git a/prover/protocol/distributed/preparation.go b/prover/protocol/distributed/preparation.go new file mode 100644 index 000000000..9784129b4 --- /dev/null +++ b/prover/protocol/distributed/preparation.go @@ -0,0 +1,118 @@ +package distributed + +import ( + "github.com/consensys/linea-monorepo/prover/protocol/compiler/innerproduct" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/lookup" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/mimc" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/specialqueries" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/protocol/wizardutils" + "github.com/consensys/linea-monorepo/prover/symbolic" +) + +// prepare reduces any query to LPP or GL. +// For Inclusion query, it push the compilation one step further: +// all the inclusion queries are compiled into a LogDarivativeSum query, +// This is required due to the challenge that table M depends on whole the witness. +func prepare(comp *wizard.CompiledIOP) { + mimc.CompileMiMC(comp) + specialqueries.RangeProof(comp) + specialqueries.CompileFixedPermutations(comp) + innerproduct.Compile(comp) + + IntoLogDerivativeSum(comp) +} + +// IIntoLogDerivativeSum compiles all the inclusion queries to a single LogDerivativeSum query that is ready for the split. +// This step is necessary for inclusion, +// as the M table depends on the whole witness and so can not be handled modules-wise without changing the API of WizardIOP. +func IntoLogDerivativeSum(comp *wizard.CompiledIOP) { + + var ( + mainLookupCtx = lookup.CaptureLookupTables(comp) + lastRound = comp.NumRounds() - 1 + // zCatalog stores a mapping (round, size) into query.LogDerivativeSumInput and helps finding + // which Z context should be used to handle a part of a given inclusion + // query. + zCatalog = map[[2]int]*query.LogDerivativeSumInput{} + ) + + // Skip the compilation phase if no lookup constraint is being used. Otherwise + // it will register a verifier action that is not required and will be bugged. + if len(mainLookupCtx.LookupTables) == 0 { + return + } + + // Step 1. construct the "per table" contexts and pack the Sigma's into + // zCatalog. + for _, lookupTable := range mainLookupCtx.LookupTables { + + var ( + // get checkedTables, rounds, Filters by lookupTableName + lookupTableName = lookup.NameTable(lookupTable) + checkTable = mainLookupCtx.CheckedTables[lookupTableName] + round = mainLookupCtx.Rounds[lookupTableName] + includedFilters = mainLookupCtx.IncludedFilters[lookupTableName] + // collapse multiColumns to single Columns + tableCtx = lookup.CompileLookupTable(comp, round, lookupTable, checkTable, includedFilters) + ) + + // push single-columns into zCatalog + PushToZCatalog(tableCtx, zCatalog) + + // insert a single LogDerivativeSum query for the global zCatalog. + comp.InsertLogDerivativeSum(lastRound, "GlobalLogDerivativeSum", zCatalog) + } +} + +// PushToZCatalog constructs the numerators and denominators for the collapsed S and T +// into zCatalog, for their corresponding rounds and size. +func PushToZCatalog(stc lookup.SingleTableCtx, zCatalog map[[2]int]*query.LogDerivativeSumInput) { + + var ( + round = stc.Gamma.Round + ) + + // tableCtx push to -> zCtx + // Process the T columns + for frag := range stc.T { + size := stc.M[frag].Size() + + key := [2]int{round, size} + if zCatalog[key] == nil { + zCatalog[key] = &query.LogDerivativeSumInput{ + Size: size, + Round: round, + } + } + + zCtxEntry := zCatalog[key] + zCtxEntry.Numerator = append(zCtxEntry.Numerator, symbolic.Neg(stc.M[frag])) // no functions for num, denom here + zCtxEntry.Denominator = append(zCtxEntry.Denominator, symbolic.Add(stc.Gamma, stc.T[frag])) + } + + // Process the S columns + for table := range stc.S { + var ( + _, _, size = wizardutils.AsExpr(stc.S[table]) + sFilter = symbolic.NewConstant(1) + ) + + if stc.SFilters[table] != nil { + sFilter = symbolic.NewVariable(stc.SFilters[table]) + } + + key := [2]int{round, size} + if zCatalog[key] == nil { + zCatalog[key] = &query.LogDerivativeSumInput{ + Size: size, + Round: round, + } + } + + zCtxEntry := zCatalog[key] + zCtxEntry.Numerator = append(zCtxEntry.Numerator, sFilter) + zCtxEntry.Denominator = append(zCtxEntry.Denominator, symbolic.Add(stc.Gamma, stc.S[table])) + } +} diff --git a/prover/protocol/query/gnark_params.go b/prover/protocol/query/gnark_params.go index 76c60b3b0..fc2ddabbf 100644 --- a/prover/protocol/query/gnark_params.go +++ b/prover/protocol/query/gnark_params.go @@ -15,6 +15,14 @@ func (p LocalOpeningParams) GnarkAssign() GnarkLocalOpeningParams { return GnarkLocalOpeningParams{Y: p.Y} } +type GnarkLogDerivSumParams struct { + Y frontend.Variable +} + +func (p LogDerivSumParams) GnarkAssign() GnarkLogDerivSumParams { + return GnarkLogDerivSumParams{Y: p.Sum} +} + // A gnark circuit version of InnerProductParams type GnarkInnerProductParams struct { Ys []frontend.Variable @@ -54,6 +62,11 @@ func (p GnarkLocalOpeningParams) UpdateFS(fs *fiatshamir.GnarkFiatShamir) { fs.Update(p.Y) } +// Update the fiat-shamir state with the the present parameters +func (p GnarkLogDerivSumParams) UpdateFS(fs *fiatshamir.GnarkFiatShamir) { + fs.Update(p.Y) +} + // Update the fiat-shamir state with the the present parameters func (p GnarkUnivariateEvalParams) UpdateFS(fs *fiatshamir.GnarkFiatShamir) { fs.Update(p.Ys...) diff --git a/prover/protocol/query/logderiv_sum.go b/prover/protocol/query/logderiv_sum.go new file mode 100644 index 000000000..8617809ab --- /dev/null +++ b/prover/protocol/query/logderiv_sum.go @@ -0,0 +1,134 @@ +package query + +import ( + "fmt" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/linea-monorepo/prover/crypto/fiatshamir" + "github.com/consensys/linea-monorepo/prover/maths/common/vector" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/column" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + sym "github.com/consensys/linea-monorepo/prover/symbolic" + "github.com/consensys/linea-monorepo/prover/utils" +) + +// LogDerivativeSumInput stores the input to the query +type LogDerivativeSumInput struct { + Round, Size int + Numerator []*sym.Expression // T -> -M, S -> +Filter + Denominator []*sym.Expression // S or T -> ({S,T} + X) +} + +// LogDerivativeSum is the context of LogDerivativeSum query. +// The fields are maps from [round, size]. +// the aim of the query is to compute: +// \sum_{i,j} N_{i,j}/D_{i,j} where +// N_{i,j} is the i-th element of the underlying column of j-th Numerator +// D_{i,j} is the i-th element of the underlying column of j-th Denominator +type LogDerivativeSum struct { + Inputs map[[2]int]*LogDerivativeSumInput + + ZNumeratorBoarded, ZDenominatorBoarded map[[2]int][]sym.ExpressionBoard + + Zs map[[2]int][]ifaces.Column + // ZOpenings are the opening queries to the end of each Z. + ZOpenings map[[2]int][]LocalOpening + + ID ifaces.QueryID +} + +// the result of the global Sum +type LogDerivSumParams struct { + Sum field.Element // the sum of all the ZOpenings from different [round,size]. +} + +// Updates a Fiat-Shamir state +func (l LogDerivSumParams) UpdateFS(fs *fiatshamir.State) { + fs.Update(l.Sum) +} + +// NewLogDerivativeSum creates the new context LogDerivativeSum. +func NewLogDerivativeSum(inp map[[2]int]*LogDerivativeSumInput, id ifaces.QueryID) LogDerivativeSum { + + // check the length consistency + for key := range inp { + if len(inp[key].Numerator) != len(inp[key].Denominator) || len(inp[key].Numerator) == 0 { + panic("Numerator and Denominator should have the same (no-zero) length") + } + for i := range inp[key].Numerator { + if err := inp[key].Numerator[i].Validate(); err != nil { + utils.Panic(" Numerator[%v] is not a valid expression", i) + } + if err := inp[key].Denominator[i].Validate(); err != nil { + utils.Panic(" Denominator[%v] is not a valid expression", i) + } + } + } + + return LogDerivativeSum{ + Inputs: inp, + ID: id, + } + +} + +// Name implements the [ifaces.Query] interface +func (r LogDerivativeSum) Name() ifaces.QueryID { + return r.ID +} + +// Constructor for the query parameters/result +func NewLogDerivSumParams(sum field.Element) LogDerivSumParams { + return LogDerivSumParams{Sum: sum} +} + +// Test that global sum is correct +func (r LogDerivativeSum) Check(run ifaces.Runtime) error { + params := run.GetParams(r.ID).(LogDerivSumParams) + // compute the actual sum from the Numerator and Denominator + actualSum := field.Zero() + for key := range r.Inputs { + for i, num := range r.Inputs[key].Numerator { + + var ( + numBoard = num.Board() + denBoard = r.Inputs[key].Denominator[i].Board() + numeratorMetadata = numBoard.ListVariableMetadata() + denominator = column.EvalExprColumn(run, denBoard).IntoRegVecSaveAlloc() + numerator []field.Element + packedZ = field.BatchInvert(denominator) + ) + + if len(numeratorMetadata) == 0 { + numerator = vector.Repeat(field.One(), r.Inputs[key].Size) + } + + if len(numeratorMetadata) > 0 { + numerator = column.EvalExprColumn(run, numBoard).IntoRegVecSaveAlloc() + } + + for k := range packedZ { + packedZ[k].Mul(&numerator[k], &packedZ[k]) + if k > 0 { + packedZ[k].Add(&packedZ[k], &packedZ[k-1]) + } + } + actualSum.Add(&actualSum, &packedZ[len(packedZ)-1]) + } + } + + if actualSum != params.Sum { + return fmt.Errorf("expected LogDerivativeSum = %s but got %s for the query %v", params.Sum.String(), actualSum.String(), r.ID) + } + + return nil +} + +// Test that global sum is correct +func (r LogDerivativeSum) CheckGnark(api frontend.API, run ifaces.GnarkRuntime) { + /*params := run.GetParams(r.ID).(GnarkLogDerivSumParams) + actualY := TBD + api.AssertIsEqual(params.Y, actualY) + */ +} diff --git a/prover/protocol/query/logderiv_sum_test.go b/prover/protocol/query/logderiv_sum_test.go new file mode 100644 index 000000000..1a6283f4f --- /dev/null +++ b/prover/protocol/query/logderiv_sum_test.go @@ -0,0 +1,71 @@ +package query_test + +import ( + "testing" + + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/dummy" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/symbolic" + "github.com/stretchr/testify/require" +) + +func TestLogDerivSum(t *testing.T) { + + define := func(b *wizard.Builder) { + var ( + comp = b.CompiledIOP + ) + + p0 := b.RegisterCommit("Num_0", 4) + p1 := b.RegisterCommit("Num_1", 4) + p2 := b.RegisterCommit("Num_2", 4) + + q0 := b.RegisterCommit("Den_0", 4) + q1 := b.RegisterCommit("Den_1", 4) + q2 := b.RegisterCommit("Den_2", 4) + + numerators := []*symbolic.Expression{ + symbolic.Mul(p0, -1), + ifaces.ColumnAsVariable(p1), + symbolic.Mul(p2, p0, 2), + } + + denominators := []*symbolic.Expression{ + ifaces.ColumnAsVariable(q0), + ifaces.ColumnAsVariable(q1), + ifaces.ColumnAsVariable(q2), + } + + key := [2]int{0, 0} + zCat1 := map[[2]int]*query.LogDerivativeSumInput{} + zCat1[key] = &query.LogDerivativeSumInput{ + Numerator: numerators, + Denominator: denominators, + } + comp.InsertLogDerivativeSum(0, "LogDerivSum_Test", zCat1) + + } + + prover := func(run *wizard.ProverRuntime) { + + run.AssignColumn("Num_0", smartvectors.ForTest(1, 1, 1, 1)) + run.AssignColumn("Num_1", smartvectors.ForTest(2, 3, 7, 9)) + run.AssignColumn("Num_2", smartvectors.ForTest(5, 6, 1, 1)) + + run.AssignColumn("Den_0", smartvectors.ForTest(1, 1, 1, 1)) + run.AssignColumn("Den_1", smartvectors.ForTest(2, 3, 7, 9)) + run.AssignColumn("Den_2", smartvectors.ForTest(5, 6, 1, 1)) + + run.AssignLogDerivSum("LogDerivSum_Test", field.NewElement(8)) + + } + + compiled := wizard.Compile(define, dummy.Compile) + proof := wizard.Prove(compiled, prover) + valid := wizard.Verify(compiled, proof) + require.NoError(t, valid) +} diff --git a/prover/protocol/wizard/compiled.go b/prover/protocol/wizard/compiled.go index 6cf5c37a3..2ccc7f554 100644 --- a/prover/protocol/wizard/compiled.go +++ b/prover/protocol/wizard/compiled.go @@ -579,6 +579,17 @@ func (c *CompiledIOP) InsertLocalOpening(round int, name ifaces.QueryID, pol ifa return q } +// InsertLogDerivativeSum registers a new LogDerivativeSum query [query.LogDerivativeSum]. +// It generates a single global summation for many Sigma Columns from Lookup compilation. +// The sigma columns are categorized by [round,size]. +func (c *CompiledIOP) InsertLogDerivativeSum(lastRound int, id ifaces.QueryID, in map[[2]int]*query.LogDerivativeSumInput) query.LogDerivativeSum { + c.assertConsistentRound(lastRound) + q := query.NewLogDerivativeSum(in, id) + // Finally registers the query + c.QueriesParams.AddToRound(lastRound, id, q) + return q +} + // assertConsistentRound compares the round passed as an argument and panic if it greater than // coin.Round. This helps ensuring that we do not have "useless" rounds. func (c *CompiledIOP) assertConsistentRound(round int) { diff --git a/prover/protocol/wizard/gnark_verifier.go b/prover/protocol/wizard/gnark_verifier.go index 841c3755c..efe06be28 100644 --- a/prover/protocol/wizard/gnark_verifier.go +++ b/prover/protocol/wizard/gnark_verifier.go @@ -47,6 +47,8 @@ type WizardVerifierCircuit struct { innerProductIDs collection.Mapping[ifaces.QueryID, int] `gnark:"-"` // Same for local-opening query localOpeningIDs collection.Mapping[ifaces.QueryID, int] `gnark:"-"` + // Same for logDerivativeSum query + logDerivSumIDs collection.Mapping[ifaces.QueryID, int] `gnark:"-"` // Columns stores the gnark witness part corresponding to the columns // provided in the proof and in the VerifyingKey. @@ -63,6 +65,9 @@ type WizardVerifierCircuit struct { // LocalOpeningParams stores an assignment for each [query.LocalOpeningParams] // from the proof. It is part of the witness of the gnark circuit. LocalOpeningParams []query.GnarkLocalOpeningParams `gnark:",secret"` + // LogDerivSumParams stores an assignment for each [query.LogDerivSumParams] + // from the proof. It is part of the witness of the gnark circuit. + LogDerivSumParams []query.GnarkLogDerivSumParams `gnark:",secret"` // FS is the Fiat-Shamir state, mirroring [VerifierRuntime.FS]. The same // cautionnary rules apply to it; e.g. don't use it externally when @@ -311,6 +316,14 @@ func (c *WizardVerifierCircuit) GetLocalPointEvalParams(name ifaces.QueryID) que return c.LocalOpeningParams[qID] } +// GetLogDerivSumParams returns the parameters for the requested +// [query.LogDerivativeSum] query. Its work mirrors the function +// [VerifierRuntime.GetLogDerivSumParams] +func (c *WizardVerifierCircuit) GetLogDerivSumParams(name ifaces.QueryID) query.GnarkLocalOpeningParams { + qID := c.localOpeningIDs.MustGet(name) + return c.LocalOpeningParams[qID] +} + // GetColumns returns the gnark assignment of a column in a gnark circuit. It // mirrors the function [VerifierRuntime.GetColumn] func (c *WizardVerifierCircuit) GetColumn(name ifaces.ColID) []frontend.Variable { @@ -350,11 +363,13 @@ func newWizardVerifierCircuit() *WizardVerifierCircuit { res.columnsIDs = collection.NewMapping[ifaces.ColID, int]() res.univariateParamsIDs = collection.NewMapping[ifaces.QueryID, int]() res.localOpeningIDs = collection.NewMapping[ifaces.QueryID, int]() + res.logDerivSumIDs = collection.NewMapping[ifaces.QueryID, int]() res.innerProductIDs = collection.NewMapping[ifaces.QueryID, int]() res.Columns = [][]frontend.Variable{} res.UnivariateParams = make([]query.GnarkUnivariateEvalParams, 0) res.InnerProductParams = make([]query.GnarkInnerProductParams, 0) res.LocalOpeningParams = make([]query.GnarkLocalOpeningParams, 0) + res.LogDerivSumParams = make([]query.GnarkLogDerivSumParams, 0) res.Coins = collection.NewMapping[coin.Name, interface{}]() return res } @@ -418,6 +433,9 @@ func GetWizardVerifierCircuitAssignment(comp *CompiledIOP, proof Proof) *WizardV case query.LocalOpeningParams: res.localOpeningIDs.InsertNew(qName, len(res.LocalOpeningParams)) res.LocalOpeningParams = append(res.LocalOpeningParams, params.GnarkAssign()) + case query.LogDerivSumParams: + res.logDerivSumIDs.InsertNew(qName, len(res.LogDerivSumParams)) + res.LogDerivSumParams = append(res.LogDerivSumParams, params.GnarkAssign()) default: utils.Panic("unknow type %T", params) @@ -435,6 +453,8 @@ func (c *WizardVerifierCircuit) GetParams(id ifaces.QueryID) ifaces.GnarkQueryPa return c.GetUnivariateParams(id) case query.LocalOpening: return c.GetLocalPointEvalParams(id) + case query.LogDerivativeSum: + return c.GetLogDerivSumParams(id) case query.InnerProduct: return c.GetInnerProductParams(id) default: diff --git a/prover/protocol/wizard/prover.go b/prover/protocol/wizard/prover.go index caa4df4aa..7d42cdb61 100644 --- a/prover/protocol/wizard/prover.go +++ b/prover/protocol/wizard/prover.go @@ -691,6 +691,43 @@ func (run *ProverRuntime) GetLocalPointEvalParams(name ifaces.QueryID) query.Loc return run.QueriesParams.MustGet(name).(query.LocalOpeningParams) } +// AssignLogDerivSum assign the claimed values for a logDeriveSum +// The function will panic if: +// - the parameters were already assigned +// - the specified query is not registered +// - the assignment round is incorrect +func (run *ProverRuntime) AssignLogDerivSum(name ifaces.QueryID, y field.Element) { + + // Global prover locks for accessing the maps + run.lock.Lock() + defer run.lock.Unlock() + + // Make sure, it is done at the right round + run.Spec.QueriesParams.MustBeInRound(run.currRound, name) + + // Adds it to the assignments + params := query.NewLogDerivSumParams(y) + run.QueriesParams.InsertNew(name, params) +} + +// GetLogDeriveSum gets the metadata of a [query.LogDerivativeSum] query. Panic if not found. +func (run *ProverRuntime) GetLogDeriveSum(name ifaces.QueryID) query.LogDerivativeSum { + // Global prover locks for accessing the maps + run.lock.Lock() + defer run.lock.Unlock() + return run.Spec.QueriesParams.Data(name).(query.LogDerivativeSum) +} + +// GetLogDerivSumParams returns the parameters of [query.LogDerivativeSum] +func (run *ProverRuntime) GetLogDerivSumParams(name ifaces.QueryID) query.LogDerivSumParams { + + // Global prover's lock for accessing params + run.lock.Lock() + defer run.lock.Unlock() + + return run.QueriesParams.MustGet(name).(query.LogDerivSumParams) +} + // GetParams generically extracts the parameters of a query. Will panic if no // parameters are found func (run *ProverRuntime) GetParams(name ifaces.QueryID) ifaces.QueryParams { diff --git a/prover/protocol/wizard/verifier.go b/prover/protocol/wizard/verifier.go index 1a355796e..2c5b8da5f 100644 --- a/prover/protocol/wizard/verifier.go +++ b/prover/protocol/wizard/verifier.go @@ -308,6 +308,11 @@ func (run *VerifierRuntime) GetLocalPointEvalParams(name ifaces.QueryID) query.L return run.QueriesParams.MustGet(name).(query.LocalOpeningParams) } +// GetLogDerivSumParams returns the parameters of a [query.LogDerivativeSum] +func (run *VerifierRuntime) GetLogDerivSumParams(name ifaces.QueryID) query.LogDerivSumParams { + return run.QueriesParams.MustGet(name).(query.LogDerivSumParams) +} + /* CopyColumnInto implements `column.GetWitness` Copies the witness into a slice diff --git a/prover/protocol/wizardutils/evaluation.go b/prover/protocol/wizardutils/evaluation.go index 7381060ac..a3b93acc8 100644 --- a/prover/protocol/wizardutils/evaluation.go +++ b/prover/protocol/wizardutils/evaluation.go @@ -5,50 +5,10 @@ import ( "github.com/consensys/linea-monorepo/prover/maths/field" "github.com/consensys/linea-monorepo/prover/protocol/coin" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" - "github.com/consensys/linea-monorepo/prover/protocol/variables" "github.com/consensys/linea-monorepo/prover/protocol/wizard" "github.com/consensys/linea-monorepo/prover/symbolic" ) -// EvalExprColumn resolves an expression to a column assignment. The expression -// must be converted to a board prior to evaluating the expression. -// -// - If the expression does not uses ifaces.Column as metadata, the function -// will panic. -// -// - If the expression contains several columns and they don't contain all -// have the same size. -func EvalExprColumn(run *wizard.ProverRuntime, board symbolic.ExpressionBoard) smartvectors.SmartVector { - - var ( - metadata = board.ListVariableMetadata() - inputs = make([]smartvectors.SmartVector, len(metadata)) - length = ExprIsOnSameLengthHandles(&board) - ) - - // Attempt to recover the size of the - for i := range inputs { - switch m := metadata[i].(type) { - case ifaces.Column: - inputs[i] = m.GetColAssignment(run) - case coin.Info: - v := run.GetRandomCoinField(m.Name) - inputs[i] = smartvectors.NewConstant(v, length) - case ifaces.Accessor: - v := m.GetVal(run) - inputs[i] = smartvectors.NewConstant(v, length) - case variables.PeriodicSample: - v := m.EvalCoset(length, 0, 1, false) - inputs[i] = v - case variables.X: - v := m.EvalCoset(length, 0, 1, false) - inputs[i] = v - } - } - - return board.Evaluate(inputs) -} - // returns the symbolic expression of a column obtained as a random linear combinations of differents handles // without committing to the column itself func RandLinCombColSymbolic(x coin.Info, hs []ifaces.Column) *symbolic.Expression { diff --git a/prover/protocol/wizardutils/utils.go b/prover/protocol/wizardutils/utils.go index ba58ac339..764092a1f 100644 --- a/prover/protocol/wizardutils/utils.go +++ b/prover/protocol/wizardutils/utils.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/variables" "github.com/consensys/linea-monorepo/prover/symbolic" @@ -41,44 +42,6 @@ func LastRoundToEval(expr *symbolic.Expression) int { return maxRound } -// ExprIsOnSameLengthHandles checks that all the variables of the expression -// that are [ifaces.Column] have the same size (and panics if it does not), then -// returns the match. -func ExprIsOnSameLengthHandles(board *symbolic.ExpressionBoard) int { - - var ( - metadatas = board.ListVariableMetadata() - length = 0 - ) - - for _, m := range metadatas { - switch metadata := m.(type) { - case ifaces.Column: - // Initialize the length with the first commitment - if length == 0 { - length = metadata.Size() - } - - // Sanity-check the vector should all have the same length - if length != metadata.Size() { - utils.Panic("Inconsistent length for %v (has size %v, but expected %v)", metadata.GetColID(), metadata.Size(), length) - } - // The expression can involve random coins - case coin.Info, variables.X, variables.PeriodicSample, ifaces.Accessor: - // Do nothing - default: - utils.Panic("unknown type %T", metadata) - } - } - - // No commitment were found in the metadata, thus this call is broken - if length == 0 { - utils.Panic("declared a handle from an expression which does not contains any handle") - } - - return length -} - // maximal round of declaration for a list of commitment func MaxRound(handles ...ifaces.Column) int { res := 0 @@ -128,7 +91,7 @@ func AsExpr(x any) (e *symbolic.Expression, round, size int) { case *symbolic.Expression: board := c1.Board() e = c1 - size = ExprIsOnSameLengthHandles(&board) + size = column.ExprIsOnSameLengthHandles(&board) round = LastRoundToEval(c1) return e, round, size } diff --git a/prover/protocol/wizardutils/wizardutils_test.go b/prover/protocol/wizardutils/wizardutils_test.go index 417a0a85a..fd2ac86bc 100644 --- a/prover/protocol/wizardutils/wizardutils_test.go +++ b/prover/protocol/wizardutils/wizardutils_test.go @@ -6,6 +6,7 @@ import ( "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" "github.com/consensys/linea-monorepo/prover/maths/field" "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/column/verifiercol" "github.com/consensys/linea-monorepo/prover/protocol/compiler/dummy" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" @@ -48,16 +49,16 @@ func TestWizarldutils(t *testing.T) { run.GetRandomCoinField(coin.Namef("Coin")) - res1Wit := EvalExprColumn(run, res1.Board()).IntoRegVecSaveAlloc() - res11Wit := EvalExprColumn(run, res11.Board()).IntoRegVecSaveAlloc() + res1Wit := column.EvalExprColumn(run, res1.Board()).IntoRegVecSaveAlloc() + res11Wit := column.EvalExprColumn(run, res11.Board()).IntoRegVecSaveAlloc() for i := range res11Wit { if res1Wit[i].Cmp(&res11Wit[i]) != 0 { panic("err") } } - res2Wit := EvalExprColumn(run, res2.Board()).IntoRegVecSaveAlloc() - res22Wit := EvalExprColumn(run, res22.Board()).IntoRegVecSaveAlloc() + res2Wit := column.EvalExprColumn(run, res2.Board()).IntoRegVecSaveAlloc() + res22Wit := column.EvalExprColumn(run, res22.Board()).IntoRegVecSaveAlloc() for i := range res11Wit { if res2Wit[i].Cmp(&res22Wit[i]) != 0 { panic("err")