Skip to content

Commit

Permalink
Merge pull request #3 from tansey-lab/jq_improve_downsampling
Browse files Browse the repository at this point in the history
Improve Downsampling / Tile Preprocessing
  • Loading branch information
jeffquinn-msk authored Feb 12, 2024
2 parents 26ef3aa + 1add7f3 commit 69efc90
Show file tree
Hide file tree
Showing 14 changed files with 1,122 additions and 923 deletions.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ dependencies = [
"matplotlib",
"autograd-minimize",
"scikit-learn",
"pyarrow"
"pyarrow",
"h5py",
"blended-tiling"
]

[project.optional-dependencies]
Expand Down
155 changes: 155 additions & 0 deletions src/nuc2seg/celltyping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import numpy as np
import tqdm
from scipy.special import softmax

from nuc2seg.xenium import logger
import numpy as np
from scipy.special import logsumexp


def aic_bic(gene_counts, expression_profiles, prior_probs):

n_components = expression_profiles.shape[0]
n_genes = expression_profiles.shape[1]
n_samples = gene_counts.shape[0]

dof = n_components * (n_genes - 1) + n_components - 1
log_likelihood = logsumexp(
np.log(prior_probs[None])
+ (gene_counts[:, None] * np.log(expression_profiles[None])).sum(axis=2),
axis=1,
).sum()
aic = -2 * log_likelihood + 2 * dof
bic = -2 * log_likelihood + dof * np.log(n_samples)

return aic, bic


## TODO: Maybe replace this with a library that does the same thing
def estimate_cell_types(
gene_counts,
min_components=2,
max_components=25,
max_em_steps=100,
tol=1e-4,
warm_start=False,
):

n_nuclei, n_genes = gene_counts.shape

# Randomly initialize cell type profiles
cur_expression_profiles = np.random.dirichlet(np.ones(n_genes), size=min_components)

# Initialize probabilities to be uniform
cur_prior_probs = np.ones(min_components) / min_components

# No need to initialize cell type assignments
cur_cell_types = None

# Track BIC and AIC scores
aic_scores = np.zeros(max_components - min_components + 1)
bic_scores = np.zeros(max_components - min_components + 1)

# Iterate through every possible number of components
(
final_expression_profiles,
final_prior_probs,
final_cell_types,
final_expression_rates,
) = ([], [], [], [])
prev_expression_profiles = np.zeros_like(cur_expression_profiles)
for idx, n_components in enumerate(
tqdm.trange(
min_components, max_components + 1, desc="estimate_cell_types", position=0
)
):

# Warm start from the previous iteration
if warm_start and n_components > min_components:
# Expand and copy over the current parameters
next_prior_probs = np.zeros(n_components)
next_expression_profiles = np.zeros((n_components, n_genes))
next_prior_probs[: n_components - 1] = cur_prior_probs
next_expression_profiles[: n_components - 1] = cur_expression_profiles

# Split the dominant cluster
dominant = np.argmax(cur_prior_probs)
split_prob = np.random.random()
next_prior_probs[-1] = next_prior_probs[dominant] * split_prob
next_prior_probs[dominant] = next_prior_probs[dominant] * (1 - split_prob)
next_expression_profiles[-1] = cur_expression_profiles[
dominant
] * split_prob + (1 - split_prob) * np.random.dirichlet(np.ones(n_genes))

cur_prior_probs = next_prior_probs
cur_expression_profiles = next_expression_profiles

logger.debug("priors:", cur_prior_probs)
logger.debug("expression:", cur_expression_profiles)
else:
# Cold start from a random location
cur_expression_profiles = np.random.dirichlet(
np.ones(n_genes), size=n_components
)
cur_prior_probs = np.ones(n_components) / n_components

converge = tol + 1
for step in tqdm.trange(
max_em_steps,
desc=f"EM for n_components {n_components}",
unit="step",
position=1,
):

# E-step: estimate cell type assignments (posterior probabilities)
logits = np.log(cur_prior_probs[None]) + (
gene_counts[:, None] * np.log(cur_expression_profiles[None])
).sum(axis=2)
cur_cell_types = softmax(logits, axis=1)

# M-step (part 1): estimate cell type profiles
prev_expression_profiles = np.array(cur_expression_profiles)
cur_expression_profiles = (
cur_cell_types[..., None] * gene_counts[:, None]
).sum(axis=0) + 1
cur_expression_profiles = (
cur_expression_profiles / (cur_cell_types.sum(axis=0) + 1)[:, None]
)

# M-step (part 2): estimate cell type probabilities
cur_prior_probs = (cur_cell_types.sum(axis=0) + 1) / (
cur_cell_types.shape[0] + n_components
)
logger.debug(f"cur_prior_probs: {cur_prior_probs}")

# Track convergence of the cell type profiles
converge = np.linalg.norm(
prev_expression_profiles - cur_expression_profiles
) / np.linalg.norm(prev_expression_profiles)

logger.debug(f"Convergence: {converge:.4f}")
if converge <= tol:
logger.debug(f"Stopping early.")
break

# Save the results
final_expression_profiles.append(cur_expression_profiles)
final_cell_types.append(cur_cell_types)
final_prior_probs.append(cur_prior_probs)

# Calculate BIC and AIC
aic_scores[idx], bic_scores[idx] = aic_bic(
gene_counts, cur_expression_profiles, cur_prior_probs
)

logger.debug(f"K={n_components}")
logger.debug(f"AIC: {aic_scores[idx]:.4f}")
logger.debug(f"BIC: {bic_scores[idx]:.4f}")

