Skip to content

Commit

Permalink
Fix optimizers for v1.0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
WangXinyan940 committed Oct 27, 2023
1 parent 03627c0 commit 029ee16
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 40 deletions.
8 changes: 5 additions & 3 deletions dmff/admp/pme.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,8 @@ def calc_e_perm(dr, mscales, kappa, lmax=2):

# be aware of unit and dimension !!
rInv = 1 / dr
rInvVec = jnp.array([DIELECTRIC*(rInv**i) for i in range(0, 9)])
alphaRVec = jnp.array([(kappa*dr)**i for i in range(0, 10)])
rInvVec = jnp.array([DIELECTRIC*jnp.power(rInv + 1e-10, i) for i in range(0, 9)])
alphaRVec = jnp.array([jnp.power(kappa*dr + 1e-10, i) for i in range(0, 10)])
X = 2 * jnp.exp(-alphaRVec[2]) / jnp.sqrt(np.pi)
tmp = jnp.array(alphaRVec[1])
doubleFactorial = 1
Expand Down Expand Up @@ -864,7 +864,9 @@ def pme_real(positions, box, pairs,
@partial(vmap, in_axes=(0, 0), out_axes=(0))
@jit_condition(static_argnums=())
def get_pair_dmp(pol1, pol2):
return jnp.power(pol1*pol2, 1/6)
p12 = pol1 * pol2
p12 = jnp.where(p12 < 1e-16, 1e-16, p12)
return jnp.power(p12, 1/6)


@jit_condition(static_argnums=(2))
Expand Down
26 changes: 20 additions & 6 deletions dmff/api/paramset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ class ParamSet:
Converts all parameters to jax arrays.
"""

def __init__(self, data: Dict[str, Union[Dict[str, jnp.ndarray], jnp.ndarray]] = None, mask: Dict[str, Union[Dict[str, jnp.ndarray], jnp.ndarray]] = None):
def __init__(
self,
data: Dict[str, Union[Dict[str, jnp.ndarray], jnp.ndarray]] = None,
mask: Dict[str, Union[Dict[str, jnp.ndarray], jnp.ndarray]] = None,
):
"""
Initializes a new ParamSet object.
Expand Down Expand Up @@ -52,7 +56,13 @@ def addField(self, field: str) -> None:
self.parameters[field] = {}
self.mask[field] = {}

def addParameter(self, values: jnp.ndarray, name: str, field: str = None, mask: jnp.ndarray = None) -> None:
def addParameter(
self,
values: jnp.ndarray,
name: str,
field: str = None,
mask: jnp.ndarray = None,
) -> None:
"""
Adds a new parameter to the parameters and mask dictionaries.
Expand Down Expand Up @@ -87,8 +97,7 @@ def to_jax(self) -> None:
for key1 in self.parameters:
if isinstance(self.parameters[key1], dict):
for key2 in self.parameters[key1]:
self.parameters[key1][key2] = jnp.array(
self.parameters[key1][key2])
self.parameters[key1][key2] = jnp.array(self.parameters[key1][key2])
else:
self.parameters[key1] = jnp.array(self.parameters[key1])

Expand All @@ -108,6 +117,12 @@ def __getitem__(self, key: str) -> Union[Dict[str, jnp.ndarray], jnp.ndarray]:
"""
return self.parameters[key]

def update_mask(self, gradients):
gradients = jax.tree_map(
lambda g, m: jnp.where(jnp.abs(m - 1.0) > 1e-5, g, 0.0), gradients, self.mask
)
return gradients


def flatten_paramset(prmset: ParamSet) -> tuple:
"""
Expand Down Expand Up @@ -145,5 +160,4 @@ def unflatten_paramset(aux_data: Dict, contents: tuple) -> ParamSet:
return ParamSet(data=contents[0], mask=aux_data)


jax.tree_util.register_pytree_node(
ParamSet, flatten_paramset, unflatten_paramset)
jax.tree_util.register_pytree_node(ParamSet, flatten_paramset, unflatten_paramset)
26 changes: 17 additions & 9 deletions dmff/generators/admp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,27 +1075,35 @@ def getName(self):

def overwrite(self, paramset):
Q_global = convert_harm2cart(paramset[self.name]["Q_local"], self.lmax)
q_local_masks = paramset.mask[self.name]["Q_local"]
polar_masks = paramset.mask[self.name]["pol"]

n_multipole, n_pol = 0, 0
for nnode in range(len(self.ffinfo["Forces"][self.name]["node"])):
node = self.ffinfo["Forces"][self.name]["node"][nnode]
if node["name"] in ["Atom", "Multipole"]:
node["c0"] = Q_global[n_multipole, 0]
node["dX"] = Q_global[n_multipole, 1]
node["dY"] = Q_global[n_multipole, 2]
node["dZ"] = Q_global[n_multipole, 3]
node["qXX"] = Q_global[n_multipole, 4]
node["qYY"] = Q_global[n_multipole, 5]
node["qZZ"] = Q_global[n_multipole, 6]
node["qXY"] = Q_global[n_multipole, 7]
node["qXZ"] = Q_global[n_multipole, 8]
node["qYZ"] = Q_global[n_multipole, 9]
if self.lmax >= 1:
node["dX"] = Q_global[n_multipole, 1]
node["dY"] = Q_global[n_multipole, 2]
node["dZ"] = Q_global[n_multipole, 3]
if self.lmax >= 2:
node["qXX"] = Q_global[n_multipole, 4]
node["qYY"] = Q_global[n_multipole, 5]
node["qZZ"] = Q_global[n_multipole, 6]
node["qXY"] = Q_global[n_multipole, 7]
node["qXZ"] = Q_global[n_multipole, 8]
node["qYZ"] = Q_global[n_multipole, 9]
if q_local_masks[n_multipole] < 0.999:
node["mask"] = "true"
n_multipole += 1
elif node["name"] == "Polarize":
node["polarizabilityXX"] = paramset[self.name]["pol"][n_pol] * 0.001
node["polarizabilityYY"] = paramset[self.name]["pol"][n_pol] * 0.001
node["polarizabilityZZ"] = paramset[self.name]["pol"][n_pol] * 0.001
node["thole"] = paramset[self.name]["thole"][n_pol]
if polar_masks[n_pol] < 0.999:
node["mask"] = "true"
n_pol += 1

def _find_multipole_key_index(self, atype: str):
Expand Down
53 changes: 31 additions & 22 deletions dmff/optimize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import jax
from jax import grad
import jax.numpy as jnp
from typing import Optional
import optax

Expand All @@ -12,41 +14,48 @@ def init_fn(params):

def update_fn(updates, state, params):
if params is None:
raise ValueError(optax.base.NO_PARAMS_MSG)
raise ValueError(optax._src.base.NO_PARAMS_MSG)

updates = jax.tree_map(
lambda p, u: jnp.where((p + u) < pmin, u + pmax - pmin, u), params,
updates)
lambda p, u: jnp.where((p + u) < pmin, u + pmax - pmin, u), params, updates
)
updates = jax.tree_map(
lambda p, u: jnp.where((p + u) > pmax, u - pmax + pmin, u), params,
updates)
lambda p, u: jnp.where((p + u) > pmax, u - pmax + pmin, u), params, updates
)
return updates, state

return optax._src.base.GradientTransformation(init_fn, update_fn)


def genOptimizer(optimizer="adam",
learning_rate=1.0,
nonzero=True,
clip=10.0,
periodic=None,
transition_steps=1000,
warmup_steps=0,
decay_rate=0.99,
options: dict={}):
def genOptimizer(
optimizer="adam",
learning_rate=1.0,
nonzero=True,
clip=10.0,
periodic=None,
transition_steps=1000,
warmup_steps=0,
decay_rate=0.99,
options: dict = {},
):
if decay_rate == 1.0 and warmup_steps == 0:
options["learning_rate"] = learning_rate
# Exponential decay of the learning rate.
elif warmup_steps == 0:
scheduler = optax.exponential_decay(init_value=learning_rate,
transition_steps=transition_steps,
decay_rate=decay_rate)
scheduler = optax.exponential_decay(
init_value=learning_rate,
transition_steps=transition_steps,
decay_rate=decay_rate,
)
options["learning_rate"] = scheduler
else:
scheduler = optax.warmup_exponential_decay_schedule(init_value=0, peak_value=learning_rate,
warmup_steps=warmup_steps,
transition_steps=transition_steps,
decay_rate=decay_rate)
scheduler = optax.warmup_exponential_decay_schedule(
init_value=0,
peak_value=learning_rate,
warmup_steps=warmup_steps,
transition_steps=transition_steps,
decay_rate=decay_rate,
)
options["learning_rate"] = scheduler

# Combining gradient transforms using `optax.chain`.
Expand Down Expand Up @@ -132,4 +141,4 @@ def __delitem__(self, key):
del self.transforms[key]

def finalize(self):
label2trans_iter(self.labels, self.mask, self.transforms)
label2trans_iter(self.labels, self.mask, self.transforms)

0 comments on commit 029ee16

Please sign in to comment.