Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Baysor Benchmark Preprocessing/Postprocessing Tiling #46

Merged
merged 4 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions nextflow/modules/nf-core/baysor_postprocess/main.nf
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
process BAYSOR_POSTPROCESS {
tag "$meta.id"
label 'process_low'
container "${ workflow.containerEngine == 'singularity' && !task.ext.singularity_pull_docker_container ?
'docker://jeffquinnmsk/nuc2seg:latest' :
'docker.io/jeffquinnmsk/nuc2seg:latest' }"

input:
tuple val(meta), path(xenium_dir), path(shapefiles)

output:
tuple val(meta), path("${prefix}/baysor/segmentation.parquet"), emit: segmentation


script:
prefix = task.ext.prefix ?: "${meta.id}"
def args = task.ext.args ?: ""
"""
mkdir -p "${prefix}/baysor"
baysor_postprocess \
--transcripts ${xenium_dir}/transcripts.parquet \
--output ${prefix}/baysor/segmentation.parquet \
--baysor-shapefiles ${shapefiles} \
${args}
"""
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,17 @@ process BAYSOR_PREPROCESS_TRANSCRIPTS {
tuple val(meta), path(xenium_dir)

output:
tuple val(meta), path("${prefix}/baysor/baysor_transcripts.csv"), emit: baysor_transcripts
tuple val(meta), path("${prefix}/baysor/input/*.csv"), emit: transcripts


script:
prefix = task.ext.prefix ?: "${meta.id}"
def args = task.ext.args ?: ""
def sample_area_flag = params.sample_area == null ? "" : "--sample-area ${params.sample_area}"
"""
mkdir -p "${prefix}/baysor"
mkdir -p "${prefix}/baysor/input"
baysor_preprocess_transcripts \
--transcripts ${xenium_dir}/transcripts.parquet \
--output-path ${prefix}/baysor/baysor_transcripts.csv \
--output-dir ${prefix}/baysor/input \
${args}
"""
}
8 changes: 7 additions & 1 deletion nextflow/workflows/nf-core/baysor/main.nf
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
include { BAYSOR } from '../../../modules/nf-core/baysor/main'
include { BAYSOR_PREPROCESS_TRANSCRIPTS } from '../../../modules/nf-core/baysor_preprocess_transcripts/main'
include { BAYSOR_POSTPROCESS } from '../../../modules/nf-core/baysor_postprocess/main'


workflow BAYSOR_SEGMENTATION {
def name = params.name == null ? "nuc2seg" : params.name
Expand All @@ -24,9 +26,13 @@ workflow BAYSOR_SEGMENTATION {
.tap { baysor_param_sweep }


BAYSOR_PREPROCESS_TRANSCRIPTS.out.baysor_transcripts
BAYSOR_PREPROCESS_TRANSCRIPTS.out.transcripts
.combine( baysor_param_sweep )
.tap { baysor_input }

BAYSOR( baysor_input )

ch_input.join(BAYSOR.out.shapes.groupTuple()).tap { postprocess_input }

BAYSOR_POSTPROCESS( postprocess_input )
}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ segment = "nuc2seg.cli.segment:main"
celltyping = "nuc2seg.cli.celltyping:main"
autofluorescence = "nuc2seg.cli.autofluorescence_benchmark:main"
baysor_preprocess_transcripts = "nuc2seg.cli.baysor_preprocess_transcripts:main"

baysor_postprocess = "nuc2seg.cli.baysor_postprocess:main"

[build-system]
requires = ["setuptools>=43.0.0", "wheel"]
Expand Down
97 changes: 97 additions & 0 deletions src/nuc2seg/cli/baysor_postprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import argparse
import logging
import geopandas as gpd
import pandas
import math

from nuc2seg import log_config
from nuc2seg.postprocess import stitch_shapes

logger = logging.getLogger(__name__)


def get_parser():
parser = argparse.ArgumentParser(
description="Benchmark cell segmentation given post-Xenium IF data that includes an autofluorescence marker."
)
log_config.add_logging_args(parser)
parser.add_argument(
"--baysor-shapefiles",
help="One or more shapefiles output by baysor.",
type=str,
required=True,
nargs="+",
)

parser.add_argument(
"--transcripts",
help="Xenium transcripts in parquet format.",
type=str,
required=True,
)
parser.add_argument(
"--output",
required=True,
type=str,
help="Output file.",
)
parser.add_argument(
"--sample-area",
default=None,
type=str,
help='Crop the dataset to this rectangle, provided in in "x1,y1,x2,y2" format.',
)
parser.add_argument(
"--tile-height",
help="Height of the tiles.",
type=int,
default=1000,
)
parser.add_argument(
"--tile-width",
help="Width of the tiles.",
type=int,
default=1000,
)
parser.add_argument(
"--overlap-percentage",
help="What percent of each tile dimension overlaps with the next tile.",
type=float,
default=0.5,
)

return parser


def get_args():
parser = get_parser()

args = parser.parse_args()

return args


def main():
args = get_args()

log_config.configure_logging(args)

transcript_df = pandas.read_parquet(args.transcripts)

x_extent = math.ceil(transcript_df["x_location"].astype(float).max())
y_extent = math.ceil(transcript_df["y_location"].astype(float).max())

shapefiles = sorted(
args.baysor_shapefiles, key=lambda x: int(x.split("_")[-1].split(".")[0])
)

gdfs = [gpd.read_file(shapefile) for shapefile in shapefiles]

stitched_shapes = stitch_shapes(
shapes=gdfs,
tile_size=(args.tile_width, args.tile_height),
base_size=(x_extent, y_extent),
overlap=args.overlap_percentage,
)

stitched_shapes.to_parquet(args.output)
42 changes: 30 additions & 12 deletions src/nuc2seg/cli/baysor_preprocess_transcripts.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import argparse
import logging
import pandas
import os.path

from nuc2seg import log_config
from nuc2seg.xenium import (
load_and_filter_transcripts,
create_shapely_rectangle,
)
from nuc2seg.preprocessing import tile_transcripts_to_csv

logger = logging.getLogger(__name__)

Expand All @@ -18,8 +17,8 @@ def get_parser():
)
log_config.add_logging_args(parser)
parser.add_argument(
"--output-path",
help="Destination for transcript CSV.",
"--output-dir",
help="Directory to save baysor input CSVs.",
type=str,
required=True,
)
Expand All @@ -41,6 +40,24 @@ def get_parser():
type=float,
default=20.0,
)
parser.add_argument(
"--tile-height",
help="Height of the tiles.",
type=int,
default=1000,
)
parser.add_argument(
"--tile-width",
help="Width of the tiles.",
type=int,
default=1000,
)
parser.add_argument(
"--overlap-percentage",
help="What percent of each tile dimension overlaps with the next tile.",
type=float,
default=0.5,
)

