From e6bb7bd9964d466a3fe54b1aa7a57cc10b07ab24 Mon Sep 17 00:00:00 2001 From: LM Date: Mon, 16 Dec 2024 17:53:46 +0100 Subject: [PATCH] Ts vae (#69) * start time vae * vae model * add fix bugs in timevae * timevae mlp experiments * timevae mlp experiments * adjust timevae * use different activation for decoder in timevae * clean up timevae * add timevae * reload from checkpoint * timevae evaluation --- dl/notebooks/time_vae.py | 77 ++++++++++++++++++++++++++++++++++------ poetry.lock | 33 ++++++++++++++++- pyproject.toml | 1 + 3 files changed, 99 insertions(+), 12 deletions(-) diff --git a/dl/notebooks/time_vae.py b/dl/notebooks/time_vae.py index 0cb1d0b2..fbd50b57 100644 --- a/dl/notebooks/time_vae.py +++ b/dl/notebooks/time_vae.py @@ -738,9 +738,17 @@ def configure_optimizers(self) -> dict: # ## Fitted Model -for i in time_vae_dm.predict_dataloader(): - print(i.size()) - i_pred = vae_model.model(i) +IS_RELOAD = True + +if IS_RELOAD: + checkpoint_path = "lightning_logs/time_vae_naive/version_29/checkpoints/epoch=1999-step=354000.ckpt" + vae_model_reloaded = VAEModel.load_from_checkpoint(checkpoint_path, model=vae) +else: + vae_model_reloaded = vae_model + +for pred_batch in time_vae_dm.predict_dataloader(): + print(pred_batch.size()) + i_pred = vae_model_reloaded.model(pred_batch.float().cuda()) break i_pred[0].size() @@ -752,23 +760,70 @@ def configure_optimizers(self) -> dict: element = 4 -ax.plot(i.detach().numpy()[element, :, 0]) -ax.plot(i_pred[0].detach().numpy()[element, :, 0], "x-") +ax.plot(pred_batch.detach().numpy()[element, :, 0]) +ax.plot(i_pred[0].cpu().detach().numpy()[element, :, 0], "x-") # - # Data generation using the decoder. -sampling_z = torch.randn(2, vae_model.model.encoder.params.latent_size).type_as( - vae_model.model.encoder.z_mean_layer.weight +sampling_z = torch.randn( + pred_batch.size(0), vae_model_reloaded.model.encoder.params.latent_size +).type_as(vae_model_reloaded.model.encoder.z_mean_layer.weight) +generated_samples_x = ( + vae_model_reloaded.model.decoder(sampling_z).cpu().detach().numpy().squeeze() ) -sampling_x = vae_model.model.decoder(sampling_z) -sampling_x.size() +generated_samples_x.size() # + _, ax = plt.subplots() -for i in range(min(len(sampling_x), 4)): - ax.plot(sampling_x.detach().numpy()[i, :, 0], "x-") +for i in range(min(len(generated_samples_x), 4)): + ax.plot(generated_samples_x[i, :], "x-") # - +from openTSNE import TSNE + +n_tsne_samples = 100 + +original_samples = pred_batch.cpu().detach().numpy().squeeze()[:n_tsne_samples] +original_samples.shape + +tsne = TSNE( + perplexity=30, + metric="euclidean", + n_jobs=8, + random_state=42, + verbose=True, +) + +original_samples_embedding = tsne.fit(original_samples) + +generated_samples_x[:n_tsne_samples] + +generated_samples_embedding = original_samples_embedding.transform( + generated_samples_x[:n_tsne_samples] +) + +# + +fig, ax = plt.subplots(figsize=(7, 7)) + +ax.scatter( + original_samples_embedding[:, 0], + original_samples_embedding[:, 1], + color="black", + marker=".", + label="original", +) + +ax.scatter( + generated_samples_embedding[:, 0], + generated_samples_embedding[:, 1], + color="red", + marker="x", + label="generated", +) + +ax.set_title("t-SNE of original and generated samples") +ax.set_xlabel("t-SNE 1") +ax.set_ylabel("t-SNE 2") diff --git a/poetry.lock b/poetry.lock index 41bb6341..3d4ba0eb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3118,6 +3118,37 @@ rsa = ["cryptography (>=3.0.0)"] signals = ["blinker (>=1.4.0)"] signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] +[[package]] +name = "opentsne" +version = "1.0.2" +description = "Extensible, parallel implementations of t-SNE" +optional = false +python-versions = ">=3.9" +files = [ + {file = "openTSNE-1.0.2-cp310-cp310-macosx_10_12_universal2.whl", hash = "sha256:c82a2c263e570c75256d58590f7c99273c8f8152fada2e3f36a3de92d165a483"}, + {file = "openTSNE-1.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a7270156cbabc249301cd30f6010387f618295ca68c50913c98b9dad8d9c682"}, + {file = "openTSNE-1.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:99c90695c95f09a100216532d1cb7ec24a6269dd005f1e835c1ca0d603d43542"}, + {file = "openTSNE-1.0.2-cp311-cp311-macosx_10_12_universal2.whl", hash = "sha256:7882df123f210946d2806fb73feec3666e76eee2d6b6744893e14203e0641a38"}, + {file = "openTSNE-1.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:949f694893fd4f803acb513bc2e3d80ef04b707166c28a469ca43033f52b8e1b"}, + {file = "openTSNE-1.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:6799cb38e560e50f5ea932be587ad3efd4cf62d5a55d28c3061ca3e2aee210ce"}, + {file = "openTSNE-1.0.2-cp312-cp312-macosx_10_12_universal2.whl", hash = "sha256:0de8826568aa4f03658274edb393a5be031f771ea86f3493e91aecad27100c56"}, + {file = "openTSNE-1.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37f24d7d139bd466f00ae765120c3a8049ceddc1282e63d75e3406c3ac3b3783"}, + {file = "openTSNE-1.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7f342ec51fe365cd1a23ad25e6a7b5417f8bd1bf4d71a5d526f42ad4c4b64114"}, + {file = "openTSNE-1.0.2-cp39-cp39-macosx_10_12_universal2.whl", hash = "sha256:78d1122ce9233ba3e9de07825127b34b3f330665047f3e04c3914ee0e2b3fad2"}, + {file = "openTSNE-1.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e349d876a26f417326b0aa4e031f2ca5af167608538eaae2b5d2fbaabd353df"}, + {file = "openTSNE-1.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:8c927dc2b47560d06abab48a8cca0ddf388a9200a2d27fc197c786b15f644e7c"}, + {file = "opentsne-1.0.2.tar.gz", hash = "sha256:e2aecaa7a487100246f2d3fef9855d1bd6cc02a1c6da8fb2a54583f307aa4229"}, +] + +[package.dependencies] +numpy = ">=1.16.6" +scikit-learn = ">=0.20" +scipy = "*" + +[package.extras] +hnsw = ["hnswlib (>=0.4.0,<0.5.0)"] +pynndescent = ["pynndescent (>=0.5.0,<0.6.0)"] + [[package]] name = "optuna" version = "4.0.0" @@ -5989,4 +6020,4 @@ test = ["pytest"] [metadata] lock-version = "2.0" python-versions = "3.10.14" -content-hash = "a1f8365360620b7ab7f998731d9c97af5164bcab13ff0c0710b377b07dc43c1e" +content-hash = "d60d46a90877b7d0a5a457b0d97b0dd6cbcf1a1a228129c0d0645d572c110a6a" diff --git a/pyproject.toml b/pyproject.toml index 167811db..204115f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ jupytext = "^1.15.2" [tool.poetry.group.visualization.dependencies] seaborn = "^0.13.2" +opentsne = "^1.0.2" [tool.poetry.group.data.dependencies]