Skip to content

Commit

Permalink
Merge pull request #47 from MELDProject/dev_docker
Browse files Browse the repository at this point in the history
Fix test and small issues
  • Loading branch information
kwagstyl authored Sep 24, 2024
2 parents abbe536 + e36ff63 commit 11816ea
Show file tree
Hide file tree
Showing 13 changed files with 50 additions and 85 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ It is not appropriate to use this algorithm on patients with:
- hypothalamic hamartoma
- periventricular nodular heterotopia
- other focal epilepsy pathologies
- previous resection cavities

** Harmonisation ** - MRI data from different MRI scanners looks subtly different. This means that feature measurements, e.g. cortical thickness measurements, differ depending on which MRI scanner a patient was scanned on. We harmonise features (using NeuroCombat) to adust for site based differences. We advise new users to harmonise data from their MRI scanner to the MELD graph dataset. Please follow the guidelines to harmonise the data from your site. Note: the model will still produce predictions on new, unharmonised subjects but the number of false positive predictions is higher if the data is not harmonised.
**Harmonisation** - MRI data from different MRI scanners looks subtly different. This means that feature measurements, e.g. cortical thickness measurements, differ depending on which MRI scanner a patient was scanned on. We harmonise features (using NeuroCombat) to adust for site based differences. We advise new users to harmonise data from their MRI scanner to the MELD graph dataset. Please follow the guidelines to harmonise the data from your site. Note: the model will still produce predictions on new, unharmonised subjects but the number of false positive predictions is higher if the data is not harmonised.

This package also contains code for training and evaluating graph-based U-net lesion segmentation models operating on icosphere meshes. \
In addition to lesion segmentation, the model also contain auxiliary distance regression and hemisphere classification losses.
Expand Down
2 changes: 1 addition & 1 deletion meld_graph/data_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def transfer_lesion(self):
)

def make_boundary_zones(self, smoothing=0, boundary_feature_name=".on_lh.boundary_zone.mgh"):
import pp3d
import potpourri3d as pp3d
# preload geodesic distance solver
solver = pp3d.MeshHeatMethodDistanceSolver(self.cohort.surf["coords"], self.cohort.surf["faces"])

Expand Down
2 changes: 1 addition & 1 deletion meld_graph/download_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def download_test_data():
"""
Download test data from figshare
"""
url = "https://figshare.com/ndownloader/files/49265704?private_link=3b790cfb027f4036f19a"
url = "https://figshare.com/ndownloader/files/49366198?private_link=3b790cfb027f4036f19a"
test_data_dir = MELD_DATA_PATH
os.makedirs(test_data_dir, exist_ok=True)
print('downloading test data to '+ test_data_dir)
Expand Down
2 changes: 1 addition & 1 deletion meld_graph/meld_cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def get_subject_ids(self, **kwargs):
else:
groups = [kwargs.get("group", "both")]
# get sites
site_codes = kwargs.get("harmo code", self.get_sites())
site_codes = kwargs.get("site_codes", self.get_sites())
if isinstance(site_codes, str):
site_codes = [site_codes]
# get scanners
Expand Down
54 changes: 17 additions & 37 deletions meld_graph/test/test_data_exists.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,56 +12,36 @@
import numpy as np
import warnings
from meld_graph.test.utils import create_test_demos

sites = [
"H1",
"H2",
"H3",
"H4",
"H5",
"H6",
"H7",
"H9",
"H10",
"H11",
"H12",
"H14",
"H15",
"H16",
"H17",
"H18",
"H19",
"H21",
"H23",
"H24",
"H26",
]
hdf5_file_roots = ["{site_code}_{group}_featurematrix.hdf5", DEFAULT_HDF5_FILE_ROOT]
from meld_graph.tools_pipeline import create_dataset_file
sites = ["TEST"]

#create demo tmp
create_test_demos()

