-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into drop-scivision
- Loading branch information
Showing
6 changed files
with
576 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
""" | ||
Unit tests for the visualisation application. | ||
""" | ||
|
||
from unittest import TestCase, mock | ||
|
||
import pandas as pd | ||
from streamlit.testing.v1 import AppTest | ||
|
||
from cyto_ml.visualisation.visualisation_app import create_figure | ||
|
||
|
||
class TestClusteringApp(TestCase): | ||
""" | ||
Test class for the visualisation streamlit app. | ||
""" | ||
|
||
def setUp(self): | ||
""" | ||
Create some dummy data for testing. | ||
""" | ||
self.data = pd.DataFrame( | ||
{ | ||
"x": [1, 2], | ||
"y": [10, 11], | ||
"topic_number": [1, 2], | ||
"doc_id": ["id1", "id2"], | ||
"short_title": ["stitle1", "stitle2"], | ||
} | ||
) | ||
|
||
def test_app_starts(self): | ||
""" | ||
Test the streamlit app starts. | ||
Note: current support for streamlit testing doesn;t currently allow to | ||
mimic user interactions with the visualisation. | ||
""" | ||
with mock.patch( | ||
"cyto_ml.visualisation.visualisation_app.image_ids", | ||
return_value=self.data, | ||
): | ||
AppTest.from_file("cyto_ml/visualisation/visualisation_app.py").run( | ||
timeout=30 | ||
) | ||
|
||
def test_create_figure(self): | ||
""" | ||
Ensure figure is created appropriately using dummy data. | ||
""" | ||
fig = create_figure(self.data) | ||
|
||
scatter_data = fig.data[0] | ||
assert scatter_data.type == "scatter", "The plot type should be scatter" | ||
assert all(scatter_data.x == self.data["x"]), "X data should match" | ||
assert all(scatter_data.y == self.data["y"]), "Y data should match" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from sklearn.cluster import KMeans | ||
import streamlit as st | ||
from cyto_ml.visualisation.visualisation_app import ( | ||
image_embeddings, | ||
image_ids, | ||
cached_image, | ||
) | ||
|
||
DEPTH = 8 | ||
|
||
|
||
@st.cache_resource | ||
def kmeans_cluster() -> KMeans: | ||
""" | ||
K-means cluster the embeddings, option in session for default size | ||
""" | ||
X = image_embeddings("plankton") | ||
n_clusters = st.session_state["n_clusters"] | ||
# Initialize and fit KMeans | ||
kmeans = KMeans(n_clusters=n_clusters, random_state=42) | ||
kmeans.fit(X) | ||
return kmeans | ||
|
||
|
||
@st.cache_data | ||
def image_labels() -> dict: | ||
""" | ||
TODO good form to move all this into cyto_ml, call from there? | ||
""" | ||
km = kmeans_cluster() | ||
clusters = dict(zip(set(km.labels_), [[] for _ in range(len(set(km.labels_)))])) | ||
|
||
for index, id in enumerate(image_ids("plankton")): | ||
label = km.labels_[index] | ||
clusters[label].append(id) | ||
return clusters | ||
|
||
|
||
def add_more() -> None: | ||
st.session_state["depth"] += DEPTH | ||
|
||
|
||
def do_less() -> None: | ||
st.session_state["depth"] -= DEPTH | ||
|
||
|
||
def show_cluster() -> None: | ||
|
||
# TODO n_clusters configurable with selector | ||
fitted = image_labels() | ||
closest = fitted[st.session_state["cluster"]] | ||
|
||
# TODO figure out why this renders twice | ||
for _ in range(0, st.session_state["depth"]): | ||
cols = st.columns(DEPTH) | ||
for c in cols: | ||
c.image(cached_image(closest.pop()), width=60) | ||
|
||
|
||
# TODO some visualisation, actual content, etc | ||
def main() -> None: | ||
|
||
# start with this cluster label | ||
if "cluster" not in st.session_state: | ||
st.session_state["cluster"] = 1 | ||
|
||
# start kmeans with this many target clusters | ||
if "n_clusters" not in st.session_state: | ||
st.session_state["n_clusters"] = 5 | ||
|
||
# show this many images * 8 across | ||
if "depth" not in st.session_state: | ||
st.session_state["depth"] = 8 | ||
|
||
st.selectbox( | ||
"cluster label", | ||
[x for x in range(0, st.session_state["n_clusters"])], | ||
key="cluster", | ||
on_change=show_cluster, | ||
) | ||
|
||
st.selectbox( | ||
"n_clusters", | ||
[3, 5, 8], | ||
key="n_clusters", | ||
on_change=kmeans_cluster, | ||
) | ||
|
||
st.button("more", on_click=add_more) | ||
|
||
st.button("less", on_click=do_less) | ||
|
||
show_cluster() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
""" | ||
Streamlit application to visualise how plankton cluster | ||
based on their embeddings from a deep learning model | ||
* Metadata in intake catalogue (basically a dataframe of filenames | ||
- later this could have lon/lat, date, depth read from Exif headers | ||
* Embeddings in chromadb, linked by filename | ||
""" | ||
|
||
import random | ||
import requests | ||
from io import BytesIO | ||
from typing import Optional | ||
|
||
import pandas as pd | ||
import numpy as np | ||
|
||
from PIL import Image | ||
import plotly.express as px | ||
import plotly.graph_objects as go | ||
import streamlit as st | ||
|
||
from scivision import load_dataset | ||
from dotenv import load_dotenv | ||
import intake | ||
from cyto_ml.data.vectorstore import vector_store | ||
|
||
load_dotenv() | ||
|
||
STORE = vector_store("plankton") | ||
|
||
|
||
@st.cache_data | ||
def image_ids(collection_name: str) -> list: | ||
""" | ||
Retrieve image embeddings from chroma database. | ||
TODO Revisit our available metadata | ||
""" | ||
result = STORE.get() | ||
return result["ids"] | ||
|
||
|
||
@st.cache_data | ||
def image_embeddings(collection_name: str) -> list: | ||
result = STORE.get(include=["embeddings"]) | ||
return np.array(result["embeddings"]) | ||
|
||
|
||
@st.cache_data | ||
def intake_dataset(catalog_yml: str) -> intake.catalog.local.YAMLFileCatalog: | ||
""" | ||
Option to load an intake catalog from a URL, feels superflous right now | ||
""" | ||
dataset = load_dataset(catalog_yml) | ||
return dataset | ||
|
||
|
||
def closest_n(url: str, n: Optional[int] = 26) -> list: | ||
""" | ||
Given an image URL return the N closest ones by cosine distance | ||
""" | ||
embed = STORE.get([url], include=["embeddings"])["embeddings"] | ||
results = STORE.query(query_embeddings=embed, n_results=n) | ||
return results["ids"][0] # by index because API assumes query always multiple | ||
|
||
|
||
@st.cache_data | ||
def cached_image(url: str) -> Image: | ||
""" | ||
Read an image URL from s3 and return a PIL Image | ||
Hopefully caches this per-image, so it'll speed up | ||
We tried streamlit_clickable_images but no tiff support | ||
""" | ||
response = requests.get(url) | ||
return Image.open(BytesIO(response.content)) | ||
|
||
|
||
def closest_grid(start_url: str, size: Optional[int] = 65): | ||
""" | ||
Given an image URL, render a grid of the N nearest images | ||
by cosine distance between embeddings | ||
N defaults to 26 | ||
""" | ||
closest = closest_n(start_url, size) | ||
|
||
# TODO understand where layout should happen | ||
rows = [] | ||
for _ in range(0, 8): | ||
rows.append(st.columns(8)) | ||
|
||
# TODO error handling | ||
for index, _ in enumerate(rows): | ||
for c in rows[index]: | ||
c.image(cached_image(closest.pop()), width=60) | ||
|
||
|
||
def create_figure(df: pd.DataFrame) -> go.Figure: | ||
""" | ||
Creates scatter plot based on handed data frame | ||
TODO replace this layout with | ||
a) most basic image grid, switch between clusters | ||
b) ... | ||
""" | ||
color_dict = {i: px.colors.qualitative.Alphabet[i] for i in range(0, 20)} | ||
color_dict[-1] = "#ABABAB" | ||
topic_color = df["topic_number"].map(color_dict) | ||
fig = go.Figure( | ||
data=go.Scatter( | ||
x=df["x"], | ||
y=df["y"], | ||
mode="markers", | ||
marker_color=topic_color, | ||
customdata=df["doc_id"], | ||
text=df["short_title"], | ||
hovertemplate="<b>%{text}</b>", | ||
) | ||
) | ||
fig.update_layout(height=600) | ||
return fig | ||
|
||
|
||
def random_image() -> str: | ||
ids = image_ids("plankton") | ||
# starting image | ||
test_image_url = random.choice(ids) | ||
return test_image_url | ||
|
||
|
||
def show_random_image(): | ||
if st.session_state["random_img"]: | ||
st.image(cached_image(st.session_state["random_img"])) | ||
|
||
|
||
def main() -> None: | ||
""" | ||
Main method that sets up the streamlit app and builds the visualisation. | ||
""" | ||
if "random_img" not in st.session_state: | ||
st.session_state["random_img"] = None | ||
|
||
st.set_page_config(layout="wide", page_title="Plankton image embeddings") | ||
st.title("Plankton image embeddings") | ||
# it starts much slower on adding this | ||
# the generated HTML is not lovely at all | ||
|
||
# catalog = "untagged-images-lana/intake.yml" | ||
# catalog_url = f"{os.environ.get('ENDPOINT')}/{catalog}" | ||
# ds = intake_dataset(catalog_url) | ||
# This way we've got a dataframe of the whole catalogue | ||
# Do we gain even slightly from this when we have the same index in the embeddings | ||
# index = ds.plankton().to_dask().compute() | ||
|
||
st.session_state["random_img"] = random_image() | ||
show_random_image() | ||
|
||
st.text("<-- random plankton") | ||
|
||
st.button("try again", on_click=random_image) | ||
|
||
# TODO figure out how streamlit is supposed to work | ||
closest_grid(st.session_state["random_img"]) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.