Skip to content

Commit

Permalink
Apparently the median is not defined for complex numbers in JAX
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 518310313
Change-Id: If7bc78b35bbcb20c260ac0c91f8e64ba7f5b8e1a
  • Loading branch information
dpfau authored and jsspencer committed Apr 25, 2023
1 parent 962c19b commit 314a657
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion ferminet/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def clip_at_total_variation(values, center, scale):
if clip_from_median:
# More natural place to center the clipping, but expensive due to both
# the median and all_gather (at least on multihost)
clip_center = jnp.median(constants.all_gather(local_values))
clip_center = jnp.median(constants.all_gather(local_values).real)
else:
clip_center = mean_local_values
# roughly, the total variation of the local energies
Expand Down

0 comments on commit 314a657

Please sign in to comment.