Skip to content

Commit

Permalink
working tests for base. Not all new tests added yet
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuaspear committed Jul 24, 2024
1 parent f91fd0b commit 5f59fa1
Show file tree
Hide file tree
Showing 18 changed files with 1,099 additions and 869 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ If importance sampling based methods are evaluating to 0, consider visualising t
The different kinds of importance samples can also be visualised by querying the ```traj_is_weights``` attribute of a given ```ImportanceSampler``` object. If for example, vanilla importance sampling is being used and the samples are not ```NaN``` or ```Inf``` then visualising the ```traj_is_weights``` may provide insight. In particular, IS weights will tend to inifinity when the evaluation policy places large density on an action in comparison to the behaviour policy.

### Release log
#### 6.1.0 (forthcoming)
* Altered discrete torch propensity model to use softmax instead of torch. Requires modelling both classes for binary classification however, improves generalisability of code
#### 7.0.0 (Major API release)
* Altered discrete torch propensity model to use softmax instead of torch. Requires modelling both classes for binary classification however, improves generalisability of code

#### 6.0.0
* Updated PropensityModels structure for sklearn and added a helper class for compatability with torch
Expand Down
70 changes: 11 additions & 59 deletions src/offline_rl_ope/OPEEstimators/DoublyRobust.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,34 @@
from jaxtyping import jaxtyped, Float
from typeguard import typechecked as typechecker

from ..types import (
WeightTensor,
RewardTensor,
StateTensor,
ActionTensor,
SingleTrajSingleStepTensor)
from .EmpiricalMeanDenom import EmpiricalMeanDenomBase
from .WeightDenom import WeightDenomBase
from ..types import WeightTensor
from .IS import ISEstimator
from .DirectMethod import DirectMethodBase
from ..RuntimeChecks import check_array_shape


class DREstimator(ISEstimator):
""" Doubly robust estimator implemented as per:
https://arxiv.org/pdf/1511.03722.pdf
"""

def __init__(
self,
self,
empirical_denom:EmpiricalMeanDenomBase,
weight_denom:WeightDenomBase,
dm_model:DirectMethodBase,
norm_weights: bool,
clip_weights:bool=False,
clip:float=0.0,
cache_traj_rewards:bool=False,
norm_kwargs:Dict[str,Any] = {}
) -> None:
assert isinstance(dm_model,DirectMethodBase)
assert isinstance(norm_weights,bool)
assert isinstance(clip,(float,type(None)))
assert isinstance(cache_traj_rewards,bool)
assert isinstance(norm_kwargs,Dict)
super().__init__(
norm_weights=norm_weights,
empirical_denom=empirical_denom,
weight_denom=weight_denom,
clip_weights=clip_weights,
clip=clip,
cache_traj_rewards=cache_traj_rewards,
norm_kwargs=norm_kwargs
cache_traj_rewards=cache_traj_rewards
)
self.dm_model = dm_model

Expand Down Expand Up @@ -93,6 +85,7 @@ def predict_traj_rewards(
rewards=rewards, discount=discount, h=h)
# weights dim is (n_trajectories, max_length)
weights = self.process_weights(weights=weights, is_msk=is_msk)
print(f"weights:{weights}")
v:List[Float[torch.Tensor, "max_length 1"]] = []
q:List[Float[torch.Tensor, "max_length 1"]] = []
for s,a in zip(states, actions):
Expand Down Expand Up @@ -136,46 +129,5 @@ def predict_traj_rewards(
0, 1
)
_t4 = torch.mul((_t2-_t3),discnt_vals)
res = (_t1-_t4).sum(dim=1)/n_traj
res = (_t1-_t4).sum(dim=1)
return res


class DR(DREstimator):

def __init__(
self,
dm_model: DirectMethodBase,
clip_weights: bool = False,
clip: float = 0.0,
cache_traj_rewards: bool = False,
norm_kwargs: Dict[str, Any] = {}
) -> None:
assert "avg_denom" not in norm_kwargs.keys(), "avg_denom is already set"
super().__init__(
dm_model=dm_model,
norm_weights=False,
clip_weights=clip_weights,
clip=clip,
cache_traj_rewards=cache_traj_rewards,
norm_kwargs={"avg_denom": False, **norm_kwargs}
)

class WDR(DREstimator):

