Skip to content

Commit

Permalink
hotfix: a few more grad tests
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot committed Dec 14, 2024
1 parent 734f2c5 commit bcd7ea6
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions test/unit/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,21 @@ def test_with_custom_gradient(self):
dx = z.gradient(x, gradient=Tensor([3.0]))[0]
self.assertListEqual(dx.tolist(), [6.0, 12.0, 18.0])

def test_broadcast_gradient(self):
x = Tensor([[1.0], [2.0], [3.0]])
y = Tensor([[10.0, 20.0, 30.0, 40.0]])
z = (x + y).sum()
dx, dy = z.gradient(x, y)
self.assertListEqual(dx.tolist(), [[4.0], [4.0], [4.0]])
self.assertListEqual(dy.tolist(), [[3.0, 3.0, 3.0, 3.0]])

def test_non_scalar_output(self):
x = Tensor([1.0, 2.0, 3.0])
z = x * x
with self.assertRaises(AssertionError): z.gradient(x)
dz = Tensor([1.0, 1.0, 1.0])
dx = z.gradient(x, gradient=dz)[0]
self.assertListEqual(dx.tolist(), [2.0, 4.0, 6.0])

if __name__ == '__main__':
unittest.main()

0 comments on commit bcd7ea6

Please sign in to comment.