Skip to content

Commit

Permalink
[LSC] Ignore incorrect type annotations related to jax.numpy APIs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 568475283
Change-Id: Ice9d8b610d2d8ab1c541679590e184ab91cbc05e
  • Loading branch information
Jake VanderPlas authored and jsspencer committed Nov 24, 2023
1 parent 1aa0a2a commit d592df1
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions ferminet/jastrows.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def _jastrow_ee(
for r in jnp.split(r_ee, nspins[0:1], axis=0)
]
r_ees_parallel = jnp.concatenate([
r_ees[0][0][jnp.triu_indices(nspins[0], k=1)],
r_ees[1][1][jnp.triu_indices(nspins[1], k=1)],
r_ees[0][0][jnp.triu_indices(nspins[0], k=1)], # pytype: disable=wrong-arg-types # jnp-type
r_ees[1][1][jnp.triu_indices(nspins[1], k=1)], # pytype: disable=wrong-arg-types # jnp-type
])

if r_ees_parallel.shape[0] > 0:
Expand Down

0 comments on commit d592df1

Please sign in to comment.