From 7df4373fd92aad98dd4539ff7566d603774092b5 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 8 Sep 2024 03:55:51 -0400 Subject: [PATCH] tensor reduction touchup (#6402) - fixing spacing - use get_args to get valid Literal values and raise ValueError to match, and a test for that - use `Y` to be consistent --- test/test_ops.py | 3 +++ tinygrad/tensor.py | 32 ++++++++++++++++---------------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index cadceb2c7da7a..29882050a87eb 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2109,6 +2109,9 @@ def test_cross_entropy_reductions(self): for r in ("mean", "sum", "none"): helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, reduction=r), lambda x,y: x.cross_entropy(y, reduction=r)) + self.helper_test_exception([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, reduction="typo"), + lambda x,y: x.cross_entropy(y, reduction="typo"), expected=ValueError) + def test_cross_entropy_smoothing(self): for ls in (0., 0.3, 0.7, 1.): helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, label_smoothing=ls), diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index df2bc382e106a..5c35fc24e52cc 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3050,37 +3050,37 @@ def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optio return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).cast(self.dtype).dropout(dropout_p) @ value def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor: - assert reduction in ("mean", "sum", "none"), "reduction must be one of ['mean', 'sum', 'none']" + if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}") reductions: Dict[str, Callable[[Tensor], Tensor]] = {"mean": Tensor.mean, "sum": Tensor.sum, "none": lambda x: x} - return reductions [reduction](self) + return reductions[reduction](self) - def binary_crossentropy(self, y:Tensor, reduction:ReductionStr="mean") -> Tensor: + def binary_crossentropy(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor: """ - Computes the binary cross-entropy loss between `self` and `y`. + Computes the binary cross-entropy loss between `self` and `Y`. See: https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html ```python exec="true" source="above" session="tensor" result="python" t = Tensor([0.1, 0.9, 0.2]) - y = Tensor([0, 1, 0]) - print(t.binary_crossentropy(y).item()) + Y = Tensor([0, 1, 0]) + print(t.binary_crossentropy(Y).item()) ``` """ - return (-y*self.log() - (1-y)*(1-self).log())._do_reduction(reduction) + return (-Y*self.log() - (1-Y)*(1-self).log())._do_reduction(reduction) - def binary_crossentropy_logits(self, y:Tensor, reduction:ReductionStr="mean") -> Tensor: + def binary_crossentropy_logits(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor: """ - Computes the binary cross-entropy loss between `self` and `y` where `self` is logits. + Computes the binary cross-entropy loss between `self` and `Y` where `self` is logits. See: https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html ```python exec="true" source="above" session="tensor" result="python" t = Tensor([-1, 2, -3]) - y = Tensor([0, 1, 0]) - print(t.binary_crossentropy_logits(y).item()) + Y = Tensor([0, 1, 0]) + print(t.binary_crossentropy_logits(Y).item()) ``` """ - return (self.maximum(0) - y * self + (1 + self.abs().neg().exp()).log())._do_reduction(reduction) + return (self.maximum(0) - Y * self + (1 + self.abs().neg().exp()).log())._do_reduction(reduction) def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1, label_smoothing=0.0, reduction:ReductionStr="mean") -> Tensor: """ @@ -3107,7 +3107,7 @@ def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1, label_smoot # NOTE: because of ignore_index, we can't use Tensor.mean (so can't use `_do_reduction` here) return -(unreduced.sum() / loss_mask.sum() if reduction == "mean" else (unreduced.sum() if reduction == "sum" else unreduced)) - def cross_entropy(self, y:Tensor, reduction:ReductionStr="mean", label_smoothing:float=0.0) -> Tensor: + def cross_entropy(self, Y:Tensor, reduction:ReductionStr="mean", label_smoothing:float=0.0) -> Tensor: """ Compute the cross entropy loss between input logits and target. @@ -3127,9 +3127,9 @@ def cross_entropy(self, y:Tensor, reduction:ReductionStr="mean", label_smoothing ``` """ assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]" - y = y.one_hot(num_classes=cast(int, self.shape[1])) if y.ndim < 2 else y - y = (1 - label_smoothing)*y + label_smoothing / cast(int, y.shape[1]) - ret = -self.log_softmax(axis=1).mul(y).sum(axis=1) + Y = Y.one_hot(num_classes=cast(int, self.shape[1])) if Y.ndim < 2 else Y + Y = (1 - label_smoothing)*Y + label_smoothing / cast(int, Y.shape[1]) + ret = -self.log_softmax(axis=1).mul(Y).sum(axis=1) return ret._do_reduction(reduction) # ***** Tensor Properties *****