diff --git a/test/unit/test_gradient.py b/test/unit/test_gradient.py index 4e6f092e4b7f..bb503f2f1e12 100644 --- a/test/unit/test_gradient.py +++ b/test/unit/test_gradient.py @@ -77,5 +77,21 @@ def test_with_custom_gradient(self): dx = z.gradient(x, gradient=Tensor([3.0]))[0] self.assertListEqual(dx.tolist(), [6.0, 12.0, 18.0]) + def test_broadcast_gradient(self): + x = Tensor([[1.0], [2.0], [3.0]]) + y = Tensor([[10.0, 20.0, 30.0, 40.0]]) + z = (x + y).sum() + dx, dy = z.gradient(x, y) + self.assertListEqual(dx.tolist(), [[4.0], [4.0], [4.0]]) + self.assertListEqual(dy.tolist(), [[3.0, 3.0, 3.0, 3.0]]) + + def test_non_scalar_output(self): + x = Tensor([1.0, 2.0, 3.0]) + z = x * x + with self.assertRaises(AssertionError): z.gradient(x) + dz = Tensor([1.0, 1.0, 1.0]) + dx = z.gradient(x, gradient=dz)[0] + self.assertListEqual(dx.tolist(), [2.0, 4.0, 6.0]) + if __name__ == '__main__': unittest.main()