Skip to content

Commit

Permalink
Replace tests that use generic calculator with EMT() objects with
Browse files Browse the repository at this point in the history
passing a 3-tuple, since newest ASE EMT objects can't be pickled.

Fixed generic calculator auto-property-prefix guessing to work better
with tuples
  • Loading branch information
bernstei committed Feb 8, 2024
1 parent b34b503 commit 1f7043c
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
2 changes: 1 addition & 1 deletion tests/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def test_relax_fixed_vol(cu_slab):

def test_subselect_from_traj(cu_slab):

calc = EMT()
calc = (EMT, [], {})

cu_slab_optimised = cu_slab.copy()
cu_slab_optimised.set_positions(expected_relaxed_positions_constant_pressure)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_remote_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def do_generic_calc(tmp_path, sys_name, monkeypatch, remoteinfo_env):
monkeypatch.setenv('WFL_EXPYRE_NO_MARK_PROCESSED', '1')

t0 = time.time()
results = generic.calculate(inputs=ci, outputs=co, calculator=calc, properties=["energy", "forces"],
results = generic.calculate(inputs=ci, outputs=co, calculator=(EMT, [], {}), properties=["energy", "forces"],
autopara_info={"remote_info": ri})
dt = time.time() - t0
print('remote parallel calc_time', dt)
Expand All @@ -161,7 +161,7 @@ def do_generic_calc(tmp_path, sys_name, monkeypatch, remoteinfo_env):
co = OutputSpec(tmp_path / f'ats_o_{sys_name}.xyz')

t0 = time.time()
results = generic.calculate(inputs=ci, outputs=co, calculator=calc, properties=["energy", "forces"],
results = generic.calculate(inputs=ci, outputs=co, calculator=(EMT, [], {}), properties=["energy", "forces"],
autopara_info={"remote_info": ri})
dt_rerun = time.time() - t0
print('remote parallel calc_time', dt_rerun)
Expand Down Expand Up @@ -304,7 +304,7 @@ def do_resubmit_killed_jobs(tmp_path, sys_name, monkeypatch, remoteinfo_env):
ats[1] = Atoms(f'C{n**3}', positions=np.asarray(np.meshgrid(range(n), range(n), range(n))).reshape((3, -1)).T,
cell=[n]*3, pbc=[True]*3)

calc = EMT()
calc = (EMT, [], {})

# first just ignore failures
ri['ignore_failed_jobs'] = True
Expand Down
8 changes: 7 additions & 1 deletion wfl/calculators/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,13 @@ def _run_autopara_wrappable(atoms, calculator, properties=None, output_prefix='_
calculator_default = None

if output_prefix == '_auto_':
output_prefix = calculator.__class__.__name__ + '_'
if isinstance(calculator, tuple):
# constructor, get name directly
calc_class = calculator[0].__name__
else:
# calculator object, get name from class
calc_class = calculator.__class__.__name__
output_prefix = calc_class + '_'

at_out = []
for at in atoms_to_list(atoms):
Expand Down

0 comments on commit 1f7043c

Please sign in to comment.