diff --git a/data/test_data/vacuum/SAMPLmol2/SAMPLmol2_FEC_100_for_AlchemicalANI1ccx_restraint_False.gz b/data/test_data/vacuum/SAMPLmol2/SAMPLmol2_FEC_100_for_AlchemicalANI1ccx_restraint_False.gz new file mode 100644 index 00000000..b2439a99 Binary files /dev/null and b/data/test_data/vacuum/SAMPLmol2/SAMPLmol2_FEC_100_for_AlchemicalANI1ccx_restraint_False.gz differ diff --git a/data/test_data/vacuum/SAMPLmol2/SAMPLmol2_FEC_20_for_AlchemicalANI1ccx_restraint_False.gz b/data/test_data/vacuum/SAMPLmol2/SAMPLmol2_FEC_20_for_AlchemicalANI1ccx_restraint_False.gz new file mode 100644 index 00000000..107362b9 Binary files /dev/null and b/data/test_data/vacuum/SAMPLmol2/SAMPLmol2_FEC_20_for_AlchemicalANI1ccx_restraint_False.gz differ diff --git a/data/test_data/vacuum/SAMPLmol2/SAMPLmol2_FEC_20_for_CompartimentedAlchemicalANI2x_restraint_False.gz b/data/test_data/vacuum/SAMPLmol2/SAMPLmol2_FEC_20_for_CompartimentedAlchemicalANI2x_restraint_False.gz new file mode 100644 index 00000000..04466f84 Binary files /dev/null and b/data/test_data/vacuum/SAMPLmol2/SAMPLmol2_FEC_20_for_CompartimentedAlchemicalANI2x_restraint_False.gz differ diff --git a/data/test_data/vacuum/SAMPLmol2/SAMPLmol2_FEC_50_for_AlchemicalANI1ccx_restraint_False.gz b/data/test_data/vacuum/SAMPLmol2/SAMPLmol2_FEC_50_for_AlchemicalANI1ccx_restraint_False.gz new file mode 100644 index 00000000..c2df947e Binary files /dev/null and b/data/test_data/vacuum/SAMPLmol2/SAMPLmol2_FEC_50_for_AlchemicalANI1ccx_restraint_False.gz differ diff --git a/data/test_data/vacuum/SAMPLmol4/SAMPLmol4_FEC_100_for_AlchemicalANI1ccx_restraint_False.gz b/data/test_data/vacuum/SAMPLmol4/SAMPLmol4_FEC_100_for_AlchemicalANI1ccx_restraint_False.gz new file mode 100644 index 00000000..576e861b Binary files /dev/null and b/data/test_data/vacuum/SAMPLmol4/SAMPLmol4_FEC_100_for_AlchemicalANI1ccx_restraint_False.gz differ diff --git a/data/test_data/vacuum/SAMPLmol4/SAMPLmol4_FEC_20_for_AlchemicalANI1ccx_restraint_False.gz b/data/test_data/vacuum/SAMPLmol4/SAMPLmol4_FEC_20_for_AlchemicalANI1ccx_restraint_False.gz new file mode 100644 index 00000000..dd643650 Binary files /dev/null and b/data/test_data/vacuum/SAMPLmol4/SAMPLmol4_FEC_20_for_AlchemicalANI1ccx_restraint_False.gz differ diff --git a/data/test_data/vacuum/SAMPLmol4/SAMPLmol4_FEC_20_for_CompartimentedAlchemicalANI2x_restraint_False.gz b/data/test_data/vacuum/SAMPLmol4/SAMPLmol4_FEC_20_for_CompartimentedAlchemicalANI2x_restraint_False.gz new file mode 100644 index 00000000..aa3332ed Binary files /dev/null and b/data/test_data/vacuum/SAMPLmol4/SAMPLmol4_FEC_20_for_CompartimentedAlchemicalANI2x_restraint_False.gz differ diff --git a/data/test_data/vacuum/SAMPLmol4/SAMPLmol4_FEC_50_for_AlchemicalANI1ccx_restraint_False.gz b/data/test_data/vacuum/SAMPLmol4/SAMPLmol4_FEC_50_for_AlchemicalANI1ccx_restraint_False.gz new file mode 100644 index 00000000..d64432ed Binary files /dev/null and b/data/test_data/vacuum/SAMPLmol4/SAMPLmol4_FEC_50_for_AlchemicalANI1ccx_restraint_False.gz differ diff --git a/data/test_data/vacuum/molDWRow_298/molDWRow_298_FEC_100_for_AlchemicalANI1ccx_restraint_False.gz b/data/test_data/vacuum/molDWRow_298/molDWRow_298_FEC_100_for_AlchemicalANI1ccx_restraint_False.gz new file mode 100644 index 00000000..719fc133 Binary files /dev/null and b/data/test_data/vacuum/molDWRow_298/molDWRow_298_FEC_100_for_AlchemicalANI1ccx_restraint_False.gz differ diff --git a/data/test_data/vacuum/molDWRow_298/molDWRow_298_FEC_10_for_CompartimentedAlchemicalANI1ccx_restraint_False.gz b/data/test_data/vacuum/molDWRow_298/molDWRow_298_FEC_10_for_CompartimentedAlchemicalANI1ccx_restraint_False.gz new file mode 100644 index 00000000..9cc16aab Binary files /dev/null and b/data/test_data/vacuum/molDWRow_298/molDWRow_298_FEC_10_for_CompartimentedAlchemicalANI1ccx_restraint_False.gz differ diff --git a/data/test_data/vacuum/molDWRow_298/molDWRow_298_FEC_20_for_AlchemicalANI1ccx_restraint_False.gz b/data/test_data/vacuum/molDWRow_298/molDWRow_298_FEC_20_for_AlchemicalANI1ccx_restraint_False.gz new file mode 100644 index 00000000..8160d56b Binary files /dev/null and b/data/test_data/vacuum/molDWRow_298/molDWRow_298_FEC_20_for_AlchemicalANI1ccx_restraint_False.gz differ diff --git a/data/test_data/vacuum/molDWRow_298/molDWRow_298_FEC_20_for_CompartimentedAlchemicalANI2x_restraint_False.gz b/data/test_data/vacuum/molDWRow_298/molDWRow_298_FEC_20_for_CompartimentedAlchemicalANI2x_restraint_False.gz new file mode 100644 index 00000000..7aba6a2a Binary files /dev/null and b/data/test_data/vacuum/molDWRow_298/molDWRow_298_FEC_20_for_CompartimentedAlchemicalANI2x_restraint_False.gz differ diff --git a/data/test_data/vacuum/molDWRow_298/molDWRow_298_FEC_50_for_AlchemicalANI1ccx_restraint_False.gz b/data/test_data/vacuum/molDWRow_298/molDWRow_298_FEC_50_for_AlchemicalANI1ccx_restraint_False.gz new file mode 100644 index 00000000..2f7588ab Binary files /dev/null and b/data/test_data/vacuum/molDWRow_298/molDWRow_298_FEC_50_for_AlchemicalANI1ccx_restraint_False.gz differ diff --git a/neutromeratio/parameter_gradients.py b/neutromeratio/parameter_gradients.py index e4cbc04d..4ee4f843 100644 --- a/neutromeratio/parameter_gradients.py +++ b/neutromeratio/parameter_gradients.py @@ -27,6 +27,17 @@ logger = logging.getLogger(__name__) +class FreeEnergy(NamedTuple): + """Returned by + + free_energy_estimate: tensor in kT + free_energy_estimate_error: float in kT + """ + + free_energy_estimate: torch.Tensor # in kT + free_energy_estimate_error: float # in kT + + class FreeEnergyCalculator: def __init__( self, @@ -243,7 +254,7 @@ def _end_state_free_energy_difference(self) -> np.ndarray: results = self.mbar.getFreeEnergyDifferences(return_dict=True) return results["Delta_f"][0, -1], results["dDelta_f"][0, -1] - def _compute_perturbed_free_energies(self, u_ln) -> torch.Tensor: + def _compute_perturbed_free_energies(self, u_ln) -> (torch.Tensor, float): """compute perturbed free energies at new thermodynamic states l""" assert type(u_ln) == torch.Tensor @@ -258,13 +269,15 @@ def _compute_perturbed_free_energies(self, u_ln) -> torch.Tensor: self.mbar.u_kn, dtype=torch.double, requires_grad=False, device=device ) + _, dG_error = self.mbar.computePerturbedFreeEnergies(u_ln.detach().numpy()) + # importance weighting log_q_k = f_k - u_kn.T A = log_q_k + torch.log(N_k) log_denominator_n = torch.logsumexp(A, dim=1) B = -u_ln - log_denominator_n - return -torch.logsumexp(B, dim=1) + return -torch.logsumexp(B, dim=1), dG_error[0][-1] def _form_u_ln(self, original_neural_network: bool = False) -> torch.Tensor: @@ -292,12 +305,12 @@ def _form_u_ln(self, original_neural_network: bool = False) -> torch.Tensor: u_ln = torch.stack([u_0, u_1]) return u_ln - def _compute_free_energy_difference(self) -> torch.Tensor: + def _compute_free_energy_difference(self) -> (torch.Tensor, float): u_ln = self._form_u_ln() - f_k = self._compute_perturbed_free_energies(u_ln) + f_k, dG_error = self._compute_perturbed_free_energies(u_ln) # keep u_ln in memory self.u_ln_rho_star_wrt_parameters = u_ln - return f_k[1] - f_k[0] + return f_k[1] - f_k[0], dG_error def get_u_ln_for_rho_and_rho_star(self) -> Tuple[(torch.Tensor, torch.Tensor)]: u_ln_rho = torch.stack( @@ -357,7 +370,7 @@ def mae_between_potentials_for_snapshots( if env == "vacuum": scale_with = 100 elif env == "droplet": - scale_with = 400 + scale_with = 100 else: raise RuntimeError( "If normalized, the environment needs to be specified" @@ -395,7 +408,9 @@ def torchify(x): return torch.tensor(x, dtype=torch.double, requires_grad=True, device=device) -def get_perturbed_free_energy_difference(fec: FreeEnergyCalculator) -> torch.Tensor: +def get_perturbed_free_energy_difference( + fec: FreeEnergyCalculator, +) -> FreeEnergy: """ Gets a list of fec instances and returns a torch.tensor with the computed free energy differences. @@ -407,11 +422,12 @@ def get_perturbed_free_energy_difference(fec: FreeEnergyCalculator) -> torch.Ten torch.tensor -- calculated free energy in kT """ if fec.flipped: - deltaF = fec._compute_free_energy_difference() * -1.0 + deltaF, dG_error = fec._compute_free_energy_difference() + deltaF *= -1 else: - deltaF = fec._compute_free_energy_difference() + deltaF, dG_error = fec._compute_free_energy_difference() - return deltaF + return FreeEnergy(deltaF, dG_error) def get_unperturbed_free_energy_difference(fec: FreeEnergyCalculator): @@ -431,7 +447,7 @@ def get_unperturbed_free_energy_difference(fec: FreeEnergyCalculator): else: deltaF = fec._end_state_free_energy_difference[0] - return torchify(deltaF) + return FreeEnergy(torchify(deltaF), fec._end_state_free_energy_difference[1]) def get_experimental_values(name: str) -> torch.Tensor: @@ -479,30 +495,6 @@ def init(): logging.basicConfig(level=logging.INFO) -# def _mp(n_proc: int, prop_list: list) -> List[FreeEnergyCalculator]: - -# with get_context("forkserver").Pool(processes=n_proc, initializer=init) as pool: -# pool_result = pool.map_async(_setup_FEC, prop_list) -# pool_result.wait(timeout=400) # should take around 120s -# try: -# if pool_result.ready(): -# FEC_list = pool_result.get(timeout=10) -# pool.close() # no more tasks -# pool.join() # wrap up current tasks -# else: -# pool.terminate() # otherwise shared memory is not released -# pool.join() # wrap up current tasks -# raise RuntimeError("Took too long ...") - -# except Exception: -# print("failing gracefully ...") -# pool.terminate() # otherwise shared memory is not released -# pool.join() # wrap up current tasks -# raise - -# return FEC_list - - def calculate_rmse_between_exp_and_calc( names: list, model: ANI, @@ -580,9 +572,17 @@ def calculate_rmse_between_exp_and_calc( for fec in FEC_list: # append calculated values if perturbed_free_energy: - e_calc.append(get_perturbed_free_energy_difference(fec).item()) + e_calc.append( + get_perturbed_free_energy_difference( + fec + ).free_energy_estimate.item() + ) else: - e_calc.append(get_unperturbed_free_energy_difference(fec).item()) + e_calc.append( + get_unperturbed_free_energy_difference( + fec + ).free_energy_estimate.item() + ) # append experimental values e_exp.append(get_experimental_values(fec.name).item()) @@ -696,6 +696,7 @@ def setup_and_perform_parameter_retraining_with_test_set_split( test_size: float = 0.2, validation_size: float = 0.2, snapshot_penalty_f: PenaltyFunction = PenaltyFunction(0, 0, 0, 0, False), + skipping: float = 1, ) -> Tuple[list, Number]: """ @@ -828,6 +829,7 @@ def setup_and_perform_parameter_retraining_with_test_set_split( lr_SGD=lr_SGD, weight_decay=weight_decay, snapshot_penalty_f=snapshot_penalty_f, + skipping=skipping, ) # final rmsd calculation on test set @@ -892,6 +894,7 @@ def _perform_training( lr_SGD: float, weight_decay: float, snapshot_penalty_f: PenaltyFunction, + skipping: float, ) -> list: early_stopping_learning_rate = 1.0e-8 @@ -930,7 +933,7 @@ def _perform_training( f"{base}_{0}.pt", ) - current_rmse = -1 + current_rmse = rmse_validation[-1] ## training loop for idx in range(AdamW_scheduler.last_epoch + 1, max_epochs): @@ -967,8 +970,8 @@ def _perform_training( snapshot_penalty_f=snapshot_penalty_f, ) - if idx > 1 and idx % 2: - #skip every second validation set calculation + if idx > 1 and idx % skipping == 0: + # skip every second validation set calculation with torch.no_grad(): # calculate the new free energies on the validation set with optimized parameters current_rmse, _ = calculate_rmse_between_exp_and_calc( @@ -1191,7 +1194,9 @@ def _loss_function( """ snapshot_penalty = torch.tensor([0.0]) # calculate the free energies - calc_free_energy_difference = get_perturbed_free_energy_difference(fec) + calc_free_energy_difference = get_perturbed_free_energy_difference( + fec + ).free_energy_estimate # obtain the experimental free energies exp_free_energy_difference = get_experimental_values(fec.name) # calculate the loss as MSE @@ -1392,6 +1397,7 @@ def setup_and_perform_parameter_retraining( lr_AdamW: float = 1e-3, lr_SGD: float = 1e-3, weight_decay: float = 0.000001, + skipping: float = 1, ) -> list: """ Much of this code is taken from: @@ -1486,6 +1492,7 @@ def setup_and_perform_parameter_retraining( lr_SGD=lr_SGD, weight_decay=weight_decay, snapshot_penalty_f=snapshot_penalty_f, + skipping=skipping, ) return rmse_validation diff --git a/neutromeratio/restraints.py b/neutromeratio/restraints.py index 7d6182b9..14dcfffa 100644 --- a/neutromeratio/restraints.py +++ b/neutromeratio/restraints.py @@ -3,18 +3,14 @@ import numpy as np from openmmtools.constants import kB -from scipy.stats import norm from simtk import unit -from torch.distributions.normal import Normal from .constants import ( bond_length_dict, device, mass_dict_in_daltons, - nm_to_angstroms, temperature, water_hoh_angle, - radian_to_degree, ) logger = logging.getLogger(__name__) diff --git a/neutromeratio/tests/test_ANI.py b/neutromeratio/tests/test_ANI.py index b01b1d72..9dff41a0 100644 --- a/neutromeratio/tests/test_ANI.py +++ b/neutromeratio/tests/test_ANI.py @@ -73,11 +73,9 @@ def test_tochani_neutromeratio_sync(): def test_neutromeratio_energy_calculations_with_ANI_in_vacuum(): - from ..tautomers import Tautomer - from ..constants import kT from ..analysis import setup_alchemical_system_and_energy_function import numpy as np - from ..ani import AlchemicalANI1ccx, ANI1ccx, ANI2x + from ..ani import AlchemicalANI1ccx # read in exp_results.pickle with open("data/test_data/exp_results.pickle", "rb") as f: @@ -128,11 +126,8 @@ def test_neutromeratio_energy_calculations_with_ANI_in_vacuum(): def test_neutromeratio_energy_calculations_with_AlchemicalANI1ccx(): - from ..tautomers import Tautomer - import numpy as np - from ..constants import kT from ..analysis import setup_alchemical_system_and_energy_function - from ..ani import AlchemicalANI1ccx, AlchemicalANI2x, ANI1ccx + from ..ani import AlchemicalANI1ccx # read in exp_results.pickle with open("data/test_data/exp_results.pickle", "rb") as f: @@ -175,10 +170,7 @@ def test_neutromeratio_energy_calculations_with_AlchemicalANI1ccx(): def test_neutromeratio_energy_calculations_with_ANI_in_droplet(): - from ..tautomers import Tautomer - from ..constants import kT from ..analysis import setup_alchemical_system_and_energy_function - import numpy as np from ..ani import AlchemicalANI1ccx # read in exp_results.pickle diff --git a/neutromeratio/tests/test_AlchemicalANI.py b/neutromeratio/tests/test_AlchemicalANI.py index 70f3b62b..1e110e50 100644 --- a/neutromeratio/tests/test_AlchemicalANI.py +++ b/neutromeratio/tests/test_AlchemicalANI.py @@ -9,7 +9,6 @@ import pickle from simtk import unit from neutromeratio.constants import device -import torchani from openmmtools.utils import is_quantity_close from neutromeratio.constants import device from neutromeratio.utils import _get_traj diff --git a/neutromeratio/tests/test_neutromeratio_misc.py b/neutromeratio/tests/test_neutromeratio_misc.py index d0bd7346..705df580 100644 --- a/neutromeratio/tests/test_neutromeratio_misc.py +++ b/neutromeratio/tests/test_neutromeratio_misc.py @@ -272,19 +272,16 @@ def test_max_nr_of_snapshots(): print(l_rmse) assert l_rmse == [ - (5.745586395263672, [3.219421508256698, -3.984754521510265, 3.775032043152413]), ( - 5.393607139587402, - [1.2104192737678794, -5.316053509708738, 4.0559353503118505], + 5.745594501495361, + [3.219412789711892, -3.984756004408934, 3.7750394158567984], ), + (5.393613815307617, [1.21041203944986, -5.316059264760289, 4.055940344778948]), ( - 5.6913323402404785, - [0.6301678213061068, -5.464017066061666, 4.675917867000249], - ), - ( - 5.574014663696289, - [1.1649843159209015, -5.0964898153872005, 4.284151701496853], + 5.691339015960693, + [0.6301607712237898, -5.464022528356984, 4.675923042604108], ), + (5.57402229309082, [1.1649766159387198, -5.096495181564678, 4.284156815859456]), ] @@ -312,8 +309,8 @@ def test_unperturbed_perturbed_free_energy(): save_pickled_FEC=False, ) - a_AlchemicalANI2x = get_unperturbed_free_energy_difference(fec) - b_AlchemicalANI2x = get_perturbed_free_energy_difference(fec) + a_AlchemicalANI2x = get_unperturbed_free_energy_difference(fec).free_energy_estimate + b_AlchemicalANI2x = get_perturbed_free_energy_difference(fec).free_energy_estimate np.isclose(a_AlchemicalANI2x.item(), b_AlchemicalANI2x.item()) del fec @@ -330,8 +327,8 @@ def test_unperturbed_perturbed_free_energy(): save_pickled_FEC=False, ) - a_CompartimentedAlchemicalANI2x = get_unperturbed_free_energy_difference(fec) - b_CompartimentedAlchemicalANI2x = get_perturbed_free_energy_difference(fec) + a_CompartimentedAlchemicalANI2x = get_unperturbed_free_energy_difference(fec).free_energy_estimate + b_CompartimentedAlchemicalANI2x = get_perturbed_free_energy_difference(fec).free_energy_estimate np.isclose( a_CompartimentedAlchemicalANI2x.item(), b_CompartimentedAlchemicalANI2x.item() ) diff --git a/neutromeratio/tests/test_parameter_retraining.py b/neutromeratio/tests/test_parameter_retraining.py index 41487e2c..197791ed 100644 --- a/neutromeratio/tests/test_parameter_retraining.py +++ b/neutromeratio/tests/test_parameter_retraining.py @@ -54,7 +54,7 @@ def test_u_ln_50_snapshots(): env="vacuum", data_path="data/test_data/vacuum", ANImodel=model, - bulk_energy_calculation=False, + bulk_energy_calculation=True, max_snapshots_per_window=50, load_pickled_FEC=True, include_restraint_energy_contribution=False, @@ -71,55 +71,8 @@ def test_u_ln_50_snapshots(): f_per_atom = fec.mae_between_potentials_for_snapshots( normalized=True, env="droplet" ) - f_scaled_to_mol = (f_per_atom / 400) * len(fec.ani_model.species[0]) - assert np.isclose(f_per_molecule.item(), f_scaled_to_mol.item()) - - -@pytest.mark.skipif( - os.environ.get("TRAVIS", None) == "true", reason="Can't upload necessary files." -) -def test_u_ln_20_snapshots(): - from ..parameter_gradients import ( - setup_FEC, - ) - from ..ani import CompartimentedAlchemicalANI2x - from neutromeratio.constants import initialize_NUM_PROC - - initialize_NUM_PROC(1) - - # with pickled tautomer object - name = "molDWRow_298" - model, model_name = CompartimentedAlchemicalANI2x, "CompartimentedAlchemicalANI2x" - model._reset_parameters() - model_instance = model([0, 0]) - model_instance.load_nn_parameters( - parameter_path="data/test_data/AlchemicalANI2x_3.pt" - ) - # vacuum - fec = setup_FEC( - name, - env="vacuum", - data_path="data/test_data/vacuum", - ANImodel=model, - bulk_energy_calculation=False, - max_snapshots_per_window=20, - load_pickled_FEC=True, - include_restraint_energy_contribution=False, - ) - fec._compute_free_energy_difference() - # compare to manually scaling - f_per_molecule = fec.mae_between_potentials_for_snapshots(env="vacuum") - f_per_atom = fec.mae_between_potentials_for_snapshots(normalized=True, env="vacuum") f_scaled_to_mol = (f_per_atom / 100) * len(fec.ani_model.species[0]) assert np.isclose(f_per_molecule.item(), f_scaled_to_mol.item()) - # for droplet - # compare to manually scaling - f_per_molecule = fec.mae_between_potentials_for_snapshots(env="droplet") - f_per_atom = fec.mae_between_potentials_for_snapshots( - normalized=True, env="droplet" - ) - f_scaled_to_mol = (f_per_atom / 400) * len(fec.ani_model.species[0]) - assert np.isclose(f_per_molecule.item(), f_scaled_to_mol.item()) def test_scaling_factor(): @@ -357,7 +310,12 @@ def test_postprocessing_vacuum(): ] # get calc free energy - f = torch.stack([get_perturbed_free_energy_difference(fec) for fec in fec_list]) + f = torch.stack( + [ + get_perturbed_free_energy_difference(fec).free_energy_estimate + for fec in fec_list + ] + ) # get exp free energy e = torch.stack([get_experimental_values(name) for name in names]) assert len(f) == 3 @@ -441,7 +399,12 @@ def test_postprocessing_droplet(): for name in names ] # get calc free energy - f = torch.stack([get_perturbed_free_energy_difference(fec) for fec in fec_list]) + f = torch.stack( + [ + get_perturbed_free_energy_difference(fec).free_energy_estimate + for fec in fec_list + ] + ) # get exp free energy e = torch.stack([get_experimental_values(name) for name in names]) @@ -488,7 +451,12 @@ def test_postprocessing_droplet(): for name in names ] # get calc free energy - f = torch.stack([get_perturbed_free_energy_difference(fec) for fec in fec_list]) + f = torch.stack( + [ + get_perturbed_free_energy_difference(fec).free_energy_estimate + for fec in fec_list + ] + ) # get exp free energy e = torch.stack([get_experimental_values(name) for name in names]) @@ -536,7 +504,12 @@ def test_postprocessing_droplet(): ] # get calc free energy - f = torch.stack([get_perturbed_free_energy_difference(fec) for fec in fec_list]) + f = torch.stack( + [ + get_perturbed_free_energy_difference(fec).free_energy_estimate + for fec in fec_list + ] + ) # get exp free energy e = torch.stack([get_experimental_values(name) for name in names]) @@ -583,7 +556,12 @@ def test_postprocessing_droplet(): ] # get calc free energy - f = torch.stack([get_perturbed_free_energy_difference(fec) for fec in fec_list]) + f = torch.stack( + [ + get_perturbed_free_energy_difference(fec).free_energy_estimate + for fec in fec_list + ] + ) # get exp free energy e = torch.stack([get_experimental_values(name) for name in names]) @@ -627,7 +605,7 @@ def test_snapshot_energy_loss_with_CompartimentedAlchemicalANI2x(): ) from ..ani import CompartimentedAlchemicalANI2x - CompartimentedAlchemicalANI2x._reset_parameters() + # CompartimentedAlchemicalANI2x._reset_parameters() name = "molDWRow_298" model_instance = CompartimentedAlchemicalANI2x([0, 0]) env = "vacuum" @@ -636,17 +614,19 @@ def test_snapshot_energy_loss_with_CompartimentedAlchemicalANI2x(): fec = setup_FEC( name, env=env, - diameter=10, + diameter=-1, data_path="data/test_data/vacuum", ANImodel=CompartimentedAlchemicalANI2x, bulk_energy_calculation=True, max_snapshots_per_window=100, - load_pickled_FEC=True, + load_pickled_FEC=False, include_restraint_energy_contribution=False, - save_pickled_FEC=True, + save_pickled_FEC=False, ) assert np.isclose( - 7.981540, get_perturbed_free_energy_difference(fec).item(), rtol=1e-3 + 7.981540, + get_perturbed_free_energy_difference(fec).free_energy_estimate.item(), + rtol=1e-3, ) assert np.isclose(fec.rmse_between_potentials_for_snapshots().item(), 0.0) @@ -655,7 +635,9 @@ def test_snapshot_energy_loss_with_CompartimentedAlchemicalANI2x(): model_instance.load_nn_parameters(f"data/test_data/AlchemicalANI2x_3.pt") assert np.isclose( - -11.25832, get_perturbed_free_energy_difference(fec).item(), rtol=1e-3 + -11.25832, + get_perturbed_free_energy_difference(fec).free_energy_estimate.item(), + rtol=1e-3, ) assert np.isclose( fec.rmse_between_potentials_for_snapshots().item(), diff --git a/neutromeratio/tests/test_setup_functions.py b/neutromeratio/tests/test_setup_functions.py index 0d41d651..5c65c4f3 100644 --- a/neutromeratio/tests/test_setup_functions.py +++ b/neutromeratio/tests/test_setup_functions.py @@ -341,14 +341,10 @@ def test_setup_FEC_test_pickle_files(): def test_FEC_with_different_free_energy_calls(): from ..parameter_gradients import ( get_perturbed_free_energy_difference, - get_unperturbed_free_energy_difference, ) - from ..constants import kT, device, exclude_set_ANI, mols_with_charge from ..parameter_gradients import setup_FEC - from glob import glob from ..ani import ( AlchemicalANI1ccx, - AlchemicalANI1x, AlchemicalANI2x, CompartimentedAlchemicalANI2x, ) @@ -384,12 +380,12 @@ def test_FEC_with_different_free_energy_calls(): # look at the two functions from parameter_gradient # both of these correct for the flipped energy calculation # and flip the sign of the prediction of mol298 - fec_values = [get_perturbed_free_energy_difference(fec) for fec in fec_list] + fec_values = [get_perturbed_free_energy_difference(fec).free_energy_estimate for fec in fec_list] print(fec_values) for e1, e2 in zip(fec_values, [1.2104, -5.3161]): assert np.isclose(e1.item(), e2, rtol=1e-2) - fec_values = [get_perturbed_free_energy_difference(fec) for fec in fec_list] + fec_values = [get_perturbed_free_energy_difference(fec).free_energy_estimate for fec in fec_list] print(fec_values) for e1, e2 in zip(fec_values, [1.2104, -5.3161]): assert np.isclose(e1.item(), e2, rtol=1e-2) @@ -403,8 +399,8 @@ def test_FEC_with_different_free_energy_calls(): assert np.isclose(e1.item(), e2, rtol=1e-2) fec_values = ( - fec_list[0]._compute_free_energy_difference(), - fec_list[1]._compute_free_energy_difference(), + fec_list[0]._compute_free_energy_difference()[0], + fec_list[1]._compute_free_energy_difference()[0], ) print(fec_values) for e1, e2 in zip(fec_values, [-1.2104, -5.3161]): @@ -437,12 +433,12 @@ def test_FEC_with_different_free_energy_calls(): # look at the two functions from parameter_gradient # both of these correct for the flipped energy calculation # and flip the sign of the prediction of mol298 - fec_values = [get_perturbed_free_energy_difference(fec) for fec in fec_list] + fec_values = [get_perturbed_free_energy_difference(fec).free_energy_estimate for fec in fec_list] print(fec_values) for e1, e2 in zip(fec_values, [8.7152, -9.2873]): assert np.isclose(e1.item(), e2, rtol=1e-2) - fec_values = [get_perturbed_free_energy_difference(fec) for fec in fec_list] + fec_values = [get_perturbed_free_energy_difference(fec).free_energy_estimate for fec in fec_list] print(fec_values) for e1, e2 in zip(fec_values, [8.7152, -9.2873]): assert np.isclose(e1.item(), e2, rtol=1e-2) @@ -456,8 +452,8 @@ def test_FEC_with_different_free_energy_calls(): assert np.isclose(e1.item(), e2, rtol=1e-2) fec_values = ( - fec_list[0]._compute_free_energy_difference(), - fec_list[1]._compute_free_energy_difference(), + fec_list[0]._compute_free_energy_difference()[0], + fec_list[1]._compute_free_energy_difference()[0], ) print(fec_values) for e1, e2 in zip(fec_values, [-8.7152, -9.2873]): @@ -490,12 +486,12 @@ def test_FEC_with_different_free_energy_calls(): # look at the two functions from parameter_gradient # both of these correct for the flipped energy calculation # and flip the sign of the prediction of mol298 - fec_values = [get_perturbed_free_energy_difference(fec) for fec in fec_list] + fec_values = [get_perturbed_free_energy_difference(fec).free_energy_estimate for fec in fec_list] print(fec_values) for e1, e2 in zip(fec_values, [8.7152, -9.2873]): assert np.isclose(e1.item(), e2, rtol=1e-2) - fec_values = [get_perturbed_free_energy_difference(fec) for fec in fec_list] + fec_values = [get_perturbed_free_energy_difference(fec).free_energy_estimate for fec in fec_list] print(fec_values) for e1, e2 in zip(fec_values, [8.7152, -9.2873]): assert np.isclose(e1.item(), e2, rtol=1e-2) @@ -509,8 +505,8 @@ def test_FEC_with_different_free_energy_calls(): assert np.isclose(e1.item(), e2, rtol=1e-2) fec_values = ( - fec_list[0]._compute_free_energy_difference(), - fec_list[1]._compute_free_energy_difference(), + fec_list[0]._compute_free_energy_difference()[0], + fec_list[1]._compute_free_energy_difference()[0], ) print(fec_values) for e1, e2 in zip(fec_values, [-8.7152, -9.2873]): @@ -524,7 +520,6 @@ def test_FEC_with_different_free_energy_calls(): ) def test_FEC_with_perturbed_free_energies(): from ..parameter_gradients import get_perturbed_free_energy_difference - from ..constants import kT, device from ..parameter_gradients import setup_FEC from ..ani import ( AlchemicalANI1ccx, @@ -567,7 +562,8 @@ def test_FEC_with_perturbed_free_energies(): ] assert len(fec_list) == 2 - fec_values = [get_perturbed_free_energy_difference(fec) for fec in fec_list] + fec_values = [get_perturbed_free_energy_difference(fec).free_energy_estimate for fec in fec_list] + print(fec_values) for e1, e2 in zip(fec_values, [1.2104, -5.3161]): assert np.isclose(e1.item(), e2, rtol=1e-4) @@ -589,7 +585,7 @@ def test_FEC_with_perturbed_free_energies(): ] assert len(fec_list) == 2 - fec = [get_perturbed_free_energy_difference(fec) for fec in fec_list] + fec = [get_perturbed_free_energy_difference(fec).free_energy_estimate for fec in fec_list] print(fec) for e1, e2 in zip(fec, [1.2104, -5.3161]): assert np.isclose(e1.item(), e2, rtol=1e-4) @@ -607,7 +603,7 @@ def test_FEC_with_perturbed_free_energies(): ) assert np.isclose( fec._end_state_free_energy_difference[0], - fec._compute_free_energy_difference().item(), + fec._compute_free_energy_difference()[0].item(), rtol=1e-5, ) @@ -630,7 +626,7 @@ def test_FEC_with_perturbed_free_energies(): ] assert len(fec_list) == 2 - fec = [get_perturbed_free_energy_difference(fec) for fec in fec_list] + fec = [get_perturbed_free_energy_difference(fec).free_energy_estimate for fec in fec_list] print(fec) for e1, e2 in zip(fec, [10.3192, -9.746403840249418]): assert np.isclose(e1.item(), e2, rtol=1e-4) @@ -654,7 +650,7 @@ def test_FEC_with_perturbed_free_energies(): ] assert len(fec_list) == 2 - fec = [get_perturbed_free_energy_difference(fec) for fec in fec_list] + fec = [get_perturbed_free_energy_difference(fec).free_energy_estimate for fec in fec_list] print(fec) for e1, e2 in zip(fec, [8.8213, -9.664895714083166]): assert np.isclose(e1.item(), e2, rtol=1e-4) @@ -678,14 +674,13 @@ def test_FEC_with_perturbed_free_energies(): ] assert len(fec_list) == 2 - fec = [get_perturbed_free_energy_difference(fec) for fec in fec_list] + fec = [get_perturbed_free_energy_difference(fec).free_energy_estimate for fec in fec_list] print(fec) for e1, e2 in zip(fec, [8.8213, -9.664895714083166]): assert np.isclose(e1.item(), e2, rtol=1e-4) del fec_list del model - @pytest.mark.skipif( os.environ.get("TRAVIS", None) == "true", reason="Slow calculation." ) @@ -719,7 +714,7 @@ def test_FEC_with_perturbed_free_energies_with_and_without_restraints(): ] assert len(fec_list) == 2 - fec = [get_perturbed_free_energy_difference(fec) for fec in fec_list] + fec = [get_perturbed_free_energy_difference(fec).free_energy_estimate for fec in fec_list] print(fec) for e1, e2 in zip(fec, [8.8213, -9.6649]): assert np.isclose(e1.item(), e2, rtol=1e-4) @@ -740,7 +735,7 @@ def test_FEC_with_perturbed_free_energies_with_and_without_restraints(): ] assert len(fec_list) == 2 - fec = [get_perturbed_free_energy_difference(fec) for fec in fec_list] + fec = [get_perturbed_free_energy_difference(fec).free_energy_estimate for fec in fec_list] print(fec) for e1, e2 in zip(fec, [8.8213, -9.6649]): assert np.isclose(e1.item(), e2, rtol=1e-4) @@ -784,7 +779,7 @@ def test_loading_saving_mbar_object_AlchemicalANI2x(): save_pickled_FEC=True, ) assert np.isclose( - 18.348107633661936, get_perturbed_free_energy_difference(fec).item() + 18.348107633661936, get_perturbed_free_energy_difference(fec).free_energy_estimate.item() ) # load checkpoint parameter file and override optimized parameters @@ -802,7 +797,7 @@ def test_loading_saving_mbar_object_AlchemicalANI2x(): # get new free energy assert np.isclose( - 3.2730393726044866, get_perturbed_free_energy_difference(fec).item() + 3.2730393726044866, get_perturbed_free_energy_difference(fec).free_energy_estimate.item() ) del fec @@ -821,7 +816,7 @@ def test_loading_saving_mbar_object_AlchemicalANI2x(): ) assert np.isclose( - 3.2730393726044866, get_perturbed_free_energy_difference(fec).item() + 3.2730393726044866, get_perturbed_free_energy_difference(fec).free_energy_estimate.item() ) pickled_model = fec.ani_model.model @@ -868,7 +863,7 @@ def test_loading_saving_mbar_object_CompartimentedAlchemicalANI2x(): save_pickled_FEC=True, ) assert np.isclose( - 18.348107633661936, get_perturbed_free_energy_difference(fec).item() + 18.348107633661936, get_perturbed_free_energy_difference(fec).free_energy_estimate.item() ) # load checkpoint parameter file and override optimized parameters @@ -886,7 +881,7 @@ def test_loading_saving_mbar_object_CompartimentedAlchemicalANI2x(): # get new free energy assert np.isclose( 3.2730393726044866, - get_perturbed_free_energy_difference(fec).item(), # -1.3759686627878307 + get_perturbed_free_energy_difference(fec).free_energy_estimate.item(), # -1.3759686627878307 ) del fec @@ -905,7 +900,7 @@ def test_loading_saving_mbar_object_CompartimentedAlchemicalANI2x(): ) assert np.isclose( - 3.2730393726044866, get_perturbed_free_energy_difference(fec).item() + 3.2730393726044866, get_perturbed_free_energy_difference(fec).free_energy_estimate.item() ) pickled_model = fec.ani_model.model