Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update eann.py #177

Closed
wants to merge 126 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
126 commits
Select commit Hold shift + click to select a range
8564c13
Remove jax_md requirement to support high version of jax & tensorflow
WangXinyan940 Jan 5, 2023
c0d9408
Initialize topology information generator and its UT
WangXinyan940 Jan 10, 2023
c46c312
Half way of detecting propers
WangXinyan940 Jan 10, 2023
79a13c8
Finish proper searching
WangXinyan940 Jan 11, 2023
4417306
Update topology.py
WangXinyan940 Jan 11, 2023
0788c1e
Update topology.py
WangXinyan940 Jan 11, 2023
edced36
Auto detect impropers
WangXinyan940 Jan 11, 2023
35bd2e3
Update topology.py
WangXinyan940 Jan 11, 2023
8ff48d3
Initialize template IO
WangXinyan940 Jan 11, 2023
0a1647e
Remove jax version limitation
WangXinyan940 Jan 11, 2023
ba5a768
Finish atom type matching with templates
WangXinyan940 Jan 11, 2023
a160cf1
Initialize UT for topology tools
WangXinyan940 Jan 11, 2023
67edf0a
Upload auto workflow for UT
WangXinyan940 Jan 12, 2023
cabcd72
Update initial version of frontend impl & unit tests
Mar 1, 2023
18849fc
Modify new code to fit old parts
WangXinyan940 Mar 1, 2023
615b393
Make the structure clean
WangXinyan940 Mar 1, 2023
8cb20b7
Update code
WangXinyan940 Mar 1, 2023
3d423c2
Finished adding vsite from template and smarts parsers
WangXinyan940 Mar 2, 2023
b687404
Support VSite IO in xml template
WangXinyan940 Mar 2, 2023
a209f10
Support vsite in template matching
WangXinyan940 Mar 2, 2023
717e6f1
Add example
WangXinyan940 Mar 2, 2023
e7a3785
Add example of vsite addition
WangXinyan940 Mar 2, 2023
28da5bd
Try a better implementation
WangXinyan940 Mar 5, 2023
51fdb4a
Implement a better Topology
WangXinyan940 Mar 6, 2023
16bb78f
Update bettertopology.py
WangXinyan940 Mar 6, 2023
14753a0
Designed the usage of operators
WangXinyan940 Mar 9, 2023
214d154
Update our own Topology class
WangXinyan940 Mar 9, 2023
4d8d4d7
Update test_operators.py
WangXinyan940 Mar 9, 2023
1a9db72
Update test_operators.py
WangXinyan940 Mar 9, 2023
84de55c
Implement VSite and AType operators
WangXinyan940 Mar 9, 2023
0fd4918
Create smartsvsite.py
WangXinyan940 Mar 9, 2023
ea4bf5a
Update am1charge.py
WangXinyan940 Mar 9, 2023
a7a3406
Update test_operators.py
WangXinyan940 Mar 9, 2023
28fb354
Add AM1 charge OP
WangXinyan940 Mar 14, 2023
904380d
Add Coulomb Generator
WangXinyan940 Mar 21, 2023
54f0e7e
supporting generators
WangXinyan940 Mar 22, 2023
68a26a9
Add method to find equivalent atoms (without vsites)
WangXinyan940 Mar 23, 2023
2816e73
Let AM1 charge calculator use eqv info
WangXinyan940 Mar 23, 2023
a19e57b
Add full support of BCC charge
WangXinyan940 Mar 24, 2023
25f554e
Support ambertools based typification
WangXinyan940 Mar 24, 2023
49cd1fd
Update classical.py
WangXinyan940 Mar 24, 2023
e91652f
Add LJ generator and an example
WangXinyan940 Mar 28, 2023
b69e5c7
support LJ prms
WangXinyan940 Apr 3, 2023
8274677
Deal with atom classes in LJ
WangXinyan940 Apr 4, 2023
b46a904
Add water example for interaction calculation
WangXinyan940 Apr 4, 2023
209612a
Update test_inter_water.py
WangXinyan940 Apr 4, 2023
4a993c0
Add support for loading BCC prms
WangXinyan940 Apr 6, 2023
c60dd6b
Update classical.py
WangXinyan940 Apr 6, 2023
e272cc1
Finish BCC support
WangXinyan940 Apr 6, 2023
b29984b
support type2 vsite
WangXinyan940 Apr 19, 2023
cc5009b
Make operators use the same initialization method
WangXinyan940 Apr 19, 2023
c2e8395
Support smirks patching on vsites.
WangXinyan940 Apr 19, 2023
2e5ed90
Finished dimer energy example
WangXinyan940 Apr 20, 2023
0eaf31f
Update classical.py
WangXinyan940 Apr 20, 2023
a2f4530
Update topology.py
WangXinyan940 Apr 20, 2023
719b0a5
Fix ParamSet to be a correct PyTree
WangXinyan940 Apr 20, 2023
c5f6d26
Update paramset.py
WangXinyan940 Apr 20, 2023
1b7011f
Calculate test system by hand
WangXinyan940 Apr 20, 2023
8cbfb45
Add test case for NoCutoff Coul and LJ energy
WangXinyan940 Apr 20, 2023
e7b5d05
Support NBFix
WangXinyan940 Apr 21, 2023
3b3eca6
Add Hamiltonian and its unittest.
WangXinyan940 Apr 23, 2023
edf1f16
Delete test.xml
WangXinyan940 Apr 23, 2023
6080c47
update vsite positions in hamiltonian
WangXinyan940 Apr 23, 2023
e729b0c
replace Chem.SanitizeMol with topdata.regularize_aromaticity
TablewareBox Apr 23, 2023
f61c2be
fix N+ formal charge
TablewareBox Apr 24, 2023
ba065d3
Merge branch 'master' into wangxy/frontend-refactor
WangXinyan940 Apr 24, 2023
d700e64
Update .gitignore
WangXinyan940 Apr 24, 2023
dc8bd96
Merge branch 'wangxy/frontend-refactor' of https://github.com/deepmod…
TablewareBox Apr 24, 2023
f03d198
Update topology.py
WangXinyan940 Apr 24, 2023
4e5678d
Support the way of updating paramset in optax
WangXinyan940 Apr 24, 2023
b4d82f2
Update test_run_dimer_energy.py
WangXinyan940 Apr 24, 2023
71c3fa6
Initialize generators before loading hamiltonian
WangXinyan940 Apr 26, 2023
cbad09a
add warmup and nesterov optimizer
TablewareBox May 1, 2023
c442416
fix 3fd vsite coordinates
TablewareBox May 4, 2023
a0852dd
Add mask generation
WangXinyan940 May 11, 2023
786237b
build vsites in multiple molecules separately
TablewareBox May 14, 2023
b721c9e
support dummy atom reading
TablewareBox May 24, 2023
9c2aeee
fix vsite topologies and molecules
TablewareBox May 25, 2023
fc9a290
Init new example
WangXinyan940 Jun 19, 2023
33a452d
Support Parmed LJ modifier
WangXinyan940 Jun 20, 2023
766832a
Update example for lennard jones optimization
WangXinyan940 Jun 26, 2023
e64caae
Create simpler MBAR estimator
WangXinyan940 Jun 26, 2023
d5557c2
Add reweighting estimator
WangXinyan940 Jun 27, 2023
315416d
Update opt.ipynb
WangXinyan940 Jun 27, 2023
4b03410
Update opt.ipynb
WangXinyan940 Jun 27, 2023
4a82a09
Update opt.ipynb
WangXinyan940 Jun 27, 2023
3214592
Add a simple example to calculate lennard-jones potential
WangXinyan940 Aug 25, 2023
dfbc39d
Add openmm as reference
WangXinyan940 Aug 25, 2023
dc974c0
Pearl dingzhen patch 1 (#112)
pearlDingzhen Sep 11, 2023
8efbe63
Update inter.py
WangXinyan940 Sep 11, 2023
60fce34
Remove Jax-MD requirement
WangXinyan940 Oct 11, 2023
a4a5fce
Fix Hamiltonian
WangXinyan940 Oct 11, 2023
3b7ffcd
Add HarmonicAngleGenerator
WangXinyan940 Oct 11, 2023
3853f61
Add PeriodicTorsionGenerator
WangXinyan940 Oct 11, 2023
118b192
Update NonbondedForce
WangXinyan940 Oct 11, 2023
b97c462
Correct the order of improper matching
WangXinyan940 Oct 11, 2023
bd62fcc
Fix and update unittests.
WangXinyan940 Oct 15, 2023
837ae1a
Update unit-test workflow to fix requirement problem
WangXinyan940 Oct 16, 2023
1cad257
Update test_compute.py
WangXinyan940 Oct 16, 2023
7eb616e
Update ut.yml
WangXinyan940 Oct 16, 2023
57c27c9
Update ut.yml
WangXinyan940 Oct 16, 2023
50e7c54
Add generator for ADMPPmeForce
WangXinyan940 Oct 17, 2023
eb76ae0
Merge changes on devel branch
WangXinyan940 Oct 17, 2023
b26de25
Merge changes on devel branch
WangXinyan940 Oct 17, 2023
98ffe8b
Update test_compute.py
WangXinyan940 Oct 17, 2023
78e0405
Support LJ long range correction
WangXinyan940 Oct 17, 2023
f642cb1
Add generators for Slater type forces
WangXinyan940 Oct 17, 2023
be7f94e
Update ut.yml
WangXinyan940 Oct 17, 2023
cd199e4
Update ut.yml
WangXinyan940 Oct 17, 2023
5c153bc
Fix No-AxisType bug (#122)
WangXinyan940 Oct 20, 2023
66d2eb9
qeq merge (#124)
gust-07 Oct 20, 2023
800480a
Add frontend for sGNN (#125)
KuangYu Oct 22, 2023
5acacff
Modified QEQ potential and add JIT support
WangXinyan940 Oct 22, 2023
81741de
Add QEQ test
WangXinyan940 Oct 22, 2023
4aa5b77
Add jaxopt requirement in github workflow
WangXinyan940 Oct 22, 2023
6134176
Update qeq.py
WangXinyan940 Oct 22, 2023
337a8cf
Support aux data for ADMP and QEQ
WangXinyan940 Oct 23, 2023
db28782
Clean the way of aux_data implementation
WangXinyan940 Oct 23, 2023
a03c841
Add new unit test for QEQ with two residues.
WangXinyan940 Oct 23, 2023
9b03c2b
Change jaxopt root finder to be jit-able
WangXinyan940 Oct 24, 2023
919bef2
Update ut.yml
WangXinyan940 Oct 24, 2023
6de2f47
Explicit support nopbc calculation
WangXinyan940 Oct 24, 2023
cf07e05
Update admp.py
WangXinyan940 Oct 24, 2023
23a1238
make the behavior of ADMP correct while using NoCutoff
WangXinyan940 Oct 24, 2023
4dd36b0
Merge remote-tracking branch 'upstream/devel' into dev
junminchen Apr 6, 2024
6c63b35
[Enhancement] Improved EANN Model with Hard Cutoff for Efficient Inte…
junminchen Apr 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/ut.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,17 @@ jobs:
- name: Install Dependencies
run: |
source $CONDA/bin/activate
<<<<<<< HEAD
conda create -n dmff -y python=${{ matrix.python-version }} numpy openmm==7.7.0 pytest rdkit biopandas openbabel mdtraj ambertools -c conda-forge
conda activate dmff
pip install --upgrade pip
pip install jax jaxlib jaxopt networkx parmed pymbar==4.0.1 chex==0.1.4 tqdm
=======
conda create -n dmff -y python=${{ matrix.python-version }} numpy openmm==7.7.0 pytest rdkit openbabel mdtraj ambertools -c conda-forge
conda activate dmff
pip install --upgrade pip
pip install jax jaxlib jaxopt networkx parmed pymbar==4.0.1 optax tqdm
>>>>>>> upstream/devel
- name: Install DMFF
run: |
source $CONDA/bin/activate dmff && pip install .
Expand Down
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,9 @@ FodyWeavers.xsd
# debugging ipynb
debug.ipynb
test.xml
<<<<<<< HEAD
=======

# PyCharm Cache
.idea/
.idea/
>>>>>>> upstream/devel
56 changes: 56 additions & 0 deletions dmff/admp/pme.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,33 @@
from .settings import POL_CONV, MAX_N_POL
from .recip import generate_pme_recip, Ck_1
from .multipole import (
<<<<<<< HEAD
C1_c2h,
=======
C1_c2h,
C1_h2c,
C2_h2c,
>>>>>>> upstream/devel
convert_cart2harm,
rot_ind_global2local,
rot_global2local,
rot_local2global,
)
<<<<<<< HEAD
from .spatial import (
v_pbc_shift,
generate_construct_local_frames,
build_quasi_internal
)
from .pairwise import (
distribute_scalar,
distribute_v3,
=======
from .spatial import v_pbc_shift, generate_construct_local_frames, build_quasi_internal
from .pairwise import (
distribute_scalar,
distribute_v3,
>>>>>>> upstream/devel
distribute_multipoles,
distribute_matrix,
)
Expand All @@ -38,6 +53,12 @@ class ADMPPmeForce:
This is a convenient wrapper for multipolar PME calculations
It wrapps all the environment parameters of multipolar PME calculation
The so called "environment paramters" means parameters that do not need to be differentiable
<<<<<<< HEAD
'''

def __init__(self, box, axis_type, axis_indices, rc, ethresh, lmax, lpol=False, lpme=True, steps_pol=None, has_aux=False):
'''
=======
"""

def __init__(
Expand All @@ -55,6 +76,7 @@ def __init__(
has_aux=False,
):
"""
>>>>>>> upstream/devel
Initialize the ADMPPmeForce calculator.

Input:
Expand Down Expand Up @@ -106,8 +128,11 @@ def __init__(
# self.n_atoms = int(covalent_map.shape[0]) # len(axis_type)
self.n_atoms = len(axis_type)
self.has_aux = has_aux
<<<<<<< HEAD
=======

self.Q_local_to_global = self.generate_Q_global_function()
>>>>>>> upstream/devel

# setup calculators
self.refresh_calculators()
Expand All @@ -118,6 +143,13 @@ def generate_get_energy(self):
if not self.lpol:

def get_energy(positions, box, pairs, Q_local, mScales):
<<<<<<< HEAD
return energy_pme(positions, box, pairs,
Q_local, None, None, None,
mScales, None, None,
self.construct_local_frames, self.pme_recip,
self.kappa, self.K1, self.K2, self.K3, self.lmax, False, lpme=self.lpme)
=======
return energy_pme(
positions,
box,
Expand All @@ -140,6 +172,7 @@ def get_energy(positions, box, pairs, Q_local, mScales):
lpme=self.lpme,
)

>>>>>>> upstream/devel
return get_energy
else:
# this is the bare energy calculator, with Uind as explicit input
Expand Down Expand Up @@ -184,6 +217,23 @@ def energy_fn(

# this is the wrapper that include a Uind optimizer
def get_energy(
<<<<<<< HEAD
positions, box, pairs,
Q_local, pol, tholes, mScales, pScales, dScales,
U_init = self.U_ind, aux = None):
self.U_ind, self.lconverg, self.n_cycle = self.optimize_Uind(
positions, box, pairs, Q_local, pol, tholes,
mScales, pScales, dScales,
U_init=U_init, steps_pol=self.steps_pol)
# here we rely on Feynman-Hellman theorem, drop the term dV/dU*dU/dr !
# self.U_ind = jax.lax.stop_gradient(U_ind)
energy = self.energy_fn(positions, box, pairs, Q_local, self.U_ind, pol, tholes, mScales, pScales, dScales)
if aux is not None:
aux["U_ind"] = self.U_ind
return energy, aux
else:
return energy
=======
positions,
box,
pairs,
Expand Down Expand Up @@ -232,6 +282,7 @@ def get_energy(
else:
return energy

>>>>>>> upstream/devel
return get_energy

def generate_esp(self):
Expand Down Expand Up @@ -935,10 +986,15 @@ def pme_real_kernel(
Output:
energy:
float, realspace interaction energy between the sites
<<<<<<< HEAD
'''
cc, cd, dd_m0, dd_m1, cq, dq_m0, dq_m1, qq_m0, qq_m1, qq_m2 = calc_e_perm(dr, mscales, kappa, lmax)
=======
"""
cc, cd, dd_m0, dd_m1, cq, dq_m0, dq_m1, qq_m0, qq_m1, qq_m2 = calc_e_perm(
dr, mscales, kappa, lmax
)
>>>>>>> upstream/devel
if lpol:
cud, dud_m0, dud_m1, udq_m0, udq_m1, udud_m0, udud_m1 = calc_e_ind(
dr, thole1, thole2, dmp, pscales, dscales, kappa, lmax
Expand Down
39 changes: 39 additions & 0 deletions dmff/admp/qeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
from ..classical.inter import CoulNoCutoffForce, CoulombPMEForce
from typing import Tuple, List
from ..settings import PRECISION
<<<<<<< HEAD
=======
from .pme import energy_pme
from .recip import generate_pme_recip, Ck_1
>>>>>>> upstream/devel

if PRECISION == "double":
CONST_0 = jnp.array(0, dtype=jnp.float64)
Expand All @@ -17,6 +20,16 @@

try:
import jaxopt
<<<<<<< HEAD
try:
from jaxopt import Broyden
JAXOPT_OLD = False
except ImportError:
JAXOPT_OLD = True
print("jaxopt is too old. The QEQ potential function cannot be jitted. Please update jaxopt to the latest version for speed concern.")
except ImportError:
print("jaxopt not found, QEQ cannot be used.")
=======

try:
from jaxopt import Broyden
Expand All @@ -31,6 +44,7 @@
except ImportError:
import warnings
warnings.warn("jaxopt not found, QEQ cannot be used.")
>>>>>>> upstream/devel
import jax

from jax.scipy.special import erf, erfc
Expand Down Expand Up @@ -106,9 +120,13 @@ def E_sr2(pos, box, pairs, q, eta, ds, buffer_scales):

@jit_condition()
def E_sr3(pos, box, pairs, q, eta, ds, buffer_scales):
<<<<<<< HEAD
etasqrt = jnp.sqrt(eta[pairs[:, 0]] ** 2 + eta[pairs[:, 1]] ** 2)
=======
etasqrt = jnp.sqrt(
eta[pairs[:, 0]] ** 2 + eta[pairs[:, 1]] ** 2 + 1e-64
) # add eta to avoid division by zero
>>>>>>> upstream/devel
epiece = eta_piecewise(etasqrt, ds)
pre_pair = -epiece * DIELECTRIC
pre_self = etainv_piecewise(eta) / (jnp.sqrt(2 * jnp.pi)) * DIELECTRIC
Expand Down Expand Up @@ -164,12 +182,20 @@ def ds_pairs(positions, box, pairs, pbc_flag):
if pbc_flag is False:
dr = pos1 - pos2
else:
<<<<<<< HEAD
box_inv = jnp.linalg.inv(box)
=======
box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36)
>>>>>>> upstream/devel
dpos = pos1 - pos2
dpos = dpos.dot(box_inv)
dpos -= jnp.floor(dpos + 0.5)
dr = dpos.dot(box)
<<<<<<< HEAD
ds = jnp.linalg.norm(dr, axis=1)
=======
ds = jnp.linalg.norm(dr + 1e-64, axis=1) # add eta to avoid division by zero
>>>>>>> upstream/devel
return ds


Expand Down Expand Up @@ -251,6 +277,14 @@ def __init__(
raise ValueError("damp_mod must be 1, 2 or 3")

if pbc_flag:
<<<<<<< HEAD
force = CoulombPMEForce(r_cut, kappa, K)
self.kappa = kappa
else:
force = CoulNoCutoffForce()
self.kappa = 1.0
self.coul_energy = force.generate_get_energy()
=======
pme_recip_fn = generate_pme_recip(
Ck_fn=Ck_1,
kappa=kappa / 10,
Expand Down Expand Up @@ -317,6 +351,7 @@ def coul_energy(positions, box, pairs, q, mscales):
self.kappa = 0.0

self.coul_energy = coul_energy
>>>>>>> upstream/devel

def generate_get_energy(self):
@jit_condition()
Expand Down Expand Up @@ -401,5 +436,9 @@ def get_energy(positions, box, pairs, mscales, eta, chi, J, aux=None):
return energy, aux
else:
return energy
<<<<<<< HEAD

=======

>>>>>>> upstream/devel
return get_energy
15 changes: 15 additions & 0 deletions dmff/api/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
import simtk.openmm.app as app
import simtk.unit as unit
from typing import Dict, Tuple, List
<<<<<<< HEAD
from rdkit import Chem
from rdkit.Chem import AllChem

=======
try:
from rdkit import Chem
from rdkit.Chem import AllChem
Expand All @@ -21,12 +26,17 @@ def is_same_list(l1, l2):
if l1[nn] != l2[nn]:
return False
return True
>>>>>>> upstream/devel

def matchTemplate(graph, template):
if graph.number_of_nodes() != template.number_of_nodes():
# print("Node with different number of nodes.")
return False, {}, {}

<<<<<<< HEAD
def match_func(n1, n2):
return n1["element"] == n2["element"] and n1["external_bond"] == n2["external_bond"]
=======
name_graph = sorted([i[1]['name'] for i in graph.nodes.data()])
name_template = sorted([i[1]['name'] for i in template.nodes.data()])

Expand All @@ -36,6 +46,7 @@ def match_func(n1, n2):
else:
def match_func(n1, n2):
return n1["element"] == n2["element"] and n1["external_bond"] == n2["external_bond"]
>>>>>>> upstream/devel

def edge_match(e1, e2):
if len(e1) == 0 and len(e2) == 0:
Expand Down Expand Up @@ -118,7 +129,11 @@ def graph2top(graph: nx.Graph) -> app.Topology:
return top


<<<<<<< HEAD
def top2rdmol(top: app.Topology, indices: List[int]) -> Chem.rdchem.Mol:
=======
def top2rdmol(top: app.Topology, indices: List[int]):
>>>>>>> upstream/devel
rdmol = Chem.Mol()
emol = Chem.EditableMol(rdmol)
idx2ridx = {}
Expand Down
21 changes: 21 additions & 0 deletions dmff/api/paramset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,15 @@ class ParamSet:
Converts all parameters to jax arrays.
"""

<<<<<<< HEAD
def __init__(self, data: Dict[str, Union[Dict[str, jnp.ndarray], jnp.ndarray]] = None, mask: Dict[str, Union[Dict[str, jnp.ndarray], jnp.ndarray]] = None):
=======
def __init__(
self,
data: Dict[str, Union[Dict[str, jnp.ndarray], jnp.ndarray]] = None,
mask: Dict[str, Union[Dict[str, jnp.ndarray], jnp.ndarray]] = None,
):
>>>>>>> upstream/devel
"""
Initializes a new ParamSet object.

Expand Down Expand Up @@ -56,13 +60,17 @@ def addField(self, field: str) -> None:
self.parameters[field] = {}
self.mask[field] = {}

<<<<<<< HEAD
def addParameter(self, values: jnp.ndarray, name: str, field: str = None, mask: jnp.ndarray = None) -> None:
=======
def addParameter(
self,
values: jnp.ndarray,
name: str,
field: str = None,
mask: jnp.ndarray = None,
) -> None:
>>>>>>> upstream/devel
"""
Adds a new parameter to the parameters and mask dictionaries.

Expand Down Expand Up @@ -97,7 +105,12 @@ def to_jax(self) -> None:
for key1 in self.parameters:
if isinstance(self.parameters[key1], dict):
for key2 in self.parameters[key1]:
<<<<<<< HEAD
self.parameters[key1][key2] = jnp.array(
self.parameters[key1][key2])
=======
self.parameters[key1][key2] = jnp.array(self.parameters[key1][key2])
>>>>>>> upstream/devel
else:
self.parameters[key1] = jnp.array(self.parameters[key1])

Expand All @@ -117,12 +130,15 @@ def __getitem__(self, key: str) -> Union[Dict[str, jnp.ndarray], jnp.ndarray]:
"""
return self.parameters[key]

<<<<<<< HEAD
=======
def update_mask(self, gradients):
gradients = jax.tree_map(
lambda g, m: jnp.where(jnp.abs(m - 1.0) > 1e-5, g, 0.0), gradients, self.mask
)
return gradients

>>>>>>> upstream/devel

def flatten_paramset(prmset: ParamSet) -> tuple:
"""
Expand Down Expand Up @@ -160,4 +176,9 @@ def unflatten_paramset(aux_data: Dict, contents: tuple) -> ParamSet:
return ParamSet(data=contents[0], mask=aux_data)


<<<<<<< HEAD
jax.tree_util.register_pytree_node(
ParamSet, flatten_paramset, unflatten_paramset)
=======
jax.tree_util.register_pytree_node(ParamSet, flatten_paramset, unflatten_paramset)
>>>>>>> upstream/devel
Loading
Loading