Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Search basic #15

Open
wants to merge 14 commits into
base: paralell_normalization
Choose a base branch
from
104 changes: 78 additions & 26 deletions transcriptomics_data_service/db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Annotated, AsyncIterator, List, Tuple
from typing import Annotated, AsyncIterator, List, Tuple, Optional
import asyncpg
from bento_lib.db.pg_async import PgAsyncDatabase
from contextlib import asynccontextmanager
Expand Down Expand Up @@ -111,9 +111,6 @@ async def create_gene_expressions(self, expressions: list[GeneExpression], trans
await transaction_conn.executemany(query, records)
self.logger.info(f"Inserted {len(records)} gene expression records.")

async def fetch_expressions(self) -> tuple[GeneExpression, ...]:
return tuple([r async for r in self._select_expressions(None)])

async def _select_expressions(self, exp_id: str | None) -> AsyncIterator[GeneExpression]:
conn: asyncpg.Connection
where_clause = "WHERE experiment_result_id = $1" if exp_id is not None else ""
Expand All @@ -123,28 +120,6 @@ async def _select_expressions(self, exp_id: str | None) -> AsyncIterator[GeneExp
for r in map(lambda g: self._deserialize_gene_expression(g), res):
yield r

async def fetch_gene_expressions_by_experiment_id(self, experiment_result_id: str) -> Tuple[GeneExpression, ...]:
"""
Fetch gene expressions for a specific experiment_result_id.
"""
conn: asyncpg.Connection
async with self.connect() as conn:
query = """
SELECT * FROM gene_expressions WHERE experiment_result_id = $1
"""
res = await conn.fetch(query, experiment_result_id)
return tuple([self._deserialize_gene_expression(record) for record in res])

async def fetch_gene_expressions(
self, experiments: list[str], method: str = "raw", paginate: bool = False
) -> Tuple[Tuple[GeneExpression, ...], int]:
if not experiments:
return (), 0
# TODO: refactor this fetch_gene_expressions_by_experiment_id and implement pagination
experiment_result_id = experiments[0]
expressions = await self.fetch_gene_expressions_by_experiment_id(experiment_result_id)
return expressions, len(expressions)

def _deserialize_gene_expression(self, rec: asyncpg.Record) -> GeneExpression:
return GeneExpression(
gene_code=rec["gene_code"],
Expand Down Expand Up @@ -219,6 +194,83 @@ async def transaction_connection(self):
# operations must be made using this connection for the transaction to apply
yield conn

async def fetch_gene_expressions(
self,
genes: Optional[List[str]] = None,
experiments: Optional[List[str]] = None,
sample_ids: Optional[List[str]] = None,
method: str = "raw",
page: int = 1,
page_size: int = 100,
paginate: bool = True,
) -> Tuple[List[GeneExpression], int]:
"""
Fetch gene expressions based on genes, experiments, sample_ids, and method, with optional pagination.
Returns a tuple of (expressions list, total_records count).
"""
conn: asyncpg.Connection
async with self.connect() as conn:
# Query builder
base_query = """
SELECT gene_code, sample_id, experiment_result_id, raw_count, tpm_count, tmm_count, getmm_count
FROM gene_expressions
"""
count_query = "SELECT COUNT(*) FROM gene_expressions"
conditions = []
params = []
param_counter = 1

if genes:
conditions.append(f"gene_code = ANY(${param_counter}::text[])")
params.append(genes)
param_counter += 1

if experiments:
conditions.append(f"experiment_result_id = ANY(${param_counter}::text[])")
params.append(experiments)
param_counter += 1

if sample_ids:
conditions.append(f"sample_id = ANY(${param_counter}::text[])")
params.append(sample_ids)
param_counter += 1

if method != "raw":
conditions.append(f"{method}_count IS NOT NULL")

where_clause = " WHERE " + " AND ".join(conditions) if conditions else ""

order_clause = " ORDER BY gene_code, sample_id"

query = base_query + where_clause + order_clause
count_query += where_clause

# Pagination
if paginate:
limit_offset_clause = f" LIMIT ${param_counter} OFFSET ${param_counter + 1}"
params.extend([page_size, (page - 1) * page_size])
query += limit_offset_clause

total_records_params = params[:-2] if paginate else params
total_records = await conn.fetchval(count_query, *total_records_params)

res = await conn.fetch(query, *params)

expressions = [
GeneExpression(
gene_code=record["gene_code"],
sample_id=record["sample_id"],
experiment_result_id=record["experiment_result_id"],
raw_count=record["raw_count"],
tpm_count=record["tpm_count"],
tmm_count=record["tmm_count"],
getmm_count=record["getmm_count"],
)
for record in res
]

return expressions, total_records


@lru_cache()
def get_db(config: ConfigDependency, logger: LoggerDependency) -> Database:
Expand Down
4 changes: 2 additions & 2 deletions transcriptomics_data_service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from transcriptomics_data_service.db import get_db
from transcriptomics_data_service.routers.experiment_results import experiment_router
from transcriptomics_data_service.routers.expressions import expression_router
from transcriptomics_data_service.routers.ingest import ingest_router
from transcriptomics_data_service.routers.normalization import normalization_router
from transcriptomics_data_service.routers.query import query_router
from . import __version__
from .config import get_config
from .constants import BENTO_SERVICE_KIND, SERVICE_TYPE
Expand Down Expand Up @@ -42,7 +42,7 @@ async def lifespan(_app: FastAPI):
lifespan=lifespan,
)

app.include_router(expression_router)
app.include_router(ingest_router)
app.include_router(experiment_router)
app.include_router(normalization_router)
app.include_router(query_router)
71 changes: 61 additions & 10 deletions transcriptomics_data_service/models.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,73 @@
from pydantic import BaseModel
from pydantic import BaseModel, Field, validator
from typing import List, Optional
from enum import Enum

__all__ = [
"ExperimentResult",
"GeneExpression",
"GeneExpressionData",
"PaginationMeta",
"GeneExpressionResponse",
"MethodEnum",
"QueryParameters",
]


class PaginatedRequest(BaseModel):
page: int = Field(1, ge=1, description="Current page number")
page_size: int = Field(100, ge=1, le=1000, description="Number of records per page")


class PaginatedResponse(PaginatedRequest):
total_records: int = Field(..., ge=0, description="Total number of records")
total_pages: int = Field(..., ge=1, description="Total number of pages")


class ExperimentResult(BaseModel):
experiment_result_id: str
assembly_id: str | None = None
assembly_name: str | None = None
experiment_result_id: str = Field(..., min_length=1, max_length=255)
assembly_id: Optional[str] = Field(None, max_length=255)
assembly_name: Optional[str] = Field(None, max_length=255)


class GeneExpression(BaseModel):
gene_code: str
sample_id: str
experiment_result_id: str
gene_code: str = Field(..., min_length=1, max_length=255)
sample_id: str = Field(..., min_length=1, max_length=255)
experiment_result_id: str = Field(..., min_length=1, max_length=255)
raw_count: int
tpm_count: float | None = None
tmm_count: float | None = None
getmm_count: float | None = None
tpm_count: Optional[float] = None
tmm_count: Optional[float] = None
getmm_count: Optional[float] = None


class GeneExpressionData(BaseModel):
gene_code: str = Field(..., min_length=1, max_length=255, description="Gene code")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was there a specific reason to only include one count?

sample_id: str = Field(..., min_length=1, max_length=255, description="Sample ID")
experiment_result_id: str = Field(..., min_length=1, max_length=255, description="Experiment result ID")
count: float = Field(..., description="Expression count")
method: str = Field(..., description="Method used to calculate the expression count")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

method: MethodEnum = Field(MethodEnum, ...)


class GeneExpressionResponse(PaginatedResponse):
expressions: List[GeneExpressionData]


class MethodEnum(str, Enum):
raw = "raw"
tpm = "tpm"
tmm = "tmm"
getmm = "getmm"


class QueryParameters(PaginatedRequest):
genes: Optional[List[str]] = Field(None, description="List of gene codes to retrieve")
experiments: Optional[List[str]] = Field(None, description="List of experiment result IDs to retrieve data from")
sample_ids: Optional[List[str]] = Field(None, description="List of sample IDs to retrieve data from")
method: MethodEnum = Field(MethodEnum.raw, description="Data method to retrieve: 'raw', 'tpm', 'tmm', 'getmm'")

@validator("genes", "experiments", "sample_ids", each_item=True)
def validate_identifiers(cls, value):
if not (1 <= len(value) <= 255):
raise ValueError("Each identifier must be between 1 and 255 characters long.")
if not value.replace("_", "").isalnum():
raise ValueError("Identifiers must contain only alphanumeric characters and underscores.")
return value
87 changes: 87 additions & 0 deletions transcriptomics_data_service/routers/query.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be in the expression router instead, with no /query prefix

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be in the expressions router, since it is has the /expressions path.

Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from fastapi import APIRouter, HTTPException, status, Query

from transcriptomics_data_service.db import DatabaseDependency
from transcriptomics_data_service.logger import LoggerDependency
from transcriptomics_data_service.models import (
GeneExpressionData,
GeneExpressionResponse,
MethodEnum,
QueryParameters,
)

query_router = APIRouter()


async def get_expressions_handler(
params: QueryParameters,
db: DatabaseDependency,
logger: LoggerDependency,
):
"""
Handler for fetching and returning gene expression data.
"""
logger.info(f"Received query parameters: {params}")

expressions, total_records = await db.fetch_gene_expressions(
genes=params.genes,
experiments=params.experiments,
sample_ids=params.sample_ids,
method=params.method.value,
page=params.page,
page_size=params.page_size,
)

if not expressions:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No gene expression data found for the given parameters.",
)

response_data = []
method_field = f"{params.method.value}_count" if params.method != MethodEnum.raw else "raw_count"
for expr in expressions:
count = getattr(expr, method_field)
response_item = GeneExpressionData(
gene_code=expr.gene_code,
sample_id=expr.sample_id,
experiment_result_id=expr.experiment_result_id,
count=count,
method=method_field,
)
response_data.append(response_item)

total_pages = (total_records + params.page_size - 1) // params.page_size

return GeneExpressionResponse(
expressions=response_data,
total_records=total_records,
page=params.page,
page_size=params.page_size,
total_pages=total_pages,
)


@query_router.post(
"/expressions",
status_code=status.HTTP_200_OK,
response_model=GeneExpressionResponse,
)
async def get_expressions_post(
params: QueryParameters,
db: DatabaseDependency,
logger: LoggerDependency,
):
"""
Retrieve gene expression data via POST request.

Example JSON body:
{
"genes": ["gene1", "gene2"],
"experiments": ["exp1"],
"sample_ids": ["sample1"],
"method": "tmm",
"page": 1,
"page_size": 100
}
"""
return await get_expressions_handler(params, db, logger)
4 changes: 3 additions & 1 deletion transcriptomics_data_service/scripts/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def tmm_normalization(counts_df, logratio_trim=0.3, sum_trim=0.05, weighting=Tru
return normalized_data


def getmm_normalization(counts_df, gene_lengths, logratio_trim=0.3, sum_trim=0.05, scaling_factor=1e3, weighting=True, n_jobs=-1):
def getmm_normalization(
counts_df, gene_lengths, logratio_trim=0.3, sum_trim=0.05, scaling_factor=1e3, weighting=True, n_jobs=-1
):
"""Perform GeTMM normalization on counts data."""
counts_df, gene_lengths = prepare_counts_and_lengths(counts_df, gene_lengths)
rpk = counts_df.mul(scaling_factor).div(gene_lengths, axis=0)
Expand Down
Loading