Skip to content

Commit

Permalink
[stdlib] Clean up b64encode (2/N)
Browse files Browse the repository at this point in the history
Signed-off-by: Yiwu Chen <[email protected]>
  • Loading branch information
soraros committed Nov 5, 2024
1 parent 98067b5 commit dfbb83f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 94 deletions.
120 changes: 27 additions & 93 deletions stdlib/src/base64/_b64encode.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,13 @@ https://arxiv.org/abs/1704.00605

from builtin.simd import _sub_with_saturation
from collections import InlineArray
from math.math import _compile_time_iota
from math.math import _iota
from memory import memcpy, bitcast, UnsafePointer
from utils import IndexList

alias Bytes = SIMD[DType.uint8, _]


fn _base64_simd_mask[
simd_width: Int
](nb_value_to_load: Int) -> SIMD[DType.bool, simd_width]:
alias mask = _compile_time_iota[DType.uint8, simd_width]()
return mask < UInt8(nb_value_to_load)


# | |---- byte 2 ----|---- byte 1 ----|---- byte 0 ----|
# | |c₁c₀d₅d₄d₃d₂d₁d₀|b₃b₂b₁b₀c₅c₄c₃c₂|a₅a₄a₃a₂a₁a₀b₅b₄|
# <----------------|----------------|----------------|----------------|
Expand Down Expand Up @@ -115,99 +108,20 @@ fn _to_b64_ascii[width: Int, //](input: Bytes[width]) -> Bytes[width]:
return abcd + OFFSETS._dynamic_shuffle(offset_indices)


fn _get_table_number_of_bytes_to_store_from_number_of_bytes_to_load[
simd_width: Int
]() -> SIMD[DType.uint8, simd_width]:
"""This is a lookup table to know how many bytes we need to store in the output buffer
for a given number of bytes to encode in base64. Including the '=' sign.
This table lookup is smaller than the simd size, because we only use it for the last chunk.
This should be called at compile time, otherwise it's quite slow.
"""
var result = SIMD[DType.uint8, simd_width](0)
for i in range(1, simd_width):
# We have "i" bytes to encode in base64, how many bytes do
# we need to store in the output buffer? Including the '=' sign.

# math.ceil cannot be called at compile time, this is a workaround
var group_of_3_bytes = i // 3
if i % 3 != 0:
group_of_3_bytes += 1

result[i] = group_of_3_bytes * 4
return result


fn _get_number_of_bytes_to_store_from_number_of_bytes_to_load[
max_size: Int
](nb_of_elements_to_load: Int) -> Int:
alias table = _get_table_number_of_bytes_to_store_from_number_of_bytes_to_load[
max_size
]()
alias table = _ceildiv(_iota[DType.uint8, max_size](), 3) * 4
return int(table[nb_of_elements_to_load])


fn _get_table_number_of_bytes_to_store_from_number_of_bytes_to_load_without_equal_sign[
simd_width: Int
]() -> SIMD[DType.uint8, simd_width]:
"""This is a lookup table to know how many bytes we need to store in the output buffer
for a given number of bytes to encode in base64. This is **not** including the '=' sign.
This table lookup is smaller than the simd size, because we only use it for the last chunk.
This should be called at compile time, otherwise it's quite slow.
"""
var result = SIMD[DType.uint8, simd_width]()
for i in range(simd_width):
# We have "i" bytes to encode in base64, how many bytes do
# we need to store in the output buffer? NOT including the '=' sign.
# We count the number of groups of 6 bits and we add 1 byte if there is an incomplete group.
var number_of_bits = i * 8
var complete_groups_of_6_bits = number_of_bits // 6
var incomplete_groups_of_6_bits: Int
if i * 8 % 6 == 0:
incomplete_groups_of_6_bits = 0
else:
incomplete_groups_of_6_bits = 1

result[i] = complete_groups_of_6_bits + incomplete_groups_of_6_bits
return result


fn _get_number_of_bytes_to_store_from_number_of_bytes_to_load_without_equal_sign[
max_size: Int
](nb_of_elements_to_load: Int) -> Int:
alias table = _get_table_number_of_bytes_to_store_from_number_of_bytes_to_load_without_equal_sign[
max_size
]()
alias table = _ceildiv(_iota[DType.uint8, max_size]() * 8, 6)
return int(table[nb_of_elements_to_load])


fn load_incomplete_simd[
simd_width: Int
](pointer: UnsafePointer[UInt8], nb_of_elements_to_load: Int) -> SIMD[
DType.uint8, simd_width
]:
var result = SIMD[DType.uint8, simd_width](0)
var tmp_buffer_pointer = UnsafePointer.address_of(result).bitcast[UInt8]()
memcpy(dest=tmp_buffer_pointer, src=pointer, count=nb_of_elements_to_load)
return result


fn store_incomplete_simd[
simd_width: Int
](
pointer: UnsafePointer[UInt8],
owned simd_vector: SIMD[DType.uint8, simd_width],
nb_of_elements_to_store: Int,
):
var tmp_buffer_pointer = UnsafePointer.address_of(simd_vector).bitcast[
UInt8
]()

memcpy(dest=pointer, src=tmp_buffer_pointer, count=nb_of_elements_to_store)
_ = simd_vector # We make it live long enough


# TODO: Use Span instead of List as input when Span is easier to use
@no_inline
fn b64encode_with_buffers(
Expand Down Expand Up @@ -242,9 +156,9 @@ fn b64encode_with_buffers(
)

# We don't want to read past the input buffer
var input_vector = load_incomplete_simd[simd_width](
var input_vector = load_simd[simd_width](
start_of_input_chunk,
nb_of_elements_to_load=nb_of_elements_to_load,
nb_of_elements_to_load,
)

result_vector = _to_b64_ascii(input_vector)
Expand All @@ -255,7 +169,9 @@ fn b64encode_with_buffers(
](
nb_of_elements_to_load
)
var equal_mask = _base64_simd_mask[simd_width](non_equal_chars_number)
var equal_mask = _iota[
DType.uint8, simd_width
]() < non_equal_chars_number

var result_vector_with_equals = equal_mask.select(
result_vector, equal_vector
Expand All @@ -266,7 +182,7 @@ fn b64encode_with_buffers(
](
nb_of_elements_to_load
)
store_incomplete_simd(
store_simd(
result.unsafe_ptr() + len(result),
result_vector_with_equals,
nb_of_elements_to_store,
Expand All @@ -278,6 +194,10 @@ fn b64encode_with_buffers(
# Utility functions


fn _ceildiv(a: Bytes, b: __type_of(a)) -> __type_of(a):
return (a + b - 1) // b


fn _repeat_until[width: Int](v: SIMD) -> SIMD[v.type, width]:
constrained[width >= v.size, "width must be at least v.size"]()

Expand All @@ -291,3 +211,17 @@ fn _rshift_bits_in_u16[shift: Int](input: Bytes) -> __type_of(input):
var u16 = bitcast[DType.uint16, input.size // 2](input)
var res = bit.rotate_bits_right[shift](u16)
return bitcast[DType.uint8, input.size](res)


fn load_simd[
width: Int
](pointer: UnsafePointer[Byte], len: Int) -> Bytes[width]:
var result = Bytes[width]()
var buffer_ptr = UnsafePointer.address_of(result).bitcast[Byte]()
memcpy(dest=buffer_ptr, src=pointer, count=len)
return result


fn store_simd(ptr: UnsafePointer[Byte], owned v: Bytes, len: Int):
var buffer_ptr = UnsafePointer.address_of(v).bitcast[Byte]()
memcpy(dest=ptr, src=buffer_ptr, count=len)
2 changes: 1 addition & 1 deletion stdlib/src/math/math.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -1049,7 +1049,7 @@ fn isclose[


# TODO: Remove this when `iota` works at compile-time
fn _compile_time_iota[type: DType, simd_width: Int]() -> SIMD[type, simd_width]:
fn _iota[type: DType, simd_width: Int]() -> SIMD[type, simd_width]:
constrained[
type.is_integral(),
"_compile_time_iota can only be used with integer types.",
Expand Down

0 comments on commit dfbb83f

Please sign in to comment.