diff --git a/stdlib/src/utils/write.mojo b/stdlib/src/utils/write.mojo index b1ab1fda8a..0d41dd7478 100644 --- a/stdlib/src/utils/write.mojo +++ b/stdlib/src/utils/write.mojo @@ -400,15 +400,40 @@ alias _hex_table = SIMD[DType.uint8, 16]( @always_inline -fn _hex_digits_to_hex_chars(x: Scalar, ptr: UnsafePointer[Byte]): - alias size = x.type.sizeof() +fn _hex_digits_to_hex_chars(ptr: UnsafePointer[Byte], decimal: Scalar): + """Write a fixed width hexadecimal value into an uninitialized pointer + location, assumed to be large enough for the value to be written. + + Examples: + + ```mojo + %# from memory import memset_zero + %# from testing import assert_equal + %# from utils import StringSlice + %# from utils.write import _write_hex + items = List[Byte](0, 0, 0, 0, 0, 0, 0, 0, 0) + alias S = StringSlice[__origin_of(items)] + ptr = items.unsafe_ptr() + _hex_digits_to_hex_chars(ptr, UInt32(ord("🔥"))) + assert_equal("0001f525", S(ptr=ptr, length=10)) + memset_zero(ptr, len(items)) + _hex_digits_to_hex_chars(ptr, UInt16(ord("你"))) + assert_equal("4f60", S(ptr=ptr, length=6)) + memset_zero(ptr, len(items)) + _hex_digits_to_hex_chars(ptr, UInt8(ord("Ö"))) + assert_equal("xd6", S(ptr=ptr, length=4)) + ``` + . + """ + + alias size = decimal.type.sizeof() var data: SIMD[DType.uint8, size] @parameter if size == 1: - data = bitcast[DType.uint8, size](x) + data = bitcast[DType.uint8, size](decimal) else: - data = bitcast[DType.uint8, size](byte_swap(x)) + data = bitcast[DType.uint8, size](byte_swap(decimal)) var nibbles = (data >> 4).interleave(data & 0xF) ptr.store(_hex_table._dynamic_shuffle(nibbles)) @@ -452,10 +477,10 @@ fn _write_hex[amnt_hex_bytes: Int](p: UnsafePointer[Byte], decimal: Int): @parameter if amnt_hex_bytes == 2: (p + 1).init_pointee_move(`x`) - _hex_digits_to_hex_chars(Scalar[DType.uint8](decimal), p + 2) + _hex_digits_to_hex_chars(p + 2, UInt8(decimal)) elif amnt_hex_bytes == 4: (p + 1).init_pointee_move(`u`) - _hex_digits_to_hex_chars(Scalar[DType.uint16](decimal), p + 2) + _hex_digits_to_hex_chars(p + 2, UInt16(decimal)) else: (p + 1).init_pointee_move(`U`) - _hex_digits_to_hex_chars(Scalar[DType.uint32](decimal), p + 2) + _hex_digits_to_hex_chars(p + 2, UInt32(decimal)) diff --git a/stdlib/test/utils/test_write.mojo b/stdlib/test/utils/test_write.mojo index 3703599ba5..b4cfab20c0 100644 --- a/stdlib/test/utils/test_write.mojo +++ b/stdlib/test/utils/test_write.mojo @@ -98,6 +98,32 @@ def test_write_int_padded(): assert_equal(s2, "12345") +def test_hex_digits_to_hex_chars(): + items = List[Byte](0, 0, 0, 0, 0, 0, 0, 0, 0) + alias S = StringSlice[__origin_of(items)] + ptr = items.unsafe_ptr() + _hex_digits_to_hex_chars(ptr, UInt32(ord("🔥"))) + assert_equal("0001f525", S(ptr=ptr, length=8)) + memset_zero(ptr, len(items)) + _hex_digits_to_hex_chars(ptr, UInt16(ord("你"))) + assert_equal("4f60", S(ptr=ptr, length=4)) + memset_zero(ptr, len(items)) + _hex_digits_to_hex_chars(ptr, UInt8(ord("Ö"))) + assert_equal("d6", S(ptr=ptr, length=2)) + _hex_digits_to_hex_chars(ptr, UInt8(0)) + assert_equal("00", S(ptr=ptr, length=2)) + _hex_digits_to_hex_chars(ptr, UInt16(0)) + assert_equal("0000", S(ptr=ptr, length=4)) + _hex_digits_to_hex_chars(ptr, UInt32(0)) + assert_equal("00000000", S(ptr=ptr, length=8)) + _hex_digits_to_hex_chars(ptr, ~UInt8(0)) + assert_equal("ff", S(ptr=ptr, length=2)) + _hex_digits_to_hex_chars(ptr, ~UInt16(0)) + assert_equal("ffff", S(ptr=ptr, length=4)) + _hex_digits_to_hex_chars(ptr, ~UInt32(0)) + assert_equal("ffffffff", S(ptr=ptr, length=8)) + + def test_write_hex(): items = List[Byte](0, 0, 0, 0, 0, 0, 0, 0, 0) alias S = StringSlice[__origin_of(items)] @@ -121,4 +147,5 @@ def main(): test_write_int_padded() + test_hex_digits_to_hex_chars() test_write_hex()