@pytest.mark.data
@pytest.mark.parametrize("hdf5_file_root", hdf5_file_roots)
def test_cohort_exists(hdf5_file_root):
c = MeldCohort(hdf5_file_root=hdf5_file_root)
@pytest.mark.parametrize("site", sites)
def test_cohort_exists(site):
hdf5_file_root = "{site_code}_{group}_featurematrix_combat.hdf5"
c = MeldCohort(hdf5_file_root=hdf5_file_root, dataset='/tmp/dataset_test.csv')
# does exist at all?
if len(c.get_subject_ids()) == 0:
if len(c.get_subject_ids(group='all', site_codes=[site])) == 0:
warnings.warn(f"hdf5_file_root {hdf5_file_root} does not exist on this system.")
return
for site in sites:
patient_ids = c.get_subject_ids(group="patient", site_codes=[site])
if len(patient_ids) == 0:
warnings.warn(f"cohort for {hdf5_file_root} does not have patients for site {site}")
control_ids = c.get_subject_ids(group="control", site_codes=[site])
if len(control_ids) == 0:
warnings.warn(f"cohort for {hdf5_file_root} does not have controls for site {site}")
patient_ids = c.get_subject_ids(group="patient", site_codes=[site])
print(patient_ids)
if len(patient_ids) == 0:
warnings.warn(f"cohort for {hdf5_file_root} does not have patients for site {site}")
control_ids = c.get_subject_ids(group="control", site_codes=[site])
if len(control_ids) == 0:
warnings.warn(f"cohort for {hdf5_file_root} does not have controls for site {site}")

@pytest.mark.data
@pytest.mark.parametrize("site", sites)
def test_borderzone_exists(site):
c = MeldCohort(hdf5_file_root="{site_code}_{group}_featurematrix.hdf5")
hdf5_file_root = "{site_code}_{group}_featurematrix_combat.hdf5"
c = MeldCohort(hdf5_file_root=hdf5_file_root, dataset='/tmp/dataset_test.csv')
subject_ids = c.get_subject_ids(group="patient", lesional_only=True, site_codes=[site])
print(subject_ids)
# get a few random subject_ids to test if has borderzone
for subj_id in np.random.default_rng().choice(subject_ids, size=min(len(subject_ids), 3), replace=False):
subj = MeldSubject(subj_id, cohort=c)
Expand Down
2 changes: 1 addition & 1 deletion meld_graph/test/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def data_parameters():
# Dataset class tests
def test_dataset_flags(data_parameters):
create_test_demos()
c = MeldCohort(hdf5_file_root=data_parameters["hdf5_file_root"])
c = MeldCohort(hdf5_file_root=data_parameters["hdf5_file_root"], dataset='/tmp/dataset_test.csv')

subject_ids = c.get_subject_ids(**data_parameters)
subject_ids = subject_ids[0:5]
Expand Down
20 changes: 8 additions & 12 deletions meld_graph/test/test_meld_cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@
import pandas as pd
from meld_graph.test.utils import create_test_demos

# @pytest.fixture(autouse=True)
# def setup_teardown_tests():
# get_test_data()
# yield

create_test_demos()

Expand All @@ -33,7 +29,7 @@
)
def test_features_overlap_with_list(hdf5_file_root):
"""test if computed full_feature_list overlaps with manual feature list"""
c = MeldCohort(hdf5_file_root=hdf5_file_root)
c = MeldCohort(hdf5_file_root=hdf5_file_root, dataset='/tmp/dataset_test.csv')
if len(c.get_subject_ids()) == 0:
warnings.warn("hdf5_file_root {hdf5_file_root} does not seem to exist on this system. Skipping this test.")
return
Expand Down Expand Up @@ -98,7 +94,7 @@ def test_features_overlap_with_list(hdf5_file_root):

def test_features_consistent():
"""test that all subjects in cohort have same features"""
c = MeldCohort(hdf5_file_root=DEFAULT_HDF5_FILE_ROOT)
c = MeldCohort(hdf5_file_root=DEFAULT_HDF5_FILE_ROOT, dataset='/tmp/dataset_test.csv')
features = c.full_feature_list
for subj_id in c.get_subject_ids():
for hemi in ["lh", "rh"]:
Expand All @@ -108,13 +104,13 @@ def test_features_consistent():


