diff --git a/src/evermore/util.py b/src/evermore/util.py
index e413f70..ef9c46f 100644
--- a/src/evermore/util.py
+++ b/src/evermore/util.py
@@ -167,4 +167,4 @@ def f(x: jax.Array) -> jax.Array:
         filepath = pathlib.Path('graph.gv')
         filepath.write_text(dump_hlo_graph(f, x), encoding='ascii')
     """
-    return jax.xla_computation(fun)(*args, **kwargs).as_hlo_dot_graph()
+    return jax.jit(fun).lower(*args, **kwargs).compiler_ir("hlo").as_hlo_dot_graph()  # type: ignore[union-attr]