Skip to content

Commit

Permalink
updated d3rlpy api and examples
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuaspear committed Jul 25, 2024
1 parent 6800c5b commit b0fc53e
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 48 deletions.
25 changes: 19 additions & 6 deletions examples/d3rlpy_training_api_continuous_stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@

from offline_rl_ope.PropensityModels.torch import FullGuassian, TorchRegTrainer
from offline_rl_ope.api.d3rlpy.Policy import PolicyFactory
from offline_rl_ope.OPEEstimators import (
EmpiricalMeanDenom, PassWeightDenom, WeightedEmpiricalMeanDenom
)


# Import callbacks
from offline_rl_ope.api.d3rlpy.Callbacks import (
Expand Down Expand Up @@ -144,22 +148,31 @@
scorers = {}

scorers.update({"vanilla_is_loss": ISEstimatorScorer(
discount=gamma, cache=is_callback, is_type="vanilla", norm_weights=False)})
discount=gamma, cache=is_callback, is_type="vanilla",
empirical_denom=EmpiricalMeanDenom(), weight_denom=PassWeightDenom()
)})

scorers.update({"pd_is_loss": ISEstimatorScorer(
discount=gamma, cache=is_callback, is_type="per_decision",
norm_weights=False)})
empirical_denom=EmpiricalMeanDenom(), weight_denom=PassWeightDenom()
)})

scorers.update({"vanilla_wis_loss": ISEstimatorScorer(
discount=gamma, cache=is_callback, is_type="vanilla", norm_weights=True)})
discount=gamma, cache=is_callback, is_type="vanilla",
empirical_denom=WeightedEmpiricalMeanDenom(),
weight_denom=PassWeightDenom()
)})

scorers.update({"vanilla_wis_loss_smooth": ISEstimatorScorer(
discount=gamma, cache=is_callback, is_type="vanilla", norm_weights=True,
norm_kwargs={"smooth_eps":0.0000001})})
discount=gamma, cache=is_callback, is_type="vanilla",
empirical_denom=WeightedEmpiricalMeanDenom(smooth_eps=0.0000001),
weight_denom=PassWeightDenom(),
)})

scorers.update({"pd_wis_loss": ISEstimatorScorer(
discount=gamma, cache=is_callback, is_type="per_decision",
norm_weights=True)})
empirical_denom=EmpiricalMeanDenom(), weight_denom=PassWeightDenom()
)})

for scr in fqe_scorers:
scorers.update({scr: QueryScorer(cache=fqe_callback, query_key=scr)})
Expand Down
24 changes: 18 additions & 6 deletions examples/d3rlpy_training_api_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import shutil

from offline_rl_ope.api.d3rlpy.Policy import PolicyFactory
from offline_rl_ope.OPEEstimators import (
EmpiricalMeanDenom, PassWeightDenom, WeightedEmpiricalMeanDenom
)

# Import callbacks
from offline_rl_ope.api.d3rlpy.Callbacks import (
Expand Down Expand Up @@ -115,22 +118,31 @@
scorers = {}

scorers.update({"vanilla_is_loss": ISEstimatorScorer(
discount=gamma, cache=is_callback, is_type="vanilla", norm_weights=False)})
discount=gamma, cache=is_callback, is_type="vanilla",
empirical_denom=EmpiricalMeanDenom(), weight_denom=PassWeightDenom()
)})

scorers.update({"pd_is_loss": ISEstimatorScorer(
discount=gamma, cache=is_callback, is_type="per_decision",
norm_weights=False)})
empirical_denom=EmpiricalMeanDenom(), weight_denom=PassWeightDenom()
)})

scorers.update({"vanilla_wis_loss": ISEstimatorScorer(
discount=gamma, cache=is_callback, is_type="vanilla", norm_weights=True)})
discount=gamma, cache=is_callback, is_type="vanilla",
empirical_denom=WeightedEmpiricalMeanDenom(),
weight_denom=PassWeightDenom()
)})

scorers.update({"vanilla_wis_loss_smooth": ISEstimatorScorer(
discount=gamma, cache=is_callback, is_type="vanilla", norm_weights=True,
norm_kwargs={"smooth_eps":0.0000001})})
discount=gamma, cache=is_callback, is_type="vanilla",
empirical_denom=WeightedEmpiricalMeanDenom(smooth_eps=0.0000001),
weight_denom=PassWeightDenom(),
)})

scorers.update({"pd_wis_loss": ISEstimatorScorer(
discount=gamma, cache=is_callback, is_type="per_decision",
norm_weights=True)})
empirical_denom=EmpiricalMeanDenom(), weight_denom=PassWeightDenom()
)})

