Skip to content

Commit

Permalink
feat: allow repeat with dtype=dict
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Jun 11, 2024
1 parent 3167a5b commit 5c74333
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 12 deletions.
48 changes: 42 additions & 6 deletions dargs/dargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class Argument:
If given, `dtype` is assumed to be dict, and its items are determined
by the `Variant`s in the given list and the value of their flag keys.
repeat: bool, optional
If true, `dtype` is assume to be list of dict and each dict consists
If true, `dtype` is assume to be list of dict or dict of dict, and each dict consists
of sub fields and sub variants described above. Defaults to false.
optional: bool, optional
If true, consider the current argument to be optional in checking.
Expand Down Expand Up @@ -235,7 +235,17 @@ def _reorg_dtype(
}
# check conner cases
if self.sub_fields or self.sub_variants:
dtype.add(list if self.repeat else dict)
if not self.repeat:
dtype.add(dict)
else:
# convert dtypes to unsubscripted types
unsubscripted_dtype = {
get_origin(dt) if get_origin(dt) is not None else dt for dt in dtype
}
if dict not in unsubscripted_dtype:
# only add list (compatible with old behaviors) if no dict in dtype
dtype.add(list)

if (
self.optional
and self.default is not _Flags.NONE
Expand Down Expand Up @@ -347,11 +357,11 @@ def traverse_value(
# in the condition where there is no leading key
if path is None:
path = []
if isinstance(value, dict):
if not self.repeat and isinstance(value, dict):
self._traverse_sub(
value, key_hook, value_hook, sub_hook, variant_hook, path
)
if isinstance(value, list) and self.repeat:
elif self.repeat and isinstance(value, list):
for idx, item in enumerate(value):
self._traverse_sub(
item,
Expand All @@ -361,6 +371,16 @@ def traverse_value(
variant_hook,
[*path, str(idx)],
)
elif self.repeat and isinstance(value, dict):
for kk, item in value.items():
self._traverse_sub(
item,
key_hook,
value_hook,
sub_hook,
variant_hook,
[*path, kk],
)

def _traverse_sub(
self,
Expand Down Expand Up @@ -653,9 +673,25 @@ def gen_doc_body(self, path: list[str] | None = None, **kwargs) -> str:
body_list.append(self.doc + "\n")
if not self.fold_subdoc:
if self.repeat:
unsubscripted_dtype = {
get_origin(dt) if get_origin(dt) is not None else dt
for dt in self.dtype
}
allowed_types = []
allowed_element = []
if list in unsubscripted_dtype:
allowed_types.append("list")
allowed_element.append("element")
elif dict in unsubscripted_dtype:
allowed_types.append("dict")
allowed_element.append("key-value pair")
else:
raise ValueError(

Check warning on line 689 in dargs/dargs.py

View check run for this annotation

Codecov / codecov/patch

dargs/dargs.py#L689

Added line #L689 was not covered by tests
"When `repeat` is True, `dtype` should contain `dict` OR `list`."
)
body_list.append(
"This argument takes a list with "
"each element containing the following: \n"
f"This argument takes a {' or '.join(allowed_types)} with "
f"each {' or '.join(allowed_element)} containing the following: \n"
)
if self.sub_fields:
# body_list.append("") # genetate a blank line
Expand Down
15 changes: 13 additions & 2 deletions dargs/sphinx.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _test_argument() -> Argument:
doc=doc_test,
sub_fields=[
Argument(
"test_repeat",
"test_repeat_list",
dtype=list,
repeat=True,
doc=doc_test,
Expand All @@ -199,7 +199,18 @@ def _test_argument() -> Argument:
"test_repeat_item", dtype=bool, doc=doc_test
),
],
)
),
Argument(
"test_repeat_dict",
dtype=dict,
repeat=True,
doc=doc_test,
sub_fields=[
Argument(
"test_repeat_item", dtype=bool, doc=doc_test
),
],
),
],
),
],
Expand Down
37 changes: 35 additions & 2 deletions tests/test_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ def test_sub_fields(self):
with self.assertRaises(ValueError):
Argument("base", dict, [Argument("sub1", int), Argument("sub1", int)])

