Skip to content

Commit

Permalink
add unit tests for to_dtype (tinygrad#8217)
Browse files Browse the repository at this point in the history
* add unit test for to_dtype

* add unit test for to_dtype

---------

Co-authored-by: pkotzbach <[email protected]>
  • Loading branch information
pkotzbach and pkotzbach authored Dec 13, 2024
1 parent 8a50868 commit c1b79c1
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion test/test_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, List
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import getenv, DEBUG, CI
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype, truncate_fp16
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype, truncate_fp16, to_dtype
from tinygrad import Device, Tensor, dtypes
from tinygrad.tensor import _to_np_dtype
from hypothesis import given, settings, strategies as strat
Expand Down Expand Up @@ -854,5 +854,18 @@ def test_max_w_alu(self):
t = Tensor([[1, 2], [3, 4]], dtype=d)
(t*t).max().item()

class TestToDtype(unittest.TestCase):
def test_dtype_to_dtype(self):
dtype = dtypes.int32
res = to_dtype(dtype)
self.assertIsInstance(res, DType)
self.assertEqual(res, dtypes.int32)

def test_str_to_dtype(self):
dtype = "int32"
res = to_dtype(dtype)
self.assertIsInstance(res, DType)
self.assertEqual(res, dtypes.int32)

if __name__ == '__main__':
unittest.main()

0 comments on commit c1b79c1

Please sign in to comment.