Skip to content

Commit

Permalink
Refactor ops to allow non-funsor parameters (#491)
Browse files Browse the repository at this point in the history
* Refactor ops to have known arity

* Add docs; merge CachedOpMeta into OpMeta

* Fix some ops

* Add Ternary -> Finitary pattern

* Fix some dispatch logic; fix patterns

* Fix ops.einsum usage

* Fix test_tensors.py

* Fix more ops

* Refactor TransformOp and WrappedTransformOp

* Fix misc ops

* lint

* Work around signature parsing in Python 3.6

* Fix test_cnf.py

* Disable obsolete/questionable test

* Add info to assertions failing only on ci

* Clean up error printing

* Fix typo

* Fix is_numeric_array(funsor.Tensor)

* Revert some changes to is_numeric_array and funsor.adjoint

* Set JAX_ENABLE_X64=1 on ci

* Increase number of samples in test_gaussian_mixture_distribution

* Increase number of samples in test_dirichlet_sample
  • Loading branch information
fritzo authored Mar 18, 2021
1 parent 1392b5d commit 7c7e0c5
Show file tree
Hide file tree
Showing 29 changed files with 788 additions and 576 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,4 @@ jobs:
pip freeze
- name: Run test
run: |
CI=1 FUNSOR_BACKEND=jax make test
CI=1 JAX_ENABLE_X64=1 FUNSOR_BACKEND=jax make test
4 changes: 2 additions & 2 deletions funsor/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import defaultdict
from collections.abc import Hashable

from funsor.cnf import Contraction, nullop
from funsor.cnf import Contraction, null
from funsor.interpretations import Interpretation, reflect
from funsor.interpreter import stack_reinterpret
from funsor.ops import AssociativeOp
Expand Down Expand Up @@ -233,7 +233,7 @@ def adjoint_contract_generic(
def adjoint_contract(
adj_sum_op, adj_prod_op, out_adj, sum_op, prod_op, reduced_vars, lhs, rhs
):
if prod_op is adj_prod_op and sum_op in (nullop, adj_sum_op):
if prod_op is adj_prod_op and sum_op in (null, adj_sum_op):

# the only change is here:
out_adj = Approximate(
Expand Down
4 changes: 2 additions & 2 deletions funsor/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from funsor.domains import Bint
from funsor.interpreter import gensym
from funsor.tensor import EinsumOp, Tensor, get_default_prototype
from funsor.tensor import Tensor, get_default_prototype
from funsor.terms import Binary, Finitary, Funsor, Lambda, Reduce, Unary, Variable

from . import ops
Expand Down Expand Up @@ -92,7 +92,7 @@ def _(fn):
return affine_inputs(fn.arg) - fn.reduced_vars


@affine_inputs.register(Finitary[EinsumOp, tuple])
@affine_inputs.register(Finitary[ops.EinsumOp, tuple])
def _(fn):
# This is simply a multiary version of the above Binary(ops.mul, ...) case.
results = []
Expand Down
36 changes: 18 additions & 18 deletions funsor/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from funsor.gaussian import Gaussian
from funsor.interpretations import eager, normalize, reflect
from funsor.interpreter import children, recursion_reinterpret
from funsor.ops import DISTRIBUTIVE_OPS, AssociativeOp, NullOp, nullop
from funsor.ops import DISTRIBUTIVE_OPS, AssociativeOp, NullOp, null
from funsor.tensor import Tensor
from funsor.terms import (
_INFIX,
Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(self, red_op, bin_op, reduced_vars, terms):
for v in terms:
inputs.update((k, d) for k, d in v.inputs.items() if k not in bound)

if bin_op is nullop:
if bin_op is null:
output = terms[0].output
else:
output = reduce(
Expand Down Expand Up @@ -107,8 +107,8 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):
if not sampled_vars:
return self

if self.red_op in (ops.logaddexp, nullop):
if self.bin_op in (ops.nullop, ops.logaddexp):
if self.red_op in (ops.logaddexp, null):
if self.bin_op in (ops.null, ops.logaddexp):
if rng_key is not None and get_backend() == "jax":
import jax

Expand Down Expand Up @@ -277,7 +277,7 @@ def eager_contraction_generic_recursive(red_op, bin_op, reduced_vars, terms):
if unique_vars:
result = term.reduce(red_op, unique_vars)
if result is not normalize.interpret(
Contraction, red_op, nullop, unique_vars, (term,)
Contraction, red_op, null, unique_vars, (term,)
):
terms[i] = result
reduced_vars -= unique_vars
Expand Down Expand Up @@ -432,7 +432,7 @@ def normalize_contraction_commutative_canonical_order(
)
def normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, mixture, other):
return Contraction(
mixture.red_op if red_op is nullop else red_op,
mixture.red_op if red_op is null else red_op,
bin_op,
reduced_vars | mixture.reduced_vars,
*(mixture.terms + (other,))
Expand All @@ -444,7 +444,7 @@ def normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, mixture, o
)
def normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, other, mixture):
return Contraction(
mixture.red_op if red_op is nullop else red_op,
mixture.red_op if red_op is null else red_op,
bin_op,
reduced_vars | mixture.reduced_vars,
*(mixture.terms + (other,))
Expand All @@ -467,13 +467,13 @@ def normalize_trivial(red_op, bin_op, reduced_vars, term):
@normalize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, tuple)
def normalize_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms):

if not reduced_vars and red_op is not nullop:
return Contraction(nullop, bin_op, reduced_vars, *terms)
if not reduced_vars and red_op is not null:
return Contraction(null, bin_op, reduced_vars, *terms)

if len(terms) == 1 and bin_op is not nullop:
return Contraction(red_op, nullop, reduced_vars, *terms)
if len(terms) == 1 and bin_op is not null:
return Contraction(red_op, null, reduced_vars, *terms)

if red_op is nullop and bin_op is nullop:
if red_op is null and bin_op is null:
return terms[0]

if red_op is bin_op:
Expand All @@ -498,11 +498,11 @@ def normalize_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms):
continue

