Skip to content

Commit

Permalink
New versions of notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
wouterboomsma committed Mar 2, 2022
1 parent 2c24db6 commit 3c4eb46
Show file tree
Hide file tree
Showing 3 changed files with 1,715 additions and 484 deletions.
4 changes: 2 additions & 2 deletions blat_class_A1A2_experiments.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -1124,7 +1124,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
"version": "3.9.5"
}
},
"nbformat": 4,
Expand Down
2,158 changes: 1,688 additions & 470 deletions distances.ipynb

Large diffs are not rendered by default.

37 changes: 25 additions & 12 deletions models/vae_geometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def __init__(self, data, weights, perm, hparams, aa_weights=None):
nn.ReLU())

self.encoder_mu = nn.Linear(self.encoder_architecture[-1], zdim)
self.encoder_scale = nn.Sequential(nn.Linear(self.encoder_architecture[-1], zdim), nn.Softplus())

if not ("simplify_to_ae" in self.hparams and self.hparams.simplify_to_ae):
self.encoder_scale = nn.Sequential(nn.Linear(self.encoder_architecture[-1], zdim), nn.Softplus())

if "sparsity_prior" in self.hparams and self.hparams.sparsity_prior:

Expand Down Expand Up @@ -216,9 +218,13 @@ def forward(self, x, n_samples=1):
x = nn.functional.one_hot(x, len(aa1_to_index))
h = self.encoder(x.float().reshape(x.shape[0], -1))

q_dist = D.Independent(D.Normal(self.encoder_mu(h),
self.encoder_scale(h) + 1e-4), 1)
z_samples = q_dist.rsample(torch.Size([n_samples]))
if not ("simplify_to_ae" in self.hparams and self.hparams.simplify_to_ae):
q_dist = D.Independent(D.Normal(self.encoder_mu(h),
self.encoder_scale(h) + 1e-4), 1)
z_samples = q_dist.rsample(torch.Size([n_samples]))
else:
q_dist = None
z_samples = self.encoder_mu(h).unsqueeze(0)

recon = self.decode(z_samples)
return recon, q_dist, z_samples
Expand All @@ -245,14 +251,20 @@ def _step(self, batch, batch_idx):
log_prob_x += -self.hparams.sparsity_prior_lambda * self.sparsity_prior.log_prob(self.S.weight).sum()

recon_loss = -log_prob_x.mean()
kl_loss = D.kl_divergence(q_dist, self.prior).mean()

kl_loss = 0

if self.hparams.iwae_bound:
# importance weighted autoencoder bound
loss = -torch.mean(torch.logsumexp((log_prob_x + self.prior.log_prob(z_samples) - q_dist.log_prob(z_samples)), dim=0) - np.log(z_samples.shape[0]))
if ("simplify_to_ae" in self.hparams and self.hparams.simplify_to_ae):
loss = recon_loss
else:
# standard elbo
loss = -torch.mean(log_prob_x + self.prior.log_prob(z_samples) - q_dist.log_prob(z_samples))
kl_loss = D.kl_divergence(q_dist, self.prior).mean()

if self.hparams.iwae_bound:
# importance weighted autoencoder bound
loss = -torch.mean(torch.logsumexp((log_prob_x + self.prior.log_prob(z_samples) - q_dist.log_prob(z_samples)), dim=0) - np.log(z_samples.shape[0]))
else:
# standard elbo
loss = -torch.mean(log_prob_x + self.prior.log_prob(z_samples) - q_dist.log_prob(z_samples))

acc = (recon.mean(dim=0).argmax(dim=1) == x)[x!=aa1_to_index['-']].float().mean()
return loss, recon_loss, kl_loss, acc
Expand Down Expand Up @@ -298,12 +310,12 @@ def configure_optimizers(self):
# 'test_acc': acc,
# 'beta': self.beta})

def train_dataloader(self, labels=None):
def train_dataloader(self, labels=None, reweighting=True):
dataset = [self.data[self.perm[:self.train_idx]]]
if labels is not None:
dataset += [labels[self.perm[:self.train_idx]]]
train_data = torch.utils.data.TensorDataset(*dataset)
if self.weights is not None:
if self.weights is not None and reweighting:
weights_normalized = self.weights[self.perm[:self.train_idx]]
weights_normalized /= weights_normalized.sum()
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights_normalized, len(weights_normalized))
Expand Down Expand Up @@ -382,6 +394,7 @@ def str2bool(v):
argparser.add_argument('-sparsity_prior', default=0, nargs='?', type=str2bool)
argparser.add_argument('-mask_out_gaps', default=0, nargs='?', type=str2bool)
argparser.add_argument('-sparsity_prior_lambda', default=1e-4, nargs='?', type=float)
argparser.add_argument('-simplify_to_ae', default=0, nargs='?', type=str2bool)

return argparser.parse_args(args)

Expand Down

0 comments on commit 3c4eb46

Please sign in to comment.