Skip to content

Commit

Permalink
Ts vae (#69)
Browse files Browse the repository at this point in the history
* start time vae

* vae model

* add fix bugs in timevae

* timevae mlp experiments

* timevae mlp experiments

* adjust timevae

* use different activation for decoder in timevae

* clean up timevae

* add timevae

* reload from checkpoint

* timevae evaluation
  • Loading branch information
emptymalei authored Dec 16, 2024
1 parent baff07a commit e6bb7bd
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 12 deletions.
77 changes: 66 additions & 11 deletions dl/notebooks/time_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,9 +738,17 @@ def configure_optimizers(self) -> dict:

# ## Fitted Model

for i in time_vae_dm.predict_dataloader():
print(i.size())
i_pred = vae_model.model(i)
IS_RELOAD = True

if IS_RELOAD:
checkpoint_path = "lightning_logs/time_vae_naive/version_29/checkpoints/epoch=1999-step=354000.ckpt"
vae_model_reloaded = VAEModel.load_from_checkpoint(checkpoint_path, model=vae)
else:
vae_model_reloaded = vae_model

for pred_batch in time_vae_dm.predict_dataloader():
print(pred_batch.size())
i_pred = vae_model_reloaded.model(pred_batch.float().cuda())
break

i_pred[0].size()
Expand All @@ -752,23 +760,70 @@ def configure_optimizers(self) -> dict:

element = 4

ax.plot(i.detach().numpy()[element, :, 0])
ax.plot(i_pred[0].detach().numpy()[element, :, 0], "x-")
ax.plot(pred_batch.detach().numpy()[element, :, 0])
ax.plot(i_pred[0].cpu().detach().numpy()[element, :, 0], "x-")
# -

# Data generation using the decoder.

sampling_z = torch.randn(2, vae_model.model.encoder.params.latent_size).type_as(
vae_model.model.encoder.z_mean_layer.weight
sampling_z = torch.randn(
pred_batch.size(0), vae_model_reloaded.model.encoder.params.latent_size
).type_as(vae_model_reloaded.model.encoder.z_mean_layer.weight)
generated_samples_x = (
vae_model_reloaded.model.decoder(sampling_z).cpu().detach().numpy().squeeze()
)
sampling_x = vae_model.model.decoder(sampling_z)

sampling_x.size()
generated_samples_x.size()

# +
_, ax = plt.subplots()

for i in range(min(len(sampling_x), 4)):
ax.plot(sampling_x.detach().numpy()[i, :, 0], "x-")
for i in range(min(len(generated_samples_x), 4)):
ax.plot(generated_samples_x[i, :], "x-")

# -
from openTSNE import TSNE

n_tsne_samples = 100

original_samples = pred_batch.cpu().detach().numpy().squeeze()[:n_tsne_samples]
original_samples.shape

tsne = TSNE(
perplexity=30,
metric="euclidean",
n_jobs=8,
random_state=42,
verbose=True,
)

original_samples_embedding = tsne.fit(original_samples)

generated_samples_x[:n_tsne_samples]

generated_samples_embedding = original_samples_embedding.transform(
generated_samples_x[:n_tsne_samples]
)

# +
fig, ax = plt.subplots(figsize=(7, 7))

ax.scatter(
original_samples_embedding[:, 0],
original_samples_embedding[:, 1],
color="black",
marker=".",
label="original",
)

ax.scatter(
generated_samples_embedding[:, 0],
generated_samples_embedding[:, 1],
color="red",
marker="x",
label="generated",
)

ax.set_title("t-SNE of original and generated samples")
ax.set_xlabel("t-SNE 1")
ax.set_ylabel("t-SNE 2")
33 changes: 32 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jupytext = "^1.15.2"

[tool.poetry.group.visualization.dependencies]
seaborn = "^0.13.2"
opentsne = "^1.0.2"


[tool.poetry.group.data.dependencies]
Expand Down

0 comments on commit e6bb7bd

Please sign in to comment.