From 7afff92257a9f63be15be3e5d34d2b51adaefa56 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 12 Mar 2021 21:34:29 -0500 Subject: [PATCH 1/4] WIP sketch ParametrizedOp and a richer ops.sum --- funsor/domains.py | 26 ++++++++++++++++++++++++++ funsor/ops/array.py | 2 +- funsor/ops/op.py | 29 +++++++++++++++++++++++------ funsor/torch/ops.py | 13 +++++++++++++ test/test_tensor.py | 31 +++++++++++++++++++++++++++++++ 5 files changed, 94 insertions(+), 7 deletions(-) diff --git a/funsor/domains.py b/funsor/domains.py index 617e965c5..5631dca45 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -248,6 +248,32 @@ def _find_domain_log_exp(op, domain): return Array["real", domain.shape] +@find_domain.register(ops.SumOp) +def _find_domain_sum(op, domain): + # Canonicalize axes. + ndim = len(domain.shape) + if op.axis is None: + axis = set(range(ndim)) + elif isinstance(op.axis, int): + axis = {op.axis % ndim} + else: + axis = {i % ndim for i in op.axis} + + # Compute shape. + if op.keepdims: + shape = tuple(1 if i in axis else size for i, size in enumerate(domain.shape)) + else: + shape = tuple(size for i, size in enumerate(domain.shape) if i not in axis) + + # Compute domain. + if domain.dtype == "real": + dtype = "real" + else: + raise NotImplementedError("TODO") + + return Array[dtype, shape] + + @find_domain.register(ops.ReshapeOp) def _find_domain_reshape(op, domain): return Array[domain.dtype, op.shape] diff --git a/funsor/ops/array.py b/funsor/ops/array.py index d31f95668..730013fbd 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -42,7 +42,7 @@ isnan = make_op(np.isnan) prod = make_op(np.prod) stack = make_op("stack") -sum = make_op(np.sum) +sum = make_op(np.sum, params=("axis", "keepdims")) transpose = make_op("transpose") sqrt.register(array)(np.sqrt) diff --git a/funsor/ops/op.py b/funsor/ops/op.py index 12d6f5769..e4a134f69 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -121,25 +121,42 @@ def decorator(fn): return decorator -def make_op(fn=None, parent=None, *, name=None, module_name="funsor.ops"): +class ParametrizedOp(Op, metaclass=CachedOpMeta): + def __call__(self, *args, **kwargs): + params = tuple(kwargs.get(k, v) for k, v in self._params.items()) + op = type(self)() # The canonical dispatcher. + return super(ParametrizedOp, op).__call__(*args, *params) + + +def make_op(fn=None, parent=None, *, name=None, params=None, module_name="funsor.ops"): """ Factory to create a new :class:`Op` subclass and a new instance of that class. """ # Support use as decorator. if fn is None: - return lambda fn: make_op(fn, parent, name=name, module_name=module_name) - - if parent is None: - parent = Op - assert issubclass(parent, Op) + return lambda fn: make_op( + fn, + parent, + name=name, + params=params, + module_name=module_name, + ) if name is None: name = fn if isinstance(fn, str) else fn.__name__ assert isinstance(name, str) + if params is not None: + parent = ParametrizedOp + if parent is None: + parent = Op + assert issubclass(parent, Op) + classname = name.capitalize().rstrip("_") + "Op" # e.g. add -> AddOp cls = type(classname, (parent,), {}) cls.__module__ = module_name + if params is not None: + cls._params = params op = cls(fn, name=name) return op diff --git a/funsor/torch/ops.py b/funsor/torch/ops.py index f1a2fd11e..bf3475a07 100644 --- a/funsor/torch/ops.py +++ b/funsor/torch/ops.py @@ -279,11 +279,24 @@ def _stack(dim, *x): return torch.stack(x, dim=dim) +# OLD @ops.sum.register(torch.Tensor, (int, type(None))) def _sum(x, dim): return x.sum() if dim is None else x.sum(dim) +# NEW +@ops.sum.register(torch.Tensor) +@ops.sum.register(torch.Tensor, type(None), bool) +def _sum_int(x, axis=None, keepdims=False): + return x.sum(keepdim=keepdims) + + +@ops.sum.register(torch.Tensor, int, bool) +def _sum_int(x, axis, keepdims): + return x.sum(axis, keepdim=keepdims) + + @ops.triangular_solve.register(torch.Tensor, torch.Tensor) def _triangular_solve(x, y, upper=False, transpose=False): return x.triangular_solve(y, upper, transpose).solution diff --git a/test/test_tensor.py b/test/test_tensor.py index 8be2d769c..ba1880960 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -1306,3 +1306,34 @@ def test_scatter_pure_renaming(): assert actual.input_vars == expected.input_vars assert ((actual - expected).abs() < 1e-4).data.all() + + +# TODO add a test with batch dimensions +@pytest.mark.parametrize("shape", [(2, 3, 4)], ids=str) +def test_sum_parameters(shape): + data = randn(*shape) + AXES = [None, 0, 1, 2, -1, -2, -3, [0, 2]] + KEEPDIMS = [False, True] + + assert_close(Tensor(ops.sum(data)), ops.sum(Tensor(data))) + for axis in AXES: + assert_close(Tensor(ops.sum(data, axis)), ops.sum(Tensor(data), axis)) + assert_close(Tensor(ops.sum(data, axis=axis)), ops.sum(Tensor(data), axis=axis)) + for keepdim in KEEPDIMS: + assert_close( + Tensor(ops.sum(data, keepdim=keepdim)), + ops.sum(Tensor(data), keepdim=keepdim), + ) + for axis in AXES: + assert_close( + Tensor(ops.sum(data, axis, keepdim)), + ops.sum(Tensor(data), axis, keepdim), + ) + assert_close( + Tensor(ops.sum(data, axis, keepdim=keepdim)), + ops.sum(Tensor(data), axis, keepdim=keepdim), + ) + assert_close( + Tensor(ops.sum(data, axis=axis, keepdim=keepdim)), + ops.sum(Tensor(data), axis=axis, keepdim=keepdim), + ) From 394f0b4338092236cd080b86de6964e13562c067 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 17 Mar 2021 21:15:19 -0400 Subject: [PATCH 2/4] Revert parametrized op changes --- funsor/ops/array.py | 2 +- funsor/ops/op.py | 29 ++++++----------------------- funsor/torch/ops.py | 13 ------------- 3 files changed, 7 insertions(+), 37 deletions(-) diff --git a/funsor/ops/array.py b/funsor/ops/array.py index 730013fbd..d31f95668 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -42,7 +42,7 @@ isnan = make_op(np.isnan) prod = make_op(np.prod) stack = make_op("stack") -sum = make_op(np.sum, params=("axis", "keepdims")) +sum = make_op(np.sum) transpose = make_op("transpose") sqrt.register(array)(np.sqrt) diff --git a/funsor/ops/op.py b/funsor/ops/op.py index e4a134f69..12d6f5769 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -121,42 +121,25 @@ def decorator(fn): return decorator -class ParametrizedOp(Op, metaclass=CachedOpMeta): - def __call__(self, *args, **kwargs): - params = tuple(kwargs.get(k, v) for k, v in self._params.items()) - op = type(self)() # The canonical dispatcher. - return super(ParametrizedOp, op).__call__(*args, *params) - - -def make_op(fn=None, parent=None, *, name=None, params=None, module_name="funsor.ops"): +def make_op(fn=None, parent=None, *, name=None, module_name="funsor.ops"): """ Factory to create a new :class:`Op` subclass and a new instance of that class. """ # Support use as decorator. if fn is None: - return lambda fn: make_op( - fn, - parent, - name=name, - params=params, - module_name=module_name, - ) - - if name is None: - name = fn if isinstance(fn, str) else fn.__name__ - assert isinstance(name, str) + return lambda fn: make_op(fn, parent, name=name, module_name=module_name) - if params is not None: - parent = ParametrizedOp if parent is None: parent = Op assert issubclass(parent, Op) + if name is None: + name = fn if isinstance(fn, str) else fn.__name__ + assert isinstance(name, str) + classname = name.capitalize().rstrip("_") + "Op" # e.g. add -> AddOp cls = type(classname, (parent,), {}) cls.__module__ = module_name - if params is not None: - cls._params = params op = cls(fn, name=name) return op diff --git a/funsor/torch/ops.py b/funsor/torch/ops.py index bf3475a07..f1a2fd11e 100644 --- a/funsor/torch/ops.py +++ b/funsor/torch/ops.py @@ -279,24 +279,11 @@ def _stack(dim, *x): return torch.stack(x, dim=dim) -# OLD @ops.sum.register(torch.Tensor, (int, type(None))) def _sum(x, dim): return x.sum() if dim is None else x.sum(dim) -# NEW -@ops.sum.register(torch.Tensor) -@ops.sum.register(torch.Tensor, type(None), bool) -def _sum_int(x, axis=None, keepdims=False): - return x.sum(keepdim=keepdims) - - -@ops.sum.register(torch.Tensor, int, bool) -def _sum_int(x, axis, keepdims): - return x.sum(axis, keepdim=keepdims) - - @ops.triangular_solve.register(torch.Tensor, torch.Tensor) def _triangular_solve(x, y, upper=False, transpose=False): return x.triangular_solve(y, upper, transpose).solution From bd092fed0fafc0fb62583fefee49b0c82063fa0d Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 17 Mar 2021 21:54:24 -0400 Subject: [PATCH 3/4] Implement richer ops.sum --- funsor/domains.py | 21 ++++++++--------- funsor/jax/ops.py | 4 ++-- funsor/ops/array.py | 4 ++-- funsor/tensor.py | 19 ++++++++++++++++ funsor/torch/ops.py | 9 ++++++-- test/test_tensor.py | 55 ++++++++++++++++++++++++++++++--------------- 6 files changed, 78 insertions(+), 34 deletions(-) diff --git a/funsor/domains.py b/funsor/domains.py index d97ff5cbd..3dba4b34b 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -250,20 +250,21 @@ def _find_domain_log_exp(op, domain): @find_domain.register(ops.SumOp) def _find_domain_sum(op, domain): - # Canonicalize axes. - ndim = len(domain.shape) - if op.axis is None: - axis = set(range(ndim)) - elif isinstance(op.axis, int): - axis = {op.axis % ndim} + # Canonicalize dim. + dim = op.defaults.get("dim", None) + ndims = len(domain.shape) + if dim is None: + dims = set(range(ndims)) + elif isinstance(dim, int): + dims = {dim % ndims} else: - axis = {i % ndim for i in op.axis} + dims = {i % ndims for i in dim} # Compute shape. - if op.keepdims: - shape = tuple(1 if i in axis else size for i, size in enumerate(domain.shape)) + if op.defaults.get("keepdims", False): + shape = tuple(1 if i in dims else size for i, size in enumerate(domain.shape)) else: - shape = tuple(size for i, size in enumerate(domain.shape) if i not in axis) + shape = tuple(size for i, size in enumerate(domain.shape) if i not in dims) # Compute domain. if domain.dtype == "real": diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index be784071d..05dccc2ec 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -258,8 +258,8 @@ def _stack(parts, dim=0): @ops.sum.register(array) -def _sum(x, dim): - return np.sum(x, axis=dim) +def _sum(x, dim, keepdims): + return np.sum(x, dim, keepdims=keepdims) @ops.triangular_solve.register(array, array) diff --git a/funsor/ops/array.py b/funsor/ops/array.py index 3e6bc9289..97831fa9b 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -70,8 +70,8 @@ def amin(x, dim=None, keepdims=False): @UnaryOp.make -def sum(x, dim=None): - return np.sum(x, dim) +def sum(x, dim=None, keepdims=False): + return np.sum(x, dim, keepdims=keepdims) @UnaryOp.make diff --git a/funsor/tensor.py b/funsor/tensor.py index c34bb4ba9..f3be2d3d2 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -772,6 +772,25 @@ def eager_reshape_tensor(op, arg): return Tensor(data, arg.inputs, arg.dtype) +@eager.register(Unary, ops.SumOp, Tensor) +def eager_reshape_tensor(op, arg): + if not arg.inputs: + return Tensor(op(arg.data), arg.inputs, arg.dtype) + + # Work around batch inputs. + dim = op.defaults.get("dim", None) + keepdims = op.defaults.get("keepdims", False) + ndims = len(arg.output.shape) + if dim is None: + dim = tuple(range(-ndims, 0)) + elif isinstance(dim, int): + dim = dim % ndims - ndims + else: + dim = tuple(d % ndims - ndims for d in dim) + data = op(arg.data, dim, keepdims) + return Tensor(data, arg.inputs, arg.dtype) + + @eager.register(Binary, GetitemOp, Tensor, Number) def eager_getitem_tensor_number(op, lhs, rhs): offset = op.defaults["offset"] diff --git a/funsor/torch/ops.py b/funsor/torch/ops.py index f99b9bd25..52e03f64e 100644 --- a/funsor/torch/ops.py +++ b/funsor/torch/ops.py @@ -271,8 +271,13 @@ def _scatter_add(destin, indices, source): @ops.sum.register(torch.Tensor) -def _sum(x, dim): - return x.sum() if dim is None else x.sum(dim) +def _sum(x, dim, keepdims): + if dim is None: + if keepdims: + dim = tuple(range(x.dim())) + return x.sum(dim, True) + return x.sum() + return x.sum(dim, keepdims) @ops.triangular_solve.register(torch.Tensor, torch.Tensor) diff --git a/test/test_tensor.py b/test/test_tensor.py index 435aba799..203c942bf 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -1307,32 +1307,51 @@ def test_scatter_pure_renaming(): assert ((actual - expected).abs() < 1e-4).data.all() -# TODO add a test with batch dimensions -@pytest.mark.parametrize("shape", [(2, 3, 4)], ids=str) -def test_sum_parameters(shape): - data = randn(*shape) - AXES = [None, 0, 1, 2, -1, -2, -3, [0, 2]] +@pytest.mark.parametrize("event_shape", [(2, 3, 4)], ids=str) +def test_sum(event_shape): + data = randn(*event_shape) + DIMS = [None, 0, 1, 2, -1, -2, -3, (0, 2)] KEEPDIMS = [False, True] assert_close(Tensor(ops.sum(data)), ops.sum(Tensor(data))) - for axis in AXES: - assert_close(Tensor(ops.sum(data, axis)), ops.sum(Tensor(data), axis)) - assert_close(Tensor(ops.sum(data, axis=axis)), ops.sum(Tensor(data), axis=axis)) - for keepdim in KEEPDIMS: + for dim in DIMS: + assert_close(Tensor(ops.sum(data, dim)), ops.sum(Tensor(data), dim)) + assert_close(Tensor(ops.sum(data, dim=dim)), ops.sum(Tensor(data), dim=dim)) + for keepdims in KEEPDIMS: assert_close( - Tensor(ops.sum(data, keepdim=keepdim)), - ops.sum(Tensor(data), keepdim=keepdim), + Tensor(ops.sum(data, keepdims=keepdims)), + ops.sum(Tensor(data), keepdims=keepdims), ) - for axis in AXES: + for dim in DIMS: assert_close( - Tensor(ops.sum(data, axis, keepdim)), - ops.sum(Tensor(data), axis, keepdim), + Tensor(ops.sum(data, dim, keepdims)), + ops.sum(Tensor(data), dim, keepdims), ) assert_close( - Tensor(ops.sum(data, axis, keepdim=keepdim)), - ops.sum(Tensor(data), axis, keepdim=keepdim), + Tensor(ops.sum(data, dim, keepdims=keepdims)), + ops.sum(Tensor(data), dim, keepdims=keepdims), ) assert_close( - Tensor(ops.sum(data, axis=axis, keepdim=keepdim)), - ops.sum(Tensor(data), axis=axis, keepdim=keepdim), + Tensor(ops.sum(data, dim=dim, keepdims=keepdims)), + ops.sum(Tensor(data), dim=dim, keepdims=keepdims), ) + + +@pytest.mark.parametrize("batch_shape", [(), (5,)], ids=str) +@pytest.mark.parametrize("event_shape", [(2, 3, 4)], ids=str) +def test_sum_batch(batch_shape, event_shape): + inputs = OrderedDict((k, Bint[s]) for k, s in zip("abc", batch_shape)) + data = randn(*batch_shape, *event_shape) + DIMS = [None, 0, 1, 2, -1, -2, -3, (0, 2)] + KEEPDIMS = [False, True] + + def raw_sum(x, dim=None, keepdims=False, batch_ndims=len(batch_shape)): + if batch_ndims == 0: + return ops.sum(x, dim, keepdims) + return ops.stack([raw_sum(part, dim, keepdims, batch_ndims - 1) for part in x]) + + for keepdims in KEEPDIMS: + for dim in DIMS: + actual = ops.sum(Tensor(data, inputs), dim, keepdims) + expected = Tensor(raw_sum(data, dim, keepdims), inputs) + assert_close(actual, expected) From d893762ada04799d1cec54079376104ae0fd1828 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 17 Mar 2021 22:20:29 -0400 Subject: [PATCH 4/4] nit --- funsor/tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funsor/tensor.py b/funsor/tensor.py index f3be2d3d2..ec86524c3 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -773,7 +773,7 @@ def eager_reshape_tensor(op, arg): @eager.register(Unary, ops.SumOp, Tensor) -def eager_reshape_tensor(op, arg): +def eager_sum_tensor(op, arg): if not arg.inputs: return Tensor(op(arg.data), arg.inputs, arg.dtype)