diff --git a/bgflow/nn/flow/transformer/jax_bridge.py b/bgflow/nn/flow/transformer/jax_bridge.py index ecaaf3e7..0f114e56 100644 --- a/bgflow/nn/flow/transformer/jax_bridge.py +++ b/bgflow/nn/flow/transformer/jax_bridge.py @@ -4,7 +4,7 @@ try: import jax - import jax.numpy as jnp + from jax import numpy as jnp, lax, vmap import jax.dlpack except ImportError: jax = None @@ -67,6 +67,24 @@ def _body(_, left_right): return _inverted +def bisect_k(bijector, left_bound, right_bound, k=2, eps=1e-6): + """k-bin-bisection search. k=2 is normal bisection""" + @vmap + def _inverted(target): + init = (left_bound, right_bound) + n_iters = jnp.ceil(-jnp.log(eps)/jnp.log(k)).astype(int) + def _body(_, left_right): + left_bound, right_bound = left_right + cand = jnp.linspace(left_bound, right_bound, k+1) # cand: candidates + pred = vmap(bijector)(cand[1:-1]) # Don't calculate the bounds, which we know the result for already + comp = jnp.concatenate([jnp.array([True]), pred < target, jnp.array([False])]) # Add in the known bounds + lbin = jnp.bitwise_and(comp[:-1], ~comp[1:]).argmax() # This should contain only one True, at the boundary point + left_bound, right_bound = lax.dynamic_slice(cand, (lbin,), (2,)) + return left_bound, right_bound + + return jax.lax.fori_loop(0, n_iters, _body, init)[0] + + return _inverted def invert_bijector(bijector, root_finder): """Inverts a bijector with a root finder @@ -133,12 +151,12 @@ def _call(x, *params): return _call -def bijector_with_approx_inverse(bijector, domain=None, eps=1e-8): +def bijector_with_approx_inverse(bijector, domain=None, eps=1e-8, k=2): """Wraps bijector with approximate inverse.""" if domain is None: domain = (0, 1) root_finder = functools.partial( - bisect, + bisect if k==2 else functools.partial(bisect_k, k=k), left_bound=domain[0], right_bound=domain[1], eps=eps) @@ -234,11 +252,11 @@ def nested_vmap(fn, indices): return fn -def jax_compile(bijector, vmap_indices, backend, domain=None, bisection_eps=1e-8): +def jax_compile(bijector, vmap_indices, backend, domain=None, bisection_eps=1e-8, k=2): """Wraps simple JAX bijector into a transformer, that can be used within the bgflow eco-system.""" compile_bijector = compose(functools.partial(jax.jit)) - fwd, bwd = bijector_with_approx_inverse(nested_vmap(bijector, vmap_indices), domain, bisection_eps) + fwd, bwd = bijector_with_approx_inverse(nested_vmap(bijector, vmap_indices), domain, bisection_eps, k=k) return tuple(map(compile_bijector, (fwd, bwd))) @@ -249,14 +267,14 @@ def torch_to_jax_backend(backend): return backend -def to_torch_impl_(bijector, vmap_indices, backend, domain=None, bisection_eps=1e-8): +def to_torch_impl_(bijector, vmap_indices, backend, domain=None, bisection_eps=1e-8, k=2): """Helper impl function that can be cashed according to `vmap_indices` and `backend`""" - fwd, bwd = jax_compile(bijector, vmap_indices, backend, domain, bisection_eps) + fwd, bwd = jax_compile(bijector, vmap_indices, backend, domain, bisection_eps, k=k) return tuple(map(wrap_jax_fun, (fwd, bwd))) -def to_torch(bijector, vmap_indices=None, domain=None, bisection_eps=1e-8): +def to_torch(bijector, vmap_indices=None, domain=None, bisection_eps=1e-8, k=2): """Converts a simple JAX bijector into a torch bijector with - numerical inverses - automatic computation of log det jac @@ -270,7 +288,7 @@ def _cached(x): if indices is None: indices = tuple(range(len(x.shape))) backend = torch_to_jax_backend(x.device.type) - return cached_compile(indices, backend, domain, bisection_eps) + return cached_compile(indices, backend, domain, bisection_eps, k=k) def _fwd(x, *params): assert_float32(x) @@ -294,10 +312,10 @@ class JaxTransformer(Transformer): bijector.""" def __init__(self, bijector, compute_params, reduce_jacobian=True, - domain=None, bisection_eps=1e-8): + domain=None, bisection_eps=1e-8, k=2): super().__init__() self._compute_params = compute_params - fwd, bwd = to_torch(bijector) + fwd, bwd = to_torch(bijector, domain=domain, bisection_eps=bisection_eps, k=k) self.fwd = fwd self.bwd = bwd self.reduce_jacobian = reduce_jacobian