From f725b63bd7b24d15ccdca8b9bbb2f4710667cb08 Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Mon, 10 Jun 2024 17:35:43 +0100 Subject: [PATCH] removed uncessary read/write from FQE callback --- .../d3rlpy_training_api_continuous_stochastic.py | 6 +++--- examples/d3rlpy_training_api_discrete.py | 4 ++-- src/offline_rl_ope/api/d3rlpy/Callbacks/DM.py | 12 ++---------- 3 files changed, 7 insertions(+), 15 deletions(-) diff --git a/examples/d3rlpy_training_api_continuous_stochastic.py b/examples/d3rlpy_training_api_continuous_stochastic.py index 7effe36..e72d07f 100644 --- a/examples/d3rlpy_training_api_continuous_stochastic.py +++ b/examples/d3rlpy_training_api_continuous_stochastic.py @@ -29,7 +29,7 @@ # Import evaluators from offline_rl_ope.api.d3rlpy.Scorers import ( - ISEstimatorScorer + ISEstimatorScorer, QueryScorer ) from offline_rl_ope.components.Policy import Policy @@ -161,8 +161,8 @@ discount=gamma, cache=is_callback, is_type="per_decision", norm_weights=True)}) - # for scr in fqe_scorers: - # scorers.update({scr: QueryScorer(cache=fqe_callback, query_key=scr)}) + for scr in fqe_scorers: + scorers.update({scr: QueryScorer(cache=fqe_callback, query_key=scr)}) epoch_callback = EpochCallbackHandler([is_callback, fqe_callback]) diff --git a/examples/d3rlpy_training_api_discrete.py b/examples/d3rlpy_training_api_discrete.py index 2ae1241..94dd738 100644 --- a/examples/d3rlpy_training_api_discrete.py +++ b/examples/d3rlpy_training_api_discrete.py @@ -146,8 +146,8 @@ } ) - # for scr in fqe_scorers: - # scorers.update({scr: QueryScorer(cache=fqe_callback, query_key=scr)}) + for scr in fqe_scorers: + scorers.update({scr: QueryScorer(cache=fqe_callback, query_key=scr)}) epoch_callback = EpochCallbackHandler([is_callback, fqe_callback, dva_callback]) diff --git a/src/offline_rl_ope/api/d3rlpy/Callbacks/DM.py b/src/offline_rl_ope/api/d3rlpy/Callbacks/DM.py index 4c2de26..aced7fd 100644 --- a/src/offline_rl_ope/api/d3rlpy/Callbacks/DM.py +++ b/src/offline_rl_ope/api/d3rlpy/Callbacks/DM.py @@ -74,7 +74,7 @@ def run(self, algo: QLearningAlgoProtocol, epoch:int, total_step:int): _msg = "Must provide n_steps_per_epoch for FQE training" assert "n_steps_per_epoch" in self.__model_fit_kwargs, _msg - fqe.fit( + res = fqe.fit( self.__dataset, evaluators=self.__scorers, **self.__model_fit_kwargs, @@ -83,15 +83,7 @@ def run(self, algo: QLearningAlgoProtocol, epoch:int, total_step:int): experiment_name=f"EXP_{str(self.__cur_exp)}" ) - res:Dict = {} - for scr in self.__scorers: - __file_path = os.path.join( - self.__logs_loc, "EXP_{}".format(self.__cur_exp), - "{}.csv".format(scr)) - lines = np.genfromtxt(__file_path, delimiter=',') - if len(lines.shape) == 1: - lines = lines.reshape(-1,1) - res[scr] = lines[-1:,-1].item() + res = res[-1][1] self.__cur_exp += 1 self.cache = res