Skip to content

Commit

Permalink
add refresh in qeq.py
Browse files Browse the repository at this point in the history
  • Loading branch information
gust-07 committed Oct 20, 2023
1 parent 32896fb commit 72cbf04
Showing 1 changed file with 16 additions and 24 deletions.
40 changes: 16 additions & 24 deletions dmff/admp/qeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,22 @@ def get_chgs():
return energy

return get_energy
def update_env(self, attr, val):
'''
Update the environment of the calculator
'''
setattr(self, attr, val)
self.refresh_calculators()


def refresh_calculators(self):
'''
refresh the energy and force calculators according to the current environment
'''
# generate the force calculator
self.get_energy = self.generate_get_energy()
self.get_forces = value_and_grad(self.get_energy)
return

def E_constQ(q, lagmt, const_list, const_vals):
constraint = (jnp.sum(q[const_list], axis=1) - const_vals) * lagmt
Expand Down Expand Up @@ -195,27 +211,3 @@ def etainv_piecewise(eta):
(lambda x: jnp.array(1/eta), lambda x:jnp.array(0)))


@jit_condition(static_argnums=(7,9,10,12))
def E_grads2(b_value, const_vals, chi, J, positions, box, pairs, damp_mod, eta, neutral_flag, constQ, const_list,pbc_flag):
n_const = len(const_vals)
q = b_value[:-n_const]
lagmt = b_value[-n_const:]
g1,g2 = grad(E_full,argnums=(3,13))(positions, box, pairs, q, damp_mod, eta, neutral_flag, chi, J, constQ, const_list, const_vals, lagmt,pbc_flag)
g = jnp.concatenate((g1,g2))
return jnp.sum(g **2)
@jit_condition(static_argnums=(7,9,10,12))
def Q_equi2(b_value, const_vals, chi, J, positions, box, pairs, damp_mod, eta, neutral_flag, constQ, const_list,pbc_flag ):
solver2 = jaxopt.BFGS(fun = E_grads2, maxiter=500)
q0,state2 = solver2.run(b_value, const_vals, chi, J, positions, box, pairs, damp_mod, eta, neutral_flag, constQ, const_list,pbc_flag)
return q0

#def Q_jac(b_value, chi, J, eta, constQ, const_list, const_vals):
# q0 = Q_equi(b_value, chi, J, eta, constQ, const_list, const_vals)
# q_jac = jax.jacobian(Q_equi, argnums=(1,2))(b_value, chi, J, eta, constQ, const_list, const_vals)
# return q_jac

@jit
def E_force(q, lagmt):
E,f = value_and_grad(E_full,argnums=(0))(pos, box, pairs, q, damp_mod, eta, neutral_flag, chi, J, constQ, const_list, const_vals, lagmt, pbc_flag)
return E,f

0 comments on commit 72cbf04

Please sign in to comment.