# fuse operations without distributing
if (v.red_op is nullop and bin_op is v.bin_op) or (
bin_op is nullop and v.red_op in (red_op, nullop)
if (v.red_op is null and bin_op is v.bin_op) or (
bin_op is null and v.red_op in (red_op, null)
):
red_op = v.red_op if red_op is nullop else red_op
bin_op = v.bin_op if bin_op is nullop else bin_op
red_op = v.red_op if red_op is null else red_op
bin_op = v.bin_op if bin_op is null else bin_op
new_terms = terms[:i] + v.terms + terms[i + 1 :]
return Contraction(
red_op, bin_op, reduced_vars | v.reduced_vars, *new_terms
Expand All @@ -519,12 +519,12 @@ def normalize_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms):

@normalize.register(Binary, AssociativeOp, Funsor, Funsor)
def binary_to_contract(op, lhs, rhs):
return Contraction(nullop, op, frozenset(), lhs, rhs)
return Contraction(null, op, frozenset(), lhs, rhs)


@normalize.register(Reduce, AssociativeOp, Funsor, frozenset)
def reduce_funsor(op, arg, reduced_vars):
return Contraction(op, nullop, reduced_vars, arg)
return Contraction(op, null, reduced_vars, arg)


@normalize.register(
Expand Down
9 changes: 4 additions & 5 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
get_default_prototype,
ignore_jit_warnings,
numeric_array,
stack,
)
from funsor.terms import (
Funsor,
Expand Down Expand Up @@ -768,15 +767,15 @@ def LogNormal(loc, scale, value="value"):


def eager_beta(concentration1, concentration0, value):
concentration = stack((concentration0, concentration1))
value = stack((1 - value, value))
concentration = ops.stack((concentration0, concentration1))
value = ops.stack((1 - value, value))
backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()])
return backend_dist.Dirichlet(concentration, value=value) # noqa: F821


def eager_binomial(total_count, probs, value):
probs = stack((1 - probs, probs))
value = stack((total_count - value, value))
probs = ops.stack((1 - probs, probs))
value = ops.stack((total_count - value, value))
backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()])
return backend_dist.Multinomial(total_count, probs, value=value) # noqa: F821

Expand Down
38 changes: 35 additions & 3 deletions funsor/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,14 +250,15 @@ def _find_domain_log_exp(op, domain):

@find_domain.register(ops.ReshapeOp)
def _find_domain_reshape(op, domain):
return Array[domain.dtype, op.shape]
return Array[domain.dtype, op.defaults["shape"]]


@find_domain.register(ops.GetitemOp)
def _find_domain_getitem(op, lhs_domain, rhs_domain):
if isinstance(lhs_domain, ArrayType):
offset = op.defaults["offset"]
dtype = lhs_domain.dtype
shape = lhs_domain.shape[: op.offset] + lhs_domain.shape[1 + op.offset :]
shape = lhs_domain.shape[:offset] + lhs_domain.shape[1 + offset :]
return Array[dtype, shape]
elif isinstance(lhs_domain, ProductDomain):
# XXX should this return a Union?
Expand Down Expand Up @@ -342,7 +343,7 @@ def _find_domain_associative_generic(op, *domains):

@find_domain.register(ops.WrappedTransformOp)
def _transform_find_domain(op, domain):
fn = op.dispatch(object)
fn = op.defaults["fn"]
shape = fn.forward_shape(domain.shape)
return Array[domain.dtype, shape]

Expand All @@ -353,6 +354,37 @@ def _transform_log_abs_det_jacobian(op, domain, codomain):
return Real


@find_domain.register(ops.StackOp)
def _find_domain_stack(op, parts):
shape = broadcast_shape(*(x.shape for x in parts))
dim = op.defaults["dim"]
if dim >= 0:
dim = dim - len(shape) - 1
assert dim < 0
split = dim + len(shape) + 1
shape = shape[:split] + (len(parts),) + shape[split:]
output = Array[parts[0].dtype, shape]
return output


