Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Add BitMask utils #3886

Open
wants to merge 7 commits into
base: nightly
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions stdlib/src/bit/mask.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2024, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""Provides functions for bit masks.

You can import these APIs from the `bit` package. For example:

```mojo
from bit.mask import BitMask
```
"""

from os import abort
from sys.info import bitwidthof


struct BitMask:
"""Utils for building bitmasks."""

alias EQ = 0
"""Value for `==`."""
alias NE = 1
"""Value for `!=`."""
alias GT = 2
"""Value for `>`."""
alias GE = 3
"""Value for `>=`."""
alias LT = 4
"""Value for `<`."""
alias LE = 5
"""Value for `<=`."""

@always_inline
@staticmethod
fn is_negative(value: Int) -> Int:
"""Get a bitmask of whether the value is negative.

Args:
value: The value to check.

Returns:
A bitmask filled with `1` if the value is negative, filled with `0`
otherwise.
"""
return int(Self.is_negative(Scalar[DType.index](value)))

@always_inline
@staticmethod
fn is_negative[D: DType](value: SIMD[D, _]) -> __type_of(value):
"""Get a bitmask of whether the value is negative.

Parameters:
D: The DType.

Args:
value: The value to check.

Returns:
A bitmask filled with `1` if the value is negative, filled with `0`
otherwise.
"""
constrained[
D.is_integral() and D.is_signed(),
"This function is for signed integral types.",
]()
return value >> (bitwidthof[D]() - 1)

@always_inline
@staticmethod
fn is_true[
D: DType, size: Int = 1
](value: SIMD[DType.bool, size]) -> SIMD[D, size]:
"""Get a bitmask of whether the value is `True`.

Parameters:
D: The DType.
size: The size of the SIMD vector.

Args:
value: The value to check.

Returns:
A bitmask filled with `1` if the value is `True`, filled with `0`
otherwise.
"""
return (-(value.cast[DType.int8]())).cast[D]()

@always_inline
@staticmethod
fn is_false[
D: DType, size: Int = 1
](value: SIMD[DType.bool, size]) -> SIMD[D, size]:
"""Get a bitmask of whether the value is `False`.

Parameters:
D: The DType.
size: The size of the SIMD vector.

Args:
value: The value to check.

Returns:
A bitmask filled with `1` if the value is `False`, filled with `0`
otherwise.
"""
return Self.is_true[D](~value)

@always_inline
@staticmethod
fn compare[
D: DType, //, comp: Int
](lhs: SIMD[D, _], rhs: __type_of(lhs)) -> __type_of(lhs):
"""Get a bitmask of the comparison between the two values.

Parameters:
D: The DType.
comp: The comparison operator, e.g. `BitMask.EQ`.

Args:
lhs: The value to check.
rhs: The value to check.

Returns:
A bitmask filled with `1` if the comparison is true, filled with `0`
otherwise.
"""

@parameter
if comp == Self.EQ:
return Self.is_true[D](lhs == rhs)
elif comp == Self.NE:
return Self.is_true[D](lhs != rhs)
elif comp == Self.GT:
return Self.is_true[D](lhs > rhs)
elif comp == Self.GE:
return Self.is_true[D](lhs >= rhs)
elif comp == Self.LT:
return Self.is_true[D](lhs < rhs)
elif comp == Self.LE:
return Self.is_true[D](lhs <= rhs)
else:
constrained[False, "comparison operator value not found"]()
return abort[__type_of(lhs)]()

@staticmethod
fn compare[comp: Int](lhs: Int, rhs: Int) -> Int:
"""Get a bitmask of the comparison between the two values.

Parameters:
comp: The comparison operator, e.g. `BitMask.EQ`.

Args:
lhs: The value to check.
rhs: The value to check.

Returns:
A bitmask filled with `1` if the comparison is true, filled with `0`
otherwise.
"""
alias S = Scalar[DType.index]
return int(Self.compare[comp=comp](S(lhs), S(rhs)))
108 changes: 108 additions & 0 deletions stdlib/test/bit/test_mask.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2024, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
# RUN: %bare-mojo %s

from testing import assert_equal
from bit.mask import BitMask
from sys.info import bitwidthof


def test_is_negative():
alias dtypes = (
DType.int8,
DType.int16,
DType.int32,
DType.int64,
DType.index,
)
alias widths = (1, 2, 4, 8)

@parameter
for i in range(len(dtypes)):
alias D = dtypes.get[i, DType]()
var last_value = 2 ** (bitwidthof[D]() - 1) - 1
var values = List(1, 2, last_value - 1, last_value)

@parameter
for j in range(len(widths)):
alias S = SIMD[D, widths.get[j, Int]()]

for k in values:
assert_equal(S(-1), BitMask.is_negative(S(-k[])))
assert_equal(S(0), BitMask.is_negative(S(k[])))


def test_is_true():
alias dtypes = (
DType.int8,
DType.int16,
DType.int32,
DType.int64,
DType.index,
DType.uint8,
DType.uint16,
DType.uint32,
DType.uint64,
)
alias widths = (1, 2, 4, 8)

@parameter
for i in range(len(dtypes)):
alias D = dtypes.get[i, DType]()

@parameter
for j in range(len(widths)):
alias w = widths.get[j, Int]()
alias B = SIMD[DType.bool, w]
assert_equal(SIMD[D, w](-1), BitMask.is_true[D](B(True)))
assert_equal(SIMD[D, w](0), BitMask.is_true[D](B(False)))


def test_compare():
alias dtypes = (
DType.int8,
DType.int16,
DType.int32,
DType.int64,
DType.index,
)
alias widths = (1, 2, 4, 8)

@parameter
for i in range(len(dtypes)):
alias D = dtypes.get[i, DType]()
var last_value = 2 ** (bitwidthof[D]() - 1) - 1
var values = List(1, 2, last_value - 1, last_value)

@parameter
for j in range(len(widths)):
alias S = SIMD[D, widths.get[j, Int]()]

for k in values:
var s_k = S(k[])
var s_k_1 = S(k[] - 1)
assert_equal(S(-1), BitMask.compare[BitMask.EQ](s_k, s_k))
assert_equal(S(-1), BitMask.compare[BitMask.EQ](-s_k, -s_k))
assert_equal(S(-1), BitMask.compare[BitMask.NE](s_k, s_k_1))
assert_equal(S(-1), BitMask.compare[BitMask.NE](-s_k, s_k_1))
assert_equal(S(-1), BitMask.compare[BitMask.GT](s_k, s_k_1))
assert_equal(S(-1), BitMask.compare[BitMask.GT](s_k_1, -s_k))
assert_equal(S(-1), BitMask.compare[BitMask.GE](-s_k, -s_k))
assert_equal(S(-1), BitMask.compare[BitMask.LT](-s_k, s_k_1))
assert_equal(S(-1), BitMask.compare[BitMask.LE](-s_k, -s_k))


def main():
test_is_negative()
test_is_true()
test_compare()
Loading