Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Search instruments #9

Merged
merged 4 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 101 additions & 46 deletions harmony_api/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def get_example_instruments() -> List[Instrument]:
example_instruments = []

with open(
str(os.getcwd()) + "/example_questionnaires.json", "r", encoding="utf-8"
str(os.getcwd()) + "/example_questionnaires.json", "r", encoding="utf-8"
) as file:
for line in file:
instrument = Instrument.model_validate_json(line)
Expand Down Expand Up @@ -108,24 +108,24 @@ def get_mhc_embeddings(model_name: str) -> tuple:
data_path = os.path.join(dir_path, "../mhc_embeddings") # submodule

with open(
os.path.join(data_path, "mhc_questions.txt"), "r", encoding="utf-8"
os.path.join(data_path, "mhc_questions.txt"), "r", encoding="utf-8"
) as file:
for line in file:
mhc_question = Question(question_text=line)
mhc_questions.append(mhc_question)

with open(
os.path.join(data_path, "mhc_all_metadatas.json"), "r", encoding="utf-8"
os.path.join(data_path, "mhc_all_metadatas.json"), "r", encoding="utf-8"
) as file:
for line in file:
mhc_meta = json.loads(line)
mhc_all_metadata.append(mhc_meta)

with open(
os.path.join(
data_path, f"mhc_embeddings_{model_name.replace('/', '-')}.npy"
),
"rb",
os.path.join(
data_path, f"mhc_embeddings_{model_name.replace('/', '-')}.npy"
),
"rb",
) as file:
mhc_embeddings = np.load(file, allow_pickle=True)
except (Exception,) as e:
Expand Down Expand Up @@ -153,8 +153,8 @@ def get_catalogue_data_default() -> dict:
else:
if settings.AZURE_STORAGE_URL:
with requests.get(
url=f"{settings.AZURE_STORAGE_URL}/catalogue_data/{all_questions_ever_seen_json}",
stream=True,
url=f"{settings.AZURE_STORAGE_URL}/catalogue_data/{all_questions_ever_seen_json}",
stream=True,
) as response:
if response.ok:
buffer = BytesIO()
Expand All @@ -171,8 +171,8 @@ def get_catalogue_data_default() -> dict:
else:
if settings.AZURE_STORAGE_URL:
with requests.get(
url=f"{settings.AZURE_STORAGE_URL}/catalogue_data/{instrument_idx_to_question_idxs_json}",
stream=True,
url=f"{settings.AZURE_STORAGE_URL}/catalogue_data/{instrument_idx_to_question_idxs_json}",
stream=True,
) as response:
if response.ok:
buffer = BytesIO()
Expand All @@ -193,8 +193,8 @@ def get_catalogue_data_default() -> dict:
else:
if settings.AZURE_STORAGE_URL:
with requests.get(
url=f"{settings.AZURE_STORAGE_URL}/catalogue_data/{all_instruments_preprocessed_json}",
stream=True,
url=f"{settings.AZURE_STORAGE_URL}/catalogue_data/{all_instruments_preprocessed_json}",
stream=True,
) as response:
if response.ok:
buffer = BytesIO()
Expand Down Expand Up @@ -237,8 +237,8 @@ def get_catalogue_data_model_embeddings(model: dict) -> np.ndarray:
decompressor_results = []
decompressor = bz2.BZ2Decompressor()
with requests.get(
url=f"{settings.AZURE_STORAGE_URL}/catalogue_data/{embeddings_filename}",
stream=True,
url=f"{settings.AZURE_STORAGE_URL}/catalogue_data/{embeddings_filename}",
stream=True,
) as response:
if response.ok:
for chunk in response.iter_content(chunk_size=1024):
Expand All @@ -252,12 +252,22 @@ def get_catalogue_data_model_embeddings(model: dict) -> np.ndarray:
return all_embeddings_concatenated


def filter_catalogue_data(catalogue_data: dict, sources: List[str]) -> dict:
def filter_catalogue_data(
catalogue_data: dict,
sources: List[str] | None = None,
topics: List[str] | None = None,
instrument_length_min: int | None = None,
instrument_length_max: int | None = None,
) -> dict:
"""
Filter catalogue data to only keep instruments with the sources.

:param catalogue_data: Catalogue data.
:param sources: Only keep instruments from sources.
:param topics: Only keep instruments with these topics. Topics can be found in the metadata of each instrument.
:param instrument_length_min: Only keep instruments with min number of questions.
:param instrument_length_max: Only keep instruments with max number of questions.
:return: The filtered catalogue data.
"""

def normalize_text(text: str):
Expand All @@ -266,17 +276,30 @@ def normalize_text(text: str):

return text

# Lowercase sources
sources_set = {x.strip().lower() for x in sources if x.strip()}
if not sources:
sources = []
if not topics:
topics = []

# If the value for any of these is less than 1, set it to 1
if instrument_length_min and (instrument_length_min < 1):
instrument_length_min = 1
if instrument_length_max and (instrument_length_max < 1):
instrument_length_max = 1

# Nothing to filter
if not sources_set:
return catalogue_data
# If min length is bigger than max length, set the min length to equal the max length
if instrument_length_min and instrument_length_max:
if instrument_length_min > instrument_length_max:
instrument_length_min = instrument_length_max

# Lowercase sources and topics
sources_set = {x.strip().lower() for x in sources if x.strip()}
topics_set = {x.strip().lower() for x in topics if x.strip()}

