Skip to content

Commit

Permalink
chore: Add new domain abstraction (#86)
Browse files Browse the repository at this point in the history
* add new Domain abstraction

* refactor rest of code to use new domain package

* fix linter

* chore: add back `TestSRSConversion`
  • Loading branch information
kevaundray authored Aug 20, 2024
1 parent 4aa46d3 commit d969d3a
Show file tree
Hide file tree
Showing 17 changed files with 201 additions and 129 deletions.
13 changes: 7 additions & 6 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package goethkzg
import (
"encoding/json"

"github.com/crate-crypto/go-eth-kzg/internal/domain"
"github.com/crate-crypto/go-eth-kzg/internal/kzg"
kzgmulti "github.com/crate-crypto/go-eth-kzg/internal/kzg_multi"
"github.com/crate-crypto/go-eth-kzg/internal/kzg_multi/fk20"
Expand All @@ -13,8 +14,8 @@ import (
// Note: We could marshall this object so that clients won't need to process the SRS each time. The time to process is
// about 2-5 seconds.
type Context struct {
domain *kzg.Domain
domainExtended *kzg.Domain
domain *domain.Domain
domainExtended *domain.Domain
commitKeyLagrange *kzg.CommitKey
commitKeyMonomial *kzg.CommitKey
openKey *kzg.OpeningKey
Expand Down Expand Up @@ -121,20 +122,20 @@ func NewContext4096(trustedSetup *JSONTrustedSetup) (*Context, error) {
G2: setupG2Points,
}

domain := kzg.NewDomain(ScalarsPerBlob)
domainBlobLen := domain.NewDomain(ScalarsPerBlob)
// Bit-Reverse the roots and the trusted setup according to the specs
// The bit reversal is not needed for simple KZG however it was
// implemented to make the step for full dank-sharding easier.
commitKeyLagrange.ReversePoints()
domain.ReverseRoots()
domainBlobLen.ReverseRoots()

domainExtended := kzg.NewDomain(scalarsPerExtBlob)
domainExtended := domain.NewDomain(scalarsPerExtBlob)
domainExtended.ReverseRoots()

fk20 := fk20.NewFK20(commitKeyMonomial.G1, scalarsPerExtBlob, scalarsPerCell)

return &Context{
domain: domain,
domain: domainBlobLen,
domainExtended: domainExtended,
commitKeyLagrange: &commitKeyLagrange,
commitKeyMonomial: &commitKeyMonomial,
Expand Down
8 changes: 4 additions & 4 deletions api_eip7594.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"slices"

"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
"github.com/crate-crypto/go-eth-kzg/internal/kzg"
"github.com/crate-crypto/go-eth-kzg/internal/domain"
kzgmulti "github.com/crate-crypto/go-eth-kzg/internal/kzg_multi"
)

Expand All @@ -16,7 +16,7 @@ func (ctx *Context) ComputeCellsAndKZGProofs(blob *Blob, numGoRoutines int) ([Ce
}

// Bit reverse the polynomial representing the Blob so that it is in normal order
kzg.BitReverse(polynomial)
domain.BitReverse(polynomial)

// Convert the polynomial in lagrange form to a polynomial in monomial form
polyCoeff := ctx.domain.IfftFr(polynomial)
Expand Down Expand Up @@ -100,7 +100,7 @@ func (ctx *Context) RecoverCellsAndComputeKZGProofs(cellIDs []uint64, cells []*C
missingCellIds := make([]uint64, 0, CellsPerExtBlob)
for cellID := uint64(0); cellID < CellsPerExtBlob; cellID++ {
if !slices.Contains(cellIDs, cellID) {
missingCellIds = append(missingCellIds, (kzg.BitReverseInt(cellID, CellsPerExtBlob)))
missingCellIds = append(missingCellIds, (domain.BitReverseInt(cellID, CellsPerExtBlob)))
}
}

Expand All @@ -119,7 +119,7 @@ func (ctx *Context) RecoverCellsAndComputeKZGProofs(cellIDs []uint64, cells []*C
copy(extendedBlob[cellID*scalarsPerCell:], cellEvals)
}
// Bit reverse the extendedBlob so that it is in normal order
kzg.BitReverse(extendedBlob)
domain.BitReverse(extendedBlob)

polyCoeff, err := ctx.dataRecovery.RecoverPolynomialCoefficients(extendedBlob, missingCellIds)
if err != nil {
Expand Down
67 changes: 67 additions & 0 deletions internal/domain/coset_fft.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package domain

import (
"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
)

// FFTCoset represents a coset for Fast Fourier Transform operations.
// It contains the generator of the coset and its inverse.
type FFTCoset struct {
// CosetGen is the generator element of the coset.
// It's used to shift the domain for coset FFT operations.
CosetGen fr.Element

// InvCosetGen is the inverse of the coset generator.
// It's used in inverse coset FFT operations to shift back to the original domain.
InvCosetGen fr.Element
}

// CosetDomain represents a domain for performing FFT operations over a coset.
// It combines a standard FFT domain with coset information for efficient coset FFT computations.
type CosetDomain struct {
// domain is the underlying FFT domain.
domain *Domain

// coset contains the coset generator and its inverse for this domain.
coset FFTCoset
}

// NewCosetDomain creates a new CosetDomain with the given Domain and FFTCoset.
func NewCosetDomain(domain *Domain, fft_coset FFTCoset) *CosetDomain {
return &CosetDomain{
domain: domain,
coset: fft_coset,
}
}

// CosetFFtFr performs a forward coset FFT on the input values.
//
// It first scales the input values by powers of the coset generator,
// then performs a standard FFT on the scaled values.
func (d *CosetDomain) CosetFFtFr(values []fr.Element) []fr.Element {
result := make([]fr.Element, len(values))

cosetScale := fr.One()
for i := 0; i < len(values); i++ {
result[i].Mul(&values[i], &cosetScale)
cosetScale.Mul(&cosetScale, &d.coset.CosetGen)
}

return d.domain.FftFr(result)
}

// CosetIFFtFr performs an inverse coset FFT on the input values.
//
// It first performs a standard inverse FFT, then scales the results
// by powers of the inverse coset generator to shift back to the original domain.
func (d *CosetDomain) CosetIFFtFr(values []fr.Element) []fr.Element {
result := d.domain.IfftFr(values)

cosetScale := fr.One()
for i := 0; i < len(result); i++ {
result[i].Mul(&result[i], &cosetScale)
cosetScale.Mul(&cosetScale, &d.coset.InvCosetGen)
}

return result
}
25 changes: 8 additions & 17 deletions internal/kzg/domain.go → internal/domain/domain.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package kzg
package domain

import (
"fmt"
Expand Down Expand Up @@ -39,12 +39,6 @@ type Domain struct {
// f(x)/g(x) where g(x) is a linear polynomial
// which vanishes on a point on the domain
PreComputedInverses []fr.Element

// CosetGenerator is the generator for the coset domain.
CosetGenerator fr.Element

// CosetGeneratorInv is the inverse of the generator for the coset domain.
CosetGeneratorInv fr.Element
}

// NewDomain returns a new domain with the desired number of points x.
Expand Down Expand Up @@ -100,9 +94,6 @@ func NewDomain(x uint64) *Domain {
// We use BatchInvert instead of the above for clarity.
domain.PreComputedInverses = fr.BatchInvert(domain.Roots)

domain.CosetGenerator = fr.NewElement(7)
domain.CosetGeneratorInv.Inverse(&domain.CosetGenerator)

return domain
}

Expand Down Expand Up @@ -165,11 +156,11 @@ func (domain *Domain) ReverseRoots() {
BitReverse(domain.PreComputedInverses)
}

// findRootIndex returns the index of the element in the domain or -1 if not found.
// FindRootIndex returns the index of the element in the domain or -1 if not found.
//
// - If point is in the domain (meaning that point is a domain.Cardinality'th root of unity), returns the index of the point in the domain.
// - If point is not in the domain, returns -1.
func (domain *Domain) findRootIndex(point fr.Element) int64 {
func (domain *Domain) FindRootIndex(point fr.Element) int64 {
for i := int64(0); i < int64(domain.Cardinality); i++ {
if point.Equal(&domain.Roots[i]) {
return i
Expand All @@ -185,21 +176,21 @@ func (domain *Domain) findRootIndex(point fr.Element) int64 {
// If len(poly) != domain.Cardinality, returns an error.
//
// [evaluate_polynomial_in_evaluation_form]: https://github.com/ethereum/consensus-specs/blob/017a8495f7671f5fff2075a9bfc9238c1a0982f8/specs/deneb/polynomial-commitments.md#evaluate_polynomial_in_evaluation_form
func (domain *Domain) EvaluateLagrangePolynomial(poly Polynomial, evalPoint fr.Element) (*fr.Element, error) {
outputPoint, _, err := domain.evaluateLagrangePolynomial(poly, evalPoint)
func (domain *Domain) EvaluateLagrangePolynomial(poly []fr.Element, evalPoint fr.Element) (*fr.Element, error) {
outputPoint, _, err := domain.EvaluateLagrangePolynomialWithIndex(poly, evalPoint)

return outputPoint, err
}

// evaluateLagrangePolynomial is the implementation for [EvaluateLagrangePolynomial].
// EvaluateLagrangePolynomialWithIndex is the implementation for [EvaluateLagrangePolynomial].
//
// It evaluates a Lagrange polynomial at the given point of evaluation and reports whether the given point was among the points of the domain:
// - The input polynomial is given in evaluation form, that is, a list of evaluations at the points in the domain.
// - The evaluationResult is the result of evaluation at evalPoint.
// - indexInDomain is the index inside domain.Roots, if evalPoint is among them, -1 otherwise
//
// This semantics was copied from the go library, see: https://cs.opensource.google/go/x/exp/+/522b1b58:slices/slices.go;l=117
func (domain *Domain) evaluateLagrangePolynomial(poly Polynomial, evalPoint fr.Element) (*fr.Element, int64, error) {
func (domain *Domain) EvaluateLagrangePolynomialWithIndex(poly []fr.Element, evalPoint fr.Element) (*fr.Element, int64, error) {
var indexInDomain int64 = -1

if domain.Cardinality != uint64(len(poly)) {
Expand All @@ -210,7 +201,7 @@ func (domain *Domain) evaluateLagrangePolynomial(poly Polynomial, evalPoint fr.E
// then evaluation of the polynomial in lagrange form
// is the same as indexing it with the position
// that the evaluation point is in, in the domain
indexInDomain = domain.findRootIndex(evalPoint)
indexInDomain = domain.FindRootIndex(evalPoint)
if indexInDomain != -1 {
return &poly[indexInDomain], indexInDomain, nil
}
Expand Down
10 changes: 5 additions & 5 deletions internal/kzg/domain_test.go → internal/domain/domain_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package kzg
package domain

import (
"crypto/rand"
Expand Down Expand Up @@ -102,7 +102,7 @@ func TestEvalPolynomialSmoke(t *testing.T) {

// lagrangePoly are the evaluations of the coefficient polynomial over
// `domain`
lagrangePoly := make(Polynomial, domain.Cardinality)
lagrangePoly := make([]fr.Element, domain.Cardinality)
for i := 0; i < int(domain.Cardinality); i++ {
x := domain.Roots[i]
lagrangePoly[i] = f(x)
Expand All @@ -113,7 +113,7 @@ func TestEvalPolynomialSmoke(t *testing.T) {
for i := int64(0); i < int64(domain.Cardinality); i++ {
inputPoint := domain.Roots[i]

gotOutputPoint, indexInDomain, err := domain.evaluateLagrangePolynomial(lagrangePoly, inputPoint)
gotOutputPoint, indexInDomain, err := domain.EvaluateLagrangePolynomialWithIndex(lagrangePoly, inputPoint)
if err != nil {
t.Error(err)
}
Expand All @@ -137,7 +137,7 @@ func TestEvalPolynomialSmoke(t *testing.T) {
// Sample some random point
inputPoint := samplePointOutsideDomain(*domain)

gotOutputPoint, indexInDomain, err := domain.evaluateLagrangePolynomial(lagrangePoly, *inputPoint)
gotOutputPoint, indexInDomain, err := domain.EvaluateLagrangePolynomialWithIndex(lagrangePoly, *inputPoint)
if err != nil {
t.Errorf(err.Error(), inputPoint.Bytes())
}
Expand All @@ -161,7 +161,7 @@ func samplePointOutsideDomain(domain Domain) *fr.Element {

for {
randElement.SetUint64(randUint64())
if domain.findRootIndex(randElement) == -1 {
if domain.FindRootIndex(randElement) == -1 {
break
}
}
Expand Down
5 changes: 5 additions & 0 deletions internal/domain/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package domain

import "errors"

var ErrPolynomialMismatchedSizeDomain = errors.New("domain size does not equal the number of evaluations in the polynomial")
26 changes: 1 addition & 25 deletions internal/kzg/fft.go → internal/domain/fft.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package kzg
package domain

import (
"math/big"
Expand Down Expand Up @@ -92,30 +92,6 @@ func fftG1(values []bls12381.G1Affine, nthRootOfUnity fr.Element) []bls12381.G1A
return evaluations
}

func (d *Domain) CosetFFtFr(values []fr.Element) []fr.Element {
result := make([]fr.Element, len(values))

cosetScale := fr.One()
for i := 0; i < len(values); i++ {
result[i].Mul(&values[i], &cosetScale)
cosetScale.Mul(&cosetScale, &d.CosetGenerator)
}

return d.FftFr(result)
}

func (d *Domain) CosetIFFtFr(values []fr.Element) []fr.Element {
result := d.IfftFr(values)

cosetScale := fr.One()
for i := 0; i < len(result); i++ {
result[i].Mul(&result[i], &cosetScale)
cosetScale.Mul(&cosetScale, &d.CosetGeneratorInv)
}

return result
}

func (d *Domain) FftFr(values []fr.Element) []fr.Element {
return fftFr(values, d.Generator)
}
Expand Down
34 changes: 8 additions & 26 deletions internal/kzg/fft_test.go → internal/domain/fft_test.go
Original file line number Diff line number Diff line change
@@ -1,34 +1,11 @@
package kzg
package domain

import (
"math/big"
"testing"

"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
)

func TestSRSConversion(t *testing.T) {
n := uint64(4096)
domain := NewDomain(n)
secret := big.NewInt(100)
srsMonomial, err := newMonomialSRSInsecureUint64(n, secret)
if err != nil {
t.Error(err)
}
srsLagrange, err := newLagrangeSRSInsecure(*domain, secret)
if err != nil {
t.Error(err)
}

lagrangeSRS := domain.IfftG1(srsMonomial.CommitKey.G1)

for i := uint64(0); i < n; i++ {
if !lagrangeSRS[i].Equal(&srsLagrange.CommitKey.G1[i]) {
t.Fatalf("conversion incorrect")
}
}
}

func TestFFt(t *testing.T) {
n := uint64(8)
polyMonomial := []fr.Element{
Expand All @@ -53,8 +30,13 @@ func TestFFt(t *testing.T) {
}
}

polyLagrangeCoset := d.CosetFFtFr(polyMonomial)
gotPolyMonomial = d.CosetIFFtFr(polyLagrangeCoset)
fftCoset := FFTCoset{}
fftCoset.CosetGen = fr.NewElement(7)
fftCoset.InvCosetGen.Inverse(&fftCoset.CosetGen)
cosetDomain := NewCosetDomain(d, fftCoset)

polyLagrangeCoset := cosetDomain.CosetFFtFr(polyMonomial)
gotPolyMonomial = cosetDomain.CosetIFFtFr(polyLagrangeCoset)

for i := uint64(0); i < n; i++ {
if !polyMonomial[i].Equal(&gotPolyMonomial[i]) {
Expand Down
Loading

0 comments on commit d969d3a

Please sign in to comment.