Skip to content

Commit

Permalink
Fix linting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
gui11aume committed Oct 4, 2023
1 parent f88a16e commit 930e32a
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions tests/infer/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

logger = logging.getLogger(__name__)


def DiffTrace_ELBO(*args, **kwargs):
return Trace_ELBO(*args, **kwargs).differentiable_loss

Expand Down Expand Up @@ -229,7 +230,7 @@ def guide(subsample):
"reparameterized,has_rsample",
[(True, None), (True, False), (True, True), (False, None)],
ids=["reparam", "reparam-False", "reparam-True", "nonreparam"],
)
)
@pytest.mark.parametrize(
"Elbo,local_samples",
[
Expand All @@ -242,7 +243,12 @@ def guide(subsample):
],
)
def test_mask_gradient(
Elbo, reparameterized, has_rsample, local_samples, mask, with_x_unobserved,
Elbo,
reparameterized,
has_rsample,
local_samples,
mask,
with_x_unobserved,
):
pyro.clear_param_store()
data = torch.tensor([-0.5, 2.0])
Expand Down Expand Up @@ -286,9 +292,7 @@ def guide(data, mask):
for _ in range(accumulation):
inference = SVI(model, guide, optim, loss=elbo)
with xfail_if_not_implemented():
inference.loss_and_grads(
model, guide, data=data, mask=torch.tensor(mask)
)
inference.loss_and_grads(model, guide, data=data, mask=torch.tensor(mask))
params = dict(pyro.get_param_store().named_parameters())
actual_grads = {
name: param.grad.detach().cpu().numpy() / accumulation
Expand All @@ -298,8 +302,8 @@ def guide(data, mask):
# grad(loc) = (n+1) * loc - (x1 + ... + xn)
# grad(scale) = (n+1) * scale - 1 / scale
expected_grads = {
"loc": sum(mask) + 1. - data[mask].sum(0, keepdim=True).numpy(),
"scale": sum(mask) + 1 - np.ones(1)
"loc": sum(mask) + 1.0 - data[mask].sum(0, keepdim=True).numpy(),
"scale": sum(mask) + 1 - np.ones(1),
}
for name in sorted(params):
logger.info("expected {} = {}".format(name, expected_grads[name]))
Expand Down

0 comments on commit 930e32a

Please sign in to comment.