diff --git a/pyrightconfig.json b/pyrightconfig.json deleted file mode 100644 index 81a234e..0000000 --- a/pyrightconfig.json +++ /dev/null @@ -1,26 +0,0 @@ -{ - "include": [ - "src" - ], - "exclude": [], - "ignore": [], - "defineConstant": { - "DEBUG": true - }, - "reportMissingImports": true, - "reportMissingTypeStubs": false, - - "pythonVersion": "3.10", - "pythonPlatform": "Linux", - - "executionEnvironments": [ - { - "root": "src/offline_rl_ope", - "pythonVersion": "3.10", - "pythonPlatform": "Linux", - }, - { - "root": "src" - } - ] - } \ No newline at end of file diff --git a/src/offline_rl_ope/PropensityModels/torch/Trainer.py b/src/offline_rl_ope/PropensityModels/torch/Trainer.py index 08ae62e..693c024 100644 --- a/src/offline_rl_ope/PropensityModels/torch/Trainer.py +++ b/src/offline_rl_ope/PropensityModels/torch/Trainer.py @@ -105,7 +105,7 @@ def predict_proba( x = self.input_setup(x) self.estimator.eval() res = self.estimator(x) - res_out = res["out"].cpu().detach() + res_out = res["out"] n_rows = res_out.shape[0] n_out = res_out.shape[2] dim_0_sub = np.arange(0,n_rows)[:,None] @@ -134,7 +134,7 @@ def predict( x = self.input_setup(x) self.estimator.eval() res = self.estimator(x) - res = res["loc"].cpu().detach().numpy() + res = res["loc"] return res def predict_proba( @@ -151,5 +151,5 @@ def predict_proba( self.estimator.eval() pred_res = self.estimator(x) d_f = self.dist_func(**pred_res) - res = torch.exp(d_f.log_prob(y)).cpu().detach().numpy() + res = torch.exp(d_f.log_prob(y)) return res \ No newline at end of file diff --git a/tests/PropensityModels/torch/test_Trainer.py b/tests/PropensityModels/torch/test_Trainer.py index 206ceca..41593ec 100644 --- a/tests/PropensityModels/torch/test_Trainer.py +++ b/tests/PropensityModels/torch/test_Trainer.py @@ -80,6 +80,7 @@ def test_predict1D(self) -> torch.Tensor: ) pred = trainer.predict(x=in_x) assert len(pred.shape) == 2 + assert isinstance(pred, torch.Tensor) nt.assert_array_equal( pred.numpy(), res_predict_true.numpy() ) @@ -92,6 +93,7 @@ def test_predict_proba1D(self) -> torch.Tensor: ) pred = trainer.predict_proba(x=in_x,y=in_y) assert len(pred.shape) == 2 + assert isinstance(pred, torch.Tensor) nt.assert_array_equal( pred.numpy(), res_predict_proba.numpy() )