From b8aa6521a8c0d6a79c8ab0015a538bf6930e4dfe Mon Sep 17 00:00:00 2001 From: Manuel Saelices Date: Thu, 18 Jul 2024 13:31:55 -0700 Subject: [PATCH] [External] [stdlib] Implement `Dict.setdefault(key, default)` (#42558) [External] [stdlib] Implement `Dict.setdefault(key, default)` Implement `Dict.setdefault` which returns a reference with the value of the item with the specified key. If the key does not exist, it is inserted with the specified value into the `Dict`. Co-authored-by: Manuel Saelices Closes modularml/mojo#2803 MODULAR_ORIG_COMMIT_REV_ID: a5514f823c3b7dc978607435d8082f1efe14aa13 --- docs/changelog.md | 5 +++++ stdlib/src/collections/dict.mojo | 18 ++++++++++++++++++ stdlib/test/collections/test_dict.mojo | 23 +++++++++++++++++++++++ 3 files changed, 46 insertions(+) diff --git a/docs/changelog.md b/docs/changelog.md index 206ea2aae4..d4b4344f67 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -157,6 +157,11 @@ what we publish. - The `Reference` type (and many iterators) now use "inferred" parameters to represent the mutability of their lifetime, simplifying the interface. +- `Dict` now implements `setdefault`, to get a value from the dictionary by + key, or set it to a default if it doesn't exist + ([PR #2803](https://github.com/modularml/mojo/pull/2803) + by [@msaelices](https://github.com/msaelices)) + - Added new `ExplicitlyCopyable` trait, to mark types that can be copied explicitly, but which might not be implicitly copyable. diff --git a/stdlib/src/collections/dict.mojo b/stdlib/src/collections/dict.mojo index 1b74ccbfa0..37b2d10c13 100644 --- a/stdlib/src/collections/dict.mojo +++ b/stdlib/src/collections/dict.mojo @@ -893,6 +893,24 @@ struct Dict[K: KeyElement, V: CollectionElement]( self._entries = Self._new_entries(Self._initial_reservation) self._index = _DictIndex(self._reserved()) + fn setdefault( + ref [_]self: Self, key: K, owned default: V + ) raises -> Reference[V, __lifetime_of(self)]: + """Get a value from the dictionary by key, or set it to a default if it doesn't exist. + + Args: + key: The key to search for in the dictionary. + default: The default value to set if the key is not present. + + Returns: + The value associated with the key, or the default value if it wasn't present. + """ + try: + return self._find_ref(key) + except KeyError: + self[key] = default^ + return self._find_ref(key) + @staticmethod @always_inline fn _new_entries(reserve_at_least: Int) -> List[Optional[DictEntry[K, V]]]: diff --git a/stdlib/test/collections/test_dict.mojo b/stdlib/test/collections/test_dict.mojo index fbe8823671..579f8bd887 100644 --- a/stdlib/test/collections/test_dict.mojo +++ b/stdlib/test/collections/test_dict.mojo @@ -563,6 +563,28 @@ def test_init_initial_capacity(): assert_equal(y._reserved(), 64) +fn test_dict_setdefault() raises: + var some_dict = Dict[String, Int]() + some_dict["key1"] = 1 + some_dict["key2"] = 2 + assert_equal(some_dict.setdefault("key1", 0)[], 1) + assert_equal(some_dict.setdefault("key2", 0)[], 2) + assert_equal(some_dict.setdefault("not_key", 0)[], 0) + assert_equal(some_dict["not_key"], 0) + + # Check that there is no copy of the default value, so it's performant + var other_dict = Dict[String, CopyCounter]() + var a = CopyCounter() + var a_def = CopyCounter() + var b_def = CopyCounter() + other_dict["a"] = a^ + assert_equal(1, other_dict["a"].copy_count) + _ = other_dict.setdefault("a", a_def^) + _ = other_dict.setdefault("b", b_def^) + assert_equal(1, other_dict["a"].copy_count) + assert_equal(1, other_dict["b"].copy_count) + + def main(): test_dict() test_dict_fromkeys() @@ -574,3 +596,4 @@ def main(): test_find_get() test_clear() test_init_initial_capacity() + test_dict_setdefault()