From e83ba7df6b276cba4555f99a1687551c55a884ee Mon Sep 17 00:00:00 2001 From: Kingter <83567446+kingster-will@users.noreply.github.com> Date: Fri, 21 Jun 2024 06:03:04 -0700 Subject: [PATCH] Introducing IP Graph Support --- core/genesis.go | 1 + core/vm/contracts.go | 78 ++++---- core/vm/contracts_fuzz_test.go | 12 +- core/vm/contracts_test.go | 36 +++- core/vm/evm.go | 10 +- core/vm/ipgraph.go | 339 +++++++++++++++++++++++++++++++++ 6 files changed, 430 insertions(+), 46 deletions(-) create mode 100644 core/vm/ipgraph.go diff --git a/core/genesis.go b/core/genesis.go index 4ca24807fccd..1d27f028ffac 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -593,6 +593,7 @@ func DeveloperGenesisBlock(gasLimit uint64, faucet *common.Address) *Genesis { common.BytesToAddress([]byte{7}): {Balance: big.NewInt(1)}, // ECScalarMul common.BytesToAddress([]byte{8}): {Balance: big.NewInt(1)}, // ECPairing common.BytesToAddress([]byte{9}): {Balance: big.NewInt(1)}, // BLAKE2b + common.BytesToAddress([]byte{26}): {Balance: big.NewInt(1)}, // ipGraph // Pre-deploy EIP-4788 system contract params.BeaconRootsAddress: {Nonce: 1, Code: params.BeaconRootsCode, Balance: common.Big0}, }, diff --git a/core/vm/contracts.go b/core/vm/contracts.go index dd71a9729f34..e61f7137bfff 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -21,6 +21,7 @@ import ( "encoding/binary" "errors" "fmt" + "github.com/ethereum/go-ethereum/crypto" "math/big" "github.com/consensys/gnark-crypto/ecc" @@ -30,10 +31,10 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/core/tracing" - "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto/blake2b" "github.com/ethereum/go-ethereum/crypto/bn256" "github.com/ethereum/go-ethereum/crypto/kzg4844" + "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/params" "golang.org/x/crypto/ripemd160" ) @@ -42,8 +43,8 @@ import ( // requires a deterministic gas count based on the input size of the Run method of the // contract. type PrecompiledContract interface { - RequiredGas(input []byte) uint64 // RequiredPrice calculates the contract gas use - Run(input []byte) ([]byte, error) // Run runs the precompiled contract + RequiredGas(input []byte) uint64 // RequiredPrice calculates the contract gas use + Run(evm *EVM, input []byte) ([]byte, error) // Run runs the precompiled contract } // PrecompiledContractsHomestead contains the default set of pre-compiled Ethereum @@ -99,16 +100,17 @@ var PrecompiledContractsBerlin = map[common.Address]PrecompiledContract{ // PrecompiledContractsCancun contains the default set of pre-compiled Ethereum // contracts used in the Cancun release. var PrecompiledContractsCancun = map[common.Address]PrecompiledContract{ - common.BytesToAddress([]byte{0x1}): &ecrecover{}, - common.BytesToAddress([]byte{0x2}): &sha256hash{}, - common.BytesToAddress([]byte{0x3}): &ripemd160hash{}, - common.BytesToAddress([]byte{0x4}): &dataCopy{}, - common.BytesToAddress([]byte{0x5}): &bigModExp{eip2565: true}, - common.BytesToAddress([]byte{0x6}): &bn256AddIstanbul{}, - common.BytesToAddress([]byte{0x7}): &bn256ScalarMulIstanbul{}, - common.BytesToAddress([]byte{0x8}): &bn256PairingIstanbul{}, - common.BytesToAddress([]byte{0x9}): &blake2F{}, - common.BytesToAddress([]byte{0xa}): &kzgPointEvaluation{}, + common.BytesToAddress([]byte{0x1}): &ecrecover{}, + common.BytesToAddress([]byte{0x2}): &sha256hash{}, + common.BytesToAddress([]byte{0x3}): &ripemd160hash{}, + common.BytesToAddress([]byte{0x4}): &dataCopy{}, + common.BytesToAddress([]byte{0x5}): &bigModExp{eip2565: true}, + common.BytesToAddress([]byte{0x6}): &bn256AddIstanbul{}, + common.BytesToAddress([]byte{0x7}): &bn256ScalarMulIstanbul{}, + common.BytesToAddress([]byte{0x8}): &bn256PairingIstanbul{}, + common.BytesToAddress([]byte{0x9}): &blake2F{}, + common.BytesToAddress([]byte{0xa}): &kzgPointEvaluation{}, + common.BytesToAddress([]byte{0x1a}): &ipGraph{}, } // PrecompiledContractsPrague contains the set of pre-compiled Ethereum @@ -192,7 +194,8 @@ func ActivePrecompiles(rules params.Rules) []common.Address { // - the returned bytes, // - the _remaining_ gas, // - any error that occurred -func RunPrecompiledContract(p PrecompiledContract, input []byte, suppliedGas uint64, logger *tracing.Hooks) (ret []byte, remainingGas uint64, err error) { +func RunPrecompiledContract(evm *EVM, p PrecompiledContract, input []byte, suppliedGas uint64, logger *tracing.Hooks) (ret []byte, remainingGas uint64, err error) { + log.Info("RunPrecompiledContract", "input", input, "suppliedGas", suppliedGas) gasCost := p.RequiredGas(input) if suppliedGas < gasCost { return nil, 0, ErrOutOfGas @@ -201,7 +204,8 @@ func RunPrecompiledContract(p PrecompiledContract, input []byte, suppliedGas uin logger.OnGasChange(suppliedGas, suppliedGas-gasCost, tracing.GasChangeCallPrecompiledContract) } suppliedGas -= gasCost - output, err := p.Run(input) + output, err := p.Run(evm, input) + log.Info("RunPrecompiledContract", "output", output, "err", err) return output, suppliedGas, err } @@ -212,7 +216,7 @@ func (c *ecrecover) RequiredGas(input []byte) uint64 { return params.EcrecoverGas } -func (c *ecrecover) Run(input []byte) ([]byte, error) { +func (c *ecrecover) Run(evm *EVM, input []byte) ([]byte, error) { const ecRecoverInputLength = 128 input = common.RightPadBytes(input, ecRecoverInputLength) @@ -253,7 +257,7 @@ type sha256hash struct{} func (c *sha256hash) RequiredGas(input []byte) uint64 { return uint64(len(input)+31)/32*params.Sha256PerWordGas + params.Sha256BaseGas } -func (c *sha256hash) Run(input []byte) ([]byte, error) { +func (c *sha256hash) Run(evm *EVM, input []byte) ([]byte, error) { h := sha256.Sum256(input) return h[:], nil } @@ -268,7 +272,7 @@ type ripemd160hash struct{} func (c *ripemd160hash) RequiredGas(input []byte) uint64 { return uint64(len(input)+31)/32*params.Ripemd160PerWordGas + params.Ripemd160BaseGas } -func (c *ripemd160hash) Run(input []byte) ([]byte, error) { +func (c *ripemd160hash) Run(evm *EVM, input []byte) ([]byte, error) { ripemd := ripemd160.New() ripemd.Write(input) return common.LeftPadBytes(ripemd.Sum(nil), 32), nil @@ -284,7 +288,7 @@ type dataCopy struct{} func (c *dataCopy) RequiredGas(input []byte) uint64 { return uint64(len(input)+31)/32*params.IdentityPerWordGas + params.IdentityBaseGas } -func (c *dataCopy) Run(in []byte) ([]byte, error) { +func (c *dataCopy) Run(evm *EVM, in []byte) ([]byte, error) { return common.CopyBytes(in), nil } @@ -406,7 +410,7 @@ func (c *bigModExp) RequiredGas(input []byte) uint64 { return gas.Uint64() } -func (c *bigModExp) Run(input []byte) ([]byte, error) { +func (c *bigModExp) Run(evm *EVM, input []byte) ([]byte, error) { var ( baseLen = new(big.Int).SetBytes(getData(input, 0, 32)).Uint64() expLen = new(big.Int).SetBytes(getData(input, 32, 32)).Uint64() @@ -486,7 +490,7 @@ func (c *bn256AddIstanbul) RequiredGas(input []byte) uint64 { return params.Bn256AddGasIstanbul } -func (c *bn256AddIstanbul) Run(input []byte) ([]byte, error) { +func (c *bn256AddIstanbul) Run(evm *EVM, input []byte) ([]byte, error) { return runBn256Add(input) } @@ -499,7 +503,7 @@ func (c *bn256AddByzantium) RequiredGas(input []byte) uint64 { return params.Bn256AddGasByzantium } -func (c *bn256AddByzantium) Run(input []byte) ([]byte, error) { +func (c *bn256AddByzantium) Run(evm *EVM, input []byte) ([]byte, error) { return runBn256Add(input) } @@ -524,7 +528,7 @@ func (c *bn256ScalarMulIstanbul) RequiredGas(input []byte) uint64 { return params.Bn256ScalarMulGasIstanbul } -func (c *bn256ScalarMulIstanbul) Run(input []byte) ([]byte, error) { +func (c *bn256ScalarMulIstanbul) Run(evm *EVM, input []byte) ([]byte, error) { return runBn256ScalarMul(input) } @@ -537,7 +541,7 @@ func (c *bn256ScalarMulByzantium) RequiredGas(input []byte) uint64 { return params.Bn256ScalarMulGasByzantium } -func (c *bn256ScalarMulByzantium) Run(input []byte) ([]byte, error) { +func (c *bn256ScalarMulByzantium) Run(evm *EVM, input []byte) ([]byte, error) { return runBn256ScalarMul(input) } @@ -592,7 +596,7 @@ func (c *bn256PairingIstanbul) RequiredGas(input []byte) uint64 { return params.Bn256PairingBaseGasIstanbul + uint64(len(input)/192)*params.Bn256PairingPerPointGasIstanbul } -func (c *bn256PairingIstanbul) Run(input []byte) ([]byte, error) { +func (c *bn256PairingIstanbul) Run(evm *EVM, input []byte) ([]byte, error) { return runBn256Pairing(input) } @@ -605,7 +609,7 @@ func (c *bn256PairingByzantium) RequiredGas(input []byte) uint64 { return params.Bn256PairingBaseGasByzantium + uint64(len(input)/192)*params.Bn256PairingPerPointGasByzantium } -func (c *bn256PairingByzantium) Run(input []byte) ([]byte, error) { +func (c *bn256PairingByzantium) Run(evm *EVM, input []byte) ([]byte, error) { return runBn256Pairing(input) } @@ -631,7 +635,7 @@ var ( errBlake2FInvalidFinalFlag = errors.New("invalid final flag") ) -func (c *blake2F) Run(input []byte) ([]byte, error) { +func (c *blake2F) Run(evm *EVM, input []byte) ([]byte, error) { // Make sure the input is valid (correct length and final flag) if len(input) != blake2FInputLength { return nil, errBlake2FInvalidInputLength @@ -685,7 +689,7 @@ func (c *bls12381G1Add) RequiredGas(input []byte) uint64 { return params.Bls12381G1AddGas } -func (c *bls12381G1Add) Run(input []byte) ([]byte, error) { +func (c *bls12381G1Add) Run(evm *EVM, input []byte) ([]byte, error) { // Implements EIP-2537 G1Add precompile. // > G1 addition call expects `256` bytes as an input that is interpreted as byte concatenation of two G1 points (`128` bytes each). // > Output is an encoding of addition operation result - single G1 point (`128` bytes). @@ -721,7 +725,7 @@ func (c *bls12381G1Mul) RequiredGas(input []byte) uint64 { return params.Bls12381G1MulGas } -func (c *bls12381G1Mul) Run(input []byte) ([]byte, error) { +func (c *bls12381G1Mul) Run(evm *EVM, input []byte) ([]byte, error) { // Implements EIP-2537 G1Mul precompile. // > G1 multiplication call expects `160` bytes as an input that is interpreted as byte concatenation of encoding of G1 point (`128` bytes) and encoding of a scalar value (`32` bytes). // > Output is an encoding of multiplication operation result - single G1 point (`128` bytes). @@ -773,7 +777,7 @@ func (c *bls12381G1MultiExp) RequiredGas(input []byte) uint64 { return (uint64(k) * params.Bls12381G1MulGas * discount) / 1000 } -func (c *bls12381G1MultiExp) Run(input []byte) ([]byte, error) { +func (c *bls12381G1MultiExp) Run(evm *EVM, input []byte) ([]byte, error) { // Implements EIP-2537 G1MultiExp precompile. // G1 multiplication call expects `160*k` bytes as an input that is interpreted as byte concatenation of `k` slices each of them being a byte concatenation of encoding of G1 point (`128` bytes) and encoding of a scalar value (`32` bytes). // Output is an encoding of multiexponentiation operation result - single G1 point (`128` bytes). @@ -819,7 +823,7 @@ func (c *bls12381G2Add) RequiredGas(input []byte) uint64 { return params.Bls12381G2AddGas } -func (c *bls12381G2Add) Run(input []byte) ([]byte, error) { +func (c *bls12381G2Add) Run(evm *EVM, input []byte) ([]byte, error) { // Implements EIP-2537 G2Add precompile. // > G2 addition call expects `512` bytes as an input that is interpreted as byte concatenation of two G2 points (`256` bytes each). // > Output is an encoding of addition operation result - single G2 point (`256` bytes). @@ -856,7 +860,7 @@ func (c *bls12381G2Mul) RequiredGas(input []byte) uint64 { return params.Bls12381G2MulGas } -func (c *bls12381G2Mul) Run(input []byte) ([]byte, error) { +func (c *bls12381G2Mul) Run(evm *EVM, input []byte) ([]byte, error) { // Implements EIP-2537 G2MUL precompile logic. // > G2 multiplication call expects `288` bytes as an input that is interpreted as byte concatenation of encoding of G2 point (`256` bytes) and encoding of a scalar value (`32` bytes). // > Output is an encoding of multiplication operation result - single G2 point (`256` bytes). @@ -908,7 +912,7 @@ func (c *bls12381G2MultiExp) RequiredGas(input []byte) uint64 { return (uint64(k) * params.Bls12381G2MulGas * discount) / 1000 } -func (c *bls12381G2MultiExp) Run(input []byte) ([]byte, error) { +func (c *bls12381G2MultiExp) Run(evm *EVM, input []byte) ([]byte, error) { // Implements EIP-2537 G2MultiExp precompile logic // > G2 multiplication call expects `288*k` bytes as an input that is interpreted as byte concatenation of `k` slices each of them being a byte concatenation of encoding of G2 point (`256` bytes) and encoding of a scalar value (`32` bytes). // > Output is an encoding of multiexponentiation operation result - single G2 point (`256` bytes). @@ -954,7 +958,7 @@ func (c *bls12381Pairing) RequiredGas(input []byte) uint64 { return params.Bls12381PairingBaseGas + uint64(len(input)/384)*params.Bls12381PairingPerPairGas } -func (c *bls12381Pairing) Run(input []byte) ([]byte, error) { +func (c *bls12381Pairing) Run(evm *EVM, input []byte) ([]byte, error) { // Implements EIP-2537 Pairing precompile logic. // > Pairing call expects `384*k` bytes as an inputs that is interpreted as byte concatenation of `k` slices. Each slice has the following structure: // > - `128` bytes of G1 point encoding @@ -1106,7 +1110,7 @@ func (c *bls12381MapG1) RequiredGas(input []byte) uint64 { return params.Bls12381MapG1Gas } -func (c *bls12381MapG1) Run(input []byte) ([]byte, error) { +func (c *bls12381MapG1) Run(evm *EVM, input []byte) ([]byte, error) { // Implements EIP-2537 Map_To_G1 precompile. // > Field-to-curve call expects an `64` bytes input that is interpreted as an element of the base field. // > Output of this call is `128` bytes and is G1 point following respective encoding rules. @@ -1135,7 +1139,7 @@ func (c *bls12381MapG2) RequiredGas(input []byte) uint64 { return params.Bls12381MapG2Gas } -func (c *bls12381MapG2) Run(input []byte) ([]byte, error) { +func (c *bls12381MapG2) Run(evm *EVM, input []byte) ([]byte, error) { // Implements EIP-2537 Map_FP2_TO_G2 precompile logic. // > Field-to-curve call expects an `128` bytes input that is interpreted as an element of the quadratic extension field. // > Output of this call is `256` bytes and is G2 point following respective encoding rules. @@ -1181,7 +1185,7 @@ var ( ) // Run executes the point evaluation precompile. -func (b *kzgPointEvaluation) Run(input []byte) ([]byte, error) { +func (b *kzgPointEvaluation) Run(evm *EVM, input []byte) ([]byte, error) { if len(input) != blobVerifyInputLength { return nil, errBlobVerifyInvalidInputLength } diff --git a/core/vm/contracts_fuzz_test.go b/core/vm/contracts_fuzz_test.go index 1e5cc8007471..c4bed0bd8879 100644 --- a/core/vm/contracts_fuzz_test.go +++ b/core/vm/contracts_fuzz_test.go @@ -17,6 +17,11 @@ package vm import ( + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/params" + "github.com/holiman/uint256" "testing" "github.com/ethereum/go-ethereum/common" @@ -35,8 +40,13 @@ func FuzzPrecompiledContracts(f *testing.F) { if gas > 10_000_000 { return } + vmctx := BlockContext{ + Transfer: func(StateDB, common.Address, common.Address, *uint256.Int) {}, + } + statedb, _ := state.New(types.EmptyRootHash, state.NewDatabase(rawdb.NewMemoryDatabase()), nil) + evm := NewEVM(vmctx, TxContext{}, statedb, params.AllEthashProtocolChanges, Config{}) inWant := string(input) - RunPrecompiledContract(p, input, gas, nil) + RunPrecompiledContract(evm, p, input, gas, nil) if inHave := string(input); inWant != inHave { t.Errorf("Precompiled %v modified input data", a) } diff --git a/core/vm/contracts_test.go b/core/vm/contracts_test.go index fff5c966f34f..bfd980eaad50 100644 --- a/core/vm/contracts_test.go +++ b/core/vm/contracts_test.go @@ -20,6 +20,11 @@ import ( "bytes" "encoding/json" "fmt" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/params" + "github.com/holiman/uint256" "os" "testing" "time" @@ -97,8 +102,13 @@ func testPrecompiled(addr string, test precompiledTest, t *testing.T) { p := allPrecompiles[common.HexToAddress(addr)] in := common.Hex2Bytes(test.Input) gas := p.RequiredGas(in) + vmctx := BlockContext{ + Transfer: func(StateDB, common.Address, common.Address, *uint256.Int) {}, + } + statedb, _ := state.New(types.EmptyRootHash, state.NewDatabase(rawdb.NewMemoryDatabase()), nil) + evm := NewEVM(vmctx, TxContext{}, statedb, params.AllEthashProtocolChanges, Config{}) t.Run(fmt.Sprintf("%s-Gas=%d", test.Name, gas), func(t *testing.T) { - if res, _, err := RunPrecompiledContract(p, in, gas, nil); err != nil { + if res, _, err := RunPrecompiledContract(evm, p, in, gas, nil); err != nil { t.Error(err) } else if common.Bytes2Hex(res) != test.Expected { t.Errorf("Expected %v, got %v", test.Expected, common.Bytes2Hex(res)) @@ -119,8 +129,14 @@ func testPrecompiledOOG(addr string, test precompiledTest, t *testing.T) { in := common.Hex2Bytes(test.Input) gas := p.RequiredGas(in) - 1 + vmctx := BlockContext{ + Transfer: func(StateDB, common.Address, common.Address, *uint256.Int) {}, + } + statedb, _ := state.New(types.EmptyRootHash, state.NewDatabase(rawdb.NewMemoryDatabase()), nil) + evm := NewEVM(vmctx, TxContext{}, statedb, params.AllEthashProtocolChanges, Config{}) + t.Run(fmt.Sprintf("%s-Gas=%d", test.Name, gas), func(t *testing.T) { - _, _, err := RunPrecompiledContract(p, in, gas, nil) + _, _, err := RunPrecompiledContract(evm, p, in, gas, nil) if err.Error() != "out of gas" { t.Errorf("Expected error [out of gas], got [%v]", err) } @@ -136,8 +152,15 @@ func testPrecompiledFailure(addr string, test precompiledFailureTest, t *testing p := allPrecompiles[common.HexToAddress(addr)] in := common.Hex2Bytes(test.Input) gas := p.RequiredGas(in) + + vmctx := BlockContext{ + Transfer: func(StateDB, common.Address, common.Address, *uint256.Int) {}, + } + statedb, _ := state.New(types.EmptyRootHash, state.NewDatabase(rawdb.NewMemoryDatabase()), nil) + evm := NewEVM(vmctx, TxContext{}, statedb, params.AllEthashProtocolChanges, Config{}) + t.Run(test.Name, func(t *testing.T) { - _, _, err := RunPrecompiledContract(p, in, gas, nil) + _, _, err := RunPrecompiledContract(evm, p, in, gas, nil) if err.Error() != test.ExpectedError { t.Errorf("Expected error [%v], got [%v]", test.ExpectedError, err) } @@ -168,8 +191,13 @@ func benchmarkPrecompiled(addr string, test precompiledTest, bench *testing.B) { start := time.Now() bench.ResetTimer() for i := 0; i < bench.N; i++ { + vmctx := BlockContext{ + Transfer: func(StateDB, common.Address, common.Address, *uint256.Int) {}, + } + statedb, _ := state.New(types.EmptyRootHash, state.NewDatabase(rawdb.NewMemoryDatabase()), nil) + evm := NewEVM(vmctx, TxContext{}, statedb, params.AllEthashProtocolChanges, Config{}) copy(data, in) - res, _, err = RunPrecompiledContract(p, data, reqGas, nil) + res, _, err = RunPrecompiledContract(evm, p, data, reqGas, nil) } bench.StopTimer() elapsed := uint64(time.Since(start)) diff --git a/core/vm/evm.go b/core/vm/evm.go index 1944189b5da2..dbdff98bee4e 100644 --- a/core/vm/evm.go +++ b/core/vm/evm.go @@ -26,6 +26,7 @@ import ( "github.com/ethereum/go-ethereum/core/tracing" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/params" "github.com/holiman/uint256" ) @@ -226,7 +227,8 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas evm.Context.Transfer(evm.StateDB, caller.Address(), addr, value) if isPrecompile { - ret, gas, err = RunPrecompiledContract(p, input, gas, evm.Config.Tracer) + ret, gas, err = RunPrecompiledContract(evm, p, input, gas, evm.Config.Tracer) + log.Info("Call", "precompile", true, "ret", ret, "gas", gas, "err", err) } else { // Initialise a new contract and set the code that is to be used by the EVM. // The contract is a scoped environment for this execution context only. @@ -295,7 +297,7 @@ func (evm *EVM) CallCode(caller ContractRef, addr common.Address, input []byte, // It is allowed to call precompiles, even via delegatecall if p, isPrecompile := evm.precompile(addr); isPrecompile { - ret, gas, err = RunPrecompiledContract(p, input, gas, evm.Config.Tracer) + ret, gas, err = RunPrecompiledContract(evm, p, input, gas, evm.Config.Tracer) } else { addrCopy := addr // Initialise a new contract and set the code that is to be used by the EVM. @@ -346,7 +348,7 @@ func (evm *EVM) DelegateCall(caller ContractRef, addr common.Address, input []by // It is allowed to call precompiles, even via delegatecall if p, isPrecompile := evm.precompile(addr); isPrecompile { - ret, gas, err = RunPrecompiledContract(p, input, gas, evm.Config.Tracer) + ret, gas, err = RunPrecompiledContract(evm, p, input, gas, evm.Config.Tracer) } else { addrCopy := addr // Initialise a new contract and make initialise the delegate values @@ -400,7 +402,7 @@ func (evm *EVM) StaticCall(caller ContractRef, addr common.Address, input []byte evm.StateDB.AddBalance(addr, new(uint256.Int), tracing.BalanceChangeTouchAccount) if p, isPrecompile := evm.precompile(addr); isPrecompile { - ret, gas, err = RunPrecompiledContract(p, input, gas, evm.Config.Tracer) + ret, gas, err = RunPrecompiledContract(evm, p, input, gas, evm.Config.Tracer) } else { // At this point, we use a copy of address. If we don't, the go compiler will // leak the 'contract' to the outer scope, and make allocation for 'contract' diff --git a/core/vm/ipgraph.go b/core/vm/ipgraph.go new file mode 100644 index 000000000000..fc4a4d55126c --- /dev/null +++ b/core/vm/ipgraph.go @@ -0,0 +1,339 @@ +package vm + +import ( + "bytes" + "fmt" + "math/big" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/log" +) + +var ( + ipGraphAddress = common.HexToAddress("0x000000000000000000000000000000000000001A") + addParentIpSelector = crypto.Keccak256Hash([]byte("addParentIp(address,address[])")).Bytes()[:4] + hasParentIpSelector = crypto.Keccak256Hash([]byte("hasParentIp(address,address)")).Bytes()[:4] + getParentIpsSelector = crypto.Keccak256Hash([]byte("getParentIps(address)")).Bytes()[:4] + getParentIpsCountSelector = crypto.Keccak256Hash([]byte("getParentIpsCount(address)")).Bytes()[:4] + getAncestorIpsSelector = crypto.Keccak256Hash([]byte("getAncestorIps(address)")).Bytes()[:4] + getAncestorIpsCountSelector = crypto.Keccak256Hash([]byte("getAncestorIpsCount(address)")).Bytes()[:4] + hasAncestorIpsSelector = crypto.Keccak256Hash([]byte("hasAncestorIp(address,address)")).Bytes()[:4] + setRoyaltySelector = crypto.Keccak256Hash([]byte("setRoyalty(address,address,uint256)")).Bytes()[:4] + getRoyaltySelector = crypto.Keccak256Hash([]byte("getRoyalty(address,address)")).Bytes()[:4] + getRoyaltyStackSelector = crypto.Keccak256Hash([]byte("getRoyaltyStack(address)")).Bytes()[:4] +) + +type ipGraph struct{} + +func (c *ipGraph) RequiredGas(input []byte) uint64 { + return uint64(1) +} + +func (c *ipGraph) Run(evm *EVM, input []byte) ([]byte, error) { + log.Info("ipGraph.Run", "input", input) + + if len(input) < 4 { + return nil, fmt.Errorf("input too short") + } + + selector := input[:4] + args := input[4:] + + switch { + case bytes.Equal(selector, addParentIpSelector): + return c.addParentIp(args, evm) + case bytes.Equal(selector, hasParentIpSelector): + return c.hasParentIp(args, evm) + case bytes.Equal(selector, getParentIpsSelector): + return c.getParentIps(args, evm) + case bytes.Equal(selector, getParentIpsCountSelector): + return c.getParentIpsCount(args, evm) + case bytes.Equal(selector, getAncestorIpsSelector): + return c.getAncestorIps(args, evm) + case bytes.Equal(selector, getAncestorIpsCountSelector): + return c.getAncestorIpsCount(args, evm) + case bytes.Equal(selector, hasAncestorIpsSelector): + return c.hasAncestorIp(args, evm) + case bytes.Equal(selector, setRoyaltySelector): + return c.setRoyalty(args, evm) + case bytes.Equal(selector, getRoyaltySelector): + return c.getRoyalty(args, evm) + case bytes.Equal(selector, getRoyaltyStackSelector): + return c.getRoyaltyStack(args, evm) + default: + return nil, fmt.Errorf("unknown selector") + } +} + +func (c *ipGraph) addParentIp(input []byte, evm *EVM) ([]byte, error) { + log.Info("addParentIp", "input", input) + if len(input) < 96 { + return nil, fmt.Errorf("input too short for addParentIp") + } + ipId := common.BytesToAddress(input[0:32]) + log.Info("addParentIp", "ipId", ipId) + parentCount := new(big.Int).SetBytes(getData(input, 64, 32)) + log.Info("addParentIp", "parentCount", parentCount) + + if len(input) < int(96+parentCount.Uint64()*32) { + return nil, fmt.Errorf("input too short for parent IPs") + } + + for i := 0; i < int(parentCount.Uint64()); i++ { + parentIpId := common.BytesToAddress(input[96+i*32 : 96+(i+1)*32]) + index := uint64(i) + slot := crypto.Keccak256Hash(ipId.Bytes()).Big() + slot.Add(slot, new(big.Int).SetUint64(index)) + log.Info("addParentIp", "ipId", ipId, "parentIpId", parentIpId, "slot", slot) + evm.StateDB.SetState(ipGraphAddress, common.BigToHash(slot), common.BytesToHash(parentIpId.Bytes())) + } + + log.Info("addParentIp", "ipId", ipId, "parentCount", parentCount) + evm.StateDB.SetState(ipGraphAddress, common.BytesToHash(ipId.Bytes()), common.BigToHash(parentCount)) + + return nil, nil +} + +func (c *ipGraph) hasParentIp(input []byte, evm *EVM) ([]byte, error) { + if len(input) < 64 { + return nil, fmt.Errorf("input too short for hasParentIp") + } + ipId := common.BytesToAddress(input[0:32]) + parentIpId := common.BytesToAddress(input[32:64]) + + currentLengthHash := evm.StateDB.GetState(ipGraphAddress, common.BytesToHash(ipId.Bytes())) + currentLength := currentLengthHash.Big() + log.Info("hasParentIp", "ipId", ipId, "parentIpId", parentIpId, "currentLength", currentLength) + for i := uint64(0); i < currentLength.Uint64(); i++ { + slot := crypto.Keccak256Hash(ipId.Bytes()).Big() + slot.Add(slot, new(big.Int).SetUint64(i)) + storedParent := evm.StateDB.GetState(ipGraphAddress, common.BigToHash(slot)) + log.Info("hasParentIp", "storedParent", storedParent, "parentIpId", parentIpId) + if common.BytesToAddress(storedParent.Bytes()) == parentIpId { + log.Info("hasParentIp", "found", true) + return common.LeftPadBytes([]byte{1}, 32), nil + } + } + log.Info("hasParentIp", "found", false) + return common.LeftPadBytes([]byte{0}, 32), nil +} + +func (c *ipGraph) getParentIps(input []byte, evm *EVM) ([]byte, error) { + log.Info("getParentIps", "input", input) + if len(input) < 32 { + return nil, fmt.Errorf("input too short for getParentIps") + } + ipId := common.BytesToAddress(input[0:32]) + + currentLengthHash := evm.StateDB.GetState(ipGraphAddress, common.BytesToHash(ipId.Bytes())) + currentLength := currentLengthHash.Big() + + output := make([]byte, 64+currentLength.Uint64()*32) + copy(output[0:32], common.BigToHash(new(big.Int).SetUint64(32)).Bytes()) + copy(output[32:64], common.BigToHash(currentLength).Bytes()) + + for i := uint64(0); i < currentLength.Uint64(); i++ { + slot := crypto.Keccak256Hash(ipId.Bytes()).Big() + slot.Add(slot, new(big.Int).SetUint64(i)) + storedParent := evm.StateDB.GetState(ipGraphAddress, common.BigToHash(slot)) + copy(output[64+i*32:], storedParent.Bytes()) + } + log.Info("getParentIps", "output", output) + return output, nil +} + +func (c *ipGraph) getParentIpsCount(input []byte, evm *EVM) ([]byte, error) { + log.Info("getParentIpsCount", "input", input) + if len(input) < 32 { + return nil, fmt.Errorf("input too short for getParentIpsCount") + } + ipId := common.BytesToAddress(input[0:32]) + + currentLengthHash := evm.StateDB.GetState(ipGraphAddress, common.BytesToHash(ipId.Bytes())) + currentLength := currentLengthHash.Big() + + log.Info("getParentIpsCount", "ipId", ipId, "currentLength", currentLength) + return common.BigToHash(currentLength).Bytes(), nil +} + +func (c *ipGraph) getAncestorIps(input []byte, evm *EVM) ([]byte, error) { + log.Info("getAncestorIps", "input", input) + if len(input) < 32 { + return nil, fmt.Errorf("input too short for getAncestorIps") + } + ipId := common.BytesToAddress(input[0:32]) + ancestors := c.findAncestors(ipId, evm) + + output := make([]byte, 64+len(ancestors)*32) + copy(output[0:32], common.BigToHash(new(big.Int).SetUint64(32)).Bytes()) + copy(output[32:64], common.BigToHash(new(big.Int).SetUint64(uint64(len(ancestors)))).Bytes()) + + i := 0 + for ancestor := range ancestors { + copy(output[64+i*32:], common.LeftPadBytes(ancestor.Bytes(), 32)) + i++ + } + + log.Info("getAncestorIps", "output", output) + return output, nil +} + +func (c *ipGraph) getAncestorIpsCount(input []byte, evm *EVM) ([]byte, error) { + log.Info("getAncestorIpsCount", "input", input) + if len(input) < 32 { + return nil, fmt.Errorf("input too short for getAncestorIpsCount") + } + ipId := common.BytesToAddress(input[0:32]) + ancestors := c.findAncestors(ipId, evm) + + count := new(big.Int).SetUint64(uint64(len(ancestors))) + log.Info("getAncestorIpsCount", "ipId", ipId, "count", count) + return common.BigToHash(count).Bytes(), nil +} + +func (c *ipGraph) hasAncestorIp(input []byte, evm *EVM) ([]byte, error) { + if len(input) < 64 { + return nil, fmt.Errorf("input too short for hasAncestorIp") + } + ipId := common.BytesToAddress(input[0:32]) + parentIpId := common.BytesToAddress(input[32:64]) + ancestors := c.findAncestors(ipId, evm) + + if _, found := ancestors[parentIpId]; found { + log.Info("hasAncestorIp", "found", true) + return common.LeftPadBytes([]byte{1}, 32), nil + } + log.Info("hasAncestorIp", "found", false) + return common.LeftPadBytes([]byte{0}, 32), nil +} + +func (c *ipGraph) findAncestors(ipId common.Address, evm *EVM) map[common.Address]struct{} { + ancestors := make(map[common.Address]struct{}) + var stack []common.Address + stack = append(stack, ipId) + for len(stack) > 0 { + node := stack[len(stack)-1] + stack = stack[:len(stack)-1] + + currentLengthHash := evm.StateDB.GetState(ipGraphAddress, common.BytesToHash(node.Bytes())) + currentLength := currentLengthHash.Big() + + for i := uint64(0); i < currentLength.Uint64(); i++ { + slot := crypto.Keccak256Hash(node.Bytes()).Big() + slot.Add(slot, new(big.Int).SetUint64(i)) + storedParent := evm.StateDB.GetState(ipGraphAddress, common.BigToHash(slot)) + parentIpId := common.BytesToAddress(storedParent.Bytes()) + + if _, found := ancestors[parentIpId]; !found { + ancestors[parentIpId] = struct{}{} + stack = append(stack, parentIpId) + } + } + } + return ancestors +} + +func (c *ipGraph) setRoyalty(input []byte, evm *EVM) ([]byte, error) { + log.Info("setRoyalty", "input", input) + if len(input) < 96 { + return nil, fmt.Errorf("input too short for setRoyalty") + } + ipId := common.BytesToAddress(input[0:32]) + parentIpId := common.BytesToAddress(input[32:64]) + royalty := new(big.Int).SetBytes(getData(input, 64, 32)) + slot := crypto.Keccak256Hash(ipId.Bytes(), parentIpId.Bytes()).Big() + log.Info("setRoyalty", "ipId", ipId, "parentIpId", parentIpId, "royalty", royalty, "slot", slot) + evm.StateDB.SetState(ipGraphAddress, common.BigToHash(slot), common.BigToHash(royalty)) + + return nil, nil +} + +func (c *ipGraph) getRoyalty(input []byte, evm *EVM) ([]byte, error) { + log.Info("getRoyalty", "input", input) + if len(input) < 64 { + return nil, fmt.Errorf("input too short for getRoyalty") + } + ipId := common.BytesToAddress(input[0:32]) + ancestorIpId := common.BytesToAddress(input[32:64]) + ancestors := c.findAncestors(ipId, evm) + totalRoyalty := big.NewInt(0) + for ancestor := range ancestors { + if ancestor == ancestorIpId { + // Traverse the graph to accumulate royalties + totalRoyalty.Add(totalRoyalty, c.getRoyaltyForAncestor(ipId, ancestorIpId, evm)) + } + } + + log.Info("getRoyalty", "ipId", ipId, "ancestorIpId", ancestorIpId, "totalRoyalty", totalRoyalty) + return common.BigToHash(totalRoyalty).Bytes(), nil +} + +func (c *ipGraph) getRoyaltyForAncestor(ipId, ancestorIpId common.Address, evm *EVM) *big.Int { + ancestors := make(map[common.Address]struct{}) + totalRoyalty := big.NewInt(0) + var stack []common.Address + stack = append(stack, ipId) + for len(stack) > 0 { + node := stack[len(stack)-1] + stack = stack[:len(stack)-1] + + currentLengthHash := evm.StateDB.GetState(ipGraphAddress, common.BytesToHash(node.Bytes())) + currentLength := currentLengthHash.Big() + + for i := uint64(0); i < currentLength.Uint64(); i++ { + slot := crypto.Keccak256Hash(node.Bytes()).Big() + slot.Add(slot, new(big.Int).SetUint64(i)) + storedParent := evm.StateDB.GetState(ipGraphAddress, common.BigToHash(slot)) + parentIpId := common.BytesToAddress(storedParent.Bytes()) + + if _, found := ancestors[parentIpId]; !found { + ancestors[parentIpId] = struct{}{} + stack = append(stack, parentIpId) + } + + if parentIpId == ancestorIpId { + royaltySlot := crypto.Keccak256Hash(node.Bytes(), ancestorIpId.Bytes()).Big() + royalty := evm.StateDB.GetState(ipGraphAddress, common.BigToHash(royaltySlot)).Big() + totalRoyalty.Add(totalRoyalty, royalty) + } + } + } + return totalRoyalty +} + +func (c *ipGraph) getRoyaltyStack(input []byte, evm *EVM) ([]byte, error) { + log.Info("getRoyaltyStack", "input", input) + if len(input) < 32 { + return nil, fmt.Errorf("input too short for getRoyaltyStack") + } + ipId := common.BytesToAddress(input[0:32]) + ancestors := make(map[common.Address]struct{}) + totalRoyalty := big.NewInt(0) + var stack []common.Address + stack = append(stack, ipId) + for len(stack) > 0 { + node := stack[len(stack)-1] + stack = stack[:len(stack)-1] + + currentLengthHash := evm.StateDB.GetState(ipGraphAddress, common.BytesToHash(node.Bytes())) + currentLength := currentLengthHash.Big() + + for i := uint64(0); i < currentLength.Uint64(); i++ { + slot := crypto.Keccak256Hash(node.Bytes()).Big() + slot.Add(slot, new(big.Int).SetUint64(i)) + storedParent := evm.StateDB.GetState(ipGraphAddress, common.BigToHash(slot)) + parentIpId := common.BytesToAddress(storedParent.Bytes()) + + if _, found := ancestors[parentIpId]; !found { + ancestors[parentIpId] = struct{}{} + stack = append(stack, parentIpId) + } + + royaltySlot := crypto.Keccak256Hash(node.Bytes(), parentIpId.Bytes()).Big() + royalty := evm.StateDB.GetState(ipGraphAddress, common.BigToHash(royaltySlot)).Big() + totalRoyalty.Add(totalRoyalty, royalty) + } + } + return common.BigToHash(totalRoyalty).Bytes(), nil +}