Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Vectorize ASCII helper functions #3859

Open
wants to merge 5 commits into
base: nightly
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
340 changes: 179 additions & 161 deletions stdlib/src/collections/string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -154,20 +154,188 @@ fn chr(c: Int) -> String:


# ===----------------------------------------------------------------------=== #
# ascii
# isdigit
# ===----------------------------------------------------------------------=== #


fn _chr_ascii(c: UInt8) -> String:
"""Returns a string based on the given ASCII code point.
@always_inline
fn _isdigit_vec[w: Int](v: SIMD[DType.uint8, w]) -> SIMD[DType.bool, w]:
alias `0` = SIMD[DType.uint8, w](Byte(ord("0")))
alias `9` = SIMD[DType.uint8, w](Byte(ord("9")))
return (`0` <= v) & (v <= `9`)


@always_inline
fn isdigit(c: Byte) -> Bool:
"""Determines whether the given character is a digit: [0, 9].

Args:
c: An integer that represents a code point.
c: The character to check.

Returns:
A string containing a single character based on the given code point.
True if the character is a digit.
"""
return _isdigit_vec(c)


# ===----------------------------------------------------------------------=== #
# isprintable
# ===----------------------------------------------------------------------=== #


@always_inline
fn _is_ascii_printable_vec[
w: Int
](v: SIMD[DType.uint8, w]) -> SIMD[DType.bool, w]:
alias ` ` = SIMD[DType.uint8, w](Byte(ord(" ")))
alias `~` = SIMD[DType.uint8, w](Byte(ord("~")))
return (` ` <= v) & (v <= `~`)


@always_inline
fn _nonprintable_ascii[w: Int](v: SIMD[DType.uint8, w]) -> SIMD[DType.bool, w]:
return (~_is_ascii_printable_vec(v)) & (v < 0b1000_0000)


@always_inline
fn _is_python_printable_vec[
w: Int
](v: SIMD[DType.uint8, w]) -> SIMD[DType.bool, w]:
alias `\\` = SIMD[DType.uint8, w](Byte(ord(" ")))
return (v != `\\`) & _is_ascii_printable_vec(v)


@always_inline
fn _nonprintable_python[w: Int](v: SIMD[DType.uint8, w]) -> SIMD[DType.bool, w]:
return (~_is_python_printable_vec(v)) & (v < 0b1000_0000)


@always_inline
fn isprintable(c: Byte) -> Bool:
"""Determines whether the given character is ASCII printable.

Args:
c: The character to check.

Returns:
True if the character is printable, otherwise False.
"""
return _is_ascii_printable_vec(c)


# ===----------------------------------------------------------------------=== #
# isupper
# ===----------------------------------------------------------------------=== #


@always_inline
fn _is_ascii_uppercase_vec[
w: Int
](v: SIMD[DType.uint8, w]) -> SIMD[DType.bool, w]:
alias `A` = SIMD[DType.uint8, w](Byte(ord("A")))
alias `Z` = SIMD[DType.uint8, w](Byte(ord("Z")))
return (`A` <= v) & (v <= `Z`)


@always_inline
fn _is_ascii_uppercase(c: Byte) -> Bool:
return _is_ascii_uppercase_vec(c)


@always_inline
fn isupper(c: Byte) -> Bool:
"""Determines whether the given character is an ASCII uppercase character:
`"ABCDEFGHIJKLMNOPQRSTUVWXYZ"`.

Args:
c: The character to check.

Returns:
True if the character is uppercase.
"""
return _is_ascii_uppercase(c)


# ===----------------------------------------------------------------------=== #
# islower
# ===----------------------------------------------------------------------=== #


@always_inline
fn _is_ascii_lowercase_vec[
w: Int
](v: SIMD[DType.uint8, w]) -> SIMD[DType.bool, w]:
alias `a` = SIMD[DType.uint8, w](Byte(ord("a")))
alias `z` = SIMD[DType.uint8, w](Byte(ord("z")))
return (`a` <= v) & (v <= `z`)


@always_inline
fn _is_ascii_lowercase(c: Byte) -> Bool:
return _is_ascii_lowercase_vec(c)


@always_inline
fn islower(c: Byte) -> Bool:
"""Determines whether the given character is an ASCII lowercase character:
`"abcdefghijklmnopqrstuvwxyz"`.

