diff --git a/docs/changelog.md b/docs/changelog.md index ee7a63b884..67d3d5a404 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -153,6 +153,12 @@ what we publish. - Added `StringSlice(..)` initializer from a `StringLiteral`. +- Added a `byte_length()` method to `String`, `StringSlice`, and `StringLiteral` +and deprecated their private `_byte_length()` methods. Added a warning to +`String.__len__` method that it will return length in Unicode codepoints in the +future and `StringSlice.__len__` now does return the Unicode codepoints length. +([PR #2960](https://github.com/modularml/mojo/pull/2960) by [@martinvuyk](https://github.com/martinvuyk)) + - Added new `StaticString` type alias. This can be used in place of `StringLiteral` for runtime string arguments. diff --git a/stdlib/src/base64/base64.mojo b/stdlib/src/base64/base64.mojo index d042c27e21..62e021d38f 100644 --- a/stdlib/src/base64/base64.mojo +++ b/stdlib/src/base64/base64.mojo @@ -72,7 +72,7 @@ fn b64encode(str: String) -> String: alias lookup = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" var b64chars = lookup.unsafe_ptr() - var length = len(str) + var length = str.byte_length() var out = String._buffer_type(capacity=length + 1) @parameter @@ -121,7 +121,7 @@ fn b64decode(str: String) -> String: Returns: The decoded string. """ - var n = len(str) + var n = str.byte_length() debug_assert(n % 4 == 0, "Input length must be divisible by 4") var p = String._buffer_type(capacity=n + 1) @@ -170,7 +170,7 @@ fn b16encode(str: String) -> String: alias lookup = "0123456789ABCDEF" var b16chars = lookup.unsafe_ptr() - var length = len(str) + var length = str.byte_length() var out = List[UInt8](capacity=length * 2 + 1) @parameter @@ -221,7 +221,7 @@ fn b16decode(str: String) -> String: return -1 - var n = len(str) + var n = str.byte_length() debug_assert(n % 2 == 0, "Input length must be divisible by 2") var p = List[UInt8](capacity=n // 2 + 1) diff --git a/stdlib/src/builtin/error.mojo b/stdlib/src/builtin/error.mojo index f80f276378..1891039620 100644 --- a/stdlib/src/builtin/error.mojo +++ b/stdlib/src/builtin/error.mojo @@ -80,7 +80,7 @@ struct Error( Returns: The constructed Error object. """ - var length = len(src) + var length = src.byte_length() var dest = UnsafePointer[UInt8].alloc(length + 1) memcpy( dest=dest, diff --git a/stdlib/src/builtin/file.mojo b/stdlib/src/builtin/file.mojo index 143ba314df..faa7e73653 100644 --- a/stdlib/src/builtin/file.mojo +++ b/stdlib/src/builtin/file.mojo @@ -239,7 +239,7 @@ struct FileHandle: var bytes = file.read(ptr, 8) print("bytes read", bytes) - var first_element = ptr.load(0) + var first_element = ptr[0] print(first_element) # Skip 2 elements @@ -374,7 +374,7 @@ struct FileHandle: ```mojo import os var f = open("/tmp/example.txt", "r") - f.seek(os.SEEK_CUR, 32) + _ = f.seek(32, os.SEEK_CUR) ``` Start from 32 bytes from the end of the file: @@ -382,7 +382,7 @@ struct FileHandle: ```mojo import os var f = open("/tmp/example.txt", "r") - f.seek(os.SEEK_END, -32) + _ = f.seek(-32, os.SEEK_END) ``` . """ @@ -409,7 +409,7 @@ struct FileHandle: Args: data: The data to write to the file. """ - self._write(data.unsafe_ptr(), len(data)) + self._write(data.unsafe_ptr(), data.byte_length()) fn write(self, data: Span[UInt8, _]) raises: """Write a borrowed sequence of data to the file. diff --git a/stdlib/src/builtin/io.mojo b/stdlib/src/builtin/io.mojo index 54ba79d319..f0d725f349 100644 --- a/stdlib/src/builtin/io.mojo +++ b/stdlib/src/builtin/io.mojo @@ -320,7 +320,7 @@ fn _put(x: DType, file: FileDescriptor = stdout): @no_inline fn _put(x: StringSlice, file: FileDescriptor = stdout): # Avoid printing "(null)" for an empty/default constructed `String` - var str_len = x._byte_length() + var str_len = x.byte_length() if not str_len: return @@ -341,7 +341,7 @@ fn _put(x: StringSlice, file: FileDescriptor = stdout): # The string can be printed, so that's fine. if str_len < MAX_STR_LEN: - _printf["%.*s"](x._byte_length(), x.unsafe_ptr(), file=file) + _printf["%.*s"](x.byte_length(), x.unsafe_ptr(), file=file) return # The string is large, then we need to chunk it. diff --git a/stdlib/src/builtin/string.mojo b/stdlib/src/builtin/string.mojo index 21bf6bf690..8c514bdfb8 100644 --- a/stdlib/src/builtin/string.mojo +++ b/stdlib/src/builtin/string.mojo @@ -68,11 +68,11 @@ fn ord(s: StringSlice) -> Int: var p = s.unsafe_ptr().bitcast[UInt8]() var b1 = p[] if (b1 >> 7) == 0: # This is 1 byte ASCII char - debug_assert(s._byte_length() == 1, "input string length must be 1") + debug_assert(s.byte_length() == 1, "input string length must be 1") return int(b1) var num_bytes = countl_zero(~b1) debug_assert( - s._byte_length() == int(num_bytes), "input string must be one character" + s.byte_length() == int(num_bytes), "input string must be one character" ) debug_assert( 1 < int(num_bytes) < 5, "invalid UTF-8 byte " + str(b1) + " at index 0" @@ -1008,11 +1008,10 @@ struct String( Construct a String from several `Formattable` arguments: ```mojo - from testing import assert_equal - var string = String.format_sequence(1, ", ", 2.0, ", ", "three") - - assert_equal(string, "1, 2.0, three") + print(string) # "1, 2.0, three" + %# from testing import assert_equal + %# assert_equal(string, "1, 2.0, three") ``` . """ @@ -1076,6 +1075,7 @@ struct String( Returns: A new string containing the character at the specified position. """ + # TODO(#933): implement this for unicode when we support llvm intrinsic evaluation at compile time var normalized_idx = normalize_index["String"](idx, self) var buf = Self._buffer_type(capacity=1) buf.append(self._buffer[normalized_idx]) @@ -1094,13 +1094,12 @@ struct String( var start: Int var end: Int var step: Int - start, end, step = span.indices(len(self)) + # TODO(#933): implement this for unicode when we support llvm intrinsic evaluation at compile time + + start, end, step = span.indices(self.byte_length()) var r = range(start, end, step) if step == 1: - return StringRef( - self._buffer.data + start, - len(r), - ) + return StringRef(self._buffer.data + start, len(r)) var buffer = Self._buffer_type() var result_len = len(r) @@ -1197,8 +1196,8 @@ struct String( return other if not other: return self - var self_len = len(self) - var other_len = len(other) + var self_len = self.byte_length() + var other_len = other.byte_length() var total_len = self_len + other_len var buffer = Self._buffer_type() buffer.resize(total_len + 1, 0) @@ -1237,8 +1236,8 @@ struct String( return if not other: return - var self_len = len(self) - var other_len = len(other) + var self_len = self.byte_length() + var other_len = other.byte_length() var total_len = self_len + other_len self._buffer.resize(total_len + 1, 0) # Copy the data alongside the terminator. @@ -1255,7 +1254,7 @@ struct String( An iterator of references to the string elements. """ return _StringIter[__lifetime_of(self)]( - unsafe_pointer=self.unsafe_ptr(), length=len(self) + unsafe_pointer=self.unsafe_ptr(), length=self.byte_length() ) fn __reversed__(ref [_]self) -> _StringIter[__lifetime_of(self), False]: @@ -1265,7 +1264,7 @@ struct String( A reversed iterator of references to the string elements. """ return _StringIter[__lifetime_of(self), forward=False]( - unsafe_pointer=self.unsafe_ptr(), length=len(self) + unsafe_pointer=self.unsafe_ptr(), length=self.byte_length() ) # ===------------------------------------------------------------------=== # @@ -1279,20 +1278,24 @@ struct String( Returns: True if the string length is greater than zero, and False otherwise. """ - return len(self) > 0 + return self.byte_length() > 0 fn __len__(self) -> Int: - """Gets the string length, in bytes. + """Gets the string length, in bytes (for now) PREFER: + String.byte_length(), a future version will make this method return + Unicode codepoints. Returns: - The string length, in bytes. + The string length, in bytes (for now). """ - # Avoid returning -1 if the buffer is not initialized - if not self.unsafe_ptr(): - return 0 + var unicode_length = self.byte_length() + + # TODO: everything uses this method assuming it's byte length + # for i in range(unicode_length): + # if _utf8_byte_type(self._buffer[i]) == 1: + # unicode_length -= 1 - # The negative 1 is to account for the terminator. - return len(self._buffer) - 1 + return unicode_length @always_inline fn __str__(self) -> String: @@ -1447,7 +1450,7 @@ struct String( strings. Using this requires the use of the _strref_keepalive() method to keep the underlying string alive long enough. """ - return StringRef(self.unsafe_ptr(), len(self)) + return StringRef(self.unsafe_ptr(), self.byte_length()) fn _strref_keepalive(self): """ @@ -1497,19 +1500,18 @@ struct String( @always_inline fn as_bytes_slice(ref [_]self) -> Span[UInt8, __lifetime_of(self)]: - """ - Returns a contiguous slice of the bytes owned by this string. - - This does not include the trailing null terminator. + """Returns a contiguous slice of the bytes owned by this string. Returns: A contiguous slice pointing to the bytes owned by this string. + + Notes: + This does not include the trailing null terminator. """ + # Does NOT include the NUL terminator. return Span[UInt8, __lifetime_of(self)]( - unsafe_ptr=self._buffer.unsafe_ptr(), - # Does NOT include the NUL terminator. - len=self._byte_length(), + unsafe_ptr=self._buffer.unsafe_ptr(), len=self.byte_length() ) @always_inline @@ -1524,21 +1526,30 @@ struct String( # guaranteed to be valid. return StringSlice(unsafe_from_utf8=self.as_bytes_slice()) - fn _byte_length(self) -> Int: + @always_inline + fn byte_length(self) -> Int: """Get the string length in bytes. - This does not include the trailing null terminator in the count. - Returns: The length of this string in bytes, excluding null terminator. + + Notes: + This does not include the trailing null terminator in the count. """ + return max(len(self._buffer) - 1, 0) - var buffer_len = len(self._buffer) + @always_inline + @deprecated("use byte_length() instead") + fn _byte_length(self) -> Int: + """Get the string length in bytes. - if buffer_len > 0: - return buffer_len - 1 - else: - return buffer_len + Returns: + The length of this string in bytes, excluding null terminator. + + Notes: + This does not include the trailing null terminator in the count. + """ + return max(len(self._buffer) - 1, 0) fn _steal_ptr(inout self) -> UnsafePointer[UInt8]: """Transfer ownership of pointer to the underlying memory. @@ -1578,7 +1589,7 @@ struct String( break res += 1 - offset = pos + len(substr) + offset = pos + substr.byte_length() return res @@ -1653,11 +1664,11 @@ struct String( var ptr2 = DTypePointer(item2) return memcmp(ptr1, ptr2, amnt) == 0 - if len(self) == 0: + if self.byte_length() == 0: return False for s in self: - var no_null_len = len(s) + var no_null_len = s.byte_length() var ptr = s.unsafe_ptr() if no_null_len == 1 and not _isspace(ptr[0]): return False @@ -1697,15 +1708,15 @@ struct String( """ var output = List[String]() - var str_iter_len = len(self) - 1 + var str_byte_len = self.byte_length() - 1 var lhs = 0 var rhs = 0 var items = 0 - var sep_len = len(sep) + var sep_len = sep.byte_length() if sep_len == 0: raise Error("ValueError: empty separator") - while lhs <= str_iter_len: + while lhs <= str_byte_len: rhs = self.find(sep, lhs) if rhs == -1: output.append(self[lhs:]) @@ -1724,12 +1735,11 @@ struct String( output.append("") return output - fn split(self, *, maxsplit: Int = -1) -> List[String]: + fn split(self, sep: NoneType = None, maxsplit: Int = -1) -> List[String]: """Split the string by every Whitespace separator. - Currently only uses C style separators. - Args: + sep: None. maxsplit: The maximum amount of items to split from String. Defaults to unlimited. @@ -1745,43 +1755,40 @@ struct String( # Splitting a string with leading, trailing, and middle whitespaces _ = String(" hello world ").split() # ["hello", "world"] + # Splitting adjacent universal newlines: + _ = String( + "hello \\t\\n\\r\\f\\v\\x1c\\x1d\\x1e\\x85\\u2028\\u2029world" + ).split() # ["hello", "world"] ``` . """ - # TODO: implement and document splitting adjacent universal newlines: - # _ = String( - # "hello \\t\\n\\r\\f\\v\\x1c\\x1e\\x85\\u2028\\u2029world" - # ).split() # ["hello", "world"] var output = List[String]() - - var str_iter_len = len(self) - 1 + var str_byte_len = self.byte_length() - 1 var lhs = 0 var rhs = 0 var items = 0 - # FIXME: this should iterate and build unicode strings - # and use self.isspace() - while lhs <= str_iter_len: + while lhs <= str_byte_len: # Python adds all "whitespace chars" as one separator # if no separator was specified - while lhs <= str_iter_len: - if not _isspace(self._buffer.unsafe_get(lhs)): + for s in self[lhs:]: + if not str(s).isspace(): # TODO: with StringSlice.isspace() break - lhs += 1 + lhs += s.byte_length() # if it went until the end of the String, then # it should be sliced up until the original # start of the whitespace which was already appended - if lhs - 1 == str_iter_len: + if lhs - 1 == str_byte_len: break - elif lhs == str_iter_len: + elif lhs == str_byte_len: # if the last char is not whitespace - output.append(self[str_iter_len]) + output.append(self[str_byte_len]) break rhs = lhs + 1 - while rhs <= str_iter_len: - if _isspace(self._buffer.unsafe_get(rhs)): + for s in self[lhs + 1 :]: + if str(s).isspace(): # TODO: with StringSlice.isspace() break - rhs += 1 + rhs += s.byte_length() if maxsplit > -1: if items == maxsplit: @@ -1804,7 +1811,7 @@ struct String( A List of Strings containing the input split by line boundaries. """ var output = List[String]() - var length = len(self) + var length = self.byte_length() var current_offset = 0 while current_offset < length: @@ -1855,9 +1862,9 @@ struct String( var self_ptr = self.unsafe_ptr() var new_ptr = new.unsafe_ptr() - var self_len = len(self) - var old_len = len(old) - var new_len = len(new) + var self_len = self.byte_length() + var old_len = old.byte_length() + var new_len = new.byte_length() var res = List[UInt8]() res.reserve(self_len + (old_len - new_len) * occurrences + 1) @@ -1922,7 +1929,7 @@ struct String( A copy of the string with no trailing characters. """ - var r_idx = len(self) + var r_idx = self.byte_length() while r_idx > 0 and self[r_idx - 1] in chars: r_idx -= 1 @@ -1934,8 +1941,12 @@ struct String( Returns: A copy of the string with no trailing whitespaces. """ - # TODO: should use self.__iter__ and self.isspace() - var r_idx = len(self) + var r_idx = self.byte_length() + # TODO (#933): should use this once llvm intrinsics can be used at comp time + # for s in self.__reversed__(): + # if not s.isspace(): + # break + # r_idx -= 1 while r_idx > 0 and _isspace(self._buffer.unsafe_get(r_idx - 1)): r_idx -= 1 return self[:r_idx] @@ -1951,7 +1962,7 @@ struct String( """ var l_idx = 0 - while l_idx < len(self) and self[l_idx] in chars: + while l_idx < self.byte_length() and self[l_idx] in chars: l_idx += 1 return self[l_idx:] @@ -1962,9 +1973,15 @@ struct String( Returns: A copy of the string with no leading whitespaces. """ - # TODO: should use self.__iter__ and self.isspace() var l_idx = 0 - while l_idx < len(self) and _isspace(self._buffer.unsafe_get(l_idx)): + # TODO (#933): should use this once llvm intrinsics can be used at comp time + # for s in self: + # if not s.isspace(): + # break + # l_idx += 1 + while l_idx < self.byte_length() and _isspace( + self._buffer.unsafe_get(l_idx) + ): l_idx += 1 return self[l_idx:] @@ -1982,9 +1999,9 @@ struct String( var res = List[UInt8]() var val_ptr = val.unsafe_ptr() var self_ptr = self.unsafe_ptr() - res.reserve(len(val) * len(self) + 1) - for i in range(len(self)): - for j in range(len(val)): + res.reserve(val.byte_length() * self.byte_length() + 1) + for i in range(self.byte_length()): + for j in range(val.byte_length()): res.append(val_ptr[j]) res.append(self_ptr[i]) res.append(0) @@ -2021,7 +2038,7 @@ struct String( var char_ptr = copy.unsafe_ptr() - for i in range(len(self)): + for i in range(self.byte_length()): var char: UInt8 = char_ptr[i] if check_case(char): var lower = _toggle_ascii_case(char) @@ -2043,7 +2060,7 @@ struct String( """ if end == -1: return StringRef( - self.unsafe_ptr() + start, len(self) - start + self.unsafe_ptr() + start, self.byte_length() - start ).startswith(prefix._strref_dangerous()) return StringRef(self.unsafe_ptr() + start, end - start).startswith( @@ -2064,7 +2081,7 @@ struct String( """ if end == -1: return StringRef( - self.unsafe_ptr() + start, len(self) - start + self.unsafe_ptr() + start, self.byte_length() - start ).endswith(suffix._strref_dangerous()) return StringRef(self.unsafe_ptr() + start, end - start).endswith( @@ -2091,7 +2108,7 @@ struct String( or a copy of the original string otherwise. """ if self.startswith(prefix): - return self[len(prefix) :] + return self[prefix.byte_length() :] return self fn removesuffix(self, suffix: String, /) -> String: @@ -2114,7 +2131,7 @@ struct String( or a copy of the original string otherwise. """ if suffix and self.endswith(suffix): - return self[: -len(suffix)] + return self[: -suffix.byte_length()] return self fn __int__(self) raises -> Int: @@ -2140,7 +2157,7 @@ struct String( """ if n <= 0: return "" - var len_self = len(self) + var len_self = self.byte_length() var count = len_self * n + 1 var buf = Self._buffer_type(capacity=count) buf.resize(count, 0) @@ -2193,7 +2210,10 @@ struct String( var current_automatic_arg_index = 0 for e in entries: - debug_assert(pos_in_self < len(self), "pos_in_self >= len(self)") + debug_assert( + pos_in_self < self.byte_length(), + "pos_in_self >= self.byte_length()", + ) res += self[pos_in_self : e[].first_curly] if e[].is_escaped_brace(): @@ -2216,8 +2236,8 @@ struct String( pos_in_self = e[].last_curly + 1 - if pos_in_self < len(self): - res += self[pos_in_self : len(self)] + if pos_in_self < self.byte_length(): + res += self[pos_in_self : self.byte_length()] return res^ @@ -2239,7 +2259,7 @@ struct String( return _is_ascii_uppercase(c) or _is_ascii_lowercase(c) for c in self: - debug_assert(c._byte_length() == 1, "only implemented for ASCII") + debug_assert(c.byte_length() == 1, "only implemented for ASCII") if is_ascii_cased(ord(c)): @parameter @@ -2492,7 +2512,7 @@ struct _FormatCurlyEntry(CollectionElement, CollectionElementNew): var entries = List[Self]() var start = Optional[Int](None) var skip_next = False - for i in range(len(format_src)): + for i in range(format_src.byte_length()): if skip_next: skip_next = False continue @@ -2549,7 +2569,7 @@ struct _FormatCurlyEntry(CollectionElement, CollectionElementNew): start = None else: # python escapes double curlies - if (i + 1) < len(format_src): + if (i + 1) < format_src.byte_length(): if format_src[i + 1] == "}": var curren_entry = Self( first_curly=i, last_curly=i + 1, field=True diff --git a/stdlib/src/builtin/string_literal.mojo b/stdlib/src/builtin/string_literal.mojo index 60b77847e7..650f075b2e 100644 --- a/stdlib/src/builtin/string_literal.mojo +++ b/stdlib/src/builtin/string_literal.mojo @@ -191,7 +191,7 @@ struct StringLiteral( # TODO(MSTDL-160): # Properly count Unicode codepoints instead of returning this length # in bytes. - return self._byte_length() + return self.byte_length() @always_inline("nodebug") fn __bool__(self) -> Bool: @@ -221,7 +221,7 @@ struct StringLiteral( A new string. """ var string = String() - var length = self._byte_length() + var length = self.byte_length() var buffer = String._buffer_type() var new_capacity = length + 1 buffer._realloc(new_capacity) @@ -265,11 +265,27 @@ struct StringLiteral( # ===-------------------------------------------------------------------===# @always_inline + fn byte_length(self) -> Int: + """Get the string length in bytes. + + Returns: + The length of this StringLiteral in bytes. + + Notes: + This does not include the trailing null terminator in the count. + """ + return __mlir_op.`pop.string.size`(self.value) + + @always_inline + @deprecated("use byte_length() instead") fn _byte_length(self) -> Int: """Get the string length in bytes. Returns: The length of this StringLiteral in bytes. + + Notes: + This does not include the trailing null terminator in the count. """ return __mlir_op.`pop.string.size`(self.value) @@ -336,7 +352,7 @@ struct StringLiteral( return Span[UInt8, ImmutableStaticLifetime]( unsafe_ptr=ptr, - len=self._byte_length(), + len=self.byte_length(), ) fn format_to(self, inout writer: Formatter): diff --git a/stdlib/src/pathlib/path.mojo b/stdlib/src/pathlib/path.mojo index 9d67ddf916..43976519c7 100644 --- a/stdlib/src/pathlib/path.mojo +++ b/stdlib/src/pathlib/path.mojo @@ -162,7 +162,7 @@ struct Path( Returns: True if the path length is greater than zero, and False otherwise. """ - return len(self.path) > 0 + return self.path.byte_length() > 0 fn format_to(self, inout writer: Formatter): """ diff --git a/stdlib/src/sys/ffi.mojo b/stdlib/src/sys/ffi.mojo index 3c65863f0d..fd2d3b8d50 100644 --- a/stdlib/src/sys/ffi.mojo +++ b/stdlib/src/sys/ffi.mojo @@ -231,7 +231,7 @@ fn _get_global[ fn _get_global_or_null[name: StringLiteral]() -> UnsafePointer[NoneType]: return external_call[ "KGEN_CompilerRT_GetGlobalOrNull", UnsafePointer[NoneType] - ](name.unsafe_ptr(), name._byte_length()) + ](name.unsafe_ptr(), name.byte_length()) @always_inline diff --git a/stdlib/src/tempfile/tempfile.mojo b/stdlib/src/tempfile/tempfile.mojo index 9864ba1b5f..bd3c9d69f2 100644 --- a/stdlib/src/tempfile/tempfile.mojo +++ b/stdlib/src/tempfile/tempfile.mojo @@ -31,7 +31,9 @@ fn _get_random_name(size: Int = 8) -> String: alias characters = String("abcdefghijklmnopqrstuvwxyz0123456789_") var name_list = List[UInt8](capacity=size + 1) for _ in range(size): - var rand_index = int(random.random_ui64(0, len(characters) - 1)) + var rand_index = int( + random.random_ui64(0, characters.byte_length() - 1) + ) name_list.append(ord(characters[rand_index])) name_list.append(0) return String(name_list^) diff --git a/stdlib/src/utils/inline_string.mojo b/stdlib/src/utils/inline_string.mojo index d5481cd223..db93624194 100644 --- a/stdlib/src/utils/inline_string.mojo +++ b/stdlib/src/utils/inline_string.mojo @@ -123,7 +123,7 @@ struct InlineString(Sized, Stringable, CollectionElement, CollectionElementNew): Args: str_slice: The string to append. """ - var total_len = len(self) + str_slice._byte_length() + var total_len = len(self) + str_slice.byte_length() # NOTE: Not guaranteed that we're in the small layout even if our # length is shorter than the small capacity. @@ -157,7 +157,7 @@ struct InlineString(Sized, Stringable, CollectionElement, CollectionElementNew): memcpy( dest=buffer.unsafe_ptr() + len(self), src=str_slice.unsafe_ptr(), - count=str_slice._byte_length(), + count=str_slice.byte_length(), ) # Record that we've initialized `total_len` count of elements @@ -441,14 +441,14 @@ struct _FixedString[CAP: Int]( inout self, str_slice: StringSlice[_], ) -> Optional[Error]: - var total_len = len(self) + str_slice._byte_length() + var total_len = len(self) + str_slice.byte_length() # Ensure there is sufficient capacity to append `str_slice` if total_len > CAP: return Optional( Error( "Insufficient capacity to append len=" - + str(str_slice._byte_length()) + + str(str_slice.byte_length()) + " string to len=" + str(len(self)) + " FixedString with capacity=" @@ -460,7 +460,7 @@ struct _FixedString[CAP: Int]( memcpy( dest=self.buffer.unsafe_ptr() + len(self), src=str_slice.unsafe_ptr(), - count=str_slice._byte_length(), + count=str_slice.byte_length(), ) self.size = total_len diff --git a/stdlib/src/utils/string_slice.mojo b/stdlib/src/utils/string_slice.mojo index 0d3b1e5dbd..ca15404743 100644 --- a/stdlib/src/utils/string_slice.mojo +++ b/stdlib/src/utils/string_slice.mojo @@ -21,6 +21,7 @@ from utils import StringSlice """ from utils import Span +from builtin.string import _isspace, _utf8_byte_type alias StaticString = StringSlice[ImmutableStaticLifetime] """An immutable static string slice.""" @@ -68,8 +69,7 @@ struct StringSlice[ # FIXME(MSTDL-160): # Ensure StringLiteral _actually_ always uses UTF-8 encoding. self = StringSlice[lifetime]( - unsafe_from_utf8_ptr=literal.unsafe_ptr(), - len=literal._byte_length(), + unsafe_from_utf8_ptr=literal.unsafe_ptr(), len=literal.byte_length() ) @always_inline @@ -155,9 +155,13 @@ struct StringSlice[ Returns: The length in Unicode codepoints. """ - # FIXME(MSTDL-160): - # Actually perform UTF-8 decoding here to count the codepoints. - return len(self._slice) + var unicode_length = self.byte_length() + + for i in range(unicode_length): + if _utf8_byte_type(self._slice[i]) == 1: + unicode_length -= 1 + + return unicode_length fn format_to(self, inout writer: Formatter): """ @@ -263,8 +267,7 @@ struct StringSlice[ @always_inline fn as_bytes_slice(self) -> Span[UInt8, lifetime]: - """ - Get the sequence of encoded bytes as a slice of the underlying string. + """Get the sequence of encoded bytes as a slice of the underlying string. Returns: A slice containing the underlying sequence of encoded bytes. @@ -273,8 +276,7 @@ struct StringSlice[ @always_inline fn unsafe_ptr(self) -> UnsafePointer[UInt8]: - """ - Gets a pointer to the first element of this string slice. + """Gets a pointer to the first element of this string slice. Returns: A pointer pointing at the first element of this string slice. @@ -283,9 +285,19 @@ struct StringSlice[ return self._slice.unsafe_ptr() @always_inline - fn _byte_length(self) -> Int: + fn byte_length(self) -> Int: + """Get the length of this string slice in bytes. + + Returns: + The length of this string slice in bytes. """ - Get the length of this string slice in bytes. + + return len(self.as_bytes_slice()) + + @always_inline + @deprecated("use byte_length() instead") + fn _byte_length(self) -> Int: + """Get the length of this string slice in bytes. Returns: The length of this string slice in bytes. @@ -294,8 +306,7 @@ struct StringSlice[ return len(self.as_bytes_slice()) fn _strref_dangerous(self) -> StringRef: - """ - Returns an inner pointer to the string as a StringRef. + """Returns an inner pointer to the string as a StringRef. Safety: This functionality is extremely dangerous because Mojo eagerly @@ -303,11 +314,10 @@ struct StringSlice[ _strref_keepalive() method to keep the underlying string alive long enough. """ - return StringRef(self.unsafe_ptr(), self._byte_length()) + return StringRef(self.unsafe_ptr(), self.byte_length()) fn _strref_keepalive(self): - """ - A no-op that keeps `self` alive through the call. This + """A no-op that keeps `self` alive through the call. This can be carefully used with `_strref_dangerous()` to wield inner pointers without the string getting deallocated early. """ diff --git a/stdlib/test/builtin/test_string.mojo b/stdlib/test/builtin/test_string.mojo index 541ef040a3..2007ee6f75 100644 --- a/stdlib/test/builtin/test_string.mojo +++ b/stdlib/test/builtin/test_string.mojo @@ -661,8 +661,31 @@ def test_split(): assert_true(d[0] == "hello \t" and d[1] == "" and d[2] == "\v\fworld") # Should add all whitespace-like chars as one - alias utf8_spaces = String(" \t\n\r\v\f") - var s = utf8_spaces + "hello" + utf8_spaces + "world" + utf8_spaces + # test all unicode separators + # 0 is to build a String with null terminator + alias next_line = List[UInt8](0xC2, 0x85, 0) + """TODO: \\x85""" + alias unicode_line_sep = List[UInt8](0xE2, 0x80, 0xA8, 0) + """TODO: \\u2028""" + alias unicode_paragraph_sep = List[UInt8](0xE2, 0x80, 0xA9, 0) + """TODO: \\u2029""" + # TODO add line and paragraph separator as stringliteral once unicode + # escape secuences are accepted + var univ_sep_var = ( + String(" ") + + String("\t") + + String("\n") + + String("\r") + + String("\v") + + String("\f") + + String("\x1c") + + String("\x1d") + + String("\x1e") + + String(next_line) + + String(unicode_line_sep) + + String(unicode_paragraph_sep) + ) + var s = univ_sep_var + "hello" + univ_sep_var + "world" + univ_sep_var d = s.split() assert_true(len(d) == 2) assert_true(d[0] == "hello" and d[1] == "world") @@ -1251,7 +1274,7 @@ def test_string_iter(): var utf8_sequence_len = 0 var byte_idx = 0 for v in item: - var byte_len = len(v) + var byte_len = v.byte_length() assert_equal(item[byte_idx : byte_idx + byte_len], v) byte_idx += byte_len utf8_sequence_len += 1