diff --git a/stdlib/src/utils/write.mojo b/stdlib/src/utils/write.mojo index e6b7814f73..b1ab1fda8a 100644 --- a/stdlib/src/utils/write.mojo +++ b/stdlib/src/utils/write.mojo @@ -12,10 +12,11 @@ # ===----------------------------------------------------------------------=== # """Establishes the contract between `Writer` and `Writable` types.""" +from bit import byte_swap from collections import InlineArray from sys.info import is_gpu -from memory import UnsafePointer, memcpy, Span +from memory import UnsafePointer, memcpy, Span, bitcast from utils import StaticString @@ -389,41 +390,27 @@ fn write_buffered[ # ===-----------------------------------------------------------------------===# -@always_inline -fn _hex_digit_to_hex_char(b: Byte) -> Byte: - alias values = SIMD[DType.uint8, 16]( - Byte(ord("0")), - Byte(ord("1")), - Byte(ord("2")), - Byte(ord("3")), - Byte(ord("4")), - Byte(ord("5")), - Byte(ord("6")), - Byte(ord("7")), - Byte(ord("8")), - Byte(ord("9")), - Byte(ord("a")), - Byte(ord("b")), - Byte(ord("c")), - Byte(ord("d")), - Byte(ord("e")), - Byte(ord("f")), - ) - return values[int(b)] +# fmt: off +alias _hex_table = SIMD[DType.uint8, 16]( + ord("0"), ord("1"), ord("2"), ord("3"), ord("4"), ord("5"), ord("6"), + ord("7"), ord("8"), ord("9"), ord("a"), ord("b"), ord("c"), ord("d"), + ord("e"), ord("f"), +) +# fmt: on @always_inline -fn _hex_digits_to_hex_chars(b: SIMD[DType.uint8, _]) -> __type_of(b): - alias `0` = Byte(ord("0")) - alias `9` = Byte(ord("9")) - alias `a` = Byte(ord("a")) - alias I8 = DType.int8 - alias U8 = DType.uint8 - return ( - `0` - + b - + (((b <= 9).cast[I8]() - 1) & (`a` - `9` - 1).cast[I8]()).cast[U8]() - ) +fn _hex_digits_to_hex_chars(x: Scalar, ptr: UnsafePointer[Byte]): + alias size = x.type.sizeof() + var data: SIMD[DType.uint8, size] + + @parameter + if size == 1: + data = bitcast[DType.uint8, size](x) + else: + data = bitcast[DType.uint8, size](byte_swap(x)) + var nibbles = (data >> 4).interleave(data & 0xF) + ptr.store(_hex_table._dynamic_shuffle(nibbles)) @always_inline @@ -465,16 +452,10 @@ fn _write_hex[amnt_hex_bytes: Int](p: UnsafePointer[Byte], decimal: Int): @parameter if amnt_hex_bytes == 2: (p + 1).init_pointee_move(`x`) + _hex_digits_to_hex_chars(Scalar[DType.uint8](decimal), p + 2) elif amnt_hex_bytes == 4: (p + 1).init_pointee_move(`u`) + _hex_digits_to_hex_chars(Scalar[DType.uint16](decimal), p + 2) else: (p + 1).init_pointee_move(`U`) - - var idx = 0 - - @parameter - for i in reversed(range(amnt_hex_bytes)): - (p + 2 + idx).init_pointee_move( - _hex_digit_to_hex_char((decimal // (16**i)) % 16) - ) - idx += 1 + _hex_digits_to_hex_chars(Scalar[DType.uint32](decimal), p + 2) diff --git a/stdlib/test/utils/test_write.mojo b/stdlib/test/utils/test_write.mojo index 790ad6382b..3703599ba5 100644 --- a/stdlib/test/utils/test_write.mojo +++ b/stdlib/test/utils/test_write.mojo @@ -20,7 +20,6 @@ from utils.write import ( Writable, Writer, _write_hex, - _hex_digit_to_hex_char, _hex_digits_to_hex_chars, ) from utils.inline_string import _FixedString @@ -100,30 +99,6 @@ def test_write_int_padded(): def test_write_hex(): - values = List[Byte]( - ord("0"), - ord("1"), - ord("2"), - ord("3"), - ord("4"), - ord("5"), - ord("6"), - ord("7"), - ord("8"), - ord("9"), - ord("a"), - ord("b"), - ord("c"), - ord("d"), - ord("e"), - ord("f"), - ) - idx = 0 - for value in values: - assert_equal(_hex_digit_to_hex_char(idx), value[]) - assert_equal(_hex_digits_to_hex_chars(Byte(idx)), value[]) - idx += 1 - items = List[Byte](0, 0, 0, 0, 0, 0, 0, 0, 0) alias S = StringSlice[__origin_of(items)] ptr = items.unsafe_ptr()