Args:
c: The character to check.

Returns:
True if the character is lowercase.
"""
return _is_ascii_lowercase(c)


# ===----------------------------------------------------------------------=== #
# isspace
# ===----------------------------------------------------------------------=== #


fn _is_ascii_space(c: Byte) -> Bool:
"""Determines whether the given character is an ASCII whitespace character:
`" \\t\\n\\v\\f\\r\\x1c\\x1d\\x1e"`.

Args:
c: The character to check.

Returns:
True if the character is one of the ASCII whitespace characters.

Notes:
For semantics similar to Python, use `String.isspace()`.
"""
return String(String._buffer_type(c, 0))

# NOTE: a global LUT doesn't work at compile time so we can't use it here.
alias ` ` = Byte(ord(" "))
alias `\t` = Byte(ord("\t"))
alias `\n` = Byte(ord("\n"))
alias `\r` = Byte(ord("\r"))
alias `\f` = Byte(ord("\f"))
alias `\v` = Byte(ord("\v"))
alias `\x1c` = Byte(ord("\x1c"))
alias `\x1d` = Byte(ord("\x1d"))
alias `\x1e` = Byte(ord("\x1e"))

# This compiles to something very clever that's even faster than a LUT.
return (
c == ` `
or c == `\t`
or c == `\n`
or c == `\r`
or c == `\f`
or c == `\v`
or c == `\x1c`
or c == `\x1d`
or c == `\x1e`
)


# ===----------------------------------------------------------------------=== #
# ascii
# ===----------------------------------------------------------------------=== #


fn _repr_ascii(c: UInt8) -> String:
Expand All @@ -187,7 +355,7 @@ fn _repr_ascii(c: UInt8) -> String:
if c == ord_back_slash:
return r"\\"
elif isprintable(c):
return _chr_ascii(c)
return String(String._buffer_type(c, 0))
elif c == ord_tab:
return r"\t"
elif c == ord_new_line:
Expand Down Expand Up @@ -304,13 +472,13 @@ fn _atol(str_slice: StringSlice, base: Int = 10) raises -> Int:
elif ord_letter_min[1] <= ord_current <= ord_letter_max[1]:
result += ord_current - ord_letter_min[1] + 10
found_valid_chars_after_start = True
elif _isspace(ord_current):
elif _is_ascii_space(ord_current):
has_space_after_number = True
start = pos + 1
break
else:
raise Error(_str_to_base_error(base, str_slice))
if pos + 1 < str_len and not _isspace(buff[pos + 1]):
if pos + 1 < str_len and not _is_ascii_space(buff[pos + 1]):
var nextresult = result * real_base
if nextresult < result:
raise Error(
Expand All @@ -324,7 +492,7 @@ fn _atol(str_slice: StringSlice, base: Int = 10) raises -> Int:

if has_space_after_number:
for pos in range(start, str_len):
if not _isspace(buff[pos]):
if not _is_ascii_space(buff[pos]):
raise Error(_str_to_base_error(base, str_slice))
if is_negative:
result = -result
Expand All @@ -346,7 +514,7 @@ fn _trim_and_handle_sign(str_slice: StringSlice, str_len: Int) -> (Int, Bool):
"""
var buff = str_slice.unsafe_ptr()
var start: Int = 0
while start < str_len and _isspace(buff[start]):
while start < str_len and _is_ascii_space(buff[start]):
start += 1
var p: Bool = buff[start] == ord("+")
var n: Bool = buff[start] == ord("-")
Expand Down Expand Up @@ -585,156 +753,6 @@ fn atof(str: String) raises -> Float64:
return _atof(str.as_string_slice())


