From c84c3292f783d67fda0beb1debf0a70f50460630 Mon Sep 17 00:00:00 2001 From: modularbot <116839051+modularbot@users.noreply.github.com> Date: Wed, 18 Dec 2024 16:19:36 -0600 Subject: [PATCH] [External] [stdlib] Add `List[Scalar[D]].extend()` from `SIMD` and `Span[Scalar[D]]` (#52584) [External] [stdlib] Add `List[Scalar[D]].extend()` from `SIMD` and `Span[Scalar[D]]` Add `List[Scalar[D]].extend()` from `SIMD` and `Span[Scalar[D]]` Split off from modularml/mojo#3814. This is needed to enable efficient appending of scalar value sequences to a `List` without having to resort to `UnsafePointer` manually. ORIGINAL_AUTHOR=martinvuyk <110240700+martinvuyk@users.noreply.github.com> PUBLIC_PR_LINK=modularml/mojo#3854 Co-authored-by: martinvuyk <110240700+martinvuyk@users.noreply.github.com> Co-authored-by: Connor Gray Closes modularml/mojo#3854 MODULAR_ORIG_COMMIT_REV_ID: 7d0c724497ba0671ae660f4de5758d6c4baad7bc --- docs/changelog.md | 5 ++ stdlib/src/base64/_b64encode.mojo | 28 +---------- stdlib/src/collections/list.mojo | 64 +++++++++++++++++++++++++- stdlib/src/testing/testing.mojo | 39 +++++++++++++++- stdlib/src/utils/inline_string.mojo | 25 +++------- stdlib/test/collections/test_list.mojo | 48 ++++++++++--------- 6 files changed, 140 insertions(+), 69 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index 63593a6da6..ad92ac49c1 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -71,6 +71,11 @@ what we publish. var ptr2 = list2.unsafe_ptr() ``` +- Added new `List.extend()` overloads taking `SIMD` and `Span`. These enable + growing a `List[Scalar[..]]` by copying the elements of a `SIMD` vector or + `Span[Scalar[..]]`, simplifying the writing of some optimized SIMD-aware + functionality. + - The `ExplicitlyCopyable` trait has changed to require a `fn copy(self) -> Self` method. Previously, an initializer with the signature `fn __init__(out self, *, other: Self)` had been required by diff --git a/stdlib/src/base64/_b64encode.mojo b/stdlib/src/base64/_b64encode.mojo index 74b8c31501..35b7f4c5d9 100644 --- a/stdlib/src/base64/_b64encode.mojo +++ b/stdlib/src/base64/_b64encode.mojo @@ -195,21 +195,6 @@ fn load_incomplete_simd[ return result -fn store_incomplete_simd[ - simd_width: Int -]( - pointer: UnsafePointer[UInt8], - owned simd_vector: SIMD[DType.uint8, simd_width], - nb_of_elements_to_store: Int, -): - var tmp_buffer_pointer = UnsafePointer.address_of(simd_vector).bitcast[ - UInt8 - ]() - - memcpy(dest=pointer, src=tmp_buffer_pointer, count=nb_of_elements_to_store) - _ = simd_vector # We make it live long enough - - # TODO: Use Span instead of List as input when Span is easier to use @no_inline fn b64encode_with_buffers( @@ -229,11 +214,7 @@ fn b64encode_with_buffers( var input_vector = start_of_input_chunk.load[width=simd_width]() - result_vector = _to_b64_ascii(input_vector) - - (result.unsafe_ptr() + len(result)).store(result_vector) - - result.size += simd_width + result.extend(_to_b64_ascii(input_vector)) input_index += input_simd_width # We handle the last 0, 1 or 2 chunks @@ -268,12 +249,7 @@ fn b64encode_with_buffers( ]( nb_of_elements_to_load ) - store_incomplete_simd( - result.unsafe_ptr() + len(result), - result_vector_with_equals, - nb_of_elements_to_store, - ) - result.size += nb_of_elements_to_store + result.extend(result_vector_with_equals, count=nb_of_elements_to_store) input_index += input_simd_width diff --git a/stdlib/src/collections/list.mojo b/stdlib/src/collections/list.mojo index 5732bc8da5..f537f7151f 100644 --- a/stdlib/src/collections/list.mojo +++ b/stdlib/src/collections/list.mojo @@ -497,9 +497,13 @@ struct List[T: CollectionElement, hint_trivial_type: Bool = False]( Args: value: The value to append. + + Notes: + If there is no capacity left, resizes to twice the current capacity. + Except for 0 capacity where it sets 1. """ if self.size >= self.capacity: - self._realloc(max(1, self.capacity * 2)) + self._realloc(self.capacity * 2 | int(self.capacity == 0)) (self.data + self.size).init_pointee_move(value^) self.size += 1 @@ -591,6 +595,64 @@ struct List[T: CollectionElement, hint_trivial_type: Bool = False]( # list. self.size = final_size + fn extend[ + D: DType, // + ](mut self: List[Scalar[D], *_, **_], value: SIMD[D, _]): + """Extends this list with the elements of a vector. + + Parameters: + D: The DType. + + Args: + value: The value to append. + + Notes: + If there is no capacity left, resizes to `len(self) + value.size`. + """ + self.reserve(self.size + value.size) + (self.data + self.size).store(value) + self.size += value.size + + fn extend[ + D: DType, // + ](mut self: List[Scalar[D], *_, **_], value: SIMD[D, _], *, count: Int): + """Extends this list with `count` number of elements from a vector. + + Parameters: + D: The DType. + + Args: + value: The value to append. + count: The ammount of items to append. Must be less than or equal to + `value.size`. + + Notes: + If there is no capacity left, resizes to `len(self) + count`. + """ + debug_assert(count <= value.size, "count must be <= value.size") + self.reserve(self.size + count) + var v_ptr = UnsafePointer.address_of(value).bitcast[Scalar[D]]() + memcpy(self.data + self.size, v_ptr, count) + self.size += count + + fn extend[ + D: DType, // + ](mut self: List[Scalar[D], *_, **_], value: Span[Scalar[D]]): + """Extends this list with the elements of a `Span`. + + Parameters: + D: The DType. + + Args: + value: The value to append. + + Notes: + If there is no capacity left, resizes to `len(self) + len(value)`. + """ + self.reserve(self.size + len(value)) + memcpy(self.data + self.size, value.unsafe_ptr(), len(value)) + self.size += len(value) + fn pop(mut self, i: Int = -1) -> T: """Pops a value from the list at the given index. diff --git a/stdlib/src/testing/testing.mojo b/stdlib/src/testing/testing.mojo index 20173be736..ef39769ff8 100644 --- a/stdlib/src/testing/testing.mojo +++ b/stdlib/src/testing/testing.mojo @@ -32,7 +32,7 @@ def main(): """ from collections import Optional from math import isclose - +from memory import memcmp from builtin._location import __call_location, _SourceLocation # ===----------------------------------------------------------------------=== # @@ -236,6 +236,43 @@ fn assert_equal[ ) +@always_inline +fn assert_equal[ + D: DType +]( + lhs: List[Scalar[D]], + rhs: List[Scalar[D]], + msg: String = "", + *, + location: Optional[_SourceLocation] = None, +) raises: + """Asserts that two lists are equal. + + Parameters: + D: A DType. + + Args: + lhs: The left-hand side list. + rhs: The right-hand side list. + msg: The message to be printed if the assertion fails. + location: The location of the error (default to the `__call_location`). + + Raises: + An Error with the provided message if assert fails and `None` otherwise. + """ + var length = len(lhs) + if ( + length != len(rhs) + or memcmp(lhs.unsafe_ptr(), rhs.unsafe_ptr(), length) != 0 + ): + raise _assert_cmp_error["`left == right` comparison"]( + lhs.__str__(), + rhs.__str__(), + msg=msg, + loc=location.or_else(__call_location()), + ) + + @always_inline fn assert_not_equal[ T: Testable diff --git a/stdlib/src/utils/inline_string.mojo b/stdlib/src/utils/inline_string.mojo index 5fb4687089..9be3326019 100644 --- a/stdlib/src/utils/inline_string.mojo +++ b/stdlib/src/utils/inline_string.mojo @@ -147,28 +147,15 @@ struct InlineString(Sized, Stringable, CollectionElement, CollectionElementNew): # Begin by heap allocating enough space to store the combined # string. var buffer = List[UInt8](capacity=total_len) - # Copy the bytes from the current small string layout - memcpy( - dest=buffer.unsafe_ptr(), - src=self._storage[_FixedString[Self.SMALL_CAP]].unsafe_ptr(), - count=len(self), + var span_self = Span[Byte, __origin_of(self)]( + ptr=self._storage[_FixedString[Self.SMALL_CAP]].unsafe_ptr(), + length=len(self), ) - + buffer.extend(span_self) # Copy the bytes from the additional string. - memcpy( - dest=buffer.unsafe_ptr() + len(self), - src=str_slice.unsafe_ptr(), - count=str_slice.byte_length(), - ) - - # Record that we've initialized `total_len` count of elements - # in `buffer` - buffer.size = total_len - - # Add the NUL byte - buffer.append(0) - + buffer.extend(str_slice.as_bytes()) + buffer.append(0) # Add the NUL byte self._storage = Self.Layout(String(buffer^)) fn __add__(self, other: StringLiteral) -> Self: diff --git a/stdlib/test/collections/test_list.mojo b/stdlib/test/collections/test_list.mojo index 9f45d66f5a..c6b3e4fb1e 100644 --- a/stdlib/test/collections/test_list.mojo +++ b/stdlib/test/collections/test_list.mojo @@ -437,32 +437,35 @@ def test_list_index(): _ = test_list_b.index(20, start=4, stop=5) -def test_list_extend(): - # - # Test extending the list [1, 2, 3] with itself - # +def test_list_append(): + var items = List[UInt32]() + items.append(1) + items.append(2) + items.append(3) + assert_equal(items, List[UInt32](1, 2, 3)) - vec = List[Int]() - vec.append(1) - vec.append(2) - vec.append(3) - assert_equal(len(vec), 3) - assert_equal(vec[0], 1) - assert_equal(vec[1], 2) - assert_equal(vec[2], 3) +def test_list_extend(): + var items = List[UInt32](1, 2, 3) + var copy = items + items.extend(copy) + assert_equal(items, List[UInt32](1, 2, 3, 1, 2, 3)) + + items = List[UInt32](1, 2, 3) + copy = List[UInt32](1, 2, 3) - var copy = vec - vec.extend(copy) + # Extend with span + items.extend(Span(copy)) + assert_equal(items, List[UInt32](1, 2, 3, 1, 2, 3)) - # vec == [1, 2, 3, 1, 2, 3] - assert_equal(len(vec), 6) - assert_equal(vec[0], 1) - assert_equal(vec[1], 2) - assert_equal(vec[2], 3) - assert_equal(vec[3], 1) - assert_equal(vec[4], 2) - assert_equal(vec[5], 3) + # Extend with whole SIMD + items = List[UInt32](1, 2, 3) + items.extend(SIMD[DType.uint32, 4](1, 2, 3, 4)) + assert_equal(items, List[UInt32](1, 2, 3, 1, 2, 3, 4)) + # Extend with part of SIMD + items = List[UInt32](1, 2, 3) + items.extend(SIMD[DType.uint32, 4](1, 2, 3, 4), count=3) + assert_equal(items, List[UInt32](1, 2, 3, 1, 2, 3)) def test_list_extend_non_trivial(): @@ -952,6 +955,7 @@ def main(): test_list_reverse_move_count() test_list_insert() test_list_index() + test_list_append() test_list_extend() test_list_extend_non_trivial() test_list_explicit_copy()