Skip to content

Commit

Permalink
implemented fix
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuaspear committed Feb 25, 2024
1 parent 17a0f1f commit 9996b30
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 29 deletions.
26 changes: 0 additions & 26 deletions pyrightconfig.json

This file was deleted.

6 changes: 3 additions & 3 deletions src/offline_rl_ope/PropensityModels/torch/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def predict_proba(
x = self.input_setup(x)
self.estimator.eval()
res = self.estimator(x)
res_out = res["out"].cpu().detach()
res_out = res["out"]
n_rows = res_out.shape[0]
n_out = res_out.shape[2]
dim_0_sub = np.arange(0,n_rows)[:,None]
Expand Down Expand Up @@ -134,7 +134,7 @@ def predict(
x = self.input_setup(x)
self.estimator.eval()
res = self.estimator(x)
res = res["loc"].cpu().detach().numpy()
res = res["loc"]
return res

def predict_proba(
Expand All @@ -151,5 +151,5 @@ def predict_proba(
self.estimator.eval()
pred_res = self.estimator(x)
d_f = self.dist_func(**pred_res)
res = torch.exp(d_f.log_prob(y)).cpu().detach().numpy()
res = torch.exp(d_f.log_prob(y))
return res
2 changes: 2 additions & 0 deletions tests/PropensityModels/torch/test_Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def test_predict1D(self) -> torch.Tensor:
)
pred = trainer.predict(x=in_x)
assert len(pred.shape) == 2
assert isinstance(pred, torch.Tensor)
nt.assert_array_equal(
pred.numpy(), res_predict_true.numpy()
)
Expand All @@ -92,6 +93,7 @@ def test_predict_proba1D(self) -> torch.Tensor:
)
pred = trainer.predict_proba(x=in_x,y=in_y)
assert len(pred.shape) == 2
assert isinstance(pred, torch.Tensor)
nt.assert_array_equal(
pred.numpy(), res_predict_proba.numpy()
)
Expand Down

0 comments on commit 9996b30

Please sign in to comment.