Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Jones authored and Matthew Jones committed Feb 6, 2024
1 parent 81bf0ee commit 9de4088
Showing 1 changed file with 34 additions and 14 deletions.
48 changes: 34 additions & 14 deletions cassiopeia/solver/missing_data_methods.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""This file contains included missing data imputation methods."""

from typing import Dict, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -59,33 +60,52 @@ def assign_missing_average(
def score_side(subset_character_states, query_states, weights):
score = 0
for char in range(len(subset_character_states)):

query_state = [q for q in query_states[char] if q != 0 and q != missing_state_indicator]

query_state = [
q
for q in query_states[char]
if q != 0 and q != missing_state_indicator
]
all_states = np.array(subset_character_states[char])
for q in query_state:
if weights:
score += (
weights[char][q]
* np.count_nonzero(all_states == q)
score += weights[char][q] * np.count_nonzero(
all_states == q
)
else:
score += np.count_nonzero(all_states == q)

return score
return score

subset_character_array_left = character_array[left_indices, :]
subset_character_array_right = character_array[right_indices, :]

all_left_states = [unravel_ambiguous_states(subset_character_array_left[:,char]) for char in range(subset_character_array_left.shape[1])]
all_right_states = [unravel_ambiguous_states(subset_character_array_right[:,char]) for char in range(subset_character_array_right.shape[1])]
all_left_states = [
unravel_ambiguous_states(subset_character_array_left[:, char])
for char in range(subset_character_array_left.shape[1])
]
all_right_states = [
unravel_ambiguous_states(subset_character_array_right[:, char])
for char in range(subset_character_array_right.shape[1])
]


for sample_index in missing_indices:

all_states_for_sample = [unravel_ambiguous_states([character_array[sample_index, char]]) for char in range(character_array.shape[1])]

left_score = score_side(np.array(all_left_states, dtype=object), np.array(all_states_for_sample, dtype=object), weights)
right_score = score_side(np.array(all_right_states, dtype=object), np.array(all_states_for_sample, dtype=object), weights)
all_states_for_sample = [
unravel_ambiguous_states([character_array[sample_index, char]])
for char in range(character_array.shape[1])
]

left_score = score_side(
np.array(all_left_states, dtype=object),
np.array(all_states_for_sample, dtype=object),
weights,
)
right_score = score_side(
np.array(all_right_states, dtype=object),
np.array(all_states_for_sample, dtype=object),
weights,
)

if (left_score / len(left_set)) > (right_score / len(right_set)):
left_set.append(sample_names[sample_index])
Expand Down

0 comments on commit 9de4088

Please sign in to comment.