def test_get_sites():
c = MeldCohort(hdf5_file_root=DEFAULT_HDF5_FILE_ROOT)
c = MeldCohort(hdf5_file_root=DEFAULT_HDF5_FILE_ROOT, dataset='/tmp/dataset_test.csv')
# TEST must be in sites, because we just created test data
assert "TEST" in c.get_sites()


def test_cortex_label():
c = MeldCohort(hdf5_file_root=DEFAULT_HDF5_FILE_ROOT)
c = MeldCohort(hdf5_file_root=DEFAULT_HDF5_FILE_ROOT, dataset='/tmp/dataset_test.csv')
# check that cortex is ordered list
assert ((c.cortex_label[1:] - c.cortex_label[:-1]) > 0).all()

Expand All @@ -124,7 +120,7 @@ def test_get_subject_ids():
tests MeldCohort.get_subject_ids
ensure that returned subjects are filtered correctly
"""
c = MeldCohort(hdf5_file_root=DEFAULT_HDF5_FILE_ROOT)
c = MeldCohort(hdf5_file_root=DEFAULT_HDF5_FILE_ROOT, dataset='/tmp/dataset_test.csv')

# test site_codes flag
all_subjects = c.get_subject_ids(site_codes="TEST")
Expand All @@ -139,7 +135,7 @@ def test_get_subject_ids():
assert len(all_subjects) == len(patients) + len(controls)

# test subject_features_to_exclude flag
flair_subjects = c.get_subject_ids(site_codes="H2", subject_features_to_exclude=["FLAIR"])
flair_subjects = c.get_subject_ids(site_codes="TEST", subject_features_to_exclude=["FLAIR"])
_, flair_features = c._filter_features(features_to_exclude=["FLAIR"], return_excluded=True)
for subj_id in flair_subjects:
# does this subject have flair features?
Expand All @@ -154,7 +150,7 @@ def test_get_subject_ids():


def test_get_subject_ids_with_dataset():
c = MeldCohort(hdf5_file_root=DEFAULT_HDF5_FILE_ROOT)
c = MeldCohort(hdf5_file_root=DEFAULT_HDF5_FILE_ROOT, dataset='/tmp/dataset_test.csv')
ds_site_codes = ["TEST"]

# create temp dataset file
Expand All @@ -176,7 +172,7 @@ def test_get_subject_ids_with_dataset():


def test_split_hemispheres():
c = MeldCohort(hdf5_file_root=DEFAULT_HDF5_FILE_ROOT)
c = MeldCohort(hdf5_file_root=DEFAULT_HDF5_FILE_ROOT, dataset='/tmp/dataset_test.csv')
# test that splits data correctly
# create data with ones for left hemi, twoes for right hemi
input_data = np.concatenate([np.ones(len(c.cortex_label)), np.ones(len(c.cortex_label)) * 2])
Expand Down
10 changes: 5 additions & 5 deletions meld_graph/test/test_meld_subject.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# load_feature_lesion_data
# other MeldSubject functions (just syntax)
# NOTE:
# these tests require a test dataset, that is created with get_test_data()
# these tests require a test dataset, that is downloaded with prepare_classifier.py
# executing this function may take a while the first time (while the test data is being created)
# MISSING TESTS:
# - more extensive tests for functions tested in test_meldsubject_api (just tests for syntax)
Expand All @@ -23,15 +23,15 @@ def test_subject_parse():
"""
test if MeldSubject.site_code, .group, .scanner work as expected
"""
c = MeldCohort(hdf5_file_root=DEFAULT_HDF5_FILE_ROOT)
c = MeldCohort(hdf5_file_root=DEFAULT_HDF5_FILE_ROOT, dataset='/tmp/dataset_test.csv')
subj = MeldSubject("MELD_TEST_15T_FCD_0002", c)
assert subj.site_code == "TEST"
assert subj.scanner == "15T"
assert subj.group == "patient"


def test_get_lesion_hemisphere():
c = MeldCohort(hdf5_file_root=DEFAULT_HDF5_FILE_ROOT)
c = MeldCohort(hdf5_file_root=DEFAULT_HDF5_FILE_ROOT, dataset='/tmp/dataset_test.csv')
# ensure that patients have a lesional hemisphere
patients = c.get_subject_ids(site_codes="TEST", group="patient", lesional_only=True)
for subj_id in patients:
Expand All @@ -54,7 +54,7 @@ def test_meldsubject_api(subj_id, hdf5_file_root):
"""
# get_test_data()

