Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slight change to none_to_dict to avoid nesting brackets. #349

Merged
merged 1 commit into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions janus_core/calculations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ def __init__(
param_prefix : str | None
Additional parameters to add to default file_prefix. Default is None.
"""
(read_kwargs, calc_kwargs, log_kwargs, tracker_kwargs) = none_to_dict(
(read_kwargs, calc_kwargs, log_kwargs, tracker_kwargs)
read_kwargs, calc_kwargs, log_kwargs, tracker_kwargs = none_to_dict(
read_kwargs, calc_kwargs, log_kwargs, tracker_kwargs
)

self.struct = struct
Expand Down
2 changes: 1 addition & 1 deletion janus_core/calculations/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(
Keyword arguments to pass to ase.io.write if saving structure with
results of calculations. Default is {}.
"""
(read_kwargs, write_kwargs) = none_to_dict((read_kwargs, write_kwargs))
read_kwargs, write_kwargs = none_to_dict(read_kwargs, write_kwargs)

self.invariants_only = invariants_only
self.calc_per_element = calc_per_element
Expand Down
4 changes: 2 additions & 2 deletions janus_core/calculations/eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ def __init__(
Prefix for output filenames. Default is inferred from structure name, or
chemical formula of the structure.
"""
(read_kwargs, minimize_kwargs, write_kwargs, plot_kwargs) = none_to_dict(
(read_kwargs, minimize_kwargs, write_kwargs, plot_kwargs)
read_kwargs, minimize_kwargs, write_kwargs, plot_kwargs = none_to_dict(
read_kwargs, minimize_kwargs, write_kwargs, plot_kwargs
)

self.min_volume = min_volume
Expand Down
4 changes: 2 additions & 2 deletions janus_core/calculations/geom_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ def __init__(
Keyword arguments to pass to ase.io.write to save optimization trajectory.
Must include "filename" keyword. Default is {}.
"""
(read_kwargs, filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs) = (
read_kwargs, filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs = (
none_to_dict(
(read_kwargs, filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs)
read_kwargs, filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs
)
)

Expand Down
22 changes: 10 additions & 12 deletions janus_core/calculations/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,13 +339,11 @@ def __init__(
post_process_kwargs,
correlation_kwargs,
) = none_to_dict(
(
read_kwargs,
minimize_kwargs,
write_kwargs,
post_process_kwargs,
correlation_kwargs,
)
read_kwargs,
minimize_kwargs,
write_kwargs,
post_process_kwargs,
correlation_kwargs,
)

self.ensemble = ensemble
Expand Down Expand Up @@ -1190,7 +1188,7 @@ def __init__(
self.pressure = pressure
super().__init__(*args, ensemble=ensemble, file_prefix=file_prefix, **kwargs)

ensemble_kwargs = ensemble_kwargs if ensemble_kwargs else {}
(ensemble_kwargs,) = none_to_dict(ensemble_kwargs)
self.ttime = thermostat_time * units.fs

if barostat_time:
Expand Down Expand Up @@ -1320,7 +1318,7 @@ def __init__(
"""
super().__init__(*args, ensemble=ensemble, **kwargs)

ensemble_kwargs = ensemble_kwargs if ensemble_kwargs else {}
(ensemble_kwargs,) = none_to_dict(ensemble_kwargs)
self.dyn = Langevin(
self.struct,
timestep=self.timestep,
Expand Down Expand Up @@ -1411,7 +1409,7 @@ def __init__(
Additional keyword arguments.
"""
super().__init__(*args, ensemble=ensemble, **kwargs)
ensemble_kwargs = ensemble_kwargs if ensemble_kwargs else {}
(ensemble_kwargs,) = none_to_dict(ensemble_kwargs)

self.dyn = VelocityVerlet(
self.struct,
Expand Down Expand Up @@ -1463,7 +1461,7 @@ def __init__(
**kwargs
Additional keyword arguments.
"""
ensemble_kwargs = ensemble_kwargs if ensemble_kwargs else {}
(ensemble_kwargs,) = none_to_dict(ensemble_kwargs)
super().__init__(
*args,
ensemble=ensemble,
Expand Down Expand Up @@ -1575,7 +1573,7 @@ def __init__(
**kwargs
Additional keyword arguments.
"""
ensemble_kwargs = ensemble_kwargs if ensemble_kwargs else {}
(ensemble_kwargs,) = none_to_dict(ensemble_kwargs)
super().__init__(
*args,
thermostat_time=thermostat_time,
Expand Down
14 changes: 6 additions & 8 deletions janus_core/calculations/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,15 +278,13 @@ def __init__(
enable_progress_bar : bool
Whether to show a progress bar during phonon calculations. Default is False.
"""
(read_kwargs, displacement_kwargs, minimize_kwargs, dos_kwargs, pdos_kwargs) = (
read_kwargs, displacement_kwargs, minimize_kwargs, dos_kwargs, pdos_kwargs = (
none_to_dict(
(
read_kwargs,
displacement_kwargs,
minimize_kwargs,
dos_kwargs,
pdos_kwargs,
)
read_kwargs,
displacement_kwargs,
minimize_kwargs,
dos_kwargs,
pdos_kwargs,
)
)

Expand Down
2 changes: 1 addition & 1 deletion janus_core/calculations/single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(
Keyword arguments to pass to ase.io.write if saving structure with results
of calculations. Default is {}.
"""
(read_kwargs, write_kwargs) = none_to_dict((read_kwargs, write_kwargs))
read_kwargs, write_kwargs = none_to_dict(read_kwargs, write_kwargs)

self.write_results = write_results
self.write_kwargs = write_kwargs
Expand Down
3 changes: 2 additions & 1 deletion janus_core/helpers/mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ase.calculators.mixing import SumCalculator

from janus_core.helpers.janus_types import Architectures, Devices, PathLike
from janus_core.helpers.utils import none_to_dict

if TYPE_CHECKING:
from ase.calculators.calculator import Calculator
Expand All @@ -39,7 +40,7 @@ def _set_model_path(
PathLike | torch.nn.Module | None
Path to MLIP model file, loaded model, or None.
"""
kwargs = kwargs if kwargs else {}
(kwargs,) = none_to_dict(kwargs)

# kwargs that may be used for `model_path`` for different MLIPs
# Note: "model" for chgnet (but not mace_mp or mace_off) and "potential" may refer
Expand Down
7 changes: 4 additions & 3 deletions janus_core/helpers/struct_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Properties,
)
from janus_core.helpers.mlip_calculators import choose_calculator
from janus_core.helpers.utils import none_to_dict


def results_to_info(
Expand Down Expand Up @@ -88,7 +89,7 @@ def attach_calculator(
calc_kwargs : dict[str, Any] | None
Keyword arguments to pass to the selected calculator. Default is {}.
"""
calc_kwargs = calc_kwargs if calc_kwargs else {}
(calc_kwargs,) = none_to_dict(calc_kwargs)

calculator = choose_calculator(
arch=arch,
Expand Down Expand Up @@ -151,7 +152,7 @@ def input_structs(
MaybeSequence[Atoms]
Structure(s) with attached MLIP calculators.
"""
read_kwargs = read_kwargs if read_kwargs else {}
(read_kwargs,) = none_to_dict(read_kwargs)

# Validate parameters
if not struct and not struct_path:
Expand Down Expand Up @@ -249,7 +250,7 @@ def output_structs(
"""
# Separate kwargs for output_structs from kwargs for ase.io.write
# This assumes values passed via kwargs have priority over passed parameters
write_kwargs = write_kwargs if write_kwargs else {}
(write_kwargs,) = none_to_dict(write_kwargs)
set_info = write_kwargs.pop("set_info", set_info)
properties = write_kwargs.pop("properties", properties)
invalidate_calc = write_kwargs.pop("invalidate_calc", invalidate_calc)
Expand Down
8 changes: 4 additions & 4 deletions janus_core/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,13 @@ def _build_filename(
return built_filename


def none_to_dict(dictionaries: Sequence[dict | None]) -> Generator[dict, None, None]:
def none_to_dict(*dictionaries: Sequence[dict | None]) -> Generator[dict, None, None]:
"""
Ensure dictionaries that may be None are dictionaries.

Parameters
----------
dictionaries : Sequence[dict | None]
*dictionaries : Sequence[dict | None]
Sequence of dictionaries that could be None.

Yields
Expand Down Expand Up @@ -259,7 +259,7 @@ def write_table(
2 4
3 6
"""
units = units if units else {}
(units,) = none_to_dict(units)
units.update(
{
key.removesuffix("_units"): val
Expand All @@ -268,7 +268,7 @@ def write_table(
}
)

formats = formats if formats else {}
(formats,) = none_to_dict(formats)
formats.update(
{
key.removesuffix("_format"): val
Expand Down
2 changes: 1 addition & 1 deletion janus_core/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def train(
tracker_kwargs : dict[str, Any] | None
Keyword arguments to pass to `config_tracker`. Default is {}.
"""
(log_kwargs, tracker_kwargs) = none_to_dict((log_kwargs, tracker_kwargs))
log_kwargs, tracker_kwargs = none_to_dict(log_kwargs, tracker_kwargs)

if req_file_keys is None:
req_file_keys = ["train_file", "test_file", "valid_file", "statistics_file"]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_output_structs(
)
def test_none_to_dict(dicts_in):
"""Test none_to_dict removes Nones from sequence, and preserves dictionaries."""
dicts = list(none_to_dict(dicts_in))
dicts = list(none_to_dict(*dicts_in))
for dictionary in dicts:
assert dictionary is not None

Expand Down