Skip to content

Commit

Permalink
feat: add support for natspec map (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Jul 11, 2024
1 parent 1a6ef43 commit 5f678f4
Show file tree
Hide file tree
Showing 12 changed files with 269 additions and 38 deletions.
75 changes: 75 additions & 0 deletions ethpm_types/contract_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,38 @@ def structs(self) -> ABIList:
"""
return self._get_abis(filter_fn=lambda a: isinstance(a, StructABI))

@property
def natspecs(self) -> Dict[str, str]:
"""
A mapping of ABI selectors to their natspec documentation.
"""
return {
**self._method_natspecs,
**self._event_natspecs,
**self._error_natspecs,
**self._struct_natspecs,
}

@cached_property
def _method_natspecs(self) -> Dict[str, str]:
# NOTE: Both Solidity and Vyper support this!
return _extract_natspec(self.devdoc or {}, "methods", self.methods)

@cached_property
def _event_natspecs(self) -> Dict[str, str]:
# NOTE: Only supported in Solidity (at time of writing this).
return _extract_natspec(self.devdoc or {}, "events", self.events)

@cached_property
def _error_natspecs(self) -> Dict[str, str]:
# NOTE: Only supported in Solidity (at time of writing this).
return _extract_natspec(self.devdoc or {}, "errors", self.errors)

@cached_property
def _struct_natspecs(self) -> Dict[str, str]:
# NOTE: Not supported in Solidity or Vyper at the time of writing this.
return _extract_natspec(self.devdoc or {}, "structs", self.structs)

@classmethod
def _selector_hash_fn(cls, selector: str) -> bytes:
# keccak is the default on most ecosystems, other ecosystems can subclass to override it
Expand Down Expand Up @@ -523,3 +555,46 @@ def get_id(aitem: ABI_W_SELECTOR_T) -> str:
List[ABI_W_SELECTOR_T], [x for x in self.abi if hasattr(x, "selector")]
)
return [(x, get_id(x)) for x in abis_with_selector]


def _extract_natspec(devdoc: dict, devdoc_key: str, abis: ABIList) -> Dict[str, str]:
result: Dict[str, str] = {}
devdocs = devdoc.get(devdoc_key, {})
for abi in abis:
dev_fields = devdocs.get(abi.selector, {})
if isinstance(dev_fields, dict):
if spec := _extract_natspec_from_dict(dev_fields, abi):
result[abi.selector] = "\n".join(spec)

elif isinstance(dev_fields, list):
for dev_field_ls_item in dev_fields:
if not isinstance(dev_field_ls_item, dict):
# Not sure.
continue

if spec := _extract_natspec_from_dict(dev_field_ls_item, abi):
result[abi.selector] = "\n".join(spec)

return result


def _extract_natspec_from_dict(data: Dict, abi: ABI) -> List[str]:
info_parts: list[str] = []

for field_key, field_doc in data.items():
if isinstance(field_doc, str):
info_parts.append(f"@{field_key} {field_doc}")
elif isinstance(field_doc, dict):
if field_key != "params":
# Not sure!
continue

for param_name, param_doc in field_doc.items():
param_type_matches = [i for i in getattr(abi, "inputs", []) if i.name == param_name]
if not param_type_matches:
continue # Unlikely?

param_type = str(param_type_matches[0].type)
info_parts.append(f"@param {param_name} {param_type} {param_doc}")

return info_parts
8 changes: 6 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@ def fn(name: str) -> ContractType:
@pytest.fixture
def get_source_path():
def fn(name: str, base: Path = SOURCE_BASE) -> Path:
for path in base.iterdir():
contracts_path = base / "contracts"
if not contracts_path.is_dir():
raise AssertionError("test setup failed - contracts directory not found")

for path in contracts_path.iterdir():
if path.stem == name:
return path

raise AssertionError("test setup failed - path not found")
raise AssertionError("test setup failed - test file '{name}' not found")

return fn

Expand Down
2 changes: 1 addition & 1 deletion tests/data/Compiled/SolidityContract.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/data/Compiled/VyperContract.json

Large diffs are not rendered by default.

File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@ contract SolidityContract {

uint256 constant MAX_FOO = 5;

/**
* @dev This is a doc for an error
*/
error ACustomError();

/**
* @dev Emitted when number is changed.
*
* `newNum` is the new number from the call.
* Expected every time number changes.
*/
event NumberChange(
bytes32 b,
uint256 prevNum,
Expand All @@ -32,9 +43,28 @@ contract SolidityContract {
uint256 indexed bar
);

event EventWithStruct(
MyStruct a_struct
);

event EventWithAddressArray(
uint32 indexed some_id,
address indexed some_address,
address[] participants,
address[1] agents
);

event EventWithUintArray(
uint256[1] agents
);

/**
* @dev This is the doc for MyStruct
**/
struct MyStruct {
address a;
bytes32 b;
uint256 c;
}

struct NestedStruct1 {
Expand Down Expand Up @@ -85,6 +115,15 @@ contract SolidityContract {
emit BarHappened(1);
}


/**
* @notice Sets a new number, with restrictions and event emission
* @dev Only the owner can call this function. The new number cannot be 5.
* @param num The new number to be set
* @custom:require num Must not be equal to 5
* @custom:modifies Sets the `myNumber` state variable
* @custom:emits Emits a `NumberChange` event with the previous number, the new number, and the previous block hash
*/
function setNumber(uint256 num) public onlyOwner {
require(num != 5);
prevNumber = myNumber;
Expand All @@ -111,7 +150,7 @@ contract SolidityContract {
}

function getStruct() public view returns(MyStruct memory) {
return MyStruct(msg.sender, blockhash(block.number - 1));
return MyStruct(msg.sender, blockhash(block.number - 1), 244);
}

function getNestedStruct1() public view returns(NestedStruct1 memory) {
Expand Down Expand Up @@ -278,4 +317,22 @@ contract SolidityContract {
function setStructArray(MyStruct[2] memory _my_struct_array) public pure {

}

function logStruct() public {
bytes32 _bytes = 0x1234567890abcdef0123456789abcdef0123456789abcdef0123456789abcdef;
MyStruct memory _struct = MyStruct(msg.sender, _bytes, 244);
emit EventWithStruct(_struct);
}

function logAddressArray() public {
address[] memory ppl = new address[](1);
ppl[0] = msg.sender;
address[1] memory agts = [msg.sender];
emit EventWithAddressArray(1001, msg.sender, ppl, agts);
}

function logUintArray() public {
uint256[1] memory agts = [uint256(1)];
emit EventWithUintArray(agts);
}
}
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# @version 0.3.7
# @version 0.3.9

# @dev Emitted when number is changed.
#
# `newNum` is the new number from the call.
# Expected every time number changes.
event NumberChange:
b: bytes32
prevNum: uint256
Expand All @@ -16,9 +20,23 @@ event FooHappened:
event BarHappened:
bar: indexed(uint256)

event EventWithStruct:
a_struct: MyStruct

event EventWithAddressArray:
some_id: uint256
some_address: address
participants: DynArray[address, 1024]
agents: address[1]

event EventWithUintArray:
agents: uint256[1]

# @dev This is the doc for MyStruct
struct MyStruct:
a: address
b: bytes32
c: uint256

struct NestedStruct1:
t: MyStruct
Expand Down Expand Up @@ -69,6 +87,14 @@ def fooAndBar():

@external
def setNumber(num: uint256):
"""
@notice Sets a new number, with restrictions and event emission
@dev Only the owner can call this function. The new number cannot be 5.
@param num The new number to be set
@custom:require num Must not be equal to 5
@custom:modifies Sets the `myNumber` state variable
@custom:emits Emits a `NumberChange` event with the previous number, the new number, and the previous block hash
"""
assert msg.sender == self.owner, "!authorized"
assert num != 5
self.prevNumber = self.myNumber
Expand All @@ -87,27 +113,27 @@ def setBalance(_address: address, bal: uint256):
@view
@external
def getStruct() -> MyStruct:
return MyStruct({a: msg.sender, b: block.prevhash})
return MyStruct({a: msg.sender, b: block.prevhash, c: 244})

@view
@external
def getNestedStruct1() -> NestedStruct1:
return NestedStruct1({t: MyStruct({a: msg.sender, b: block.prevhash}), foo: 1})
return NestedStruct1({t: MyStruct({a: msg.sender, b: block.prevhash, c: 244}), foo: 1})

@view
@external
def getNestedStruct2() -> NestedStruct2:
return NestedStruct2({foo: 2, t: MyStruct({a: msg.sender, b: block.prevhash})})
return NestedStruct2({foo: 2, t: MyStruct({a: msg.sender, b: block.prevhash, c: 244})})

@view
@external
def getNestedStructWithTuple1() -> (NestedStruct1, uint256):
return (NestedStruct1({t: MyStruct({a: msg.sender, b: block.prevhash}), foo: 1}), 1)
return (NestedStruct1({t: MyStruct({a: msg.sender, b: block.prevhash, c: 244}), foo: 1}), 1)

@view
@external
def getNestedStructWithTuple2() -> (uint256, NestedStruct2):
return (2, NestedStruct2({foo: 2, t: MyStruct({a: msg.sender, b: block.prevhash})}))
return (2, NestedStruct2({foo: 2, t: MyStruct({a: msg.sender, b: block.prevhash, c: 244})}))

@pure
@external
Expand Down Expand Up @@ -149,8 +175,8 @@ def getStructWithArray() -> WithArray:
{
foo: 1,
arr: [
MyStruct({a: msg.sender, b: block.prevhash}),
MyStruct({a: msg.sender, b: block.prevhash})
MyStruct({a: msg.sender, b: block.prevhash, c: 244}),
MyStruct({a: msg.sender, b: block.prevhash, c: 244})
],
bar: 2
}
Expand Down Expand Up @@ -180,16 +206,16 @@ def getAddressArray() -> DynArray[address, 2]:
@external
def getDynamicStructArray() -> DynArray[NestedStruct1, 2]:
return [
NestedStruct1({t: MyStruct({a: msg.sender, b: block.prevhash}), foo: 1}),
NestedStruct1({t: MyStruct({a: msg.sender, b: block.prevhash}), foo: 2})
NestedStruct1({t: MyStruct({a: msg.sender, b: block.prevhash, c: 244}), foo: 1}),
NestedStruct1({t: MyStruct({a: msg.sender, b: block.prevhash, c: 244}), foo: 2})
]

@view
@external
def getStaticStructArray() -> NestedStruct2[2]:
return [
NestedStruct2({foo: 1, t: MyStruct({a: msg.sender, b: block.prevhash})}),
NestedStruct2({foo: 2, t: MyStruct({a: msg.sender, b: block.prevhash})})
NestedStruct2({foo: 1, t: MyStruct({a: msg.sender, b: block.prevhash, c: 244})}),
NestedStruct2({foo: 2, t: MyStruct({a: msg.sender, b: block.prevhash, c: 244})})
]

@pure
Expand Down Expand Up @@ -270,3 +296,25 @@ def setStruct(_my_struct: MyStruct):
@external
def setStructArray(_my_struct_array: MyStruct[2]):
pass

@external
def logStruct():
_bytes: bytes32 = 0x1234567890abcdef0123456789abcdef0123456789abcdef0123456789abcdef
_struct: MyStruct = MyStruct({
a: msg.sender,
b: _bytes,
c: 244
})
log EventWithStruct(_struct)

@external
def logAddressArray():
ppl: DynArray[address, 1024] = []
ppl.append(msg.sender)
agts: address[1] = [msg.sender]
log EventWithAddressArray(1001, msg.sender, ppl, agts)

@external
def logUintArray():
agts: uint256[1] = [1]
log EventWithUintArray(agts)
Loading

0 comments on commit 5f678f4

Please sign in to comment.