Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
acycliq committed Nov 1, 2024
1 parent 32bb257 commit 5a69137
Showing 1 changed file with 175 additions and 75 deletions.
250 changes: 175 additions & 75 deletions pciSeq/src/preprocess/spot_labels.py
Original file line number Diff line number Diff line change
@@ -1,119 +1,219 @@
"""
Functions to prepare the data for pciSeq. The label image and spots are parsed and if a spot
lies within the cell boundaries then the corresponding cell id is recorded.
Cell centroids and cell areas are also calculated.
Cell and Spot Label Processing Module for pciSeq
This module provides core functionality for processing and analyzing spatial transcriptomics data,
specifically handling the relationship between cell segmentation and RNA spot detection.
Key Functions:
-------------
- inside_cell: Maps RNA spots to their containing cells
- reorder_labels: Normalizes cell labels to sequential integers
- stage_data: Main processing pipeline that integrates all functionality
Data Processing Steps:
--------------------
1. Cell Label Processing:
- Validates and normalizes cell segmentation labels
- Ensures sequential integer labeling
- Computes cell properties (centroids, areas)
2. Spot Assignment:
- Maps each RNA spot to its containing cell
- Validates spot coordinates against image boundaries
- Links spots with cell metadata
3. Boundary Processing:
- Extracts and validates cell boundaries
- Ensures consistency between properties and boundaries
"""

from dataclasses import dataclass


import numpy as np
import pandas as pd
from typing import Tuple, Union
import skimage.measure as skmeas
from typing import Tuple
from scipy.sparse import coo_matrix, csr_matrix
from pciSeq.src.preprocess.cell_borders import extract_borders_dip, extract_borders
from scipy.sparse import coo_matrix, csr_matrix, spmatrix
from pciSeq.src.preprocess.cell_borders import extract_borders
import logging


spot_labels_logger = logging.getLogger(__name__)


def inside_cell(label_image, spots) -> np.array:
if isinstance(label_image, coo_matrix):
label_image = label_image.tocsr()
elif isinstance(label_image, np.ndarray):
def inside_cell(label_image: Union[spmatrix, np.ndarray],
spots: pd.DataFrame) -> np.ndarray:
"""
Determine which cell contains each spot.
Args:
label_image: Cell segmentation mask (sparse matrix or array)
spots: DataFrame with spot coordinates ('x' and 'y' columns)
Returns:
Array of cell labels for each spot
Raises:
TypeError: If label_image is not of supported type
"""
if isinstance(label_image, np.ndarray):
label_image = csr_matrix(label_image)
elif isinstance(label_image, csr_matrix):
pass
else:
raise Exception('label_image should be of type "csr_matrix" ')
m = label_image[spots.y, spots.x]
out = np.asarray(m, dtype=np.uint32)
return out[0]
elif isinstance(label_image, coo_matrix):
label_image = label_image.tocsr()
elif not isinstance(label_image, csr_matrix):
raise TypeError('label_image must be ndarray, coo_matrix, or csr_matrix')

return np.asarray(label_image[spots.y, spots.x], dtype=np.uint32)[0]

def remap_labels(coo):

def remap_labels(coo: coo_matrix) -> coo_matrix:
"""
Used for debugging/sanity checking only. It resuffles the label_image
Randomly reshuffle label assignments (for testing/debugging).
Args:
coo: Input label matrix
Returns:
Matrix with randomly remapped labels
"""
coo_max = coo.data.max()
_keys = 1 + np.arange(coo_max)
_vals = _keys.copy()
np.random.shuffle(_vals)
d = dict(zip(_keys, _vals))
new_data = np.array([d[x] for x in coo.data]).astype(np.uint64)
out = coo_matrix((new_data, (coo.row, coo.col)), shape=coo.shape)
return out
original_labels = 1 + np.arange(coo_max)
new_labels = original_labels.copy()
np.random.shuffle(new_labels)

label_map = dict(zip(original_labels, new_labels))
new_data = np.array([label_map[x] for x in coo.data]).astype(np.uint64)

return coo_matrix((new_data, (coo.row, coo.col)), shape=coo.shape)


def reorder_labels(coo):
def reorder_labels(coo: coo_matrix) -> Tuple[coo_matrix, pd.DataFrame]:
"""
rearranges the labels so that they are a sequence of integers
Normalize labels to be sequential integers starting from 1.
Args:
coo: Sparse matrix containing cell labels
Returns:
Tuple containing:
- Relabeled sparse matrix
- DataFrame mapping old labels to new labels
"""
label_image = coo.toarray()
flat_arr = label_image.flatten()
u, idx = np.unique(flat_arr, return_inverse=True)

label_map = pd.DataFrame(
set(zip(flat_arr, idx)),
columns=['old_label', 'new_label'],
dtype=np.uint32)
label_map = label_map.sort_values(by='old_label', ignore_index=True)
unique_labels, idx = np.unique(label_image.flatten(), return_inverse=True)

label_map = pd.DataFrame({
'old_label': unique_labels,
'new_label': np.arange(len(unique_labels))
}, dtype=np.uint32).sort_values(by='old_label', ignore_index=True)

return coo_matrix(idx.reshape(label_image.shape)), label_map


def stage_data(spots: pd.DataFrame, coo: coo_matrix) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
Reads the spots and the label image that are passed in and calculates which cell (if any) encircles any
given spot within its boundaries. It also retrieves the coordinates of the cell boundaries, the cell
centroids and the cell area
"""
Process spot and cell segmentation data to establish spot-cell relationships.
Args:
spots: DataFrame with columns ['x', 'y', 'Gene']
coo: Sparse matrix containing cell segmentation labels
Returns:
Tuple containing:
- cells: DataFrame with cell properties
- boundaries: DataFrame with cell boundary coordinates
- spots: DataFrame with spot locations and cell assignments
Raises:
ValueError: If required columns are missing or data validation fails
"""
# Validate inputs
required_columns = {'x', 'y', 'Gene'}
missing_columns = required_columns - set(spots.columns)
if missing_columns:
raise ValueError(f"Missing required columns in spots DataFrame: {missing_columns}")