def __init__(
self,
dm_model: DirectMethodBase,
clip_weights: bool = False,
clip: float = 0.0,
cache_traj_rewards: bool = False,
norm_kwargs: Dict[str, Any] = {}
) -> None:
assert "avg_denom" not in norm_kwargs.keys(), "avg_denom is already set"
super().__init__(
dm_model=dm_model,
norm_weights=True,
clip_weights=clip_weights,
clip=clip,
cache_traj_rewards=cache_traj_rewards,
norm_kwargs={"avg_denom": False, **norm_kwargs}
)
69 changes: 69 additions & 0 deletions src/offline_rl_ope/OPEEstimators/EmpiricalMeanDenom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from abc import ABCMeta, abstractmethod
import torch
from jaxtyping import jaxtyped, Float
from typeguard import typechecked as typechecker

from ..types import WeightTensor
from .utils import get_traj_weight_final


class EmpiricalMeanDenomBase(metaclass=ABCMeta):

@abstractmethod
def __call__(
self,
weights:WeightTensor,
is_msk:WeightTensor
) -> Float[torch.Tensor,""]:
pass

class EmpiricalMeanDenom(EmpiricalMeanDenomBase):

def __init__(self) -> None:
"""Empirical mean denominator:
- http://incompleteideas.net/papers/PSS-00.pdf (Q^{IS} when weights are IS)
- http://incompleteideas.net/papers/PSS-00.pdf (Q^{PD} when weights are PD)
"""
super().__init__()

@jaxtyped(typechecker=typechecker)
def __call__(
self,
weights:WeightTensor,
is_msk:WeightTensor
) -> Float[torch.Tensor,""]:
return torch.tensor(weights.shape[0]).float()


class WeightedEmpiricalMeanDenom(EmpiricalMeanDenomBase):

def __init__(
self,
smooth_eps:float=0.0,
cumulative:bool=False
) -> None:
"""Empirical mean denominator:
- http://incompleteideas.net/papers/PSS-00.pdf
(Q^{ISW} when weights are IS)
- http://incompleteideas.net/papers/PSS-00.pdf
(Q^{PDW} when weights are PD and cumulative = True)
"""
super().__init__()
self.cumulative = cumulative
self.smooth_eps = smooth_eps

@jaxtyped(typechecker=typechecker)
def __call__(
self,
weights:WeightTensor,
is_msk:WeightTensor
) -> Float[torch.Tensor,""]:
if self.cumulative:
# For each timepoint, sum across the trajectories
denom = torch.mul(weights,is_msk).sum()
else:
denom = get_traj_weight_final(weights=weights, is_msk=is_msk)
denom = denom.sum()
denom = denom + self.smooth_eps
return denom

47 changes: 23 additions & 24 deletions src/offline_rl_ope/OPEEstimators/IS.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from jaxtyping import jaxtyped, Float
from typeguard import typechecked as typechecker

from .. import logger
from .utils import (
WISWeightNorm, VanillaNormWeights, WeightNorm,
clip_weights_pass as cwp,
clip_weights as cw
)
from .EmpiricalMeanDenom import EmpiricalMeanDenomBase
from .WeightDenom import WeightDenomBase
from .base import OPEEstimatorBase
from ..types import (RewardTensor,WeightTensor)

Expand All @@ -17,28 +17,28 @@ class ISEstimatorBase(OPEEstimatorBase):

def __init__(
self,
norm_weights:bool,
empirical_denom:EmpiricalMeanDenomBase,
weight_denom:WeightDenomBase,
clip_weights:bool=False,
cache_traj_rewards:bool=False,
clip:float=0.0,
norm_kwargs:Dict[str,Union[str,bool]] = {}
) -> None:
super().__init__(cache_traj_rewards)
assert isinstance(norm_weights,bool)
super().__init__(
empirical_denom=empirical_denom,
cache_traj_rewards=cache_traj_rewards
)
assert isinstance(weight_denom,WeightDenomBase)
assert isinstance(clip_weights,bool)
assert isinstance(cache_traj_rewards,bool)
assert isinstance(clip,float)
assert isinstance(norm_kwargs,Dict)
if norm_weights:
_norm_weights = WISWeightNorm(**norm_kwargs)
else:
_norm_weights = VanillaNormWeights(**norm_kwargs)
self.norm_weights:WeightNorm = _norm_weights
self.clip = clip
if clip_weights:
self.clip_weights = cw
else:
self.clip_weights = cwp
self.weight_denom = weight_denom

