Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop scivision #24

Merged
merged 14 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
[flake8]
max-line-length=120
exclude =
venv
__pycache__
tests
vectors
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
**/__pycache__/
vectors/
*.ipynb
*.egg-info/
venv/
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12
40 changes: 30 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,50 @@ It's a companion project to an R-shiny based image annotation app that is not ye

## Installation

### Python environment setup
### Environment and package installation

Use anaconda or miniconda to create a python environment using the included `environment.yml`
#### Using pip

Create a fresh virtual environment in the repository root using Python >=3.12 and (e.g.) `venv`:

```
conda env create -f environment.yml
python -m venv venv
```

Please note that this is specifically pinned to python 3.9 due to dependency versions; we make experimental use of the [CEFAS plankton model available through SciVision](https://sci.vision/#/model/resnet50-plankton), which in turn uses an older version of pytorch that isn't packaged above python 3.9.
Next, install the package using `pip`:

### Object store connection
```
python -m pip install .
```

`.env` contains environment variable names for S3 connection details for the [JASMIN object store](https://github.com/NERC-CEH/object_store_tutorial/). Fill these in with your own credentials. If you're not sure what the `ENDPOINT` should be, please reach out to one of the project contributors listed below.
Most likely you are interested in developing and/or experimenting, so you will probably want to install the package in 'editable' mode (`-e`), along with dev tools and jupyter notebook functionality

```
python -m pip install -e .[all]
```

### Package installation
#### Using conda

Get started by cloning this repository and running
Use anaconda or miniconda to create a python environment using the included `environment.yml`

`pip install -e .`
```
conda env create -f environment.yml
conda activate cyto_ml
```

Next install this package _without dependencies_:

```
python -m pip install --no-deps -e .
```

### Object store connection

`.env` contains environment variable names for S3 connection details for the [JASMIN object store](https://github.com/NERC-CEH/object_store_tutorial/). Fill these in with your own credentials. If you're not sure what the `ENDPOINT` should be, please reach out to one of the project contributors listed below.

### Running tests

`python -m pytest` or `py.test`
`pytest` or `py.test`

## Contents

Expand Down
57 changes: 0 additions & 57 deletions cyto_ml/models/scivision.py

This file was deleted.

28 changes: 16 additions & 12 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
name: cyto_39
name: cyto_ml
channels:
- pytorch
- conda-forge
- defaults
channel_priority: flexible
dependencies:
- python=3.9
- pytorch=1.10.0
- mkl=2024.0
- chromadb=0.5.3
- python=3.12
- pytorch
- black
- chromadb
- flake8
- intake-xarray
- scikit-image
- scikit-learn
- intake=0.7
- isort
- jupyterlab
- jupytext
- matplotlib
- pandas
- pytest
- python-dotenv
- s3fs
- jupyterlab
- jupytext
- scikit-image
- scikit-learn
- xarray
- pip
- streamlit
- plotly
- pip:
- scivision
- git+https://github.com/alan-turing-institute/plankton-cefas-scivision@main
- git+https://github.com/jmarshrossney/resnet50-cefas
50 changes: 15 additions & 35 deletions notebooks/ImageEmbeddings.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@ jupyter:
name: python3
---

Use this with the `cyto_39` environment (the scivision model needs a specific version of `pytorch` that isn't packaged for >3.9, i have raised a Github issue asking if they plan to update it)
Use this with the `cyto_ml` environment.

`conda env create -f environment.yml`
`conda activate cyto_39`
`conda activate cyto_ml`

```python
import os
from scivision import load_pretrained_model, load_dataset
from dotenv import load_dotenv
import torch
import torchvision
Expand All @@ -29,58 +28,37 @@ sys.path.append('../')
from cyto_ml.models.scivision import prepare_image
from intake_xarray import ImageSource
load_dotenv() # sets our object store endpoint and credentials from the .env file

from intake import open_catalog

from resnet50_cefas import load_model
```

```python
dataset = load_dataset(f"{os.environ.get('ENDPOINT', '')}/metadata/intake.yml")
model = load_pretrained_model("https://github.com/alan-turing-institute/plankton-cefas-scivision")
dataset = open_catalog(f"{os.environ.get('ENDPOINT', '')}/metadata/intake.yml")
model = load_model()
dataset.test_image().to_dask()
```

The scivision wrapper depends on this being an xarray Dataset with settable attributes, rather than a DataArray

Setting exif_tags: True (Dataset) or False (DataArray) is what controls this
https://docs.xarray.dev/en/stable/generated/xarray.DataArray.to_dataset.html

https://github.com/alan-turing-institute/scivision/blob/07fb74e5231bc1d56cf39df38c19ef40e3265e4c/src/scivision/io/reader.py#L183
https://github.com/intake/intake/blob/29c8878aa7bf6e93185e2c9639f8739445dff22b/intake/__init__.py#L101

But now we're dependent on image height and width metadata being set in the EXIF tags to use the `predict` interface, this is set in the model description through `scivision`, this is brittle

https://github.com/alan-turing-institute/plankton-cefas-scivision/blob/main/resnet50_cefas/model.py#L71



A quick look at the example dataset that comes with the model, for reference


In this case we don't want to use the `predict` interface anyway (one of N class labels) - we want the features that go into the last fully-connected layer (as described here https://stackoverflow.com/a/52548419)

```python
network = torch.nn.Sequential(*(list(model._plumbing.model.pretrained_model.children())[:-1]))
network = load_model(strip_final_layer=True)
```

```python
imgs = dataset.test_image().to_dask()
i= imgs.to_numpy()
i.shape

imgs.to_numpy().shape
```

https://github.com/alan-turing-institute/plankton-cefas-scivision/blob/main/resnet50_cefas/data.py



Pass the image through our truncated network and get some embeddings out

```python
o = torch.stack([torchvision.transforms.ToTensor()(i)])
o = prepare_image(imgs)
feats = network(o)
feats.shape
```

```python
embeddings = list(feats[0].squeeze(1).squeeze(1).detach().numpy().astype(float))
embeddings = feats[0].tolist()
```

```python
Expand Down Expand Up @@ -129,7 +107,7 @@ index

```python
def flat_embeddings(features: torch.Tensor):
return list(features[0].squeeze(1).squeeze(1).detach().numpy().astype(float))
return features[0].tolist()
```

```python
Expand Down Expand Up @@ -158,6 +136,8 @@ This scales ok at 8000 or so images
collection.count()
```

This is _really_ slow - joe

```python
res = index.apply(file_embeddings, axis=1)
```
Expand Down
35 changes: 31 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,39 @@
[build-system]
requires = ["setuptools >= 61.0"]
build-backend = "setuptools.build_meta"

[project]
name = "cyto_ml"
version = "0.1"
version = "0.2.0"
requires-python = ">=3.12"
description = "This package supports the processing and analysis of plankton sample data"
readme = "README.md"
requires-python = "==3.9.*"
dependencies = [
"chromadb",
"intake==0.7.0",
"intake-xarray",
"pandas",
"python-dotenv",
"s3fs",
"scikit-image", # secretly required by intake-xarray as default reader
"torch",
"xarray",
"resnet50-cefas@git+https://github.com/jmarshrossney/resnet50-cefas",
]

[tool.setuptools]
py-modules = []
[project.optional-dependencies]
jupyter = ["jupyterlab", "jupytext", "matplotlib"]
dev = ["pytest", "black", "flake8", "isort"]
all = ["cyto_ml[jupyter,dev]"]

[tool.jupytext]
formats = "ipynb,md"

[tool.pytest.ini_options]
filterwarnings = [
"ignore::DeprecationWarning",
]

[tool.black]
target-version = ["py312"]
line-length = 88
1 change: 1 addition & 0 deletions scripts/intake_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Via https://gallery.pangeo.io/repos/pangeo-data/pangeo-tutorial-gallery/intake.html#Build-an-intake-catalog

"""

import os
from cyto_ml.data.intake import intake_yaml
from cyto_ml.data.s3 import s3_endpoint, image_index
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

logging.basicConfig(level=logging.INFO)
# TODO make this sensibly configurable, not confusingly hardcoded
STORE = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../vectors")
STORE = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../../vectors")

client = chromadb.PersistentClient(
path=STORE,
Expand Down
File renamed without changes.
36 changes: 36 additions & 0 deletions src/cyto_ml/models/scivision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
from torchvision.transforms.v2.functional import to_image, to_dtype
from xarray import DataArray


def prepare_image(image: DataArray):
"""
Take an xarray of image data and prepare it to pass through the model
a) Converts the image data to a PyTorch tensor
b) Accepts a single image or batch (no need for torch.stack)
"""
# Computes the DataArray and returns a numpy array
image_numpy = image.to_numpy()

# Convert the image data to a PyTorch tensor
tensor_image = to_dtype(
to_image(image_numpy), # permutes HWC -> CHW
torch.float32,
scale=True, # rescales [0, 255] -> [0, 1]
)
assert torch.all((tensor_image >= 0.0) & (tensor_image <= 1.0))

if tensor_image.dim() == 3:
# Single image, add a batch dimension
tensor_image = tensor_image.unsqueeze(0)

assert tensor_image.dim() == 4

return tensor_image


def flat_embeddings(features: torch.Tensor):
"""Utility function that takes the features returned by the model in truncate_model
And flattens them into a list suitable for storing in a vector database"""
# TODO: this only returns the 0th tensor in the batch...why?
return features[0].detach().tolist()
Loading