Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pytorch policy #31

Merged
merged 18 commits into from
Nov 28, 2023
258 changes: 258 additions & 0 deletions notebooks/news_recommendation_byom.ipynb

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from setuptools import setup, find_packages
import os

with open("README.md", "r", encoding="UTF-8") as fh:
long_description = fh.read()
Expand Down
17 changes: 12 additions & 5 deletions src/learn_to_pick/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@
BasedOn,
Embed,
Featurizer,
ModelRepository,
Policy,
SelectionScorer,
ToSelectFrom,
VwPolicy,
VwLogger,
embed,
)
from learn_to_pick.pick_best import (
Expand All @@ -22,6 +19,14 @@
)


from learn_to_pick.vw.policy import VwPolicy
from learn_to_pick.vw.model_repository import ModelRepository
from learn_to_pick.vw.logger import VwLogger

from learn_to_pick.pytorch.policy import PyTorchPolicy
from learn_to_pick.pytorch.feature_embedder import PyTorchFeatureEmbedder


def configure_logger() -> None:
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand All @@ -48,9 +53,11 @@ def configure_logger() -> None:
"SelectionScorer",
"AutoSelectionScorer",
"Featurizer",
"ModelRepository",
"Policy",
"PyTorchPolicy",
"PyTorchFeatureEmbedder",
"embed",
"ModelRepository",
"VwPolicy",
"VwLogger",
"embed",
]
53 changes: 1 addition & 52 deletions src/learn_to_pick/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,12 @@
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
Callable,
)

from learn_to_pick.metrics import MetricsTrackerAverage, MetricsTrackerRollingWindow
from learn_to_pick.model_repository import ModelRepository
from learn_to_pick.vw_logger import VwLogger

from learn_to_pick.features import Featurized, DenseFeatures, SparseFeatures
from enum import Enum

Expand Down Expand Up @@ -89,10 +86,6 @@ def EmbedAndKeep(anything: Any) -> Any:
# helper functions


def _parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Example"]:
return [parser.parse_line(line) for line in input_str.split("\n")]


