Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Add stol #3951

Open
wants to merge 1 commit into
base: nightly
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions stdlib/src/collections/string/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ from .string import (
ascii,
atof,
atol,
stol,
chr,
ord,
)
Expand Down
157 changes: 157 additions & 0 deletions stdlib/src/collections/string/string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,134 @@ fn ascii(value: StringSlice) -> String:
# ===----------------------------------------------------------------------=== #


fn stol(str_slice: StringSlice, base: Int = 10) raises -> (Int, String):
"""Convert a string to a integer and return the remaining unparsed string.

Similar to `atol`, but `stol` parses only a portion of the string and returns
both the parsed integer and the remaining unparsed part. For example, `stol("32abc")` returns `(32, "abc")`.
If base is 0, the string is parsed as an [Integer literal][1], with the following considerations:
- '0b' or '0B' prefix indicates binary (base 2)
- '0o' or '0O' prefix indicates octal (base 8)
- '0x' or '0X' prefix indicates hexadecimal (base 16)
- Without a prefix, it's treated as decimal (base 10)

Notes:
This follows [Python's integer literals](\
https://docs.python.org/3/reference/lexical_analysis.html#integers)

Raises:
If the base is invalid or if the string is empty.

Args:
str_slice: A string to be parsed as an integer in the given base.
base: Base used for conversion, value must be between 2 and 36, or 0.

Returns:
A tuple containing:
- An integer value representing the parsed part of the string.
- The remaining unparsed part of the string.

Examples:
>>> stol("19abc")
(19, "abc")
>>> stol("0xFF hello", 16)
(255, " hello")
>>> stol("0x123ghi", 0)
(291, "ghi")
>>> stol("0b1010 binary", 0)
(10, " binary")
>>> stol("0o123 octal", 0)
(83, " octal")

See Also:
`atol`: A similar function that parses the entire string and returns an integer.
[1]: https://docs.python.org/3/reference/lexical_analysis.html#integers.
"""
if (base != 0) and (base < 2 or base > 36):
raise Error("Base must be >= 2 and <= 36, or 0.")

if not str_slice:
raise Error("Cannot convert empty string to integer.")

var result: Int = 0
var real_base: Int
var start: Int = 0
var is_negative: Bool = False
var has_prefix: Bool = False
var str_len = str_slice.byte_length()
var buff = str_slice.unsafe_ptr()

start, is_negative = _trim_and_handle_sign(str_slice, str_len)

if start == str_len or not _is_valid_digit(Int(buff[start]), base):
return 0, String(str_slice)

var ord_num_max: Int
alias ord_0 = ord("0")
var ord_letter_max = (-1, -1)
alias ord_letter_min = (ord("a"), ord("A"))
alias ord_underscore = ord("_")

if base == 0:
real_base, start = _identify_base(str_slice, start)
if real_base == -1:
return 0, String(str_slice)

has_prefix = real_base != 10
else:
start, has_prefix = _handle_base_prefix(start, str_slice, str_len, base)
real_base = base

if real_base <= 10:
ord_num_max = ord(String(real_base - 1))
else:
ord_num_max = ord("9")
ord_letter_max = (
ord("a") + (real_base - 11),
ord("A") + (real_base - 11),
)

var was_last_digit_underscore = not (real_base in (2, 8, 16) and has_prefix)
for pos in range(start, str_len):
var ord_current = Int(buff[pos])
if ord_current == ord_underscore and was_last_digit_underscore:
break # Break out as opposed to raising exception as in `atol`
if ord_current == ord_underscore:
was_last_digit_underscore = True
continue

was_last_digit_underscore = False

var digit_value: Int
if ord_0 <= ord_current <= ord_num_max:
digit_value = ord_current - ord_0
elif ord_letter_min[0] <= ord_current <= ord_letter_max[0]:
digit_value = ord_current - ord_letter_min[0] + 10
elif ord_letter_min[1] <= ord_current <= ord_letter_max[1]:
digit_value = ord_current - ord_letter_min[1] + 10
else:
break

if digit_value >= real_base:
break

var new_result = result * real_base + digit_value
if new_result <= result and result > 0:
raise Error(
_str_to_base_error(real_base, str_slice)
+ " String expresses an integer too large to store in Int."
)
result = new_result
start = pos + 1

if is_negative:
result = -result

return result, String(
StringSlice(unsafe_from_utf8=str_slice.as_bytes()[start:])
)


