Skip to content

Commit

Permalink
Refactor and bug fix of atol function
Browse files Browse the repository at this point in the history
- Support leading underscores (bug fix)
- Add handle_base_prefix and trim_and_handle_sign helper functions
- Rename atol_error to str_to_base_error for clarity
- Update atol docstring for improved clarity

Breaks up the functionality of atol for better readability
and reusability, as suggested in PR modularml#3178.

Co-authored-by: martinvuyk <[email protected]>
Co-authored-by: soraros <[email protected]>

Signed-off-by: Joshua James Venter <[email protected]>
  • Loading branch information
jjvraw committed Jul 6, 2024
1 parent 39d95f0 commit c5876a6
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 51 deletions.
137 changes: 91 additions & 46 deletions stdlib/src/builtin/string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ fn _atol(str_ref: StringRef, base: Int = 10) raises -> Int:
if (base != 0) and (base < 2 or base > 36):
raise Error("Base must be >= 2 and <= 36, or 0.")
if not str_ref:
raise Error(_atol_error(base, str_ref))
raise Error(_str_to_base_error(base, str_ref))

var real_base: Int
var ord_num_max: Int
Expand All @@ -229,35 +229,12 @@ fn _atol(str_ref: StringRef, base: Int = 10) raises -> Int:
var is_negative: Bool = False
var start: Int = 0
var str_len = len(str_ref)
var buff = str_ref.unsafe_ptr()

for pos in range(start, str_len):
if _isspace(buff[pos]):
continue

if str_ref[pos] == "-":
is_negative = True
start = pos + 1
elif str_ref[pos] == "+":
start = pos + 1
else:
start = pos
break
start, is_negative = _trim_and_handle_sign(str_ref, str_len)

if str_ref[start] == "0" and start + 1 < str_len:
if base == 2 and (
str_ref[start + 1] == "b" or str_ref[start + 1] == "B"
):
start += 2
elif base == 8 and (
str_ref[start + 1] == "o" or str_ref[start + 1] == "O"
):
start += 2
elif base == 16 and (
str_ref[start + 1] == "x" or str_ref[start + 1] == "X"
):
start += 2
start = _handle_base_prefix(start, str_ref, str_len, base)

var buff = str_ref.unsafe_ptr()
alias ord_0 = ord("0")
# FIXME:
# Change this to `alias` after fixing support for __getitem__ of alias.
Expand All @@ -269,7 +246,7 @@ fn _atol(str_ref: StringRef, base: Int = 10) raises -> Int:
real_base = real_base_new_start[0]
start = real_base_new_start[1]
if real_base == -1:
raise Error(_atol_error(base, str_ref))
raise Error(_str_to_base_error(base, str_ref))
else:
real_base = base

Expand All @@ -284,20 +261,17 @@ fn _atol(str_ref: StringRef, base: Int = 10) raises -> Int:

var found_valid_chars_after_start = False
var has_space_after_number = False
# single underscores are only allowed between digits
# starting "was_last_digit_undescore" to true such that
# if the first digit is an undesrcore an error is raised
var was_last_digit_undescore = True
var was_last_digit_underscore = real_base == 10
for pos in range(start, str_len):
var ord_current = int(buff[pos])
if ord_current == ord_underscore:
if was_last_digit_undescore:
raise Error(_atol_error(base, str_ref))
if was_last_digit_underscore:
raise Error(_str_to_base_error(base, str_ref))
else:
was_last_digit_undescore = True
was_last_digit_underscore = True
continue
else:
was_last_digit_undescore = False
was_last_digit_underscore = False
if ord_0 <= ord_current <= ord_num_max:
result += ord_current - ord_0
found_valid_chars_after_start = True
Expand All @@ -312,29 +286,86 @@ fn _atol(str_ref: StringRef, base: Int = 10) raises -> Int:
start = pos + 1
break
else:
raise Error(_atol_error(base, str_ref))
raise Error(_str_to_base_error(base, str_ref))
if pos + 1 < str_len and not _isspace(buff[pos + 1]):
var nextresult = result * real_base
if nextresult < result:
raise Error(
_atol_error(base, str_ref)
_str_to_base_error(base, str_ref)
+ " String expresses an integer too large to store in Int."
)
result = nextresult

if was_last_digit_undescore or (not found_valid_chars_after_start):
raise Error(_atol_error(base, str_ref))
if was_last_digit_underscore or (not found_valid_chars_after_start):
raise Error(_str_to_base_error(base, str_ref))

if has_space_after_number:
for pos in range(start, str_len):
if not _isspace(buff[pos]):
raise Error(_atol_error(base, str_ref))
raise Error(_str_to_base_error(base, str_ref))
if is_negative:
result = -result
return result


