Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Implement Dict.setdefault(key, default) #2803

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ what we publish.

- The `Reference` type (and many iterators) now use "inferred" parameters to
represent the mutability of their lifetime, simplifying the interface.

- `Dict` now implements `setdefault`, to get a value from the dictionary by
key, or set it to a default if it doesn't exist
([PR #2803](https://github.com/modularml/mojo/pull/2803)
by [@msaelices](https://github.com/msaelices))

- Added new `ExplicitlyCopyable` trait, to mark types that can be copied
explicitly, but which might not be implicitly copyable.
Expand Down
18 changes: 18 additions & 0 deletions stdlib/src/collections/dict.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,24 @@ struct Dict[K: KeyElement, V: CollectionElement](
self._entries = Self._new_entries(Self._initial_reservation)
self._index = _DictIndex(self._reserved())

fn setdefault(
ref [_]self: Self, key: K, owned default: V
) raises -> Reference[V, __lifetime_of(self)]:
"""Get a value from the dictionary by key, or set it to a default if it doesn't exist.

Args:
key: The key to search for in the dictionary.
default: The default value to set if the key is not present.

Returns:
The value associated with the key, or the default value if it wasn't present.
"""
try:
return self.__get_ref(key)
except KeyError:
self[key] = default^
return self.__get_ref(key)

@staticmethod
@always_inline
fn _new_entries(reserve_at_least: Int) -> List[Optional[DictEntry[K, V]]]:
Expand Down
23 changes: 23 additions & 0 deletions stdlib/test/collections/test_dict.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -548,10 +548,33 @@ fn test_clear() raises:
assert_equal(len(some_dict), 0)


fn test_dict_setdefault() raises:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add another test that verifies that no copies are happening when calling setdefault ? We can use the CopyCounter struct for this purpose. There is an example here: https://github.com/modularml/mojo/blob/nightly/stdlib/test/collections/test_list.mojo#L568

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do it when I manage to fix the compiler issue above. Thank you!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is still this new test to add

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done here: msaelices@7e53b09

It's curious. I thought the case of the "b" key would have 0 copies but it's 1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you should use the transfer operator in the test ^? Maybe that's why there is a copy. Also I think calling __getitem__ triggers a copy

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried it but no luck: msaelices@91d9ee5

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be a bug in the dict implementation. let's skip this problem for now. I think it needs to be investigated later on, but that would be out of scope for this PR.

var some_dict = Dict[String, Int]()
some_dict["key1"] = 1
some_dict["key2"] = 2
assert_equal(some_dict.setdefault("key1", 0)[], 1)
assert_equal(some_dict.setdefault("key2", 0)[], 2)
assert_equal(some_dict.setdefault("not_key", 0)[], 0)
assert_equal(some_dict["not_key"], 0)

# Check that there is no copy of the default value, so it's performant
var other_dict = Dict[String, CopyCounter]()
var a = CopyCounter()
var a_def = CopyCounter()
var b_def = CopyCounter()
other_dict["a"] = a^
assert_equal(1, other_dict["a"].copy_count)
_ = other_dict.setdefault("a", a_def^)
_ = other_dict.setdefault("b", b_def^)
assert_equal(1, other_dict["a"].copy_count)
assert_equal(1, other_dict["b"].copy_count)


def main():
test_dict()
test_dict_fromkeys()
test_dict_fromkeys_optional()
test_dict_setdefault()
test_dict_string_representation_string_int()
test_dict_string_representation_int_int()
test_owned_kwargs_dict()
Expand Down