From 352d4080c270785c2ba7a119b26865ae0c350ae4 Mon Sep 17 00:00:00 2001 From: Noam Bernstein Date: Wed, 31 Jan 2024 15:09:31 -0500 Subject: [PATCH] Correctly test for matching ConfigSet_loc so group iterator will work. optimize no longer increase nesting level when traj_subselect indicates only one config from each traj ("last" or "last_converged") --- complete_pytest.tin | 2 +- setup.py | 2 +- tests/test_optimize.py | 3 +-- wfl/configset.py | 49 ++++++++++++++++++++-------------------- wfl/generate/optimize.py | 8 ++----- 5 files changed, 30 insertions(+), 34 deletions(-) diff --git a/complete_pytest.tin b/complete_pytest.tin index 1f5fd97d..6f7a442d 100755 --- a/complete_pytest.tin +++ b/complete_pytest.tin @@ -66,7 +66,7 @@ echo "summary line $l" # ===== 152 passed, 17 skipped, 3 xpassed, 78 warnings in 4430.81s (1:13:50) ===== lp=$( echo $l | sed -E -e 's/ in .*//' -e 's/\s*,\s*/\n/g' ) -declare -A expected_n=( ["passed"]="163" ["skipped"]="21" ["warnings"]=801 ["xfailed"]=2 ["xpassed"]=1 ) +declare -A expected_n=( ["passed"]="163" ["skipped"]="21" ["warnings"]=803 ["xfailed"]=2 ["xpassed"]=1 ) IFS=$'\n' for out in $lp; do out_n=$(echo $out | sed -e 's/^=* //' -e 's/ .*//' -e 's/,//') diff --git a/setup.py b/setup.py index c1e24213..f176ba1b 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setuptools.setup( name="wfl", - version="0.2.0", + version="0.2.1", packages=setuptools.find_packages(exclude=["tests"]), install_requires=["click>=7.0", "numpy", "ase>=3.21", "pyyaml", "spglib", "docstring_parser", "expyre-wfl @ https://github.com/libAtoms/ExPyRe/tarball/main", diff --git a/tests/test_optimize.py b/tests/test_optimize.py index 9f06a35e..2b50e9cb 100644 --- a/tests/test_optimize.py +++ b/tests/test_optimize.py @@ -221,9 +221,8 @@ def test_subselect_from_traj(cu_slab): _autopara_per_item_info = [{} for _ in range(len(inputs))] ) - assert len(atoms_opt[1]) == 1 - assert isinstance(atoms_opt[1][0], Atoms) # and not None assert atoms_opt[0] is None + assert isinstance(atoms_opt[1], Atoms) # not None # check that iterable_loop handles Nones as expected inputs = ConfigSet([cu_slab.copy(), cu_slab_optimised.copy()]) diff --git a/wfl/configset.py b/wfl/configset.py index cbf9e9f9..2bad8242 100644 --- a/wfl/configset.py +++ b/wfl/configset.py @@ -1,4 +1,5 @@ import sys +import re from pathlib import Path @@ -148,7 +149,7 @@ def __iter__(self): ## print("DEBUG __iter__ one file", self.items) for at_i, at in enumerate(ase.io.iread(self.items, **self.read_kwargs)): loc = at.info.get("_ConfigSet_loc", ConfigSet._loc_sep + str(at_i)) - if loc.startswith(self._file_loc): + if len(self._file_loc) == 0 or re.match(self._file_loc + r'\b', loc): ## print("DEBUG matching loc", loc, "file_loc", self._file_loc) loc = loc.replace(self._file_loc, "", 1) ## print("DEBUG stripped loc", loc) @@ -160,7 +161,7 @@ def __iter__(self): for file_i, filepath in enumerate(self.items): for at_i, at in enumerate(ase.io.iread(filepath, **self.read_kwargs)): loc = ConfigSet._loc_sep + str(file_i) + at.info.get("_ConfigSet_loc", ConfigSet._loc_sep + str(at_i)) - if loc.startswith(self._file_loc): + if len(self._file_loc) == 0 or re.match(self._file_loc + r'\b', loc): loc = loc.replace(self._file_loc, "", 1) at.info["_ConfigSet_loc"] = loc self._cur_loc = at.info.get("_ConfigSet_loc") @@ -200,20 +201,20 @@ def advance(at_i=None): return self._cur_at[0].info.get("_ConfigSet_loc", ConfigSet._loc_sep + str(at_i) if at_i is not None else None) if isinstance(self.items, Path): - ## print("DEBUG groups() for one file self._open_reader", self._open_reader, "self._cur_at", self._cur_at) + ## print("DEBUG groups() for one file self._open_reader", self._open_reader, "self._cur_at", self._cur_at) ##DEBUG # one file, return a ConfigSet for each group, or a sequence of individual Atoms if self._open_reader is None or self._cur_at[0] is None: # initialize reading of file - ## print("DEBUG initializing reader", self.items) + ## print("DEBUG initializing reader", self.items) ##DEBUG self._open_reader = ase.io.iread(self.items, **self.read_kwargs) self._cur_at = [None] - ## print("DEBUG setting initial self._cur_at = [None]") + ## print("DEBUG setting initial self._cur_at = [None]") ##DEBUG try: - ## print("DEBUG advancing, getting at_loc with cur_at_i", cur_at_i) + ## print("DEBUG advancing, getting at_loc with cur_at_i", cur_at_i) ##DEBUG at_loc = advance(cur_at_i) - ## print("DEBUG got at_loc", at_loc) + ## print("DEBUG got at_loc", at_loc) ##DEBUG except StopIteration: - ## print("DEBUG got EOF, returning") + ## print("DEBUG got EOF, returning") ##DEBUG # indicate EOF self._cur_at = [None] return @@ -223,12 +224,12 @@ def advance(at_i=None): try: # Skip any that don't match self._file_loc - ## print("DEBUG first skipping non-matching, at_loc", at_loc, "self._file_loc", self._file_loc) + ## print("DEBUG first skipping non-matching, at_loc", at_loc, "self._file_loc", self._file_loc) ##DEBUG while not at_loc.startswith(self._file_loc): at_loc = advance() - ## print("DEBUG after first skipping non-matching, new at_loc", at_loc) + ## print("DEBUG after first skipping non-matching, new at_loc", at_loc) ##DEBUG except StopIteration: - ## print("DEBUG starting second skipping") + ## print("DEBUG starting second skipping") ##DEBUG # Failed to find config that matches self._file_loc from current position of reader. # Search again from start (in case we're doing something out of order) self._open_reader = ase.io.iread(self.items, **self.read_kwargs) @@ -242,42 +243,42 @@ def advance(at_i=None): # any matching configs raise RuntimeError(f"No matching configs in file {self.items} for location {self._file_loc}") from exc - ## print("DEBUG after possibly skipping, now should be at right place, at_loc", at_loc, "self._file_loc", - ## self._file_loc, "cur_at_i", cur_at_i, "self._cur_at.numbers", self._cur_at[0].numbers) + ## print("DEBUG after possibly skipping, now should be at right place, at_loc", at_loc, "self._file_loc", ##DEBUG + ## self._file_loc, "cur_at_i", cur_at_i, "self._cur_at.numbers", self._cur_at[0].numbers) ##DEBUG # now self._cur_at should be first config that matches self._file_loc requested_depth = len(self._file_loc.split(ConfigSet._loc_sep)) - ## print("DEBUG requested_depth", requested_depth) + ## print("DEBUG requested_depth", requested_depth) ##DEBUG if len(at_loc.split(ConfigSet._loc_sep)) == requested_depth + 1: - ## print("DEBUG in right container, starting to yield atoms") + ## print("DEBUG in right container, starting to yield atoms") ##DEBUG # in right container, yield Atoms while at_loc.startswith(self._file_loc): if "_ConfigSet_loc" in self._cur_at[0].info: del self._cur_at[0].info["_ConfigSet_loc"] - ## print("DEBUG in loop, yielding Atoms") + ## print("DEBUG in loop, yielding Atoms") ##DEBUG yield self._cur_at[0] cur_at_i += 1 try: at_loc = advance(cur_at_i) except StopIteration: - ## print("DEBUG EOF while yielding atoms") + ## print("DEBUG EOF while yielding atoms") ##DEBUG # indicate EOF self._cur_at[0] = None return return else: - ## print("DEBUG right place, but need to go deeper") + ## print("DEBUG right place, but need to go deeper") ##DEBUG # at first config that matches self._file_loc, but there are deeper levels to nest into while at_loc.startswith(self._file_loc): # get location of deeper iterator from at_loc down to one deeper than requested at this level new_file_loc = ConfigSet._loc_sep.join(at_loc.split(ConfigSet._loc_sep)[0:requested_depth + 1]) - ## print("DEBUG making and yielding ConfigSet with new _file_loc", new_file_loc) + ## print("DEBUG making and yielding ConfigSet with new _file_loc", new_file_loc) ##DEBUG t = ConfigSet(self.items, _open_reader=self._open_reader, _cur_at=self._cur_at, _file_loc=new_file_loc) - ## print("DEBUG yielding ConfigSet", t, "_open_reader", t._open_reader, "_cur_at", self._cur_at) + ## print("DEBUG yielding ConfigSet", t, "_open_reader", t._open_reader, "_cur_at", self._cur_at) ##DEBUG yield t - ## print("DEBUG after yield, got self._cur_at", self._cur_at[0].numbers if self._cur_at[0] is not None else None) + ## print("DEBUG after yield, got self._cur_at", self._cur_at[0].numbers if self._cur_at[0] is not None else None) ##DEBUG if self._cur_at[0] is None: - ## print("DEBUG got EOF, returning") + ## print("DEBUG got EOF, returning") ##DEBUG # got EOF deeper inside, exit return # self._cur_at could have advanced when calling function iterated over yielded @@ -289,12 +290,12 @@ def advance(at_i=None): try: at_loc = advance() except StopIteration: - ## print("DEBUG past yielded iterator got EOF") + ## print("DEBUG past yielded iterator got EOF") ##DEBUG self._cur_at[0] = None return # now we should be past configs that could have been consumed by previously yielded iterator - ## print("DEBUG end of loop, now at_loc is", at_loc) + ## print("DEBUG end of loop, now at_loc is", at_loc) ##DEBUG else: # self.items is list(Atoms) or list(...list(Atoms)) or list(Path) diff --git a/wfl/generate/optimize.py b/wfl/generate/optimize.py index 9741bc26..79898fec 100644 --- a/wfl/generate/optimize.py +++ b/wfl/generate/optimize.py @@ -237,12 +237,8 @@ def subselect_from_traj(traj, subselect=None): if subselect is None: return traj elif subselect == "last": - return [traj[-1]] + return traj[-1] elif subselect == "last_converged": - converged_configs = [at for at in traj if at.info["optimize_config_type"] == "optimize_last_converged"] - if len(converged_configs) == 0: - return None - else: - return converged_configs + return traj[-1] if (traj[-1].info["optimize_config_type"] == "optimize_last_converged") else None raise RuntimeError(f'Subselecting confgs from trajectory with rule ' f'"subselect={subselect}" is not yet implemented')