fn _atol_error(base: Int, str_ref: StringRef) -> String:
@always_inline
fn _trim_and_handle_sign(str_ref: StringRef, str_len: Int) -> (Int, Bool):
"""Trims leading whitespace and handles the sign of the number in the string.
Args:
str_ref: A StringRef containing the number to parse.
str_len: The length of the string.
Returns:
A tuple containing:
- The starting index of the number after whitespace and sign.
- A boolean indicating whether the number is negative.
"""
var buff = str_ref.unsafe_ptr()
var is_negative: Bool = False
var start: Int = 0
for pos in range(start, str_len):
if _isspace(buff[pos]):
continue

if str_ref[pos] == "-":
is_negative = True
start = pos + 1
elif str_ref[pos] == "+":
start = pos + 1
else:
start = pos
break

return start, is_negative


@always_inline
fn _handle_base_prefix(
pos: Int, str_ref: StringRef, str_len: Int, base: Int
) -> Int:
"""Adjusts the starting position if a valid base prefix is present.
Handles "0b"/"0B" for base 2, "0o"/"0O" for base 8, and "0x"/"0X" for base 16.
Only adjusts if the base matches the prefix.
Args:
pos: Current position in the string.
str_ref: The input string.
str_len: Length of the input string.
base: The specified base.
Returns:
Updated position after the prefix, if applicable.
"""
var start = pos
if start + 1 < str_len:
var prefix_char = str_ref[start + 1]
if str_ref[start] == "0" and (
(base == 2 and (prefix_char == "b" or prefix_char == "B"))
or (base == 8 and (prefix_char == "o" or prefix_char == "O"))
or (base == 16 and (prefix_char == "x" or prefix_char == "X"))
):
start += 2
return start


fn _str_to_base_error(base: Int, str_ref: StringRef) -> String:
return (
"String is not convertible to integer with base "
+ str(base)
Expand Down Expand Up @@ -381,19 +412,33 @@ fn _identify_base(str_ref: StringRef, start: Int) -> Tuple[Int, Int]:
fn atol(str: String, base: Int = 10) raises -> Int:
"""Parses and returns the given string as an integer in the given base.
For example, `atol("19")` returns `19`. If base is 0 the the string is
parsed as an Integer literal, see: https://docs.python.org/3/reference/lexical_analysis.html#integers.
For example, `atol("32")` returns `32`, and `atol("FF", 16)` returns `255`.
If base is set to 0, the string is parsed as an Integer literal, with the
following considerations:
- '0b' or '0B' prefix indicates binary (base 2)
- '0o' or '0O' prefix indicates octal (base 8)
- '0x' or '0X' prefix indicates hexadecimal (base 16)
- Without a prefix, it's treated as decimal (base 10)
Raises:
If the given string cannot be parsed as an integer value. For example in
`atol("hi")`.
- If the given string cannot be parsed as an integer value. For example
in `atol("Mojo")`.
- Incorrect base is provided.
Args:
str: A string to be parsed as an integer in the given base.
base: Base used for conversion, value must be between 2 and 36, or 0.
Returns:
An integer value that represents the string, or otherwise raises.
Notes:
This follows [Python's integer literals](\
https://docs.python.org/3/reference/lexical_analysis.html#integers).
See Also:
- function `stol`: A similar function that returns both the parsed
integer and the remaining unparsed string.
"""
return _atol(str._strref_dangerous(), base)

Expand Down
16 changes: 11 additions & 5 deletions stdlib/test/builtin/test_string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,9 @@ def test_atol():
assert_equal(10, atol("0o12", 8))
assert_equal(10, atol("0O12", 8))
assert_equal(35, atol("Z", 36))
assert_equal(255, atol("0x_00_ff", 16))
assert_equal(18, atol("0b0001_0010", 2))
assert_equal(18, atol("0b_000_1001_0", 2))

# Negative cases
with assert_raises(
Expand Down Expand Up @@ -433,6 +436,14 @@ def test_atol_base_0():

assert_equal(0, atol("0X0", base=0))

assert_equal(255, atol("0x_00_ff", base=0))

assert_equal(18, atol("0b_0001_0010", base=0))
assert_equal(18, atol("0b000_1001_0", base=0))

assert_equal(10, atol("0o_000_12", base=0))
assert_equal(10, atol("0o00_12", base=0))

with assert_raises(
contains="String is not convertible to integer with base 0: ' 0x'"
):
Expand All @@ -453,11 +464,6 @@ def test_atol_base_0():
):
_ = atol("0r100", base=0)

with assert_raises(
contains="String is not convertible to integer with base 0: '0b_0'"
):
_ = atol("0b_0", base=0)

with assert_raises(
contains="String is not convertible to integer with base 0: '0xf__f'"
):
Expand Down

0 comments on commit c5876a6

Please sign in to comment.