-
Notifications
You must be signed in to change notification settings - Fork 180
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #6232 from onflow/ramtin/evm-refactor-precompiled-…
…call-tracker [Flow EVM] Refactoring precompiled contract call tracker
- Loading branch information
Showing
13 changed files
with
401 additions
and
160 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
package emulator | ||
|
||
import ( | ||
"bytes" | ||
"sort" | ||
|
||
"github.com/onflow/flow-go/fvm/evm/types" | ||
) | ||
|
||
// CallTracker captures precompiled calls | ||
type CallTracker struct { | ||
callsByAddress map[types.Address]*types.PrecompiledCalls | ||
} | ||
|
||
// NewCallTracker constructs a new CallTracker | ||
func NewCallTracker() *CallTracker { | ||
return &CallTracker{} | ||
} | ||
|
||
// RegisterPrecompiledContract registers a precompiled contract for tracking | ||
func (ct *CallTracker) RegisterPrecompiledContract(pc types.PrecompiledContract) types.PrecompiledContract { | ||
return &WrappedPrecompiledContract{ | ||
pc: pc, | ||
ct: ct, | ||
} | ||
} | ||
|
||
// CaptureRequiredGas captures a required gas call | ||
func (ct *CallTracker) CaptureRequiredGas(address types.Address, input []byte, output uint64) { | ||
if ct.callsByAddress == nil { | ||
ct.callsByAddress = make(map[types.Address]*types.PrecompiledCalls) | ||
} | ||
calls, found := ct.callsByAddress[address] | ||
if !found { | ||
calls = &types.PrecompiledCalls{ | ||
Address: address, | ||
} | ||
ct.callsByAddress[address] = calls | ||
} | ||
|
||
calls.RequiredGasCalls = append(calls.RequiredGasCalls, types.RequiredGasCall{ | ||
Input: input, | ||
Output: output, | ||
}) | ||
} | ||
|
||
// CaptureRun captures a run calls | ||
func (ct *CallTracker) CaptureRun(address types.Address, input []byte, output []byte, err error) { | ||
if ct.callsByAddress == nil { | ||
ct.callsByAddress = make(map[types.Address]*types.PrecompiledCalls) | ||
} | ||
calls, found := ct.callsByAddress[address] | ||
if !found { | ||
calls = &types.PrecompiledCalls{ | ||
Address: address, | ||
} | ||
ct.callsByAddress[address] = calls | ||
} | ||
errMsg := "" | ||
if err != nil { | ||
errMsg = err.Error() | ||
} | ||
calls.RunCalls = append(calls.RunCalls, types.RunCall{ | ||
Input: input, | ||
Output: output, | ||
ErrorMsg: errMsg, | ||
}) | ||
} | ||
|
||
// IsCalled returns true if any calls has been captured | ||
func (ct *CallTracker) IsCalled() bool { | ||
return len(ct.callsByAddress) != 0 | ||
} | ||
|
||
// Encoded | ||
func (ct *CallTracker) CapturedCalls() ([]byte, error) { | ||
if !ct.IsCalled() { | ||
return nil, nil | ||
} | ||
// else constructs an aggregated precompiled calls | ||
apc := make(types.AggregatedPrecompiledCalls, 0) | ||
|
||
sortedAddresses := make([]types.Address, 0, len(ct.callsByAddress)) | ||
// we need to sort by address to stay deterministic | ||
for addr := range ct.callsByAddress { | ||
sortedAddresses = append(sortedAddresses, addr) | ||
} | ||
|
||
sort.Slice(sortedAddresses, | ||
func(i, j int) bool { | ||
return bytes.Compare(sortedAddresses[i][:], sortedAddresses[j][:]) < 0 | ||
}) | ||
|
||
for _, addr := range sortedAddresses { | ||
apc = append(apc, *ct.callsByAddress[addr]) | ||
} | ||
|
||
return apc.Encode() | ||
} | ||
|
||
// Resets the tracker | ||
func (ct *CallTracker) Reset() { | ||
ct.callsByAddress = nil | ||
} | ||
|
||
type WrappedPrecompiledContract struct { | ||
pc types.PrecompiledContract | ||
ct *CallTracker | ||
} | ||
|
||
func (wpc *WrappedPrecompiledContract) Address() types.Address { | ||
return wpc.pc.Address() | ||
} | ||
func (wpc *WrappedPrecompiledContract) RequiredGas(input []byte) uint64 { | ||
output := wpc.pc.RequiredGas(input) | ||
wpc.ct.CaptureRequiredGas(wpc.pc.Address(), input, output) | ||
return output | ||
} | ||
|
||
func (wpc *WrappedPrecompiledContract) Run(input []byte) ([]byte, error) { | ||
output, err := wpc.pc.Run(input) | ||
wpc.ct.CaptureRun(wpc.pc.Address(), input, output, err) | ||
return output, err | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
package emulator_test | ||
|
||
import ( | ||
"errors" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
|
||
"github.com/onflow/flow-go/fvm/evm/emulator" | ||
"github.com/onflow/flow-go/fvm/evm/testutils" | ||
"github.com/onflow/flow-go/fvm/evm/types" | ||
) | ||
|
||
func TestTracker(t *testing.T) { | ||
apc := testutils.AggregatedPrecompiledCallsFixture(t) | ||
var runCallCounter int | ||
var requiredGasCallCounter int | ||
pc := &MockedPrecompiled{ | ||
AddressFunc: func() types.Address { | ||
return apc[0].Address | ||
}, | ||
RequiredGasFunc: func(input []byte) uint64 { | ||
res := apc[0].RequiredGasCalls[requiredGasCallCounter] | ||
require.Equal(t, res.Input, input) | ||
requiredGasCallCounter += 1 | ||
return res.Output | ||
}, | ||
RunFunc: func(input []byte) ([]byte, error) { | ||
res := apc[0].RunCalls[runCallCounter] | ||
require.Equal(t, res.Input, input) | ||
runCallCounter += 1 | ||
var err error | ||
if len(res.ErrorMsg) > 0 { | ||
err = errors.New(res.ErrorMsg) | ||
} | ||
return res.Output, err | ||
}, | ||
} | ||
tracker := emulator.NewCallTracker() | ||
wpc := tracker.RegisterPrecompiledContract(pc) | ||
|
||
require.Equal(t, apc[0].Address, wpc.Address()) | ||
for _, pc := range apc { | ||
for _, call := range pc.RequiredGasCalls { | ||
require.Equal(t, call.Output, wpc.RequiredGas(call.Input)) | ||
} | ||
for _, call := range pc.RunCalls { | ||
ret, err := wpc.Run(call.Input) | ||
require.Equal(t, call.Output, ret) | ||
errMsg := "" | ||
if err != nil { | ||
errMsg = err.Error() | ||
} | ||
require.Equal(t, call.ErrorMsg, errMsg) | ||
} | ||
|
||
} | ||
require.True(t, tracker.IsCalled()) | ||
|
||
expectedEncoded, err := apc.Encode() | ||
require.NoError(t, err) | ||
encoded, err := tracker.CapturedCalls() | ||
require.NoError(t, err) | ||
require.Equal(t, expectedEncoded, encoded) | ||
} |
Oops, something went wrong.