From fe3bc267156fdf5de42f074c5f316537f5527e08 Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Sun, 25 Feb 2024 14:48:33 +0000 Subject: [PATCH] prevented gradient in IS calculations and returned numpy array from IS d3rlpy scorer --- src/offline_rl_ope/api/d3rlpy/Scorers/IS.py | 2 +- src/offline_rl_ope/components/ImportanceSampler.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/offline_rl_ope/api/d3rlpy/Scorers/IS.py b/src/offline_rl_ope/api/d3rlpy/Scorers/IS.py index fed6b13..48a0e69 100644 --- a/src/offline_rl_ope/api/d3rlpy/Scorers/IS.py +++ b/src/offline_rl_ope/api/d3rlpy/Scorers/IS.py @@ -48,7 +48,7 @@ def __call__( weights=self.cache[self.is_type].traj_is_weights, is_msk=self.cache.weight_msk, discount=self.discount ) - return res + return res.numpy() diff --git a/src/offline_rl_ope/components/ImportanceSampler.py b/src/offline_rl_ope/components/ImportanceSampler.py index 412d467..6264cc7 100644 --- a/src/offline_rl_ope/components/ImportanceSampler.py +++ b/src/offline_rl_ope/components/ImportanceSampler.py @@ -34,10 +34,11 @@ def get_traj_w(self, states:torch.Tensor, actions:torch.Tensor, logger.debug("states.shape: {}".format(states.shape)) logger.debug("actions.shape: {}".format(actions.shape)) raise Exception("State and actions should have 2 dimensions") - behav_probs = self.__behav_policy(action=actions, - state=states) - #logger.debug("behav_probs: {}".format(behav_probs)) - eval_probs = eval_policy(action=actions, state=states) + with torch.no_grad(): + behav_probs = self.__behav_policy(action=actions, + state=states) + #logger.debug("behav_probs: {}".format(behav_probs)) + eval_probs = eval_policy(action=actions, state=states) #logger.debug("eval_probs: {}".format(eval_probs)) weight_array = eval_probs/behav_probs weight_array = weight_array.view(-1)