diff --git a/cassiopeia/solver/missing_data_methods.py b/cassiopeia/solver/missing_data_methods.py index 71a07880..4cf10351 100644 --- a/cassiopeia/solver/missing_data_methods.py +++ b/cassiopeia/solver/missing_data_methods.py @@ -1,4 +1,5 @@ """This file contains included missing data imputation methods.""" + from typing import Dict, List, Optional, Tuple, Union import numpy as np @@ -59,33 +60,52 @@ def assign_missing_average( def score_side(subset_character_states, query_states, weights): score = 0 for char in range(len(subset_character_states)): - - query_state = [q for q in query_states[char] if q != 0 and q != missing_state_indicator] + + query_state = [ + q + for q in query_states[char] + if q != 0 and q != missing_state_indicator + ] all_states = np.array(subset_character_states[char]) for q in query_state: if weights: - score += ( - weights[char][q] - * np.count_nonzero(all_states == q) + score += weights[char][q] * np.count_nonzero( + all_states == q ) else: score += np.count_nonzero(all_states == q) - return score - + return score + subset_character_array_left = character_array[left_indices, :] subset_character_array_right = character_array[right_indices, :] - all_left_states = [unravel_ambiguous_states(subset_character_array_left[:,char]) for char in range(subset_character_array_left.shape[1])] - all_right_states = [unravel_ambiguous_states(subset_character_array_right[:,char]) for char in range(subset_character_array_right.shape[1])] + all_left_states = [ + unravel_ambiguous_states(subset_character_array_left[:, char]) + for char in range(subset_character_array_left.shape[1]) + ] + all_right_states = [ + unravel_ambiguous_states(subset_character_array_right[:, char]) + for char in range(subset_character_array_right.shape[1]) + ] - for sample_index in missing_indices: - all_states_for_sample = [unravel_ambiguous_states([character_array[sample_index, char]]) for char in range(character_array.shape[1])] - - left_score = score_side(np.array(all_left_states, dtype=object), np.array(all_states_for_sample, dtype=object), weights) - right_score = score_side(np.array(all_right_states, dtype=object), np.array(all_states_for_sample, dtype=object), weights) + all_states_for_sample = [ + unravel_ambiguous_states([character_array[sample_index, char]]) + for char in range(character_array.shape[1]) + ] + + left_score = score_side( + np.array(all_left_states, dtype=object), + np.array(all_states_for_sample, dtype=object), + weights, + ) + right_score = score_side( + np.array(all_right_states, dtype=object), + np.array(all_states_for_sample, dtype=object), + weights, + ) if (left_score / len(left_set)) > (right_score / len(right_set)): left_set.append(sample_names[sample_index])