From f695a18bb7afccdfd702b7ec4b2845b61fcc9318 Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Mon, 30 Oct 2023 12:45:55 +0000 Subject: [PATCH] epsilon adjustment for deterministic policies --- src/offline_rl_ope/components/Policy.py | 10 +++++++-- tests/components/test_Policy.py | 30 ++++++++++++++++++++++--- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/src/offline_rl_ope/components/Policy.py b/src/offline_rl_ope/components/Policy.py index 065c086..366b46f 100644 --- a/src/offline_rl_ope/components/Policy.py +++ b/src/offline_rl_ope/components/Policy.py @@ -67,8 +67,10 @@ def __call__(self, state: torch.Tensor, action: torch.Tensor): class D3RlPyDeterministic(Policy): - def __init__(self, policy_class: Callable, collect_res:bool=False, - collect_act:bool=False, gpu:bool=True) -> None: + def __init__( + self, policy_class: Callable, collect_res:bool=False, + collect_act:bool=False, gpu:bool=True, eps:float=0 + ) -> None: super().__init__(policy_class, collect_res=collect_res, collect_act=collect_act) if gpu: @@ -77,6 +79,7 @@ def __init__(self, policy_class: Callable, collect_res:bool=False, else: self.__preproc_tens = lambda x: x self.__postproc_tens = lambda x: x + self.__eps = eps def __call__(self, state: torch.Tensor, action: torch.Tensor)->torch.Tensor: @@ -85,6 +88,9 @@ def __call__(self, state: torch.Tensor, action: torch.Tensor)->torch.Tensor: greedy_action = self.__postproc_tens(greedy_action) self.collect_act_func(greedy_action) res = (greedy_action == action).all(dim=1, keepdim=True).int() + res_eps_upper = res*(1-self.__eps) + res_eps_lower = (1-res)*(self.__eps) + res = res_eps_upper + res_eps_lower self.collect_res_fn(res) return res diff --git a/tests/components/test_Policy.py b/tests/components/test_Policy.py index 2ab19f6..81c01a7 100644 --- a/tests/components/test_Policy.py +++ b/tests/components/test_Policy.py @@ -12,6 +12,8 @@ logger = logging.getLogger("offline_rl_ope") +eps = 0.001 + class D3RlPyDeterministicTest(unittest.TestCase): def setUp(self) -> None: @@ -22,9 +24,11 @@ def __mock_return(x): } return lkp[str(x)] policy_class = MagicMock(side_effect=__mock_return) - self.policy = D3RlPyDeterministic(policy_class, gpu=False) + self.policy_0_eps = D3RlPyDeterministic(policy_class, gpu=False) + self.policy_001_eps = D3RlPyDeterministic( + policy_class, gpu=False, eps=eps) - def test___call__(self): + def test___call__0_eps(self): test_pred = [] __test_action_vals = [np.array(i) for i in test_action_vals] __test_eval_action_vals = [np.array(i) for i in test_eval_action_vals] @@ -35,7 +39,27 @@ def test___call__(self): for s,a in zip(test_state_vals, test_action_vals): s = torch.Tensor(s) a = torch.Tensor(a) - pred = self.policy(state=s, action=a) + pred = self.policy_0_eps(state=s, action=a) + self.assertEqual(pred.shape, torch.Size((s.shape[0],1))) + test_pred.append(pred.squeeze().numpy()) + test_pred = np.concatenate(test_pred) + np.testing.assert_allclose(test_pred, test_res, atol=tollerance) + + def test___call__0001_eps(self): + test_pred = [] + __test_action_vals = [np.array(i) for i in test_action_vals] + __test_eval_action_vals = [np.array(i) for i in test_eval_action_vals] + test_res = [(x==y).astype(int) + for x,y in zip(__test_action_vals, __test_eval_action_vals)] + test_res = np.concatenate(test_res).squeeze() + test_res = np.where( + test_res == 1, 1-eps, 0+eps + ) + tollerance = test_res.mean()/1000 + for s,a in zip(test_state_vals, test_action_vals): + s = torch.Tensor(s) + a = torch.Tensor(a) + pred = self.policy_001_eps(state=s, action=a) self.assertEqual(pred.shape, torch.Size((s.shape[0],1))) test_pred.append(pred.squeeze().numpy()) test_pred = np.concatenate(test_pred)