diff --git a/src/evermore/util.py b/src/evermore/util.py index e413f70..ef9c46f 100644 --- a/src/evermore/util.py +++ b/src/evermore/util.py @@ -167,4 +167,4 @@ def f(x: jax.Array) -> jax.Array: filepath = pathlib.Path('graph.gv') filepath.write_text(dump_hlo_graph(f, x), encoding='ascii') """ - return jax.xla_computation(fun)(*args, **kwargs).as_hlo_dot_graph() + return jax.jit(fun).lower(*args, **kwargs).compiler_ir("hlo").as_hlo_dot_graph() # type: ignore[union-attr]