Skip to content

Commit

Permalink
new MC notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
olive004 committed Oct 27, 2024
1 parent 60b3b80 commit 8d43ecb
Show file tree
Hide file tree
Showing 4 changed files with 376 additions and 57 deletions.
33 changes: 1 addition & 32 deletions notebooks/22_adaptation_random_sim.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
"from synbio_morpher.srv.parameter_prediction.simulator import make_piecewise_stepcontrol\n",
"from synbio_morpher.utils.misc.type_handling import flatten_listlike\n",
"from synbio_morpher.utils.modelling.physical import eqconstant_to_rates, equilibrium_constant_reparameterisation\n",
"from synbio_morpher.utils.results.analytics.timeseries import calculate_adaptation, compute_sensitivity, compute_sensitivity2, compute_precision, compute_peaks"
"from synbio_morpher.utils.results.analytics.timeseries import calculate_adaptation, compute_adaptability_full, compute_peaks"
]
},
{
Expand Down Expand Up @@ -346,37 +346,6 @@
"metadata": {},
"outputs": [],
"source": [
"def compute_adaptability_full(ys_steady, ys_signal, idx_sig, use_sensitivity_func1):\n",
" \"\"\" ts: time series with dimensions [t, species] \"\"\"\n",
"\n",
" if use_sensitivity_func1:\n",
" peaks = compute_peaks(ys_steady[-1], ys_signal[-1],\n",
" ys_signal.max(axis=0), ys_signal.min(axis=0))\n",
"\n",
" s = compute_sensitivity(\n",
" signal_idx=idx_sig,\n",
" starting_states=ys_steady[-1],\n",
" peaks=peaks,\n",
" )\n",
" else:\n",
" s = compute_sensitivity2(\n",
" starting_states=ys_steady[-1],\n",
" minv=ys_signal.min(axis=0),\n",
" maxv=ys_signal.max(axis=0),\n",
" signal_0=ys_steady[-1, idx_sig],\n",
" signal_1=ys_signal[0, idx_sig]\n",
" )\n",
" p = compute_precision(\n",
" starting_states=ys_steady[-1],\n",
" steady_states=ys_signal[-1],\n",
" signal_0=ys_steady[-1, idx_sig],\n",
" signal_1=ys_signal[0, idx_sig])\n",
" a = calculate_adaptation(s, p)\n",
" # a = jnp.log(a)\n",
" \n",
" return a, s, p\n",
"\n",
"\n",
"adaptability, sensitivity, precision = jax.vmap(partial(compute_adaptability_full, idx_sig=idxs_signal[0], use_sensitivity_func1=use_sensitivity_func1))(\n",
" sol_steady_states.ys, sol_signal.ys)\n",
"\n",
Expand Down
330 changes: 330 additions & 0 deletions notebooks/23_Monte_Carlo_adaptability_2.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,330 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"import diffrax as dfx\n",
"from typing import List\n",
"\n",
"from functools import partial\n",
"import os\n",
"import sys\n",
"\n",
"jax.config.update('jax_platform_name', 'gpu')\n",
"\n",
"if __package__ is None:\n",
"\n",
" module_path = os.path.abspath(os.path.join('..'))\n",
" sys.path.append(module_path)\n",
"\n",
" __package__ = os.path.basename(module_path)\n",
"\n",
"\n",
"np.random.seed(0)\n",
"jax.devices()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from synbio_morpher.srv.parameter_prediction.simulator import make_piecewise_stepcontrol\n",
"from synbio_morpher.utils.misc.type_handling import flatten_listlike\n",
"from synbio_morpher.utils.modelling.physical import eqconstant_to_rates, equilibrium_constant_reparameterisation\n",
"from synbio_morpher.utils.modelling.deterministic import bioreaction_sim_dfx_expanded\n",
"from synbio_morpher.utils.modelling.solvers import get_diffrax_solver, make_stepsize_controller\n",
"from synbio_morpher.utils.results.analytics.timeseries import calculate_adaptation, compute_peaks, compute_adaptability_full\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Set up test circuits"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def make_species_bound(species_unbound):\n",
" return sorted(set(flatten_listlike([['-'.join(sorted([x, y])) for x in species_unbound] for y in species_unbound])))\n",
"\n",
"\n",
"# RNA circuit settings\n",
"species_unbound = ['RNA_0', 'RNA_1', 'RNA_2']\n",
"species_bound = make_species_bound(species_unbound)\n",
"species = species_unbound + species_bound\n",
"species_signal = ['RNA_0']\n",
"species_output = ['RNA_2']\n",
"species_nonsignal = [s for s in species_unbound if s not in species_signal]\n",
"idxs_signal = np.array([species.index(s) for s in species_signal])\n",
"idxs_output = np.array([species.index(s) for s in species_output])\n",
"idxs_unbound = np.array([species.index(s) for s in species_unbound])\n",
"idxs_bound = np.array([species.index(s) for s in species_bound])\n",
"signal_onehot = np.array([1 if s in idxs_signal else 0 for s in np.arange(len(species))])\n",
"\n",
"# Initial parameters\n",
"n_circuits = 1000\n",
"n_circuits_display = 30\n",
"k_a = 0.00150958097\n",
"N0 = 200\n",
"y00 = np.array([[N0, N0, N0, 0, 0, 0, 0, 0, 0]]).astype(np.float32)\n",
"y00 = np.repeat(y00, repeats=n_circuits, axis=0)\n",
"\n",
"# Dynamic Simulation parameters\n",
"signal_target = 2\n",
"t0 = 0\n",
"t1 = 200\n",
"ts = np.linspace(t0, t1, 500)\n",
"dt0 = 0.0005555558569638981\n",
"dt1_factor = 5\n",
"dt1 = dt0 * dt1_factor\n",
"max_steps = 16**4 * 10\n",
"use_sensitivity_func1 = False\n",
"sim_method = 'Dopri5'\n",
"stepsize_controller = 'adaptive'\n",
"\n",
"# MC parameters\n",
"total_steps = 10\n",
"\n",
"# Reactions\n",
"energies = np.random.rand(n_circuits, len(species_unbound), len(species_unbound))\n",
"energies = np.interp(energies, (energies.min(), energies.max()), (-25, 0))\n",
"energies[np.tril_indices(len(species_unbound))] = energies[np.triu_indices(len(species_unbound))]\n",
"eqconstants = jax.vmap(equilibrium_constant_reparameterisation)(energies, y00[:, idxs_unbound])\n",
"forward_rates, reverse_rates = eqconstant_to_rates(eqconstants, k_a)\n",
"forward_rates = np.array(list(map(lambda r: r[np.triu_indices(len(species_unbound))], forward_rates)))\n",
"reverse_rates = np.array(list(map(lambda r: r[np.triu_indices(len(species_unbound))], reverse_rates)))\n",
"\n",
"inputs = np.array([\n",
" [2, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [1, 1, 0, 0, 0, 0, 0, 0, 0],\n",
" [1, 0, 1, 0, 0, 0, 0, 0, 0],\n",
" [0, 2, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 1, 1, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 2, 0, 0, 0, 0, 0, 0],\n",
"], dtype=np.float64)\n",
"outputs = np.array([\n",
" [0, 0, 0, 1, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 1, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 1, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 1, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 1, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 1],\n",
"], dtype=np.float64)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Initialise simulations"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sim_func = jax.jit(jax.vmap(\n",
" partial(bioreaction_sim_dfx_expanded,\n",
" t0=t0, t1=t1, dt0=dt0,\n",
" forward_rates=forward_rates,\n",
" inputs=inputs,\n",
" outputs=outputs,\n",
" solver=get_diffrax_solver(\n",
" sim_method),\n",
" saveat=dfx.SaveAt(\n",
" ts=jnp.linspace(t0, t1, 500)), # int(np.min([500, t1-t0]))))\n",
" stepsize_controller=make_stepsize_controller(t0, t1, dt0, dt1,\n",
" choice=stepsize_controller)\n",
" )))\n",
"sol_steady_states = jax.vmap(bioreaction_sim_dfx_expanded)(y00, reverse_rates)\n",
"\n",
"y01 = np.array(sol_steady_states.ys[:, -1])\n",
"y01[:, np.array(idxs_signal)] = y01[:, np.array(idxs_signal)] * signal_target\n",
"sol_signal = jax.vmap(bioreaction_sim_dfx_expanded)(y01, reverse_rates)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"adaptability, sensitivity, precision = jax.vmap(partial(compute_adaptability_full, idx_sig=idxs_signal[0], use_sensitivity_func1=use_sensitivity_func1))(\n",
" sol_steady_states.ys, sol_signal.ys)\n",
"\n",
"sensitivity = np.array(sensitivity)\n",
"precision = np.array(precision)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Monte Carlo iterations"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"def choose_next(batch: list, data_writer, distance_func, choose_max: int = 4, target_species: List[str] = ['RNA_1', 'RNA_2'], use_diversity: bool = False):\n",
" \n",
" def make_data(batch, batch_analytics, target_species: List[str]):\n",
" d = pd.DataFrame(\n",
" data=np.concatenate(\n",
" [\n",
" np.asarray([c.name for c in batch])[:, None],\n",
" np.asarray([c.subname for c in batch])[:, None]\n",
" ], axis=1\n",
" ),\n",
" columns=['Name', 'Subname']\n",
" )\n",
" d['Circuit Obj'] = batch\n",
" species_names = [s.name for s in batch[0].model.species]\n",
" t_idxs = {s: species_names.index(s) for s in species_names if s in target_species}\n",
" for t in target_species:\n",
" t_idx = t_idxs[t]\n",
" d[f'Sensitivity species-{t}'] = np.asarray([b['sensitivity_wrt_species-6'][t_idx] for b in batch_analytics])\n",
" d[f'Precision species-{t}'] = np.asarray([b['precision_wrt_species-6'][t_idx] for b in batch_analytics])\n",
" d[f'Overshoot species-{t}'] = np.asarray([b['overshoot'][t_idx] for b in batch_analytics])\n",
" \n",
" rs = d[d['Subname'] == 'ref_circuit']\n",
" d[f'Parent Sensitivity species-{t}'] = jax.tree_util.tree_map(lambda n: rs[rs['Name'] == n][f'Sensitivity species-{t}'].iloc[0], d['Name'].to_list())\n",
" d[f'Parent Precision species-{t}'] = jax.tree_util.tree_map(lambda n: rs[rs['Name'] == n][f'Precision species-{t}'].iloc[0], d['Name'].to_list())\n",
" \n",
" d[f'dS species-{t}'] = np.asarray([b['sensitivity_wrt_species-6_diff_to_base_circuit'][t_idx] for b in batch_analytics])\n",
" d[f'dP species-{t}'] = np.asarray([b['precision_wrt_species-6_diff_to_base_circuit'][t_idx] for b in batch_analytics])\n",
" # d[f'dS species-{t}'] = d[f'Sensitivity species-{t}'] - d[f'Parent Sensitivity species-{t}']\n",
" # d[f'dP species-{t}'] = d[f'Precision species-{t}'] - d[f'Parent Precision species-{t}']\n",
" \n",
" # d[f'Diag Distance species-{t}'] = distance_func(s=d[f'Sensitivity species-{t}'].to_numpy(), p=d[f'Precision species-{t}'].to_numpy())\n",
" d[f'SP Prod species-{t}'] = sp_prod(s=d[f'Sensitivity species-{t}'].to_numpy(), p=d[f'Precision species-{t}'].to_numpy(), \n",
" sp_factor=1, #(d[f'Precision species-{t}'] / d[f'Sensitivity species-{t}']).max(), \n",
" s_weight=0) #np.log(d[f'Precision species-{t}']) / d[f'Sensitivity species-{t}'])\n",
" d[f'Log Distance species-{t}'] = np.array(log_distance(s=d[f'Sensitivity species-{t}'].to_numpy(), p=d[f'Precision species-{t}'].to_numpy()))\n",
" # d[f'SP and distance species-{t}'] = np.log( np.power(d[f'Log Distance species-{t}'], dist_weight) * np.log(d[f'SP Prod species-{t}']))\n",
" d[f'SP and distance species-{t}'] = d[f'Sensitivity species-{t}'] * d[f'Log Distance species-{t}']\n",
" \n",
" return d\n",
" \n",
" def select_next(data_1, choose_max, t, use_diversity: bool):\n",
" # filt = (data_1[f'dS species-{t}'] >= 0) & (data_1[f'dP species-{t}'] >= 0) & (\n",
" # data_1[f'Sensitivity species-{t}'] >= data_1[data_1['Subname'] == 'ref_circuit'][f'Sensitivity species-{t}'].min()) & (\n",
" # data_1[f'Precision species-{t}'] >= data_1[data_1['Subname'] == 'ref_circuit'][f'Precision species-{t}'].min())\n",
" \n",
" data_1['Diversity selection'] = False\n",
" circuits_chosen = data_1.sort_values(\n",
" by=[f'SP and distance species-{t}', f'Log Distance species-{t}', f'SP Prod species-{t}', 'Name', 'Subname'], ascending=False)['Circuit Obj'].iloc[:choose_max].to_list()\n",
" prev_circuits = data_1[data_1['Subname'] == 'ref_circuit']\n",
" keep_n = int(0.7 * choose_max)\n",
" if use_diversity and all([c in prev_circuits for c in circuits_chosen]) and (len(data_1) >= keep_n):\n",
" _, circuits_chosen = select_next(data_1[data_1['Circuit Obj'].isin(prev_circuits[:keep_n])], choose_max, t)\n",
" data_1['Diversity selection'] = data_1['Circuit Obj'].isin(circuits_chosen)\n",
" \n",
" data_1['Next selected'] = data_1['Circuit Obj'].isin(circuits_chosen)\n",
" return data_1, circuits_chosen\n",
" \n",
" def get_batch_analytics(batch, data_writer):\n",
" batch_analytics = []\n",
" for c in batch:\n",
" if c.subname == 'ref_circuit':\n",
" batch_analytics.append(\n",
" load_json_as_dict(os.path.join(data_writer.top_write_dir, c.name, 'report_signal.json')))\n",
" else:\n",
" batch_analytics.append(\n",
" load_json_as_dict(os.path.join(data_writer.top_write_dir, c.name, 'mutations', c.subname, 'report_signal.json'))\n",
" )\n",
" batch_analytics = jax.tree_util.tree_map(lambda x: np.float64(x), batch_analytics)\n",
" return batch_analytics\n",
" \n",
" batch_analytics = get_batch_analytics(batch, data_writer)\n",
" data_1 = make_data(batch, batch_analytics, target_species)\n",
" \n",
" t = target_species[0]\n",
" # circuits_chosen = data_1[(data_1[f'dS species-{t}'] >= 0) & (data_1[f'dP species-{t}'] >= 0)].sort_values(by=[f'Sensitivity species-{t}', f'Precision species-{t}'], ascending=False)['Circuit Obj'].iloc[:choose_max].to_list()\n",
" data_1, circuits_chosen = select_next(data_1, choose_max, t, use_diversity)\n",
" return circuits_chosen, data_1\n",
"\n",
"\n",
"def mutate(circuits):\n",
" return circuits\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"for step in range(total_steps):\n",
" \n",
" print(f'\\n\\nStarting step {step+1} out of {total_steps}\\n\\n')\n",
"\n",
" batch = mutate(starting, evolver, algorithm=config['mutations_args']['algorithm'])\n",
" batch = simulate(batch, modeller, config)\n",
" starting, summary_data = choose_next(batch=expanded_batchs, data_writer=data_writer, distance_func=distance_func, \n",
" choose_max=choose_max, target_species=target_species, use_diversity=config.get('use_diversity', False))\n",
" starting = process_for_next_run(starting, data_writer=data_writer)\n",
" "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
4 changes: 2 additions & 2 deletions synbio_morpher/utils/modelling/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# LICENSE file in the root directory of this source tree.

from functools import partial
from typing import Union
from typing import Union, Optional
import numpy as np
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -104,7 +104,7 @@ def bioreaction_sim_wrapper(qreactions: QuantifiedReactions, t0, t1, dt0,

def bioreaction_sim_dfx_expanded(y0, t0, t1, dt0,
inputs, outputs, forward_rates, reverse_rates,
signal, signal_onehot: Union[int, np.ndarray],
signal=None, signal_onehot: Optional[Union[int, np.ndarray]]=None,
solver=dfx.Tsit5(),
saveat=dfx.SaveAt(
t0=True, t1=True, steps=True),
Expand Down
Loading

0 comments on commit 8d43ecb

Please sign in to comment.