diff --git a/stdlib/src/collections/string/string_slice.mojo b/stdlib/src/collections/string/string_slice.mojo index f632e806fd..a9aabe49a3 100644 --- a/stdlib/src/collections/string/string_slice.mojo +++ b/stdlib/src/collections/string/string_slice.mojo @@ -31,7 +31,7 @@ from memory.memory import _memcmp_impl_unconstrained from sys import bitwidthof, simdwidthof from sys.intrinsics import unlikely, likely from utils.stringref import StringRef, _memmem -from os import PathLike +from os import PathLike, abort alias StaticString = StringSlice[StaticConstantOrigin] """An immutable static string slice.""" @@ -165,6 +165,79 @@ fn _memrmem[ return UnsafePointer[Scalar[type]]() +@value +struct _SplitlinesIter[ + is_mutable: Bool, //, + origin: Origin[is_mutable], + forward: Bool = True, +]: + """Iterator for `StringSlice` over unicode linebreaks. + + Parameters: + is_mutable: Whether the slice is mutable. + origin: The origin of the underlying string data. + forward: The iteration direction. `False` is backwards. + """ + + alias `\r` = UInt8(ord("\r")) + alias `\n` = UInt8(ord("\n")) + + var index: Int + var ptr: UnsafePointer[Byte] + var length: Int + var keepends: Bool + + fn __iter__(self) -> Self: + return self + + fn __next__(mut self) -> StringSlice[origin]: + # highly performance sensitive code, benchmark before touching + @parameter + if forward: + var eol_start = self.index + var eol_length = 0 + + while eol_start < self.length: + var b0 = self.ptr[eol_start] + var char_len = _utf8_first_byte_sequence_length(b0) + debug_assert( + eol_start + char_len <= self.length, + "corrupted sequence causing unsafe memory access", + ) + var isnewline = unlikely( + _is_newline_char(self.ptr, eol_start, b0, char_len) + ) + var char_end = int(isnewline) * (eol_start + char_len) + var next_idx = char_end * int(char_end < self.length) + var is_r_n = b0 == Self.`\r` and next_idx != 0 and self.ptr[ + next_idx + ] == Self.`\n` + eol_length = int(isnewline) * char_len + int(is_r_n) + if isnewline: + break + eol_start += char_len + + var str_len = eol_start - self.index + int( + self.keepends + ) * eol_length + var s = StringSlice[origin]( + ptr=self.ptr + self.index, length=str_len + ) + self.index = eol_start + eol_length + return s + else: + constrained[False, "reversed splitlines not yet implemented"]() + return abort[StringSlice[origin]]() + + @always_inline + fn __has_next__(self) -> Bool: + @parameter + if forward: + return self.index < self.length + else: + return self.index > 0 + + @value struct _StringSliceIter[ mut: Bool, //, @@ -230,6 +303,24 @@ struct _StringSliceIter[ Span[Byte, ImmutableAnyOrigin](ptr=self.ptr, length=self.index) ) + fn splitlines( + owned self: _StringSliceIter[forward=True], *, keepends: Bool = False + ) -> _SplitlinesIter[origin, forward=True]: + """Split the string at line boundaries. This corresponds to Python's + [universal newlines:]( + https://docs.python.org/3/library/stdtypes.html#str.splitlines) + `"\\r\\n"` and `"\\t\\n\\v\\f\\r\\x1c\\x1d\\x1e\\x85\\u2028\\u2029"`. + + Args: + keepends: If True, line breaks are kept in the resulting strings. + + Returns: + An iterator of StringSlices over the input split by line boundaries. + """ + return _SplitlinesIter[origin, True]( + self.index, self.ptr, self.length, keepends + ) + @value @register_passable("trivial") @@ -1089,17 +1180,12 @@ struct StringSlice[mut: Bool, //, origin: Origin[mut]]( offset += b_len return length != 0 - fn splitlines[ - O: ImmutableOrigin, // - ](self: StringSlice[O], keepends: Bool = False) -> List[StringSlice[O]]: + fn splitlines(self, keepends: Bool = False) -> List[Self]: """Split the string at line boundaries. This corresponds to Python's [universal newlines:]( https://docs.python.org/3/library/stdtypes.html#str.splitlines) `"\\r\\n"` and `"\\t\\n\\v\\f\\r\\x1c\\x1d\\x1e\\x85\\u2028\\u2029"`. - Parameters: - O: The immutable origin. - Args: keepends: If True, line breaks are kept in the resulting strings. @@ -1107,44 +1193,9 @@ struct StringSlice[mut: Bool, //, origin: Origin[mut]]( A List of Strings containing the input split by line boundaries. """ - # highly performance sensitive code, benchmark before touching - alias `\r` = UInt8(ord("\r")) - alias `\n` = UInt8(ord("\n")) - - output = List[StringSlice[O]](capacity=128) # guessing - var ptr = self.unsafe_ptr() - var length = self.byte_length() - var offset = 0 - - while offset < length: - var eol_start = offset - var eol_length = 0 - - while eol_start < length: - var b0 = ptr[eol_start] - var char_len = _utf8_first_byte_sequence_length(b0) - debug_assert( - eol_start + char_len <= length, - "corrupted sequence causing unsafe memory access", - ) - var isnewline = unlikely( - _is_newline_char(ptr, eol_start, b0, char_len) - ) - var char_end = int(isnewline) * (eol_start + char_len) - var next_idx = char_end * int(char_end < length) - var is_r_n = b0 == `\r` and next_idx != 0 and ptr[ - next_idx - ] == `\n` - eol_length = int(isnewline) * char_len + int(is_r_n) - if isnewline: - break - eol_start += char_len - - var str_len = eol_start - offset + int(keepends) * eol_length - var s = StringSlice[O](ptr=ptr + offset, length=str_len) + var output = List[Self](capacity=128) # guessing + for s in self.__iter__().splitlines(keepends=keepends): output.append(s) - offset = eol_start + eol_length - return output^ fn count(self, substr: StringSlice) -> Int: