Skip to content

Commit

Permalink
fixed static example, updated meta data and bumped version
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuaspear committed Oct 9, 2023
1 parent 51df5a4 commit f00c8cf
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 20 deletions.
53 changes: 37 additions & 16 deletions examples/static.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from d3rlpy.algos import DQN
from d3rlpy.algos import DQNConfig
import pickle
from d3rlpy.datasets import get_cartpole
import numpy as np
from sklearn.multioutput import MultiOutputClassifier
from sklearn.multiclass import OneVsRestClassifier
from d3rlpy.ope import DiscreteFQE
from d3rlpy.metrics.scorer import (soft_opc_scorer,
initial_state_value_estimation_scorer)
from d3rlpy.ope import DiscreteFQE, FQEConfig
from d3rlpy.metrics import (SoftOPCEvaluator,
InitialStateValueEstimationEvaluator)
from d3rlpy.dataset import BasicTransitionPicker
from xgboost import XGBClassifier
import math
import torch

from offline_rl_ope.Dataset import ISEpisode
Expand All @@ -25,7 +25,7 @@

# setup algorithm
gamma = 0.99
dqn = DQN(gamma=gamma, target_update_interval=100)
dqn = DQNConfig(gamma=gamma, target_update_interval=100).create()

unique_pol_acts = np.arange(0,env.action_space.n)

Expand Down Expand Up @@ -54,25 +54,46 @@ def eval_pdf(self, indep_vals:np.array, dep_vals:np.array):
objective="binary:logistic")))

# Fit the behaviour model
behav_est.fit(X=dataset.observations, Y=dataset.actions.reshape(-1,1))
observations = []
actions = []
tp = BasicTransitionPicker()
for ep in dataset.episodes:
for i in range(ep.transition_count):
_transition = tp(ep,i)
observations.append(_transition.observation.reshape(1,-1))
actions.append(_transition.action)

observations = np.concatenate(observations)
actions = np.concatenate(actions)

behav_est.fit(X=observations, Y=actions.reshape(-1,1))

gbt_est = GbtEst(estimator=behav_est)
gbt_policy_be = BehavPolicy(policy_class=gbt_est, collect_res=False)

dqn.fit(dataset.episodes, n_epochs=1)
no_obs_steps = int(len(actions)*0.025)
n_epochs=1
n_steps_per_epoch = no_obs_steps
n_steps = no_obs_steps*n_epochs
dqn.fit(dataset, n_steps=n_steps, n_steps_per_epoch=n_steps_per_epoch,
with_timestamp=False)

fqe_scorers = {
"soft_opc": soft_opc_scorer(70),
"init_state_val": initial_state_value_estimation_scorer
"soft_opc": SoftOPCEvaluator(
return_threshold=70,
episodes=dataset.episodes
),
"init_state_val": InitialStateValueEstimationEvaluator(
episodes=dataset.episodes
)
}

fqe_init_kwargs = {"use_gpu": False, "discrete_action": True,
"q_func_factory": 'mean', "learning_rate": 1e-4
}
discrete_fqe = DiscreteFQE(algo=dqn, **fqe_init_kwargs)

discrete_fqe.fit(dataset.episodes, eval_episodes=dataset.episodes,
scorers=fqe_scorers, n_epochs=1)
fqe_config = FQEConfig(learning_rate=1e-4)
#discrete_fqe = DiscreteFQE(algo=dqn, **fqe_init_kwargs)
discrete_fqe = DiscreteFQE(algo=dqn, config=fqe_config, device=False)

discrete_fqe.fit(dataset, evaluators=fqe_scorers, n_steps=no_obs_steps)


# Static OPE evaluation
Expand Down
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
license='MIT',
classifiers=[],
package_dir={"": "src"},
python_requires="",
python_requires=">=3.11",
install_requires=[
#"d3rlpy @ git+https://github.com/takuseno/d3rlpy.git"
"d3rlpy==2.0.4"
"d3rlpy>=2.0.4"
],
)
2 changes: 1 addition & 1 deletion src/offline_rl_ope/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.0.0"
__version__ = "3.0.1"

0 comments on commit f00c8cf

Please sign in to comment.