diff --git a/stdlib/benchmarks/utils/bench_memmem.mojo b/stdlib/benchmarks/utils/bench_memmem.mojo index d72a4bddb1..6fd0b16a89 100644 --- a/stdlib/benchmarks/utils/bench_memmem.mojo +++ b/stdlib/benchmarks/utils/bench_memmem.mojo @@ -18,7 +18,7 @@ from sys import simdwidthof from benchmark import Bench, BenchConfig, Bencher, BenchId, Unit, keep, run from bit import count_trailing_zeros from builtin.dtype import _uint_type_of_width -from memory import memcmp, bitcast, UnsafePointer +from memory import memcmp, bitcast, UnsafePointer, pack_bits from utils.stringref import _align_down, _memchr, _memmem @@ -168,7 +168,7 @@ fn _memmem_baseline[ ) for i in range(0, vectorized_end, bool_mask_width): var bool_mask = haystack.load[width=bool_mask_width](i) == first_needle - var mask = bitcast[_uint_type_of_width[bool_mask_width]()](bool_mask) + var mask = pack_bits(bool_mask) while mask: var offset = int(i + count_trailing_zeros(mask)) if memcmp(haystack + offset + 1, needle + 1, needle_len - 1) == 0: diff --git a/stdlib/src/memory/__init__.mojo b/stdlib/src/memory/__init__.mojo index cc226348fd..1074e5983a 100644 --- a/stdlib/src/memory/__init__.mojo +++ b/stdlib/src/memory/__init__.mojo @@ -16,5 +16,5 @@ from .arc import Arc from .box import Box from .memory import memcmp, memcpy, memset, memset_zero, stack_allocation from .pointer import AddressSpace, Pointer -from .unsafe import bitcast +from .unsafe import bitcast, pack_bits from .unsafe_pointer import UnsafePointer diff --git a/stdlib/src/memory/unsafe.mojo b/stdlib/src/memory/unsafe.mojo index 93d7e2266b..82e995de53 100644 --- a/stdlib/src/memory/unsafe.mojo +++ b/stdlib/src/memory/unsafe.mojo @@ -28,18 +28,21 @@ from sys import bitwidthof @always_inline("nodebug") fn bitcast[ - new_type: DType, new_width: Int, src_type: DType, src_width: Int -](val: SIMD[src_type, src_width]) -> SIMD[new_type, new_width]: + type: DType, + width: Int, //, + new_type: DType, + new_width: Int = width, +](val: SIMD[type, width]) -> SIMD[new_type, new_width]: """Bitcasts a SIMD value to another SIMD value. Constraints: The bitwidth of the two types must be the same. Parameters: + type: The source type. + width: The source width. new_type: The target type. new_width: The target width. - src_type: The source type. - src_width: The source width. Args: val: The source value. @@ -49,13 +52,13 @@ fn bitcast[ source SIMD value. """ constrained[ - bitwidthof[SIMD[src_type, src_width]]() + bitwidthof[SIMD[type, width]]() == bitwidthof[SIMD[new_type, new_width]](), "the source and destination types must have the same bitwidth", ]() @parameter - if new_type == src_type: + if new_type == type: return rebind[SIMD[new_type, new_width]](val) return __mlir_op.`pop.bitcast`[ _type = __mlir_type[ @@ -65,45 +68,31 @@ fn bitcast[ @always_inline("nodebug") -fn bitcast[ - new_type: DType, src_type: DType -](val: SIMD[src_type, 1]) -> SIMD[new_type, 1]: - """Bitcasts a SIMD value to another SIMD value. - - Constraints: - The bitwidth of the two types must be the same. - - Parameters: - new_type: The target type. - src_type: The source type. - - Args: - val: The source value. - - Returns: - A new SIMD value with the specified type and width with a bitcopy of the - source SIMD value. - """ - constrained[ - bitwidthof[SIMD[src_type, 1]]() == bitwidthof[SIMD[new_type, 1]](), - "the source and destination types must have the same bitwidth", - ]() - - return bitcast[new_type, 1, src_type, 1](val) +fn _uint(n: Int) -> DType: + if n == 8: + return DType.uint8 + elif n == 16: + return DType.uint16 + elif n == 32: + return DType.uint32 + else: + return DType.uint64 @always_inline("nodebug") -fn bitcast[ - new_type: DType, src_width: Int -](val: SIMD[DType.bool, src_width]) -> Scalar[new_type]: +fn pack_bits[ + width: Int, //, + new_type: DType = _uint(width), +](val: SIMD[DType.bool, width]) -> Scalar[new_type]: """Packs a SIMD bool into an integer. Constraints: - The bitwidth of the two types must be the same. + The width of the bool vector must be the same as the bitwidth of the + target type. Parameters: + width: The source width. new_type: The target type. - src_width: The source width. Args: val: The source value. @@ -112,8 +101,11 @@ fn bitcast[ A new integer scalar which has the same bitwidth as the bool vector. """ constrained[ - src_width == bitwidthof[Scalar[new_type]](), - "the source and destination types must have the same bitwidth", + width == bitwidthof[Scalar[new_type]](), + ( + "the width of the bool vector must be the same as the bitwidth of" + " the target type" + ), ]() return __mlir_op.`pop.bitcast`[ diff --git a/stdlib/src/utils/stringref.mojo b/stdlib/src/utils/stringref.mojo index 87732fc351..89c0bb2711 100644 --- a/stdlib/src/utils/stringref.mojo +++ b/stdlib/src/utils/stringref.mojo @@ -17,7 +17,7 @@ from bit import count_trailing_zeros from builtin.dtype import _uint_type_of_width from collections.string import _atol, _isspace from hashlib._hasher import _HashableWithHasher, _Hasher -from memory import UnsafePointer, memcmp, bitcast +from memory import UnsafePointer, memcmp, pack_bits from memory.memory import _memcmp_impl_unconstrained from utils import StringSlice from sys.ffi import c_char @@ -698,7 +698,7 @@ fn _memchr[ for i in range(0, vectorized_end, bool_mask_width): var bool_mask = source.load[width=bool_mask_width](i) == first_needle - var mask = bitcast[_uint_type_of_width[bool_mask_width]()](bool_mask) + var mask = pack_bits(bool_mask) if mask: return source + int(i + count_trailing_zeros(mask)) @@ -742,7 +742,7 @@ fn _memmem[ var eq_last = last_needle == last_block var bool_mask = eq_first & eq_last - var mask = bitcast[_uint_type_of_width[bool_mask_width]()](bool_mask) + var mask = pack_bits(bool_mask) while mask: var offset = int(i + count_trailing_zeros(mask))