diff --git a/stdlib/src/utils/write.mojo b/stdlib/src/utils/write.mojo index 6d68951b9b..4aeada7b87 100644 --- a/stdlib/src/utils/write.mojo +++ b/stdlib/src/utils/write.mojo @@ -12,10 +12,11 @@ # ===----------------------------------------------------------------------=== # """Establishes the contract between `Writer` and `Writable` types.""" +from bit import byte_swap from collections import InlineArray from sys.info import is_gpu -from memory import UnsafePointer, memcpy, Span +from memory import UnsafePointer, memcpy, Span, bitcast from utils import StaticString @@ -382,3 +383,104 @@ fn write_buffered[ var buffer = _WriteBufferStack[buffer_size](writer^) write_args(buffer, args, sep=sep, end=end) buffer.flush() + + +# ===-----------------------------------------------------------------------===# +# Utils +# ===-----------------------------------------------------------------------===# + + +# fmt: off +alias _hex_table = SIMD[DType.uint8, 16]( + ord("0"), ord("1"), ord("2"), ord("3"), ord("4"), ord("5"), ord("6"), + ord("7"), ord("8"), ord("9"), ord("a"), ord("b"), ord("c"), ord("d"), + ord("e"), ord("f"), +) +# fmt: on + + +@always_inline +fn _hex_digits_to_hex_chars(ptr: UnsafePointer[Byte], decimal: Scalar): + """Write a fixed width hexadecimal value into an uninitialized pointer + location, assumed to be large enough for the value to be written. + + Examples: + + ```mojo + %# from memory import memset_zero + %# from testing import assert_equal + %# from utils import StringSlice + %# from utils.write import _hex_digits_to_hex_chars + items = List[Byte](0, 0, 0, 0, 0, 0, 0, 0, 0) + alias S = StringSlice[__origin_of(items)] + ptr = items.unsafe_ptr() + _hex_digits_to_hex_chars(ptr, UInt32(ord("🔥"))) + assert_equal("0001f525", S(ptr=ptr, length=8)) + memset_zero(ptr, len(items)) + _hex_digits_to_hex_chars(ptr, UInt16(ord("你"))) + assert_equal("4f60", S(ptr=ptr, length=4)) + memset_zero(ptr, len(items)) + _hex_digits_to_hex_chars(ptr, UInt8(ord("Ö"))) + assert_equal("d6", S(ptr=ptr, length=2)) + ``` + . + """ + + alias size = decimal.type.sizeof() + var data: SIMD[DType.uint8, size] + + @parameter + if size == 1: + data = bitcast[DType.uint8, size](decimal) + else: + data = bitcast[DType.uint8, size](byte_swap(decimal)) + var nibbles = (data >> 4).interleave(data & 0xF) + ptr.store(_hex_table._dynamic_shuffle(nibbles)) + + +@always_inline +fn _write_hex[amnt_hex_bytes: Int](p: UnsafePointer[Byte], decimal: Int): + """Write a python compliant hexadecimal value into an uninitialized pointer + location, assumed to be large enough for the value to be written. + + Examples: + + ```mojo + %# from memory import memset_zero + %# from testing import assert_equal + %# from utils import StringSlice + %# from utils.write import _write_hex + items = List[Byte](0, 0, 0, 0, 0, 0, 0, 0, 0) + alias S = StringSlice[__origin_of(items)] + ptr = items.unsafe_ptr() + _write_hex[8](ptr, ord("🔥")) + assert_equal(r"\\U0001f525", S(ptr=ptr, length=10)) + memset_zero(ptr, len(items)) + _write_hex[4](ptr, ord("你")) + assert_equal(r"\\u4f60", S(ptr=ptr, length=6)) + memset_zero(ptr, len(items)) + _write_hex[2](ptr, ord("Ö")) + assert_equal(r"\\xd6", S(ptr=ptr, length=4)) + ``` + . + """ + + constrained[amnt_hex_bytes in (2, 4, 8), "only 2 or 4 or 8 sequences"]() + + alias `\\` = Byte(ord("\\")) + alias `x` = Byte(ord("x")) + alias `u` = Byte(ord("u")) + alias `U` = Byte(ord("U")) + + p.init_pointee_move(`\\`) + + @parameter + if amnt_hex_bytes == 2: + (p + 1).init_pointee_move(`x`) + _hex_digits_to_hex_chars(p + 2, UInt8(decimal)) + elif amnt_hex_bytes == 4: + (p + 1).init_pointee_move(`u`) + _hex_digits_to_hex_chars(p + 2, UInt16(decimal)) + else: + (p + 1).init_pointee_move(`U`) + _hex_digits_to_hex_chars(p + 2, UInt32(decimal)) diff --git a/stdlib/test/utils/test_format.mojo b/stdlib/test/utils/test_write.mojo similarity index 52% rename from stdlib/test/utils/test_format.mojo rename to stdlib/test/utils/test_write.mojo index 975d26464b..b4cfab20c0 100644 --- a/stdlib/test/utils/test_format.mojo +++ b/stdlib/test/utils/test_write.mojo @@ -14,20 +14,17 @@ from testing import assert_equal -from utils import Writable, Writer +from memory.memory import memset_zero +from utils import StringSlice +from utils.write import ( + Writable, + Writer, + _write_hex, + _hex_digits_to_hex_chars, +) from utils.inline_string import _FixedString -fn main() raises: - test_writer_of_string() - test_string_format_seq() - test_stringable_based_on_format() - - test_writer_of_fixed_string() - - test_write_int_padded() - - @value struct Point(Writable, Stringable): var x: Int @@ -42,7 +39,7 @@ struct Point(Writable, Stringable): return String.write(self) -fn test_writer_of_string() raises: +def test_writer_of_string(): # # Test write_to(String) # @@ -58,7 +55,7 @@ fn test_writer_of_string() raises: assert_equal(s2, "Point(3, 8)") -fn test_string_format_seq() raises: +def test_string_write_seq(): var s1 = String.write("Hello, ", "World!") assert_equal(s1, "Hello, World!") @@ -69,17 +66,17 @@ fn test_string_format_seq() raises: assert_equal(s3, "") -fn test_stringable_based_on_format() raises: +def test_stringable_based_on_format(): assert_equal(str(Point(10, 11)), "Point(10, 11)") -fn test_writer_of_fixed_string() raises: +def test_writer_of_fixed_string(): var s1 = _FixedString[100]() s1.write("Hello, World!") assert_equal(str(s1), "Hello, World!") -fn test_write_int_padded() raises: +def test_write_int_padded(): var s1 = String() Int(5).write_padded(s1, width=5) @@ -99,3 +96,56 @@ fn test_write_int_padded() raises: Int(12345).write_padded(s2, width=3) assert_equal(s2, "12345") + + +def test_hex_digits_to_hex_chars(): + items = List[Byte](0, 0, 0, 0, 0, 0, 0, 0, 0) + alias S = StringSlice[__origin_of(items)] + ptr = items.unsafe_ptr() + _hex_digits_to_hex_chars(ptr, UInt32(ord("🔥"))) + assert_equal("0001f525", S(ptr=ptr, length=8)) + memset_zero(ptr, len(items)) + _hex_digits_to_hex_chars(ptr, UInt16(ord("你"))) + assert_equal("4f60", S(ptr=ptr, length=4)) + memset_zero(ptr, len(items)) + _hex_digits_to_hex_chars(ptr, UInt8(ord("Ö"))) + assert_equal("d6", S(ptr=ptr, length=2)) + _hex_digits_to_hex_chars(ptr, UInt8(0)) + assert_equal("00", S(ptr=ptr, length=2)) + _hex_digits_to_hex_chars(ptr, UInt16(0)) + assert_equal("0000", S(ptr=ptr, length=4)) + _hex_digits_to_hex_chars(ptr, UInt32(0)) + assert_equal("00000000", S(ptr=ptr, length=8)) + _hex_digits_to_hex_chars(ptr, ~UInt8(0)) + assert_equal("ff", S(ptr=ptr, length=2)) + _hex_digits_to_hex_chars(ptr, ~UInt16(0)) + assert_equal("ffff", S(ptr=ptr, length=4)) + _hex_digits_to_hex_chars(ptr, ~UInt32(0)) + assert_equal("ffffffff", S(ptr=ptr, length=8)) + + +def test_write_hex(): + items = List[Byte](0, 0, 0, 0, 0, 0, 0, 0, 0) + alias S = StringSlice[__origin_of(items)] + ptr = items.unsafe_ptr() + _write_hex[8](ptr, ord("🔥")) + assert_equal(r"\U0001f525", S(ptr=ptr, length=10)) + memset_zero(ptr, len(items)) + _write_hex[4](ptr, ord("你")) + assert_equal(r"\u4f60", S(ptr=ptr, length=6)) + memset_zero(ptr, len(items)) + _write_hex[2](ptr, ord("Ö")) + assert_equal(r"\xd6", S(ptr=ptr, length=4)) + + +def main(): + test_writer_of_string() + test_string_write_seq() + test_stringable_based_on_format() + + test_writer_of_fixed_string() + + test_write_int_padded() + + test_hex_digits_to_hex_chars() + test_write_hex()