Skip to content

Commit

Permalink
added mixin utilities test
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjones315 committed Sep 27, 2023
1 parent 085a409 commit ac3deb4
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 14 deletions.
44 changes: 34 additions & 10 deletions cassiopeia/solver/GreedySolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
import pandas as pd

from cassiopeia.data import CassiopeiaTree
from cassiopeia.mixins import GreedySolverError, is_ambiguous_state, unravel_ambiguous_states
from cassiopeia.mixins import (
GreedySolverError,
is_ambiguous_state,
unravel_ambiguous_states,
)
from cassiopeia.solver import CassiopeiaSolver, solver_utilities


Expand Down Expand Up @@ -41,7 +45,6 @@ class GreedySolver(CassiopeiaSolver.CassiopeiaSolver):
"""

def __init__(self, prior_transformation: str = "negative_log"):

super().__init__(prior_transformation)
self.allow_ambiguous = False

Expand Down Expand Up @@ -155,13 +158,29 @@ def _solve(
character_matrix = cassiopeia_tree.character_matrix.copy()

# Raise exception if the character matrix has ambiguous states.
if any(
is_ambiguous_state(state)
for state in character_matrix.values.flatten()
) and not self.allow_ambiguous:
raise GreedySolverError("Ambiguous states are not currently supported with this solver.")
if (
any(
is_ambiguous_state(state)
for state in character_matrix.values.flatten()
)
and not self.allow_ambiguous
):
raise GreedySolverError(
"Ambiguous states are not currently supported with this solver."
)

keep_rows = character_matrix.apply(lambda x: [set(s) if is_ambiguous_state(s) else set([s]) for s in x.values], axis=0).apply(tuple, axis=1).drop_duplicates().index.values
keep_rows = (
character_matrix.apply(
lambda x: [
set(s) if is_ambiguous_state(s) else set([s])
for s in x.values
],
axis=0,
)
.apply(tuple, axis=1)
.drop_duplicates()
.index.values
)
unique_character_matrix = character_matrix.loc[keep_rows].copy()

tree = nx.DiGraph()
Expand Down Expand Up @@ -198,9 +217,14 @@ def compute_mutation_frequencies(
Generates a dictionary that maps each character to a dictionary of state/
sample frequency pairs, allowing quick lookup. Subsets the character matrix
to only include the samples in the sample set.
This currently supports ambiguous states, for the GreedySolvers that
support ambiguous states during inference.
Args:
samples: The set of relevant samples in calculating frequencies
unique_character_matrix: The character matrix from which to calculate frequencies
unique_character_matrix: The character matrix from which to
calculate frequencies
missing_state_indicator: The character representing missing values
Returns:
A dictionary containing frequency information for each character/state
Expand All @@ -210,7 +234,7 @@ def compute_mutation_frequencies(
freq_dict = {}
for char in range(subset_cm.shape[1]):
char_dict = {}
all_states = unravel_ambiguous_states(subset_cm[:,char])
all_states = unravel_ambiguous_states(subset_cm[:, char])
state_counts = np.unique(all_states, return_counts=True)

for i in range(len(state_counts[0])):
Expand Down
7 changes: 3 additions & 4 deletions cassiopeia/solver/missing_data_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,10 @@ def score_side(subset_character_matrix, missing_sample):
for char in range(character_matrix.shape[1]):
state = character_array[missing_sample, char]
if state != missing_state_indicator and state != 0:
all_states = unravel_ambiguous_states(subset_character_matrix[:, char])
all_states = unravel_ambiguous_states(
subset_character_matrix[:, char]
)
state_counts = np.unique(all_states, return_counts=True)
# state_counts = np.unique(
# subset_character_matrix[:, char], return_counts=True
# )
ind = np.where(state_counts[0] == state)
if len(ind[0]) > 0:
if weights:
Expand Down
11 changes: 11 additions & 0 deletions test/mixin_tests/mixin_utilities_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@ def test_is_ambiguous_state(self):
self.assertTrue(utilities.is_ambiguous_state((1, 2)))
self.assertFalse(utilities.is_ambiguous_state(1))

def test_unravel_states(self):
state_array = [0, (1, 2), 3, 4, 5]
self.assertListEqual(
[0, 1, 2, 3, 4, 5], utilities.unravel_ambiguous_states(state_array)
)

state_array = [0, 1, 2, 3, 4, 5]
self.assertListEqual(
[0, 1, 2, 3, 4, 5], utilities.unravel_ambiguous_states(state_array)
)


if __name__ == "__main__":
unittest.main()

0 comments on commit ac3deb4

Please sign in to comment.