Skip to content

Commit

Permalink
add more docs
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Oct 2, 2023
1 parent cf381b0 commit 6b9f34b
Showing 1 changed file with 62 additions and 0 deletions.
62 changes: 62 additions & 0 deletions src/dilax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,75 @@ class HistDB(FrozenDB):


def as1darray(x: float | jax.Array) -> jax.Array:
"""
Converts `x` to a 1d array.
Example:
.. code-block:: python
import jax.numpy as jnp
as1darray(1.0)
# -> Array([1.], dtype=float32, weak_type=True)
as1darray(jnp.array(1.0))
# -> Array([1.], dtype=float32, weak_type=True)
"""

return jnp.atleast_1d(jnp.asarray(x))


def dump_jaxpr(fun: Callable, *args: Any, **kwargs: Any) -> str:
"""Helper function to dump the Jaxpr of a function.
Example:
.. code-block:: python
import jax
import jax.numpy as jnp
def f(x: jax.Array) -> jax.Array:
return jnp.sin(x) ** 2 + jnp.cos(x) ** 2
x = jnp.array([1.0, 2.0, 3.0])
print(dump_jaxpr(f, x))
# -> { lambda ; a:f32[3]. let
# b:f32[3] = sin a # []
# c:f32[3] = integer_pow[y=2] b # []
# d:f32[3] = cos a # []
# e:f32[3] = integer_pow[y=2] d # []
# f:f32[3] = add c e # []
# in (f,) }
"""
jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)
return jaxpr.pretty_print(name_stack=True)


def dump_hlo_graph(fun: Callable, *args: Any, **kwargs: Any) -> str:
"""
Helper to dump the HLO graph of a function as a `dot` graph.
Example:
.. code-block:: python
import jax
import jax.numpy as jnp
import path
def f(x: jax.Array) -> jax.Array:
return x + 1.0
x = jnp.array([1.0, 2.0, 3.0])
# dump dot graph to file
filepath = pathlib.Path('graph.gv')
filepath.write_text(dump_hlo_graph(f, x), encoding='ascii')
"""
return jax.xla_computation(fun)(*args, **kwargs).as_hlo_dot_graph()

0 comments on commit 6b9f34b

Please sign in to comment.