Skip to content

Commit

Permalink
Optimized local variable allocation scheme (#684)
Browse files Browse the repository at this point in the history
* working through a unit test

* working through a unit test

* an attempt to fix local variable over allocation causing compilation disagreement

* remove excessive allocations

* renmaing file name to be sth more appropriate

* a bit abstraction on abi allocation in subroutine eval

* multi proc testing

* revert test related code

* testcase to ensure that the fix works

* testcase that breaks on master

* feels frustrated, debugged whole morning don't know why

* declare new instance in context

* minor

* remove unnecessary ctx manager for proto

* minor

---------

Co-authored-by: Zeph Grunschlag <[email protected]>
  • Loading branch information
ahangsu and Zeph Grunschlag authored Mar 7, 2023
1 parent b6a8172 commit 42e6981
Show file tree
Hide file tree
Showing 46 changed files with 2,429 additions and 1,896 deletions.
52 changes: 37 additions & 15 deletions pyteal/ast/subroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def get_declaration_by_option(
decl = self.option_map[fp_option]
if decl is not None:
return decl
self.option_map[fp_option] = self.option_method[fp_option](self.subroutine)
self.option_map[fp_option] = self.option_method[fp_option].evaluate(
self.subroutine
)
return cast(SubroutineDeclaration, self.option_map[fp_option])

def __probe_info(self, fp_option: bool) -> tuple[bool, TealType]:
Expand Down Expand Up @@ -829,7 +831,7 @@ def __call__(self, fn_implementation: Callable[..., Expr]) -> SubroutineFnWrappe


@contextmanager
def _frame_pointer_context(proto: Proto):
def _frame_pointer_context(proto: Proto | None):
tmp, SubroutineEval._current_proto = SubroutineEval._current_proto, proto
yield proto
SubroutineEval._current_proto = tmp
Expand Down Expand Up @@ -911,6 +913,23 @@ class SubroutineEval:
use_frame_pt: bool = False
_current_proto: ClassVar[Optional[Proto]] = None

@staticmethod
def _new_abi_instance_from_storage(
spec: abi.TypeSpec, storage: FrameVar
) -> abi.BaseType:
"""
This hidden method generates new ABI instance that is tied to the storage: FrameVar as follows:
- generates new instance that is based on scratch vars
- rewind the new instance to be using storage: FrameVar
- rewind the state changed by scratch slot allocation
"""
current_scratch_id = ScratchSlot.nextSlotId
with _frame_pointer_context(None):
instance = spec.new_instance()
instance._stored_value = storage
ScratchSlot.reset_slot_numbering(current_scratch_id)
return instance

@staticmethod
def var_n_loaded_scratch(
subroutine: SubroutineDefinition,
Expand All @@ -924,7 +943,8 @@ def var_n_loaded_scratch(
argument_var = DynamicScratchVar(TealType.anytype)
loaded_var = argument_var
elif param in subroutine.abi_args:
internal_abi_var = subroutine.abi_args[param].new_instance()
with _frame_pointer_context(None):
internal_abi_var = subroutine.abi_args[param].new_instance()
argument_var = cast(ScratchVar, internal_abi_var._stored_value)
loaded_var = internal_abi_var
else:
Expand All @@ -951,11 +971,12 @@ def var_n_loaded_fp(
argument_var = DynamicScratchVar(TealType.anytype)
loaded_var = argument_var
elif param in subroutine.abi_args:
internal_abi_var = subroutine.abi_args[param].new_instance()
dig_index = (
subroutine.arguments().index(param) - subroutine.argument_count()
)
internal_abi_var._stored_value = FrameVar(proto, dig_index)
internal_abi_var = SubroutineEval._new_abi_instance_from_storage(
subroutine.abi_args[param], FrameVar(proto, dig_index)
)
argument_var = None
loaded_var = internal_abi_var
else:
Expand Down Expand Up @@ -1002,7 +1023,7 @@ def __proto(subroutine: SubroutineDefinition) -> Proto:

return Proto(subroutine.argument_count(), num_stack_outputs, mem_layout=layout)

def __call__(self, subroutine: SubroutineDefinition) -> SubroutineDeclaration:
def evaluate(self, subroutine: SubroutineDefinition) -> SubroutineDeclaration:
proto = self.__proto(subroutine)

args = subroutine.arguments()
Expand All @@ -1021,22 +1042,23 @@ def __call__(self, subroutine: SubroutineDefinition) -> SubroutineDeclaration:
output_carrying_abi: Optional[abi.BaseType] = None

if output_kwarg_info:
output_carrying_abi = output_kwarg_info.abi_type.new_instance()
if self.use_frame_pt:
output_carrying_abi._stored_value = FrameVar(proto, 0)
if not self.use_frame_pt:
with _frame_pointer_context(None):
output_carrying_abi = output_kwarg_info.abi_type.new_instance()
else:
output_carrying_abi = SubroutineEval._new_abi_instance_from_storage(
output_kwarg_info.abi_type, FrameVar(proto, 0)
)

abi_output_kwargs[output_kwarg_info.name] = output_carrying_abi

# Arg usage "B" supplied to build an AST from the user-defined PyTEAL function:
subroutine_body: Expr
if not self.use_frame_pt:

with _frame_pointer_context(proto if self.use_frame_pt else None):
subroutine_body = subroutine.implementation(
*loaded_args, **abi_output_kwargs
)
else:
with _frame_pointer_context(proto):
subroutine_body = subroutine.implementation(
*loaded_args, **abi_output_kwargs
)

if not isinstance(subroutine_body, Expr):
raise TealInputError(
Expand Down
91 changes: 83 additions & 8 deletions pyteal/ast/subroutine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass

import pyteal as pt
from pyteal.ast.frame import Proto, ProtoStackLayout, FrameBury, FrameDig
from pyteal.ast.frame import FrameVar, Proto, ProtoStackLayout, FrameBury, FrameDig
from pyteal.ast.subroutine import ABIReturnSubroutine, SubroutineEval
from pyteal.compiler.compiler import FRAME_POINTERS_VERSION

Expand Down Expand Up @@ -374,7 +374,7 @@ def var_abi_output_impl(*, output: pt.abi.Uint16):
)

# Now we get to _validate_annotation():
one_vanilla = mock_subroutine_definition(lambda x: pt.Return(pt.Int(1)))
one_vanilla = mock_subroutine_definition(lambda _: pt.Return(pt.Int(1)))

params, anns, arg_types, byrefs, abi_args, output_kwarg = one_vanilla._validate()
assert len(params) == 1
Expand Down Expand Up @@ -1159,7 +1159,7 @@ def mySubroutine():
definition = pt.SubroutineDefinition(mySubroutine, return_type)
evaluate_subroutine = SubroutineEval.normal_evaluator()

declaration = evaluate_subroutine(definition)
declaration = evaluate_subroutine.evaluate(definition)

assert isinstance(declaration, pt.SubroutineDeclaration)
assert declaration.subroutine is definition
Expand Down Expand Up @@ -1187,7 +1187,7 @@ def mySubroutine(a1):
definition = pt.SubroutineDefinition(mySubroutine, return_type)

evaluate_subroutine = SubroutineEval.normal_evaluator()
declaration = evaluate_subroutine(definition)
declaration = evaluate_subroutine.evaluate(definition)

assert isinstance(declaration, pt.SubroutineDeclaration)
assert declaration.subroutine is definition
Expand Down Expand Up @@ -1224,7 +1224,7 @@ def mySubroutine(a1, a2):
definition = pt.SubroutineDefinition(mySubroutine, return_type)

evaluate_subroutine = SubroutineEval.normal_evaluator()
declaration = evaluate_subroutine(definition)
declaration = evaluate_subroutine.evaluate(definition)

assert isinstance(declaration, pt.SubroutineDeclaration)
assert declaration.subroutine is definition
Expand Down Expand Up @@ -1264,7 +1264,7 @@ def mySubroutine(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10):
definition = pt.SubroutineDefinition(mySubroutine, return_type)

evaluate_subroutine = SubroutineEval.normal_evaluator()
declaration = evaluate_subroutine(definition)
declaration = evaluate_subroutine.evaluate(definition)

assert isinstance(declaration, pt.SubroutineDeclaration)
assert declaration.subroutine is definition
Expand Down Expand Up @@ -1315,7 +1315,7 @@ def mySubroutine_arg_10(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10):
definition = pt.SubroutineDefinition(subr, return_type)
evaluate_subroutine = SubroutineEval.fp_evaluator()

declaration = evaluate_subroutine(definition)
declaration = evaluate_subroutine.evaluate(definition)

assert isinstance(declaration, pt.SubroutineDeclaration)
assert declaration.subroutine is definition
Expand Down Expand Up @@ -1790,7 +1790,7 @@ def test_evaluate_subroutine_local_variables(test_case: LocalVariableTestCase):
(SubroutineEval.normal_evaluator(), test_case.expected_body_normal_evaluator),
(SubroutineEval.fp_evaluator(), test_case.expected_body_fp_evaluator),
):
declaration = evaluator(definition)
declaration = evaluator.evaluate(definition)

evaluator_type = "fp" if evaluator.use_frame_pt else "normal"
failure_msg = f"assertion failed for {evaluator_type} evaluator"
Expand Down Expand Up @@ -2044,3 +2044,78 @@ def test_frame_option_version_range_well_formed():
assert (
pt.Op.callsub.min_version < FRAME_POINTERS_VERSION < pt.MAX_PROGRAM_VERSION + 1
)


def test_new_abi_instance_from_storage():
current_proto = Proto(num_args=2, num_returns=1)

current_scratch_slot_id = pt.ScratchSlot.nextSlotId

arg_storage = FrameVar(current_proto, -1)
some_arg_from_proto = SubroutineEval._new_abi_instance_from_storage(
pt.abi.Uint64TypeSpec(),
arg_storage,
)

assert some_arg_from_proto._stored_value == arg_storage
assert current_scratch_slot_id == pt.ScratchSlot.nextSlotId

ret_storage = FrameVar(current_proto, 0)
ret_from_proto = SubroutineEval._new_abi_instance_from_storage(
pt.abi.AddressTypeSpec(),
ret_storage,
)

assert ret_from_proto._stored_value == ret_storage
assert current_scratch_slot_id == pt.ScratchSlot.nextSlotId


def test_subroutine_evaluation_local_allocation_correct():
foo = pt.abi.Uint64()

@pt.ABIReturnSubroutine
def get(
x: pt.abi.Uint64, y: pt.abi.Uint8, *, output: pt.abi.DynamicBytes
) -> pt.Expr:
return pt.Seq(
output.set(pt.Bytes("")),
)

@pt.ABIReturnSubroutine
def get_fie(y: pt.abi.Uint8, *, output: pt.abi.Uint64) -> pt.Expr:
data = pt.abi.make(pt.abi.DynamicBytes)
return pt.Seq(
data.set(get(foo, y)),
output.set(pt.Btoi(data.get())),
)

@pt.ABIReturnSubroutine
def set_(x: pt.abi.Uint64, y: pt.abi.Uint8) -> pt.Expr:
return pt.Seq()

router = pt.Router("Jane Doe")

@router.method
def fie(y: pt.abi.Uint8) -> pt.Expr:
old_amount = pt.abi.Uint64()

return pt.Seq(
old_amount.set(get_fie(y)),
set_(foo, y),
)

evaluator = SubroutineEval.fp_evaluator()

evaluated_fie = evaluator.evaluate(cast(pt.ABIReturnSubroutine, fie).subroutine)
layout_fie = cast(Proto, cast(pt.Seq, evaluated_fie.body).args[0]).mem_layout

assert len(layout_fie.local_stack_types) == 1

evaluated_get_fie = evaluator.evaluate(
cast(pt.ABIReturnSubroutine, get_fie).subroutine
)
layout_get_fie = cast(
Proto, cast(pt.Seq, evaluated_get_fie.body).args[0]
).mem_layout

assert len(layout_get_fie.local_stack_types) == 2
8 changes: 4 additions & 4 deletions tests/abi_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,13 @@ def string_reverse(x: self.annotation, *, output: self.annotation): # type: ign

def tuple_comp_factory(self) -> pt.ABIReturnSubroutine: # type: ignore[name-defined]
value_type_specs: list[abi.TypeSpec] = self.type_spec.value_type_specs() # type: ignore[attr-defined]
insts = [vts.new_instance() for vts in value_type_specs]
roundtrips: list[ABIRoundtrip[T]] = [
ABIRoundtrip(inst, length=None) for inst in insts # type: ignore[arg-type]
]

@pt.ABIReturnSubroutine
def tuple_complement(x: self.annotation, *, output: self.annotation): # type: ignore[name-defined]
insts = [vts.new_instance() for vts in value_type_specs]
roundtrips: list[ABIRoundtrip[T]] = [
ABIRoundtrip(inst, length=None) for inst in insts # type: ignore[arg-type]
]
setters = [inst.set(x[i]) for i, inst in enumerate(insts)] # type: ignore[attr-defined]
comp_funcs = [rtrip.mutator_factory() for rtrip in roundtrips]
compers = [inst.set(comp_funcs[i](inst)) for i, inst in enumerate(insts)] # type: ignore[attr-defined]
Expand Down
7 changes: 7 additions & 0 deletions tests/integration/abi_roundtrip_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,13 @@ class NamedTupleInherit(abi.NamedTuple):
],
2,
),
abi.Tuple1[
abi.Tuple3[
abi.Uint64,
abi.DynamicBytes,
abi.StaticArray[abi.Uint64, Literal[1]],
],
],
NamedTupleInherit,
]

Expand Down
Loading

0 comments on commit 42e6981

Please sign in to comment.