diff --git a/README.md b/README.md index d0ceb6a..b05615c 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # offline_rl_ope (BETA RELEASE) **WARNING** +- Per-decision weighted importance sampling was incorrectly implemented in versions < 5.X - Weighted importance sampling was incorrectly implemented in versions 1.X.X and 2.1.X, 2.2.X - Unit testing currently only running in Python 3.11. 3.10 will be supported in the future - Only 1 dimensional discrete action spaces are currently supported! @@ -88,6 +89,16 @@ If importance sampling based methods are evaluating to 0, consider visualising t The different kinds of importance samples can also be visualised by querying the ```traj_is_weights``` attribute of a given ```ImportanceSampler``` object. If for example, vanilla importance sampling is being used and the samples are not ```NaN``` or ```Inf``` then visualising the ```traj_is_weights``` may provide insight. In particular, IS weights will tend to inifinity when the evaluation policy places large density on an action in comparison to the behaviour policy. ### Release log +#### 5.0.0 +* Correctly implemented per-decision weighted importance sampling +* Expanded the different types of weights that can be implemented based on: + * http://proceedings.mlr.press/v48/jiang16.pdf: Per-decision weights are defined as the average weight at a given timepoint. This results in a different denominator for different timepoints. This is implemented with the following ```WISWeightNorm(avg_denom=True)``` + * https://scholarworks.umass.edu/cgi/viewcontent.cgi?article=1079&context=cs_faculty_pubs: Per-decision weights are defined as the sum of discounted weights across all timesteps. This is implemented with the following ```WISWeightNorm(discount=discount_value)``` + * Combinations of different weights can be easily implemented for example 'average discounted weights' ```WISWeightNorm(discount=discount_value, avg_denom=True)``` however, these do not necessaily have backing from literature. +* EffectiveSampleSize metric optinally returns nan if all weights are 0 +* Bug fixes: + * Fix bug when running on cuda where tensors were not being pushed to CPU + * Improved static typing #### 4.0.0 * Predefined propensity models including: * Generic feedforward MLP for continuous and discrete action spaces built in PyTorch diff --git a/src/offline_rl_ope/Metrics/EffectiveSampleSize.py b/src/offline_rl_ope/Metrics/EffectiveSampleSize.py index 79ed63b..34004e7 100644 --- a/src/offline_rl_ope/Metrics/EffectiveSampleSize.py +++ b/src/offline_rl_ope/Metrics/EffectiveSampleSize.py @@ -12,12 +12,13 @@ def __init__(self, nan_if_all_0:bool=True) -> None: def __ess(self, weights:torch.Tensor) -> float: # https://victorelvira.github.io/papers/kong92.pdf - weights = weights.sum(dim=1) - numer = len(weights) - w_var = torch.var(weights).item() - if (w_var == 0) and (self.__nan_if_all_0): + all_0 = (weights == 0).all().item() + if (all_0) and (self.__nan_if_all_0): res = np.nan else: + weights = weights.sum(dim=1) + numer = len(weights) + w_var = torch.var(weights).item() res = (numer/(1+w_var)) return res diff --git a/src/offline_rl_ope/OPEEstimators/IS.py b/src/offline_rl_ope/OPEEstimators/IS.py index 8b207c4..672ac9f 100644 --- a/src/offline_rl_ope/OPEEstimators/IS.py +++ b/src/offline_rl_ope/OPEEstimators/IS.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List from .utils import ( - WISNormWeights, NormWeightsPass, WeightNorm, + WISWeightNorm, VanillaNormWeights, WeightNorm, clip_weights_pass as cwp, clip_weights as cw ) @@ -23,9 +23,9 @@ def __init__( ) -> None: super().__init__(cache_traj_rewards) if norm_weights: - _norm_weights = WISNormWeights(**norm_kwargs) + _norm_weights = WISWeightNorm(**norm_kwargs) else: - _norm_weights = NormWeightsPass(**norm_kwargs) + _norm_weights = VanillaNormWeights(**norm_kwargs) self.norm_weights:WeightNorm = _norm_weights self.clip = clip if clip_weights: diff --git a/src/offline_rl_ope/OPEEstimators/utils.py b/src/offline_rl_ope/OPEEstimators/utils.py index 2ec5deb..0719f8c 100644 --- a/src/offline_rl_ope/OPEEstimators/utils.py +++ b/src/offline_rl_ope/OPEEstimators/utils.py @@ -9,25 +9,52 @@ class WeightNorm(metaclass=ABCMeta): def __call__(self, traj_is_weights:torch.Tensor, is_msk:torch.Tensor ) -> torch.Tensor: pass + +# is_msk.sum(axis=0, keepdim=True) is taken as the +# denominator since it is required to take the average over valid time t +# importance ratios. This may differ for different episodes. +# ref: http://proceedings.mlr.press/v48/jiang16.pdf + + +class WISWeightNorm(WeightNorm): -class WISNormWeights(WeightNorm): - - def __init__(self, smooth_eps:float=0.0, *args, **kwargs) -> None: + def __init__( + self, + smooth_eps:float=0.0, + avg_denom:bool=False, + discount:float=1, + *args, + **kwargs + ) -> None: self.smooth_eps = smooth_eps + self.avg_denom = avg_denom + self.discount = discount - def calc_norm(self, traj_is_weights:torch.Tensor, is_msk:torch.Tensor - ) -> torch.Tensor: - """Calculates the denominator for weighted importance sampling i.e. - w_{t} = 1/n sum_{i=1}^{n} p_{1:t}. Note, if traj_is_weights represent - vanilla IS samples then this will be w_{t} = 1/n sum_{i=1}^{n} p_{1:H} - for all samples. is_msk.sum(axis=0, keepdim=True) is taken as the - denominator since it is required to take the average over valid time t - importance ratios. This may differ for different episodes. - ref: http://proceedings.mlr.press/v48/jiang16.pdf + def calc_norm( + self, + traj_is_weights:torch.Tensor, + is_msk:torch.Tensor + ) -> torch.Tensor: + """Calculates the denominator for weighted importance sampling. smooth_eps prevents nan values occuring in instances where there exists valid time t importance ratios however, these are all 0. This should be set as small as possible. - + avg_denom: defines the denominator as the average weight for time t + as per http://proceedings.mlr.press/v48/jiang16.pdf + + Note: + - If traj_is_weights represents vanilla IS samples then: + - The denominator will be w_{t} = sum_{i=1}^{n} p_{1:H} for all + samples. + - If avg_denom is set to true, the denominator will be + w_{t} = 1/n_{t} sum_{i=1}^{n} p_{1:H} where n_{t} is the number of + trajectories of at least length, t. + - If traj_is_weights represents PD IS samples then: + - The denominator will be w_{t} = sum_{i=1}^{n} p_{1:t}. + - If avg_denom is set to true, the denominator will be + w_{t} = 1/n_{t} sum_{i=1}^{n} p_{1:t} where n_{t} is the number of + trajectories of at least length, t. This definition aligns with + http://proceedings.mlr.press/v48/jiang16.pdf Args: traj_is_weights (torch.Tensor): (# trajectories, max(traj_length)) Tensor. traj_is_weights[i,j] defines the jth timestep propensity @@ -40,11 +67,19 @@ def calc_norm(self, traj_is_weights:torch.Tensor, is_msk:torch.Tensor torch.Tensor: Tensor of dimension (# trajectories, 1) defining the normalisation value for each timestep """ - denom:torch.Tensor = traj_is_weights.sum(dim=0, keepdim=True) - denom = (denom+self.smooth_eps)/( - is_msk.sum(dim=0, keepdim=True)+self.smooth_eps) + discnt_tens = torch.full(traj_is_weights.shape, self.discount) + discnt_pows = torch.arange(0, traj_is_weights.shape[1])[None,:].repeat( + traj_is_weights.shape[0],1) + discnt_tens = torch.pow(discnt_tens,discnt_pows) + traj_is_weights = torch.mul(traj_is_weights,discnt_tens) + denom = ( + traj_is_weights.sum(dim=0, keepdim=True) + self.smooth_eps + ) + if self.avg_denom: + denom = denom/( + is_msk.sum(dim=0, keepdim=True)+self.smooth_eps) return denom - + def __call__(self, traj_is_weights:torch.Tensor, is_msk:torch.Tensor ) -> torch.Tensor: """Normalised propensity weights according to @@ -63,10 +98,12 @@ def __call__(self, traj_is_weights:torch.Tensor, is_msk:torch.Tensor with normalised weights """ denom = self.calc_norm(traj_is_weights=traj_is_weights, is_msk=is_msk) - res = traj_is_weights/(denom+self.smooth_eps) + res = traj_is_weights/denom return res + + -class NormWeightsPass(WeightNorm): +class VanillaNormWeights(WeightNorm): def __init__(self, *args, **kwargs) -> None: pass @@ -84,9 +121,11 @@ def __call__(self, traj_is_weights:torch.Tensor, is_msk:torch.Tensor ith trajectory was observed Returns: - torch.Tensor: Identical tensor to traj_is_weights + torch.Tensor: traj_is_weights with element wise average """ - return traj_is_weights + # The first dimension defines the number of trajectories and we require + # the average over trajectories + return traj_is_weights/traj_is_weights.shape[0] def clip_weights( traj_is_weights:torch.Tensor, diff --git a/tests/Metrics/test_EffectiveSampleSize.py b/tests/Metrics/test_EffectiveSampleSize.py index 6d6fe10..10a5878 100644 --- a/tests/Metrics/test_EffectiveSampleSize.py +++ b/tests/Metrics/test_EffectiveSampleSize.py @@ -23,6 +23,9 @@ def test_call(self): assert len(weights) == 2 denum = 1 + torch.var(weights) act_res = (num/denum).item() - metric = EffectiveSampleSize(is_obj=TestImportanceSampler()) - pred_res = metric() - self.assertEqual(act_res,pred_res) \ No newline at end of file + metric = EffectiveSampleSize(nan_if_all_0=True) + pred_res = metric( + weights=weight_test_res + ) + tol = act_res/1000 + np.testing.assert_allclose(pred_res, act_res, atol=tol) \ No newline at end of file diff --git a/tests/OPEEstimators/test_DoublyRobust.py b/tests/OPEEstimators/test_DoublyRobust.py index a1c0f72..5cfda2a 100644 --- a/tests/OPEEstimators/test_DoublyRobust.py +++ b/tests/OPEEstimators/test_DoublyRobust.py @@ -120,8 +120,11 @@ def v_side_effect(state:torch.Tensor): weights=weight_test_res, discount=gamma, is_msk=msk_test_res) + #weight_test_res = weight_test_res/weight_test_res.shape[0] + denom = weight_test_res.shape[0] for idx, (r,s,a,w,msk) in enumerate(zip(rewards, states, actions, weight_test_res, msk_test_res)): + w = w/denom p = torch.masked_select(w, msk>0) __test_res = is_est.get_traj_discnt_reward( reward_array=r, discount=gamma, state_array=s, action_array=a, @@ -129,9 +132,8 @@ def v_side_effect(state:torch.Tensor): test_res.append(__test_res.numpy()) #test_res = np.concatenate(test_res).mean() test_res = np.concatenate(test_res) - tol = (test_res.mean()/1000).item() + tol = (np.abs(test_res.mean()/100)).item() self.assertEqual(pred_res.shape, torch.Size((len(rewards),))) - np.testing.assert_allclose(pred_res.numpy(),test_res, atol=tol) - + np.testing.assert_allclose(pred_res.numpy(),test_res, atol=tol) \ No newline at end of file diff --git a/tests/OPEEstimators/test_IS.py b/tests/OPEEstimators/test_IS.py index 0af439a..f88c597 100644 --- a/tests/OPEEstimators/test_IS.py +++ b/tests/OPEEstimators/test_IS.py @@ -58,7 +58,10 @@ def __mock_return(rewards, discount, h): pred_res = self.is_estimator.predict_traj_rewards( rewards=rewards, actions=[], states=[], weights=weight_test_res, discount=gamma, is_msk=msk_test_res) - test_res = np.multiply(reward_test_res.numpy(), weight_test_res.numpy()) + test_res = np.multiply( + reward_test_res.numpy(), + weight_test_res.numpy()/weight_test_res.shape[0] + ) test_res=test_res.sum(axis=1) #test_res = test_res.sum(axis=1).mean() tol = test_res.mean()/1000 diff --git a/tests/OPEEstimators/test_utils.py b/tests/OPEEstimators/test_utils.py index 9dca121..2d00b9e 100644 --- a/tests/OPEEstimators/test_utils.py +++ b/tests/OPEEstimators/test_utils.py @@ -3,7 +3,7 @@ import torch import unittest from offline_rl_ope.OPEEstimators.utils import ( - clip_weights, clip_weights_pass, NormWeightsPass, WISNormWeights) + clip_weights, clip_weights_pass, VanillaNormWeights, WISWeightNorm) from ..base import (weight_test_res, msk_test_res) weight_test_res_alter = copy.deepcopy(weight_test_res) @@ -37,10 +37,11 @@ def test_clip_weights_pass(self): # np.testing.assert_allclose(pred_res.numpy(), test_res.numpy(), # atol=toll.numpy()) - def test_norm_weights_pass(self): - test_res = copy.deepcopy(weight_test_res) + def test_norm_weights_vanilla(self): + denom = weight_test_res.shape[0] + test_res = weight_test_res/denom toll = test_res.mean()/1000 - calculator = NormWeightsPass() + calculator = VanillaNormWeights() pred_res = calculator(traj_is_weights=weight_test_res, is_msk=msk_test_res) self.assertEqual(pred_res.shape,weight_test_res.shape) @@ -48,10 +49,10 @@ def test_norm_weights_pass(self): atol=toll.numpy()) def test_norm_weights_wis(self): - denom = weight_test_res.sum(dim=0)/msk_test_res.sum(dim=0) + denom = weight_test_res.sum(dim=0) test_res = weight_test_res/denom toll = test_res.mean()/1000 - calculator = WISNormWeights() + calculator = WISWeightNorm() pred_res = calculator(traj_is_weights=weight_test_res, is_msk=msk_test_res) self.assertEqual(pred_res.shape,weight_test_res.shape) @@ -60,11 +61,10 @@ def test_norm_weights_wis(self): def test_norm_weights_wis_smooth(self): smooth_eps = 0.00000001 - denom = (weight_test_res_alter.sum(dim=0)+smooth_eps)/( - msk_test_res.sum(dim=0)+smooth_eps) - test_res: torch.Tensor = weight_test_res_alter/(denom) + denom = weight_test_res_alter.sum(dim=0)+smooth_eps + test_res = weight_test_res_alter/denom toll = test_res.nanmean()/1000 - calculator = WISNormWeights(smooth_eps=smooth_eps) + calculator = WISWeightNorm(smooth_eps=smooth_eps) pred_res = calculator(traj_is_weights=weight_test_res_alter, is_msk=msk_test_res) self.assertEqual(pred_res.shape,weight_test_res_alter.shape) @@ -72,13 +72,170 @@ def test_norm_weights_wis_smooth(self): atol=toll.numpy()) def test_norm_weights_wis_no_smooth(self): - denom = weight_test_res_alter.sum(dim=0)/msk_test_res.sum(dim=0) - test_res: torch.Tensor = weight_test_res_alter/denom + denom = weight_test_res_alter.sum(dim=0) + test_res = weight_test_res_alter/denom toll = test_res.nanmean()/1000 - calculator = WISNormWeights() + calculator = WISWeightNorm() pred_res = calculator(traj_is_weights=weight_test_res_alter, is_msk=msk_test_res) self.assertEqual(pred_res.shape,weight_test_res_alter.shape) np.testing.assert_allclose(pred_res.numpy(), test_res.numpy(), atol=toll.numpy(), equal_nan=True) + def test_norm_weights_wis_smooth_discount(self): + smooth_eps = 0.00000001 + discount=0.99 + discnt_tens = torch.full( + weight_test_res_alter.shape, + discount + ) + discnt_pows = torch.arange( + 0, weight_test_res_alter.shape[1])[None,:].repeat( + weight_test_res_alter.shape[0],1 + ) + discnt_tens = torch.pow(discnt_tens,discnt_pows) + denom = torch.mul( + weight_test_res_alter, + discnt_tens + ) + denom = denom.sum(dim=0)+smooth_eps + test_res = weight_test_res_alter/denom + toll = test_res.nanmean()/1000 + calculator = WISWeightNorm( + smooth_eps=smooth_eps, + discount=discount + ) + pred_res = calculator(traj_is_weights=weight_test_res_alter, + is_msk=msk_test_res) + self.assertEqual(pred_res.shape,weight_test_res_alter.shape) + np.testing.assert_allclose(pred_res.numpy(), test_res.numpy(), + atol=toll.numpy()) + + def test_norm_weights_wis_no_smooth_discount(self): + discount=0.99 + discnt_tens = torch.full( + weight_test_res_alter.shape, + discount + ) + discnt_pows = torch.arange( + 0, weight_test_res_alter.shape[1])[None,:].repeat( + weight_test_res_alter.shape[0],1 + ) + discnt_tens = torch.pow(discnt_tens,discnt_pows) + denom = torch.mul( + weight_test_res_alter, + discnt_tens + ) + denom = denom.sum(dim=0) + test_res = weight_test_res_alter/denom + toll = test_res.nanmean()/1000 + calculator = WISWeightNorm( + discount=discount + ) + pred_res = calculator(traj_is_weights=weight_test_res_alter, + is_msk=msk_test_res) + self.assertEqual(pred_res.shape,weight_test_res_alter.shape) + np.testing.assert_allclose(pred_res.numpy(), test_res.numpy(), + atol=toll.numpy()) + + def test_norm_weights_wis_smooth_avg(self): + smooth_eps = 0.00000001 + time_t_freq = msk_test_res.sum(dim=0, keepdim=True).repeat( + msk_test_res.shape[0],1 + ) + denom = weight_test_res_alter/time_t_freq + denom = denom.sum(dim=0)+smooth_eps + test_res = weight_test_res_alter/denom + toll = test_res.nanmean()/1000 + calculator = WISWeightNorm( + smooth_eps=smooth_eps, + avg_denom=True + ) + pred_res = calculator(traj_is_weights=weight_test_res_alter, + is_msk=msk_test_res) + self.assertEqual(pred_res.shape,weight_test_res_alter.shape) + np.testing.assert_allclose(pred_res.numpy(), test_res.numpy(), + atol=toll.numpy()) + + def test_norm_weights_wis_no_smooth_avg(self): + time_t_freq = msk_test_res.sum(dim=0, keepdim=True).repeat( + msk_test_res.shape[0],1 + ) + denom = weight_test_res_alter/time_t_freq + denom = denom.sum(dim=0) + test_res = weight_test_res_alter/denom + toll = test_res.nanmean()/1000 + calculator = WISWeightNorm( + avg_denom=True + ) + pred_res = calculator(traj_is_weights=weight_test_res_alter, + is_msk=msk_test_res) + self.assertEqual(pred_res.shape,weight_test_res_alter.shape) + np.testing.assert_allclose(pred_res.numpy(), test_res.numpy(), + atol=toll.numpy()) + + def test_norm_weights_wis_smooth_discount_avg(self): + smooth_eps = 0.00000001 + discount=0.99 + discnt_tens = torch.full( + weight_test_res_alter.shape, + discount + ) + discnt_pows = torch.arange( + 0, weight_test_res_alter.shape[1])[None,:].repeat( + weight_test_res_alter.shape[0],1 + ) + discnt_tens = torch.pow(discnt_tens,discnt_pows) + denom = torch.mul( + weight_test_res_alter, + discnt_tens + ) + time_t_freq = msk_test_res.sum(dim=0, keepdim=True).repeat( + msk_test_res.shape[0],1 + ) + denom = denom/time_t_freq + denom = denom.sum(dim=0)+smooth_eps + test_res = weight_test_res_alter/denom + toll = test_res.nanmean()/1000 + calculator = WISWeightNorm( + smooth_eps=smooth_eps, + discount=discount, + avg_denom=True + ) + pred_res = calculator(traj_is_weights=weight_test_res_alter, + is_msk=msk_test_res) + self.assertEqual(pred_res.shape,weight_test_res_alter.shape) + np.testing.assert_allclose(pred_res.numpy(), test_res.numpy(), + atol=toll.numpy()) + + def test_norm_weights_wis_no_smooth_discount_avg(self): + discount=0.99 + discnt_tens = torch.full( + weight_test_res_alter.shape, + discount + ) + discnt_pows = torch.arange( + 0, weight_test_res_alter.shape[1])[None,:].repeat( + weight_test_res_alter.shape[0],1 + ) + discnt_tens = torch.pow(discnt_tens,discnt_pows) + denom = torch.mul( + weight_test_res_alter, + discnt_tens + ) + time_t_freq = msk_test_res.sum(dim=0, keepdim=True).repeat( + msk_test_res.shape[0],1 + ) + denom = denom/time_t_freq + denom = denom.sum(dim=0) + test_res = weight_test_res_alter/denom + toll = test_res.nanmean()/1000 + calculator = WISWeightNorm( + discount=0.99, + avg_denom=True + ) + pred_res = calculator(traj_is_weights=weight_test_res_alter, + is_msk=msk_test_res) + self.assertEqual(pred_res.shape,weight_test_res_alter.shape) + np.testing.assert_allclose(pred_res.numpy(), test_res.numpy(), + atol=toll.numpy())