Skip to content

Commit

Permalink
debugging jax failure
Browse files Browse the repository at this point in the history
  • Loading branch information
olive004 committed Oct 30, 2024
1 parent d190ce3 commit dfd8d7b
Showing 1 changed file with 56 additions and 139 deletions.
195 changes: 56 additions & 139 deletions notebooks/24_autodiff.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 25,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -35,7 +35,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 26,
"metadata": {},
"outputs": [
{
Expand All @@ -44,7 +44,7 @@
"[CpuDevice(id=0)]"
]
},
"execution_count": 10,
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -86,7 +86,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -108,7 +108,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -194,32 +194,19 @@
"metadata": {},
"outputs": [],
"source": [
"sim_func = jax.vmap(partial(bioreaction_sim_dfx_expanded,\n",
" t0=t0, t1=t1, dt0=dt0,\n",
" forward_rates=forward_rates,\n",
" inputs=inputs,\n",
" outputs=outputs,\n",
" solver=get_diffrax_solver(sim_method),\n",
" saveat=dfx.SaveAt(ts=jnp.linspace(t0, t1, 500)),\n",
" stepsize_controller=make_stepsize_controller(t0, t1, dt0, dt1, choice=stepsize_controller)))\n",
"\n",
"# sim_func = partial(bioreaction_sim_expanded,\n",
"# inputs=inputs, outputs=outputs,\n",
"# forward_rates=forward_rates.squeeze()\n",
"# )\n",
"\n",
"sim_func = partial(one_step_de_sim_expanded,\n",
" inputs=inputs,\n",
" outputs=outputs,\n",
" forward_rates=forward_rates) # + signal(t) * signal_onehot"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def bioreaction_sim_expanded(t, y,\n",
" args,\n",
" inputs, outputs,\n",
" # signal, signal_onehot: jnp.ndarray,\n",
" forward_rates=None, reverse_rates=None):\n",
" params, y0 = args\n",
" y_t, y_sens, y_prec = y\n",
" dy_dt = one_step_de_sim_expanded(\n",
" spec_conc=y_t, inputs=inputs,\n",
" outputs=outputs,\n",
" forward_rates=forward_rates,\n",
" reverse_rates=reverse_rates) #+ signal(t) * signal_onehot\n",
" return dy_dt, y_sens, y_prec\n",
"\n",
"\n",
"def bioreaction_sim_dfx_expanded(y0, t0, t1, dt0,\n",
Expand All @@ -237,125 +224,55 @@
" forward_rates=forward_rates.squeeze(), reverse_rates=reverse_rates.squeeze()\n",
" )\n",
" )\n",
" args = (reverse_rates, y0)\n",
" y00 = (y0, y0, y0)\n",
" sol = dfx.diffeqsolve(term, solver,\n",
" t0=t0, t1=t1, dt0=dt0,\n",
" y0=y0.squeeze(),\n",
" y0=y00,\n",
" args=args,\n",
" saveat=saveat, max_steps=max_steps,\n",
" stepsize_controller=stepsize_controller)\n",
" return sol.ts, sol.ys"
" return sol.ts, sol.ys\n",
"\n",
"\n",
"sim_func = jax.vmap(partial(bioreaction_sim_dfx_expanded,\n",
" t0=t0, t1=t1, dt0=dt0,\n",
" forward_rates=forward_rates,\n",
" inputs=inputs,\n",
" outputs=outputs,\n",
" solver=get_diffrax_solver(sim_method),\n",
" saveat=dfx.SaveAt(ts=jnp.linspace(t0, t1, 500)),\n",
" stepsize_controller=make_stepsize_controller(t0, t1, dt0, dt1, choice=stepsize_controller)))\n",
"# sim_func = partial(bioreaction_sim_expanded,\n",
"# inputs=inputs, outputs=outputs,\n",
"# forward_rates=forward_rates.squeeze()\n",
"# )\n",
"\n",
"# sim_func = partial(one_step_de_sim_expanded,\n",
"# inputs=inputs,\n",
"# outputs=outputs,\n",
"# forward_rates=forward_rates) # + signal(t) * signal_onehot\n"
]
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Array([[ 0. , 1.002004, 2.004008, ..., 497.996 , 498.99802 ,\n",
" 500. ],\n",
" [ 0. , 1.002004, 2.004008, ..., 497.996 , 498.99802 ,\n",
" 500. ],\n",
" [ 0. , 1.002004, 2.004008, ..., 497.996 , 498.99802 ,\n",
" 500. ],\n",
" ...,\n",
" [ 0. , 1.002004, 2.004008, ..., 497.996 , 498.99802 ,\n",
" 500. ],\n",
" [ 0. , 1.002004, 2.004008, ..., 497.996 , 498.99802 ,\n",
" 500. ],\n",
" [ 0. , 1.002004, 2.004008, ..., 497.996 , 498.99802 ,\n",
" 500. ]], dtype=float32),\n",
" Array([[[2.00000000e+02, 2.00000000e+02, 2.00000000e+02, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [9.24192581e+01, 9.08212891e+01, 9.23455048e+01, ...,\n",
" 2.71574574e+01, 2.75130463e+01, 2.75738888e+01],\n",
" [6.28656578e+01, 5.91621017e+01, 6.26813736e+01, ...,\n",
" 3.47783165e+01, 3.58699989e+01, 3.61319504e+01],\n",
" ...,\n",
" [8.88446236e+00, 1.47674494e+01, 7.52454662e+00, ...,\n",
" 2.26259670e+01, 1.19676529e+02, 3.60869331e+01],\n",
" [8.88373089e+00, 1.47605391e+01, 7.52323532e+00, ...,\n",
" 2.26030750e+01, 1.19744850e+02, 3.60535088e+01],\n",
" [8.88300419e+00, 1.47536402e+01, 7.52192688e+00, ...,\n",
" 2.25802345e+01, 1.19813004e+02, 3.60201683e+01]],\n",
" \n",
" [[2.00000000e+02, 2.00000000e+02, 2.00000000e+02, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [9.08339233e+01, 9.05240707e+01, 9.10328979e+01, ...,\n",
" 2.73795090e+01, 2.74108887e+01, 2.71728058e+01],\n",
" [5.92983971e+01, 5.85163498e+01, 5.97971382e+01, ...,\n",
" 3.53879204e+01, 3.55095673e+01, 3.48664589e+01],\n",
" ...,\n",
" [1.62853565e+01, 1.80375528e+00, 1.82791328e+01, ...,\n",
" 5.14834785e+01, 8.26435699e+01, 3.21613808e+01],\n",
" [1.62860413e+01, 1.80158103e+00, 1.82763309e+01, ...,\n",
" 5.14822121e+01, 8.26739044e+01, 3.21499557e+01],\n",
" [1.62867241e+01, 1.79941380e+00, 1.82735348e+01, ...,\n",
" 5.14809341e+01, 8.27041779e+01, 3.21385574e+01]],\n",
" \n",
" [[2.00000000e+02, 2.00000000e+02, 2.00000000e+02, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [1.09884285e+02, 8.99853973e+01, 1.04429390e+02, ...,\n",
" 2.73189831e+01, 2.65429802e+01, 2.98244076e+01],\n",
" [9.03784180e+01, 5.75780563e+01, 7.76243286e+01, ...,\n",
" 3.51244278e+01, 3.27386742e+01, 4.19864998e+01],\n",
" ...,\n",
" [5.56591721e+01, 4.52586269e+00, 2.06094503e+00, ...,\n",
" 4.21267128e+01, 1.19989447e-01, 9.88722610e+01],\n",
" [5.56493492e+01, 4.52690601e+00, 2.06073833e+00, ...,\n",
" 4.21153603e+01, 1.20004341e-01, 9.88723602e+01],\n",
" [5.56395454e+01, 4.52794790e+00, 2.06053448e+00, ...,\n",
" 4.21040382e+01, 1.20019317e-01, 9.88724670e+01]],\n",
" \n",
" ...,\n",
" \n",
" [[2.00000000e+02, 2.00000000e+02, 2.00000000e+02, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [9.18916397e+01, 9.64394455e+01, 1.05049713e+02, ...,\n",
" 2.81655197e+01, 2.05473309e+01, 2.40432758e+01],\n",
" [6.13188171e+01, 7.00012360e+01, 8.94586868e+01, ...,\n",
" 3.80337715e+01, 2.02778664e+01, 2.85694027e+01],\n",
" ...,\n",
" [1.08690071e+01, 1.69269638e+01, 9.83243179e+01, ...,\n",
" 8.74400940e+01, 4.37396336e+00, 4.24577866e+01],\n",
" [1.08690042e+01, 1.69269638e+01, 9.83243179e+01, ...,\n",
" 8.74400940e+01, 4.37396336e+00, 4.24577866e+01],\n",
" [1.08690023e+01, 1.69269638e+01, 9.83243179e+01, ...,\n",
" 8.74400940e+01, 4.37396336e+00, 4.24577866e+01]],\n",
" \n",
" [[2.00000000e+02, 2.00000000e+02, 2.00000000e+02, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [9.17368622e+01, 9.19353180e+01, 9.05790710e+01, ...,\n",
" 2.74012737e+01, 2.74069099e+01, 2.73210106e+01],\n",
" [6.13807716e+01, 6.18817673e+01, 5.85931969e+01, ...,\n",
" 3.56110382e+01, 3.55415955e+01, 3.52179146e+01],\n",
" ...,\n",
" [9.37140942e+00, 1.87672386e+01, 1.09823017e+01, ...,\n",
" 5.42997704e+01, 7.00314026e+01, 4.44822464e+01],\n",
" [9.37071514e+00, 1.87668285e+01, 1.09823265e+01, ...,\n",
" 5.42950783e+01, 7.00414658e+01, 4.44831886e+01],\n",
" [9.37002373e+00, 1.87664185e+01, 1.09823513e+01, ...,\n",
" 5.42904053e+01, 7.00514755e+01, 4.44841270e+01]],\n",
" \n",
" [[2.00000000e+02, 2.00000000e+02, 2.00000000e+02, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [9.39313278e+01, 9.03014832e+01, 9.59741287e+01, ...,\n",
" 2.73625488e+01, 2.74032383e+01, 2.56552315e+01],\n",
" [6.64336548e+01, 5.77052078e+01, 7.12187729e+01, ...,\n",
" 3.52692719e+01, 3.56921158e+01, 3.13752174e+01],\n",
" ...,\n",
" [3.84244003e+01, 2.39125037e+00, 6.83517838e+01, ...,\n",
" 5.37357140e+01, 1.29268179e+01, 4.71906662e+01],\n",
" [3.84264984e+01, 2.39043307e+00, 6.83527832e+01, ...,\n",
" 5.37439995e+01, 1.29212008e+01, 4.71921272e+01],\n",
" [3.84285965e+01, 2.38961887e+00, 6.83537903e+01, ...,\n",
" 5.37522812e+01, 1.29156113e+01, 4.71935768e+01]]], dtype=float32))"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
"ename": "EquinoxRuntimeError",
"evalue": "Above is the stack outside of JIT. Below is the stack inside of JIT:\n File \"/usr/local/lib/python3.10/dist-packages/diffrax/_integrate.py\", line 1423, in diffeqsolve\n sol = result.error_if(sol, jnp.invert(is_okay(result)))\nequinox.EquinoxRuntimeError: The maximum number of solver steps was reached. Try increasing `max_steps`.\n\n-------------------\n\nAn error occurred during the runtime of your JAX program.\n\n1) Setting the environment variable `EQX_ON_ERROR=breakpoint` is usually the most useful\nway to debug such errors. This can be interacted with using most of the usual commands\nfor the Python debugger: `u` and `d` to move up and down frames, the name of a variable\nto print its value, etc.\n\n2) You may also like to try setting `JAX_DISABLE_JIT=1`. This will mean that you can\n(mostly) inspect the state of your program as if it was normal Python.\n\n3) See `https://docs.kidger.site/equinox/api/debug/` for more suggestions.\n",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mEquinoxRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[30], line 33\u001b[0m\n\u001b[1;32m 17\u001b[0m solution \u001b[38;5;241m=\u001b[39m dfx\u001b[38;5;241m.\u001b[39mdiffeqsolve(\n\u001b[1;32m 18\u001b[0m dfx\u001b[38;5;241m.\u001b[39mODETerm(ode_sp),\n\u001b[1;32m 19\u001b[0m solver\u001b[38;5;241m=\u001b[39mget_diffrax_solver(sim_method),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[38;5;66;03m# adjoint=dfx.BacksolveAdjoint() # This enables differentiation\u001b[39;00m\n\u001b[1;32m 30\u001b[0m )\n\u001b[1;32m 31\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m solution\n\u001b[0;32m---> 33\u001b[0m \u001b[43msim_func\u001b[49m\u001b[43m(\u001b[49m\u001b[43my0\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43my0\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my0\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my0\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreverse_rates\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my0\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
" \u001b[0;31m[... skipping hidden 3 frame]\u001b[0m\n",
"Cell \u001b[0;32mIn[29], line 32\u001b[0m, in \u001b[0;36mbioreaction_sim_dfx_expanded\u001b[0;34m(y0, t0, t1, dt0, inputs, outputs, forward_rates, reverse_rates, args, solver, saveat, max_steps, stepsize_controller)\u001b[0m\n\u001b[1;32m 25\u001b[0m dt0 \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 26\u001b[0m term \u001b[38;5;241m=\u001b[39m dfx\u001b[38;5;241m.\u001b[39mODETerm(\n\u001b[1;32m 27\u001b[0m partial(bioreaction_sim_expanded,\n\u001b[1;32m 28\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs, outputs\u001b[38;5;241m=\u001b[39moutputs,\n\u001b[1;32m 29\u001b[0m forward_rates\u001b[38;5;241m=\u001b[39mforward_rates\u001b[38;5;241m.\u001b[39msqueeze(), reverse_rates\u001b[38;5;241m=\u001b[39mreverse_rates\u001b[38;5;241m.\u001b[39msqueeze()\n\u001b[1;32m 30\u001b[0m )\n\u001b[1;32m 31\u001b[0m )\n\u001b[0;32m---> 32\u001b[0m sol \u001b[38;5;241m=\u001b[39m \u001b[43mdfx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdiffeqsolve\u001b[49m\u001b[43m(\u001b[49m\u001b[43mterm\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msolver\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 33\u001b[0m \u001b[43m \u001b[49m\u001b[43mt0\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mt0\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt1\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mt1\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdt0\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdt0\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 34\u001b[0m \u001b[43m \u001b[49m\u001b[43my0\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43my0\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 35\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 36\u001b[0m \u001b[43m \u001b[49m\u001b[43msaveat\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msaveat\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmax_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 37\u001b[0m \u001b[43m \u001b[49m\u001b[43mstepsize_controller\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstepsize_controller\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m sol\u001b[38;5;241m.\u001b[39mts, sol\u001b[38;5;241m.\u001b[39mys\n",
" \u001b[0;31m[... skipping hidden 2 frame]\u001b[0m\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/equinox/_jit.py:262\u001b[0m, in \u001b[0;36m_JitWrapper._call\u001b[0;34m(self, is_lower, args, kwargs)\u001b[0m\n\u001b[1;32m 249\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m JaxRuntimeError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 250\u001b[0m \u001b[38;5;66;03m# Catch Equinox's runtime errors, and re-raise them with actually useful\u001b[39;00m\n\u001b[1;32m 251\u001b[0m \u001b[38;5;66;03m# information. (By default XlaRuntimeError produces a lot of terrifying\u001b[39;00m\n\u001b[1;32m 252\u001b[0m \u001b[38;5;66;03m# but useless information.)\u001b[39;00m\n\u001b[1;32m 253\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 254\u001b[0m last_msg \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 255\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m last_stack \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 260\u001b[0m \u001b[38;5;66;03m# callback necessarily executed in the same interpreter as we are in\u001b[39;00m\n\u001b[1;32m 261\u001b[0m \u001b[38;5;66;03m# here?\u001b[39;00m\n\u001b[0;32m--> 262\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m EquinoxRuntimeError(\n\u001b[1;32m 263\u001b[0m _on_error_msg\u001b[38;5;241m.\u001b[39mformat(msg\u001b[38;5;241m=\u001b[39mlast_msg, stack\u001b[38;5;241m=\u001b[39mlast_stack)\n\u001b[1;32m 264\u001b[0m ) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 265\u001b[0m \u001b[38;5;66;03m# `from None` to hide the large but uninformative XlaRuntimeError.\u001b[39;00m\n\u001b[1;32m 266\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 267\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m\n",
"\u001b[0;31mEquinoxRuntimeError\u001b[0m: Above is the stack outside of JIT. Below is the stack inside of JIT:\n File \"/usr/local/lib/python3.10/dist-packages/diffrax/_integrate.py\", line 1423, in diffeqsolve\n sol = result.error_if(sol, jnp.invert(is_okay(result)))\nequinox.EquinoxRuntimeError: The maximum number of solver steps was reached. Try increasing `max_steps`.\n\n-------------------\n\nAn error occurred during the runtime of your JAX program.\n\n1) Setting the environment variable `EQX_ON_ERROR=breakpoint` is usually the most useful\nway to debug such errors. This can be interacted with using most of the usual commands\nfor the Python debugger: `u` and `d` to move up and down frames, the name of a variable\nto print its value, etc.\n\n2) You may also like to try setting `JAX_DISABLE_JIT=1`. This will mean that you can\n(mostly) inspect the state of your program as if it was normal Python.\n\n3) See `https://docs.kidger.site/equinox/api/debug/` for more suggestions.\n"
]
}
],
"source": [
Expand Down

0 comments on commit dfd8d7b

Please sign in to comment.