Skip to content

Commit

Permalink
fixed per decision weighted IS. Updated testing. Altered effective sa…
Browse files Browse the repository at this point in the history
…mple size to return nan if all weights are 0
  • Loading branch information
joshuaspear committed Mar 1, 2024
1 parent d3ed4d9 commit bf8d506
Show file tree
Hide file tree
Showing 8 changed files with 264 additions and 48 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# offline_rl_ope (BETA RELEASE)

**WARNING**
- Per-decision weighted importance sampling was incorrectly implemented in versions < 5.X
- Weighted importance sampling was incorrectly implemented in versions 1.X.X and 2.1.X, 2.2.X
- Unit testing currently only running in Python 3.11. 3.10 will be supported in the future
- Only 1 dimensional discrete action spaces are currently supported!
Expand Down Expand Up @@ -88,6 +89,16 @@ If importance sampling based methods are evaluating to 0, consider visualising t
The different kinds of importance samples can also be visualised by querying the ```traj_is_weights``` attribute of a given ```ImportanceSampler``` object. If for example, vanilla importance sampling is being used and the samples are not ```NaN``` or ```Inf``` then visualising the ```traj_is_weights``` may provide insight. In particular, IS weights will tend to inifinity when the evaluation policy places large density on an action in comparison to the behaviour policy.

### Release log
#### 5.0.0
* Correctly implemented per-decision weighted importance sampling
* Expanded the different types of weights that can be implemented based on:
* http://proceedings.mlr.press/v48/jiang16.pdf: Per-decision weights are defined as the average weight at a given timepoint. This results in a different denominator for different timepoints. This is implemented with the following ```WISWeightNorm(avg_denom=True)```
* https://scholarworks.umass.edu/cgi/viewcontent.cgi?article=1079&context=cs_faculty_pubs: Per-decision weights are defined as the sum of discounted weights across all timesteps. This is implemented with the following ```WISWeightNorm(discount=discount_value)```
* Combinations of different weights can be easily implemented for example 'average discounted weights' ```WISWeightNorm(discount=discount_value, avg_denom=True)``` however, these do not necessaily have backing from literature.
* EffectiveSampleSize metric optinally returns nan if all weights are 0
* Bug fixes:
* Fix bug when running on cuda where tensors were not being pushed to CPU
* Improved static typing
#### 4.0.0
* Predefined propensity models including:
* Generic feedforward MLP for continuous and discrete action spaces built in PyTorch
Expand Down
9 changes: 5 additions & 4 deletions src/offline_rl_ope/Metrics/EffectiveSampleSize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ def __init__(self, nan_if_all_0:bool=True) -> None:

def __ess(self, weights:torch.Tensor) -> float:
# https://victorelvira.github.io/papers/kong92.pdf
weights = weights.sum(dim=1)
numer = len(weights)
w_var = torch.var(weights).item()
if (w_var == 0) and (self.__nan_if_all_0):
all_0 = (weights == 0).all().item()
if (all_0) and (self.__nan_if_all_0):
res = np.nan
else:
weights = weights.sum(dim=1)
numer = len(weights)
w_var = torch.var(weights).item()
res = (numer/(1+w_var))
return res

Expand Down
6 changes: 3 additions & 3 deletions src/offline_rl_ope/OPEEstimators/IS.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Dict, List

from .utils import (
WISNormWeights, NormWeightsPass, WeightNorm,
WISWeightNorm, VanillaNormWeights, WeightNorm,
clip_weights_pass as cwp,
clip_weights as cw
)
Expand All @@ -23,9 +23,9 @@ def __init__(
) -> None:
super().__init__(cache_traj_rewards)
if norm_weights:
_norm_weights = WISNormWeights(**norm_kwargs)
_norm_weights = WISWeightNorm(**norm_kwargs)
else:
_norm_weights = NormWeightsPass(**norm_kwargs)
_norm_weights = VanillaNormWeights(**norm_kwargs)
self.norm_weights:WeightNorm = _norm_weights
self.clip = clip
if clip_weights:
Expand Down
81 changes: 60 additions & 21 deletions src/offline_rl_ope/OPEEstimators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,52 @@ class WeightNorm(metaclass=ABCMeta):
def __call__(self, traj_is_weights:torch.Tensor, is_msk:torch.Tensor
) -> torch.Tensor:
pass

