From bdeaca9dc2d1824ea1d6be280f0b062d9fdbe18f Mon Sep 17 00:00:00 2001 From: Jacob Wilkins <46597752+oerc0122@users.noreply.github.com> Date: Fri, 8 Nov 2024 15:39:57 +0000 Subject: [PATCH] Change none_to_dict to avoid nesting brackets (#349) --- janus_core/calculations/base.py | 4 ++-- janus_core/calculations/descriptors.py | 2 +- janus_core/calculations/eos.py | 4 ++-- janus_core/calculations/geom_opt.py | 4 ++-- janus_core/calculations/md.py | 22 ++++++++++------------ janus_core/calculations/phonons.py | 14 ++++++-------- janus_core/calculations/single_point.py | 2 +- janus_core/helpers/mlip_calculators.py | 3 ++- janus_core/helpers/struct_io.py | 7 ++++--- janus_core/helpers/utils.py | 8 ++++---- janus_core/training/train.py | 2 +- tests/test_utils.py | 2 +- 12 files changed, 36 insertions(+), 38 deletions(-) diff --git a/janus_core/calculations/base.py b/janus_core/calculations/base.py index 0797b1db..5956b8fd 100644 --- a/janus_core/calculations/base.py +++ b/janus_core/calculations/base.py @@ -133,8 +133,8 @@ def __init__( param_prefix : str | None Additional parameters to add to default file_prefix. Default is None. """ - (read_kwargs, calc_kwargs, log_kwargs, tracker_kwargs) = none_to_dict( - (read_kwargs, calc_kwargs, log_kwargs, tracker_kwargs) + read_kwargs, calc_kwargs, log_kwargs, tracker_kwargs = none_to_dict( + read_kwargs, calc_kwargs, log_kwargs, tracker_kwargs ) self.struct = struct diff --git a/janus_core/calculations/descriptors.py b/janus_core/calculations/descriptors.py index f8fa98ad..6b71a56e 100644 --- a/janus_core/calculations/descriptors.py +++ b/janus_core/calculations/descriptors.py @@ -137,7 +137,7 @@ def __init__( Keyword arguments to pass to ase.io.write if saving structure with results of calculations. Default is {}. """ - (read_kwargs, write_kwargs) = none_to_dict((read_kwargs, write_kwargs)) + read_kwargs, write_kwargs = none_to_dict(read_kwargs, write_kwargs) self.invariants_only = invariants_only self.calc_per_element = calc_per_element diff --git a/janus_core/calculations/eos.py b/janus_core/calculations/eos.py index 09926588..bd7bc863 100644 --- a/janus_core/calculations/eos.py +++ b/janus_core/calculations/eos.py @@ -196,8 +196,8 @@ def __init__( Prefix for output filenames. Default is inferred from structure name, or chemical formula of the structure. """ - (read_kwargs, minimize_kwargs, write_kwargs, plot_kwargs) = none_to_dict( - (read_kwargs, minimize_kwargs, write_kwargs, plot_kwargs) + read_kwargs, minimize_kwargs, write_kwargs, plot_kwargs = none_to_dict( + read_kwargs, minimize_kwargs, write_kwargs, plot_kwargs ) self.min_volume = min_volume diff --git a/janus_core/calculations/geom_opt.py b/janus_core/calculations/geom_opt.py index 170ae3f8..6077fddd 100644 --- a/janus_core/calculations/geom_opt.py +++ b/janus_core/calculations/geom_opt.py @@ -187,9 +187,9 @@ def __init__( Keyword arguments to pass to ase.io.write to save optimization trajectory. Must include "filename" keyword. Default is {}. """ - (read_kwargs, filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs) = ( + read_kwargs, filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs = ( none_to_dict( - (read_kwargs, filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs) + read_kwargs, filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs ) ) diff --git a/janus_core/calculations/md.py b/janus_core/calculations/md.py index 3b691795..f16263c0 100644 --- a/janus_core/calculations/md.py +++ b/janus_core/calculations/md.py @@ -339,13 +339,11 @@ def __init__( post_process_kwargs, correlation_kwargs, ) = none_to_dict( - ( - read_kwargs, - minimize_kwargs, - write_kwargs, - post_process_kwargs, - correlation_kwargs, - ) + read_kwargs, + minimize_kwargs, + write_kwargs, + post_process_kwargs, + correlation_kwargs, ) self.ensemble = ensemble @@ -1190,7 +1188,7 @@ def __init__( self.pressure = pressure super().__init__(*args, ensemble=ensemble, file_prefix=file_prefix, **kwargs) - ensemble_kwargs = ensemble_kwargs if ensemble_kwargs else {} + (ensemble_kwargs,) = none_to_dict(ensemble_kwargs) self.ttime = thermostat_time * units.fs if barostat_time: @@ -1320,7 +1318,7 @@ def __init__( """ super().__init__(*args, ensemble=ensemble, **kwargs) - ensemble_kwargs = ensemble_kwargs if ensemble_kwargs else {} + (ensemble_kwargs,) = none_to_dict(ensemble_kwargs) self.dyn = Langevin( self.struct, timestep=self.timestep, @@ -1411,7 +1409,7 @@ def __init__( Additional keyword arguments. """ super().__init__(*args, ensemble=ensemble, **kwargs) - ensemble_kwargs = ensemble_kwargs if ensemble_kwargs else {} + (ensemble_kwargs,) = none_to_dict(ensemble_kwargs) self.dyn = VelocityVerlet( self.struct, @@ -1463,7 +1461,7 @@ def __init__( **kwargs Additional keyword arguments. """ - ensemble_kwargs = ensemble_kwargs if ensemble_kwargs else {} + (ensemble_kwargs,) = none_to_dict(ensemble_kwargs) super().__init__( *args, ensemble=ensemble, @@ -1575,7 +1573,7 @@ def __init__( **kwargs Additional keyword arguments. """ - ensemble_kwargs = ensemble_kwargs if ensemble_kwargs else {} + (ensemble_kwargs,) = none_to_dict(ensemble_kwargs) super().__init__( *args, thermostat_time=thermostat_time, diff --git a/janus_core/calculations/phonons.py b/janus_core/calculations/phonons.py index 0718908c..fd936fb6 100644 --- a/janus_core/calculations/phonons.py +++ b/janus_core/calculations/phonons.py @@ -278,15 +278,13 @@ def __init__( enable_progress_bar : bool Whether to show a progress bar during phonon calculations. Default is False. """ - (read_kwargs, displacement_kwargs, minimize_kwargs, dos_kwargs, pdos_kwargs) = ( + read_kwargs, displacement_kwargs, minimize_kwargs, dos_kwargs, pdos_kwargs = ( none_to_dict( - ( - read_kwargs, - displacement_kwargs, - minimize_kwargs, - dos_kwargs, - pdos_kwargs, - ) + read_kwargs, + displacement_kwargs, + minimize_kwargs, + dos_kwargs, + pdos_kwargs, ) ) diff --git a/janus_core/calculations/single_point.py b/janus_core/calculations/single_point.py index 0fcb1b95..953f0070 100644 --- a/janus_core/calculations/single_point.py +++ b/janus_core/calculations/single_point.py @@ -140,7 +140,7 @@ def __init__( Keyword arguments to pass to ase.io.write if saving structure with results of calculations. Default is {}. """ - (read_kwargs, write_kwargs) = none_to_dict((read_kwargs, write_kwargs)) + read_kwargs, write_kwargs = none_to_dict(read_kwargs, write_kwargs) self.write_results = write_results self.write_kwargs = write_kwargs diff --git a/janus_core/helpers/mlip_calculators.py b/janus_core/helpers/mlip_calculators.py index 2308c3e5..5a436229 100644 --- a/janus_core/helpers/mlip_calculators.py +++ b/janus_core/helpers/mlip_calculators.py @@ -14,6 +14,7 @@ from ase.calculators.mixing import SumCalculator from janus_core.helpers.janus_types import Architectures, Devices, PathLike +from janus_core.helpers.utils import none_to_dict if TYPE_CHECKING: from ase.calculators.calculator import Calculator @@ -39,7 +40,7 @@ def _set_model_path( PathLike | torch.nn.Module | None Path to MLIP model file, loaded model, or None. """ - kwargs = kwargs if kwargs else {} + (kwargs,) = none_to_dict(kwargs) # kwargs that may be used for `model_path`` for different MLIPs # Note: "model" for chgnet (but not mace_mp or mace_off) and "potential" may refer diff --git a/janus_core/helpers/struct_io.py b/janus_core/helpers/struct_io.py index 9cdc04f4..e790a0ee 100644 --- a/janus_core/helpers/struct_io.py +++ b/janus_core/helpers/struct_io.py @@ -22,6 +22,7 @@ Properties, ) from janus_core.helpers.mlip_calculators import choose_calculator +from janus_core.helpers.utils import none_to_dict def results_to_info( @@ -88,7 +89,7 @@ def attach_calculator( calc_kwargs : dict[str, Any] | None Keyword arguments to pass to the selected calculator. Default is {}. """ - calc_kwargs = calc_kwargs if calc_kwargs else {} + (calc_kwargs,) = none_to_dict(calc_kwargs) calculator = choose_calculator( arch=arch, @@ -151,7 +152,7 @@ def input_structs( MaybeSequence[Atoms] Structure(s) with attached MLIP calculators. """ - read_kwargs = read_kwargs if read_kwargs else {} + (read_kwargs,) = none_to_dict(read_kwargs) # Validate parameters if not struct and not struct_path: @@ -249,7 +250,7 @@ def output_structs( """ # Separate kwargs for output_structs from kwargs for ase.io.write # This assumes values passed via kwargs have priority over passed parameters - write_kwargs = write_kwargs if write_kwargs else {} + (write_kwargs,) = none_to_dict(write_kwargs) set_info = write_kwargs.pop("set_info", set_info) properties = write_kwargs.pop("properties", properties) invalidate_calc = write_kwargs.pop("invalidate_calc", invalidate_calc) diff --git a/janus_core/helpers/utils.py b/janus_core/helpers/utils.py index 3153fac4..bd416e3d 100644 --- a/janus_core/helpers/utils.py +++ b/janus_core/helpers/utils.py @@ -151,13 +151,13 @@ def _build_filename( return built_filename -def none_to_dict(dictionaries: Sequence[dict | None]) -> Generator[dict, None, None]: +def none_to_dict(*dictionaries: Sequence[dict | None]) -> Generator[dict, None, None]: """ Ensure dictionaries that may be None are dictionaries. Parameters ---------- - dictionaries : Sequence[dict | None] + *dictionaries : Sequence[dict | None] Sequence of dictionaries that could be None. Yields @@ -259,7 +259,7 @@ def write_table( 2 4 3 6 """ - units = units if units else {} + (units,) = none_to_dict(units) units.update( { key.removesuffix("_units"): val @@ -268,7 +268,7 @@ def write_table( } ) - formats = formats if formats else {} + (formats,) = none_to_dict(formats) formats.update( { key.removesuffix("_format"): val diff --git a/janus_core/training/train.py b/janus_core/training/train.py index 72bb5610..d70f6300 100644 --- a/janus_core/training/train.py +++ b/janus_core/training/train.py @@ -70,7 +70,7 @@ def train( tracker_kwargs : dict[str, Any] | None Keyword arguments to pass to `config_tracker`. Default is {}. """ - (log_kwargs, tracker_kwargs) = none_to_dict((log_kwargs, tracker_kwargs)) + log_kwargs, tracker_kwargs = none_to_dict(log_kwargs, tracker_kwargs) if req_file_keys is None: req_file_keys = ["train_file", "test_file", "valid_file", "statistics_file"] diff --git a/tests/test_utils.py b/tests/test_utils.py index c000b6af..fdcaea3f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -148,7 +148,7 @@ def test_output_structs( ) def test_none_to_dict(dicts_in): """Test none_to_dict removes Nones from sequence, and preserves dictionaries.""" - dicts = list(none_to_dict(dicts_in)) + dicts = list(none_to_dict(*dicts_in)) for dictionary in dicts: assert dictionary is not None