From da2245a45834904ec4a2cc5372e4f0b561cbb3d4 Mon Sep 17 00:00:00 2001 From: uuuvn <83587632+uuuvn@users.noreply.github.com> Date: Sun, 15 Dec 2024 21:24:05 +0200 Subject: [PATCH] Fix double => half cast on clang (#8265) --- test/test_dtype.py | 3 --- tinygrad/renderer/cstyle.py | 3 +++ 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_dtype.py b/test/test_dtype.py index 60fb28a4c019..07c59dce00e3 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -52,9 +52,6 @@ def _test_cast(a:Tensor, target_dtype:DType): if target_dtype == dtypes.half and Device.DEFAULT == "PYTHON": # TODO: struct.pack cannot pack value > 65504 (max of half) into e format a = (a > 65504).where(65504, a) - if CI and Device.DEFAULT == "CLANG" and (target_dtype, a.dtype) in [(dtypes.double, dtypes.half), (dtypes.half, dtypes.double)]: - # TODO: cast between double and half are broken https://github.com/tinygrad/tinygrad/issues/4084 - return _test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(_to_np_dtype(target_dtype)))) def _test_bitcast(a:Tensor, target_dtype:DType, target=None): diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 5425e0966e9b..4d7aaa8c23e4 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -172,6 +172,9 @@ class ClangRenderer(CStyleLanguage): type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"} code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2]}), Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})"} + # LLVM legalizes double => half cast on systems that don't support it natively (like x86 cpus without AVX512-FP16) into a compiler-rt libcall. + extra_matcher = PatternMatcher([(UPat.var("x", dtypes.float64).cast(dtypes.float16), lambda x: x.cast(dtypes.float32).cast(dtypes.float16))]) + \ + CStyleLanguage.extra_matcher if AMX: tensor_cores = [TensorCore(dims=(sz,sz,1), threads=[], reduce_axes=[], upcast_axes=([(1,sz)],[(0,sz)],[(1,sz),(0,sz)]), dtype_in=dt, dtype_out=dt)