Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3D Feature Extraction debugging #122

Merged
merged 11 commits into from
Dec 10, 2024
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ install_requires =
scikit-learn
scipy
tqdm
vtk
vtk==9.3.1
python_requires = >=3.8
package_dir =
= src
Expand All @@ -56,7 +56,7 @@ where = src
[options.extras_require]
fractal-tasks =
anndata
fractal-tasks-core==1.2.1
fractal-tasks-core==1.3.3
plotting =
anndata
dask
Expand Down
6 changes: 6 additions & 0 deletions src/scmultiplex/__FRACTAL_MANIFEST__.json
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,12 @@
"title": "Expand By Factor",
"type": "number",
"description": "Multiplier that specifies pixels by which to expand each label. Float in range [0, 1 or higher], e.g. 0.2 means that 20% of mean equivalent diameter of labels in region is used."
},
"mask_expansion_by_parent": {
"default": false,
"title": "Mask Expansion By Parent",
"type": "boolean",
"description": "If True, final expanded labels are masked by group_by object. Recommended to set to True for child/parent masking."
}
},
"required": [
Expand Down
7 changes: 7 additions & 0 deletions src/scmultiplex/fractal/expand_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def expand_labels(
expand_by_pixels: Union[int, None] = None,
calculate_image_based_expansion_distance: bool = False,
expand_by_factor: Union[float, None] = None,
mask_expansion_by_parent: bool = False,
) -> dict[str, Any]:
"""
Expand labels in 2D or 3D image without overlap.
Expand All @@ -72,6 +73,8 @@ def expand_labels(
be supplied.
expand_by_factor: Multiplier that specifies pixels by which to expand each label. Float in range
[0, 1 or higher], e.g. 0.2 means that 20% of mean equivalent diameter of labels in region is used.
mask_expansion_by_parent: If True, final expanded labels are masked by group_by object. Recommended to set
to True for child/parent masking.
"""

logger.info(
Expand Down Expand Up @@ -243,6 +246,10 @@ def expand_labels(
expandby,
expansion_distance_image_based=calculate_image_based_expansion_distance,
)

if mask_expansion_by_parent and group_by is not None:
seg_expanded = seg_expanded * parent_mask

logger.info(f"Expanded label(s) in region {label_str} by {distance} pixels.")

##############
Expand Down
55 changes: 48 additions & 7 deletions src/scmultiplex/fractal/scmultiplex_feature_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

__OME_NGFF_VERSION__ = fractal_tasks_core.__OME_NGFF_VERSION__

from scmultiplex.meshing.FilterFunctions import mask_by_parent_object

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -196,6 +197,27 @@ def scmultiplex_feature_measurements( # noqa: C901
f"the label image {label_image}"
)

# If relevant, load parent object segmentation to mask child objects
if use_ROI_masks:
# TODO: Abstract this operation using ngio. It currently loads the
# masking label image based on the masking roi table metadata
roi_table_path = f"{zarr_url}/tables/{input_ROI_table}"
with zarr.open(roi_table_path, mode="r") as zarr_store:
attrs = zarr_store.attrs
masking_label_url = (
f"{zarr_url}/tables/{dict(attrs)['region']['path']}/{label_level}"
)
# Load well image as dask array for parent objects
mask_dask = da.from_zarr(masking_label_url)

# Upscale masking image to target resolution
mask_dask = upscale_array(
array=mask_dask,
target_shape=target_shape,
axis=axis,
pad_with_zeros=True,
)

# Loop over ROIs to make measurements
df_well = pd.DataFrame()
df_info_well = pd.DataFrame()
Expand All @@ -205,8 +227,9 @@ def scmultiplex_feature_measurements( # noqa: C901

logger.debug(f"ROI {i_ROI+1}/{num_ROIs}: {region=}")

# Define some constant values to be added as a separat column to
# Define some constant values to be added as a separate column to
# the obs table
# TODO: consider chanding "ROI_name" to "ROI_index"
extra_values = {
"ROI_table_name": input_ROI_table,
"ROI_name": ROI_table.obs.index[i_ROI],
Expand All @@ -218,10 +241,26 @@ def scmultiplex_feature_measurements( # noqa: C901

label_img = input_label_image[region].compute()
if use_ROI_masks:
current_label = int(ROI_table.obs.iloc[i_ROI]["label"])
background = label_img != current_label
label_img[background] = 0
current_label = int(float(ROI_table.obs.iloc[i_ROI]["label"]))
extra_values["ROI_label"] = current_label
# For feature extraction of child objects (e.g. nuclei) masked
# by parent (e.g. organoid), mask by parent image
label_img, parent_mask = mask_by_parent_object(
label_img, mask_dask, list_indices, i_ROI, current_label
)
# Only proceed if labelmap is not empty
if np.amax(label_img) == 0:
logger.warning(
f"Skipping region label {current_label}. Label image "
"contains no labeled objects."
)
# Skip this object
continue
else:
logger.info(
f"Calculating features for {label_image} object(s) masked "
f"by region label {current_label}"
)

if label_img.shape[0] == 1:
logger.debug("Label image is 2D only, processing with 2D options")
Expand All @@ -239,7 +278,6 @@ def scmultiplex_feature_measurements( # noqa: C901
f"Loaded an image of shape {label_img.shape}. "
"Processing is only supported for 2D & 3D images"
)

# Set inputs
df_roi = pd.DataFrame()
df_info_roi = pd.DataFrame()
Expand Down Expand Up @@ -273,7 +311,6 @@ def scmultiplex_feature_measurements( # noqa: C901
channel_prefix=input_name,
extra_values=extra_values,
)

# Only measure morphology for the first intensity channel provided
# => just once per label image
first_channel = False
Expand Down Expand Up @@ -331,7 +368,7 @@ def scmultiplex_feature_measurements( # noqa: C901
df_well = df_well.astype(measurement_dtype)
df_well.index = df_well.index.map(str)
# Convert to anndata
measurement_table = ad.AnnData(df_well, dtype=measurement_dtype)
measurement_table = ad.AnnData(df_well)
measurement_table.obs = df_info_well
else:
# Create empty anndata table
Expand All @@ -352,6 +389,10 @@ def scmultiplex_feature_measurements( # noqa: C901
),
)

logger.info(f"End feature_measurement task for {zarr_url}/labels/{label_image}")

return {}


if __name__ == "__main__":
from fractal_tasks_core.tasks._utils import run_fractal_task
Expand Down
2 changes: 1 addition & 1 deletion tests/fractal/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def linking_zenodo_zarrs(testdata_path: Path) -> list[str]:
"""

# 1 Download Zarrs from Zenodo
DOI = "10.5281/zenodo.10683086"
DOI = "10.5281/zenodo.13982701"
DOI_slug = DOI.replace("/", "_").replace(".", "_")
platenames = [
"220605_151046.zarr",
Expand Down
68 changes: 34 additions & 34 deletions tests/fractal/integration/test_fractal_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
name_mip = "220605_151046_mip.zarr"

test_calculate_object_linking_expected_output = np.array(
[[1.0, 1.0, 0.9305816], [2.0, 2.0, 0.78365386], [3.0, 3.0, 0.95049506]]
[[1.0, 1.0, 0.9727956], [2.0, 2.0, 0.8249731], [3.0, 3.0, 0.9644809]]
)

test_calculate_linking_consensus_expected_output = np.array(
Expand All @@ -45,59 +45,60 @@
test_relabel_by_linking_consensus_output_dict = {
"0": np.array(
[
[13.0, 20.8, 0.0, 22.533333, 21.666666, 0.6],
[13.0, 90.13333, 0.0, 20.8, 18.2, 0.6],
[38.133335, 103.13333, 0.0, 25.133333, 29.466667, 0.6],
[13.216666, 21.016666, 0.0, 22.1, 20.583334, 0.6],
[13.866667, 91.433334, 0.0, 18.85, 15.816667, 0.6],
[39.0, 103.566666, 0.0, 23.183332, 28.166666, 0.6],
]
),
"1": np.array(
[
[49.4, 10.4, 0.0, 23.4, 22.533333, 0.6],
[51.133335, 81.46667, 0.0, 19.933332, 16.466667, 0.6],
[74.53333, 92.73333, 0.0, 24.266666, 29.466667, 0.6],
[49.616665, 11.05, 0.0, 22.533333, 20.583334, 0.6],
[51.783333, 81.9, 0.0, 18.85, 15.816667, 0.6],
[75.4, 93.816666, 0.0, 23.4, 27.95, 0.6],
]
),
}

test_calculate_platymatch_registration_output = np.array(
[
[1.0, 1.0],
[3.0, 2.0],
[2.0, 3.0],
[2.0, 2.0],
[3.0, 3.0],
[4.0, 4.0],
[5.0, 5.0],
[6.0, 6.0],
[7.0, 7.0],
[8.0, 8.0],
[9.0, 9.0],
[10.0, 10.0],
[11.0, 11.0],
[12.0, 12.0],
[13.0, 13.0],
[14.0, 14.0],
[15.0, 15.0],
[16.0, 16.0],
[8.0, 7.0],
[9.0, 8.0],
[10.0, 9.0],
[11.0, 10.0],
[13.0, 12.0],
[14.0, 13.0],
[15.0, 14.0],
[16.0, 15.0],
[17.0, 16.0],
[19.0, 17.0],
[18.0, 18.0],
[20.0, 19.0],
]
)

test_sphr_harmonics_from_labelimg_expected_output = np.array(
[10.87666, 9.29089, 13.33493]
[10.86649, 9.2556, 13.3412]
)

test_sphr_harmonics_from_mesh_expected_output = np.array(
[10.116642, 8.294999, 12.791296]
[10.115658, 8.254395, 12.78826]
)

test_scmultiplex_mesh_measurements_expected_output = np.array(
[
4.31434180e03,
1.30477405e03,
1.01806331e00,
4.67207789e-01,
9.94542837e-01,
5.45713631e-03,
1.40581485e-02,
1.12836015e00,
1.00899124e00,
4.3092324e03,
1.3029731e03,
1.0174618e00,
4.6224737e-01,
9.9401093e-01,
5.9890612e-03,
9.5799491e-03,
1.1295289e00,
1.0086931e00,
]
)

Expand Down Expand Up @@ -142,7 +143,6 @@ def test_calculate_object_linking(linking_zenodo_zarrs, name=name_mip):
output_table_path = f"{zarr_url}/tables/{label_name}_match_table"

output = ad.read_zarr(output_table_path).to_df().to_numpy()

assert_almost_equal(output, test_calculate_object_linking_expected_output)


Expand Down Expand Up @@ -231,7 +231,7 @@ def test_calculate_platymatch_registration(linking_zenodo_zarrs, name=name_3d):
calculate_ffd=True,
seg_channel=channel,
volume_filter=True,
volume_filter_threshold=0.05,
volume_filter_threshold=0.10,
)

output_table_path_affine = (
Expand Down
45 changes: 44 additions & 1 deletion tests/fractal/test_fractal_measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,16 +226,20 @@ def test_3D_fractal_measurements(


inputs_masked = [{}, single_input_channels]
inputs_masked = [single_input_channels]


@pytest.mark.filterwarnings("ignore:Transforming to str index.")
@pytest.mark.parametrize("input_channels,", inputs_masked)
def test_masked_measurements(
def test_masked_measurements_test(
tiny_zenodo_zarrs_base_path,
metadata_tiny_zenodo,
column_names,
input_channels,
):
# FIXME: Get a smaller test dataset to test this on. This test takes ~40s.
# Criteria: Has masking ROI table, resolution of label image != resolution
# of intensity image.
# Test measuring when using a ROI table with masks
allow_duplicate_labels = False
zarr_url = f"{tiny_zenodo_zarrs_base_path}/{image_path_2D}"
Expand Down Expand Up @@ -287,6 +291,45 @@ def test_masked_measurements(
assert_frame_equal(df, df_expected)


@pytest.mark.filterwarnings("ignore:Transforming to str index.")
def test_masked_measurements_with_orgs_and_nuc(
linking_zenodo_zarrs,
):
"""
The purpose is to test masked measurements where the mask is a different
label image than the labels to be measured. See
https://github.com/fmi-basel/gliberal-scMultipleX/pull/122 for details.
"""
allow_duplicate_labels = False
zarr_url = f"{linking_zenodo_zarrs[0]}/C/02/0"
input_ROI_table = "org_ROI_table"
measure_morphology = True
output_table_name = "measurements_nuc_masked"
input_channels = {"C01": ChannelInputModel(wavelength_id="A04_C01")}

# Prepare fractal task
label_image = "nuc"

scmultiplex_feature_measurements(
zarr_url=zarr_url,
input_ROI_table=input_ROI_table,
input_channels=input_channels,
label_image=label_image,
label_level=label_level,
level=level,
output_table_name=output_table_name,
measure_morphology=measure_morphology,
allow_duplicate_labels=allow_duplicate_labels,
)

# Check that there are measurement for all 20 nuclei (before #122,
# there was only 1 measurements)
# Check & verify the output_table
ad_path = Path(zarr_url) / "tables" / output_table_name
df = load_features_for_well(ad_path)
assert len(df) == 20


inputs_empty = [
({}, True),
({}, False),
Expand Down
Loading