def filter_inputs(inputs: Dict[str, Any], role: Role) -> Dict[str, Any]:
return {
k: v.value
Expand Down Expand Up @@ -144,50 +137,6 @@ def save(self) -> None:
pass


class VwPolicy(Policy):
def __init__(
self,
model_repo: ModelRepository,
vw_cmd: List[str],
featurizer: Featurizer,
formatter: Callable,
vw_logger: VwLogger,
**kwargs: Any,
):
super().__init__(**kwargs)
self.model_repo = model_repo
self.vw_cmd = vw_cmd
self.workspace = self.model_repo.load(vw_cmd)
self.featurizer = featurizer
self.formatter = formatter
self.vw_logger = vw_logger

def format(self, event):
return self.formatter(*self.featurizer.featurize(event))

def predict(self, event: TEvent) -> Any:
import vowpal_wabbit_next as vw

text_parser = vw.TextFormatParser(self.workspace)
return self.workspace.predict_one(_parse_lines(text_parser, self.format(event)))

def learn(self, event: TEvent) -> None:
import vowpal_wabbit_next as vw

vw_ex = self.format(event)
text_parser = vw.TextFormatParser(self.workspace)
multi_ex = _parse_lines(text_parser, vw_ex)
self.workspace.learn_one(multi_ex)

def log(self, event: TEvent) -> None:
if self.vw_logger.logging_enabled():
vw_ex = self.format(event)
self.vw_logger.log(vw_ex)

def save(self) -> None:
self.model_repo.save(self.workspace)


class Featurizer(Generic[TEvent], ABC):
def __init__(self, *args: Any, **kwargs: Any):
pass
Expand Down
10 changes: 7 additions & 3 deletions src/learn_to_pick/pick_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
import numpy as np

from learn_to_pick import base
from learn_to_pick.vw.policy import VwPolicy
from learn_to_pick.vw.model_repository import ModelRepository
from learn_to_pick.vw.logger import VwLogger


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -333,14 +337,14 @@ def create_policy(

vw_cmd = interactions + vw_cmd

return base.VwPolicy(
model_repo=base.ModelRepository(
return VwPolicy(
model_repo=ModelRepository(
model_save_dir, with_history=True, reset=reset_model
),
vw_cmd=vw_cmd,
featurizer=featurizer,
formatter=formatter,
vw_logger=base.VwLogger(rl_logs),
vw_logger=VwLogger(rl_logs),
)

def _default_policy(self):
Expand Down
Empty file.
50 changes: 50 additions & 0 deletions src/learn_to_pick/pytorch/feature_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from sentence_transformers import SentenceTransformer
import torch
from learn_to_pick import PickBestFeaturizer


class PyTorchFeatureEmbedder:
def __init__(self, model=None, *args, **kwargs):
if model is None:
model = SentenceTransformer("all-MiniLM-L6-v2")

self.model = model
self.featurizer = PickBestFeaturizer(auto_embed=False)

def encode(self, stuff):
embeddings = self.model.encode(stuff, convert_to_tensor=True)
normalized = torch.nn.functional.normalize(embeddings)
return normalized

def convert_features_to_text(self, sparse_features):
results = []
for ns, obj in sparse_features.items():
value = obj.get("default_ft", "")
results.append(f"{ns}={value}")
return " ".join(results)

def format(self, event):
# TODO: handle dense
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe throw a not yet supported message

context_featurized, actions_featurized, selected = self.featurizer.featurize(
event
)

context_sparse = self.encode(
[self.convert_features_to_text(context_featurized.sparse)]
)

actions_sparse = []
for action_featurized in actions_featurized:
actions_sparse.append(
self.convert_features_to_text(action_featurized.sparse)
)
actions_sparse = self.encode(actions_sparse).unsqueeze(0)

if selected.score is not None:
return (
torch.Tensor([[selected.score]]),
context_sparse,
actions_sparse[:, selected.index, :].unsqueeze(1),
)
else:
return context_sparse, actions_sparse
19 changes: 19 additions & 0 deletions src/learn_to_pick/pytorch/igw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch


def IGW(fhat, gamma):
from math import sqrt

fhatahat, ahat = fhat.max(dim=1)
A = fhat.shape[1]
gamma *= sqrt(A)
p = 1 / (A + gamma * (fhatahat.unsqueeze(1) - fhat))
sump = p.sum(dim=1)
p[range(p.shape[0]), ahat] += torch.clamp(1 - sump, min=0, max=None)
return torch.multinomial(p, num_samples=1).squeeze(1), ahat


def SamplingIGW(A, P, gamma):
exploreind, _ = IGW(P, gamma)
explore = [ind for _, ind in zip(A, exploreind)]
return explore
90 changes: 90 additions & 0 deletions src/learn_to_pick/pytorch/logistic_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import parameterfree
import torch
import torch.nn.functional as F


class MLP(torch.nn.Module):
@staticmethod
def new_gelu(x):
import math

return (
0.5
* x
* (
1.0
+ torch.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
)
)
)

def __init__(self, dim):
super().__init__()
self.c_fc = torch.nn.Linear(dim, 4 * dim)
self.c_proj = torch.nn.Linear(4 * dim, dim)
self.dropout = torch.nn.Dropout(0.5)

def forward(self, x):
x = self.c_fc(x)
x = self.new_gelu(x)
x = self.c_proj(x)
x = self.dropout(x)
return x


class Block(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.layer = MLP(dim)

def forward(self, x):
return x + self.layer(x)


class ResidualLogisticRegressor(torch.nn.Module):
def __init__(self, in_features, depth, device):
super().__init__()
self._in_features = in_features
self._depth = depth
self.blocks = torch.nn.Sequential(*[Block(in_features) for _ in range(depth)])
self.linear = torch.nn.Linear(in_features=in_features, out_features=1)
self.optim = parameterfree.COCOB(self.parameters())
self._device = device

def clone(self):
other = ResidualLogisticRegressor(self._in_features, self._depth, self._device)
other.load_state_dict(self.state_dict())
other.optim = parameterfree.COCOB(other.parameters())
other.optim.load_state_dict(self.optim.state_dict())
return other

def forward(self, X, A):
return self.logits(X, A)

def logits(self, X, A):
# X = batch x features
# A = batch x actionbatch x actionfeatures

Xreshap = X.unsqueeze(1).expand(
-1, A.shape[1], -1
) # batch x actionbatch x features
XA = (
torch.cat((Xreshap, A), dim=-1)
.reshape(X.shape[0], A.shape[1], -1)
.to(self._device)
) # batch x actionbatch x (features + actionfeatures)
return self.linear(self.blocks(XA)).squeeze(2) # batch x actionbatch

def predict(self, X, A):
self.eval()
return torch.special.expit(self.logits(X, A))

def bandit_learn(self, X, A, R):
self.train()
self.optim.zero_grad()
output = self(X, A)
loss = F.binary_cross_entropy_with_logits(output, R)
loss.backward()
self.optim.step()
return loss.item()
Loading
Loading