Skip to content

Commit

Permalink
Merge pull request #288 from libAtoms/fix_group_iter
Browse files Browse the repository at this point in the history
Fix group iterator
  • Loading branch information
bernstei authored Feb 1, 2024
2 parents 1193a85 + 352d408 commit 2b95c14
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 34 deletions.
2 changes: 1 addition & 1 deletion complete_pytest.tin
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ echo "summary line $l"
# ===== 152 passed, 17 skipped, 3 xpassed, 78 warnings in 4430.81s (1:13:50) =====
lp=$( echo $l | sed -E -e 's/ in .*//' -e 's/\s*,\s*/\n/g' )

declare -A expected_n=( ["passed"]="163" ["skipped"]="21" ["warnings"]=801 ["xfailed"]=2 ["xpassed"]=1 )
declare -A expected_n=( ["passed"]="163" ["skipped"]="21" ["warnings"]=803 ["xfailed"]=2 ["xpassed"]=1 )
IFS=$'\n'
for out in $lp; do
out_n=$(echo $out | sed -e 's/^=* //' -e 's/ .*//' -e 's/,//')
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setuptools.setup(
name="wfl",
version="0.2.0",
version="0.2.1",
packages=setuptools.find_packages(exclude=["tests"]),
install_requires=["click>=7.0", "numpy", "ase>=3.21", "pyyaml", "spglib", "docstring_parser",
"expyre-wfl @ https://github.com/libAtoms/ExPyRe/tarball/main",
Expand Down
3 changes: 1 addition & 2 deletions tests/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,8 @@ def test_subselect_from_traj(cu_slab):
_autopara_per_item_info = [{} for _ in range(len(inputs))]
)

assert len(atoms_opt[1]) == 1
assert isinstance(atoms_opt[1][0], Atoms) # and not None
assert atoms_opt[0] is None
assert isinstance(atoms_opt[1], Atoms) # not None

# check that iterable_loop handles Nones as expected
inputs = ConfigSet([cu_slab.copy(), cu_slab_optimised.copy()])
Expand Down
49 changes: 25 additions & 24 deletions wfl/configset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
import re

from pathlib import Path

Expand Down Expand Up @@ -148,7 +149,7 @@ def __iter__(self):
## print("DEBUG __iter__ one file", self.items)
for at_i, at in enumerate(ase.io.iread(self.items, **self.read_kwargs)):
loc = at.info.get("_ConfigSet_loc", ConfigSet._loc_sep + str(at_i))
if loc.startswith(self._file_loc):
if len(self._file_loc) == 0 or re.match(self._file_loc + r'\b', loc):
## print("DEBUG matching loc", loc, "file_loc", self._file_loc)
loc = loc.replace(self._file_loc, "", 1)
## print("DEBUG stripped loc", loc)
Expand All @@ -160,7 +161,7 @@ def __iter__(self):
for file_i, filepath in enumerate(self.items):
for at_i, at in enumerate(ase.io.iread(filepath, **self.read_kwargs)):
loc = ConfigSet._loc_sep + str(file_i) + at.info.get("_ConfigSet_loc", ConfigSet._loc_sep + str(at_i))
if loc.startswith(self._file_loc):
if len(self._file_loc) == 0 or re.match(self._file_loc + r'\b', loc):
loc = loc.replace(self._file_loc, "", 1)
at.info["_ConfigSet_loc"] = loc
self._cur_loc = at.info.get("_ConfigSet_loc")
Expand Down Expand Up @@ -200,20 +201,20 @@ def advance(at_i=None):
return self._cur_at[0].info.get("_ConfigSet_loc", ConfigSet._loc_sep + str(at_i) if at_i is not None else None)

if isinstance(self.items, Path):
## print("DEBUG groups() for one file self._open_reader", self._open_reader, "self._cur_at", self._cur_at)
## print("DEBUG groups() for one file self._open_reader", self._open_reader, "self._cur_at", self._cur_at) ##DEBUG
# one file, return a ConfigSet for each group, or a sequence of individual Atoms
if self._open_reader is None or self._cur_at[0] is None:
# initialize reading of file
## print("DEBUG initializing reader", self.items)
## print("DEBUG initializing reader", self.items) ##DEBUG
self._open_reader = ase.io.iread(self.items, **self.read_kwargs)
self._cur_at = [None]
## print("DEBUG setting initial self._cur_at = [None]")
## print("DEBUG setting initial self._cur_at = [None]") ##DEBUG
try:
## print("DEBUG advancing, getting at_loc with cur_at_i", cur_at_i)
## print("DEBUG advancing, getting at_loc with cur_at_i", cur_at_i) ##DEBUG
at_loc = advance(cur_at_i)
## print("DEBUG got at_loc", at_loc)
## print("DEBUG got at_loc", at_loc) ##DEBUG
except StopIteration:
## print("DEBUG got EOF, returning")
## print("DEBUG got EOF, returning") ##DEBUG
# indicate EOF
self._cur_at = [None]
return
Expand All @@ -223,12 +224,12 @@ def advance(at_i=None):

try:
# Skip any that don't match self._file_loc
## print("DEBUG first skipping non-matching, at_loc", at_loc, "self._file_loc", self._file_loc)
## print("DEBUG first skipping non-matching, at_loc", at_loc, "self._file_loc", self._file_loc) ##DEBUG
while not at_loc.startswith(self._file_loc):
at_loc = advance()
## print("DEBUG after first skipping non-matching, new at_loc", at_loc)
## print("DEBUG after first skipping non-matching, new at_loc", at_loc) ##DEBUG
except StopIteration:
## print("DEBUG starting second skipping")
## print("DEBUG starting second skipping") ##DEBUG
# Failed to find config that matches self._file_loc from current position of reader.
# Search again from start (in case we're doing something out of order)
self._open_reader = ase.io.iread(self.items, **self.read_kwargs)
Expand All @@ -242,42 +243,42 @@ def advance(at_i=None):
# any matching configs
raise RuntimeError(f"No matching configs in file {self.items} for location {self._file_loc}") from exc

## print("DEBUG after possibly skipping, now should be at right place, at_loc", at_loc, "self._file_loc",
## self._file_loc, "cur_at_i", cur_at_i, "self._cur_at.numbers", self._cur_at[0].numbers)
## print("DEBUG after possibly skipping, now should be at right place, at_loc", at_loc, "self._file_loc", ##DEBUG
## self._file_loc, "cur_at_i", cur_at_i, "self._cur_at.numbers", self._cur_at[0].numbers) ##DEBUG
# now self._cur_at should be first config that matches self._file_loc
requested_depth = len(self._file_loc.split(ConfigSet._loc_sep))
## print("DEBUG requested_depth", requested_depth)
## print("DEBUG requested_depth", requested_depth) ##DEBUG
if len(at_loc.split(ConfigSet._loc_sep)) == requested_depth + 1:
## print("DEBUG in right container, starting to yield atoms")
## print("DEBUG in right container, starting to yield atoms") ##DEBUG
# in right container, yield Atoms
while at_loc.startswith(self._file_loc):
if "_ConfigSet_loc" in self._cur_at[0].info:
del self._cur_at[0].info["_ConfigSet_loc"]
## print("DEBUG in loop, yielding Atoms")
## print("DEBUG in loop, yielding Atoms") ##DEBUG
yield self._cur_at[0]
cur_at_i += 1
try:
at_loc = advance(cur_at_i)
except StopIteration:
## print("DEBUG EOF while yielding atoms")
## print("DEBUG EOF while yielding atoms") ##DEBUG
# indicate EOF
self._cur_at[0] = None
return
return
else:
## print("DEBUG right place, but need to go deeper")
## print("DEBUG right place, but need to go deeper") ##DEBUG
# at first config that matches self._file_loc, but there are deeper levels to nest into

while at_loc.startswith(self._file_loc):
# get location of deeper iterator from at_loc down to one deeper than requested at this level
new_file_loc = ConfigSet._loc_sep.join(at_loc.split(ConfigSet._loc_sep)[0:requested_depth + 1])
## print("DEBUG making and yielding ConfigSet with new _file_loc", new_file_loc)
## print("DEBUG making and yielding ConfigSet with new _file_loc", new_file_loc) ##DEBUG
t = ConfigSet(self.items, _open_reader=self._open_reader, _cur_at=self._cur_at, _file_loc=new_file_loc)
## print("DEBUG yielding ConfigSet", t, "_open_reader", t._open_reader, "_cur_at", self._cur_at)
## print("DEBUG yielding ConfigSet", t, "_open_reader", t._open_reader, "_cur_at", self._cur_at) ##DEBUG
yield t
## print("DEBUG after yield, got self._cur_at", self._cur_at[0].numbers if self._cur_at[0] is not None else None)
## print("DEBUG after yield, got self._cur_at", self._cur_at[0].numbers if self._cur_at[0] is not None else None) ##DEBUG
if self._cur_at[0] is None:
## print("DEBUG got EOF, returning")
## print("DEBUG got EOF, returning") ##DEBUG
# got EOF deeper inside, exit
return
# self._cur_at could have advanced when calling function iterated over yielded
Expand All @@ -289,12 +290,12 @@ def advance(at_i=None):
try:
at_loc = advance()
except StopIteration:
## print("DEBUG past yielded iterator got EOF")
## print("DEBUG past yielded iterator got EOF") ##DEBUG
self._cur_at[0] = None
return
# now we should be past configs that could have been consumed by previously yielded iterator

## print("DEBUG end of loop, now at_loc is", at_loc)
## print("DEBUG end of loop, now at_loc is", at_loc) ##DEBUG

else:
# self.items is list(Atoms) or list(...list(Atoms)) or list(Path)
Expand Down
8 changes: 2 additions & 6 deletions wfl/generate/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,8 @@ def subselect_from_traj(traj, subselect=None):
if subselect is None:
return traj
elif subselect == "last":
return [traj[-1]]
return traj[-1]
elif subselect == "last_converged":
converged_configs = [at for at in traj if at.info["optimize_config_type"] == "optimize_last_converged"]
if len(converged_configs) == 0:
return None
else:
return converged_configs
return traj[-1] if (traj[-1].info["optimize_config_type"] == "optimize_last_converged") else None
raise RuntimeError(f'Subselecting confgs from trajectory with rule '
f'"subselect={subselect}" is not yet implemented')

0 comments on commit 2b95c14

Please sign in to comment.