From dd316d7063f2d5a039dd2ef5acaccd16c9879c3b Mon Sep 17 00:00:00 2001 From: olive004 Date: Fri, 18 Oct 2024 10:57:55 +0100 Subject: [PATCH] typing hints --- ..._mutation_effect_on_interactions_signal.py | 2 +- .../srv/io/loaders/circuit_loader.py | 4 +- synbio_morpher/srv/io/loaders/data_loader.py | 2 +- synbio_morpher/srv/io/manage/data_manager.py | 6 +- .../circuit/agnostic_circuits/circuit.py | 26 ++++----- .../agnostic_circuits/circuit_manager.py | 55 +++++++++---------- synbio_morpher/utils/common/setup.py | 8 +-- synbio_morpher/utils/data/common.py | 12 ++-- .../utils/data/data_format_tools/common.py | 2 +- synbio_morpher/utils/evolution/evolver.py | 17 +++--- synbio_morpher/utils/misc/io.py | 8 +-- 11 files changed, 70 insertions(+), 72 deletions(-) diff --git a/synbio_morpher/scripts/mutation_effect_on_interactions_signal/run_mutation_effect_on_interactions_signal.py b/synbio_morpher/scripts/mutation_effect_on_interactions_signal/run_mutation_effect_on_interactions_signal.py index 24b75df7..bf06fdfe 100644 --- a/synbio_morpher/scripts/mutation_effect_on_interactions_signal/run_mutation_effect_on_interactions_signal.py +++ b/synbio_morpher/scripts/mutation_effect_on_interactions_signal/run_mutation_effect_on_interactions_signal.py @@ -74,7 +74,7 @@ def logging_circuit(clist: list): [ # Mutate circuit Protocol( - partial(Evolver(data_writer=data_writer, sequence_type=config_file.get('system_type'), + partial(Evolver(data_writer=data_writer, sequence_type=config_file['system_type'], seed=config_file.get('mutations_args', {}).get('seed', np.random.randint(1000))).mutate, write_to_subsystem=True, algorithm=config_file.get('mutations_args', {}).get('algorithm', 'random')), diff --git a/synbio_morpher/srv/io/loaders/circuit_loader.py b/synbio_morpher/srv/io/loaders/circuit_loader.py index f9a58e89..e7245762 100644 --- a/synbio_morpher/srv/io/loaders/circuit_loader.py +++ b/synbio_morpher/srv/io/loaders/circuit_loader.py @@ -5,9 +5,9 @@ from synbio_morpher.srv.parameter_prediction.interactions import INTERACTION_FIELDS_TO_WRITE -from synbio_morpher.utils.common.setup import construct_circuit_from_cfg, prepare_config, expand_model_config, compose_kwargs +from synbio_morpher.utils.common.setup import construct_circuit_from_cfg from synbio_morpher.utils.data.data_format_tools.common import FORMAT_EXTS, load_json_as_dict -from synbio_morpher.utils.evolution.mutation import implement_mutation +from synbio_morpher.utils.evolution.evolver import implement_mutation from synbio_morpher.utils.misc.type_handling import inverse_dict, flatten_nested_dict from synbio_morpher.utils.evolution.evolver import load_mutations diff --git a/synbio_morpher/srv/io/loaders/data_loader.py b/synbio_morpher/srv/io/loaders/data_loader.py index 853b0c3c..1a172b0e 100644 --- a/synbio_morpher/srv/io/loaders/data_loader.py +++ b/synbio_morpher/srv/io/loaders/data_loader.py @@ -52,5 +52,5 @@ class GeneCircuitLoader(DataLoader): def __init__(self) -> None: super().__init__() - def load_data(self, filepath: str, identities: dict = {}): + def load_data(self, filepath: str, identities: dict = {}) -> Data: return Data(super().load_data(filepath=filepath), source_files=filepath, identities=identities) diff --git a/synbio_morpher/srv/io/manage/data_manager.py b/synbio_morpher/srv/io/manage/data_manager.py index a3e07f30..1111adb0 100644 --- a/synbio_morpher/srv/io/manage/data_manager.py +++ b/synbio_morpher/srv/io/manage/data_manager.py @@ -8,17 +8,17 @@ import logging +from typing import Optional from synbio_morpher.srv.io.loaders.data_loader import GeneCircuitLoader from synbio_morpher.utils.data.common import Data class DataManager(): - def __init__(self, filepath: str = None, identities: dict = None, data: dict = None): + def __init__(self, identities: dict, filepath: Optional[str] = None, data: Optional[dict] = None): self.loader = GeneCircuitLoader() self.source = filepath - self.data = data if data is None and filepath: - self.data = self.loader.load_data(filepath, identities=identities) + self.data: Data = self.loader.load_data(filepath, identities=identities) elif data is None and filepath is None: self.data = Data(loaded_data=data, identities=identities) # logging.warning( diff --git a/synbio_morpher/utils/circuit/agnostic_circuits/circuit.py b/synbio_morpher/utils/circuit/agnostic_circuits/circuit.py index 3e9441d6..d3a92b9c 100644 --- a/synbio_morpher/utils/circuit/agnostic_circuits/circuit.py +++ b/synbio_morpher/utils/circuit/agnostic_circuits/circuit.py @@ -66,7 +66,7 @@ def init_refcircuit(self, config: dict): assert self.interactions_state != 'uninitialised', f'The interactions should have been initialised from {config.get("interactions")}' self.signal: Signal # = None self.mutations_args: dict = config.get('mutations_args', {}) - self.mutations: Dict[str, Mutations] = {} + self.mutations: Dict[str, Dict[str, Mutations]] = {} self.update_species_simulated_rates(self.interactions) @@ -141,17 +141,17 @@ def update_species_simulated_rates(self, interactions: MolecularInteractions): self.qreactions.reactions = self.qreactions.init_reactions( self.model) - @property - def signal(self): - return self._signal + # @property + # def signal(self): + # return self._signal - @signal.getter - def signal(self): - # if self._signal is None: - # logging.warning( - # f'Trying to retrieve None signal from circuit. Make sure signal specified in circuit config') - return self._signal + # @signal.getter + # def signal(self): + # # if self._signal is None: + # # logging.warning( + # # f'Trying to retrieve None signal from circuit. Make sure signal specified in circuit config') + # return self._signal - @signal.setter - def signal(self, value): - self._signal = value + # @signal.setter + # def signal(self, value): + # self._signal = value diff --git a/synbio_morpher/utils/circuit/agnostic_circuits/circuit_manager.py b/synbio_morpher/utils/circuit/agnostic_circuits/circuit_manager.py index 8884dbc5..2052eef4 100644 --- a/synbio_morpher/utils/circuit/agnostic_circuits/circuit_manager.py +++ b/synbio_morpher/utils/circuit/agnostic_circuits/circuit_manager.py @@ -20,6 +20,7 @@ import diffrax as dfx import numpy as np import jax +import jax.numpy as jnp # jax.config.update('jax_platform_name', 'cpu') from scipy import integrate @@ -28,7 +29,7 @@ from bioreaction.simulation.manager import simulate_steady_states from synbio_morpher.srv.parameter_prediction.simulator import SIMULATOR_UNITS, make_piecewise_stepcontrol -from synbio_morpher.srv.parameter_prediction.interactions import InteractionDataHandler, InteractionSimulator, INTERACTION_FIELDS_TO_WRITE +from synbio_morpher.srv.parameter_prediction.interactions import InteractionSimulator, INTERACTION_FIELDS_TO_WRITE, MolecularInteractions from synbio_morpher.utils.circuit.agnostic_circuits.circuit import Circuit, interactions_to_df from synbio_morpher.utils.misc.helper import vanilla_return from synbio_morpher.utils.misc.numerical import invert_onehot, zero_out_negs @@ -222,7 +223,7 @@ def compute_interactions_batch(self, circuits: List[Circuit], batch=True): circuits[i] = self.compute_interactions(c) return circuits - def run_interaction_simulator(self, species: List[Species], quantities, filename=None) -> InteractionDataHandler: + def run_interaction_simulator(self, species: List[Species], quantities, filename=None) -> MolecularInteractions: data = {s: s.sequence for s in species} # if filename is not None: # return self.interaction_simulator.run((filename, data), compute_by_filename=True) @@ -251,7 +252,7 @@ def find_steady_states(self, circuits: List[Circuit], batch=True) -> List[Circui def compute_steady_states(self, circuits: List[Circuit], solver_type: str = 'jax', use_zero_rates: bool = False) -> Tuple[np.ndarray, np.ndarray]: - + t = np.zeros(0) if solver_type == 'ivp': b_copynumbers = [] for circuit in circuits: @@ -262,8 +263,8 @@ def compute_steady_states(self, circuits: List[Circuit], ((circuit.qreactions.reactions.reverse_rates - circuit.qreactions.reactions.forward_rates) < 1e2) * 1 - signal_onehot = np.zeros_like( - circuit.signal.reactions_onehot) if circuit.use_prod_and_deg else np.zeros_like(circuit.signal.onehot) + signal_onehot = jnp.zeros_like( + circuit.signal.reactions_onehot) if circuit.use_prod_and_deg else jnp.zeros_like(circuit.signal.onehot) steady_state_result = integrate.solve_ivp( partial(bioreaction_sim, args=None, reactions=r, signal=vanilla_return, signal_onehot=signal_onehot), @@ -319,7 +320,7 @@ def compute_steady_states(self, circuits: List[Circuit], b_copynumbers = np.swapaxes(b_copynumbers, 1, 2) - elif solver_type == 'torchode': + elif solver_type in ['torchode', 'torchdiffeq']: raise NotImplementedError() # import torchode as tode # import torch @@ -347,9 +348,8 @@ def compute_steady_states(self, circuits: List[Circuit], # self.solver = torch.compile(solver) # sol = self.solver.solve(prob) - elif solver_type == 'torchdiffeq': - t_eval = np.linspace(self.t0, self.t1, 100) - raise NotImplementedError() + # elif solver_type == 'torchdiffeq': + # t_eval = np.linspace(self.t0, self.t1, 100) # tdeq.odeint_adjoint(sim_func, starting_states, t) else: @@ -358,21 +358,21 @@ def compute_steady_states(self, circuits: List[Circuit], return np.asarray(b_copynumbers), np.squeeze(t) - def model_circuit(self, y0: np.ndarray, circuit: Circuit): - assert np.shape(y0)[circuit.time_axis] == 1, 'Please only use 1-d ' \ - f'initial copynumbers instead of {np.shape(y0)}' + # def model_circuit(self, y0: np.ndarray, circuit: Circuit): + # assert np.shape(y0)[circuit.time_axis] == 1, 'Please only use 1-d ' \ + # f'initial copynumbers instead of {np.shape(y0)}' - modelling_func = partial(bioreaction_sim, args=None, - reactions=circuit.qreactions.reactions, - signal=circuit.signal.func, - signal_onehot=circuit.signal.reactions_onehot) + # modelling_func = partial(bioreaction_sim, args=None, + # reactions=circuit.qreactions.reactions, + # signal=circuit.signal.func, + # signal_onehot=jnp.array(circuit.signal.reactions_onehot)) - copynumbers = self.iterate_modelling_func(y0, modelling_func, - max_time=self.t1, - time_interval=self.dt0, - signal_f=circuit.signal.func, - signal_onehot=circuit.signal.onehot) - return copynumbers + # copynumbers = self.iterate_modelling_func(y0, modelling_func, + # max_time=self.t1, + # time_interval=self.dt0, + # signal_f=circuit.signal.func, + # signal_onehot=circuit.signal.onehot) + # return copynumbers def iterate_modelling_func(self, init_copynumbers, modelling_func, max_time, @@ -430,7 +430,7 @@ def simulate_signal_batch(self, circuits: List[Circuit], def prepare_batch_params(circuits: List[Circuit]): b_steady_states = [None] * len(circuits) - b_reverse_rates = [None] * len(circuits) + b_reverse_rates = np.zeros((len(circuits), *circuits[0].qreactions.reactions.reverse_rates.shape)) species_chosen = circuits[0].model.species[np.argmax( signal.onehot)] @@ -565,14 +565,14 @@ def make_subcircuit(self, circuit: Circuit, mutation_name: str, mutation: Option subcircuit.reset_to_initial_state() subcircuit.strip_to_core() if mutation is None: - mutation = circuit.mutations.get(mutation_name) + mutation = circuit.mutations[mutation_name] subcircuit.subname = mutation_name subcircuit = implement_mutation(circuit=subcircuit, mutation=mutation) return subcircuit def load_mutations(self, circuit: Circuit): - subcircuits = [Circuit(config=None, as_mutation=True) + subcircuits = [Circuit(config={}, as_mutation=True) for m in flatten_nested_dict(circuit.mutations)] for i, (m_name, m) in enumerate(flatten_nested_dict(circuit.mutations).items()): if not m: @@ -608,13 +608,12 @@ def wrap_mutations(self, circuit: Circuit, methods: dict, include_normal_run=Tru # logging.info(f'Running methods on mutation {name} ({i})') if include_normal_run and i == 0: self.result_writer.unsubdivide_last_dir() - circuit = self.apply_to_circuit( - circuit, methods, ref_circuit=circuit) + circuit = self.apply_to_circuit(circuit, methods) self.result_writer.subdivide_writing( 'mutations', safe_dir_change=False) subcircuit = self.make_subcircuit(circuit, name, mutation) self.result_writer.subdivide_writing(name, safe_dir_change=False) - self.apply_to_circuit(subcircuit, methods, ref_circuit=circuit) + self.apply_to_circuit(subcircuit, methods) self.result_writer.unsubdivide_last_dir() self.result_writer.unsubdivide() return circuit diff --git a/synbio_morpher/utils/common/setup.py b/synbio_morpher/utils/common/setup.py index b868cac9..3c31549c 100644 --- a/synbio_morpher/utils/common/setup.py +++ b/synbio_morpher/utils/common/setup.py @@ -7,7 +7,7 @@ import logging from copy import deepcopy -from typing import List, Union +from typing import List, Union, Optional from bioreaction.model.data_containers import BasicModel from synbio_morpher.srv.io.manage.sys_interface import make_filename_safely from synbio_morpher.srv.io.manage.data_manager import DataManager @@ -32,7 +32,7 @@ ] -def expand_model_config(in_config: dict, out_config: dict, sample_names: List[str]) -> dict: +def expand_model_config(in_config: dict, out_config: dict, sample_names: Union[List[str], dict]) -> dict: if 'starting_concentration' not in out_config.keys(): out_config['starting_concentration'] = {} for s in sample_names: @@ -58,7 +58,7 @@ def process_molecular_params(params: dict, factor=1) -> dict: return params -def compose_kwargs(prev_configs: dict = None, config: dict = None) -> dict: +def compose_kwargs(config: dict, prev_configs: Optional[dict] = None) -> dict: """ Extra configs like data paths can be supplied here, eg. for circuits that were dynamically generated. """ if prev_configs is not None: @@ -119,7 +119,7 @@ def prepare_config(config_filepath: Union[str, dict, None] = None, config_file: config_file = add_empty_fields(config_file) return config_file -def construct_circuit_from_cfg(prev_configs: dict, config_file: dict): +def construct_circuit_from_cfg(prev_configs: Optional[dict], config_file: dict): kwargs = compose_kwargs(prev_configs=prev_configs, config=config_file) circuit = instantiate_system(kwargs) if kwargs.get("signal"): diff --git a/synbio_morpher/utils/data/common.py b/synbio_morpher/utils/data/common.py index 29b5f211..4275b318 100644 --- a/synbio_morpher/utils/data/common.py +++ b/synbio_morpher/utils/data/common.py @@ -3,18 +3,18 @@ # All rights reserved. # This source code is licensed under the MIT-style license found in the -# LICENSE file in the root directory of this source tree. - +# LICENSE file in the root directory of this source tree. import logging +from typing import Optional, List import pandas as pd class Data(): """ Holds things like FASTA files or other genetic info files. """ - def __init__(self, loaded_data: dict, identities: dict = {}, source_files=None) -> None: + def __init__(self, loaded_data: Optional[dict], identities: dict = {}, source_files: Optional[str] = None) -> None: self.data = loaded_data if loaded_data is not None else {} self.source = source_files self.sample_names = self.make_sample_names() @@ -39,12 +39,12 @@ def convert_names_to_idxs(names_table: dict, source: list) -> dict: f'Identities not found: {names_table.values()} not in {source}') return indexed_identities - def make_sample_names(self, sample_names: list = None) -> list: + def make_sample_names(self, sample_names: Optional[List[str]] = None) -> List[str]: if type(self.data) == dict: return list(self.data.keys()) elif type(self.data) == pd.DataFrame: return list(self.data.columns) - elif self.data is None: + elif self.data is None and (sample_names is not None): return sample_names raise ValueError(f'Unrecognised loaded data type {type(self.data)}.') @@ -71,7 +71,7 @@ def sample_names(self): return self._sample_names @sample_names.getter - def sample_names(self): + def sample_names(self) -> List[str]: return self.make_sample_names() # if type(self.data) == dict: # return list(self.data.keys()) diff --git a/synbio_morpher/utils/data/data_format_tools/common.py b/synbio_morpher/utils/data/data_format_tools/common.py index d173906f..533b702b 100644 --- a/synbio_morpher/utils/data/data_format_tools/common.py +++ b/synbio_morpher/utils/data/data_format_tools/common.py @@ -60,7 +60,7 @@ def get_filename_from_within_package(json_pathname: str) -> str: return full_json_pathname -def load_json_as_dict(json_pathname: Union[str, dict], process=True) -> dict: +def load_json_as_dict(json_pathname: Union[str, dict, None], process=True) -> dict: if not json_pathname: return {} elif type(json_pathname) == dict: diff --git a/synbio_morpher/utils/evolution/evolver.py b/synbio_morpher/utils/evolution/evolver.py index c19cb479..5194a16d 100644 --- a/synbio_morpher/utils/evolution/evolver.py +++ b/synbio_morpher/utils/evolution/evolver.py @@ -8,14 +8,13 @@ from functools import partial import logging import os -from typing import Tuple, List +from typing import Tuple, List, Union import numpy as np from bioreaction.model.data_containers import Species from synbio_morpher.srv.io.loaders.misc import load_csv from synbio_morpher.utils.evolution.mutation import get_mutation_type_mapping, Mutations, EXCLUDED_NUCS -from synbio_morpher.utils.misc.type_handling import flatten_listlike, flatten_nested_dict +from synbio_morpher.utils.misc.type_handling import flatten_listlike from synbio_morpher.utils.results.writer import DataWriter, kwargs_from_table -from synbio_morpher.utils.misc.string_handling import add_outtype from synbio_morpher.utils.circuit.agnostic_circuits.circuit import Circuit @@ -26,9 +25,9 @@ class Evolver(): def __init__(self, data_writer: DataWriter, mutation_type: str = 'random', - sequence_type: str = None, - seed: int = None, - concurrent_species_to_mutate: str = '') -> None: + sequence_type: str = 'RNA', + seed: int = 0, + concurrent_species_to_mutate: Union[str, list] = '') -> None: self.data_writer = data_writer self.mutation_type = mutation_type # Not implemented self.out_name = 'mutations' @@ -258,9 +257,9 @@ def load_mutations(circuit, filename=None): def apply_mutation_to_sequence(sequence: str, mutation_positions: List[int], mutation_types: List[str]) -> List[str]: - sequence = np.array([*sequence]) - sequence[mutation_positions] = mutation_types - return ''.join(sequence) + sequence_arr = np.array([*sequence]) + sequence_arr[mutation_positions] = mutation_types + return ''.join(sequence_arr) def implement_mutation(circuit: Circuit, mutation: Mutations): diff --git a/synbio_morpher/utils/misc/io.py b/synbio_morpher/utils/misc/io.py index 4bb2223e..7c65fce3 100644 --- a/synbio_morpher/utils/misc/io.py +++ b/synbio_morpher/utils/misc/io.py @@ -5,7 +5,7 @@ # This source code is licensed under the MIT-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Union +from typing import List, Union, Optional import glob import os from synbio_morpher.utils.misc.string_handling import remove_file_extension @@ -15,7 +15,7 @@ from synbio_morpher.utils.data.data_format_tools.common import load_multiple_as_list -def isolate_filename(filepath: str): +def isolate_filename(filepath: Optional[str]): if type(filepath) == str: return os.path.splitext(os.path.basename(filepath))[0] return None @@ -33,7 +33,7 @@ def get_pathnames_from_mult_dirs(search_dirs: List[str], **get_pathnames_kwargs) def get_pathnames(search_dir: str, file_key: Union[List, str] = '', first_only: bool = False, allow_empty: bool = False, subdir: str = '', - subdirs: list = None, + subdirs: Optional[list] = None, conditional: Union[str, None] = 'filenames', as_dict=False) -> Union[dict, list]: """ Get the pathnames in a folder given a keyword. @@ -70,7 +70,7 @@ def get_pathnames(search_dir: str, file_key: Union[List, str] = '', first_only: elif not file_key: path_names = sorted([os.path.join(search_dir, f) for f in os.listdir( search_dir) if path_condition_f(os.path.join(search_dir, f))]) - else: + elif type(file_key) == str: path_names = sorted([f for f in glob.glob(os.path.join( search_dir, '*' + file_key + '*')) if path_condition_f(f)]) if first_only and path_names: