Skip to content

Commit

Permalink
Merge pull request #361 from MannLabs/refactor_volcano_III
Browse files Browse the repository at this point in the history
Refactor volcano iii
  • Loading branch information
mschwoer authored Nov 14, 2024
2 parents 7a04325 + 230857d commit 0f4694f
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 150 deletions.
52 changes: 50 additions & 2 deletions alphastats/DataSet.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,21 @@ def __init__(
self.sample: str = sample
self.preprocessing_info: Dict = preprocessing_info

self._gene_name_to_protein_id_map = (
{
k: v
for k, v in dict(
zip(
self.rawinput[self._gene_names].tolist(),
self.rawinput[self.index_column].tolist(),
)
).items()
if isinstance(k, str) # avoid having NaN as key
}
if self._gene_names
else {}
)

print("DataSet has been created.")

def _get_init_dataset(
Expand Down Expand Up @@ -403,9 +418,32 @@ def plot_volcano(

return volcano_plot.plot

def _get_protein_id_for_gene_name(
self,
gene_name: str,
) -> str:
"""Get protein id from gene id. If gene id is not present, return gene id, as we might already have a gene id.
'VCL;HEL114' -> 'P18206;A0A024QZN4;V9HWK2;B3KXA2;Q5JQ13;B4DKC9;B4DTM7;A0A096LPE1'
Args:
gene_name (str): Gene name
Returns:
str: Protein id or gene name if not present in the mapping.
"""
if gene_name in self._gene_name_to_protein_id_map:
return self._gene_name_to_protein_id_map[gene_name]

for gene, protein_id in self._gene_name_to_protein_id_map.items():
if gene_name in gene.split(";"):
return protein_id
return gene_name

def plot_intensity(
self,
protein_id: str,
*,
protein_id: str = None,
gene_name: str = None,
group: str = None,
subgroups: list = None,
method: str = "box",
Expand All @@ -416,7 +454,8 @@ def plot_intensity(
"""Plot Intensity of individual Protein/ProteinGroup
Args:
protein_id (str): ProteinGroup ID
protein_id (str): ProteinGroup ID. Mutually exclusive with gene_name.
gene_name (str): Gene Name, will be mapped to a ProteinGroup ID. Mutually exclusive with protein_id.
group (str, optional): A metadata column used for grouping. Defaults to None.
subgroups (list, optional): Select variables from the group column. Defaults to None.
method (str, optional): Violinplot = "violin", Boxplot = "box", Scatterplot = "scatter" or "all". Defaults to "box".
Expand All @@ -434,6 +473,15 @@ def plot_intensity(
# )
# return results

if gene_name is None and protein_id is not None:
pass
elif gene_name is not None and protein_id is None:
protein_id = self._get_protein_id_for_gene_name(gene_name)
else:
raise ValueError(
"Either protein_id or gene_name must be provided, but not both."
)

intensity_plot = IntensityPlot(
mat=self.mat,
metadata=self.metadata,
Expand Down
43 changes: 15 additions & 28 deletions alphastats/gui/pages/05_LLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@ def llm_config():
current_model = st.session_state.get(StateKeys.MODEL_NAME, None)

models = [Models.GPT4O, Models.OLLAMA_31_70B, Models.OLLAMA_31_8B]
st.session_state[StateKeys.MODEL_NAME] = st.selectbox(
model_name = st.selectbox(
"Select LLM",
models,
index=models.index(st.session_state.get(StateKeys.MODEL_NAME))
if current_model is not None
else 0,
)
st.session_state[StateKeys.MODEL_NAME] = model_name

base_url = None
if st.session_state[StateKeys.MODEL_NAME] in [Models.GPT4O]:
Expand Down Expand Up @@ -92,49 +93,35 @@ def llm_config():
st.write(f"Parameters used for analysis: {parameter_dict}")
c1, c2 = st.columns((1, 2))

with c1:
genes_of_interest_df = volcano_plot.res
genes_of_interest_df = genes_of_interest_df[genes_of_interest_df["label"] != ""]

gene_names_colname = st.session_state[StateKeys.LOADER].gene_names
prot_ids_colname = st.session_state[StateKeys.LOADER].index_column
with c2:
display_figure(volcano_plot.plot)

gene_to_prot_id_map = dict( # TODO move this logic to dataset
zip(
genes_of_interest_df[gene_names_colname].tolist(),
genes_of_interest_df[prot_ids_colname].tolist(),
)
with c1:
regulated_genes_df = volcano_plot.res[volcano_plot.res["label"] != ""]
regulated_genes_dict = dict(
zip(regulated_genes_df["label"], regulated_genes_df["color"].tolist())
)

with c2:
display_figure(volcano_plot.plot)

labels = [
";".join([i for i in j.split(";") if i])
for j in genes_of_interest_df["label"].tolist()
]
genes_of_interest = dict(zip(labels, genes_of_interest_df["color"].tolist()))

if not genes_of_interest:
if not regulated_genes_dict:
st.text("No genes of interest found.")
st.stop()

upregulated_genes = [
key for key in genes_of_interest if genes_of_interest[key] == "up"
key for key in regulated_genes_dict if regulated_genes_dict[key] == "up"
]
downregulated_genes = [
key for key in genes_of_interest if genes_of_interest[key] == "down"
key for key in regulated_genes_dict if regulated_genes_dict[key] == "down"
]

st.markdown("##### Genes of interest")
c1, c2 = st.columns((1, 2), gap="medium")
with c1:
c11, c12 = st.columns((1, 2), gap="medium")
with c11:
st.write("Upregulated genes")
st.markdown(
get_display_proteins_html(upregulated_genes, True), unsafe_allow_html=True
)

with c2:
with c12:
st.write("Downregulated genes")
st.markdown(
get_display_proteins_html(downregulated_genes, False),
Expand Down Expand Up @@ -194,7 +181,7 @@ def llm_config():
api_key=st.session_state[StateKeys.OPENAI_API_KEY],
base_url=OLLAMA_BASE_URL,
dataset=st.session_state[StateKeys.DATASET],
gene_to_prot_id_map=gene_to_prot_id_map,
genes_of_interest=list(regulated_genes_dict.keys()),
)

st.session_state[StateKeys.LLM_INTEGRATION][model_name] = llm_integration
Expand Down
9 changes: 4 additions & 5 deletions alphastats/llm/llm_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_general_assistant_functions() -> List[Dict]:


def get_assistant_functions(
gene_to_prot_id_map: Dict,
genes_of_interest: List[str],
metadata: pd.DataFrame,
subgroups_for_each_group: Dict,
) -> List[Dict]:
Expand All @@ -78,13 +78,12 @@ def get_assistant_functions(
For more information on how to format functions for Assistants, see https://platform.openai.com/docs/assistants/tools/function-calling
Args:
gene_to_prot_id_map (dict): A dictionary with gene names as keys and protein IDs as values.
genes_of_interest (list): A list with gene names.
metadata (pd.DataFrame): The metadata dataframe (which sample has which disease/treatment/condition/etc).
subgroups_for_each_group (dict): A dictionary with the column names as keys and a list of unique values as values. Defaults to get_subgroups_for_each_group().
Returns:
List[Dict]: A list of dictionaries desscribing the assistant functions.
"""
gene_names = list(gene_to_prot_id_map.keys())
groups = [str(col) for col in metadata.columns.to_list()]
return [
{
Expand All @@ -95,9 +94,9 @@ def get_assistant_functions(
"parameters": {
"type": "object",
"properties": {
"protein_id": { # LLM will provide gene_name, mapping to protein_id is done when calling the function
"gene_name": {
"type": "string",
"enum": gene_names,
"enum": genes_of_interest,
"description": "Identifier for the gene of interest",
},
"group": {
Expand Down
27 changes: 6 additions & 21 deletions alphastats/llm/llm_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
get_general_assistant_functions,
)
from alphastats.llm.llm_utils import (
get_protein_id_for_gene_name,
get_subgroups_for_each_group,
)
from alphastats.llm.prompts import get_tool_call_message
Expand Down Expand Up @@ -54,8 +53,8 @@ class LLMIntegration:
The API key for authentication, by default None
dataset : Any, optional
The dataset to be used in the conversation, by default None
gene_to_prot_id_map: optional
Mapping of gene names to protein IDs
genes_of_interest: optional
List of regulated genes
"""

def __init__(
Expand All @@ -67,7 +66,7 @@ def __init__(
system_message: str = None,
load_tools: bool = True,
dataset: Optional[DataSet] = None,
gene_to_prot_id_map: Optional[Dict[str, str]] = None,
genes_of_interest: Optional[List[str]] = None,
):
self._model = model_name

Expand All @@ -81,7 +80,7 @@ def __init__(

self._dataset = dataset
self._metadata = None if dataset is None else dataset.metadata
self._gene_to_prot_id_map = gene_to_prot_id_map
self._genes_of_interest = genes_of_interest

self._tools = self._get_tools() if load_tools else None

Expand All @@ -104,10 +103,10 @@ def _get_tools(self) -> List[Dict[str, Any]]:
tools = [
*get_general_assistant_functions(),
]
if self._metadata is not None and self._gene_to_prot_id_map is not None:
if self._metadata is not None and self._genes_of_interest is not None:
tools += (
*get_assistant_functions(
gene_to_prot_id_map=self._gene_to_prot_id_map,
genes_of_interest=self._genes_of_interest,
metadata=self._metadata,
subgroups_for_each_group=get_subgroups_for_each_group(
self._metadata
Expand Down Expand Up @@ -196,20 +195,6 @@ def _execute_function(
# first try to find the function in the non-Dataset functions
if (function := GENERAL_FUNCTION_MAPPING.get(function_name)) is not None:
return function(**function_args)

# special treatment for this function
elif function_name == "plot_intensity":
# TODO move this logic to dataset
gene_name = function_args.pop(
"protein_id"
) # no typo, the LLM sets "protein_id" to gene_name
protein_id = get_protein_id_for_gene_name(
gene_name, self._gene_to_prot_id_map
)
function_args["protein_id"] = protein_id

return self._dataset.plot_intensity(**function_args)

# look up the function in the DataSet class
else:
function = getattr(
Expand Down
23 changes: 0 additions & 23 deletions alphastats/llm/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,3 @@ def get_subgroups_for_each_group(
for group in groups
}
return group_to_subgroup_values


def get_protein_id_for_gene_name(
gene_name: str, gene_to_prot_id_map: Dict[str, str]
) -> str:
"""Get protein id from gene id. If gene id is not present, return gene id, as we might already have a gene id.
'VCL;HEL114' -> 'P18206;A0A024QZN4;V9HWK2;B3KXA2;Q5JQ13;B4DKC9;B4DTM7;A0A096LPE1'
Args:
gene_name (str): Gene id
gene_to_prot_id_map (Dict[str, str]): Gene name to protein id mapping.
Returns:
str: Protein id or gene id if not present in the mapping.
"""
if gene_name in gene_to_prot_id_map:
return gene_to_prot_id_map[gene_name]

for gene, protein_id in gene_to_prot_id_map.items():
if gene_name in gene.split(";"):
return protein_id

return gene_name
24 changes: 1 addition & 23 deletions tests/llm/test_llm_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def llm_integration(mock_openai_client):
api_key="test-key", # pragma: allowlist secret
system_message="Test system message",
dataset=dataset,
gene_to_prot_id_map={"GENE1": "PROT1", "GENE2": "PROT2"},
genes_of_interest={"GENE1": "PROT1", "GENE2": "PROT2"},
)


Expand Down Expand Up @@ -286,28 +286,6 @@ def test_execute_general_function(
assert result == expected_result


@pytest.mark.parametrize(
"gene_name,plot_args,expected_protein_id",
[
("GENE1", {"param1": "value1"}, "PROT1"),
("GENE2", {"param2": "value2"}, "PROT2"),
],
)
def test_execute_plot_intensity(
llm_integration, gene_name, plot_args, expected_protein_id
):
"""Test execution of plot_intensity with gene name translation"""
function_args = {"protein_id": gene_name, **plot_args}

result = llm_integration._execute_function("plot_intensity", function_args)

# Verify the dataset's plot_intensity was called with correct protein ID
llm_integration._dataset.plot_intensity.assert_called_once()
call_args = llm_integration._dataset.plot_intensity.call_args[1]
assert call_args["protein_id"] == expected_protein_id
assert result == "Plot created"


def test_execute_dataset_function(llm_integration):
"""Test execution of a function from the dataset"""
result = llm_integration._execute_function("custom_function", {"param1": "value1"})
Expand Down
48 changes: 0 additions & 48 deletions tests/llm/test_llm_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import pandas as pd
import pytest

from alphastats.llm.llm_utils import (
get_protein_id_for_gene_name,
get_subgroups_for_each_group,
)

Expand Down Expand Up @@ -52,49 +50,3 @@ def test_get_subgroups_for_each_group_numeric_values():

expected = {"age_group": ["20", "30"], "score": ["1.5", "2.5"]}
assert result == expected


@pytest.fixture
def gene_to_prot_map():
"""Fixture for protein mapping dictionary."""
return {
"VCL": "P18206",
"VCL;HEL114": "P18206;A0A024QZN4",
"MULTI;GENE": "PROT1;PROT2;PROT3",
}


def test_get_protein_id_direct_match(gene_to_prot_map):
"""Test when gene name directly matches a key."""
result = get_protein_id_for_gene_name("VCL", gene_to_prot_map)
assert result == "P18206"


def test_get_protein_id_compound_key(gene_to_prot_map):
"""Test when gene name is part of a compound key."""
result = get_protein_id_for_gene_name("HEL114", gene_to_prot_map)
assert result == "P18206;A0A024QZN4"


def test_get_protein_id_not_found(gene_to_prot_map):
"""Test when gene name is not found in mapping."""
result = get_protein_id_for_gene_name("UNKNOWN", gene_to_prot_map)
assert result == "UNKNOWN"


def test_get_protein_id_empty_map():
"""Test with empty mapping dictionary."""
result = get_protein_id_for_gene_name("VCL", {})
assert result == "VCL"


def test_get_protein_id_multiple_matches(gene_to_prot_map):
"""Test with a gene that appears in multiple compound keys."""
result = get_protein_id_for_gene_name("MULTI", gene_to_prot_map)
assert result == "PROT1;PROT2;PROT3"


def test_get_protein_id_case_sensitivity(gene_to_prot_map):
"""Test case sensitivity of gene name matching."""
result = get_protein_id_for_gene_name("vcl", gene_to_prot_map)
assert result == "vcl" # Should not match 'VCL' due to case sensitivity
Loading

0 comments on commit 0f4694f

Please sign in to comment.