diff --git a/funsor/domains.py b/funsor/domains.py index bfc358564..3dba4b34b 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -248,6 +248,33 @@ def _find_domain_log_exp(op, domain): return Array["real", domain.shape] +@find_domain.register(ops.SumOp) +def _find_domain_sum(op, domain): + # Canonicalize dim. + dim = op.defaults.get("dim", None) + ndims = len(domain.shape) + if dim is None: + dims = set(range(ndims)) + elif isinstance(dim, int): + dims = {dim % ndims} + else: + dims = {i % ndims for i in dim} + + # Compute shape. + if op.defaults.get("keepdims", False): + shape = tuple(1 if i in dims else size for i, size in enumerate(domain.shape)) + else: + shape = tuple(size for i, size in enumerate(domain.shape) if i not in dims) + + # Compute domain. + if domain.dtype == "real": + dtype = "real" + else: + raise NotImplementedError("TODO") + + return Array[dtype, shape] + + @find_domain.register(ops.ReshapeOp) def _find_domain_reshape(op, domain): return Array[domain.dtype, op.defaults["shape"]] diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index be784071d..05dccc2ec 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -258,8 +258,8 @@ def _stack(parts, dim=0): @ops.sum.register(array) -def _sum(x, dim): - return np.sum(x, axis=dim) +def _sum(x, dim, keepdims): + return np.sum(x, dim, keepdims=keepdims) @ops.triangular_solve.register(array, array) diff --git a/funsor/ops/array.py b/funsor/ops/array.py index 3e6bc9289..97831fa9b 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -70,8 +70,8 @@ def amin(x, dim=None, keepdims=False): @UnaryOp.make -def sum(x, dim=None): - return np.sum(x, dim) +def sum(x, dim=None, keepdims=False): + return np.sum(x, dim, keepdims=keepdims) @UnaryOp.make diff --git a/funsor/tensor.py b/funsor/tensor.py index c34bb4ba9..ec86524c3 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -772,6 +772,25 @@ def eager_reshape_tensor(op, arg): return Tensor(data, arg.inputs, arg.dtype) +@eager.register(Unary, ops.SumOp, Tensor) +def eager_sum_tensor(op, arg): + if not arg.inputs: + return Tensor(op(arg.data), arg.inputs, arg.dtype) + + # Work around batch inputs. + dim = op.defaults.get("dim", None) + keepdims = op.defaults.get("keepdims", False) + ndims = len(arg.output.shape) + if dim is None: + dim = tuple(range(-ndims, 0)) + elif isinstance(dim, int): + dim = dim % ndims - ndims + else: + dim = tuple(d % ndims - ndims for d in dim) + data = op(arg.data, dim, keepdims) + return Tensor(data, arg.inputs, arg.dtype) + + @eager.register(Binary, GetitemOp, Tensor, Number) def eager_getitem_tensor_number(op, lhs, rhs): offset = op.defaults["offset"] diff --git a/funsor/torch/ops.py b/funsor/torch/ops.py index f99b9bd25..52e03f64e 100644 --- a/funsor/torch/ops.py +++ b/funsor/torch/ops.py @@ -271,8 +271,13 @@ def _scatter_add(destin, indices, source): @ops.sum.register(torch.Tensor) -def _sum(x, dim): - return x.sum() if dim is None else x.sum(dim) +def _sum(x, dim, keepdims): + if dim is None: + if keepdims: + dim = tuple(range(x.dim())) + return x.sum(dim, True) + return x.sum() + return x.sum(dim, keepdims) @ops.triangular_solve.register(torch.Tensor, torch.Tensor) diff --git a/test/test_tensor.py b/test/test_tensor.py index bacc5cdc0..203c942bf 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -1305,3 +1305,53 @@ def test_scatter_pure_renaming(): assert actual.input_vars == expected.input_vars assert ((actual - expected).abs() < 1e-4).data.all() + + +@pytest.mark.parametrize("event_shape", [(2, 3, 4)], ids=str) +def test_sum(event_shape): + data = randn(*event_shape) + DIMS = [None, 0, 1, 2, -1, -2, -3, (0, 2)] + KEEPDIMS = [False, True] + + assert_close(Tensor(ops.sum(data)), ops.sum(Tensor(data))) + for dim in DIMS: + assert_close(Tensor(ops.sum(data, dim)), ops.sum(Tensor(data), dim)) + assert_close(Tensor(ops.sum(data, dim=dim)), ops.sum(Tensor(data), dim=dim)) + for keepdims in KEEPDIMS: + assert_close( + Tensor(ops.sum(data, keepdims=keepdims)), + ops.sum(Tensor(data), keepdims=keepdims), + ) + for dim in DIMS: + assert_close( + Tensor(ops.sum(data, dim, keepdims)), + ops.sum(Tensor(data), dim, keepdims), + ) + assert_close( + Tensor(ops.sum(data, dim, keepdims=keepdims)), + ops.sum(Tensor(data), dim, keepdims=keepdims), + ) + assert_close( + Tensor(ops.sum(data, dim=dim, keepdims=keepdims)), + ops.sum(Tensor(data), dim=dim, keepdims=keepdims), + ) + + +@pytest.mark.parametrize("batch_shape", [(), (5,)], ids=str) +@pytest.mark.parametrize("event_shape", [(2, 3, 4)], ids=str) +def test_sum_batch(batch_shape, event_shape): + inputs = OrderedDict((k, Bint[s]) for k, s in zip("abc", batch_shape)) + data = randn(*batch_shape, *event_shape) + DIMS = [None, 0, 1, 2, -1, -2, -3, (0, 2)] + KEEPDIMS = [False, True] + + def raw_sum(x, dim=None, keepdims=False, batch_ndims=len(batch_shape)): + if batch_ndims == 0: + return ops.sum(x, dim, keepdims) + return ops.stack([raw_sum(part, dim, keepdims, batch_ndims - 1) for part in x]) + + for keepdims in KEEPDIMS: + for dim in DIMS: + actual = ops.sum(Tensor(data, inputs), dim, keepdims) + expected = Tensor(raw_sum(data, dim, keepdims), inputs) + assert_close(actual, expected)