for act in unique_pol_acts:
scorers.update(
Expand Down
31 changes: 22 additions & 9 deletions examples/static_torch_deterministic_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from offline_rl_ope.Dataset import ISEpisode
from offline_rl_ope.components.Policy import Policy, GreedyDeterministic
from offline_rl_ope.components.ImportanceSampler import ISWeightOrchestrator
from offline_rl_ope.OPEEstimators import (
ISEstimator, WDR, D3rlpyQlearnDM)
from offline_rl_ope.OPEEstimators import D3rlpyQlearnDM
from offline_rl_ope.api.StandardEstimators import (
VanillaISPDIS, WIS, WDR)
from offline_rl_ope.PropensityModels.torch import FullGuassian, TorchRegTrainer
from offline_rl_ope.LowerBounds.HCOPE import get_lower_bound

Expand Down Expand Up @@ -171,14 +172,16 @@ def __call__(

fqe_dm_model = D3rlpyQlearnDM(model=fqe)

is_estimator = ISEstimator(norm_weights=False, cache_traj_rewards=True)
wis_estimator = ISEstimator(norm_weights=True)
wis_estimator_smooth = ISEstimator(norm_weights=True, norm_kwargs={
"smooth_eps":0.0000001
})
is_estimator = VanillaISPDIS(cache_traj_rewards=True)
wis_estimator = WIS()
wis_estimator_smooth = WIS(smooth_eps=0.0000001)
w_dr_estimator = WDR(
dm_model=fqe_dm_model,
dm_model=fqe_dm_model
)
w_dr_estimator_smooth = WDR(
dm_model=fqe_dm_model, smooth_eps=0.0000001
)



res = is_estimator.predict(
Expand Down Expand Up @@ -229,11 +232,21 @@ def __call__(
print(res)

res = w_dr_estimator.predict(
weights=is_weight_calculator["per_decision"].traj_is_weights,
weights=is_weight_calculator["vanilla"].traj_is_weights,
rewards=[ep.reward for ep in episodes], discount=0.99,
is_msk=is_weight_calculator.weight_msk,
states=[ep.state for ep in episodes],
actions=[ep.action for ep in episodes],
)
print(res)

res = w_dr_estimator_smooth.predict(
weights=is_weight_calculator["vanilla"].traj_is_weights,
rewards=[ep.reward for ep in episodes], discount=0.99,
is_msk=is_weight_calculator.weight_msk,
states=[ep.state for ep in episodes],
actions=[ep.action for ep in episodes],
)
print(res)


27 changes: 19 additions & 8 deletions examples/static_torch_stochastic_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from offline_rl_ope.Dataset import ISEpisode
from offline_rl_ope.components.Policy import Policy
from offline_rl_ope.components.ImportanceSampler import ISWeightOrchestrator
from offline_rl_ope.OPEEstimators import (
ISEstimator, WDR, D3rlpyQlearnDM)
from offline_rl_ope.OPEEstimators import D3rlpyQlearnDM
from offline_rl_ope.api.StandardEstimators import (
VanillaISPDIS, WIS, WDR)
from offline_rl_ope.PropensityModels.torch import FullGuassian, TorchRegTrainer
from offline_rl_ope.LowerBounds.HCOPE import get_lower_bound

Expand Down Expand Up @@ -153,14 +154,15 @@

fqe_dm_model = D3rlpyQlearnDM(model=fqe)

is_estimator = ISEstimator(norm_weights=False, cache_traj_rewards=True)
wis_estimator = ISEstimator(norm_weights=True)
wis_estimator_smooth = ISEstimator(norm_weights=True, norm_kwargs={
"smooth_eps":0.0000001
})
is_estimator = VanillaISPDIS(cache_traj_rewards=True)
wis_estimator = WIS()
wis_estimator_smooth = WIS(smooth_eps=0.0000001)
w_dr_estimator = WDR(
dm_model=fqe_dm_model
)
w_dr_estimator_smooth = WDR(
dm_model=fqe_dm_model, smooth_eps=0.0000001
)


res = is_estimator.predict(
Expand Down Expand Up @@ -211,7 +213,16 @@
print(res)

res = w_dr_estimator.predict(
weights=is_weight_calculator["per_decision"].traj_is_weights,
weights=is_weight_calculator["vanilla"].traj_is_weights,
rewards=[ep.reward for ep in episodes], discount=0.99,
is_msk=is_weight_calculator.weight_msk,
states=[ep.state for ep in episodes],
actions=[ep.action for ep in episodes],
)
print(res)

res = w_dr_estimator_smooth.predict(
weights=is_weight_calculator["vanilla"].traj_is_weights,
rewards=[ep.reward for ep in episodes], discount=0.99,
is_msk=is_weight_calculator.weight_msk,
states=[ep.state for ep in episodes],
Expand Down
28 changes: 19 additions & 9 deletions examples/static_xgboost_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from offline_rl_ope.components.Policy import (
GreedyDeterministic, Policy, NumpyPolicyFuncWrapper)
from offline_rl_ope.components.ImportanceSampler import ISWeightOrchestrator
from offline_rl_ope.OPEEstimators import (
ISEstimator, WDR, D3rlpyQlearnDM)
from offline_rl_ope.OPEEstimators import D3rlpyQlearnDM
from offline_rl_ope.api.StandardEstimators import (
VanillaISPDIS, WIS, WDR)
from offline_rl_ope.PropensityModels.sklearn import (
SklearnDiscrete)
from offline_rl_ope.LowerBounds.HCOPE import get_lower_bound
Expand Down Expand Up @@ -118,15 +119,15 @@

fqe_dm_model = D3rlpyQlearnDM(model=discrete_fqe)

is_estimator = ISEstimator(norm_weights=False, cache_traj_rewards=True)
wis_estimator = ISEstimator(norm_weights=True)
wis_estimator_smooth = ISEstimator(norm_weights=True, norm_kwargs={
"smooth_eps":0.0000001
})
is_estimator = VanillaISPDIS(cache_traj_rewards=True)
wis_estimator = WIS()
wis_estimator_smooth = WIS(smooth_eps=0.0000001)
w_dr_estimator = WDR(
dm_model=fqe_dm_model
)

w_dr_estimator_smooth = WDR(
dm_model=fqe_dm_model, smooth_eps=0.0000001
)

res = is_estimator.predict(
rewards=[ep.reward for ep in episodes], discount=0.99,
Expand Down Expand Up @@ -176,7 +177,16 @@
print(res)

res = w_dr_estimator.predict(
weights=is_weight_calculator["per_decision"].traj_is_weights,
weights=is_weight_calculator["vanilla"].traj_is_weights,
rewards=[ep.reward for ep in episodes], discount=0.99,
is_msk=is_weight_calculator.weight_msk,
states=[ep.state for ep in episodes],
actions=[ep.action for ep in episodes],
)
print(res)

res = w_dr_estimator_smooth.predict(
weights=is_weight_calculator["vanilla"].traj_is_weights,
rewards=[ep.reward for ep in episodes], discount=0.99,
is_msk=is_weight_calculator.weight_msk,
states=[ep.state for ep in episodes],
Expand Down
1 change: 0 additions & 1 deletion src/offline_rl_ope/OPEEstimators/DoublyRobust.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def predict_traj_rewards(
rewards=rewards, discount=discount, h=h)
# weights dim is (n_trajectories, max_length)
weights = self.process_weights(weights=weights, is_msk=is_msk)
print(f"weights:{weights}")
v:List[Float[torch.Tensor, "max_length 1"]] = []
q:List[Float[torch.Tensor, "max_length 1"]] = []
for s,a in zip(states, actions):
Expand Down
1 change: 0 additions & 1 deletion src/offline_rl_ope/api/StandardEstimators/IS.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def __init__(
smooth_eps:float = 0.0
) -> None:
"""_summary_
- https://arxiv.org/pdf/1906.03735 (snis when weights are IS)
- https://arxiv.org/pdf/1906.03735 (snsis when weights are PD)
Args:
clip_weights (bool, optional): _description_. Defaults to False.
Expand Down
24 changes: 16 additions & 8 deletions src/offline_rl_ope/api/d3rlpy/Scorers/IS.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from d3rlpy.metrics import EvaluatorProtocol
from d3rlpy.dataset import ReplayBuffer

from ....OPEEstimators import ISEstimator
from ....OPEEstimators import (
ISEstimator, WeightDenomBase, EmpiricalMeanDenomBase
)
from .base import OPEEstimatorScorerBase
from ..Callbacks.IS import ISCallback

Expand All @@ -19,18 +21,24 @@

class ISEstimatorScorer(OPEEstimatorScorerBase, ISEstimator):

def __init__(self, discount, cache:ISCallback, is_type:str,
norm_weights: bool, clip_weights:bool=False,
clip: float = 0.0, norm_kwargs:Dict[str,Any] = {},
episodes:Optional[Sequence[EpisodeBase]] = None
) -> None:
def __init__(
self,
discount:float,
cache:ISCallback,
is_type:str,
empirical_denom: WeightDenomBase,
weight_denom: EmpiricalMeanDenomBase,
clip_weights:bool=False,
clip: float = 0.0,
episodes:Optional[Sequence[EpisodeBase]] = None
) -> None:
OPEEstimatorScorerBase.__init__(self, cache=cache, episodes=episodes)
ISEstimator.__init__(
self,
norm_weights=norm_weights,
empirical_denom=empirical_denom,
weight_denom=weight_denom,
clip_weights=clip_weights,
clip=clip,
norm_kwargs=norm_kwargs
)
self.is_type = is_type
self.discount = discount
Expand Down

0 comments on commit b0fc53e

Please sign in to comment.