forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_scalartype.py
36 lines (32 loc) · 1.11 KB
/
test_scalartype.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import pytest
import torch
from vllm.scalar_type import scalar_types
@pytest.mark.parametrize("type_tuple", (
(-8, 7, scalar_types.int4),
(0, 15, scalar_types.uint4),
(-8, 7, scalar_types.uint4b8),
(-128, 127, scalar_types.uint8b128),
(-28., 28., scalar_types.float6_e3m2f),
(torch.int8, scalar_types.int8),
(torch.uint8, scalar_types.uint8),
(torch.float8_e5m2, scalar_types.float8_e5m2),
(torch.float8_e4m3fn, scalar_types.float8_e4m3fn),
(torch.bfloat16, scalar_types.float16_e8m7),
(torch.float16, scalar_types.float16_e5m10),
),
ids=lambda x: str(x))
def test_scalar_type_min_max(type_tuple):
print(type_tuple)
if len(type_tuple) == 3:
min, max, t = type_tuple
else:
torch_type, t = type_tuple
if torch_type.is_floating_point:
min = torch.finfo(torch_type).min
max = torch.finfo(torch_type).max
else:
min = torch.iinfo(torch_type).min
max = torch.iinfo(torch_type).max
print(t, min, max, t.min(), t.max())
assert min == t.min()
assert max == t.max()