diff --git a/examples/d3rlpy_training_api.py b/examples/d3rlpy_training_api.py index 5d02607..deff62e 100644 --- a/examples/d3rlpy_training_api.py +++ b/examples/d3rlpy_training_api.py @@ -45,8 +45,8 @@ def __init__(self, estimator:MultiOutputClassifier) -> None: self.estimator = estimator def __call__(self, y:torch.Tensor, x:torch.Tensor): - x = x.numpy() - y = y.numpy() + x = x.cpu().numpy() + y = y.cpu().numpy() probs = self.estimator.predict_proba(X=x) res = [] for i,out_prob in enumerate(probs): diff --git a/examples/static.py b/examples/static.py index b836721..42f40f1 100644 --- a/examples/static.py +++ b/examples/static.py @@ -40,8 +40,8 @@ def __init__(self, estimator:MultiOutputClassifier) -> None: self.estimator = estimator def __call__(self, y:torch.Tensor, x:torch.Tensor): - x = x.numpy() - y = y.numpy() + x = x.cpu().numpy() + y = y.cpu().numpy() probs = self.estimator.predict_proba(X=x) res = [] for i,out_prob in enumerate(probs): diff --git a/src/offline_rl_ope/OPEEstimators/DirectMethod.py b/src/offline_rl_ope/OPEEstimators/DirectMethod.py index 98d193c..9495ae2 100644 --- a/src/offline_rl_ope/OPEEstimators/DirectMethod.py +++ b/src/offline_rl_ope/OPEEstimators/DirectMethod.py @@ -24,7 +24,7 @@ def __init__(self, model:QLearningAlgoBase) -> None: def get_q(self, state:torch.Tensor, action:torch.Tensor) -> torch.Tensor: values = torch.tensor(self.model.predict_value( - x=state.numpy(), action=action.numpy())) + x=state.cpu().numpy(), action=action.cpu().numpy())) return values def get_v(self, state:torch.Tensor) -> torch.Tensor: diff --git a/src/offline_rl_ope/api/d3rlpy/Misc.py b/src/offline_rl_ope/api/d3rlpy/Misc.py index 214db44..a48d616 100644 --- a/src/offline_rl_ope/api/d3rlpy/Misc.py +++ b/src/offline_rl_ope/api/d3rlpy/Misc.py @@ -9,5 +9,5 @@ def __init__(self, predict_func:D3rlpyAlgoPredictProtocal): self.predict_func = predict_func def __call__(self, x:torch.Tensor): - pred = self.predict_func(x.numpy()) + pred = self.predict_func(x.cpu().numpy()) return torch.Tensor(pred) diff --git a/src/offline_rl_ope/api/d3rlpy/Scorers/IS.py b/src/offline_rl_ope/api/d3rlpy/Scorers/IS.py index 48a0e69..656cb24 100644 --- a/src/offline_rl_ope/api/d3rlpy/Scorers/IS.py +++ b/src/offline_rl_ope/api/d3rlpy/Scorers/IS.py @@ -48,7 +48,7 @@ def __call__( weights=self.cache[self.is_type].traj_is_weights, is_msk=self.cache.weight_msk, discount=self.discount ) - return res.numpy() + return res.cpu().numpy()