Skip to content

Commit

Permalink
fix: use emulated arithmetic for GLV decomp (#1167)
Browse files Browse the repository at this point in the history
* fix: use emulated arithmetic for GLV decomp

* fix: use emulated GLV decomp everywhere
  • Loading branch information
ivokub authored Jul 18, 2024
1 parent c2da0b0 commit 4cc13fd
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 111 deletions.
36 changes: 3 additions & 33 deletions std/algebra/native/sw_bls24315/g1.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,21 +188,7 @@ func (P *G1Affine) varScalarMul(api frontend.API, Q G1Affine, s frontend.Variabl
// curve.
cc := getInnerCurveConfig(api.Compiler().Field())

// the hints allow to decompose the scalar s into s1 and s2 such that
// s1 + λ * s2 == s mod r,
// where λ is third root of one in 𝔽_r.
sd, err := api.Compiler().NewHint(decomposeScalarG1Simple, 3, s)
if err != nil {
// err is non-nil only for invalid number of inputs
panic(err)
}
s1, s2 := sd[0], sd[1]

// s1 + λ * s2 == s mod r
api.AssertIsEqual(
api.Add(s1, api.Mul(s2, cc.lambda)),
api.Add(s, api.Mul(cc.fr, sd[2])),
)
s1, s2 := callDecomposeScalar(api, s, true)

nbits := 127
s1bits := api.ToBinary(s1, nbits)
Expand Down Expand Up @@ -451,24 +437,8 @@ func (P *G1Affine) jointScalarMul(api frontend.API, Q, R G1Affine, s, t frontend
// P = [s]Q + [t]R using Shamir's trick
func (P *G1Affine) jointScalarMulUnsafe(api frontend.API, Q, R G1Affine, s, t frontend.Variable) *G1Affine {
cc := getInnerCurveConfig(api.Compiler().Field())

sd, err := api.Compiler().NewHint(decomposeScalarG1, 3, s)
if err != nil {
// err is non-nil only for invalid number of inputs
panic(err)
}
s1, s2 := sd[0], sd[1]

td, err := api.Compiler().NewHint(decomposeScalarG1, 3, t)
if err != nil {
// err is non-nil only for invalid number of inputs
panic(err)
}
t1, t2 := td[0], td[1]

api.AssertIsEqual(api.Add(s1, api.Mul(s2, cc.lambda)), api.Add(s, api.Mul(cc.fr, sd[2])))
api.AssertIsEqual(api.Add(t1, api.Mul(t2, cc.lambda)), api.Add(t, api.Mul(cc.fr, td[2])))

s1, s2 := callDecomposeScalar(api, s, false)
t1, t2 := callDecomposeScalar(api, t, false)
nbits := cc.lambda.BitLen() + 1

s1bits := api.ToBinary(s1, nbits)
Expand Down
16 changes: 1 addition & 15 deletions std/algebra/native/sw_bls24315/g2.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,21 +201,7 @@ func (P *g2AffP) varScalarMul(api frontend.API, Q g2AffP, s frontend.Variable, o
// curve.
cc := getInnerCurveConfig(api.Compiler().Field())

// the hints allow to decompose the scalar s into s1 and s2 such that
// s1 + λ * s2 == s mod r,
// where λ is third root of one in 𝔽_r.
sd, err := api.Compiler().NewHint(decomposeScalarG1Simple, 3, s)
if err != nil {
// err is non-nil only for invalid number of inputs
panic(err)
}
s1, s2 := sd[0], sd[1]

// s1 + λ * s2 == s mod r,
api.AssertIsEqual(
api.Add(s1, api.Mul(s2, cc.lambda)),
api.Add(s, api.Mul(cc.fr, sd[2])),
)
s1, s2 := callDecomposeScalar(api, s, true)

nbits := 127
s1bits := api.ToBinary(s1, nbits)
Expand Down
151 changes: 98 additions & 53 deletions std/algebra/native/sw_bls24315/hints.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,77 +6,122 @@ import (

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/constraint/solver"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/emulated"
"github.com/consensys/gnark/std/math/emulated/emparams"
)

func GetHints() []solver.Hint {
return []solver.Hint{
decomposeScalarG1,
decomposeScalarG1Simple,
decomposeScalarG2,
decomposeScalar,
decomposeScalarSimple,
decompose,
}
}

func init() {
solver.RegisterHint(GetHints()...)
}

func decomposeScalarG1Simple(scalarField *big.Int, inputs []*big.Int, outputs []*big.Int) error {
if len(inputs) != 1 {
return fmt.Errorf("expecting one input")
}
if len(outputs) != 3 {
return fmt.Errorf("expecting three outputs")
}
cc := getInnerCurveConfig(scalarField)
sp := ecc.SplitScalar(inputs[0], cc.glvBasis)
outputs[0].Set(&(sp[0]))
outputs[1].Set(&(sp[1]))
// figure out how many times we have overflowed
outputs[2].Mul(outputs[1], cc.lambda).Add(outputs[2], outputs[0])
outputs[2].Sub(outputs[2], inputs[0])
outputs[2].Div(outputs[2], cc.fr)
func decomposeScalarSimple(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error {
return emulated.UnwrapHintWithNativeInput(nativeInputs, nativeOutputs, func(nnMod *big.Int, nninputs, nnOutputs []*big.Int) error {
if len(nninputs) != 1 {
return fmt.Errorf("expecting one input")
}
if len(nnOutputs) != 2 {
return fmt.Errorf("expecting two outputs")
}
cc := getInnerCurveConfig(nativeMod)
sp := ecc.SplitScalar(nninputs[0], cc.glvBasis)
nnOutputs[0].Set(&(sp[0]))
nnOutputs[1].Set(&(sp[1]))

return nil
return nil
})
}

func decomposeScalarG1(scalarField *big.Int, inputs []*big.Int, res []*big.Int) error {
cc := getInnerCurveConfig(scalarField)
sp := ecc.SplitScalar(inputs[0], cc.glvBasis)
res[0].Set(&(sp[0]))
res[1].Set(&(sp[1]))
one := big.NewInt(1)
// add (lambda+1, lambda) until scalar compostion is over Fr to ensure that
// the high bits are set in decomposition.
for res[0].Cmp(cc.lambda) < 1 && res[1].Cmp(cc.lambda) < 1 {
res[0].Add(res[0], cc.lambda)
res[0].Add(res[0], one)
res[1].Add(res[1], cc.lambda)
}
// figure out how many times we have overflowed
res[2].Mul(res[1], cc.lambda).Add(res[2], res[0])
res[2].Sub(res[2], inputs[0])
res[2].Div(res[2], cc.fr)
func decomposeScalar(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error {
return emulated.UnwrapHintWithNativeInput(nativeInputs, nativeOutputs, func(nnMod *big.Int, nninputs, nnOutputs []*big.Int) error {
if len(nninputs) != 1 {
return fmt.Errorf("expecting one input")
}
if len(nnOutputs) != 2 {
return fmt.Errorf("expecting two outputs")
}
cc := getInnerCurveConfig(nativeMod)
sp := ecc.SplitScalar(nninputs[0], cc.glvBasis)
nnOutputs[0].Set(&(sp[0]))
nnOutputs[1].Set(&(sp[1]))
one := big.NewInt(1)
// add (lambda+1, lambda) until scalar compostion is over Fr to ensure that
// the high bits are set in decomposition.
for nnOutputs[0].Cmp(cc.lambda) < 1 && nnOutputs[1].Cmp(cc.lambda) < 1 {
nnOutputs[0].Add(nnOutputs[0], cc.lambda)
nnOutputs[0].Add(nnOutputs[0], one)
nnOutputs[1].Add(nnOutputs[1], cc.lambda)
}

return nil
return nil
})
}

func decomposeScalarG2(scalarField *big.Int, inputs []*big.Int, res []*big.Int) error {
cc := getInnerCurveConfig(scalarField)
sp := ecc.SplitScalar(inputs[0], cc.glvBasis)
res[0].Set(&(sp[0]))
res[1].Set(&(sp[1]))
one := big.NewInt(1)
// add (lambda+1, lambda) until scalar compostion is over Fr to ensure that
// the high bits are set in decomposition.
for res[0].Cmp(cc.lambda) < 1 && res[1].Cmp(cc.lambda) < 1 {
res[0].Add(res[0], cc.lambda)
res[0].Add(res[0], one)
res[1].Add(res[1], cc.lambda)
func callDecomposeScalar(api frontend.API, s frontend.Variable, simple bool) (s1, s2 frontend.Variable) {
var fr emparams.BLS24315Fr
cc := getInnerCurveConfig(api.Compiler().Field())
sapi, err := emulated.NewField[emparams.BLS24315Fr](api)
if err != nil {
panic(err)
}
var hintFn solver.Hint
if simple {
hintFn = decomposeScalarSimple
} else {
hintFn = decomposeScalar
}
// compute the decomposition using a hint. We have to use the emulated
// version which takes native input and outputs non-native outputs.
//
// the hints allow to decompose the scalar s into s1 and s2 such that
// s1 + λ * s2 == s mod r,
// where λ is third root of one in 𝔽_r.
sd, err := sapi.NewHintWithNativeInput(hintFn, 2, s)
if err != nil {
panic(err)
}
// lambda as nonnative element
lambdaEmu := sapi.NewElement(cc.lambda)
// the scalar as nonnative element. We need to split at 64 bits.
limbs, err := api.NewHint(decompose, int(fr.NbLimbs()), s)
if err != nil {
panic(err)
}
semu := sapi.NewElement(limbs)
// s1 + λ * s2 == s mod r
lhs := sapi.MulNoReduce(sd[1], lambdaEmu)
lhs = sapi.Add(lhs, sd[0])

sapi.AssertIsEqual(lhs, semu)

s1 = 0
s2 = 0
b := big.NewInt(1)
for i := range sd[0].Limbs {
s1 = api.Add(s1, api.Mul(sd[0].Limbs[i], b))
s2 = api.Add(s2, api.Mul(sd[1].Limbs[i], b))
b.Lsh(b, 64)
}
// figure out how many times we have overflowed
res[2].Mul(res[1], cc.lambda).Add(res[2], res[0])
res[2].Sub(res[2], inputs[0])
res[2].Div(res[2], cc.fr)
return s1, s2
}

func decompose(mod *big.Int, inputs, outputs []*big.Int) error {
if len(inputs) != 1 && len(outputs) != 4 {
return fmt.Errorf("input/output length mismatch")
}
tmp := new(big.Int).Set(inputs[0])
mask := new(big.Int).SetUint64(^uint64(0))
for i := 0; i < 4; i++ {
outputs[i].And(tmp, mask)
tmp.Rsh(tmp, 64)
}
return nil
}
11 changes: 1 addition & 10 deletions std/algebra/native/sw_bls24315/pairing2.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,16 +174,7 @@ func (c *Curve) MultiScalarMul(P []*G1Affine, scalars []*Scalar, opts ...algopts
}
gamma := c.packScalarToVar(scalars[0])
// decompose gamma in the endomorphism eigenvalue basis and bit-decompose the sub-scalars
cc := getInnerCurveConfig(c.api.Compiler().Field())
sd, err := c.api.Compiler().NewHint(decomposeScalarG1Simple, 3, gamma)
if err != nil {
panic(err)
}
gamma1, gamma2 := sd[0], sd[1]
c.api.AssertIsEqual(
c.api.Add(gamma1, c.api.Mul(gamma2, cc.lambda)),
c.api.Add(gamma, c.api.Mul(cc.fr, sd[2])),
)
gamma1, gamma2 := callDecomposeScalar(c.api, gamma, true)
nbits := 127
gamma1Bits := c.api.ToBinary(gamma1, nbits)
gamma2Bits := c.api.ToBinary(gamma2, nbits)
Expand Down

0 comments on commit 4cc13fd

Please sign in to comment.