Skip to content

Commit

Permalink
pkg/hintrunner/zero: allow multiplication binary ops in references (#196
Browse files Browse the repository at this point in the history
)

References may include the multiplication inside them.

Let's take this like of code for the example:

    https://github.com/starkware-libs/cairo-lang/blob/caba294d82eeeccc3d86a158adb8ba209bf2d8fc/src/starkware/cairo/common/math.cairo#L193

It will produce a reference like this:
```json
    {
        "cairo_type": "felt",
        "full_name": "starkware.cairo.common.math.assert_le_felt.arc_prod",
        "references": [
            {
                "ap_tracking_data": {
                    "group": 1,
                    "offset": 8
                },
                "pc": 14,
                "value": "cast([ap + (-5)] * [ap + (-1)], felt)"
            }
        ],
        "type": "reference"
    }
```
  • Loading branch information
quasilyte authored Feb 8, 2024
1 parent 88587ae commit fb1008c
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 3 deletions.
33 changes: 30 additions & 3 deletions pkg/hintrunner/zero/hintparser.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package zero
import (
"fmt"

"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter"
op "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter"
"github.com/alecthomas/participle/v2"
)
Expand All @@ -22,6 +23,7 @@ var parser *participle.Parser[IdentifierExp] = participle.MustBuild[IdentifierEx
// 2 dereferences off1 omitted: cast([reg] + [reg + off2], type)
// 2 dereferences off2 omitted: cast([reg + off1] + [reg], type)
// 2 dereferences both offs omitted: cast([reg] + [reg], type)
// 2 dereferences with multiplication: cast([reg + off1] * [reg + off2], felt)
// Reference no dereference 2 offsets - + : cast(reg - off1 + off2, type)

// Note: The same cases apply with an external dereference. Example: [cast(number, type)]
Expand Down Expand Up @@ -62,7 +64,8 @@ type DerefExp struct {
}

type BinOpExp struct {
LeftExp *LeftExp `@@ "+"`
LeftExp *LeftExp `@@`
Operator string `@("+" | "*")`
RightExp *RightExp `@@`
}

Expand All @@ -83,10 +86,12 @@ type RightExp struct {

type DerefOffset struct {
Deref op.Deref
Op op.Operator
Offset *int
}
type DerefDeref struct {
LeftDeref op.Deref
Op op.Operator
RightDeref op.Deref
}

Expand Down Expand Up @@ -141,8 +146,9 @@ func (expression CastExp) Evaluate() (any, error) {
return result, nil
case DerefOffset:
return op.BinaryOp{
Operator: 0,
Operator: result.Op,
Lhs: result.Deref.Deref,
// TODO: why we're not using something like f.NewElement here?
Rhs: op.Immediate{
uint64(0),
uint64(0),
Expand All @@ -152,7 +158,7 @@ func (expression CastExp) Evaluate() (any, error) {
}, nil
case DerefDeref:
return op.BinaryOp{
Operator: 0,
Operator: result.Op,
Lhs: result.LeftDeref.Deref,
Rhs: result.RightDeref,
}, nil
Expand Down Expand Up @@ -238,8 +244,16 @@ func (expression BinOpExp) Evaluate() (any, error) {
return nil, err
}

operation, err := parseOperator(expression.Operator)
if err != nil {
return nil, err
}

switch lResult := leftExp.(type) {
case op.CellRefer:
// Right now we assume that there is no expression like `reg - off1 * off2`,
// but if there are, we would need to come up with an idea how to handle it.
// Right now we only cover `off1 + off2` expressions here.
offset, ok := rightExp.(*int)
if !ok {
return nil, fmt.Errorf("invalid type operation")
Expand Down Expand Up @@ -267,11 +281,13 @@ func (expression BinOpExp) Evaluate() (any, error) {
case op.Deref:
return DerefDeref{
lResult,
operation,
rResult,
}, nil
case *int:
return DerefOffset{
lResult,
operation,
rResult,
}, nil
}
Expand Down Expand Up @@ -308,3 +324,14 @@ func ParseIdentifier(value string) (any, error) {

return identifierExp.Evaluate()
}

func parseOperator(op string) (hinter.Operator, error) {
switch op {
case "+":
return hinter.Add, nil
case "*":
return hinter.Mul, nil
default:
return 0, fmt.Errorf("unexpected op: %q", op)
}
}
20 changes: 20 additions & 0 deletions pkg/hintrunner/zero/hintparser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,26 @@ func TestHintParser(t *testing.T) {
},
},
},
{
Parameter: "cast([ap + (-5)] * [ap + (-1)], felt)",
ExpectedCellRefer: nil,
ExpectedResOperander: hinter.BinaryOp{
Operator: hinter.Mul,
Lhs: hinter.ApCellRef(-5),
Rhs: hinter.Deref{
Deref: hinter.ApCellRef(-1),
},
},
},
{
Parameter: "cast([ap] * 3, felt)",
ExpectedCellRefer: nil,
ExpectedResOperander: hinter.BinaryOp{
Operator: hinter.Mul,
Lhs: hinter.ApCellRef(0),
Rhs: hinter.Immediate{0, 0, 0, 3},
},
},
}

for _, test := range testSet {
Expand Down

0 comments on commit fb1008c

Please sign in to comment.