Skip to content

Commit

Permalink
[External] [stdlib] Clean up memory.unsafe (#48797)
Browse files Browse the repository at this point in the history
[External] [stdlib] Clean up `memory.unsafe`

- Make more things infer-only
- Remove unnecessary overload
- Rename the `bitcast` overload that performs the 'movemask' operation
to `pack_mask`. This change is intended to prevent issues where the
wrong overload might be selected, and since there is an implicit
conversion from scalar to simd at return, the user won't get a type
mismatch error to warn them about that.

Co-authored-by: soraros <[email protected]>
Closes #3588
MODULAR_ORIG_COMMIT_REV_ID: 18e72968cbdadf593e6bc9bba0794e326bada362
  • Loading branch information
soraros authored and modularbot committed Nov 2, 2024
1 parent b58d54c commit bb4a60c
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 44 deletions.
4 changes: 2 additions & 2 deletions stdlib/benchmarks/utils/bench_memmem.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ from sys import simdwidthof
from benchmark import Bench, BenchConfig, Bencher, BenchId, Unit, keep, run
from bit import count_trailing_zeros
from builtin.dtype import _uint_type_of_width
from memory import memcmp, bitcast, UnsafePointer
from memory import memcmp, bitcast, UnsafePointer, pack_bits

from utils.stringref import _align_down, _memchr, _memmem

Expand Down Expand Up @@ -168,7 +168,7 @@ fn _memmem_baseline[
)
for i in range(0, vectorized_end, bool_mask_width):
var bool_mask = haystack.load[width=bool_mask_width](i) == first_needle
var mask = bitcast[_uint_type_of_width[bool_mask_width]()](bool_mask)
var mask = pack_bits(bool_mask)
while mask:
var offset = int(i + count_trailing_zeros(mask))
if memcmp(haystack + offset + 1, needle + 1, needle_len - 1) == 0:
Expand Down
2 changes: 1 addition & 1 deletion stdlib/src/memory/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ from .arc import Arc
from .box import Box
from .memory import memcmp, memcpy, memset, memset_zero, stack_allocation
from .pointer import AddressSpace, Pointer
from .unsafe import bitcast
from .unsafe import bitcast, pack_bits
from .unsafe_pointer import UnsafePointer
68 changes: 30 additions & 38 deletions stdlib/src/memory/unsafe.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,21 @@ from sys import bitwidthof

@always_inline("nodebug")
fn bitcast[
new_type: DType, new_width: Int, src_type: DType, src_width: Int
](val: SIMD[src_type, src_width]) -> SIMD[new_type, new_width]:
type: DType,
width: Int, //,
new_type: DType,
new_width: Int = width,
](val: SIMD[type, width]) -> SIMD[new_type, new_width]:
"""Bitcasts a SIMD value to another SIMD value.
Constraints:
The bitwidth of the two types must be the same.
Parameters:
type: The source type.
width: The source width.
new_type: The target type.
new_width: The target width.
src_type: The source type.
src_width: The source width.
Args:
val: The source value.
Expand All @@ -49,13 +52,13 @@ fn bitcast[
source SIMD value.
"""
constrained[
bitwidthof[SIMD[src_type, src_width]]()
bitwidthof[SIMD[type, width]]()
== bitwidthof[SIMD[new_type, new_width]](),
"the source and destination types must have the same bitwidth",
]()

@parameter
if new_type == src_type:
if new_type == type:
return rebind[SIMD[new_type, new_width]](val)
return __mlir_op.`pop.bitcast`[
_type = __mlir_type[
Expand All @@ -65,45 +68,31 @@ fn bitcast[


@always_inline("nodebug")
fn bitcast[
new_type: DType, src_type: DType
](val: SIMD[src_type, 1]) -> SIMD[new_type, 1]:
"""Bitcasts a SIMD value to another SIMD value.
Constraints:
The bitwidth of the two types must be the same.
Parameters:
new_type: The target type.
src_type: The source type.
Args:
val: The source value.
Returns:
A new SIMD value with the specified type and width with a bitcopy of the
source SIMD value.
"""
constrained[
bitwidthof[SIMD[src_type, 1]]() == bitwidthof[SIMD[new_type, 1]](),
"the source and destination types must have the same bitwidth",
]()

return bitcast[new_type, 1, src_type, 1](val)
fn _uint(n: Int) -> DType:
if n == 8:
return DType.uint8
elif n == 16:
return DType.uint16
elif n == 32:
return DType.uint32
else:
return DType.uint64


@always_inline("nodebug")
fn bitcast[
new_type: DType, src_width: Int
](val: SIMD[DType.bool, src_width]) -> Scalar[new_type]:
fn pack_bits[
width: Int, //,
new_type: DType = _uint(width),
](val: SIMD[DType.bool, width]) -> Scalar[new_type]:
"""Packs a SIMD bool into an integer.
Constraints:
The bitwidth of the two types must be the same.
The width of the bool vector must be the same as the bitwidth of the
target type.
Parameters:
width: The source width.
new_type: The target type.
src_width: The source width.
Args:
val: The source value.
Expand All @@ -112,8 +101,11 @@ fn bitcast[
A new integer scalar which has the same bitwidth as the bool vector.
"""
constrained[
src_width == bitwidthof[Scalar[new_type]](),
"the source and destination types must have the same bitwidth",
width == bitwidthof[Scalar[new_type]](),
(
"the width of the bool vector must be the same as the bitwidth of"
" the target type"
),
]()

return __mlir_op.`pop.bitcast`[
Expand Down
6 changes: 3 additions & 3 deletions stdlib/src/utils/stringref.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ from bit import count_trailing_zeros
from builtin.dtype import _uint_type_of_width
from collections.string import _atol, _isspace
from hashlib._hasher import _HashableWithHasher, _Hasher
from memory import UnsafePointer, memcmp, bitcast
from memory import UnsafePointer, memcmp, pack_bits
from memory.memory import _memcmp_impl_unconstrained
from utils import StringSlice
from sys.ffi import c_char
Expand Down Expand Up @@ -698,7 +698,7 @@ fn _memchr[

for i in range(0, vectorized_end, bool_mask_width):
var bool_mask = source.load[width=bool_mask_width](i) == first_needle
var mask = bitcast[_uint_type_of_width[bool_mask_width]()](bool_mask)
var mask = pack_bits(bool_mask)
if mask:
return source + int(i + count_trailing_zeros(mask))

Expand Down Expand Up @@ -742,7 +742,7 @@ fn _memmem[
var eq_last = last_needle == last_block

var bool_mask = eq_first & eq_last
var mask = bitcast[_uint_type_of_width[bool_mask_width]()](bool_mask)
var mask = pack_bits(bool_mask)

while mask:
var offset = int(i + count_trailing_zeros(mask))
Expand Down

0 comments on commit bb4a60c

Please sign in to comment.