Skip to content

Commit

Permalink
removed uncessary read/write from FQE callback
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuaspear committed Jun 10, 2024
1 parent 6856a8b commit f725b63
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 15 deletions.
6 changes: 3 additions & 3 deletions examples/d3rlpy_training_api_continuous_stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

# Import evaluators
from offline_rl_ope.api.d3rlpy.Scorers import (
ISEstimatorScorer
ISEstimatorScorer, QueryScorer
)
from offline_rl_ope.components.Policy import Policy

Expand Down Expand Up @@ -161,8 +161,8 @@
discount=gamma, cache=is_callback, is_type="per_decision",
norm_weights=True)})

# for scr in fqe_scorers:
# scorers.update({scr: QueryScorer(cache=fqe_callback, query_key=scr)})
for scr in fqe_scorers:
scorers.update({scr: QueryScorer(cache=fqe_callback, query_key=scr)})

epoch_callback = EpochCallbackHandler([is_callback, fqe_callback])

Expand Down
4 changes: 2 additions & 2 deletions examples/d3rlpy_training_api_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@
}
)

# for scr in fqe_scorers:
# scorers.update({scr: QueryScorer(cache=fqe_callback, query_key=scr)})
for scr in fqe_scorers:
scorers.update({scr: QueryScorer(cache=fqe_callback, query_key=scr)})

epoch_callback = EpochCallbackHandler([is_callback, fqe_callback, dva_callback])

Expand Down
12 changes: 2 additions & 10 deletions src/offline_rl_ope/api/d3rlpy/Callbacks/DM.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def run(self, algo: QLearningAlgoProtocol, epoch:int, total_step:int):
_msg = "Must provide n_steps_per_epoch for FQE training"
assert "n_steps_per_epoch" in self.__model_fit_kwargs, _msg

fqe.fit(
res = fqe.fit(
self.__dataset,
evaluators=self.__scorers,
**self.__model_fit_kwargs,
Expand All @@ -83,15 +83,7 @@ def run(self, algo: QLearningAlgoProtocol, epoch:int, total_step:int):
experiment_name=f"EXP_{str(self.__cur_exp)}"
)

res:Dict = {}
for scr in self.__scorers:
__file_path = os.path.join(
self.__logs_loc, "EXP_{}".format(self.__cur_exp),
"{}.csv".format(scr))
lines = np.genfromtxt(__file_path, delimiter=',')
if len(lines.shape) == 1:
lines = lines.reshape(-1,1)
res[scr] = lines[-1:,-1].item()
res = res[-1][1]
self.__cur_exp += 1
self.cache = res

Expand Down

0 comments on commit f725b63

Please sign in to comment.