Skip to content

Commit

Permalink
updated vwp metric
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuaspear committed Sep 29, 2024
1 parent 12462f1 commit 6faca78
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
9 changes: 3 additions & 6 deletions src/offline_rl_ope/Metrics/ValidWeightsProp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,10 @@ def __valid_weights(
weights:WeightTensor,
weight_msk:WeightTensor
) -> float:
fnl_weights = get_traj_weight_final(
weights=weights,
is_msk=weight_msk
)
sum_weights = torch.mul(weights,weight_msk).sum(dim=1)
vw_mask = (
(fnl_weights > self.__min_w) &
(fnl_weights < self.__max_w)
(sum_weights > self.__min_w) &
(sum_weights < self.__max_w)
).squeeze()
return torch.mean(vw_mask.float()).item()

Expand Down
5 changes: 4 additions & 1 deletion tests/Metrics/test_ValidWeightsProp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ def test_call(self):
min_val=0.000001
fnl_weights = []
for idx,i in enumerate(self.test_conf.traj_lengths):
fnl_weights.append(self.test_conf.weight_test_res[idx,i-1][None])
fnl_weights.append(self.test_conf.weight_test_res[idx,:i-1].sum(
dim=0,
keepdim=True
))
fnl_weights_tens = torch.concat(fnl_weights, axis=0)
num = (fnl_weights_tens > min_val) & (fnl_weights_tens < max_val)
num = torch.sum(num, axis=0)
Expand Down

0 comments on commit 6faca78

Please sign in to comment.