Skip to content

Commit

Permalink
Implement RandomEcPoint hint (#513)
Browse files Browse the repository at this point in the history
* Implement RandomEcPoint

* Add unit test

* Update test assert

* Add comment

* Some test changes

* Update pkg/hintrunner/zero/zerohint_ec.go

Co-authored-by: Tristan <[email protected]>

---------

Co-authored-by: Tristan <[email protected]>
  • Loading branch information
har777 and TAdev0 authored Jul 11, 2024
1 parent 08d12e6 commit 050d0e0
Show file tree
Hide file tree
Showing 9 changed files with 283 additions and 75 deletions.
96 changes: 49 additions & 47 deletions integration_tests/BenchMarks.txt

Large diffs are not rendered by default.

3 changes: 0 additions & 3 deletions integration_tests/cairozero_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,6 @@ func TestCairoFiles(t *testing.T) {
errorExpected := false
if name == "range_check.small.cairo" {
errorExpected = true
} else if name == "ecop.starknet_with_keccak.cairo" {
// temporary, being fixed in another PR soon
continue
}

path := filepath.Join(root, name)
Expand Down
16 changes: 15 additions & 1 deletion pkg/hintrunner/utils/math_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func IsQuadResidue(x *fp.Element) bool {
return x.IsZero() || x.IsOne() || x.Legendre() == 1
}

func YSquaredFromX(x, beta, fieldPrime *big.Int) *big.Int {
func ySquaredFromX(x, beta, fieldPrime *big.Int) *big.Int {
// Computes y^2 using the curve equation:
// y^2 = x^3 + alpha * x + beta (mod field_prime)
// We ignore alpha as it is a constant with a value of 1
Expand All @@ -171,3 +171,17 @@ func Sqrt(x, p *big.Int) *big.Int {

return m
}

func RecoverY(x, beta, fieldPrime *big.Int) (*big.Int, error) {
ySquared := ySquaredFromX(x, beta, fieldPrime)
if IsQuadResidue(new(fp.Element).SetBigInt(ySquared)) {
return Sqrt(ySquared, fieldPrime), nil
}
return nil, fmt.Errorf("%s does not represent the x coordinate of a point on the curve", ySquared.String())
}

func GetCairoPrime() (big.Int, bool) {
// 2**251 + 17 * 2**192 + 1
cairoPrime, ok := new(big.Int).SetString("3618502788666131213697322783095070105623107215331596699973092056135872020481", 10)
return *cairoPrime, ok
}
1 change: 1 addition & 0 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ ids.multiplicities = segments.gen_arg([len(positions_dict[k]) for k in output])`
isZeroPackCode string = "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack\n\nx = pack(ids.x, PRIME) % SECP_P"
isZeroDivModCode string = "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P\nfrom starkware.python.math_utils import div_mod\n\nvalue = x_inv = div_mod(1, x, SECP_P)"
recoverYCode string = "from starkware.crypto.signature.signature import ALPHA, BETA, FIELD_PRIME\nfrom starkware.python.math_utils import recover_y\nids.p.x = ids.x\n# This raises an exception if `x` is not on the curve.\nids.p.y = recover_y(ids.x, ALPHA, BETA, FIELD_PRIME)"
randomEcPointCode string = "from starkware.crypto.signature.signature import ALPHA, BETA, FIELD_PRIME\nfrom starkware.python.math_utils import random_ec_point\nfrom starkware.python.utils import to_bytes\n\n# Define a seed for random_ec_point that's dependent on all the input, so that:\n# (1) The added point s is deterministic.\n# (2) It's hard to choose inputs for which the builtin will fail.\nseed = b\"\".join(map(to_bytes, [ids.p.x, ids.p.y, ids.m, ids.q.x, ids.q.y]))\nids.s.x, ids.s.y = random_ec_point(FIELD_PRIME, ALPHA, BETA, seed)"

// ------ Signature hints related code ------
verifyECDSASignatureCode string = "ecdsa_builtin.add_signature(ids.ecdsa_ptr.address_, (ids.signature_r, ids.signature_s))"
Expand Down
2 changes: 2 additions & 0 deletions pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return createIsZeroDivModHinter()
case recoverYCode:
return createRecoverYHinter(resolver)
case randomEcPointCode:
return createRandomEcPointHinter(resolver)
// Blake hints
case blake2sAddUint256BigendCode:
return createBlake2sAddUint256Hinter(resolver, true)
Expand Down
165 changes: 147 additions & 18 deletions pkg/hintrunner/zero/zerohint_ec.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package zero

import (
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"fmt"
"math/big"

Expand Down Expand Up @@ -901,33 +904,25 @@ func newRecoverYHint(x, p hinter.ResOperander) hinter.Hinter {
return err
}

const betaString = "3141592653589793238462643383279502884197169399375105820974944592307816406665"
betaBigInt, ok := new(big.Int).SetString(betaString, 10)
if !ok {
panic("failed to convert BETA string to big.Int")
}
betaBigInt := new(big.Int)
utils.Beta.BigInt(betaBigInt)

const fieldPrimeString = "3618502788666131213697322783095070105623107215331596699973092056135872020481"
fieldPrimeBigInt, ok := new(big.Int).SetString(fieldPrimeString, 10)
fieldPrimeBigInt, ok := secp_utils.GetCairoPrime()
if !ok {
panic("failed to convert FIELD_PRIME string to big.Int")
return fmt.Errorf("GetCairoPrime failed")
}

xBigInt := new(big.Int)
xFelt.BigInt(xBigInt)

// y^2 = x^3 + alpha * x + beta (mod field_prime)
ySquaredBigInt := secp_utils.YSquaredFromX(xBigInt, betaBigInt, fieldPrimeBigInt)
ySquaredFelt := new(fp.Element).SetBigInt(ySquaredBigInt)

if secp_utils.IsQuadResidue(ySquaredFelt) {
result := new(fp.Element).SetBigInt(secp_utils.Sqrt(ySquaredBigInt, fieldPrimeBigInt))
value := mem.MemoryValueFromFieldElement(result)
return vm.Memory.WriteToAddress(&pYAddr, &value)
} else {
ySquaredString := ySquaredBigInt.String()
return fmt.Errorf("%s does not represent the x coordinate of a point on the curve", ySquaredString)
resultBigInt, err := secp_utils.RecoverY(xBigInt, betaBigInt, &fieldPrimeBigInt)
if err != nil {
return err
}
resultFelt := new(fp.Element).SetBigInt(resultBigInt)
resultMv := mem.MemoryValueFromFieldElement(resultFelt)
return vm.Memory.WriteToAddress(&pYAddr, &resultMv)
},
}
}
Expand All @@ -945,3 +940,137 @@ func createRecoverYHinter(resolver hintReferenceResolver) (hinter.Hinter, error)

return newRecoverYHint(x, p), nil
}

// RandomEcPoint hint returns a random non-zero point on the elliptic curve
// y^2 = x^3 + alpha * x + beta (mod field_prime).
// The point is created deterministically from the seed.
//
// `newRandomEcPointHint` takes 4 operanders as arguments
// - `p` is an EC point used for seed generation
// - `m` the multiplication coefficient of Q used for seed generation
// - `q` an EC point used for seed generation
// - `s` is where the generated random EC point is written to
func newRandomEcPointHint(p, m, q, s hinter.ResOperander) hinter.Hinter {
return &GenericZeroHinter{
Name: "RandomEcPoint",
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
//> from starkware.crypto.signature.signature import ALPHA, BETA, FIELD_PRIME
//> from starkware.python.math_utils import random_ec_point
//> from starkware.python.utils import to_bytes
//>
//> # Define a seed for random_ec_point that's dependent on all the input, so that:
//> # (1) The added point s is deterministic.
//> # (2) It's hard to choose inputs for which the builtin will fail.
//> seed = b"".join(map(to_bytes, [ids.p.x, ids.p.y, ids.m, ids.q.x, ids.q.y]))
//> ids.s.x, ids.s.y = random_ec_point(FIELD_PRIME, ALPHA, BETA, seed)

pAddr, err := p.GetAddress(vm)
if err != nil {
return err
}
pValues, err := vm.Memory.ResolveAsEcPoint(pAddr)
if err != nil {
return err
}
mFelt, err := hinter.ResolveAsFelt(vm, m)
if err != nil {
return err
}
qAddr, err := q.GetAddress(vm)
if err != nil {
return err
}
qValues, err := vm.Memory.ResolveAsEcPoint(qAddr)
if err != nil {
return err
}

var bytesArray []byte
writeFeltToBytesArray := func(n *fp.Element) {
for _, byteValue := range n.Bytes() {
bytesArray = append(bytesArray, byteValue)
}
}
for _, felt := range pValues {
writeFeltToBytesArray(felt)
}
writeFeltToBytesArray(mFelt)
for _, felt := range qValues {
writeFeltToBytesArray(felt)
}
seed := sha256.Sum256(bytesArray)

alphaBig := new(big.Int)
utils.Alpha.BigInt(alphaBig)
betaBig := new(big.Int)
utils.Beta.BigInt(betaBig)
fieldPrime, ok := secp_utils.GetCairoPrime()
if !ok {
return fmt.Errorf("GetCairoPrime failed")
}

for i := uint64(0); i < 100; i++ {
iBytes := make([]byte, 10)
binary.LittleEndian.PutUint64(iBytes, i)
concatenated := append(seed[1:], iBytes...)
hash := sha256.Sum256(concatenated)
hashHex := hex.EncodeToString(hash[:])
x := new(big.Int)
x.SetString(hashHex, 16)

yCoef := big.NewInt(1)
if seed[0]&1 == 1 {
yCoef.Neg(yCoef)
}

// Try to recover y
if !ok {
return fmt.Errorf("failed to get field prime value")
}
if y, err := secp_utils.RecoverY(x, betaBig, &fieldPrime); err == nil {
y.Mul(yCoef, y)
y.Mod(y, &fieldPrime)

sAddr, err := s.GetAddress(vm)
if err != nil {
return err
}

sXFelt := new(fp.Element).SetBigInt(x)
sYFelt := new(fp.Element).SetBigInt(y)
sXMv := mem.MemoryValueFromFieldElement(sXFelt)
sYMv := mem.MemoryValueFromFieldElement(sYFelt)

err = vm.Memory.WriteToNthStructField(sAddr, sXMv, 0)
if err != nil {
return err
}
return vm.Memory.WriteToNthStructField(sAddr, sYMv, 1)
}
}

return fmt.Errorf("could not find a point on the curve")
},
}
}

func createRandomEcPointHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
p, err := resolver.GetResOperander("p")
if err != nil {
return nil, err
}
m, err := resolver.GetResOperander("m")
if err != nil {
return nil, err
}
q, err := resolver.GetResOperander("q")
if err != nil {
return nil, err
}
s, err := resolver.GetResOperander("s")
if err != nil {
return nil, err
}

return newRandomEcPointHint(p, m, q, s), nil
}
46 changes: 46 additions & 0 deletions pkg/hintrunner/zero/zerohint_ec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,52 @@ func TestZeroHintEc(t *testing.T) {
errCheck: errorTextContains("does not represent the x coordinate of a point on the curve"),
},
},
"RandomEcPoint": {
{
operanders: []*hintOperander{
{Name: "p.x", Kind: apRelative, Value: feltString("3004956058830981475544150447242655232275382685012344776588097793621230049020")},
{Name: "p.y", Kind: apRelative, Value: feltString("3232266734070744637901977159303149980795588196503166389060831401046564401743")},
{Name: "m", Kind: apRelative, Value: feltUint64(34)},
{Name: "q.x", Kind: apRelative, Value: feltString("2864041794633455918387139831609347757720597354645583729611044800117714995244")},
{Name: "q.y", Kind: apRelative, Value: feltString("2252415379535459416893084165764951913426528160630388985542241241048300343256")},
{Name: "s.x", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newRandomEcPointHint(
ctx.operanders["p.x"],
ctx.operanders["m"],
ctx.operanders["q.x"],
ctx.operanders["s.x"],
)
},
check: consecutiveVarValueEquals("s.x", []*fp.Element{
feltString("96578541406087262240552119423829615463800550101008760434566010168435227837635"),
feltString("3412645436898503501401619513420382337734846074629040678138428701431530606439"),
}),
},
{
operanders: []*hintOperander{
{Name: "p.x", Kind: apRelative, Value: feltUint64(12345)},
{Name: "p.y", Kind: apRelative, Value: feltUint64(6789)},
{Name: "m", Kind: apRelative, Value: feltUint64(101)},
{Name: "q.x", Kind: apRelative, Value: feltUint64(98765)},
{Name: "q.y", Kind: apRelative, Value: feltUint64(4321)},
{Name: "s.x", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newRandomEcPointHint(
ctx.operanders["p.x"],
ctx.operanders["m"],
ctx.operanders["q.x"],
ctx.operanders["s.x"],
)
},
check: consecutiveVarValueEquals("s.x", []*fp.Element{
feltString("39190969885360777615413526676655883809466222002423777590585892821354159079496"),
feltString("533983185449702770508526175744869430974740140562200547506631069957329272485"),
}),
},
},
},
)
}
11 changes: 5 additions & 6 deletions pkg/hintrunner/zero/zerohint_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -1154,21 +1154,20 @@ func newIsQuadResidueHint(x, y hinter.ResOperander) hinter.Hinter {
var value = memory.MemoryValue{}
var result *fp.Element = new(fp.Element)

const primeString = "3618502788666131213697322783095070105623107215331596699973092056135872020481"
primeBigInt, ok := new(big.Int).SetString(primeString, 10)
primeBigInt, ok := math_utils.GetCairoPrime()
if !ok {
panic("failed to convert prime string to big.Int")
return fmt.Errorf("GetCairoPrime failed")
}

if math_utils.IsQuadResidue(x) {
result.SetBigInt(math_utils.Sqrt(&xBigInt, primeBigInt))
result.SetBigInt(math_utils.Sqrt(&xBigInt, &primeBigInt))
} else {
y, err := math_utils.Divmod(&xBigInt, big.NewInt(3), primeBigInt)
y, err := math_utils.Divmod(&xBigInt, big.NewInt(3), &primeBigInt)
if err != nil {
return err
}

result.SetBigInt(math_utils.Sqrt(&y, primeBigInt))
result.SetBigInt(math_utils.Sqrt(&y, &primeBigInt))
}

value = memory.MemoryValueFromFieldElement(result)
Expand Down
18 changes: 18 additions & 0 deletions pkg/vm/memory/memory_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,21 @@ func (memory *Memory) ResolveAsBigInt3(valAddr MemoryAddress) ([3]*f.Element, er

return valValues, nil
}

func (memory *Memory) ResolveAsEcPoint(valAddr MemoryAddress) ([2]*f.Element, error) {
valMemoryValues, err := memory.GetConsecutiveMemoryValues(valAddr, int16(2))
if err != nil {
return [2]*f.Element{}, err
}

var valValues [2]*f.Element
for i := 0; i < 2; i++ {
valValue, err := valMemoryValues[i].FieldElement()
if err != nil {
return [2]*f.Element{}, err
}
valValues[i] = valValue
}

return valValues, nil
}

0 comments on commit 050d0e0

Please sign in to comment.