-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add pytorch policy #31
Merged
Changes from 13 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
97793c8
add byom notebook
cheng-tan afef4a7
WIP: BYOM
cheng-tan 063faaf
Save resume model
cheng-tan 2ec66ab
fix format
cheng-tan 8480a8b
Merge branch 'main' of github.com:VowpalWabbit/learn_to_pick into byom
cheng-tan 8bb09a1
WIP
cheng-tan b093e1b
Fix test
cheng-tan 45e273e
Merge branch 'main' of github.com:VowpalWabbit/learn_to_pick into byom
cheng-tan 97d9f0c
format
cheng-tan 324ca81
separate vw and pytorch
cheng-tan 1295d33
format
cheng-tan cf4b284
fix tests
cheng-tan d6e9c87
update feature embedder
cheng-tan 20f8a21
add type hint to pytorch
cheng-tan af207ff
update notebook
cheng-tan d75f88a
rename notebook
cheng-tan 7f30c85
rename variables
cheng-tan d88ea8e
update readme
cheng-tan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from sentence_transformers import SentenceTransformer | ||
import torch | ||
from learn_to_pick import PickBestFeaturizer | ||
|
||
|
||
class PyTorchFeatureEmbedder: | ||
def __init__(self, model=None, *args, **kwargs): | ||
if model is None: | ||
model = SentenceTransformer("all-MiniLM-L6-v2") | ||
|
||
self.model = model | ||
self.featurizer = PickBestFeaturizer(auto_embed=False) | ||
|
||
def encode(self, stuff): | ||
embeddings = self.model.encode(stuff, convert_to_tensor=True) | ||
normalized = torch.nn.functional.normalize(embeddings) | ||
return normalized | ||
|
||
def convert_features_to_text(self, sparse_features): | ||
results = [] | ||
for ns, obj in sparse_features.items(): | ||
value = obj.get("default_ft", "") | ||
results.append(f"{ns}={value}") | ||
return " ".join(results) | ||
|
||
def format(self, event): | ||
# TODO: handle dense | ||
context_featurized, actions_featurized, selected = self.featurizer.featurize( | ||
event | ||
) | ||
|
||
context_sparse = self.encode( | ||
[self.convert_features_to_text(context_featurized.sparse)] | ||
) | ||
|
||
actions_sparse = [] | ||
for action_featurized in actions_featurized: | ||
actions_sparse.append( | ||
self.convert_features_to_text(action_featurized.sparse) | ||
) | ||
actions_sparse = self.encode(actions_sparse).unsqueeze(0) | ||
|
||
if selected.score is not None: | ||
return ( | ||
torch.Tensor([[selected.score]]), | ||
context_sparse, | ||
actions_sparse[:, selected.index, :].unsqueeze(1), | ||
) | ||
else: | ||
return context_sparse, actions_sparse |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import torch | ||
|
||
|
||
def IGW(fhat, gamma): | ||
from math import sqrt | ||
|
||
fhatahat, ahat = fhat.max(dim=1) | ||
A = fhat.shape[1] | ||
gamma *= sqrt(A) | ||
p = 1 / (A + gamma * (fhatahat.unsqueeze(1) - fhat)) | ||
sump = p.sum(dim=1) | ||
p[range(p.shape[0]), ahat] += torch.clamp(1 - sump, min=0, max=None) | ||
return torch.multinomial(p, num_samples=1).squeeze(1), ahat | ||
|
||
|
||
def SamplingIGW(A, P, gamma): | ||
exploreind, _ = IGW(P, gamma) | ||
explore = [ind for _, ind in zip(A, exploreind)] | ||
return explore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import parameterfree | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
|
||
class MLP(torch.nn.Module): | ||
@staticmethod | ||
def new_gelu(x): | ||
import math | ||
|
||
return ( | ||
0.5 | ||
* x | ||
* ( | ||
1.0 | ||
+ torch.tanh( | ||
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)) | ||
) | ||
) | ||
) | ||
|
||
def __init__(self, dim): | ||
super().__init__() | ||
self.c_fc = torch.nn.Linear(dim, 4 * dim) | ||
self.c_proj = torch.nn.Linear(4 * dim, dim) | ||
self.dropout = torch.nn.Dropout(0.5) | ||
|
||
def forward(self, x): | ||
x = self.c_fc(x) | ||
x = self.new_gelu(x) | ||
x = self.c_proj(x) | ||
x = self.dropout(x) | ||
return x | ||
|
||
|
||
class Block(torch.nn.Module): | ||
def __init__(self, dim): | ||
super().__init__() | ||
self.layer = MLP(dim) | ||
|
||
def forward(self, x): | ||
return x + self.layer(x) | ||
|
||
|
||
class ResidualLogisticRegressor(torch.nn.Module): | ||
def __init__(self, in_features, depth, device): | ||
super().__init__() | ||
self._in_features = in_features | ||
self._depth = depth | ||
self.blocks = torch.nn.Sequential(*[Block(in_features) for _ in range(depth)]) | ||
self.linear = torch.nn.Linear(in_features=in_features, out_features=1) | ||
self.optim = parameterfree.COCOB(self.parameters()) | ||
self._device = device | ||
|
||
def clone(self): | ||
other = ResidualLogisticRegressor(self._in_features, self._depth, self._device) | ||
other.load_state_dict(self.state_dict()) | ||
other.optim = parameterfree.COCOB(other.parameters()) | ||
other.optim.load_state_dict(self.optim.state_dict()) | ||
return other | ||
|
||
def forward(self, X, A): | ||
return self.logits(X, A) | ||
|
||
def logits(self, X, A): | ||
# X = batch x features | ||
# A = batch x actionbatch x actionfeatures | ||
|
||
Xreshap = X.unsqueeze(1).expand( | ||
-1, A.shape[1], -1 | ||
) # batch x actionbatch x features | ||
XA = ( | ||
torch.cat((Xreshap, A), dim=-1) | ||
.reshape(X.shape[0], A.shape[1], -1) | ||
.to(self._device) | ||
) # batch x actionbatch x (features + actionfeatures) | ||
return self.linear(self.blocks(XA)).squeeze(2) # batch x actionbatch | ||
|
||
def predict(self, X, A): | ||
self.eval() | ||
return torch.special.expit(self.logits(X, A)) | ||
|
||
def bandit_learn(self, X, A, R): | ||
self.train() | ||
self.optim.zero_grad() | ||
output = self(X, A) | ||
loss = F.binary_cross_entropy_with_logits(output, R) | ||
loss.backward() | ||
self.optim.step() | ||
return loss.item() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe throw a not yet supported message