Skip to content

Commit

Permalink
Fix Preprocessing OOM (#157)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffquinn-msk authored Nov 25, 2024
1 parent 42284c1 commit b8fec05
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 27 deletions.
3 changes: 1 addition & 2 deletions nextflow.config
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,8 @@ profiles {
}
}
iris {
apptainer {
podman {
enabled = true
autoMounts = true
}
process {
executor = 'slurm'
Expand Down
6 changes: 5 additions & 1 deletion nextflow/config/base.config
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ process {
withLabel:gpu {
accelerator = 1
clusterOptions = { getGpuClusterOptions( task.executor ) }
containerOptions = { workflow.containerEngine == "singularity" ? '--nv': ( workflow.containerEngine == "docker" ? '--gpus all': null ) }
containerOptions = {
workflow.containerEngine == "singularity" ?
'--nv': ( workflow.containerEngine == "docker" ?
'--gpus all': ( workflow.containerEngine == "podman" ?
'--device nvidia.com/gpu=all' : null ) ) }
}
}
35 changes: 11 additions & 24 deletions src/nuc2seg/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,55 +82,38 @@ def create_rasterized_dataset(
)
labels_geo_df.rename(columns={"index_right": "nucleus_id_xenium"}, inplace=True)

# Calculate the nearest transcript neighbors
logger.info("Calculating the nearest transcript neighbors")
transcript_xy = np.array(
[tx_geo_df["x_location"].values, tx_geo_df["y_location"].values]
).T
kdtree = KDTree(transcript_xy)

# Get the distance to the k'th nearest transcript
logger.info("Get the distance to the k'th nearest transcript")
pixels_xy = np.array([labels_geo_df["X"].values, labels_geo_df["Y"].values]).T
labels_geo_df["transcript_distance"] = kdtree.query(
pixels_xy, k=background_pixel_transcripts + 1
)[0][:, -1]

# Assign pixels roughly on top of nuclei to belong to that nuclei label
logger.info("Assign pixels roughly on top of nuclei to belong to that nuclei label")
pixel_labels = np.zeros(labels_geo_df.shape[0], dtype=int) - 1
nucleus_pixels = labels_geo_df["nucleus_distance"] <= foreground_nucleus_distance
pixel_labels[nucleus_pixels] = labels_geo_df["nucleus_label"][nucleus_pixels]

# Assign pixels to the background if they are far from nuclei and not near a dense region of transcripts
logger.info(
"Assign pixels to the background if they are far from nuclei and not near a dense region of transcripts"
)
background_pixels = (
labels_geo_df["nucleus_distance"] > background_nucleus_distance
) & (labels_geo_df["transcript_distance"] > background_transcript_distance)
pixel_labels[background_pixels] = 0

# Convert back over to the grid format
logger.info("Convert pixel labels to a grid")
labels = np.zeros((x_size, y_size), dtype=int)
labels[labels_geo_df["X"] - x_min, labels_geo_df["Y"] - y_min] = pixel_labels

# Create a nuclei x gene count matrix
tx_nuclei_geo_df = gpd.sjoin_nearest(
tx_geo_df, nuclei_geo_df, distance_col="nucleus_distance"
)
nuclei_count_geo_df = tx_nuclei_geo_df[
tx_nuclei_geo_df["nucleus_distance"] <= foreground_nucleus_distance
]

# I think we have enough memory to just store this as a dense array
nuclei_count_matrix = np.zeros((nuclei_geo_df.shape[0] + 1, n_genes), dtype=int)
np.add.at(
nuclei_count_matrix,
(
nuclei_count_geo_df["nucleus_label"].values.astype(int),
nuclei_count_geo_df["gene_id"].values.astype(int),
),
1,
)

# Assume for simplicity that it's a homogeneous poisson process for transcripts.
# Add up all the transcripts in each pixel.
logger.info("Add up all transcripts in each pixel")
tx_count_grid = np.zeros((x_size, y_size), dtype=int)
np.add.at(
tx_count_grid,
Expand All @@ -142,6 +125,7 @@ def create_rasterized_dataset(
)

# Estimate the background rate
logger.info("Estimating background rate")
tx_background_mask = (
labels[
tx_geo_df["x_location"].values.astype(int) - x_min,
Expand All @@ -156,6 +140,9 @@ def create_rasterized_dataset(

# Calculate the angle at which each pixel faces to point at its nearest nucleus centroid.
# Normalize it to be in [0,1]
logger.info(
"Calculating the angle at which each pixel faces to point at its nearest nucleus centroid"
)
labels_geo_df["nucleus_angle"] = (
cart2pol(
labels_geo_df["nucleus_centroid_x"].values - labels_geo_df["X"].values,
Expand Down

0 comments on commit b8fec05

Please sign in to comment.