diff --git a/src/dilax/util.py b/src/dilax/util.py index 7622565..6ac9bbb 100644 --- a/src/dilax/util.py +++ b/src/dilax/util.py @@ -244,13 +244,75 @@ class HistDB(FrozenDB): def as1darray(x: float | jax.Array) -> jax.Array: + """ + Converts `x` to a 1d array. + + Example: + + .. code-block:: python + + import jax.numpy as jnp + + + as1darray(1.0) + # -> Array([1.], dtype=float32, weak_type=True) + + as1darray(jnp.array(1.0)) + # -> Array([1.], dtype=float32, weak_type=True) + """ + return jnp.atleast_1d(jnp.asarray(x)) def dump_jaxpr(fun: Callable, *args: Any, **kwargs: Any) -> str: + """Helper function to dump the Jaxpr of a function. + + Example: + + .. code-block:: python + + import jax + import jax.numpy as jnp + + def f(x: jax.Array) -> jax.Array: + return jnp.sin(x) ** 2 + jnp.cos(x) ** 2 + + x = jnp.array([1.0, 2.0, 3.0]) + + print(dump_jaxpr(f, x)) + # -> { lambda ; a:f32[3]. let + # b:f32[3] = sin a # [] + # c:f32[3] = integer_pow[y=2] b # [] + # d:f32[3] = cos a # [] + # e:f32[3] = integer_pow[y=2] d # [] + # f:f32[3] = add c e # [] + # in (f,) } + """ jaxpr = jax.make_jaxpr(fun)(*args, **kwargs) return jaxpr.pretty_print(name_stack=True) def dump_hlo_graph(fun: Callable, *args: Any, **kwargs: Any) -> str: + """ + Helper to dump the HLO graph of a function as a `dot` graph. + + Example: + + .. code-block:: python + + import jax + import jax.numpy as jnp + + import path + + + def f(x: jax.Array) -> jax.Array: + return x + 1.0 + + x = jnp.array([1.0, 2.0, 3.0]) + + # dump dot graph to file + filepath = pathlib.Path('graph.gv') + filepath.write_text(dump_hlo_graph(f, x), encoding='ascii') + """ return jax.xla_computation(fun)(*args, **kwargs).as_hlo_dot_graph()