Skip to content

Commit

Permalink
Merge pull request #15 from bento-platform/search_basic
Browse files Browse the repository at this point in the history
feat: search basic
  • Loading branch information
v-rocheleau authored Jan 7, 2025
2 parents 293d131 + b154ffe commit 39fbcd0
Show file tree
Hide file tree
Showing 6 changed files with 220 additions and 46 deletions.
2 changes: 1 addition & 1 deletion docker-compose.dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ services:
depends_on:
- tds-db
environment:
- BENTO_UID=1001
- BENTO_UID=${UID}
- DATABASE_URI=postgres://tds_user:tds_password@tds-db:5432/tds_db
- CORS_ORIGINS="*"
- BENTO_AUTHZ_SERVICE_URL=""
Expand Down
104 changes: 78 additions & 26 deletions transcriptomics_data_service/db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Annotated, AsyncIterator, List, Tuple
from typing import Annotated, AsyncIterator, List, Tuple, Optional
import asyncpg
from bento_lib.db.pg_async import PgAsyncDatabase
from contextlib import asynccontextmanager
Expand Down Expand Up @@ -111,9 +111,6 @@ async def create_gene_expressions(self, expressions: list[GeneExpression], trans
await transaction_conn.executemany(query, records)
self.logger.info(f"Inserted {len(records)} gene expression records.")

async def fetch_expressions(self) -> tuple[GeneExpression, ...]:
return tuple([r async for r in self._select_expressions(None)])

async def _select_expressions(self, exp_id: str | None) -> AsyncIterator[GeneExpression]:
conn: asyncpg.Connection
where_clause = "WHERE experiment_result_id = $1" if exp_id is not None else ""
Expand All @@ -123,28 +120,6 @@ async def _select_expressions(self, exp_id: str | None) -> AsyncIterator[GeneExp
for r in map(lambda g: self._deserialize_gene_expression(g), res):
yield r

async def fetch_gene_expressions_by_experiment_id(self, experiment_result_id: str) -> Tuple[GeneExpression, ...]:
"""
Fetch gene expressions for a specific experiment_result_id.
"""
conn: asyncpg.Connection
async with self.connect() as conn:
query = """
SELECT * FROM gene_expressions WHERE experiment_result_id = $1
"""
res = await conn.fetch(query, experiment_result_id)
return tuple([self._deserialize_gene_expression(record) for record in res])

async def fetch_gene_expressions(
self, experiments: list[str], method: str = "raw", paginate: bool = False
) -> Tuple[Tuple[GeneExpression, ...], int]:
if not experiments:
return (), 0
# TODO: refactor this fetch_gene_expressions_by_experiment_id and implement pagination
experiment_result_id = experiments[0]
expressions = await self.fetch_gene_expressions_by_experiment_id(experiment_result_id)
return expressions, len(expressions)

def _deserialize_gene_expression(self, rec: asyncpg.Record) -> GeneExpression:
return GeneExpression(
gene_code=rec["gene_code"],
Expand Down Expand Up @@ -219,6 +194,83 @@ async def transaction_connection(self):
# operations must be made using this connection for the transaction to apply
yield conn

async def fetch_gene_expressions(
self,
genes: Optional[List[str]] = None,
experiments: Optional[List[str]] = None,
sample_ids: Optional[List[str]] = None,
method: str = "raw",
page: int = 1,
page_size: int = 100,
paginate: bool = True,
) -> Tuple[List[GeneExpression], int]:
"""
Fetch gene expressions based on genes, experiments, sample_ids, and method, with optional pagination.
Returns a tuple of (expressions list, total_records count).
"""
conn: asyncpg.Connection
async with self.connect() as conn:
# Query builder
base_query = """
SELECT gene_code, sample_id, experiment_result_id, raw_count, tpm_count, tmm_count, getmm_count
FROM gene_expressions
"""
count_query = "SELECT COUNT(*) FROM gene_expressions"
conditions = []
params = []
param_counter = 1

if genes:
conditions.append(f"gene_code = ANY(${param_counter}::text[])")
params.append(genes)
param_counter += 1

if experiments:
conditions.append(f"experiment_result_id = ANY(${param_counter}::text[])")
params.append(experiments)
param_counter += 1

if sample_ids:
conditions.append(f"sample_id = ANY(${param_counter}::text[])")
params.append(sample_ids)
param_counter += 1

if method != "raw":
conditions.append(f"{method}_count IS NOT NULL")

where_clause = " WHERE " + " AND ".join(conditions) if conditions else ""

order_clause = " ORDER BY gene_code, sample_id"

query = base_query + where_clause + order_clause
count_query += where_clause

# Pagination
if paginate:
limit_offset_clause = f" LIMIT ${param_counter} OFFSET ${param_counter + 1}"
params.extend([page_size, (page - 1) * page_size])
query += limit_offset_clause

total_records_params = params[:-2] if paginate else params
total_records = await conn.fetchval(count_query, *total_records_params)

res = await conn.fetch(query, *params)

expressions = [
GeneExpression(
gene_code=record["gene_code"],
sample_id=record["sample_id"],
experiment_result_id=record["experiment_result_id"],
raw_count=record["raw_count"],
tpm_count=record["tpm_count"],
tmm_count=record["tmm_count"],
getmm_count=record["getmm_count"],
)
for record in res
]

return expressions, total_records


@lru_cache()
def get_db(config: ConfigDependency, logger: LoggerDependency) -> Database:
Expand Down
4 changes: 2 additions & 2 deletions transcriptomics_data_service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from transcriptomics_data_service.db import get_db
from transcriptomics_data_service.routers.experiment_results import experiment_router
from transcriptomics_data_service.routers.expressions import expression_router
from transcriptomics_data_service.routers.ingest import ingest_router
from transcriptomics_data_service.routers.normalization import normalization_router
from transcriptomics_data_service.routers.expressions import expressions_router
from . import __version__
from .config import get_config
from .constants import BENTO_SERVICE_KIND, SERVICE_TYPE
Expand Down Expand Up @@ -42,7 +42,7 @@ async def lifespan(_app: FastAPI):
lifespan=lifespan,
)

app.include_router(expression_router)
app.include_router(ingest_router)
app.include_router(experiment_router)
app.include_router(normalization_router)
app.include_router(expressions_router)
63 changes: 53 additions & 10 deletions transcriptomics_data_service/models.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,65 @@
from pydantic import BaseModel
from pydantic import BaseModel, Field, validator, field_validator
from typing import List, Optional
from enum import Enum

__all__ = [
"ExperimentResult",
"GeneExpression",
"GeneExpressionData",
"PaginationMeta",
"GeneExpressionResponse",
"MethodEnum",
"ExpressionQueryBody",
]


class MethodEnum(str, Enum):
raw = "raw"
tpm = "tpm"
tmm = "tmm"
getmm = "getmm"


class PaginatedRequest(BaseModel):
page: int = Field(1, ge=1, description="Current page number")
page_size: int = Field(100, ge=1, le=1000, description="Number of records per page")


class PaginatedResponse(PaginatedRequest):
total_records: int = Field(..., ge=0, description="Total number of records")
total_pages: int = Field(..., ge=1, description="Total number of pages")


class ExperimentResult(BaseModel):
experiment_result_id: str
assembly_id: str | None = None
assembly_name: str | None = None
experiment_result_id: str = Field(..., min_length=1, max_length=255)
assembly_id: Optional[str] = Field(None, max_length=255)
assembly_name: Optional[str] = Field(None, max_length=255)


class GeneExpression(BaseModel):
gene_code: str
sample_id: str
experiment_result_id: str
gene_code: str = Field(..., min_length=1, max_length=255)
sample_id: str = Field(..., min_length=1, max_length=255)
experiment_result_id: str = Field(..., min_length=1, max_length=255)
raw_count: int
tpm_count: float | None = None
tmm_count: float | None = None
getmm_count: float | None = None
tpm_count: Optional[float] = None
tmm_count: Optional[float] = None
getmm_count: Optional[float] = None


class GeneExpressionData(BaseModel):
gene_code: str = Field(..., min_length=1, max_length=255, description="Gene code")
sample_id: str = Field(..., min_length=1, max_length=255, description="Sample ID")
experiment_result_id: str = Field(..., min_length=1, max_length=255, description="Experiment result ID")
count: float = Field(..., description="Expression count")


class ExpressionQueryBody(PaginatedRequest):
genes: Optional[List[str]] = Field(None, description="List of gene codes to retrieve")
experiments: Optional[List[str]] = Field(None, description="List of experiment result IDs to retrieve data from")
sample_ids: Optional[List[str]] = Field(None, description="List of sample IDs to retrieve data from")
method: MethodEnum = Field(MethodEnum.raw, description="Data method to retrieve: 'raw', 'tpm', 'tmm', 'getmm'")


class GeneExpressionResponse(PaginatedResponse):
query: ExpressionQueryBody = Field(..., description="The query that produced this response")
expressions: List[GeneExpressionData] = Field(..., description="List of gene expressions")
89 changes: 83 additions & 6 deletions transcriptomics_data_service/routers/expressions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,89 @@
from fastapi import APIRouter, status
from fastapi import APIRouter, HTTPException, status, Query

from transcriptomics_data_service.db import DatabaseDependency
from transcriptomics_data_service.logger import LoggerDependency
from transcriptomics_data_service.models import (
GeneExpressionData,
GeneExpressionResponse,
ExpressionQueryBody,
)

__all__ = ["expression_router"]
expressions_router = APIRouter(prefix="/expressions")

expression_router = APIRouter(prefix="/expressions")

async def get_expressions_handler(
query_body: ExpressionQueryBody,
db: DatabaseDependency,
logger: LoggerDependency,
):
"""
Handler for fetching and returning gene expression data.
"""
logger.info(f"Received query parameters: {query_body}")

@expression_router.get("", status_code=status.HTTP_200_OK)
async def expressions_list(db: DatabaseDependency):
return await db.fetch_expressions()
expressions, total_records = await db.fetch_gene_expressions(
genes=query_body.genes,
experiments=query_body.experiments,
sample_ids=query_body.sample_ids,
method=query_body.method.value,
page=query_body.page,
page_size=query_body.page_size,
)

if not expressions:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No gene expression data found for the given parameters.",
)

method = query_body.method.value
method_count_field = f"{method}_count"
response_data = [
GeneExpressionData(
gene_code=expr.gene_code,
sample_id=expr.sample_id,
experiment_result_id=expr.experiment_result_id,
count=getattr(expr, method_count_field),
)
for expr in expressions
]

total_pages = (total_records + query_body.page_size - 1) // query_body.page_size

return GeneExpressionResponse(
# pagination
page=query_body.page,
page_size=query_body.page_size,
total_records=total_records,
total_pages=total_pages,
# data
expressions=response_data,
query=query_body,
)


@expressions_router.post(
"",
status_code=status.HTTP_200_OK,
response_model=GeneExpressionResponse,
)
async def get_expressions_post(
params: ExpressionQueryBody,
db: DatabaseDependency,
logger: LoggerDependency,
):
"""
Retrieve gene expression data via POST request.
Using POST instead of GET in order to add a body of type ExpressionQueryBody
Example JSON body:
{
"genes": ["gene1", "gene2"],
"experiments": ["exp1"],
"sample_ids": ["sample1"],
"method": "tmm",
"page": 1,
"page_size": 100
}
"""
return await get_expressions_handler(params, db, logger)
4 changes: 3 additions & 1 deletion transcriptomics_data_service/scripts/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def tmm_normalization(counts_df, logratio_trim=0.3, sum_trim=0.05, weighting=Tru
return normalized_data


def getmm_normalization(counts_df, gene_lengths, logratio_trim=0.3, sum_trim=0.05, scaling_factor=1e3, weighting=True, n_jobs=-1):
def getmm_normalization(
counts_df, gene_lengths, logratio_trim=0.3, sum_trim=0.05, scaling_factor=1e3, weighting=True, n_jobs=-1
):
"""Perform GeTMM normalization on counts data."""
counts_df, gene_lengths = prepare_counts_and_lengths(counts_df, gene_lengths)
rpk = counts_df.mul(scaling_factor).div(gene_lengths, axis=0)
Expand Down

0 comments on commit 39fbcd0

Please sign in to comment.