Skip to content

Commit

Permalink
refactor normalization process
Browse files Browse the repository at this point in the history
  • Loading branch information
noctillion committed Dec 12, 2024
1 parent 9e9e5c4 commit 8c8fd2c
Showing 1 changed file with 30 additions and 17 deletions.
47 changes: 30 additions & 17 deletions transcriptomics_data_service/routers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from transcriptomics_data_service.db import DatabaseDependency
from transcriptomics_data_service.models import GeneExpression
from transcriptomics_data_service.scripts.normalize import (
read_counts2tpm,
tpm_normalization,
tmm_normalization,
getmm_normalization,
)
Expand Down Expand Up @@ -36,35 +36,38 @@ async def normalize(
"""
Normalize gene expressions using the specified method for a given experiment_result_id.
"""
# method validation
if method not in VALID_METHODS:
# Method validation
if method.lower() not in VALID_METHODS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unsupported normalization method: {method}"
)

# load gene lengths
if method in [NORM_TPM, NORM_GETMM]:
# Load gene lengths if required
if method.lower() in [NORM_TPM, NORM_GETMM]:
if gene_lengths_file is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Gene lengths file is required for {method.upper()} normalization.",
)
gene_lengths = await _load_gene_lengths(gene_lengths_file)
else:
gene_lengths = None

# Fetch raw counts from the database
raw_counts_df = await _fetch_raw_counts(db, experiment_result_id)

# normalization
if method == NORM_TPM:
# Perform normalization
if method.lower() == NORM_TPM:
raw_counts_df, gene_lengths_series = _align_gene_lengths(raw_counts_df, gene_lengths)
normalized_df = read_counts2tpm(raw_counts_df, gene_lengths_series)
elif method == NORM_TMM:
normalized_df = tpm_normalization(raw_counts_df, gene_lengths_series)
elif method.lower() == NORM_TMM:
normalized_df = tmm_normalization(raw_counts_df)
elif method == NORM_GETMM:
elif method.lower() == NORM_GETMM:
raw_counts_df, gene_lengths_series = _align_gene_lengths(raw_counts_df, gene_lengths)
normalized_df = getmm_normalization(raw_counts_df, gene_lengths_series)

# database update using normalized values
await _update_normalized_values(db, normalized_df, experiment_result_id, method=method)
# Update database with normalized values
await _update_normalized_values(db, normalized_df, experiment_result_id, method=method.lower())

return {"message": f"{method.upper()} normalization completed successfully"}

Expand All @@ -74,8 +77,13 @@ async def _load_gene_lengths(gene_lengths_file: UploadFile) -> pd.Series:
Load gene lengths from the uploaded file.
"""
content = await gene_lengths_file.read()
gene_lengths_df = pd.read_csv(StringIO(content.decode("utf-8")), index_col="GeneID")
gene_lengths_series = gene_lengths_df["GeneLength"]
gene_lengths_df = pd.read_csv(StringIO(content.decode("utf-8")), index_col=0)
if gene_lengths_df.shape[1] != 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Gene lengths file should contain exactly one column of gene lengths.",
)
gene_lengths_series = gene_lengths_df.iloc[:, 0]
gene_lengths_series = gene_lengths_series.apply(pd.to_numeric, errors="raise")
return gene_lengths_series

Expand All @@ -85,7 +93,9 @@ async def _fetch_raw_counts(db, experiment_result_id: str) -> pd.DataFrame:
Fetch raw counts from the database for the given experiment_result_id.
Returns a DataFrame with genes as rows and samples as columns.
"""
expressions = await db.fetch_gene_expressions_by_experiment_id(experiment_result_id)
expressions, _ = await db.fetch_gene_expressions(
experiments=[experiment_result_id], method="raw", paginate=False
)
if not expressions:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Experiment result not found.")

Expand Down Expand Up @@ -116,10 +126,12 @@ def _align_gene_lengths(raw_counts_df: pd.DataFrame, gene_lengths: pd.Series):

async def _update_normalized_values(db, normalized_df: pd.DataFrame, experiment_result_id: str, method: str):
"""
Update the normalized values in the database
Update the normalized values in the database.
"""
# Fetch existing expressions to get raw_count values
existing_expressions = await db.fetch_gene_expressions_by_experiment_id(experiment_result_id)
existing_expressions, _ = await db.fetch_gene_expressions(
experiments=[experiment_result_id], method="raw", paginate=False
)
raw_count_dict = {(expr.gene_code, expr.sample_id): expr.raw_count for expr in existing_expressions}

normalized_df = normalized_df.reset_index().melt(
Expand All @@ -138,6 +150,7 @@ async def _update_normalized_values(db, normalized_df: pd.DataFrame, experiment_
detail=f"Raw count not found for gene {gene_code}, sample {sample_id}",
)

# Create a GeneExpression object with the normalized value
gene_expression = GeneExpression(
gene_code=gene_code,
sample_id=sample_id,
Expand Down

0 comments on commit 8c8fd2c

Please sign in to comment.