Skip to content

Commit

Permalink
[External] [stdlib] Add List[Scalar[D]].extend() from SIMD and `S…
Browse files Browse the repository at this point in the history
…pan[Scalar[D]]` (#52584)

[External] [stdlib] Add `List[Scalar[D]].extend()` from `SIMD` and
`Span[Scalar[D]]`

Add `List[Scalar[D]].extend()` from `SIMD` and `Span[Scalar[D]]`

Split off from modularml#3814. This is needed to enable efficient
appending of scalar value sequences to a `List` without having to resort
to `UnsafePointer` manually.

ORIGINAL_AUTHOR=martinvuyk
<[email protected]>
PUBLIC_PR_LINK=modularml#3854

Co-authored-by: martinvuyk <[email protected]>
Co-authored-by: Connor Gray <[email protected]>
Closes modularml#3854
MODULAR_ORIG_COMMIT_REV_ID: 7d0c724497ba0671ae660f4de5758d6c4baad7bc
  • Loading branch information
3 people authored and msaelices committed Dec 21, 2024
1 parent e69f028 commit c84c329
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 69 deletions.
5 changes: 5 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ what we publish.
var ptr2 = list2.unsafe_ptr()
```
- Added new `List.extend()` overloads taking `SIMD` and `Span`. These enable
growing a `List[Scalar[..]]` by copying the elements of a `SIMD` vector or
`Span[Scalar[..]]`, simplifying the writing of some optimized SIMD-aware
functionality.
- The `ExplicitlyCopyable` trait has changed to require a
`fn copy(self) -> Self` method. Previously, an initializer with the signature
`fn __init__(out self, *, other: Self)` had been required by
Expand Down
28 changes: 2 additions & 26 deletions stdlib/src/base64/_b64encode.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -195,21 +195,6 @@ fn load_incomplete_simd[
return result


fn store_incomplete_simd[
simd_width: Int
](
pointer: UnsafePointer[UInt8],
owned simd_vector: SIMD[DType.uint8, simd_width],
nb_of_elements_to_store: Int,
):
var tmp_buffer_pointer = UnsafePointer.address_of(simd_vector).bitcast[
UInt8
]()

memcpy(dest=pointer, src=tmp_buffer_pointer, count=nb_of_elements_to_store)
_ = simd_vector # We make it live long enough


# TODO: Use Span instead of List as input when Span is easier to use
@no_inline
fn b64encode_with_buffers(
Expand All @@ -229,11 +214,7 @@ fn b64encode_with_buffers(

var input_vector = start_of_input_chunk.load[width=simd_width]()

result_vector = _to_b64_ascii(input_vector)

(result.unsafe_ptr() + len(result)).store(result_vector)

result.size += simd_width
result.extend(_to_b64_ascii(input_vector))
input_index += input_simd_width

# We handle the last 0, 1 or 2 chunks
Expand Down Expand Up @@ -268,12 +249,7 @@ fn b64encode_with_buffers(
](
nb_of_elements_to_load
)
store_incomplete_simd(
result.unsafe_ptr() + len(result),
result_vector_with_equals,
nb_of_elements_to_store,
)
result.size += nb_of_elements_to_store
result.extend(result_vector_with_equals, count=nb_of_elements_to_store)
input_index += input_simd_width


Expand Down
64 changes: 63 additions & 1 deletion stdlib/src/collections/list.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -497,9 +497,13 @@ struct List[T: CollectionElement, hint_trivial_type: Bool = False](
Args:
value: The value to append.
Notes:
If there is no capacity left, resizes to twice the current capacity.
Except for 0 capacity where it sets 1.
"""
if self.size >= self.capacity:
self._realloc(max(1, self.capacity * 2))
self._realloc(self.capacity * 2 | int(self.capacity == 0))
(self.data + self.size).init_pointee_move(value^)
self.size += 1

Expand Down Expand Up @@ -591,6 +595,64 @@ struct List[T: CollectionElement, hint_trivial_type: Bool = False](
# list.
self.size = final_size

fn extend[
D: DType, //
](mut self: List[Scalar[D], *_, **_], value: SIMD[D, _]):
"""Extends this list with the elements of a vector.
Parameters:
D: The DType.
Args:
value: The value to append.
Notes:
If there is no capacity left, resizes to `len(self) + value.size`.
"""
self.reserve(self.size + value.size)
(self.data + self.size).store(value)
self.size += value.size

fn extend[
D: DType, //
](mut self: List[Scalar[D], *_, **_], value: SIMD[D, _], *, count: Int):
"""Extends this list with `count` number of elements from a vector.
Parameters:
D: The DType.
Args:
value: The value to append.
count: The ammount of items to append. Must be less than or equal to
`value.size`.
Notes:
If there is no capacity left, resizes to `len(self) + count`.
"""
debug_assert(count <= value.size, "count must be <= value.size")
self.reserve(self.size + count)
var v_ptr = UnsafePointer.address_of(value).bitcast[Scalar[D]]()
memcpy(self.data + self.size, v_ptr, count)
self.size += count

fn extend[
D: DType, //
](mut self: List[Scalar[D], *_, **_], value: Span[Scalar[D]]):
"""Extends this list with the elements of a `Span`.
Parameters:
D: The DType.
Args:
value: The value to append.
Notes:
If there is no capacity left, resizes to `len(self) + len(value)`.
"""
self.reserve(self.size + len(value))
memcpy(self.data + self.size, value.unsafe_ptr(), len(value))
self.size += len(value)

fn pop(mut self, i: Int = -1) -> T:
"""Pops a value from the list at the given index.
Expand Down
39 changes: 38 additions & 1 deletion stdlib/src/testing/testing.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def main():
"""
from collections import Optional
from math import isclose

from memory import memcmp
from builtin._location import __call_location, _SourceLocation

# ===----------------------------------------------------------------------=== #
Expand Down Expand Up @@ -236,6 +236,43 @@ fn assert_equal[
)


@always_inline
fn assert_equal[
D: DType
](
lhs: List[Scalar[D]],
rhs: List[Scalar[D]],
msg: String = "",
*,
location: Optional[_SourceLocation] = None,
) raises:
"""Asserts that two lists are equal.
Parameters:
D: A DType.
Args:
lhs: The left-hand side list.
rhs: The right-hand side list.
msg: The message to be printed if the assertion fails.
location: The location of the error (default to the `__call_location`).
Raises:
An Error with the provided message if assert fails and `None` otherwise.
"""
var length = len(lhs)
if (
length != len(rhs)
or memcmp(lhs.unsafe_ptr(), rhs.unsafe_ptr(), length) != 0
):
raise _assert_cmp_error["`left == right` comparison"](
lhs.__str__(),
rhs.__str__(),
msg=msg,
loc=location.or_else(__call_location()),
)


@always_inline
fn assert_not_equal[
T: Testable
Expand Down
25 changes: 6 additions & 19 deletions stdlib/src/utils/inline_string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -147,28 +147,15 @@ struct InlineString(Sized, Stringable, CollectionElement, CollectionElementNew):
# Begin by heap allocating enough space to store the combined
# string.
var buffer = List[UInt8](capacity=total_len)

# Copy the bytes from the current small string layout
memcpy(
dest=buffer.unsafe_ptr(),
src=self._storage[_FixedString[Self.SMALL_CAP]].unsafe_ptr(),
count=len(self),
var span_self = Span[Byte, __origin_of(self)](
ptr=self._storage[_FixedString[Self.SMALL_CAP]].unsafe_ptr(),
length=len(self),
)

buffer.extend(span_self)
# Copy the bytes from the additional string.
memcpy(
dest=buffer.unsafe_ptr() + len(self),
src=str_slice.unsafe_ptr(),
count=str_slice.byte_length(),
)

# Record that we've initialized `total_len` count of elements
# in `buffer`
buffer.size = total_len

# Add the NUL byte
buffer.append(0)

buffer.extend(str_slice.as_bytes())
buffer.append(0) # Add the NUL byte
self._storage = Self.Layout(String(buffer^))

fn __add__(self, other: StringLiteral) -> Self:
Expand Down
48 changes: 26 additions & 22 deletions stdlib/test/collections/test_list.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -437,32 +437,35 @@ def test_list_index():
_ = test_list_b.index(20, start=4, stop=5)


def test_list_extend():
#
# Test extending the list [1, 2, 3] with itself
#
def test_list_append():
var items = List[UInt32]()
items.append(1)
items.append(2)
items.append(3)
assert_equal(items, List[UInt32](1, 2, 3))

vec = List[Int]()
vec.append(1)
vec.append(2)
vec.append(3)

assert_equal(len(vec), 3)
assert_equal(vec[0], 1)
assert_equal(vec[1], 2)
assert_equal(vec[2], 3)
def test_list_extend():
var items = List[UInt32](1, 2, 3)
var copy = items
items.extend(copy)
assert_equal(items, List[UInt32](1, 2, 3, 1, 2, 3))

items = List[UInt32](1, 2, 3)
copy = List[UInt32](1, 2, 3)

var copy = vec
vec.extend(copy)
# Extend with span
items.extend(Span(copy))
assert_equal(items, List[UInt32](1, 2, 3, 1, 2, 3))

# vec == [1, 2, 3, 1, 2, 3]
assert_equal(len(vec), 6)
assert_equal(vec[0], 1)
assert_equal(vec[1], 2)
assert_equal(vec[2], 3)
assert_equal(vec[3], 1)
assert_equal(vec[4], 2)
assert_equal(vec[5], 3)
# Extend with whole SIMD
items = List[UInt32](1, 2, 3)
items.extend(SIMD[DType.uint32, 4](1, 2, 3, 4))
assert_equal(items, List[UInt32](1, 2, 3, 1, 2, 3, 4))
# Extend with part of SIMD
items = List[UInt32](1, 2, 3)
items.extend(SIMD[DType.uint32, 4](1, 2, 3, 4), count=3)
assert_equal(items, List[UInt32](1, 2, 3, 1, 2, 3))


def test_list_extend_non_trivial():
Expand Down Expand Up @@ -952,6 +955,7 @@ def main():
test_list_reverse_move_count()
test_list_insert()
test_list_index()
test_list_append()
test_list_extend()
test_list_extend_non_trivial()
test_list_explicit_copy()
Expand Down

0 comments on commit c84c329

Please sign in to comment.