Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Paralell normalization #14

Open
wants to merge 7 commits into
base: normalization
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jsonschema = "^4.21.1"
pydantic-settings = "^2.1.0"
asyncpg = "^0.29.0"
pandas = "^2.2.3"
conorm = "^1.2.0"
joblib = "^1.4.2"

[tool.poetry.group.dev.dependencies]
aioresponses = "^0.7.6"
Expand Down
10 changes: 10 additions & 0 deletions transcriptomics_data_service/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,16 @@ async def fetch_gene_expressions_by_experiment_id(self, experiment_result_id: st
res = await conn.fetch(query, experiment_result_id)
return tuple([self._deserialize_gene_expression(record) for record in res])

async def fetch_gene_expressions(
noctillion marked this conversation as resolved.
Show resolved Hide resolved
self, experiments: list[str], method: str = "raw", paginate: bool = False
) -> Tuple[Tuple[GeneExpression, ...], int]:
if not experiments:
return (), 0
# TODO: refactor this fetch_gene_expressions_by_experiment_id and implement pagination
experiment_result_id = experiments[0]
expressions = await self.fetch_gene_expressions_by_experiment_id(experiment_result_id)
return expressions, len(expressions)

def _deserialize_gene_expression(self, rec: asyncpg.Record) -> GeneExpression:
return GeneExpression(
gene_code=rec["gene_code"],
Expand Down
45 changes: 28 additions & 17 deletions transcriptomics_data_service/routers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from transcriptomics_data_service.db import DatabaseDependency
from transcriptomics_data_service.models import GeneExpression
from transcriptomics_data_service.scripts.normalize import (
read_counts2tpm,
tpm_normalization,
tmm_normalization,
getmm_normalization,
)
Expand Down Expand Up @@ -36,35 +36,38 @@ async def normalize(
"""
Normalize gene expressions using the specified method for a given experiment_result_id.
"""
# method validation
if method not in VALID_METHODS:
# Method validation
if method.lower() not in VALID_METHODS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unsupported normalization method: {method}"
)

# load gene lengths
if method in [NORM_TPM, NORM_GETMM]:
# Load gene lengths if required
if method.lower() in [NORM_TPM, NORM_GETMM]:
if gene_lengths_file is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Gene lengths file is required for {method.upper()} normalization.",
)
gene_lengths = await _load_gene_lengths(gene_lengths_file)
else:
gene_lengths = None

# Fetch raw counts from the database
raw_counts_df = await _fetch_raw_counts(db, experiment_result_id)

# normalization
if method == NORM_TPM:
# Perform normalization
if method.lower() == NORM_TPM:
raw_counts_df, gene_lengths_series = _align_gene_lengths(raw_counts_df, gene_lengths)
normalized_df = read_counts2tpm(raw_counts_df, gene_lengths_series)
elif method == NORM_TMM:
normalized_df = tpm_normalization(raw_counts_df, gene_lengths_series)
elif method.lower() == NORM_TMM:
normalized_df = tmm_normalization(raw_counts_df)
elif method == NORM_GETMM:
elif method.lower() == NORM_GETMM:
raw_counts_df, gene_lengths_series = _align_gene_lengths(raw_counts_df, gene_lengths)
normalized_df = getmm_normalization(raw_counts_df, gene_lengths_series)

# database update using normalized values
await _update_normalized_values(db, normalized_df, experiment_result_id, method=method)
# Update database with normalized values
await _update_normalized_values(db, normalized_df, experiment_result_id, method=method.lower())

return {"message": f"{method.upper()} normalization completed successfully"}

Expand All @@ -74,8 +77,13 @@ async def _load_gene_lengths(gene_lengths_file: UploadFile) -> pd.Series:
Load gene lengths from the uploaded file.
"""
content = await gene_lengths_file.read()
gene_lengths_df = pd.read_csv(StringIO(content.decode("utf-8")), index_col="GeneID")
gene_lengths_series = gene_lengths_df["GeneLength"]
gene_lengths_df = pd.read_csv(StringIO(content.decode("utf-8")), index_col=0)
if gene_lengths_df.shape[1] != 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Gene lengths file should contain exactly one column of gene lengths.",
)
gene_lengths_series = gene_lengths_df.iloc[:, 0]
gene_lengths_series = gene_lengths_series.apply(pd.to_numeric, errors="raise")
return gene_lengths_series

Expand All @@ -85,7 +93,7 @@ async def _fetch_raw_counts(db, experiment_result_id: str) -> pd.DataFrame:
Fetch raw counts from the database for the given experiment_result_id.
Returns a DataFrame with genes as rows and samples as columns.
"""
expressions = await db.fetch_gene_expressions_by_experiment_id(experiment_result_id)
expressions, _ = await db.fetch_gene_expressions(experiments=[experiment_result_id], method="raw", paginate=False)
if not expressions:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Experiment result not found.")

Expand Down Expand Up @@ -116,10 +124,12 @@ def _align_gene_lengths(raw_counts_df: pd.DataFrame, gene_lengths: pd.Series):

async def _update_normalized_values(db, normalized_df: pd.DataFrame, experiment_result_id: str, method: str):
"""
Update the normalized values in the database
Update the normalized values in the database.
"""
# Fetch existing expressions to get raw_count values
existing_expressions = await db.fetch_gene_expressions_by_experiment_id(experiment_result_id)
existing_expressions, _ = await db.fetch_gene_expressions(
experiments=[experiment_result_id], method="raw", paginate=False
)
raw_count_dict = {(expr.gene_code, expr.sample_id): expr.raw_count for expr in existing_expressions}

normalized_df = normalized_df.reset_index().melt(
Expand All @@ -138,6 +148,7 @@ async def _update_normalized_values(db, normalized_df: pd.DataFrame, experiment_
detail=f"Raw count not found for gene {gene_code}, sample {sample_id}",
)

# Create a GeneExpression object with the normalized value
gene_expression = GeneExpression(
gene_code=gene_code,
sample_id=sample_id,
Expand Down
188 changes: 129 additions & 59 deletions transcriptomics_data_service/scripts/normalize.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,142 @@
import pandas as pd
import numpy as np
from joblib import Parallel, delayed


def read_counts2tpm(counts_df, gene_lengths, scale_library=1e6, scale_length=1e3):
"""
Convert raw read counts to TPM (Transcripts Per Million).
def filter_counts(counts_df):
"""Filter out genes (rows) and samples (columns) with zero total counts."""
row_filter = counts_df.sum(axis=1) > 0
col_filter = counts_df.sum(axis=0) > 0
return counts_df.loc[row_filter, col_filter]

Parameters:
counts_df (DataFrame): DataFrame with genes as rows and samples as columns.
gene_lengths (Series): Series with gene lengths, index matches counts_df.index.
scale_library (int or float): Scaling factor for library size normalization (default 1e6).
scale_length (int or float): Scaling factor for gene length scaling (default 1e3).

Returns:
DataFrame: TPM-normalized values.
"""
# Ensure counts_df and gene_lengths are aligned
def prepare_counts_and_lengths(counts_df, gene_lengths, scale_length=None):
"""Align counts and gene_lengths, drop zeros, and optionally scale gene lengths."""
counts_df = counts_df.loc[gene_lengths.index]
valid_lengths = gene_lengths.replace(0, pd.NA).dropna()
counts_df = counts_df.loc[valid_lengths.index]
gene_lengths = valid_lengths
if scale_length is not None:
gene_lengths = gene_lengths / scale_length
return filter_counts(counts_df), gene_lengths

# Scale gene lengths
gene_lengths_scaled = gene_lengths / scale_length

# Calculate Reads Per Scaled Kilobase (RPK)
rpk = counts_df.div(gene_lengths_scaled, axis=0)
def parallel_apply(columns, func, n_jobs=-1):
"""Apply a function to each column in parallel and combine results."""
results = Parallel(n_jobs=n_jobs)(delayed(func)(col) for col in columns)
return pd.concat(results, axis=1)

# Calculate scaling factors
scaling_factors = rpk.sum(axis=0) / scale_library

# Calculate TPM
tpm = rpk.div(scaling_factors, axis=1)
def trim_values(log_ratio, log_mean, w, logratio_trim, sum_trim):
"""Perform log ratio and sum trimming."""
n = len(log_ratio)
loL = int(np.floor(n * logratio_trim / 2))
hiL = n - loL
lr_order = np.argsort(log_ratio)
trimmed_idx = lr_order[loL:hiL]

return tpm
lr_t = log_ratio[trimmed_idx]
w_t = w[trimmed_idx]
mean_t = log_mean[trimmed_idx]

n_t = len(mean_t)
loS = int(np.floor(n_t * sum_trim / 2))
hiS = n_t - loS
mean_order = np.argsort(mean_t)
final_idx = mean_order[loS:hiS]

return lr_t[final_idx], w_t[final_idx]


def compute_TMM_normalization_factors(counts_df, logratio_trim=0.3, sum_trim=0.05, weighting=True, n_jobs=-1):
"""Compute TMM normalization factors for counts data."""
lib_sizes = counts_df.sum(axis=0)
median_lib = lib_sizes.median()
ref_sample = (lib_sizes - median_lib).abs().idxmin()

ref_counts = counts_df[ref_sample].values
sample_names = counts_df.columns
data_values = counts_df.values

norm_factors = pd.Series(index=sample_names, dtype="float64")
norm_factors[ref_sample] = 1.0

def compute_norm_factor(sample):
if sample == ref_sample:
return sample, 1.0

i = sample_names.get_loc(sample)
data_i = data_values[:, i]

mask = (data_i > 0) & (ref_counts > 0)
data_i_masked = data_i[mask]
data_r_masked = ref_counts[mask]

N_i = data_i_masked.sum()
N_r = data_r_masked.sum()

data_i_norm = data_i_masked / N_i
data_r_norm = data_r_masked / N_r

log_ratio = np.log2(data_i_norm) - np.log2(data_r_norm)
log_mean = 0.5 * (np.log2(data_i_norm) + np.log2(data_r_norm))

w = 1.0 / (data_i_norm + data_r_norm) if weighting else np.ones_like(log_ratio)

lr_final, w_final = trim_values(log_ratio, log_mean, w, logratio_trim, sum_trim)

mean_M = np.sum(w_final * lr_final) / np.sum(w_final)
norm_factor = 2**mean_M
return sample, norm_factor

def tmm_normalization(counts_df):
"""
Perform TMM normalization on counts data.

Parameters:
counts_df (DataFrame): DataFrame with genes as rows and samples as columns.

Returns:
DataFrame: TMM-normalized values.
"""
try:
import conorm
except ImportError:
raise ImportError("The 'conorm' package is required for this function but is not installed.")
normalized_array = conorm.tmm(counts_df)
normalized_df = pd.DataFrame(normalized_array, columns=counts_df.columns, index=counts_df.index)
return normalized_df


def getmm_normalization(counts_df, gene_lengths):
"""
Perform GeTMM normalization on counts data.

Parameters:
counts_df (DataFrame): DataFrame with genes as rows and samples as columns.
gene_lengths (Series): Series with gene lengths, index matches counts_df.index.

Returns:
DataFrame: GeTMM-normalized values.
"""
try:
import conorm
except ImportError:
raise ImportError("The 'conorm' package is required for this function but is not installed.")

normalized_array = conorm.getmm(counts_df, gene_lengths)
normalized_df = pd.DataFrame(normalized_array, columns=counts_df.columns, index=counts_df.index)
return normalized_df
samples = [s for s in sample_names if s != ref_sample]
results = Parallel(n_jobs=n_jobs)(delayed(compute_norm_factor)(s) for s in samples)

for sample, nf in results:
norm_factors[sample] = nf

norm_factors = norm_factors / np.exp(np.mean(np.log(norm_factors)))
return norm_factors


def tmm_normalization(counts_df, logratio_trim=0.3, sum_trim=0.05, weighting=True, n_jobs=-1):
"""Perform TMM normalization on counts data."""
counts_df = filter_counts(counts_df)
norm_factors = compute_TMM_normalization_factors(counts_df, logratio_trim, sum_trim, weighting, n_jobs)
lib_sizes = counts_df.sum(axis=0)
normalized_data = counts_df.div(lib_sizes, axis=1).div(norm_factors, axis=1) * lib_sizes.mean()
return normalized_data


def getmm_normalization(counts_df, gene_lengths, logratio_trim=0.3, sum_trim=0.05, scaling_factor=1e3, weighting=True, n_jobs=-1):
"""Perform GeTMM normalization on counts data."""
counts_df, gene_lengths = prepare_counts_and_lengths(counts_df, gene_lengths)
rpk = counts_df.mul(scaling_factor).div(gene_lengths, axis=0)
return tmm_normalization(rpk, logratio_trim, sum_trim, weighting, n_jobs)


def compute_rpk(counts_df, gene_lengths_scaled, n_jobs=-1):
"""Compute RPK values in parallel."""
columns = counts_df.columns

def rpk_col(col):
return counts_df[col] / gene_lengths_scaled

rpk = parallel_apply(columns, rpk_col, n_jobs)
rpk.columns = columns
return rpk


def tpm_normalization(counts_df, gene_lengths, scale_library=1e6, scale_length=1e3, n_jobs=-1):
"""Convert raw read counts to TPM in parallel."""
counts_df, gene_lengths_scaled = prepare_counts_and_lengths(counts_df, gene_lengths, scale_length=scale_length)
rpk = compute_rpk(counts_df, gene_lengths_scaled, n_jobs)
scaling_factors = rpk.sum(axis=0).replace(0, pd.NA)
scaling_factors_norm = scaling_factors / scale_library

def tpm_col(col):
return rpk[col] / scaling_factors_norm[col]

tpm = parallel_apply(rpk.columns, tpm_col, n_jobs)
tpm.columns = rpk.columns
return tpm