Skip to content

Commit

Permalink
typing hints
Browse files Browse the repository at this point in the history
  • Loading branch information
olive004 committed Oct 18, 2024
1 parent 344ecd7 commit dd316d7
Show file tree
Hide file tree
Showing 11 changed files with 70 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def logging_circuit(clist: list):
[
# Mutate circuit
Protocol(
partial(Evolver(data_writer=data_writer, sequence_type=config_file.get('system_type'),
partial(Evolver(data_writer=data_writer, sequence_type=config_file['system_type'],
seed=config_file.get('mutations_args', {}).get('seed', np.random.randint(1000))).mutate,
write_to_subsystem=True,
algorithm=config_file.get('mutations_args', {}).get('algorithm', 'random')),
Expand Down
4 changes: 2 additions & 2 deletions synbio_morpher/srv/io/loaders/circuit_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@


from synbio_morpher.srv.parameter_prediction.interactions import INTERACTION_FIELDS_TO_WRITE
from synbio_morpher.utils.common.setup import construct_circuit_from_cfg, prepare_config, expand_model_config, compose_kwargs
from synbio_morpher.utils.common.setup import construct_circuit_from_cfg
from synbio_morpher.utils.data.data_format_tools.common import FORMAT_EXTS, load_json_as_dict
from synbio_morpher.utils.evolution.mutation import implement_mutation
from synbio_morpher.utils.evolution.evolver import implement_mutation
from synbio_morpher.utils.misc.type_handling import inverse_dict, flatten_nested_dict
from synbio_morpher.utils.evolution.evolver import load_mutations

Expand Down
2 changes: 1 addition & 1 deletion synbio_morpher/srv/io/loaders/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,5 @@ class GeneCircuitLoader(DataLoader):
def __init__(self) -> None:
super().__init__()

def load_data(self, filepath: str, identities: dict = {}):
def load_data(self, filepath: str, identities: dict = {}) -> Data:
return Data(super().load_data(filepath=filepath), source_files=filepath, identities=identities)
6 changes: 3 additions & 3 deletions synbio_morpher/srv/io/manage/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@


import logging
from typing import Optional
from synbio_morpher.srv.io.loaders.data_loader import GeneCircuitLoader
from synbio_morpher.utils.data.common import Data


class DataManager():
def __init__(self, filepath: str = None, identities: dict = None, data: dict = None):
def __init__(self, identities: dict, filepath: Optional[str] = None, data: Optional[dict] = None):
self.loader = GeneCircuitLoader()
self.source = filepath
self.data = data
if data is None and filepath:
self.data = self.loader.load_data(filepath, identities=identities)
self.data: Data = self.loader.load_data(filepath, identities=identities)
elif data is None and filepath is None:
self.data = Data(loaded_data=data, identities=identities)
# logging.warning(
Expand Down
26 changes: 13 additions & 13 deletions synbio_morpher/utils/circuit/agnostic_circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def init_refcircuit(self, config: dict):
assert self.interactions_state != 'uninitialised', f'The interactions should have been initialised from {config.get("interactions")}'
self.signal: Signal # = None
self.mutations_args: dict = config.get('mutations_args', {})
self.mutations: Dict[str, Mutations] = {}
self.mutations: Dict[str, Dict[str, Mutations]] = {}

self.update_species_simulated_rates(self.interactions)

Expand Down Expand Up @@ -141,17 +141,17 @@ def update_species_simulated_rates(self, interactions: MolecularInteractions):
self.qreactions.reactions = self.qreactions.init_reactions(
self.model)

@property
def signal(self):
return self._signal
# @property
# def signal(self):
# return self._signal

@signal.getter
def signal(self):
# if self._signal is None:
# logging.warning(
# f'Trying to retrieve None signal from circuit. Make sure signal specified in circuit config')
return self._signal
# @signal.getter
# def signal(self):
# # if self._signal is None:
# # logging.warning(
# # f'Trying to retrieve None signal from circuit. Make sure signal specified in circuit config')
# return self._signal

@signal.setter
def signal(self, value):
self._signal = value
# @signal.setter
# def signal(self, value):
# self._signal = value
55 changes: 27 additions & 28 deletions synbio_morpher/utils/circuit/agnostic_circuits/circuit_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import diffrax as dfx
import numpy as np
import jax
import jax.numpy as jnp
# jax.config.update('jax_platform_name', 'cpu')

from scipy import integrate
Expand All @@ -28,7 +29,7 @@
from bioreaction.simulation.manager import simulate_steady_states

from synbio_morpher.srv.parameter_prediction.simulator import SIMULATOR_UNITS, make_piecewise_stepcontrol
from synbio_morpher.srv.parameter_prediction.interactions import InteractionDataHandler, InteractionSimulator, INTERACTION_FIELDS_TO_WRITE
from synbio_morpher.srv.parameter_prediction.interactions import InteractionSimulator, INTERACTION_FIELDS_TO_WRITE, MolecularInteractions
from synbio_morpher.utils.circuit.agnostic_circuits.circuit import Circuit, interactions_to_df
from synbio_morpher.utils.misc.helper import vanilla_return
from synbio_morpher.utils.misc.numerical import invert_onehot, zero_out_negs
Expand Down Expand Up @@ -222,7 +223,7 @@ def compute_interactions_batch(self, circuits: List[Circuit], batch=True):
circuits[i] = self.compute_interactions(c)
return circuits

def run_interaction_simulator(self, species: List[Species], quantities, filename=None) -> InteractionDataHandler:
def run_interaction_simulator(self, species: List[Species], quantities, filename=None) -> MolecularInteractions:
data = {s: s.sequence for s in species}
# if filename is not None:
# return self.interaction_simulator.run((filename, data), compute_by_filename=True)
Expand Down Expand Up @@ -251,7 +252,7 @@ def find_steady_states(self, circuits: List[Circuit], batch=True) -> List[Circui

def compute_steady_states(self, circuits: List[Circuit],
solver_type: str = 'jax', use_zero_rates: bool = False) -> Tuple[np.ndarray, np.ndarray]:

t = np.zeros(0)
if solver_type == 'ivp':
b_copynumbers = []
for circuit in circuits:
Expand All @@ -262,8 +263,8 @@ def compute_steady_states(self, circuits: List[Circuit],
((circuit.qreactions.reactions.reverse_rates -
circuit.qreactions.reactions.forward_rates) < 1e2) * 1

signal_onehot = np.zeros_like(
circuit.signal.reactions_onehot) if circuit.use_prod_and_deg else np.zeros_like(circuit.signal.onehot)
signal_onehot = jnp.zeros_like(
circuit.signal.reactions_onehot) if circuit.use_prod_and_deg else jnp.zeros_like(circuit.signal.onehot)
steady_state_result = integrate.solve_ivp(
partial(bioreaction_sim, args=None, reactions=r, signal=vanilla_return,
signal_onehot=signal_onehot),
Expand Down Expand Up @@ -319,7 +320,7 @@ def compute_steady_states(self, circuits: List[Circuit],

b_copynumbers = np.swapaxes(b_copynumbers, 1, 2)

elif solver_type == 'torchode':
elif solver_type in ['torchode', 'torchdiffeq']:
raise NotImplementedError()
# import torchode as tode
# import torch
Expand Down Expand Up @@ -347,9 +348,8 @@ def compute_steady_states(self, circuits: List[Circuit],
# self.solver = torch.compile(solver)
# sol = self.solver.solve(prob)

elif solver_type == 'torchdiffeq':
t_eval = np.linspace(self.t0, self.t1, 100)
raise NotImplementedError()
# elif solver_type == 'torchdiffeq':
# t_eval = np.linspace(self.t0, self.t1, 100)
# tdeq.odeint_adjoint(sim_func, starting_states, t)

else:
Expand All @@ -358,21 +358,21 @@ def compute_steady_states(self, circuits: List[Circuit],

return np.asarray(b_copynumbers), np.squeeze(t)

def model_circuit(self, y0: np.ndarray, circuit: Circuit):
assert np.shape(y0)[circuit.time_axis] == 1, 'Please only use 1-d ' \
f'initial copynumbers instead of {np.shape(y0)}'
# def model_circuit(self, y0: np.ndarray, circuit: Circuit):
# assert np.shape(y0)[circuit.time_axis] == 1, 'Please only use 1-d ' \
# f'initial copynumbers instead of {np.shape(y0)}'

modelling_func = partial(bioreaction_sim, args=None,
reactions=circuit.qreactions.reactions,
signal=circuit.signal.func,
signal_onehot=circuit.signal.reactions_onehot)
# modelling_func = partial(bioreaction_sim, args=None,
# reactions=circuit.qreactions.reactions,
# signal=circuit.signal.func,
# signal_onehot=jnp.array(circuit.signal.reactions_onehot))

copynumbers = self.iterate_modelling_func(y0, modelling_func,
max_time=self.t1,
time_interval=self.dt0,
signal_f=circuit.signal.func,
signal_onehot=circuit.signal.onehot)
return copynumbers
# copynumbers = self.iterate_modelling_func(y0, modelling_func,
# max_time=self.t1,
# time_interval=self.dt0,
# signal_f=circuit.signal.func,
# signal_onehot=circuit.signal.onehot)
# return copynumbers

def iterate_modelling_func(self, init_copynumbers,
modelling_func, max_time,
Expand Down Expand Up @@ -430,7 +430,7 @@ def simulate_signal_batch(self, circuits: List[Circuit],
def prepare_batch_params(circuits: List[Circuit]):

b_steady_states = [None] * len(circuits)
b_reverse_rates = [None] * len(circuits)
b_reverse_rates = np.zeros((len(circuits), *circuits[0].qreactions.reactions.reverse_rates.shape))

species_chosen = circuits[0].model.species[np.argmax(
signal.onehot)]
Expand Down Expand Up @@ -565,14 +565,14 @@ def make_subcircuit(self, circuit: Circuit, mutation_name: str, mutation: Option
subcircuit.reset_to_initial_state()
subcircuit.strip_to_core()
if mutation is None:
mutation = circuit.mutations.get(mutation_name)
mutation = circuit.mutations[mutation_name]
subcircuit.subname = mutation_name

subcircuit = implement_mutation(circuit=subcircuit, mutation=mutation)
return subcircuit

def load_mutations(self, circuit: Circuit):
subcircuits = [Circuit(config=None, as_mutation=True)
subcircuits = [Circuit(config={}, as_mutation=True)
for m in flatten_nested_dict(circuit.mutations)]
for i, (m_name, m) in enumerate(flatten_nested_dict(circuit.mutations).items()):
if not m:
Expand Down Expand Up @@ -608,13 +608,12 @@ def wrap_mutations(self, circuit: Circuit, methods: dict, include_normal_run=Tru
# logging.info(f'Running methods on mutation {name} ({i})')
if include_normal_run and i == 0:
self.result_writer.unsubdivide_last_dir()
circuit = self.apply_to_circuit(
circuit, methods, ref_circuit=circuit)
circuit = self.apply_to_circuit(circuit, methods)
self.result_writer.subdivide_writing(
'mutations', safe_dir_change=False)
subcircuit = self.make_subcircuit(circuit, name, mutation)
self.result_writer.subdivide_writing(name, safe_dir_change=False)
self.apply_to_circuit(subcircuit, methods, ref_circuit=circuit)
self.apply_to_circuit(subcircuit, methods)
self.result_writer.unsubdivide_last_dir()
self.result_writer.unsubdivide()
return circuit
Expand Down
8 changes: 4 additions & 4 deletions synbio_morpher/utils/common/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import logging
from copy import deepcopy
from typing import List, Union
from typing import List, Union, Optional
from bioreaction.model.data_containers import BasicModel
from synbio_morpher.srv.io.manage.sys_interface import make_filename_safely
from synbio_morpher.srv.io.manage.data_manager import DataManager
Expand All @@ -32,7 +32,7 @@
]


def expand_model_config(in_config: dict, out_config: dict, sample_names: List[str]) -> dict:
def expand_model_config(in_config: dict, out_config: dict, sample_names: Union[List[str], dict]) -> dict:
if 'starting_concentration' not in out_config.keys():
out_config['starting_concentration'] = {}
for s in sample_names:
Expand All @@ -58,7 +58,7 @@ def process_molecular_params(params: dict, factor=1) -> dict:
return params


def compose_kwargs(prev_configs: dict = None, config: dict = None) -> dict:
def compose_kwargs(config: dict, prev_configs: Optional[dict] = None) -> dict:
""" Extra configs like data paths can be supplied here, eg. for circuits that were dynamically generated. """

if prev_configs is not None:
Expand Down Expand Up @@ -119,7 +119,7 @@ def prepare_config(config_filepath: Union[str, dict, None] = None, config_file:
config_file = add_empty_fields(config_file)
return config_file

def construct_circuit_from_cfg(prev_configs: dict, config_file: dict):
def construct_circuit_from_cfg(prev_configs: Optional[dict], config_file: dict):
kwargs = compose_kwargs(prev_configs=prev_configs, config=config_file)
circuit = instantiate_system(kwargs)
if kwargs.get("signal"):
Expand Down
12 changes: 6 additions & 6 deletions synbio_morpher/utils/data/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
# All rights reserved.

# This source code is licensed under the MIT-style license found in the
# LICENSE file in the root directory of this source tree.

# LICENSE file in the root directory of this source tree.


import logging
from typing import Optional, List
import pandas as pd


class Data():
""" Holds things like FASTA files or other genetic info files. """

def __init__(self, loaded_data: dict, identities: dict = {}, source_files=None) -> None:
def __init__(self, loaded_data: Optional[dict], identities: dict = {}, source_files: Optional[str] = None) -> None:
self.data = loaded_data if loaded_data is not None else {}
self.source = source_files
self.sample_names = self.make_sample_names()
Expand All @@ -39,12 +39,12 @@ def convert_names_to_idxs(names_table: dict, source: list) -> dict:
f'Identities not found: {names_table.values()} not in {source}')
return indexed_identities

def make_sample_names(self, sample_names: list = None) -> list:
def make_sample_names(self, sample_names: Optional[List[str]] = None) -> List[str]:
if type(self.data) == dict:
return list(self.data.keys())
elif type(self.data) == pd.DataFrame:
return list(self.data.columns)
elif self.data is None:
elif self.data is None and (sample_names is not None):
return sample_names
raise ValueError(f'Unrecognised loaded data type {type(self.data)}.')

Expand All @@ -71,7 +71,7 @@ def sample_names(self):
return self._sample_names

@sample_names.getter
def sample_names(self):
def sample_names(self) -> List[str]:
return self.make_sample_names()
# if type(self.data) == dict:
# return list(self.data.keys())
Expand Down
2 changes: 1 addition & 1 deletion synbio_morpher/utils/data/data_format_tools/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_filename_from_within_package(json_pathname: str) -> str:
return full_json_pathname


def load_json_as_dict(json_pathname: Union[str, dict], process=True) -> dict:
def load_json_as_dict(json_pathname: Union[str, dict, None], process=True) -> dict:
if not json_pathname:
return {}
elif type(json_pathname) == dict:
Expand Down
17 changes: 8 additions & 9 deletions synbio_morpher/utils/evolution/evolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
from functools import partial
import logging
import os
from typing import Tuple, List
from typing import Tuple, List, Union
import numpy as np
from bioreaction.model.data_containers import Species
from synbio_morpher.srv.io.loaders.misc import load_csv
from synbio_morpher.utils.evolution.mutation import get_mutation_type_mapping, Mutations, EXCLUDED_NUCS
from synbio_morpher.utils.misc.type_handling import flatten_listlike, flatten_nested_dict
from synbio_morpher.utils.misc.type_handling import flatten_listlike
from synbio_morpher.utils.results.writer import DataWriter, kwargs_from_table
from synbio_morpher.utils.misc.string_handling import add_outtype


from synbio_morpher.utils.circuit.agnostic_circuits.circuit import Circuit
Expand All @@ -26,9 +25,9 @@ class Evolver():
def __init__(self,
data_writer: DataWriter,
mutation_type: str = 'random',
sequence_type: str = None,
seed: int = None,
concurrent_species_to_mutate: str = '') -> None:
sequence_type: str = 'RNA',
seed: int = 0,
concurrent_species_to_mutate: Union[str, list] = '') -> None:
self.data_writer = data_writer
self.mutation_type = mutation_type # Not implemented
self.out_name = 'mutations'
Expand Down Expand Up @@ -258,9 +257,9 @@ def load_mutations(circuit, filename=None):


def apply_mutation_to_sequence(sequence: str, mutation_positions: List[int], mutation_types: List[str]) -> List[str]:
sequence = np.array([*sequence])
sequence[mutation_positions] = mutation_types
return ''.join(sequence)
sequence_arr = np.array([*sequence])
sequence_arr[mutation_positions] = mutation_types
return ''.join(sequence_arr)


def implement_mutation(circuit: Circuit, mutation: Mutations):
Expand Down
Loading

0 comments on commit dd316d7

Please sign in to comment.