return parser

Expand All @@ -64,23 +81,24 @@ def main():
sample_area = create_shapely_rectangle(
*[float(x) for x in args.sample_area.split(",")]
)

else:
df = pandas.read_parquet(args.transcripts)
y_max = df["y_location"].max()
x_max = df["x_location"].max()

sample_area = create_shapely_rectangle(0, 0, x_max, y_max)
sample_area = None

transcripts = load_and_filter_transcripts(
transcripts_file=args.transcripts,
sample_area=sample_area,
min_qv=args.min_qv,
)

mask = (transcripts["cell_id"] > 0) & (transcripts["overlaps_nucleus"].astype(bool))

transcripts["nucleus_id"] = 0
transcripts.loc[mask, "nucleus_id"] = transcripts["cell_id"][mask]

logger.info(f"Writing CSV to {args.output_path}")
transcripts.to_csv(os.path.join(args.output_path), index=False)
logger.info(f"Writing CSVs to {args.output_path}")
tile_transcripts_to_csv(
transcripts=transcripts,
tile_size=(args.tile_height, args.tile_width),
overlap=args.overlap_percentage,
output_dir=args.output_dir,
)
8 changes: 1 addition & 7 deletions src/nuc2seg/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,7 @@ def generate_tiles(
tiler: TilingModule, x_extent, y_extent, tile_size, overlap_fraction, tile_ids=None
):
"""
A generator function to yield overlapping tiles from a 2D NumPy array (image).

Parameters:
- image: 2D NumPy array representing the image.
- tile_size: Tuple of (tile_height, tile_width), the size of each tile.
- overlap_fraction: Fraction of overlap between tiles (0 to 1).
- tile_ids: List of tile IDs to generate. If None, all tiles are generated.
A generator function to yield overlapping tiles

Yields:
- BBox extent in pixels for each tile (non inclusive end) x1, y1, x2, y2
Expand Down
51 changes: 51 additions & 0 deletions src/nuc2seg/postprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import geopandas as gpd
import numpy as np
import pandas as pd

from nuc2seg.data import generate_tiles
from blended_tiling import TilingModule
from shapely import box


def stitch_shapes(shapes: list[gpd.GeoDataFrame], tile_size, base_size, overlap):
tiler = TilingModule(
tile_size=tile_size,
tile_overlap=(overlap, overlap),
base_size=base_size,
)

tile_masks = tiler.get_tile_masks()[:, 0, :, :]

bboxes = generate_tiles(
tiler,
x_extent=base_size[0],
y_extent=base_size[1],
tile_size=tile_size,
overlap_fraction=overlap,
)

all_shapes = []
for (mask, shapes), bbox in zip(zip(tile_masks, shapes), bboxes):
mask = mask.detach().cpu().numpy()
mask = ~(mask < 1).astype(bool)

# get the index of the upper left most true value
x, y = np.where(mask)
x_min, x_max = x.min(), x.max()
y_min, y_max = y.min(), y.max()

offset_x = bbox[0]
offset_y = bbox[1]

selection_box = box(
x_min + offset_x,
y_min + offset_y,
x_max + offset_x + 1,
y_max + offset_y + 1,
)

# select only rows where box contains the shapes
selection_vector = shapes.geometry.within(selection_box)
all_shapes.append(shapes[selection_vector])

return gpd.GeoDataFrame(pd.concat(all_shapes, ignore_index=True))
38 changes: 38 additions & 0 deletions src/nuc2seg/postprocess_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from nuc2seg.postprocess import stitch_shapes
from shapely import box
import geopandas as gpd


def test_stitch_shapes():
upper_left_shape = gpd.GeoDataFrame({"geometry": [box(3, 3, 4, 4)]})
bottom_right_shape = gpd.GeoDataFrame({"geometry": [box(18, 18, 19, 19)]})
empty_shape = gpd.GeoDataFrame({"geometry": []})

shapes = [
upper_left_shape,
empty_shape.copy(),
empty_shape.copy(),
empty_shape.copy(),
empty_shape.copy(),
empty_shape.copy(),
empty_shape.copy(),
empty_shape.copy(),
bottom_right_shape,
]
result = stitch_shapes(shapes, (10, 10), (20, 20), 0.5)

assert len(result) == 2

shapes = [
empty_shape.copy(),
upper_left_shape,
empty_shape.copy(),
empty_shape.copy(),
empty_shape.copy(),
empty_shape.copy(),
empty_shape.copy(),
empty_shape.copy(),
bottom_right_shape,
]
result = stitch_shapes(shapes, (10, 10), (20, 20), 0.5)
assert len(result) == 1
Loading
Loading