Skip to content

Commit

Permalink
tensor reduction touchup (tinygrad#6402)
Browse files Browse the repository at this point in the history
- fixing spacing
- use get_args to get valid Literal values and raise ValueError to match, and a test for that
- use `Y` to be consistent
  • Loading branch information
chenyuxyz authored Sep 8, 2024
1 parent 65da03e commit 7df4373
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
3 changes: 3 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2109,6 +2109,9 @@ def test_cross_entropy_reductions(self):
for r in ("mean", "sum", "none"):
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, reduction=r),
lambda x,y: x.cross_entropy(y, reduction=r))
self.helper_test_exception([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, reduction="typo"),
lambda x,y: x.cross_entropy(y, reduction="typo"), expected=ValueError)

def test_cross_entropy_smoothing(self):
for ls in (0., 0.3, 0.7, 1.):
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, label_smoothing=ls),
Expand Down
32 changes: 16 additions & 16 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3050,37 +3050,37 @@ def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optio
return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).cast(self.dtype).dropout(dropout_p) @ value

def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor:
assert reduction in ("mean", "sum", "none"), "reduction must be one of ['mean', 'sum', 'none']"
if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")
reductions: Dict[str, Callable[[Tensor], Tensor]] = {"mean": Tensor.mean, "sum": Tensor.sum, "none": lambda x: x}
return reductions [reduction](self)
return reductions[reduction](self)

def binary_crossentropy(self, y:Tensor, reduction:ReductionStr="mean") -> Tensor:
def binary_crossentropy(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor:
"""
Computes the binary cross-entropy loss between `self` and `y`.
Computes the binary cross-entropy loss between `self` and `Y`.
See: https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([0.1, 0.9, 0.2])
y = Tensor([0, 1, 0])
print(t.binary_crossentropy(y).item())
Y = Tensor([0, 1, 0])
print(t.binary_crossentropy(Y).item())
```
"""
return (-y*self.log() - (1-y)*(1-self).log())._do_reduction(reduction)
return (-Y*self.log() - (1-Y)*(1-self).log())._do_reduction(reduction)

def binary_crossentropy_logits(self, y:Tensor, reduction:ReductionStr="mean") -> Tensor:
def binary_crossentropy_logits(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor:
"""
Computes the binary cross-entropy loss between `self` and `y` where `self` is logits.
Computes the binary cross-entropy loss between `self` and `Y` where `self` is logits.
See: https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 2, -3])
y = Tensor([0, 1, 0])
print(t.binary_crossentropy_logits(y).item())
Y = Tensor([0, 1, 0])
print(t.binary_crossentropy_logits(Y).item())
```
"""
return (self.maximum(0) - y * self + (1 + self.abs().neg().exp()).log())._do_reduction(reduction)
return (self.maximum(0) - Y * self + (1 + self.abs().neg().exp()).log())._do_reduction(reduction)

def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1, label_smoothing=0.0, reduction:ReductionStr="mean") -> Tensor:
"""
Expand All @@ -3107,7 +3107,7 @@ def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1, label_smoot
# NOTE: because of ignore_index, we can't use Tensor.mean (so can't use `_do_reduction` here)
return -(unreduced.sum() / loss_mask.sum() if reduction == "mean" else (unreduced.sum() if reduction == "sum" else unreduced))

def cross_entropy(self, y:Tensor, reduction:ReductionStr="mean", label_smoothing:float=0.0) -> Tensor:
def cross_entropy(self, Y:Tensor, reduction:ReductionStr="mean", label_smoothing:float=0.0) -> Tensor:
"""
Compute the cross entropy loss between input logits and target.
Expand All @@ -3127,9 +3127,9 @@ def cross_entropy(self, y:Tensor, reduction:ReductionStr="mean", label_smoothing
```
"""
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
y = y.one_hot(num_classes=cast(int, self.shape[1])) if y.ndim < 2 else y
y = (1 - label_smoothing)*y + label_smoothing / cast(int, y.shape[1])
ret = -self.log_softmax(axis=1).mul(y).sum(axis=1)
Y = Y.one_hot(num_classes=cast(int, self.shape[1])) if Y.ndim < 2 else Y
Y = (1 - label_smoothing)*Y + label_smoothing / cast(int, Y.shape[1])
ret = -self.log_softmax(axis=1).mul(Y).sum(axis=1)
return ret._do_reduction(reduction)

# ***** Tensor Properties *****
Expand Down

0 comments on commit 7df4373

Please sign in to comment.