Skip to content

Commit

Permalink
bug fix when running on cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuaspear committed Feb 27, 2024
1 parent 3a1cbd4 commit 53164ba
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions examples/d3rlpy_training_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def __init__(self, estimator:MultiOutputClassifier) -> None:
self.estimator = estimator

def __call__(self, y:torch.Tensor, x:torch.Tensor):
x = x.numpy()
y = y.numpy()
x = x.cpu().numpy()
y = y.cpu().numpy()
probs = self.estimator.predict_proba(X=x)
res = []
for i,out_prob in enumerate(probs):
Expand Down
4 changes: 2 additions & 2 deletions examples/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def __init__(self, estimator:MultiOutputClassifier) -> None:
self.estimator = estimator

def __call__(self, y:torch.Tensor, x:torch.Tensor):
x = x.numpy()
y = y.numpy()
x = x.cpu().numpy()
y = y.cpu().numpy()
probs = self.estimator.predict_proba(X=x)
res = []
for i,out_prob in enumerate(probs):
Expand Down
2 changes: 1 addition & 1 deletion src/offline_rl_ope/OPEEstimators/DirectMethod.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, model:QLearningAlgoBase) -> None:

def get_q(self, state:torch.Tensor, action:torch.Tensor) -> torch.Tensor:
values = torch.tensor(self.model.predict_value(
x=state.numpy(), action=action.numpy()))
x=state.cpu().numpy(), action=action.cpu().numpy()))
return values

def get_v(self, state:torch.Tensor) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion src/offline_rl_ope/api/d3rlpy/Misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ def __init__(self, predict_func:D3rlpyAlgoPredictProtocal):
self.predict_func = predict_func

def __call__(self, x:torch.Tensor):
pred = self.predict_func(x.numpy())
pred = self.predict_func(x.cpu().numpy())
return torch.Tensor(pred)
2 changes: 1 addition & 1 deletion src/offline_rl_ope/api/d3rlpy/Scorers/IS.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __call__(
weights=self.cache[self.is_type].traj_is_weights,
is_msk=self.cache.weight_msk, discount=self.discount
)
return res.numpy()
return res.cpu().numpy()



Expand Down

0 comments on commit 53164ba

Please sign in to comment.