return {
"bic": bic_scores,
"aic": aic_scores,
"expression_profiles": final_expression_profiles,
"prior_probs": final_prior_probs,
"cell_types": final_cell_types,
}
73 changes: 44 additions & 29 deletions src/nuc2seg/cli/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import argparse
import logging
import numpy as np
import pandas

from nuc2seg import log_config
from nuc2seg.xenium_utils import spatial_as_sparse_arrays
from nuc2seg.xenium import (
load_nuclei,
load_and_filter_transcripts,
create_shapely_rectangle,
)
from nuc2seg.preprocessing import create_rasterized_dataset

logger = logging.getLogger(__name__)

Expand All @@ -26,8 +32,8 @@ def get_parser():
required=True,
)
parser.add_argument(
"--output-dir",
help="Output directory.",
"--output",
help="Output path.",
type=str,
required=True,
)
Expand All @@ -38,10 +44,10 @@ def get_parser():
default=0,
)
parser.add_argument(
"--pixel-stride",
help="Stride for the pixel grid.",
type=int,
default=1,
"--resolution",
help="Size of a pixel in microns for rasterization.",
type=float,
default=1.0,
)
parser.add_argument(
"--min-qv",
Expand Down Expand Up @@ -74,22 +80,10 @@ def get_parser():
default=5,
)
parser.add_argument(
"--tile-height",
help="Height of the tiles.",
type=int,
default=64,
)
parser.add_argument(
"--tile-width",
help="Width of the tiles.",
type=int,
default=64,
)
parser.add_argument(
"--tile-stride",
help="Stride of the tiles.",
type=int,
default=48,
"--sample-area",
default=None,
type=str,
help='Crop the dataset to this rectangle, provided in in "x1,x2,y1,y2" format.',
)
return parser

Expand All @@ -109,17 +103,38 @@ def main():

prng = np.random.default_rng(args.seed)

spatial_as_sparse_arrays(
if args.sample_area:
sample_area = create_shapely_rectangle(
*[float(x) for x in args.sample_area.split(",")]
)

else:
df = pandas.read_parquet(args.transcripts_file)
y_max = df["y_location"].max()
x_max = df["x_location"].max()

sample_area = create_shapely_rectangle(0, 0, x_max, y_max)

nuclei_geo_df = load_nuclei(
nuclei_file=args.nuclei_file,
sample_area=sample_area,
)

tx_geo_df = load_and_filter_transcripts(
transcripts_file=args.transcripts_file,
outdir=args.output_dir,
pixel_stride=args.pixel_stride,
sample_area=sample_area,
min_qv=args.min_qv,
)

ds = create_rasterized_dataset(
nuclei_geo_df=nuclei_geo_df,
tx_geo_df=tx_geo_df,
sample_area=sample_area,
resolution=args.resolution,
foreground_nucleus_distance=args.foreground_nucleus_distance,
background_nucleus_distance=args.background_nucleus_distance,
background_pixel_transcripts=args.background_pixel_transcripts,
background_transcript_distance=args.background_transcript_distance,
tile_width=args.tile_width,
tile_height=args.tile_height,
tile_stride=args.tile_stride,
)

ds.save_h5(args.output)
53 changes: 42 additions & 11 deletions src/nuc2seg/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from nuc2seg import log_config
from nuc2seg.train import train
from nuc2seg.unet_model import SparseUNet
from nuc2seg.data import Nuc2SegDataset, TiledDataset

logger = logging.getLogger(__name__)

Expand All @@ -16,8 +17,8 @@ def get_parser():
)
log_config.add_logging_args(parser)
parser.add_argument(
"--preprocessed-tiles-dir",
help="Directory containing preprocessed tiles.",
"--dataset",
help="Path to dataset in h5 format.",
type=str,
required=True,
)
Expand All @@ -27,12 +28,6 @@ def get_parser():
type=str,
required=True,
)
parser.add_argument(
"--n-classes",
help="Number of classes to segment.",
type=int,
required=True,
)
parser.add_argument(
"--seed",
help="Seed to use for PRNG.",
Expand Down Expand Up @@ -112,7 +107,30 @@ def get_parser():
default="cpu",
choices=["cpu", "cuda"],
)

parser.add_argument(
"--tile-height",
help="Height of the tiles.",
type=int,
default=64,
)
parser.add_argument(
"--tile-width",
help="Width of the tiles.",
type=int,
default=64,
)
parser.add_argument(
"--overlap-percentage",
help="What percent of each tile dimension overlaps with the next tile.",
type=float,
default=0.25,
)
parser.add_argument(
"--num-dataloader-workers",
help="Number of workers to use for the data loader.",
type=int,
default=0,
)
return parser


Expand All @@ -130,12 +148,24 @@ def main():
log_config.configure_logging(args)

np.random.seed(args.seed)
model = SparseUNet(600, args.n_classes + 2, (64, 64))

logger.info(f"Loading dataset from {args.dataset}")

ds = Nuc2SegDataset.load_h5(args.dataset)

tiled_dataset = TiledDataset(
ds,
tile_height=args.tile_height,
tile_width=args.tile_width,
tile_overlap=args.overlap_percentage,
)

model = SparseUNet(600, ds.n_classes + 2, (64, 64))

train(
model,
device=args.device,
tiles_dir=args.preprocessed_tiles_dir,
dataset=tiled_dataset,
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
Expand All @@ -147,6 +177,7 @@ def main():
gradient_clipping=args.gradient_clipping,
max_workers=args.max_workers,
validation_frequency=args.validation_frequency,
num_dataloader_workers=args.num_dataloader_workers,
)

logger.info(f"Saving model weights to {args.model_weights_output}")
Expand Down
Loading

0 comments on commit 69efc90

Please sign in to comment.