c = MeldCohort(hdf5_file_root=hdf5_file_root)
c = MeldCohort(hdf5_file_root=hdf5_file_root, dataset='/tmp/dataset_test.csv')
subj = MeldSubject(subj_id, cohort=c)

subj.get_demographic_features("Age")
Expand All @@ -73,7 +73,7 @@ def test_load_feature_lesion_data(subj_id, hdf5_file_root):
# TODO also test on TEST site where we know the expected feature values?
# get_test_data()

c = MeldCohort(hdf5_file_root=hdf5_file_root)
c = MeldCohort(hdf5_file_root=hdf5_file_root, dataset='/tmp/dataset_test.csv')
subj = MeldSubject(subj_id, cohort=c)
lesion_hemi = subj.get_lesion_hemisphere()
for hemi in ["lh", "rh"]:
Expand Down
5 changes: 4 additions & 1 deletion meld_graph/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,21 @@
import numpy as np
import pandas as pd
from meld_graph.paths import BASE_PATH, DEFAULT_HDF5_FILE_ROOT, MELD_DATA_PATH
from meld_graph.tools_pipeline import create_dataset_file

def create_test_demos():
data = {"ID":{"0":"sub-test001","1":"MELD_TEST_15T_C_0001","2":"MELD_TEST_15T_C_0002","3":"MELD_TEST_15T_C_0003","4":"MELD_TEST_15T_C_0004","5":"MELD_TEST_15T_C_0005","6":"MELD_TEST_3T_C_0001","7":"MELD_TEST_3T_C_0002","8":"MELD_TEST_3T_C_0003","9":"MELD_TEST_3T_C_0004","10":"MELD_TEST_3T_C_0005","11":"MELD_TEST_15T_FCD_0002","12":"MELD_TEST_15T_FCD_0003","13":"MELD_TEST_15T_FCD_0004","14":"MELD_TEST_15T_FCD_0005","15":"MELD_TEST_15T_FCD_0006","16":"MELD_TEST_3T_FCD_0002","17":"MELD_TEST_3T_FCD_0003","18":"MELD_TEST_3T_FCD_0004","19":"MELD_TEST_3T_FCD_0005","20":"MELD_TEST_3T_FCD_0006"},
"Harmo code":{"0":"TEST","1":"TEST","2":"TEST","3":"TEST","4":"TEST","5":"TEST","6":"TEST","7":"TEST","8":"TEST","9":"TEST","10":"TEST","11":"TEST","12":"TEST","13":"TEST","14":"TEST","15":"TEST","16":"TEST","17":"TEST","18":"TEST","19":"TEST","20":"TEST"},
"Group ":{"0":"patient","1":"control","2":"control","3":"control","4":"control","5":"control","6":"control","7":"control","8":"control","9":"control","10":"control","11":"patient","12":"patient","13":"patient","14":"patient","15":"patient","16":"patient","17":"patient","18":"patient","19":"patient","20":"patient"},
"Age at preoperative":{"0":25,"1":7,"2":9,"3":14,"4":3,"5":15,"6":26,"7":22,"8":4,"9":5,"10":12,"11":4,"12":6,"13":20,"14":12,"15":7,"16":4,"17":6,"18":10,"19":9,"20":12},
"Sex":{"0":1,"1":0,"2":1,"3":0,"4":0,"5":1,"6":0,"7":1,"8":0,"9":0,"10":1,"11":0,"12":1,"13":1,"14":0,"15":1,"16":0,"17":0,"18":1,"19":0,"20":0},
"Scanner":{"0":"3T","1":"15T","2":"15T","3":"15T","4":"15T","5":"15T","6":"3T","7":"3T","8":"3T","9":"3T","10":"3T","11":"15T","12":"15T","13":"15T","14":"15T","15":"15T","16":"3T","17":"3T","18":"3T","19":"3T","20":"3T"}}
"Scanner":{"0":"3T","1":"15T","2":"15T","3":"15T","4":"15T","5":"15T","6":"3T","7":"3T","8":"3T","9":"3T","10":"3T","11":"15T","12":"15T","13":"15T","14":"15T","15":"15T","16":"3T","17":"3T","18":"3T","19":"3T","20":"3T"},}

