Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ops.sum(data, dim=None, keepdims=False) #490

Merged
merged 5 commits into from
Mar 18, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions funsor/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,33 @@ def _find_domain_log_exp(op, domain):
return Array["real", domain.shape]


@find_domain.register(ops.SumOp)
def _find_domain_sum(op, domain):
# Canonicalize dim.
dim = op.defaults.get("dim", None)
ndims = len(domain.shape)
if dim is None:
dims = set(range(ndims))
elif isinstance(dim, int):
dims = {dim % ndims}
else:
dims = {i % ndims for i in dim}

# Compute shape.
if op.defaults.get("keepdims", False):
shape = tuple(1 if i in dims else size for i, size in enumerate(domain.shape))
else:
shape = tuple(size for i, size in enumerate(domain.shape) if i not in dims)

# Compute domain.
if domain.dtype == "real":
dtype = "real"
else:
raise NotImplementedError("TODO")

return Array[dtype, shape]


@find_domain.register(ops.ReshapeOp)
def _find_domain_reshape(op, domain):
return Array[domain.dtype, op.defaults["shape"]]
Expand Down
4 changes: 2 additions & 2 deletions funsor/jax/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,8 @@ def _stack(parts, dim=0):


@ops.sum.register(array)
def _sum(x, dim):
return np.sum(x, axis=dim)
def _sum(x, dim, keepdims):
return np.sum(x, dim, keepdims=keepdims)


@ops.triangular_solve.register(array, array)
Expand Down
4 changes: 2 additions & 2 deletions funsor/ops/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def amin(x, dim=None, keepdims=False):


@UnaryOp.make
def sum(x, dim=None):
return np.sum(x, dim)
def sum(x, dim=None, keepdims=False):
return np.sum(x, dim, keepdims=keepdims)


@UnaryOp.make
Expand Down
19 changes: 19 additions & 0 deletions funsor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,25 @@ def eager_reshape_tensor(op, arg):
return Tensor(data, arg.inputs, arg.dtype)


@eager.register(Unary, ops.SumOp, Tensor)
def eager_reshape_tensor(op, arg):
fritzo marked this conversation as resolved.
Show resolved Hide resolved
if not arg.inputs:
return Tensor(op(arg.data), arg.inputs, arg.dtype)

# Work around batch inputs.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic converting dimensions into batch-aware dimensions seems useful and general-purpose enough that some version of it should maybe live in an Op method or something - we don't want to have to write this from scratch in each new op with nontrivial shape semantics.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this is useful, and I'm hoping @ordabayevy will generalize this in #482 once we have more than one use case. This file seems like the right place for that general logic since funsor.ops should be agnostic to inputs and domains etc.

dim = op.defaults.get("dim", None)
keepdims = op.defaults.get("keepdims", False)
ndims = len(arg.output.shape)
if dim is None:
dim = tuple(range(-ndims, 0))
elif isinstance(dim, int):
dim = dim % ndims - ndims
else:
dim = tuple(d % ndims - ndims for d in dim)
data = op(arg.data, dim, keepdims)
return Tensor(data, arg.inputs, arg.dtype)


@eager.register(Binary, GetitemOp, Tensor, Number)
def eager_getitem_tensor_number(op, lhs, rhs):
offset = op.defaults["offset"]
Expand Down
9 changes: 7 additions & 2 deletions funsor/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,13 @@ def _scatter_add(destin, indices, source):


@ops.sum.register(torch.Tensor)
def _sum(x, dim):
return x.sum() if dim is None else x.sum(dim)
def _sum(x, dim, keepdims):
if dim is None:
if keepdims:
dim = tuple(range(x.dim()))
return x.sum(dim, True)
return x.sum()
return x.sum(dim, keepdims)


@ops.triangular_solve.register(torch.Tensor, torch.Tensor)
Expand Down
50 changes: 50 additions & 0 deletions test/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,3 +1305,53 @@ def test_scatter_pure_renaming():

assert actual.input_vars == expected.input_vars
assert ((actual - expected).abs() < 1e-4).data.all()


@pytest.mark.parametrize("event_shape", [(2, 3, 4)], ids=str)
def test_sum(event_shape):
data = randn(*event_shape)
DIMS = [None, 0, 1, 2, -1, -2, -3, (0, 2)]
KEEPDIMS = [False, True]

assert_close(Tensor(ops.sum(data)), ops.sum(Tensor(data)))
for dim in DIMS:
assert_close(Tensor(ops.sum(data, dim)), ops.sum(Tensor(data), dim))
assert_close(Tensor(ops.sum(data, dim=dim)), ops.sum(Tensor(data), dim=dim))
for keepdims in KEEPDIMS:
assert_close(
Tensor(ops.sum(data, keepdims=keepdims)),
ops.sum(Tensor(data), keepdims=keepdims),
)
for dim in DIMS:
assert_close(
Tensor(ops.sum(data, dim, keepdims)),
ops.sum(Tensor(data), dim, keepdims),
)
assert_close(
Tensor(ops.sum(data, dim, keepdims=keepdims)),
ops.sum(Tensor(data), dim, keepdims=keepdims),
)
assert_close(
Tensor(ops.sum(data, dim=dim, keepdims=keepdims)),
ops.sum(Tensor(data), dim=dim, keepdims=keepdims),
)


@pytest.mark.parametrize("batch_shape", [(), (5,)], ids=str)
@pytest.mark.parametrize("event_shape", [(2, 3, 4)], ids=str)
def test_sum_batch(batch_shape, event_shape):
inputs = OrderedDict((k, Bint[s]) for k, s in zip("abc", batch_shape))
data = randn(*batch_shape, *event_shape)
DIMS = [None, 0, 1, 2, -1, -2, -3, (0, 2)]
KEEPDIMS = [False, True]

def raw_sum(x, dim=None, keepdims=False, batch_ndims=len(batch_shape)):
if batch_ndims == 0:
return ops.sum(x, dim, keepdims)
return ops.stack([raw_sum(part, dim, keepdims, batch_ndims - 1) for part in x])

for keepdims in KEEPDIMS:
for dim in DIMS:
actual = ops.sum(Tensor(data, inputs), dim, keepdims)
expected = Tensor(raw_sum(data, dim, keepdims), inputs)
assert_close(actual, expected)