diff --git a/accounts/smart_account.go b/accounts/smart_account.go index cb105d3..81d308f 100644 --- a/accounts/smart_account.go +++ b/accounts/smart_account.go @@ -75,14 +75,39 @@ func (a *SmartAccount) Address() common.Address { return a.address } -// Balance returns the balance of the specified token that can be either ETH or any ERC20 token. +// Balance returns the balance of the specified token that can be either base token or any ERC20 token. // The block number can be nil, in which case the balance is taken from the latest known block. func (a *SmartAccount) Balance(ctx context.Context, token common.Address, at *big.Int) (*big.Int, error) { err := a.cacheData(ctx) if err != nil { return nil, err } - if token == utils.LegacyEthAddress || token == a.baseToken { + + if token == utils.LegacyEthAddress { + token = utils.EthAddressInContracts + } + + isEthBasedChain, err := a.client.IsEthBasedChain(ensureContext(ctx)) + if err != nil { + return nil, err + } + + if token == utils.EthAddressInContracts && !isEthBasedChain { + l2EthAddress, l2TokenAddressErr := a.client.L2TokenAddress(ensureContext(ctx), utils.EthAddressInContracts) + if l2TokenAddressErr != nil { + return nil, l2TokenAddressErr + } + token = l2EthAddress + } else if token == utils.EthAddressInContracts && isEthBasedChain { + token = utils.L2BaseTokenAddress + } + + isBaseToken, err := a.client.IsBaseToken(ctx, token) + if err != nil { + return nil, err + } + + if isBaseToken { return a.client.BalanceAt(ensureContext(ctx), a.Address(), at) } erc20Token, err := erc20.NewIERC20(token, a.client) diff --git a/accounts/wallet_l1.go b/accounts/wallet_l1.go index 082a625..21ee09c 100644 --- a/accounts/wallet_l1.go +++ b/accounts/wallet_l1.go @@ -183,6 +183,7 @@ func (a *WalletL1) IsEthBasedChain(ctx context.Context) (bool, error) { // BalanceL1 returns the balance of the specified token on L1 that can be // either ETH or any ERC20 token. func (a *WalletL1) BalanceL1(opts *CallOpts, token common.Address) (*big.Int, error) { + // TODO fix this callOpts := ensureCallOpts(opts).ToCallOpts(a.auth.From) if token == utils.LegacyEthAddress || token == utils.L2BaseTokenAddress || token == utils.EthAddressInContracts { return a.clientL1.BalanceAt(callOpts.Context, a.auth.From, callOpts.BlockNumber) diff --git a/accounts/wallet_l2.go b/accounts/wallet_l2.go index f79db9d..9bd9601 100644 --- a/accounts/wallet_l2.go +++ b/accounts/wallet_l2.go @@ -106,6 +106,25 @@ func (a *WalletL2) Signer() Signer { // Balance returns the balance of the specified token that can be either base token or any ERC20 token. // The block number can be nil, in which case the balance is taken from the latest known block. func (a *WalletL2) Balance(ctx context.Context, token common.Address, at *big.Int) (*big.Int, error) { + if token == utils.LegacyEthAddress { + token = utils.EthAddressInContracts + } + + isEthBasedChain, err := a.client.IsEthBasedChain(ensureContext(ctx)) + if err != nil { + return nil, err + } + + if token == utils.EthAddressInContracts && !isEthBasedChain { + l2EthAddress, l2TokenAddressErr := a.client.L2TokenAddress(ensureContext(ctx), utils.EthAddressInContracts) + if l2TokenAddressErr != nil { + return nil, l2TokenAddressErr + } + token = l2EthAddress + } else if token == utils.EthAddressInContracts && isEthBasedChain { + token = utils.L2BaseTokenAddress + } + isBaseToken, err := a.IsBaseToken(ctx, token) if err != nil { return nil, err diff --git a/test/setup_test.go b/test/setup_test.go index 9bc35b5..b22c601 100644 --- a/test/setup_test.go +++ b/test/setup_test.go @@ -108,6 +108,23 @@ func deployMultisigAccount(wallet *accounts.Wallet, client clients.Client) { if err != nil { log.Fatal(err) } + + if !IsEthBasedChain { + // transfer base token to multisig account + transferTx, err = wallet.Transfer(nil, accounts.TransferTransaction{ + To: multisigAccountAddress, + Amount: big.NewInt(2_000_000_000_000_000_000), + Token: utils.L2BaseTokenAddress, + }) + if err != nil { + log.Fatal(err) + } + + _, err = client.WaitMined(context.Background(), transferTx.Hash()) + if err != nil { + log.Fatal(err) + } + } } func deployPaymasterAndToken(wallet *accounts.Wallet, client clients.Client) { diff --git a/test/smart_account_test.go b/test/smart_account_test.go index 824142e..4f93dc7 100644 --- a/test/smart_account_test.go +++ b/test/smart_account_test.go @@ -673,18 +673,19 @@ func TestIntegration_NonEthBasedChain_SmartAccount_TransferEth(t *testing.T) { defer client.Close() assert.NoError(t, err, "clients.DialBase should not return an error") - account := accounts.NewECDSASmartAccount(Address1, PrivateKey1, client) + sender := accounts.NewECDSASmartAccount(Address1, PrivateKey1, client) + receiver := accounts.NewECDSASmartAccount(Address2, PrivateKey2, client) l2EthAddress, err := client.L2TokenAddress(context.Background(), utils.EthAddressInContracts) assert.NoError(t, err, "L2TokenAddress should not return an error") - balanceBeforeTransferSender, err := account.Balance(context.Background(), l2EthAddress, nil) + balanceBeforeTransferSender, err := sender.Balance(context.Background(), l2EthAddress, nil) assert.NoError(t, err, "Balance should not return an error") balanceBeforeTransferReceiver, err := client.BalanceAt(context.Background(), Address2, nil) assert.NoError(t, err, "BalanceAt should not return an error") - txHash, err := account.Transfer(nil, accounts.TransferTransaction{ + txHash, err := sender.Transfer(nil, accounts.TransferTransaction{ To: Address2, Amount: amount, Token: utils.LegacyEthAddress, // l2EthAddress @@ -695,14 +696,14 @@ func TestIntegration_NonEthBasedChain_SmartAccount_TransferEth(t *testing.T) { assert.NoError(t, err, "client.WaitMined should not return an error") assert.NotNil(t, receipt.BlockHash, "Transaction should be mined") - balanceAfterTransferSender, err := account.Balance(context.Background(), l2EthAddress, nil) + balanceAfterTransferSender, err := sender.Balance(context.Background(), l2EthAddress, nil) assert.NoError(t, err, "Balance should not return an error") - balanceAfterTransferReceiver, err := client.BalanceAt(context.Background(), Address2, nil) + balanceAfterTransferReceiver, err := receiver.Balance(context.Background(), l2EthAddress, nil) assert.NoError(t, err, "BalanceAt should not return an error") assert.True(t, new(big.Int).Sub(balanceBeforeTransferSender, balanceAfterTransferSender).Cmp(amount) >= 0, "Sender balance should be decreased") - assert.True(t, new(big.Int).Sub(balanceAfterTransferReceiver, balanceBeforeTransferReceiver).Cmp(amount) >= 0, "Address2 balance should be increased") + assert.True(t, new(big.Int).Sub(balanceAfterTransferReceiver, balanceBeforeTransferReceiver).Cmp(amount) >= 0, "Receiver balance should be increased") } func TestIntegration_EthBasedChain_SmartAccount_TransferEthUsingPaymaster(t *testing.T) { @@ -786,7 +787,8 @@ func TestIntegration_NonEthBasedChain_SmartAccount_TransferEthUsingPaymaster(t * defer client.Close() assert.NoError(t, err, "clients.DialBase should not return an error") - account := accounts.NewECDSASmartAccount(Address1, PrivateKey1, client) + sender := accounts.NewECDSASmartAccount(Address1, PrivateKey1, client) + receiver := accounts.NewECDSASmartAccount(Address2, PrivateKey2, client) l2EthAddress, err := client.L2TokenAddress(context.Background(), utils.EthAddressInContracts) assert.NoError(t, err, "L2TokenAddress should not return an error") @@ -794,13 +796,13 @@ func TestIntegration_NonEthBasedChain_SmartAccount_TransferEthUsingPaymaster(t * approvalToken, err := erc20.NewIERC20(ApprovalToken, client) assert.NoError(t, err, "NewIERC20 should not return an error") - balanceBeforeTransferSender, err := account.Balance(context.Background(), l2EthAddress, nil) + balanceBeforeTransferSender, err := sender.Balance(context.Background(), l2EthAddress, nil) assert.NoError(t, err, "Balance should not return an error") approvalTokenBalanceBeforeTransferSender, err := approvalToken.BalanceOf(nil, Address1) assert.NoError(t, err, "BalanceOf should not return an error") - balanceBeforeTransferReceiver, err := client.BalanceAt(context.Background(), Address2, nil) + balanceBeforeTransferReceiver, err := receiver.Balance(context.Background(), l2EthAddress, nil) assert.NoError(t, err, "BalanceAt should not return an error") balanceBeforeTransferPaymaster, err := client.BalanceAt(context.Background(), Paymaster, nil) @@ -818,7 +820,7 @@ func TestIntegration_NonEthBasedChain_SmartAccount_TransferEthUsingPaymaster(t * }) assert.NoError(t, err, "GetPaymasterParams should not return an error") - txHash, err := account.Transfer(nil, accounts.TransferTransaction{ + txHash, err := sender.Transfer(nil, accounts.TransferTransaction{ To: Address2, Amount: amount, Token: utils.LegacyEthAddress, // or l2EthAddress @@ -830,13 +832,13 @@ func TestIntegration_NonEthBasedChain_SmartAccount_TransferEthUsingPaymaster(t * assert.NoError(t, err, "client.WaitMined should not return an error") assert.NotNil(t, receipt.BlockHash, "Transaction should be mined") - balanceAfterTransferSender, err := account.Balance(context.Background(), l2EthAddress, nil) + balanceAfterTransferSender, err := sender.Balance(context.Background(), l2EthAddress, nil) assert.NoError(t, err, "Balance should not return an error") approvalTokenBalanceAfterTransferSender, err := approvalToken.BalanceOf(nil, Address1) assert.NoError(t, err, "BalanceOf should not return an error") - balanceAfterTransferReceiver, err := client.BalanceAt(context.Background(), Address2, nil) + balanceAfterTransferReceiver, err := receiver.Balance(context.Background(), l2EthAddress, nil) assert.NoError(t, err, "BalanceAt should not return an error") balanceAfterTransferPaymaster, err := client.BalanceAt(context.Background(), Paymaster, nil) @@ -851,7 +853,7 @@ func TestIntegration_NonEthBasedChain_SmartAccount_TransferEthUsingPaymaster(t * assert.True(t, new(big.Int).Sub(balanceBeforeTransferSender, balanceAfterTransferSender).Cmp(amount) >= 0, "Sender balance should be decreased") assert.True(t, new(big.Int).Sub(approvalTokenBalanceBeforeTransferSender, minimalAllowance).Cmp(approvalTokenBalanceAfterTransferSender) == 0, "Sender approval token balance should be decreased") - assert.True(t, new(big.Int).Sub(balanceAfterTransferReceiver, balanceBeforeTransferReceiver).Cmp(amount) >= 0, "Address2 balance should be increased") + assert.True(t, new(big.Int).Sub(balanceAfterTransferReceiver, balanceBeforeTransferReceiver).Cmp(amount) >= 0, "Receiver balance should be increased") } func TestIntegration_NonEthBasedChain_SmartAccount_TransferBaseToken(t *testing.T) { @@ -1467,22 +1469,23 @@ func TestIntegrationMultisigSmartAccount_WithdrawTokenUsingPaymaster(t *testing. assert.True(t, new(big.Int).Sub(approvalTokenBalanceBeforeWithdrawal, minimalAllowance).Cmp(approvalTokenBalanceAfterWithdrawal) == 0, "Sender approval token balance should be decreased") } -func TestIntegrationMultisigSmartAccount_Transfer(t *testing.T) { +func TestIntegrationMultisigSmartAccount_TransferEth(t *testing.T) { amount := big.NewInt(7_000_000_000) client, err := clients.DialBase(L2ChainURL) defer client.Close() assert.NoError(t, err, "clients.DialBase should not return an error") - account := accounts.NewMultisigECDSASmartAccount(MultisigAccount, []string{PrivateKey1, PrivateKey2}, client) + sender := accounts.NewMultisigECDSASmartAccount(MultisigAccount, []string{PrivateKey1, PrivateKey2}, client) + receiver := accounts.NewECDSASmartAccount(Address2, PrivateKey2, client) - balanceBeforeTransferSender, err := account.Balance(context.Background(), utils.LegacyEthAddress, nil) + balanceBeforeTransferSender, err := sender.Balance(context.Background(), utils.LegacyEthAddress, nil) assert.NoError(t, err, "Balance should not return an error") - balanceBeforeTransferReceiver, err := client.BalanceAt(context.Background(), Address2, nil) + balanceBeforeTransferReceiver, err := receiver.Balance(context.Background(), utils.LegacyEthAddress, nil) assert.NoError(t, err, "BalanceAt should not return an error") - txHash, err := account.Transfer(nil, accounts.TransferTransaction{ + txHash, err := sender.Transfer(nil, accounts.TransferTransaction{ To: Address2, Amount: amount, Token: utils.LegacyEthAddress, @@ -1493,17 +1496,17 @@ func TestIntegrationMultisigSmartAccount_Transfer(t *testing.T) { assert.NoError(t, err, "client.WaitMined should not return an error") assert.NotNil(t, receipt.BlockHash, "Transaction should be mined") - balanceAfterTransferSender, err := account.Balance(context.Background(), utils.LegacyEthAddress, nil) + balanceAfterTransferSender, err := sender.Balance(context.Background(), utils.LegacyEthAddress, nil) assert.NoError(t, err, "Balance should not return an error") - balanceAfterTransferReceiver, err := client.BalanceAt(context.Background(), Address2, nil) + balanceAfterTransferReceiver, err := receiver.Balance(context.Background(), utils.LegacyEthAddress, nil) assert.NoError(t, err, "BalanceAt should not return an error") assert.True(t, new(big.Int).Sub(balanceBeforeTransferSender, balanceAfterTransferSender).Cmp(amount) >= 0, "Sender balance should be decreased") - assert.True(t, new(big.Int).Sub(balanceAfterTransferReceiver, balanceBeforeTransferReceiver).Cmp(amount) >= 0, "Address2 balance should be increased") + assert.True(t, new(big.Int).Sub(balanceAfterTransferReceiver, balanceBeforeTransferReceiver).Cmp(amount) >= 0, "Receiver balance should be increased") } -func TestIntegrationMultisigSmartAccount_TransferUsingPaymaster(t *testing.T) { +func TestIntegrationMultisigSmartAccount_TransferEthUsingPaymaster(t *testing.T) { amount := big.NewInt(7_000_000_000) minimalAllowance := big.NewInt(1) @@ -1512,6 +1515,7 @@ func TestIntegrationMultisigSmartAccount_TransferUsingPaymaster(t *testing.T) { assert.NoError(t, err, "clients.DialBase should not return an error") account := accounts.NewMultisigECDSASmartAccount(MultisigAccount, []string{PrivateKey1, PrivateKey2}, client) + receiver := accounts.NewECDSASmartAccount(Address2, PrivateKey2, client) approvalToken, err := erc20.NewIERC20(ApprovalToken, client) assert.NoError(t, err, "NewIERC20 should not return an error") @@ -1522,7 +1526,7 @@ func TestIntegrationMultisigSmartAccount_TransferUsingPaymaster(t *testing.T) { approvalTokenBalanceBeforeTransferSender, err := approvalToken.BalanceOf(nil, MultisigAccount) assert.NoError(t, err, "BalanceOf should not return an error") - balanceBeforeTransferReceiver, err := client.BalanceAt(context.Background(), Address2, nil) + balanceBeforeTransferReceiver, err := receiver.Balance(context.Background(), utils.LegacyEthAddress, nil) assert.NoError(t, err, "BalanceAt should not return an error") balanceBeforeTransferPaymaster, err := client.BalanceAt(context.Background(), Paymaster, nil) @@ -1558,7 +1562,7 @@ func TestIntegrationMultisigSmartAccount_TransferUsingPaymaster(t *testing.T) { approvalTokenBalanceAfterTransferSender, err := approvalToken.BalanceOf(nil, MultisigAccount) assert.NoError(t, err, "BalanceOf should not return an error") - balanceAfterTransferReceiver, err := client.BalanceAt(context.Background(), Address2, nil) + balanceAfterTransferReceiver, err := receiver.Balance(context.Background(), utils.LegacyEthAddress, nil) assert.NoError(t, err, "BalanceAt should not return an error") balanceAfterTransferPaymaster, err := client.BalanceAt(context.Background(), Paymaster, nil) @@ -1573,7 +1577,7 @@ func TestIntegrationMultisigSmartAccount_TransferUsingPaymaster(t *testing.T) { assert.True(t, new(big.Int).Sub(balanceBeforeTransferSender, balanceAfterTransferSender).Cmp(amount) >= 0, "Sender balance should be decreased") assert.True(t, new(big.Int).Sub(approvalTokenBalanceBeforeTransferSender, minimalAllowance).Cmp(approvalTokenBalanceAfterTransferSender) == 0, "Sender approval token balance should be decreased") - assert.True(t, new(big.Int).Sub(balanceAfterTransferReceiver, balanceBeforeTransferReceiver).Cmp(amount) >= 0, "Address2 balance should be increased") + assert.True(t, new(big.Int).Sub(balanceAfterTransferReceiver, balanceBeforeTransferReceiver).Cmp(amount) >= 0, "Receiver balance should be increased") } func TestIntegrationMultisigSmartAccount_TransferToken(t *testing.T) { diff --git a/test/wallet_test.go b/test/wallet_test.go index 588a78c..36ba954 100644 --- a/test/wallet_test.go +++ b/test/wallet_test.go @@ -516,13 +516,13 @@ func TestIntegration_NonEthBasedChain_Wallet_WithdrawBaseToken(t *testing.T) { baseToken, err := wallet.BaseToken(nil) assert.NoError(t, err, "BaseToken should not return an error") - l2BalanceBeforeWithdrawal, err := wallet.Balance(context.Background(), baseToken, nil) + l2BalanceBeforeWithdrawal, err := wallet.Balance(context.Background(), utils.L2BaseTokenAddress, nil) assert.NoError(t, err, "Balance should not return an error") withdrawTx, err := wallet.Withdraw(nil, accounts.WithdrawalTransaction{ To: wallet.Address(), Amount: amount, - Token: utils.LegacyEthAddress, + Token: utils.L2BaseTokenAddress, }) assert.NoError(t, err, "Withdraw should not return an error")