diff --git a/baselines/generate_figures.py b/baselines/generate_figures.py index b99d8462..d1ad7b67 100644 --- a/baselines/generate_figures.py +++ b/baselines/generate_figures.py @@ -130,16 +130,24 @@ def parse_xml_file(file_path): def fetch_site_and_method(input_string, pred_type): """ - Fetch the file and method from the input string + Fetch the subject_id, site and method from the input string :param input_string: input string, e.g. 'sub-5416_T2w_seg_nnunet' + :return subject_id: subject id, e.g. 'sub-5416' :return site: site name, e.g. 'zurich' or 'colorado' :return method: segmentation method, e.g. 'nnunet' """ + + # Fetch subject id + subject = re.search('sub-(.*?)[_/]', input_string) # [_/] slash or underscore + subject_id = subject.group(0)[:-1] if subject else "" # [:-1] removes the last underscore or slash + + # Fetch site if 'sub-zh' in input_string: site = 'zurich' else: site = 'colorado' - + + # Fetch method if pred_type == 'sc': method = input_string.split('_seg_')[1] elif pred_type == 'lesion': @@ -147,7 +155,7 @@ def fetch_site_and_method(input_string, pred_type): else: raise ValueError(f'Unknown pred_type: {pred_type}') - return site, method + return subject_id, site, method def print_mean_and_std(df, list_of_metrics, pred_type): @@ -383,9 +391,11 @@ def main(): list_of_metrics.append('ExecutionTime[s]') # Apply the fetch_filename_and_method function to each row using a lambda function - df[['site', 'method']] = df['filename'].apply(lambda x: pd.Series(fetch_site_and_method(x, pred_type))) + df[['subject_id', 'site', 'method']] = df['filename'].\ + apply(lambda x: pd.Series(fetch_site_and_method(x, pred_type))) # Reorder the columns - df = df[['filename', 'site', 'method'] + [col for col in df.columns if col not in ['filename', 'site', 'method']]] + df = df[['filename', 'subject_id', 'site', 'method'] + [col for col in df.columns if col not in + ['filename', 'subject_id', 'site', 'method']]] # remove '_fullres' from the method column df['method'] = df['method'].apply(lambda x: x.replace('_fullres', '')) @@ -395,6 +405,10 @@ def main(): # Concatenate the list of dataframes into a single dataframe df_concat = pd.concat(list_of_df, ignore_index=True) + # Remove 'sub-5740' (https://github.com/ivadomed/model_seg_sci/issues/59) + logger.info(f'Removing subject sub-5740 from the dataframe.') + df_concat = df_concat[df_concat['subject_id'] != 'sub-5740'] + # Print colorado subjects with Dice=0 print_colorado_subjects_with_dice_0(df_concat)