Skip to content

Commit

Permalink
Change none_to_dict to avoid nesting brackets (#349)
Browse files Browse the repository at this point in the history
  • Loading branch information
oerc0122 authored Nov 8, 2024
1 parent fb00ace commit bdeaca9
Show file tree
Hide file tree
Showing 12 changed files with 36 additions and 38 deletions.
4 changes: 2 additions & 2 deletions janus_core/calculations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ def __init__(
param_prefix : str | None
Additional parameters to add to default file_prefix. Default is None.
"""
(read_kwargs, calc_kwargs, log_kwargs, tracker_kwargs) = none_to_dict(
(read_kwargs, calc_kwargs, log_kwargs, tracker_kwargs)
read_kwargs, calc_kwargs, log_kwargs, tracker_kwargs = none_to_dict(
read_kwargs, calc_kwargs, log_kwargs, tracker_kwargs
)

self.struct = struct
Expand Down
2 changes: 1 addition & 1 deletion janus_core/calculations/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(
Keyword arguments to pass to ase.io.write if saving structure with
results of calculations. Default is {}.
"""
(read_kwargs, write_kwargs) = none_to_dict((read_kwargs, write_kwargs))
read_kwargs, write_kwargs = none_to_dict(read_kwargs, write_kwargs)

self.invariants_only = invariants_only
self.calc_per_element = calc_per_element
Expand Down
4 changes: 2 additions & 2 deletions janus_core/calculations/eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ def __init__(
Prefix for output filenames. Default is inferred from structure name, or
chemical formula of the structure.
"""
(read_kwargs, minimize_kwargs, write_kwargs, plot_kwargs) = none_to_dict(
(read_kwargs, minimize_kwargs, write_kwargs, plot_kwargs)
read_kwargs, minimize_kwargs, write_kwargs, plot_kwargs = none_to_dict(
read_kwargs, minimize_kwargs, write_kwargs, plot_kwargs
)

self.min_volume = min_volume
Expand Down
4 changes: 2 additions & 2 deletions janus_core/calculations/geom_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ def __init__(
Keyword arguments to pass to ase.io.write to save optimization trajectory.
Must include "filename" keyword. Default is {}.
"""
(read_kwargs, filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs) = (
read_kwargs, filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs = (
none_to_dict(
(read_kwargs, filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs)
read_kwargs, filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs
)
)

Expand Down
22 changes: 10 additions & 12 deletions janus_core/calculations/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,13 +339,11 @@ def __init__(
post_process_kwargs,
correlation_kwargs,
) = none_to_dict(
(
read_kwargs,
minimize_kwargs,
write_kwargs,
post_process_kwargs,
correlation_kwargs,
)
read_kwargs,
minimize_kwargs,
write_kwargs,
post_process_kwargs,
correlation_kwargs,
)

self.ensemble = ensemble
Expand Down Expand Up @@ -1190,7 +1188,7 @@ def __init__(
self.pressure = pressure
super().__init__(*args, ensemble=ensemble, file_prefix=file_prefix, **kwargs)

ensemble_kwargs = ensemble_kwargs if ensemble_kwargs else {}
(ensemble_kwargs,) = none_to_dict(ensemble_kwargs)
self.ttime = thermostat_time * units.fs

if barostat_time:
Expand Down Expand Up @@ -1320,7 +1318,7 @@ def __init__(
"""
super().__init__(*args, ensemble=ensemble, **kwargs)

ensemble_kwargs = ensemble_kwargs if ensemble_kwargs else {}
(ensemble_kwargs,) = none_to_dict(ensemble_kwargs)
self.dyn = Langevin(
self.struct,
timestep=self.timestep,
Expand Down Expand Up @@ -1411,7 +1409,7 @@ def __init__(
Additional keyword arguments.
"""
super().__init__(*args, ensemble=ensemble, **kwargs)
ensemble_kwargs = ensemble_kwargs if ensemble_kwargs else {}
(ensemble_kwargs,) = none_to_dict(ensemble_kwargs)

self.dyn = VelocityVerlet(
self.struct,
Expand Down Expand Up @@ -1463,7 +1461,7 @@ def __init__(
**kwargs
Additional keyword arguments.
"""
ensemble_kwargs = ensemble_kwargs if ensemble_kwargs else {}
(ensemble_kwargs,) = none_to_dict(ensemble_kwargs)
super().__init__(
*args,
ensemble=ensemble,
Expand Down Expand Up @@ -1575,7 +1573,7 @@ def __init__(
**kwargs
Additional keyword arguments.
"""
ensemble_kwargs = ensemble_kwargs if ensemble_kwargs else {}
(ensemble_kwargs,) = none_to_dict(ensemble_kwargs)
super().__init__(
*args,
thermostat_time=thermostat_time,
Expand Down
14 changes: 6 additions & 8 deletions janus_core/calculations/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,15 +278,13 @@ def __init__(
enable_progress_bar : bool
Whether to show a progress bar during phonon calculations. Default is False.
"""
(read_kwargs, displacement_kwargs, minimize_kwargs, dos_kwargs, pdos_kwargs) = (
read_kwargs, displacement_kwargs, minimize_kwargs, dos_kwargs, pdos_kwargs = (
none_to_dict(
(
read_kwargs,
displacement_kwargs,
minimize_kwargs,
dos_kwargs,
pdos_kwargs,
)
read_kwargs,
displacement_kwargs,
minimize_kwargs,
dos_kwargs,
pdos_kwargs,
)
)

Expand Down
2 changes: 1 addition & 1 deletion janus_core/calculations/single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(
Keyword arguments to pass to ase.io.write if saving structure with results
of calculations. Default is {}.
"""
(read_kwargs, write_kwargs) = none_to_dict((read_kwargs, write_kwargs))
read_kwargs, write_kwargs = none_to_dict(read_kwargs, write_kwargs)

self.write_results = write_results
self.write_kwargs = write_kwargs
Expand Down
3 changes: 2 additions & 1 deletion janus_core/helpers/mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ase.calculators.mixing import SumCalculator

from janus_core.helpers.janus_types import Architectures, Devices, PathLike
from janus_core.helpers.utils import none_to_dict

if TYPE_CHECKING:
from ase.calculators.calculator import Calculator
Expand All @@ -39,7 +40,7 @@ def _set_model_path(
PathLike | torch.nn.Module | None
Path to MLIP model file, loaded model, or None.
"""
kwargs = kwargs if kwargs else {}
(kwargs,) = none_to_dict(kwargs)

# kwargs that may be used for `model_path`` for different MLIPs
# Note: "model" for chgnet (but not mace_mp or mace_off) and "potential" may refer
Expand Down
7 changes: 4 additions & 3 deletions janus_core/helpers/struct_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Properties,
)
from janus_core.helpers.mlip_calculators import choose_calculator
from janus_core.helpers.utils import none_to_dict


def results_to_info(
Expand Down Expand Up @@ -88,7 +89,7 @@ def attach_calculator(
calc_kwargs : dict[str, Any] | None
Keyword arguments to pass to the selected calculator. Default is {}.
"""
calc_kwargs = calc_kwargs if calc_kwargs else {}
(calc_kwargs,) = none_to_dict(calc_kwargs)

calculator = choose_calculator(
arch=arch,
Expand Down Expand Up @@ -151,7 +152,7 @@ def input_structs(
MaybeSequence[Atoms]
Structure(s) with attached MLIP calculators.
"""
read_kwargs = read_kwargs if read_kwargs else {}
(read_kwargs,) = none_to_dict(read_kwargs)

# Validate parameters
if not struct and not struct_path:
Expand Down Expand Up @@ -249,7 +250,7 @@ def output_structs(
"""
# Separate kwargs for output_structs from kwargs for ase.io.write
# This assumes values passed via kwargs have priority over passed parameters
write_kwargs = write_kwargs if write_kwargs else {}
(write_kwargs,) = none_to_dict(write_kwargs)
set_info = write_kwargs.pop("set_info", set_info)
properties = write_kwargs.pop("properties", properties)
invalidate_calc = write_kwargs.pop("invalidate_calc", invalidate_calc)
Expand Down
8 changes: 4 additions & 4 deletions janus_core/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,13 @@ def _build_filename(
return built_filename


def none_to_dict(dictionaries: Sequence[dict | None]) -> Generator[dict, None, None]:
def none_to_dict(*dictionaries: Sequence[dict | None]) -> Generator[dict, None, None]:
"""
Ensure dictionaries that may be None are dictionaries.
Parameters
----------
dictionaries : Sequence[dict | None]
*dictionaries : Sequence[dict | None]
Sequence of dictionaries that could be None.
Yields
Expand Down Expand Up @@ -259,7 +259,7 @@ def write_table(
2 4
3 6
"""
units = units if units else {}
(units,) = none_to_dict(units)
units.update(
{
key.removesuffix("_units"): val
Expand All @@ -268,7 +268,7 @@ def write_table(
}
)

formats = formats if formats else {}
(formats,) = none_to_dict(formats)
formats.update(
{
key.removesuffix("_format"): val
Expand Down
2 changes: 1 addition & 1 deletion janus_core/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def train(
tracker_kwargs : dict[str, Any] | None
Keyword arguments to pass to `config_tracker`. Default is {}.
"""
(log_kwargs, tracker_kwargs) = none_to_dict((log_kwargs, tracker_kwargs))
log_kwargs, tracker_kwargs = none_to_dict(log_kwargs, tracker_kwargs)

if req_file_keys is None:
req_file_keys = ["train_file", "test_file", "valid_file", "statistics_file"]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_output_structs(
)
def test_none_to_dict(dicts_in):
"""Test none_to_dict removes Nones from sequence, and preserves dictionaries."""
dicts = list(none_to_dict(dicts_in))
dicts = list(none_to_dict(*dicts_in))
for dictionary in dicts:
assert dictionary is not None

Expand Down

0 comments on commit bdeaca9

Please sign in to comment.