Skip to content

Commit

Permalink
remove bitcast backward in function.py (tinygrad#7031)
Browse files Browse the repository at this point in the history
bitcast cannot backward
  • Loading branch information
chenyuxyz authored Oct 13, 2024
1 parent ace834e commit 13575f0
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
3 changes: 3 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2144,6 +2144,9 @@ def test_cast(self):
helper_test_op([(3, 3)], lambda x: x.int(), forward_only=True)
helper_test_op([(3, 3)], lambda x: x.bool(), forward_only=True)

def test_bitcast(self):
helper_test_op([(3, 3)], lambda x: x.view(torch.int32), lambda x: x.bitcast(dtypes.int32), forward_only=True)

class TestOpsUint8(unittest.TestCase):
@unittest.skip('this is broken for negative numbers')
def test_cast(self):
Expand Down
3 changes: 2 additions & 1 deletion tinygrad/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def forward(self, x:LazyBuffer, dtype:DType, bitcast:bool=False) -> LazyBuffer:
return x.bitcast(dtype) if self.bitcast else x.cast(dtype)

def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.bitcast(self.input_dtype) if self.bitcast else grad_output.cast(self.input_dtype)
if self.bitcast: raise RuntimeError("bitcast cannot backward")
return grad_output.cast(self.input_dtype)

# ************* unary ops *************

Expand Down

0 comments on commit 13575f0

Please sign in to comment.