diff --git a/src/offline_rl_ope/components/Policy.py b/src/offline_rl_ope/components/Policy.py index e9973f2..d05154a 100644 --- a/src/offline_rl_ope/components/Policy.py +++ b/src/offline_rl_ope/components/Policy.py @@ -167,7 +167,7 @@ def __call__( action_prs = self.postproc_tens(p_return.action_prs) self.collect_res_fn(action_prs) self.collect_act_func(actions) - return p_return.action_prs + return action_prs class GreedyDeterministic(BasePolicy):