-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
376 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,330 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%load_ext autoreload\n", | ||
"%autoreload 2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Imports" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import jax\n", | ||
"import jax.numpy as jnp\n", | ||
"import pandas as pd\n", | ||
"import seaborn as sns\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import diffrax as dfx\n", | ||
"from typing import List\n", | ||
"\n", | ||
"from functools import partial\n", | ||
"import os\n", | ||
"import sys\n", | ||
"\n", | ||
"jax.config.update('jax_platform_name', 'gpu')\n", | ||
"\n", | ||
"if __package__ is None:\n", | ||
"\n", | ||
" module_path = os.path.abspath(os.path.join('..'))\n", | ||
" sys.path.append(module_path)\n", | ||
"\n", | ||
" __package__ = os.path.basename(module_path)\n", | ||
"\n", | ||
"\n", | ||
"np.random.seed(0)\n", | ||
"jax.devices()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from synbio_morpher.srv.parameter_prediction.simulator import make_piecewise_stepcontrol\n", | ||
"from synbio_morpher.utils.misc.type_handling import flatten_listlike\n", | ||
"from synbio_morpher.utils.modelling.physical import eqconstant_to_rates, equilibrium_constant_reparameterisation\n", | ||
"from synbio_morpher.utils.modelling.deterministic import bioreaction_sim_dfx_expanded\n", | ||
"from synbio_morpher.utils.modelling.solvers import get_diffrax_solver, make_stepsize_controller\n", | ||
"from synbio_morpher.utils.results.analytics.timeseries import calculate_adaptation, compute_peaks, compute_adaptability_full\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Set up test circuits" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def make_species_bound(species_unbound):\n", | ||
" return sorted(set(flatten_listlike([['-'.join(sorted([x, y])) for x in species_unbound] for y in species_unbound])))\n", | ||
"\n", | ||
"\n", | ||
"# RNA circuit settings\n", | ||
"species_unbound = ['RNA_0', 'RNA_1', 'RNA_2']\n", | ||
"species_bound = make_species_bound(species_unbound)\n", | ||
"species = species_unbound + species_bound\n", | ||
"species_signal = ['RNA_0']\n", | ||
"species_output = ['RNA_2']\n", | ||
"species_nonsignal = [s for s in species_unbound if s not in species_signal]\n", | ||
"idxs_signal = np.array([species.index(s) for s in species_signal])\n", | ||
"idxs_output = np.array([species.index(s) for s in species_output])\n", | ||
"idxs_unbound = np.array([species.index(s) for s in species_unbound])\n", | ||
"idxs_bound = np.array([species.index(s) for s in species_bound])\n", | ||
"signal_onehot = np.array([1 if s in idxs_signal else 0 for s in np.arange(len(species))])\n", | ||
"\n", | ||
"# Initial parameters\n", | ||
"n_circuits = 1000\n", | ||
"n_circuits_display = 30\n", | ||
"k_a = 0.00150958097\n", | ||
"N0 = 200\n", | ||
"y00 = np.array([[N0, N0, N0, 0, 0, 0, 0, 0, 0]]).astype(np.float32)\n", | ||
"y00 = np.repeat(y00, repeats=n_circuits, axis=0)\n", | ||
"\n", | ||
"# Dynamic Simulation parameters\n", | ||
"signal_target = 2\n", | ||
"t0 = 0\n", | ||
"t1 = 200\n", | ||
"ts = np.linspace(t0, t1, 500)\n", | ||
"dt0 = 0.0005555558569638981\n", | ||
"dt1_factor = 5\n", | ||
"dt1 = dt0 * dt1_factor\n", | ||
"max_steps = 16**4 * 10\n", | ||
"use_sensitivity_func1 = False\n", | ||
"sim_method = 'Dopri5'\n", | ||
"stepsize_controller = 'adaptive'\n", | ||
"\n", | ||
"# MC parameters\n", | ||
"total_steps = 10\n", | ||
"\n", | ||
"# Reactions\n", | ||
"energies = np.random.rand(n_circuits, len(species_unbound), len(species_unbound))\n", | ||
"energies = np.interp(energies, (energies.min(), energies.max()), (-25, 0))\n", | ||
"energies[np.tril_indices(len(species_unbound))] = energies[np.triu_indices(len(species_unbound))]\n", | ||
"eqconstants = jax.vmap(equilibrium_constant_reparameterisation)(energies, y00[:, idxs_unbound])\n", | ||
"forward_rates, reverse_rates = eqconstant_to_rates(eqconstants, k_a)\n", | ||
"forward_rates = np.array(list(map(lambda r: r[np.triu_indices(len(species_unbound))], forward_rates)))\n", | ||
"reverse_rates = np.array(list(map(lambda r: r[np.triu_indices(len(species_unbound))], reverse_rates)))\n", | ||
"\n", | ||
"inputs = np.array([\n", | ||
" [2, 0, 0, 0, 0, 0, 0, 0, 0],\n", | ||
" [1, 1, 0, 0, 0, 0, 0, 0, 0],\n", | ||
" [1, 0, 1, 0, 0, 0, 0, 0, 0],\n", | ||
" [0, 2, 0, 0, 0, 0, 0, 0, 0],\n", | ||
" [0, 1, 1, 0, 0, 0, 0, 0, 0],\n", | ||
" [0, 0, 2, 0, 0, 0, 0, 0, 0],\n", | ||
"], dtype=np.float64)\n", | ||
"outputs = np.array([\n", | ||
" [0, 0, 0, 1, 0, 0, 0, 0, 0],\n", | ||
" [0, 0, 0, 0, 1, 0, 0, 0, 0],\n", | ||
" [0, 0, 0, 0, 0, 1, 0, 0, 0],\n", | ||
" [0, 0, 0, 0, 0, 0, 1, 0, 0],\n", | ||
" [0, 0, 0, 0, 0, 0, 0, 1, 0],\n", | ||
" [0, 0, 0, 0, 0, 0, 0, 0, 1],\n", | ||
"], dtype=np.float64)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Initialise simulations" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"sim_func = jax.jit(jax.vmap(\n", | ||
" partial(bioreaction_sim_dfx_expanded,\n", | ||
" t0=t0, t1=t1, dt0=dt0,\n", | ||
" forward_rates=forward_rates,\n", | ||
" inputs=inputs,\n", | ||
" outputs=outputs,\n", | ||
" solver=get_diffrax_solver(\n", | ||
" sim_method),\n", | ||
" saveat=dfx.SaveAt(\n", | ||
" ts=jnp.linspace(t0, t1, 500)), # int(np.min([500, t1-t0]))))\n", | ||
" stepsize_controller=make_stepsize_controller(t0, t1, dt0, dt1,\n", | ||
" choice=stepsize_controller)\n", | ||
" )))\n", | ||
"sol_steady_states = jax.vmap(bioreaction_sim_dfx_expanded)(y00, reverse_rates)\n", | ||
"\n", | ||
"y01 = np.array(sol_steady_states.ys[:, -1])\n", | ||
"y01[:, np.array(idxs_signal)] = y01[:, np.array(idxs_signal)] * signal_target\n", | ||
"sol_signal = jax.vmap(bioreaction_sim_dfx_expanded)(y01, reverse_rates)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"adaptability, sensitivity, precision = jax.vmap(partial(compute_adaptability_full, idx_sig=idxs_signal[0], use_sensitivity_func1=use_sensitivity_func1))(\n", | ||
" sol_steady_states.ys, sol_signal.ys)\n", | ||
"\n", | ||
"sensitivity = np.array(sensitivity)\n", | ||
"precision = np.array(precision)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Monte Carlo iterations" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"\n", | ||
"def choose_next(batch: list, data_writer, distance_func, choose_max: int = 4, target_species: List[str] = ['RNA_1', 'RNA_2'], use_diversity: bool = False):\n", | ||
" \n", | ||
" def make_data(batch, batch_analytics, target_species: List[str]):\n", | ||
" d = pd.DataFrame(\n", | ||
" data=np.concatenate(\n", | ||
" [\n", | ||
" np.asarray([c.name for c in batch])[:, None],\n", | ||
" np.asarray([c.subname for c in batch])[:, None]\n", | ||
" ], axis=1\n", | ||
" ),\n", | ||
" columns=['Name', 'Subname']\n", | ||
" )\n", | ||
" d['Circuit Obj'] = batch\n", | ||
" species_names = [s.name for s in batch[0].model.species]\n", | ||
" t_idxs = {s: species_names.index(s) for s in species_names if s in target_species}\n", | ||
" for t in target_species:\n", | ||
" t_idx = t_idxs[t]\n", | ||
" d[f'Sensitivity species-{t}'] = np.asarray([b['sensitivity_wrt_species-6'][t_idx] for b in batch_analytics])\n", | ||
" d[f'Precision species-{t}'] = np.asarray([b['precision_wrt_species-6'][t_idx] for b in batch_analytics])\n", | ||
" d[f'Overshoot species-{t}'] = np.asarray([b['overshoot'][t_idx] for b in batch_analytics])\n", | ||
" \n", | ||
" rs = d[d['Subname'] == 'ref_circuit']\n", | ||
" d[f'Parent Sensitivity species-{t}'] = jax.tree_util.tree_map(lambda n: rs[rs['Name'] == n][f'Sensitivity species-{t}'].iloc[0], d['Name'].to_list())\n", | ||
" d[f'Parent Precision species-{t}'] = jax.tree_util.tree_map(lambda n: rs[rs['Name'] == n][f'Precision species-{t}'].iloc[0], d['Name'].to_list())\n", | ||
" \n", | ||
" d[f'dS species-{t}'] = np.asarray([b['sensitivity_wrt_species-6_diff_to_base_circuit'][t_idx] for b in batch_analytics])\n", | ||
" d[f'dP species-{t}'] = np.asarray([b['precision_wrt_species-6_diff_to_base_circuit'][t_idx] for b in batch_analytics])\n", | ||
" # d[f'dS species-{t}'] = d[f'Sensitivity species-{t}'] - d[f'Parent Sensitivity species-{t}']\n", | ||
" # d[f'dP species-{t}'] = d[f'Precision species-{t}'] - d[f'Parent Precision species-{t}']\n", | ||
" \n", | ||
" # d[f'Diag Distance species-{t}'] = distance_func(s=d[f'Sensitivity species-{t}'].to_numpy(), p=d[f'Precision species-{t}'].to_numpy())\n", | ||
" d[f'SP Prod species-{t}'] = sp_prod(s=d[f'Sensitivity species-{t}'].to_numpy(), p=d[f'Precision species-{t}'].to_numpy(), \n", | ||
" sp_factor=1, #(d[f'Precision species-{t}'] / d[f'Sensitivity species-{t}']).max(), \n", | ||
" s_weight=0) #np.log(d[f'Precision species-{t}']) / d[f'Sensitivity species-{t}'])\n", | ||
" d[f'Log Distance species-{t}'] = np.array(log_distance(s=d[f'Sensitivity species-{t}'].to_numpy(), p=d[f'Precision species-{t}'].to_numpy()))\n", | ||
" # d[f'SP and distance species-{t}'] = np.log( np.power(d[f'Log Distance species-{t}'], dist_weight) * np.log(d[f'SP Prod species-{t}']))\n", | ||
" d[f'SP and distance species-{t}'] = d[f'Sensitivity species-{t}'] * d[f'Log Distance species-{t}']\n", | ||
" \n", | ||
" return d\n", | ||
" \n", | ||
" def select_next(data_1, choose_max, t, use_diversity: bool):\n", | ||
" # filt = (data_1[f'dS species-{t}'] >= 0) & (data_1[f'dP species-{t}'] >= 0) & (\n", | ||
" # data_1[f'Sensitivity species-{t}'] >= data_1[data_1['Subname'] == 'ref_circuit'][f'Sensitivity species-{t}'].min()) & (\n", | ||
" # data_1[f'Precision species-{t}'] >= data_1[data_1['Subname'] == 'ref_circuit'][f'Precision species-{t}'].min())\n", | ||
" \n", | ||
" data_1['Diversity selection'] = False\n", | ||
" circuits_chosen = data_1.sort_values(\n", | ||
" by=[f'SP and distance species-{t}', f'Log Distance species-{t}', f'SP Prod species-{t}', 'Name', 'Subname'], ascending=False)['Circuit Obj'].iloc[:choose_max].to_list()\n", | ||
" prev_circuits = data_1[data_1['Subname'] == 'ref_circuit']\n", | ||
" keep_n = int(0.7 * choose_max)\n", | ||
" if use_diversity and all([c in prev_circuits for c in circuits_chosen]) and (len(data_1) >= keep_n):\n", | ||
" _, circuits_chosen = select_next(data_1[data_1['Circuit Obj'].isin(prev_circuits[:keep_n])], choose_max, t)\n", | ||
" data_1['Diversity selection'] = data_1['Circuit Obj'].isin(circuits_chosen)\n", | ||
" \n", | ||
" data_1['Next selected'] = data_1['Circuit Obj'].isin(circuits_chosen)\n", | ||
" return data_1, circuits_chosen\n", | ||
" \n", | ||
" def get_batch_analytics(batch, data_writer):\n", | ||
" batch_analytics = []\n", | ||
" for c in batch:\n", | ||
" if c.subname == 'ref_circuit':\n", | ||
" batch_analytics.append(\n", | ||
" load_json_as_dict(os.path.join(data_writer.top_write_dir, c.name, 'report_signal.json')))\n", | ||
" else:\n", | ||
" batch_analytics.append(\n", | ||
" load_json_as_dict(os.path.join(data_writer.top_write_dir, c.name, 'mutations', c.subname, 'report_signal.json'))\n", | ||
" )\n", | ||
" batch_analytics = jax.tree_util.tree_map(lambda x: np.float64(x), batch_analytics)\n", | ||
" return batch_analytics\n", | ||
" \n", | ||
" batch_analytics = get_batch_analytics(batch, data_writer)\n", | ||
" data_1 = make_data(batch, batch_analytics, target_species)\n", | ||
" \n", | ||
" t = target_species[0]\n", | ||
" # circuits_chosen = data_1[(data_1[f'dS species-{t}'] >= 0) & (data_1[f'dP species-{t}'] >= 0)].sort_values(by=[f'Sensitivity species-{t}', f'Precision species-{t}'], ascending=False)['Circuit Obj'].iloc[:choose_max].to_list()\n", | ||
" data_1, circuits_chosen = select_next(data_1, choose_max, t, use_diversity)\n", | ||
" return circuits_chosen, data_1\n", | ||
"\n", | ||
"\n", | ||
"def mutate(circuits):\n", | ||
" return circuits\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"for step in range(total_steps):\n", | ||
" \n", | ||
" print(f'\\n\\nStarting step {step+1} out of {total_steps}\\n\\n')\n", | ||
"\n", | ||
" batch = mutate(starting, evolver, algorithm=config['mutations_args']['algorithm'])\n", | ||
" batch = simulate(batch, modeller, config)\n", | ||
" starting, summary_data = choose_next(batch=expanded_batchs, data_writer=data_writer, distance_func=distance_func, \n", | ||
" choose_max=choose_max, target_species=target_species, use_diversity=config.get('use_diversity', False))\n", | ||
" starting = process_for_next_run(starting, data_writer=data_writer)\n", | ||
" " | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"name": "python", | ||
"version": "3.10.12" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.