Skip to content

Commit

Permalink
added abstract classes for torch propensity trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuaspear committed Feb 25, 2024
1 parent fe3bc26 commit 3f23e3b
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion src/offline_rl_ope/PropensityModels/torch/Trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
import torch
import torch.nn as nn
import numpy as np
import pickle

Expand Down Expand Up @@ -49,6 +49,25 @@ def save(self, path:str) -> None:
if self.gpu:
self.to_gpu()

@abstractmethod
def predict(
self,
x:torch.Tensor,
*args,
**kwargs
) -> torch.Tensor:
pass

@abstractmethod
def predict_proba(
self,
x: torch.Tensor,
y: torch.Tensor,
*args,
**kwargs
) -> torch.Tensor:
pass



class TorchClassTrainer(TorchPropensityTrainer):
Expand Down

0 comments on commit 3f23e3b

Please sign in to comment.