Skip to content

Commit

Permalink
feat: add validation for multi message execution wasm (#2092)
Browse files Browse the repository at this point in the history
* feat: add validation for multi message execution wasm

* chore: changelog

* merge conflicts + simplify

---------

Co-authored-by: Unique-Divine <[email protected]>
  • Loading branch information
matthiasmatt and Unique-Divine authored Oct 30, 2024
1 parent 9d0dbd5 commit 7b7beb7
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ reverts inside of a try-catch.
for (1) ERC20 transfers with tokens that return false success values instead of
throwing an error and (2) ERC20 transfers with other operations that don't bring
about the expected resulting balance for the transfer recipient.
- [#2092](https://github.com/NibiruChain/nibiru/pull/2092) - feat(evm): add validation for wasm multi message execution

#### Nibiru EVM | Before Audit 1 - 2024-10-18

Expand Down
12 changes: 9 additions & 3 deletions x/evm/precompile/wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/NibiruChain/nibiru/v2/x/evm/embeds"

wasmkeeper "github.com/CosmWasm/wasmd/x/wasm/keeper"
wasm "github.com/CosmWasm/wasmd/x/wasm/types"
gethabi "github.com/ethereum/go-ethereum/accounts/abi"
gethcommon "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/vm"
Expand Down Expand Up @@ -275,10 +276,15 @@ func (p precompileWasm) executeMulti(
callerBech32 := eth.EthAddrToNibiruAddr(caller)

var responses [][]byte
for _, m := range wasmExecMsgs {
for i, m := range wasmExecMsgs {
wasmContract, e := sdk.AccAddressFromBech32(m.ContractAddr)
if e != nil {
err = fmt.Errorf("Execute failed: %w", e)
err = fmt.Errorf("Execute failed at index %d: %w", i, e)
return
}
msgArgsCopy := wasm.RawContractMessage(m.MsgArgs)
if e := msgArgsCopy.ValidateBasic(); e != nil {
err = fmt.Errorf("Execute failed at index %d: error parsing msg args: %w", i, e)
return
}
var funds sdk.Coins
Expand All @@ -290,7 +296,7 @@ func (p precompileWasm) executeMulti(
}
respBz, e := p.Wasm.Execute(ctx, wasmContract, callerBech32, m.MsgArgs, funds)
if e != nil {
err = e
err = fmt.Errorf("Execute failed at index %d: %w", i, e)
return
}
responses = append(responses, respBz)
Expand Down
183 changes: 183 additions & 0 deletions x/evm/precompile/wasm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
wasm "github.com/CosmWasm/wasmd/x/wasm/types"

"github.com/NibiruChain/nibiru/v2/x/common/testutil"
"github.com/NibiruChain/nibiru/v2/x/common/testutil/testapp"
"github.com/NibiruChain/nibiru/v2/x/evm/embeds"
"github.com/NibiruChain/nibiru/v2/x/evm/evmtest"
"github.com/NibiruChain/nibiru/v2/x/evm/precompile"
Expand Down Expand Up @@ -313,3 +314,185 @@ func (s *WasmSuite) TestSadArgsExecute() {
})
}
}

type WasmExecuteMsg struct {
ContractAddr string `json:"contractAddr"`
MsgArgs []byte `json:"msgArgs"`
Funds []precompile.WasmBankCoin `json:"funds"`
}

func (s *WasmSuite) TestExecuteMultiValidation() {
deps := evmtest.NewTestDeps()

s.Require().NoError(testapp.FundAccount(
deps.App.BankKeeper,
deps.Ctx,
deps.Sender.NibiruAddr,
sdk.NewCoins(sdk.NewCoin("unibi", sdk.NewInt(100))),
))

wasmContracts := test.SetupWasmContracts(&deps, &s.Suite)
wasmContract := wasmContracts[1] // hello_world_counter.wasm

invalidMsgArgsBz := []byte(`{"invalid": "json"}`) // Invalid message format
validMsgArgsBz := []byte(`{"increment": {}}`) // Valid increment message

var emptyFunds []precompile.WasmBankCoin
validFunds := []precompile.WasmBankCoin{{
Denom: "unibi",
Amount: big.NewInt(100),
}}
invalidFunds := []precompile.WasmBankCoin{{
Denom: "invalid!denom",
Amount: big.NewInt(100),
}}

testCases := []struct {
name string
executeMsgs []WasmExecuteMsg
wantError string
}{
{
name: "valid - single message",
executeMsgs: []WasmExecuteMsg{
{
ContractAddr: wasmContract.String(),
MsgArgs: validMsgArgsBz,
Funds: emptyFunds,
},
},
wantError: "",
},
{
name: "valid - multiple messages",
executeMsgs: []WasmExecuteMsg{
{
ContractAddr: wasmContract.String(),
MsgArgs: validMsgArgsBz,
Funds: validFunds,
},
{
ContractAddr: wasmContract.String(),
MsgArgs: validMsgArgsBz,
Funds: emptyFunds,
},
},
wantError: "",
},
{
name: "invalid - malformed contract address",
executeMsgs: []WasmExecuteMsg{
{
ContractAddr: "invalid-address",
MsgArgs: validMsgArgsBz,
Funds: emptyFunds,
},
},
wantError: "decoding bech32 failed",
},
{
name: "invalid - malformed message args",
executeMsgs: []WasmExecuteMsg{
{
ContractAddr: wasmContract.String(),
MsgArgs: invalidMsgArgsBz,
Funds: emptyFunds,
},
},
wantError: "unknown variant",
},
{
name: "invalid - malformed funds",
executeMsgs: []WasmExecuteMsg{
{
ContractAddr: wasmContract.String(),
MsgArgs: validMsgArgsBz,
Funds: invalidFunds,
},
},
wantError: "invalid coins",
},
{
name: "invalid - second message fails validation",
executeMsgs: []WasmExecuteMsg{
{
ContractAddr: wasmContract.String(),
MsgArgs: validMsgArgsBz,
Funds: emptyFunds,
},
{
ContractAddr: wasmContract.String(),
MsgArgs: invalidMsgArgsBz,
Funds: emptyFunds,
},
},
wantError: "unknown variant",
},
}

for _, tc := range testCases {
s.Run(tc.name, func() {
callArgs := []any{tc.executeMsgs}
input, err := embeds.SmartContract_Wasm.ABI.Pack(
string(precompile.WasmMethod_executeMulti),
callArgs...,
)
s.Require().NoError(err)

ethTxResp, _, err := deps.EvmKeeper.CallContractWithInput(
deps.Ctx, deps.Sender.EthAddr, &precompile.PrecompileAddr_Wasm, true, input,
)

if tc.wantError != "" {
s.Require().ErrorContains(err, tc.wantError)
return
}
s.Require().NoError(err)
s.NotNil(ethTxResp)
s.NotEmpty(ethTxResp.Ret)
})
}
}

// TestExecuteMultiPartialExecution ensures that no state changes occur if any message
// in the batch fails validation
func (s *WasmSuite) TestExecuteMultiPartialExecution() {
deps := evmtest.NewTestDeps()
wasmContracts := test.SetupWasmContracts(&deps, &s.Suite)
wasmContract := wasmContracts[1] // hello_world_counter.wasm

// First verify initial state is 0
test.AssertWasmCounterState(&s.Suite, deps, wasmContract, 0)

// Create a batch where the second message will fail validation
executeMsgs := []WasmExecuteMsg{
{
ContractAddr: wasmContract.String(),
MsgArgs: []byte(`{"increment": {}}`),
Funds: []precompile.WasmBankCoin{},
},
{
ContractAddr: wasmContract.String(),
MsgArgs: []byte(`{"invalid": "json"}`), // This will fail validation
Funds: []precompile.WasmBankCoin{},
},
}

callArgs := []any{executeMsgs}
input, err := embeds.SmartContract_Wasm.ABI.Pack(
string(precompile.WasmMethod_executeMulti),
callArgs...,
)
s.Require().NoError(err)

ethTxResp, _, err := deps.EvmKeeper.CallContractWithInput(
deps.Ctx, deps.Sender.EthAddr, &precompile.PrecompileAddr_Wasm, true, input,
)

// Verify that the call failed
s.Require().Error(err, "ethTxResp: ", ethTxResp)
s.Require().Contains(err.Error(), "unknown variant")

// Verify that no state changes occurred
test.AssertWasmCounterState(&s.Suite, deps, wasmContract, 0)
}
3 changes: 0 additions & 3 deletions x/evm/statedb/statedb.go
Original file line number Diff line number Diff line change
Expand Up @@ -480,9 +480,6 @@ func (s *StateDB) Snapshot() int {

// RevertToSnapshot reverts all state changes made since the given revision.
func (s *StateDB) RevertToSnapshot(revid int) {
fmt.Printf("len(s.validRevisions): %d\n", len(s.validRevisions))
fmt.Printf("s.validRevisions: %v\n", s.validRevisions)

// Find the snapshot in the stack of valid snapshots.
idx := sort.Search(len(s.validRevisions), func(i int) bool {
return s.validRevisions[i].id >= revid
Expand Down

0 comments on commit 7b7beb7

Please sign in to comment.