diff --git a/src/offline_rl_ope/OPEEstimators/DirectMethod.py b/src/offline_rl_ope/OPEEstimators/DirectMethod.py index 980a9fb..98d193c 100644 --- a/src/offline_rl_ope/OPEEstimators/DirectMethod.py +++ b/src/offline_rl_ope/OPEEstimators/DirectMethod.py @@ -9,11 +9,11 @@ def __init__(self, model: Callable) -> None: self.model = model @abstractmethod - def get_v(self, state:torch.Tensor): + def get_v(self, state:torch.Tensor) -> torch.Tensor: pass @abstractmethod - def get_q(self, state:torch.Tensor, action:torch.Tensor): + def get_q(self, state:torch.Tensor, action:torch.Tensor) -> torch.Tensor: pass @@ -22,12 +22,12 @@ class D3rlpyQlearnDM(DirectMethodBase): def __init__(self, model:QLearningAlgoBase) -> None: super().__init__(model=model) - def get_q(self, state:torch.Tensor, action:torch.Tensor): + def get_q(self, state:torch.Tensor, action:torch.Tensor) -> torch.Tensor: values = torch.tensor(self.model.predict_value( x=state.numpy(), action=action.numpy())) return values - def get_v(self, state:torch.Tensor): + def get_v(self, state:torch.Tensor) -> torch.Tensor: state = state.numpy() actions = self.model.predict(state) values = torch.tensor(self.model.predict_value( diff --git a/src/offline_rl_ope/OPEEstimators/IS.py b/src/offline_rl_ope/OPEEstimators/IS.py index 9dc2966..8b207c4 100644 --- a/src/offline_rl_ope/OPEEstimators/IS.py +++ b/src/offline_rl_ope/OPEEstimators/IS.py @@ -1,9 +1,12 @@ import logging import torch -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List -from .utils import (WISNormWeights, NormWeightsPass, clip_weights_pass as cwp, - clip_weights as cw) +from .utils import ( + WISNormWeights, NormWeightsPass, WeightNorm, + clip_weights_pass as cwp, + clip_weights as cw + ) from .base import OPEEstimatorBase logger = logging.getLogger("offline_rl_ope") @@ -20,9 +23,10 @@ def __init__( ) -> None: super().__init__(cache_traj_rewards) if norm_weights: - self.norm_weights = WISNormWeights(**norm_kwargs) + _norm_weights = WISNormWeights(**norm_kwargs) else: - self.norm_weights = NormWeightsPass(**norm_kwargs) + _norm_weights = NormWeightsPass(**norm_kwargs) + self.norm_weights:WeightNorm = _norm_weights self.clip = clip if clip_weights: self.clip_weights = cw diff --git a/src/offline_rl_ope/PropensityModels/torch/Trainer.py b/src/offline_rl_ope/PropensityModels/torch/Trainer.py index 4d9f072..d3d17d4 100644 --- a/src/offline_rl_ope/PropensityModels/torch/Trainer.py +++ b/src/offline_rl_ope/PropensityModels/torch/Trainer.py @@ -97,9 +97,9 @@ def predict( """ x = self.input_setup(x) self.estimator.eval() - res = self.estimator(x) + propense_res = self.estimator(x) # Take max over values - res = torch.argmax(res["out"], dim=1, keepdim=False) + res = torch.argmax(propense_res["out"], dim=1, keepdim=False) return res def predict_proba( @@ -152,9 +152,8 @@ def predict( ) -> torch.Tensor: x = self.input_setup(x) self.estimator.eval() - res = self.estimator(x) - res = res["loc"] - return res + propense_res = self.estimator(x) + return propense_res["loc"] def predict_proba( self, diff --git a/src/offline_rl_ope/PropensityModels/torch/models/Discrete.py b/src/offline_rl_ope/PropensityModels/torch/models/Discrete.py index 791a6a1..0bab09a 100644 --- a/src/offline_rl_ope/PropensityModels/torch/models/Discrete.py +++ b/src/offline_rl_ope/PropensityModels/torch/models/Discrete.py @@ -36,9 +36,9 @@ def __init__( def forward(self, x) -> PropensityTorchOutputType: for layer in self.layers: x = layer(x) - out = [] + out:List[torch.Tensor] = [] for head in self.out_layers: out_val = head(x) out.append(self.out_actvton(out_val)[:,:,None]) - out = torch.concat(out, dim=2) - return {"out": out} + res = torch.concat(out, dim=2) + return {"out": res} diff --git a/src/offline_rl_ope/api/d3rlpy/Callbacks/Misc.py b/src/offline_rl_ope/api/d3rlpy/Callbacks/Misc.py index 65d7dcd..145b524 100644 --- a/src/offline_rl_ope/api/d3rlpy/Callbacks/Misc.py +++ b/src/offline_rl_ope/api/d3rlpy/Callbacks/Misc.py @@ -49,7 +49,8 @@ def run(self, algo: QLearningAlgoProtocol, epoch:int, total_step:int): "values": total_values, "actions":total_actions }) - res = res.groupby(by="actions", as_index=False)["values"].mean() + res = res.groupby(by="actions")["values"].mean() + res = res.reset_index(drop=False) res_dict = {key:val for key,val in zip(res["actions"], res["values"])} res_dict = { key: (res_dict[key] if key in res_dict.keys() else np.nan) diff --git a/src/offline_rl_ope/api/d3rlpy/Misc.py b/src/offline_rl_ope/api/d3rlpy/Misc.py index a2ca478..214db44 100644 --- a/src/offline_rl_ope/api/d3rlpy/Misc.py +++ b/src/offline_rl_ope/api/d3rlpy/Misc.py @@ -1,13 +1,13 @@ -from typing import Callable import torch +from .types import D3rlpyAlgoPredictProtocal __all__ = ["D3RlPyTorchAlgoPredict"] class D3RlPyTorchAlgoPredict: - def __init__(self, predict_func:Callable): + def __init__(self, predict_func:D3rlpyAlgoPredictProtocal): self.predict_func = predict_func def __call__(self, x:torch.Tensor): - pred = self.predict_func(x.cpu().numpy()) + pred = self.predict_func(x.numpy()) return torch.Tensor(pred) diff --git a/src/offline_rl_ope/api/d3rlpy/types.py b/src/offline_rl_ope/api/d3rlpy/types.py new file mode 100644 index 0000000..83e9db2 --- /dev/null +++ b/src/offline_rl_ope/api/d3rlpy/types.py @@ -0,0 +1,11 @@ +from typing import Any, Sequence, Union, Protocol +import numpy.typing as npt + +NDArray = npt.NDArray[Any] +Observation = Union[NDArray, Sequence[NDArray]] + + +class D3rlpyAlgoPredictProtocal(Protocol): + + def __call__(self, x:Observation) -> NDArray: + ... \ No newline at end of file diff --git a/src/offline_rl_ope/components/ImportanceSampler.py b/src/offline_rl_ope/components/ImportanceSampler.py index 6264cc7..8c7a065 100644 --- a/src/offline_rl_ope/components/ImportanceSampler.py +++ b/src/offline_rl_ope/components/ImportanceSampler.py @@ -13,9 +13,9 @@ class ISWeightCalculator: def __init__(self, behav_policy:Policy) -> None: self.__behav_policy = behav_policy - self.is_weights = torch.empty(0) - self.weight_msk = torch.empty(0) - self.policy_actions = torch.empty(0) + self.is_weights:torch.Tensor = torch.empty(0) + self.weight_msk:torch.Tensor = torch.empty(0) + self.policy_actions:List[torch.Tensor] = [torch.empty(0)] def get_traj_w(self, states:torch.Tensor, actions:torch.Tensor, eval_policy:Policy)->torch.Tensor: diff --git a/src/offline_rl_ope/components/Policy.py b/src/offline_rl_ope/components/Policy.py index f770e3f..b02769f 100644 --- a/src/offline_rl_ope/components/Policy.py +++ b/src/offline_rl_ope/components/Policy.py @@ -13,7 +13,7 @@ def preproc_cuda(x:torch.Tensor)->torch.Tensor: __all__ = [ - "Policy", "GreedyDeterministic", "BehavPolicy", "LinearMixedPolicy" + "Policy", "GreedyDeterministic", "BehavPolicy" ] class Policy(metaclass=ABCMeta): @@ -33,8 +33,8 @@ def __init__( collect_act (bool, optional): _description_. Defaults to False. """ self.policy_func = policy_func - self.policy_predictions = [] - self.policy_actions = [] + self.policy_predictions:List[torch.Tensor] = [] + self.policy_actions:List[torch.Tensor] = [] if collect_res: self.collect_res_fn = self.__cllct_res_true else: @@ -125,25 +125,3 @@ def __call__(self, state: torch.Tensor, action: torch.Tensor)->torch.Tensor: self.collect_res_fn(res) return res -class LinearMixedPolicy: - - def __init__(self, policy_funcs:List[Policy], - mixing_params:torch.Tensor) -> None: - if sum(mixing_params) != 1: - raise Exception("Mixing params must equal 1") - self.__policy_funcs = policy_funcs - self.__mixing_params = mixing_params - self.__policy_predictions = [] - - @property - def policy_predictions(self): - return self.__policy_predictions - - def __call__(self, state: torch.Tensor, action: torch.Tensor)->torch.Tensor: - res = [] - for pol in self.__policy_funcs: - pol_out = pol(state=state, action=action) - res.append(pol_out) - res = torch.cat(res, dim=1) - self.__policy_predictions.append(res) - return torch.sum(res*self.__mixing_params, dim=1, keepdim=True) \ No newline at end of file