From 7a1e82826184c1d3d1615c3358b5c3830e5a2e14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=96mer=20Faruk=20IRMAK?= Date: Fri, 6 Oct 2023 16:42:09 +0300 Subject: [PATCH] add bitwise builtin support --- .../builtin_tests/bitwise_builtin_test.cairo | 18 ++++ integration_tests/cairozero_test.go | 10 +++ pkg/vm/builtins/bitwise.go | 87 +++++++++++++++++++ pkg/vm/builtins/bitwise_test.go | 41 +++++++++ pkg/vm/builtins/builtin_runner.go | 2 +- pkg/vm/vm.go | 18 ++-- 6 files changed, 166 insertions(+), 10 deletions(-) create mode 100644 integration_tests/builtin_tests/bitwise_builtin_test.cairo create mode 100644 pkg/vm/builtins/bitwise.go create mode 100644 pkg/vm/builtins/bitwise_test.go diff --git a/integration_tests/builtin_tests/bitwise_builtin_test.cairo b/integration_tests/builtin_tests/bitwise_builtin_test.cairo new file mode 100644 index 000000000..041bb1fa9 --- /dev/null +++ b/integration_tests/builtin_tests/bitwise_builtin_test.cairo @@ -0,0 +1,18 @@ +%builtins bitwise +from starkware.cairo.common.bitwise import bitwise_and, bitwise_xor, bitwise_or, bitwise_operations +from starkware.cairo.common.cairo_builtins import BitwiseBuiltin + +func main{bitwise_ptr: BitwiseBuiltin*}() { + let (and_a) = bitwise_and(12, 10); // Binary (1100, 1010). + assert and_a = 8; // Binary 1000. + let (xor_a) = bitwise_xor(12, 10); + assert xor_a = 6; + let (or_a) = bitwise_or(12, 10); + assert or_a = 14; + + let (and_b, xor_b, or_b) = bitwise_operations(9, 3); + assert and_b = 1; + assert xor_b = 10; + assert or_b = 11; + return (); +} diff --git a/integration_tests/cairozero_test.go b/integration_tests/cairozero_test.go index 7c217c6dd..866d80fe2 100644 --- a/integration_tests/cairozero_test.go +++ b/integration_tests/cairozero_test.go @@ -248,3 +248,13 @@ func TestFailingRangeCheck(t *testing.T) { clean("./failing_contracts/") } + +func TestBitwise(t *testing.T) { + compiledOutput, err := compileZeroCode("./builtin_tests/bitwise_builtin_test.cairo") + require.NoError(t, err) + + _, _, err = runVm(compiledOutput) + require.NoError(t, err) + + clean("./builtin_tests/") +} diff --git a/pkg/vm/builtins/bitwise.go b/pkg/vm/builtins/bitwise.go new file mode 100644 index 000000000..0aa1ed7f1 --- /dev/null +++ b/pkg/vm/builtins/bitwise.go @@ -0,0 +1,87 @@ +package builtins + +import ( + "errors" + + "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" + "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" +) + +const cellsPerBitwise = 5 +const inputCellsPerBitwise = 2 + +type Bitwise struct{} + +func (b *Bitwise) CheckWrite(segment *memory.Segment, offset uint64, value *memory.MemoryValue) error { + return nil +} + +func (b *Bitwise) InferValue(segment *memory.Segment, offset uint64) error { + bitwiseIndex := offset % cellsPerBitwise + // input cell + if bitwiseIndex < inputCellsPerBitwise { + return errors.New("cannot infer value") + } + + xOffset := offset - bitwiseIndex + yOffset := xOffset + 1 + + xValue, err := segment.Read(xOffset) + if err != nil { + return err + } + + yValue, err := segment.Read(yOffset) + if err != nil { + return err + } + + xFelt, err := xValue.FieldElement() + if err != nil { + return err + } + + yFelt, err := yValue.FieldElement() + if err != nil { + return err + } + + xBytes := xFelt.Bytes() + yBytes := yFelt.Bytes() + + var bitwiseValue memory.MemoryValue + var bitwiseFelt fp.Element + var bitwiseBytes [32]byte + for i := 0; i < 32; i++ { + bitwiseBytes[i] = xBytes[i] & yBytes[i] + } + bitwiseFelt.SetBytes(bitwiseBytes[:]) + bitwiseValue = memory.MemoryValueFromFieldElement(&bitwiseFelt) + if err := segment.Write(xOffset+2, &bitwiseValue); err != nil { + return err + } + + for i := 0; i < 32; i++ { + bitwiseBytes[i] = xBytes[i] ^ yBytes[i] + } + bitwiseFelt.SetBytes(bitwiseBytes[:]) + bitwiseValue = memory.MemoryValueFromFieldElement(&bitwiseFelt) + if err := segment.Write(xOffset+3, &bitwiseValue); err != nil { + return err + } + + for i := 0; i < 32; i++ { + bitwiseBytes[i] = xBytes[i] | yBytes[i] + } + bitwiseFelt.SetBytes(bitwiseBytes[:]) + bitwiseValue = memory.MemoryValueFromFieldElement(&bitwiseFelt) + if err := segment.Write(xOffset+4, &bitwiseValue); err != nil { + return err + } + + return nil +} + +func (b *Bitwise) String() string { + return "bitwise" +} diff --git a/pkg/vm/builtins/bitwise_test.go b/pkg/vm/builtins/bitwise_test.go new file mode 100644 index 000000000..0842950f1 --- /dev/null +++ b/pkg/vm/builtins/bitwise_test.go @@ -0,0 +1,41 @@ +package builtins + +import ( + "testing" + + "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" + "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBitwise(t *testing.T) { + bitwise := &Bitwise{} + segment := memory.EmptySegmentWithLength(5) + segment.WithBuiltinRunner(bitwise) + + x, _ := new(fp.Element).SetString("0xAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") + y, _ := new(fp.Element).SetString("0xBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB") + xValue := memory.MemoryValueFromFieldElement(x) + yValue := memory.MemoryValueFromFieldElement(y) + require.NoError(t, segment.Write(0, &xValue)) + require.NoError(t, segment.Write(1, &yValue)) + + xAndY, err := segment.Read(2) + require.NoError(t, err) + xAndYFelt, err := xAndY.FieldElement() + require.NoError(t, err) + assert.Equal(t, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", xAndYFelt.Text(16)) + + xXorY, err := segment.Read(3) + require.NoError(t, err) + xXorYFelt, err := xXorY.FieldElement() + require.NoError(t, err) + assert.Equal(t, "11111111111111111111111111111111111111111111111111111111111111", xXorYFelt.Text(16)) + + xOrY, err := segment.Read(4) + require.NoError(t, err) + xOrYFelt, err := xOrY.FieldElement() + require.NoError(t, err) + assert.Equal(t, "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", xOrYFelt.Text(16)) +} diff --git a/pkg/vm/builtins/builtin_runner.go b/pkg/vm/builtins/builtin_runner.go index df5dffa50..8d5444260 100644 --- a/pkg/vm/builtins/builtin_runner.go +++ b/pkg/vm/builtins/builtin_runner.go @@ -18,7 +18,7 @@ func Runner(name starknetParser.Builtin) memory.BuiltinRunner { case starknetParser.Keccak: panic("Not implemented") case starknetParser.Bitwise: - panic("Not implemented") + return &Bitwise{} case starknetParser.ECOP: panic("Not implemented") case starknetParser.Poseidon: diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 5fb6e54be..daa7aef94 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -273,6 +273,15 @@ func (vm *VirtualMachine) inferOperand( return mem.MemoryValue{}, nil } + dstValue, err := vm.Memory.PeekFromAddress(dstAddr) + if err != nil { + return mem.MemoryValue{}, fmt.Errorf("cannot read dst: %w", err) + } + + if !dstValue.Known() { + return mem.MemoryValue{}, nil // let computeRes try to handle it + } + op0Value, err := vm.Memory.PeekFromAddress(op0Addr) if err != nil { return mem.MemoryValue{}, fmt.Errorf("cannot read op0: %w", err) @@ -286,15 +295,6 @@ func (vm *VirtualMachine) inferOperand( return mem.MemoryValue{}, nil } - dstValue, err := vm.Memory.PeekFromAddress(dstAddr) - if err != nil { - return mem.MemoryValue{}, fmt.Errorf("cannot read dst: %w", err) - } - - if !dstValue.Known() { - return mem.MemoryValue{}, fmt.Errorf("value at dst is unknown") - } - if instruction.Res == Op1 && !op1Value.Known() { if err = vm.Memory.WriteToAddress(op1Addr, &dstValue); err != nil { return mem.MemoryValue{}, err