# Create a dictionary with questions and their vectors
question_normalized_to_vector: dict[str, List[float]] = {}
for question, vector in zip(
catalogue_data["all_questions"], catalogue_data["all_embeddings_concatenated"]
catalogue_data["all_questions"], catalogue_data["all_embeddings_concatenated"]
):
question_normalized = normalize_text(question)
if question_normalized not in question_normalized_to_vector:
Expand All @@ -285,13 +308,45 @@ def normalize_text(text: str):
# Find instrument indexes to remove
idxs_instruments_to_remove: List[int] = []
for instrument_idx, catalogue_instrument in enumerate(
catalogue_data["all_instruments"]
catalogue_data["all_instruments"]
):
if (
catalogue_instrument["metadata"]["source"].strip().lower()
not in sources_set
):
idxs_instruments_to_remove.append(instrument_idx)
questions_len = len(catalogue_instrument["questions"])

# By min instrument questions length
if instrument_length_min:
if questions_len < instrument_length_min:
idxs_instruments_to_remove.append(instrument_idx)
continue

# By max instrument questions length
if instrument_length_max:
if questions_len > instrument_length_max:
idxs_instruments_to_remove.append(instrument_idx)
continue

# By sources
if sources_set:
if (
catalogue_instrument["metadata"]["source"].strip().lower()
not in sources_set
):
idxs_instruments_to_remove.append(instrument_idx)
continue

# By topics
if topics_set:
not_found_topics_len = 0
catalogue_instrument_topics: list[str] = catalogue_instrument[
"metadata"
].get("topics", [])
for topic in topics_set:
if topic not in [
x.strip().lower() for x in catalogue_instrument_topics if x.strip()
]:
not_found_topics_len += 1
if not_found_topics_len == len(topics_set):
idxs_instruments_to_remove.append(instrument_idx)
continue

# Remove instruments
for idx_instrument_to_remove in sorted(idxs_instruments_to_remove, reverse=True):
Expand Down Expand Up @@ -375,7 +430,7 @@ def check_model_availability(model: dict) -> bool:


def get_cached_text_vectors(
instruments: List[Instrument], model: dict, query: str | None = None
instruments: List[Instrument], model: dict, query: str | None = None
) -> dict[str, List[float]]:
"""
Get cached text vectors.
Expand Down Expand Up @@ -432,63 +487,63 @@ def get_vectorisation_function_for_model(model: dict) -> Callable | None:
vectorisation_function: Callable | None = None

if (
model["framework"] == HUGGINGFACE_MINILM_L12_V2["framework"]
and model["model"] == HUGGINGFACE_MINILM_L12_V2["model"]
model["framework"] == HUGGINGFACE_MINILM_L12_V2["framework"]
and model["model"] == HUGGINGFACE_MINILM_L12_V2["model"]
):
vectorisation_function = (
hugging_face_embeddings.get_hugging_face_embeddings_minilm_l12_v2
)

elif (
model["framework"] == HUGGINGFACE_MPNET_BASE_V2["framework"]
and model["model"] == HUGGINGFACE_MPNET_BASE_V2["model"]
model["framework"] == HUGGINGFACE_MPNET_BASE_V2["framework"]
and model["model"] == HUGGINGFACE_MPNET_BASE_V2["model"]
):
vectorisation_function = (
hugging_face_embeddings.get_hugging_face_embeddings_mpnet_base_v2
)

elif (
model["framework"] == OPENAI_ADA_02["framework"]
and model["model"] == OPENAI_ADA_02["model"]
model["framework"] == OPENAI_ADA_02["framework"]
and model["model"] == OPENAI_ADA_02["model"]
):
vectorisation_function = openai_embeddings.get_openai_embeddings_ada_02
elif (
model["framework"] == OPENAI_3_LARGE["framework"]
and model["model"] == OPENAI_3_LARGE["model"]
model["framework"] == OPENAI_3_LARGE["framework"]
and model["model"] == OPENAI_3_LARGE["model"]
):
vectorisation_function = openai_embeddings.get_openai_embeddings_3_large
elif (
model["framework"] == AZURE_OPENAI_3_LARGE["framework"]
and model["model"] == AZURE_OPENAI_3_LARGE["model"]
model["framework"] == AZURE_OPENAI_3_LARGE["framework"]
and model["model"] == AZURE_OPENAI_3_LARGE["model"]
):
vectorisation_function = (
azure_openai_embeddings.get_azure_openai_embeddings_3_large
)
elif (
model["framework"] == AZURE_OPENAI_ADA_02["framework"]
and model["model"] == AZURE_OPENAI_ADA_02["model"]
model["framework"] == AZURE_OPENAI_ADA_02["framework"]
and model["model"] == AZURE_OPENAI_ADA_02["model"]
):
vectorisation_function = (
azure_openai_embeddings.get_azure_openai_embeddings_ada_02
)
elif (
model["framework"] == GOOGLE_GECKO_MULTILINGUAL["framework"]
and model["model"] == GOOGLE_GECKO_MULTILINGUAL["model"]
model["framework"] == GOOGLE_GECKO_MULTILINGUAL["framework"]
and model["model"] == GOOGLE_GECKO_MULTILINGUAL["model"]
):
vectorisation_function = (
google_embeddings.get_google_embeddings_gecko_multilingual
)
elif (
model["framework"] == GOOGLE_GECKO_003["framework"]
and model["model"] == GOOGLE_GECKO_003["model"]
model["framework"] == GOOGLE_GECKO_003["framework"]
and model["model"] == GOOGLE_GECKO_003["model"]
):
vectorisation_function = google_embeddings.get_google_embeddings_gecko_003

return vectorisation_function


def assign_missing_ids_to_instruments(
instruments: List[Instrument],
instruments: List[Instrument],
) -> List[Instrument]:
"""
Assign missing IDs to instruments.
Expand Down
Loading
Loading