Skip to content

Commit

Permalink
[External] [stdlib] Clean up b64encode (2/N) (#50831)
Browse files Browse the repository at this point in the history
[External] [stdlib] Clean up `b64encode` (2/N)

Co-authored-by: soraros <[email protected]>
Closes #3746
MODULAR_ORIG_COMMIT_REV_ID: e5bf916a6cc953c18bcb23b482de4a86778d5f52
  • Loading branch information
soraros authored and modularbot committed Nov 15, 2024
1 parent 02aba74 commit 7edd561
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 95 deletions.
120 changes: 27 additions & 93 deletions stdlib/src/base64/_b64encode.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,13 @@ https://arxiv.org/abs/1704.00605

from builtin.simd import _sub_with_saturation
from collections import InlineArray
from math.math import _compile_time_iota
from math.math import _iota
from memory import memcpy, bitcast, UnsafePointer
from utils import IndexList

alias Bytes = SIMD[DType.uint8, _]


fn _base64_simd_mask[
simd_width: Int
](nb_value_to_load: Int) -> SIMD[DType.bool, simd_width]:
alias mask = _compile_time_iota[DType.uint8, simd_width]()
return mask < UInt8(nb_value_to_load)


# | |---- byte 2 ----|---- byte 1 ----|---- byte 0 ----|
# | |c₁c₀d₅d₄d₃d₂d₁d₀|b₃b₂b₁b₀c₅c₄c₃c₂|a₅a₄a₃a₂a₁a₀b₅b₄|
# <----------------|----------------|----------------|----------------|
Expand Down Expand Up @@ -115,99 +108,20 @@ fn _to_b64_ascii[width: Int, //](input: Bytes[width]) -> Bytes[width]:
return abcd + OFFSETS._dynamic_shuffle(offset_indices)


fn _get_table_number_of_bytes_to_store_from_number_of_bytes_to_load[
simd_width: Int
]() -> SIMD[DType.uint8, simd_width]:
"""This is a lookup table to know how many bytes we need to store in the output buffer
for a given number of bytes to encode in base64. Including the '=' sign.
This table lookup is smaller than the simd size, because we only use it for the last chunk.
This should be called at compile time, otherwise it's quite slow.
"""
var result = SIMD[DType.uint8, simd_width](0)
for i in range(1, simd_width):
# We have "i" bytes to encode in base64, how many bytes do
# we need to store in the output buffer? Including the '=' sign.

# math.ceil cannot be called at compile time, this is a workaround
var group_of_3_bytes = i // 3
if i % 3 != 0:
group_of_3_bytes += 1

result[i] = group_of_3_bytes * 4
return result


fn _get_number_of_bytes_to_store_from_number_of_bytes_to_load[
max_size: Int
](nb_of_elements_to_load: Int) -> Int:
alias table = _get_table_number_of_bytes_to_store_from_number_of_bytes_to_load[
max_size
]()
alias table = _ceildiv_u8(_iota[DType.uint8, max_size](), 3) * 4
return int(table[nb_of_elements_to_load])


fn _get_table_number_of_bytes_to_store_from_number_of_bytes_to_load_without_equal_sign[
simd_width: Int
]() -> SIMD[DType.uint8, simd_width]:
"""This is a lookup table to know how many bytes we need to store in the output buffer
for a given number of bytes to encode in base64. This is **not** including the '=' sign.
This table lookup is smaller than the simd size, because we only use it for the last chunk.
This should be called at compile time, otherwise it's quite slow.
"""
var result = SIMD[DType.uint8, simd_width]()
for i in range(simd_width):
# We have "i" bytes to encode in base64, how many bytes do
# we need to store in the output buffer? NOT including the '=' sign.
# We count the number of groups of 6 bits and we add 1 byte if there is an incomplete group.
var number_of_bits = i * 8
var complete_groups_of_6_bits = number_of_bits // 6
var incomplete_groups_of_6_bits: Int
if i * 8 % 6 == 0:
incomplete_groups_of_6_bits = 0
else:
incomplete_groups_of_6_bits = 1

result[i] = complete_groups_of_6_bits + incomplete_groups_of_6_bits
return result


fn _get_number_of_bytes_to_store_from_number_of_bytes_to_load_without_equal_sign[
max_size: Int
](nb_of_elements_to_load: Int) -> Int:
alias table = _get_table_number_of_bytes_to_store_from_number_of_bytes_to_load_without_equal_sign[
max_size
]()
alias table = _ceildiv_u8(_iota[DType.uint8, max_size]() * 8, 6)
return int(table[nb_of_elements_to_load])


fn load_incomplete_simd[
simd_width: Int
](pointer: UnsafePointer[UInt8], nb_of_elements_to_load: Int) -> SIMD[
DType.uint8, simd_width
]:
var result = SIMD[DType.uint8, simd_width](0)
var tmp_buffer_pointer = UnsafePointer.address_of(result).bitcast[UInt8]()
memcpy(dest=tmp_buffer_pointer, src=pointer, count=nb_of_elements_to_load)
return result


fn store_incomplete_simd[
simd_width: Int
](
pointer: UnsafePointer[UInt8],
owned simd_vector: SIMD[DType.uint8, simd_width],
nb_of_elements_to_store: Int,
):
var tmp_buffer_pointer = UnsafePointer.address_of(simd_vector).bitcast[
UInt8
]()

memcpy(dest=pointer, src=tmp_buffer_pointer, count=nb_of_elements_to_store)
_ = simd_vector # We make it live long enough


# TODO: Use Span instead of List as input when Span is easier to use
@no_inline
fn b64encode_with_buffers(
Expand Down Expand Up @@ -242,9 +156,9 @@ fn b64encode_with_buffers(
)

# We don't want to read past the input buffer
var input_vector = load_incomplete_simd[simd_width](
var input_vector = load_simd[simd_width](
start_of_input_chunk,
nb_of_elements_to_load=nb_of_elements_to_load,
nb_of_elements_to_load,
)

result_vector = _to_b64_ascii(input_vector)
Expand All @@ -255,7 +169,9 @@ fn b64encode_with_buffers(
](
nb_of_elements_to_load
)
var equal_mask = _base64_simd_mask[simd_width](non_equal_chars_number)
var equal_mask = _iota[
DType.uint8, simd_width
]() < non_equal_chars_number

var result_vector_with_equals = equal_mask.select(
result_vector, equal_vector
Expand All @@ -266,7 +182,7 @@ fn b64encode_with_buffers(
](
nb_of_elements_to_load
)
store_incomplete_simd(
store_simd(
result.unsafe_ptr() + len(result),
result_vector_with_equals,
nb_of_elements_to_store,
Expand All @@ -278,6 +194,10 @@ fn b64encode_with_buffers(
# Utility functions


fn _ceildiv_u8(a: Bytes, b: __type_of(a)) -> __type_of(a):
return (a + b - 1) / b


fn _repeat_until[width: Int](v: SIMD) -> SIMD[v.type, width]:
constrained[width >= v.size, "width must be at least v.size"]()

Expand All @@ -291,3 +211,17 @@ fn _rshift_bits_in_u16[shift: Int](input: Bytes) -> __type_of(input):
var u16 = bitcast[DType.uint16, input.size // 2](input)
var res = bit.rotate_bits_right[shift](u16)
return bitcast[DType.uint8, input.size](res)


fn load_simd[
width: Int
](pointer: UnsafePointer[Byte], len: Int) -> Bytes[width]:
var result = Bytes[width]()
var buffer_ptr = UnsafePointer.address_of(result).bitcast[Byte]()
memcpy(dest=buffer_ptr, src=pointer, count=len)
return result


fn store_simd(ptr: UnsafePointer[Byte], owned v: Bytes, len: Int):
var buffer_ptr = UnsafePointer.address_of(v).bitcast[Byte]()
memcpy(dest=ptr, src=buffer_ptr, count=len)
4 changes: 2 additions & 2 deletions stdlib/src/math/math.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -1049,10 +1049,10 @@ fn isclose[


# TODO: Remove this when `iota` works at compile-time
fn _compile_time_iota[type: DType, simd_width: Int]() -> SIMD[type, simd_width]:
fn _iota[type: DType, simd_width: Int]() -> SIMD[type, simd_width]:
constrained[
type.is_integral(),
"_compile_time_iota can only be used with integer types.",
"_iota can only be used with integer types.",
]()
var a = SIMD[type, simd_width](0)
for i in range(simd_width):
Expand Down

0 comments on commit 7edd561

Please sign in to comment.