Skip to content

Commit

Permalink
prevented gradient in IS calculations and returned numpy array from I…
Browse files Browse the repository at this point in the history
…S d3rlpy scorer
  • Loading branch information
joshuaspear committed Feb 25, 2024
1 parent 9996b30 commit fe3bc26
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/offline_rl_ope/api/d3rlpy/Scorers/IS.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __call__(
weights=self.cache[self.is_type].traj_is_weights,
is_msk=self.cache.weight_msk, discount=self.discount
)
return res
return res.numpy()



Expand Down
9 changes: 5 additions & 4 deletions src/offline_rl_ope/components/ImportanceSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ def get_traj_w(self, states:torch.Tensor, actions:torch.Tensor,
logger.debug("states.shape: {}".format(states.shape))
logger.debug("actions.shape: {}".format(actions.shape))
raise Exception("State and actions should have 2 dimensions")
behav_probs = self.__behav_policy(action=actions,
state=states)
#logger.debug("behav_probs: {}".format(behav_probs))
eval_probs = eval_policy(action=actions, state=states)
with torch.no_grad():
behav_probs = self.__behav_policy(action=actions,
state=states)
#logger.debug("behav_probs: {}".format(behav_probs))
eval_probs = eval_policy(action=actions, state=states)
#logger.debug("eval_probs: {}".format(eval_probs))
weight_array = eval_probs/behav_probs
weight_array = weight_array.view(-1)
Expand Down

0 comments on commit fe3bc26

Please sign in to comment.