Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Add List[Scalar[D]] append SIMD and Span[Scalar[D]] #3854

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 2 additions & 26 deletions stdlib/src/base64/_b64encode.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -195,21 +195,6 @@ fn load_incomplete_simd[
return result


fn store_incomplete_simd[
simd_width: Int
](
pointer: UnsafePointer[UInt8],
owned simd_vector: SIMD[DType.uint8, simd_width],
nb_of_elements_to_store: Int,
):
var tmp_buffer_pointer = UnsafePointer.address_of(simd_vector).bitcast[
UInt8
]()

memcpy(dest=pointer, src=tmp_buffer_pointer, count=nb_of_elements_to_store)
_ = simd_vector # We make it live long enough


# TODO: Use Span instead of List as input when Span is easier to use
@no_inline
fn b64encode_with_buffers(
Expand All @@ -229,11 +214,7 @@ fn b64encode_with_buffers(

var input_vector = start_of_input_chunk.load[width=simd_width]()

result_vector = _to_b64_ascii(input_vector)

(result.unsafe_ptr() + len(result)).store(result_vector)

result.size += simd_width
result.append(_to_b64_ascii(input_vector))
input_index += input_simd_width

# We handle the last 0, 1 or 2 chunks
Expand Down Expand Up @@ -268,12 +249,7 @@ fn b64encode_with_buffers(
](
nb_of_elements_to_load
)
store_incomplete_simd(
result.unsafe_ptr() + len(result),
result_vector_with_equals,
nb_of_elements_to_store,
)
result.size += nb_of_elements_to_store
result.append(result_vector_with_equals, nb_of_elements_to_store)
input_index += input_simd_width


Expand Down
63 changes: 62 additions & 1 deletion stdlib/src/collections/list.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -496,12 +496,73 @@ struct List[T: CollectionElement, hint_trivial_type: Bool = False](

Args:
value: The value to append.

Notes:
If there is no capacity left, resizes to twice the current capacity.
Except for 0 capacity where it sets 1.
"""
if self.size >= self.capacity:
self._realloc(max(1, self.capacity * 2))
self._realloc(self.capacity * 2 | int(self.capacity == 0))
(self.data + self.size).init_pointee_move(value^)
self.size += 1

fn append[
D: DType, //
](mut self: List[Scalar[D], *_, **_], value: SIMD[D, _]):
"""Appends a vector to this list.

Parameters:
D: The DType.

Args:
value: The value to append.

Notes:
If there is no capacity left, resizes to `len(self) + value.size`.
"""
self.reserve(self.size + value.size)
(self.data + self.size).store(value)
self.size += value.size

fn append[
D: DType, //
](mut self: List[Scalar[D], *_, **_], value: SIMD[D, _], count: Int):
"""Appends a vector to this list.

Parameters:
D: The DType.

Args:
value: The value to append.
count: The ammount of items to append.

Notes:
If there is no capacity left, resizes to `len(self) + count`.
"""
debug_assert(count <= value.size, "count must be <= value.size")
self.reserve(self.size + count)
var v_ptr = UnsafePointer.address_of(value).bitcast[Scalar[D]]()
memcpy(self.data + self.size, v_ptr, count)
martinvuyk marked this conversation as resolved.
Show resolved Hide resolved
self.size += count

fn append[
D: DType, //
](mut self: List[Scalar[D], *_, **_], value: Span[Scalar[D]]):
"""Appends a Span to this list.

Parameters:
D: The DType.

Args:
value: The value to append.

Notes:
If there is no capacity left, resizes to `len(self) + len(value)`.
"""
self.reserve(self.size + len(value))
memcpy(self.data + self.size, value.unsafe_ptr(), len(value))
self.size += len(value)

fn insert(mut self, i: Int, owned value: T):
"""Inserts a value to the list at the given index.
`a.insert(len(a), value)` is equivalent to `a.append(value)`.
Expand Down
39 changes: 38 additions & 1 deletion stdlib/src/testing/testing.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def main():
"""
from collections import Optional
from math import isclose

from memory import memcmp
from builtin._location import __call_location, _SourceLocation

# ===----------------------------------------------------------------------=== #
Expand Down Expand Up @@ -236,6 +236,43 @@ fn assert_equal[
)


@always_inline
fn assert_equal[
D: DType
](
lhs: List[Scalar[D]],
rhs: List[Scalar[D]],
msg: String = "",
*,
location: Optional[_SourceLocation] = None,
) raises:
"""Asserts that two lists are equal.

Parameters:
D: A DType.

Args:
lhs: The left-hand side list.
rhs: The right-hand side list.
msg: The message to be printed if the assertion fails.
location: The location of the error (default to the `__call_location`).

Raises:
An Error with the provided message if assert fails and `None` otherwise.
"""
var length = len(lhs)
if (
length != len(rhs)
or memcmp(lhs.unsafe_ptr(), rhs.unsafe_ptr(), length) != 0
):
raise _assert_cmp_error["`left == right` comparison"](
lhs.__str__(),
rhs.__str__(),
msg=msg,
loc=location.or_else(__call_location()),
)


@always_inline
fn assert_not_equal[
T: Testable
Expand Down
25 changes: 6 additions & 19 deletions stdlib/src/utils/inline_string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -147,28 +147,15 @@ struct InlineString(Sized, Stringable, CollectionElement, CollectionElementNew):
# Begin by heap allocating enough space to store the combined
# string.
var buffer = List[UInt8](capacity=total_len)

# Copy the bytes from the current small string layout
memcpy(
dest=buffer.unsafe_ptr(),
src=self._storage[_FixedString[Self.SMALL_CAP]].unsafe_ptr(),
count=len(self),
var span_self = Span[Byte, __origin_of(self)](
ptr=self._storage[_FixedString[Self.SMALL_CAP]].unsafe_ptr(),
length=len(self),
)

buffer.append(span_self)
# Copy the bytes from the additional string.
memcpy(
dest=buffer.unsafe_ptr() + len(self),
src=str_slice.unsafe_ptr(),
count=str_slice.byte_length(),
)

# Record that we've initialized `total_len` count of elements
# in `buffer`
buffer.size = total_len

# Add the NUL byte
buffer.append(0)

buffer.append(str_slice.as_bytes())
buffer.append(0) # Add the NUL byte
Copy link
Collaborator

@ConnorGray ConnorGray Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! This is a great cleanup 🙂

In theory this might be slightly less efficient since the append() will need to perform redundant bounds-checking. But I think that's a small cost that we shouldn't worry too much about optimizing yet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory this might be slightly less efficient since the append() will need to perform redundant bounds-checking.

We could add a unsafe_no_checks parameter to all the append functions at some point. That way we'd stop having to resort to UnsafePointer so much for such scenarios.

self._storage = Self.Layout(String(buffer^))

fn __add__(self, other: StringLiteral) -> Self:
Expand Down
50 changes: 26 additions & 24 deletions stdlib/test/collections/test_list.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -437,32 +437,33 @@ def test_list_index():
_ = test_list_b.index(20, start=4, stop=5)


def test_list_extend():
#
# Test extending the list [1, 2, 3] with itself
#
def test_list_append():
items = List[UInt32]()
items.append(1)
items.append(2)
items.append(3)
assert_equal(items, List[UInt32](1, 2, 3))

# append span
copy = items
items.append(Span(copy))
assert_equal(items, List[UInt32](1, 2, 3, 1, 2, 3))

# whole SIMD
items = List[UInt32](1, 2, 3)
items.append(SIMD[DType.uint32, 4](1, 2, 3, 4))
assert_equal(items, List[UInt32](1, 2, 3, 1, 2, 3, 4))
# part of SIMD
items = List[UInt32](1, 2, 3)
items.append(SIMD[DType.uint32, 4](1, 2, 3, 4), 3)
assert_equal(items, List[UInt32](1, 2, 3, 1, 2, 3))

vec = List[Int]()
vec.append(1)
vec.append(2)
vec.append(3)

assert_equal(len(vec), 3)
assert_equal(vec[0], 1)
assert_equal(vec[1], 2)
assert_equal(vec[2], 3)

var copy = vec
vec.extend(copy)

# vec == [1, 2, 3, 1, 2, 3]
assert_equal(len(vec), 6)
assert_equal(vec[0], 1)
assert_equal(vec[1], 2)
assert_equal(vec[2], 3)
assert_equal(vec[3], 1)
assert_equal(vec[4], 2)
assert_equal(vec[5], 3)
def test_list_extend():
items = List[Int](1, 2, 3)
copy = items
items.extend(copy)
assert_equal(items, List[Int](1, 2, 3, 1, 2, 3))


def test_list_extend_non_trivial():
Expand Down Expand Up @@ -952,6 +953,7 @@ def main():
test_list_reverse_move_count()
test_list_insert()
test_list_index()
test_list_append()
test_list_extend()
test_list_extend_non_trivial()
test_list_explicit_copy()
Expand Down
3 changes: 2 additions & 1 deletion stdlib/test/python/my_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def __init__(self, bar):

class AbstractPerson(ABC):
@abstractmethod
def method(self): ...
def method(self):
...


def my_function(name):
Expand Down
Loading