Skip to content

Commit

Permalink
oversample
Browse files Browse the repository at this point in the history
  • Loading branch information
kapoorlab committed Sep 2, 2024
1 parent df22da3 commit 374cd8d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 27 deletions.
47 changes: 22 additions & 25 deletions src/napatrackmater/Trackvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3389,7 +3389,9 @@ def plot_at_mitosis_time(matrix_directory, save_dir, dataset_name, channel):
plt.show()


def plot_histograms_for_groups(matrix_directory, save_dir, dataset_name, channel, name = 'all'):
def plot_histograms_for_groups(
matrix_directory, save_dir, dataset_name, channel, name="all"
):

files = os.listdir(matrix_directory)
sorted_files = natsorted(
Expand Down Expand Up @@ -3440,7 +3442,7 @@ def plot_histograms_for_groups(matrix_directory, save_dir, dataset_name, channel


def plot_histograms_for_cell_type_groups(
matrix_directory, save_dir, dataset_name, channel, label_dict = None, name = 'all'
matrix_directory, save_dir, dataset_name, channel, label_dict=None, name="all"
):

files = os.listdir(matrix_directory)
Expand Down Expand Up @@ -3470,11 +3472,13 @@ def plot_histograms_for_cell_type_groups(
file_path = os.path.join(matrix_directory, file_name)
cell_type = extract_celltype(file_name)
if label_dict is not None:
cell_type_name = label_dict[cell_type]
cell_type_name = label_dict[cell_type]
else:
cell_type_name = cell_type
cell_type_name = cell_type
data = np.load(file_path, allow_pickle=True)
sns.histplot(data, alpha=0.5, kde=True, label=f"Cell_Type: {cell_type_name}")
sns.histplot(
data, alpha=0.5, kde=True, label=f"Cell_Type: {cell_type_name}"
)

plt.xlabel("Value")
plt.ylabel("Counts")
Expand Down Expand Up @@ -4042,7 +4046,6 @@ def inception_model_prediction(
class_map,
dynamic_model=None,
shape_model=None,
num_samples=10,
device="cpu",
):
sub_dataframe = dataframe[dataframe["Track ID"] == track_id]
Expand All @@ -4051,13 +4054,13 @@ def inception_model_prediction(

total_duration = sub_dataframe["Track Duration"].max()

def sample_subarrays(data, num_samples, tracklet_length, total_duration):
if sub_dataframe.shape[0] < tracklet_length:
return "UnClassified"

def sample_subarrays(data, tracklet_length, total_duration):

max_start_index = total_duration - tracklet_length
if max_start_index > num_samples:
start_indices = random.sample(range(max_start_index), num_samples)
else:
start_indices = [0] * num_samples
start_indices = random.sample(range(max_start_index), max_start_index)

subarrays = []
for start_index in start_indices:
Expand All @@ -4070,10 +4073,10 @@ def sample_subarrays(data, num_samples, tracklet_length, total_duration):
return subarrays

sub_arrays_shape = sample_subarrays(
sub_dataframe_shape, num_samples, tracklet_length, total_duration
sub_dataframe_shape, tracklet_length, total_duration
)
sub_arrays_dynamic = sample_subarrays(
sub_dataframe_dynamic, num_samples, tracklet_length, total_duration
sub_dataframe_dynamic, tracklet_length, total_duration
)

def make_prediction(input_data, model):
Expand All @@ -4089,12 +4092,11 @@ def make_prediction(input_data, model):
return predicted_class.item()

def get_most_frequent_prediction(predictions):
if predictions:

prediction_counts = Counter(predictions)
most_common_prediction, count = prediction_counts.most_common(1)[0]
prediction_counts = Counter(predictions)
most_common_prediction, count = prediction_counts.most_common(1)[0]

return most_common_prediction
return most_common_prediction

shape_predictions = []
if shape_model is not None:
Expand All @@ -4109,16 +4111,11 @@ def get_most_frequent_prediction(predictions):
dynamic_predictions.append(predicted_class)

final_predictions = shape_predictions + dynamic_predictions
if len(final_predictions) > 0:
most_frequent_prediction = get_most_frequent_prediction(final_predictions)
if most_frequent_prediction is not None:
most_predicted_class = class_map[int(most_frequent_prediction)]

return most_predicted_class
most_frequent_prediction = get_most_frequent_prediction(final_predictions)

else:
most_predicted_class = class_map[int(most_frequent_prediction)]

return "UnClassified"
return most_predicted_class


def save_cell_type_predictions(
Expand Down
4 changes: 2 additions & 2 deletions src/napatrackmater/_version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = version = "5.4.9"
__version_tuple__ = version_tuple = (5, 4, 9)
__version__ = version = "5.5.0"
__version_tuple__ = version_tuple = (5, 5, 0)

0 comments on commit 374cd8d

Please sign in to comment.