diff --git a/pyproject.toml b/pyproject.toml index 3b095a2..1290de0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,9 @@ dependencies = [ "matplotlib", "autograd-minimize", "scikit-learn", - "pyarrow" + "pyarrow", + "h5py", + "blended-tiling" ] [project.optional-dependencies] diff --git a/src/nuc2seg/celltyping.py b/src/nuc2seg/celltyping.py new file mode 100644 index 0000000..5c3489b --- /dev/null +++ b/src/nuc2seg/celltyping.py @@ -0,0 +1,155 @@ +import numpy as np +import tqdm +from scipy.special import softmax + +from nuc2seg.xenium import logger +import numpy as np +from scipy.special import logsumexp + + +def aic_bic(gene_counts, expression_profiles, prior_probs): + + n_components = expression_profiles.shape[0] + n_genes = expression_profiles.shape[1] + n_samples = gene_counts.shape[0] + + dof = n_components * (n_genes - 1) + n_components - 1 + log_likelihood = logsumexp( + np.log(prior_probs[None]) + + (gene_counts[:, None] * np.log(expression_profiles[None])).sum(axis=2), + axis=1, + ).sum() + aic = -2 * log_likelihood + 2 * dof + bic = -2 * log_likelihood + dof * np.log(n_samples) + + return aic, bic + + +## TODO: Maybe replace this with a library that does the same thing +def estimate_cell_types( + gene_counts, + min_components=2, + max_components=25, + max_em_steps=100, + tol=1e-4, + warm_start=False, +): + + n_nuclei, n_genes = gene_counts.shape + + # Randomly initialize cell type profiles + cur_expression_profiles = np.random.dirichlet(np.ones(n_genes), size=min_components) + + # Initialize probabilities to be uniform + cur_prior_probs = np.ones(min_components) / min_components + + # No need to initialize cell type assignments + cur_cell_types = None + + # Track BIC and AIC scores + aic_scores = np.zeros(max_components - min_components + 1) + bic_scores = np.zeros(max_components - min_components + 1) + + # Iterate through every possible number of components + ( + final_expression_profiles, + final_prior_probs, + final_cell_types, + final_expression_rates, + ) = ([], [], [], []) + prev_expression_profiles = np.zeros_like(cur_expression_profiles) + for idx, n_components in enumerate( + tqdm.trange( + min_components, max_components + 1, desc="estimate_cell_types", position=0 + ) + ): + + # Warm start from the previous iteration + if warm_start and n_components > min_components: + # Expand and copy over the current parameters + next_prior_probs = np.zeros(n_components) + next_expression_profiles = np.zeros((n_components, n_genes)) + next_prior_probs[: n_components - 1] = cur_prior_probs + next_expression_profiles[: n_components - 1] = cur_expression_profiles + + # Split the dominant cluster + dominant = np.argmax(cur_prior_probs) + split_prob = np.random.random() + next_prior_probs[-1] = next_prior_probs[dominant] * split_prob + next_prior_probs[dominant] = next_prior_probs[dominant] * (1 - split_prob) + next_expression_profiles[-1] = cur_expression_profiles[ + dominant + ] * split_prob + (1 - split_prob) * np.random.dirichlet(np.ones(n_genes)) + + cur_prior_probs = next_prior_probs + cur_expression_profiles = next_expression_profiles + + logger.debug("priors:", cur_prior_probs) + logger.debug("expression:", cur_expression_profiles) + else: + # Cold start from a random location + cur_expression_profiles = np.random.dirichlet( + np.ones(n_genes), size=n_components + ) + cur_prior_probs = np.ones(n_components) / n_components + + converge = tol + 1 + for step in tqdm.trange( + max_em_steps, + desc=f"EM for n_components {n_components}", + unit="step", + position=1, + ): + + # E-step: estimate cell type assignments (posterior probabilities) + logits = np.log(cur_prior_probs[None]) + ( + gene_counts[:, None] * np.log(cur_expression_profiles[None]) + ).sum(axis=2) + cur_cell_types = softmax(logits, axis=1) + + # M-step (part 1): estimate cell type profiles + prev_expression_profiles = np.array(cur_expression_profiles) + cur_expression_profiles = ( + cur_cell_types[..., None] * gene_counts[:, None] + ).sum(axis=0) + 1 + cur_expression_profiles = ( + cur_expression_profiles / (cur_cell_types.sum(axis=0) + 1)[:, None] + ) + + # M-step (part 2): estimate cell type probabilities + cur_prior_probs = (cur_cell_types.sum(axis=0) + 1) / ( + cur_cell_types.shape[0] + n_components + ) + logger.debug(f"cur_prior_probs: {cur_prior_probs}") + + # Track convergence of the cell type profiles + converge = np.linalg.norm( + prev_expression_profiles - cur_expression_profiles + ) / np.linalg.norm(prev_expression_profiles) + + logger.debug(f"Convergence: {converge:.4f}") + if converge <= tol: + logger.debug(f"Stopping early.") + break + + # Save the results + final_expression_profiles.append(cur_expression_profiles) + final_cell_types.append(cur_cell_types) + final_prior_probs.append(cur_prior_probs) + + # Calculate BIC and AIC + aic_scores[idx], bic_scores[idx] = aic_bic( + gene_counts, cur_expression_profiles, cur_prior_probs + ) + + logger.debug(f"K={n_components}") + logger.debug(f"AIC: {aic_scores[idx]:.4f}") + logger.debug(f"BIC: {bic_scores[idx]:.4f}") + + return { + "bic": bic_scores, + "aic": aic_scores, + "expression_profiles": final_expression_profiles, + "prior_probs": final_prior_probs, + "cell_types": final_cell_types, + } diff --git a/src/nuc2seg/cli/preprocess.py b/src/nuc2seg/cli/preprocess.py index 4096086..4c17aca 100644 --- a/src/nuc2seg/cli/preprocess.py +++ b/src/nuc2seg/cli/preprocess.py @@ -1,9 +1,15 @@ import argparse import logging import numpy as np +import pandas from nuc2seg import log_config -from nuc2seg.xenium_utils import spatial_as_sparse_arrays +from nuc2seg.xenium import ( + load_nuclei, + load_and_filter_transcripts, + create_shapely_rectangle, +) +from nuc2seg.preprocessing import create_rasterized_dataset logger = logging.getLogger(__name__) @@ -26,8 +32,8 @@ def get_parser(): required=True, ) parser.add_argument( - "--output-dir", - help="Output directory.", + "--output", + help="Output path.", type=str, required=True, ) @@ -38,10 +44,10 @@ def get_parser(): default=0, ) parser.add_argument( - "--pixel-stride", - help="Stride for the pixel grid.", - type=int, - default=1, + "--resolution", + help="Size of a pixel in microns for rasterization.", + type=float, + default=1.0, ) parser.add_argument( "--min-qv", @@ -74,22 +80,10 @@ def get_parser(): default=5, ) parser.add_argument( - "--tile-height", - help="Height of the tiles.", - type=int, - default=64, - ) - parser.add_argument( - "--tile-width", - help="Width of the tiles.", - type=int, - default=64, - ) - parser.add_argument( - "--tile-stride", - help="Stride of the tiles.", - type=int, - default=48, + "--sample-area", + default=None, + type=str, + help='Crop the dataset to this rectangle, provided in in "x1,x2,y1,y2" format.', ) return parser @@ -109,17 +103,38 @@ def main(): prng = np.random.default_rng(args.seed) - spatial_as_sparse_arrays( + if args.sample_area: + sample_area = create_shapely_rectangle( + *[float(x) for x in args.sample_area.split(",")] + ) + + else: + df = pandas.read_parquet(args.transcripts_file) + y_max = df["y_location"].max() + x_max = df["x_location"].max() + + sample_area = create_shapely_rectangle(0, 0, x_max, y_max) + + nuclei_geo_df = load_nuclei( nuclei_file=args.nuclei_file, + sample_area=sample_area, + ) + + tx_geo_df = load_and_filter_transcripts( transcripts_file=args.transcripts_file, - outdir=args.output_dir, - pixel_stride=args.pixel_stride, + sample_area=sample_area, min_qv=args.min_qv, + ) + + ds = create_rasterized_dataset( + nuclei_geo_df=nuclei_geo_df, + tx_geo_df=tx_geo_df, + sample_area=sample_area, + resolution=args.resolution, foreground_nucleus_distance=args.foreground_nucleus_distance, background_nucleus_distance=args.background_nucleus_distance, background_pixel_transcripts=args.background_pixel_transcripts, background_transcript_distance=args.background_transcript_distance, - tile_width=args.tile_width, - tile_height=args.tile_height, - tile_stride=args.tile_stride, ) + + ds.save_h5(args.output) diff --git a/src/nuc2seg/cli/train.py b/src/nuc2seg/cli/train.py index d5eea60..2b8ccca 100644 --- a/src/nuc2seg/cli/train.py +++ b/src/nuc2seg/cli/train.py @@ -6,6 +6,7 @@ from nuc2seg import log_config from nuc2seg.train import train from nuc2seg.unet_model import SparseUNet +from nuc2seg.data import Nuc2SegDataset, TiledDataset logger = logging.getLogger(__name__) @@ -16,8 +17,8 @@ def get_parser(): ) log_config.add_logging_args(parser) parser.add_argument( - "--preprocessed-tiles-dir", - help="Directory containing preprocessed tiles.", + "--dataset", + help="Path to dataset in h5 format.", type=str, required=True, ) @@ -27,12 +28,6 @@ def get_parser(): type=str, required=True, ) - parser.add_argument( - "--n-classes", - help="Number of classes to segment.", - type=int, - required=True, - ) parser.add_argument( "--seed", help="Seed to use for PRNG.", @@ -112,7 +107,30 @@ def get_parser(): default="cpu", choices=["cpu", "cuda"], ) - + parser.add_argument( + "--tile-height", + help="Height of the tiles.", + type=int, + default=64, + ) + parser.add_argument( + "--tile-width", + help="Width of the tiles.", + type=int, + default=64, + ) + parser.add_argument( + "--overlap-percentage", + help="What percent of each tile dimension overlaps with the next tile.", + type=float, + default=0.25, + ) + parser.add_argument( + "--num-dataloader-workers", + help="Number of workers to use for the data loader.", + type=int, + default=0, + ) return parser @@ -130,12 +148,24 @@ def main(): log_config.configure_logging(args) np.random.seed(args.seed) - model = SparseUNet(600, args.n_classes + 2, (64, 64)) + + logger.info(f"Loading dataset from {args.dataset}") + + ds = Nuc2SegDataset.load_h5(args.dataset) + + tiled_dataset = TiledDataset( + ds, + tile_height=args.tile_height, + tile_width=args.tile_width, + tile_overlap=args.overlap_percentage, + ) + + model = SparseUNet(600, ds.n_classes + 2, (64, 64)) train( model, device=args.device, - tiles_dir=args.preprocessed_tiles_dir, + dataset=tiled_dataset, epochs=args.epochs, batch_size=args.batch_size, learning_rate=args.learning_rate, @@ -147,6 +177,7 @@ def main(): gradient_clipping=args.gradient_clipping, max_workers=args.max_workers, validation_frequency=args.validation_frequency, + num_dataloader_workers=args.num_dataloader_workers, ) logger.info(f"Saving model weights to {args.model_weights_output}") diff --git a/src/nuc2seg/data.py b/src/nuc2seg/data.py new file mode 100644 index 0000000..856c13d --- /dev/null +++ b/src/nuc2seg/data.py @@ -0,0 +1,213 @@ +import h5py +import logging +import torch + +import numpy as np +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import Dataset +from blended_tiling import TilingModule + +logger = logging.getLogger(__name__) + + +class Nuc2SegDataset: + def __init__( + self, labels, angles, classes, transcripts, bbox, n_classes, n_genes, resolution + ): + self.labels = labels + self.angles = angles + self.classes = classes + self.transcripts = transcripts + self.bbox = bbox + self.n_classes = n_classes + self.n_genes = n_genes + self.resolution = resolution + + def save_h5(self, path): + with h5py.File(path, "w") as f: + f.create_dataset("labels", data=self.labels, compression="gzip") + f.create_dataset("angles", data=self.angles, compression="gzip") + f.create_dataset("classes", data=self.classes, compression="gzip") + f.create_dataset("transcripts", data=self.transcripts, compression="gzip") + f.create_dataset("bbox", data=self.bbox) + f.attrs["n_classes"] = self.n_classes + f.attrs["n_genes"] = self.n_genes + f.attrs["resolution"] = self.resolution + + @property + def x_extent_pixels(self): + return self.labels.shape[0] + + @property + def y_extent_pixels(self): + return self.labels.shape[1] + + @staticmethod + def load_h5(path): + with h5py.File(path, "r") as f: + labels = f["labels"][:] + angles = f["angles"][:] + classes = f["classes"][:] + transcripts = f["transcripts"][:] + bbox = f["bbox"][:] + n_classes = f.attrs["n_classes"] + n_genes = f.attrs["n_genes"] + resolution = f.attrs["resolution"] + return Nuc2SegDataset( + labels=labels, + angles=angles, + classes=classes, + transcripts=transcripts, + bbox=bbox, + n_classes=n_classes, + n_genes=n_genes, + resolution=resolution, + ) + + +def generate_tiles( + tiler: TilingModule, x_extent, y_extent, tile_size, overlap_fraction, tile_ids=None +): + """ + A generator function to yield overlapping tiles from a 2D NumPy array (image). + + Parameters: + - image: 2D NumPy array representing the image. + - tile_size: Tuple of (tile_height, tile_width), the size of each tile. + - overlap_fraction: Fraction of overlap between tiles (0 to 1). + - tile_ids: List of tile IDs to generate. If None, all tiles are generated. + + Yields: + - BBox extent in pixels for each tile (non inclusive end) x1, y1, x2, y2 + """ + # Generate tiles + tile_id = 0 + for x in tiler._calc_tile_coords(x_extent, tile_size[0], overlap_fraction)[0]: + for y in tiler._calc_tile_coords(y_extent, tile_size[1], overlap_fraction)[0]: + if tile_ids is not None: + if tile_id in tile_ids: + yield x, y, x + tile_size[0], y + tile_size[1] + else: + yield x, y, x + tile_size[0], y + tile_size[1] + tile_id += 1 + + +def collate_tiles(data): + outputs = {key: [] for key in data[0].keys()} + for sample in data: + for key, val in sample.items(): + outputs[key].append(val) + outputs["X"] = pad_sequence(outputs["X"], batch_first=True, padding_value=-1) + outputs["Y"] = pad_sequence(outputs["Y"], batch_first=True, padding_value=-1) + outputs["gene"] = pad_sequence(outputs["gene"], batch_first=True, padding_value=-1) + outputs["labels"] = torch.stack(outputs["labels"]) + outputs["angles"] = torch.stack(outputs["angles"]) + outputs["classes"] = torch.stack(outputs["classes"]) + outputs["label_mask"] = torch.stack(outputs["label_mask"]).type(torch.bool) + outputs["nucleus_mask"] = torch.stack(outputs["nucleus_mask"]).type(torch.bool) + outputs["location"] = torch.tensor(np.stack(outputs["location"])).type(torch.long) + + # Edge case: pad_sequence will squeeze tensors if there are no entries. + # In that case, we just need to add the dimension back. + if len(outputs["gene"].shape) == 1: + outputs["X"] = outputs["X"][:, None] + outputs["Y"] = outputs["Y"][:, None] + outputs["gene"] = outputs["gene"][:, None] + + return outputs + + +class TiledDataset(Dataset): + def __init__( + self, + dataset: Nuc2SegDataset, + tile_height: int, + tile_width: int, + tile_overlap: float = 0.25, + ): + self.ds = dataset + self.tile_height = tile_width + self.tile_width = tile_height + self.tile_overlap = tile_overlap + + self._tiler = TilingModule( + tile_size=(tile_width, tile_height), + tile_overlap=(tile_overlap, tile_overlap), + base_size=(dataset.x_extent_pixels, dataset.y_extent_pixels), + ) + + def __len__(self): + return self._tiler.num_tiles() + + @property + def per_tile_class_histograms(self): + class_tiles = ( + self._tiler.split_into_tiles(torch.tensor(self.ds.classes[None, None, ...])) + .squeeze() + .detach() + .numpy() + .astype(int) + ) + + class_tiles_flattened = class_tiles.reshape( + (self._tiler.num_tiles(), class_tiles.shape[1] * class_tiles.shape[2]) + ) + + return np.apply_along_axis( + np.bincount, 1, class_tiles_flattened + 1, minlength=self.ds.n_classes + 2 + ) + + def __getitem__(self, idx): + x1, y1, x2, y2 = next( + generate_tiles( + tiler=self._tiler, + x_extent=self.ds.x_extent_pixels, + y_extent=self.ds.y_extent_pixels, + tile_size=(self.tile_height, self.tile_width), + overlap_fraction=self.tile_overlap, + tile_ids=[idx], + ) + ) + transcripts = self.ds.transcripts + labels = self.ds.labels + angles = self.ds.angles + classes = self.ds.classes + + selection_criteria = ( + (transcripts[:, 0] < x2) + & (transcripts[:, 0] > x1) + & (transcripts[:, 1] < y2) + & (transcripts[:, 1] > y1) + ) + tile_transcripts = transcripts[selection_criteria] + + tile_transcripts[:, 0] = tile_transcripts[:, 0] - x1 + tile_transcripts[:, 1] = tile_transcripts[:, 1] - y1 + + tile_labels = labels[x1:x2, y1:y2] + + local_ids = np.unique(tile_labels) + local_ids = local_ids[local_ids > 0] + for i, c in enumerate(local_ids): + tile_labels[tile_labels == c] = i + 1 + + tile_angles = angles[x1:x2, y1:y2] + + tile_angles[tile_labels == -1] = -1 + + tile_classes = classes[x1:x2, y1:y2] + + labels_mask = tile_labels > -1 + nucleus_mask = tile_labels > 0 + + return { + "X": torch.as_tensor(tile_transcripts[:, 0]).long().contiguous(), + "Y": torch.as_tensor(tile_transcripts[:, 1]).long().contiguous(), + "gene": torch.as_tensor(tile_transcripts[:, 2]).long().contiguous(), + "labels": torch.as_tensor(tile_labels).long().contiguous(), + "angles": torch.as_tensor(tile_angles).float().contiguous(), + "classes": torch.as_tensor(tile_classes).long().contiguous(), + "label_mask": torch.as_tensor(labels_mask).bool().contiguous(), + "nucleus_mask": torch.as_tensor(nucleus_mask).bool().contiguous(), + "location": np.array([x1, y1]), + } diff --git a/src/nuc2seg/data_loading.py b/src/nuc2seg/data_loading.py deleted file mode 100755 index f745d50..0000000 --- a/src/nuc2seg/data_loading.py +++ /dev/null @@ -1,88 +0,0 @@ -import logging -import os -import torch -from os.path import join -from pathlib import Path - -import numpy as np -from torch.nn.utils.rnn import pad_sequence -from torch.utils.data import Dataset - -logger = logging.getLogger(__name__) - - -def xenium_collate_fn(data): - outputs = {key: [] for key in data[0].keys()} - for sample in data: - for key, val in sample.items(): - outputs[key].append(val) - outputs["X"] = pad_sequence(outputs["X"], batch_first=True, padding_value=-1) - outputs["Y"] = pad_sequence(outputs["Y"], batch_first=True, padding_value=-1) - outputs["gene"] = pad_sequence(outputs["gene"], batch_first=True, padding_value=-1) - outputs["labels"] = torch.stack(outputs["labels"]) - outputs["angles"] = torch.stack(outputs["angles"]) - outputs["classes"] = torch.stack(outputs["classes"]) - outputs["label_mask"] = torch.stack(outputs["label_mask"]).type(torch.bool) - outputs["nucleus_mask"] = torch.stack(outputs["nucleus_mask"]).type(torch.bool) - outputs["location"] = torch.tensor(np.stack(outputs["location"])).type(torch.long) - - # Edge case: pad_sequence will squeeze tensors if there are no entries. - # In that case, we just need to add the dimension back. - if len(outputs["gene"].shape) == 1: - outputs["X"] = outputs["X"][:, None] - outputs["Y"] = outputs["Y"][:, None] - outputs["gene"] = outputs["gene"][:, None] - - return outputs - - -class XeniumDataset(Dataset): - def __init__(self, tiles_dir: str): - self.transcripts_dir = Path(join(tiles_dir, "transcripts/")) - self.labels_dir = Path(join(tiles_dir, "labels/")) - self.angles_dir = Path(join(tiles_dir, "angles/")) - self.classes_dir = Path(join(tiles_dir, "classes/")) - - self.locations = np.load(join(tiles_dir, "locations.npy")) - - self.ids = np.arange(self.locations.shape[0]) - - logging.info(f"Creating dataset with {len(self.ids)} examples") - self.class_counts = np.load(join(tiles_dir, "class_counts.npy")) - self.transcript_counts = np.load(join(tiles_dir, "transcript_counts.npy")) - self.max_length = self.transcript_counts.max() - self.label_values = np.arange(self.class_counts.shape[1]) - 1 - self.n_classes = self.class_counts.shape[1] - 2 - self.gene_ids = {int(i): j for i, j in np.load(join(tiles_dir, "gene_ids.npy"))} - self.n_genes = max(self.gene_ids) + 1 - - # Note: class IDs are 1-based since ID=0 is background - logging.info(f"Unique label values: {self.label_values}") - - def __len__(self): - return len(self.ids) - - def __getitem__(self, idx): - transcripts_file = os.path.join(self.transcripts_dir, f"{idx}.npz") - labels_file = os.path.join(self.labels_dir, f"{idx}.npz") - angles_file = os.path.join(self.angles_dir, f"{idx}.npz") - classes_file = os.path.join(self.classes_dir, f"{idx}.npz") - - xyg = np.load(transcripts_file)["arr_0"] - labels = np.load(labels_file)["arr_0"] - angles = np.load(angles_file)["arr_0"] - classes = np.load(classes_file)["arr_0"] - labels_mask = labels > -1 - nucleus_mask = labels > 0 - - return { - "X": torch.as_tensor(np.array(xyg[:, 0])).long().contiguous(), - "Y": torch.as_tensor(np.array(xyg[:, 1])).long().contiguous(), - "gene": torch.as_tensor(np.array(xyg[:, 2])).long().contiguous(), - "labels": torch.as_tensor(labels).long().contiguous(), - "angles": torch.as_tensor(angles).float().contiguous(), - "classes": torch.as_tensor(classes).long().contiguous(), - "label_mask": torch.as_tensor(labels_mask).bool().contiguous(), - "nucleus_mask": torch.as_tensor(nucleus_mask).bool().contiguous(), - "location": self.locations[idx], - } diff --git a/src/nuc2seg/data_test.py b/src/nuc2seg/data_test.py new file mode 100644 index 0000000..f06d629 --- /dev/null +++ b/src/nuc2seg/data_test.py @@ -0,0 +1,99 @@ +import shutil +import pytest +from nuc2seg.data import Nuc2SegDataset, TiledDataset, generate_tiles +import numpy as np +import tempfile +import os.path +from blended_tiling import TilingModule + + +@pytest.fixture(scope="package") +def test_dataset(): + return Nuc2SegDataset( + labels=np.ones((10, 20)), + angles=np.ones((10, 20)), + classes=np.ones((10, 20)), + transcripts=np.array([[0, 0, 0], [5, 5, 1], [10, 10, 2]]), + bbox=np.array([100, 100, 110, 120]), + n_classes=3, + n_genes=3, + resolution=1, + ) + + +def test_Nuc2SegDataset(): + ds = Nuc2SegDataset( + labels=np.ones((10, 20)), + angles=np.ones((10, 20)), + classes=np.ones((10, 20)), + transcripts=np.array([[0, 0, 0], [5, 5, 1], [10, 10, 2]]), + bbox=np.array([100, 100, 110, 120]), + n_classes=3, + n_genes=3, + resolution=1, + ) + + tmpdir = tempfile.mkdtemp() + output_path = os.path.join(tmpdir, "test.h5") + + assert ds.n_classes == 3 + assert ds.n_genes == 3 + assert ds.x_extent_pixels == 10 + assert ds.y_extent_pixels == 20 + + try: + ds.save_h5(output_path) + ds2 = Nuc2SegDataset.load_h5(output_path) + + np.testing.assert_array_equal(ds.labels, ds2.labels) + np.testing.assert_array_equal(ds.angles, ds2.angles) + np.testing.assert_array_equal(ds.classes, ds2.classes) + np.testing.assert_array_equal(ds.transcripts, ds2.transcripts) + np.testing.assert_array_equal(ds.bbox, ds2.bbox) + assert ds.n_classes == ds2.n_classes + assert ds.n_genes == ds2.n_genes + finally: + shutil.rmtree(tmpdir) + + +def test_generate_tiles(): + tiler = TilingModule( + base_size=(10, 20), tile_size=(5, 10), tile_overlap=(0.25, 0.25) + ) + + tile_bboxes = list( + generate_tiles( + tiler=tiler, + x_extent=10, + y_extent=20, + tile_size=(5, 10), + overlap_fraction=0.25, + ) + ) + assert len(tile_bboxes) == 9 + + +def test_tiled_dataset(test_dataset): + td = TiledDataset( + dataset=test_dataset, tile_height=10, tile_width=5, tile_overlap=0.25 + ) + + assert len(td) == 9 + + first_tile = td[0] + + assert first_tile["angles"].shape == (5, 10) + assert first_tile["labels"].shape == (5, 10) + assert first_tile["classes"].shape == (5, 10) + assert first_tile["location"].size == 2 + assert first_tile["nucleus_mask"].shape == (5, 10) + + assert td.per_tile_class_histograms.shape == (len(td), test_dataset.n_classes + 2) + + second_tile = td[1] + + assert second_tile["angles"].shape == (5, 10) + assert second_tile["labels"].shape == (5, 10) + assert second_tile["classes"].shape == (5, 10) + assert second_tile["location"].size == 2 + assert second_tile["nucleus_mask"].shape == (5, 10) diff --git a/src/nuc2seg/evaluate.py b/src/nuc2seg/evaluate.py index 6e64cd9..063d656 100755 --- a/src/nuc2seg/evaluate.py +++ b/src/nuc2seg/evaluate.py @@ -5,7 +5,7 @@ from matplotlib import pyplot as plt import tqdm -from nuc2seg.xenium_utils import pol2cart +from nuc2seg.preprocessing import pol2cart def dice_coeff( @@ -51,33 +51,25 @@ def evaluate(net, dataloader, device, amp): # with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp): # for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False): for idx, batch in enumerate( - tqdm.tqdm(dataloader, desc="Validation", unit="batch", position=3) + tqdm.tqdm(dataloader, desc="Validation", unit="batch", position=2, leave=False) ): x, y, z, labels, label_mask = ( - batch["X"], - batch["Y"], - batch["gene"], - batch["labels"], - batch["label_mask"], + batch["X"].to(device), + batch["Y"].to(device), + batch["gene"].to(device), + batch["labels"].to(device), + batch["label_mask"].to(device), ) label_mask = label_mask.type(torch.bool) mask_true = (labels > 0).type(torch.float) - # move images and labels to correct device and type - # image = image.to(device=device, dtype=torch.float32, memory_format=torch.channels_last) - # mask_true = mask_true.to(device=device, dtype=torch.long) - # predict the mask mask_pred = net(x, y, z) if mask_pred.dim() == 3: mask_pred = mask_pred[None] - # mask_pred = mask_pred.detach().numpy().copy() - foreground_pred = torch.sigmoid(mask_pred[..., 0]) - angles_pred = torch.sigmoid(mask_pred[..., 1]) * 2 * np.pi - np.pi - # class_pred = softmax(mask_pred[...,2:], axis=-1) for im_pred, im_true, im_label_mask in zip( foreground_pred, mask_true, label_mask @@ -91,13 +83,6 @@ def evaluate(net, dataloader, device, amp): / mask_true.shape[0] ) - # assert im_true.min() >= 0 and im_true.max() < net.n_classes, 'True mask indices should be in [0, n_classes[' - # # convert to one-hot format - # im_true = F.one_hot(im_true, net.n_classes).permute(0, 3, 1, 2).float() - # im_pred = F.one_hot(im_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float() - # # compute the Dice score, ignoring background - # dice_score += multiclass_dice_coeff(im_pred[:, 1:], mask_true[:, 1:], reduce_batch_first=True) - net.train() return dice_score / max(num_val_batches, 1) diff --git a/src/nuc2seg/plotting.py b/src/nuc2seg/plotting.py new file mode 100644 index 0000000..a898ef5 --- /dev/null +++ b/src/nuc2seg/plotting.py @@ -0,0 +1,14 @@ +import geopandas +from shapely import box +from matplotlib import pyplot as plt + + +def plot_tiling(bboxes): + polygons = [box(*x) for x in bboxes] + gdf = geopandas.GeoDataFrame(geometry=polygons) + + fig, ax = plt.subplots() + + gdf.boundary.plot(ax=ax, color="red", alpha=0.5) + + fig.savefig("/tmp/tiling.pdf") diff --git a/src/nuc2seg/preprocessing.py b/src/nuc2seg/preprocessing.py new file mode 100644 index 0000000..a5c9168 --- /dev/null +++ b/src/nuc2seg/preprocessing.py @@ -0,0 +1,194 @@ +import math + +import geopandas +import geopandas as gpd +import numpy as np +import shapely +from scipy.spatial import KDTree +import pandas as pd + +from nuc2seg.data import Nuc2SegDataset +from nuc2seg.xenium import get_bounding_box, logger +from nuc2seg.celltyping import estimate_cell_types + + +def cart2pol(x, y): + """Convert Cartesian coordinates to polar coordinates""" + rho = np.sqrt(x**2 + y**2) + phi = np.arctan2(y, x) + return (rho, phi) + + +def pol2cart(rho, phi): + x = rho * np.cos(phi) + y = rho * np.sin(phi) + return (x, y) + + +def create_pixel_geodf(x_min, x_max, y_min, y_max): + # Create the list of all pixels + grid_df = pd.DataFrame( + np.array( + np.meshgrid(np.arange(x_min, x_max + 1), np.arange(y_min, y_max + 1)) + ).T.reshape(-1, 2), + columns=["X", "Y"], + ) + + # Convert the xy locations to a geopandas data frame + idx_geo_df = gpd.GeoDataFrame( + grid_df, + geometry=gpd.points_from_xy(grid_df["X"], grid_df["Y"]), + ) + + return idx_geo_df + + +def create_rasterized_dataset( + nuclei_geo_df: geopandas.GeoDataFrame, + tx_geo_df: geopandas.GeoDataFrame, + sample_area: shapely.Polygon, + resolution=1, + foreground_nucleus_distance=1, + background_nucleus_distance=10, + background_transcript_distance=4, + background_pixel_transcripts=5, +): + n_genes = tx_geo_df["gene_id"].max() + 1 + + x_min, x_max, y_min, y_max = get_bounding_box(sample_area) + x_min, x_max = math.floor(x_min), math.ceil(x_max) + y_min, y_max = math.floor(y_min), math.ceil(y_max) + + x_size = (x_max - x_min) + 1 + y_size = (y_max - y_min) + 1 + + logger.info("Creating pixel geometry dataframe") + # Create a dataframe with an entry for every pixel + idx_geo_df = create_pixel_geodf(x_max=x_max, x_min=x_min, y_min=y_min, y_max=y_max) + + logger.info("Find the nearest nucleus to each pixel") + # Find the nearest nucleus to each pixel + labels_geo_df = gpd.sjoin_nearest( + idx_geo_df, nuclei_geo_df, how="left", distance_col="nucleus_distance" + ) + labels_geo_df.rename(columns={"index_right": "nucleus_id_xenium"}, inplace=True) + + # Calculate the nearest transcript neighbors + + logger.info("Calculating the nearest transcript neighbors") + transcript_xy = np.array( + [tx_geo_df["x_location"].values, tx_geo_df["y_location"].values] + ).T + kdtree = KDTree(transcript_xy) + + # Get the distance to the k'th nearest transcript + pixels_xy = np.array([labels_geo_df["X"].values, labels_geo_df["Y"].values]).T + labels_geo_df["transcript_distance"] = kdtree.query( + pixels_xy, k=background_pixel_transcripts + 1 + )[0][:, -1] + + # Assign pixels roughly on top of nuclei to belong to that nuclei label + pixel_labels = np.zeros(labels_geo_df.shape[0], dtype=int) - 1 + nucleus_pixels = labels_geo_df["nucleus_distance"] <= foreground_nucleus_distance + pixel_labels[nucleus_pixels] = labels_geo_df["nucleus_label"][nucleus_pixels] + + # Assign pixels to the background if they are far from nuclei and not near a dense region of transcripts + background_pixels = ( + labels_geo_df["nucleus_distance"] > background_nucleus_distance + ) & (labels_geo_df["transcript_distance"] > background_transcript_distance) + pixel_labels[background_pixels] = 0 + + # Convert back over to the grid format + labels = np.zeros((x_size, y_size), dtype=int) + labels[labels_geo_df["X"] - x_min, labels_geo_df["Y"] - y_min] = pixel_labels + + # Create a nuclei x gene count matrix + tx_nuclei_geo_df = gpd.sjoin_nearest( + tx_geo_df, nuclei_geo_df, distance_col="nucleus_distance" + ) + nuclei_count_geo_df = tx_nuclei_geo_df[ + tx_nuclei_geo_df["nucleus_distance"] <= foreground_nucleus_distance + ] + + # I think we have enough memory to just store this as a dense array + nuclei_count_matrix = np.zeros((nuclei_geo_df.shape[0] + 1, n_genes), dtype=int) + np.add.at( + nuclei_count_matrix, + ( + nuclei_count_geo_df["nucleus_label"].values.astype(int), + nuclei_count_geo_df["gene_id"].values.astype(int), + ), + 1, + ) + + # Assume for simplicity that it's a homogeneous poisson process for transcripts. + # Add up all the transcripts in each pixel. + tx_count_grid = np.zeros((x_size, y_size), dtype=int) + np.add.at( + tx_count_grid, + ( + tx_geo_df["x_location"].values.astype(int) - x_min, + tx_geo_df["y_location"].values.astype(int) - y_min, + ), + 1, + ) + + logger.info("Estimating cell types") + # Estimate the cell types + results = estimate_cell_types(nuclei_count_matrix) + best_k = 12 + + # Estimate the background rate + tx_background_mask = ( + labels[ + tx_geo_df["x_location"].values.astype(int) - x_min, + tx_geo_df["y_location"].values.astype(int) - y_min, + ] + == 0 + ) + background_probs = np.zeros(n_genes) + tx_geo_df_background = tx_geo_df[tx_background_mask] + for g in range(n_genes): + background_probs[g] = (tx_geo_df_background["gene_id"] == g).sum() + 1 + + # Estimate the density of each cell type + cell_type_probs = results["cell_types"][best_k] + + # Assign hard labels to nuclei + cell_type_labels = np.argmax(cell_type_probs, axis=1) + 1 + pixel_types = np.copy(labels) + nuclei_mask = labels > 0 + pixel_types[nuclei_mask] = cell_type_labels[labels[nuclei_mask]] + + # Calculate the angle at which each pixel faces to point at its nearest nucleus centroid. + # Normalize it to be in [0,1] + labels_geo_df["nucleus_angle"] = ( + cart2pol( + labels_geo_df["nucleus_centroid_x"].values - labels_geo_df["X"].values, + labels_geo_df["nucleus_centroid_y"].values - labels_geo_df["Y"].values, + )[1] + + np.pi + ) / (2 * np.pi) + angles = np.zeros(labels.shape) + angles[labels_geo_df["X"].values - x_min, labels_geo_df["Y"].values - y_min] = ( + labels_geo_df["nucleus_angle"].values + ) + + n_classes = cell_type_probs.shape[-1] + + logger.info("Creating dataset") + X = tx_geo_df["x_location"].values.astype(int) - x_min + Y = tx_geo_df["y_location"].values.astype(int) - y_min + G = tx_geo_df["gene_id"].values.astype(int) + ds = Nuc2SegDataset( + labels=labels, + angles=angles, + classes=pixel_types, + transcripts=np.array([X, Y, G]).T, + bbox=np.array([x_min, x_max, y_min, y_max]), + n_classes=n_classes, + n_genes=n_genes, + resolution=1.0, + ) + + return ds diff --git a/src/nuc2seg/preprocessing_test.py b/src/nuc2seg/preprocessing_test.py new file mode 100644 index 0000000..93f2cd6 --- /dev/null +++ b/src/nuc2seg/preprocessing_test.py @@ -0,0 +1,178 @@ +import pytest +import pandas +import geopandas +import shapely +from nuc2seg.preprocessing import create_rasterized_dataset +import numpy as np + + +@pytest.fixture(scope="package") +def test_nuclei_df(): + RECORDS = [ + { + "geometry": shapely.Polygon( + [(7.0, 7.0), (13.0, 7.0), (13.0, 13.0), (7.0, 13.0), (7.0, 7.0)] + ), + "nucleus_label": 1, + "nucleus_centroid": shapely.Point(10.0, 10.0), + "nucleus_centroid_x": 10.0, + "nucleus_centroid_y": 10.0, + }, + { + "geometry": shapely.Polygon( + [(17.0, 7.0), (23.0, 7.0), (23.0, 13.0), (17.0, 13.0), (17.0, 7.0)] + ), + "nucleus_label": 2, + "nucleus_centroid": shapely.Point(20.0, 10.0), + "nucleus_centroid_x": 20.0, + "nucleus_centroid_y": 10.0, + }, + ] + + return geopandas.GeoDataFrame(RECORDS).set_geometry("geometry") + + +@pytest.fixture(scope="package") +def test_transcripts_df(): + RECORDS = [ + # cell 1 + { + "transcript_id": 1, + "cell_id": "cell1", + "overlaps_nucleus": 1, + "feature_name": "gene1", + "x_location": 10.0, + "y_location": 11.0, + "z_location": 20.691404342651367, + "qv": 40.0, + "fov_name": "C10", + "codeword_index": 28, + "gene_id": 0, + }, + { + "transcript_id": 2, + "cell_id": "cell1", + "overlaps_nucleus": 1, + "feature_name": "gene1", + "x_location": 12.0, + "y_location": 11.0, + "z_location": 20.691404342651367, + "qv": 40.0, + "fov_name": "C10", + "codeword_index": 28, + "gene_id": 0, + }, + { + "transcript_id": 3, + "cell_id": "cell1", + "overlaps_nucleus": 1, + "feature_name": "gene1", + "x_location": 11.0, + "y_location": 10.0, + "z_location": 20.691404342651367, + "qv": 40.0, + "fov_name": "C10", + "codeword_index": 28, + "gene_id": 0, + }, + # cell 2 + { + "transcript_id": 4, + "cell_id": "cell2", + "overlaps_nucleus": 1, + "feature_name": "gene2", + "x_location": 20.0, + "y_location": 11.0, + "z_location": 20.691404342651367, + "qv": 40.0, + "fov_name": "C10", + "codeword_index": 28, + "gene_id": 1, + }, + { + "transcript_id": 5, + "cell_id": "cell1", + "overlaps_nucleus": 1, + "feature_name": "gene2", + "x_location": 22.0, + "y_location": 11.0, + "z_location": 20.691404342651367, + "qv": 40.0, + "fov_name": "C10", + "codeword_index": 28, + "gene_id": 1, + }, + { + "transcript_id": 6, + "cell_id": "cell1", + "overlaps_nucleus": 1, + "feature_name": "gene2", + "x_location": 21.0, + "y_location": 10.0, + "z_location": 20.691404342651367, + "qv": 40.0, + "fov_name": "C10", + "codeword_index": 28, + "gene_id": 1, + }, + # unlabeled transcripts + { + "transcript_id": 7, + "cell_id": "UNASSIGNED", + "overlaps_nucleus": 0, + "feature_name": "gene1", + "x_location": 10.0, + "y_location": 5.0, + "z_location": 13.079690933227539, + "qv": 40.0, + "fov_name": "C18", + "codeword_index": 54, + "gene_id": 0, + }, + { + "transcript_id": 8, + "cell_id": "UNASSIGNED", + "overlaps_nucleus": 0, + "feature_name": "gene2", + "x_location": 20.0, + "y_location": 5.0, + "z_location": 13.079690933227539, + "qv": 40.0, + "fov_name": "C18", + "codeword_index": 54, + "gene_id": 1, + }, + ] + + df = pandas.DataFrame(RECORDS) + + return geopandas.GeoDataFrame( + df, + geometry=geopandas.points_from_xy(df["x_location"], df["y_location"]), + ) + + +def test_create_rasterized_dataset(test_nuclei_df, test_transcripts_df): + np.random.seed(0) + ds = create_rasterized_dataset( + nuclei_geo_df=test_nuclei_df, + tx_geo_df=test_transcripts_df, + sample_area=shapely.Polygon([(1, 1), (30, 1), (30, 20), (1, 20), (1, 1)]), + resolution=1, + foreground_nucleus_distance=1, + background_nucleus_distance=5, + background_transcript_distance=3, + background_pixel_transcripts=5, + ) + + assert ds.labels.shape == (30, 20) + assert ds.classes.shape == (30, 20) + assert ds.transcripts.shape == (8, 3) + assert ds.x_extent_pixels == 30 + assert ds.y_extent_pixels == 20 + assert ds.n_genes == 2 + assert ds.n_classes == 14 + + # Assert coordinated are transformed relative to the bbox + assert ds.transcripts[:, 0].min() == 9.0 + assert ds.transcripts[:, 1].min() == 4.0 diff --git a/src/nuc2seg/train.py b/src/nuc2seg/train.py index ea947d2..f35a6f8 100755 --- a/src/nuc2seg/train.py +++ b/src/nuc2seg/train.py @@ -2,11 +2,10 @@ import torch.nn as nn import tqdm -from nuc2seg.data_loading import XeniumDataset, xenium_collate_fn +from nuc2seg.data import TiledDataset, collate_tiles from torch import optim from torch.utils.data import DataLoader, random_split -from nuc2seg.unet_model import SparseUNet from nuc2seg.evaluate import evaluate @@ -19,24 +18,10 @@ def angle_loss(predictions, targets): return torch.minimum(torch.minimum(delta**2, (delta - 1) ** 2), (delta + 1) ** 2) -""" -device = 'cpu' -epochs = 5 -batch_size = 3 -learning_rate = 1e-5 -val_percent = 0.1 -save_checkpoint = True -amp = False -weight_decay = 1e-8 -momentum = 0.999 -gradient_clipping = 1.0 -max_workers = 1""" - - def train( model, device, - tiles_dir: str, + dataset: TiledDataset, epochs: int = 50, batch_size: int = 1, learning_rate: float = 1e-5, @@ -48,9 +33,8 @@ def train( gradient_clipping: float = 1.0, max_workers: int = 1, validation_frequency: int = 500, + num_dataloader_workers: int = 4, ): - # Create the dataset - dataset = XeniumDataset(tiles_dir) # 2. Split into train / validation partitions n_val = int(len(dataset) * val_percent) @@ -61,22 +45,14 @@ def train( # 3. Create data loaders loader_args = dict( - batch_size=batch_size, pin_memory=True, collate_fn=xenium_collate_fn + batch_size=batch_size, + pin_memory=True, + collate_fn=collate_tiles, + num_workers=num_dataloader_workers, ) # TODO: add num_workers back; cut out to work in ipython train_loader = DataLoader(train_set, shuffle=True, **loader_args) val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args) - # logging.info(f'''Starting training: - # Epochs: {epochs} - # Batch size: {batch_size} - # Learning rate: {learning_rate} - # Training size: {n_train} - # Validation size: {n_val} - # Checkpoints: {save_checkpoint} - # Device: {device.type} - # Mixed Precision: {amp} - # ''') - # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP optimizer = optim.RMSprop( model.parameters(), @@ -90,8 +66,8 @@ def train( celltype_criterion = nn.CrossEntropyLoss( reduction="mean", weight=torch.Tensor( - dataset.class_counts[:, 2:].mean() - / dataset.class_counts[:, 2:].mean(axis=0) + dataset.per_tile_class_histograms[:, 2:].mean() + / dataset.per_tile_class_histograms[:, 2:].mean(axis=0) ), ) # Class imbalance reweighting @@ -99,28 +75,23 @@ def train( validation_scores = [] # 5. Begin training - for epoch in tqdm.trange(1, epochs + 1, position=0, desc="Epoch"): + for epoch in tqdm.trange(0, epochs, position=0, desc="Epoch"): model.train() epoch_loss = 0 - # with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar: - for batch in tqdm.tqdm(train_loader, position=1, desc="Batch"): + for batch in tqdm.tqdm(train_loader, position=1, desc="Batch", leave=False): x, y, z, labels, angles, classes, label_mask, nucleus_mask = ( - batch["X"], - batch["Y"], - batch["gene"], - batch["labels"], - batch["angles"], - batch["classes"], - batch["label_mask"], - batch["nucleus_mask"], + batch["X"].to(device), + batch["Y"].to(device), + batch["gene"].to(device), + batch["labels"].to(device), + batch["angles"].to(device), + batch["classes"].to(device), + batch["label_mask"].to(device), + batch["nucleus_mask"].to(device), ) label_mask = label_mask.type(torch.bool) - # TODO: move everything to the appropriate device if GPU training - # images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last) - # true_masks = true_masks.to(device=device, dtype=torch.long) - # with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp): mask_pred = model(x, y, z) foreground_pred = mask_pred[..., 0] @@ -159,24 +130,3 @@ def train( if global_step % validation_frequency == 0: validation_score = evaluate(model, val_loader, device, amp) validation_scores.append(validation_score) - - -if __name__ == "__main__": - transcripts_dir = "data/tiles/transcripts/" - labels_dir = "data/tiles/labels/" - angles_dir = "data/tiles/angles/" - classes_dir = "data/tiles/classes/" - - n_classes = 12 - - # Create the model - # Outputs: - # Channel 0: Foreground vs background logit - # Channel 1: Angle logit pointing to the nucleus - # Channel 2-K+2: Class label prediction - # TODO: first parameter should be the number of unique transcripts - model = SparseUNet(600, n_classes + 2, (64, 64)) - - device = "cpu" - - train(model, device, transcripts_dir, labels_dir, angles_dir, classes_dir) diff --git a/src/nuc2seg/xenium.py b/src/nuc2seg/xenium.py new file mode 100755 index 0000000..af05bef --- /dev/null +++ b/src/nuc2seg/xenium.py @@ -0,0 +1,154 @@ +import logging +import os + +import shapely +import geopandas as gpd +import numpy as np +import pandas as pd +from matplotlib import pyplot as plt +from shapely import Polygon +from shapely.geometry import box + +logger = logging.getLogger(__name__) + + +def create_shapely_rectangle(x1, y1, x2, y2): + return box(x1, y1, x2, y2) + + +def get_bounding_box(poly: shapely.Polygon): + coords = list(poly.exterior.coords) + + # Find the extreme vertices + leftmost = min(coords, key=lambda point: point[0]) + rightmost = max(coords, key=lambda point: point[0]) + topmost = max(coords, key=lambda point: point[1]) + bottommost = min(coords, key=lambda point: point[1]) + + return leftmost[0], rightmost[0], bottommost[1], topmost[1] + + +def filter_gdf_to_inside_polygon(gdf, polygon): + selection_vector = gdf.geometry.apply(lambda x: polygon.contains(x)) + logging.info( + f"Filtering {len(gdf)} points to {selection_vector.sum()} inside polygon" + ) + return gdf[selection_vector] + + +def read_boundaries_into_polygons( + boundaries_file, + cell_id_column="cell_id", + x_column_name="vertex_x", + y_column_name="vertex_y", +): + boundaries = pd.read_parquet(boundaries_file) + geo_df = gpd.GeoDataFrame( + boundaries, + geometry=gpd.points_from_xy( + boundaries[x_column_name], boundaries[y_column_name] + ), + ) + polys = geo_df.groupby(cell_id_column)["geometry"].apply( + lambda x: Polygon(x.tolist()) + ) + return gpd.GeoDataFrame(polys) + + +def read_transcripts_into_points( + transcripts_file, x_column_name="x_location", y_column_name="y_location" +): + transcripts = pd.read_parquet(transcripts_file) + + tx_geo_df = gpd.GeoDataFrame( + transcripts, + geometry=gpd.points_from_xy( + transcripts[x_column_name], transcripts[y_column_name] + ), + ) + return tx_geo_df + + +def load_nuclei(nuclei_file: str, sample_area: shapely.Polygon): + nuclei_geo_df = read_boundaries_into_polygons(nuclei_file) + + original_n_nuclei = nuclei_geo_df.shape[0] + + nuclei_geo_df = filter_gdf_to_inside_polygon(nuclei_geo_df, sample_area) + + logger.info( + f"{original_n_nuclei-nuclei_geo_df.shape[0]} nuclei filtered after bounding to {sample_area}" + ) + + if nuclei_geo_df.empty: + raise ValueError("No nuclei found in the sample area") + + nuclei_geo_df["nucleus_label"] = np.arange(1, nuclei_geo_df.shape[0] + 1) + nuclei_geo_df["nucleus_centroid"] = nuclei_geo_df["geometry"].centroid + nuclei_geo_df["nucleus_centroid_x"] = nuclei_geo_df["geometry"].centroid.x + nuclei_geo_df["nucleus_centroid_y"] = nuclei_geo_df["geometry"].centroid.y + + logger.info(f"Loaded {nuclei_geo_df.shape[0]} nuclei.") + + return nuclei_geo_df + + +def load_and_filter_transcripts( + transcripts_file: str, sample_area: shapely.Polygon, min_qv=20.0 +): + transcripts_df = read_transcripts_into_points(transcripts_file) + + original_count = len(transcripts_df) + + transcripts_df = filter_gdf_to_inside_polygon(transcripts_df, sample_area) + transcripts_df.drop(columns=["nucleus_distance"], inplace=True) + + count_after_bbox = len(transcripts_df) + + logger.info( + f"{original_count-count_after_bbox} tx filtered after bounding to {sample_area}" + ) + + # Filter out controls and low quality transcripts + transcripts_df = transcripts_df[ + (transcripts_df["qv"] >= min_qv) + & (~transcripts_df["feature_name"].str.startswith("NegControlProbe_")) + & (~transcripts_df["feature_name"].str.startswith("antisense_")) + & (~transcripts_df["feature_name"].str.startswith("NegControlCodeword_")) + & (~transcripts_df["feature_name"].str.startswith("BLANK_")) + ] + + count_after_quality_filtering = len(transcripts_df) + + logger.info( + f"{count_after_bbox-count_after_quality_filtering} tx filtered after quality filtering" + ) + + if transcripts_df.empty: + raise ValueError("No transcripts found in the sample area") + + # Assign a unique integer ID to each gene + gene_ids = transcripts_df["feature_name"].unique() + n_genes = len(gene_ids) + mapping = dict(zip(gene_ids, np.arange(len(gene_ids)))) + transcripts_df["gene_id"] = transcripts_df["feature_name"].apply( + lambda x: mapping.get(x, 0) + ) + + logger.info( + f"Loaded {count_after_quality_filtering} transcripts. {n_genes} unique genes." + ) + + return transcripts_df + + +def plot_distribution_of_cell_types(cell_type_probs): + n_rows = cell_type_probs.shape[1] // 3 + int(cell_type_probs.shape[1] % 3 > 0) + n_cols = 3 + fig, axarr = plt.subplots(n_rows, n_cols, sharex=True) + for idx in range(cell_type_probs.shape[1]): + i, j = idx // 3, idx % 3 + axarr[i, j].hist( + cell_type_probs[np.argmax(cell_type_probs, axis=1) == idx, idx], bins=200 + ) + plt.show() diff --git a/src/nuc2seg/xenium_utils.py b/src/nuc2seg/xenium_utils.py deleted file mode 100755 index c5bf0d3..0000000 --- a/src/nuc2seg/xenium_utils.py +++ /dev/null @@ -1,703 +0,0 @@ -import argparse -import logging -import os -import sys - -import tqdm -from scipy.special import softmax -import geopandas as gpd -import numpy as np -import pandas as pd -from matplotlib import pyplot as plt -from shapely import Polygon -from scipy.special import logsumexp -from scipy.stats import poisson -from scipy.spatial import KDTree - -logger = logging.getLogger(__name__) - - -def configure_logging(): - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s [%(levelname)s] %(message)s", - handlers=[logging.StreamHandler(sys.stdout)], - ) - - -def create_shapely_rectangle(x1, y1, x2, y2): - return Polygon([(x1, y1), (x1, y2), (x2, y2), (x2, y1)]) - - -def filter_gdf_to_inside_polygon(gdf, polygon): - selection_vector = gdf.geometry.apply(lambda x: polygon.contains(x)) - logging.info( - f"Filtering {len(gdf)} points to {selection_vector.sum()} inside polygon" - ) - return gdf[selection_vector] - - -def read_boundaries_into_polygons( - boundaries_file, - cell_id_column="cell_id", - x_column_name="vertex_x", - y_column_name="vertex_y", -): - boundaries = pd.read_parquet(boundaries_file) - geo_df = gpd.GeoDataFrame( - boundaries, - geometry=gpd.points_from_xy( - boundaries[x_column_name], boundaries[y_column_name] - ), - ) - polys = geo_df.groupby(cell_id_column)["geometry"].apply( - lambda x: Polygon(x.tolist()) - ) - return gpd.GeoDataFrame(polys) - - -def read_transcripts_into_points( - transcripts_file, x_column_name="x_location", y_column_name="y_location" -): - transcripts = pd.read_parquet(transcripts_file) - - tx_geo_df = gpd.GeoDataFrame( - transcripts, - geometry=gpd.points_from_xy( - transcripts[x_column_name], transcripts[y_column_name] - ), - ) - return tx_geo_df - - -def spatial_join_polygons_and_transcripts( - boundaries: gpd.GeoDataFrame, transcripts: gpd.GeoDataFrame -): - joined_gdf = gpd.sjoin(transcripts, boundaries, how="left") - - unique_vals = joined_gdf["index_right"][joined_gdf["index_right"].notna()].unique() - - mapping = dict(zip(unique_vals, np.arange(len(unique_vals)))) - - joined_gdf["nucleus_id"] = joined_gdf["index_right"].apply( - lambda x: mapping.get(x, 0) - ) - - return joined_gdf - - -def plot_spatial_join(joined_gdf, polygon_gdf, output_dir, cell_id_column="nucleus_id"): - fig, ax = plt.subplots(figsize=(10, 10), dpi=1000) - polygon_gdf.plot(ax=ax, edgecolor="black") - joined_gdf[joined_gdf["nucleus_id"] == 0].plot(ax=ax, color="black", markersize=1) - joined_gdf[joined_gdf["nucleus_id"] > 0].plot( - ax=ax, - column=cell_id_column, - categorical=True, - legend=False, - markersize=1, - cmap="rainbow", - ) - - plt.savefig(os.path.join(output_dir, "transcript_and_boundary.pdf")) - - -## TODO: Maybe replace this with a library that does the same thing -def estimate_cell_types( - gene_counts, - min_components=2, - max_components=25, - max_em_steps=100, - tol=1e-4, - warm_start=False, -): - - n_nuclei, n_genes = gene_counts.shape - - # Randomly initialize cell type profiles - cur_expression_profiles = np.random.dirichlet(np.ones(n_genes), size=min_components) - - # Initialize probabilities to be uniform - cur_prior_probs = np.ones(min_components) / min_components - - # No need to initialize cell type assignments - cur_cell_types = None - - # Track BIC and AIC scores - aic_scores = np.zeros(max_components - min_components + 1) - bic_scores = np.zeros(max_components - min_components + 1) - - # Iterate through every possible number of components - ( - final_expression_profiles, - final_prior_probs, - final_cell_types, - final_expression_rates, - ) = ([], [], [], []) - prev_expression_profiles = np.zeros_like(cur_expression_profiles) - for idx, n_components in enumerate( - tqdm.trange( - min_components, max_components + 1, desc="estimate_cell_types", position=0 - ) - ): - - # Warm start from the previous iteration - if warm_start and n_components > min_components: - # Expand and copy over the current parameters - next_prior_probs = np.zeros(n_components) - next_expression_profiles = np.zeros((n_components, n_genes)) - next_prior_probs[: n_components - 1] = cur_prior_probs - next_expression_profiles[: n_components - 1] = cur_expression_profiles - - # Split the dominant cluster - dominant = np.argmax(cur_prior_probs) - split_prob = np.random.random() - next_prior_probs[-1] = next_prior_probs[dominant] * split_prob - next_prior_probs[dominant] = next_prior_probs[dominant] * (1 - split_prob) - next_expression_profiles[-1] = cur_expression_profiles[ - dominant - ] * split_prob + (1 - split_prob) * np.random.dirichlet(np.ones(n_genes)) - - cur_prior_probs = next_prior_probs - cur_expression_profiles = next_expression_profiles - - logger.debug("priors:", cur_prior_probs) - logger.debug("expression:", cur_expression_profiles) - else: - # Cold start from a random location - cur_expression_profiles = np.random.dirichlet( - np.ones(n_genes), size=n_components - ) - cur_prior_probs = np.ones(n_components) / n_components - - converge = tol + 1 - for step in tqdm.trange( - max_em_steps, - desc=f"EM for n_components {n_components}", - unit="step", - position=1, - ): - - # E-step: estimate cell type assignments (posterior probabilities) - logits = np.log(cur_prior_probs[None]) + ( - gene_counts[:, None] * np.log(cur_expression_profiles[None]) - ).sum(axis=2) - cur_cell_types = softmax(logits, axis=1) - - # M-step (part 1): estimate cell type profiles - prev_expression_profiles = np.array(cur_expression_profiles) - cur_expression_profiles = ( - cur_cell_types[..., None] * gene_counts[:, None] - ).sum(axis=0) + 1 - cur_expression_profiles = ( - cur_expression_profiles / (cur_cell_types.sum(axis=0) + 1)[:, None] - ) - - # M-step (part 2): estimate cell type probabilities - cur_prior_probs = (cur_cell_types.sum(axis=0) + 1) / ( - cur_cell_types.shape[0] + n_components - ) - logger.debug(f"cur_prior_probs: {cur_prior_probs}") - - # Track convergence of the cell type profiles - converge = np.linalg.norm( - prev_expression_profiles - cur_expression_profiles - ) / np.linalg.norm(prev_expression_profiles) - - logger.debug(f"Convergence: {converge:.4f}") - if converge <= tol: - logger.debug(f"Stopping early.") - break - - # Save the results - final_expression_profiles.append(cur_expression_profiles) - final_cell_types.append(cur_cell_types) - final_prior_probs.append(cur_prior_probs) - - # Calculate BIC and AIC - aic_scores[idx], bic_scores[idx] = aic_bic( - gene_counts, cur_expression_profiles, cur_prior_probs - ) - - logger.debug(f"K={n_components}") - logger.debug(f"AIC: {aic_scores[idx]:.4f}") - logger.debug(f"BIC: {bic_scores[idx]:.4f}") - - return { - "bic": bic_scores, - "aic": aic_scores, - "expression_profiles": final_expression_profiles, - "prior_probs": final_prior_probs, - "cell_types": final_cell_types, - } - - -def aic_bic(gene_counts, expression_profiles, prior_probs): - - n_components = expression_profiles.shape[0] - n_genes = expression_profiles.shape[1] - n_samples = gene_counts.shape[0] - - dof = n_components * (n_genes - 1) + n_components - 1 - log_likelihood = logsumexp( - np.log(prior_probs[None]) - + (gene_counts[:, None] * np.log(expression_profiles[None])).sum(axis=2), - axis=1, - ).sum() - aic = -2 * log_likelihood + 2 * dof - bic = -2 * log_likelihood + dof * np.log(n_samples) - - return aic, bic - - -""" -nuclei_file = 'data/nucleus_boundaries.parquet' -transcripts_file = 'data/transcripts.csv' -out_dir = 'data/' -pixel_stride=1 -min_qv=20.0 -foreground_nucleus_distance=1 -background_nucleus_distance=10 -background_transcript_distance=4 -background_pixel_transcripts=5 -tile_height=64 -tile_width=64 -""" - - -def calculate_pixel_loglikes( - tx_geo_df, tx_count_grid, expression_profiles, expression_rates -): - # Create zero-padded arrays of pixels x transcript counts and IDs - # NOTE: this is a painfully slow way of doing this. it's just a quick and dirty solution. - max_transcript_count = tx_count_grid.max() + 1 - transcript_ids = np.zeros(tx_count_grid.shape + (max_transcript_count,), dtype=int) - transcript_counts = np.zeros( - tx_count_grid.shape + (max_transcript_count,), dtype=int - ) - for row_idx, row in tqdm.tqdm( - tx_geo_df.iterrows(), desc="calculate_pixel_loglikes" - ): - x, y = int(row["x_location"]), int(row["y_location"]) - tx_mask = transcript_ids[x, y] == row["gene_id"] - if tx_mask.sum() != 0: - gene_idx = np.argmax(tx_mask) - transcript_counts[x, y, gene_idx] += 1 - else: - tx_mask = transcript_counts[x, y] == 0 - gene_idx = np.argmax(tx_mask) - transcript_ids[x, y, gene_idx] += row["gene_id"] - transcript_counts[x, y, gene_idx] += 1 - - # Track how many transcripts are in each pixel and how much padding we need - totals = transcript_counts.sum(axis=-1) - max_count = totals.max() - - # Calculate the log-likelihoods for each pixel having that many transcripts. - # Be clever about this to avoid recalculating lots of duplicate PMF values. - rate_loglike_uniques = poisson.logpmf( - np.arange(max_count + 1)[:, None], expression_rates[None] - ) - rate_loglikes = rate_loglike_uniques[totals.reshape(-1)].reshape( - totals.shape + (expression_rates.shape[0],) - ) - - # Calculate the log-likelihoods for each pixel generating a certain set of transcripts. - # NOTE: yes we could vectorize this, but that leads to memory issues. - expression_loglikes = np.zeros_like(rate_loglikes) - log_expression = np.log(expression_profiles) - total = expression_loglikes.shape[0] * expression_loglikes.shape[1] - pbar = tqdm.tqdm(total=total, desc="calculate_pixel_loglikes") - - for i in range(expression_loglikes.shape[0]): - for j in range(expression_loglikes.shape[1]): - expression_loglikes[i, j] = ( - transcript_counts[i, j, :, None] - * log_expression.T[transcript_ids[i, j]] - ).sum(axis=0) - pbar.update(1) - - # Add the two log-likelihoods together to get the pixelwise log-likelihood - return rate_loglikes, expression_loglikes - - -def cart2pol(x, y): - """Convert Cartesian coordinates to polar coordinates""" - rho = np.sqrt(x**2 + y**2) - phi = np.arctan2(y, x) - return (rho, phi) - - -def pol2cart(rho, phi): - x = rho * np.cos(phi) - y = rho * np.sin(phi) - return (x, y) - - -def load_nuclei(nuclei_file: str): - nuclei_geo_df = read_boundaries_into_polygons(nuclei_file) - nuclei_geo_df["nucleus_label"] = np.arange(1, nuclei_geo_df.shape[0] + 1) - nuclei_geo_df["nucleus_centroid"] = nuclei_geo_df["geometry"].centroid - nuclei_geo_df["nucleus_centroid_x"] = nuclei_geo_df["geometry"].centroid.x - nuclei_geo_df["nucleus_centroid_y"] = nuclei_geo_df["geometry"].centroid.y - return nuclei_geo_df - - -def load_and_filter_transcripts(transcripts_file: str, min_qv=20.0): - transcripts_df = pd.read_parquet(transcripts_file) - transcripts_df.drop(columns=["nucleus_distance"], inplace=True) - - # Filter out controls and low quality transcripts - transcripts_df = transcripts_df[ - (transcripts_df["qv"] >= min_qv) - & (~transcripts_df["feature_name"].str.startswith("NegControlProbe_")) - & (~transcripts_df["feature_name"].str.startswith("antisense_")) - & (~transcripts_df["feature_name"].str.startswith("NegControlCodeword_")) - & (~transcripts_df["feature_name"].str.startswith("BLANK_")) - ] - - # Convert to a geopandas object - tx_geo_df = gpd.GeoDataFrame( - transcripts_df, - geometry=gpd.points_from_xy( - transcripts_df["x_location"], transcripts_df["y_location"] - ), - ) - - # Assign a unique integer ID to each gene - gene_ids = tx_geo_df["feature_name"].unique() - n_genes = len(gene_ids) - mapping = dict(zip(gene_ids, np.arange(len(gene_ids)))) - tx_geo_df["gene_id"] = tx_geo_df["feature_name"].apply(lambda x: mapping.get(x, 0)) - - return tx_geo_df, gene_ids - - -def create_pixel_geodf(x_max, y_max): - # Create the list of all pixels - grid_df = pd.DataFrame( - np.array(np.meshgrid(np.arange(x_max + 1), np.arange(y_max + 1))).T.reshape( - -1, 2 - ), - columns=["X", "Y"], - ) - - # Convert the xy locations to a geopandas data frame - idx_geo_df = gpd.GeoDataFrame( - grid_df, - geometry=gpd.points_from_xy(grid_df["X"], grid_df["Y"]), - ) - - return idx_geo_df - - -def plot_distribution_of_cell_types(cell_type_probs): - n_rows = cell_type_probs.shape[1] // 3 + int(cell_type_probs.shape[1] % 3 > 0) - n_cols = 3 - fig, axarr = plt.subplots(n_rows, n_cols, sharex=True) - for idx in range(cell_type_probs.shape[1]): - i, j = idx // 3, idx % 3 - axarr[i, j].hist( - cell_type_probs[np.argmax(cell_type_probs, axis=1) == idx, idx], bins=200 - ) - plt.show() - - -def spatial_as_sparse_arrays( - nuclei_file: str, - transcripts_file: str, - outdir: str, - pixel_stride=1, - min_qv=20.0, - foreground_nucleus_distance=1, - background_nucleus_distance=10, - background_transcript_distance=4, - background_pixel_transcripts=5, - tile_height=64, - tile_width=64, - tile_stride=48, -): - """Creates a list of sparse CSC arrays. First array is the nuclei mask. - All other arrays are the transcripts.""" - # Load the nuclei boundaries and assign them unique integer IDs - logger.info("Loading nuclei boundaries") - nuclei_geo_df = load_nuclei(nuclei_file) - - # Load the transcript locations - logger.info("Loading transcript locations") - tx_geo_df, gene_ids = load_and_filter_transcripts(transcripts_file, min_qv=min_qv) - n_genes = tx_geo_df["gene_id"].max() + 1 - - # Get the approx bounds of the image - x_max, y_max = int(tx_geo_df["x_location"].max() + 1), int( - tx_geo_df["y_location"].max() + 1 - ) - x_min, y_min = int(tx_geo_df["x_location"].min()), int( - tx_geo_df["y_location"].min() - ) - - logger.info("Creating pixel geometry dataframe") - # Create a dataframe with an entry for every pixel - idx_geo_df = create_pixel_geodf(x_max, y_max) - - logger.info("Find the nearest nucleus to each pixel") - # Find the nearest nucleus to each pixel - labels_geo_df = gpd.sjoin_nearest( - idx_geo_df, nuclei_geo_df, how="left", distance_col="nucleus_distance" - ) - labels_geo_df.rename(columns={"index_right": "nucleus_id_xenium"}, inplace=True) - - # Calculate the nearest transcript neighbors - - logger.info("Calculating the nearest transcript neighbors") - transcript_xy = np.array( - [tx_geo_df["x_location"].values, tx_geo_df["y_location"].values] - ).T - kdtree = KDTree(transcript_xy) - - # Get the distance to the k'th nearest transcript - pixels_xy = np.array([labels_geo_df["X"].values, labels_geo_df["Y"].values]).T - labels_geo_df["transcript_distance"] = kdtree.query( - pixels_xy, k=background_pixel_transcripts + 1 - )[0][:, -1] - - # Assign pixels roughly on top of nuclei to belong to that nuclei label - pixel_labels = np.zeros(labels_geo_df.shape[0], dtype=int) - 1 - nucleus_pixels = labels_geo_df["nucleus_distance"] <= foreground_nucleus_distance - pixel_labels[nucleus_pixels] = labels_geo_df["nucleus_label"][nucleus_pixels] - - # Assign pixels to the background if they are far from nuclei and not near a dense region of transcripts - background_pixels = ( - labels_geo_df["nucleus_distance"] > background_nucleus_distance - ) & (labels_geo_df["transcript_distance"] > background_transcript_distance) - pixel_labels[background_pixels] = 0 - - # Convert back over to the grid format - labels = np.zeros( - (labels_geo_df["X"].max() + 1, labels_geo_df["Y"].max() + 1), dtype=int - ) - labels[labels_geo_df["X"], labels_geo_df["Y"]] = pixel_labels - - # Create a nuclei x gene count matrix - tx_nuclei_geo_df = gpd.sjoin_nearest( - tx_geo_df, nuclei_geo_df, distance_col="nucleus_distance" - ) - nuclei_count_geo_df = tx_nuclei_geo_df[ - tx_nuclei_geo_df["nucleus_distance"] <= foreground_nucleus_distance - ] - - # I think we have enough memory to just store this as a dense array - nuclei_count_matrix = np.zeros((nuclei_geo_df.shape[0] + 1, n_genes), dtype=int) - np.add.at( - nuclei_count_matrix, - ( - nuclei_count_geo_df["nucleus_label"].values.astype(int), - nuclei_count_geo_df["gene_id"].values.astype(int), - ), - 1, - ) - - # Assume for simplicity that it's a homogeneous poisson process for transcripts. - # Add up all the transcripts in each pixel. - tx_count_grid = np.zeros((x_max + 1, y_max + 1), dtype=int) - np.add.at( - tx_count_grid, - ( - tx_geo_df["x_location"].values.astype(int), - tx_geo_df["y_location"].values.astype(int), - ), - 1, - ) - - logger.info("Estimating cell types") - # Estimate the cell types - results = estimate_cell_types(nuclei_count_matrix) - best_k = 12 - - # Estimate the background rate - tx_background_mask = ( - labels[ - tx_geo_df["x_location"].values.astype(int), - tx_geo_df["y_location"].values.astype(int), - ] - == 0 - ) - background_probs = np.zeros(n_genes) - tx_geo_df_background = tx_geo_df[tx_background_mask] - for g in range(n_genes): - background_probs[g] = (tx_geo_df_background["gene_id"] == g).sum() + 1 - background_probs = background_probs / background_probs.sum() - background_rate = tx_background_mask.sum() / (pixel_labels == 0).sum() - - # Estimate the density of each cell type - cell_type_probs = results["cell_types"][best_k] - pixel_mask = pixel_labels > 0 - labeled_pixels = labels_geo_df[pixel_mask] - X, Y, L = labeled_pixels["X"], labeled_pixels["Y"], pixel_labels[pixel_mask] - cell_type_rates = (cell_type_probs[L] * tx_count_grid[X, Y][:, None]).sum( - axis=0 - ) / cell_type_probs[L].sum(axis=0) - - # Combine the background and foreground classes - all_expression_profiles = np.vstack( - [background_probs, results["expression_profiles"][best_k]] - ) - all_expression_rates = np.concatenate([[background_rate], cell_type_rates]) - - # Assign hard labels to nuclei - cell_type_labels = np.argmax(cell_type_probs, axis=1) + 1 - pixel_types = np.copy(labels) - nuclei_mask = labels > 0 - pixel_types[nuclei_mask] = cell_type_labels[labels[nuclei_mask]] - - # Plot the results for a small region - # arr = np.array(pixel_types[2000:2400,2000:2400], dtype=float) - # for i,c in enumerate(np.random.choice(np.unique(arr[arr > 0]), replace=False, size=len(np.unique(arr[arr > 0])))): - # if c <= 0: - # continue - # arr[arr == c] = i+1 - # arr[arr == -1] = np.nan - # plt.imshow(arr, cmap='tab20b') - # plt.show() - - # Get the log-likelihood of each pixel being generated by each mixture component - logger.info("Calculating pixel log-likelihoods") - rate_loglikes, expression_loglikes = calculate_pixel_loglikes( - tx_geo_df, tx_count_grid, all_expression_profiles, all_expression_rates - ) - - # # Plot the watershed results) - # from skimage.segmentation import watershed - # watershed_segments = watershed(smooth_probs[...,0]) - # arr = np.copy(watershed_segments) - # for i,c in enumerate(np.random.choice(np.unique(arr), replace=False, size=len(np.unique(arr)))): - # arr[arr == c] = i+1 - # plt.imshow(arr, cmap='tab20b') - # plt.show() - - #### Calculate the direction that each nucleus pixel is pointing in - # nuclei_geo_df['nucleus_centroid'] = nuclei_geo_df['geometry'].centroid - # centroid_mapping = {row['nucleus_label']: (row['nucleus_centroid'],row['nucleus_centroid_x'],row['nucleus_centroid_y']) for _,row in nuclei_geo_df.iterrows()} - # labels_geo_df['nucleus_centroid'] = labels_geo_df['nucleus_label'].apply(lambda x: centroid_mapping.get(x)[0]) - # labels_geo_df['nucleus_centroid_x'] = labels_geo_df['nucleus_label'].apply(lambda x: centroid_mapping.get(x)[1]) - # labels_geo_df['nucleus_centroid_y'] = labels_geo_df['nucleus_label'].apply(lambda x: centroid_mapping.get(x)[2]) - - # Calculate the angle at which each pixel faces to point at its nearest nucleus centroid. - # Normalize it to be in [0,1] - labels_geo_df["nucleus_angle"] = ( - cart2pol( - labels_geo_df["nucleus_centroid_x"].values - labels_geo_df["X"].values, - labels_geo_df["nucleus_centroid_y"].values - labels_geo_df["Y"].values, - )[1] - + np.pi - ) / (2 * np.pi) - angles = np.zeros(labels.shape) - angles[labels_geo_df["X"].values, labels_geo_df["Y"].values] = labels_geo_df[ - "nucleus_angle" - ].values - - # Create tiled images and labels from the giant image - image_id = 0 - tile_locations = [] - tx_local_counts = [] - class_local_counts = [] - n_classes = cell_type_probs.shape[-1] - if not os.path.exists(f"{outdir}/tiles/transcripts/"): - os.makedirs(f"{outdir}/tiles/transcripts/") - if not os.path.exists(f"{outdir}/tiles/labels/"): - os.makedirs(f"{outdir}/tiles/labels/") - if not os.path.exists(f"{outdir}/tiles/angles/"): - os.makedirs(f"{outdir}/tiles/angles/") - if not os.path.exists(f"{outdir}/tiles/classes/"): - os.makedirs(f"{outdir}/tiles/classes/") - - n_tiles_total = ( - np.arange(y_max, x_max + 1, tile_stride).shape[0] - * np.arange(y_min, y_max + 1, tile_stride).shape[0] - ) - - progress_bar = tqdm.tqdm(total=n_tiles_total, desc="Processing tiles") - - logger.info("Creating tiles") - - for x_start in np.arange(y_max, x_max + 1, tile_stride): - # Handle edge cases - x_start = min(x_start, x_max - tile_width) - - # Filter transcripts and labels - tx_local_x = tx_geo_df[ - tx_geo_df["x_location"].between( - x_start, x_start + tile_width, inclusive="left" - ) - ] - - for y_start in np.arange(y_min, y_max + 1, tile_stride): - # Handle edge cases - y_start = min(y_start, y_max - tile_height) - - # Filter transcripts and labels - tx_local = tx_local_x[ - tx_local_x["y_location"].between( - y_start, y_start + tile_height, inclusive="left" - ) - ] - - # Save a numpy array of pixel x, pixel y, and gene ID - X = tx_local["x_location"].values.astype(int) - x_start - Y = tx_local["y_location"].values.astype(int) - y_start - G = tx_local["gene_id"].values.astype(int) - np.savez(f"{outdir}/tiles/transcripts/{image_id}", np.array([X, Y, G]).T) - - # Save a numpy matrix with every entry being the ID of the cell, background (0), or unknown (-1) - labels_local = np.array( - labels[x_start : x_start + tile_width, y_start : y_start + tile_height] - ) - local_ids = np.unique(labels_local) - local_ids = local_ids[local_ids > 0] - for i, c in enumerate(local_ids): - labels_local[labels_local == c] = i + 1 - np.savez(f"{outdir}/tiles/labels/{image_id}", labels_local) - - # Save a numpy matrix with every entry being the angle from this pixel to the centroid - # of the nucleus (in [0,1]) or unknown (-1) - angles_local = np.array( - angles[x_start : x_start + tile_width, y_start : y_start + tile_height] - ) - angles_local[labels_local == -1] = -1 - np.savez(f"{outdir}/tiles/angles/{image_id}", angles_local) - - # Save a numpy matrix with every entry being the class ID (cell type) of the cell (1...K), - # background (0), or unknown (-1) - classes_local = np.array( - pixel_types[ - x_start : x_start + tile_width, y_start : y_start + tile_height - ] - ) - np.savez(f"{outdir}/tiles/classes/{image_id}", classes_local) - - # Track the original coordinates that this tile belongs to - tile_locations.append([x_start, y_start]) - - # Track how many transcripts this tile has - tx_local_counts.append(len(G)) - - # Track how many of each class label this tile has - temp = np.zeros(n_classes + 2, dtype=int) - uniques = np.unique(classes_local, return_counts=True) - logger.debug(f"uniques: {uniques}") - temp[uniques[0] + 1] = uniques[1] - class_local_counts.append(temp) - - # Update the image filename ID - image_id += 1 - progress_bar.update(1) - - # Save the tile (x,y) locations - logger.info(f"Saving {outdir}/tiles/locations.npy") - np.save(f"{outdir}/tiles/locations.npy", np.array(tile_locations)) - logger.info(f"Saving {outdir}/tiles/class_counts.npy") - np.save(f"{outdir}/tiles/class_counts.npy", class_local_counts) - logger.info(f"Saving {outdir}/tiles/transcript_counts.npy") - np.save(f"{outdir}/tiles/transcript_counts.npy", tx_local_counts) - logger.info(f"Saving {outdir}/tiles/class_counts.npy") - np.save(f"{outdir}/tiles/gene_ids.npy", np.array(list(enumerate(gene_ids))))