From dfbb83f7b013a2118786c569c3ff3f50a6c9ccc2 Mon Sep 17 00:00:00 2001 From: Yiwu Chen <210at85@gmail.com> Date: Tue, 5 Nov 2024 21:45:11 +0000 Subject: [PATCH] [stdlib] Clean up `b64encode` (2/N) Signed-off-by: Yiwu Chen <210at85@gmail.com> --- stdlib/src/base64/_b64encode.mojo | 120 +++++++----------------------- stdlib/src/math/math.mojo | 2 +- 2 files changed, 28 insertions(+), 94 deletions(-) diff --git a/stdlib/src/base64/_b64encode.mojo b/stdlib/src/base64/_b64encode.mojo index d867a91be5a..369593d97b9 100644 --- a/stdlib/src/base64/_b64encode.mojo +++ b/stdlib/src/base64/_b64encode.mojo @@ -26,20 +26,13 @@ https://arxiv.org/abs/1704.00605 from builtin.simd import _sub_with_saturation from collections import InlineArray -from math.math import _compile_time_iota +from math.math import _iota from memory import memcpy, bitcast, UnsafePointer from utils import IndexList alias Bytes = SIMD[DType.uint8, _] -fn _base64_simd_mask[ - simd_width: Int -](nb_value_to_load: Int) -> SIMD[DType.bool, simd_width]: - alias mask = _compile_time_iota[DType.uint8, simd_width]() - return mask < UInt8(nb_value_to_load) - - # | |---- byte 2 ----|---- byte 1 ----|---- byte 0 ----| # | |c₁c₀d₅d₄d₃d₂d₁d₀|b₃b₂b₁b₀c₅c₄c₃c₂|a₅a₄a₃a₂a₁a₀b₅b₄| # <----------------|----------------|----------------|----------------| @@ -115,99 +108,20 @@ fn _to_b64_ascii[width: Int, //](input: Bytes[width]) -> Bytes[width]: return abcd + OFFSETS._dynamic_shuffle(offset_indices) -fn _get_table_number_of_bytes_to_store_from_number_of_bytes_to_load[ - simd_width: Int -]() -> SIMD[DType.uint8, simd_width]: - """This is a lookup table to know how many bytes we need to store in the output buffer - for a given number of bytes to encode in base64. Including the '=' sign. - - This table lookup is smaller than the simd size, because we only use it for the last chunk. - This should be called at compile time, otherwise it's quite slow. - """ - var result = SIMD[DType.uint8, simd_width](0) - for i in range(1, simd_width): - # We have "i" bytes to encode in base64, how many bytes do - # we need to store in the output buffer? Including the '=' sign. - - # math.ceil cannot be called at compile time, this is a workaround - var group_of_3_bytes = i // 3 - if i % 3 != 0: - group_of_3_bytes += 1 - - result[i] = group_of_3_bytes * 4 - return result - - fn _get_number_of_bytes_to_store_from_number_of_bytes_to_load[ max_size: Int ](nb_of_elements_to_load: Int) -> Int: - alias table = _get_table_number_of_bytes_to_store_from_number_of_bytes_to_load[ - max_size - ]() + alias table = _ceildiv(_iota[DType.uint8, max_size](), 3) * 4 return int(table[nb_of_elements_to_load]) -fn _get_table_number_of_bytes_to_store_from_number_of_bytes_to_load_without_equal_sign[ - simd_width: Int -]() -> SIMD[DType.uint8, simd_width]: - """This is a lookup table to know how many bytes we need to store in the output buffer - for a given number of bytes to encode in base64. This is **not** including the '=' sign. - - This table lookup is smaller than the simd size, because we only use it for the last chunk. - This should be called at compile time, otherwise it's quite slow. - """ - var result = SIMD[DType.uint8, simd_width]() - for i in range(simd_width): - # We have "i" bytes to encode in base64, how many bytes do - # we need to store in the output buffer? NOT including the '=' sign. - # We count the number of groups of 6 bits and we add 1 byte if there is an incomplete group. - var number_of_bits = i * 8 - var complete_groups_of_6_bits = number_of_bits // 6 - var incomplete_groups_of_6_bits: Int - if i * 8 % 6 == 0: - incomplete_groups_of_6_bits = 0 - else: - incomplete_groups_of_6_bits = 1 - - result[i] = complete_groups_of_6_bits + incomplete_groups_of_6_bits - return result - - fn _get_number_of_bytes_to_store_from_number_of_bytes_to_load_without_equal_sign[ max_size: Int ](nb_of_elements_to_load: Int) -> Int: - alias table = _get_table_number_of_bytes_to_store_from_number_of_bytes_to_load_without_equal_sign[ - max_size - ]() + alias table = _ceildiv(_iota[DType.uint8, max_size]() * 8, 6) return int(table[nb_of_elements_to_load]) -fn load_incomplete_simd[ - simd_width: Int -](pointer: UnsafePointer[UInt8], nb_of_elements_to_load: Int) -> SIMD[ - DType.uint8, simd_width -]: - var result = SIMD[DType.uint8, simd_width](0) - var tmp_buffer_pointer = UnsafePointer.address_of(result).bitcast[UInt8]() - memcpy(dest=tmp_buffer_pointer, src=pointer, count=nb_of_elements_to_load) - return result - - -fn store_incomplete_simd[ - simd_width: Int -]( - pointer: UnsafePointer[UInt8], - owned simd_vector: SIMD[DType.uint8, simd_width], - nb_of_elements_to_store: Int, -): - var tmp_buffer_pointer = UnsafePointer.address_of(simd_vector).bitcast[ - UInt8 - ]() - - memcpy(dest=pointer, src=tmp_buffer_pointer, count=nb_of_elements_to_store) - _ = simd_vector # We make it live long enough - - # TODO: Use Span instead of List as input when Span is easier to use @no_inline fn b64encode_with_buffers( @@ -242,9 +156,9 @@ fn b64encode_with_buffers( ) # We don't want to read past the input buffer - var input_vector = load_incomplete_simd[simd_width]( + var input_vector = load_simd[simd_width]( start_of_input_chunk, - nb_of_elements_to_load=nb_of_elements_to_load, + nb_of_elements_to_load, ) result_vector = _to_b64_ascii(input_vector) @@ -255,7 +169,9 @@ fn b64encode_with_buffers( ]( nb_of_elements_to_load ) - var equal_mask = _base64_simd_mask[simd_width](non_equal_chars_number) + var equal_mask = _iota[ + DType.uint8, simd_width + ]() < non_equal_chars_number var result_vector_with_equals = equal_mask.select( result_vector, equal_vector @@ -266,7 +182,7 @@ fn b64encode_with_buffers( ]( nb_of_elements_to_load ) - store_incomplete_simd( + store_simd( result.unsafe_ptr() + len(result), result_vector_with_equals, nb_of_elements_to_store, @@ -278,6 +194,10 @@ fn b64encode_with_buffers( # Utility functions +fn _ceildiv(a: Bytes, b: __type_of(a)) -> __type_of(a): + return (a + b - 1) // b + + fn _repeat_until[width: Int](v: SIMD) -> SIMD[v.type, width]: constrained[width >= v.size, "width must be at least v.size"]() @@ -291,3 +211,17 @@ fn _rshift_bits_in_u16[shift: Int](input: Bytes) -> __type_of(input): var u16 = bitcast[DType.uint16, input.size // 2](input) var res = bit.rotate_bits_right[shift](u16) return bitcast[DType.uint8, input.size](res) + + +fn load_simd[ + width: Int +](pointer: UnsafePointer[Byte], len: Int) -> Bytes[width]: + var result = Bytes[width]() + var buffer_ptr = UnsafePointer.address_of(result).bitcast[Byte]() + memcpy(dest=buffer_ptr, src=pointer, count=len) + return result + + +fn store_simd(ptr: UnsafePointer[Byte], owned v: Bytes, len: Int): + var buffer_ptr = UnsafePointer.address_of(v).bitcast[Byte]() + memcpy(dest=ptr, src=buffer_ptr, count=len) diff --git a/stdlib/src/math/math.mojo b/stdlib/src/math/math.mojo index 9e2d918eb73..0975f4a93ac 100644 --- a/stdlib/src/math/math.mojo +++ b/stdlib/src/math/math.mojo @@ -1049,7 +1049,7 @@ fn isclose[ # TODO: Remove this when `iota` works at compile-time -fn _compile_time_iota[type: DType, simd_width: Int]() -> SIMD[type, simd_width]: +fn _iota[type: DType, simd_width: Int]() -> SIMD[type, simd_width]: constrained[ type.is_integral(), "_compile_time_iota can only be used with integer types.",