diff --git a/stdlib/src/collections/string/string.mojo b/stdlib/src/collections/string/string.mojo index b31dbf6a49..9cd5126d69 100644 --- a/stdlib/src/collections/string/string.mojo +++ b/stdlib/src/collections/string/string.mojo @@ -129,15 +129,11 @@ fn chr(c: Int) -> String: Examples: ```mojo - print(chr(97)) # "a" - print(chr(8364)) # "€" + print(chr(97), chr(8364)) # "a €" ``` . """ - if c < 0b1000_0000: # 1 byte ASCII char - return String(String._buffer_type(c, 0)) - var num_bytes = _unicode_codepoint_utf8_byte_length(c) var p = UnsafePointer[UInt8].alloc(num_bytes + 1) _shift_unicode_to_utf8(p, c, num_bytes) diff --git a/stdlib/src/collections/string/string_slice.mojo b/stdlib/src/collections/string/string_slice.mojo index f632e806fd..a43671ee22 100644 --- a/stdlib/src/collections/string/string_slice.mojo +++ b/stdlib/src/collections/string/string_slice.mojo @@ -67,8 +67,8 @@ fn _unicode_codepoint_utf8_byte_length(c: Int) -> Int: debug_assert( 0 <= c <= 0x10FFFF, "Value: ", c, " is not a valid Unicode code point" ) - alias sizes = SIMD[DType.int32, 4](0, 0b0111_1111, 0b0111_1111_1111, 0xFFFF) - return int((sizes < c).cast[DType.uint8]().reduce_add()) + alias sizes = SIMD[DType.uint32, 4](0, 0x80, 0x8_00, 0x1_00_00) + return int((sizes <= c).cast[DType.uint8]().reduce_add()) @always_inline @@ -80,12 +80,17 @@ fn _utf8_first_byte_sequence_length(b: Byte) -> Int: (b & 0b1100_0000) != 0b1000_0000, "Function does not work correctly if given a continuation byte.", ) - return int(count_leading_zeros(~b)) + int(b < 0b1000_0000) + return int(count_leading_zeros(~b) | (b < 0b1000_0000).cast[DType.uint8]()) -fn _shift_unicode_to_utf8(ptr: UnsafePointer[UInt8], c: Int, num_bytes: Int): +fn _shift_unicode_to_utf8[ + optimize_ascii: Bool = True +](ptr: UnsafePointer[UInt8], c: Int, num_bytes: Int): """Shift unicode to utf8 representation. + Parameters: + optimize_ascii: Optimize for languages with mostly ASCII characters. + ### Unicode (represented as UInt32 BE) to UTF-8 conversion: - 1: 00000000 00000000 00000000 0aaaaaaa -> 0aaaaaaa - a @@ -98,19 +103,32 @@ fn _shift_unicode_to_utf8(ptr: UnsafePointer[UInt8], c: Int, num_bytes: Int): - (a >> 18) | 0b11110000, (b >> 12) | 0b10000000, (c >> 6) | 0b10000000, d | 0b10000000 """ - if num_bytes == 1: - ptr[0] = UInt8(c) - return - var shift = 6 * (num_bytes - 1) - var mask = UInt8(0xFF) >> (num_bytes + 1) - var num_bytes_marker = UInt8(0xFF) << (8 - num_bytes) - ptr[0] = ((c >> shift) & mask) | num_bytes_marker - for i in range(1, num_bytes): - shift -= 6 - ptr[i] = ((c >> shift) & 0b0011_1111) | 0b1000_0000 + @parameter + if optimize_ascii: + if likely(num_bytes == 1): + ptr[0] = UInt8(c) + return + var shift = 6 * (num_bytes - 1) + var mask = UInt8(0xFF) >> (num_bytes + 1) + var num_bytes_marker = UInt8(0xFF) << (8 - num_bytes) + ptr[0] = ((c >> shift) & mask) | num_bytes_marker + for i in range(1, num_bytes): + shift -= 6 + ptr[i] = ((c >> shift) & 0b0011_1111) | 0b1000_0000 + else: + var shift = 6 * (num_bytes - 1) + var mask = UInt8(0xFF) >> (num_bytes + int(num_bytes > 1)) + var num_bytes_marker = UInt8(0xFF) << (8 - num_bytes) + ptr[0] = ((c >> shift) & mask) | ( + num_bytes_marker & -int(num_bytes != 1) + ) + for i in range(1, num_bytes): + shift -= 6 + ptr[i] = ((c >> shift) & 0b0011_1111) | 0b1000_0000 +@always_inline fn _utf8_byte_type(b: SIMD[DType.uint8, _], /) -> __type_of(b): """UTF-8 byte type. @@ -125,7 +143,7 @@ fn _utf8_byte_type(b: SIMD[DType.uint8, _], /) -> __type_of(b): - 3 -> start of 3 byte long sequence. - 4 -> start of 4 byte long sequence. """ - return count_leading_zeros(~(b & UInt8(0b1111_0000))) + return count_leading_zeros(~b) @always_inline diff --git a/stdlib/test/collections/string/test_string.mojo b/stdlib/test/collections/string/test_string.mojo index 74ae492479..f2c9e37486 100644 --- a/stdlib/test/collections/string/test_string.mojo +++ b/stdlib/test/collections/string/test_string.mojo @@ -326,6 +326,7 @@ def test_ord(): def test_chr(): + assert_equal("\0", chr(0)) assert_equal("A", chr(65)) assert_equal("a", chr(97)) assert_equal("!", chr(33))