# ===----------------------------------------------------------------------=== #
# isdigit
# ===----------------------------------------------------------------------=== #


fn isdigit(c: UInt8) -> Bool:
"""Determines whether the given character is a digit [0-9].

Args:
c: The character to check.

Returns:
True if the character is a digit.
"""
alias ord_0 = ord("0")
alias ord_9 = ord("9")
return ord_0 <= int(c) <= ord_9


# ===----------------------------------------------------------------------=== #
# isupper
# ===----------------------------------------------------------------------=== #


fn isupper(c: UInt8) -> Bool:
"""Determines whether the given character is an uppercase character.

This currently only respects the default "C" locale, i.e. returns True iff
the character specified is one of "ABCDEFGHIJKLMNOPQRSTUVWXYZ".

Args:
c: The character to check.

Returns:
True if the character is uppercase.
"""
return _is_ascii_uppercase(c)


fn _is_ascii_uppercase(c: UInt8) -> Bool:
alias ord_a = ord("A")
alias ord_z = ord("Z")
return ord_a <= int(c) <= ord_z


# ===----------------------------------------------------------------------=== #
# islower
# ===----------------------------------------------------------------------=== #


fn islower(c: UInt8) -> Bool:
"""Determines whether the given character is an lowercase character.

This currently only respects the default "C" locale, i.e. returns True iff
the character specified is one of "abcdefghijklmnopqrstuvwxyz".

Args:
c: The character to check.

Returns:
True if the character is lowercase.
"""
return _is_ascii_lowercase(c)


fn _is_ascii_lowercase(c: UInt8) -> Bool:
alias ord_a = ord("a")
alias ord_z = ord("z")
return ord_a <= int(c) <= ord_z


# ===----------------------------------------------------------------------=== #
# _isspace
# ===----------------------------------------------------------------------=== #


fn _isspace(c: String) -> Bool:
"""Determines whether the given character is a whitespace character.

This only respects the default "C" locale, i.e. returns True only if the
character specified is one of " \\t\\n\\v\\f\\r". For semantics similar
to Python, use `String.isspace()`.

Args:
c: The character to check.

Returns:
True iff the character is one of the whitespace characters listed above.
"""
return _isspace(ord(c))


fn _isspace(c: UInt8) -> Bool:
"""Determines whether the given character is a whitespace character.

This only respects the default "C" locale, i.e. returns True only if the
character specified is one of " \\t\\n\\v\\f\\r". For semantics similar
to Python, use `String.isspace()`.

Args:
c: The character to check.

Returns:
True iff the character is one of the whitespace characters listed above.
"""

# NOTE: a global LUT doesn't work at compile time so we can't use it here.
alias ` ` = UInt8(ord(" "))
alias `\t` = UInt8(ord("\t"))
alias `\n` = UInt8(ord("\n"))
alias `\r` = UInt8(ord("\r"))
alias `\f` = UInt8(ord("\f"))
alias `\v` = UInt8(ord("\v"))
alias `\x1c` = UInt8(ord("\x1c"))
alias `\x1d` = UInt8(ord("\x1d"))
alias `\x1e` = UInt8(ord("\x1e"))

# This compiles to something very clever that's even faster than a LUT.
return (
c == ` `
or c == `\t`
or c == `\n`
or c == `\r`
or c == `\f`
or c == `\v`
or c == `\x1c`
or c == `\x1d`
or c == `\x1e`
)


# ===----------------------------------------------------------------------=== #
# isprintable
# ===----------------------------------------------------------------------=== #


fn isprintable(c: UInt8) -> Bool:
"""Determines whether the given character is a printable character.

Args:
c: The character to check.

Returns:
True if the character is a printable character, otherwise False.
"""
alias ord_space = ord(" ")
alias ord_tilde = ord("~")
return ord_space <= int(c) <= ord_tilde


# ===----------------------------------------------------------------------=== #
# String
# ===----------------------------------------------------------------------=== #
Expand Down
Loading
Loading