-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: martinvuyk <[email protected]>
- Loading branch information
1 parent
9f4544a
commit 824bd02
Showing
2 changed files
with
271 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
# ===----------------------------------------------------------------------=== # | ||
# Copyright (c) 2024, Modular Inc. All rights reserved. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions: | ||
# https://llvm.org/LICENSE.txt | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ===----------------------------------------------------------------------=== # | ||
"""Provides functions for bit manipulation. | ||
You can import these APIs from the `bit` package. For example: | ||
```mojo | ||
from bit.utils import count_leading_zeros | ||
``` | ||
""" | ||
|
||
from os import abort | ||
from sys.info import bitwidthof | ||
|
||
|
||
struct BitMask: | ||
"""Utils for building bitmasks.""" | ||
|
||
alias EQ = 0 | ||
"""Value for `==`.""" | ||
alias NE = 1 | ||
"""Value for `!=`.""" | ||
alias GT = 2 | ||
"""Value for `>`.""" | ||
alias GE = 3 | ||
"""Value for `>=`.""" | ||
alias LT = 4 | ||
"""Value for `<`.""" | ||
alias LE = 5 | ||
"""Value for `<=`.""" | ||
|
||
@always_inline | ||
@staticmethod | ||
fn is_negative(value: Int) -> Int: | ||
"""Get a bitmask of whether the value is negative. | ||
Args: | ||
value: The value to check. | ||
Returns: | ||
A bitmask filled with `1` if the value is negative, filled with `0` | ||
otherwise. | ||
""" | ||
return int(Self.is_negative(Scalar[DType.index](value))) | ||
|
||
@always_inline | ||
@staticmethod | ||
fn is_negative[D: DType](value: SIMD[D, _]) -> __type_of(value): | ||
"""Get a bitmask of whether the value is negative. | ||
Parameters: | ||
D: The DType. | ||
Args: | ||
value: The value to check. | ||
Returns: | ||
A bitmask filled with `1` if the value is negative, filled with `0` | ||
otherwise. | ||
""" | ||
constrained[ | ||
D.is_integral() and D.is_signed(), | ||
"This function is for signed integral types.", | ||
]() | ||
return value >> (bitwidthof[D]() - 1) | ||
|
||
@always_inline | ||
@staticmethod | ||
fn is_true[ | ||
D: DType, size: Int = 1 | ||
](value: SIMD[DType.bool, size]) -> SIMD[D, size]: | ||
"""Get a bitmask of whether the value is `True`. | ||
Parameters: | ||
D: The DType. | ||
size: The size of the SIMD vector. | ||
Args: | ||
value: The value to check. | ||
Returns: | ||
A bitmask filled with `1` if the value is `True`, filled with `0` | ||
otherwise. | ||
""" | ||
return Self.is_false[D](~value) | ||
|
||
@always_inline | ||
@staticmethod | ||
fn is_false[ | ||
D: DType, size: Int = 1 | ||
](value: SIMD[DType.bool, size]) -> SIMD[D, size]: | ||
"""Get a bitmask of whether the value is `False`. | ||
Parameters: | ||
D: The DType. | ||
size: The size of the SIMD vector. | ||
Args: | ||
value: The value to check. | ||
Returns: | ||
A bitmask filled with `1` if the value is `False`, filled with `0` | ||
otherwise. | ||
""" | ||
return (value.cast[DType.int8]() - 1).cast[D]() | ||
|
||
@always_inline | ||
@staticmethod | ||
fn compare[ | ||
D: DType, //, comp: Int | ||
](lhs: SIMD[D, _], rhs: __type_of(lhs)) -> __type_of(lhs): | ||
"""Get a bitmask of the comparison between the two values. | ||
Args: | ||
lhs: The value to check. | ||
rhs: The value to check. | ||
Returns: | ||
A bitmask filled with `1` if the comparison is true, filled with `0` | ||
otherwise. | ||
""" | ||
|
||
@parameter | ||
if comp == Self.EQ: | ||
return Self.is_true[D](lhs == rhs) | ||
elif comp == Self.NE: | ||
return Self.is_true[D](lhs != rhs) | ||
elif comp == Self.GT: | ||
return Self.is_true[D](lhs > rhs) | ||
elif comp == Self.GE: | ||
return Self.is_true[D](lhs >= rhs) | ||
elif comp == Self.LT: | ||
return Self.is_true[D](lhs < rhs) | ||
elif comp == Self.LE: | ||
return Self.is_true[D](lhs <= rhs) | ||
else: | ||
constrained[False, "comparison operator value not found"]() | ||
return abort[__type_of(lhs)]() | ||
|
||
@staticmethod | ||
fn compare[D: DType, //, comp: Int](lhs: Int, rhs: Int) -> Int: | ||
"""Get a bitmask of the comparison between the two values. | ||
Args: | ||
lhs: The value to check. | ||
rhs: The value to check. | ||
Returns: | ||
A bitmask filled with `1` if the comparison is true, filled with `0` | ||
otherwise. | ||
""" | ||
alias S = Scalar[DType.index] | ||
return int(Self.compare[comp=comp](S(lhs), S(rhs))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# ===----------------------------------------------------------------------=== # | ||
# Copyright (c) 2024, Modular Inc. All rights reserved. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions: | ||
# https://llvm.org/LICENSE.txt | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ===----------------------------------------------------------------------=== # | ||
# RUN: %bare-mojo %s | ||
|
||
from testing import assert_equal | ||
from bit.mask import BitMask | ||
from sys.info import bitwidthof | ||
|
||
|
||
def test_is_negative(): | ||
alias dtypes = ( | ||
DType.int8, | ||
DType.int16, | ||
DType.int32, | ||
DType.int64, | ||
DType.index, | ||
) | ||
alias widths = (1, 2, 4, 8) | ||
|
||
@parameter | ||
for i in range(len(dtypes)): | ||
alias D = dtypes.get[i, DType]() | ||
var last_value = 2 ** (bitwidthof[D]() - 1) - 1 | ||
var values = List(1, 2, last_value - 1, last_value) | ||
|
||
@parameter | ||
for j in range(len(widths)): | ||
alias S = SIMD[D, widths.get[j, Int]()] | ||
|
||
for k in values: | ||
assert_equal(S(-1), BitMask.is_negative(S(-k[]))) | ||
assert_equal(S(0), BitMask.is_negative(S(k[]))) | ||
|
||
|
||
def test_is_true(): | ||
alias dtypes = ( | ||
DType.int8, | ||
DType.int16, | ||
DType.int32, | ||
DType.int64, | ||
DType.index, | ||
DType.uint8, | ||
DType.uint16, | ||
DType.uint32, | ||
DType.uint64, | ||
) | ||
alias widths = (1, 2, 4, 8) | ||
|
||
@parameter | ||
for i in range(len(dtypes)): | ||
alias D = dtypes.get[i, DType]() | ||
|
||
@parameter | ||
for j in range(len(widths)): | ||
alias w = widths.get[j, Int]() | ||
alias B = SIMD[DType.bool, w] | ||
assert_equal(SIMD[D, w](-1), BitMask.is_true[D](B(True))) | ||
assert_equal(SIMD[D, w](0), BitMask.is_true[D](B(False))) | ||
|
||
|
||
def test_compare(): | ||
alias dtypes = ( | ||
DType.int8, | ||
DType.int16, | ||
DType.int32, | ||
DType.int64, | ||
DType.index, | ||
) | ||
alias widths = (1, 2, 4, 8) | ||
|
||
@parameter | ||
for i in range(len(dtypes)): | ||
alias D = dtypes.get[i, DType]() | ||
var last_value = 2 ** (bitwidthof[D]() - 1) - 1 | ||
var values = List(1, 2, last_value - 1, last_value) | ||
|
||
@parameter | ||
for j in range(len(widths)): | ||
alias S = SIMD[D, widths.get[j, Int]()] | ||
|
||
for k in values: | ||
var s_k = S(k[]) | ||
var s_k_1 = S(k[] - 1) | ||
assert_equal(S(-1), BitMask.compare[BitMask.EQ](s_k, s_k)) | ||
assert_equal(S(-1), BitMask.compare[BitMask.EQ](-s_k, -s_k)) | ||
assert_equal(S(-1), BitMask.compare[BitMask.NE](s_k, s_k_1)) | ||
assert_equal(S(-1), BitMask.compare[BitMask.NE](-s_k, s_k_1)) | ||
assert_equal(S(-1), BitMask.compare[BitMask.GT](s_k, s_k_1)) | ||
assert_equal(S(-1), BitMask.compare[BitMask.GT](s_k_1, -s_k)) | ||
assert_equal(S(-1), BitMask.compare[BitMask.GE](-s_k, -s_k)) | ||
assert_equal(S(-1), BitMask.compare[BitMask.LT](-s_k, s_k_1)) | ||
assert_equal(S(-1), BitMask.compare[BitMask.LE](-s_k, -s_k)) | ||
|
||
|
||
def main(): | ||
test_is_negative() | ||
test_is_true() | ||
test_compare() |