Skip to content

Commit

Permalink
Convert output of INIT_CONDS parameter to trajs
Browse files Browse the repository at this point in the history
  • Loading branch information
dwhswenson committed Aug 25, 2024
1 parent de53902 commit d1c7539
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 4 deletions.
26 changes: 25 additions & 1 deletion paths_cli/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,31 @@
store='schemes',
)

INIT_CONDS = OPSStorageLoadMultiple(
class InitCondsLoader(OPSStorageLoadMultiple):
def _extract_trajectories(self, obj):
import openpathsampling as paths
if isinstance(obj, paths.SampleSet):
yield from (s.trajectory for s in obj)
elif isinstance(obj, paths.Sample):
yield obj.trajectory
elif isinstance(obj, paths.Trajectory):
yield obj
elif isinstance(obj, paths.BaseSnapshot):
yield paths.Trajectory([obj])
elif isinstance(obj, list):
for o in obj:
yield from self._extract_trajectories(o)
else:
raise RuntimeError("Unknown initial conditions type: "
f"{obj} (type: {type(obj)}")

def get(self, storage, names):
results = super().get(storage, names)
final_results = list(self._extract_trajectories(results))
return final_results


INIT_CONDS = InitCondsLoader(
param=Option('-t', '--init-conds', multiple=True,
help=("identifier for initial conditions "
+ "(sample set or trajectory)" + HELP_MULTIPLE)),
Expand Down
13 changes: 10 additions & 3 deletions paths_cli/tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ def test_get(self, getter):
storage = paths.Storage(filename, mode='r')
get_type, getter_style = self._parse_getter(getter)
expected = {
'sset': self.sample_set,
'traj': self.traj
'sset': [s.trajectory for s in self.sample_set],
'traj': [self.traj]
}[get_type]
get_arg = {
'name': 'traj',
Expand Down Expand Up @@ -277,7 +277,14 @@ def test_get_none(self, num_in_file):

st = paths.Storage(filename, mode='r')
obj = INIT_CONDS.get(st, None)
assert obj == stored_things[num_in_file - 1]
# TODO: fix this for all being trajectories
expected = [
[self.traj],
[s.trajectory for s in self.sample_set],
[s.trajectory for s in self.other_sample_set],
[s.trajectory for s in self.other_sample_set],
]
assert obj == expected[num_in_file - 1]

def test_get_multiple(self):
filename = self.create_file('number-traj')
Expand Down

0 comments on commit d1c7539

Please sign in to comment.