Skip to content

Commit

Permalink
Merge pull request #19 from hakuhodo-technologies/depreciated
Browse files Browse the repository at this point in the history
Upgrade to scope-rl==0.1.3
  • Loading branch information
aiueola authored Jul 24, 2023
2 parents f80e55e + 72b16a6 commit 138db9e
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 71 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,13 @@ random_ = EpsilonGreedyHead(
evaluation_policies = [cql_, ddqn_, random_]
# create input for the OPE class
prep = CreateOPEInput(
env=env,
logged_dataset=test_logged_dataset,
use_base_model=True, # use model-based prediction
)
input_dict = prep.obtain_whole_inputs(
logged_dataset=test_logged_dataset,
evaluation_policies=evaluation_policies,
env=env,
require_value_prediction=True,
n_trajectories_on_policy_evaluation=100,
random_state=random_state,
)
Expand Down
4 changes: 2 additions & 2 deletions docs/documentation/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,11 @@ and Doubly Robust (DR) :cite:`jiang2016doubly, thomas2016data`.
# create input for OPE class
prep = CreateOPEInput(
env=env,
logged_dataset=test_logged_dataset,
use_base_model=True, # use model-based prediction
)
input_dict = prep.obtain_whole_inputs(
logged_dataset=test_logged_dataset,
evaluation_policies=evaluation_policies,
require_value_prediction=True,
n_trajectories_on_policy_evaluation=100,
random_state=random_state,
)
Expand Down
144 changes: 84 additions & 60 deletions scope_rl/ope/ope.py
Original file line number Diff line number Diff line change
Expand Up @@ -2063,14 +2063,17 @@ def visualize_policy_value_with_multiple_estimates(
)

else:
sns.swarmplot(
data=df,
x="estimator",
y="policy_value",
hue="behavior_policy",
palette=palette,
ax=ax,
)
try:
sns.swarmplot(
data=df,
x="estimator",
y="policy_value",
hue="behavior_policy",
palette=palette,
ax=ax,
)
except:
warn("Encountered NaN values during plot.")

on_policy = policy_value_dict[behavior_policy][eval_policy][
"on_policy"
Expand Down Expand Up @@ -2178,14 +2181,17 @@ def visualize_policy_value_with_multiple_estimates(
)

else:
sns.swarmplot(
data=df,
x="eval_policy",
y="policy_value",
hue="behavior_policy",
palette=palette,
ax=ax,
)
try:
sns.swarmplot(
data=df,
x="eval_policy",
y="policy_value",
hue="behavior_policy",
palette=palette,
ax=ax,
)
except:
warn("Encountered NaN values during plot.")

if visualize_on_policy:
ax.scatter(
Expand Down Expand Up @@ -2273,13 +2279,16 @@ def visualize_policy_value_with_multiple_estimates(
ax=ax,
)
else:
sns.swarmplot(
data=df,
x="estimator",
y="policy_value",
palette=palette,
ax=ax,
)
try:
sns.swarmplot(
data=df,
x="estimator",
y="policy_value",
palette=palette,
ax=ax,
)
except:
warn("Encountered NaN values during plot.")

on_policy = policy_value_dict[eval_policy]["on_policy"]
if on_policy is not None:
Expand Down Expand Up @@ -2358,13 +2367,16 @@ def visualize_policy_value_with_multiple_estimates(
)

else:
sns.swarmplot(
data=df,
x="eval_policy",
y="policy_value",
palette=palette,
ax=ax,
)
try:
sns.swarmplot(
data=df,
x="eval_policy",
y="policy_value",
palette=palette,
ax=ax,
)
except:
warn("Encountered NaN values during plot.")

if visualize_on_policy:
ax.scatter(
Expand Down Expand Up @@ -5535,14 +5547,17 @@ def _visualize_off_policy_estimates_with_multiple_estimates(
)

else:
sns.swarmplot(
data=df,
x="estimator",
y="policy_value",
hue="behavior_policy",
palette=palette,
ax=ax,
)
try:
sns.swarmplot(
data=df,
x="estimator",
y="policy_value",
hue="behavior_policy",
palette=palette,
ax=ax,
)
except:
warn("Encountered NaN values during plot.")

on_policy = estimation_dict[behavior_policy][eval_policy][
"on_policy"
Expand Down Expand Up @@ -5650,14 +5665,17 @@ def _visualize_off_policy_estimates_with_multiple_estimates(
)

else:
sns.swarmplot(
data=df,
x="eval_policy",
y="policy_value",
hue="behavior_policy",
palette=palette,
ax=ax,
)
try:
sns.swarmplot(
data=df,
x="eval_policy",
y="policy_value",
hue="behavior_policy",
palette=palette,
ax=ax,
)
except:
warn("Encountered NaN values during plot.")

if visualize_on_policy:
ax.scatter(
Expand Down Expand Up @@ -5745,13 +5763,16 @@ def _visualize_off_policy_estimates_with_multiple_estimates(
ax=ax,
)
else:
sns.swarmplot(
data=df,
x="estimator",
y="policy_value",
palette=palette,
ax=ax,
)
try:
sns.swarmplot(
data=df,
x="estimator",
y="policy_value",
palette=palette,
ax=ax,
)
except:
warn("Encountered NaN values during plot.")

on_policy = estimation_dict[eval_policy]["on_policy"]
if on_policy is not None:
Expand Down Expand Up @@ -5843,13 +5864,16 @@ def _visualize_off_policy_estimates_with_multiple_estimates(
)

else:
sns.swarmplot(
data=df,
x="eval_policy",
y="policy_value",
palette=palette,
ax=ax,
)
try:
sns.swarmplot(
data=df,
x="eval_policy",
y="policy_value",
palette=palette,
ax=ax,
)
except:
warn("Encountered NaN values during plot.")

if visualize_on_policy:
ax.scatter(
Expand Down
24 changes: 18 additions & 6 deletions scope_rl/policy/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def __post_init__(self):
raise ValueError("random_state must be given")
self.random_ = check_random_state(self.random_state)

def _calc_pscore(self, greedy_action: np.ndarray, action: np.ndarray):
def calc_action_choice_probability(self, greedy_action: np.ndarray, action: np.ndarray):
"""Calculate pscore.
Parameters
Expand Down Expand Up @@ -723,7 +723,7 @@ def sample_action_and_output_pscore(self, x: np.ndarray):
"""
greedy_action = self.base_policy.predict(x)
action = self.sample_action(x)
pscore = self._calc_pscore(greedy_action, action)
pscore = self.calc_action_choice_probability(greedy_action, action)
return action, pscore

def calc_pscore_given_action(self, x: np.ndarray, action: np.ndarray):
Expand All @@ -744,7 +744,7 @@ def calc_pscore_given_action(self, x: np.ndarray, action: np.ndarray):
"""
greedy_action = self.base_policy.predict(x)
return self._calc_pscore(greedy_action, action)
return self.calc_action_choice_probability(greedy_action, action)

def sample_action(self, x: np.ndarray):
"""Sample action.
Expand Down Expand Up @@ -844,7 +844,7 @@ def __post_init__(self):
raise ValueError("random_state must be given")
self.random_ = check_random_state(self.random_state)

def _calc_pscore(self, greedy_action: np.ndarray, action: np.ndarray):
def calc_action_choice_probability(self, greedy_action: np.ndarray, action: np.ndarray):
"""Calculate pscore.
Parameters
Expand Down Expand Up @@ -889,7 +889,7 @@ def sample_action_and_output_pscore(self, x: np.ndarray):
"""
greedy_action = self.base_policy.predict(x)
action = self.sample_action(x)
pscore = self._calc_pscore(greedy_action, action)
pscore = self.calc_action_choice_probability(greedy_action, action)
return action, pscore

def calc_pscore_given_action(self, x: np.ndarray, action: np.ndarray):
Expand All @@ -910,7 +910,7 @@ def calc_pscore_given_action(self, x: np.ndarray, action: np.ndarray):
"""
greedy_action = self.base_policy.predict(x)
return self._calc_pscore(greedy_action, action)
return self.calc_action_choice_probability(greedy_action, action)

def sample_action(self, x: np.ndarray):
"""Sample action.
Expand Down Expand Up @@ -976,6 +976,18 @@ def __post_init__(self):
if not isinstance(self.base_policy, AlgoBase):
raise ValueError("base_policy must be a child class of AlgoBase")

def sample_action_and_output_pscore(self, x: np.ndarray):
"""Only for API consistency."""
pass

def calc_action_choice_probability(self, x: np.ndarray):
"""Only for API consistency."""
pass

def calc_pscore_given_action(self, x: np.ndarray, action: np.ndarray):
"""Only for API consistency."""
pass

def sample_action(self, x: np.ndarray):
"""Sample action.
Expand Down
2 changes: 1 addition & 1 deletion scope_rl/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, Haruka Kiyohara, Ren Kishimoto, HAKUHODO Technologies Inc., and Hanjuku-kaso Co., Ltd. All rights reserved.
# Licensed under the Apache 2.0 License.

__version__ = "0.1.2"
__version__ = "0.1.3"

0 comments on commit 138db9e

Please sign in to comment.