Skip to content

Commit

Permalink
Merge branch 'main' into drop-scivision
Browse files Browse the repository at this point in the history
  • Loading branch information
metazool authored Aug 27, 2024
2 parents 2bf3586 + e4f4271 commit a21719c
Show file tree
Hide file tree
Showing 6 changed files with 576 additions and 0 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ If you modify the contents of a notebook, run the command after closing the note

For more information see the [Jupytext docs](https://jupytext.readthedocs.io/en/latest/).

## Visualisation

Streamlit app based off the [text embeddings for EIDC catalogue metadata](https://github.com/NERC-CEH/embeddings_app/) one

```
streamlit run cyto_ml/visualisation/visualisation_app.py
```

The demo should automatically open in your browser when you run streamlit. If it does not, connect using: http://localhost:8501.

### TBC (object store upload, derived classifiers, etc)

Expand Down
56 changes: 56 additions & 0 deletions cyto_ml/tests/test_visualisation_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""
Unit tests for the visualisation application.
"""

from unittest import TestCase, mock

import pandas as pd
from streamlit.testing.v1 import AppTest

from cyto_ml.visualisation.visualisation_app import create_figure


class TestClusteringApp(TestCase):
"""
Test class for the visualisation streamlit app.
"""

def setUp(self):
"""
Create some dummy data for testing.
"""
self.data = pd.DataFrame(
{
"x": [1, 2],
"y": [10, 11],
"topic_number": [1, 2],
"doc_id": ["id1", "id2"],
"short_title": ["stitle1", "stitle2"],
}
)

def test_app_starts(self):
"""
Test the streamlit app starts.
Note: current support for streamlit testing doesn;t currently allow to
mimic user interactions with the visualisation.
"""
with mock.patch(
"cyto_ml.visualisation.visualisation_app.image_ids",
return_value=self.data,
):
AppTest.from_file("cyto_ml/visualisation/visualisation_app.py").run(
timeout=30
)

def test_create_figure(self):
"""
Ensure figure is created appropriately using dummy data.
"""
fig = create_figure(self.data)

scatter_data = fig.data[0]
assert scatter_data.type == "scatter", "The plot type should be scatter"
assert all(scatter_data.x == self.data["x"]), "X data should match"
assert all(scatter_data.y == self.data["y"]), "Y data should match"
98 changes: 98 additions & 0 deletions cyto_ml/visualisation/pages/02_kmeans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from sklearn.cluster import KMeans
import streamlit as st
from cyto_ml.visualisation.visualisation_app import (
image_embeddings,
image_ids,
cached_image,
)

DEPTH = 8


@st.cache_resource
def kmeans_cluster() -> KMeans:
"""
K-means cluster the embeddings, option in session for default size
"""
X = image_embeddings("plankton")
n_clusters = st.session_state["n_clusters"]
# Initialize and fit KMeans
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
kmeans.fit(X)
return kmeans


@st.cache_data
def image_labels() -> dict:
"""
TODO good form to move all this into cyto_ml, call from there?
"""
km = kmeans_cluster()
clusters = dict(zip(set(km.labels_), [[] for _ in range(len(set(km.labels_)))]))

for index, id in enumerate(image_ids("plankton")):
label = km.labels_[index]
clusters[label].append(id)
return clusters


def add_more() -> None:
st.session_state["depth"] += DEPTH


def do_less() -> None:
st.session_state["depth"] -= DEPTH


def show_cluster() -> None:

# TODO n_clusters configurable with selector
fitted = image_labels()
closest = fitted[st.session_state["cluster"]]

# TODO figure out why this renders twice
for _ in range(0, st.session_state["depth"]):
cols = st.columns(DEPTH)
for c in cols:
c.image(cached_image(closest.pop()), width=60)


# TODO some visualisation, actual content, etc
def main() -> None:

# start with this cluster label
if "cluster" not in st.session_state:
st.session_state["cluster"] = 1

# start kmeans with this many target clusters
if "n_clusters" not in st.session_state:
st.session_state["n_clusters"] = 5

# show this many images * 8 across
if "depth" not in st.session_state:
st.session_state["depth"] = 8

st.selectbox(
"cluster label",
[x for x in range(0, st.session_state["n_clusters"])],
key="cluster",
on_change=show_cluster,
)

st.selectbox(
"n_clusters",
[3, 5, 8],
key="n_clusters",
on_change=kmeans_cluster,
)

st.button("more", on_click=add_more)

st.button("less", on_click=do_less)

show_cluster()


if __name__ == "__main__":
main()
166 changes: 166 additions & 0 deletions cyto_ml/visualisation/visualisation_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""
Streamlit application to visualise how plankton cluster
based on their embeddings from a deep learning model
* Metadata in intake catalogue (basically a dataframe of filenames
- later this could have lon/lat, date, depth read from Exif headers
* Embeddings in chromadb, linked by filename
"""

import random
import requests
from io import BytesIO
from typing import Optional

import pandas as pd
import numpy as np

from PIL import Image
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st

from scivision import load_dataset
from dotenv import load_dotenv
import intake
from cyto_ml.data.vectorstore import vector_store

load_dotenv()

STORE = vector_store("plankton")


@st.cache_data
def image_ids(collection_name: str) -> list:
"""
Retrieve image embeddings from chroma database.
TODO Revisit our available metadata
"""
result = STORE.get()
return result["ids"]


@st.cache_data
def image_embeddings(collection_name: str) -> list:
result = STORE.get(include=["embeddings"])
return np.array(result["embeddings"])


@st.cache_data
def intake_dataset(catalog_yml: str) -> intake.catalog.local.YAMLFileCatalog:
"""
Option to load an intake catalog from a URL, feels superflous right now
"""
dataset = load_dataset(catalog_yml)
return dataset


def closest_n(url: str, n: Optional[int] = 26) -> list:
"""
Given an image URL return the N closest ones by cosine distance
"""
embed = STORE.get([url], include=["embeddings"])["embeddings"]
results = STORE.query(query_embeddings=embed, n_results=n)
return results["ids"][0] # by index because API assumes query always multiple


@st.cache_data
def cached_image(url: str) -> Image:
"""
Read an image URL from s3 and return a PIL Image
Hopefully caches this per-image, so it'll speed up
We tried streamlit_clickable_images but no tiff support
"""
response = requests.get(url)
return Image.open(BytesIO(response.content))


def closest_grid(start_url: str, size: Optional[int] = 65):
"""
Given an image URL, render a grid of the N nearest images
by cosine distance between embeddings
N defaults to 26
"""
closest = closest_n(start_url, size)

# TODO understand where layout should happen
rows = []
for _ in range(0, 8):
rows.append(st.columns(8))

# TODO error handling
for index, _ in enumerate(rows):
for c in rows[index]:
c.image(cached_image(closest.pop()), width=60)


def create_figure(df: pd.DataFrame) -> go.Figure:
"""
Creates scatter plot based on handed data frame
TODO replace this layout with
a) most basic image grid, switch between clusters
b) ...
"""
color_dict = {i: px.colors.qualitative.Alphabet[i] for i in range(0, 20)}
color_dict[-1] = "#ABABAB"
topic_color = df["topic_number"].map(color_dict)
fig = go.Figure(
data=go.Scatter(
x=df["x"],
y=df["y"],
mode="markers",
marker_color=topic_color,
customdata=df["doc_id"],
text=df["short_title"],
hovertemplate="<b>%{text}</b>",
)
)
fig.update_layout(height=600)
return fig


def random_image() -> str:
ids = image_ids("plankton")
# starting image
test_image_url = random.choice(ids)
return test_image_url


def show_random_image():
if st.session_state["random_img"]:
st.image(cached_image(st.session_state["random_img"]))


def main() -> None:
"""
Main method that sets up the streamlit app and builds the visualisation.
"""
if "random_img" not in st.session_state:
st.session_state["random_img"] = None

st.set_page_config(layout="wide", page_title="Plankton image embeddings")
st.title("Plankton image embeddings")
# it starts much slower on adding this
# the generated HTML is not lovely at all

# catalog = "untagged-images-lana/intake.yml"
# catalog_url = f"{os.environ.get('ENDPOINT')}/{catalog}"
# ds = intake_dataset(catalog_url)
# This way we've got a dataframe of the whole catalogue
# Do we gain even slightly from this when we have the same index in the embeddings
# index = ds.plankton().to_dask().compute()

st.session_state["random_img"] = random_image()
show_random_image()

st.text("<-- random plankton")

st.button("try again", on_click=random_image)

# TODO figure out how streamlit is supposed to work
closest_grid(st.session_state["random_img"])


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ dependencies:
- python-dotenv
- s3fs
- scikit-image
- scikit-learn
- xarray
- pip
- streamlit
- plotly
- pip:
- git+https://github.com/jmarshrossney/resnet50-cefas
Loading

0 comments on commit a21719c

Please sign in to comment.