From 5b8949b3d868d4d446d696f2c7254022349f45a4 Mon Sep 17 00:00:00 2001 From: mcbal Date: Tue, 12 Oct 2021 11:33:22 +0200 Subject: [PATCH] Simplify attention module for now --- afem/attention.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/afem/attention.py b/afem/attention.py index 1e3a920..1b02c27 100644 --- a/afem/attention.py +++ b/afem/attention.py @@ -52,10 +52,6 @@ def forward( x, t0=None, beta=None, - return_afe=False, - return_magnetizations=True, - return_internal_energy=False, - return_log_prob=False, use_analytical_grads=True, ): h = self.pre_norm(x) / np.sqrt(self.spin_model.dim) @@ -64,13 +60,8 @@ def forward( h, t0=t0 if exists(t0) else torch.ones_like(x[0, :, 0]), beta=beta, - return_afe=return_afe, - return_magnetizations=return_magnetizations, - return_internal_energy=return_internal_energy, - return_log_prob=return_log_prob, + return_magnetizations=True, use_analytical_grads=use_analytical_grads, ) - out.magnetizations = self.post_norm(out.magnetizations) if exists(out.magnetizations) else None - - return out + return self.post_norm(out.magnetizations)