Skip to content

Commit

Permalink
ess calculation returns nan if all weights are 0 since this means the…
Browse files Browse the repository at this point in the history
… data is uninformative
  • Loading branch information
joshuaspear committed Feb 28, 2024
1 parent 53164ba commit d3ed4d9
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions src/offline_rl_ope/Metrics/EffectiveSampleSize.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
import torch

from ..components.ImportanceSampler import ImportanceSampler
import numpy as np

__all__ = [
"EffectiveSampleSize"
]

class EffectiveSampleSize:

def __init__(self, is_obj:ImportanceSampler) -> None:
self.__is_obj = is_obj
class EffectiveSampleSize:

def __ess(self) -> float:
def __init__(self, nan_if_all_0:bool=True) -> None:
self.__nan_if_all_0 = nan_if_all_0

def __ess(self, weights:torch.Tensor) -> float:
# https://victorelvira.github.io/papers/kong92.pdf
weights = self.__is_obj.traj_is_weights.sum(dim=1)
weights = weights.sum(dim=1)
numer = len(weights)
return (numer/(1+torch.var(weights))).item()
w_var = torch.var(weights).item()
if (w_var == 0) and (self.__nan_if_all_0):
res = np.nan
else:
res = (numer/(1+w_var))
return res


def __call__(self) -> float:
return self.__ess()
def __call__(self, weights:torch.Tensor) -> float:
return self.__ess(weights=weights)

0 comments on commit d3ed4d9

Please sign in to comment.