From 9d22a5214df73ab87ca69180c4944e5060267efa Mon Sep 17 00:00:00 2001 From: aiueola Date: Mon, 24 Jul 2023 07:29:05 +0900 Subject: [PATCH 1/2] fix some bugs --- README.md | 5 +- docs/documentation/quickstart.rst | 4 +- scope_rl/ope/ope.py | 144 +++++++++++++++++------------- scope_rl/policy/head.py | 24 +++-- 4 files changed, 107 insertions(+), 70 deletions(-) diff --git a/README.md b/README.md index 557a7cd4..a823fd2a 100644 --- a/README.md +++ b/README.md @@ -280,12 +280,13 @@ random_ = EpsilonGreedyHead( evaluation_policies = [cql_, ddqn_, random_] # create input for the OPE class prep = CreateOPEInput( + env=env, logged_dataset=test_logged_dataset, - use_base_model=True, # use model-based prediction ) input_dict = prep.obtain_whole_inputs( + logged_dataset=test_logged_dataset, evaluation_policies=evaluation_policies, - env=env, + require_value_prediction=True, n_trajectories_on_policy_evaluation=100, random_state=random_state, ) diff --git a/docs/documentation/quickstart.rst b/docs/documentation/quickstart.rst index 7cfcab3b..b44b21ea 100644 --- a/docs/documentation/quickstart.rst +++ b/docs/documentation/quickstart.rst @@ -207,11 +207,11 @@ and Doubly Robust (DR) :cite:`jiang2016doubly, thomas2016data`. # create input for OPE class prep = CreateOPEInput( env=env, - logged_dataset=test_logged_dataset, - use_base_model=True, # use model-based prediction ) input_dict = prep.obtain_whole_inputs( + logged_dataset=test_logged_dataset, evaluation_policies=evaluation_policies, + require_value_prediction=True, n_trajectories_on_policy_evaluation=100, random_state=random_state, ) diff --git a/scope_rl/ope/ope.py b/scope_rl/ope/ope.py index c5a52e43..a3deeef4 100644 --- a/scope_rl/ope/ope.py +++ b/scope_rl/ope/ope.py @@ -2063,14 +2063,17 @@ def visualize_policy_value_with_multiple_estimates( ) else: - sns.swarmplot( - data=df, - x="estimator", - y="policy_value", - hue="behavior_policy", - palette=palette, - ax=ax, - ) + try: + sns.swarmplot( + data=df, + x="estimator", + y="policy_value", + hue="behavior_policy", + palette=palette, + ax=ax, + ) + except: + warn("Encountered NaN values during plot.") on_policy = policy_value_dict[behavior_policy][eval_policy][ "on_policy" @@ -2178,14 +2181,17 @@ def visualize_policy_value_with_multiple_estimates( ) else: - sns.swarmplot( - data=df, - x="eval_policy", - y="policy_value", - hue="behavior_policy", - palette=palette, - ax=ax, - ) + try: + sns.swarmplot( + data=df, + x="eval_policy", + y="policy_value", + hue="behavior_policy", + palette=palette, + ax=ax, + ) + except: + warn("Encountered NaN values during plot.") if visualize_on_policy: ax.scatter( @@ -2273,13 +2279,16 @@ def visualize_policy_value_with_multiple_estimates( ax=ax, ) else: - sns.swarmplot( - data=df, - x="estimator", - y="policy_value", - palette=palette, - ax=ax, - ) + try: + sns.swarmplot( + data=df, + x="estimator", + y="policy_value", + palette=palette, + ax=ax, + ) + except: + warn("Encountered NaN values during plot.") on_policy = policy_value_dict[eval_policy]["on_policy"] if on_policy is not None: @@ -2358,13 +2367,16 @@ def visualize_policy_value_with_multiple_estimates( ) else: - sns.swarmplot( - data=df, - x="eval_policy", - y="policy_value", - palette=palette, - ax=ax, - ) + try: + sns.swarmplot( + data=df, + x="eval_policy", + y="policy_value", + palette=palette, + ax=ax, + ) + except: + warn("Encountered NaN values during plot.") if visualize_on_policy: ax.scatter( @@ -5535,14 +5547,17 @@ def _visualize_off_policy_estimates_with_multiple_estimates( ) else: - sns.swarmplot( - data=df, - x="estimator", - y="policy_value", - hue="behavior_policy", - palette=palette, - ax=ax, - ) + try: + sns.swarmplot( + data=df, + x="estimator", + y="policy_value", + hue="behavior_policy", + palette=palette, + ax=ax, + ) + except: + warn("Encountered NaN values during plot.") on_policy = estimation_dict[behavior_policy][eval_policy][ "on_policy" @@ -5650,14 +5665,17 @@ def _visualize_off_policy_estimates_with_multiple_estimates( ) else: - sns.swarmplot( - data=df, - x="eval_policy", - y="policy_value", - hue="behavior_policy", - palette=palette, - ax=ax, - ) + try: + sns.swarmplot( + data=df, + x="eval_policy", + y="policy_value", + hue="behavior_policy", + palette=palette, + ax=ax, + ) + except: + warn("Encountered NaN values during plot.") if visualize_on_policy: ax.scatter( @@ -5745,13 +5763,16 @@ def _visualize_off_policy_estimates_with_multiple_estimates( ax=ax, ) else: - sns.swarmplot( - data=df, - x="estimator", - y="policy_value", - palette=palette, - ax=ax, - ) + try: + sns.swarmplot( + data=df, + x="estimator", + y="policy_value", + palette=palette, + ax=ax, + ) + except: + warn("Encountered NaN values during plot.") on_policy = estimation_dict[eval_policy]["on_policy"] if on_policy is not None: @@ -5843,13 +5864,16 @@ def _visualize_off_policy_estimates_with_multiple_estimates( ) else: - sns.swarmplot( - data=df, - x="eval_policy", - y="policy_value", - palette=palette, - ax=ax, - ) + try: + sns.swarmplot( + data=df, + x="eval_policy", + y="policy_value", + palette=palette, + ax=ax, + ) + except: + warn("Encountered NaN values during plot.") if visualize_on_policy: ax.scatter( diff --git a/scope_rl/policy/head.py b/scope_rl/policy/head.py index 71acc6c6..b4463fa7 100644 --- a/scope_rl/policy/head.py +++ b/scope_rl/policy/head.py @@ -680,7 +680,7 @@ def __post_init__(self): raise ValueError("random_state must be given") self.random_ = check_random_state(self.random_state) - def _calc_pscore(self, greedy_action: np.ndarray, action: np.ndarray): + def calc_action_choice_probability(self, greedy_action: np.ndarray, action: np.ndarray): """Calculate pscore. Parameters @@ -723,7 +723,7 @@ def sample_action_and_output_pscore(self, x: np.ndarray): """ greedy_action = self.base_policy.predict(x) action = self.sample_action(x) - pscore = self._calc_pscore(greedy_action, action) + pscore = self.calc_action_choice_probability(greedy_action, action) return action, pscore def calc_pscore_given_action(self, x: np.ndarray, action: np.ndarray): @@ -744,7 +744,7 @@ def calc_pscore_given_action(self, x: np.ndarray, action: np.ndarray): """ greedy_action = self.base_policy.predict(x) - return self._calc_pscore(greedy_action, action) + return self.calc_action_choice_probability(greedy_action, action) def sample_action(self, x: np.ndarray): """Sample action. @@ -844,7 +844,7 @@ def __post_init__(self): raise ValueError("random_state must be given") self.random_ = check_random_state(self.random_state) - def _calc_pscore(self, greedy_action: np.ndarray, action: np.ndarray): + def calc_action_choice_probability(self, greedy_action: np.ndarray, action: np.ndarray): """Calculate pscore. Parameters @@ -889,7 +889,7 @@ def sample_action_and_output_pscore(self, x: np.ndarray): """ greedy_action = self.base_policy.predict(x) action = self.sample_action(x) - pscore = self._calc_pscore(greedy_action, action) + pscore = self.calc_action_choice_probability(greedy_action, action) return action, pscore def calc_pscore_given_action(self, x: np.ndarray, action: np.ndarray): @@ -910,7 +910,7 @@ def calc_pscore_given_action(self, x: np.ndarray, action: np.ndarray): """ greedy_action = self.base_policy.predict(x) - return self._calc_pscore(greedy_action, action) + return self.calc_action_choice_probability(greedy_action, action) def sample_action(self, x: np.ndarray): """Sample action. @@ -976,6 +976,18 @@ def __post_init__(self): if not isinstance(self.base_policy, AlgoBase): raise ValueError("base_policy must be a child class of AlgoBase") + def sample_action_and_output_pscore(self, x: np.ndarray): + """Only for API consistency.""" + pass + + def calc_action_choice_probability(self, x: np.ndarray): + """Only for API consistency.""" + pass + + def calc_pscore_given_action(self, x: np.ndarray, action: np.ndarray): + """Only for API consistency.""" + pass + def sample_action(self, x: np.ndarray): """Sample action. From 72b16a6a53039f75c190c7a298298551ca1f36eb Mon Sep 17 00:00:00 2001 From: aiueola Date: Mon, 24 Jul 2023 07:32:06 +0900 Subject: [PATCH 2/2] upgrade to scope-rl==0.1.3 --- scope_rl/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scope_rl/version.py b/scope_rl/version.py index 349477cf..75731ad8 100644 --- a/scope_rl/version.py +++ b/scope_rl/version.py @@ -1,4 +1,4 @@ # Copyright (c) 2023, Haruka Kiyohara, Ren Kishimoto, HAKUHODO Technologies Inc., and Hanjuku-kaso Co., Ltd. All rights reserved. # Licensed under the Apache 2.0 License. -__version__ = "0.1.2" +__version__ = "0.1.3"