From 7ff022f7a59fa91b6992cbbd932b4407acb7e881 Mon Sep 17 00:00:00 2001 From: olive004 Date: Mon, 28 Oct 2024 21:12:52 +0000 Subject: [PATCH] new autodiff notebook --- notebooks/24_autodiff.ipynb | 157 ++++++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 notebooks/24_autodiff.ipynb diff --git a/notebooks/24_autodiff.ipynb b/notebooks/24_autodiff.ipynb new file mode 100644 index 00000000..b559c62a --- /dev/null +++ b/notebooks/24_autodiff.ipynb @@ -0,0 +1,157 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 0, Loss: 36.00006866455078, Params: -0.2058422565460205\n", + "Step 100, Loss: 36.00006866455078, Params: -0.2058422565460205\n", + "Step 200, Loss: 36.00006866455078, Params: -0.2058422565460205\n", + "Step 300, Loss: 36.00006866455078, Params: -0.2058422565460205\n", + "Step 400, Loss: 36.00006866455078, Params: -0.2058422565460205\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[25], line 60\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[38;5;66;03m# Optimization loop\u001b[39;00m\n\u001b[1;32m 59\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m1000\u001b[39m):\n\u001b[0;32m---> 60\u001b[0m params \u001b[38;5;241m=\u001b[39m \u001b[43moptimize_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my0\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt0\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt1\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget_max\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 61\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m%\u001b[39m \u001b[38;5;241m100\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 62\u001b[0m current_loss \u001b[38;5;241m=\u001b[39m loss_fn(params, y0, t0, t1, target_max)\n", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py:327\u001b[0m, in \u001b[0;36m_cpp_pjit..cache_miss\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 325\u001b[0m \u001b[38;5;129m@api_boundary\u001b[39m\n\u001b[1;32m 326\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcache_miss\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 327\u001b[0m outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked \u001b[38;5;241m=\u001b[39m \u001b[43m_python_pjit_helper\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 328\u001b[0m \u001b[43m \u001b[49m\u001b[43mjit_info\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 329\u001b[0m executable \u001b[38;5;241m=\u001b[39m _read_most_recent_pjit_call_executable(jaxpr)\n\u001b[1;32m 330\u001b[0m pgle_profiler \u001b[38;5;241m=\u001b[39m _read_pgle_profiler(jaxpr)\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py:185\u001b[0m, in \u001b[0;36m_python_pjit_helper\u001b[0;34m(jit_info, *args, **kwargs)\u001b[0m\n\u001b[1;32m 182\u001b[0m args_flat \u001b[38;5;241m=\u001b[39m [\u001b[38;5;241m*\u001b[39minit_states, \u001b[38;5;241m*\u001b[39margs_flat]\n\u001b[1;32m 184\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 185\u001b[0m out_flat \u001b[38;5;241m=\u001b[39m \u001b[43mpjit_p\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbind\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs_flat\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 186\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m pxla\u001b[38;5;241m.\u001b[39mDeviceAssignmentMismatchError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 187\u001b[0m fails, \u001b[38;5;241m=\u001b[39m e\u001b[38;5;241m.\u001b[39margs\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/core.py:2822\u001b[0m, in \u001b[0;36mAxisPrimitive.bind\u001b[0;34m(self, *args, **params)\u001b[0m\n\u001b[1;32m 2818\u001b[0m axis_main \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmax\u001b[39m((axis_frame(a)\u001b[38;5;241m.\u001b[39mmain_trace \u001b[38;5;28;01mfor\u001b[39;00m a \u001b[38;5;129;01min\u001b[39;00m used_axis_names(\u001b[38;5;28mself\u001b[39m, params)),\n\u001b[1;32m 2819\u001b[0m default\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, key\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mlambda\u001b[39;00m t: \u001b[38;5;28mgetattr\u001b[39m(t, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlevel\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m))\n\u001b[1;32m 2820\u001b[0m top_trace \u001b[38;5;241m=\u001b[39m (top_trace \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m axis_main \u001b[38;5;129;01mor\u001b[39;00m axis_main\u001b[38;5;241m.\u001b[39mlevel \u001b[38;5;241m<\u001b[39m top_trace\u001b[38;5;241m.\u001b[39mlevel\n\u001b[1;32m 2821\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m axis_main\u001b[38;5;241m.\u001b[39mwith_cur_sublevel())\n\u001b[0;32m-> 2822\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbind_with_trace\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtop_trace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/core.py:420\u001b[0m, in \u001b[0;36mPrimitive.bind_with_trace\u001b[0;34m(self, trace, args, params)\u001b[0m\n\u001b[1;32m 418\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbind_with_trace\u001b[39m(\u001b[38;5;28mself\u001b[39m, trace, args, params):\n\u001b[1;32m 419\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m pop_level(trace\u001b[38;5;241m.\u001b[39mlevel):\n\u001b[0;32m--> 420\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mtrace\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprocess_primitive\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mmap\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mtrace\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfull_raise\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 421\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mmap\u001b[39m(full_lower, out) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmultiple_results \u001b[38;5;28;01melse\u001b[39;00m full_lower(out)\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/core.py:909\u001b[0m, in \u001b[0;36mEvalTrace.process_primitive\u001b[0;34m(self, primitive, tracers, params)\u001b[0m\n\u001b[1;32m 907\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m call_impl_with_key_reuse_checks(primitive, primitive\u001b[38;5;241m.\u001b[39mimpl, \u001b[38;5;241m*\u001b[39mtracers, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mparams)\n\u001b[1;32m 908\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 909\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mprimitive\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimpl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtracers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py:1636\u001b[0m, in \u001b[0;36m_pjit_call_impl\u001b[0;34m(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, *args)\u001b[0m\n\u001b[1;32m 1633\u001b[0m donated_argnums \u001b[38;5;241m=\u001b[39m [i \u001b[38;5;28;01mfor\u001b[39;00m i, d \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(donated_invars) \u001b[38;5;28;01mif\u001b[39;00m d]\n\u001b[1;32m 1634\u001b[0m has_explicit_sharding \u001b[38;5;241m=\u001b[39m _pjit_explicit_sharding(\n\u001b[1;32m 1635\u001b[0m in_shardings, out_shardings, \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m-> 1636\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mxc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_xla\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpjit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1637\u001b[0m \u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcall_impl_cache_miss\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m[\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m[\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdonated_argnums\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1638\u001b[0m \u001b[43m \u001b[49m\u001b[43mtree_util\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdispatch_registry\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1639\u001b[0m \u001b[43m \u001b[49m\u001b[43mpxla\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshard_arg\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1640\u001b[0m \u001b[43m \u001b[49m\u001b[43m_get_cpp_global_cache\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhas_explicit_sharding\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py:1615\u001b[0m, in \u001b[0;36m_pjit_call_impl..call_impl_cache_miss\u001b[0;34m(*args_, **kwargs_)\u001b[0m\n\u001b[1;32m 1614\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcall_impl_cache_miss\u001b[39m(\u001b[38;5;241m*\u001b[39margs_, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs_):\n\u001b[0;32m-> 1615\u001b[0m out_flat, compiled \u001b[38;5;241m=\u001b[39m \u001b[43m_pjit_call_impl_python\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1616\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mjaxpr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjaxpr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_shardings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43min_shardings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1617\u001b[0m \u001b[43m \u001b[49m\u001b[43mout_shardings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mout_shardings\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_layouts\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43min_layouts\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1618\u001b[0m \u001b[43m \u001b[49m\u001b[43mout_layouts\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mout_layouts\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mresource_env\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresource_env\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1619\u001b[0m \u001b[43m \u001b[49m\u001b[43mdonated_invars\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdonated_invars\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeep_unused\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkeep_unused\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1620\u001b[0m \u001b[43m \u001b[49m\u001b[43minline\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minline\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1621\u001b[0m pgle_profiler \u001b[38;5;241m=\u001b[39m _read_pgle_profiler(jaxpr)\n\u001b[1;32m 1622\u001b[0m fastpath_data \u001b[38;5;241m=\u001b[39m _get_fastpath_data(\n\u001b[1;32m 1623\u001b[0m compiled, tree_structure(out_flat), args, out_flat, [], jaxpr\u001b[38;5;241m.\u001b[39meffects,\n\u001b[1;32m 1624\u001b[0m jaxpr\u001b[38;5;241m.\u001b[39mconsts, \u001b[38;5;28;01mNone\u001b[39;00m, pgle_profiler)\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py:1569\u001b[0m, in \u001b[0;36m_pjit_call_impl_python\u001b[0;34m(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, *args)\u001b[0m\n\u001b[1;32m 1561\u001b[0m distributed_debug_log((\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRunning pjit\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124md function\u001b[39m\u001b[38;5;124m\"\u001b[39m, name),\n\u001b[1;32m 1562\u001b[0m (\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124min_shardings\u001b[39m\u001b[38;5;124m\"\u001b[39m, in_shardings),\n\u001b[1;32m 1563\u001b[0m (\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mout_shardings\u001b[39m\u001b[38;5;124m\"\u001b[39m, out_shardings),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1566\u001b[0m (\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mabstract args\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28mmap\u001b[39m(xla\u001b[38;5;241m.\u001b[39mabstractify, args)),\n\u001b[1;32m 1567\u001b[0m (\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfingerprint\u001b[39m\u001b[38;5;124m\"\u001b[39m, fingerprint))\n\u001b[1;32m 1568\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1569\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcompiled\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munsafe_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m, compiled\n\u001b[1;32m 1570\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mFloatingPointError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 1571\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m config\u001b[38;5;241m.\u001b[39mdebug_nans\u001b[38;5;241m.\u001b[39mvalue \u001b[38;5;129;01mor\u001b[39;00m config\u001b[38;5;241m.\u001b[39mdebug_infs\u001b[38;5;241m.\u001b[39mvalue \u001b[38;5;66;03m# compiled_fun can only raise in this case\u001b[39;00m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py:335\u001b[0m, in \u001b[0;36mannotate_function..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 332\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[1;32m 333\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 334\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m TraceAnnotation(name, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mdecorator_kwargs):\n\u001b[0;32m--> 335\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 336\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m wrapper\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py:1216\u001b[0m, in \u001b[0;36mExecuteReplicated.__call__\u001b[0;34m(self, *args)\u001b[0m\n\u001b[1;32m 1213\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mordered_effects \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhas_unordered_effects\n\u001b[1;32m 1214\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhas_host_callbacks):\n\u001b[1;32m 1215\u001b[0m input_bufs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_add_tokens_to_inputs(input_bufs)\n\u001b[0;32m-> 1216\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mxla_executable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexecute_sharded\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1217\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_bufs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwith_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\n\u001b[1;32m 1218\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1220\u001b[0m result_token_bufs \u001b[38;5;241m=\u001b[39m results\u001b[38;5;241m.\u001b[39mdisassemble_prefix_into_single_device_arrays(\n\u001b[1;32m 1221\u001b[0m \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mordered_effects))\n\u001b[1;32m 1222\u001b[0m sharded_runtime_token \u001b[38;5;241m=\u001b[39m results\u001b[38;5;241m.\u001b[39mconsume_token()\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "import diffrax as dfx\n", + "from jax import random\n", + "\n", + "# Define the ODE function\n", + "def ode_func(t, y, params):\n", + " return y * params\n", + "\n", + "# Modified function to solve the ODE and track the maximum value with params\n", + "def solve_ode_and_track_max(ode_func, y0, t0, t1, params, max_steps=1000):\n", + " def max_value_ode(t, y, args):\n", + " # args is a tuple (params, max_val)\n", + " params, max_val = args\n", + " dy_dt = ode_func(t, y[0], params)\n", + " max_val = jnp.maximum(max_val, y[0])[0]\n", + " return dy_dt, max_val\n", + "\n", + " solver = dfx.Dopri5()\n", + " initial_max = jnp.array(y0[0]) # Initial max is the initial condition\n", + " max_callback = dfx.SaveAt(ts=jnp.linspace(t0, t1, max_steps))\n", + " \n", + " # dy_dt, max_val = max_value_ode(1, (jnp.array(y0), initial_max), (params, initial_max))\n", + "\n", + " # Solve the ODE with an auxiliary variable to track the max\n", + " solution = dfx.diffeqsolve(\n", + " dfx.ODETerm(max_value_ode),\n", + " solver=solver,\n", + " t0=t0,\n", + " t1=t1,\n", + " dt0=0.1,\n", + " y0=(jnp.array(y0), initial_max),\n", + " args=(params, initial_max), # Pass params and initial max as args\n", + " saveat=max_callback,\n", + " )\n", + " \n", + " max_value = solution.ys[1][-1] # Get the max value reached\n", + " return max_value, solution\n", + "\n", + "# Define a loss function to optimize params\n", + "def loss_fn(params, y0, t0, t1, target_max):\n", + " max_value = solve_ode_and_track_max(ode_func, y0, t0, t1, params)\n", + " return (max_value - target_max) ** 2 # Squared error loss\n", + "\n", + "# Define optimization step\n", + "@jax.jit\n", + "def optimize_step(params, y0, t0, t1, target_max, learning_rate=0.01):\n", + " grads = jax.grad(loss_fn)(params, y0, t0, t1, target_max)\n", + " return params - learning_rate * grads\n", + "\n", + "# Initialize parameters and settings\n", + "key = random.PRNGKey(0)\n", + "params = random.normal(key, ()) # Start with a random parameter\n", + "y0 = jnp.array([1.0]) # Initial condition\n", + "t0, t1 = 0.0, 10.0 # Start and end times\n", + "target_max = 5.0 # Desired maximum value\n", + "n_steps = 1000 # Number of optimization steps\n", + "\n", + "# Optimization loop\n", + "for i in range(n_steps):\n", + " params = optimize_step(params, y0, t0, t1, target_max)\n", + " if i % 100 == 0:\n", + " current_loss = loss_fn(params, y0, t0, t1, target_max)\n", + " print(f\"Step {i}, Loss: {current_loss}, Params: {params}\")\n", + "\n", + "print(f\"Optimized Params: {params}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}