diff --git a/espfit/utils/graphs.py b/espfit/utils/graphs.py index a66145d..1960e42 100644 --- a/espfit/utils/graphs.py +++ b/espfit/utils/graphs.py @@ -514,7 +514,7 @@ def compute_relative_energy(self): del new_graphs - def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False, keyname=None): + def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False, keyname='u_ref'): """Reshape conformation size. This is a work around to handle different graph size (shape). DGL requires at least one dimension with same size. @@ -539,10 +539,10 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False, k include_min_energy_conf : boolean, default=False If True, then minimum energy conformer will be included for all split graphs. - keyname : str, default=None - Key name to be used to define the energy minima. This is usually u_ref or u_qm. - Note that depending on how the dataset was prepared, nonbonded energies could be subtracted from u_ref, - whereas u_qm could be the raw QM energies. + keyname : str, default='u_ref' + Key name to be used to define the energy minima. This is usually `u_ref` or `u_qm`. + Note that depending on how the dataset was prepared, nonbonded energies could be subtracted from `u_ref`, + whereas `u_qm` could be the raw QM energies. Returns ------- @@ -559,8 +559,8 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False, k import torch # Check if keyname is specified - if include_min_energy_conf == True and keyname == None: - raise Exception(f'Key name not specified. This is usually u_ref or u_qm.') + if include_min_energy_conf == True and keyname not in ['u_ref', 'u_qm']: + raise Exception(f'Key name {keyname} not supported. Supported keynames are u_ref and u_qm') new_graphs = [] n_confs_cache = n_confs