Skip to content

Commit

Permalink
adding debug mode
Browse files Browse the repository at this point in the history
  • Loading branch information
olive004 committed Oct 17, 2024
1 parent 64b826d commit 719429d
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"purpose": "mutation_effect_on_interactions_signal",
"no_visualisations": false,
"no_numerical": true,
"debug_mode": false
"debug_mode": true
},
"interaction_simulator": {
"name": "IntaRNA",
Expand Down Expand Up @@ -89,9 +89,9 @@
"solver": "diffrax",
"use_batch_mutations": true,
"interaction_factor": 1,
"batch_size": 20000,
"batch_size": 5,
"max_circuits": 60000,
"device": "gpu",
"device": "cpu",
"threshold_steady_states": 0.05,
"use_rate_scaling": true
},
Expand Down
2 changes: 1 addition & 1 deletion synbio_morpher/utils/circuit/agnostic_circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def init_mutation(self):
self.use_prod_and_deg = True
self.model = None
self.circuit_size = None
self.qreactions = None
self.qreactions: QuantifiedReactions = None
self.interactions_state: str = 'uninitialised'
self.interactions = None
self.signal: Signal = None
Expand Down
54 changes: 33 additions & 21 deletions synbio_morpher/utils/circuit/agnostic_circuits/circuit_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from synbio_morpher.utils.misc.runtime import clear_caches
from synbio_morpher.utils.misc.type_handling import flatten_nested_dict, flatten_listlike, append_nest_dicts
from synbio_morpher.utils.results.visualisation import VisODE
from synbio_morpher.utils.modelling.deterministic import bioreaction_sim_dfx_expanded
from synbio_morpher.utils.modelling.deterministic import bioreaction_sim_dfx_expanded, bioreaction_sim_dfx_debug
from synbio_morpher.utils.modelling.solvers import get_diffrax_solver, make_stepsize_controller
from synbio_morpher.utils.evolution.mutation import implement_mutation
from synbio_morpher.utils.results.analytics.timeseries import generate_analytics
Expand All @@ -45,9 +45,6 @@
# Set modelling environment variables
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
debug_mode = True
if debug_mode:
jax.config.update("jax_disable_jit", True)


def wrap_queue_res(k, inp, q, f, **kwargs):
Expand All @@ -62,7 +59,8 @@ def wrap_queue_res(k, inp, q, f, **kwargs):
class CircuitModeller():

def __init__(self, result_writer=None, config: dict = {}) -> None:
self.result_writer = ResultWriter(None) if result_writer is None else result_writer
self.result_writer = ResultWriter(
None) if result_writer is None else result_writer
self.steady_state_args = config['simulation_steady_state']
self.simulator_args = config['interaction_simulator']
self.simulation_args = config.get('simulation', {})
Expand Down Expand Up @@ -92,6 +90,9 @@ def __init__(self, result_writer=None, config: dict = {}) -> None:
jax.config.update('jax_platform_name', config.get(
'simulation', {}).get('device', 'cpu'))

if self.debug_mode:
jax.config.update("jax_disable_jit", True)

