diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index fa528792..d5c5392e 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -2,8 +2,8 @@ name: Publish Docker on: push: branches: - - main - - master + - main + - master # pull_request: ~ env: @@ -14,37 +14,29 @@ jobs: build: runs-on: ubuntu-latest steps: - - name: Checkout - uses: actions/checkout@v3.3.0 - with: - fetch-depth: 2 - - name: Log in to the Container registry - uses: docker/login-action@v2.1.0 - with: - registry: ${{ env.REGISTRY }} - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} + - name: Checkout + uses: actions/checkout@v3.3.0 + with: + fetch-depth: 2 + - name: Log in to the Container registry + if: ${{ !env.ACT }} + uses: docker/login-action@v2.1.0 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} - - name: Extract metadata (tags, labels) for Docker - id: meta - uses: docker/metadata-action@v4.3.0 - with: - images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + - name: Extract metadata (tags, labels) for Docker + id: meta + uses: docker/metadata-action@v4.3.0 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} - - name: Build and push Docker image (version tag) - if: steps.check-version.outputs.current-version - uses: docker/build-push-action@v3.3.0 - with: - context: . - push: true - tags: ghcr.io/${{ github.repository }}:${{ steps.check-version.outputs.current-version }} - labels: ${{ steps.meta.outputs.labels }} - - - name: Build and push Docker image (latest tag) - if: steps.check-version.outputs.current-version - uses: docker/build-push-action@v3.3.0 - with: - context: . - push: true - tags: ghcr.io/${{ github.repository }}:latest - labels: ${{ steps.meta.outputs.labels }} \ No newline at end of file + - name: Build and push Docker image (version tag) + if: steps.check-version.outputs.current-version + uses: docker/build-push-action@v3.3.0 + with: + context: . + push: true + tags: ghcr.io/${{ github.repository }}:${{ steps.check-version.outputs.current-version }} + labels: ${{ steps.meta.outputs.labels }} diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 5f3d9f5f..290f82f8 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -1,36 +1,44 @@ -# https://github.com/marketplace/actions/install-poetry-action -name: test - -on: [pull_request,push] - +name: conda +on: [push] jobs: - test: + constructor: + name: conda build (${{ matrix.python-version }}, ${{ matrix.os }}) + runs-on: ${{ matrix.os }}-latest defaults: run: - shell: bash -l {0} + shell: ${{ matrix.shell }} strategy: - fail-fast: false matrix: + # os: [ubuntu, windows, macos] + os: [ubuntu] python-version: ["3.9"] - os: [ubuntu-latest] - # os: [ubuntu-18.04, macos-latest, windows-latest] - runs-on: ${{ matrix.os }} + include: + - os: ubuntu + shell: bash -l {0} + # - os: windows + # shell: cmd /C call {0} + # - os: macos + # shell: bash -l {0} steps: - - name: Check out repository - uses: actions/checkout@v2 - - uses: conda-incubator/setup-miniconda@v2 - with: - auto-update-conda: true - use-mamba: true - environment-file: environment.yml - python-version: ${{ matrix.python-version }} - - name: poetry env - run: poetry env use python - - name: Poetry lock - run: poetry lock - - name: Install library - run: poetry install --no-interaction - # - name: Run tests - # run: | - # source .venv/bin/activate - # pytest tests/ \ No newline at end of file + - uses: actions/checkout@v2 + - name: Free Disk Space (Ubuntu) + uses: jlumbroso/free-disk-space@main + with: + tool-cache: false + android: true + dotnet: true + haskell: true + large-packages: true + docker-images: true + swap-storage: true + - uses: conda-incubator/setup-miniconda@v2 + with: + environment-file: environment.yml + miniforge-variant: Mambaforge + miniforge-version: latest + mamba-version: "*" + use-mamba: true + python-version: ${{ matrix.python-version }} + - name: Run tests + run: | + make test diff --git a/Makefile b/Makefile index 1f1fb42a..08d73569 100644 --- a/Makefile +++ b/Makefile @@ -9,101 +9,3 @@ download.data: test: pytest - -GOOGLE_APPLICATION_CREDENTIALS=$(shell pwd)/credentials.json -BUCKET_NAME=idr-hipsci -TRAINING_DIR=idr0034-kilpinen-hipsci -PROJECT=prj-ext-dev-bia-binder-113155 - -JOB_PREFIX=vae -JOB_NAME=$(JOB_PREFIX)_$(shell date +%Y%m%d_%H%M%S) -JOB_DIR=gs://${BUCKET_NAME}/${JOB_NAME}/models -DATA_DIR=gs://${BUCKET_NAME}/${TRAINING_DIR} - -.EXPORT_ALL_VARIABLES: - GOOGLE_APPLICATION_CREDENTIALS - BUCKET_NAME - TRAINING_DIR - JOB_PREFIX - JOB_NAME - JOB_DIR - - -# MY_VAR := $(shell echo whatever) - -# test: -# @echo MY_VAR IS $(MY_VAR) - -test: - @echo $$GOOGLE_APPLICATION_CREDENTIALS $$BUCKET_NAME $$TRAINING_DIR - -all: get_data_list build - -build: - conda activate torch - python idr_get_data.py - -get_data_list: - ls /nfs/bioimage/drop/idr*/**/*.tiff > file_list.txt - ls -u /nfs/bioimage/drop/idr*/**/*.tiff > file_list.txt - -run.on.cloud: - python idr_get_data_s3.py - -run.on.cloud.snake: - snakemake --use-conda --cores all \ - --verbose --google-lifesciences \ - --default-remote-prefix idr-hipsci \ - --google-lifesciences-region eu-west2 - -run.snake: - snakemake --cores all -F --use-conda --verbose - -get.env.file: - conda env export --from-history -f environment.yml -n torch - -on.gcp: - gcloud ai-platform jobs submit training ${JOB_NAME} \ - --region=europe-west2 \ - --master-image-uri=gcr.io/cloud-ml-public/training/pytorch-gpu.1-9 \ - --scale-tier=CUSTOM \ - --master-machine-type=n1-standard-8 \ - --master-accelerator=type=nvidia-tesla-t4,count=1 \ - --job-dir=${JOB_DIR} \ - --package-path=./trainer \ - --module-name=trainer.train \ - --stream-logs \ - -- \ - --num-epochs=10 \ - --batch-size=100 \ - --learning-rate=0.001 \ - --gpus=1 - - -on.gcp.big: - gcloud ai-platform jobs submit training ${JOB_NAME} \ - --region=europe-west2 \ - --master-image-uri=gcr.io/cloud-ml-public/training/pytorch-gpu.1-9 \ - --config=config.yaml \ - --job-dir=${JOB_DIR} \ - --package-path=./trainer \ - --module-name=trainer.train \ - --stream-logs \ - -- \ - --num-epochs=10 \ - --batch-size=100 \ - --learning-rate=0.001 \ - --gpus=2 \ - --accelerator='ddp'\ - --num_nodes=3 - -tensorboard: - tensorboard --logdir=gs://$(BUCKET_NAME)/${JOB_NAME} -download.data: - kaggle competitions download -c data-science-bowl-2018 - -test: - pytest - -download.idr: - rsync -avR --progress ctr26@noah-login:/nfs/bioimage/drop/idr0093-mueller-perturbation/ data/idr diff --git a/README.md b/README.md index 6b34690d..ccf6c625 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,13 @@ This utility makes it simple to fetch the necessary datasets: ```bash make download.data ``` +If you don't have a Kaggle account you must create one and then follow the next steps: +1. Install the Kaggle API package so you can download the data from the Makefile you have all the information in their [Github repository](https://github.com/Kaggle/kaggle-api). +2. To use the Kaggle API you need also to create an API token. + You can found how to do it in their [documentation](https://github.com/Kaggle/kaggle-api#api-credentials) +4. After that you will need to add your user and key in a file called `kaggle.json` in this location in your home directory `chmod 600 ~/.kaggle/kaggle.json` +5. Don't forget to accept the conditions for the "2018 Data Science Bowl" on the Kaggle website. + Otherwise you would not be able to pull this data from the command line. ### 4. Developer Installation: @@ -88,4 +95,4 @@ bioimage_embed is licensed under the MIT License. Please refer to the [LICENSE]( --- -Happy Embedding! 🧬🔬 \ No newline at end of file +Happy Embedding! 🧬🔬 diff --git a/bioimage_embed/augmentations.py b/bioimage_embed/augmentations.py index e2c14074..6c9daba4 100644 --- a/bioimage_embed/augmentations.py +++ b/bioimage_embed/augmentations.py @@ -1,40 +1,6 @@ import albumentations as A import cv2 -DEFAULT_AUGMENTATION = A.Compose( - [ - # Flip the images horizontally or vertically with a 50% chance - A.OneOf( - [ - A.HorizontalFlip(p=0.5), - A.VerticalFlip(p=0.5), - ], - p=0.5, - ), - # Rotate the images by a random angle within a specified range - A.Rotate(limit=45, p=0.5), - # Randomly scale the image intensity to adjust brightness and contrast - A.RandomGamma(gamma_limit=(80, 120), p=0.5), - # Apply random elastic transformations to the images - A.ElasticTransform( - alpha=1, - sigma=50, - alpha_affine=50, - p=0.5, - ), - # Shift the image channels along the intensity axis - A.ChannelShuffle(p=0.5), - # Add a small amount of noise to the images - A.GaussNoise(var_limit=(10.0, 50.0), p=0.5), - # Crop a random part of the image and resize it back to the original size - A.RandomResizedCrop( - height=512, width=512, scale=(0.9, 1.0), ratio=(0.9, 1.1), p=0.5 - ), - # Adjust image intensity with a specified range for individual channels - A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5), - ] -) - from typing import Any import albumentations @@ -43,6 +9,39 @@ from omegaconf import DictConfig from PIL import Image +DEFAULT_AUGMENTATION_LIST = [ + # Flip the images horizontally or vertically with a 50% chance + A.OneOf( + [ + A.HorizontalFlip(p=0.5), + A.VerticalFlip(p=0.5), + ], + p=0.5, + ), + # Rotate the images by a random angle within a specified range + A.Rotate(limit=45, p=0.5), + # Randomly scale the image intensity to adjust brightness and contrast + A.RandomGamma(gamma_limit=(80, 120), p=0.5), + # Apply random elastic transformations to the images + A.ElasticTransform( + alpha=1, + sigma=50, + alpha_affine=50, + p=0.5, + ), + # Shift the image channels along the intensity axis + A.ChannelShuffle(p=0.5), + # Add a small amount of noise to the images + A.GaussNoise(var_limit=(10.0, 50.0), p=0.5), + # Crop a random part of the image and resize it back to the original size + A.RandomResizedCrop( + height=512, width=512, scale=(0.9, 1.0), ratio=(0.9, 1.1), p=0.5 + ), + # Adjust image intensity with a specified range for individual channels + A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5), +] + +DEFAULT_AUGMENTATION = A.Compose(DEFAULT_AUGMENTATION_LIST) class TransformsWrapper: def __init__(self, transforms_cfg: DictConfig) -> None: @@ -81,9 +80,7 @@ def __init__(self, transforms_cfg: DictConfig) -> None: _convert_="object", ) valid_test_predict_aug.append(aug) - self.valid_test_predict_aug = albumentations.Compose( - valid_test_predict_aug - ) + self.valid_test_predict_aug = albumentations.Compose(valid_test_predict_aug) def set_mode(self, mode: str) -> None: """Set `__call__` mode. @@ -111,4 +108,4 @@ def __call__(self, image: Any, **kwargs: Any) -> Any: image = np.asarray(image) if self.mode == "train": return self.train_aug(image=image, **kwargs) - return self.valid_test_predict_aug(image=image, **kwargs) \ No newline at end of file + return self.valid_test_predict_aug(image=image, **kwargs) diff --git a/bioimage_embed/cli.py b/bioimage_embed/cli.py new file mode 100644 index 00000000..45529654 --- /dev/null +++ b/bioimage_embed/cli.py @@ -0,0 +1,12 @@ +from .hydra import train, infer +from typer import Typer + +app = Typer() +app.command()(train) +app.command()(infer) + +def main(): + app() + +if __name__ == "__main__": + main() diff --git a/bioimage_embed/hydra.py b/bioimage_embed/hydra.py new file mode 100644 index 00000000..46ad75de --- /dev/null +++ b/bioimage_embed/hydra.py @@ -0,0 +1,106 @@ +from hydra.core.config_store import ConfigStore +from dataclasses import dataclass +from hydra import compose, initialize +from omegaconf import OmegaConf +from types import SimpleNamespace +import hydra +from hydra.core.config_store import ConfigStore +from omegaconf import OmegaConf +import albumentations +from dataclasses import dataclass, field +from bioimage_embed.augmentations import DEFAULT_AUGMENTATION_LIST +import albumentations as A +import os + +@dataclass +class Receipe: + _target_: str = "types.SimpleNamespace" + opt: str = "adamw" + weight_decay: float = 0.001 + momentum: float = 0.9 + sched: str = "cosine" + epochs: int = 50 + lr: float = 1e-4 + min_lr: float = 1e-6 + t_initial: int = 10 + t_mul: int = 2 + lr_min: float = None + decay_rate: float = 0.1 + warmup_lr: float = 1e-6 + warmup_lr_init: float = 1e-6 + warmup_epochs: int = 5 + cycle_limit: int = None + t_in_epochs: bool = False + noisy: bool = False + noise_std: float = 0.1 + noise_pct: float = 0.67 + noise_seed: int = None + cooldown_epochs: int = 5 + warmup_t: int = 0 + + +@dataclass +class Transform: + _target_: str = "albumentations.Compose" + transforms: A.Compose = field(default_factory=A.Compose(DEFAULT_AUGMENTATION_LIST)) + + +# @dataclass +# class AlbumentationsTransform: +# _target_: str = "albumentations.from_dict" +# transform_dict: dict = field(default_factory=A.from_dict) +# transform = A.from_dict(OmegaConf.to_container(cfg.albumentations, resolve=True)) + + +@dataclass +class ImageDataset: + _target_: str = "torchvision.datasets.ImageFolder" + transform: Transform = field(default_factory=Transform) + + +@dataclass +class Dataset: + pass + + +@dataclass +class DataLoader: + _target_: str = "bioimage_embed.lightning.dataloader.DataModule" + dataset: str = field(default_factory=ImageDataset) + + +# def cs_generator(): +cs = ConfigStore.instance() +cs.store(name="receipe", node=Receipe) +cs.store(name="dataloader", node=DataLoader) + + +# return cs +def train(): + main(job_name="test_app") + + +def write_default_config_file(config_path, config_filename, config): + os.makedirs(config_path, exist_ok=True) + with open(os.path.join(config_path, config_filename), "w") as file: + file.write(OmegaConf.to_yaml(config)) + + +def main(config_path="conf", job_name="test_app"): + config_file = os.path.join(config_path, "config.yaml") + + # Check if the configuration directory exists, if not, create it + if not os.path.exists(config_path): + os.makedirs(config_path) + # Initialize Hydra with a basic configuration + hydra.initialize(version_base=None, config_path=config_path, job_name=job_name) + cfg = hydra.compose(config_name="config") + # Save the default configuration + with open(config_file, "w") as file: + file.write(OmegaConf.to_yaml(cfg)) + else: + # Initialize Hydra normally if the configuration directory exists + hydra.initialize(version_base=None, config_path=config_path, job_name=job_name) + cfg = hydra.compose(config_name="config") + + print(OmegaConf.to_yaml(cfg)) diff --git a/bioimage_embed/lightning/dataloader.py b/bioimage_embed/lightning/dataloader.py index 29f608a4..34b84097 100644 --- a/bioimage_embed/lightning/dataloader.py +++ b/bioimage_embed/lightning/dataloader.py @@ -35,6 +35,7 @@ def __init__( "pin_memory": True, "shuffle": False, "sampler": sampler, + "drop_last": True, # "collate_fn": self.collate_wrapper(self.collate_filter_for_none), # "collate_fn": self.collate_filter_for_none, } diff --git a/bioimage_embed/lightning/torch.py b/bioimage_embed/lightning/torch.py index e9eef522..22147b81 100644 --- a/bioimage_embed/lightning/torch.py +++ b/bioimage_embed/lightning/torch.py @@ -7,7 +7,7 @@ import argparse import timm from pythae.models.base.base_utils import ModelOutput - +import torch.nn.functional as F class LitAutoEncoderTorch(pl.LightningModule): args = argparse.Namespace( @@ -45,10 +45,12 @@ def __init__(self, model, args=SimpleNamespace()): if args: self.args = SimpleNamespace(**{**vars(args), **vars(self.args)}) # if kwargs: - # merged_kwargs = {k: v for d in kwargs.values() for k, v in d.items()} - # self.args = SimpleNamespace(**{**merged_kwargs, **vars(self.args)}) + # merged_kwargs = {k: v for d in kwargs.values() for k, v in d.items()} + # self.args = SimpleNamespace(**{**merged_kwargs, **vars(self.args)}) self.save_hyperparameters(vars(self.args)) # self.model.train() + # keep a handle on metrics logged by the model + self.metrics = {} def forward(self, batch): x = self.batch_to_tensor(batch) @@ -72,35 +74,32 @@ def get_model_output(self, x, batch_idx): return model_output, loss def training_step(self, batch, batch_idx): - # results = self.get_results(batch) self.model.train() x = self.batch_to_tensor(batch) model_output, loss = self.get_model_output( x, batch_idx, ) - # loss = self.model.training_step(x) - # loss = self.loss_function(model_output,optimizer_idx) - - # self.log("train_loss", self.loss) - # self.log("train_loss", loss) - self.logger.experiment.add_scalar("Loss/train", loss, batch_idx) - - self.logger.experiment.add_image( - "input", torchvision.utils.make_grid(x["data"]), batch_idx - ) - - # if self.PYTHAE_FLAG: - self.logger.experiment.add_image( - "output", - torchvision.utils.make_grid(model_output.recon_x), - batch_idx, + self.log_dict( + { + "loss/train": loss, + "mse/train": F.mse_loss(model_output.recon_x, x["data"]), + }, + # on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, ) - + if isinstance(self.logger, pl.loggers.TensorBoardLogger): + self.log_tensorboard(model_output, x) return loss def loss_function(self, model_output, *args, **kwargs): - return model_output.loss + #return model_output.loss + return { + "loss": model_output.loss, + "recon_loss": model_output.recon_loss, + } # def logging_step(self, z, loss, x, model_output, batch_idx): # self.logger.experiment.add_embedding( @@ -121,20 +120,13 @@ def validation_step(self, batch, batch_idx): x = self.batch_to_tensor(batch) model_output, loss = self.get_model_output(x, batch_idx) z = self.embedding_from_output(model_output) - # z_indices - self.logger.experiment.add_embedding( - z, - label_img=x["data"], - global_step=self.current_epoch, - tag="z", - ) - - self.logger.experiment.add_scalar("Loss/val", loss, batch_idx) - self.logger.experiment.add_image( - "val", - torchvision.utils.make_grid(model_output["recon_x"]), - batch_idx, - ) + val_metrics ={ + "loss/val": loss, + "mse/val": F.mse_loss(model_output.recon_x, x["data"]), + } + self.log_dict( val_metrics,) + self.metrics = {**self.metrics, **val_metrics} + return loss # def lr_scheduler_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None): # # Implement your own logic for updating the lr scheduler @@ -181,19 +173,27 @@ def test_step(self, batch, batch_idx): loss = self.loss_function(model_output) # Log test metrics - self.log("test_loss", loss) + test_metrics = { + "loss/test": loss, + "mse/test": F.mse_loss(model_output.recon_x, x["data"]), + } + self.log_dict(test_metrics) + self.metrics = {**self.metrics, **test_metrics} + return loss + + def log_wandb(self): + pass + + def log_tensorboard(self, model_output, x): # Optionally you can add more logging, for example, visualizations: self.logger.experiment.add_image( "test_input", torchvision.utils.make_grid(x["data"]), - batch_idx, + self.global_step, ) self.logger.experiment.add_image( "test_output", torchvision.utils.make_grid(model_output.recon_x), - batch_idx, + self.global_step, ) - - # Return whatever data you need, for example, the loss - return loss diff --git a/bioimage_embed/models/factory.py b/bioimage_embed/models/factory.py index 749ebaa6..4c5f1a21 100644 --- a/bioimage_embed/models/factory.py +++ b/bioimage_embed/models/factory.py @@ -18,7 +18,6 @@ from . import bolts from functools import partial - class ModelFactory: def __init__( self, input_dim, latent_dim, pretrained=False, progress=True, **kwargs @@ -97,6 +96,32 @@ def resnet18_vae(self): bolts.ResNet18VAEDecoder, ) + def resnet18_vqvae(self): + return self.create_model( + partial( + pythae.models.VQVAEConfig, + use_default_encoder=False, + use_default_decoder=False, + **self.kwargs + ), + pythae.models.VQVAE, + bolts.ResNet18VQVAEEncoder, + bolts.ResNet18VQVAEDecoder, + ) + + def resnet18_beta_vae(self): + return self.create_model( + partial( + pythae.models.BetaVAEConfig, + use_default_encoder=False, + use_default_decoder=False, + **self.kwargs + ), + pythae.models.BetaVAE, + bolts.ResNet18VAEEncoder, + bolts.ResNet18VAEDecoder, + ) + def resnet50_vae(self): return self.create_model( partial( @@ -110,7 +135,7 @@ def resnet50_vae(self): bolts.ResNet50VAEDecoder, ) - def resnet18_vqvae(self): + def resnet50_vqvae(self): return self.create_model( partial( pythae.models.VQVAEConfig, @@ -119,21 +144,21 @@ def resnet18_vqvae(self): **self.kwargs ), pythae.models.VQVAE, - bolts.ResNet18VQVAEEncoder, - bolts.ResNet18VQVAEDecoder, + bolts.ResNet50VQVAEEncoder, + bolts.ResNet50VQVAEDecoder, ) - def resnet50_vqvae(self): + def resnet50_beta_vae(self): return self.create_model( partial( - pythae.models.VQVAEConfig, + pythae.models.BetaVAEConfig, use_default_encoder=False, use_default_decoder=False, **self.kwargs ), - pythae.models.VQVAE, - bolts.ResNet50VQVAEEncoder, - bolts.ResNet50VQVAEDecoder, + pythae.models.BetaVAE, + bolts.ResNet50VAEEncoder, + bolts.ResNet50VAEDecoder, ) def resnet_vae_legacy(self, depth): @@ -174,10 +199,77 @@ def resnet110_vqvae_legacy(self): def resnet152_vqvae_legacy(self): return self.resnet_vqvae_legacy(152) + def o2vae(self): + from .o2vae.models.decoders.cnn_decoder import CnnDecoder + from .o2vae.models.encoders_o2.e2scnn import E2SFCNN + from .o2vae.models.vae import VAE as O2VAE + + # encoder + q_net = E2SFCNN( + n_channels = 1, + n_classes = 64 * 2, # bc vae saves mean and stdDev vecors + # `name`: 'o2_cnn' for o2-invariant encoder. 'cnn_encoder' for standard cnn encoder. + name="o2_cnn_encoder", + # `cnn_dims`: must be 6 elements long. Increase numbers for larger model capacity + cnn_dims=[6, 9, 12, 12, 19, 25], + # `layer_type`: type of cnn layer (following e2cnn library examples) + layer_type="inducedgated_norm", # recommend not changing + # `N`: Ignored if `name!='o2'`. Negative means the model will be O2-invariant. + # Again, see (e2cnn library examples). Recommend not changing. + N=-3, + ) + + # decoder + p_net = CnnDecoder( + zdim = 64, + name="cnn_decoder", # 'cnn' is the ony option + # `cnn_dims`: each extra layer doubles the dimension (image width) by a factor of 2. + # E.g. if there are 6 elements, image width is 2^6=64 + cnn_dims=[192, 96, 96, 48, 48, 48], + #cnn_dims=[192, 96, 96, 48, 48, 24, 24, 12, 12], + out_channels=1, + ) + + # vae + model = O2VAE( + q_net = q_net, + p_net = p_net, + zdim = 64, # vae bottleneck layer + do_sigmoid = True, # whether to make the output be between [0,1]. Usually True. + loss_kwargs = dict( + # 'beta' from beta-vae, or the weight on the KL-divergence term https://openreview.net/forum?id=Sy2fzU9gl + beta=0.01, + # `recon_loss_type`: "bce" (binary cross entropy) or "mse" (mean square error) + # or "ce" (cross-entropy, but warning, not been tested well) + #recon_loss_type="bce", + recon_loss_type="mse", + # for reconstrutcion loss, pixel mask. Must be either `None` or an array with same dimension as the images. + mask=None, + align_loss=True, # whether to align the output image to the input image + # whether to use efficient Foureier-based loss alignment. (Ignored if align_loss==False) + align_fourier=True, + # whether to do align the best rotation AND flip, instead of just rotation. (Ignored if align_loss==False) + do_flip=True, + # if doing brute force align loss, this is the rotation discretization. (Ignored if + # align_loss==False or if align_fourier==True) + rot_steps=2, + # Recommend not changing. The vae prior distribution. Optoins: ("standard","normal","gmm"). See models.vae.VAE for deatils. + prior_kwargs=dict( prior="standard",), + ) + ) + + # extra attributes + model.encoder = q_net + model.decoder = p_net + + return model + MODELS = [ "resnet18_vae", + "resnet18_beta_vae", "resnet50_vae", + "resnet50_beta_vae", "resnet18_vae_bolt", "resnet50_vae_bolt", "resnet18_vqvae", @@ -189,6 +281,7 @@ def resnet152_vqvae_legacy(self): "resnet152_vqvae_legacy", "resnet18_vae_legacy", "resnet50_vae_legacy", + "o2vae", ] from typing import Tuple diff --git a/bioimage_embed/models/o2vae_shapeembed_integration.diff b/bioimage_embed/models/o2vae_shapeembed_integration.diff new file mode 100644 index 00000000..309d7206 --- /dev/null +++ b/bioimage_embed/models/o2vae_shapeembed_integration.diff @@ -0,0 +1,97 @@ +diff --git a/models/align_reconstructions.py b/models/align_reconstructions.py +index d07d1ab..c52b40d 100644 +--- a/models/align_reconstructions.py ++++ b/models/align_reconstructions.py +@@ -6,7 +6,7 @@ import torch + import torchgeometry as tgm + import torchvision.transforms.functional as T_f + +-from registration import registration ++from ..registration import registration + + + def loss_reconstruction_fourier_batch(x, y, recon_loss_type="bce", mask=None): +diff --git a/models/decoders/cnn_decoder.py b/models/decoders/cnn_decoder.py +index ba3a1cc..1740945 100644 +--- a/models/decoders/cnn_decoder.py ++++ b/models/decoders/cnn_decoder.py +@@ -58,7 +58,7 @@ class CnnDecoder(nn.Module): + + self.dec_conv = nn.Sequential(*layers) + +- def forward(self, x): ++ def forward(self, x, epoch = None): + bs = x.size(0) + x = self.fc(x) + dim = x.size(1) +diff --git a/models/encoders_o2/e2scnn.py b/models/encoders_o2/e2scnn.py +index 9c4f47f..e292b1e 100644 +--- a/models/encoders_o2/e2scnn.py ++++ b/models/encoders_o2/e2scnn.py +@@ -219,14 +219,20 @@ class E2SFCNN(torch.nn.Module): + repr += f"\t{i: <3} - {name: <70} | {params: <8} |\n" + return repr + +- def forward(self, input: torch.tensor): ++ def forward(self, input: torch.tensor, epoch = None): ++ #print(f"DEBUG: e2scnn forward: input.shape: {input.shape}") + x = GeometricTensor(input, self.in_repr) ++ #print(f"DEBUG: e2scnn forward: pre layers x.shape: {x.shape}") + + for layer in self.eq_layers: + x = layer(x) + ++ #print(f"DEBUG: e2scnn forward: pre fully_net x.shape: {x.shape}") ++ + x = self.fully_net(x.tensor.reshape(x.tensor.shape[0], -1)) + ++ #print(f"DEBUG: e2scnn forward: pre final x.shape: {x.shape}") ++ + return x + + def build_layer_regular( +diff --git a/models/vae.py b/models/vae.py +index 3af262b..af1a2dc 100644 +--- a/models/vae.py ++++ b/models/vae.py +@@ -3,8 +3,9 @@ import importlib + import numpy as np + import torch + import torchvision ++from pythae.models.base.base_utils import ModelOutput + +-from models import align_reconstructions ++from . import align_reconstructions + + from . import model_utils as mut + +@@ -273,10 +274,11 @@ class VAE(torch.nn.Module): + + return y + +- def forward(self, x): ++ def forward(self, x, epoch = None): ++ x = x["data"] + in_shape = x.shape + bs = in_shape[0] +- assert x.ndim == 4 ++ assert len(in_shape) == 4 + + # inference and sample + z = self.q_net(x) +@@ -290,8 +292,12 @@ class VAE(torch.nn.Module): + y = torch.sigmoid(y) + # check the spatial dimensions are good (if doing multiclass prediction per pixel, the `c` dim may be different) + assert in_shape[-2:] == y.shape[-2:], ( +- "output image different dimension to " +- "input image ... probably change the number of layers (cnn_dims) in the decoder" ++ f"output image different dimension {y.shape[-2:]} to " ++ f"input image {in_shape[-2:]} ... probably change the number of layers (cnn_dims) in the decoder" + ) + +- return x, y, mu, logvar ++ # gather losses ++ losses = self.loss(x, y, mu, logvar) ++ ++ return ModelOutput(recon_x=y, z=z_sample, loss=losses['loss'], recon_loss=losses['loss_recon']) ++ #return ModelOutput(recon_x=y, z=z_sample) diff --git a/bioimage_embed/models/pythae/legacy/vq_vae.py b/bioimage_embed/models/pythae/legacy/vq_vae.py index 38a45706..8ddc00c1 100644 --- a/bioimage_embed/models/pythae/legacy/vq_vae.py +++ b/bioimage_embed/models/pythae/legacy/vq_vae.py @@ -132,10 +132,12 @@ def forward(self, x, epoch=None): input=x["data"], ) # This matches how pythae returns the loss + + indices = (encodings == 1).nonzero(as_tuple=True) + recon_loss = F.mse_loss(x_recon, x["data"], reduction="sum") - mse_loss = F.mse_loss(x_recon, x["data"]) + mse_loss = F.mse_loss(x_recon, x["data"], reduction="mean") - indices = (encodings == 1).nonzero(as_tuple=True) variational_loss = loss-mse_loss pythae_loss_dict = { diff --git a/bioimage_embed/shapes/contours.py b/bioimage_embed/shapes/contours.py index fd82c4ba..6845b97f 100644 --- a/bioimage_embed/shapes/contours.py +++ b/bioimage_embed/shapes/contours.py @@ -35,7 +35,7 @@ def cubic_polar_resample_contour(contour: np.array, size: int) -> np.array: def contour_to_xy(contour: np.array): - return contour[0][:, 0], contour[0][:, 1] + return contour[:, 0], contour[:, 1] def uniform_spline_resample_contour(contour: np.array, size: int) -> np.array: diff --git a/bioimage_embed/shapes/lightning.py b/bioimage_embed/shapes/lightning.py index e5ec529e..a9a1e947 100644 --- a/bioimage_embed/shapes/lightning.py +++ b/bioimage_embed/shapes/lightning.py @@ -7,6 +7,7 @@ from torch import nn from ..lightning import LitAutoEncoderTorch from . import loss_functions as lf +import pythae from pythae.models.base.base_utils import ModelOutput from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from types import SimpleNamespace @@ -35,24 +36,45 @@ def batch_to_tensor(self, batch): return ModelOutput(data=normalised_data / scalings, scalings=scalings) def loss_function(self, model_output, *args, **kwargs): - loss_ops = lf.DistanceMatrixLoss(model_output.recon_x, norm=True) + loss_ops = lf.DistanceMatrixLoss(model_output.recon_x, norm=False) loss = model_output.loss - loss += torch.sum( + shape_loss = torch.sum( torch.stack( [ loss_ops.diagonal_loss(), loss_ops.symmetry_loss(), - # loss_ops.triangle_inequality(), loss_ops.non_negative_loss(), + # loss_ops.triangle_inequality(), # loss_ops.clockwise_order_loss(), ] ) ) + loss += shape_loss # loss += lf.diagonal_loss(model_output.recon_x) # loss += lf.symmetry_loss(model_output.recon_x) # loss += lf.triangle_inequality_loss(model_output.recon_x) # loss += lf.non_negative_loss(model_output.recon_x) + #return loss + + #variational_loss = model_output.loss - model_output.recon_loss + + metrics = { + "loss": loss, + "shape_loss": shape_loss, + "reconstruction_loss": model_output.recon_loss, + } + if isinstance(self.model, pythae.models.VQVAE): + metrics["vq_loss"] = model_output.vq_loss + if isinstance(self.model, pythae.models.BetaVAE): + metrics['KLD_loss'] = model_output.reg_loss + + self.log_dict( + metrics, + on_epoch=True, + prog_bar=True, + logger=True, + ) return loss diff --git a/bioimage_embed/shapes/mds.py b/bioimage_embed/shapes/mds.py index fdcf2af1..19846375 100644 --- a/bioimage_embed/shapes/mds.py +++ b/bioimage_embed/shapes/mds.py @@ -7,11 +7,12 @@ def mds(d): :return: A matrix of x, y coordinates. """ n = d.size(0) - I = torch.eye(n) + I = torch.eye(n, dtype=torch.float64) H = I - torch.ones((n, n)) / n S = -0.5 * H @ d @ H - eigvals, eigvecs = S.symeig(eigenvectors=True) + #eigvals, eigvecs = S.symeig(eigenvectors=True) + eigvals, eigvecs = torch.linalg.eigh(S) # Sort the eigenvalues and eigenvectors in decreasing order idx = eigvals.argsort(descending=True) diff --git a/bioimage_embed/shapes/transforms.py b/bioimage_embed/shapes/transforms.py index 33535871..2abd401f 100644 --- a/bioimage_embed/shapes/transforms.py +++ b/bioimage_embed/shapes/transforms.py @@ -11,6 +11,7 @@ from sklearn.metrics.pairwise import euclidean_distances from skimage.measure import find_contours import torch +from torch import nn import torch.nn.functional as F from . import contours @@ -159,14 +160,25 @@ def __repr__(self): def get_distogram(self, coords, matrix_normalised=False): xii, yii = coords - distance_matrix = euclidean_distances(np.array([xii, yii]).T) - # Fro norm is the same as the L2 norm, but for positive semi-definite matrices + distance_matrix = euclidean_distances(np.array([xii, yii]).T) / ( + np.sqrt(2) * self.size + ) + # TODO size should be shape of matrix and the normalisation should be + # D / (np.linalg.norm(x.shape[-2:])) + + norm = np.linalg.norm(distance_matrix, "fro") if matrix_normalised: return distance_matrix / np.linalg.norm(distance_matrix, "fro") if not matrix_normalised: return distance_matrix / np.linalg.norm([self.size, self.size]) +def find_longest_array(arrays): + lengths = [len(arr.flatten()) for arr in arrays] + max_length_index = np.argmax(lengths) + return arrays[max_length_index] + + class ImageToCoords(torch.nn.Module): def __init__(self, size): super().__init__() @@ -198,7 +210,8 @@ def get_coords_C( return torch.tensor(np.array(coords_list)) def get_coords(self, image, size, method="uniform_spline", contour_level=0.8): - contour = find_contours(np.array(image), contour_level) + contour_list = find_contours(np.array(image), contour_level) + contour = find_longest_array(contour_list) if method == "uniform_spline": return contours.uniform_spline_resample_contour(contour=contour, size=size) if method == "cubic_polar": @@ -365,3 +378,20 @@ def asym_dist_to_sym_dist(self, asymm_dist): sym_dist = np.max(dist_stack, axis=0) return torch.tensor(np.array(sym_dist)) + + +class RotateIndexingClockwise(nn.Module): + def __init__(self, max_rotations=None, p=1.0): + super(RotateIndexingClockwise, self).__init__() + self.max_rotations = max_rotations + self.probability = p + + def forward(self, img): + if np.random.rand() < self.probability: + if self.max_rotations is None: + self.max_rotations = img.shape[0] + num_rotations = np.random.randint(0, self.max_rotations) + img = np.roll( + img.numpy(), shift=[num_rotations, num_rotations], axis=[0, 1] + ) + return torch.from_numpy(img) diff --git a/bioimage_embed/tests/test_cli.py b/bioimage_embed/tests/test_cli.py new file mode 100644 index 00000000..dca082aa --- /dev/null +++ b/bioimage_embed/tests/test_cli.py @@ -0,0 +1,42 @@ +import os +import pytest +from ..hydra import main + +def test_main_creates_config(): + # Arrange + config_path = "test_conf" + job_name = "test_app" + + # Ensure the configuration directory does not exist initially + if os.path.exists(config_path): + os.rmdir(config_path) + + # Act + main(config_path=config_path, job_name=job_name) + + # Assert + assert os.path.exists(config_path), "Config directory was not created" + assert os.path.isfile(os.path.join(config_path, "config.yaml")), "Config file was not created" + + # Clean up + os.remove(os.path.join(config_path, "config.yaml")) + os.rmdir(config_path) + +@pytest.mark.parametrize("config_path, job_name", [ + ("conf", "test_app"), + ("another_conf", "another_job") +]) +def test_hydra_initializes(config_path, job_name): + # Act + main(config_path=config_path, job_name=job_name) + + # Assert + # Here you can assert specifics about the cfg object if needed. + # Since main does not return anything, you might need to adjust + # the main function to return the cfg for more thorough testing. + + # Clean up + if os.path.exists(config_path): + os.remove(os.path.join(config_path, "config.yaml")) + os.rmdir(config_path) + \ No newline at end of file diff --git a/bioimage_embed/tests/test_lightning.py b/bioimage_embed/tests/test_lightning.py index e1e5dc4a..a02ed2ca 100644 --- a/bioimage_embed/tests/test_lightning.py +++ b/bioimage_embed/tests/test_lightning.py @@ -109,7 +109,7 @@ def data(input_dim): @pytest.fixture() def dataset(data): - return data.unsqueeze(0) + return data @pytest.fixture() diff --git a/conf/augmentations/default.yaml b/conf/augmentations/default.yaml deleted file mode 100644 index 3ab17c45..00000000 --- a/conf/augmentations/default.yaml +++ /dev/null @@ -1,70 +0,0 @@ -# __version__: 1.3.0 -# transform: -# __class_fullname__: Compose -# additional_targets: {} -# bbox_params: null -# keypoint_params: null -# p: 1.0 -# transforms: -# - __class_fullname__: OneOf -# p: 0.5 -# transforms: -# - __class_fullname__: HorizontalFlip -# always_apply: false -# p: 0.5 -# - __class_fullname__: VerticalFlip -# always_apply: false -# p: 0.5 -# - __class_fullname__: Rotate -# always_apply: false -# border_mode: 4 -# crop_border: false -# interpolation: 1 -# limit: -# - -45 -# - 45 -# mask_value: null -# p: 0.5 -# rotate_method: largest_box -# value: null -# - __class_fullname__: RandomGamma -# always_apply: false -# eps: null -# gamma_limit: -# - 80 -# - 120 -# p: 0.5 -# - __class_fullname__: ElasticTransform -# alpha: 1 -# alpha_affine: 50 -# always_apply: false -# approximate: false -# border_mode: 4 -# interpolation: 1 -# mask_value: null -# p: 0.5 -# same_dxdy: false -# sigma: 50 -# value: null -# - __class_fullname__: GaussNoise -# always_apply: false -# mean: 0 -# p: 0.5 -# per_channel: true -# var_limit: -# - 10.0 -# - 50.0 -# - __class_fullname__: RandomCrop -# always_apply: false -# height: ${dataset.crop_size[0]} -# p: 1 -# width: ${dataset.crop_size[1]} -# - __class_fullname__: Normalize -# always_apply: true -# p: 1.0 -# transpose_mask: false -# - __class_fullname__: ToTensorV2 -# always_apply: true -# p: 1.0 -# transpose_mask: false - diff --git a/conf/bio_vae/default.yaml b/conf/bio_vae/default.yaml deleted file mode 100644 index 12f762d1..00000000 --- a/conf/bio_vae/default.yaml +++ /dev/null @@ -1,8 +0,0 @@ -_target_: bioimage_embed.models.BioimageEmbed -model: "VQVAE" -input_dim: - - 3 - - 128 - - 128 -latent_dim: 64 -model_config: ${pythae.model_config} diff --git a/conf/checkpoints/default.yaml b/conf/checkpoints/default.yaml deleted file mode 100644 index 76ebb7cf..00000000 --- a/conf/checkpoints/default.yaml +++ /dev/null @@ -1,3 +0,0 @@ -_target_: pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint -dirpath: ${paths.output_dir} -save_last: True \ No newline at end of file diff --git a/conf/config.yaml b/conf/config.yaml deleted file mode 100644 index d8156ab3..00000000 --- a/conf/config.yaml +++ /dev/null @@ -1,153 +0,0 @@ -defaults: - - _self_ - - trainer: default.yaml - - pythae: default.yaml - # - optimizer: default.yaml - # - scheulder: default.yaml - - timm: default.yaml - - augmentations: default.yaml - # - dataset: default.yaml - - dataloader: default.yaml - - paths: default.yaml - - lightning: default.yaml - - bioimage_embed: default.yaml - - logger: default.yaml - - checkpoints: default.yaml - -version_base: 2.0 - -# seed for random number generators in pytorch, numpy and python.random -seed: 42 - -# name of the run, accessed by loggers -name: null - -trainer: - accelerator: "gpu" - devices: "auto" - gradient_clip_val: 1 - accumulate_grad_batches: 16 - min_epochs: 0 - max_epochs: 200 - strategy: "ddp" - profiler: null - fast_dev_run: False - -dataset: - name: "ivy_gap" - # dir: "data" - train_dataset_glob: ${paths.data_dir}/${dataset.name}/random/*png - crop_size: - - 256 - - 256 - -dataloader: - batch_size: 32 - num_workers: 8 - pin_memory: false - shuffle: true - persistent_workers: true - -model: - _target_: bioimage_embed.models.create_model - name: "resnet18_vqvae_legacy" - # Dims match ImageNet - input_dim: [3, 64, 64] - latent_dim: 8 - opt: LAMB - lr: 1.0e-4 - weight_decay: 0.0001 - momentum: 0.9 - sched: cosine - min_lr: 1.0e-6 - warmup_epochs: 5 - warmup_lr: 1.0e-6 - cooldown_epochs: 10 - t_max: 50 - cycle_momentum: false - -# pythae: -# encoder: bioimage_embed.models.ResNet18VAEEncoder -# # _target_: Encoder_ResNet_VQVAE_CELEBA -# decoder: bioimage_embed.models.ResNet18VAEDecoder -# model_config: -# _target_: pythae.models.VAEConfig - -albumentations: - __version__: 1.3.0 - transform: - __class_fullname__: Compose - additional_targets: {} - bbox_params: null - keypoint_params: null - p: 1.0 - transforms: - - __class_fullname__: OneOf - p: 0.5 - transforms: - - __class_fullname__: HorizontalFlip - always_apply: false - p: 0.5 - - __class_fullname__: VerticalFlip - always_apply: false - p: 0.5 - - __class_fullname__: RandomCrop - always_apply: true - height: ${dataset.crop_size[0]} - p: 1 - width: ${dataset.crop_size[1]} - # scale: - # - 1.0 - # - 1.0 - # - __class_fullname__: Rotate - # always_apply: false - # border_mode: 4 - # crop_border: false - # interpolation: 1 - # limit: - # - -45 - # - 45 - # mask_value: null - # p: 0.5 - # rotate_method: largest_box - # value: null - # - __class_fullname__: RandomGamma - # always_apply: false - # eps: null - # gamma_limit: - # - 80 - # - 120 - # p: 0.5 - # - __class_fullname__: ElasticTransform - # alpha: 1 - # alpha_affine: 50 - # always_apply: false - # approximate: false - # border_mode: 4 - # interpolation: 1 - # mask_value: null - # p: 0.5 - # same_dxdy: false - # sigma: 50 - # value: null - # - __class_fullname__: GaussNoise - # always_apply: false - # mean: 0 - # p: 0.5 - # per_channel: true - # var_limit: - # - 10.0 - # - 50.0 - - __class_fullname__: Resize - always_apply: true - height: ${model.input_dim[1]} - p: 1 - width: ${model.input_dim[2]} - - __class_fullname__: ToFloat - always_apply: true - p: 1.0 - max_value: 1.0 - - __class_fullname__: ToTensorV2 - always_apply: true - p: 1.0 - # transpose_mask: false diff --git a/conf/dataloader/default.yaml b/conf/dataloader/default.yaml deleted file mode 100644 index 872861b6..00000000 --- a/conf/dataloader/default.yaml +++ /dev/null @@ -1,7 +0,0 @@ -_target_: bioimage_embed.lightning.DatamoduleGlob -glob_str: ${dataset.train_dataset_glob} -batch_size: 32 -num_workers: 4 -pin_memory: true -shuffle: true -persistent_workers: true \ No newline at end of file diff --git a/conf/dataset/default.yaml b/conf/dataset/default.yaml deleted file mode 100644 index e69de29b..00000000 diff --git a/conf/hydra/default.yaml b/conf/hydra/default.yaml deleted file mode 100644 index 9de8ac12..00000000 --- a/conf/hydra/default.yaml +++ /dev/null @@ -1,14 +0,0 @@ -# https://hydra.cc/docs/configure_hydra/intro/ -# https://github.com/ashleve/lightning-hydra-template/blob/main/configs/hydra/default.yaml - -# enable color logging -defaults: - - override hydra_logging: colorlog - - override job_logging: colorlog - -# output directory, generated dynamically on each run -run: - dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} -sweep: - dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} - subdir: ${hydra.job.num} \ No newline at end of file diff --git a/conf/ivy_gap.yaml b/conf/ivy_gap.yaml deleted file mode 100644 index 777faccc..00000000 --- a/conf/ivy_gap.yaml +++ /dev/null @@ -1,103 +0,0 @@ -dataset: "ivy_gap" -data_dir: "data" -train_dataset_glob: f"{data_dir}/{dataset}/random/*png" - -optimizer_params: - opt: LAMB - lr: 0.001 - weight_decay: 0.0001 - momentum: 0.9 - -lr_scheduler_params: - sched: cosine - min_lr: 1.0e-6 - warmup_epochs: 5 - warmup_lr: 1.0e-6 - cooldown_epochs: 10 - t_max: 50 - cycle_momentum: false - -albumentations: - __version__: 1.3.0 - transform: - __class_fullname__: Compose - additional_targets: {} - bbox_params: null - keypoint_params: null - p: 1.0 - transforms: - - __class_fullname__: OneOf - p: 0.5 - transforms: - - __class_fullname__: HorizontalFlip - always_apply: false - p: 0.5 - - __class_fullname__: VerticalFlip - always_apply: false - p: 0.5 - - __class_fullname__: Rotate - always_apply: false - border_mode: 4 - crop_border: false - interpolation: 1 - limit: - - -45 - - 45 - mask_value: null - p: 0.5 - rotate_method: largest_box - value: null - - __class_fullname__: RandomGamma - always_apply: false - eps: null - gamma_limit: - - 80 - - 120 - p: 0.5 - - __class_fullname__: ElasticTransform - alpha: 1 - alpha_affine: 50 - always_apply: false - approximate: false - border_mode: 4 - interpolation: 1 - mask_value: null - p: 0.5 - same_dxdy: false - sigma: 50 - value: null - - __class_fullname__: GaussNoise - always_apply: false - mean: 0 - p: 0.5 - per_channel: true - var_limit: - - 10.0 - - 50.0 - - __class_fullname__: RandomCrop - always_apply: false - height: 128 - p: 1 - width: 128 - - __class_fullname__: RandomBrightnessContrast - always_apply: false - brightness_by_max: true - brightness_limit: - - -0.2 - - 0.2 - contrast_limit: - - -0.2 - - 0.2 - p: 0.5 - - __class_fullname__: Normalize - always_apply: false - max_pixel_value: 255.0 - mean: - - 0.485 - - 0.456 - - 0.406 - p: 1.0 - std: - - 0.229 - - 0.224 - - 0.225 diff --git a/conf/lightning/default.yaml b/conf/lightning/default.yaml deleted file mode 100644 index 6a45b2de..00000000 --- a/conf/lightning/default.yaml +++ /dev/null @@ -1,3 +0,0 @@ -_target_: bioimage_embed.lightning.LitAutoEncoderTorch -model: ${pythae} -args: ${timm} \ No newline at end of file diff --git a/conf/logger/default.yaml b/conf/logger/default.yaml deleted file mode 100644 index 2ad96e8b..00000000 --- a/conf/logger/default.yaml +++ /dev/null @@ -1,2 +0,0 @@ -_target_: pytorch_lightning.loggers.TensorBoardLogger -save_dir: ${paths.log_dir} diff --git a/conf/paths/default.yaml b/conf/paths/default.yaml deleted file mode 100644 index d8738dc1..00000000 --- a/conf/paths/default.yaml +++ /dev/null @@ -1,18 +0,0 @@ -# path to root directory -# this requires PROJECT_ROOT environment variable to exist -# you can replace it with "." if you want the root to be the current working directory -# root_dir: ${oc.env:PROJECT_ROOT} -root_dir: . -# path to data directory -data_dir: ${paths.root_dir}/data/ - -# path to logging directory -log_dir: ${paths.root_dir}/logs/ - -# path to output directory, created dynamically by hydra -# path generation pattern is specified in `configs/hydra/default.yaml` -# use it to store all files generated during the run, like ckpts and metrics -output_dir: ${hydra:runtime.output_dir} - -# path to working directory -work_dir: ${hydra:runtime.cwd} \ No newline at end of file diff --git a/conf/pythae/default.yaml b/conf/pythae/default.yaml deleted file mode 100644 index f4c01e7f..00000000 --- a/conf/pythae/default.yaml +++ /dev/null @@ -1,17 +0,0 @@ -# model_name: VQVAE - -# model: -_target_: pythae.models.VAE -# model_config: $(model.model_config) -encoder: - _target_: bioimage_embed.models.ResNet18VAEEncoder - model_config: ${pythae.model_config} -decoder: - _target_: bioimage_embed.models.ResNet18VAEDecoder - model_config: ${pythae.model_config} - -model_config: - _target_: pythae.models.VAEConfig - _convert_: all - input_dim: ${model.input_dim} - latent_dim: ${model.latent_dim} diff --git a/conf/timm/default.yaml b/conf/timm/default.yaml deleted file mode 100644 index 0d61e8c3..00000000 --- a/conf/timm/default.yaml +++ /dev/null @@ -1,15 +0,0 @@ - # _target_: timm.optim.optimizer -opt: LAMB -lr: 1.0e-3 -weight_decay: 0.0001 -momentum: 0.9 -# scheduler: -# _target_: timm.scheduler.scheduler -sched: cosine -min_lr: 1.0e-6 -warmup_epochs: 5 -warmup_lr: 1.0e-6 -cooldown_epochs: 10 -t_max: 50 -cycle_momentum: false -epochs: ${trainer.max_epochs} \ No newline at end of file diff --git a/conf/trainer/default.yaml b/conf/trainer/default.yaml deleted file mode 100644 index 86d4d552..00000000 --- a/conf/trainer/default.yaml +++ /dev/null @@ -1,18 +0,0 @@ -_target_: pytorch_lightning.Trainer - -accelerator: "gpu" -devices: "1" -# weights_summary: null -# progress_bar_refresh_rate: 5 -# resume_from_checkpoint: null -# val_check_interval: 1 -check_val_every_n_epoch: 1 -logger: ${logger} -gradient_clip_val: 1 -enable_checkpointing: True -accumulate_grad_batches: 4 -callbacks: - - ${checkpoints} -min_epochs: 50 -max_epochs: 200 -precision: 32 \ No newline at end of file diff --git a/environment.yml b/environment.yml index 32343b75..cdd9cd54 100644 --- a/environment.yml +++ b/environment.yml @@ -1,19 +1,19 @@ # name: bioimage_embed channels: - - conda-forge - - defaults - - torch - - bioconda +- conda-forge +- defaults +- torch +- bioconda dependencies: - - cudatoolkit-dev=10 - - python=3.9 - - mamba - - poetry - - gcc - - libgcc - - pytorch - - pillow=9.5.0 - - snakemake-minimal - - pip - - pip: - - -e . +- cudatoolkit-dev=10 +- python=3.9 +- mamba +- poetry +- gcc +- libgcc +- pytorch +- pillow=9.5.0 +- pip +- conda-forge::opencv +- pip: + - -e . diff --git a/pyproject.toml b/pyproject.toml index 82fb7df8..23672073 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ scikit-image = "^0.21.0" iteround = "^1.0.4" ipykernel = "^6.25.1" nonechucks = "^0.4.2" -pythae = "^0.1.1" +pythae = { git = "https://github.com/clementchadebec/benchmark_VAE.git", branch = "main" } pytest = "^7.4.0" pandas = "^2.1.0" bokeh = "^3.2.2" diff --git a/scripts/shapeembed/__init__.py b/scripts/shapeembed/__init__.py new file mode 100644 index 00000000..cd331ee4 --- /dev/null +++ b/scripts/shapeembed/__init__.py @@ -0,0 +1,2 @@ +from .dataset_transformations import mask2distmatrix +from .evaluation import * diff --git a/scripts/shapeembed/common_helpers.py b/scripts/shapeembed/common_helpers.py new file mode 100644 index 00000000..fd09a241 --- /dev/null +++ b/scripts/shapeembed/common_helpers.py @@ -0,0 +1,42 @@ +import re +import os +import glob +import types +import logging + +def compressed_n_features(dist_mat_size, comp_fact): + return dist_mat_size*(dist_mat_size-1)//(2**comp_fact) + +def model_str(params): + s = f'{params.model_name}' + if hasattr(params, 'model_args'): + s += f"-{'_'.join([f'{k}{v}' for k, v in vars(params.model_args).items()])}" + return s + +def job_str(params): + return f"{params.dataset.name}-{model_str(params)}-{params.compression_factor}-{params.latent_dim}-{params.batch_size}" + +def job_str_re(): + return re.compile("(.*)-(.*)-(\d+)-(\d+)-(\d+)") + +def params_from_job_str(jobstr): + raw = jobstr.split('-') + ps = types.SimpleNamespace() + ps.batch_size = int(raw.pop()) + ps.latent_dim = int(raw.pop()) + ps.compression_factor = int(raw.pop()) + if len(raw) == 3: + ps.model_args = types.SimpleNamespace() + for p in raw.pop().split('-'): + if p[0:4] == 'beta': ps.model_args.beta = float(p[4:]) + ps.model_name = raw.pop() + ps.dataset = types.SimpleNamespace(name=raw.pop()) + return ps + +def find_existing_run_scores(dirname, logger=logging.getLogger(__name__)): + ps = [] + for f in glob.glob(f'{dirname}/*-shapeembed-score_df.csv'): + p = params_from_job_str(os.path.basename(f)[:-24]) + p.csv_file = f + ps.append(p) + return ps diff --git a/scripts/shapeembed/dataset_transformations.py b/scripts/shapeembed/dataset_transformations.py new file mode 100644 index 00000000..8c4c6693 --- /dev/null +++ b/scripts/shapeembed/dataset_transformations.py @@ -0,0 +1,216 @@ +import numpy as np +import imageio.v3 as iio +import skimage as sk +from scipy.interpolate import splprep, splev +import scipy.spatial +import argparse +import pathlib +import types +import glob +import os +import logging + +# logging facilities +############################################################################### +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +# misc helpers +############################################################################### + +def rgb2grey(rgb, cr = 0.2989, cg = 0.5870, cb = 0.1140): + """Turn an rgb array into a greyscale array using the following reduction: + grey = cr * r + cg * g + cb * b + + :param rgb: The rgb array + :param cr: The red coefficient + :param cg: The green coefficient + :param cb: The blue coefficient + + :returns: The greyscale array. + """ + r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2] + return cr * r + cg * g + cb * b + +# API functions +############################################################################### + +def find_longest_contour(mask, normalise_coord=False): + """Find all contours existing in 'mask' and return the longest one + + :param mask: The image with masked objects + :param normalise_coord(default: False): optionally normalise coordinates + + :returns: the longest contour as a pair of lists for the x and y + coordinates + """ + # force the image to grayscale + if len(mask.shape) == 3: # (lines, columns, number of channels) + mask = rgb2grey(mask) + # extract the contours from the now grayscale image + contours = sk.measure.find_contours(mask, 0.8) + logger.debug(f'find_longest_contour: len(contours) {len(contours)}') + # sort the contours by length + contours = sorted(contours, key=lambda x: len(x), reverse=True) + # isolate the longest contour (first in the sorted list) + x, y = contours[0][:, 0], contours[0][:, 1] + # optionally normalise the coordinates in the countour + if normalise_coord: + x = x - np.min(x) + x = x / np.max(x) + y = y - np.min(y) + y = y / np.max(y) + # return the contour as a pair of lists of x and y coordinates + return x, y + +def spline_interpolation(x, y, spline_sampling, raw_sampling_sparsity=1): + """Return a resampled spline interpolation of a provided contour + + :param x: The list of x coordinates of a contour + :param y: The list of y coordinates of a contour + :param spline_sampling: The number of points to sample on the spline + :param raw_sampling_sparsity (default=1): + The distance (in number of gaps) to the next point to consider in the + raw contour (i.e. whether consider every point, every other point + , every 3 points... This might be considered to avoid artifacts due to + high point count contours over low pixel resolution images, with contour + effectively curving around individual pixel edges) + + :returns: the resampled spline with spline_sampling points as a pair of + lists of x and y coordinates + """ + # Force sparsity to be at least one + raw_sampling_sparsity = max(1, raw_sampling_sparsity) + logger.debug(f'spline_interpolation: running with raw_sampling_sparsity {raw_sampling_sparsity} and spline_sampling {spline_sampling}') + logger.debug(f'spline_interpolation: x.shape {x.shape} y.shape {y.shape}') + # prepare the spline interpolation of the given contour + tck, u = splprep( [x[::raw_sampling_sparsity], y[::raw_sampling_sparsity]] + , s = 0 # XXX + , per = True # closed contour (periodic spline) + ) + # how many times to sample the spline + # last parameter is how dense is our spline, how many points. + new_u = np.linspace(u.min(), u.max(), spline_sampling) + # evaluate and return the sampled spline + x_spline, y_spline = splev(new_u, tck) + return x_spline, y_spline + +def build_distance_matrix(x_reinterpolated, y_reinterpolated): + """Turn a (reinterpolated) contour into a distance matrix + + :param x_reinterpolated: The list of x coordinates of a contour + :param y_reinterpolated: The list of y coordinates of a contour + + :returns: the distance matrix characteristic of the provided contour + """ + # reshape the pair of lists of individual x and y coordinates as a single + # numpy array of pairs of (x,y) coordinates + reinterpolated_contour = np.column_stack([ x_reinterpolated + , y_reinterpolated ]) + # build the distance matrix from the reshaped input data + dm = scipy.spatial.distance_matrix( reinterpolated_contour + , reinterpolated_contour ) + return dm + +def dist_to_coords(dst_mat): + """Turn a distance matrix into the corresponding contour + XXX + TODO sort out exactly the specifics here... + """ + embedding = MDS(n_components=2, dissimilarity='precomputed') + return embedding.fit_transform(dst_mat) + +def mask2distmatrix(mask, matrix_size=512, raw_sampling_sparsity=1): + """Get the distance matrix characteristic of the (biggest) object in the + provided image + + :param mask: The image with masked objects + :param matrix_size(default: 512): the desired matrix size + :param raw_sampling_sparsity (default=1): + The distance (in number of gaps) to the next point to consider in the + raw contour (i.e. whether consider every point, every other point + , every 3 points... This might be considered to avoid artifacts due to + high point count contours over low pixel resolution images, with contour + effectively curving around individual pixel edges) + + :returns: the distance matrix characteristic of the (biggest) object in + the provided image + """ + logger.debug(f'mask2distmatrix: running with raw_sampling_sparsity {raw_sampling_sparsity} and matrix_size {matrix_size}') + # extract mask contour + x, y = find_longest_contour(mask, normalise_coord=True) + logger.debug(f'mask2distmatrix: found contour shape x {x.shape} y {y.shape}') + # Reinterpolate (spline) + x_reinterpolated, y_reinterpolated = spline_interpolation(x, y, matrix_size, raw_sampling_sparsity) + # Build the distance matrix + dm = build_distance_matrix(x_reinterpolated, y_reinterpolated) + logger.debug(f'mask2distmatrix: created distance matrix shape {dm.shape}') + return dm + +def bbox(img): + """ + This function returns the bounding box of the content of an image, where + "content" is any non 0-valued pixel. The bounding box is returned as the + quadruple ymin, ymax, xmin, xmax. + + Parameters + ---------- + img : 2-d numpy array + An image with an object to find the bounding box for. The truth value of + object pixels should be True and of non-object pixels should be False. + + Returns + ------- + ymin: int + The lowest index row containing object pixels + ymax: int + The highest index row containing object pixels + xmin: int + The lowest index column containing object pixels + xmax: int + The highest index column containing object pixels + """ + rows = np.any(img, axis=1) + cols = np.any(img, axis=0) + ymin, ymax = np.where(rows)[0][[0, -1]] + xmin, xmax = np.where(cols)[0][[0, -1]] + return ymin, ymax, xmin, xmax + +def recrop_image(img, square=False): + """ + This function returns an image recroped to its content. + + Parameters + ---------- + img : 3-d numpy array + A 3-channels (rgb) 2-d image with an object to recrop around. The value of + object pixels should be non-zero (and zero for non-object pixels). + + Returns + ------- + 3-d numpy array + The recroped image + """ + + ymin, ymax, xmin, xmax = bbox(img) + newimg = img[ymin:ymax+1, xmin:xmax+1] + + if square: # slot the new image into a black square + dx, dy = xmax+1 - xmin, ymax+1 - ymin + dmax = max(dx, dy) + #dmin = min(dx, dy) + dd = max(dx, dy) - min(dx, dy) + off = dd // 2 + res = np.full((dmax, dmax, 3), [.0,.0,.0]) # big black square + #print(f"DEBUG: dx {dx}, dy {dy}, dmax {dmax}, dd {dd}, off {off}") + #print(f"DEBUG: res[off+1:off+1+newimg.shape[0],:].shape: {res[off+1:off+1+newimg.shape[0],:].shape}") + #print(f"DEBUG: newimg.shape: {newimg.shape}") + if dx < dy: # fewer columns, center horizontally + res[:, off:off+newimg.shape[1]] = newimg + else: # fewer lines, center vertically + res[off:off+newimg.shape[0],:] = newimg + #print(f"DEBUG: res img updated") + #print(f"DEBUG: ------------------------------") + return res + else: + return newimg diff --git a/scripts/shapeembed/efd.py b/scripts/shapeembed/efd.py new file mode 100755 index 00000000..9b9525f8 --- /dev/null +++ b/scripts/shapeembed/efd.py @@ -0,0 +1,94 @@ +#! /usr/bin/env python3 + +import os +import types +import pyefd +import random +import logging +import argparse + +# own imports +#import bioimage_embed # necessary for the datamodule class to make sure we get the same test set +from bioimage_embed.shapes.transforms import ImageToCoords +from evaluation import * + +def get_dataset(dataset_params): + # access the dataset + assert dataset_params.type == 'mask', f'unsupported dataset type {dataset_params.type}' + raw_dataset = datasets.ImageFolder( dataset_params.path + , transform=transforms.Compose([ + transforms.Grayscale(1) + , ImageToCoords(contour_size) ])) + dataset = [x for x in raw_dataset] + random.shuffle(dataset) + return dataset + +def run_elliptic_fourier_descriptors(dataset, contour_size, logger): + # run efd on each image + dfs = [] + logger.info(f'running efd on {dataset}') + for i, (img, lbl) in enumerate(tqdm.tqdm(dataset)): + coeffs = pyefd.elliptic_fourier_descriptors(img, order=10, normalize=False) + norm_coeffs = pyefd.normalize_efd(coeffs) + df = pandas.DataFrame({ + "norm_coeffs": norm_coeffs.flatten().tolist() + , "coeffs": coeffs.flatten().tolist() + }).T.rename_axis("coeffs") + df['class'] = lbl + df.set_index("class", inplace=True, append=True) + dfs.append(df) + # concatenate results as a single dataframe and return it + df = pandas.concat(dfs).xs('coeffs', level='coeffs') + df.reset_index(level='class', inplace=True) + return df + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Run efd on a given dataset') + + dflt_dataset=('tiny_synthetic_shapes', '/nfs/research/uhlmann/afoix/datasets/image_datasets/tiny_synthetic_shapes', 'mask') + parser.add_argument( + '-d', '--dataset', nargs=3, metavar=('NAME', 'PATH', 'TYPE'), default=dflt_dataset + , help=f"The NAME, PATH and TYPE of the dataset (default: {dflt_dataset})") + + dflt_contour_size=512 + + parser.add_argument( + '-o', '--output-dir', metavar='OUTPUT_DIR', default='./' + , help=f"The OUTPUT_DIR path to use to dump results") + + parser.add_argument('-v', '--verbose', action='count', default=0 + , help="Increase verbosity level by adding more \"v\".") + + # parse command line arguments + clargs=parser.parse_args() + + # set verbosity level + logger = logging.getLogger(__name__) + if clargs.verbose > 2: + logger.setLevel(logging.DEBUG) + elif clargs.verbose > 0: + logger.setLevel(logging.INFO) + + # update default params with clargs + dataset = types.SimpleNamespace( name=clargs.dataset[0] + , path=clargs.dataset[1] + , type=clargs.dataset[2] ) + contour_size = dflt_contour_size + + # create output dir if it does not exist + os.makedirs(clargs.output_dir, exist_ok=True) + + # efd on input data and score + + efd_df = run_elliptic_fourier_descriptors(get_dataset(dataset), contour_size, logger) + + logger.info(f'-- efd on {dataset.name}, raw\n{efd_df}') + efd_df.to_csv(f"{clargs.output_dir}/{dataset.name}-efd-raw_df.csv") + umap_plot(efd_df, f'{dataset.name}-efd', outputdir=clargs.output_dir) + + efd_cm, efd_score_df = score_dataframe(efd_df, 'efd') + + logger.info(f'-- efd on {dataset.name}, score\n{efd_score_df}') + efd_score_df.to_csv(f"{clargs.output_dir}/{dataset.name}-efd-score_df.csv") + logger.info(f'-- confusion matrix:\n{efd_cm}') + confusion_matrix_plot(efd_cm, f'{dataset.name}-efd', clargs.output_dir) diff --git a/scripts/shapeembed/evaluation.py b/scripts/shapeembed/evaluation.py new file mode 100644 index 00000000..d530e9f6 --- /dev/null +++ b/scripts/shapeembed/evaluation.py @@ -0,0 +1,325 @@ +from torchvision import datasets, transforms +import pyefd +from umap import UMAP +from skimage import measure +from sklearn.cluster import KMeans +from sklearn.pipeline import Pipeline +from sklearn.ensemble import RandomForestClassifier +from sklearn.discriminant_analysis import StandardScaler +from sklearn import metrics +from sklearn.metrics import make_scorer +from sklearn.metrics import confusion_matrix, accuracy_score +from sklearn.model_selection import cross_validate, cross_val_predict, KFold, train_test_split, StratifiedKFold + +import tqdm +import numpy +import pandas +import logging +import seaborn +import matplotlib.pyplot as plt + +# logging facilities +############################################################################### +logger = logging.getLogger(__name__) +#logging.basicConfig(level=logging.DEBUG) + +def dataloader_to_dataframe(dataloader): + # gather the data and the associated labels, and drop rows with NaNs + all_data = [] + all_lbls = [] + for batch in dataloader: + inputs, lbls = batch + for data, lbl in zip(inputs, lbls): + all_data.append(data.flatten().numpy()) + all_lbls.append(int(lbl)) + df = pandas.DataFrame(all_data) + df['class'] = all_lbls + df.dropna() + return df + +def run_kmeans(dataframe, random_seed=42): + # run KMeans and derive accuracy metric and confusion matrix + kmeans = KMeans( n_clusters=len(dataframe['class'].unique()) + , random_state=random_seed + ).fit(dataframe.drop('class', axis=1)) + accuracy = accuracy_score(dataframe['class'], kmeans.labels_) + conf_mat = confusion_matrix(dataframe['class'], kmeans.labels_) + return kmeans, accuracy, conf_mat + +def score_dataframe( df, name + , tag_columns=[] + , test_sz=0.2, rand_seed=42, shuffle=True, k_folds=5 ): + # drop strings and python object columns + #clean_df = df.select_dtypes(exclude=['object']) + clean_df = df.select_dtypes(include=['number']) + # TODO, currently unused + # Split the data into training and test sets + #X_train, X_test, y_train, y_test = train_test_split( + # clean_df.drop('class', axis=1), clean_df['class'] + #, stratify=clean_df['class'] + #, test_size=test_sz, randm_state=rand_seed, shuffle=shuffle + #) + # Define a dictionary of metrics + scoring = { + "accuracy": make_scorer(metrics.balanced_accuracy_score) + , "precision": make_scorer(metrics.precision_score, average="macro") + , "recall": make_scorer(metrics.recall_score, average="macro") + , "f1": make_scorer(metrics.f1_score, average="macro") + #, "roc_auc": make_scorer(metrics.roc_auc_score, average="macro") + } + # Create a random forest classifier + pipeline = Pipeline([ + ("scaler", StandardScaler()) + #, ("pca", PCA(n_components=0.95, whiten=True, random_state=rand_seed)) + , ("clf", RandomForestClassifier()) + #, ("clf", DummyClassifier()) + ]) + # build confusion matrix + clean_df.columns = clean_df.columns.astype(str) # only string column names + lbl_pred = cross_val_predict( pipeline + , clean_df.drop('class', axis=1) + , clean_df['class']) + conf_mat = confusion_matrix(clean_df['class'], lbl_pred) + # Perform k-fold cross-validation + cv_results = cross_validate( + estimator=pipeline + , X=clean_df.drop('class', axis=1) + , y=clean_df['class'] + , cv=StratifiedKFold(n_splits=k_folds) + , scoring=scoring + , n_jobs=-1 + , return_train_score=False + ) + # Put the results into a DataFrame + df = pandas.DataFrame(cv_results) + df = df.drop(["fit_time", "score_time"], axis=1) + df.insert(loc=0, column='trial', value=name) + tag_columns.reverse() + for tag_col_name, tag_col_value in tag_columns: + df.insert(loc=0, column=tag_col_name, value=tag_col_value) + return conf_mat, df + +def confusion_matrix_plot( cm, name, outputdir + , figsize=(10,7) ): + # Plot confusion matrix + plt.clf() # Clear figure + plt.figure(figsize=figsize) + seaborn.heatmap(cm, annot=True, fmt='d') + plt.title(f'{name} - Confusion Matrix') + plt.xlabel('Predicted') + plt.ylabel('Actual') + plt.savefig(f'{outputdir}/{name}-confusion_matrix.png') + plt.clf() # Clear figure + +def umap_plot( df + , name + , outputdir='.' + , n_neighbors=15 + , min_dist=0.1 + , n_components=2 + , rand_seed=42 + , split=0.7 + , width=3.45 + , height=3.45 / 1.618 ): + clean_df = df.select_dtypes(include=['number']) + umap_reducer = UMAP( n_neighbors=n_neighbors + , min_dist=min_dist + , n_components=n_components + , random_state=rand_seed ) + mask = numpy.random.rand(clean_df.shape[0]) < split + + #clean_df.reset_index(level='class', inplace=True) + classes = clean_df['class'].copy() + semi_labels = classes.copy() + semi_labels[~mask] = -1 # Assuming -1 indicates unknown label for semi-supervision + clean_df.drop('class', axis=1, inplace=True) + + umap_embedding = umap_reducer.fit_transform(clean_df, y=semi_labels) + umap_data=pandas.DataFrame(umap_embedding, columns=["umap0", "umap1"]) + umap_data['class'] = classes + + ax = seaborn.relplot( data=umap_data + , x="umap0" + , y="umap1" + , hue="class" + , palette="deep" + , alpha=0.5 + , edgecolor=None + , s=5 + , height=height + , aspect=0.5 * width / height ) + + seaborn.move_legend(ax, "upper center") + ax.set(xlabel=None, ylabel=None) + seaborn.despine(left=True, bottom=True) + plt.tick_params(bottom=False, left=False, labelbottom=False, labelleft=False) + plt.tight_layout() + plt.savefig(f"{outputdir}/{name}-umap.pdf") + plt.close() + +def save_scores( scores_df + , outputdir='.' + , width = 3.45 + , height = 3.45 / 1.618 ): + # save all raw scores as csv + scores_df.to_csv(f"{outputdir}/scores_df.csv") + # save score means as csv + scores_df.groupby("trial").mean().to_csv(f"{outputdir}/scores_df_mean.csv") + # save a barplot representation of scores + melted_df = scores_df.melt( id_vars="trial" + , var_name="Metric" + , value_name="Score" ) + seaborn.catplot( data=melted_df + , kind="bar" + , x="trial" + , hue="Metric" + , y="Score" + , errorbar="se" + , height=height + , aspect=width * 2**0.5 / height ) + plt.savefig(f"{outputdir}/scores_barplot.pdf") + plt.close() + # log info + logger.info(melted_df.set_index(["trial", "Metric"]) + .xs("test_f1", level="Metric", drop_level=False) + .groupby("trial") + .mean()) + +def save_barplot( scores_df + , outputdir='.' + , width = 7 + , height = 7 / 1.2 ): + # save a barplot representation of scores + melted_df = scores_df[['model', 'beta', 'compression_factor', 'latent_dim', 'batch_size', 'test_f1']].melt( + id_vars=['model', 'beta', 'compression_factor', 'latent_dim', 'batch_size'] + , var_name="Metric" + , value_name="Score" + ) + # test plots... + for m in melted_df['model'].unique(): + # 1 - general overview plot... + df = melted_df.loc[ (melted_df['model'] == m) + , ['compression_factor', 'latent_dim', 'batch_size', 'beta', 'Metric', 'Score'] ].sort_values(by=['compression_factor', 'latent_dim', 'batch_size', 'beta']) + hue = df[['compression_factor', 'latent_dim']].apply(lambda r: f'cf: {r.compression_factor}({r.latent_dim})', axis=1) + if 'beta' in m: + hue = df[['compression_factor', 'latent_dim', 'beta']].apply(lambda r: f'cf: {r.compression_factor}({r.latent_dim}), beta: {r.beta}', axis=1) + ax = seaborn.catplot( data=df + , kind="bar" + , x='batch_size' + , y="Score" + , hue=hue + , errorbar="se" + , height=height + , aspect=width * 2**0.5 / height ) + #ax.tick_params(axis='x', rotation=90) + #ax.set(xlabel=None) + #ax.set(xticklabels=[]) + ax._legend.remove() + #ax.fig.legend(loc='upper center', bbox_to_anchor=(0.5, 0.0), ncol=3) + #ax.fig.legend(ncol=4, loc='lower center') + ax.fig.legend(ncol=1) + #ax.fig.subplots_adjust(top=0.9) + #ax.set(title=f'f1 score against batch size ({m})') + + #add overall title + plt.title(f'f1 score against batch size ({m})', fontsize=16) + + ##add axis titles + #plt.xlabel('') + #plt.ylabel('') + + #rotate x-axis labels + #plt.xticks(rotation=45) + + plt.savefig(f"{outputdir}/barplot_{m}_x_bs.pdf", bbox_inches="tight") + plt.close() + + # 1b - general overview plot... + df = melted_df.loc[ (melted_df['model'] == m) + , ['batch_size', 'compression_factor', 'latent_dim', 'beta', 'Metric', 'Score'] ].sort_values(by=['batch_size', 'compression_factor', 'latent_dim', 'beta']) + hue = df['batch_size'].apply(lambda r: f'bs: {r}') + if 'beta' in m: + hue = df[['batch_size', 'beta']].apply(lambda r: f'bs: {r.batch_size}, beta: {r.beta}', axis=1) + ax = seaborn.catplot( data=df + , kind="bar" + , x=df[['compression_factor', 'latent_dim']].apply(lambda r: f'cf: {r.compression_factor}({r.latent_dim})', axis=1) + , y="Score" + , hue=hue + , errorbar="se" + , height=height + , aspect=width * 2**0.5 / height ) + #ax.tick_params(axis='x', rotation=90) + #ax.set(xlabel=None) + #ax.set(xticklabels=[]) + ax._legend.remove() + #ax.fig.legend(loc='upper center', bbox_to_anchor=(0.5, 0.0), ncol=3) + #ax.fig.legend(ncol=4, loc='lower center') + ax.fig.legend(ncol=1) + #ax.fig.subplots_adjust(top=0.9) + #ax.set(title=f'f1 score against batch size ({m})') + + #add overall title + plt.title(f'f1 score against compression factor (latent space size) ({m})', fontsize=16) + + ##add axis titles + #plt.xlabel('') + #plt.ylabel('') + + #rotate x-axis labels + #plt.xticks(rotation=45) + + plt.savefig(f"{outputdir}/barplot_{m}_x_cf.pdf", bbox_inches="tight") + plt.close() + + # 2 - more specific plots + for cf in melted_df['compression_factor'].unique(): + if 'beta' in m: + for bs in melted_df['batch_size'].unique(): + ax = seaborn.catplot( data=melted_df.loc[ (melted_df['model'] == m) & (melted_df['compression_factor'] == cf) & (melted_df['batch_size'] == bs) + , ['beta', 'Metric', 'Score'] ] + , kind="bar" + , x='beta' + , hue="Metric" + , y="Score" + , errorbar="se" + , height=height + , aspect=width * 2**0.5 / height ) + ax.tick_params(axis='x', rotation=90) + ax.fig.subplots_adjust(top=0.9) + ax.set(title=f'f1 score against beta ({m}, compression factor {cf}, batch size {bs})') + plt.savefig(f"{outputdir}/beta_barplot_{m}_{cf}_{bs}.pdf") + plt.close() + ax = seaborn.catplot( data=melted_df.loc[ (melted_df['model'] == m) & (melted_df['compression_factor'] == cf) + , ['batch_size', 'beta', 'Metric', 'Score'] ] + , kind="bar" + , x='batch_size' + , hue='beta' if 'beta' in m else 'Metric' + , y="Score" + , errorbar="se" + , height=height + , aspect=width * 2**0.5 / height ) + ax.tick_params(axis='x', rotation=90) + ax.fig.subplots_adjust(top=0.9) + ax.set(title=f'f1 score against batch size ({m}, compression factor {cf})') + plt.savefig(f"{outputdir}/barplot_{m}_x_bs_cf{cf}.pdf") + plt.close() + for bs in melted_df['batch_size'].unique(): + ax = seaborn.catplot( data=melted_df.loc[ (melted_df['model'] == m) & (melted_df['batch_size'] == bs) + , ['compression_factor', 'beta', 'Metric', 'Score'] ] + , kind="bar" + , x='compression_factor' + , hue='beta' if 'beta' in m else 'Metric' + , y="Score" + , errorbar="se" + , height=height + , aspect=width * 2**0.5 / height ) + ax.tick_params(axis='x', rotation=90) + ax.fig.subplots_adjust(top=0.9) + ax.set(title=f'f1 score against compression factor ({m}, compression batch size {bs})') + plt.savefig(f"{outputdir}/barplot_{m}_x_cf_bs{bs}.pdf") + plt.close() + # log info + #logger.info(melted_df.set_index(["trial", "Metric"]) + # .xs("test_f1", level="Metric", drop_level=False) + # .groupby("trial") + # .mean()) diff --git a/scripts/shapeembed/gather_run_results.py b/scripts/shapeembed/gather_run_results.py new file mode 100755 index 00000000..b14ffb58 --- /dev/null +++ b/scripts/shapeembed/gather_run_results.py @@ -0,0 +1,280 @@ +#! /usr/bin/env python3 + +import os +import re +import shutil +import logging +import seaborn +import argparse +import datetime +import functools +import pandas as pd + +from common_helpers import * +from evaluation import * + +def trial_table(df, tname): + best_model = df.dropna(subset=['model']).sort_values(by='test_f1', ascending=False).iloc[0] + with open(f'{tname}_tabular.tex', 'w') as fp: + fp.write("\\begin{tabular}{|l|r|} \hline\n") + fp.write("Trial & F1 score \\\\ \hline\n") + name = best_model['trial'].replace('_','\_') + fp.write(f"{name} & {best_model['test_f1']} \\\\ \hline\n") + fp.write(f"regionprops & {df[df['trial'] == 'regionprops'].iloc[0]['test_f1']} \\\\ \hline\n") + fp.write(f"efd & {df[df['trial'] == 'efd'].iloc[0]['test_f1']} \\\\ \hline\n") + fp.write("\end{tabular}\n") + +#def simple_table(df, tname, model_re=".*vq.*"): +def simple_table(df, tname, model_re=".*", sort_by_col=None, ascending=False, best_n=40): + cols=['model', 'compression_factor', 'latent_dim', 'batch_size', 'beta', 'test_f1', 'test_f1_std', 'mse/test'] + df = df.loc[df.model.str.contains(model_re), cols].sort_values(by=cols) + if sort_by_col: + df = df.sort_values(by=sort_by_col, ascending=ascending) + df = df.iloc[:best_n] + + with open(f'{tname}_tabular.tex', 'w') as fp: + fp.write("\\begin{tabular}{|llll|r|r|} \hline\n") + fp.write("Model & CF (and latent space size) & batch size & BETA & F1 score & F1 score (std) & Mse \\\\ \hline\n") + for _, r in df.iterrows(): + mname = r['model'].replace('_','\_') + beta = '-' if pd.isna(r['beta']) else r['beta'] + fp.write(f"{mname} & {r['compression_factor']} ({r['latent_dim']}) & {r['batch_size']} & {beta} & {r['test_f1']:f} & {r['test_f1_std']:f} & {r['mse/test']:f} \\\\\n") + fp.write("\hline\n") + fp.write("\end{tabular}\n") + +def compare_f1_mse_table(df, tname, best_n=40): + cols=['model', 'compression_factor', 'latent_dim', 'batch_size', 'beta', 'test_f1', 'mse/test'] + df0 = df[cols].sort_values(by=cols) + df0 = df0.sort_values(by='test_f1', ascending=False) + df0 = df0.iloc[:best_n] + df1 = df[cols].sort_values(by=cols) + df1 = df1.sort_values(by='mse/test', ascending=True) + df1 = df1.iloc[:best_n] + df = pd.concat([df0.reset_index(), df1.reset_index()], axis=1, keys=['f1', 'mse']) + print(df) + with open(f'{tname}_tabular.tex', 'w') as fp: + fp.write("\\begin{tabular}{|llll|r|r|llll|r|r|} \hline\n") + fp.write("\multicolumn{6}{|l}{Best F1 score} & \multicolumn{6}{|l|}{Best Mse} \\\\\n") + fp.write("Model & CF (latent space) & batch size & BETA & F1 score & Mse & Model & CF (latent space) & batch size & BETA & F1 score & Mse \\\\ \hline\n") + for _, r in df.iterrows(): + f1_name = r[('f1', 'model')].replace('_','\_') + mse_name = r[('mse', 'model')].replace('_','\_') + f1_beta = '-' if pd.isna(r[('f1', 'beta')]) else r[('f1', 'beta')] + mse_beta = '-' if pd.isna(r[('mse', 'beta')]) else r[('mse', 'beta')] + fp.write(f"{f1_name} & {r[('f1', 'compression_factor')]} ({r[('f1', 'latent_dim')]}) & {r[('f1', 'batch_size')]} & {f1_beta} & {r[('f1', 'test_f1')]:f} & {r[('f1', 'mse/test')]:f} & {mse_name} & {r[('mse', 'compression_factor')]} ({r[('mse', 'latent_dim')]}) & {r[('mse', 'batch_size')]} & {mse_beta} & {r[('mse', 'test_f1')]:f} & {r[('mse', 'mse/test')]:f} \\\\\n") + fp.write("\hline\n") + fp.write("\end{tabular}\n") + +def main_process(clargs, logger=logging.getLogger(__name__)): + + dfs = [] + + # regionprops / efd + for dirname in clargs.run_folders: + for f in glob.glob(f'{dirname}/*-regionprops-score_df.csv'): + dfs.append(pd.read_csv(f, index_col=0)) + for f in glob.glob(f'{dirname}/*-efd-score_df.csv'): + dfs.append(pd.read_csv(f, index_col=0)) + + # shapeembed + params = [] + for f in clargs.run_folders: + ps = find_existing_run_scores(f) + for p in ps: p.folder = f + params.append(ps) + params = [x for ps in params for x in ps] + logger.debug(params) + + os.makedirs(clargs.output_dir, exist_ok=True) + + for p in params: + + # open scores dataframe + df = pd.read_csv(p.csv_file, index_col=0) + + # split model column in case model args are present + model_cols = df['model'].str.split('-', n=1, expand=True) + if model_cols.shape[1] == 2: + df = df.drop('model', axis=1) + df.insert(1, 'model_args', model_cols[1]) + df.insert(1, 'model', model_cols[0]) + + # pair up with confusion matrix + conf_mat_file = f'{job_str(p)}-shapeembed-confusion_matrix.png' + print(f'{p.folder}/{conf_mat_file}') + if os.path.isfile(f'{p.folder}/{conf_mat_file}'): + shutil.copy(f'{p.folder}/{conf_mat_file}',f'{clargs.output_dir}/{conf_mat_file}') + df['conf_mat'] = f'./{conf_mat_file}' + else: + df['conf_mat'] = f'nofile' + + # pair up with umap + umap_file = f'{job_str(p)}-shapeembed-umap.pdf' + if os.path.isfile(f'{p.folder}/{umap_file}'): + shutil.copy(f'{p.folder}/{umap_file}',f'{clargs.output_dir}/{umap_file}') + df['umap'] = f'./{umap_file}' + else: + df['umap'] = f'nofile' + + # NA desired columns if not already present + if 'beta' not in df.keys(): + df['beta'] = pd.NA + + ## pair up with barplot + #barplot = f'scores_barplot.pdf' + #if os.path.isfile(f'{d}/{barplot}'): + # shutil.copy(f'{d}/{barplot}',f'{clargs.output_dir}/{run_name}_{barplot}') + # df.loc[df['trial'] == trial, 'barplot'] = f'./{run_name}_{barplot}' + #else: + # df.loc[df['trial'] == trial, 'barplot'] = f'nofile' + + # add dataframe to list for future concatenation + dfs.append(df.convert_dtypes()) + + # gather all dataframes together + df = pd.concat(dfs) + logger.debug(df) + df.to_csv(f'{clargs.output_dir}/all_scores_df.csv', index=False) + save_barplot(df.dropna(subset=['model']), clargs.output_dir) + + #df = df.iloc[:, 1:] # drop first column 'unnamed' for non-mean df + # define a Custom aggregation + # function for finding total + def keep_first_fname(series): + return functools.reduce(lambda x, y: y if str(x) == 'nofile' else x, series) + idx_cols = ['trial', 'dataset', 'model', 'compression_factor', 'latent_dim', 'batch_size'] + df.set_index(idx_cols, inplace=True) + df.sort_index(inplace=True) + #df = df.groupby(level=['trial', 'dataset', 'model', 'compression_factor', 'latent_dim', 'batch_size']).agg({ + df['test_f1_std'] = df['test_f1'].astype(float) + df = df.groupby(level=idx_cols, dropna=False).agg({ + 'beta': 'mean' + , 'test_accuracy': 'mean' + , 'test_precision': 'mean' + , 'test_recall': 'mean' + , 'test_f1': 'mean' + , 'test_f1_std': 'std' + , 'mse/test': 'mean' + , 'loss/test': 'mean' + , 'mse/val': 'mean' + , 'loss/val': 'mean' + , 'conf_mat': keep_first_fname + , 'umap': keep_first_fname + #, 'barplot': keep_first_fname + }) + + print('-'*80) + print(df) + print('-'*80) + df.to_csv(f'{clargs.output_dir}/all_scores_agg_df.csv') + df = df.reset_index() + + # table results for f1 and mse comparison + simple_table(df, f'{clargs.output_dir}/table_top40_f1', sort_by_col='test_f1') + simple_table(df, f'{clargs.output_dir}/table_top40_mse', sort_by_col='mse/test', ascending=True) + # temporarily drop regionprops and efd rows for F1 and MSE comparison + dff = df[(df['trial'] != 'regionprops') & (df['trial'] != 'efd')] + compare_f1_mse_table(dff, f'{clargs.output_dir}/table_top5_compare', best_n=5) + if 'regionprops' in df['trial'].values and 'efd' in df['trial'].values: + trial_table(df, f'{clargs.output_dir}/trials') + else: + logger.info('skipped trial table comparison (need both regionprops and efd results)') + + # mse / f1 plots + dff=df[df['mse/test'] instead of + # 'selector': 'td:hover', + # 'props': [('background-color', '#ffffb3')] + # } + #index_names = { + # 'selector': '.index_name', + # 'props': 'font-style: italic; color: darkgrey; font-weight:normal;' + # } + #headers = { + # 'selector': 'th:not(.index_name)', + # 'props': 'background-color: #eeeeee; color: #333333;' + # } + + #def html_img(path): + # if os.path.splitext(path)[1][1:] == 'png': + # return f'' + # if os.path.splitext(path)[1][1:] == 'pdf': + # return f'' + # return '
:(
' + #df['conf_mat'] = df['conf_mat'].apply(html_img) + #df['umap'] = df['umap'].apply(html_img) + #df['barplot'] = df['barplot'].apply(html_img) + + #def render_html(fname, d): + # with open(fname, 'w') as f: + # f.write(''' + # + # + # + # ''') + # s = d.style + # s.set_table_styles([cell_hover, index_names, headers]) + # s.to_html(f, classes='df') + # f.write('') + + #with open(f'{clargs.output_dir}/gathered_table.tex', 'w') as f: + # f.write('\\documentclass[12pt]{article}\n\\usepackage{booktabs}\n\\usepackage{underscore}\n\\usepackage{multirow}\n\\begin{document}\n') + # df.to_latex(f) + # f.write('\\end{decument}') + #render_html(f'{clargs.output_dir}/gathered_table.html', df) + + #dft = df.transpose() + #with open(f'{clargs.output_dir}/gathered_table_transpose.tex', 'w') as f: + # f.write('\\documentclass[12pt]{article}\n\\usepackage{booktabs}\n\\usepackage{underscore}\n\\usepackage{multirow}\n\\begin{document}\n') + # dft.to_latex(f) + # f.write('\\end{decument}') + #render_html(f'{clargs.output_dir}/gathered_table_transpose.html', dft) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description='Run the shape embed pipeline') + + parser.add_argument( 'run_folders', metavar='run_folder', nargs="+", type=str + , help=f"The runs folders to gather results from") + parser.add_argument( '-o', '--output-dir', metavar='OUTPUT_DIR' + , default=f'{os.getcwd()}/gathered_results_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}' + , help=f"The OUTPUT_DIR path to use to gather results") + parser.add_argument('-v', '--verbose', action='count', default=0 + , help="Increase verbosity level by adding more \"v\".") + + # parse command line arguments + clargs=parser.parse_args() + + # set verbosity level + logging.basicConfig() + logger = logging.getLogger(__name__) + if clargs.verbose > 1: + logger.setLevel('DEBUG') + elif clargs.verbose > 0: + logger.setLevel('INFO') + + main_process(clargs, logger) diff --git a/scripts/shapeembed/readme.md b/scripts/shapeembed/readme.md new file mode 100644 index 00000000..76bebf92 --- /dev/null +++ b/scripts/shapeembed/readme.md @@ -0,0 +1,9 @@ +# Shape Embed + +There are currently 3 toplevel scripts: + +- shapeembed.py +- regionprops.py +- efd.py + +Each can be run to generate results, a umap and a confusion matrix. Each have a `-o` option to specify an output directory. diff --git a/scripts/shapeembed/regionprops.py b/scripts/shapeembed/regionprops.py new file mode 100755 index 00000000..a2325c86 --- /dev/null +++ b/scripts/shapeembed/regionprops.py @@ -0,0 +1,99 @@ +#! /usr/bin/env python3 + +import os +import types +import random +import logging +import argparse +from skimage import measure + +# own imports +#import bioimage_embed # necessary for the datamodule class to make sure we get the same test set +from evaluation import * + +def get_dataset(dataset_params): + # access the dataset + assert dataset_params.type == 'mask', f'unsupported dataset type {dataset_params.type}' + raw_dataset = datasets.ImageFolder(dataset_params.path, transforms.Grayscale(1)) + dataset = [x for x in raw_dataset] + random.shuffle(dataset) + return dataset + +def run_regionprops( dataset + , properties + , logger ): + # run regionprops for the given properties for each image + dfs = [] + logger.info(f'running regionprops on {dataset}, properties: {properties}') + for i, (img, lbl) in enumerate(tqdm.tqdm(dataset)): + data = numpy.where(numpy.array(img)>20, 255, 0) + t = measure.regionprops_table(data, properties=properties) + df = pandas.DataFrame(t) + assert df.shape[0] == 1, f'More than one object in image #{i}' + df.index = [i] + df['class'] = lbl + #df.set_index("class", inplace=True) + dfs.append(df) + # concatenate results as a single dataframe and return it + df = pandas.concat(dfs) + return df + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Run regionprops on a given dataset') + + dflt_dataset=('tiny_synthetic_shapes', '/nfs/research/uhlmann/afoix/datasets/image_datasets/tiny_synthetic_shapes', 'mask') + parser.add_argument( + '-d', '--dataset', nargs=3, metavar=('NAME', 'PATH', 'TYPE'), default=dflt_dataset + , help=f"The NAME, PATH and TYPE of the dataset (default: {dflt_dataset})") + + dflt_properties=[ "area" + , "perimeter" + , "centroid" + , "major_axis_length" + , "minor_axis_length" + , "orientation" ] + + parser.add_argument( + '-p', '--properties', metavar='PROP', default=dflt_properties, nargs='+' + , help=f"Overwrite the list of properties to consider (default: {dflt_properties})") + + parser.add_argument( + '-o', '--output-dir', metavar='OUTPUT_DIR', default='./' + , help=f"The OUTPUT_DIR path to use to dump results") + + parser.add_argument('-v', '--verbose', action='count', default=0 + , help="Increase verbosity level by adding more \"v\".") + + # parse command line arguments + clargs=parser.parse_args() + + # set verbosity level + logger = logging.getLogger(__name__) + if clargs.verbose > 2: + logger.setLevel(logging.DEBUG) + elif clargs.verbose > 0: + logger.setLevel(logging.INFO) + + # update default params with clargs + dataset = types.SimpleNamespace( name=clargs.dataset[0] + , path=clargs.dataset[1] + , type=clargs.dataset[2] ) + properties = clargs.properties + + # create output dir if it does not exist + os.makedirs(clargs.output_dir, exist_ok=True) + + # regionprops on input data and score + + regionprops_df = run_regionprops(get_dataset(dataset), properties, logger) + + logger.info(f'-- regionprops on {dataset.name}, raw\n{regionprops_df}') + regionprops_df.to_csv(f"{clargs.output_dir}/{dataset.name}-regionprops-raw_df.csv") + umap_plot(regionprops_df, f'{dataset.name}-regionprops', outputdir=clargs.output_dir) + + regionprops_cm, regionprops_score_df = score_dataframe(regionprops_df, 'regionprops') + + logger.info(f'-- regionprops on {dataset.name}, score\n{regionprops_score_df}') + regionprops_score_df.to_csv(f"{clargs.output_dir}/{dataset.name}-regionprops-score_df.csv") + logger.info(f'-- confusion matrix:\n{regionprops_cm}') + confusion_matrix_plot(regionprops_cm, f'{dataset.name}-regionprops', clargs.output_dir) diff --git a/scripts/shapeembed/shapeembed.py b/scripts/shapeembed/shapeembed.py new file mode 100755 index 00000000..9f18a7c0 --- /dev/null +++ b/scripts/shapeembed/shapeembed.py @@ -0,0 +1,548 @@ +#! /usr/bin/env python3 + +# machine learning utils +import torch +from torchvision import datasets, transforms +import pytorch_lightning as pl +from pytorch_lightning import loggers as pl_loggers +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint + +# general utils +import os +import re +import copy +import types +import pickle +import base64 +import pandas +import hashlib +import logging +import datetime +import functools + +# own source files +import bioimage_embed +import bioimage_embed.shapes +from dataset_transformations import * +from evaluation import * + +from common_helpers import * + +# logging facilities +############################################################################### +logger = logging.getLogger(__name__) + +# script inputs and parameters +############################################################################### + +# available types of datasets (raw, masks, distance matrix) +dataset_types = [ + "raw_image" +, "mask" +, "distance_matrix" +] + +# available models +models = [ + "resnet18_vae" +, "resnet50_vae" +, "resnet18_beta_vae" +, "resnet50_beta_vae" +, "resnet18_vae_bolt" +, "resnet50_vae_bolt" +, "resnet18_vqvae" +, "resnet50_vqvae" +, "resnet18_vqvae_legacy" +, "resnet50_vqvae_legacy" +, "resnet101_vqvae_legacy" +, "resnet110_vqvae_legacy" +, "resnet152_vqvae_legacy" +, "resnet18_vae_legacy" +, "resnet50_vae_legacy" +, "o2vae" +] + +# set of parameters for a run, with default values +dflt_params = types.SimpleNamespace( + model_name='resnet18_vae' +, dataset=types.SimpleNamespace( + name='tiny_synthetic_shapes' + , path='/nfs/research/uhlmann/afoix/datasets/image_datasets/tiny_synthetic_shapes' + , type='mask' + ) +, batch_size=4 +, compression_factor=2 +, distance_matrix_size=512 +, num_embeddings=1024 +, num_hiddens=1024 +, num_workers=8 +, min_epochs=50 +, max_epochs=150 +, pretrained=False +, frobenius_norm=False +, early_stop=False +, distance_matrix_normalize=True +, distance_matrix_roll_probability=1.0 +, checkpoints_path='./checkpoints' +, commitment_cost=0.25 +, decay=0.99 +# optimizer_params +, opt="AdamW" +, lr=0.001 +, weight_decay=0.0001 +, momentum=0.9 +# lr_scheduler_params +, sched="cosine" +, min_lr=1e-4 +, warmup_epochs=5 +, warmup_lr=1e-6 +, cooldown_epochs=10 +, t_max=50 +, cycle_momentum=False +) + +def tag_cols(params): + cols = [] + cols.append(('dataset', params.dataset.name)) + cols.append(('model', params.model_name)) + for k, v in vars(params.model_args).items(): cols.append((k, v)) + cols.append(('compression_factor', params.compression_factor)) + cols.append(('latent_dim', params.latent_dim)) + cols.append(('batch_size', params.batch_size)) + return cols + +def oom_retry(f, *args, n_oom_retries=1, logger=logging.getLogger(__name__), **kwargs): + try: + logger.info(f'Trying {f.__name__} within oom_retry, n_oom_retries = {n_oom_retries}') + return f(*args, **kwargs) + except RuntimeError as e: + if 'out of memory' in str(e) and n_oom_retries > 0: + logger.warning(f'{f.__name__} ran out of memory, retrying') + torch.cuda.empty_cache() + return oom_retry(f, *args, n_oom_retries=n_oom_retries-1, logger=logger, **kwargs) + else: + raise e + +# dataset loading functions +############################################################################### + +def maybe_roll(dist_mat, p = 0.5): + if np.random.rand() < p: + return np.roll(dist_mat, np.random.randint(0, dist_mat.shape[0]), (0,1)) + else: + return dist_mat + +def sanity_check(dist_mat): + if not np.allclose(dist_mat, dist_mat.T): + raise ValueError("Matrix is not symmetric") + if np.any(dist_mat < 0): + raise ValueError("Matrix has negative values") + if np.any(np.diag(dist_mat)): + raise ValueError("Matrix has non-zero diagonal") + return dist_mat + +def get_dataloader(params): + + # transformations / checks to run on distance matrices + ts = [] + if params.distance_matrix_normalize: # optionally normalize the matrix + ts.append(lambda x: x / np.linalg.norm(x, "fro")) + if params.distance_matrix_roll_probability > 0.0: # optionally try to roll the matrix + ts.append(lambda x: maybe_roll(x, p=params.distance_matrix_roll_probability)) + # always check if the matrix is symmetric, positive, and diagonal is zero + ts.append(sanity_check) + # turn (H,W) numpy array into a (H,W) tensor + ts.append(torch.as_tensor) + # turn (H,W) tensor into a (3,H,W) tensor (downstream model expectations) + ts.append(lambda x: x.repeat(3, 1, 1)) + # compose the all the distance matrix transformations + logger.debug(f'transformations to run: {len(ts)}') + distmat_ts = transforms.Compose(ts) + + # dataset to load + logger.info(f'loading dataset {params.dataset.name}') + dataset = None + if params.dataset.type == 'raw_image': # TODO + raise NotImplementedError("raw images not yet supported") + elif params.dataset.type == 'mask': # mask data, convert to distance matrix first + #dataset = datasets.ImageFolder( + # params.dataset.path + #, transforms.Compose([ np.array + # , functools.partial( mask2distmatrix + # , matrix_size=params.distance_matrix_size ) + # , distmat_ts ])) + def f(x): + print(f"DEBUG: shape:{x.shape}") + return x + def g(x): + print(f"-------------") + return x + dataset = datasets.ImageFolder( + params.dataset.path + , transforms.Compose([ np.array + , functools.partial(recrop_image, square=True) + , torch.as_tensor + , lambda x: torch.transpose(x, 0, 2) + , transforms.Resize(64) + , lambda x: torch.transpose(x, 0, 2) + , rgb2grey + #, lambda x: x.repeat(3, 1, 1) + , lambda x: x.repeat(1, 1, 1) + ])) + elif params.dataset.type == 'distance_matrix': # distance matrix data + dataset = datasets.DatasetFolder( params.dataset.path + , loader=np.load + , extensions=('npy') + , transform = distmat_ts ) + assert dataset, f"could not load dataset {params.dataset.name}" + # create the dataloader from the dataset and other parameters + dataloader = bioimage_embed.lightning.DataModule( + dataset + , batch_size=params.batch_size + , shuffle=True + , num_workers=params.num_workers + ) + dataloader.setup() + logger.info(f'dataloader ready') + return dataloader + +# model +############################################################################### + +def get_model(params): + logger.info(f'setup model') + model = bioimage_embed.models.create_model( + model=params.model_name + , input_dim=params.input_dim + , latent_dim=params.latent_dim + , pretrained=params.pretrained + , **vars(params.model_args) + ) + lit_model = bioimage_embed.shapes.MaskEmbed(model, params) + logger.info(f'model ready') + return lit_model + +# trainer +############################################################################### + +def hashing_fn(args): + serialized_args = pickle.dumps(vars(args)) + hash_object = hashlib.sha256(serialized_args) + hashed_string = base64.urlsafe_b64encode(hash_object.digest()).decode() + return hashed_string + +def get_trainer(model, params): + + # setup WandB logger + logger.info('setup wandb logger') + wandblogger = pl_loggers.WandbLogger(entity=params.wandb_entity, project=params.wandb_project, name=job_str(params)) + wandblogger.watch(model, log="all") + + # setup checkpoints + logger.info('setup checkpoints') + model_dir = f"{params.checkpoints_path}/{hashing_fn(params)}" + os.makedirs(f"{model_dir}/", exist_ok=True) + checkpoint_callback = ModelCheckpoint( + dirpath=f"{model_dir}/" + , save_last=True + , save_top_k=1 + , monitor="loss/val" + , mode="min" + ) + + # setup trainer + logger.info('setup trainer') + trainer_callbacks = [checkpoint_callback] + if params.early_stop: + trainer_callbacks.append(EarlyStopping(monitor="loss/val", mode="min")) + trainer = pl.Trainer( + logger=[wandblogger] + , gradient_clip_val=0.5 + , enable_checkpointing=True + , devices=1 + , accelerator="gpu" + , accumulate_grad_batches=4 + , callbacks=trainer_callbacks + , min_epochs=params.min_epochs + , max_epochs=params.max_epochs + , log_every_n_steps=1 + ) + + logger.info(f'trainer ready') + return trainer + +# train / validate / test the model +############################################################################### + +def train_model(trainer, model, dataloader): + # retrieve the checkpoint information from the trainer and check if a + # checkpoint exists to resume from + checkpoint_callback = trainer.checkpoint_callback + last_checkpoint_path = checkpoint_callback.last_model_path + best_checkpoint_path = checkpoint_callback.best_model_path + if os.path.isfile(last_checkpoint_path): + resume_checkpoint = last_checkpoint_path + elif best_checkpoint_path and os.path.isfile(best_checkpoint_path): + resume_checkpoint = best_checkpoint_path + else: + resume_checkpoint = None + # train the model + logger.info('training the model') + trainer.fit(model, datamodule=dataloader, ckpt_path=resume_checkpoint) + model.eval() + + return model + +def validate_model(trainer, model, dataloader): + logger.info('validating the model') + validation = trainer.validate(model, datamodule=dataloader) + return validation + +def test_model(trainer, model, dataloader): + logger.info('testing the model') + testing = trainer.test(model, datamodule=dataloader) + return testing + +def run_predictions(trainer, model, dataloader, num_workers=8): + + # prepare new unshuffled datamodule + datamod = bioimage_embed.lightning.DataModule( + dataloader.dataset + , batch_size=1 + , shuffle=False + , num_workers=num_workers + ) + datamod.setup() + + # run predictions + predictions = trainer.predict(model, datamodule=datamod) + + # extract latent space + latent_space = torch.stack([d.out.z.flatten() for d in predictions]).numpy() + + # extract class indices and filenames and provide a richer pandas dataframe + ds = datamod.get_dataset() + class_indices = np.array([ int(lbl) + for _, lbl in datamod.predict_dataloader() ]) + fnames = [fname for fname, _ in ds.samples] + df = pandas.DataFrame(latent_space) + df.insert(loc=0, column='fname', value=fnames) + #df.insert(loc=0, column='scale', value=scalings[:,0].squeeze()) + df.insert( loc=0, column='class_name' + , value=[ds.classes[x] for x in class_indices]) + df.insert(loc=0, column='class', value=class_indices) + #df.set_index("class", inplace=True) + df.columns = df.columns.astype(str) # only string column names + + return latent_space, df + +# main process +############################################################################### + +def main_process(params): + + # setup + ####### + model = oom_retry(get_model, params) + trainer = oom_retry(get_trainer, model, params) + dataloader = oom_retry(get_dataloader, params) + + # run actual work + ################# + oom_retry(train_model, trainer, model, dataloader, n_oom_retries=2) + oom_retry(validate_model, trainer, model, dataloader) + oom_retry(test_model, trainer, model, dataloader) + + # run predictions + ################# + # ... and gather latent space + os.makedirs(f"{params.output_dir}/", exist_ok=True) + logger.info(f'-- run predictions and extract latent space --') + latent_space, shapeembed_df = run_predictions( + trainer, model, dataloader + , num_workers=params.num_workers + ) + + # gather and log stats + ###################### + logger.debug(f'\n{shapeembed_df}') + pfx=job_str(params) + np.save(f'{params.output_dir}/{pfx}-shapeembed-latent_space.npy', latent_space) + shapeembed_df.to_pickle(f'{params.output_dir}/{pfx}-shapeembed-latent_space.pkl') + shapeembed_df.to_csv(f"{params.output_dir}/{pfx}-shapeembed-raw_df.csv") + logger.info(f'-- generate shapeembed umap --') + umap_plot(shapeembed_df, f'{pfx}-shapeembed', outputdir=params.output_dir) + logger.info(f'-- score shape embed --') + shapeembed_cm, shapeembed_score_df = score_dataframe(shapeembed_df, pfx, tag_cols(params)+[(k, v.item()) for k, v in model.metrics.items()]) + logger.info(f'-- shapeembed on {params.dataset.name}, score\n{shapeembed_score_df}') + shapeembed_score_df.to_csv(f"{params.output_dir}/{pfx}-shapeembed-score_df.csv") + logger.info(f'-- confusion matrix:\n{shapeembed_cm}') + confusion_matrix_plot(shapeembed_cm, f'{pfx}-shapeembed', params.output_dir) + # XXX TODO move somewhere else if desired XXX + ## combined shapeembed + efd + regionprops + #logger.info(f'-- shapeembed + efd + regionprops --') + #comb_df = pandas.concat([ shapeembed_df + # , efd_df.drop('class', axis=1) + # , regionprops_df.drop('class', axis=1) ], axis=1) + #logger.debug(f'\n{comb_df}') + #comb_cm, comb_score_df = score_dataframe(comb_df, 'combined_all') + #logger.info(f'-- shapeembed + efd + regionprops on input data') + #logger.info(f'-- score:\n{comb_score_df}') + #logger.info(f'-- confusion matrix:\n{comb_cm}') + #confusion_matrix_plot(comb_cm, 'combined_all', params.output_dir) + ## XXX Not currently doing the kmeans + ## XXX kmeans on input data and score + ##logger.info(f'-- kmeans on input data --') + ##kmeans, accuracy, conf_mat = run_kmeans(dataloader_to_dataframe(dataloader.predict_dataloader())) + ##print(kmeans) + ##logger.info(f'-- kmeans accuracy: {accuracy}') + ##logger.info(f'-- kmeans confusion matrix:\n{conf_mat}') + + ## collate and save gathered results TODO KMeans + #scores_df = pandas.concat([ regionprops_score_df + # , efd_score_df + # , shapeembed_score_df + # , comb_score_df ]) + #save_scores(scores_df, outputdir=params.output_dir) + +# main entry point +############################################################################### +if __name__ == '__main__': + + def auto_pos_int (x): + val = int(x,0) + if val <= 0: + raise argparse.ArgumentTypeError(f"argument must be a positive int. Got {val:d}.") + return val + + def prob (x): + val = float(x) + if val < 0.0 or val > 1.0: + raise argparse.ArgumentTypeError(f"argument must be between 0.0 and 1.0. Got {val:f}.") + return val + + parser = argparse.ArgumentParser(description='Run the shape embed pipeline') + + parser.add_argument( + '-m', '--model', choices=models, metavar='MODEL' + , help=f"The MODEL to use, one of {models} (default {dflt_params.model_name}).") + parser.add_argument( + '--model-arg-beta', type=float, metavar='BETA' + , help=f"The BETA parameter to use for a beta-vae model.") + parser.add_argument( + '-d', '--dataset', nargs=3, metavar=('NAME', 'PATH', 'TYPE') + , help=f"The NAME, PATH and TYPE of the dataset (default: {dflt_params.dataset})") + parser.add_argument( + '-o', '--output-dir', metavar='OUTPUT_DIR', default=None + , help=f"The OUTPUT_DIR path to use to dump results") + parser.add_argument( + '--wandb-entity', default="foix", metavar='WANDB_ENTITY' + , help=f"The WANDB_ENTITY name") + parser.add_argument( + '--wandb-project', default="simply-shape", metavar='WANDB_PROJECT' + , help=f"The WANDB_PROJECT name") + parser.add_argument( + '-b', '--batch-size', metavar='BATCH_SIZE', type=auto_pos_int + , help=f"The BATCH_SIZE for the run, a positive integer (default {dflt_params.batch_size})") + parser.add_argument( + '--early-stop', action=argparse.BooleanOptionalAction, default=None + , help=f'Whether to stop training early or not (when loss "stops" decreasing. Beware of second decay...)') + parser.add_argument( + '--distance-matrix-normalize', action=argparse.BooleanOptionalAction, default=None + , help=f'Whether to normalize the distance matrices or not') + parser.add_argument( + '--distance-matrix-roll-probability', metavar='ROLL_PROB', type=prob, default=None + , help=f'Probability to roll the distance matrices along the diagonal (default {dflt_params.distance_matrix_roll_probability})') + parser.add_argument( + '-c', '--compression-factor', metavar='COMPRESSION_FACTOR', type=auto_pos_int + , help=f"The COMPRESSION_FACTOR, a positive integer (default {dflt_params.compression_factor})") + parser.add_argument( + '--distance-matrix-size', metavar='MATRIX_SIZE', type=auto_pos_int + , help=f"The size of the distance matrix (default {dflt_params.distance_matrix_size})") + parser.add_argument( + '--number-embeddings', metavar='NUM_EMBEDDINGS', type=auto_pos_int + , help=f"The NUM_EMBEDDINGS, a positive integer (default {dflt_params.num_embeddings})") + parser.add_argument( + '--number-hiddens', metavar='NUM_HIDDENS', type=auto_pos_int + , help=f"The NUM_HIDDENS, a positive integer (default {dflt_params.num_hiddens})") + parser.add_argument( + '-n', '--num-workers', metavar='NUM_WORKERS', type=auto_pos_int + , help=f"The NUM_WORKERS for the run, a positive integer (default {dflt_params.num_workers})") + parser.add_argument( + '--min-epochs', metavar='MIN_EPOCHS', type=auto_pos_int + , help=f"Set the MIN_EPOCHS for the run, a positive integer (default {dflt_params.min_epochs})") + parser.add_argument( + '--max-epochs', metavar='MAX_EPOCHS', type=auto_pos_int + , help=f"Set the MAX_EPOCHS for the run, a positive integer (default {dflt_params.max_epochs})") + parser.add_argument( + '-e', '--num-epochs', metavar='NUM_EPOCHS', type=auto_pos_int + , help=f"Forces the NUM_EPOCHS for the run, a positive integer (sets both min and max epoch)") + parser.add_argument('--clear-checkpoints', action='store_true' + , help='remove checkpoints') + parser.add_argument('-v', '--verbose', action='count', default=0 + , help="Increase verbosity level by adding more \"v\".") + + # parse command line arguments + clargs=parser.parse_args() + + # set verbosity level + if clargs.verbose > 2: + logger.setLevel(logging.DEBUG) + elif clargs.verbose > 0: + logger.setLevel(logging.INFO) + + # update default params with clargs + params = copy.deepcopy(dflt_params) + if clargs.model: + params.model_name = clargs.model + params.model_args = types.SimpleNamespace() + if clargs.model_arg_beta: + params.model_args.beta = clargs.model_arg_beta + params.output_dir = clargs.output_dir + if clargs.dataset: + params.dataset = types.SimpleNamespace( name=clargs.dataset[0] + , path=clargs.dataset[1] + , type=clargs.dataset[2] ) + + if clargs.wandb_entity: + params.wandb_entity = clargs.wandb_entity + if clargs.wandb_project: + params.wandb_project = clargs.wandb_project + if clargs.batch_size: + params.batch_size = clargs.batch_size + if clargs.distance_matrix_size: + params.distance_matrix_size = clargs.distance_matrix_size + params.input_dim = (3, params.distance_matrix_size, params.distance_matrix_size) + if clargs.early_stop is not None: + params.early_stop = clargs.early_stop + if clargs.distance_matrix_normalize is not None: + params.distance_matrix_normalize = clargs.distance_matrix_normalize + if clargs.distance_matrix_roll_probability is not None: + params.distance_matrix_roll_probability = clargs.distance_matrix_roll_probability + if clargs.compression_factor: + params.compression_factor = clargs.compression_factor + params.latent_dim = compressed_n_features(params.distance_matrix_size, params.compression_factor) + if clargs.number_embeddings: + params.num_embeddings = clargs.number_embeddings + if clargs.number_hiddens: + params.num_hiddens = clargs.number_hiddens + if clargs.num_workers: + params.num_workers = clargs.num_workers + if clargs.min_epochs: + params.min_epochs = clargs.min_epochs + if clargs.max_epochs: + params.max_epochs = clargs.max_epochs + if clargs.num_epochs: + params.min_epochs = clargs.num_epochs + params.max_epochs = clargs.num_epochs + if clargs.output_dir: + params.output_dir = clargs.output_dir + else: + params.output_dir = f'./{job_str(params)}_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}' + + # XXX + torch.set_float32_matmul_precision('medium') + # XXX + logger.debug(f'run parameters:\n{params}') + main_process(params) diff --git a/scripts/shapeembed/slurm_sweep_shapeembed.py b/scripts/shapeembed/slurm_sweep_shapeembed.py new file mode 100755 index 00000000..d04a5d5f --- /dev/null +++ b/scripts/shapeembed/slurm_sweep_shapeembed.py @@ -0,0 +1,221 @@ +#! /usr/bin/env python3 + +import os +import glob +import copy +import types +import logging +import tempfile +import argparse +import datetime +import itertools +import subprocess + +from common_helpers import * + +# shapeembed parameters to sweep +################################################################################ + +datasets_pfx = '/nfs/research/uhlmann/afoix/datasets/image_datasets' +datasets = [ +# ("synthetic_shapes", f"{datasets_pfx}/synthetic_shapes/", "mask") +# ("tiny_synthcell", f"{datasets_pfx}/tiny_synthcellshapes_dataset/", "mask") +# ("vampire", f"{datasets_pfx}/vampire/torchvision/Control/", "mask") + ("mefs_cells", f"{datasets_pfx}/mefs_single_object_cell/", "mask") +, ("vampire_nuclei", f"{datasets_pfx}/vampire_nuclei/", "mask") +, ("binary_vampire", f"{datasets_pfx}/binary_vampire/", "mask") +, ("bbbc010", f"{datasets_pfx}/bbbc010/BBBC010_v1_foreground_eachworm/", "mask") +, ("synthcell", f"{datasets_pfx}/synthcellshapes_dataset/", "mask") +, ("helakyoto", f"{datasets_pfx}/H2b_10x_MD_exp665/samples/", "mask") +, ("allen", f"{datasets_pfx}/allen_dataset/", "mask") +] + +models = [ + "o2vae" +# "resnet18_vqvae" +#, "resnet50_vqvae" +#, "resnet18_vae" +#, "resnet50_vae" +#, "resnet18_beta_vae" +#, "resnet50_beta_vae" +#, "resnet18_vae_bolt" +#, "resnet50_vae_bolt" +#, "resnet18_vqvae_legacy" +#, "resnet50_vqvae_legacy" +#, "resnet101_vqvae_legacy" +#, "resnet110_vqvae_legacy" +#, "resnet152_vqvae_legacy" +#, "resnet18_vae_legacy" +#, "resnet50_vae_legacy" +] + +model_params = { + #"resnet18_beta_vae": {'beta': [2,5]} +# "resnet18_beta_vae": {'beta': [0.0001]} +#, "resnet50_beta_vae": {'beta': [2,5]} +#, "resnet50_beta_vae": {'beta': [0.00001]} +} + +#compression_factors = [1,2,3,5,10] +compression_factors = [1] + +batch_sizes = [4, 16, 64, 128, 256] + +# XXX XXX XXX XXX XXX XXX XXX # +# XXX ad-hoc one-off config XXX # +# XXX XXX XXX XXX XXX XXX XXX # +# uncomment the lines below for a quick overwrite of the parameter sweep +#datasets = [("synthetic_shapes", f"{datasets_pfx}/synthetic_shapes/", "mask")] +#models = ["resnet50_vae"] +#model_params = {} #{"resnet50_beta_vae": {'beta': [1]}} +#compression_factors = [10] +#batch_sizes = [16] + +def gen_params_sweep_list(): + p_sweep_list = [] + for params in [ { 'dataset': types.SimpleNamespace(name=ds[0], path=ds[1], type=ds[2]) + , 'model_name': m + , 'compression_factor': cf + , 'latent_dim': compressed_n_features(512, cf) + , 'batch_size': bs + } for ds in datasets + for m in models + for cf in compression_factors + for bs in batch_sizes ]: + # per model params: + if params['model_name'] in model_params: + mps = model_params[params['model_name']] + for ps in [dict(zip(mps.keys(), vs)) for vs in itertools.product(*mps.values())]: + newparams = copy.deepcopy(params) + newparams['model_args'] = types.SimpleNamespace(**ps) + p_sweep_list.append(types.SimpleNamespace(**newparams)) + else: + p_sweep_list.append(types.SimpleNamespace(**params)) + return p_sweep_list + +def params_match(x, ys): + found = False + def check_model_args(a, b): + a_yes = hasattr(a, 'model_args') + b_yes = hasattr(b, 'model_args') + if not a_yes and not b_yes: return True + if a_yes and b_yes: return a.model_args == b.model_args + return False + for y in ys: + if x.dataset.name == y.dataset.name \ + and x.model_name == y.model_name \ + and check_model_args(x, y) \ + and x.compression_factor == y.compression_factor \ + and x.latent_dim == y.latent_dim \ + and x.batch_size == y.batch_size: + found = True + break + return found + +def find_submitted_slurm_jobs(): + jobs = subprocess.run(['squeue', '--format', '%j'], stdout=subprocess.PIPE).stdout.decode('utf-8').split() + return list(map(params_from_job_str, filter(lambda x: x, map(job_str_re().match, jobs[1:])))) + +# other parameters +################################################################################ + +dflt_slurm_dir=f'{os.getcwd()}/slurm_info_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}' +dflt_out_dir=f'{os.getcwd()}/output_results_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}' + +slurm_time = '50:00:00' +slurm_mem = '80G' +slurm_gpus = 'a100:1' + +shapeembed_script=f'{os.getcwd()}/shapeembed.py' +wandb_project='shapeembed' + +################################################################################ + +def spawn_slurm_job(slurm_out_dir, out_dir, ps, logger=logging.getLogger(__name__)): + + jobname = job_str(ps) + cmd = [ 'python3', shapeembed_script + , '--wandb-project', wandb_project + , '--output-dir', out_dir + ] + cmd += [ '--clear-checkpoints' + , '--no-early-stop' + , '--num-epochs', 150 + ] + cmd += [ '--dataset', ps.dataset.name, ps.dataset.path, ps.dataset.type + , '--model', ps.model_name + , '--compression-factor', ps.compression_factor + , '--batch-size', ps.batch_size + ] + if hasattr(ps, 'model_args'): + for k, v in vars(ps.model_args).items(): + cmd.append(f'--model-arg-{k}') + cmd.append(f'{v}') + logger.debug(" ".join(map(str,cmd))) + with tempfile.NamedTemporaryFile('w+') as fp: + fp.write('#! /usr/bin/env sh\n') + fp.write(" ".join(map(str,cmd))) + fp.write('\n') + fp.flush() + cmd = [ 'sbatch' + , '--time', slurm_time + , '--mem', slurm_mem + , '--job-name', jobname + , '--output', f'{slurm_out_dir}/{jobname}.out' + , '--error', f'{slurm_out_dir}/{jobname}.err' + , f'--gpus={slurm_gpus}' + , fp.name ] + logger.debug(" ".join(map(str,cmd))) + result = subprocess.run(cmd, stdout=subprocess.PIPE) + logger.debug(result.stdout.decode('utf-8')) + logger.info(f'job spawned for {ps}') + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description='Sweap parameters for shapeembed') + + parser.add_argument( + '-s', '--slurm-output-dir', metavar='SLURM_OUTPUT_DIR', default=dflt_slurm_dir + , help=f"The SLURM_OUTPUT_DIR path to use to dump slurm info") + + parser.add_argument( + '-o', '--output-dir', metavar='OUTPUT_DIR', default=dflt_out_dir + , help=f"The OUTPUT_DIR path to use to dump results") + + parser.add_argument( + '--filter-done', action=argparse.BooleanOptionalAction, default=True + , help=f'filter out jobs with results (a *scores_df.csv) in OUTPUT_DIR') + + parser.add_argument( + '--filter-submitted', action=argparse.BooleanOptionalAction, default=True + , help=f'filter out jobs present in the current slurm `squeue`') + + parser.add_argument('-v', '--verbose', action='count', default=0 + , help="Increase verbosity level by adding more \"v\".") + + # parse command line arguments + clargs=parser.parse_args() + + # set verbosity level + logging.basicConfig() + logger = logging.getLogger(__name__) + if clargs.verbose > 1: + logger.setLevel('DEBUG') + elif clargs.verbose > 0: + logger.setLevel('INFO') + + os.makedirs(clargs.slurm_output_dir, exist_ok=True) + os.makedirs(clargs.output_dir, exist_ok=True) + + todo_params = gen_params_sweep_list() + + if clargs.filter_done: + done_params = find_existing_run_scores(clargs.output_dir) + todo_params = [x for x in todo_params if not params_match(x, done_params)] + if clargs.filter_submitted: + in_slurm_params = find_submitted_slurm_jobs() + todo_params = [x for x in todo_params if not params_match(x, in_slurm_params)] + + for ps in todo_params: + spawn_slurm_job(clargs.slurm_output_dir, clargs.output_dir, ps, logger=logger) diff --git a/scripts/shapes/check_latent_space.py b/scripts/shapes/check_latent_space.py new file mode 100644 index 00000000..6fb085a4 --- /dev/null +++ b/scripts/shapes/check_latent_space.py @@ -0,0 +1,125 @@ +import pandas as pd +import numpy as np +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import train_test_split, cross_validate +from sklearn.preprocessing import StandardScaler +from sklearn.decomposition import PCA +from sklearn.pipeline import Pipeline +from sklearn import svm +from sklearn.ensemble import GradientBoostingClassifier +from sklearn.metrics import classification_report, confusion_matrix +import umap +import seaborn as sns +import matplotlib.pyplot as plt +import os +from tabulate import tabulate +import json + +pd.set_option('display.max_colwidth', None) + +df = pd.read_csv("clustered_data.csv") + +df.insert(0, 'label', df['fname'].str.extract(r'^(?:[^/]*/){7}([^/]*)').squeeze()) +df.insert(0, 'n_label', df['label'].apply(lambda x: 0 if x == 'alive' else 1)) + +new_df = df.iloc[:, :-4] + +y = new_df.iloc[:, 0] +X = new_df.iloc[:, 2:] + +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y) + +def build_and_evaluate_model(clf, X_train, y_train, X_test, y_test): + model = Pipeline( + [ + ("scaler", StandardScaler()), + ("pca", PCA(n_components=0.95, whiten=True, random_state=42)), + ("clf", clf), + ] + ) + + pipeline = model.fit(X_train, y_train) + + score = pipeline.score(X_test, y_test) + print(f"Classification score: {score}") + + y_pred = pipeline.predict(X_test) + + print("Classification Report:") + print(classification_report(y_test, y_pred)) + + print("Confusion Matrix:") + cm = confusion_matrix(y_test, y_pred) + print(cm) + + # Cross-validation + cv_results = cross_validate(pipeline, X, y, cv=5) + print("Cross-validation results:") + print(cv_results) + + # Plot and save the confusion matrix + plt.figure(figsize=(10,7)) + sns.heatmap(cm, annot=True, fmt='d') + plt.xlabel('Predicted') + plt.ylabel('Truth') + plt.title(f'Confusion Matrix for {clf.__class__.__name__}') + plt.savefig(f'confusion_matrix_{clf.__class__.__name__}.png') + plt.clf() # Clear the current figure + + return score, cm, cv_results + +classifiers = [RandomForestClassifier(), GradientBoostingClassifier(n_estimators=100, learning_rate=1.0, max_depth=1, random_state=0), svm.SVC()] + +results = [] + +for clf in classifiers: + score, cm, cv_results = build_and_evaluate_model(clf, X_train, y_train, X_test, y_test) + results.append((clf.__class__.__name__, score, cm, cv_results)) + +known_labels = list(y[:50]) +unknown_labels = [-1]*len(y[50:]) +partial_labels = known_labels + unknown_labels + +reducer = umap.UMAP() +embedding = reducer.fit_transform(X, y=partial_labels) + +plt.scatter(embedding[:, 0], embedding[:, 1], c=partial_labels, cmap='Spectral', s=5) +plt.gca().set_aspect('equal', 'datalim') +plt.colorbar(boundaries=np.arange(11)-0.5).set_ticks(np.arange(10)) +plt.title('UMAP projection of the dataset', fontsize=24) + +plt.savefig('umap_visualization.png') +plt.clf() # Clear the current figure + +# Generate LaTeX report +with open('final_report.tex', 'w') as f: + f.write("\\documentclass{article}\n\\usepackage{graphicx}\n\\usepackage{longtable}\n\\usepackage{listings}\n\\begin{document}\n") + for name, score, cm, cv_results in results: + f.write(f"\\section*{{Results for {name}}}\n") + f.write("\\begin{longtable}{|l|l|}\n") + f.write("\\hline\n") + f.write(f"Classification Score & {score} \\\\\n") + f.write("\\hline\n") + f.write("Confusion Matrix & \\\\\n") + f.write("\\begin{lstlisting}\n") + f.write(np.array2string(cm).replace('\n', ' \\\\\n')) + f.write("\\end{lstlisting}\n") + f.write("\\hline\n") + f.write("Cross-validation Results & \\\\\n") + f.write("\\begin{lstlisting}\n") + cv_results_df = pd.DataFrame(cv_results) + cv_results_df = cv_results_df.applymap(lambda x: x.tolist() if isinstance(x, np.ndarray) else x) + f.write(cv_results_df.to_string().replace('\n', ' \\\\\n')) + f.write("\\end{lstlisting}\n") + f.write("\\hline\n") + f.write("\\end{longtable}\n") + f.write("\\section*{UMAP visualization}\n") + f.write("\\includegraphics[width=\\textwidth]{umap_visualization.png}\n") + f.write("\\end{document}\n") + +os.system('pdflatex final_report.tex') + +# Generate CSV report +report_df = pd.DataFrame(results, columns=['Classifier', 'Score', 'Confusion Matrix', 'Cross-validation Results']) +report_df['Cross-validation Results'] = report_df['Cross-validation Results'].apply(lambda x: pd.DataFrame(x).applymap(lambda y: y.tolist() if isinstance(y, np.ndarray) else y).to_dict()) +report_df.to_csv('final_report.csv', index=False) diff --git a/scripts/shapes/distmatrices2contour.py b/scripts/shapes/distmatrices2contour.py new file mode 100644 index 00000000..23b56bb8 --- /dev/null +++ b/scripts/shapes/distmatrices2contour.py @@ -0,0 +1,73 @@ +import matplotlib.pyplot as plt +from sklearn.manifold import MDS +import numpy as np +import argparse +import pathlib +import types +import glob + +# misc helpers +############################################################################### + +def vprint(tgtlvl, msg, pfx = f"{'':<5}"): + try: + if (tgtlvl <= vprint.lvl): + print(f"{pfx}{msg}") + except AttributeError: + print("verbosity level not set, defaulting to 0") + vprint.lvl = 0 + vprint(tgtlvl, msg) + +def asym_to_sym(asym_dist_mat): + return np.max(np.stack([asym_dist_mat, asym_dist_mat.T]), axis=0) + +def dist_to_coords(dst_mat): + embedding = MDS(n_components=2, dissimilarity='precomputed', normalized_stress='auto') + return embedding.fit_transform(dst_mat) + +def distmatrices2contour(params): + plt.clf() + dm_npys = glob.glob(f'{params.matrices_folder}/orig*.npy') + glob.glob(f'{params.matrices_folder}/recon*.npy') + for dm_npy in dm_npys: + dm = np.load(dm_npy) + vprint(2, f'{dm_npy}: dm.shape={dm.shape}') + dm = asym_to_sym(dm) + p = pathlib.Path(dm_npy) + p = p.with_suffix('.png') + reconstructed_coords = dist_to_coords(dm) + plt.axes().set_aspect('equal') + plt.scatter(*zip(*reconstructed_coords), s=6) + plt.savefig(p) + vprint(2, f'saved {p}') + plt.clf() + +############################################################################### + +params = types.SimpleNamespace(**{ + "matrices_folder": None +}) + +if __name__ == "__main__": + + def auto_pos_int (x): + val = int(x,0) + if val <= 0: + raise argparse.ArgumentTypeError("argument must be a positive int. Got {:d}.".format(val)) + return val + + parser = argparse.ArgumentParser(description='Turn distance matrices into contours') + + parser.add_argument('matrices_folder', metavar='MATRICES_FOLDER', help=f"The path to the matrices folder") + parser.add_argument('-v', '--verbose', action='count', default=0 + , help="Increase verbosity level by adding more \"v\".") + + # parse command line arguments + clargs=parser.parse_args() + + # set verbosity level for vprint function + vprint.lvl = clargs.verbose + + # update default params with clargs + params.matrices_folder = clargs.matrices_folder + + distmatrices2contour(params) diff --git a/scripts/shapes/distmatrix2embeding.py b/scripts/shapes/distmatrix2embeding.py new file mode 100644 index 00000000..f23b1d2a --- /dev/null +++ b/scripts/shapes/distmatrix2embeding.py @@ -0,0 +1,626 @@ +import seaborn as sns +import pyefd +from torchvision import datasets, transforms +import pytorch_lightning as pl +import pandas as pd +import numpy as np +import umap +import umap.plot +import bokeh.plotting +import matplotlib.pyplot as plt +from sklearn.cluster import KMeans +import bioimage_embed +import bioimage_embed.shapes +import bioimage_embed.lightning +from bioimage_embed.lightning import DataModule +from pytorch_lightning import loggers as pl_loggers +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +import argparse +import datetime +import pathlib +import torch +import types +import re +import shutil +from pathlib import Path +from sklearn.model_selection import cross_validate, KFold, train_test_split, StratifiedKFold +from sklearn.metrics import make_scorer +from sklearn import metrics +from sklearn.discriminant_analysis import StandardScaler +from sklearn.ensemble import RandomForestClassifier +from sklearn.pipeline import Pipeline +from skimage import measure +from tqdm import tqdm +import logging + +from bioimage_embed.shapes.transforms import ( + ImageToCoords, + CropCentroidPipeline +) + +import pickle +import base64 +import hashlib +import os + +# Seed everything +np.random.seed(42) +pl.seed_everything(42) + +# misc helpers +############################################################################### + +def vprint(tgtlvl, msg, pfx = f"{'':<5}"): + try: + if (tgtlvl <= vprint.lvl): + print(f"{pfx}{msg}") + except AttributeError: + print("verbosity level not set, defaulting to 0") + vprint.lvl = 0 + vprint(tgtlvl, msg) + +def maybe_roll (dist_mat, p = 0.5): + if np.random.rand() < p: + return np.roll(dist_mat, np.random.randint(0, dist_mat.shape[0]), (0,1)) + else: + return dist_mat + +def sanity_check (dist_mat): + if not np.allclose(dist_mat, dist_mat.T): + raise ValueError("Matrix is not symmetric") + if np.any(dist_mat < 0): + raise ValueError("Matrix has negative values") + if np.any(np.diag(dist_mat)): + raise ValueError("Matrix has non-zero diagonal") + return dist_mat + +def hashing_fn(args): + serialized_args = pickle.dumps(vars(args)) + hash_object = hashlib.sha256(serialized_args) + hashed_string = base64.urlsafe_b64encode(hash_object.digest()).decode() + return hashed_string + +def scoring_df(X, y): + # Split the data into training and test sets + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42, shuffle=True, stratify=y + ) + # Define a dictionary of metrics + scoring = { + "accuracy": make_scorer(metrics.balanced_accuracy_score), + "precision": make_scorer(metrics.precision_score, average="macro"), + "recall": make_scorer(metrics.recall_score, average="macro"), + "f1": make_scorer(metrics.f1_score, average="macro"), + #"roc_auc": make_scorer(metrics.roc_auc_score, average="macro") + } + + # Create a random forest classifier + pipeline = Pipeline( + [ + ("scaler", StandardScaler()), + # ("pca", PCA(n_components=0.95, whiten=True, random_state=42)), + ("clf", RandomForestClassifier()), + # ("clf", DummyClassifier()), + ] + ) + + # Specify the number of folds + k_folds = 5 + + # Perform k-fold cross-validation + cv_results = cross_validate( + estimator=pipeline, + X=X, + y=y, + cv=StratifiedKFold(n_splits=k_folds), + scoring=scoring, + n_jobs=-1, + return_train_score=False, + ) + + # Put the results into a DataFrame + return pd.DataFrame(cv_results) + +def create_regionprops_df( dataset + , properties = [ "area" + , "perimeter" + , "centroid" + , "major_axis_length" + , "minor_axis_length" + , "orientation" ] ): + dfs = [] + # Distance matrix data + for i, data in enumerate(tqdm(dataset)): + X, y = data + # Do regionprops here + # Calculate shape summary statistics using regionprops + # We're considering that the mask has only one object, so we take the first element [0] + # props = regionprops(np.array(X).astype(int))[0] + props_table = measure.regionprops_table( + np.array(X).astype(int), properties=properties + ) + + # Store shape properties in a dataframe + df = pd.DataFrame(props_table) + + # Assuming the class or label is contained in 'y' variable + df["class"] = y + df.set_index("class", inplace=True) + dfs.append(df) + + return pd.concat(dfs) + +def create_efd_df(dataset): + dfs = [] + for i, data in enumerate(tqdm(dataset)): + # Convert the tensor to a numpy array + X, y = data + print(f" The image: {i}") + + # Feed it to PyEFD's calculate_efd function + coeffs = pyefd.elliptic_fourier_descriptors(X, order=10, normalize=False) + # coeffs_df = pd.DataFrame({'class': [y], 'norm_coeffs': [norm_coeffs.flatten().tolist()]}) + + norm_coeffs = pyefd.normalize_efd(coeffs) + df = pd.DataFrame( + { + "norm_coeffs": norm_coeffs.flatten().tolist(), + "coeffs": coeffs.flatten().tolist(), + } + ).T.rename_axis("coeffs") + df["class"] = y + df.set_index("class", inplace=True, append=True) + dfs.append(df) + + return pd.concat(dfs) + +def run_trials( trials, outputdir + , logger = logging.getLogger(__name__) + , width = 3.45 + , height = 3.45 / 1.618 ): + trial_df = pd.DataFrame() + for trial in trials: + X = trial["features"] + y = trial["labels"] + trial["score_df"] = scoring_df(X, y) + trial["score_df"]["trial"] = trial["name"] + logger.info(trial["score_df"]) + trial["score_df"].to_csv(f"{outputdir}/{trial['name']}_score_df.csv") + trial_df = pd.concat([trial_df, trial["score_df"]]) + trial_df = trial_df.drop(["fit_time", "score_time"], axis=1) + + trial_df.to_csv(f"{outputdir}/trial_df.csv") + trial_df.groupby("trial").mean().to_csv(f"{outputdir}/trial_df_mean.csv") + trial_df.plot(kind="bar") + + avg = trial_df.groupby("trial").mean() + logger.info(avg) + avg.to_latex(f"{outputdir}/trial_df.tex") + + melted_df = trial_df.melt(id_vars="trial", var_name="Metric", value_name="Score") + # fig, ax = plt.subplots(figsize=(width, height)) + ax = sns.catplot( + data=melted_df, + kind="bar", + x="trial", + hue="Metric", + y="Score", + errorbar="se", + height=height, + aspect=width * 2**0.5 / height, + ) + # ax.xtick_params(labelrotation=45) + # plt.legend(loc='lower center', bbox_to_anchor=(1, 1)) + # sns.move_legend(ax, "lower center", bbox_to_anchor=(1, 1)) + # ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left') + # plt.tight_layout() + plt.savefig(f"{outputdir}/trials_barplot.pdf") + plt.close() + + avs = ( + melted_df.set_index(["trial", "Metric"]) + .xs("test_f1", level="Metric", drop_level=False) + .groupby("trial") + .mean() + ) + logger.info(avs) + +# Main process +############################################################################### + +def main_process(params): + + # Loading the data (matrices) + ########################################################################### + + preproc_transform = transforms.Compose([ + lambda x: x / np.linalg.norm(x, "fro"), # normalize the matrix + #lambda x: x*1000, # scale the matrix + #lambda x: x / x.max(), # normalize each element to one using the max value (0-1) + lambda x: maybe_roll(x, p = 1.0), # "potentially" roll the matrix + sanity_check, # check if the matrix is symmetric and positive, and the diagonal is zero + torch.as_tensor, # turn (H,W) numpy array into a (H,W) tensor + lambda x: x.repeat(3, 1, 1) # turn (H,W) tensor into a (3,H,W) tensor (to fit downstream model expectations) + ]) + + dataset = datasets.DatasetFolder(params.dataset[1], loader=np.load, extensions=('npy'), transform = preproc_transform) + #dataset = datasets.DatasetFolder(params.dataset[1], loader=lambda x: np.load(x, allow_pickle=True), extensions=('npy'), transform = preproc_transform) + dataloader = bioimage_embed.lightning.DataModule( + dataset, + batch_size=params.batch_size, + shuffle=True, + num_workers=params.num_workers, + ) + dataloader.setup() + vprint(1, f'dataloader ready') + + # Build the model + ########################################################################### + + extra_params = {} + if re.match(".*_beta_vae", params.model): + extra_params['beta'] = params.model_beta_vae_beta + model = bioimage_embed.models.create_model( + model=params.model, + input_dim=params.input_dim, + latent_dim=params.latent_dim, + pretrained=params.pretrained, + **extra_params + ) + lit_model = bioimage_embed.shapes.MaskEmbed(model, params) + vprint(1, f'model ready') + + model_dir = f"checkpoints/{hashing_fn(params)}" + + + if clargs.clear_checkpoints: + print("cleaning checkpoints") + shutil.rmtree("checkpoints/") + model_dir = f"checkpoints/{hashing_fn(params)}" + + # WandB logger + ########################################################################### + jobname = f"{params.model}_{'_'.join([f'{k}{v}' for k, v in extra_params.items()])}_{params.latent_dim}_{params.batch_size}_{params.dataset[0]}" + wandblogger = pl_loggers.WandbLogger(entity=params.wandb_entity, project=params.wandb_project, name=jobname) + wandblogger.watch(lit_model, log="all") + # TODO: Sanity check: + # test_data = dataset[0][0].unsqueeze(0) + # test_output = lit_model.forward((test_data,)) + + # Train the model + ########################################################################### + + Path(f"{model_dir}/").mkdir(parents=True, exist_ok=True) + + checkpoint_callback = ModelCheckpoint( + dirpath=f"{model_dir}/", + save_last=True, + save_top_k=1, + monitor="loss/val", + mode="min", + ) + + trainer = pl.Trainer( + logger=[wandblogger], + gradient_clip_val=0.5, + enable_checkpointing=True, + devices=1, + accelerator="gpu", + accumulate_grad_batches=4, + callbacks=[checkpoint_callback], + min_epochs=50, + max_epochs=params.epochs, + log_every_n_steps=1, + ) + + # Determine the checkpoint path for resuming + last_checkpoint_path = f"{model_dir}/last.ckpt" + best_checkpoint_path = checkpoint_callback.best_model_path + + # Check if a last checkpoint exists to resume from + if os.path.isfile(last_checkpoint_path): + resume_checkpoint = last_checkpoint_path + elif best_checkpoint_path and os.path.isfile(best_checkpoint_path): + resume_checkpoint = best_checkpoint_path + else: + resume_checkpoint = None + + trainer.fit(lit_model, datamodule=dataloader) + lit_model.eval() + vprint(1, f'trainer fitted') + + + #TODO: Validate the model + ########################################################################### + vprint(1, f'Validate the model') + validation = trainer.validate(lit_model, datamodule=dataloader) + + #TODO: Test the model + ########################################################################### + vprint(1, f'Test the model') + testing = trainer.test(lit_model, datamodule=dataloader) + + # Inference on full dataset + dataloader = DataModule( + dataset, + batch_size=1, + shuffle=False, + num_workers=params.num_workers, + # Transform is commented here to avoid augmentations in real data + # HOWEVER, applying the transform multiple times and averaging the results might produce better latent embeddings + # transform=transform, + ) + dataloader.setup() + + # Predict + ########################################################################### + predictions = trainer.predict(lit_model, datamodule=dataloader) + filenames = [sample[0] for sample in dataloader.get_dataset().samples] + class_indices = np.array([int(data[-1]) for data in dataloader.predict_dataloader()]) + + #TODO: Pull the embedings and reconstructed distance matrices + ########################################################################### + # create the output directory + output_dir = params.output_dir + if output_dir is None: + output_dir = f'./{params.model}_{params.latent_dim}_{params.batch_size}_{params.dataset[0]}_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}' + pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True) + for class_label in dataset.classes: + pathlib.Path(f'{output_dir}/{class_label}').mkdir(parents=True, exist_ok=True) + # Save the latent space + vprint(1, f'pull the embedings') + latent_space = torch.stack([d.out.z.flatten() for d in predictions]).numpy() + scalings = torch.stack([d.x.scalings.flatten() for d in predictions]) + + np.save(f'{output_dir}/latent_space.npy', latent_space) + df = pd.DataFrame(latent_space) + df['class_idx'] = class_indices + #df['class'] = [dataset.classes[x] for x in class_indices] + df['class'] = pd.Series([dataset.classes[x] for x in class_indices]).astype("category") + df['fname'] = filenames + #df['scale'] = scalings[:,0].squeeze() + df.to_pickle(f'{output_dir}/latent_space.pkl') + + df_shape_embed = df.drop('fname', axis=1).copy() + df_shape_embed = df_shape_embed.set_index('class') + regionprop_dataset = datasets.ImageFolder('/nfs/research/uhlmann/afoix/image_datasets/tiny_broken_synthetic_shapes/', transform=transforms.Compose([ + #regionprop_dataset = datasets.ImageFolder('/nfs/research/uhlmann/afoix/image_datasets/synthetic_shapes/', transform=transforms.Compose([ + transforms.Grayscale(1) + #, CropCentroidPipeline(128 * 2) + ])) + df_regionprops = create_regionprops_df(regionprop_dataset) + efd_dataset = datasets.ImageFolder('/nfs/research/uhlmann/afoix/image_datasets/tiny_broken_synthetic_shapes/', transform=transforms.Compose([ + #efd_dataset = datasets.ImageFolder('/nfs/research/uhlmann/afoix/image_datasets/synthetic_shapes/', transform=transforms.Compose([ + transforms.Grayscale(1) + #, CropCentroidPipeline(128 * 2) + , ImageToCoords(128 * 2) + ])) + print(efd_dataset) + df_efd = create_efd_df(efd_dataset) + + # setup trials + trials = [ + { + "name": "mask_embed", + "features": df_shape_embed.to_numpy(), + "labels": df_shape_embed.index, + }, + { + "name": "fourier_coeffs", + "features": df_efd.xs("coeffs", level="coeffs"), + "labels": df_efd.xs("coeffs", level="coeffs").index, + }, + # {"name": "fourier_norm_coeffs", + # "features": df_efd.xs("norm_coeffs", level="coeffs"), + # "labels": df_efd.xs("norm_coeffs", level="coeffs").index + # } + { + "name": "regionprops", + "features": df_regionprops, + "labels": df_regionprops.index, + } + ] + + run_trials(trials, output_dir) + + + # Save the (original input and) reconstructions + for i, (pred, class_idx, fname) in enumerate(zip(predictions, class_indices, filenames)): + vprint(5, f'pred#={i}, class_idx={class_idx}, fname={fname}') + class_label = dataset.classes[class_idx] + np.save(f'{output_dir}/{class_label}/original_{i}_{class_label}.npy', pred.x.data[0,0]) + np.save(f'{output_dir}/{class_label}/reconstruction_{i}_{class_label}.npy', pred.out.recon_x[0,0]) + # umap + vprint(4, f'generate umap') + umap_model = umap.UMAP(n_neighbors=50, min_dist=0.8, n_components=2, random_state=42) + mapper = umap_model.fit(df.drop(['class_idx','class','fname'], axis=1)) + umap.plot.points(mapper, labels=np.array(df['class'])) + plt.savefig(f'{output_dir}/umap.png') + #p = umap.plot.interactive(mapper, labels=df['class_idx'], hover_data=df[['class','fname']]) + p = umap.plot.interactive(mapper, values=df.drop(['class_idx','class','fname'], axis=1).mean(axis=1), theme='viridis', hover_data=df[['class','fname']]) + # save interactive plot as html + bokeh.plotting.output_file(f"{output_dir}/umap.html") + bokeh.plotting.save(p) + + # kmean and clustering information + # Perform KMeans clustering on the UMAP result + vprint(4, f'cluster data with kmean') + n_clusters = 4 # Define the number of clusters + kmeans = KMeans(n_clusters=n_clusters, random_state=42) + umap_result = umap_model.fit_transform(latent_space) + cluster_labels = kmeans.fit_predict(umap_result) + + # Concatenate the original data, UMAP result, and cluster labels + data_with_clusters = np.column_stack((latent_space, umap_result, cluster_labels)) + + # Convert to DataFrame for better handling + columns = [f'Feature_{i}' for i in range(latent_space.shape[1])] + \ + ['UMAP_Dimension_1', 'UMAP_Dimension_2', 'Cluster_Label'] + df = pd.DataFrame(data_with_clusters, columns=columns) + df['fname'] = filenames + + df.to_csv(f'{output_dir}/clustered_data.csv', index=False) + + # Plot the UMAP result with cluster labels + plt.figure(figsize=(10, 8)) + for i in range(n_clusters): + plt.scatter(umap_result[cluster_labels == i, 0], umap_result[cluster_labels == i, 1], label=f'Cluster {i+1}', s=5) + plt.title('UMAP Visualization of Latent Space with KMeans Clustering') + plt.xlabel('UMAP Dimension 1') + plt.ylabel('UMAP Dimension 2') + plt.legend() + + # Save the figure + plt.savefig(f'{output_dir}/umap_with_kmeans_clusters.png') + + # Test embeding for a classifcation task + + +# default parameters +############################################################################### + +models = [ + "resnet18_vae" +, "resnet50_vae" +, "resnet18_beta_vae" +, "resnet18_vae_bolt" +, "resnet50_vae_bolt" +, "resnet18_vqvae" +, "resnet50_vqvae" +, "resnet18_vqvae_legacy" +, "resnet50_vqvae_legacy" +, "resnet101_vqvae_legacy" +, "resnet110_vqvae_legacy" +, "resnet152_vqvae_legacy" +, "resnet18_vae_legacy" +, "resnet50_vae_legacy" +] + +#matrix_dim = 512 +matrix_dim = 4 +n = 2 +params = types.SimpleNamespace(**{ + # general params + "model": "resnet18_vae", + "epochs": 150, + "batch_size": 4, + "num_workers": 2**4, + "input_dim": (3, 512, 512), + #"latent_dim": int((matrix_dim**2 - matrix_dim) / 2), + "latent_dim": int((matrix_dim*(matrix_dim-1))/2**n), + "num_embeddings": 1024, + "num_hiddens": 1024, + "pretrained": True, + "commitment_cost": 0.25, + "decay": 0.99, + "frobenius_norm": False, + "dataset": ("tiny_dist", "/nfs/research/uhlmann/afoix/distmat_datasets/tiny_synthcellshapes_dataset_distmat"), + # model-specific params + "model_beta_vae_beta": 1, + # optimizer_params + "opt": "AdamW", + "lr": 0.001, + "weight_decay": 0.0001, + "momentum": 0.9, + # lr_scheduler_params + "sched": "cosine", + "min_lr": 1e-4, + "warmup_epochs": 5, + "warmup_lr": 1e-6, + "cooldown_epochs": 10, + "t_max": 50, + "cycle_momentum": False, +}) + +############################################################################### + +if __name__ == "__main__": + + def auto_pos_int (x): + val = int(x,0) + if val <= 0: + raise argparse.ArgumentTypeError("argument must be a positive int. Got {:d}.".format(val)) + return val + + parser = argparse.ArgumentParser(description='Run the shape embed pipeline') + + parser.add_argument( + '-m', '--model', choices=models, metavar='MODEL' + , help=f"The MODEL to use, one of {models} (default {params.model}).") + parser.add_argument( + '--model-beta-vae-beta', type=float, metavar='BETA' + , help=f"The BETA parameter to use for a beta-vae model.") + parser.add_argument( + '-d', '--dataset', nargs=2, metavar=('NAME', 'PATH') + , help=f"The NAME of and PATH to the dataset (default: {params.dataset})") + parser.add_argument( + '-o', '--output-dir', metavar='OUTPUT_DIR', default=None + , help=f"The OUTPUT_DIR path to use to dump results") + parser.add_argument( + '--wandb-entity', default="foix", metavar='WANDB_ENTITY' + , help=f"The WANDB_ENTITY name") + parser.add_argument( + '--wandb-project', default="simply-shape", metavar='WANDB_PROJECT' + , help=f"The WANDB_PROJECT name") + parser.add_argument( + '-b', '--batch-size', metavar='BATCH_SIZE', type=auto_pos_int + , help=f"The BATCH_SIZE for the run, a positive integer (default {params.batch_size})") + parser.add_argument( + '-l', '--latent-space-size', metavar='LATENT_SPACE_SIZE', type=auto_pos_int + , help=f"The LATENT_SPACE_SIZE, a positive integer (default {params.latent_dim})") + parser.add_argument( + '--input-dimensions', metavar='INPUT_DIM', nargs=2, type=auto_pos_int + , help=f"The width and height INPUT_DIM for the input dimensions (default {params.input_dim[1]} and {params.input_dim[2]})") + parser.add_argument( + '--number-embeddings', metavar='NUM_EMBEDDINGS', type=auto_pos_int + , help=f"The NUM_EMBEDDINGS, a positive integer (default {params.num_embeddings})") + parser.add_argument( + '--number-hiddens', metavar='NUM_HIDDENS', type=auto_pos_int + , help=f"The NUM_HIDDENS, a positive integer (default {params.num_hiddens})") + parser.add_argument( + '-n', '--num-workers', metavar='NUM_WORKERS', type=auto_pos_int + , help=f"The NUM_WORKERS for the run, a positive integer (default {params.num_workers})") + parser.add_argument( + '-e', '--num-epochs', metavar='NUM_EPOCHS', type=auto_pos_int + , help=f"The NUM_EPOCHS for the run, a positive integer (default {params.epochs})") + parser.add_argument('--clear-checkpoints', action='store_true' + , help='remove checkpoints') + parser.add_argument('-v', '--verbose', action='count', default=0 + , help="Increase verbosity level by adding more \"v\".") + + # parse command line arguments + clargs=parser.parse_args() + + # set verbosity level for vprint function + vprint.lvl = clargs.verbose + + # update default params with clargs + if clargs.model: + params.model = clargs.model + if clargs.model_beta_vae_beta: + params.model_beta_vae_beta = clargs.model_beta_vae_beta + params.output_dir = clargs.output_dir + if clargs.dataset: + params.dataset = clargs.dataset + if clargs.wandb_entity: + params.wandb_entity = clargs.wandb_entity + if clargs.wandb_project: + params.wandb_project = clargs.wandb_project + if clargs.batch_size: + params.batch_size = clargs.batch_size + if clargs.latent_space_size: + params.latent_dim = clargs.latent_space_size + if clargs.input_dimensions: + params.input_dim = (params.input_dim[0], clargs.input_dimensions[0], clargs.input_dimensions[1]) + if clargs.number_embeddings: + params.num_embeddings = clargs.number_embeddings + if clargs.number_hiddens: + params.num_hiddens = clargs.number_hiddens + if clargs.num_workers: + params.num_workers = clargs.num_workers + if clargs.num_epochs: + params.epochs = clargs.num_epochs + + logging.basicConfig(level=logging.INFO) + # run main process + main_process(params) \ No newline at end of file diff --git a/scripts/shapes/drawContourFromDM.py b/scripts/shapes/drawContourFromDM.py new file mode 100644 index 00000000..fde5172f --- /dev/null +++ b/scripts/shapes/drawContourFromDM.py @@ -0,0 +1,74 @@ + +import matplotlib.pyplot as plt +from sklearn.manifold import MDS +import numpy as np +import argparse +import pathlib +import types +import glob + +# misc helpers +############################################################################### + +def vprint(tgtlvl, msg, pfx = f"{'':<5}"): + try: + if (tgtlvl <= vprint.lvl): + print(f"{pfx}{msg}") + except AttributeError: + print("verbosity level not set, defaulting to 0") + vprint.lvl = 0 + vprint(tgtlvl, msg) + +def asym_to_sym(asym_dist_mat): + return np.max(np.stack([asym_dist_mat, asym_dist_mat.T]), axis=0) + +def dist_to_coords(dst_mat): + embedding = MDS(n_components=2, dissimilarity='precomputed', normalized_stress='auto') + return embedding.fit_transform(dst_mat) + +def distmatrices2contour(params): + plt.clf() + dm_npys = glob.glob(f'{params.matrices_folder}/*.npy') + for dm_npy in dm_npys: + dm = np.load(dm_npy) + vprint(2, f'{dm_npy}: dm.shape={dm.shape}') + dm = asym_to_sym(dm) + p = pathlib.Path(dm_npy) + p = p.with_suffix('.png') + reconstructed_coords = dist_to_coords(dm) + plt.axes().set_aspect('equal') + plt.scatter(*zip(*reconstructed_coords), s=6) + plt.savefig(p) + vprint(2, f'saved {p}') + plt.clf() + +############################################################################### + +params = types.SimpleNamespace(**{ + "matrices_folder": None +}) + +if __name__ == "__main__": + + def auto_pos_int (x): + val = int(x,0) + if val <= 0: + raise argparse.ArgumentTypeError("argument must be a positive int. Got {:d}.".format(val)) + return val + + parser = argparse.ArgumentParser(description='Turn distance matrices into contours') + + parser.add_argument('matrices_folder', metavar='MATRICES_FOLDER', help=f"The path to the matrices folder") + parser.add_argument('-v', '--verbose', action='count', default=0 + , help="Increase verbosity level by adding more \"v\".") + + # parse command line arguments + clargs=parser.parse_args() + + # set verbosity level for vprint function + vprint.lvl = clargs.verbose + + # update default params with clargs + params.matrices_folder = clargs.matrices_folder + + distmatrices2contour(params) diff --git a/scripts/shapes/genUMAPs.py b/scripts/shapes/genUMAPs.py new file mode 100755 index 00000000..265f0525 --- /dev/null +++ b/scripts/shapes/genUMAPs.py @@ -0,0 +1,141 @@ +#! /usr/bin/env python3 + +import os +import os.path +import pandas as pd +import numpy as np +import umap +import umap.plot +import matplotlib.pyplot as plt +import bokeh.plotting +import argparse +import datetime +import pathlib +import multiprocessing +import subprocess + +# Seed everything +np.random.seed(42) + +# misc helpers +############################################################################### + +def vprint(tgtlvl, msg, pfx = f"{'':<5}"): + try: + if (tgtlvl <= vprint.lvl): + print(f"{pfx}{msg}") + except AttributeError: + print("verbosity level not set, defaulting to 0") + vprint.lvl = 0 + vprint(tgtlvl, msg) + +# render UMAPS +def render_umap_core(df, output_dir, n_neighbors, min_dist, n_components): + name = f'umap_{n_neighbors}_{min_dist}_{n_components}' + vprint(4, f'generate {name}') + vprint(5, f'n_neigbhors: {type(n_neighbors)} {n_neighbors}') + vprint(5, f'min_dist: {type(min_dist)} {min_dist}') + vprint(5, f'n_components: {type(n_components)} {n_components}') + umap_model = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=n_components, random_state=42) + mapper = umap_model.fit(df.drop(['class_idx','class','fname'], axis=1)) + umap.plot.points(mapper, labels=np.array(df['class'])) + plt.savefig(f'{output_dir}/{name}.png') + theme_values = df.drop(['class_idx','class','fname'], axis=1).mean(axis=1) + vprint(5, f'theme_values type: {type(theme_values)}') + if True: #temporary condition to work ONLY with the tree dataset + theme_values = list(map(lambda x: int(x.split('_')[-1].split('.')[0]), df['fname'])) + vprint(5, f'new theme_values type: {type(theme_values)}') + vprint(5, f'theme_values: {theme_values}') + #p = umap.plot.interactive(mapper, labels=df['class_idx'], hover_data=df[['class','fname']]) + #p = umap.plot.interactive(mapper, values=df.drop(['class_idx','class','fname'], axis=1).mean(axis=1), theme='viridis', hover_data=df[['class','fname']]) + p = umap.plot.interactive(mapper, values=theme_values, theme='viridis', hover_data=df[['class','fname']]) + # save interactive plot as html + bokeh.plotting.output_file(f"{output_dir}/{name}.html") + bokeh.plotting.save(p) + +def render_umap(latent_space_pkl, output_dir, n_neighbors, min_dist, n_components): + # create output directory if it does not already exist + os.makedirs(output_dir, exist_ok=True) + # load latent space + df = pd.read_pickle(latent_space_pkl) + # render umap + render_umap_core(df, output_dir, n_neighbors, min_dist, n_components) + +############################################################################### + +if __name__ == "__main__": + + def auto_pos_int (x): + val = int(x,0) + if val <= 0: + raise argparse.ArgumentTypeError("argument must be a positive int. Got {:d}.".format(val)) + return val + + parser = argparse.ArgumentParser(description='generate umaps') + + parser.add_argument('latent_space', metavar='LATENT_SPACE', type=os.path.abspath + , help=f"The path to the latent space") + parser.add_argument('-j', '--n_jobs', type=auto_pos_int, default=2*os.cpu_count() + , help="number of jobs to start. Default is 2x the number of CPUs.") + parser.add_argument('--slurm', action=argparse.BooleanOptionalAction) + parser.add_argument('-n', '--n_neighbors', nargs='+', type=auto_pos_int, default=[50] + , help="A list of the number of neighbors to use in UMAP. Default is [50].") + parser.add_argument('-m', '--min_dist', nargs='+', type=float, default=[0.8] + , help="A list of the minimum distances to use in UMAP. Default is [0.8].") + parser.add_argument('-c', '--n_components', nargs='+', type=auto_pos_int, default=[2] + , help="A list of the number of components to use in UMAP. Default is [2].") + parser.add_argument( '-o', '--output-dir', metavar='OUTPUT_DIR', default=f'{os.getcwd()}/umaps' + , help=f"The OUTPUT_DIR path to use to dump results") + parser.add_argument('-v', '--verbose', action='count', default=0 + , help="Increase verbosity level by adding more \"v\".") + + # parse command line arguments + clargs=parser.parse_args() + + # set verbosity level for vprint function + vprint.lvl = clargs.verbose + + #for x,y,z in [(x, y, z) for x in clargs.n_neighbors + # for y in clargs.min_dist + # for z in clargs.n_components]: + # render_umap(df, x, y, z) + + params=[(x, y, z) for x in clargs.n_neighbors + for y in clargs.min_dist + for z in clargs.n_components] + if clargs.slurm: + vprint(2, f'running with slurm') + for (n_neighbors, min_dist, n_components) in params: + vprint(3, f'running with n_neighbors={n_neighbors}, min_dist={min_dist}, n_components={n_components}') + print('Directory Name: ', os.path.dirname(__file__)) + + cmd = [ "srun" + , "-t", "50:00:00" + , "--mem=200G" + , "--gpus=a100:1" + , "--job-name", f"render_umap_{n_neighbors}_{min_dist}_{n_components}" + , "--pty" + , "python3", "-c" + , f""" +import sys +sys.path.insert(1, '{os.path.dirname(__file__)}') +import genUMAPs +genUMAPs.vprint.lvl = {clargs.verbose} +genUMAPs.render_umap('{clargs.latent_space}','{clargs.output_dir}',{n_neighbors},{min_dist},{n_components}) +"""] + vprint(4, cmd) + subprocess.run(cmd) + + else: + vprint(2, f'running with python multiprocessing') + + # create output directory if it does not already exist + os.makedirs(clargs.output_dir, exist_ok=True) + + # load latent space + df = pd.read_pickle(clargs.latent_space) + + def render_umap_wrapper(args): + render_umap(df, clargs.output_dir, *args) + with multiprocessing.Pool(clargs.n_jobs) as pool: + pool.starmap(render_umap_wrapper, params) \ No newline at end of file diff --git a/scripts/shapes/masks2distmatrices.py b/scripts/shapes/masks2distmatrices.py new file mode 100644 index 00000000..c6af9ae8 --- /dev/null +++ b/scripts/shapes/masks2distmatrices.py @@ -0,0 +1,353 @@ +import numpy as np +import imageio.v3 as iio +import skimage as sk +from scipy.interpolate import splprep, splev +import scipy.spatial +import argparse +import pathlib +import types +import glob +import os + +# misc helpers +############################################################################### + +def vprint(tgtlvl, msg, pfx = f"{'':<5}"): + try: + if (tgtlvl <= vprint.lvl): + print(f"{pfx}{msg}") + except AttributeError: + print("verbosity level not set, defaulting to 0") + vprint.lvl = 0 + vprint(tgtlvl, msg) + +def rgb2grey(rgb, cr = 0.2989, cg = 0.5870, cb = 0.1140): + """Turn an rgb array into a greyscale array using the following reduction: + grey = cr * r + cg * g + cb * b + + :param rgb: The rgb array + :param cr: The red coefficient + :param cg: The green coefficient + :param cb: The blue coefficient + + :returns: The greyscale array. + """ + r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2] + return cr * r + cg * g + cb * b + +########################################################################## +####### Simplified version in order to make the things properly work ##### +########################################################################## + +def find_longest_contour(mask, normalise_coord=False): + if len(mask.shape) == 3: # (lines, columns, number of channels) + mask = rgb2grey(mask) + contours = sk.measure.find_contours(mask, 0.8) + vprint(4, f'len(contours) {len(contours)}') + contours = sorted(contours, key=lambda x: len(x), reverse=True) + x, y = contours[0][:, 0], contours[0][:, 1] + if normalise_coord: + x = x - np.min(x) + x = x / np.max(x) + y = y - np.min(y) + y = y / np.max(y) + return x, y + +def spline_interpolation(x, y, raw_sampling_sparsity, spline_sampling): + # Sparsity of the contour. Dropping some of the sample (points) to make the spline smoother + raw_sampling_sparsity = max(1, raw_sampling_sparsity) + vprint(3, f'running with raw_sampling_sparsity {raw_sampling_sparsity} and spline_sampling {spline_sampling}') + vprint(3, f'x.shape {x.shape} y.shape {y.shape}') + tck, u = splprep([x[::raw_sampling_sparsity], y[::raw_sampling_sparsity]], s = 0, per = True) + # How many times to sample the spline + new_u = np.linspace(u.min(), u.max(), spline_sampling) # Last parameter is how dense is our spline, how many points. + # Evaluate the spline + x_spline, y_spline = splev(new_u, tck) + return x_spline, y_spline + +def build_distance_matrix(x_reinterpolated, y_reinterpolated): + reinterpolated_contour = np.column_stack([x_reinterpolated, y_reinterpolated]) + dm = scipy.spatial.distance_matrix(reinterpolated_contour, reinterpolated_contour) + return dm + +def dist_to_coords(dst_mat): + embedding = MDS(n_components=2, dissimilarity='precomputed') + return embedding.fit_transform(dst_mat) + +def mask2distmatrix(mask, raw_sampling_sparsity=1, spline_sampling=512): + vprint(3, f'running with raw_sampling_sparsity {raw_sampling_sparsity} and spline_sampling {spline_sampling}') + # extract mask contour + x, y = find_longest_contour(mask, normalise_coord=True) + vprint(3, f'found contour shape x {x.shape} y {y.shape}') + # Reinterpolate (spline) + x_reinterpolated, y_reinterpolated = spline_interpolation(x, y, raw_sampling_sparsity, spline_sampling) + # Build the distance matrix + dm = build_distance_matrix(x_reinterpolated, y_reinterpolated) + vprint(3, f'created distance matrix shape {dm.shape}') + return dm + +def masks2distmatrices(params): + + vprint(1, 'loading base dataset') + + if not params.mask_dataset_path: + sys.exit("no mask dataset provided") + if not params.output_path: + p = pathlib.Path(params.mask_dataset_path) + params.output_path=p.joinpath(p.parent, p.name+'_distmat') + + vprint(2, f'>>>> params.mask_dataset_path: {params.mask_dataset_path}') + vprint(2, f'>>>> params.mask_dataset_path: {next(os.walk(params.mask_dataset_path))[1]}') + vprint(2, f'>>>> params.output_path: {params.output_path}') + pathlib.Path(params.output_path).mkdir(parents=True, exist_ok=True) + class_folders = next(os.walk(params.mask_dataset_path))[1] + vprint(2, f'>>>> class_folders: {class_folders}') + for class_folder in class_folders: + vprint(2, f'>>>> class_folder: {class_folder}') + output_class_folder=os.path.join(params.output_path, class_folder) + vprint(2, f'creating output class folder: {output_class_folder}') + pathlib.Path(output_class_folder).mkdir(parents=True, exist_ok=True) + for mask_png in glob.glob(params.mask_dataset_path+'/'+class_folder+'/'+'*.png'): + vprint(3, f'{"-"*80}') + vprint(3, f'working on {mask_png}') + filename = os.path.basename(mask_png).split('.')[0] + vprint(3, f'filename {filename}') + mask = iio.imread(mask_png) + dm = mask2distmatrix(mask, params.raw_sampling_sparsity, params.spline_sampling) + output_file_name=f"{output_class_folder}/{filename}.npy" + vprint(3, f'saving {output_file_name}') + vprint(3, f'{"-"*80}') + np.save(output_file_name, dm) + + + #print('loading base dataset') + #dataset = datasets.ImageFolder(mask_dataset_path, transform=transforms.Compose([ + # np.array, + # mask2distmatrix + #])) + #for idx, data in enumerate(dataset): + # print(f'idx: {idx}') + # print(f'data: {data}') + # #torch.save(data, 'data_drive_path{}'.format(idx)) + #print(dataset) + +# # Simplified version for test +# def process_png_file(mask_path, idx, output_folder='./results/reconstruction'): +# # Perform specific action for each PNG file +# print("Processing:", mask_path) +# mask = plt.imread(mask_path) + +# # Get the contour +# x, y = find_longest_contour(mask) + +# # Reinterpolate (spline) +# x_reinterpolated, y_reinterpolated = spline_interpolation(x, y) +# plt.scatter(x_reinterpolated, y_reinterpolated, s=6) +# plt.savefig(f'{output_folder}/original_contour{idx}.png') +# plt.clf() + +# # Build the distance matrix +# dm = build_distance_matrix(x_reinterpolated, y_reinterpolated) +# np.save(f"{output_folder}/matrix_{idx}.npy", dm) + +# # Reconstruction coordinates and matrix (MDS) +# reconstructed_coords = dist_to_coords(dm) +# print(reconstructed_coords) +# plt.scatter(*zip(*reconstructed_coords), s=6) +# plt.savefig(f'{output_folder}/reconstructed_contour{idx}.png') +# plt.clf() +# reconstructed_matrix = euclidean_distances(reconstructed_coords) + +# # Error with matrix +# err = np.average(dm - reconstructed_matrix) +# print(f"Dist error is: {err}") + +############################################################################### + +params = types.SimpleNamespace(**{ + "mask_dataset_path": None + , "output_path": None + , "raw_sampling_sparsity": 4 + , "spline_sampling": 512 +}) + +if __name__ == "__main__": + + def auto_pos_int (x): + val = int(x,0) + if val <= 0: + raise argparse.ArgumentTypeError("argument must be a positive int. Got {:d}.".format(val)) + return val + + parser = argparse.ArgumentParser(description='Turn mask dataset into distance matrix dataset') + + parser.add_argument('path', metavar='PATH', help=f"The PATH to the dataset") + parser.add_argument('-o', '--output-path', help="The desired output path to the generated dataset") + parser.add_argument('-s', '--raw-sampling-sparsity', type=auto_pos_int + , help=f"The desired sparsity (in number of points) when sampling the raw contour (default, every {params.raw_sampling_sparsity} point(s))") + parser.add_argument('-n', '--spline-sampling', type=auto_pos_int + , help=f"The desired number of points when sampling the spline contour (default, {params.spline_sampling} point(s))") + parser.add_argument('-v', '--verbose', action='count', default=0 + , help="Increase verbosity level by adding more \"v\".") + + # parse command line arguments + clargs=parser.parse_args() + + # set verbosity level for vprint function + vprint.lvl = clargs.verbose + + # update default params with clargs + if clargs.path: + params.mask_dataset_path = clargs.path + #params.mask_dataset_path = "/nfs/research/uhlmann/afoix/tiny_synthcellshapes_dataset" + if clargs.output_path: + params.output_path = clargs.output_path + if clargs.raw_sampling_sparsity: + params.raw_sampling_sparsity = clargs.raw_sampling_sparsity + if clargs.spline_sampling: + params.spline_sampling = clargs.spline_sampling + + masks2distmatrices(params) + + + +############################################################################### +############################################################################### +############################################################################### +######################################## +############# Other code ############### +######################################## + +# # Needed variables +# window_size = 256 # needs to be the same as the latent space size +# interp_size = 256 # latent space size needs to match the window size + +# # This crops the image using the centroid by window sizes. (remember to removed and see what happens) +# transform_crop = CropCentroidPipeline(window_size) + +# # From the coordinates of the distance matrix, this is actually building the distance matrix +# transform_coord_to_dist = CoordsToDistogram(interp_size, matrix_normalised=False) + +# # It takes the images and converts it into a numpy array of the image and the size +# transform_coords = ImageToCoords(window_size) + +# # Combination of transforms +# transform_mask_to_gray = transforms.Compose([transforms.Grayscale(1)]) + +# transform_mask_to_crop = transforms.Compose( +# [ +# # transforms.ToTensor(), +# transform_mask_to_gray, +# transform_crop, +# ] +# ) + +# transform_mask_to_coords = transforms.Compose( +# [ +# transform_mask_to_crop, +# transform_coords, +# ] +# ) + +# transform_mask_to_dist = transforms.Compose( +# [ +# transform_mask_to_coords, +# transform_coord_to_dist, +# ] +# ) + +# def dist_to_coords(dst_mat): +# embedding = MDS(n_components=2, dissimilarity='precomputed', max_iter=1) +# return embedding.fit_transform(dst_mat) + + #coords_prime = MDS( + #n_components=2, dissimilarity="precomputed", random_state=0).fit_transform(dst_mat) + + #return coords_prime + #return mds(dst_mat) + + # from https://math.stackexchange.com/a/423898 and https://stackoverflow.com/a/17177833/16632916 +# m = np.zeros(shape=dst_mat.shape) +# for i in range(dst_mat.shape[0]): +# for j in range(dst_mat.shape[1]): +# m[i,j]= 0.5*(dst_mat[0, j]**2 + dst_mat[i, 0]**2 - dst_mat[i, j]**2) +# eigenvalues, eigenvectors = np.linalg.eig(m) +# print(f'm:{m}') +# print(f'eigenvalues:{eigenvalues}') +# print(f'eigenvectors:{eigenvectors}') +# return np.sqrt(eigenvalues)*eigenvectors + +# # Convert your image to gray scale +# gray2rgb = transforms.Lambda(lambda x: x.repeat(3, 1, 1)) + +# # choose the transformation you want to apply to your data and Compose +# transform = transforms.Compose( +# [ +# transform_mask_to_dist, +# transforms.ToTensor(), +# RotateIndexingClockwise(p=1), # This module effectively allows for random clockwise rotations of input images with a specified probability. +# gray2rgb, +# ] +# ) + +# transforms_dict = { +# "none": transform_mask_to_gray, +# "transform_crop": transform_mask_to_crop, +# "transform_dist": transform_mask_to_dist, +# "transform_coords": transform_mask_to_coords, +# } + + + +# diagonal = np.diag(dm) + +# if np.all(diagonal == 0): +# print("All elements in the diagonal are zeros.") +# dataset_raw[i][0].save(f'original_{i}.png') +# np.save(f"random_matrix_{i}.npy", dataset_trans[i][0][0]) +# matplotlib.image.imsave(f'dist_mat_{i}.png', dataset_trans[i][0][0]) +# coords = dist_to_coords(dataset_trans[i][0][0]) +# print(coords) +# x, y = list(zip(*coords)) +# plt.scatter(x_reinterpolated, y_reinterpolated) +# plt.savefig(f'mask_{i}.png') +# plt.clf() +# fig, ax = plt.subplots(1, 4, figsize=(20, 5)) +# ax[0].imshow(mask) +# ax[1].scatter(x_reinterpolated, y_reinterpolated) +# ax[1].imshow(dm) +# ax[3].scatter(x, y) +# fig.savefig(f'combined_{i}.png') +# else: +# print("Not all elements in the diagonal are zeros.") + + + +# # Apply transform to find which images don't work +# dataset_raw = datasets.ImageFolder(dataset) +# dataset_contours = datasets.ImageFolder(dataset, transform=transform_mask_to_coords) +# dataset_trans = datasets.ImageFolder(dataset, transform=transform) + +# # This is a single image distance matrix +# for i in range(0, 10): +# print(dataset_trans[i][0][0]) +# diagonal = np.diag(dataset_trans[i][0][0]) +# if np.all(diagonal == 0): +# print("All elements in the diagonal are zeros.") +# dataset_raw[i][0].save(f'original_{i}.png') +# np.save(f"random_matrix_{i}.npy", dataset_trans[i][0][0]) +# matplotlib.image.imsave(f'dist_mat_{i}.png', dataset_trans[i][0][0]) +# coords = dist_to_coords(dataset_trans[i][0][0]) +# print(coords) +# x, y = list(zip(*coords)) +# plt.scatter(x, y) +# plt.savefig(f'mask_{i}.png') +# plt.clf() +# fig, ax = plt.subplots(1, 4, figsize=(20, 5)) +# ax[0].imshow(dataset_raw[i][0]) +# ax[1].imshow(dataset_trans[i][0][0]) +# ax[2].scatter(dataset_contours[i][0][0], dataset_contours[i][0][1]) +# ax[3].scatter(x, y) +# fig.savefig(f'combined_{i}.png') +# else: +# print("Not all elements in the diagonal are zeros.") diff --git a/scripts/shapes/shape_embed.py b/scripts/shapes/shape_embed.py index 3d121881..2a82c708 100644 --- a/scripts/shapes/shape_embed.py +++ b/scripts/shapes/shape_embed.py @@ -1,36 +1,41 @@ # %% import seaborn as sns import pyefd +from sklearn.decomposition import PCA +from sklearn.discriminant_analysis import StandardScaler from sklearn.ensemble import RandomForestClassifier -from sklearn.model_selection import cross_validate, KFold, train_test_split +from sklearn.model_selection import cross_validate, KFold, train_test_split, StratifiedKFold from sklearn.metrics import make_scorer import pandas as pd from sklearn import metrics import matplotlib as mpl -import seaborn as sns from pathlib import Path -import umap +from sklearn.pipeline import Pipeline from torch.autograd import Variable from types import SimpleNamespace import numpy as np -import logging from skimage import measure import umap.plot from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint import pytorch_lightning as pl import torch from types import SimpleNamespace - -# Deal with the filesystem +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +import argparse +import wandb +import shutil +from umap import UMAP +import os import torch.multiprocessing +import logging +from tqdm import tqdm + +logging.basicConfig(level=logging.INFO) torch.multiprocessing.set_sharing_strategy("file_system") from bioimage_embed import shapes import bioimage_embed - -# Note - you must have torchvision installed for this example - from pytorch_lightning import loggers as pl_loggers from torchvision import transforms from bioimage_embed.lightning import DataModule @@ -41,18 +46,70 @@ CropCentroidPipeline, DistogramToCoords, MaskToDistogramPipeline, + RotateIndexingClockwise, + CoordsToDistogram, + AsymmetricDistogramToCoordsPipeline, ) - import matplotlib.pyplot as plt from bioimage_embed.lightning import DataModule import matplotlib as mpl from matplotlib import rc -import logging +import pickle +import base64 +import hashlib logger = logging.getLogger(__name__) +# Seed everything +np.random.seed(42) +pl.seed_everything(42) + + +def hashing_fn(args): + serialized_args = pickle.dumps(vars(args)) + hash_object = hashlib.sha256(serialized_args) + hashed_string = base64.urlsafe_b64encode(hash_object.digest()).decode() + return hashed_string + + +def umap_plot(df, metadata, width=3.45, height=3.45 / 1.618, split=0.8): + umap_reducer = UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42) + mask = np.random.rand(len(df)) < split + + semi_labels = df.index.codes.copy() + semi_labels[~mask] = -1 + + umap_embedding = umap_reducer.fit_transform(df.sample(frac=1), y=semi_labels) + + ax = sns.relplot( + data=pd.DataFrame( + umap_embedding, columns=["umap0", "umap1"], index=df.index + ).reset_index(), + x="umap0", + y="umap1", + hue="Class", + palette="deep", + alpha=0.5, + edgecolor=None, + s=5, + height=height, + aspect=0.5 * width / height, + ) + + sns.move_legend( + ax, + "upper center", + ) + ax.set(xlabel=None, ylabel=None) + sns.despine(left=True, bottom=True) + plt.tick_params(bottom=False, left=False, labelbottom=False, labelleft=False) + plt.tight_layout() + plt.savefig(metadata(f"umap_no_axes.pdf")) + # plt.show() + plt.close() + def scoring_df(X, y): # Split the data into training and test sets @@ -61,24 +118,31 @@ def scoring_df(X, y): ) # Define a dictionary of metrics scoring = { - "accuracy": make_scorer(metrics.accuracy_score), + "accuracy": make_scorer(metrics.balanced_accuracy_score), "precision": make_scorer(metrics.precision_score, average="macro"), "recall": make_scorer(metrics.recall_score, average="macro"), "f1": make_scorer(metrics.f1_score, average="macro"), + "roc_auc": make_scorer(metrics.roc_auc_score, average="macro"), } # Create a random forest classifier - clf = RandomForestClassifier() + pipeline = Pipeline( + [ + ("scaler", StandardScaler()), + # ("pca", PCA(n_components=0.95, whiten=True, random_state=42)), + ("clf", RandomForestClassifier()), + ] + ) # Specify the number of folds - k_folds = 10 + k_folds = 5 # Perform k-fold cross-validation cv_results = cross_validate( - estimator=clf, + estimator=pipeline, X=X, y=y, - cv=KFold(n_splits=k_folds), + cv=StratifiedKFold(n_splits=k_folds), scoring=scoring, n_jobs=-1, return_train_score=False, @@ -88,7 +152,7 @@ def scoring_df(X, y): return pd.DataFrame(cv_results) -def shape_embed_process(): +def shape_embed_process(clargs): # Setting the font size mpl.rcParams["font.size"] = 10 @@ -101,36 +165,40 @@ def shape_embed_process(): sns.set(style="white", context="notebook", rc={"figure.figsize": (width, height)}) # matplotlib.use("TkAgg") - interp_size = 128 * 2 + interp_size = clargs.latent_space_size * 2 + #interp_size = 128 * 2 max_epochs = 100 - window_size = 128 * 2 + window_size = clargs.latent_space_size * 2 + #window_size = 128 * 2 params = { - "model":"resnet18_vqvae_legacy", - "epochs": 75, - "batch_size": 4, + "model":clargs.model, + #"model":"resnet18_vae", + "epochs": 150, + "batch_size": clargs.batch_size, + #"batch_size": 4, "num_workers": 2**4, "input_dim": (3, interp_size, interp_size), "latent_dim": interp_size, "num_embeddings": interp_size, "num_hiddens": interp_size, - "num_residual_hiddens": 32, - "num_residual_layers": 150, "pretrained": True, - # "embedding_dim": 32, - # "num_embeddings": 16, "commitment_cost": 0.25, "decay": 0.99, "frobenius_norm": False, + # dataset = "bbbc010/BBBC010_v1_foreground_eachworm" + # dataset = "vampire/mefs/data/processed/Control" + #"dataset": "synthcellshapes_dataset", + "dataset": clargs.dataset[0], } - + optimizer_params = { - "opt": "LAMB", + "opt": "AdamW", "lr": 0.001, "weight_decay": 0.0001, "momentum": 0.9, } - + lr_scheduler_params = { "sched": "cosine", "min_lr": 1e-4, @@ -140,34 +208,27 @@ def shape_embed_process(): "t_max": 50, "cycle_momentum": False, } - + args = SimpleNamespace(**params, **optimizer_params, **lr_scheduler_params) - - #dataset_path = "bbbc010/BBBC010_v1_foreground_eachworm" - dataset_path = "shape_embed_data/data/bbbc010/BBBC010_v1_foreground_eachworm/" - # dataset_path = "vampire/mefs/data/processed/Control" - # dataset_path = "shape_embed_data/data/vampire/torchvision/Control/" - # dataset_path = "vampire/torchvision/Control" - # dataset = "bbbc010" - - # train_data_path = f"scripts/shapes/data/{dataset_path}" - train_data_path = f"scripts/shapes/data/{dataset_path}" + + dataset_path = clargs.dataset[1] + train_data_path = f"/nfs/research/uhlmann/afoix/{dataset_path}" metadata = lambda x: f"results/{dataset_path}_{args.model}/{x}" - + path = Path(metadata("")) path.mkdir(parents=True, exist_ok=True) - model_dir = f"models/{dataset_path}_{args.model}" # %% - + transform_crop = CropCentroidPipeline(window_size) - transform_dist = MaskToDistogramPipeline( - window_size, interp_size, matrix_normalised=False - ) - transform_mdscoords = DistogramToCoords(window_size) + # transform_dist = MaskToDistogramPipeline( + # window_size, interp_size, matrix_normalised=False + # ) + transform_coord_to_dist = CoordsToDistogram(interp_size, matrix_normalised=False) + #transform_mdscoords = DistogramToCoords(window_size) transform_coords = ImageToCoords(window_size) - + transform_mask_to_gray = transforms.Compose([transforms.Grayscale(1)]) - + transform_mask_to_crop = transforms.Compose( [ # transforms.ToTensor(), @@ -176,76 +237,112 @@ def shape_embed_process(): ] ) - transform_mask_to_dist = transforms.Compose( + transform_mask_to_coords = transforms.Compose( [ transform_mask_to_crop, - transform_dist, + transform_coords, ] ) - transform_mask_to_coords = transforms.Compose( + + transform_mask_to_dist = transforms.Compose( [ - transform_mask_to_crop, - transform_coords, + transform_mask_to_coords, + transform_coord_to_dist, ] ) + gray2rgb = transforms.Lambda(lambda x: x.repeat(3, 1, 1)) + transform = transforms.Compose( + [ + transform_mask_to_dist, + transforms.ToTensor(), + RotateIndexingClockwise(p=1), + gray2rgb, + ] + ) + transforms_dict = { "none": transform_mask_to_gray, "transform_crop": transform_mask_to_crop, "transform_dist": transform_mask_to_dist, "transform_coords": transform_mask_to_coords, } + + # Apply transform to find which images don't work + dataset = datasets.ImageFolder(train_data_path, transform=transform) + + valid_indices = [] + # Iterate through the dataset and apply the transform to each image + for idx in range(len(dataset)): + try: + image, label = dataset[idx] + # If the transform works without errors, add the index to the list of valid indices + valid_indices.append(idx) + except Exception as e: + # A better way to do with would be with batch collation + logger.warning(f"Error occurred for image {idx}: {e}") train_data = { - key: datasets.ImageFolder(train_data_path, transform=value) + key: torch.utils.data.Subset( + datasets.ImageFolder(train_data_path, transform=value), valid_indices + ) for key, value in transforms_dict.items() } + dataset = torch.utils.data.Subset( + datasets.ImageFolder(train_data_path, transform=transform), valid_indices + ) + for key, value in train_data.items(): - print(key, len(value)) - plt.imshow(train_data[key][0][0], cmap="gray") + logger.info(key, len(value)) + plt.imshow(np.array(train_data[key][0][0]), cmap="gray") plt.imsave(metadata(f"{key}.png"), train_data[key][0][0], cmap="gray") # plt.show() plt.close() - + # plt.scatter(*train_data["transform_coords"][0][0]) # plt.savefig(metadata(f"transform_coords.png")) # plt.show() - + # plt.imshow(train_data["transform_crop"][0][0], cmap="gray") # plt.scatter(*train_data["transform_coords"][0][0],c=np.arange(interp_size), cmap='rainbow', s=1) # plt.show() # plt.savefig(metadata(f"transform_coords.png")) - + # Retrieve the coordinates and cropped image coords = train_data["transform_coords"][0][0] crop_image = train_data["transform_crop"][0][0] - + fig = plt.figure(frameon=True) ax = plt.Axes(fig, [0, 0, 1, 1]) ax.set_axis_off() fig.add_axes(ax) - + # Display the cropped image using grayscale colormap plt.imshow(crop_image, cmap="gray_r") - + # Scatter plot with smaller point size plt.scatter(*coords, c=np.arange(interp_size), cmap="rainbow", s=2) - + # Save the plot as an image without border and coordinate axes plt.savefig(metadata(f"transform_coords.png"), bbox_inches="tight", pad_inches=0) - + # Close the plot plt.close() # import albumentations as A # %% gray2rgb = transforms.Lambda(lambda x: x.repeat(3, 1, 1)) transform = transforms.Compose( - [transform_mask_to_dist, transforms.ToTensor(), gray2rgb] + [ + transform_mask_to_dist, + transforms.ToTensor(), + RotateIndexingClockwise(p=1), + gray2rgb, + ] ) - + dataset = datasets.ImageFolder(train_data_path, transform=transform) - + valid_indices = [] # Iterate through the dataset and apply the transform to each image for idx in range(len(dataset)): @@ -256,7 +353,7 @@ def shape_embed_process(): except Exception as e: # A better way to do with would be with batch collation print(f"Error occurred for image {idx}: {e}") - + # Create a Subset using the valid indices dataset = torch.utils.data.Subset(dataset, valid_indices) dataloader = DataModule( @@ -265,37 +362,53 @@ def shape_embed_process(): shuffle=True, num_workers=args.num_workers, ) - + # model = bioimage_embed.models.create_model("resnet18_vqvae_legacy", **vars(args)) # + model = bioimage_embed.models.create_model( model=args.model, input_dim=args.input_dim, latent_dim=args.latent_dim, pretrained=args.pretrained, ) - + # model = bioimage_embed.models.factory.ModelFactory(**vars(args)).resnet50_vqvae_legacy() - + # lit_model = shapes.MaskEmbedLatentAugment(model, args) lit_model = shapes.MaskEmbed(model, args) test_data = dataset[0][0].unsqueeze(0) # test_lit_data = 2*(dataset[0][0].unsqueeze(0).repeat_interleave(3, dim=1),) test_output = lit_model.forward((test_data,)) - + dataloader.setup() model.eval() - - model_dir = f"my_models/{dataset_path}_{model._get_name()}_{lit_model._get_name()}" - + + model_dir = f"checkpoints/{hashing_fn(args)}" + + if clargs.clear_checkpoints: + print("cleaning checkpoints") + shutil.rmtree("checkpoints/") + model_dir = f"checkpoints/{hashing_fn(args)}" + tb_logger = pl_loggers.TensorBoardLogger(f"logs/") - + jobname = f"{params['model']}_{interp_size}_{params['batch_size']}_{clargs.dataset[0]}" + wandblogger = pl_loggers.WandbLogger(entity='foix', project="shape_embed_fixes", name=jobname) + #wandblogger = pl_loggers.WandbLogger(project=clargs.wandb_project, name=jobname) + Path(f"{model_dir}/").mkdir(parents=True, exist_ok=True) - - checkpoint_callback = ModelCheckpoint(dirpath=f"{model_dir}/", save_last=True) - + + checkpoint_callback = ModelCheckpoint( + dirpath=f"{model_dir}/", + save_last=True, + save_top_k=1, + monitor="loss/val", + mode="min", + ) + wandblogger.watch(lit_model, log="all") + trainer = pl.Trainer( - logger=tb_logger, + logger=[wandblogger, tb_logger], gradient_clip_val=0.5, enable_checkpointing=True, devices=1, @@ -304,106 +417,85 @@ def shape_embed_process(): callbacks=[checkpoint_callback], min_epochs=50, max_epochs=args.epochs, + log_every_n_steps=1, ) # %% - try: - trainer.fit( - lit_model, datamodule=dataloader, ckpt_path=f"{model_dir}/last.ckpt" - ) - except: - trainer.fit(lit_model, datamodule=dataloader) - + + # Determine the checkpoint path for resuming + last_checkpoint_path = f"{model_dir}/last.ckpt" + best_checkpoint_path = checkpoint_callback.best_model_path + + # Check if a last checkpoint exists to resume from + if os.path.isfile(last_checkpoint_path): + resume_checkpoint = last_checkpoint_path + elif best_checkpoint_path and os.path.isfile(best_checkpoint_path): + resume_checkpoint = best_checkpoint_path + else: + resume_checkpoint = None + + trainer.fit(lit_model, datamodule=dataloader, ckpt_path=resume_checkpoint) + lit_model.eval() - + validation = trainer.validate(lit_model, datamodule=dataloader) - # testing = trainer.test(lit_model, datamodule=dataloader) + testing = trainer.test(lit_model, datamodule=dataloader) example_input = Variable(torch.rand(1, *args.input_dim)) - + # torch.jit.save(lit_model.to_torchscript(), f"{model_dir}/model.pt") # torch.onnx.export(lit_model, example_input, f"{model_dir}/model.onnx") - + # %% - # Inference - + # Inference on full dataset dataloader = DataModule( dataset, batch_size=1, shuffle=False, num_workers=args.num_workers, # Transform is commented here to avoid augmentations in real data - # HOWEVER, applying a the transform multiple times and averaging the results might produce better latent embeddings - # transform=transform, + # HOWEVER, applying the transform multiple times and averaging the results might produce better latent embeddings # transform=transform, ) dataloader.setup() - + predictions = trainer.predict(lit_model, datamodule=dataloader) - # Use the namespace variables - latent_space = torch.stack([d.out.z.flatten() for d in predictions]) - scalings = torch.stack([d.x.scalings.flatten() for d in predictions]) - - idx_to_class = {v: k for k, v in dataset.dataset.class_to_idx.items()} - - y = np.array([int(data[-1]) for data in dataloader.predict_dataloader()]) - - y_partial = y.copy() - indices = np.random.choice(y.size, int(0.3 * y.size), replace=False) - y_partial[indices] = -1 - y_blind = -1 * np.ones_like(y) - umap_labels = y_blind - classes = np.array([idx_to_class[i] for i in y]) + test_dist_pred = predictions[0].out.recon_x + plt.imsave(metadata(f"test_dist_pred.png"), test_dist_pred.mean(axis=(0,1))) + plt.close() - n_components = 64 # Number of UMAP components - component_names = [f"umap{i}" for i in range(n_components)] # List of column names + test_dist_in = predictions[0].x.data + plt.imsave(metadata(f"test_dist_in.png"), test_dist_in.mean(axis=(0,1))) + plt.close() - logger.info("UMAP fitting") - mapper = umap.UMAP(n_components=64, random_state=42).fit( - latent_space.numpy(), y=umap_labels + test_pred_coords = AsymmetricDistogramToCoordsPipeline(window_size=window_size)( + np.array(test_dist_pred[:, 0, :, :].unsqueeze(dim=0)) ) - logger.info("UMAP transforming") - semi_supervised_latent = mapper.transform(latent_space.numpy()) + plt.scatter(*test_pred_coords[0,0].T) + # Save the plot as an image without border and coordinate axes + plt.savefig(metadata(f"test_pred_coords.png"), bbox_inches="tight", pad_inches=0) + plt.close() - df = pd.DataFrame(semi_supervised_latent, columns=component_names) - df["Class"] = y - # Map numeric classes to their labels - idx_to_class = {0: "alive", 1: "dead"} - df["Class"] = df["Class"].map(idx_to_class) + # Use the namespace variables + latent_space = torch.stack([d.out.z.flatten() for d in predictions]) + scalings = torch.stack([d.x.scalings.flatten() for d in predictions]) + idx_to_class = {v: k for k, v in dataset.dataset.class_to_idx.items()} + y = np.array([int(data[-1]) for data in dataloader.predict_dataloader()]) + + df = pd.DataFrame(latent_space.numpy()) + df["Class"] = pd.Series(y).map(idx_to_class).astype("category") df["Scale"] = scalings[:, 0].squeeze() df = df.set_index("Class") df_shape_embed = df.copy() - - ax = sns.relplot( - data=df, - x="umap0", - y="umap1", - hue="Class", - palette="deep", - alpha=0.5, - edgecolor=None, - s=5, - height=height, - aspect=0.5 * width / height, - ) - - sns.move_legend( - ax, - "upper center", - ) - ax.set(xlabel=None, ylabel=None) - sns.despine(left=True, bottom=True) - plt.tick_params(bottom=False, left=False, labelbottom=False, labelleft=False) - plt.tight_layout() - plt.savefig(metadata(f"umap_no_axes.pdf")) - # plt.show() - plt.close() - + # %% + # %% UMAP plot - X = df_shape_embed.to_numpy() - y = df_shape_embed.index.values + umap_plot(df, metadata, width, height, split=0.9) + X = df_shape_embed.to_numpy() + y = df_shape_embed.index + properties = [ "area", "perimeter", @@ -413,36 +505,35 @@ def shape_embed_process(): "orientation", ] dfs = [] - for i, data in enumerate(train_data["transform_crop"]): + # Distance matrix data + for i, data in enumerate(tqdm(train_data["transform_crop"])): X, y = data # Do regionprops here # Calculate shape summary statistics using regionprops - # We're considering that the mask has only one object, thus we take the first element [0] + # We're considering that the mask has only one object, so we take the first element [0] # props = regionprops(np.array(X).astype(int))[0] props_table = measure.regionprops_table( np.array(X).astype(int), properties=properties ) - + # Store shape properties in a dataframe df = pd.DataFrame(props_table) - + # Assuming the class or label is contained in 'y' variable df["class"] = y df.set_index("class", inplace=True) dfs.append(df) - + df_regionprops = pd.concat(dfs) - - # Assuming 'dataset_contour' is your DataLoader for the dataset dfs = [] - for i, data in enumerate(train_data["transform_coords"]): + for i, data in enumerate(tqdm(train_data["transform_coords"])): # Convert the tensor to a numpy array X, y = data - + # Feed it to PyEFD's calculate_efd function coeffs = pyefd.elliptic_fourier_descriptors(X, order=10, normalize=False) # coeffs_df = pd.DataFrame({'class': [y], 'norm_coeffs': [norm_coeffs.flatten().tolist()]}) - + norm_coeffs = pyefd.normalize_efd(coeffs) df = pd.DataFrame( { @@ -453,9 +544,9 @@ def shape_embed_process(): df["class"] = y df.set_index("class", inplace=True, append=True) dfs.append(df) - + df_pyefd = pd.concat(dfs) - + trials = [ { "name": "mask_embed", @@ -477,21 +568,37 @@ def shape_embed_process(): "labels": df_regionprops.index, }, ] - + trial_df = pd.DataFrame() for trial in trials: X = trial["features"] y = trial["labels"] trial["score_df"] = scoring_df(X, y) trial["score_df"]["trial"] = trial["name"] - print(trial["score_df"]) + logger.info(trial["score_df"]) trial["score_df"].to_csv(metadata(f"{trial['name']}_score_df.csv")) trial_df = pd.concat([trial_df, trial["score_df"]]) trial_df = trial_df.drop(["fit_time", "score_time"], axis=1) - + trial_df.to_csv(metadata(f"trial_df.csv")) trial_df.groupby("trial").mean().to_csv(metadata(f"trial_df_mean.csv")) trial_df.plot(kind="bar") + + #mean_df = trial_df.groupby("trial").mean() + #std_df = trial_df.groupby("trial").std() + #wandb.log_table(mean_df) + #wandb.log_table(std_df) + + #Special metrics for f1 score for wandb + wandblogger.experiment.log({"trial_df": wandb.Table(dataframe=trial_df)}) + mean_df = trial_df.groupby("trial").mean() + std_df = trial_df.groupby("trial").std() + wandblogger.experiment.log({"Mean": wandb.Table(dataframe=mean_df)}) + wandblogger.experiment.log({"Std": wandb.Table(dataframe=std_df)}) + + avg = trial_df.groupby("trial").mean() + logger.info(avg) + avg.to_latex(metadata(f"trial_df.tex")) melted_df = trial_df.melt(id_vars="trial", var_name="Metric", value_name="Score") # fig, ax = plt.subplots(figsize=(width, height)) @@ -512,16 +619,66 @@ def shape_embed_process(): # plt.tight_layout() plt.savefig(metadata(f"trials_barplot.pdf")) plt.close() - + avs = ( melted_df.set_index(["trial", "Metric"]) .xs("test_f1", level="Metric", drop_level=False) .groupby("trial") .mean() ) - print(avs) + logger.info(avs) # tikzplotlib.save(metadata(f"trials_barplot.tikz")) + + +############################################################################### + if __name__ == "__main__": - shape_embed_process() + + def auto_pos_int (x): + val = int(x,0) + if val <= 0: + raise argparse.ArgumentTypeError("argument must be a positive int. Got {:d}.".format(val)) + return val + + parser = argparse.ArgumentParser(description='Run the shape embed pipeline') + + models = [ + "resnet18_vae" + , "resnet50_vae" + , "resnet18_vae_bolt" + , "resnet50_vae_bolt" + , "resnet18_vqvae" + , "resnet50_vqvae" + , "resnet18_vqvae_legacy" + , "resnet50_vqvae_legacy" + , "resnet101_vqvae_legacy" + , "resnet110_vqvae_legacy" + , "resnet152_vqvae_legacy" + , "resnet18_vae_legacy" + , "resnet50_vae_legacy" + ] + parser.add_argument( + '-m', '--model', choices=models, default=models[0], metavar='MODEL' + , help=f"The MODEL to use, one of {models} (default {models[0]}).") + parser.add_argument( + '-d', '--dataset', nargs=2, default=("vampire", "vampire/torchvision/Control/"), metavar=('NAME', 'PATH') + , help=f"The NAME of and PATH to the dataset") + parser.add_argument( + '-w', '--wandb-project', default="shape-embed", metavar='PROJECT' + , help=f"The wandb PROJECT name") + parser.add_argument( + '-b', '--batch-size', default=int(4), metavar='BATCH_SIZE', type=auto_pos_int + , help="The BATCH_SIZE for the run, a positive integer (default 4)") + parser.add_argument( + '-l', '--latent-space-size', default=int(128), metavar='LATENT_SPACE_SIZE', type=auto_pos_int + , help="The LATENT_SPACE_SIZE, a positive integer (default 128)") + parser.add_argument('--clear-checkpoints', action='store_true' + , help='remove checkpoints') + #parser.add_argument('-v', '--verbose', action='count', default=0, + # help="Increase verbosity level by adding more \"v\".") + + #clargs=parser.parse_args() + #print(clargs.dataset) + shape_embed_process(parser.parse_args()) diff --git a/scripts/shapes/shape_embed_backup.py b/scripts/shapes/shape_embed_backup.py new file mode 100644 index 00000000..eea708e4 --- /dev/null +++ b/scripts/shapes/shape_embed_backup.py @@ -0,0 +1,558 @@ +# %% +import seaborn as sns +import pyefd +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import cross_validate, KFold, train_test_split +from sklearn.metrics import make_scorer +import pandas as pd +from sklearn import metrics +import matplotlib as mpl +import seaborn as sns +from pathlib import Path +import umap +from torch.autograd import Variable +from types import SimpleNamespace +import numpy as np +import logging +from skimage import measure +import umap.plot +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +import pytorch_lightning as pl +import torch +from types import SimpleNamespace +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +import argparse +import wandb +import shutil + +# Deal with the filesystem +import torch.multiprocessing + +torch.multiprocessing.set_sharing_strategy("file_system") + +from bioimage_embed import shapes +import bioimage_embed + +# Note - you must have torchvision installed for this example + +from pytorch_lightning import loggers as pl_loggers +from torchvision import transforms +from bioimage_embed.lightning import DataModule + +from torchvision import datasets +from bioimage_embed.shapes.transforms import ( + ImageToCoords, + CropCentroidPipeline, + DistogramToCoords, + MaskToDistogramPipeline, + RotateIndexingClockwise, +) + +import matplotlib.pyplot as plt + +from bioimage_embed.lightning import DataModule +import matplotlib as mpl +from matplotlib import rc + +import logging +import pickle +import base64 +import hashlib + +logger = logging.getLogger(__name__) + +def hashing_fn(args): + serialized_args = pickle.dumps(vars(args)) + hash_object = hashlib.sha256(serialized_args) + hashed_string = base64.urlsafe_b64encode(hash_object.digest()).decode() + return hashed_string + +def scoring_df(X, y): + # Split the data into training and test sets + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42, shuffle=True, stratify=y + ) + # Define a dictionary of metrics + scoring = { + "accuracy": make_scorer(metrics.accuracy_score), + "precision": make_scorer(metrics.precision_score, average="macro"), + "recall": make_scorer(metrics.recall_score, average="macro"), + "f1": make_scorer(metrics.f1_score, average="macro"), + } + + # Create a random forest classifier + clf = RandomForestClassifier() + + # Specify the number of folds + k_folds = 10 + + # Perform k-fold cross-validation + cv_results = cross_validate( + estimator=clf, + X=X, + y=y, + cv=KFold(n_splits=k_folds), + scoring=scoring, + n_jobs=-1, + return_train_score=False, + ) + + # Put the results into a DataFrame + return pd.DataFrame(cv_results) + + +def shape_embed_process(clargs): + # Setting the font size + mpl.rcParams["font.size"] = 10 + + # rc("text", usetex=True) + rc("font", **{"family": "sans-serif", "sans-serif": ["Arial"]}) + width = 3.45 + height = width / 1.618 + plt.rcParams["figure.figsize"] = [width, height] + + sns.set(style="white", context="notebook", rc={"figure.figsize": (width, height)}) + + # matplotlib.use("TkAgg") + interp_size = clargs.latent_space_size * 2 + #interp_size = 128 * 2 + max_epochs = 100 + window_size = clargs.latent_space_size * 2 + #window_size = 128 * 2 + + params = { + "model":clargs.model, + #"model":"resnet18_vae", + "epochs": 75, + "batch_size": clargs.batch_size, + #"batch_size": 4, + "num_workers": 2**4, + "input_dim": (3, interp_size, interp_size), + "latent_dim": interp_size, + "num_embeddings": interp_size, + "num_hiddens": interp_size, + "pretrained": True, + "commitment_cost": 0.25, + "decay": 0.99, + "frobenius_norm": False, + } + + optimizer_params = { + "opt": "AdamW", + "lr": 0.001, + "weight_decay": 0.0001, + "momentum": 0.9, + } + + lr_scheduler_params = { + "sched": "cosine", + "min_lr": 1e-4, + "warmup_epochs": 5, + "warmup_lr": 1e-6, + "cooldown_epochs": 10, + "t_max": 50, + "cycle_momentum": False, + } + + args = SimpleNamespace(**params, **optimizer_params, **lr_scheduler_params) + + dataset_path = "bbbc010/BBBC010_v1_foreground_eachworm/" + dataset = "bbbc010" + train_data_path = f"/nfs/research/uhlmann/afoix/{dataset_path}" + metadata = lambda x: f"results/{dataset_path}_{args.model}/{x}" + + path = Path(metadata("")) + path.mkdir(parents=True, exist_ok=True) + # %% + + transform_crop = CropCentroidPipeline(window_size) + transform_dist = MaskToDistogramPipeline( + window_size, interp_size, matrix_normalised=False + ) + transform_mdscoords = DistogramToCoords(window_size) + transform_coords = ImageToCoords(window_size) + + transform_mask_to_gray = transforms.Compose([transforms.Grayscale(1)]) + + transform_mask_to_crop = transforms.Compose( + [ + # transforms.ToTensor(), + transform_mask_to_gray, + transform_crop, + ] + ) + + transform_mask_to_dist = transforms.Compose( + [ + transform_mask_to_crop, + transform_dist, + ] + ) + transform_mask_to_coords = transforms.Compose( + [ + transform_mask_to_crop, + transform_coords, + ] + ) + + transforms_dict = { + "none": transform_mask_to_gray, + "transform_crop": transform_mask_to_crop, + "transform_dist": transform_mask_to_dist, + "transform_coords": transform_mask_to_coords, + } + + train_data = { + key: datasets.ImageFolder(train_data_path, transform=value) + for key, value in transforms_dict.items() + } + + for key, value in train_data.items(): + print(key, len(value)) + plt.imshow(train_data[key][0][0], cmap="gray") + plt.imsave(metadata(f"{key}.png"), train_data[key][0][0], cmap="gray") + # plt.show() + plt.close() + + # plt.scatter(*train_data["transform_coords"][0][0]) + # plt.savefig(metadata(f"transform_coords.png")) + # plt.show() + + # plt.imshow(train_data["transform_crop"][0][0], cmap="gray") + # plt.scatter(*train_data["transform_coords"][0][0],c=np.arange(interp_size), cmap='rainbow', s=1) + # plt.show() + # plt.savefig(metadata(f"transform_coords.png")) + + # Retrieve the coordinates and cropped image + coords = train_data["transform_coords"][0][0] + crop_image = train_data["transform_crop"][0][0] + + fig = plt.figure(frameon=True) + ax = plt.Axes(fig, [0, 0, 1, 1]) + ax.set_axis_off() + fig.add_axes(ax) + + # Display the cropped image using grayscale colormap + plt.imshow(crop_image, cmap="gray_r") + + # Scatter plot with smaller point size + plt.scatter(*coords, c=np.arange(interp_size), cmap="rainbow", s=2) + + # Save the plot as an image without border and coordinate axes + plt.savefig(metadata(f"transform_coords.png"), bbox_inches="tight", pad_inches=0) + + # Close the plot + plt.close() + # import albumentations as A + # %% + gray2rgb = transforms.Lambda(lambda x: x.repeat(3, 1, 1)) + transform = transforms.Compose( + [ + transform_mask_to_dist, + transforms.ToTensor(), + RotateIndexingClockwise(p=1), + gray2rgb, + ] + ) + + dataset = datasets.ImageFolder(train_data_path, transform=transform) + + valid_indices = [] + # Iterate through the dataset and apply the transform to each image + for idx in range(len(dataset)): + try: + image, label = dataset[idx] + # If the transform works without errors, add the index to the list of valid indices + valid_indices.append(idx) + except Exception as e: + # A better way to do with would be with batch collation + print(f"Error occurred for image {idx}: {e}") + + # Create a Subset using the valid indices + dataset = torch.utils.data.Subset(dataset, valid_indices) + dataloader = DataModule( + dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.num_workers, + ) + + # model = bioimage_embed.models.create_model("resnet18_vqvae_legacy", **vars(args)) + # + model = bioimage_embed.models.create_model( + model=args.model, + input_dim=args.input_dim, + latent_dim=args.latent_dim, + pretrained=args.pretrained, + ) + + # model = bioimage_embed.models.factory.ModelFactory(**vars(args)).resnet50_vqvae_legacy() + + # lit_model = shapes.MaskEmbedLatentAugment(model, args) + lit_model = shapes.MaskEmbed(model, args) + test_data = dataset[0][0].unsqueeze(0) + # test_lit_data = 2*(dataset[0][0].unsqueeze(0).repeat_interleave(3, dim=1),) + test_output = lit_model.forward((test_data,)) + + dataloader.setup() + model.eval() + + if clargs.clear_checkpoints: + print("cleaning checkpoints") + shutil.rmtree("checkpoints/") + model_dir = f"checkpoints/{hashing_fn(args)}" + + tb_logger = pl_loggers.TensorBoardLogger(f"logs/") + wandblogger = pl_loggers.WandbLogger(project="shape-embed", name=f"{params['model']}_{interp_size}_{params['batch_size']}") + + Path(f"{model_dir}/").mkdir(parents=True, exist_ok=True) + + checkpoint_callback = ModelCheckpoint(dirpath=f"{model_dir}/", save_last=True) + wandblogger.watch(lit_model, log="all") + + trainer = pl.Trainer( + logger=[wandblogger,tb_logger], + gradient_clip_val=0.5, + enable_checkpointing=True, + devices=1, + accelerator="gpu", + accumulate_grad_batches=4, + callbacks=[checkpoint_callback], + min_epochs=50, + max_epochs=args.epochs, + log_every_n_steps=1, + ) + # %% + try: + trainer.fit( + lit_model, datamodule=dataloader, ckpt_path=f"{model_dir}/last.ckpt" + ) + except: + trainer.fit(lit_model, datamodule=dataloader) + + lit_model.eval() + + validation = trainer.validate(lit_model, datamodule=dataloader) + testing = trainer.test(lit_model, datamodule=dataloader) + example_input = Variable(torch.rand(1, *args.input_dim)) + + # torch.jit.save(lit_model.to_torchscript(), f"{model_dir}/model.pt") + # torch.onnx.export(lit_model, example_input, f"{model_dir}/model.onnx") + + # %% + # Inference + + dataloader = DataModule( + dataset, + batch_size=1, + shuffle=False, + num_workers=args.num_workers, + # Transform is commented here to avoid augmentations in real data + # HOWEVER, applying a the transform multiple times and averaging the results might produce better latent embeddings + # transform=transform, + # transform=transform, + ) + dataloader.setup() + + predictions = trainer.predict(lit_model, datamodule=dataloader) + + # Use the namespace variables + latent_space = torch.stack([d.out.z.flatten() for d in predictions]) + scalings = torch.stack([d.x.scalings.flatten() for d in predictions]) + idx_to_class = {v: k for k, v in dataset.dataset.class_to_idx.items()} + y = np.array([int(data[-1]) for data in dataloader.predict_dataloader()]) + + y_partial = y.copy() + indices = np.random.choice(y.size, int(0.3 * y.size), replace=False) + y_partial[indices] = -1 + y_blind = -1 * np.ones_like(y) + + df = pd.DataFrame(latent_space.numpy()) + df["Class"] = y + # Map numeric classes to their labels + idx_to_class = {0: "alive", 1: "dead"} + df["Class"] = df["Class"].map(idx_to_class) + df["Scale"] = scalings[:, 0].squeeze() + df = df.set_index("Class") + df_shape_embed = df.copy() + + # %% + + X = df_shape_embed.to_numpy() + y = df_shape_embed.index.values + + properties = [ + "area", + "perimeter", + "centroid", + "major_axis_length", + "minor_axis_length", + "orientation", + ] + dfs = [] + for i, data in enumerate(train_data["transform_crop"]): + X, y = data + # Do regionprops here + # Calculate shape summary statistics using regionprops + # We're considering that the mask has only one object, thus we take the first element [0] + # props = regionprops(np.array(X).astype(int))[0] + props_table = measure.regionprops_table( + np.array(X).astype(int), properties=properties + ) + + # Store shape properties in a dataframe + df = pd.DataFrame(props_table) + + # Assuming the class or label is contained in 'y' variable + df["class"] = y + df.set_index("class", inplace=True) + dfs.append(df) + + df_regionprops = pd.concat(dfs) + + # Assuming 'dataset_contour' is your DataLoader for the dataset + dfs = [] + for i, data in enumerate(train_data["transform_coords"]): + # Convert the tensor to a numpy array + X, y = data + + # Feed it to PyEFD's calculate_efd function + coeffs = pyefd.elliptic_fourier_descriptors(X, order=10, normalize=False) + # coeffs_df = pd.DataFrame({'class': [y], 'norm_coeffs': [norm_coeffs.flatten().tolist()]}) + + norm_coeffs = pyefd.normalize_efd(coeffs) + df = pd.DataFrame( + { + "norm_coeffs": norm_coeffs.flatten().tolist(), + "coeffs": coeffs.flatten().tolist(), + } + ).T.rename_axis("coeffs") + df["class"] = y + df.set_index("class", inplace=True, append=True) + dfs.append(df) + + df_pyefd = pd.concat(dfs) + + trials = [ + { + "name": "mask_embed", + "features": df_shape_embed.to_numpy(), + "labels": df_shape_embed.index, + }, + { + "name": "fourier_coeffs", + "features": df_pyefd.xs("coeffs", level="coeffs"), + "labels": df_pyefd.xs("coeffs", level="coeffs").index, + }, + # {"name": "fourier_norm_coeffs", + # "features": df_pyefd.xs("norm_coeffs", level="coeffs"), + # "labels": df_pyefd.xs("norm_coeffs", level="coeffs").index + # } + { + "name": "regionprops", + "features": df_regionprops, + "labels": df_regionprops.index, + }, + ] + + trial_df = pd.DataFrame() + for trial in trials: + X = trial["features"] + y = trial["labels"] + trial["score_df"] = scoring_df(X, y) + trial["score_df"]["trial"] = trial["name"] + print(trial["score_df"]) + trial["score_df"].to_csv(metadata(f"{trial['name']}_score_df.csv")) + trial_df = pd.concat([trial_df, trial["score_df"]]) + trial_df = trial_df.drop(["fit_time", "score_time"], axis=1) + + trial_df.to_csv(metadata(f"trial_df.csv")) + trial_df.groupby("trial").mean().to_csv(metadata(f"trial_df_mean.csv")) + trial_df.plot(kind="bar") + + #mean_df = trial_df.groupby("trial").mean() + #std_df = trial_df.groupby("trial").std() + #wandb.log_table(mean_df) + #wandb.log_table(std_df) + + #Special metrics for f1 score for wandb + wandblogger.experiment.log({"trial_df": wandb.Table(dataframe=trial_df)}) + mean_df = trial_df.groupby("trial").mean() + std_df = trial_df.groupby("trial").std() + wandblogger.experiment.log({"Mean": wandb.Table(dataframe=mean_df)}) + wandblogger.experiment.log({"Std": wandb.Table(dataframe=std_df)}) + + melted_df = trial_df.melt(id_vars="trial", var_name="Metric", value_name="Score") + # fig, ax = plt.subplots(figsize=(width, height)) + ax = sns.catplot( + data=melted_df, + kind="bar", + x="trial", + hue="Metric", + y="Score", + errorbar="se", + height=height, + aspect=width * 2**0.5 / height, + ) + # ax.xtick_params(labelrotation=45) + # plt.legend(loc='lower center', bbox_to_anchor=(1, 1)) + # sns.move_legend(ax, "lower center", bbox_to_anchor=(1, 1)) + # ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left') + # plt.tight_layout() + plt.savefig(metadata(f"trials_barplot.pdf")) + plt.close() + + avs = ( + melted_df.set_index(["trial", "Metric"]) + .xs("test_f1", level="Metric", drop_level=False) + .groupby("trial") + .mean() + ) + print(avs) + # tikzplotlib.save(metadata(f"trials_barplot.tikz")) + + + + +############################################################################### + +if __name__ == "__main__": + + def auto_pos_int (x): + val = int(x,0) + if val <= 0: + raise argparse.ArgumentTypeError("argument must be a positive int. Got {:d}.".format(val)) + return val + + parser = argparse.ArgumentParser(description='Run the shape embed pipeline') + + models = [ + "resnet18_vae" + , "resnet50_vae" + , "resnet18_vae_bolt" + , "resnet50_vae_bolt" + , "resnet18_vqvae" + , "resnet50_vqvae" + , "resnet18_vqvae_legacy" + , "resnet50_vqvae_legacy" + , "resnet101_vqvae_legacy" + , "resnet110_vqvae_legacy" + , "resnet152_vqvae_legacy" + , "resnet18_vae_legacy" + , "resnet50_vae_legacy" + ] + parser.add_argument( + '-m', '--model', choices=models, default=models[0], metavar='MODEL' + , help=f"The MODEL to use, one of {models} (default {models[0]}).") + parser.add_argument( + '-b', '--batch-size', default=int(4), metavar='BATCH_SIZE', type=auto_pos_int + , help="The BATCH_SIZE for the run, a positive integer (default 4)") + parser.add_argument( + '-l', '--latent-space-size', default=int(128), metavar='LATENT_SPACE_SIZE', type=auto_pos_int + , help="The LATENT_SPACE_SIZE, a positive integer (default 128)") + parser.add_argument('--clear-checkpoints', action='store_true' + , help='remove checkpoints') + #parser.add_argument('-v', '--verbose', action='count', default=0, + # help="Increase verbosity level by adding more \"v\".") + + shape_embed_process(parser.parse_args()) diff --git a/slurm_shape_embed.py b/slurm_shape_embed.py new file mode 100644 index 00000000..daea5ca5 --- /dev/null +++ b/slurm_shape_embed.py @@ -0,0 +1,103 @@ +#! /usr/bin/env python3 + +import os +import subprocess +import tempfile + +## Assign the arguments to variables +#model_arg=$1 +#sizes_list="${@:2}" +# +## Create SLURM job script +#job_script="slurm_job.sh" +# +#echo "#!/bin/bash" > "$job_script" +#echo "#SBATCH --job-name=ite_shape_embed" >> "$job_script" +#echo "#SBATCH --output=ite_shape_embed.out" >> "$job_script" +#echo "#SBATCH --error=ite_shape_embed.err" >> "$job_script" +#echo "#SBATCH --gres=gpu:2" >> "$job_script" # Adjust the number of CPUs as needed +#echo "#SBATCH --mem=50GB" >> "$job_script" # Adjust the memory requirement as needed +#echo "" >> "$job_script" +# +## Loop through the sizes and append the Python command to the job script +#for size in $sizes_list; do +# echo "python ite_shape_embed.py --model $model_arg --ls_size $size" >> "$job_script" +#done +# +## Submit SLURM job +#sbatch "$job_script" + +models = [ + "resnet18_vae" +, "resnet50_vae" +, "resnet18_vae_bolt" +, "resnet50_vae_bolt" +, "resnet18_vqvae" +, "resnet50_vqvae" +, "resnet18_vqvae_legacy" +, "resnet50_vqvae_legacy" +, "resnet101_vqvae_legacy" +, "resnet110_vqvae_legacy" +, "resnet152_vqvae_legacy" +, "resnet18_vae_legacy" +, "resnet50_vae_legacy" +] +batch_sizes = [4, 8, 16] +latent_space_sizes = [64, 128, 256, 512] + +slurm_script="""#!/bin/bash + +JOB_NAME=shape_embed_{model}_{b_size}_{ls_size} +echo "running shape embed with:" +echo " - model {model}" +echo " - batch size {b_size}" +echo " - latent space size {ls_size}" +rand_name=$(cat /dev/urandom | tr -cd 'a-f0-9' | head -c 16) +mkdir -p slurm_rundir/$rand_name +cp -r $(ls | grep -v slurm_rundir) slurm_rundir/$rand_name/. +cd slurm_rundir/$rand_name +python3 scripts/shapes/shape_embed.py --model {model} --batch-size {b_size} --latent-space-size {ls_size} --clear-checkpoints +""" + +def mem_size(ls): + if ls <= 128: + return '50GB' + if ls > 128: + return '100GB' + if ls > 256: + return '300GB' + +def n_gpus(ls): + if ls <= 128: + return 'gpu:2' + if ls > 128: + return 'gpu:2' + if ls > 256: + return 'gpu:3' + +if __name__ == "__main__": + + slurmdir = f'{os.getcwd()}/slurmdir' + os.makedirs(slurmdir, exist_ok=True) + for m, bs, ls in [ (m,bs,ls) for m in models + for bs in batch_sizes + for ls in latent_space_sizes ]: + jobname = f'shape_embed_{m}_{bs}_{ls}' + print(jobname) + fp = open(mode='w+', file=f'{slurmdir}/slurm_script_shape_embed_{m}_{bs}_{ls}.script') + fp.write(slurm_script.format(model=m, b_size=bs, ls_size=ls)) + fp.flush() + print(f'{fp.name}') + print(f'cat {fp.name}') + result = subprocess.run(['cat', fp.name], stdout=subprocess.PIPE) + print(result.stdout.decode('utf-8')) + print(mem_size(ls)) + result = subprocess.run([ 'sbatch' + , '--time', '10:00:00' + , '--mem', mem_size(ls) + , '--job-name', jobname + , '--output', f'{slurmdir}/{jobname}.out' + , '--error', f'{slurmdir}/{jobname}.err' + , '--gres', n_gpus(ls) + , fp.name], stdout=subprocess.PIPE) + print(result.stdout.decode('utf-8')) diff --git a/slurm_shape_embed_dataset.py b/slurm_shape_embed_dataset.py new file mode 100644 index 00000000..0651c361 --- /dev/null +++ b/slurm_shape_embed_dataset.py @@ -0,0 +1,111 @@ +#! /usr/bin/env python3 + +import os +import subprocess +import tempfile + +## Assign the arguments to variables +#model_arg=$1 +#sizes_list="${@:2}" +# +## Create SLURM job script +#job_script="slurm_job.sh" +# +#echo "#!/bin/bash" > "$job_script" +#echo "#SBATCH --job-name=ite_shape_embed" >> "$job_script" +#echo "#SBATCH --output=ite_shape_embed.out" >> "$job_script" +#echo "#SBATCH --error=ite_shape_embed.err" >> "$job_script" +#echo "#SBATCH --gres=gpu:2" >> "$job_script" # Adjust the number of CPUs as needed +#echo "#SBATCH --mem=50GB" >> "$job_script" # Adjust the memory requirement as needed +#echo "" >> "$job_script" +# +## Loop through the sizes and append the Python command to the job script +#for size in $sizes_list; do +# echo "python ite_shape_embed.py --model $model_arg --ls_size $size" >> "$job_script" +#done +# +## Submit SLURM job +#sbatch "$job_script" + +models = [ + "resnet50_vae" +, "resnet50_vqvae" +, "resnet50_vqvae_legacy" +, "resnet50_vae_legacy" +, "resnet18_vae" +, "resnet18_vqvae" +, "resnet18_vqvae_legacy" +, "resnet18_vae_legacy"] + +batch_sizes = [4] +latent_space_sizes = [512] + +datasets = [ +# ("tiny_synthcell", "tiny_synthcellshapes_dataset/") + ("vampire", "vampire/torchvision/Control/") +, ("bbbc010", "bbbc010/BBBC010_v1_foreground_eachworm/") +, ("synthcell", "synthcellshapes_dataset/") +, ("helakyoto", "H2b_10x_MD_exp665/samples/") +, ("allen", "allen_dataset/") +] + +wandb_project='shape-embed-no-norm' + +slurm_script="""#!/bin/bash + +echo "running shape embed with:" +echo " - model {model}" +echo " - dataset {dataset[0]} ({dataset[1]})" +echo " - batch size {b_size}" +echo " - latent space size {ls_size}" +rand_name=$(cat /dev/urandom | tr -cd 'a-f0-9' | head -c 16) +mkdir -p slurm_rundir/$rand_name +cp -r $(ls | grep -v slurm_rundir) slurm_rundir/$rand_name/. +cd slurm_rundir/$rand_name +python3 scripts/shapes/shape_embed.py --wandb-project {wandb_project} --model {model} --dataset {dataset[0]} {dataset[1]} --batch-size {b_size} --latent-space-size {ls_size} --clear-checkpoints +""" + +def mem_size(ls): + if ls <= 128: + return '50GB' + if ls > 128: + return '100GB' + if ls > 256: + return '300GB' + +def n_gpus(ls): + if ls <= 128: + return 'gpu:2' + if ls > 128: + return 'gpu:2' + if ls > 256: + return 'gpu:3' + +if __name__ == "__main__": + + slurmdir = f'{os.getcwd()}/slurmdir' + os.makedirs(slurmdir, exist_ok=True) + for m, bs, ls, ds in [ (m,bs,ls,ds) for m in models + for bs in batch_sizes + for ls in latent_space_sizes + for ds in datasets ]: + jobname = f'shape_embed_{m}_{ds[0]}_{bs}_{ls}' + print(jobname) + fp = open(mode='w+', file=f'{slurmdir}/slurm_script_shape_embed_{m}_{bs}_{ls}.script') + fp.write(slurm_script.format(model=m, dataset=ds, b_size=bs, ls_size=ls, wandb_project=wandb_project)) + fp.flush() + print(f'{fp.name}') + print(f'cat {fp.name}') + result = subprocess.run(['cat', fp.name], stdout=subprocess.PIPE) + print(result.stdout.decode('utf-8')) + print(mem_size(ls)) + result = subprocess.run([ 'sbatch' + , '--time', '24:00:00' + , '--mem', mem_size(ls) + , '--job-name', jobname + , '--output', f'{slurmdir}/{jobname}.out' + , '--error', f'{slurmdir}/{jobname}.err' + #, '--gres', n_gpus(ls) + , '--gpus=a100:1' + , fp.name], stdout=subprocess.PIPE) + print(result.stdout.decode('utf-8'))