diff --git a/transcriptomics_data_service/routers/normalization.py b/transcriptomics_data_service/routers/normalization.py index a507208..d845ce7 100644 --- a/transcriptomics_data_service/routers/normalization.py +++ b/transcriptomics_data_service/routers/normalization.py @@ -5,7 +5,7 @@ from transcriptomics_data_service.db import DatabaseDependency from transcriptomics_data_service.models import GeneExpression from transcriptomics_data_service.scripts.normalize import ( - read_counts2tpm, + tpm_normalization, tmm_normalization, getmm_normalization, ) @@ -36,35 +36,38 @@ async def normalize( """ Normalize gene expressions using the specified method for a given experiment_result_id. """ - # method validation - if method not in VALID_METHODS: + # Method validation + if method.lower() not in VALID_METHODS: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unsupported normalization method: {method}" ) - # load gene lengths - if method in [NORM_TPM, NORM_GETMM]: + # Load gene lengths if required + if method.lower() in [NORM_TPM, NORM_GETMM]: if gene_lengths_file is None: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Gene lengths file is required for {method.upper()} normalization.", ) gene_lengths = await _load_gene_lengths(gene_lengths_file) + else: + gene_lengths = None + # Fetch raw counts from the database raw_counts_df = await _fetch_raw_counts(db, experiment_result_id) - # normalization - if method == NORM_TPM: + # Perform normalization + if method.lower() == NORM_TPM: raw_counts_df, gene_lengths_series = _align_gene_lengths(raw_counts_df, gene_lengths) - normalized_df = read_counts2tpm(raw_counts_df, gene_lengths_series) - elif method == NORM_TMM: + normalized_df = tpm_normalization(raw_counts_df, gene_lengths_series) + elif method.lower() == NORM_TMM: normalized_df = tmm_normalization(raw_counts_df) - elif method == NORM_GETMM: + elif method.lower() == NORM_GETMM: raw_counts_df, gene_lengths_series = _align_gene_lengths(raw_counts_df, gene_lengths) normalized_df = getmm_normalization(raw_counts_df, gene_lengths_series) - # database update using normalized values - await _update_normalized_values(db, normalized_df, experiment_result_id, method=method) + # Update database with normalized values + await _update_normalized_values(db, normalized_df, experiment_result_id, method=method.lower()) return {"message": f"{method.upper()} normalization completed successfully"} @@ -74,8 +77,13 @@ async def _load_gene_lengths(gene_lengths_file: UploadFile) -> pd.Series: Load gene lengths from the uploaded file. """ content = await gene_lengths_file.read() - gene_lengths_df = pd.read_csv(StringIO(content.decode("utf-8")), index_col="GeneID") - gene_lengths_series = gene_lengths_df["GeneLength"] + gene_lengths_df = pd.read_csv(StringIO(content.decode("utf-8")), index_col=0) + if gene_lengths_df.shape[1] != 1: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Gene lengths file should contain exactly one column of gene lengths.", + ) + gene_lengths_series = gene_lengths_df.iloc[:, 0] gene_lengths_series = gene_lengths_series.apply(pd.to_numeric, errors="raise") return gene_lengths_series @@ -85,7 +93,9 @@ async def _fetch_raw_counts(db, experiment_result_id: str) -> pd.DataFrame: Fetch raw counts from the database for the given experiment_result_id. Returns a DataFrame with genes as rows and samples as columns. """ - expressions = await db.fetch_gene_expressions_by_experiment_id(experiment_result_id) + expressions, _ = await db.fetch_gene_expressions( + experiments=[experiment_result_id], method="raw", paginate=False + ) if not expressions: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Experiment result not found.") @@ -116,10 +126,12 @@ def _align_gene_lengths(raw_counts_df: pd.DataFrame, gene_lengths: pd.Series): async def _update_normalized_values(db, normalized_df: pd.DataFrame, experiment_result_id: str, method: str): """ - Update the normalized values in the database + Update the normalized values in the database. """ # Fetch existing expressions to get raw_count values - existing_expressions = await db.fetch_gene_expressions_by_experiment_id(experiment_result_id) + existing_expressions, _ = await db.fetch_gene_expressions( + experiments=[experiment_result_id], method="raw", paginate=False + ) raw_count_dict = {(expr.gene_code, expr.sample_id): expr.raw_count for expr in existing_expressions} normalized_df = normalized_df.reset_index().melt( @@ -138,6 +150,7 @@ async def _update_normalized_values(db, normalized_df: pd.DataFrame, experiment_ detail=f"Raw count not found for gene {gene_code}, sample {sample_id}", ) + # Create a GeneExpression object with the normalized value gene_expression = GeneExpression( gene_code=gene_code, sample_id=sample_id,