From b4f24a85ce4d5a4bb507c772e895398a70480f88 Mon Sep 17 00:00:00 2001 From: Zairon Jacobs Date: Wed, 4 Sep 2024 07:16:40 +0200 Subject: [PATCH 1/4] move print statement in embeddings modules --- harmony_api/services/azure_openai_embeddings.py | 2 +- harmony_api/services/google_embeddings.py | 2 +- harmony_api/services/openai_embeddings.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/harmony_api/services/azure_openai_embeddings.py b/harmony_api/services/azure_openai_embeddings.py index 42d5a43..4143d44 100644 --- a/harmony_api/services/azure_openai_embeddings.py +++ b/harmony_api/services/azure_openai_embeddings.py @@ -15,9 +15,9 @@ API_VERSION = "2023-12-01-preview" # This might change in the future # Check available models -print("INFO:\t Checking Azure OpenAI models...") HARMONY_API_AVAILABLE_AZURE_OPENAI_MODELS_LIST: List[str] = [] if settings.AZURE_OPENAI_API_KEY and settings.AZURE_OPENAI_ENDPOINT: + print("INFO:\t Checking Azure OpenAI models...") for harmony_api_azure_openai_model in HARMONY_API_AZURE_OPENAI_MODELS_LIST: try: AzureOpenAI( diff --git a/harmony_api/services/google_embeddings.py b/harmony_api/services/google_embeddings.py index 3a9f51d..7a84b3b 100644 --- a/harmony_api/services/google_embeddings.py +++ b/harmony_api/services/google_embeddings.py @@ -26,9 +26,9 @@ ) # Check available models -print("INFO:\t Checking Google models...") HARMONY_API_AVAILABLE_GOOGLE_MODELS_LIST: List[str] = [] if settings.GOOGLE_APPLICATION_CREDENTIALS: + print("INFO:\t Checking Google models...") for harmony_api_google_model in HARMONY_API_GOOGLE_MODELS_LIST: try: TextEmbeddingModel.from_pretrained(harmony_api_google_model["model"]) diff --git a/harmony_api/services/openai_embeddings.py b/harmony_api/services/openai_embeddings.py index 2ab27b1..85b2288 100644 --- a/harmony_api/services/openai_embeddings.py +++ b/harmony_api/services/openai_embeddings.py @@ -17,9 +17,9 @@ openai.api_key = settings.OPENAI_API_KEY # Check available models -print("INFO:\t Checking OpenAI models...") HARMONY_API_AVAILABLE_OPENAI_MODELS_LIST: List[str] = [] if settings.OPENAI_API_KEY: + print("INFO:\t Checking OpenAI models...") openai_client = OpenAI() OPENAI_MODELS: List[str] = [x.id for x in openai_client.models.list()] for harmony_api_openai_model in HARMONY_API_OPENAI_MODELS_LIST: From 985cbd96b54e5d56c7ff22cde0d0d5b228f0943c Mon Sep 17 00:00:00 2001 From: Zairon Jacobs Date: Wed, 4 Sep 2024 08:43:08 +0200 Subject: [PATCH 2/4] add search_instruments endpoint --- harmony_api/helpers.py | 70 +++++++++++++--- harmony_api/routers/text_router.py | 113 ++++++++++++++++++++++++-- harmony_api/services/vectors_cache.py | 17 +++- 3 files changed, 178 insertions(+), 22 deletions(-) diff --git a/harmony_api/helpers.py b/harmony_api/helpers.py index bf20bbe..1e0299d 100644 --- a/harmony_api/helpers.py +++ b/harmony_api/helpers.py @@ -252,12 +252,22 @@ def get_catalogue_data_model_embeddings(model: dict) -> np.ndarray: return all_embeddings_concatenated -def filter_catalogue_data(catalogue_data: dict, sources: List[str]) -> dict: +def filter_catalogue_data( + catalogue_data: dict, + sources: List[str] | None = None, + topics: List[str] | None = None, + instrument_length_min: int | None = None, + instrument_length_max: int | None = None, +) -> dict: """ Filter catalogue data to only keep instruments with the sources. :param catalogue_data: Catalogue data. :param sources: Only keep instruments from sources. + :param topics: Only keep instruments with these topics. Topics can be found in the metadata of each instrument. + :param instrument_length_min: Only keep instruments with min number of questions. + :param instrument_length_max: Only keep instruments with max number of questions. + :return: The filtered catalogue data. """ def normalize_text(text: str): @@ -266,17 +276,19 @@ def normalize_text(text: str): return text - # Lowercase sources - sources_set = {x.strip().lower() for x in sources if x.strip()} + if not sources: + sources = [] + if not topics: + topics = [] - # Nothing to filter - if not sources_set: - return catalogue_data + # Lowercase sources and topics + sources_set = {x.strip().lower() for x in sources if x.strip()} + topics_set = {x.strip().lower() for x in topics if x.strip()} # Create a dictionary with questions and their vectors question_normalized_to_vector: dict[str, List[float]] = {} for question, vector in zip( - catalogue_data["all_questions"], catalogue_data["all_embeddings_concatenated"] + catalogue_data["all_questions"], catalogue_data["all_embeddings_concatenated"] ): question_normalized = normalize_text(question) if question_normalized not in question_normalized_to_vector: @@ -285,13 +297,45 @@ def normalize_text(text: str): # Find instrument indexes to remove idxs_instruments_to_remove: List[int] = [] for instrument_idx, catalogue_instrument in enumerate( - catalogue_data["all_instruments"] + catalogue_data["all_instruments"] ): - if ( - catalogue_instrument["metadata"]["source"].strip().lower() - not in sources_set - ): - idxs_instruments_to_remove.append(instrument_idx) + questions_len = len(catalogue_instrument["questions"]) + + # By min instrument questions length + if instrument_length_min: + if questions_len < instrument_length_min: + idxs_instruments_to_remove.append(instrument_idx) + continue + + # By max instrument questions length + if instrument_length_max: + if questions_len > instrument_length_max: + idxs_instruments_to_remove.append(instrument_idx) + continue + + # By sources + if sources_set: + if ( + catalogue_instrument["metadata"]["source"].strip().lower() + not in sources_set + ): + idxs_instruments_to_remove.append(instrument_idx) + continue + + # By topics + if topics_set: + not_found_topics_len = 0 + catalogue_instrument_topics: list[str] = catalogue_instrument[ + "metadata" + ].get("topics", []) + for topic in topics_set: + if topic not in [ + x.strip().lower() for x in catalogue_instrument_topics if x.strip() + ]: + not_found_topics_len += 1 + if not_found_topics_len == len(topics_set): + idxs_instruments_to_remove.append(instrument_idx) + continue # Remove instruments for idx_instrument_to_remove in sorted(idxs_instruments_to_remove, reverse=True): diff --git a/harmony_api/routers/text_router.py b/harmony_api/routers/text_router.py index cbd74bc..90874bd 100644 --- a/harmony_api/routers/text_router.py +++ b/harmony_api/routers/text_router.py @@ -31,16 +31,21 @@ from fastapi import APIRouter, Body, status, Depends, Query from harmony.matching.default_matcher import match_instruments_with_function -from harmony.matching.matcher import match_instruments_with_catalogue_instruments +from harmony.matching.matcher import ( + match_instruments_with_catalogue_instruments, + match_query_with_catalogue_instruments, +) from harmony.parsing.wrapper_all_parsers import convert_files_to_instruments from harmony.schemas.requests.text import ( RawFile, Instrument, MatchBody, + SearchInstrumentsBody, ) from harmony.schemas.responses.text import ( MatchResponse, CacheResponse, + SearchInstrumentsResponse, ) from harmony_api import helpers, dependencies, constants @@ -245,7 +250,7 @@ def match( catalogue_data = {"all_embeddings_concatenated": catalogue_embeddings} catalogue_data.update(catalogue_data_default) - # Filter catalogue data for sources + # Filter catalogue data if catalogue_sources: catalogue_data = helpers.filter_catalogue_data( catalogue_data=copy.deepcopy(catalogue_data), sources=catalogue_sources @@ -278,12 +283,11 @@ def match( ) # Add new vectors to cache - for key, value in new_text_vectors.items(): - vector_key = vectors_cache.generate_key( - text=key, model_framework=model.framework, model_name=model.model - ) - if not vectors_cache.has(vector_key): - vectors_cache.set(vector_key, {key: value}) + vectors_cache.add( + new_text_vectors=new_text_vectors, + model_name=model.model, + framework=model.framework + ) # List of matches matches_jsonable = matches.tolist() @@ -332,3 +336,96 @@ def get_cache() -> CacheResponse: instruments=instruments_list, vectors=vectors_list, ) + + +@router.post( + path="/search_instruments", + response_model=SearchInstrumentsResponse, + status_code=status.HTTP_200_OK, + response_model_exclude_none=True, +) +def search_instruments( + search_instruments_body: SearchInstrumentsBody = SearchInstrumentsBody(), + query: str | None = Query(default=None), + instrument_length_min: int = Query(default=5, gt=0), + instrument_length_max: int = Query(default=10, gt=0), + sources: List[str] = Query(default=[]), + topics: List[str] = Query(default=[]), +) -> SearchInstrumentsResponse: + """ + Search instruments. + """ + + # If min length is bigger than max length, set the min length to equal the max length + if instrument_length_min and instrument_length_max: + if instrument_length_min > instrument_length_max: + instrument_length_min = instrument_length_max + + # Model + model = search_instruments_body.parameters + model_dict = model.model_dump(mode="json") + + # Get vect function + vectorisation_function = helpers.get_vectorisation_function_for_model( + model=model_dict + ) + if not vectorisation_function: + raise http_exceptions.CouldNotFindResourceHTTPException( + "Could not find a vectorisation function for model." + ) + + # Currently only the model "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" is supported for + # searching instruments. + if model_dict["model"] != constants.HUGGINGFACE_MINILM_L12_V2["model"]: + return SearchInstrumentsResponse(instruments=[]) + + # Catalogue data + catalogue_embeddings = catalogue_data_embeddings_for_model[model_dict["model"]] + if catalogue_embeddings.size == 0: + return SearchInstrumentsResponse(instruments=[]) + catalogue_data = {"all_embeddings_concatenated": catalogue_embeddings} + catalogue_data.update(catalogue_data_default) + + # Filter catalogue data + if sources or topics or instrument_length_min or instrument_length_max: + catalogue_data = helpers.filter_catalogue_data( + catalogue_data=copy.deepcopy(catalogue_data), + sources=sources, + topics=topics, + instrument_length_min=instrument_length_min, + instrument_length_max=instrument_length_max, + ) + + # Query is provided: Match the query with the catalogue instruments + if query: + texts_cached_vectors = helpers.get_cached_text_vectors( + instruments=[], query=query, model=model_dict + ) + + match_result = match_query_with_catalogue_instruments( + query=query, + catalogue_data=catalogue_data, + vectorisation_function=vectorisation_function, + texts_cached_vectors=texts_cached_vectors, + ) + + # Add new vectors to cache + vectors_cache.add( + new_text_vectors=match_result["new_text_vectors"], + model_name=model.model, + framework=model.framework, + ) + + return SearchInstrumentsResponse(instruments=match_result["instruments"]) + + # No query provided: Get the first n catalogue instruments + else: + top_n = 100 + catalogue_instruments = catalogue_data["all_instruments"][:top_n] + instruments = [ + Instrument.model_validate(catalogue_instrument) + for catalogue_instrument in catalogue_instruments + ] + + return SearchInstrumentsResponse(instruments=instruments) + diff --git a/harmony_api/services/vectors_cache.py b/harmony_api/services/vectors_cache.py index 88a7e77..12df63d 100644 --- a/harmony_api/services/vectors_cache.py +++ b/harmony_api/services/vectors_cache.py @@ -64,7 +64,22 @@ def __load(self): self.__cache = cache - def set(self, key: str, value: dict[str, List[float]]): + def add(self, new_text_vectors: dict[str, List[List]], model_name: str, framework: str) -> None: + """ + Add new text vectors to cache. + + :param new_text_vectors: A dict of new text vectors. + :param model_name: The model name. + :param framework: The framework. + """ + + for key, value in new_text_vectors.items(): + vector_key = self.generate_key( + text=key, model_framework=framework, model_name=model_name + ) + self.__set(vector_key, {key: value}) + + def __set(self, key: str, value: dict[str, List[float]]): """ :param key: The cache key. :param value: The cache value. From 62f73a3cdba3441ce53a2773804afa92b1ad7a51 Mon Sep 17 00:00:00 2001 From: Zairon Jacobs Date: Thu, 5 Sep 2024 07:26:38 +0200 Subject: [PATCH 3/4] add max_results variable to search_instruments func --- harmony_api/routers/text_router.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/harmony_api/routers/text_router.py b/harmony_api/routers/text_router.py index 90874bd..5ebcb1b 100644 --- a/harmony_api/routers/text_router.py +++ b/harmony_api/routers/text_router.py @@ -396,6 +396,9 @@ def search_instruments( instrument_length_max=instrument_length_max, ) + # Return up to 100 instruments + max_results = 100 + # Query is provided: Match the query with the catalogue instruments if query: texts_cached_vectors = helpers.get_cached_text_vectors( @@ -407,6 +410,7 @@ def search_instruments( catalogue_data=catalogue_data, vectorisation_function=vectorisation_function, texts_cached_vectors=texts_cached_vectors, + max_results=max_results, ) # Add new vectors to cache @@ -420,8 +424,7 @@ def search_instruments( # No query provided: Get the first n catalogue instruments else: - top_n = 100 - catalogue_instruments = catalogue_data["all_instruments"][:top_n] + catalogue_instruments = catalogue_data["all_instruments"][:max_results] instruments = [ Instrument.model_validate(catalogue_instrument) for catalogue_instrument in catalogue_instruments From 23c8086134a8095456ab8c358399603b87bd3883 Mon Sep 17 00:00:00 2001 From: Zairon Jacobs Date: Thu, 5 Sep 2024 07:34:25 +0200 Subject: [PATCH 4/4] check instrument length in catalogue filter func --- harmony_api/helpers.py | 77 +++++++++++++++++------------- harmony_api/routers/text_router.py | 5 -- 2 files changed, 44 insertions(+), 38 deletions(-) diff --git a/harmony_api/helpers.py b/harmony_api/helpers.py index 1e0299d..82d75e4 100644 --- a/harmony_api/helpers.py +++ b/harmony_api/helpers.py @@ -79,7 +79,7 @@ def get_example_instruments() -> List[Instrument]: example_instruments = [] with open( - str(os.getcwd()) + "/example_questionnaires.json", "r", encoding="utf-8" + str(os.getcwd()) + "/example_questionnaires.json", "r", encoding="utf-8" ) as file: for line in file: instrument = Instrument.model_validate_json(line) @@ -108,24 +108,24 @@ def get_mhc_embeddings(model_name: str) -> tuple: data_path = os.path.join(dir_path, "../mhc_embeddings") # submodule with open( - os.path.join(data_path, "mhc_questions.txt"), "r", encoding="utf-8" + os.path.join(data_path, "mhc_questions.txt"), "r", encoding="utf-8" ) as file: for line in file: mhc_question = Question(question_text=line) mhc_questions.append(mhc_question) with open( - os.path.join(data_path, "mhc_all_metadatas.json"), "r", encoding="utf-8" + os.path.join(data_path, "mhc_all_metadatas.json"), "r", encoding="utf-8" ) as file: for line in file: mhc_meta = json.loads(line) mhc_all_metadata.append(mhc_meta) with open( - os.path.join( - data_path, f"mhc_embeddings_{model_name.replace('/', '-')}.npy" - ), - "rb", + os.path.join( + data_path, f"mhc_embeddings_{model_name.replace('/', '-')}.npy" + ), + "rb", ) as file: mhc_embeddings = np.load(file, allow_pickle=True) except (Exception,) as e: @@ -153,8 +153,8 @@ def get_catalogue_data_default() -> dict: else: if settings.AZURE_STORAGE_URL: with requests.get( - url=f"{settings.AZURE_STORAGE_URL}/catalogue_data/{all_questions_ever_seen_json}", - stream=True, + url=f"{settings.AZURE_STORAGE_URL}/catalogue_data/{all_questions_ever_seen_json}", + stream=True, ) as response: if response.ok: buffer = BytesIO() @@ -171,8 +171,8 @@ def get_catalogue_data_default() -> dict: else: if settings.AZURE_STORAGE_URL: with requests.get( - url=f"{settings.AZURE_STORAGE_URL}/catalogue_data/{instrument_idx_to_question_idxs_json}", - stream=True, + url=f"{settings.AZURE_STORAGE_URL}/catalogue_data/{instrument_idx_to_question_idxs_json}", + stream=True, ) as response: if response.ok: buffer = BytesIO() @@ -193,8 +193,8 @@ def get_catalogue_data_default() -> dict: else: if settings.AZURE_STORAGE_URL: with requests.get( - url=f"{settings.AZURE_STORAGE_URL}/catalogue_data/{all_instruments_preprocessed_json}", - stream=True, + url=f"{settings.AZURE_STORAGE_URL}/catalogue_data/{all_instruments_preprocessed_json}", + stream=True, ) as response: if response.ok: buffer = BytesIO() @@ -237,8 +237,8 @@ def get_catalogue_data_model_embeddings(model: dict) -> np.ndarray: decompressor_results = [] decompressor = bz2.BZ2Decompressor() with requests.get( - url=f"{settings.AZURE_STORAGE_URL}/catalogue_data/{embeddings_filename}", - stream=True, + url=f"{settings.AZURE_STORAGE_URL}/catalogue_data/{embeddings_filename}", + stream=True, ) as response: if response.ok: for chunk in response.iter_content(chunk_size=1024): @@ -281,6 +281,17 @@ def normalize_text(text: str): if not topics: topics = [] + # If the value for any of these is less than 1, set it to 1 + if instrument_length_min and (instrument_length_min < 1): + instrument_length_min = 1 + if instrument_length_max and (instrument_length_max < 1): + instrument_length_max = 1 + + # If min length is bigger than max length, set the min length to equal the max length + if instrument_length_min and instrument_length_max: + if instrument_length_min > instrument_length_max: + instrument_length_min = instrument_length_max + # Lowercase sources and topics sources_set = {x.strip().lower() for x in sources if x.strip()} topics_set = {x.strip().lower() for x in topics if x.strip()} @@ -419,7 +430,7 @@ def check_model_availability(model: dict) -> bool: def get_cached_text_vectors( - instruments: List[Instrument], model: dict, query: str | None = None + instruments: List[Instrument], model: dict, query: str | None = None ) -> dict[str, List[float]]: """ Get cached text vectors. @@ -476,55 +487,55 @@ def get_vectorisation_function_for_model(model: dict) -> Callable | None: vectorisation_function: Callable | None = None if ( - model["framework"] == HUGGINGFACE_MINILM_L12_V2["framework"] - and model["model"] == HUGGINGFACE_MINILM_L12_V2["model"] + model["framework"] == HUGGINGFACE_MINILM_L12_V2["framework"] + and model["model"] == HUGGINGFACE_MINILM_L12_V2["model"] ): vectorisation_function = ( hugging_face_embeddings.get_hugging_face_embeddings_minilm_l12_v2 ) elif ( - model["framework"] == HUGGINGFACE_MPNET_BASE_V2["framework"] - and model["model"] == HUGGINGFACE_MPNET_BASE_V2["model"] + model["framework"] == HUGGINGFACE_MPNET_BASE_V2["framework"] + and model["model"] == HUGGINGFACE_MPNET_BASE_V2["model"] ): vectorisation_function = ( hugging_face_embeddings.get_hugging_face_embeddings_mpnet_base_v2 ) elif ( - model["framework"] == OPENAI_ADA_02["framework"] - and model["model"] == OPENAI_ADA_02["model"] + model["framework"] == OPENAI_ADA_02["framework"] + and model["model"] == OPENAI_ADA_02["model"] ): vectorisation_function = openai_embeddings.get_openai_embeddings_ada_02 elif ( - model["framework"] == OPENAI_3_LARGE["framework"] - and model["model"] == OPENAI_3_LARGE["model"] + model["framework"] == OPENAI_3_LARGE["framework"] + and model["model"] == OPENAI_3_LARGE["model"] ): vectorisation_function = openai_embeddings.get_openai_embeddings_3_large elif ( - model["framework"] == AZURE_OPENAI_3_LARGE["framework"] - and model["model"] == AZURE_OPENAI_3_LARGE["model"] + model["framework"] == AZURE_OPENAI_3_LARGE["framework"] + and model["model"] == AZURE_OPENAI_3_LARGE["model"] ): vectorisation_function = ( azure_openai_embeddings.get_azure_openai_embeddings_3_large ) elif ( - model["framework"] == AZURE_OPENAI_ADA_02["framework"] - and model["model"] == AZURE_OPENAI_ADA_02["model"] + model["framework"] == AZURE_OPENAI_ADA_02["framework"] + and model["model"] == AZURE_OPENAI_ADA_02["model"] ): vectorisation_function = ( azure_openai_embeddings.get_azure_openai_embeddings_ada_02 ) elif ( - model["framework"] == GOOGLE_GECKO_MULTILINGUAL["framework"] - and model["model"] == GOOGLE_GECKO_MULTILINGUAL["model"] + model["framework"] == GOOGLE_GECKO_MULTILINGUAL["framework"] + and model["model"] == GOOGLE_GECKO_MULTILINGUAL["model"] ): vectorisation_function = ( google_embeddings.get_google_embeddings_gecko_multilingual ) elif ( - model["framework"] == GOOGLE_GECKO_003["framework"] - and model["model"] == GOOGLE_GECKO_003["model"] + model["framework"] == GOOGLE_GECKO_003["framework"] + and model["model"] == GOOGLE_GECKO_003["model"] ): vectorisation_function = google_embeddings.get_google_embeddings_gecko_003 @@ -532,7 +543,7 @@ def get_vectorisation_function_for_model(model: dict) -> Callable | None: def assign_missing_ids_to_instruments( - instruments: List[Instrument], + instruments: List[Instrument], ) -> List[Instrument]: """ Assign missing IDs to instruments. diff --git a/harmony_api/routers/text_router.py b/harmony_api/routers/text_router.py index 5ebcb1b..63b0155 100644 --- a/harmony_api/routers/text_router.py +++ b/harmony_api/routers/text_router.py @@ -356,11 +356,6 @@ def search_instruments( Search instruments. """ - # If min length is bigger than max length, set the min length to equal the max length - if instrument_length_min and instrument_length_max: - if instrument_length_min > instrument_length_max: - instrument_length_min = instrument_length_max - # Model model = search_instruments_body.parameters model_dict = model.model_dump(mode="json")