Skip to content

Commit

Permalink
removing Deterministic modeller
Browse files Browse the repository at this point in the history
  • Loading branch information
olive004 committed Mar 21, 2024
1 parent 9a766a6 commit b00c172
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 48 deletions.
4 changes: 0 additions & 4 deletions notebooks/19_Jacobian.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,6 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"from synbio_morpher.utils.modelling.deterministic import Deterministic\n",
"\n",
"\n",
"unbound_species = ['RNA_0', 'RNA_1', 'RNA_2']\n",
"bound_species = sorted(set(flatten_listlike([['-'.join(sorted([x, y])) for x in unbound_species] for y in unbound_species])))\n",
"species = unbound_species + bound_species\n",
Expand Down
58 changes: 15 additions & 43 deletions synbio_morpher/utils/circuit/agnostic_circuits/circuit_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,11 @@
from synbio_morpher.utils.misc.runtime import clear_caches
from synbio_morpher.utils.misc.type_handling import flatten_nested_dict, flatten_listlike, append_nest_dicts
from synbio_morpher.utils.results.visualisation import VisODE
from synbio_morpher.utils.modelling.base import Modeller
from synbio_morpher.utils.modelling.deterministic import Deterministic, bioreaction_sim_dfx_expanded
from synbio_morpher.utils.modelling.deterministic import bioreaction_sim_dfx_expanded
from synbio_morpher.utils.modelling.solvers import get_diffrax_solver
from synbio_morpher.utils.evolution.mutation import implement_mutation
from synbio_morpher.utils.results.analytics.timeseries import generate_analytics
from synbio_morpher.utils.results.result_writer import ResultWriter
from synbio_morpher.utils.signal.signals_new import Signal


def wrap_queue_res(k, inp, q, f, **kwargs):
Expand Down Expand Up @@ -163,8 +161,10 @@ def compute_interactions_batch(self, circuits: List[Circuit], batch=True):
for i in range(0, len(circuits), n_threads):
j = i + n_threads if i + \
n_threads < len(circuits) else len(circuits)
executor = concurrent.futures.ProcessPoolExecutor(max_workers=j - i)
results = list(executor.map(self.compute_interactions, circuits[i:j]))
executor = concurrent.futures.ProcessPoolExecutor(
max_workers=j - i)
results = list(executor.map(
self.compute_interactions, circuits[i:j]))
circuits[i:j] = results

