diff --git a/pyproject.toml b/pyproject.toml index 42dffd7..a10bfc4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,7 @@ combine_segmentations = "nuc2seg.cli.combine_segmentations:main" combine_predictions = "nuc2seg.cli.combine_predictions:main" get_n_tiles = "nuc2seg.cli.get_n_tiles:main" segmented_xenium_to_anndata = "nuc2seg.cli.segmented_xenium_to_anndata:main" +calculate_benchmarks = "nuc2seg.cli.calculate_benchmarks:main" [build-system] requires = ["setuptools>=43.0.0", "wheel"] diff --git a/src/nuc2seg/cli/calculate_benchmarks.py b/src/nuc2seg/cli/calculate_benchmarks.py index 49ddfc0..7801bef 100644 --- a/src/nuc2seg/cli/calculate_benchmarks.py +++ b/src/nuc2seg/cli/calculate_benchmarks.py @@ -2,10 +2,13 @@ import logging import os +import pandas +import geopandas +import tqdm + from nuc2seg import log_config from nuc2seg.postprocess import ( - calculate_average_intersection_over_union, - calculate_segmentation_jaccard_index, + calculate_benchmarks_with_nuclear_prior, ) from nuc2seg.xenium import load_vertex_file, load_and_filter_transcripts_as_points @@ -18,8 +21,8 @@ def get_parser(): ) log_config.add_logging_args(parser) parser.add_argument( - "--output-dir", - help="Directory for output tables and plots.", + "--output-file", + help="Output parquet file to save the results.", type=str, required=True, ) @@ -56,6 +59,14 @@ def get_parser(): type=str, required=True, ) + + parser.add_argument( + "--xenium-cell-metadata", + help="Cell metadata.", + type=str, + required=True, + ) + parser.add_argument("--chunk-size", help="Chunk size", type=int, default=5_000) return parser @@ -63,34 +74,48 @@ def main(): args = get_parser().parse_args() logger.info("Loading true boundaries.") - true_boundaries = load_vertex_file(args.true_boundaries) + true_gdf = load_vertex_file(args.true_boundaries) logger.info("Loading transcripts.") transcripts = load_and_filter_transcripts_as_points(args.transcripts) - segmentations = {} - ious = {} - jaccards = {} + nuclear_gdf = load_vertex_file(args.nuclei_boundaries) - os.makedirs(args.output_dir, exist_ok=True) + cell_metadata_df = pandas.read_parquet(args.xenium_cell_metadata) + cell_metadata_gdf = geopandas.GeoDataFrame( + cell_metadata_df, + geometry=geopandas.points_from_xy( + transcripts["x_centroid"], transcripts["y_centroid"] + ), + ) + true_gdf = geopandas.sjoin(true_gdf, cell_metadata_gdf, how="inner") + + os.makedirs(args.output_dir, exist_ok=True) + dfs = [] for method_name, seg_fn in zip( args.segmentation_method_names, args.segmentation_files ): logger.info(f"Loading segmentation from {seg_fn}.") - segmentation_shapes = load_vertex_file(seg_fn) + method_gdf = load_vertex_file(seg_fn) + segments_chunk_size = args.chunk_size + + logger.info(f"Converting transcripts to anndata") logger.info(f"Calculating benchmarks for {method_name}.") - iou = calculate_average_intersection_over_union( - segmentation_shapes, true_boundaries - ) - jaccard = calculate_segmentation_jaccard_index( - transcripts, segmentation_shapes, true_boundaries - ) - - segmentations[method_name] = segmentation_shapes - ious[method_name] = iou - jaccards[method_name] = jaccard - - iou.to_parquet(f"{args.output_dir}/{method_name}_iou.parquet") - jaccard.to_parquet(f"{args.output_dir}/{method_name}_jaccard.parquet") + + for i in tqdm.tqdm(range(0, len(method_gdf), segments_chunk_size)): + chunk = method_gdf[i : i + segments_chunk_size] + + dfs.append( + calculate_benchmarks_with_nuclear_prior( + true_segs=true_gdf, + method_segs=chunk, + nuclear_segs=nuclear_gdf, + transcripts_gdf=transcripts, + ) + ) + + df = pandas.concat(dfs) + + df.to_parquet(args.output_file) diff --git a/src/nuc2seg/postprocess.py b/src/nuc2seg/postprocess.py index 4644641..9971066 100644 --- a/src/nuc2seg/postprocess.py +++ b/src/nuc2seg/postprocess.py @@ -664,6 +664,12 @@ def get_jaccard_truth_segment(row): right_on="method_segment_id", ) + results = results.merge( + true_segs[["truth_segment_id", "segmentation_method"]], + left_on="truth_segment_id", + right_on="truth_segment_id", + ) + return results