From 6faca7838f48762d62c26741f50d95ed13911736 Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Sun, 29 Sep 2024 11:16:01 +0100 Subject: [PATCH] updated vwp metric --- src/offline_rl_ope/Metrics/ValidWeightsProp.py | 9 +++------ tests/Metrics/test_ValidWeightsProp.py | 5 ++++- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/offline_rl_ope/Metrics/ValidWeightsProp.py b/src/offline_rl_ope/Metrics/ValidWeightsProp.py index e2c1810..91f7c63 100644 --- a/src/offline_rl_ope/Metrics/ValidWeightsProp.py +++ b/src/offline_rl_ope/Metrics/ValidWeightsProp.py @@ -26,13 +26,10 @@ def __valid_weights( weights:WeightTensor, weight_msk:WeightTensor ) -> float: - fnl_weights = get_traj_weight_final( - weights=weights, - is_msk=weight_msk - ) + sum_weights = torch.mul(weights,weight_msk).sum(dim=1) vw_mask = ( - (fnl_weights > self.__min_w) & - (fnl_weights < self.__max_w) + (sum_weights > self.__min_w) & + (sum_weights < self.__max_w) ).squeeze() return torch.mean(vw_mask.float()).item() diff --git a/tests/Metrics/test_ValidWeightsProp.py b/tests/Metrics/test_ValidWeightsProp.py index 6c691d1..7b97f31 100644 --- a/tests/Metrics/test_ValidWeightsProp.py +++ b/tests/Metrics/test_ValidWeightsProp.py @@ -17,7 +17,10 @@ def test_call(self): min_val=0.000001 fnl_weights = [] for idx,i in enumerate(self.test_conf.traj_lengths): - fnl_weights.append(self.test_conf.weight_test_res[idx,i-1][None]) + fnl_weights.append(self.test_conf.weight_test_res[idx,:i-1].sum( + dim=0, + keepdim=True + )) fnl_weights_tens = torch.concat(fnl_weights, axis=0) num = (fnl_weights_tens > min_val) & (fnl_weights_tens < max_val) num = torch.sum(num, axis=0)