generated from joshuaspear/python-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fixed bug in behavPolicy testing. Added metrics and tests
- Loading branch information
1 parent
6bb313b
commit 0238094
Showing
10 changed files
with
144 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torch | ||
|
||
from ..components.ImportanceSampler import ImportanceSampler | ||
|
||
__all__ = [ | ||
"EffectiveSampleSize" | ||
] | ||
|
||
class EffectiveSampleSize: | ||
|
||
def __init__(self, is_obj:ImportanceSampler) -> None: | ||
self.__is_obj = is_obj | ||
|
||
def __ess(self) -> float: | ||
numer = torch.sum(torch.pow(self.__is_obj.traj_is_weights,2)) | ||
denom = torch.pow(torch.sum(self.__is_obj.traj_is_weights),2) | ||
return (numer/denom).item() | ||
|
||
|
||
def __call__(self) -> float: | ||
return self.__ess() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import torch | ||
|
||
from ..components.ImportanceSampler import ImportanceSampler | ||
|
||
__all__ = [ | ||
"ValidWeightsProp" | ||
] | ||
|
||
class ValidWeightsProp: | ||
|
||
def __init__( | ||
self, | ||
is_obj:ImportanceSampler, | ||
min_w:float, | ||
max_w:float | ||
) -> None: | ||
self.__is_obj = is_obj | ||
self.__min_w = min_w | ||
self.__max_w = max_w | ||
|
||
def __valid_weights(self) -> float: | ||
vw_mask = ( | ||
(self.__is_obj.traj_is_weights > self.__min_w) & | ||
(self.__is_obj.traj_is_weights < self.__max_w) | ||
) | ||
vw_num = torch.sum(vw_mask, axis=1) | ||
vw_denom = torch.sum(self.__is_obj.is_weight_calc.weight_msk, axis=1) | ||
return torch.mean(vw_num/vw_denom).item() | ||
|
||
def __call__(self) -> float: | ||
return self.__valid_weights() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .EffectiveSampleSize import * | ||
from .ValidWeightsProp import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = "3.0.3" | ||
__version__ = "4.0.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import unittest | ||
import torch | ||
import logging | ||
import numpy as np | ||
import copy | ||
from offline_rl_ope.Metrics import EffectiveSampleSize | ||
from ..base import weight_test_res | ||
|
||
logger = logging.getLogger("offline_rl_ope") | ||
|
||
class TestImportanceSampler: | ||
|
||
def __init__(self) -> None: | ||
self.is_weight_calc = None | ||
self.traj_is_weights = weight_test_res | ||
|
||
|
||
class EffectiveSampleSizeTest(unittest.TestCase): | ||
|
||
def test_call(self): | ||
num = torch.sum(torch.pow(weight_test_res,2)) | ||
denum = torch.pow(torch.sum(weight_test_res),2) | ||
act_res = (num/denum).item() | ||
metric = EffectiveSampleSize(is_obj=TestImportanceSampler()) | ||
pred_res = metric() | ||
self.assertEqual(act_res,pred_res) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import unittest | ||
import torch | ||
import logging | ||
import numpy as np | ||
import copy | ||
from offline_rl_ope.Metrics import ValidWeightsProp | ||
from ..base import weight_test_res, msk_test_res | ||
|
||
logger = logging.getLogger("offline_rl_ope") | ||
|
||
|
||
class TestImportanceCalc: | ||
|
||
def __init__(self) -> None: | ||
self.weight_msk = msk_test_res | ||
|
||
class TestImportanceSampler: | ||
|
||
def __init__(self) -> None: | ||
self.is_weight_calc = None | ||
self.traj_is_weights = weight_test_res | ||
self.is_weight_calc = TestImportanceCalc() | ||
|
||
class TestValidWeightsProp(unittest.TestCase): | ||
|
||
def test_call(self): | ||
max_val=10000 | ||
min_val=0.000001 | ||
num = (weight_test_res > min_val) & (weight_test_res < max_val) | ||
num = torch.sum(num, axis=1) | ||
denum = torch.sum(msk_test_res, axis=1) | ||
act_res = torch.mean(num/denum).item() | ||
metric = ValidWeightsProp( | ||
is_obj=TestImportanceSampler(), | ||
max_w=max_val, | ||
min_w=min_val | ||
) | ||
pred_res = metric() | ||
self.assertEqual(act_res,pred_res) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters