Skip to content

Commit

Permalink
type checking and removed linear mixed policy since it was out dated
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuaspear committed Feb 26, 2024
1 parent 3f23e3b commit ab9cdd0
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 49 deletions.
8 changes: 4 additions & 4 deletions src/offline_rl_ope/OPEEstimators/DirectMethod.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ def __init__(self, model: Callable) -> None:
self.model = model

@abstractmethod
def get_v(self, state:torch.Tensor):
def get_v(self, state:torch.Tensor) -> torch.Tensor:
pass

@abstractmethod
def get_q(self, state:torch.Tensor, action:torch.Tensor):
def get_q(self, state:torch.Tensor, action:torch.Tensor) -> torch.Tensor:
pass


Expand All @@ -22,12 +22,12 @@ class D3rlpyQlearnDM(DirectMethodBase):
def __init__(self, model:QLearningAlgoBase) -> None:
super().__init__(model=model)

def get_q(self, state:torch.Tensor, action:torch.Tensor):
def get_q(self, state:torch.Tensor, action:torch.Tensor) -> torch.Tensor:
values = torch.tensor(self.model.predict_value(
x=state.numpy(), action=action.numpy()))
return values

def get_v(self, state:torch.Tensor):
def get_v(self, state:torch.Tensor) -> torch.Tensor:
state = state.numpy()
actions = self.model.predict(state)
values = torch.tensor(self.model.predict_value(
Expand Down
14 changes: 9 additions & 5 deletions src/offline_rl_ope/OPEEstimators/IS.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import logging
import torch
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List

from .utils import (WISNormWeights, NormWeightsPass, clip_weights_pass as cwp,
clip_weights as cw)
from .utils import (
WISNormWeights, NormWeightsPass, WeightNorm,
clip_weights_pass as cwp,
clip_weights as cw
)
from .base import OPEEstimatorBase

logger = logging.getLogger("offline_rl_ope")
Expand All @@ -20,9 +23,10 @@ def __init__(
) -> None:
super().__init__(cache_traj_rewards)
if norm_weights:
self.norm_weights = WISNormWeights(**norm_kwargs)
_norm_weights = WISNormWeights(**norm_kwargs)
else:
self.norm_weights = NormWeightsPass(**norm_kwargs)
_norm_weights = NormWeightsPass(**norm_kwargs)
self.norm_weights:WeightNorm = _norm_weights
self.clip = clip
if clip_weights:
self.clip_weights = cw
Expand Down
9 changes: 4 additions & 5 deletions src/offline_rl_ope/PropensityModels/torch/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def predict(
"""
x = self.input_setup(x)
self.estimator.eval()
res = self.estimator(x)
propense_res = self.estimator(x)
# Take max over values
res = torch.argmax(res["out"], dim=1, keepdim=False)
res = torch.argmax(propense_res["out"], dim=1, keepdim=False)
return res

def predict_proba(
Expand Down Expand Up @@ -152,9 +152,8 @@ def predict(
) -> torch.Tensor:
x = self.input_setup(x)
self.estimator.eval()
res = self.estimator(x)
res = res["loc"]
return res
propense_res = self.estimator(x)
return propense_res["loc"]

def predict_proba(
self,
Expand Down
6 changes: 3 additions & 3 deletions src/offline_rl_ope/PropensityModels/torch/models/Discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def __init__(
def forward(self, x) -> PropensityTorchOutputType:
for layer in self.layers:
x = layer(x)
out = []
out:List[torch.Tensor] = []
for head in self.out_layers:
out_val = head(x)
out.append(self.out_actvton(out_val)[:,:,None])
out = torch.concat(out, dim=2)
return {"out": out}
res = torch.concat(out, dim=2)
return {"out": res}
3 changes: 2 additions & 1 deletion src/offline_rl_ope/api/d3rlpy/Callbacks/Misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def run(self, algo: QLearningAlgoProtocol, epoch:int, total_step:int):
"values": total_values,
"actions":total_actions
})
res = res.groupby(by="actions", as_index=False)["values"].mean()
res = res.groupby(by="actions")["values"].mean()
res = res.reset_index(drop=False)
res_dict = {key:val for key,val in zip(res["actions"], res["values"])}
res_dict = {
key: (res_dict[key] if key in res_dict.keys() else np.nan)
Expand Down
6 changes: 3 additions & 3 deletions src/offline_rl_ope/api/d3rlpy/Misc.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import Callable
import torch
from .types import D3rlpyAlgoPredictProtocal

__all__ = ["D3RlPyTorchAlgoPredict"]

class D3RlPyTorchAlgoPredict:

def __init__(self, predict_func:Callable):
def __init__(self, predict_func:D3rlpyAlgoPredictProtocal):
self.predict_func = predict_func

def __call__(self, x:torch.Tensor):
pred = self.predict_func(x.cpu().numpy())
pred = self.predict_func(x.numpy())
return torch.Tensor(pred)
11 changes: 11 additions & 0 deletions src/offline_rl_ope/api/d3rlpy/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import Any, Sequence, Union, Protocol
import numpy.typing as npt

NDArray = npt.NDArray[Any]
Observation = Union[NDArray, Sequence[NDArray]]


class D3rlpyAlgoPredictProtocal(Protocol):

def __call__(self, x:Observation) -> NDArray:
...
6 changes: 3 additions & 3 deletions src/offline_rl_ope/components/ImportanceSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
class ISWeightCalculator:
def __init__(self, behav_policy:Policy) -> None:
self.__behav_policy = behav_policy
self.is_weights = torch.empty(0)
self.weight_msk = torch.empty(0)
self.policy_actions = torch.empty(0)
self.is_weights:torch.Tensor = torch.empty(0)
self.weight_msk:torch.Tensor = torch.empty(0)
self.policy_actions:List[torch.Tensor] = [torch.empty(0)]

def get_traj_w(self, states:torch.Tensor, actions:torch.Tensor,
eval_policy:Policy)->torch.Tensor:
Expand Down
28 changes: 3 additions & 25 deletions src/offline_rl_ope/components/Policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def preproc_cuda(x:torch.Tensor)->torch.Tensor:


__all__ = [
"Policy", "GreedyDeterministic", "BehavPolicy", "LinearMixedPolicy"
"Policy", "GreedyDeterministic", "BehavPolicy"
]

class Policy(metaclass=ABCMeta):
Expand All @@ -33,8 +33,8 @@ def __init__(
collect_act (bool, optional): _description_. Defaults to False.
"""
self.policy_func = policy_func
self.policy_predictions = []
self.policy_actions = []
self.policy_predictions:List[torch.Tensor] = []
self.policy_actions:List[torch.Tensor] = []
if collect_res:
self.collect_res_fn = self.__cllct_res_true
else:
Expand Down Expand Up @@ -125,25 +125,3 @@ def __call__(self, state: torch.Tensor, action: torch.Tensor)->torch.Tensor:
self.collect_res_fn(res)
return res

class LinearMixedPolicy:

def __init__(self, policy_funcs:List[Policy],
mixing_params:torch.Tensor) -> None:
if sum(mixing_params) != 1:
raise Exception("Mixing params must equal 1")
self.__policy_funcs = policy_funcs
self.__mixing_params = mixing_params
self.__policy_predictions = []

@property
def policy_predictions(self):
return self.__policy_predictions

def __call__(self, state: torch.Tensor, action: torch.Tensor)->torch.Tensor:
res = []
for pol in self.__policy_funcs:
pol_out = pol(state=state, action=action)
res.append(pol_out)
res = torch.cat(res, dim=1)
self.__policy_predictions.append(res)
return torch.sum(res*self.__mixing_params, dim=1, keepdim=True)

0 comments on commit ab9cdd0

Please sign in to comment.