From 4aa5b77f043006c6262260fd0ff5078c55f21a5a Mon Sep 17 00:00:00 2001 From: Wang Xinyan Date: Sun, 22 Oct 2023 20:57:02 +0800 Subject: [PATCH] Add jaxopt requirement in github workflow --- .github/workflows/ut.yml | 2 +- dmff/admp/qeq.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ut.yml b/.github/workflows/ut.yml index 5b8de3ba1..6838f84b0 100644 --- a/.github/workflows/ut.yml +++ b/.github/workflows/ut.yml @@ -22,7 +22,7 @@ jobs: conda create -n dmff -y python=${{ matrix.python-version }} numpy openmm==7.7.0 pytest rdkit biopandas openbabel mdtraj ambertools -c conda-forge conda activate dmff pip install --upgrade pip - pip install jax jaxlib networkx parmed pymbar==4.0.1 chex==0.1.4 tqdm + pip install jax jaxlib jaxopt networkx parmed pymbar==4.0.1 chex==0.1.4 tqdm - name: Install DMFF run: | source $CONDA/bin/activate dmff && pip install . diff --git a/dmff/admp/qeq.py b/dmff/admp/qeq.py index e13e8eecc..1bd0fcb49 100644 --- a/dmff/admp/qeq.py +++ b/dmff/admp/qeq.py @@ -29,13 +29,18 @@ def mask_index(idx, max_idx): return jnp.piecewise( idx, [idx < max_idx, idx >= max_idx], [lambda x: CONST_1, lambda x: CONST_0] ) + + mask_index = jax.vmap(mask_index, in_axes=(0, None)) + @jit_condition() def group_sum(val_list, indices): max_idx = val_list.shape[0] mask = mask_index(indices, max_idx) return jnp.sum(val_list[indices] * mask) + + group_sum = jax.vmap(group_sum, in_axes=(None, 0))