From 3f23e3ba7212e1dc366c8eaf7d6d9d63b41f16f4 Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Sun, 25 Feb 2024 15:08:04 +0000 Subject: [PATCH] added abstract classes for torch propensity trainer --- .../PropensityModels/torch/Trainer.py | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/offline_rl_ope/PropensityModels/torch/Trainer.py b/src/offline_rl_ope/PropensityModels/torch/Trainer.py index 693c024..4d9f072 100644 --- a/src/offline_rl_ope/PropensityModels/torch/Trainer.py +++ b/src/offline_rl_ope/PropensityModels/torch/Trainer.py @@ -1,5 +1,5 @@ +from abc import abstractmethod import torch -import torch.nn as nn import numpy as np import pickle @@ -49,6 +49,25 @@ def save(self, path:str) -> None: if self.gpu: self.to_gpu() + @abstractmethod + def predict( + self, + x:torch.Tensor, + *args, + **kwargs + ) -> torch.Tensor: + pass + + @abstractmethod + def predict_proba( + self, + x: torch.Tensor, + y: torch.Tensor, + *args, + **kwargs + ) -> torch.Tensor: + pass + class TorchClassTrainer(TorchPropensityTrainer):