Skip to content

Commit

Permalink
Add BitMask utils
Browse files Browse the repository at this point in the history
Signed-off-by: martinvuyk <[email protected]>
  • Loading branch information
martinvuyk committed Dec 16, 2024
1 parent 9f4544a commit 824bd02
Show file tree
Hide file tree
Showing 2 changed files with 271 additions and 0 deletions.
163 changes: 163 additions & 0 deletions stdlib/src/bit/mask.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2024, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""Provides functions for bit manipulation.
You can import these APIs from the `bit` package. For example:
```mojo
from bit.utils import count_leading_zeros
```
"""

from os import abort
from sys.info import bitwidthof


struct BitMask:
"""Utils for building bitmasks."""

alias EQ = 0
"""Value for `==`."""
alias NE = 1
"""Value for `!=`."""
alias GT = 2
"""Value for `>`."""
alias GE = 3
"""Value for `>=`."""
alias LT = 4
"""Value for `<`."""
alias LE = 5
"""Value for `<=`."""

@always_inline
@staticmethod
fn is_negative(value: Int) -> Int:
"""Get a bitmask of whether the value is negative.
Args:
value: The value to check.
Returns:
A bitmask filled with `1` if the value is negative, filled with `0`
otherwise.
"""
return int(Self.is_negative(Scalar[DType.index](value)))

@always_inline
@staticmethod
fn is_negative[D: DType](value: SIMD[D, _]) -> __type_of(value):
"""Get a bitmask of whether the value is negative.
Parameters:
D: The DType.
Args:
value: The value to check.
Returns:
A bitmask filled with `1` if the value is negative, filled with `0`
otherwise.
"""
constrained[
D.is_integral() and D.is_signed(),
"This function is for signed integral types.",
]()
return value >> (bitwidthof[D]() - 1)

@always_inline
@staticmethod
fn is_true[
D: DType, size: Int = 1
](value: SIMD[DType.bool, size]) -> SIMD[D, size]:
"""Get a bitmask of whether the value is `True`.
Parameters:
D: The DType.
size: The size of the SIMD vector.
Args:
value: The value to check.
Returns:
A bitmask filled with `1` if the value is `True`, filled with `0`
otherwise.
"""
return Self.is_false[D](~value)

@always_inline
@staticmethod
fn is_false[
D: DType, size: Int = 1
](value: SIMD[DType.bool, size]) -> SIMD[D, size]:
"""Get a bitmask of whether the value is `False`.
Parameters:
D: The DType.
size: The size of the SIMD vector.
Args:
value: The value to check.
Returns:
A bitmask filled with `1` if the value is `False`, filled with `0`
otherwise.
"""
return (value.cast[DType.int8]() - 1).cast[D]()

@always_inline
@staticmethod
fn compare[
D: DType, //, comp: Int
](lhs: SIMD[D, _], rhs: __type_of(lhs)) -> __type_of(lhs):
"""Get a bitmask of the comparison between the two values.
Args:
lhs: The value to check.
rhs: The value to check.
Returns:
A bitmask filled with `1` if the comparison is true, filled with `0`
otherwise.
"""

@parameter
if comp == Self.EQ:
return Self.is_true[D](lhs == rhs)
elif comp == Self.NE:
return Self.is_true[D](lhs != rhs)
elif comp == Self.GT:
return Self.is_true[D](lhs > rhs)
elif comp == Self.GE:
return Self.is_true[D](lhs >= rhs)
elif comp == Self.LT:
return Self.is_true[D](lhs < rhs)
elif comp == Self.LE:
return Self.is_true[D](lhs <= rhs)
else:
constrained[False, "comparison operator value not found"]()
return abort[__type_of(lhs)]()

@staticmethod
fn compare[D: DType, //, comp: Int](lhs: Int, rhs: Int) -> Int:
"""Get a bitmask of the comparison between the two values.
Args:
lhs: The value to check.
rhs: The value to check.
Returns:
A bitmask filled with `1` if the comparison is true, filled with `0`
otherwise.
"""
alias S = Scalar[DType.index]
return int(Self.compare[comp=comp](S(lhs), S(rhs)))
108 changes: 108 additions & 0 deletions stdlib/test/bit/test_mask.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2024, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
# RUN: %bare-mojo %s

from testing import assert_equal
from bit.mask import BitMask
from sys.info import bitwidthof


def test_is_negative():
alias dtypes = (
DType.int8,
DType.int16,
DType.int32,
DType.int64,
DType.index,
)
alias widths = (1, 2, 4, 8)

@parameter
for i in range(len(dtypes)):
alias D = dtypes.get[i, DType]()
var last_value = 2 ** (bitwidthof[D]() - 1) - 1
var values = List(1, 2, last_value - 1, last_value)

@parameter
for j in range(len(widths)):
alias S = SIMD[D, widths.get[j, Int]()]

for k in values:
assert_equal(S(-1), BitMask.is_negative(S(-k[])))
assert_equal(S(0), BitMask.is_negative(S(k[])))


def test_is_true():
alias dtypes = (
DType.int8,
DType.int16,
DType.int32,
DType.int64,
DType.index,
DType.uint8,
DType.uint16,
DType.uint32,
DType.uint64,
)
alias widths = (1, 2, 4, 8)

@parameter
for i in range(len(dtypes)):
alias D = dtypes.get[i, DType]()

@parameter
for j in range(len(widths)):
alias w = widths.get[j, Int]()
alias B = SIMD[DType.bool, w]
assert_equal(SIMD[D, w](-1), BitMask.is_true[D](B(True)))
assert_equal(SIMD[D, w](0), BitMask.is_true[D](B(False)))


def test_compare():
alias dtypes = (
DType.int8,
DType.int16,
DType.int32,
DType.int64,
DType.index,
)
alias widths = (1, 2, 4, 8)

@parameter
for i in range(len(dtypes)):
alias D = dtypes.get[i, DType]()
var last_value = 2 ** (bitwidthof[D]() - 1) - 1
var values = List(1, 2, last_value - 1, last_value)

@parameter
for j in range(len(widths)):
alias S = SIMD[D, widths.get[j, Int]()]

for k in values:
var s_k = S(k[])
var s_k_1 = S(k[] - 1)
assert_equal(S(-1), BitMask.compare[BitMask.EQ](s_k, s_k))
assert_equal(S(-1), BitMask.compare[BitMask.EQ](-s_k, -s_k))
assert_equal(S(-1), BitMask.compare[BitMask.NE](s_k, s_k_1))
assert_equal(S(-1), BitMask.compare[BitMask.NE](-s_k, s_k_1))
assert_equal(S(-1), BitMask.compare[BitMask.GT](s_k, s_k_1))
assert_equal(S(-1), BitMask.compare[BitMask.GT](s_k_1, -s_k))
assert_equal(S(-1), BitMask.compare[BitMask.GE](-s_k, -s_k))
assert_equal(S(-1), BitMask.compare[BitMask.LT](-s_k, s_k_1))
assert_equal(S(-1), BitMask.compare[BitMask.LE](-s_k, -s_k))


def main():
test_is_negative()
test_is_true()
test_compare()

0 comments on commit 824bd02

Please sign in to comment.