diff --git a/stdlib/src/memory/span.mojo b/stdlib/src/memory/span.mojo index 2a0c2d020b..0ee5bff233 100644 --- a/stdlib/src/memory/span.mojo +++ b/stdlib/src/memory/span.mojo @@ -406,3 +406,74 @@ struct Span[ return Span[T, ImmutableOrigin.cast_from[origin].result]( ptr=self._data, length=self._len ) + + fn apply[ + D: DType, + O: MutableOrigin, //, + func: fn[w: Int] (SIMD[D, w]) -> SIMD[D, w], + ](mut self: Span[Scalar[D], O]): + """Apply the function to the `Span` inplace. + + Parameters: + D: The DType. + O: The origin of the `Span`. + func: The function to evaluate. + """ + + alias widths = (256, 128, 64, 32, 16, 8, 4) + var ptr = self.unsafe_ptr() + var length = len(self) + var processed = 0 + + @parameter + for i in range(len(widths)): + alias w = widths.get[i, Int]() + + @parameter + if simdwidthof[D]() >= w: + for _ in range((length - processed) // w): + var p_curr = ptr + processed + p_curr.store(func(p_curr.load[width=w]())) + processed += w + + for i in range(length - processed): + (ptr + processed + i).init_pointee_move(func(ptr[processed + i])) + + fn apply[ + D: DType, + O: MutableOrigin, //, + func: fn[w: Int] (SIMD[D, w]) -> SIMD[D, w], + *, + where: fn[w: Int] (SIMD[D, w]) -> SIMD[DType.bool, w], + ](mut self: Span[Scalar[D], O]): + """Apply the function to the `Span` inplace where the condition is + `True`. + + Parameters: + D: The DType. + O: The origin of the `Span`. + func: The function to evaluate. + where: The condition to apply the function. + """ + + alias widths = (256, 128, 64, 32, 16, 8, 4) + var ptr = self.unsafe_ptr() + var length = len(self) + var processed = 0 + + @parameter + for i in range(len(widths)): + alias w = widths.get[i, Int]() + + @parameter + if simdwidthof[D]() >= w: + for _ in range((length - processed) // w): + var p_curr = ptr + processed + var vec = p_curr.load[width=w]() + p_curr.store(where(vec).select(func(vec), vec)) + processed += w + + for i in range(length - processed): + var vec = ptr[processed + i] + if where(vec): + (ptr + processed + i).init_pointee_move(func(vec)) diff --git a/stdlib/test/memory/test_span.mojo b/stdlib/test/memory/test_span.mojo index 4a3b6dd980..14a82aa830 100644 --- a/stdlib/test/memory/test_span.mojo +++ b/stdlib/test/memory/test_span.mojo @@ -208,6 +208,46 @@ def test_reversed(): i += 1 +def test_apply(): + fn _twice[D: DType, w: Int](x: SIMD[D, w]) -> SIMD[D, w]: + return x * 2 + + fn _where[D: DType, w: Int](x: SIMD[D, w]) -> SIMD[DType.bool, w]: + return x % 2 == 0 + + def _test[D: DType](): + items = List[Scalar[D]]( + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 + ) + twice = items + span = Span(twice) + span.apply[func = _twice[D]]() + for i in range(len(items)): + assert_true(span[i] == items[i] * 2) + + # twice only even numbers + twice = items + span = Span(twice) + span.apply[func = _twice[D], where = _where[D]]() + for i in range(len(items)): + if items[i] % 2 == 0: + assert_true(span[i] == items[i] * 2) + else: + assert_true(span[i] == items[i]) + + _test[DType.uint8]() + _test[DType.uint16]() + _test[DType.uint32]() + _test[DType.uint64]() + _test[DType.int8]() + _test[DType.int16]() + _test[DType.int32]() + _test[DType.int64]() + _test[DType.float16]() + _test[DType.float32]() + _test[DType.float64]() + + def main(): test_span_list_int() test_span_list_str() @@ -221,3 +261,4 @@ def main(): test_fill() test_ref() test_reversed() + test_apply()