df = pd.DataFrame(data)
df.to_csv(os.path.join(MELD_DATA_PATH,'/tmp/demographics_file.csv'))

create_dataset_file(df['ID'].values, '/tmp/dataset_test.csv')

def create_test_data():
"""
This function was initially used to create the random test dataset.
Expand Down
8 changes: 8 additions & 0 deletions meld_graph/tools_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,12 @@ def create_demographic_file(subjects_ids, save_file, harmo_code='noHarmo'):
df['Harmo code']=[str(harmo_code) for subject in subjects_ids]
df['Group']=['patient' for subject in subjects_ids]
df['Scanner']=['3T' for subject in subjects_ids]
df.to_csv(save_file)

def create_dataset_file(subjects_ids, save_file):
df=pd.DataFrame()
if isinstance(subjects_ids, str):
subjects_ids=[subjects_ids]
df['subject_id']=subjects_ids
df['split']=['test' for subject in subjects_ids]
df.to_csv(save_file)
9 changes: 1 addition & 8 deletions scripts/new_patient_pipeline/run_script_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,8 @@
from scripts.manage_results.register_back_to_xhemi import register_subject_to_xhemi
from scripts.manage_results.move_predictions_to_mgh import move_predictions_to_mgh
from scripts.manage_results.plot_prediction_report import generate_prediction_report
from meld_graph.tools_pipeline import get_m, create_demographic_file
from meld_graph.tools_pipeline import get_m, create_demographic_file, create_dataset_file

def create_dataset_file(subjects_ids, save_file):
df=pd.DataFrame()
if isinstance(subjects_ids, str):
subjects_ids=[subjects_ids]
df['subject_id']=subjects_ids
df['split']=['test' for subject in subjects_ids]
df.to_csv(save_file)

def predict_subjects(subject_ids, output_dir, plot_images = False, saliency=False,
experiment_path=EXPERIMENT_PATH, hdf5_file_root= DEFAULT_HDF5_FILE_ROOT,):
Expand Down
10 changes: 1 addition & 9 deletions scripts/new_patient_pipeline/run_script_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from os.path import join as opj
from meld_graph.meld_cohort import MeldCohort
from meld_graph.data_preprocessing import Preprocess, Feature
from meld_graph.tools_pipeline import get_m, create_demographic_file
from meld_graph.tools_pipeline import get_m, create_demographic_file, create_dataset_file
from meld_graph.paths import (
BASE_PATH,
MELD_PARAMS_PATH,
Expand All @@ -28,14 +28,6 @@
)


def create_dataset_file(subjects_ids, save_file):
df=pd.DataFrame()
if isinstance(subjects_ids, str):
subjects_ids=[subjects_ids]
df['subject_id']=subjects_ids
df['split']=['test' for subject in subjects_ids]
df.to_csv(save_file)

def which_combat_file(harmo_code):
file_site=os.path.join(BASE_PATH, f'MELD_{harmo_code}', f'{harmo_code}_combat_parameters.hdf5')
if os.path.isfile(file_site):
Expand Down
8 changes: 0 additions & 8 deletions scripts/new_patient_pipeline/run_script_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,14 +224,6 @@ def extract_features(subject_id, fs_folder, output_dir, verbose=False):
result = create_training_data_hdf5(subject_id, fs_folder, output_dir )
if result == False:
return False

def create_dataset_file(subjects, output_path):
df=pd.DataFrame()
subjects_id = [subject for subject in subjects]
df['subject_id']=subjects_id
df['split']=['test' for subject in subjects]
df.to_csv(output_path)


def run_subjects_segmentation_parallel(subject_ids, num_procs=10, harmo_code="noHarmo", use_fastsurfer=False, verbose=False):
# parallel version of the pipeline, finish each stage for all subjects first
Expand Down

0 comments on commit 11816ea

Please sign in to comment.