diff --git a/ferminet/psiformer.py b/ferminet/psiformer.py index ab33450..11dc797 100644 --- a/ferminet/psiformer.py +++ b/ferminet/psiformer.py @@ -445,7 +445,7 @@ def network_apply( jnp.reshape(orbital, (options.states, -1) + orbital.shape[1:]) for orbital in orbitals ] - return batch_logdet_matmul(*orbitals) + return batch_logdet_matmul(orbitals) else: return network_blocks.logdet_matmul(orbitals)