From f411b5bef079a95c782986acd924a7054d6286fe Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Fri, 13 Dec 2024 10:18:18 -0300 Subject: [PATCH 1/7] Add vectorized Span.apply() Signed-off-by: martinvuyk --- stdlib/src/memory/span.mojo | 72 +++++++++++++++++++++++++++++++ stdlib/test/memory/test_span.mojo | 41 ++++++++++++++++++ 2 files changed, 113 insertions(+) diff --git a/stdlib/src/memory/span.mojo b/stdlib/src/memory/span.mojo index 03f860f899..55d9981f9c 100644 --- a/stdlib/src/memory/span.mojo +++ b/stdlib/src/memory/span.mojo @@ -23,6 +23,7 @@ from memory import Span from collections import InlineArray from memory import Pointer, UnsafePointer +from sys.info import simdwidthof trait AsBytes: @@ -371,3 +372,74 @@ struct Span[ return Span[T, ImmutableOrigin.cast_from[origin].result]( ptr=self._data, length=self._len ) + + fn apply[ + D: DType, + O: MutableOrigin, //, + func: fn[w: Int] (SIMD[D, w]) -> SIMD[D, w], + ](mut self: Span[Scalar[D], O]): + """Apply the function to the `Span` inplace. + + Parameters: + D: The DType. + O: The origin of the `Span`. + func: The function to evaluate. + """ + + alias widths = (256, 128, 64, 32, 16, 8, 4) + var ptr = self.unsafe_ptr() + var length = len(self) + var processed = 0 + + @parameter + for i in range(len(widths)): + alias w = widths.get[i, Int]() + + @parameter + if simdwidthof[D]() >= w: + for _ in range((length - processed) // w): + var p_curr = ptr + processed + p_curr.store(func(p_curr.load[width=w]())) + processed += w + + for i in range(length - processed): + (ptr + processed + i).init_pointee_move(func(ptr[processed + i])) + + fn apply[ + D: DType, + O: MutableOrigin, //, + func: fn[w: Int] (SIMD[D, w]) -> SIMD[D, w], + *, + where: fn[w: Int] (SIMD[D, w]) -> SIMD[DType.bool, w], + ](mut self: Span[Scalar[D], O]): + """Apply the function to the `Span` inplace where the condition is + `True`. + + Parameters: + D: The DType. + O: The origin of the `Span`. + func: The function to evaluate. + where: The condition to apply the function. + """ + + alias widths = (256, 128, 64, 32, 16, 8, 4) + var ptr = self.unsafe_ptr() + var length = len(self) + var processed = 0 + + @parameter + for i in range(len(widths)): + alias w = widths.get[i, Int]() + + @parameter + if simdwidthof[D]() >= w: + for _ in range((length - processed) // w): + var p_curr = ptr + processed + var vec = p_curr.load[width=w]() + p_curr.store(where(vec).select(func(vec), vec)) + processed += w + + for i in range(length - processed): + var vec = ptr[processed + i] + if where(vec): + (ptr + processed + i).init_pointee_move(func(vec)) diff --git a/stdlib/test/memory/test_span.mojo b/stdlib/test/memory/test_span.mojo index 92c49210c6..3eb10980ae 100644 --- a/stdlib/test/memory/test_span.mojo +++ b/stdlib/test/memory/test_span.mojo @@ -199,6 +199,46 @@ def test_reversed(): i += 1 +def test_apply(): + fn _twice[D: DType, w: Int](x: SIMD[D, w]) -> SIMD[D, w]: + return x * 2 + + fn _where[D: DType, w: Int](x: SIMD[D, w]) -> SIMD[DType.bool, w]: + return x % 2 == 0 + + def _test[D: DType](): + items = List[Scalar[D]]( + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 + ) + twice = items + span = Span(twice) + span.apply[func = _twice[D]]() + for i in range(len(items)): + assert_true(span[i] == items[i] * 2) + + # twice only even numbers + twice = items + span = Span(twice) + span.apply[func = _twice[D], where = _where[D]]() + for i in range(len(items)): + if items[i] % 2 == 0: + assert_true(span[i] == items[i] * 2) + else: + assert_true(span[i] == items[i]) + + _test[DType.uint8]() + _test[DType.uint16]() + _test[DType.uint32]() + _test[DType.uint64]() + _test[DType.int8]() + _test[DType.int16]() + _test[DType.int32]() + _test[DType.int64]() + _test[DType.float16]() + _test[DType.float32]() + _test[DType.float64]() + + def main(): test_span_list_int() test_span_list_str() @@ -211,3 +251,4 @@ def main(): test_fill() test_ref() test_reversed() + test_apply() From 18ede12ef5eb05ba5b71f33c8c832d7bf7fe17cb Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Fri, 13 Dec 2024 10:30:57 -0300 Subject: [PATCH 2/7] move functions to argument space instead of param Signed-off-by: martinvuyk --- stdlib/src/memory/span.mojo | 18 ++++++++++++------ stdlib/test/memory/test_span.mojo | 4 ++-- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/stdlib/src/memory/span.mojo b/stdlib/src/memory/span.mojo index 55d9981f9c..6001743e6b 100644 --- a/stdlib/src/memory/span.mojo +++ b/stdlib/src/memory/span.mojo @@ -374,15 +374,18 @@ struct Span[ ) fn apply[ - D: DType, - O: MutableOrigin, //, + D: DType, O: MutableOrigin, // + ]( + mut self: Span[Scalar[D], O], func: fn[w: Int] (SIMD[D, w]) -> SIMD[D, w], - ](mut self: Span[Scalar[D], O]): + ): """Apply the function to the `Span` inplace. Parameters: D: The DType. O: The origin of the `Span`. + + Args: func: The function to evaluate. """ @@ -406,18 +409,21 @@ struct Span[ (ptr + processed + i).init_pointee_move(func(ptr[processed + i])) fn apply[ - D: DType, - O: MutableOrigin, //, + D: DType, O: MutableOrigin, // + ]( + mut self: Span[Scalar[D], O], func: fn[w: Int] (SIMD[D, w]) -> SIMD[D, w], *, where: fn[w: Int] (SIMD[D, w]) -> SIMD[DType.bool, w], - ](mut self: Span[Scalar[D], O]): + ): """Apply the function to the `Span` inplace where the condition is `True`. Parameters: D: The DType. O: The origin of the `Span`. + + Args: func: The function to evaluate. where: The condition to apply the function. """ diff --git a/stdlib/test/memory/test_span.mojo b/stdlib/test/memory/test_span.mojo index 3eb10980ae..8990a8a3d2 100644 --- a/stdlib/test/memory/test_span.mojo +++ b/stdlib/test/memory/test_span.mojo @@ -212,14 +212,14 @@ def test_apply(): ) twice = items span = Span(twice) - span.apply[func = _twice[D]]() + span.apply(_twice[D]) for i in range(len(items)): assert_true(span[i] == items[i] * 2) # twice only even numbers twice = items span = Span(twice) - span.apply[func = _twice[D], where = _where[D]]() + span.apply(_twice[D], where=_where[D]) for i in range(len(items)): if items[i] % 2 == 0: assert_true(span[i] == items[i] * 2) From 22f064353bbbe6733571c8632ab1bf7d2385d677 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Fri, 13 Dec 2024 10:38:12 -0300 Subject: [PATCH 3/7] revert take funcs back to parameter space Signed-off-by: martinvuyk --- stdlib/src/memory/span.mojo | 14 +++++--------- stdlib/test/memory/test_span.mojo | 4 ++-- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/stdlib/src/memory/span.mojo b/stdlib/src/memory/span.mojo index 6001743e6b..37eda6a8f6 100644 --- a/stdlib/src/memory/span.mojo +++ b/stdlib/src/memory/span.mojo @@ -374,18 +374,16 @@ struct Span[ ) fn apply[ - D: DType, O: MutableOrigin, // + D: DType, O: MutableOrigin, //, + func: fn[w: Int] (SIMD[D, w]) -> SIMD[D, w], ]( mut self: Span[Scalar[D], O], - func: fn[w: Int] (SIMD[D, w]) -> SIMD[D, w], ): """Apply the function to the `Span` inplace. Parameters: D: The DType. O: The origin of the `Span`. - - Args: func: The function to evaluate. """ @@ -409,12 +407,12 @@ struct Span[ (ptr + processed + i).init_pointee_move(func(ptr[processed + i])) fn apply[ - D: DType, O: MutableOrigin, // - ]( - mut self: Span[Scalar[D], O], + D: DType, O: MutableOrigin, //, func: fn[w: Int] (SIMD[D, w]) -> SIMD[D, w], *, where: fn[w: Int] (SIMD[D, w]) -> SIMD[DType.bool, w], + ]( + mut self: Span[Scalar[D], O], ): """Apply the function to the `Span` inplace where the condition is `True`. @@ -422,8 +420,6 @@ struct Span[ Parameters: D: The DType. O: The origin of the `Span`. - - Args: func: The function to evaluate. where: The condition to apply the function. """ diff --git a/stdlib/test/memory/test_span.mojo b/stdlib/test/memory/test_span.mojo index 8990a8a3d2..8db120951b 100644 --- a/stdlib/test/memory/test_span.mojo +++ b/stdlib/test/memory/test_span.mojo @@ -212,14 +212,14 @@ def test_apply(): ) twice = items span = Span(twice) - span.apply(_twice[D]) + span.apply[func=_twice[D]]() for i in range(len(items)): assert_true(span[i] == items[i] * 2) # twice only even numbers twice = items span = Span(twice) - span.apply(_twice[D], where=_where[D]) + span.apply[func=_twice[D], where=_where[D]]() for i in range(len(items)): if items[i] % 2 == 0: assert_true(span[i] == items[i] * 2) From 15df64b202000f05d4923325a1c437620edda233 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Fri, 13 Dec 2024 10:38:55 -0300 Subject: [PATCH 4/7] mojo format Signed-off-by: martinvuyk --- stdlib/src/memory/span.mojo | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/stdlib/src/memory/span.mojo b/stdlib/src/memory/span.mojo index 37eda6a8f6..8e70d8807a 100644 --- a/stdlib/src/memory/span.mojo +++ b/stdlib/src/memory/span.mojo @@ -376,9 +376,7 @@ struct Span[ fn apply[ D: DType, O: MutableOrigin, //, func: fn[w: Int] (SIMD[D, w]) -> SIMD[D, w], - ]( - mut self: Span[Scalar[D], O], - ): + ](mut self: Span[Scalar[D], O]): """Apply the function to the `Span` inplace. Parameters: @@ -411,9 +409,7 @@ struct Span[ func: fn[w: Int] (SIMD[D, w]) -> SIMD[D, w], *, where: fn[w: Int] (SIMD[D, w]) -> SIMD[DType.bool, w], - ]( - mut self: Span[Scalar[D], O], - ): + ](mut self: Span[Scalar[D], O]): """Apply the function to the `Span` inplace where the condition is `True`. From a674810ec67b7d73794f689ad62aa59c12d8b552 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Fri, 13 Dec 2024 10:48:25 -0300 Subject: [PATCH 5/7] mojo format Signed-off-by: martinvuyk --- stdlib/src/memory/span.mojo | 123 ++++++++++++++++++------------------ 1 file changed, 63 insertions(+), 60 deletions(-) diff --git a/stdlib/src/memory/span.mojo b/stdlib/src/memory/span.mojo index 8e70d8807a..71f3c9570f 100644 --- a/stdlib/src/memory/span.mojo +++ b/stdlib/src/memory/span.mojo @@ -373,71 +373,74 @@ struct Span[ ptr=self._data, length=self._len ) - fn apply[ - D: DType, O: MutableOrigin, //, - func: fn[w: Int] (SIMD[D, w]) -> SIMD[D, w], - ](mut self: Span[Scalar[D], O]): - """Apply the function to the `Span` inplace. +fn apply[ + D: DType, + O: MutableOrigin, //, + func: fn[w: Int] (SIMD[D, w]) -> SIMD[D, w], +](mut self: Span[Scalar[D], O]): + """Apply the function to the `Span` inplace. - Parameters: - D: The DType. - O: The origin of the `Span`. - func: The function to evaluate. - """ + Parameters: + D: The DType. + O: The origin of the `Span`. + func: The function to evaluate. + """ + + alias widths = (256, 128, 64, 32, 16, 8, 4) + var ptr = self.unsafe_ptr() + var length = len(self) + var processed = 0 - alias widths = (256, 128, 64, 32, 16, 8, 4) - var ptr = self.unsafe_ptr() - var length = len(self) - var processed = 0 + @parameter + for i in range(len(widths)): + alias w = widths.get[i, Int]() @parameter - for i in range(len(widths)): - alias w = widths.get[i, Int]() - - @parameter - if simdwidthof[D]() >= w: - for _ in range((length - processed) // w): - var p_curr = ptr + processed - p_curr.store(func(p_curr.load[width=w]())) - processed += w - - for i in range(length - processed): - (ptr + processed + i).init_pointee_move(func(ptr[processed + i])) - - fn apply[ - D: DType, O: MutableOrigin, //, - func: fn[w: Int] (SIMD[D, w]) -> SIMD[D, w], - *, - where: fn[w: Int] (SIMD[D, w]) -> SIMD[DType.bool, w], - ](mut self: Span[Scalar[D], O]): - """Apply the function to the `Span` inplace where the condition is - `True`. + if simdwidthof[D]() >= w: + for _ in range((length - processed) // w): + var p_curr = ptr + processed + p_curr.store(func(p_curr.load[width=w]())) + processed += w + + for i in range(length - processed): + (ptr + processed + i).init_pointee_move(func(ptr[processed + i])) + + +fn apply[ + D: DType, + O: MutableOrigin, //, + func: fn[w: Int] (SIMD[D, w]) -> SIMD[D, w], + *, + where: fn[w: Int] (SIMD[D, w]) -> SIMD[DType.bool, w], +](mut self: Span[Scalar[D], O]): + """Apply the function to the `Span` inplace where the condition is + `True`. - Parameters: - D: The DType. - O: The origin of the `Span`. - func: The function to evaluate. - where: The condition to apply the function. - """ + Parameters: + D: The DType. + O: The origin of the `Span`. + func: The function to evaluate. + where: The condition to apply the function. + """ + + alias widths = (256, 128, 64, 32, 16, 8, 4) + var ptr = self.unsafe_ptr() + var length = len(self) + var processed = 0 - alias widths = (256, 128, 64, 32, 16, 8, 4) - var ptr = self.unsafe_ptr() - var length = len(self) - var processed = 0 + @parameter + for i in range(len(widths)): + alias w = widths.get[i, Int]() @parameter - for i in range(len(widths)): - alias w = widths.get[i, Int]() - - @parameter - if simdwidthof[D]() >= w: - for _ in range((length - processed) // w): - var p_curr = ptr + processed - var vec = p_curr.load[width=w]() - p_curr.store(where(vec).select(func(vec), vec)) - processed += w - - for i in range(length - processed): - var vec = ptr[processed + i] - if where(vec): - (ptr + processed + i).init_pointee_move(func(vec)) + if simdwidthof[D]() >= w: + for _ in range((length - processed) // w): + var p_curr = ptr + processed + var vec = p_curr.load[width=w]() + p_curr.store(where(vec).select(func(vec), vec)) + processed += w + + for i in range(length - processed): + var vec = ptr[processed + i] + if where(vec): + (ptr + processed + i).init_pointee_move(func(vec)) From 4bb7c83e1367f40a8a9242d6bd08a5ecc4da3777 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Fri, 13 Dec 2024 10:48:36 -0300 Subject: [PATCH 6/7] mojo format Signed-off-by: martinvuyk --- stdlib/src/memory/span.mojo | 126 ++++++++++++++++++------------------ 1 file changed, 63 insertions(+), 63 deletions(-) diff --git a/stdlib/src/memory/span.mojo b/stdlib/src/memory/span.mojo index 71f3c9570f..f4a0a44cad 100644 --- a/stdlib/src/memory/span.mojo +++ b/stdlib/src/memory/span.mojo @@ -373,74 +373,74 @@ struct Span[ ptr=self._data, length=self._len ) -fn apply[ - D: DType, - O: MutableOrigin, //, - func: fn[w: Int] (SIMD[D, w]) -> SIMD[D, w], -](mut self: Span[Scalar[D], O]): - """Apply the function to the `Span` inplace. + fn apply[ + D: DType, + O: MutableOrigin, //, + func: fn[w: Int] (SIMD[D, w]) -> SIMD[D, w], + ](mut self: Span[Scalar[D], O]): + """Apply the function to the `Span` inplace. - Parameters: - D: The DType. - O: The origin of the `Span`. - func: The function to evaluate. - """ - - alias widths = (256, 128, 64, 32, 16, 8, 4) - var ptr = self.unsafe_ptr() - var length = len(self) - var processed = 0 + Parameters: + D: The DType. + O: The origin of the `Span`. + func: The function to evaluate. + """ - @parameter - for i in range(len(widths)): - alias w = widths.get[i, Int]() + alias widths = (256, 128, 64, 32, 16, 8, 4) + var ptr = self.unsafe_ptr() + var length = len(self) + var processed = 0 @parameter - if simdwidthof[D]() >= w: - for _ in range((length - processed) // w): - var p_curr = ptr + processed - p_curr.store(func(p_curr.load[width=w]())) - processed += w - - for i in range(length - processed): - (ptr + processed + i).init_pointee_move(func(ptr[processed + i])) - - -fn apply[ - D: DType, - O: MutableOrigin, //, - func: fn[w: Int] (SIMD[D, w]) -> SIMD[D, w], - *, - where: fn[w: Int] (SIMD[D, w]) -> SIMD[DType.bool, w], -](mut self: Span[Scalar[D], O]): - """Apply the function to the `Span` inplace where the condition is - `True`. + for i in range(len(widths)): + alias w = widths.get[i, Int]() + + @parameter + if simdwidthof[D]() >= w: + for _ in range((length - processed) // w): + var p_curr = ptr + processed + p_curr.store(func(p_curr.load[width=w]())) + processed += w + + for i in range(length - processed): + (ptr + processed + i).init_pointee_move(func(ptr[processed + i])) + + + fn apply[ + D: DType, + O: MutableOrigin, //, + func: fn[w: Int] (SIMD[D, w]) -> SIMD[D, w], + *, + where: fn[w: Int] (SIMD[D, w]) -> SIMD[DType.bool, w], + ](mut self: Span[Scalar[D], O]): + """Apply the function to the `Span` inplace where the condition is + `True`. - Parameters: - D: The DType. - O: The origin of the `Span`. - func: The function to evaluate. - where: The condition to apply the function. - """ - - alias widths = (256, 128, 64, 32, 16, 8, 4) - var ptr = self.unsafe_ptr() - var length = len(self) - var processed = 0 + Parameters: + D: The DType. + O: The origin of the `Span`. + func: The function to evaluate. + where: The condition to apply the function. + """ - @parameter - for i in range(len(widths)): - alias w = widths.get[i, Int]() + alias widths = (256, 128, 64, 32, 16, 8, 4) + var ptr = self.unsafe_ptr() + var length = len(self) + var processed = 0 @parameter - if simdwidthof[D]() >= w: - for _ in range((length - processed) // w): - var p_curr = ptr + processed - var vec = p_curr.load[width=w]() - p_curr.store(where(vec).select(func(vec), vec)) - processed += w - - for i in range(length - processed): - var vec = ptr[processed + i] - if where(vec): - (ptr + processed + i).init_pointee_move(func(vec)) + for i in range(len(widths)): + alias w = widths.get[i, Int]() + + @parameter + if simdwidthof[D]() >= w: + for _ in range((length - processed) // w): + var p_curr = ptr + processed + var vec = p_curr.load[width=w]() + p_curr.store(where(vec).select(func(vec), vec)) + processed += w + + for i in range(length - processed): + var vec = ptr[processed + i] + if where(vec): + (ptr + processed + i).init_pointee_move(func(vec)) From 42e0bf936d150c5870fd481ddb04fdc05bcb92a6 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Fri, 13 Dec 2024 14:26:14 -0300 Subject: [PATCH 7/7] mojo format Signed-off-by: martinvuyk --- stdlib/src/memory/span.mojo | 1 - stdlib/test/memory/test_span.mojo | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/stdlib/src/memory/span.mojo b/stdlib/src/memory/span.mojo index f4a0a44cad..55d9981f9c 100644 --- a/stdlib/src/memory/span.mojo +++ b/stdlib/src/memory/span.mojo @@ -405,7 +405,6 @@ struct Span[ for i in range(length - processed): (ptr + processed + i).init_pointee_move(func(ptr[processed + i])) - fn apply[ D: DType, O: MutableOrigin, //, diff --git a/stdlib/test/memory/test_span.mojo b/stdlib/test/memory/test_span.mojo index 8db120951b..3eb10980ae 100644 --- a/stdlib/test/memory/test_span.mojo +++ b/stdlib/test/memory/test_span.mojo @@ -212,14 +212,14 @@ def test_apply(): ) twice = items span = Span(twice) - span.apply[func=_twice[D]]() + span.apply[func = _twice[D]]() for i in range(len(items)): assert_true(span[i] == items[i] * 2) # twice only even numbers twice = items span = Span(twice) - span.apply[func=_twice[D], where=_where[D]]() + span.apply[func = _twice[D], where = _where[D]]() for i in range(len(items)): if items[i] % 2 == 0: assert_true(span[i] == items[i] * 2)