Skip to content

Commit

Permalink
movidy adaptability
Browse files Browse the repository at this point in the history
  • Loading branch information
olive004 committed Oct 24, 2024
1 parent 3ed75b2 commit 1b400dc
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 75 deletions.
2 changes: 1 addition & 1 deletion docker_unix/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ biopython
bokeh
-e git+https://github.com/Steel-Lab-Oxford/core-bioreaction-simulation.git@47391ff32aa2e0e9dcbb4541efb526ff6c43e427#egg=bioreaction
celluloid
chex>=0.1.6
chex
fire
jax==0.4.29
# -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Expand Down
245 changes: 199 additions & 46 deletions notebooks/22_adaptation_autograd.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"name": "IntaRNA",
"postprocess": true
},
"interaction_file_keyword": ["energies", "eqconstants", "binding_rates_dissociation"],
"interaction_file_keyword": ["energies", "eqconstants", "binding_rates_dissociation", "binding_sites"],
"molecular_params": "./synbio_morpher/utils/common/configs/RNA_circuit/molecular_params.json",
"experiment": {
"purpose": "gather_interaction_stats"
Expand Down Expand Up @@ -96,7 +96,8 @@
"threshold_steady_states": 0.05,
"use_rate_scaling": true,
"method": "Dopri5",
"stepsize_controller": "adaptive"
"stepsize_controller": "adaptive",
"use_initial_to_add_signal": true
},
"simulation_steady_state": {
"method": "Dopri5",
Expand Down
2 changes: 2 additions & 0 deletions synbio_morpher/scripts/vis_6_scatter/run_vis_6_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def get_selection(m):
if selection_conditions:
data_selected = select_rows_by_conditional_cols(
data, selection_conditions)
hue = 'adaptation' if (
~data_selected['adaptation'].isna()).sum() == int(0.5 * len(data_selected)) else 'overshoot'

if data_selected.empty:
continue
Expand Down
44 changes: 20 additions & 24 deletions synbio_morpher/utils/circuit/agnostic_circuits/circuit_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,40 +430,36 @@ def simulate_signal_batch(self, circuits: List[Circuit],

def prepare_batch_params(circuits: List[Circuit]):

b_steady_states = [None] * len(circuits)
b_steady_states = np.zeros((len(circuits), len(circuits[0].model.species)))
b_og_states = np.array([c.result_collector.get_result(
'steady_states').analytics['steady_states'].flatten(
) for i, c in enumerate(circuits)])
b_reverse_rates = np.zeros(
(len(circuits), *circuits[0].qreactions.reactions.reverse_rates.shape))

species_chosen = circuits[0].model.species[np.argmax(
signal.onehot)]
other_species = flatten_listlike(
[r.output for r in circuits[0].model.reactions if species_chosen in r.input])
onehots = np.array([1 if s in other_species + [species_chosen]
else 0 for s in circuits[0].model.species])
add_sig_to_all_sigspecies = False
onehots = signal.onehot
if add_sig_to_all_sigspecies and (not self.use_initial_to_add_signal):
species_chosen = circuits[0].model.species[np.argmax(
signal.onehot)]
other_species = flatten_listlike(
[r.output for r in circuits[0].model.reactions if species_chosen in r.input])
onehots_all_sigspecies = np.array([1 if s in other_species + [species_chosen]
else 0 for s in circuits[0].model.species])
onehots = onehots_all_sigspecies
for i, c in enumerate(circuits):
analytics_stst = c.result_collector.get_result(
'steady_states').analytics
if analytics_stst is None:
raise ValueError(f'Could not find analytics for steady state result.')
if not c.use_prod_and_deg:
stst = analytics_stst['steady_states'].flatten()
if self.use_initial_to_add_signal:
inst = analytics_stst['initial_steady_states'].flatten()
b_steady_states[i] = stst * ((signal.onehot == 0) * 1) + \
(inst *
signal.func.keywords['target']) * signal.onehot
else:
b_steady_states[i] = stst * ((onehots == 0) * 1) + \
(stst * signal.func.keywords['target']) * onehots

stst_key = 'initial_steady_states' if self.use_initial_to_add_signal else 'steady_states'
stst = analytics_stst[stst_key].flatten()
b_steady_states[i] = stst * ((onehots == 0) * 1) + \
(stst * signal.func.keywords['target']) * onehots
else:
b_steady_states[i] = analytics_stst['steady_states'].flatten()
b_reverse_rates[i] = c.qreactions.reactions.reverse_rates
b_steady_states = np.asarray(b_steady_states)
b_reverse_rates = np.asarray(b_reverse_rates)
# b_og_states = np.array([analytics_stst['steady_states'].flatten(
# ) * onehots + b_steady_states[i] * ((onehots == 0) * 1) for i, c in enumerate(circuits)])
b_og_states = b_steady_states * onehots + b_steady_states * ((onehots == 0) * 1)

return b_steady_states, b_reverse_rates, b_og_states

Expand Down Expand Up @@ -527,8 +523,8 @@ def prepare_batch_params(circuits: List[Circuit]):
analytics_func = jax.vmap(partial(
generate_analytics, time=t, labels=[
s.name for s in ref_circuit.model.species],
signal_onehot=signal.onehot, signal_time=signal_time,
ref_circuit_data=ref_circuit_data))
signal_onehot=signal.onehot, signal_time=signal_time, # type: ignore
ref_circuit_data=ref_circuit_data)) # type: ignore
b_analytics = analytics_func(
data=b_new_copynumbers[ref_idx:ref_idx2])
b_analytics_l = append_nest_dicts(
Expand Down
8 changes: 6 additions & 2 deletions synbio_morpher/utils/results/analytics/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,12 @@ def compute_sensitivity_simple(starting_states, peaks, signal_factor):
def calculate_adaptation(s, p):
""" Adaptation = robustness to noise
s = sensitivity, p = precision """
return np.log(log_distance(s=s, p=p) * np.log(sp_prod(
s=s, p=p, sp_factor=(p / s).max(), s_weight=(np.log(p) / s))))
return sp_prod(s, p)
# return jnp.log(log_distance(s=s, p=p)) * sp_prod(s, p)
# return jnp.log(log_distance(s=s, p=p)) * sp_prod(
# s=s, p=p, sp_factor=(p / s).max(), s_weight=1)
# return jnp.log(log_distance(s=s, p=p) * jnp.log(sp_prod(
# s=s, p=p, sp_factor=(p / s).max(), s_weight=(jnp.log(p) / s))))


def compute_rmse(data: np.ndarray, ref_circuit_data: Optional[np.ndarray]):
Expand Down

0 comments on commit 1b400dc

Please sign in to comment.