def init_circuit(self, circuit: Circuit) -> Circuit:
if self.simulation_args.get('use_rate_scaling', True):
circuit = self.scale_rates([circuit])[0]
Expand Down Expand Up @@ -276,24 +277,34 @@ def compute_steady_states(self, circuits: List[Circuit],

elif solver_type == 'diffrax':
ref_circuit = circuits[0]
forward_rates = ref_circuit.qreactions.reactions.forward_rates # Assuming all forward rates are the same
# Assuming all forward rates are the same
forward_rates = ref_circuit.qreactions.reactions.forward_rates
reverse_rates = np.asarray(
[c.qreactions.reactions.reverse_rates for c in circuits])
y0 = np.asarray([c.qreactions.quantities for c in circuits])
signal_onehot = np.zeros_like(
ref_circuit.signal.reactions_onehot) if ref_circuit.use_prod_and_deg else np.zeros_like(ref_circuit.signal.onehot)

sim_func = jax.vmap(partial(bioreaction_sim_dfx_expanded,
t0=self.t0, t1=self.t1, dt0=self.dt0,
signal=vanilla_return, signal_onehot=signal_onehot,
inputs=ref_circuit.qreactions.reactions.inputs,
outputs=ref_circuit.qreactions.reactions.outputs,
forward_rates=forward_rates,
solver=get_diffrax_solver(
self.steady_state_args.get('method', 'Dopri5')),
saveat=dfx.SaveAt(
ts=np.linspace(self.t0, self.t1, int(np.min([200, self.t1-self.t0])))),
stepsize_controller=make_stepsize_controller(self.t0, self.t1, self.dt0, self.dt1, choice='piecewise')))
if self.debug_mode:
sim_func = partial(bioreaction_sim_dfx_debug,
t0=self.t0, t1=self.t1, dt0=self.dt0,
inputs=ref_circuit.qreactions.reactions.inputs,
outputs=ref_circuit.qreactions.reactions.outputs,
forward_rates=forward_rates,
save_every_n_tsteps=5
)
else:
sim_func = jax.vmap(partial(bioreaction_sim_dfx_expanded,
t0=self.t0, t1=self.t1, dt0=self.dt0,
signal=vanilla_return, signal_onehot=signal_onehot,
inputs=ref_circuit.qreactions.reactions.inputs,
outputs=ref_circuit.qreactions.reactions.outputs,
forward_rates=forward_rates,
solver=get_diffrax_solver(
self.steady_state_args.get('method', 'Dopri5')),
saveat=dfx.SaveAt(
ts=np.linspace(self.t0, self.t1, int(np.min([200, self.t1-self.t0])))),
stepsize_controller=make_stepsize_controller(self.t0, self.t1, self.dt0, self.dt1, choice='piecewise')))

b_copynumbers, t = simulate_steady_states(
y0=y0, total_time=self.tmax, sim_func=sim_func,
Expand All @@ -305,7 +316,7 @@ def compute_steady_states(self, circuits: List[Circuit],
b_copynumbers = np.swapaxes(b_copynumbers, 1, 2)

elif solver_type == 'torchode':
raise NotImplementedError
raise NotImplementedError()
# import torchode as tode
# import torch
# ref_circuit = circuits[0]
Expand Down Expand Up @@ -334,11 +345,12 @@ def compute_steady_states(self, circuits: List[Circuit],

elif solver_type == 'torchdiffeq':
t_eval = np.linspace(self.t0, self.t1, 100)
raise NotImplementedError
raise NotImplementedError()
# tdeq.odeint_adjoint(sim_func, starting_states, t)

else:
raise ValueError(f'The chosen solver `{solver_type}` is not supported.')
raise ValueError(
f'The chosen solver `{solver_type}` is not supported.')

return np.asarray(b_copynumbers), np.squeeze(t)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"IntaRNA": "./synbio_morpher/utils/common/configs/simulators/intaRNA_args.json",
"simulation_steady_state": {
"max_time": 20,
"method": "Dopri5",
"method": "Euler",
"steady_state_solver": "diffrax",
"time_interval": 0.01,
"use_zero_rates": true
Expand Down
38 changes: 33 additions & 5 deletions synbio_morpher/utils/modelling/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
# All rights reserved.

# This source code is licensed under the MIT-style license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.

import logging
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
import diffrax as dfx
from bioreaction.simulation.simfuncs.basic_de import bioreaction_sim, bioreaction_sim_expanded
from diffrax._step_size_controller.base import AbstractStepSizeController
from bioreaction.simulation.simfuncs.basic_de import bioreaction_sim, bioreaction_sim_expanded, one_step_de_sim_expanded
from bioreaction.model.data_containers import QuantifiedReactions
from synbio_morpher.utils.modelling.base import Modeller

Expand Down Expand Up @@ -107,13 +108,13 @@ def bioreaction_sim_dfx_expanded(y0, t0, t1, dt0,
saveat=dfx.SaveAt(
t0=True, t1=True, steps=True),
max_steps=16**5,
stepsize_controller=dfx.ConstantStepSize()):
stepsize_controller: AbstractStepSizeController = dfx.ConstantStepSize()):
if type(stepsize_controller) == dfx.StepTo:
dt0 = None
term = dfx.ODETerm(
partial(bioreaction_sim_expanded,
inputs=inputs, outputs=outputs,
# signal=signal,
# signal=signal,
# signal_onehot=signal_onehot,
forward_rates=forward_rates.squeeze(), reverse_rates=reverse_rates.squeeze()
)
Expand All @@ -123,3 +124,30 @@ def bioreaction_sim_dfx_expanded(y0, t0, t1, dt0,
y0=y0.squeeze(),
saveat=saveat, max_steps=max_steps,
stepsize_controller=stepsize_controller)


def bioreaction_sim_dfx_debug(y0, reverse_rates,
t0, t1, dt0,
inputs, outputs, forward_rates,
save_every_n_tsteps: int = 1
):

y = y0
num_saves = int((t1 - t0) // (dt0 * save_every_n_tsteps) + 1)
saves_y = np.zeros((num_saves, *y0.shape))
saves_t = np.arange(t0, t1, dt0*save_every_n_tsteps)

save_index = 0 # To keep track of saves
for i, ti in enumerate(np.arange(t0, t1, dt0)):
for iy, yi in enumerate(y):
y[iy] = yi + one_step_de_sim_expanded(
spec_conc=yi, inputs=inputs,
outputs=outputs,
forward_rates=forward_rates,
reverse_rates=reverse_rates[iy]) * dt0

if i % save_every_n_tsteps == 0:
saves_y[save_index, iy] = y
save_index += 1

return saves_y, saves_t

0 comments on commit 719429d

Please sign in to comment.