def test_sub_repeat(self):
def test_sub_repeat_list(self):
ca = Argument(
"base", dict, [Argument("sub1", int), Argument("sub2", str)], repeat=True
"base", list, [Argument("sub1", int), Argument("sub2", str)], repeat=True
)
test_dict1 = {
"base": [{"sub1": 10, "sub2": "hello"}, {"sub1": 11, "sub2": "world"}]
Expand All @@ -124,6 +124,39 @@ def test_sub_repeat(self):
with self.assertRaises(ArgumentTypeError):
ca.check(err_dict2)

def test_sub_repeat_dict(self):
ca = Argument(
"base", dict, [Argument("sub1", int), Argument("sub2", str)], repeat=True
)
test_dict1 = {
"base": {
"item1": {"sub1": 10, "sub2": "hello"},
"item2": {"sub1": 11, "sub2": "world"},
}
}
ca.check(test_dict1)
ca.check_value(test_dict1["base"])
err_dict1 = {
"base": {
"item1": {"sub1": 10, "sub2": "hello"},
"item2": {"sub1": 11, "sub3": "world"},
}
}
with self.assertRaises(ArgumentKeyError):
ca.check(err_dict1)
err_dict1["base"]["item2"]["sub2"] = "world too"
ca.check(err_dict1) # now should pass
with self.assertRaises(ArgumentKeyError):
ca.check(err_dict1, strict=True) # but should fail when strict
err_dict2 = {
"base": {
"item1": {"sub1": 10, "sub2": "hello"},
"item2": {"sub1": 11, "sub2": None},
}
}
with self.assertRaises(ArgumentTypeError):
ca.check(err_dict2)

def test_sub_variants(self):
ca = Argument(
"base",
Expand Down
30 changes: 29 additions & 1 deletion tests/test_docgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_sub_fields(self):
jsonstr = json.dumps(ca, cls=ArgumentEncoder)
# print("\n\n"+docstr)

def test_sub_repeat(self):
def test_sub_repeat_list(self):
ca = Argument(
"base",
list,
Expand Down Expand Up @@ -70,6 +70,34 @@ def test_sub_repeat(self):
jsonstr = json.dumps(ca, cls=ArgumentEncoder)
# print("\n\n"+docstr)

def test_sub_repeat_dict(self):
ca = Argument(
"base",
dict,
[
Argument("sub1", int, doc="sub doc." * 5),
Argument(
"sub2",
[None, str, dict],
[
Argument("subsub1", int, doc="subsub doc." * 5, optional=True),
Argument(
"subsub2",
dict,
[Argument("subsubsub1", int, doc="subsubsub doc." * 5)],
doc="subsub doc." * 5,
repeat=True,
),
],
doc="sub doc." * 5,
),
],
doc="Base doc. " * 10,
repeat=True,
)
docstr = ca.gen_doc()
jsonstr = json.dumps(ca, cls=ArgumentEncoder)

def test_sub_variants(self):
ca = Argument(
"base",
Expand Down
12 changes: 11 additions & 1 deletion tests/test_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ def test_complicated(self):
repeat=True,
alias=["sub2a"],
),
Argument(
"sub2_dict",
dict,
[Argument("ss1", int, optional=True, default=21, alias=["ss1a"])],
repeat=True,
alias=["sub2a_dict"],
),
],
[
Variant(
Expand Down Expand Up @@ -145,11 +152,12 @@ def test_complicated(self):
)
],
)
beg1 = {"base": {"sub2": [{}, {}]}}
beg1 = {"base": {"sub2": [{}, {}], "sub2_dict": {"item1": {}, "item2": {}}}}
ref1 = {
"base": {
"sub1": 1,
"sub2": [{"ss1": 21}, {"ss1": 21}],
"sub2_dict": {"item1": {"ss1": 21}, "item2": {"ss1": 21}},
"vnt_flag": "type1",
"shared": -1,
"vnt1": 111,
Expand All @@ -161,6 +169,7 @@ def test_complicated(self):
"base": {
"sub1a": 2,
"sub2a": [{"ss1a": 22}, {"_comment1": None}],
"sub2a_dict": {"item1": {"ss1a": 22}, "item2": {"_comment1": None}},
"vnt_flag": "type3",
"sharedb": -3,
"vnt2a": 223,
Expand All @@ -172,6 +181,7 @@ def test_complicated(self):
"base": {
"sub1": 2,
"sub2": [{"ss1": 22}, {"ss1": 21}],
"sub2_dict": {"item1": {"ss1": 22}, "item2": {"ss1": 21}},
"vnt_flag": "type2",
"shared": -3,
"vnt2": 223,
Expand Down

0 comments on commit 5c74333

Please sign in to comment.