Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor #65

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ jobs:
run: poetry --help
- name: Install poetry
uses: abatilo/actions-poetry@v2
- name: Dynamic versioning
uses: poetry self add poetry-dynamic-versioning
- name: Setup a local virtual environment (if no poetry.toml file)
run: |
poetry config virtualenvs.create true --local
Expand Down
21 changes: 15 additions & 6 deletions bioimage_embed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import torch

torch.cuda.empty_cache()
# from . import models, lightning, cli, export, config
from .lightning import AESupervised, AEUnsupervised, AE, AutoEncoderSupervised, AutoEncoderUnsupervised, AutoEncoder
from .lightning import (
AESupervised,
AEUnsupervised,
AE,
AutoEncoderSupervised,
AutoEncoderUnsupervised,
AutoEncoder,
)

# TODO: Fix this import as it currently produces too many warnings
from .models import ModelFactory, create_model
Expand All @@ -13,14 +17,19 @@
import logging
logging.captureWarnings(True)

import torch

torch.cuda.empty_cache()
__all__ = [
"AESupervised",
"AutoEncoderUnsupervised",
"AEUnsupervised",
"AutoEncoderSupervised",
"AutoEncoder"
"AE"
"AutoEncoder",
"AE",
"BioImageEmbed",
"Config",
"augmentations",
"ModelFactory",
"create_model",
]
25 changes: 0 additions & 25 deletions bioimage_embed/datasets/__init__.py

This file was deleted.

102 changes: 0 additions & 102 deletions bioimage_embed/hydra.py

This file was deleted.

20 changes: 18 additions & 2 deletions bioimage_embed/lightning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
from .pyro import LitAutoEncoderPyro
from .torch import AESupervised, AEUnsupervised, AutoEncoder, AE, AutoEncoderSupervised, AutoEncoderUnsupervised
from .torch import (
AESupervised,
AEUnsupervised,
AutoEncoder,
AE,
AutoEncoderSupervised,
AutoEncoderUnsupervised,
)
from .dataloader import DataModule

__all__ = ["LitAutoEncoderPyro", "AESupervised", "AEUnsupervised", "DataModule", "AutoEncoder","AE","AutoEncoderUnsupervised","AutoEncoderSupervised"]
__all__ = [
"LitAutoEncoderPyro",
"AESupervised",
"AEUnsupervised",
"DataModule",
"AutoEncoder",
"AE",
"AutoEncoderUnsupervised",
"AutoEncoderSupervised",
]
7 changes: 4 additions & 3 deletions bioimage_embed/lightning/pyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@


class LitAutoEncoderPyro(pl.LightningModule):
"""
WIP Unsupported
"""

def __init__(self, model, batch_size=1, learning_rate=1e-3):
super().__init__()
# self.autoencoder = AutoEncoder(batch_size, 1)
Expand Down Expand Up @@ -59,6 +63,3 @@ def pyro_training_step(self, train_batch, batch_idx):

def training_step(self, train_batch, batch_idx):
return self.torch_training_step(train_batch, batch_idx)

def training_step(self, train_batch, batch_idx):
return self.pyro_training_step(train_batch, batch_idx)
1 change: 1 addition & 0 deletions bioimage_embed/lightning/tests/test_ndims.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch


# Fixture for batch sizes
@pytest.fixture(params=[1, 16])
def batch_size(request):
Expand Down
12 changes: 0 additions & 12 deletions bioimage_embed/lightning/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,6 @@ def eval_step(self, batch, batch_idx):
"""
return self.predict_step(batch, batch_idx)

# def lr_scheduler_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
# # Implement your own logic for updating the lr scheduler
# # This method will be called at each training step
# # Update the lr scheduler based on the provided arguments
# # You can access the lr scheduler using `self.lr_schedulers()`

# # Example:
# for lr_scheduler in self.lr_schedulers():
# lr_scheduler.step()

def timm_optimizers(self, model):
optimizer = optim.create_optimizer(self.args, model.parameters())
lr_scheduler = scheduler.create_scheduler(self.args, optimizer)[0]
Expand All @@ -168,8 +158,6 @@ def timm_to_lightning(self, optimizer, lr_scheduler):
}

def configure_optimizers(self):
# optimizer = optim.create_optimizer(self.args, self.model.parameters())
# lr_scheduler = scheduler.create_scheduler(self.args, optimizer)[0]
optimizer, lr_scheduler = self.timm_optimizers(self.model)
return self.timm_to_lightning(optimizer, lr_scheduler)

Expand Down
23 changes: 12 additions & 11 deletions bioimage_embed/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import torch
import torch.nn.functional as F
# Description: This file is the main entry point for the models module. It imports all the necessary modules and classes for the models module to function properly.

# Note - you must have torchvision installed for this example
from torch.utils.data import DataLoader

# from .ae import AutoEncoder

# from .vae_bio import Mask_VAE, Image_VAE
# from .utils import BaseVAE
# from .legacy.vae import VAE
# from .vq_vae import VQ_VAE

from .bolts import ResNet18VAEEncoder, ResNet18VAEDecoder

from . import bolts
from . import pythae
from .factory import ModelFactory, create_model, __all_models__
from .factory import ModelFactory, create_model, __all_models__

__all__ = [
"ModelFactory",
"create_model",
"__all_models__",
"ResNet18VAEEncoder",
"ResNet18VAEDecoder",
"bolts",
"pythae",
]
4 changes: 3 additions & 1 deletion bioimage_embed/models/tests/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
latent_dim = [64, 16]
pretrained_options = [True, False]
progress_options = [True, False]
batch = [1,]
batch = [
1,
]


@pytest.mark.parametrize("model", __all_models__)
Expand Down
5 changes: 3 additions & 2 deletions bioimage_embed/tests/test_bioimage_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import torch
from ..bie import BioImageEmbed


@pytest.fixture()
def test_bioimage_embed():
bie = BioImageEmbed()
bie.train()
bie.infer()
bie.validate()
model_output = bie(torch.tensor([1, 2, 3, 4, 5]))
tensor = bie.model(torch.tensor([1, 2, 3, 4, 5]))
assert bie(torch.tensor([1, 2, 3, 4, 5]))
assert bie.model(torch.tensor([1, 2, 3, 4, 5]))

bie.model(torch.tensor([1, 2, 3, 4, 5]))
Loading
Loading