Skip to content

Commit

Permalink
keccak integration tests (#497)
Browse files Browse the repository at this point in the history
* unsafe keccak + unsafe keccak finalize

* solve many issues

* fmt

* fmt

---------

Co-authored-by: Shourya Goel <[email protected]>
  • Loading branch information
TAdev0 and Sh0g0-1758 authored Jul 4, 2024
1 parent 6cab0ef commit 312d061
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 95 deletions.
21 changes: 21 additions & 0 deletions integration_tests/cairo_zero_hint_tests/unsafe_keccak.small.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
%builtins output

from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.keccak import unsafe_keccak

func main{output_ptr: felt*}() {
alloc_locals;

let (data: felt*) = alloc();

assert data[0] = 500;
assert data[1] = 2;
assert data[2] = 3;
assert data[3] = 6;
assert data[4] = 1;
assert data[5] = 4444;

let (low: felt, high: felt) = unsafe_keccak(data, 6);

return ();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
%builtins output

from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.keccak import unsafe_keccak_finalize, KeccakState
from starkware.cairo.common.uint256 import Uint256

func main{output_ptr: felt*}() {
alloc_locals;

let (data: felt*) = alloc();

assert data[0] = 0;
assert data[1] = 1;
assert data[2] = 2;

let keccak_state = KeccakState(start_ptr=data, end_ptr=data + 2);

let res: Uint256 = unsafe_keccak_finalize(keccak_state);

return ();
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
%builtins range_check bitwise

from starkware.cairo.common.cairo_keccak.keccak import cairo_keccak, finalize_keccak
from starkware.cairo.common.uint256 import Uint256
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin
from starkware.cairo.common.alloc import alloc

func main{range_check_ptr: felt, bitwise_ptr: BitwiseBuiltin*}() {
alloc_locals;

let (keccak_ptr: felt*) = alloc();
let keccak_ptr_start = keccak_ptr;

let (inputs: felt*) = alloc();

assert inputs[0] = 8031924123371070792;
assert inputs[1] = 560229490;

let n_bytes = 16;

let (res: Uint256) = cairo_keccak{keccak_ptr=keccak_ptr}(inputs=inputs, n_bytes=n_bytes);

assert res.low = 293431514620200399776069983710520819074;
assert res.high = 317109767021952548743448767588473366791;

finalize_keccak(keccak_ptr_start=keccak_ptr_start, keccak_ptr_end=keccak_ptr);

return ();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
%builtins output range_check bitwise

from starkware.cairo.common.cairo_keccak.keccak import _keccak
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.serialize import serialize_word

func fill_array(array: felt*, base: felt, array_length: felt, iterator: felt) {
if (iterator == array_length) {
return ();
}

assert array[iterator] = base;

return fill_array(array, base, array_length, iterator + 1);
}

func main{output_ptr: felt*, range_check_ptr, bitwise_ptr: BitwiseBuiltin*}() {
alloc_locals;

let (output: felt*) = alloc();
let keccak_output = output;

let (inputs: felt*) = alloc();
let inputs_start = inputs;
fill_array(inputs, 9, 3, 0);

let (state: felt*) = alloc();
let state_start = state;
fill_array(state, 5, 25, 0);

let n_bytes = 24;

let (res: felt*) = _keccak{keccak_ptr=keccak_output}(
inputs=inputs_start, n_bytes=n_bytes, state=state_start
);

return ();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
%builtins output range_check bitwise

from starkware.cairo.common.keccak_utils.keccak_utils import keccak_add_uint256
from starkware.cairo.common.uint256 import Uint256
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.serialize import serialize_word

func main{output_ptr: felt*, range_check_ptr, bitwise_ptr: BitwiseBuiltin*}() {
alloc_locals;

let (inputs) = alloc();
let inputs_start = inputs;

let num = Uint256(34623634663146736, 598249824422424658356);

keccak_add_uint256{inputs=inputs_start}(num=num, bigend=0);

return ();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
%builtins range_check bitwise keccak
from starkware.cairo.common.cairo_builtins import KeccakBuiltin, BitwiseBuiltin
from starkware.cairo.common.builtin_keccak.keccak import keccak_uint256s
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.uint256 import Uint256

func main{range_check_ptr, bitwise_ptr: BitwiseBuiltin*, keccak_ptr: KeccakBuiltin*}() {
let elements: Uint256* = alloc();
assert elements[0] = Uint256(713458135386519, 18359173571);
assert elements[1] = Uint256(1536741637546373185, 84357893467438914);
assert elements[2] = Uint256(2842949328439284983294, 39248298942938492384);
assert elements[3] = Uint256(27518568234293478923754395731931, 981587843715983274);
assert elements[4] = Uint256(326848123647324823482, 93453458349589345);
let (res) = keccak_uint256s(5, elements);

return ();
}
15 changes: 4 additions & 11 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,11 @@ ids.multiplicities = segments.gen_arg([len(positions_dict[k]) for k in output])`
blake2sComputeCode string = "from starkware.cairo.common.cairo_blake2s.blake2s_utils import compute_blake2s_func\ncompute_blake2s_func(segments=segments, output_ptr=ids.output)"

// ------ Keccak hints related code ------
unsafeKeccakFinalizeCode string = "from eth_hash.auto import keccak\nkeccak_input = bytearray()\nn_elms = ids.keccak_state.end_ptr - ids.keccak_state.start_ptr\nfor word in memory.get_range(ids.keccak_state.start_ptr, n_elms):\n keccak_input += word.to_bytes(16, 'big')\nhashed = keccak(keccak_input)\nids.high = int.from_bytes(hashed[:16], 'big')\nids.low = int.from_bytes(hashed[16:32], 'big')"
unsafeKeccakCode string = "from eth_hash.auto import keccak\n\ndata, length = ids.data, ids.length\n\nif '__keccak_max_size' in globals():\n assert length <= __keccak_max_size, \\\n f'unsafe_keccak() can only be used with length<={__keccak_max_size}. ' \\\n f'Got: length={length}.'\n\nkeccak_input = bytearray()\nfor word_i, byte_i in enumerate(range(0, length, 16)):\n word = memory[data + word_i]\n n_bytes = min(16, length - byte_i)\n assert 0 <= word < 2 ** (8 * n_bytes)\n keccak_input += word.to_bytes(n_bytes, 'big')\n\nhashed = keccak(keccak_input)\nids.high = int.from_bytes(hashed[:16], 'big')\nids.low = int.from_bytes(hashed[16:32], 'big')"
cairoKeccakFinalizeCode string = `# Add dummy pairs of input and output.
_keccak_state_size_felts = int(ids.KECCAK_STATE_SIZE_FELTS)
_block_size = int(ids.BLOCK_SIZE)
assert 0 <= _keccak_state_size_felts < 100
assert 0 <= _block_size < 10
inp = [0] * _keccak_state_size_felts
padding = (inp + keccak_func(inp)) * _block_size
segments.write_arg(ids.keccak_ptr_end, padding)`
unsafeKeccakFinalizeCode string = "from eth_hash.auto import keccak\nkeccak_input = bytearray()\nn_elms = ids.keccak_state.end_ptr - ids.keccak_state.start_ptr\nfor word in memory.get_range(ids.keccak_state.start_ptr, n_elms):\n keccak_input += word.to_bytes(16, 'big')\nhashed = keccak(keccak_input)\nids.high = int.from_bytes(hashed[:16], 'big')\nids.low = int.from_bytes(hashed[16:32], 'big')"
unsafeKeccakCode string = "from eth_hash.auto import keccak\n\ndata, length = ids.data, ids.length\n\nif '__keccak_max_size' in globals():\n assert length <= __keccak_max_size, \\\n f'unsafe_keccak() can only be used with length<={__keccak_max_size}. ' \\\n f'Got: length={length}.'\n\nkeccak_input = bytearray()\nfor word_i, byte_i in enumerate(range(0, length, 16)):\n word = memory[data + word_i]\n n_bytes = min(16, length - byte_i)\n assert 0 <= word < 2 ** (8 * n_bytes)\n keccak_input += word.to_bytes(n_bytes, 'big')\n\nhashed = keccak(keccak_input)\nids.high = int.from_bytes(hashed[:16], 'big')\nids.low = int.from_bytes(hashed[16:32], 'big')"
cairoKeccakFinalizeCode string = "# Add dummy pairs of input and output.\n_keccak_state_size_felts = int(ids.KECCAK_STATE_SIZE_FELTS)\n_block_size = int(ids.BLOCK_SIZE)\nassert 0 <= _keccak_state_size_felts < 100\nassert 0 <= _block_size < 10\ninp = [0] * _keccak_state_size_felts\npadding = (inp + keccak_func(inp)) * _block_size\nsegments.write_arg(ids.keccak_ptr_end, padding)"
keccakWriteArgsCode string = "segments.write_arg(ids.inputs, [ids.low % 2 ** 64, ids.low // 2 ** 64])\nsegments.write_arg(ids.inputs + 2, [ids.high % 2 ** 64, ids.high // 2 ** 64])"
blockPermutationCode string = "from starkware.cairo.common.keccak_utils.keccak_utils import keccak_func\n_keccak_state_size_felts = int(ids.KECCAK_STATE_SIZE_FELTS)\nassert 0 <= _keccak_state_size_felts < 100\noutput_values = keccak_func(memory.get_range(\nids.keccak_ptr - _keccak_state_size_felts, _keccak_state_size_felts))\nsegments.write_arg(ids.keccak_ptr, output_values)"
blockPermutationCode string = "from starkware.cairo.common.keccak_utils.keccak_utils import keccak_func\n_keccak_state_size_felts = int(ids.KECCAK_STATE_SIZE_FELTS)\nassert 0 <= _keccak_state_size_felts < 100\n\noutput_values = keccak_func(memory.get_range(\n ids.keccak_ptr - _keccak_state_size_felts, _keccak_state_size_felts))\nsegments.write_arg(ids.keccak_ptr, output_values)"
compareBytesInWordCode string = "memory[ap] = to_felt_or_relocatable(ids.n_bytes < ids.BYTES_IN_WORD)"
compareKeccakFullRateInBytesCode string = "memory[ap] = to_felt_or_relocatable(ids.n_bytes >= ids.KECCAK_FULL_RATE_IN_BYTES)"
splitInput3Code string = "ids.high3, ids.low3 = divmod(memory[ids.inputs + 3], 256)"
Expand Down

0 comments on commit 312d061

Please sign in to comment.