Skip to content

Commit

Permalink
epsilon adjustment for deterministic policies
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuaspear committed Oct 30, 2023
1 parent 2fdfa76 commit f695a18
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
10 changes: 8 additions & 2 deletions src/offline_rl_ope/components/Policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ def __call__(self, state: torch.Tensor, action: torch.Tensor):

class D3RlPyDeterministic(Policy):

def __init__(self, policy_class: Callable, collect_res:bool=False,
collect_act:bool=False, gpu:bool=True) -> None:
def __init__(
self, policy_class: Callable, collect_res:bool=False,
collect_act:bool=False, gpu:bool=True, eps:float=0
) -> None:
super().__init__(policy_class, collect_res=collect_res,
collect_act=collect_act)
if gpu:
Expand All @@ -77,6 +79,7 @@ def __init__(self, policy_class: Callable, collect_res:bool=False,
else:
self.__preproc_tens = lambda x: x
self.__postproc_tens = lambda x: x
self.__eps = eps


def __call__(self, state: torch.Tensor, action: torch.Tensor)->torch.Tensor:
Expand All @@ -85,6 +88,9 @@ def __call__(self, state: torch.Tensor, action: torch.Tensor)->torch.Tensor:
greedy_action = self.__postproc_tens(greedy_action)
self.collect_act_func(greedy_action)
res = (greedy_action == action).all(dim=1, keepdim=True).int()
res_eps_upper = res*(1-self.__eps)
res_eps_lower = (1-res)*(self.__eps)
res = res_eps_upper + res_eps_lower
self.collect_res_fn(res)
return res

Expand Down
30 changes: 27 additions & 3 deletions tests/components/test_Policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
logger = logging.getLogger("offline_rl_ope")


eps = 0.001

class D3RlPyDeterministicTest(unittest.TestCase):

def setUp(self) -> None:
Expand All @@ -22,9 +24,11 @@ def __mock_return(x):
}
return lkp[str(x)]
policy_class = MagicMock(side_effect=__mock_return)
self.policy = D3RlPyDeterministic(policy_class, gpu=False)
self.policy_0_eps = D3RlPyDeterministic(policy_class, gpu=False)
self.policy_001_eps = D3RlPyDeterministic(
policy_class, gpu=False, eps=eps)

def test___call__(self):
def test___call__0_eps(self):
test_pred = []
__test_action_vals = [np.array(i) for i in test_action_vals]
__test_eval_action_vals = [np.array(i) for i in test_eval_action_vals]
Expand All @@ -35,7 +39,27 @@ def test___call__(self):
for s,a in zip(test_state_vals, test_action_vals):
s = torch.Tensor(s)
a = torch.Tensor(a)
pred = self.policy(state=s, action=a)
pred = self.policy_0_eps(state=s, action=a)
self.assertEqual(pred.shape, torch.Size((s.shape[0],1)))
test_pred.append(pred.squeeze().numpy())
test_pred = np.concatenate(test_pred)
np.testing.assert_allclose(test_pred, test_res, atol=tollerance)

def test___call__0001_eps(self):
test_pred = []
__test_action_vals = [np.array(i) for i in test_action_vals]
__test_eval_action_vals = [np.array(i) for i in test_eval_action_vals]
test_res = [(x==y).astype(int)
for x,y in zip(__test_action_vals, __test_eval_action_vals)]
test_res = np.concatenate(test_res).squeeze()
test_res = np.where(
test_res == 1, 1-eps, 0+eps
)
tollerance = test_res.mean()/1000
for s,a in zip(test_state_vals, test_action_vals):
s = torch.Tensor(s)
a = torch.Tensor(a)
pred = self.policy_001_eps(state=s, action=a)
self.assertEqual(pred.shape, torch.Size((s.shape[0],1)))
test_pred.append(pred.squeeze().numpy())
test_pred = np.concatenate(test_pred)
Expand Down

0 comments on commit f695a18

Please sign in to comment.