Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Reintroduce Stringlike trait and use it for Stringlike.find() #3861

Draft
wants to merge 12 commits into
base: nightly
Choose a base branch
from
12 changes: 6 additions & 6 deletions stdlib/src/builtin/string_literal.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ from memory import UnsafePointer, memcpy, Span
from utils import StaticString, StringRef, StringSlice, Writable, Writer
from utils._visualizers import lldb_formatter_wrapping_type
from utils.format import _CurlyEntryFormattable, _FormatCurlyEntry
from utils.string_slice import _StringSliceIter, _to_string_list
from utils.string_slice import Stringlike, _StringSliceIter, _to_string_list

# ===-----------------------------------------------------------------------===#
# StringLiteral
Expand All @@ -34,9 +34,9 @@ from utils.string_slice import _StringSliceIter, _to_string_list
@lldb_formatter_wrapping_type
@register_passable("trivial")
struct StringLiteral(
Stringlike,
Boolable,
Comparable,
CollectionElementNew,
Writable,
IntableRaising,
KeyElement,
Expand Down Expand Up @@ -593,7 +593,7 @@ struct StringLiteral(

writer.write(self.as_string_slice())

fn find(self, substr: StringLiteral, start: Int = 0) -> Int:
fn find(self, substr: StringSlice, start: Int = 0) -> Int:
"""Finds the offset of the first occurrence of `substr` starting at
`start`. If not found, returns -1.

Expand All @@ -604,9 +604,9 @@ struct StringLiteral(
Returns:
The offset of `substr` relative to the beginning of the string.
"""
return StringRef(self).find(substr, start=start)
return self.as_string_slice().find(substr, start=start)

fn rfind(self, substr: StringLiteral, start: Int = 0) -> Int:
fn rfind(self, substr: StringSlice, start: Int = 0) -> Int:
"""Finds the offset of the last occurrence of `substr` starting at
`start`. If not found, returns -1.

Expand All @@ -617,7 +617,7 @@ struct StringLiteral(
Returns:
The offset of `substr` relative to the beginning of the string.
"""
return StringRef(self).rfind(substr, start=start)
return self.as_string_slice().rfind(substr, start=start)

fn replace(self, old: StringLiteral, new: StringLiteral) -> StringLiteral:
"""Return a copy of the string with all occurrences of substring `old`
Expand Down
18 changes: 7 additions & 11 deletions stdlib/src/collections/string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ from utils._unicode import (
)
from utils.format import _CurlyEntryFormattable, _FormatCurlyEntry
from utils.string_slice import (
Stringlike,
_shift_unicode_to_utf8,
_StringSliceIter,
_to_string_list,
Expand Down Expand Up @@ -742,6 +743,7 @@ fn isprintable(c: UInt8) -> Bool:

@value
struct String(
Stringlike,
Sized,
Stringable,
AsBytes,
Expand All @@ -752,7 +754,6 @@ struct String(
Boolable,
Writable,
Writer,
CollectionElementNew,
FloatableRaising,
_HashableWithHasher,
):
Expand Down Expand Up @@ -1595,8 +1596,7 @@ struct String(
return String(buf^)

fn unsafe_ptr(
ref self,
) -> UnsafePointer[
ref self ) -> UnsafePointer[
Byte,
mut = Origin(__origin_of(self)).is_mutable,
origin = __origin_of(self),
Expand Down Expand Up @@ -1712,7 +1712,7 @@ struct String(
"""
return substr.as_string_slice() in self.as_string_slice()

fn find(self, substr: String, start: Int = 0) -> Int:
fn find(self, substr: StringSlice, start: Int = 0) -> Int:
"""Finds the offset of the first occurrence of `substr` starting at
`start`. If not found, returns -1.

Expand All @@ -1723,10 +1723,9 @@ struct String(
Returns:
The offset of `substr` relative to the beginning of the string.
"""
return self.as_string_slice().find(substr, start)

return self.as_string_slice().find(substr.as_string_slice(), start)

fn rfind(self, substr: String, start: Int = 0) -> Int:
fn rfind(self, substr: StringSlice, start: Int = 0) -> Int:
"""Finds the offset of the last occurrence of `substr` starting at
`start`. If not found, returns -1.

Expand All @@ -1737,10 +1736,7 @@ struct String(
Returns:
The offset of `substr` relative to the beginning of the string.
"""

return self.as_string_slice().rfind(
substr.as_string_slice(), start=start
)
return self.as_string_slice().rfind(substr, start=start)

fn isspace(self) -> Bool:
"""Determines whether every character in the given String is a
Expand Down
90 changes: 81 additions & 9 deletions stdlib/src/utils/string_slice.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,10 @@ struct _StringSliceIter[
@value
@register_passable("trivial")
struct StringSlice[mut: Bool, //, origin: Origin[mut]](
Stringlike,
Stringable,
Sized,
Writable,
CollectionElement,
CollectionElementNew,
Hashable,
):
"""A non-owning view to encoded string data.
Expand Down Expand Up @@ -905,7 +904,8 @@ struct StringSlice[mut: Bool, //, origin: Origin[mut]](
"""
return _FormatCurlyEntry.format(self, args)

fn find(ref self, substr: StringSlice, start: Int = 0) -> Int:
# FIXME(#3526): this should return unicode codepoint offsets
fn find(self, substr: StringSlice, start: Int = 0) -> Int:
"""Finds the offset of the first occurrence of `substr` starting at
`start`. If not found, returns `-1`.

Expand All @@ -916,7 +916,7 @@ struct StringSlice[mut: Bool, //, origin: Origin[mut]](
Returns:
The offset of `substr` relative to the beginning of the string.
"""
if not substr:
if substr.byte_length() == 0:
return 0

if self.byte_length() < substr.byte_length() + start:
Expand All @@ -938,6 +938,7 @@ struct StringSlice[mut: Bool, //, origin: Origin[mut]](

return int(loc) - int(self.unsafe_ptr())

# FIXME(#3526): this should return unicode codepoint offsets
fn rfind(self, substr: StringSlice, start: Int = 0) -> Int:
"""Finds the offset of the last occurrence of `substr` starting at
`start`. If not found, returns `-1`.
Expand All @@ -949,10 +950,10 @@ struct StringSlice[mut: Bool, //, origin: Origin[mut]](
Returns:
The offset of `substr` relative to the beginning of the string.
"""
if not substr:
return len(self)
if substr.byte_length() == 0:
return self.byte_length()

if len(self) < len(substr) + start:
if self.byte_length() < substr.byte_length() + start:
return -1

# The substring to search within, offset from the beginning if `start`
Expand All @@ -961,9 +962,9 @@ struct StringSlice[mut: Bool, //, origin: Origin[mut]](

var loc = _memrmem(
haystack_str.unsafe_ptr(),
len(haystack_str),
haystack_str.byte_length(),
substr.unsafe_ptr(),
len(substr),
substr.byte_length(),
)

if not loc:
Expand Down Expand Up @@ -1111,6 +1112,77 @@ struct StringSlice[mut: Bool, //, origin: Origin[mut]](
# ===-----------------------------------------------------------------------===#


trait StringLike(CollectionElement, CollectionElementNew):
"""Trait intended to be used as a generic entrypoint for all String-like
types."""
...

trait StringOwnerLike(StringLike):
fn as_bytes(ref self) -> Span[Byte, __origin_of(self)]:
"""Returns a contiguous slice of the bytes owned by this string.

Returns:
A contiguous slice pointing to the bytes owned by this string.

Notes:
This does not include the trailing null terminator.
"""
...

fn as_string_slice(ref self) -> StringSlice[__origin_of(self)]:
"""Returns a string slice of the data owned by this string.

Returns:
A string slice pointing to the data owned by this string.
"""
...

trait StringSliceLike(StringLike):
alias mut: Bool
"""The mutability of the origin."""
alias origin: Origin[mut]
"""The origin of the data."""

fn as_bytes(self) -> Span[Byte, origin]:
"""Returns a contiguous slice of the bytes.

Returns:
A contiguous slice pointing to the bytes.

Notes:
This does not include the trailing null terminator.
"""
...

fn as_string_slice(self) -> StringSlice[origin]:
"""Returns a string slice of the data.

Returns:
A string slice pointing to the data.
"""
...

trait StringLiteralLike(StringLike):
fn as_bytes(self) -> Span[Byte, StaticConstantOrigin]:
"""Returns a contiguous slice of the bytes.

Returns:
A contiguous slice pointing to the bytes.

Notes:
This does not include the trailing null terminator.
"""
...

fn as_string_slice(self) -> StringSlice[StaticConstantOrigin]:
"""Returns a string slice of the data.

Returns:
A string slice pointing to the data.
"""
...


fn _to_string_list[
T: CollectionElement, # TODO(MOCO-1446): Make `T` parameter inferred
len_fn: fn (T) -> Int,
Expand Down
Loading