Skip to content

Commit

Permalink
[thermo] moving type checks to estimate_umbrella_sampling() and minor…
Browse files Browse the repository at this point in the history
… refactoring
  • Loading branch information
cwehmeyer committed Jun 13, 2016
1 parent 090b80b commit 7b1359d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 21 deletions.
32 changes: 29 additions & 3 deletions pyemma/thermo/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,25 @@ def estimate_umbrella_sampling(
array([ 0.63..., 1.60..., 1.31...])
"""
from .util import get_umbrella_sampling_data as _get_umbrella_sampling_data
# sanity checks
if estimator not in ['wham', 'dtram', 'tram']:
raise ValueError("unsupported estimator: %s" % estimator)
from .util import get_umbrella_sampling_data as _get_umbrella_sampling_data
ttrajs, btrajs, umbrella_centers, force_constants, unbiased_state = _get_umbrella_sampling_data(
us_trajs, us_centers, us_force_constants, md_trajs=md_trajs, kT=kT)
if not isinstance(us_trajs, (list, tuple)):
raise ValueError("The parameter us_trajs must be a list of numpy.ndarray objects")
if not isinstance(us_centers, (list, tuple)):
raise ValueError(
"The parameter us_centers must be a list of floats or numpy.ndarray objects")
if not isinstance(us_force_constants, (list, tuple)):
raise ValueError(
"The parameter us_force_constants must be a list of floats or numpy.ndarray objects")
if len(us_trajs) != len(us_centers):
raise ValueError("Unmatching number of umbrella sampling trajectories and centers: %d!=%d" \
% (len(us_trajs), len(us_centers)))
if len(us_trajs) != len(us_force_constants):
raise ValueError(
"Unmatching number of umbrella sampling trajectories and force constants: %d!=%d" \
% (len(us_trajs), len(us_force_constants)))
if len(us_trajs) != len(us_dtrajs):
raise ValueError(
"Number of continuous and discrete umbrella sampling trajectories does not " + \
Expand All @@ -159,9 +173,16 @@ def estimate_umbrella_sampling(
"Lengths of continuous and discrete umbrella sampling trajectories with " + \
"index %d does not match: %d!=%d" % (i, len(us_trajs), len(us_dtrajs)))
i += 1
if md_trajs is not None:
if not isinstance(md_trajs, (list, tuple)):
raise ValueError("The parameter md_trajs must be a list of numpy.ndarray objects")
if md_dtrajs is None:
raise ValueError("You have provided md_trajs, but md_dtrajs is None")
if md_dtrajs is None:
md_dtrajs = []
else:
if md_trajs is None:
raise ValueError("You have provided md_dtrajs, but md_trajs is None")
if len(md_trajs) != len(md_dtrajs):
raise ValueError(
"Number of continuous and discrete unbiased trajectories does not " + \
Expand All @@ -173,7 +194,11 @@ def estimate_umbrella_sampling(
"Lengths of continuous and discrete unbiased trajectories with " + \
"index %d does not match: %d!=%d" % (i, len(md_trajs), len(md_dtrajs)))
i += 1
# data preparation
ttrajs, btrajs, umbrella_centers, force_constants, unbiased_state = _get_umbrella_sampling_data(
us_trajs, us_centers, us_force_constants, md_trajs=md_trajs, kT=kT)
estimator_obj = None
# estimation
if estimator == 'wham':
estimator_obj = wham(
ttrajs, us_dtrajs + md_dtrajs,
Expand All @@ -200,6 +225,7 @@ def estimate_umbrella_sampling(
maxiter=maxiter, maxerr=maxerr, save_convergence_info=save_convergence_info,
dt_traj=dt_traj, init=init, init_maxiter=init_maxiter, init_maxerr=init_maxerr,
**parsed_kwargs)
# adding thermodynamic state information and return results
try:
estimator_obj.umbrella_centers = umbrella_centers
estimator_obj.force_constants = force_constants
Expand Down
18 changes: 0 additions & 18 deletions pyemma/thermo/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,24 +129,6 @@ def _get_umbrella_sampling_parameters(
nthermo = 0
unbiased_state = None
dimension = None
if not isinstance(us_trajs, (list, tuple)):
raise ValueError("The parameter us_trajs must be a list of numpy.ndarray objects")
if not isinstance(us_centers, (list, tuple)):
raise ValueError(
"The parameter us_centers must be a list of floats or numpy.ndarray objects")
if not isinstance(us_force_constants, (list, tuple)):
raise ValueError(
"The parameter us_force_constants must be a list of floats or numpy.ndarray objects")
if len(us_trajs) != len(us_centers):
raise ValueError("Unmatching number of umbrella sampling trajectories and centers: %d!=%d" \
% (len(us_trajs), len(us_centers)))
if len(us_trajs) != len(us_force_constants):
raise ValueError(
"Unmatching number of umbrella sampling trajectories and force constants: %d!=%d" \
% (len(us_trajs), len(us_force_constants)))
if md_trajs is not None:
if not isinstance(md_trajs, (list, tuple)):
raise ValueError("The parameter md_trajs must be a list of numpy.ndarray objects")
for i, traj in enumerate(us_trajs):
state = None
try:
Expand Down

0 comments on commit 7b1359d

Please sign in to comment.