diff --git a/src/harmony/matching/cluster.py b/src/harmony/matching/cluster.py new file mode 100644 index 0000000..2e18f80 --- /dev/null +++ b/src/harmony/matching/cluster.py @@ -0,0 +1,89 @@ +import os +import sys +import numpy as np +import pandas as pd +from sklearn.cluster import KMeans +from sklearn.metrics import silhouette_score +from sklearn.decomposition import PCA +from sentence_transformers import SentenceTransformer + +if ( + os.environ.get("HARMONY_SENTENCE_TRANSFORMER_PATH", None) is not None + and os.environ.get("HARMONY_SENTENCE_TRANSFORMER_PATH", None) != "" +): + sentence_transformer_path = os.environ["HARMONY_SENTENCE_TRANSFORMER_PATH"] +else: + sentence_transformer_path = ( + "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" + ) + +model = SentenceTransformer(sentence_transformer_path) + +# questions_in should be a list of question strings +def get_embeddings(questions_in): + # Generate embeddings using HuggingFace model + embedding_result = model.encode(questions_in, show_progress_bar=True) + questions_df = pd.DataFrame() + + # Add embeddings to df and convert the embeddings to numpy arrays + questions_df["embedding"] = [embedding.tolist() for embedding in embedding_result] + questions_df["embedding"] = questions_df["embedding"].apply(np.array) + + # Stack embeddings into a matrix + matrix = np.vstack(questions_df.embedding.values) + return matrix + + +def perform_kmeans(embeddings_in, num_clusters=5): + kmeans = KMeans(n_clusters=num_clusters) + kmeans_labels = kmeans.fit_predict(embeddings_in) + return kmeans_labels + + +def visualize_clusters(embeddings_in, kmeans_labels): + try: + import matplotlib.pyplot as plt + pca = PCA(n_components=2) + reduced_embeddings = pca.fit_transform(embeddings_in) + plt.scatter(reduced_embeddings[:, 0], reduced_embeddings[:, 1], c=kmeans_labels, cmap='viridis', s=50) + plt.colorbar() + plt.title("Question Clusters") + + for i, point in enumerate(reduced_embeddings): + plt.annotate( + str(i), # Label each point with its question number + (point[0], point[1]), # Coordinates from reduced_embeddings + fontsize=8, + ha="center" + ) + + plt.show() + except ImportError as e: + print( + "Matplotlib is not installed. Please install it using:\n" + "pip install matplotlib==3.7.0" + ) + sys.exit(1) + +def cluster_questions(instrument_in, num_clusters: int, graph: bool): + # convert instruments into a list of questions + questions_list = [] + for question in instrument_in.questions: + questions_list.append(question.question_text) + embedding_matrix = get_embeddings(questions_list) + kmeans_labels = perform_kmeans(embedding_matrix, num_clusters) + df = pd.DataFrame({ + "question_text": questions_list, + "cluster_number": kmeans_labels + }) + + # silhouette score requires at least 2 clusters + if num_clusters > 1: + sil_score = silhouette_score(embedding_matrix, kmeans_labels) + else: + sil_score = None + + if graph: + visualize_clusters(embedding_matrix, kmeans_labels) + + return df, sil_score diff --git a/tests/test_cluster.py b/tests/test_cluster.py new file mode 100644 index 0000000..882b108 --- /dev/null +++ b/tests/test_cluster.py @@ -0,0 +1,54 @@ +''' +MIT License + +Copyright (c) 2023 Ulster University (https://www.ulster.ac.uk). +Project: Harmony (https://harmonydata.ac.uk) +Maintainer: Thomas Wood (https://fastdatascience.com) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' + +import sys +import unittest + +sys.path.append("../src") + +from harmony.matching.cluster import cluster_questions +from harmony import create_instrument_from_list, import_instrument_into_harmony_web +from harmony.schemas.requests.text import Instrument, Question + + +class TestCluster(unittest.TestCase): + def setUp(self): + self. all_questions_real = [Question(question_no="1", question_text="Feeling nervous, anxious, or on edge"), + Question(question_no="2", question_text="Not being able to stop or control worrying"), + Question(question_no="3", question_text="Little interest or pleasure in doing things"), + Question(question_no="4", question_text="Feeling down, depressed, or hopeless"), + Question(question_no="5", + question_text="Trouble falling/staying asleep, sleeping too much"), ] + self.instruments = Instrument(questions=self.all_questions_real) + + def test_cluster(self): + clusters_out, score_out = cluster_questions(self.instruments, 2, False) + assert(len(clusters_out) == 5) + assert score_out + +if __name__ == '__main__': + unittest.main()