Skip to content

Commit

Permalink
bug: eq / ka
Browse files Browse the repository at this point in the history
  • Loading branch information
olive004 committed Nov 14, 2024
1 parent 9e7d84e commit 566f11d
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 205 deletions.
156 changes: 87 additions & 69 deletions notebooks/23_Monte_Carlo_adaptability_2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[cuda(id=0), cuda(id=1)]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
"ename": "ModuleNotFoundError",
"evalue": "No module named 'jax'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[2], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mjnp\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpandas\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpd\u001b[39;00m\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'jax'"
]
}
],
"source": [
Expand Down Expand Up @@ -85,7 +86,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -171,7 +172,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -203,7 +204,16 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"run_mc = False"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -492,50 +502,52 @@
" stepsize_controller=make_stepsize_controller(t0, t1, dt0, dt1, choice=stepsize_controller)))\n",
"\n",
"curr_en = energies\n",
"for step in range(total_steps):\n",
"\n",
" print(f'\\n\\nStarting iteration {step+1} out of {total_steps}\\n\\n')\n",
"if run_mc:\n",
" for step in range(total_steps):\n",
"\n",
" curr_eq = jax.vmap(\n",
" partial(equilibrium_constant_reparameterisation, initial=N0))(curr_en)\n",
" _, curr_rt = eqconstant_to_rates(curr_eq, k_a)\n",
" print(f'\\n\\nStarting iteration {step+1} out of {total_steps}\\n\\n')\n",
"\n",
" ys0, ts0, ys1, ts1 = simulate(\n",
" y00, curr_rt, sim_func, t0, t1, tmax, batch_size, threshold_steady_state)\n",
" next_idxs, adaptability, sensitivity, precision = choose_next(sol=(ys0, ys1), idxs_signal=idxs_signal, idxs_output=idxs_output,\n",
" use_sensitivity_func1=use_sensitivity_func1, choose_max=choose_max, \n",
" total_samples=total_samples, diversity=diversity)\n",
" print(f'Choosing {len(next_idxs)} next circuits')\n",
" if len(next_idxs) < choose_max:\n",
" print('Not enough circuits chosen, will randomly choose the rest')\n",
" idxs_rnd = choose_next_rnd(choose_max - len(next_idxs), total_samples)\n",
" next_idxs = jnp.concatenate([next_idxs, idxs_rnd])\n",
" \n",
" if np.mod(step, int(total_steps/5)) == 0:\n",
" plt.figure(figsize=(13, 5))\n",
" ax = plt.subplot(1, 2, 1)\n",
" sns.scatterplot(x=sensitivity[..., idxs_output].flatten(), y=precision[..., idxs_output].flatten(), hue=adaptability[..., idxs_output].flatten(), alpha=0.2)\n",
" plt.xscale('log')\n",
" plt.yscale('log')\n",
" ax = plt.subplot(1, 2, 2)\n",
" sns.histplot(x=sensitivity[:, idxs_output].flatten(), y=precision[:, idxs_output].flatten(), bins=50, log_scale=[True, True])\n",
" plt.suptitle(f'Step {step}')\n",
" curr_eq = jax.vmap(\n",
" partial(equilibrium_constant_reparameterisation, initial=N0))(curr_en)\n",
" _, curr_rt = eqconstant_to_rates(curr_eq, k_a)\n",
"\n",
" ys0, ts0, ys1, ts1 = simulate(\n",
" y00, curr_rt, sim_func, t0, t1, tmax, batch_size, threshold_steady_state)\n",
" next_idxs, adaptability, sensitivity, precision = choose_next(sol=(ys0, ys1), idxs_signal=idxs_signal, idxs_output=idxs_output,\n",
" use_sensitivity_func1=use_sensitivity_func1, choose_max=choose_max, \n",
" total_samples=total_samples, diversity=diversity)\n",
" print(f'Choosing {len(next_idxs)} next circuits')\n",
" if len(next_idxs) < choose_max:\n",
" print('Not enough circuits chosen, will randomly choose the rest')\n",
" idxs_rnd = choose_next_rnd(choose_max - len(next_idxs), total_samples)\n",
" next_idxs = jnp.concatenate([next_idxs, idxs_rnd])\n",
" \n",
" if np.mod(step, int(total_steps/5)) == 0:\n",
" plt.figure(figsize=(13, 5))\n",
" ax = plt.subplot(1, 2, 1)\n",
" sns.scatterplot(x=sensitivity[..., idxs_output].flatten(), y=precision[..., idxs_output].flatten(), hue=adaptability[..., idxs_output].flatten(), alpha=0.2)\n",
" plt.xscale('log')\n",
" plt.yscale('log')\n",
" ax = plt.subplot(1, 2, 2)\n",
" sns.histplot(x=sensitivity[:, idxs_output].flatten(), y=precision[:, idxs_output].flatten(), bins=50, log_scale=[True, True])\n",
" plt.suptitle(f'Step {step}')\n",
"\n",
"\n",
" # Save results\n",
" all_params_en[step] = curr_en\n",
" all_params_eq[step] = curr_eq\n",
" all_params_rt[step] = curr_rt\n",
" all_is_parent[step][next_idxs] = True\n",
" all_adaptability[step] = adaptability\n",
" all_sensitivity[step] = sensitivity\n",
" all_precision[step] = precision\n",
" # Save results\n",
" all_params_en[step] = curr_en\n",
" all_params_eq[step] = curr_eq\n",
" all_params_rt[step] = curr_rt\n",
" all_is_parent[step][next_idxs] = True\n",
" all_adaptability[step] = adaptability\n",
" all_sensitivity[step] = sensitivity\n",
" all_precision[step] = precision\n",
"\n",
" # Mutate energies\n",
" next_en = mutate_expand(\n",
" curr_en[next_idxs], n_samples_per_parent, mutation_scale)[:total_samples]\n",
" print(f'Mutated and expanding {len(next_idxs)} into {len(next_en)} next circuits')\n",
" curr_en = next_en"
" # Mutate energies\n",
" next_en = mutate_expand(\n",
" curr_en[next_idxs], n_samples_per_parent, mutation_scale)[:total_samples]\n",
" print(f'Mutated and expanding {len(next_idxs)} into {len(next_en)} next circuits')\n",
" curr_en = next_en"
]
},
{
Expand Down Expand Up @@ -565,14 +577,15 @@
}
],
"source": [
"plt.figure(figsize=(13, 5))\n",
"ax = plt.subplot(1, 2, 1)\n",
"sns.scatterplot(x=sensitivity[..., idxs_output].flatten(), y=precision[..., idxs_output].flatten(), hue=adaptability[..., idxs_output].flatten(), alpha=0.2)\n",
"plt.xscale('log')\n",
"plt.yscale('log')\n",
"ax = plt.subplot(1, 2, 2)\n",
"sns.histplot(x=sensitivity[:, idxs_output].flatten(), y=precision[:, idxs_output].flatten(), bins=50, log_scale=[True, True])\n",
"plt.suptitle(f'Step {step}')\n"
"if run_mc:\n",
" plt.figure(figsize=(13, 5))\n",
" ax = plt.subplot(1, 2, 1)\n",
" sns.scatterplot(x=sensitivity[..., idxs_output].flatten(), y=precision[..., idxs_output].flatten(), hue=adaptability[..., idxs_output].flatten(), alpha=0.2)\n",
" plt.xscale('log')\n",
" plt.yscale('log')\n",
" ax = plt.subplot(1, 2, 2)\n",
" sns.histplot(x=sensitivity[:, idxs_output].flatten(), y=precision[:, idxs_output].flatten(), bins=50, log_scale=[True, True])\n",
" plt.suptitle(f'Step {step}')\n"
]
},
{
Expand All @@ -584,20 +597,24 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"d = pd.DataFrame(data={'Adaptability': all_adaptability[..., idxs_output].flatten(), \n",
" 'Sensitivity': all_sensitivity[..., idxs_output].flatten(), \n",
" 'Precision': all_precision[..., idxs_output].flatten(), \n",
" 'Is parent circuit': np.repeat(all_is_parent.flatten(), repeats=len(species_output)),\n",
" 'Circuit idx': np.repeat(np.repeat(np.arange(total_samples), repeats=len(species_output)), repeats=total_steps),\n",
" 'Species': flatten_listlike([[s] * total_samples for s in species_output] * total_steps),\n",
" 'Step': np.repeat(np.arange(total_steps), repeats=total_samples*len(species_output)),\n",
" 'Params energy': [l.tolist() for l in np.repeat(all_params_en.flatten(), repeats=len(species_output)).reshape(-1, energies.shape[-1])],\n",
" 'Params equilibrium constants': [l.tolist() for l in np.repeat(all_params_eq.flatten(), repeats=len(species_output)).reshape(-1, energies.shape[-1])],\n",
" 'Params rates': [l.tolist() for l in np.repeat(all_params_rt.flatten(), repeats=len(species_output)).reshape(-1, energies.shape[-1])]})"
"fn_save = f'results/23_Monte_Carlo_adaptability_2/23_Monte_Carlo_adaptability_n{total_samples}.csv'\n",
"if os.path.exists(fn_save):\n",
" d = pd.read_csv(fn_save)\n",
"else:\n",
" d = pd.DataFrame(data={'Adaptability': all_adaptability[..., idxs_output].flatten(), \n",
" 'Sensitivity': all_sensitivity[..., idxs_output].flatten(), \n",
" 'Precision': all_precision[..., idxs_output].flatten(), \n",
" 'Is parent circuit': np.repeat(all_is_parent.flatten(), repeats=len(species_output)),\n",
" 'Circuit idx': np.repeat(np.repeat(np.arange(total_samples), repeats=len(species_output)), repeats=total_steps),\n",
" 'Species': flatten_listlike([[s] * total_samples for s in species_output] * total_steps),\n",
" 'Step': np.repeat(np.arange(total_steps), repeats=total_samples*len(species_output)),\n",
" 'Params energy': [l.tolist() for l in np.repeat(all_params_en.flatten(), repeats=len(species_output)).reshape(-1, energies.shape[-1])],\n",
" 'Params equilibrium constants': [l.tolist() for l in np.repeat(all_params_eq.flatten(), repeats=len(species_output)).reshape(-1, energies.shape[-1])],\n",
" 'Params rates': [l.tolist() for l in np.repeat(all_params_rt.flatten(), repeats=len(species_output)).reshape(-1, energies.shape[-1])]})"
]
},
{
Expand All @@ -606,7 +623,8 @@
"metadata": {},
"outputs": [],
"source": [
"d.sort_values(by='Adaptability', ascending=False).to_csv(f'results/23_Monte_Carlo_adaptability_2/23_Monte_Carlo_adaptability_n{total_samples}.csv', index=False)\n"
"if not os.path.exists(fn_save):\n",
" d.sort_values(by='Adaptability', ascending=False).to_csv(fn_save, index=False)\n"
]
},
{
Expand Down
Loading

0 comments on commit 566f11d

Please sign in to comment.