# is_msk.sum(axis=0, keepdim=True) is taken as the
# denominator since it is required to take the average over valid time t
# importance ratios. This may differ for different episodes.
# ref: http://proceedings.mlr.press/v48/jiang16.pdf


class WISWeightNorm(WeightNorm):

class WISNormWeights(WeightNorm):

def __init__(self, smooth_eps:float=0.0, *args, **kwargs) -> None:
def __init__(
self,
smooth_eps:float=0.0,
avg_denom:bool=False,
discount:float=1,
*args,
**kwargs
) -> None:
self.smooth_eps = smooth_eps
self.avg_denom = avg_denom
self.discount = discount

def calc_norm(self, traj_is_weights:torch.Tensor, is_msk:torch.Tensor
) -> torch.Tensor:
"""Calculates the denominator for weighted importance sampling i.e.
w_{t} = 1/n sum_{i=1}^{n} p_{1:t}. Note, if traj_is_weights represent
vanilla IS samples then this will be w_{t} = 1/n sum_{i=1}^{n} p_{1:H}
for all samples. is_msk.sum(axis=0, keepdim=True) is taken as the
denominator since it is required to take the average over valid time t
importance ratios. This may differ for different episodes.
ref: http://proceedings.mlr.press/v48/jiang16.pdf
def calc_norm(
self,
traj_is_weights:torch.Tensor,
is_msk:torch.Tensor
) -> torch.Tensor:
"""Calculates the denominator for weighted importance sampling.
smooth_eps prevents nan values occuring in instances where there exists
valid time t importance ratios however, these are all 0. This should
be set as small as possible.
avg_denom: defines the denominator as the average weight for time t
as per http://proceedings.mlr.press/v48/jiang16.pdf
Note:
- If traj_is_weights represents vanilla IS samples then:
- The denominator will be w_{t} = sum_{i=1}^{n} p_{1:H} for all
samples.
- If avg_denom is set to true, the denominator will be
w_{t} = 1/n_{t} sum_{i=1}^{n} p_{1:H} where n_{t} is the number of
trajectories of at least length, t.
- If traj_is_weights represents PD IS samples then:
- The denominator will be w_{t} = sum_{i=1}^{n} p_{1:t}.
- If avg_denom is set to true, the denominator will be
w_{t} = 1/n_{t} sum_{i=1}^{n} p_{1:t} where n_{t} is the number of
trajectories of at least length, t. This definition aligns with
http://proceedings.mlr.press/v48/jiang16.pdf
Args:
traj_is_weights (torch.Tensor): (# trajectories, max(traj_length))
Tensor. traj_is_weights[i,j] defines the jth timestep propensity
Expand All @@ -40,11 +67,19 @@ def calc_norm(self, traj_is_weights:torch.Tensor, is_msk:torch.Tensor
torch.Tensor: Tensor of dimension (# trajectories, 1) defining the
normalisation value for each timestep
"""
denom:torch.Tensor = traj_is_weights.sum(dim=0, keepdim=True)
denom = (denom+self.smooth_eps)/(
is_msk.sum(dim=0, keepdim=True)+self.smooth_eps)
discnt_tens = torch.full(traj_is_weights.shape, self.discount)
discnt_pows = torch.arange(0, traj_is_weights.shape[1])[None,:].repeat(
traj_is_weights.shape[0],1)
discnt_tens = torch.pow(discnt_tens,discnt_pows)
traj_is_weights = torch.mul(traj_is_weights,discnt_tens)
denom = (
traj_is_weights.sum(dim=0, keepdim=True) + self.smooth_eps
)
if self.avg_denom:
denom = denom/(
is_msk.sum(dim=0, keepdim=True)+self.smooth_eps)
return denom

def __call__(self, traj_is_weights:torch.Tensor, is_msk:torch.Tensor
) -> torch.Tensor:
"""Normalised propensity weights according to
Expand All @@ -63,10 +98,12 @@ def __call__(self, traj_is_weights:torch.Tensor, is_msk:torch.Tensor
with normalised weights
"""
denom = self.calc_norm(traj_is_weights=traj_is_weights, is_msk=is_msk)
res = traj_is_weights/(denom+self.smooth_eps)
res = traj_is_weights/denom
return res



class NormWeightsPass(WeightNorm):
class VanillaNormWeights(WeightNorm):

def __init__(self, *args, **kwargs) -> None:
pass
Expand All @@ -84,9 +121,11 @@ def __call__(self, traj_is_weights:torch.Tensor, is_msk:torch.Tensor
ith trajectory was observed
Returns:
torch.Tensor: Identical tensor to traj_is_weights
torch.Tensor: traj_is_weights with element wise average
"""
return traj_is_weights
# The first dimension defines the number of trajectories and we require
# the average over trajectories
return traj_is_weights/traj_is_weights.shape[0]

def clip_weights(
traj_is_weights:torch.Tensor,
Expand Down
9 changes: 6 additions & 3 deletions tests/Metrics/test_EffectiveSampleSize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def test_call(self):
assert len(weights) == 2
denum = 1 + torch.var(weights)
act_res = (num/denum).item()
metric = EffectiveSampleSize(is_obj=TestImportanceSampler())
pred_res = metric()
self.assertEqual(act_res,pred_res)
metric = EffectiveSampleSize(nan_if_all_0=True)
pred_res = metric(
weights=weight_test_res
)
tol = act_res/1000
np.testing.assert_allclose(pred_res, act_res, atol=tol)
8 changes: 5 additions & 3 deletions tests/OPEEstimators/test_DoublyRobust.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,20 @@ def v_side_effect(state:torch.Tensor):
weights=weight_test_res,
discount=gamma,
is_msk=msk_test_res)
#weight_test_res = weight_test_res/weight_test_res.shape[0]
denom = weight_test_res.shape[0]
for idx, (r,s,a,w,msk) in enumerate(zip(rewards, states, actions,
weight_test_res, msk_test_res)):
w = w/denom
p = torch.masked_select(w, msk>0)
__test_res = is_est.get_traj_discnt_reward(
reward_array=r, discount=gamma, state_array=s, action_array=a,
weight_array=p)
test_res.append(__test_res.numpy())
#test_res = np.concatenate(test_res).mean()
test_res = np.concatenate(test_res)
tol = (test_res.mean()/1000).item()
tol = (np.abs(test_res.mean()/100)).item()
self.assertEqual(pred_res.shape, torch.Size((len(rewards),)))
np.testing.assert_allclose(pred_res.numpy(),test_res, atol=tol)

np.testing.assert_allclose(pred_res.numpy(),test_res, atol=tol)


5 changes: 4 additions & 1 deletion tests/OPEEstimators/test_IS.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ def __mock_return(rewards, discount, h):
pred_res = self.is_estimator.predict_traj_rewards(
rewards=rewards, actions=[], states=[], weights=weight_test_res,
discount=gamma, is_msk=msk_test_res)
test_res = np.multiply(reward_test_res.numpy(), weight_test_res.numpy())
test_res = np.multiply(
reward_test_res.numpy(),
weight_test_res.numpy()/weight_test_res.shape[0]
)
test_res=test_res.sum(axis=1)
#test_res = test_res.sum(axis=1).mean()
tol = test_res.mean()/1000
Expand Down
Loading

0 comments on commit bf8d506

Please sign in to comment.