@find_domain.register(ops.EinsumOp)
def _find_domain_einsum(op, operands):
equation = op.defaults["equation"]
ein_inputs, ein_output = equation.split("->")
ein_inputs = ein_inputs.split(",")
size_dict = {}
for ein_input, x in zip(ein_inputs, operands):
assert x.dtype == "real"
assert len(ein_input) == len(x.shape)
for name, size in zip(ein_input, x.shape):
other_size = size_dict.setdefault(name, size)
if other_size != size:
raise ValueError(
"Size mismatch at {}: {} vs {}".format(name, size, other_size)
)
return Reals[tuple(size_dict[d] for d in ein_output)]


__all__ = [
"Bint",
"BintType",
Expand Down
2 changes: 1 addition & 1 deletion funsor/einsum/numpy_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def einsum(equation, *operands):
shift = ops.permute(shift, [dims.index(dim) for dim in output])
shifts.append(shift)

result = ops.log(ops.einsum(equation, *exp_operands))
result = ops.log(ops.einsum(exp_operands, equation))
return sum(shifts + [result])


Expand Down
40 changes: 20 additions & 20 deletions funsor/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def as_tensor(self):

# Concatenate parts.
parts = [v for k, v in sorted(self.parts.items())]
result = ops.cat(-1, *parts)
result = ops.cat(parts, -1)
if not get_tracing_state():
assert result.shape == self.shape
return result
Expand Down Expand Up @@ -182,10 +182,10 @@ def as_tensor(self):
# TODO This could be optimized into a single .reshape().cat().reshape() if
# all inputs are contiguous, thereby saving a memcopy.
columns = {
i: ops.cat(-1, *[v for j, v in sorted(part.items())])
i: ops.cat([v for j, v in sorted(part.items())], -1)
for i, part in self.parts.items()
}
result = ops.cat(-2, *[v for i, v in sorted(columns.items())])
result = ops.cat([v for i, v in sorted(columns.items())], -2)
if not get_tracing_state():
assert result.shape == self.shape
return result
Expand Down Expand Up @@ -468,32 +468,32 @@ def _eager_subs_real(self, subs, remaining_subs):
k for k, d in self.inputs.items() if d.dtype == "real" and k not in b
)
prec_aa = ops.cat(
-2,
*[
ops.cat(-1, *[precision[..., i1, i2] for k2, i2 in slices if k2 in a])
[
ops.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in a], -1)
for k1, i1 in slices
if k1 in a
]
],
-2,
)
prec_ab = ops.cat(
-2,
*[
ops.cat(-1, *[precision[..., i1, i2] for k2, i2 in slices if k2 in b])
[
ops.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b], -1)
for k1, i1 in slices
if k1 in a
]
],
-2,
)
prec_bb = ops.cat(
-2,
*[
ops.cat(-1, *[precision[..., i1, i2] for k2, i2 in slices if k2 in b])
[
ops.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b], -1)
for k1, i1 in slices
if k1 in b
]
],
-2,
)
info_a = ops.cat(-1, *[info_vec[..., i] for k, i in slices if k in a])
info_b = ops.cat(-1, *[info_vec[..., i] for k, i in slices if k in b])
value_b = ops.cat(-1, *[values[k] for k, i in slices if k in b])
info_a = ops.cat([info_vec[..., i] for k, i in slices if k in a], -1)
info_b = ops.cat([info_vec[..., i] for k, i in slices if k in b], -1)
value_b = ops.cat([values[k] for k, i in slices if k in b], -1)
info_vec = info_a - _mv(prec_ab, value_b)
log_scale = _vv(value_b, info_b - 0.5 * _mv(prec_bb, value_b))
precision = ops.expand(prec_aa, info_vec.shape + info_vec.shape[-1:])
Expand Down Expand Up @@ -637,8 +637,8 @@ def eager_reduce(self, op, reduced_vars):
1,
)
(b if key in reduced_vars else a).append(block)
a = ops.cat(-1, *a)
b = ops.cat(-1, *b)
a = ops.cat(a, -1)
b = ops.cat(b, -1)
prec_aa = self.precision[..., a[..., None], a]
prec_ba = self.precision[..., b[..., None], a]
prec_bb = self.precision[..., b[..., None], b]
Expand Down
2 changes: 1 addition & 1 deletion funsor/jax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def deltadist_to_data(funsor_dist, name_to_dim=None):

@to_funsor.register(dist.transforms.Transform)
def transform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None):
op = ops.WrappedTransformOp(tfm)
op = ops.WrappedTransformOp(fn=tfm)
name = next(real_inputs.keys()) if real_inputs else "value"
return op(Variable(name, output))

Expand Down
Loading

0 comments on commit 7c7e0c5

Please sign in to comment.