@jaxtyped(typechecker=typechecker)
def process_weights(
Expand All @@ -58,12 +58,8 @@ def process_weights(
WeightTensor: Tensor of processed weight, of dimension
(n_trajectories, max_length)
"""
# assert isinstance(weights,torch.Tensor)
# assert isinstance(is_msk,torch.Tensor)
# assert weights.shape == is_msk.shape
weights = self.clip_weights(
traj_is_weights=weights, clip=self.clip)
weights = self.norm_weights(traj_is_weights=weights, is_msk=is_msk)
weights = self.clip_weights(weights=weights, clip=self.clip)
weights = self.weight_denom(weights=weights, is_msk=is_msk)
return weights

def get_dataset_discnt_reward(
Expand Down Expand Up @@ -137,15 +133,19 @@ class ISEstimator(ISEstimatorBase):

def __init__(
self,
norm_weights: bool,
empirical_denom:EmpiricalMeanDenomBase,
weight_denom: WeightDenomBase,
clip_weights:bool=False,
clip: float = 0.0,
cache_traj_rewards:bool=False,
norm_kwargs:Dict[str,Union[str,bool]] = {}
cache_traj_rewards:bool=False
) -> None:
super().__init__(norm_weights=norm_weights, clip_weights=clip_weights,
clip=clip, cache_traj_rewards=cache_traj_rewards,
norm_kwargs=norm_kwargs)
super().__init__(
empirical_denom=empirical_denom,
weight_denom=weight_denom,
clip_weights=clip_weights,
cache_traj_rewards=cache_traj_rewards,
clip=clip
)

@jaxtyped(typechecker=typechecker)
def predict_traj_rewards(
Expand Down Expand Up @@ -194,5 +194,4 @@ def predict_traj_rewards(
weights = self.process_weights(weights=weights, is_msk=is_msk)
# (n_trajectories,max_length) ELEMENT WISE * (n_trajectories,max_length)
res = torch.mul(discnt_rewards,weights).sum(dim=1)
return res

return res
76 changes: 76 additions & 0 deletions src/offline_rl_ope/OPEEstimators/WeightDenom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from abc import ABCMeta, abstractmethod
import torch
from jaxtyping import jaxtyped
from typeguard import typechecked as typechecker

from ..types import WeightTensor

class WeightDenomBase(metaclass=ABCMeta):

@abstractmethod
def __call__(
self,
weights:WeightTensor,
is_msk:WeightTensor
) -> WeightTensor:
pass


class PassWeightDenom(WeightDenomBase):

def __call__(
self,
weights:WeightTensor,
is_msk:WeightTensor
) -> WeightTensor:
return weights


class AvgWeightDenom(WeightDenomBase):

def __init__(self) -> None:
"""Weight denominator as per:
- https://arxiv.org/pdf/1604.00923 (DR when weights are IS)
"""
super().__init__()

@jaxtyped(typechecker=typechecker)
def __call__(
self,
weights:WeightTensor,
is_msk:WeightTensor
) -> WeightTensor:
return weights/torch.tensor(weights.shape[0])


class PiTWeightDenom(WeightDenomBase):

def __init__(
self,
smooth_eps:float=0.0
) -> None:
"""Weight denominator as per:
- https://arxiv.org/pdf/1906.03735 (snsis when weights are PD)
- https://arxiv.org/pdf/1604.00923 (WDR when weights are PD)
Args:
smooth_eps (float, optional): Laplacian smoothing. Defaults to 0.0.
"""
super().__init__()
assert isinstance(smooth_eps,float)
self.smooth_eps = smooth_eps

@jaxtyped(typechecker=typechecker)
def __call__(
self,
weights:WeightTensor,
is_msk:WeightTensor
) -> WeightTensor:
msked_weights = torch.mul(weights, is_msk)
pit_vals = msked_weights.mean(dim=0,keepdim=True).repeat(
weights.shape[0], 1
)
pit_vals = pit_vals + self.smooth_eps
res = msked_weights/pit_vals
return res

6 changes: 4 additions & 2 deletions src/offline_rl_ope/OPEEstimators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .DirectMethod import DirectMethodBase, D3rlpyQlearnDM
from .DoublyRobust import DREstimator, DR, WDR
from .DoublyRobust import DREstimator
from .IS import ISEstimatorBase, ISEstimator
from .base import OPEEstimatorBase
from .base import OPEEstimatorBase
from .EmpiricalMeanDenom import *
from .WeightDenom import *
Loading

0 comments on commit 5f59fa1

Please sign in to comment.