Skip to content

Commit

Permalink
add bitwise builtin support
Browse files Browse the repository at this point in the history
  • Loading branch information
omerfirmak committed Oct 9, 2023
1 parent f7b2424 commit 7a1e828
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 10 deletions.
18 changes: 18 additions & 0 deletions integration_tests/builtin_tests/bitwise_builtin_test.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
%builtins bitwise
from starkware.cairo.common.bitwise import bitwise_and, bitwise_xor, bitwise_or, bitwise_operations
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin

func main{bitwise_ptr: BitwiseBuiltin*}() {
let (and_a) = bitwise_and(12, 10); // Binary (1100, 1010).
assert and_a = 8; // Binary 1000.
let (xor_a) = bitwise_xor(12, 10);
assert xor_a = 6;
let (or_a) = bitwise_or(12, 10);
assert or_a = 14;

let (and_b, xor_b, or_b) = bitwise_operations(9, 3);
assert and_b = 1;
assert xor_b = 10;
assert or_b = 11;
return ();
}
10 changes: 10 additions & 0 deletions integration_tests/cairozero_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,13 @@ func TestFailingRangeCheck(t *testing.T) {

clean("./failing_contracts/")
}

func TestBitwise(t *testing.T) {
compiledOutput, err := compileZeroCode("./builtin_tests/bitwise_builtin_test.cairo")
require.NoError(t, err)

_, _, err = runVm(compiledOutput)
require.NoError(t, err)

clean("./builtin_tests/")
}
87 changes: 87 additions & 0 deletions pkg/vm/builtins/bitwise.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package builtins

import (
"errors"

"github.com/NethermindEth/cairo-vm-go/pkg/vm/memory"
"github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
)

const cellsPerBitwise = 5
const inputCellsPerBitwise = 2

type Bitwise struct{}

func (b *Bitwise) CheckWrite(segment *memory.Segment, offset uint64, value *memory.MemoryValue) error {
return nil
}

func (b *Bitwise) InferValue(segment *memory.Segment, offset uint64) error {
bitwiseIndex := offset % cellsPerBitwise
// input cell
if bitwiseIndex < inputCellsPerBitwise {
return errors.New("cannot infer value")
}

xOffset := offset - bitwiseIndex
yOffset := xOffset + 1

xValue, err := segment.Read(xOffset)
if err != nil {
return err
}

yValue, err := segment.Read(yOffset)
if err != nil {
return err
}

xFelt, err := xValue.FieldElement()
if err != nil {
return err
}

yFelt, err := yValue.FieldElement()
if err != nil {
return err
}

xBytes := xFelt.Bytes()
yBytes := yFelt.Bytes()

var bitwiseValue memory.MemoryValue
var bitwiseFelt fp.Element
var bitwiseBytes [32]byte
for i := 0; i < 32; i++ {
bitwiseBytes[i] = xBytes[i] & yBytes[i]
}
bitwiseFelt.SetBytes(bitwiseBytes[:])
bitwiseValue = memory.MemoryValueFromFieldElement(&bitwiseFelt)
if err := segment.Write(xOffset+2, &bitwiseValue); err != nil {
return err
}

for i := 0; i < 32; i++ {
bitwiseBytes[i] = xBytes[i] ^ yBytes[i]
}
bitwiseFelt.SetBytes(bitwiseBytes[:])
bitwiseValue = memory.MemoryValueFromFieldElement(&bitwiseFelt)
if err := segment.Write(xOffset+3, &bitwiseValue); err != nil {
return err
}

for i := 0; i < 32; i++ {
bitwiseBytes[i] = xBytes[i] | yBytes[i]
}
bitwiseFelt.SetBytes(bitwiseBytes[:])
bitwiseValue = memory.MemoryValueFromFieldElement(&bitwiseFelt)
if err := segment.Write(xOffset+4, &bitwiseValue); err != nil {
return err
}

return nil
}

func (b *Bitwise) String() string {
return "bitwise"
}
41 changes: 41 additions & 0 deletions pkg/vm/builtins/bitwise_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package builtins

import (
"testing"

"github.com/NethermindEth/cairo-vm-go/pkg/vm/memory"
"github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestBitwise(t *testing.T) {
bitwise := &Bitwise{}
segment := memory.EmptySegmentWithLength(5)
segment.WithBuiltinRunner(bitwise)

x, _ := new(fp.Element).SetString("0xAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
y, _ := new(fp.Element).SetString("0xBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB")
xValue := memory.MemoryValueFromFieldElement(x)
yValue := memory.MemoryValueFromFieldElement(y)
require.NoError(t, segment.Write(0, &xValue))
require.NoError(t, segment.Write(1, &yValue))

xAndY, err := segment.Read(2)
require.NoError(t, err)
xAndYFelt, err := xAndY.FieldElement()
require.NoError(t, err)
assert.Equal(t, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", xAndYFelt.Text(16))

xXorY, err := segment.Read(3)
require.NoError(t, err)
xXorYFelt, err := xXorY.FieldElement()
require.NoError(t, err)
assert.Equal(t, "11111111111111111111111111111111111111111111111111111111111111", xXorYFelt.Text(16))

xOrY, err := segment.Read(4)
require.NoError(t, err)
xOrYFelt, err := xOrY.FieldElement()
require.NoError(t, err)
assert.Equal(t, "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", xOrYFelt.Text(16))
}
2 changes: 1 addition & 1 deletion pkg/vm/builtins/builtin_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func Runner(name starknetParser.Builtin) memory.BuiltinRunner {
case starknetParser.Keccak:
panic("Not implemented")
case starknetParser.Bitwise:
panic("Not implemented")
return &Bitwise{}
case starknetParser.ECOP:
panic("Not implemented")
case starknetParser.Poseidon:
Expand Down
18 changes: 9 additions & 9 deletions pkg/vm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,15 @@ func (vm *VirtualMachine) inferOperand(
return mem.MemoryValue{}, nil
}

dstValue, err := vm.Memory.PeekFromAddress(dstAddr)
if err != nil {
return mem.MemoryValue{}, fmt.Errorf("cannot read dst: %w", err)
}

if !dstValue.Known() {
return mem.MemoryValue{}, nil // let computeRes try to handle it
}

op0Value, err := vm.Memory.PeekFromAddress(op0Addr)
if err != nil {
return mem.MemoryValue{}, fmt.Errorf("cannot read op0: %w", err)
Expand All @@ -286,15 +295,6 @@ func (vm *VirtualMachine) inferOperand(
return mem.MemoryValue{}, nil
}

dstValue, err := vm.Memory.PeekFromAddress(dstAddr)
if err != nil {
return mem.MemoryValue{}, fmt.Errorf("cannot read dst: %w", err)
}

if !dstValue.Known() {
return mem.MemoryValue{}, fmt.Errorf("value at dst is unknown")
}

if instruction.Res == Op1 && !op1Value.Known() {
if err = vm.Memory.WriteToAddress(op1Addr, &dstValue); err != nil {
return mem.MemoryValue{}, err
Expand Down

0 comments on commit 7a1e828

Please sign in to comment.