spot_labels_logger.info(f'Number of spots passed-in: {spots.shape[0]}')
spot_labels_logger.info(f'Number of segmented cells: {len(set(coo.data))}')
spot_labels_logger.info(
f'Segmentation array implies that image has width: {coo.shape[1]}px and height: {coo.shape[0]}px')

# Normalize labels if needed
label_map = None
unique_labels = set(coo.data)
max_label = coo.data.max()

if coo.data.max() != len(set(coo.data)):
spot_labels_logger.info('The labels in the label image do not seem to be a sequence of successive integers. Relabelling the label image.')
spot_labels_logger.info(
f'Detected non-sequential cell labels: found {len(unique_labels)} unique labels '
f'with maximum value of {max_label}. Normalizing labels to range [1, {len(unique_labels)}]'
)
coo, label_map = reorder_labels(coo)

spot_labels_logger.info('Number of spots passed-in: %d' % spots.shape[0])
spot_labels_logger.info('Number of segmented cells: %d' % len(set(coo.data)))
spot_labels_logger.info('Segmentation array implies that image has width: %dpx and height: %dpx' % (coo.shape[1], coo.shape[0]))
# Filter spots to image bounds
mask_x = (spots.x >= 0) & (spots.x <= coo.shape[1])
mask_y = (spots.y >= 0) & (spots.y <= coo.shape[0])
spots = spots[mask_x & mask_y]
spots = spots[mask_x & mask_y].copy()

# 1. Find which cell the spots lie within
inc = inside_cell(coo.tocsr().astype(np.uint32), spots)
spots = spots.assign(label=inc)
# Assign spots to cells
spots = spots.assign(label=inside_cell(coo, spots))

# 2. Get cell centroids and area
props = skmeas.regionprops_table(coo.toarray().astype(np.int32), properties=['label', 'area', 'centroid'])
props_df = pd.DataFrame(props).rename(columns={'centroid-0': 'y_cell', 'centroid-1': 'x_cell'})
# Calculate cell properties
props = skmeas.regionprops_table(
coo.toarray().astype(np.int32),
properties=['label', 'area', 'centroid']
)
props_df = pd.DataFrame(props).rename(
columns={'centroid-0': 'y_cell', 'centroid-1': 'x_cell'}
)

# if there is a label map, attach it to the cell props.
# Apply label mapping if exists
if label_map is not None:
props_df = pd.merge(props_df, label_map, left_on='label', right_on='new_label', how='left')
props_df = props_df.drop(['new_label'], axis=1)
props_df = pd.merge(
props_df,
label_map,
left_on='label',
right_on='new_label',
how='left'
).drop(['new_label'], axis=1)

# Set datatypes
props_df = props_df.astype({
"label": np.uint32,
"area": np.uint32,
'y_cell': np.float32,
'x_cell': np.float32
})

# Extract cell boundaries
cell_boundaries = extract_borders(coo.toarray().astype(np.uint32))

# set the datatypes of the columns
props_df = props_df.astype({"label": np.uint32,
"area": np.uint32,
'y_cell': np.float32,
'x_cell': np.float32})
# Validate cell data consistency
if not (props_df.shape[0] == cell_boundaries.shape[0] == np.unique(coo.data).shape[0]):
raise ValueError("Inconsistency detected between cell properties and boundaries")

# 3. Get the cell boundaries
# cell_boundaries = extract_borders_dip(coo.toarray().astype(np.uint32))
cell_boundaries = extract_borders(coo.toarray().astype(np.uint32))
assert props_df.shape[0] == cell_boundaries.shape[0] == np.unique(coo.data).shape[0]
assert set(spots.label[spots.label > 0]) <= set(props_df.label)
# Ensure spots are assigned to valid cells
if not set(spots.label[spots.label > 0]).issubset(set(props_df.label)):
raise ValueError("Spots assigned to non-existent cell labels")

# Prepare final data structures
cells = props_df.merge(cell_boundaries)
cells.sort_values(by=['label', 'x_cell', 'y_cell'])
assert cells.shape[0] == cell_boundaries.shape[0] == props_df.shape[0]

# join spots and cells on the cell label so you can get the x,y coords of the cell for any given spot
cells.sort_values(by=['label', 'x_cell', 'y_cell'], inplace=True)
spots = spots.merge(cells, how='left', on=['label'])

_cells = cells.drop(columns=['coords'])
_cells = _cells.rename(columns={'x_cell': 'x0', 'y_cell': 'y0'})
_cell_boundaries = cells[['label', 'coords']].rename(columns={'label': 'cell_id'})
_spots = spots[['x', 'y', 'label', 'Gene', 'x_cell', 'y_cell']].rename(columns={'Gene': 'target', 'x': 'x_global', 'y': 'y_global'})

return _cells, _cell_boundaries, _spots

return (
cells.drop(columns=['coords']).rename(columns={
'x_cell': 'x0',
'y_cell': 'y0'
}),
cells[['label', 'coords']].rename(columns={
'label': 'cell_id'
}),
spots[['x', 'y', 'label', 'Gene', 'x_cell', 'y_cell']].rename(columns={
'Gene': 'target',
'x': 'x_global',
'y': 'y_global'
})
)

0 comments on commit 5a69137

Please sign in to comment.