diff --git a/paths_cli/commands/pathsampling.py b/paths_cli/commands/pathsampling.py index 03a06f7..7436311 100644 --- a/paths_cli/commands/pathsampling.py +++ b/paths_cli/commands/pathsampling.py @@ -3,7 +3,8 @@ from paths_cli import OPSCommandPlugin from paths_cli.parameters import ( - INPUT_FILE, OUTPUT_FILE, INIT_CONDS, SCHEME, N_STEPS_MC + INPUT_FILE, OUTPUT_FILE, INIT_CONDS, SCHEME, N_STEPS_MC, + SIMULATION_CV_MODE, ) @@ -16,9 +17,12 @@ @SCHEME.clicked(required=False) @INIT_CONDS.clicked(required=False) @N_STEPS_MC -def pathsampling(input_file, output_file, scheme, init_conds, nsteps): +@SIMULATION_CV_MODE.clicked() +def pathsampling(input_file, output_file, scheme, init_conds, nsteps, + cv_mode): """General path sampling, using setup in INPUT_FILE""" storage = INPUT_FILE.get(input_file) + SIMULATION_CV_MODE(storage, cv_mode) pathsampling_main(output_storage=OUTPUT_FILE.get(output_file), scheme=SCHEME.get(storage, scheme), init_conds=INIT_CONDS.get(storage, init_conds), diff --git a/paths_cli/parameters.py b/paths_cli/parameters.py index 574b17d..c94dee9 100644 --- a/paths_cli/parameters.py +++ b/paths_cli/parameters.py @@ -1,4 +1,5 @@ import click +import warnings from paths_cli.param_core import ( Option, Argument, OPSStorageLoadSingle, OPSStorageLoadMultiple, OPSStorageLoadNames, StorageLoader, GetByName, GetByNumber, GetOnly, @@ -133,3 +134,41 @@ def get(self, storage, names): help="number of Monte Carlo trials to run") MULTI_CV = CVS + + +class CVMode: + """Class for generating CVMode parameters. + """ + def __init__(self, options, default): + allowed = {"production", "analysis", "no-caching"} + if extras := set(options) - allowed: + raise ValueError(f"Invalid options: {extras}") + if default not in options: + raise ValueError(f"Default '{default}' not in options {options}") + + self.default = default + self.param = Option( + "--cv-mode", + type=click.Choice(options), + help=( + "Mode for CVs (only used for SimStore DB files). Default " + f"'{default}'." + ), + default=default, + ) + + def clicked(self): + return self.param.clicked() + + def __call__(self, storage, cv_mode): + from openpathsampling.experimental.storage import Storage + if cv_mode != self.default and not isinstance(storage, Storage): + warnings.warn("Not a SimStore file: cv-mode argument unused") + return + + for cv in storage.cvs: + # TODO: add logger + # _logger.info(f"Setting '{cv.name}' to mode '{cv_mode}'.") + cv.mode = cv_mode + +SIMULATION_CV_MODE = CVMode(["production", "no-caching"], "production") diff --git a/paths_cli/tests/test_parameters.py b/paths_cli/tests/test_parameters.py index e4cc9f1..755ae78 100644 --- a/paths_cli/tests/test_parameters.py +++ b/paths_cli/tests/test_parameters.py @@ -511,3 +511,49 @@ def test_APPEND_FILE(ext): os.remove(filename) os.rmdir(tempdir) undo_monkey_patch(stored_functions) + + +class TestCVMode: + def test_bad_option(self): + with pytest.raises(ValueError, match="Invalid options"): + CVMode(["production", "foo"], "production") + + def test_bad_default(self): + with pytest.raises(ValueError, match="not in options"): + CVMode(["production", "no-caching"], "analysis") + + def test_call(self, tmp_path): + cv_mode = CVMode(["production", "no-caching"], "no-caching") + from openpathsampling.experimental.storage.collective_variables \ + import CollectiveVariable + from openpathsampling.experimental.storage import ( + Storage, monkey_patch_all + ) + cv = CollectiveVariable(lambda s: s.xyz[0][0]).named('x') + filename = str(tmp_path / "foo.db") + stored_functions = pre_monkey_patch() + monkey_patch_all(paths) + st = Storage(filename, mode='w') + assert cv.mode == "analysis" + st.save(cv) + st.close() + del cv + + storage = Storage(filename, mode='r') + cv = storage.cvs['x'] + assert cv.mode == "analysis" + cv_mode(storage, "no-caching") + assert cv.mode == "no-caching" + undo_monkey_patch(stored_functions) + + def test_call_non_simstore(self, tmp_path): + cv_mode = CVMode(["production", "no-caching"], "no-caching") + filename = str(tmp_path / "foo.nc") + cv = paths.FunctionCV("x", lambda x: x.xyz[0][0]) + st = paths.Storage(filename, mode='w') + st.save(cv) + st.close() + + storage = paths.Storage(filename, mode='r') + with pytest.warns(UserWarning, match="Not a SimStore"): + cv_mode(storage, "production")