Skip to content

Commit

Permalink
add the implementation from #3694 by @soraros
Browse files Browse the repository at this point in the history
Signed-off-by: martinvuyk <[email protected]>
  • Loading branch information
martinvuyk committed Dec 19, 2024
1 parent 6b7c940 commit ac5c11b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 67 deletions.
65 changes: 23 additions & 42 deletions stdlib/src/utils/write.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# ===----------------------------------------------------------------------=== #
"""Establishes the contract between `Writer` and `Writable` types."""

from bit import byte_swap
from collections import InlineArray
from sys.info import is_gpu

from memory import UnsafePointer, memcpy, Span
from memory import UnsafePointer, memcpy, Span, bitcast

from utils import StaticString

Expand Down Expand Up @@ -389,41 +390,27 @@ fn write_buffered[
# ===-----------------------------------------------------------------------===#


@always_inline
fn _hex_digit_to_hex_char(b: Byte) -> Byte:
alias values = SIMD[DType.uint8, 16](
Byte(ord("0")),
Byte(ord("1")),
Byte(ord("2")),
Byte(ord("3")),
Byte(ord("4")),
Byte(ord("5")),
Byte(ord("6")),
Byte(ord("7")),
Byte(ord("8")),
Byte(ord("9")),
Byte(ord("a")),
Byte(ord("b")),
Byte(ord("c")),
Byte(ord("d")),
Byte(ord("e")),
Byte(ord("f")),
)
return values[int(b)]
# fmt: off
alias _hex_table = SIMD[DType.uint8, 16](
ord("0"), ord("1"), ord("2"), ord("3"), ord("4"), ord("5"), ord("6"),
ord("7"), ord("8"), ord("9"), ord("a"), ord("b"), ord("c"), ord("d"),
ord("e"), ord("f"),
)
# fmt: on


@always_inline
fn _hex_digits_to_hex_chars(b: SIMD[DType.uint8, _]) -> __type_of(b):
alias `0` = Byte(ord("0"))
alias `9` = Byte(ord("9"))
alias `a` = Byte(ord("a"))
alias I8 = DType.int8
alias U8 = DType.uint8
return (
`0`
+ b
+ (((b <= 9).cast[I8]() - 1) & (`a` - `9` - 1).cast[I8]()).cast[U8]()
)
fn _hex_digits_to_hex_chars(x: Scalar, ptr: UnsafePointer[Byte]):
alias size = x.type.sizeof()
var data: SIMD[DType.uint8, size]

@parameter
if size == 1:
data = bitcast[DType.uint8, size](x)
else:
data = bitcast[DType.uint8, size](byte_swap(x))
var nibbles = (data >> 4).interleave(data & 0xF)
ptr.store(_hex_table._dynamic_shuffle(nibbles))


@always_inline
Expand Down Expand Up @@ -465,16 +452,10 @@ fn _write_hex[amnt_hex_bytes: Int](p: UnsafePointer[Byte], decimal: Int):
@parameter
if amnt_hex_bytes == 2:
(p + 1).init_pointee_move(`x`)
_hex_digits_to_hex_chars(Scalar[DType.uint8](decimal), p + 2)
elif amnt_hex_bytes == 4:
(p + 1).init_pointee_move(`u`)
_hex_digits_to_hex_chars(Scalar[DType.uint16](decimal), p + 2)
else:
(p + 1).init_pointee_move(`U`)

var idx = 0

@parameter
for i in reversed(range(amnt_hex_bytes)):
(p + 2 + idx).init_pointee_move(
_hex_digit_to_hex_char((decimal // (16**i)) % 16)
)
idx += 1
_hex_digits_to_hex_chars(Scalar[DType.uint32](decimal), p + 2)
25 changes: 0 additions & 25 deletions stdlib/test/utils/test_write.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ from utils.write import (
Writable,
Writer,
_write_hex,
_hex_digit_to_hex_char,
_hex_digits_to_hex_chars,
)
from utils.inline_string import _FixedString
Expand Down Expand Up @@ -100,30 +99,6 @@ def test_write_int_padded():


def test_write_hex():
values = List[Byte](
ord("0"),
ord("1"),
ord("2"),
ord("3"),
ord("4"),
ord("5"),
ord("6"),
ord("7"),
ord("8"),
ord("9"),
ord("a"),
ord("b"),
ord("c"),
ord("d"),
ord("e"),
ord("f"),
)
idx = 0
for value in values:
assert_equal(_hex_digit_to_hex_char(idx), value[])
assert_equal(_hex_digits_to_hex_chars(Byte(idx)), value[])
idx += 1

items = List[Byte](0, 0, 0, 0, 0, 0, 0, 0, 0)
alias S = StringSlice[__origin_of(items)]
ptr = items.unsafe_ptr()
Expand Down

0 comments on commit ac5c11b

Please sign in to comment.