# manager = multiprocessing.Manager()
Expand Down Expand Up @@ -221,12 +221,8 @@ def run_interaction_simulator(self, species: List[Species], quantities, filename
return self.interaction_simulator.run(data, quantities=quantities, compute_by_filename=False)

def find_steady_states(self, circuits: List[Circuit], batch=True) -> List[Circuit]:
modeller_steady_state = Deterministic(
max_time=self.steady_state_args['max_time'],
time_interval=self.steady_state_args['time_interval'])

b_steady_states, t = self.compute_steady_states(modeller_steady_state,
circuits=circuits,
b_steady_states, t = self.compute_steady_states(circuits=circuits,
solver_type=self.steady_state_args['steady_state_solver'],
use_zero_rates=self.steady_state_args['use_zero_rates'])

Expand All @@ -244,7 +240,7 @@ def find_steady_states(self, circuits: List[Circuit], batch=True) -> List[Circui
no_write=False)
return circuits

def compute_steady_states(self, modeller: Modeller, circuits: List[Circuit],
def compute_steady_states(self, circuits: List[Circuit],
solver_type: str = 'jax', use_zero_rates: bool = False) -> List[Circuit]:

if solver_type == 'ivp':
Expand All @@ -262,7 +258,7 @@ def compute_steady_states(self, modeller: Modeller, circuits: List[Circuit],
steady_state_result = integrate.solve_ivp(
partial(bioreaction_sim, args=None, reactions=r, signal=vanilla_return,
signal_onehot=signal_onehot),
(0, modeller.max_time),
(self.t0, self.t1),
y0=circuit.qreactions.quantities,
method=self.steady_state_args.get('method', 'DOP853'))
if not steady_state_result.success:
Expand All @@ -288,7 +284,8 @@ def compute_steady_states(self, modeller: Modeller, circuits: List[Circuit],
inputs=ref_circuit.qreactions.reactions.inputs,
outputs=ref_circuit.qreactions.reactions.outputs,
forward_rates=forward_rates,
solver=get_diffrax_solver(self.steady_state_args.get('method', 'Dopri5')),
solver=get_diffrax_solver(
self.steady_state_args.get('method', 'Dopri5')),
saveat=dfx.SaveAt(
ts=np.linspace(self.t0, self.t1, int(np.min([200, self.t1-self.t0])))),
stepsize_controller=self.make_stepsize_controller(
Expand All @@ -303,7 +300,7 @@ def compute_steady_states(self, modeller: Modeller, circuits: List[Circuit],
)

b_copynumbers = np.swapaxes(b_copynumbers, 1, 2)

elif solver_type == 'torchode':
raise NotImplementedError
# import torchode as tode
Expand All @@ -321,7 +318,7 @@ def compute_steady_states(self, modeller: Modeller, circuits: List[Circuit],
# outputs=ref_circuit.qreactions.reactions.outputs,
# forward_rates=forward_rates.squeeze(), reverse_rates=reverse_rates.squeeze()
# )

# t_eval = np.linspace(self.t0, self.t1, int(np.min([200, self.t1-self.t0]))).repeat(len(circuits))
# prob = tode.InitialValueProblem(y0=y0, t_eval=t_eval)
# odeterm = tode.ODETerm(sim_func, with_args=True)
Expand All @@ -331,13 +328,12 @@ def compute_steady_states(self, modeller: Modeller, circuits: List[Circuit],
# solver = tode.AutoDiffAdjoint(step_method, step_controller)
# self.solver = torch.compile(solver)
# sol = self.solver.solve(prob)

elif solver_type == 'torchdiffeq':
t_eval = np.linspace(self.t0, self.t1, 100)
raise NotImplementedError
# tdeq.odeint_adjoint(sim_func, starting_states, t)


return np.asarray(b_copynumbers), np.squeeze(t)

def model_circuit(self, y0: np.ndarray, circuit: Circuit):
Expand Down Expand Up @@ -598,31 +594,6 @@ def wrap_mutations(self, circuit: Circuit, methods: dict, include_normal_run=Tru
self.result_writer.unsubdivide()
return circuit

def make_stepsize_controller(self, choice: str, **kwargs):
""" The choice can be either log or piecewise """
if choice == 'log':
return self.make_log_stepcontrol(**kwargs)
elif choice == 'piecewise':
return make_piecewise_stepcontrol(t0=self.t0, t1=self.t1, dt0=self.dt0, dt1=self.dt1, **kwargs)
elif choice == 'adaptive':
return dfx.PIDController(rtol=1e-3, atol=1e-5)
else:
raise ValueError(
f'The stepsize controller option `{choice}` is not available.')

def make_log_stepcontrol(self, upper_log: int = 3):
num = 1000
x = np.interp(np.logspace(0, upper_log, num=num), [
1, np.power(10, upper_log)], [self.dt0, self.dt1])
while np.cumsum(x)[-1] < self.t1:
x = np.interp(np.logspace(0, upper_log, num=num), [
1, np.power(10, upper_log)], [self.dt0, self.dt1])
num += 1
ts = np.cumsum(x)
ts[0] = self.t0
ts[-1] = self.t1
return dfx.StepTo(ts)

def prepare_internal_funcs(self, circuits: List[Circuit]):
""" Create simulation function. If more customisation is needed per circuit, move
variables into the relevant wrapper simulation method """
Expand Down Expand Up @@ -653,7 +624,8 @@ def prepare_internal_funcs(self, circuits: List[Circuit]):
forward_rates=forward_rates,
inputs=ref_circuit.qreactions.reactions.inputs,
outputs=ref_circuit.qreactions.reactions.outputs,
solver=get_diffrax_solver(self.simulation_args.get('method', 'Dopri5')),
solver=get_diffrax_solver(
self.simulation_args.get('method', 'Dopri5')),
saveat=dfx.SaveAt(
# t0=True, t1=True),
ts=np.linspace(self.t0, self.t1, 500)), # int(np.min([500, self.t1-self.t0]))))
Expand Down
29 changes: 28 additions & 1 deletion synbio_morpher/utils/modelling/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,32 @@ def get_diffrax_solver(solver_type):

assert solver_type in list(solvers.keys(
)), f'Diffrax solver option {solver_type} not found. See https://docs.kidger.site/diffrax/api/solvers/ode_solvers/ for valid solvers.'

return solvers[solver_type]


def make_stepsize_controller(self, choice: str, **kwargs):
""" The choice can be either log or piecewise """
if choice == 'log':
return self.make_log_stepcontrol(**kwargs)
elif choice == 'piecewise':
return make_piecewise_stepcontrol(t0=self.t0, t1=self.t1, dt0=self.dt0, dt1=self.dt1, **kwargs)
elif choice == 'adaptive':
return dfx.PIDController(rtol=1e-3, atol=1e-5)
else:
raise ValueError(
f'The stepsize controller option `{choice}` is not available.')


def make_log_stepcontrol(self, upper_log: int = 3):
num = 1000
x = np.interp(np.logspace(0, upper_log, num=num), [
1, np.power(10, upper_log)], [self.dt0, self.dt1])
while np.cumsum(x)[-1] < self.t1:
x = np.interp(np.logspace(0, upper_log, num=num), [
1, np.power(10, upper_log)], [self.dt0, self.dt1])
num += 1
ts = np.cumsum(x)
ts[0] = self.t0
ts[-1] = self.t1
return dfx.StepTo(ts)

0 comments on commit b00c172

Please sign in to comment.