diff --git a/README.md b/README.md index 731797f..e637785 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,15 @@ If you modify the contents of a notebook, run the command after closing the note For more information see the [Jupytext docs](https://jupytext.readthedocs.io/en/latest/). +## Visualisation + +Streamlit app based off the [text embeddings for EIDC catalogue metadata](https://github.com/NERC-CEH/embeddings_app/) one + +``` +streamlit run cyto_ml/visualisation/visualisation_app.py +``` + +The demo should automatically open in your browser when you run streamlit. If it does not, connect using: http://localhost:8501. ### TBC (object store upload, derived classifiers, etc) diff --git a/cyto_ml/tests/test_visualisation_app.py b/cyto_ml/tests/test_visualisation_app.py new file mode 100644 index 0000000..2fe5bb7 --- /dev/null +++ b/cyto_ml/tests/test_visualisation_app.py @@ -0,0 +1,56 @@ +""" +Unit tests for the visualisation application. +""" + +from unittest import TestCase, mock + +import pandas as pd +from streamlit.testing.v1 import AppTest + +from cyto_ml.visualisation.visualisation_app import create_figure + + +class TestClusteringApp(TestCase): + """ + Test class for the visualisation streamlit app. + """ + + def setUp(self): + """ + Create some dummy data for testing. + """ + self.data = pd.DataFrame( + { + "x": [1, 2], + "y": [10, 11], + "topic_number": [1, 2], + "doc_id": ["id1", "id2"], + "short_title": ["stitle1", "stitle2"], + } + ) + + def test_app_starts(self): + """ + Test the streamlit app starts. + + Note: current support for streamlit testing doesn;t currently allow to + mimic user interactions with the visualisation. + """ + with mock.patch( + "cyto_ml.visualisation.visualisation_app.image_ids", + return_value=self.data, + ): + AppTest.from_file("cyto_ml/visualisation/visualisation_app.py").run( + timeout=30 + ) + + def test_create_figure(self): + """ + Ensure figure is created appropriately using dummy data. + """ + fig = create_figure(self.data) + + scatter_data = fig.data[0] + assert scatter_data.type == "scatter", "The plot type should be scatter" + assert all(scatter_data.x == self.data["x"]), "X data should match" + assert all(scatter_data.y == self.data["y"]), "Y data should match" diff --git a/cyto_ml/visualisation/pages/02_kmeans.py b/cyto_ml/visualisation/pages/02_kmeans.py new file mode 100644 index 0000000..308a627 --- /dev/null +++ b/cyto_ml/visualisation/pages/02_kmeans.py @@ -0,0 +1,98 @@ +from sklearn.cluster import KMeans +import streamlit as st +from cyto_ml.visualisation.visualisation_app import ( + image_embeddings, + image_ids, + cached_image, +) + +DEPTH = 8 + + +@st.cache_resource +def kmeans_cluster() -> KMeans: + """ + K-means cluster the embeddings, option in session for default size + + """ + X = image_embeddings("plankton") + n_clusters = st.session_state["n_clusters"] + # Initialize and fit KMeans + kmeans = KMeans(n_clusters=n_clusters, random_state=42) + kmeans.fit(X) + return kmeans + + +@st.cache_data +def image_labels() -> dict: + """ + TODO good form to move all this into cyto_ml, call from there? + """ + km = kmeans_cluster() + clusters = dict(zip(set(km.labels_), [[] for _ in range(len(set(km.labels_)))])) + + for index, id in enumerate(image_ids("plankton")): + label = km.labels_[index] + clusters[label].append(id) + return clusters + + +def add_more() -> None: + st.session_state["depth"] += DEPTH + + +def do_less() -> None: + st.session_state["depth"] -= DEPTH + + +def show_cluster() -> None: + + # TODO n_clusters configurable with selector + fitted = image_labels() + closest = fitted[st.session_state["cluster"]] + + # TODO figure out why this renders twice + for _ in range(0, st.session_state["depth"]): + cols = st.columns(DEPTH) + for c in cols: + c.image(cached_image(closest.pop()), width=60) + + +# TODO some visualisation, actual content, etc +def main() -> None: + + # start with this cluster label + if "cluster" not in st.session_state: + st.session_state["cluster"] = 1 + + # start kmeans with this many target clusters + if "n_clusters" not in st.session_state: + st.session_state["n_clusters"] = 5 + + # show this many images * 8 across + if "depth" not in st.session_state: + st.session_state["depth"] = 8 + + st.selectbox( + "cluster label", + [x for x in range(0, st.session_state["n_clusters"])], + key="cluster", + on_change=show_cluster, + ) + + st.selectbox( + "n_clusters", + [3, 5, 8], + key="n_clusters", + on_change=kmeans_cluster, + ) + + st.button("more", on_click=add_more) + + st.button("less", on_click=do_less) + + show_cluster() + + +if __name__ == "__main__": + main() diff --git a/cyto_ml/visualisation/visualisation_app.py b/cyto_ml/visualisation/visualisation_app.py new file mode 100644 index 0000000..c10ed0f --- /dev/null +++ b/cyto_ml/visualisation/visualisation_app.py @@ -0,0 +1,166 @@ +""" +Streamlit application to visualise how plankton cluster +based on their embeddings from a deep learning model + +* Metadata in intake catalogue (basically a dataframe of filenames + - later this could have lon/lat, date, depth read from Exif headers +* Embeddings in chromadb, linked by filename + +""" + +import random +import requests +from io import BytesIO +from typing import Optional + +import pandas as pd +import numpy as np + +from PIL import Image +import plotly.express as px +import plotly.graph_objects as go +import streamlit as st + +from scivision import load_dataset +from dotenv import load_dotenv +import intake +from cyto_ml.data.vectorstore import vector_store + +load_dotenv() + +STORE = vector_store("plankton") + + +@st.cache_data +def image_ids(collection_name: str) -> list: + """ + Retrieve image embeddings from chroma database. + TODO Revisit our available metadata + """ + result = STORE.get() + return result["ids"] + + +@st.cache_data +def image_embeddings(collection_name: str) -> list: + result = STORE.get(include=["embeddings"]) + return np.array(result["embeddings"]) + + +@st.cache_data +def intake_dataset(catalog_yml: str) -> intake.catalog.local.YAMLFileCatalog: + """ + Option to load an intake catalog from a URL, feels superflous right now + """ + dataset = load_dataset(catalog_yml) + return dataset + + +def closest_n(url: str, n: Optional[int] = 26) -> list: + """ + Given an image URL return the N closest ones by cosine distance + """ + embed = STORE.get([url], include=["embeddings"])["embeddings"] + results = STORE.query(query_embeddings=embed, n_results=n) + return results["ids"][0] # by index because API assumes query always multiple + + +@st.cache_data +def cached_image(url: str) -> Image: + """ + Read an image URL from s3 and return a PIL Image + Hopefully caches this per-image, so it'll speed up + We tried streamlit_clickable_images but no tiff support + """ + response = requests.get(url) + return Image.open(BytesIO(response.content)) + + +def closest_grid(start_url: str, size: Optional[int] = 65): + """ + Given an image URL, render a grid of the N nearest images + by cosine distance between embeddings + N defaults to 26 + """ + closest = closest_n(start_url, size) + + # TODO understand where layout should happen + rows = [] + for _ in range(0, 8): + rows.append(st.columns(8)) + + # TODO error handling + for index, _ in enumerate(rows): + for c in rows[index]: + c.image(cached_image(closest.pop()), width=60) + + +def create_figure(df: pd.DataFrame) -> go.Figure: + """ + Creates scatter plot based on handed data frame + TODO replace this layout with + a) most basic image grid, switch between clusters + b) ... + """ + color_dict = {i: px.colors.qualitative.Alphabet[i] for i in range(0, 20)} + color_dict[-1] = "#ABABAB" + topic_color = df["topic_number"].map(color_dict) + fig = go.Figure( + data=go.Scatter( + x=df["x"], + y=df["y"], + mode="markers", + marker_color=topic_color, + customdata=df["doc_id"], + text=df["short_title"], + hovertemplate="%{text}", + ) + ) + fig.update_layout(height=600) + return fig + + +def random_image() -> str: + ids = image_ids("plankton") + # starting image + test_image_url = random.choice(ids) + return test_image_url + + +def show_random_image(): + if st.session_state["random_img"]: + st.image(cached_image(st.session_state["random_img"])) + + +def main() -> None: + """ + Main method that sets up the streamlit app and builds the visualisation. + """ + if "random_img" not in st.session_state: + st.session_state["random_img"] = None + + st.set_page_config(layout="wide", page_title="Plankton image embeddings") + st.title("Plankton image embeddings") + # it starts much slower on adding this + # the generated HTML is not lovely at all + + # catalog = "untagged-images-lana/intake.yml" + # catalog_url = f"{os.environ.get('ENDPOINT')}/{catalog}" + # ds = intake_dataset(catalog_url) + # This way we've got a dataframe of the whole catalogue + # Do we gain even slightly from this when we have the same index in the embeddings + # index = ds.plankton().to_dask().compute() + + st.session_state["random_img"] = random_image() + show_random_image() + + st.text("<-- random plankton") + + st.button("try again", on_click=random_image) + + # TODO figure out how streamlit is supposed to work + closest_grid(st.session_state["random_img"]) + + +if __name__ == "__main__": + main() diff --git a/environment.yml b/environment.yml index 9bccdc0..75d90c1 100644 --- a/environment.yml +++ b/environment.yml @@ -10,6 +10,7 @@ dependencies: - chromadb=0.5.3 - intake-xarray - scikit-image + - scikit-learn - pandas - pytest - python-dotenv @@ -17,6 +18,8 @@ dependencies: - jupyterlab - jupytext - pip + - streamlit + - plotly - pip: - scivision - git+https://github.com/alan-turing-institute/plankton-cefas-scivision@main diff --git a/notebooks/Clustering.md b/notebooks/Clustering.md new file mode 100644 index 0000000..9e0be1a --- /dev/null +++ b/notebooks/Clustering.md @@ -0,0 +1,244 @@ +--- +jupyter: + jupytext: + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.16.4 + kernelspec: + display_name: cyto_39 + language: python + name: python3 +--- + +Notebook examination of applying unsupervised clustering methods to vector embeddings and visualising the results + +Keywords: K-means, DBScan, T-SNE, other? + +Paper reference for approach: https://aslopubs.onlinelibrary.wiley.com/doi/full/10.1002/lno.12101#lno12101-sec-0025-title + +Possibly interesting if we try the transformer-based plankton model from Turing: https://link.springer.com/chapter/10.1007/978-3-030-74251-5_23 + + +```python +import sys +sys.path.append('../') +from cyto_ml.data.vectorstore import vector_store, client +from sklearn.cluster import DBSCAN, KMeans +from sklearn.preprocessing import StandardScaler +from matplotlib import pyplot as plt +from skimage import io +import numpy as np +``` + +Load our embeddings into a form suitable for throwing at clustering algorithms, 2048 features might be optimistic and we need to first reduce them! + +```python +store = vector_store('plankton') +res = store.get(include=['embeddings']) +X = np.array(res['embeddings']) +``` + +This doesn't work with such a high number of features even with `make_blobs` generating pre-clustered data with 2048 features, and tips like scaling values. + +So either PCA first, or we just stick with K-means as a simpler effort and work our way back here. + +```python +def do_dbscan(X): + + scaler = StandardScaler() + X_scaled = scaler.fit_transform(X) + db = DBSCAN(eps=0.7, min_samples=100).fit(X_scaled) + + labels = db.labels_ + + # Number of clusters in labels, ignoring noise if present. + n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0) + n_noise_ = list(labels).count(-1) + + print("Estimated number of clusters: %d" % n_clusters_) + print("Estimated number of noise points: %d" % n_noise_) + +do_dbscan(X) +``` + +Sense check on generated dataset with same number of features and three natural clusters - uncomment to see it + +```python +# from sklearn.datasets import make_blobs +# X, y = make_blobs(n_samples=1000, centers=3, n_features=2048, random_state=0) +# do_dbscan(X) +``` + +```python +type(X) +``` + +Fall back to a K-means approach, just to try and get some visual feedback + +```python +# Set the number of clusters +num_clusters = 10 # Adjust based on your data + +# Initialize and fit KMeans +kmeans = KMeans(n_clusters=num_clusters, random_state=42) +kmeans.fit(X) + +# Get cluster labels +labels = kmeans.labels_ +``` + +```python +len(labels) +``` + +```python + +clusters = dict(zip(set(labels), [[] for _ in range(len(set(labels)))])) + +for index, id in enumerate(res['ids']): + l = labels[index] + clusters[l].append(id) +``` + +```python +i = 3 # picked at random +from mpl_toolkits.axes_grid1 import ImageGrid +fig = plt.figure(figsize=(10., 10.)) +grid = ImageGrid(fig, 111, # similar to subplot(111) + nrows_ncols=(5, 5), # creates 2x2 grid of axes + axes_pad=0.2, # pad between axes in inch. + ) + +for index, ax in enumerate(grid): + # Iterating over the grid returns the Axes. + ax.imshow(io.imread(clusters[i][index])) + +``` + +To be continued + +* Iteration with cluster sizes - 10 was picked arbitrarily, 1 and 2 look like detritus +* Proper look at image quality - what's getting lost between the FlowCam and here +* Nicer way of doing this than a notebook, that has some level of reuse value for other image projects + + + +Silhoutte analysis as per https://scikit-learn.org/stable/auto_examples/cluster/plot_kmeans_silhouette_analysis.html + +```python +import matplotlib.cm as cm +import matplotlib.pyplot as plt +import numpy as np + +from sklearn.cluster import KMeans +from sklearn.datasets import make_blobs +from sklearn.metrics import silhouette_samples, silhouette_score + +range_n_clusters = [3, 5, 7, 8, 10] + +for n_clusters in range_n_clusters: + # Create a subplot with 1 row and 2 columns + fig, (ax1, ax2) = plt.subplots(1, 2) + fig.set_size_inches(18, 7) + + # The 1st subplot is the silhouette plot + # The silhouette coefficient can range from -1, 1 but in this example all + # lie within [-0.1, 1] + ax1.set_xlim([-0.1, 1]) + # The (n_clusters+1)*10 is for inserting blank space between silhouette + # plots of individual clusters, to demarcate them clearly. + ax1.set_ylim([0, len(X) + (n_clusters + 1) * 10]) + + # Initialize the clusterer with n_clusters value and a random generator + # seed of 10 for reproducibility. + clusterer = KMeans(n_clusters=n_clusters, random_state=10) + cluster_labels = clusterer.fit_predict(X) + + # The silhouette_score gives the average value for all the samples. + # This gives a perspective into the density and separation of the formed + # clusters + silhouette_avg = silhouette_score(X, cluster_labels) + print( + "For n_clusters =", + n_clusters, + "The average silhouette_score is :", + silhouette_avg, + ) + + # Compute the silhouette scores for each sample + sample_silhouette_values = silhouette_samples(X, cluster_labels) + + y_lower = 10 + for i in range(n_clusters): + # Aggregate the silhouette scores for samples belonging to + # cluster i, and sort them + ith_cluster_silhouette_values = sample_silhouette_values[cluster_labels == i] + + ith_cluster_silhouette_values.sort() + + size_cluster_i = ith_cluster_silhouette_values.shape[0] + y_upper = y_lower + size_cluster_i + + color = cm.nipy_spectral(float(i) / n_clusters) + ax1.fill_betweenx( + np.arange(y_lower, y_upper), + 0, + ith_cluster_silhouette_values, + facecolor=color, + edgecolor=color, + alpha=0.7, + ) + + # Label the silhouette plots with their cluster numbers at the middle + ax1.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i)) + + # Compute the new y_lower for next plot + y_lower = y_upper + 10 # 10 for the 0 samples + + ax1.set_title("The silhouette plot for the various clusters.") + ax1.set_xlabel("The silhouette coefficient values") + ax1.set_ylabel("Cluster label") + + # The vertical line for average silhouette score of all the values + ax1.axvline(x=silhouette_avg, color="red", linestyle="--") + + ax1.set_yticks([]) # Clear the yaxis labels / ticks + ax1.set_xticks([-0.1, 0, 0.2, 0.4, 0.6, 0.8, 1]) + + # 2nd Plot showing the actual clusters formed + colors = cm.nipy_spectral(cluster_labels.astype(float) / n_clusters) + ax2.scatter( + X[:, 0], X[:, 1], marker=".", s=30, lw=0, alpha=0.7, c=colors, edgecolor="k" + ) + + # Labeling the clusters + centers = clusterer.cluster_centers_ + # Draw white circles at cluster centers + ax2.scatter( + centers[:, 0], + centers[:, 1], + marker="o", + c="white", + alpha=1, + s=200, + edgecolor="k", + ) + + for i, c in enumerate(centers): + ax2.scatter(c[0], c[1], marker="$%d$" % i, alpha=1, s=50, edgecolor="k") + + ax2.set_title("The visualization of the clustered data.") + ax2.set_xlabel("Feature space for the 1st feature") + ax2.set_ylabel("Feature space for the 2nd feature") + + plt.suptitle( + "Silhouette analysis for KMeans clustering on sample data with n_clusters = %d" + % n_clusters, + fontsize=14, + fontweight="bold", + ) + +plt.show() +``` diff --git a/notebooks/ImageEmbeddings.md b/notebooks/ImageEmbeddings.md index bb9f13a..889a9f2 100644 --- a/notebooks/ImageEmbeddings.md +++ b/notebooks/ImageEmbeddings.md @@ -5,7 +5,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.16.3 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 (ipykernel) language: python