diff --git a/dmff/admp/pme.py b/dmff/admp/pme.py index bf718cbdd..89daeae64 100755 --- a/dmff/admp/pme.py +++ b/dmff/admp/pme.py @@ -407,8 +407,8 @@ def calc_e_perm(dr, mscales, kappa, lmax=2): # be aware of unit and dimension !! rInv = 1 / dr - rInvVec = jnp.array([DIELECTRIC*(rInv**i) for i in range(0, 9)]) - alphaRVec = jnp.array([(kappa*dr)**i for i in range(0, 10)]) + rInvVec = jnp.array([DIELECTRIC*jnp.power(rInv + 1e-10, i) for i in range(0, 9)]) + alphaRVec = jnp.array([jnp.power(kappa*dr + 1e-10, i) for i in range(0, 10)]) X = 2 * jnp.exp(-alphaRVec[2]) / jnp.sqrt(np.pi) tmp = jnp.array(alphaRVec[1]) doubleFactorial = 1 @@ -864,7 +864,9 @@ def pme_real(positions, box, pairs, @partial(vmap, in_axes=(0, 0), out_axes=(0)) @jit_condition(static_argnums=()) def get_pair_dmp(pol1, pol2): - return jnp.power(pol1*pol2, 1/6) + p12 = pol1 * pol2 + p12 = jnp.where(p12 < 1e-16, 1e-16, p12) + return jnp.power(p12, 1/6) @jit_condition(static_argnums=(2)) diff --git a/dmff/api/paramset.py b/dmff/api/paramset.py index d15bef2a8..141adb366 100644 --- a/dmff/api/paramset.py +++ b/dmff/api/paramset.py @@ -24,7 +24,11 @@ class ParamSet: Converts all parameters to jax arrays. """ - def __init__(self, data: Dict[str, Union[Dict[str, jnp.ndarray], jnp.ndarray]] = None, mask: Dict[str, Union[Dict[str, jnp.ndarray], jnp.ndarray]] = None): + def __init__( + self, + data: Dict[str, Union[Dict[str, jnp.ndarray], jnp.ndarray]] = None, + mask: Dict[str, Union[Dict[str, jnp.ndarray], jnp.ndarray]] = None, + ): """ Initializes a new ParamSet object. @@ -52,7 +56,13 @@ def addField(self, field: str) -> None: self.parameters[field] = {} self.mask[field] = {} - def addParameter(self, values: jnp.ndarray, name: str, field: str = None, mask: jnp.ndarray = None) -> None: + def addParameter( + self, + values: jnp.ndarray, + name: str, + field: str = None, + mask: jnp.ndarray = None, + ) -> None: """ Adds a new parameter to the parameters and mask dictionaries. @@ -87,8 +97,7 @@ def to_jax(self) -> None: for key1 in self.parameters: if isinstance(self.parameters[key1], dict): for key2 in self.parameters[key1]: - self.parameters[key1][key2] = jnp.array( - self.parameters[key1][key2]) + self.parameters[key1][key2] = jnp.array(self.parameters[key1][key2]) else: self.parameters[key1] = jnp.array(self.parameters[key1]) @@ -108,6 +117,12 @@ def __getitem__(self, key: str) -> Union[Dict[str, jnp.ndarray], jnp.ndarray]: """ return self.parameters[key] + def update_mask(self, gradients): + gradients = jax.tree_map( + lambda g, m: jnp.where(jnp.abs(m - 1.0) > 1e-5, g, 0.0), gradients, self.mask + ) + return gradients + def flatten_paramset(prmset: ParamSet) -> tuple: """ @@ -145,5 +160,4 @@ def unflatten_paramset(aux_data: Dict, contents: tuple) -> ParamSet: return ParamSet(data=contents[0], mask=aux_data) -jax.tree_util.register_pytree_node( - ParamSet, flatten_paramset, unflatten_paramset) +jax.tree_util.register_pytree_node(ParamSet, flatten_paramset, unflatten_paramset) diff --git a/dmff/generators/admp.py b/dmff/generators/admp.py index d2756c0b4..85ec54981 100644 --- a/dmff/generators/admp.py +++ b/dmff/generators/admp.py @@ -1075,27 +1075,35 @@ def getName(self): def overwrite(self, paramset): Q_global = convert_harm2cart(paramset[self.name]["Q_local"], self.lmax) + q_local_masks = paramset.mask[self.name]["Q_local"] + polar_masks = paramset.mask[self.name]["pol"] n_multipole, n_pol = 0, 0 for nnode in range(len(self.ffinfo["Forces"][self.name]["node"])): node = self.ffinfo["Forces"][self.name]["node"][nnode] if node["name"] in ["Atom", "Multipole"]: node["c0"] = Q_global[n_multipole, 0] - node["dX"] = Q_global[n_multipole, 1] - node["dY"] = Q_global[n_multipole, 2] - node["dZ"] = Q_global[n_multipole, 3] - node["qXX"] = Q_global[n_multipole, 4] - node["qYY"] = Q_global[n_multipole, 5] - node["qZZ"] = Q_global[n_multipole, 6] - node["qXY"] = Q_global[n_multipole, 7] - node["qXZ"] = Q_global[n_multipole, 8] - node["qYZ"] = Q_global[n_multipole, 9] + if self.lmax >= 1: + node["dX"] = Q_global[n_multipole, 1] + node["dY"] = Q_global[n_multipole, 2] + node["dZ"] = Q_global[n_multipole, 3] + if self.lmax >= 2: + node["qXX"] = Q_global[n_multipole, 4] + node["qYY"] = Q_global[n_multipole, 5] + node["qZZ"] = Q_global[n_multipole, 6] + node["qXY"] = Q_global[n_multipole, 7] + node["qXZ"] = Q_global[n_multipole, 8] + node["qYZ"] = Q_global[n_multipole, 9] + if q_local_masks[n_multipole] < 0.999: + node["mask"] = "true" n_multipole += 1 elif node["name"] == "Polarize": node["polarizabilityXX"] = paramset[self.name]["pol"][n_pol] * 0.001 node["polarizabilityYY"] = paramset[self.name]["pol"][n_pol] * 0.001 node["polarizabilityZZ"] = paramset[self.name]["pol"][n_pol] * 0.001 node["thole"] = paramset[self.name]["thole"][n_pol] + if polar_masks[n_pol] < 0.999: + node["mask"] = "true" n_pol += 1 def _find_multipole_key_index(self, atype: str): diff --git a/dmff/optimize.py b/dmff/optimize.py index d2bacef6b..f2bc67ffb 100644 --- a/dmff/optimize.py +++ b/dmff/optimize.py @@ -1,4 +1,6 @@ +import jax from jax import grad +import jax.numpy as jnp from typing import Optional import optax @@ -12,41 +14,48 @@ def init_fn(params): def update_fn(updates, state, params): if params is None: - raise ValueError(optax.base.NO_PARAMS_MSG) + raise ValueError(optax._src.base.NO_PARAMS_MSG) updates = jax.tree_map( - lambda p, u: jnp.where((p + u) < pmin, u + pmax - pmin, u), params, - updates) + lambda p, u: jnp.where((p + u) < pmin, u + pmax - pmin, u), params, updates + ) updates = jax.tree_map( - lambda p, u: jnp.where((p + u) > pmax, u - pmax + pmin, u), params, - updates) + lambda p, u: jnp.where((p + u) > pmax, u - pmax + pmin, u), params, updates + ) return updates, state return optax._src.base.GradientTransformation(init_fn, update_fn) -def genOptimizer(optimizer="adam", - learning_rate=1.0, - nonzero=True, - clip=10.0, - periodic=None, - transition_steps=1000, - warmup_steps=0, - decay_rate=0.99, - options: dict={}): +def genOptimizer( + optimizer="adam", + learning_rate=1.0, + nonzero=True, + clip=10.0, + periodic=None, + transition_steps=1000, + warmup_steps=0, + decay_rate=0.99, + options: dict = {}, +): if decay_rate == 1.0 and warmup_steps == 0: options["learning_rate"] = learning_rate # Exponential decay of the learning rate. elif warmup_steps == 0: - scheduler = optax.exponential_decay(init_value=learning_rate, - transition_steps=transition_steps, - decay_rate=decay_rate) + scheduler = optax.exponential_decay( + init_value=learning_rate, + transition_steps=transition_steps, + decay_rate=decay_rate, + ) options["learning_rate"] = scheduler else: - scheduler = optax.warmup_exponential_decay_schedule(init_value=0, peak_value=learning_rate, - warmup_steps=warmup_steps, - transition_steps=transition_steps, - decay_rate=decay_rate) + scheduler = optax.warmup_exponential_decay_schedule( + init_value=0, + peak_value=learning_rate, + warmup_steps=warmup_steps, + transition_steps=transition_steps, + decay_rate=decay_rate, + ) options["learning_rate"] = scheduler # Combining gradient transforms using `optax.chain`. @@ -132,4 +141,4 @@ def __delitem__(self, key): del self.transforms[key] def finalize(self): - label2trans_iter(self.labels, self.mask, self.transforms) \ No newline at end of file + label2trans_iter(self.labels, self.mask, self.transforms)