Skip to content

Commit

Permalink
truncate consts early (tinygrad#6741)
Browse files Browse the repository at this point in the history
* truncate consts early

* ptx still fails

* Update dtype.py
  • Loading branch information
geohot authored Sep 25, 2024
1 parent e31552e commit cb22ef3
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 21 deletions.
3 changes: 1 addition & 2 deletions test/test_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import torch
from typing import Any, List
from tinygrad.helpers import getenv, DEBUG, CI
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype, truncate_fp16
from tinygrad import Device, Tensor, dtypes
from tinygrad.tensor import _to_np_dtype
from tinygrad.ops import truncate_fp16
from hypothesis import given, settings, strategies as strat
from test.helpers import is_dtype_supported, rand_for_dtype

Expand Down
2 changes: 1 addition & 1 deletion test/test_multitensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def test_data_parallel_resnet_train_step(self):
shard_output.backward()
shard_grad = m.conv1.weight.grad.numpy()
# sometimes there is zeros in these grads... why?
np.testing.assert_allclose(grad, shard_grad, atol=3e-6, rtol=3e-6)
np.testing.assert_allclose(grad, shard_grad, atol=1e-5, rtol=1e-5)

def test_multi_tensor_jit_param(self):
@TinyJit
Expand Down
4 changes: 2 additions & 2 deletions test/test_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1520,7 +1520,7 @@ def test_sparse_categorical_crossentropy_simple(self):
X = Tensor([[0, 2, 3], [1, 2, 3]]).realize()
Y = Tensor([1, 2]).realize()
loss = X.sparse_categorical_crossentropy(Y)
self.check_schedule(loss, 6)
self.check_schedule(loss, 4)
np.testing.assert_allclose(loss.item(), 0.878309, atol=1e-5, rtol=1e-6)

def test_mnist_val(self):
Expand All @@ -1531,7 +1531,7 @@ def test_mnist_val(self):
yt = Tensor.randn(BS, 10)
with Context(SPLIT_REDUCEOP=0):
loss = yt.sparse_categorical_crossentropy(Y_train[samples])
self.check_schedule(loss, 7)
self.check_schedule(loss, 6)
loss_fused = loss.numpy()
loss_ref = torch.nn.CrossEntropyLoss()(torch.tensor(yt.numpy()), torch.tensor(Y_train.numpy())[torch.tensor(samples.numpy())])
np.testing.assert_allclose(loss_fused, loss_ref.numpy(), atol=1e-6, rtol=1e-6)
Expand Down
16 changes: 15 additions & 1 deletion tinygrad/dtype.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union, Callable
import math, struct, ctypes
from dataclasses import dataclass
import functools
from tinygrad.helpers import getenv
Expand Down Expand Up @@ -63,6 +64,7 @@ def as_const(val: Tuple[ConstType, ...]|ConstType, dtype:DType):
if isinstance(val, tuple):
assert len(val) == dtype.count, f"mismatch {val} {dtype}"
return tuple(dtypes.as_const(x, dtype) for x in val)
# TODO: should truncate here
return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
@staticmethod
@functools.lru_cache(None)
Expand Down Expand Up @@ -144,3 +146,15 @@ def sum_acc_dtype(dt:DType):
if dtypes.is_unsigned(dt): return least_upper_dtype(dt, dtypes.uint)
if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int)
return least_upper_dtype(dt, dtypes.float)

def truncate_fp16(x):
try: return struct.unpack("@e", struct.pack("@e", float(x)))[0]
except OverflowError: return math.copysign(math.inf, x)

truncate: Dict[DType, Callable] = {dtypes.bool: bool,
# TODO: bfloat16
dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value \
if isinstance(x,int) else x, dtypes.int64: lambda x: ctypes.c_int64(x).value}
16 changes: 2 additions & 14 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar, DefaultDict
import sys, time, functools, itertools, math, operator, ctypes, struct, hashlib
import sys, time, functools, itertools, math, operator, hashlib
from enum import auto, IntEnum, Enum
from collections import defaultdict
from dataclasses import dataclass, field
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
from tinygrad.helpers import _CURRENT_KERNEL, ContextVar, pretty_print, prod, getenv, all_same
from tinygrad.shape.symbolic import Variable, sint
if TYPE_CHECKING:
Expand Down Expand Up @@ -310,18 +310,6 @@ def wfxn(*args):
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf,
TernaryOps.MULACC: lambda x,y,z: (x*y)+z, TernaryOps.WHERE: lambda x,y,z: y if x else z}

def truncate_fp16(x):
try: return struct.unpack("@e", struct.pack("@e", float(x)))[0]
except OverflowError: return math.copysign(math.inf, x)

truncate: Dict[DType, Callable] = {dtypes.bool: bool,
# TODO: bfloat16
dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value \
if isinstance(x,int) else x, dtypes.int64: lambda x: ctypes.c_int64(x).value}

def exec_alu(op:Op, dtype:DType, operands):
if dtype.count > 1:
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)])
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3164,7 +3164,7 @@ def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1, label_smoot
"""
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
assert reduction in ("mean", "sum", "none"), "reduction must be one of ['mean', 'sum', 'none']"
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index)
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool)
y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
y = ((y_counter == Y.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask)
Expand Down

0 comments on commit cb22ef3

Please sign in to comment.