generated from joshuaspear/python-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
working tests for base. Not all new tests added yet
- Loading branch information
1 parent
f91fd0b
commit 5f59fa1
Showing
18 changed files
with
1,099 additions
and
869 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from abc import ABCMeta, abstractmethod | ||
import torch | ||
from jaxtyping import jaxtyped, Float | ||
from typeguard import typechecked as typechecker | ||
|
||
from ..types import WeightTensor | ||
from .utils import get_traj_weight_final | ||
|
||
|
||
class EmpiricalMeanDenomBase(metaclass=ABCMeta): | ||
|
||
@abstractmethod | ||
def __call__( | ||
self, | ||
weights:WeightTensor, | ||
is_msk:WeightTensor | ||
) -> Float[torch.Tensor,""]: | ||
pass | ||
|
||
class EmpiricalMeanDenom(EmpiricalMeanDenomBase): | ||
|
||
def __init__(self) -> None: | ||
"""Empirical mean denominator: | ||
- http://incompleteideas.net/papers/PSS-00.pdf (Q^{IS} when weights are IS) | ||
- http://incompleteideas.net/papers/PSS-00.pdf (Q^{PD} when weights are PD) | ||
""" | ||
super().__init__() | ||
|
||
@jaxtyped(typechecker=typechecker) | ||
def __call__( | ||
self, | ||
weights:WeightTensor, | ||
is_msk:WeightTensor | ||
) -> Float[torch.Tensor,""]: | ||
return torch.tensor(weights.shape[0]).float() | ||
|
||
|
||
class WeightedEmpiricalMeanDenom(EmpiricalMeanDenomBase): | ||
|
||
def __init__( | ||
self, | ||
smooth_eps:float=0.0, | ||
cumulative:bool=False | ||
) -> None: | ||
"""Empirical mean denominator: | ||
- http://incompleteideas.net/papers/PSS-00.pdf | ||
(Q^{ISW} when weights are IS) | ||
- http://incompleteideas.net/papers/PSS-00.pdf | ||
(Q^{PDW} when weights are PD and cumulative = True) | ||
""" | ||
super().__init__() | ||
self.cumulative = cumulative | ||
self.smooth_eps = smooth_eps | ||
|
||
@jaxtyped(typechecker=typechecker) | ||
def __call__( | ||
self, | ||
weights:WeightTensor, | ||
is_msk:WeightTensor | ||
) -> Float[torch.Tensor,""]: | ||
if self.cumulative: | ||
# For each timepoint, sum across the trajectories | ||
denom = torch.mul(weights,is_msk).sum() | ||
else: | ||
denom = get_traj_weight_final(weights=weights, is_msk=is_msk) | ||
denom = denom.sum() | ||
denom = denom + self.smooth_eps | ||
return denom | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from abc import ABCMeta, abstractmethod | ||
import torch | ||
from jaxtyping import jaxtyped | ||
from typeguard import typechecked as typechecker | ||
|
||
from ..types import WeightTensor | ||
|
||
class WeightDenomBase(metaclass=ABCMeta): | ||
|
||
@abstractmethod | ||
def __call__( | ||
self, | ||
weights:WeightTensor, | ||
is_msk:WeightTensor | ||
) -> WeightTensor: | ||
pass | ||
|
||
|
||
class PassWeightDenom(WeightDenomBase): | ||
|
||
def __call__( | ||
self, | ||
weights:WeightTensor, | ||
is_msk:WeightTensor | ||
) -> WeightTensor: | ||
return weights | ||
|
||
|
||
class AvgWeightDenom(WeightDenomBase): | ||
|
||
def __init__(self) -> None: | ||
"""Weight denominator as per: | ||
- https://arxiv.org/pdf/1604.00923 (DR when weights are IS) | ||
""" | ||
super().__init__() | ||
|
||
@jaxtyped(typechecker=typechecker) | ||
def __call__( | ||
self, | ||
weights:WeightTensor, | ||
is_msk:WeightTensor | ||
) -> WeightTensor: | ||
return weights/torch.tensor(weights.shape[0]) | ||
|
||
|
||
class PiTWeightDenom(WeightDenomBase): | ||
|
||
def __init__( | ||
self, | ||
smooth_eps:float=0.0 | ||
) -> None: | ||
"""Weight denominator as per: | ||
- https://arxiv.org/pdf/1906.03735 (snsis when weights are PD) | ||
- https://arxiv.org/pdf/1604.00923 (WDR when weights are PD) | ||
Args: | ||
smooth_eps (float, optional): Laplacian smoothing. Defaults to 0.0. | ||
""" | ||
super().__init__() | ||
assert isinstance(smooth_eps,float) | ||
self.smooth_eps = smooth_eps | ||
|
||
@jaxtyped(typechecker=typechecker) | ||
def __call__( | ||
self, | ||
weights:WeightTensor, | ||
is_msk:WeightTensor | ||
) -> WeightTensor: | ||
msked_weights = torch.mul(weights, is_msk) | ||
pit_vals = msked_weights.mean(dim=0,keepdim=True).repeat( | ||
weights.shape[0], 1 | ||
) | ||
pit_vals = pit_vals + self.smooth_eps | ||
res = msked_weights/pit_vals | ||
return res | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
from .DirectMethod import DirectMethodBase, D3rlpyQlearnDM | ||
from .DoublyRobust import DREstimator, DR, WDR | ||
from .DoublyRobust import DREstimator | ||
from .IS import ISEstimatorBase, ISEstimator | ||
from .base import OPEEstimatorBase | ||
from .base import OPEEstimatorBase | ||
from .EmpiricalMeanDenom import * | ||
from .WeightDenom import * |
Oops, something went wrong.