Skip to content

Commit

Permalink
refactor to use bitwise operations
Browse files Browse the repository at this point in the history
Signed-off-by: martinvuyk <[email protected]>
  • Loading branch information
martinvuyk committed Dec 16, 2024
1 parent 6b51a32 commit 01e8c97
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 183 deletions.
114 changes: 114 additions & 0 deletions stdlib/src/bit/utils.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2024, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""Provides functions for bit manipulation.
You can import these APIs from the `bit` package. For example:
```mojo
from bit.utils import count_leading_zeros
```
"""

from sys.info import bitwidthof


# ===-----------------------------------------------------------------------===#
# bitmasks
# ===-----------------------------------------------------------------------===#


@always_inline
fn is_negative_bitmask(value: Int) -> Int:
"""Get a bitmask of whether the value is negative.
Args:
value: The value to check.
Returns:
A bitmask filled with `1` if the value is negative, filled with `0`
otherwise.
"""
return int(is_negative_bitmask(Scalar[DType.index](value)))


@always_inline
fn is_negative_bitmask[D: DType](value: SIMD[D, _]) -> __type_of(value):
"""Get a bitmask of whether the value is negative.
Parameters:
D: The DType.
Args:
value: The value to check.
Returns:
A bitmask filled with `1` if the value is negative, filled with `0`
otherwise.
"""
constrained[D.is_signed(), "This function is for signed types."]()
return value >> (bitwidthof[D]() - 1)


@always_inline
fn is_true_bitmask[
D: DType
](value: SIMD[DType.bool, _]) -> SIMD[D, __type_of(value).size]:
"""Get a bitmask of whether the value is `True`.
Parameters:
D: The DType.
Args:
value: The value to check.
Returns:
A bitmask filled with `1` if the value is `True`, filled with `0`
otherwise.
"""
return is_negative_bitmask(value.cast[DType.int8]() - 1).cast[D]()


@always_inline
fn are_equal_bitmask(lhs: Int, rhs: Int) -> Int:
"""Get a bitmask of whether the values are equal.
Args:
lhs: The value to check.
rhs: The value to check.
Returns:
A bitmask filled with `1` if the values are equal, filled with `0`
otherwise.
"""
alias S = Scalar[DType.index]
return int(are_equal_bitmask(S(lhs), S(rhs)))


@always_inline
fn are_equal_bitmask[
D: DType
](lhs: SIMD[D, _], rhs: __type_of(lhs)) -> __type_of(lhs):
"""Get a bitmask of whether the values are equal.
Parameters:
D: The DType.
Args:
lhs: The value to check.
rhs: The value to check.
Returns:
A bitmask filled with `1` if the values are equal, filled with `0`
otherwise.
"""
return is_true_bitmask[D](lhs ^ rhs != 0)
43 changes: 15 additions & 28 deletions stdlib/src/builtin/string_literal.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ from utils.format import _CurlyEntryFormattable, _FormatCurlyEntry
from utils.string_slice import (
_StringSliceIter,
_to_string_list,
Stringlike,
_split,
)

Expand All @@ -51,7 +50,6 @@ struct StringLiteral(
FloatableRaising,
BytesCollectionElement,
_HashableWithHasher,
Stringlike,
):
"""This type represents a string literal.
Expand Down Expand Up @@ -440,25 +438,23 @@ struct StringLiteral(
"""
return self.__str__()

fn __iter__(ref self) -> _StringSliceIter[__origin_of(self)]:
fn __iter__(ref self) -> _StringSliceIter[StaticConstantOrigin]:
"""Iterate over the string unicode characters.
Returns:
An iterator of references to the string unicode characters.
"""
return _StringSliceIter[__origin_of(self)](
return _StringSliceIter[StaticConstantOrigin](
unsafe_pointer=self.unsafe_ptr(), length=self.byte_length()
)

fn __reversed__(
ref self,
) -> _StringSliceIter[__origin_of(self), forward=False]:
fn __reversed__(self) -> _StringSliceIter[StaticConstantOrigin, False]:
"""Iterate backwards over the string unicode characters.
Returns:
A reversed iterator of references to the string unicode characters.
"""
return _StringSliceIter[__origin_of(self), forward=False](
return _StringSliceIter[StaticConstantOrigin, forward=False](
unsafe_pointer=self.unsafe_ptr(), length=self.byte_length()
)

Expand Down Expand Up @@ -532,10 +528,10 @@ struct StringLiteral(

@always_inline
fn as_bytes(self) -> Span[Byte, StaticConstantOrigin]:
"""Returns a contiguous slice of bytes.
"""Returns a contiguous Span of the bytes owned by this string.
Returns:
A contiguous slice pointing to bytes.
A contiguous slice pointing to the bytes owned by this string.
Notes:
This does not include the trailing null terminator.
Expand All @@ -546,10 +542,10 @@ struct StringLiteral(

@always_inline
fn as_bytes(ref self) -> Span[Byte, __origin_of(self)]:
"""Returns a contiguous slice of bytes.
"""Returns a contiguous Span of the bytes owned by this string.
Returns:
A contiguous slice pointing to bytes.
A contiguous slice pointing to the bytes owned by this string.
Notes:
This does not include the trailing null terminator.
Expand Down Expand Up @@ -597,13 +593,10 @@ struct StringLiteral(

writer.write(self.as_string_slice())

fn find[T: Stringlike, //](self, substr: T, start: Int = 0) -> Int:
fn find(self, substr: StringLiteral, start: Int = 0) -> Int:
"""Finds the offset of the first occurrence of `substr` starting at
`start`. If not found, returns -1.
Parameters:
T: The type of the substring.
Args:
substr: The substring to find.
start: The offset from which to find.
Expand Down Expand Up @@ -698,12 +691,9 @@ struct StringLiteral(
return result

@always_inline
fn split[T: Stringlike, //](self, sep: T, maxsplit: Int) -> List[String]:
fn split(self, sep: StringSlice, maxsplit: Int) -> List[String]:
"""Split the string by a separator.
Parameters:
T: The type of the separator.
Args:
sep: The string to split on.
maxsplit: The maximum amount of items to split from String.
Expand All @@ -722,15 +712,12 @@ struct StringLiteral(
```
.
"""
return _split[has_maxsplit=True, has_sep=True](self, sep, maxsplit)
return _split[has_maxsplit=True](self, sep, maxsplit)

@always_inline
fn split[T: Stringlike, //](self, sep: T) -> List[String]:
fn split(self, sep: StringSlice) -> List[String]:
"""Split the string by a separator.
Parameters:
T: The type of the separator.
Args:
sep: The string to split on.
Expand All @@ -750,7 +737,7 @@ struct StringLiteral(
```
.
"""
return _split[has_maxsplit=False, has_sep=True](self, sep, -1)
return _split[has_maxsplit=False](self, sep, -1)

@always_inline
fn split(self, *, maxsplit: Int) -> List[String]:
Expand All @@ -770,7 +757,7 @@ struct StringLiteral(
```
.
"""
return _split[has_maxsplit=True, has_sep=False](self, None, maxsplit)
return _split[has_maxsplit=True](self, None, maxsplit)

@always_inline
fn split(self, sep: NoneType = None) -> List[String]:
Expand All @@ -797,7 +784,7 @@ struct StringLiteral(
```
.
"""
return _split[has_maxsplit=False, has_sep=False](self, sep, -1)
return _split[has_maxsplit=False](self, None, -1)

fn splitlines(self, keepends: Bool = False) -> List[String]:
"""Split the string literal at line boundaries. This corresponds to Python's
Expand Down
5 changes: 2 additions & 3 deletions stdlib/src/collections/list.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -492,14 +492,13 @@ struct List[T: CollectionElement, hint_trivial_type: Bool = False](
self.capacity = new_capacity

fn append(mut self, owned value: T):
"""Appends a value to this list. If there is no capacity left, resizes
to twice the current capacity. Except for 0 capacity where it sets 1.
"""Appends a value to this list.
Args:
value: The value to append.
"""
if self.size >= self.capacity:
self._realloc(self.capacity * 2 + int(self.capacity == 0))
self._realloc(max(1, self.capacity * 2))
(self.data + self.size).init_pointee_move(value^)
self.size += 1

Expand Down
Loading

0 comments on commit 01e8c97

Please sign in to comment.