Skip to content

Commit

Permalink
mutating rates
Browse files Browse the repository at this point in the history
  • Loading branch information
olive004 committed Oct 29, 2024
1 parent 439f34b commit 63160da
Showing 1 changed file with 29 additions and 66 deletions.
95 changes: 29 additions & 66 deletions notebooks/23_Monte_Carlo_adaptability_2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,18 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 11,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
Expand All @@ -24,7 +33,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 12,
"metadata": {},
"outputs": [
{
Expand All @@ -33,7 +42,7 @@
"[cuda(id=0)]"
]
},
"execution_count": 2,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -68,7 +77,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -89,7 +98,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -114,9 +123,9 @@
"k_a = 0.00150958097\n",
"signal_target = 2\n",
"t0 = 0\n",
"t1 = 300\n",
"t1 = 10\n",
"ts = np.linspace(t0, t1, 500)\n",
"tmax = 700\n",
"tmax = 10\n",
"dt0 = 0.0005555558569638981\n",
"dt1_factor = 5\n",
"dt1 = dt0 * dt1_factor\n",
Expand Down Expand Up @@ -173,7 +182,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -198,7 +207,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -217,7 +226,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -236,10 +245,14 @@
"\n",
"\n",
"def mutate(parents: jnp.ndarray, n_samples_per_parent, mutation_scale):\n",
" min_rate = parents.min()\n",
" # Generate mutated samples from each parent\n",
" mutated = jax.tree_util.tree_map(\n",
" lambda x: x + x * mutation_scale * np.random.randn(n_samples_per_parent, *x.shape), parents)\n",
" return mutated.reshape(mutated.shape[2], mutated.shape[0] * mutated.shape[1], *mutated.shape[3:])\n",
" mutated = jnp.power(10,\n",
" jax.tree_util.tree_map(\n",
" lambda x: x + x * mutation_scale * np.random.randn(n_samples_per_parent, *x.shape), jnp.log10(parents)))\n",
" mutated_nonzero = jnp.where(mutated < min_rate, min_rate, mutated)\n",
"\n",
" return mutated_nonzero.reshape(mutated_nonzero.shape[0] * mutated_nonzero.shape[1], *mutated_nonzero.shape[2:])\n",
"\n",
"\n",
"def simulate(y00, reverse_rates, sim_func, t0, t1, tmax, threshold):\n",
Expand All @@ -265,17 +278,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"next_starting, adaptability, sensitivity, precision = choose_next(params=gen_y, sol=(ys0, ys1), idxs_signal=idxs_signal, idxs_output=idxs_output,\n",
" use_sensitivity_func1=use_sensitivity_func1, choose_max=choose_max, n_samples_per_parent=n_samples_per_parent)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
Expand All @@ -286,33 +288,14 @@
"Starting iteration 1 out of 10\n",
"\n",
"\n",
"Steady states: 300 iterations. 163 left to steady out. 0:00:19.037496\n",
"Steady states: 600 iterations. 81 left to steady out. 0:00:37.313892\n",
"Done: 0:00:55.538999\n",
"Steady states: 300 iterations. 17 left to steady out. 0:00:17.923077\n",
"Steady states: 600 iterations. 9 left to steady out. 0:00:34.969756\n",
"Done: 0:00:52.138873\n",
"Done: 0:00:01.666165\n",
"Done: 0:00:01.609551\n",
"\n",
"\n",
"Starting iteration 2 out of 10\n",
"\n",
"\n"
]
},
{
"ename": "ValueError",
"evalue": "vmap got inconsistent sizes for array axes to be mapped:\n * one axis had size 100: axis 0 of argument y0 of type float32[100,9];\n * one axis had size 2: axis 0 of argument reverse_rates of type float32[2,100,6]",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[8], line 23\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m step \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(total_steps):\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mStarting iteration \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mstep\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m out of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtotal_steps\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m---> 23\u001b[0m ys0, ts0, ys1, ts1 \u001b[38;5;241m=\u001b[39m \u001b[43msimulate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[43m \u001b[49m\u001b[43my00\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgen_y\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msim_func\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt0\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt1\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtmax\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mthreshold_steady_state\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 25\u001b[0m next_starting, adaptability, sensitivity, precision \u001b[38;5;241m=\u001b[39m choose_next(params\u001b[38;5;241m=\u001b[39mgen_y, sol\u001b[38;5;241m=\u001b[39m(ys0, ys1), idxs_signal\u001b[38;5;241m=\u001b[39midxs_signal, idxs_output\u001b[38;5;241m=\u001b[39midxs_output,\n\u001b[1;32m 26\u001b[0m use_sensitivity_func1\u001b[38;5;241m=\u001b[39muse_sensitivity_func1, choose_max\u001b[38;5;241m=\u001b[39mchoose_max, n_samples_per_parent\u001b[38;5;241m=\u001b[39mn_samples_per_parent)\n\u001b[1;32m 27\u001b[0m gen_z \u001b[38;5;241m=\u001b[39m mutate(next_starting, n_samples_per_parent, mutation_scale)\n",
"Cell \u001b[0;32mIn[7], line 21\u001b[0m, in \u001b[0;36msimulate\u001b[0;34m(y00, reverse_rates, sim_func, t0, t1, tmax, threshold)\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msimulate\u001b[39m(y00, reverse_rates, sim_func, t0, t1, tmax, threshold):\n\u001b[0;32m---> 21\u001b[0m ys0, ts0 \u001b[38;5;241m=\u001b[39m \u001b[43msimulate_steady_states\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 22\u001b[0m \u001b[43m \u001b[49m\u001b[43my0\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43my00\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtotal_time\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtmax\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msim_func\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msim_func\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 23\u001b[0m \u001b[43m \u001b[49m\u001b[43mt0\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mt0\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt1\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mt1\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[43m \u001b[49m\u001b[43mthreshold\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mthreshold\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 25\u001b[0m \u001b[43m \u001b[49m\u001b[43mreverse_rates\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreverse_rates\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 26\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 27\u001b[0m y01 \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marray(ys0[:, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m 28\u001b[0m y01[:, np\u001b[38;5;241m.\u001b[39marray(idxs_signal)] \u001b[38;5;241m=\u001b[39m y01[:, np\u001b[38;5;241m.\u001b[39marray(\n\u001b[1;32m 29\u001b[0m idxs_signal)] \u001b[38;5;241m*\u001b[39m signal_target\n",
"File \u001b[0;32m/workdir/synbio_morpher/utils/modelling/solvers.py:107\u001b[0m, in \u001b[0;36msimulate_steady_states\u001b[0;34m(y0, total_time, sim_func, t0, t1, threshold, disable_logging, **sim_kwargs)\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 105\u001b[0m y00 \u001b[38;5;241m=\u001b[39m ys[:, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, :]\n\u001b[0;32m--> 107\u001b[0m ts, ys \u001b[38;5;241m=\u001b[39m \u001b[43msim_func\u001b[49m\u001b[43m(\u001b[49m\u001b[43my00\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43msim_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m np\u001b[38;5;241m.\u001b[39msum(np\u001b[38;5;241m.\u001b[39margmax(ts \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39minf)) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 110\u001b[0m ys \u001b[38;5;241m=\u001b[39m ys[:, :np\u001b[38;5;241m.\u001b[39margmax(ts \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39minf), :]\n",
" \u001b[0;31m[... skipping hidden 2 frame]\u001b[0m\n",
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/api.py:1296\u001b[0m, in \u001b[0;36m_mapped_axis_size\u001b[0;34m(fn, tree, vals, dims, name)\u001b[0m\n\u001b[1;32m 1294\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1295\u001b[0m msg\u001b[38;5;241m.\u001b[39mappend(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m * some axes (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mct\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m of them) had size \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msz\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, e.g. axis \u001b[39m\u001b[38;5;132;01m{\u001b[39;00max\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mex\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m;\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m-> 1296\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(msg)[:\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m])\n",
"\u001b[0;31mValueError\u001b[0m: vmap got inconsistent sizes for array axes to be mapped:\n * one axis had size 100: axis 0 of argument y0 of type float32[100,9];\n * one axis had size 2: axis 0 of argument reverse_rates of type float32[2,100,6]"
]
}
],
"source": [
Expand Down Expand Up @@ -350,26 +333,6 @@
" precision_all[step] = precision"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(20, 2, 6)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"next_starting.shape"
]
},
{
"cell_type": "code",
"execution_count": 9,
Expand Down

0 comments on commit 63160da

Please sign in to comment.