fn atol(str_slice: StringSlice, base: Int = 10) raises -> Int:
"""Parses and returns the given string as an integer in the given base.

Expand Down Expand Up @@ -415,6 +543,35 @@ fn _identify_base(str_slice: StringSlice, start: Int) -> Tuple[Int, Int]:
return 10, start


@always_inline
fn _is_valid_digit(char: UInt8, base: Int) -> Bool:
"""Checks if a character is a valid digit for the given base.
Args:
char: The character to check, as a UInt8.
base: The numeric base (0-36, where 0 is special case).
Returns:
True if the character is a valid digit for the given base, False otherwise.
"""
if base == 0:
# For base 0, we need to allow 0-9 and a-f/A-F for potential hex numbers
if char >= ord("0") and char <= ord("9"):
return True
var upper_char = char & ~32 # Convert to uppercase
return upper_char >= ord("A") and upper_char <= ord("F")

if char == ord("_"):
return True

if char >= ord("0") and char <= ord("9"):
return (char - ord("0")) < base
if base <= 10:
return False
var upper_char = char & ~32 # Convert to uppercase
if upper_char >= ord("A") and upper_char <= ord("Z"):
return (upper_char - ord("A") + 10) < base
return False


fn _atof_error(str_ref: StringSlice) -> Error:
return Error(
"String is not convertible to float: '" + String(str_ref) + "'"
Expand Down
1 change: 1 addition & 0 deletions stdlib/src/prelude/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ from collections.string import (
ascii,
atof,
atol,
stol,
chr,
ord,
)
Expand Down
162 changes: 162 additions & 0 deletions stdlib/test/collections/string/test_string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,166 @@ def test_string_indexing():
assert_equal("H", str[-50::50])


def test_stol():
var result: Int
var remaining: String

# base 10
result, remaining = stol(String("375 ABC"))
assert_equal(375, result)
assert_equal(" ABC", remaining)
result, remaining = stol(String(" 005"))
assert_equal(5, result)
assert_equal("", remaining)
result, remaining = stol(String(" 013 "))
assert_equal(13, result)
assert_equal(" ", remaining)
result, remaining = stol(String("-89"))
assert_equal(-89, result)
assert_equal("", remaining)
result, remaining = stol(String(" -52"))
assert_equal(-52, result)
assert_equal("", remaining)

# other bases
result, remaining = stol(" FF", 16)
assert_equal(255, result)
assert_equal("", remaining)
result, remaining = stol(" 0xff ", 16)
assert_equal(255, result)
assert_equal(" ", remaining)
result, remaining = stol("10010eighteen18", 2)
assert_equal(18, result)
assert_equal("eighteen18", remaining)
result, remaining = stol("0b10010", 2)
assert_equal(18, result)
result, remaining = stol("0b_10010", 2)
assert_equal(18, result)
result, remaining = stol("0b_0010010", 2)
assert_equal(18, result)
result, remaining = stol("0b0000_0_010010", 2)
assert_equal(18, result)
assert_equal("", remaining)
result, remaining = stol("0o12", 8)
assert_equal(10, result)
result, remaining = stol("0o_12", 8)
assert_equal(10, result)
result, remaining = stol("0o_012", 8)
assert_equal(10, result)
result, remaining = stol("0o0000_0_0012", 8)
assert_equal(10, result)
assert_equal("", remaining)
result, remaining = stol("Z", 36)
assert_equal(35, result)
assert_equal("", remaining)

# test with trailing characters
result, remaining = stol("123abc")
assert_equal(123, result)
assert_equal("abc", remaining)
result, remaining = stol("-45def")
assert_equal(-45, result)
assert_equal("def", remaining)
result, remaining = stol("0xffghi", 0)
assert_equal(255, result)
result, remaining = stol("0x_ffghi", 0)
assert_equal(255, result)
result, remaining = stol("0x_0ffghi", 0)
assert_equal(255, result)
result, remaining = stol("0x0000_0_00ffghi", 0)
assert_equal(255, result)
assert_equal("ghi", remaining)

result, remaining = stol(" ")
assert_equal(0, result)
assert_equal(" ", remaining)

result, remaining = stol("123.456", 10)
assert_equal(123, result)
assert_equal(".456", remaining)
result, remaining = stol("--123", 10)
assert_equal(0, result)
assert_equal("--123", remaining)

result, remaining = stol("12a34", 10)
assert_equal(12, result)
assert_equal("a34", remaining)
result, remaining = stol("1G5", 16)
assert_equal(1, result)
assert_equal("G5", remaining)

result, remaining = stol("-1A", 16)
assert_equal(-26, result)
assert_equal("", remaining)
result, remaining = stol("-110", 2)
assert_equal(-6, result)
assert_equal("", remaining)

result, remaining = stol("Mojo!")
assert_equal(0, result)
assert_equal("Mojo!", remaining)

# Negative Cases
with assert_raises(contains="Cannot convert empty string to integer."):
_ = stol("")

with assert_raises(contains="Base must be >= 2 and <= 36, or 0."):
_ = stol("Bad Base", 42)

with assert_raises(
contains="String expresses an integer too large to store in Int."
):
_ = stol(String("9223372036854775832"), 10)


def test_stol_base_0():
var result: Int
var remaining: String

result, remaining = stol("155_155", 0)
assert_equal(155155, result)
assert_equal("", remaining)
result, remaining = stol("1_2_3_4_5", 0)
assert_equal(12345, result)
assert_equal("", remaining)
result, remaining = stol("1_2_3_4_5_", 0)
assert_equal(12345, result)
assert_equal("_", remaining)
result, remaining = stol("0b1_0_1_0", 0)
assert_equal(10, result)
assert_equal("", remaining)
result, remaining = stol("0o1_2_3", 0)
assert_equal(83, result)
assert_equal("", remaining)
result, remaining = stol("0x1_A_B", 0)
assert_equal(427, result)
assert_equal("", remaining)
result, remaining = stol("123_", 0)
assert_equal(123, result)
assert_equal("_", remaining)
result, remaining = stol("_123", 0)
assert_equal(0, result)
assert_equal("_123", remaining)
result, remaining = stol("123__456", 0)
assert_equal(123, result)
assert_equal("__456", remaining)
result, remaining = stol("0x1_23", 0)
assert_equal(291, result)
assert_equal("", remaining)
result, remaining = stol("0_123", 0)
assert_equal(0, result)
assert_equal("0_123", remaining)
result, remaining = stol("0z123", 0)
assert_equal(0, result)
assert_equal("0z123", remaining)
result, remaining = stol("Mojo!", 0)
assert_equal(0, result)
assert_equal("Mojo!", remaining)
result, remaining = stol("0o123 octal", 0)
assert_equal(83, result)
assert_equal(" octal", remaining)


def test_atol():
# base 10
assert_equal(375, atol("375"))
Expand Down Expand Up @@ -1468,6 +1628,8 @@ def main():
test_ord()
test_chr()
test_string_indexing()
test_stol()
test_stol_base_0()
test_atol()
test_atol_base_0()
test_atof()
Expand Down
Loading