Skip to content

Commit

Permalink
Implement U256InvModN hint (#631)
Browse files Browse the repository at this point in the history
* Wrote Impl

* nit

* fmt

* removed bug

* Added hint support for starkent

* Refactored Implementation

* Added tests

* nit

---------

Co-authored-by: MaksymMalicki <[email protected]>
  • Loading branch information
Sh0g0-1758 and MaksymMalicki authored Aug 21, 2024
1 parent 47baa8f commit 6ebeffd
Show file tree
Hide file tree
Showing 5 changed files with 509 additions and 9 deletions.
227 changes: 227 additions & 0 deletions pkg/hintrunner/core/hint.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,233 @@ func (hint DivMod) Execute(vm *VM.VirtualMachine, _ *hinter.HintRunnerContext) e
return nil
}

type U256InvModN struct {
B0 hinter.ResOperander
B1 hinter.ResOperander
N0 hinter.ResOperander
N1 hinter.ResOperander
G0OrNoInv hinter.CellRefer
G1Option hinter.CellRefer
SOrR0 hinter.CellRefer
SOrR1 hinter.CellRefer
TOrK0 hinter.CellRefer
TOrK1 hinter.CellRefer
}

func (hint U256InvModN) String() string {
return "U256InvModN"
}

func (hint U256InvModN) Execute(vm *VM.VirtualMachine, _ *hinter.HintRunnerContext) error {
B0, err := hint.B0.Resolve(vm)
if err != nil {
return fmt.Errorf("resolve B0 operand %s: %v", hint.B0, err)
}

B1, err := hint.B1.Resolve(vm)
if err != nil {
return fmt.Errorf("resolve B1 operand %s: %v", hint.B1, err)
}

N0, err := hint.N0.Resolve(vm)
if err != nil {
return fmt.Errorf("resolve N0 operand %s: %v", hint.N0, err)
}

N1, err := hint.N1.Resolve(vm)
if err != nil {
return fmt.Errorf("resolve N1 operand %s: %v", hint.N1, err)
}

g0OrNoInvAddr, err := hint.G0OrNoInv.Get(vm)
if err != nil {
return fmt.Errorf("get G0OrNoInv address %s: %w", g0OrNoInvAddr, err)
}

g1OptionAddr, err := hint.G1Option.Get(vm)
if err != nil {
return fmt.Errorf("get G1Option address %s: %w", g1OptionAddr, err)
}

sOrR0Addr, err := hint.SOrR0.Get(vm)
if err != nil {
return fmt.Errorf("get SOrR0 address %s: %w", sOrR0Addr, err)
}

sOrR1Addr, err := hint.SOrR1.Get(vm)
if err != nil {
return fmt.Errorf("get SOrR1 address %s: %w", sOrR1Addr, err)
}

tOrK0Addr, err := hint.TOrK0.Get(vm)
if err != nil {
return fmt.Errorf("get TOrK0 address %s: %w", tOrK0Addr, err)
}

tOrK1Addr, err := hint.TOrK1.Get(vm)
if err != nil {
return fmt.Errorf("get TOrK1 address %s: %w", tOrK1Addr, err)
}

B0Felt, err := B0.FieldElement()
if err != nil {
return err
}
B1Felt, err := B1.FieldElement()
if err != nil {
return err
}
N0Felt, err := N0.FieldElement()
if err != nil {
return err
}
N1Felt, err := N1.FieldElement()
if err != nil {
return err
}

var B0BigInt big.Int
B0Felt.BigInt(&B0BigInt)

var B1BigInt big.Int
B1Felt.BigInt(&B1BigInt)

var N0BigInt big.Int
N0Felt.BigInt(&N0BigInt)

var N1BigInt big.Int
N1Felt.BigInt(&N1BigInt)

b := new(big.Int).Lsh(&B1BigInt, 128)
b.Add(b, &B0BigInt)

n := new(big.Int).Lsh(&N1BigInt, 128)
n.Add(n, &N0BigInt)

_, r, g := u.Igcdex(n, b)
mask := new(big.Int).Lsh(big.NewInt(1), 128)
mask.Sub(mask, big.NewInt(1))

if n.Cmp(big.NewInt(1)) == 0 {
mv := mem.MemoryValueFromFieldElement(B0Felt)
err = vm.Memory.WriteToAddress(&sOrR0Addr, &mv)
if err != nil {
return fmt.Errorf("write to SOrR0 address %s: %w", sOrR0Addr, err)
}

mv = mem.MemoryValueFromFieldElement(B1Felt)
err = vm.Memory.WriteToAddress(&sOrR1Addr, &mv)
if err != nil {
return fmt.Errorf("write to SOrR1 address %s: %w", sOrR1Addr, err)
}

mv = mem.MemoryValueFromFieldElement(&utils.FeltOne)
err = vm.Memory.WriteToAddress(&tOrK0Addr, &mv)
if err != nil {
return fmt.Errorf("write to TOrK0 address %s: %w", tOrK0Addr, err)
}

mv = mem.MemoryValueFromFieldElement(&utils.FeltZero)
err = vm.Memory.WriteToAddress(&tOrK1Addr, &mv)
if err != nil {
return fmt.Errorf("write to TOrK1 address %s: %w", tOrK1Addr, err)
}

mv = mem.MemoryValueFromFieldElement(&utils.FeltOne)
err = vm.Memory.WriteToAddress(&g0OrNoInvAddr, &mv)
if err != nil {
return fmt.Errorf("write to G0OrNoInv address %s: %w", g0OrNoInvAddr, err)
}

mv = mem.MemoryValueFromFieldElement(&utils.FeltZero)
err = vm.Memory.WriteToAddress(&g1OptionAddr, &mv)
if err != nil {
return fmt.Errorf("write to G1Option address %s: %w", g1OptionAddr, err)
}
} else if g.Cmp(big.NewInt(1)) != 0 {
if new(big.Int).Rem(&g, big.NewInt(2)).Cmp(big.NewInt(0)) == 0 {
g = *big.NewInt(2)
}

s := new(big.Int).Div(b, &g)
t := new(big.Int).Div(n, &g)

mv := mem.MemoryValueFromFieldElement(new(f.Element).SetBigInt(new(big.Int).And(s, mask)))
err = vm.Memory.WriteToAddress(&sOrR0Addr, &mv)
if err != nil {
return fmt.Errorf("write to SOrR0 address %s: %w", sOrR0Addr, err)
}

mv = mem.MemoryValueFromFieldElement(new(f.Element).SetBigInt(new(big.Int).Rsh(s, 128)))
err = vm.Memory.WriteToAddress(&sOrR1Addr, &mv)
if err != nil {
return fmt.Errorf("write to SOrR1 address %s: %w", sOrR1Addr, err)
}

mv = mem.MemoryValueFromFieldElement(new(f.Element).SetBigInt(new(big.Int).And(t, mask)))
err = vm.Memory.WriteToAddress(&tOrK0Addr, &mv)
if err != nil {
return fmt.Errorf("write to TOrK0 address %s: %w", tOrK0Addr, err)
}

mv = mem.MemoryValueFromFieldElement(new(f.Element).SetBigInt(new(big.Int).Rsh(t, 128)))
err = vm.Memory.WriteToAddress(&tOrK1Addr, &mv)
if err != nil {
return fmt.Errorf("write to TOrK1 address %s: %w", tOrK1Addr, err)
}

mv = mem.MemoryValueFromFieldElement(new(f.Element).SetBigInt(new(big.Int).And(&g, mask)))
err = vm.Memory.WriteToAddress(&g0OrNoInvAddr, &mv)
if err != nil {
return fmt.Errorf("write to G0OrNoInv address %s: %w", g0OrNoInvAddr, err)
}

mv = mem.MemoryValueFromFieldElement(new(f.Element).SetBigInt(new(big.Int).Rsh(&g, 128)))
err = vm.Memory.WriteToAddress(&g1OptionAddr, &mv)
if err != nil {
return fmt.Errorf("write to G1Option address %s: %w", g1OptionAddr, err)
}
} else {
r.Rem(&r, n)

k := new(big.Int).Mul(&r, b)
k.Sub(k, big.NewInt(1))
k.Div(k, n)

mv := mem.MemoryValueFromFieldElement(new(f.Element).SetBigInt(new(big.Int).And(&r, mask)))
err = vm.Memory.WriteToAddress(&sOrR0Addr, &mv)
if err != nil {
return fmt.Errorf("write to SOrR0 address %s: %w", sOrR0Addr, err)
}

mv = mem.MemoryValueFromFieldElement(new(f.Element).SetBigInt(new(big.Int).Rsh(&r, 128)))
err = vm.Memory.WriteToAddress(&sOrR1Addr, &mv)
if err != nil {
return fmt.Errorf("write to SOrR1 address %s: %w", sOrR1Addr, err)
}

mv = mem.MemoryValueFromFieldElement(new(f.Element).SetBigInt(new(big.Int).And(k, mask)))
err = vm.Memory.WriteToAddress(&tOrK0Addr, &mv)
if err != nil {
return fmt.Errorf("write to TOrK0 address %s: %w", tOrK0Addr, err)
}

mv = mem.MemoryValueFromFieldElement(new(f.Element).SetBigInt(new(big.Int).Rsh(k, 128)))
err = vm.Memory.WriteToAddress(&tOrK1Addr, &mv)
if err != nil {
return fmt.Errorf("write to TOrK1 address %s: %w", tOrK1Addr, err)
}

mv = mem.MemoryValueFromFieldElement(&utils.FeltZero)
err = vm.Memory.WriteToAddress(&g0OrNoInvAddr, &mv)
if err != nil {
return fmt.Errorf("write to G0OrNoInv address %s: %w", g0OrNoInvAddr, err)
}
}

return nil
}

type Uint256DivMod struct {
dividend0 hinter.ResOperander
dividend1 hinter.ResOperander
Expand Down
Loading

0 comments on commit 6ebeffd

Please sign in to comment.