diff --git a/ferminet/loss.py b/ferminet/loss.py index d72dfaa..a41b4cd 100644 --- a/ferminet/loss.py +++ b/ferminet/loss.py @@ -110,7 +110,7 @@ def clip_at_total_variation(values, center, scale): if clip_from_median: # More natural place to center the clipping, but expensive due to both # the median and all_gather (at least on multihost) - clip_center = jnp.median(constants.all_gather(local_values)) + clip_center = jnp.median(constants.all_gather(local_values).real) else: clip_center = mean_local_values # roughly, the total variation of the local energies