Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Search basic #15

Merged
merged 20 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 78 additions & 26 deletions transcriptomics_data_service/db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Annotated, AsyncIterator, List, Tuple
from typing import Annotated, AsyncIterator, List, Tuple, Optional
import asyncpg
from bento_lib.db.pg_async import PgAsyncDatabase
from contextlib import asynccontextmanager
Expand Down Expand Up @@ -111,9 +111,6 @@ async def create_gene_expressions(self, expressions: list[GeneExpression], trans
await transaction_conn.executemany(query, records)
self.logger.info(f"Inserted {len(records)} gene expression records.")

async def fetch_expressions(self) -> tuple[GeneExpression, ...]:
return tuple([r async for r in self._select_expressions(None)])

async def _select_expressions(self, exp_id: str | None) -> AsyncIterator[GeneExpression]:
conn: asyncpg.Connection
where_clause = "WHERE experiment_result_id = $1" if exp_id is not None else ""
Expand All @@ -123,28 +120,6 @@ async def _select_expressions(self, exp_id: str | None) -> AsyncIterator[GeneExp
for r in map(lambda g: self._deserialize_gene_expression(g), res):
yield r

async def fetch_gene_expressions_by_experiment_id(self, experiment_result_id: str) -> Tuple[GeneExpression, ...]:
"""
Fetch gene expressions for a specific experiment_result_id.
"""
conn: asyncpg.Connection
async with self.connect() as conn:
query = """
SELECT * FROM gene_expressions WHERE experiment_result_id = $1
"""
res = await conn.fetch(query, experiment_result_id)
return tuple([self._deserialize_gene_expression(record) for record in res])

async def fetch_gene_expressions(
self, experiments: list[str], method: str = "raw", paginate: bool = False
) -> Tuple[Tuple[GeneExpression, ...], int]:
if not experiments:
return (), 0
# TODO: refactor this fetch_gene_expressions_by_experiment_id and implement pagination
experiment_result_id = experiments[0]
expressions = await self.fetch_gene_expressions_by_experiment_id(experiment_result_id)
return expressions, len(expressions)

def _deserialize_gene_expression(self, rec: asyncpg.Record) -> GeneExpression:
return GeneExpression(
gene_code=rec["gene_code"],
Expand Down Expand Up @@ -219,6 +194,83 @@ async def transaction_connection(self):
# operations must be made using this connection for the transaction to apply
yield conn

async def fetch_gene_expressions(
self,
genes: Optional[List[str]] = None,
experiments: Optional[List[str]] = None,
sample_ids: Optional[List[str]] = None,
method: str = "raw",
page: int = 1,
page_size: int = 100,
paginate: bool = True,
) -> Tuple[List[GeneExpression], int]:
"""
Fetch gene expressions based on genes, experiments, sample_ids, and method, with optional pagination.
Returns a tuple of (expressions list, total_records count).
"""
conn: asyncpg.Connection
async with self.connect() as conn:
# Query builder
base_query = """
SELECT gene_code, sample_id, experiment_result_id, raw_count, tpm_count, tmm_count, getmm_count
FROM gene_expressions
"""
count_query = "SELECT COUNT(*) FROM gene_expressions"
conditions = []
params = []
param_counter = 1

if genes:
conditions.append(f"gene_code = ANY(${param_counter}::text[])")
params.append(genes)
param_counter += 1

if experiments:
conditions.append(f"experiment_result_id = ANY(${param_counter}::text[])")
params.append(experiments)
param_counter += 1

if sample_ids:
conditions.append(f"sample_id = ANY(${param_counter}::text[])")
params.append(sample_ids)
param_counter += 1

if method != "raw":
conditions.append(f"{method}_count IS NOT NULL")

where_clause = " WHERE " + " AND ".join(conditions) if conditions else ""

order_clause = " ORDER BY gene_code, sample_id"

query = base_query + where_clause + order_clause
count_query += where_clause

# Pagination
if paginate:
limit_offset_clause = f" LIMIT ${param_counter} OFFSET ${param_counter + 1}"
params.extend([page_size, (page - 1) * page_size])
query += limit_offset_clause

total_records_params = params[:-2] if paginate else params
total_records = await conn.fetchval(count_query, *total_records_params)

res = await conn.fetch(query, *params)

expressions = [
GeneExpression(
gene_code=record["gene_code"],
sample_id=record["sample_id"],
experiment_result_id=record["experiment_result_id"],
raw_count=record["raw_count"],
tpm_count=record["tpm_count"],
tmm_count=record["tmm_count"],
getmm_count=record["getmm_count"],
)
for record in res
]

return expressions, total_records


@lru_cache()
def get_db(config: ConfigDependency, logger: LoggerDependency) -> Database:
Expand Down
4 changes: 2 additions & 2 deletions transcriptomics_data_service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from transcriptomics_data_service.db import get_db
from transcriptomics_data_service.routers.experiment_results import experiment_router
from transcriptomics_data_service.routers.expressions import expression_router
from transcriptomics_data_service.routers.ingest import ingest_router
from transcriptomics_data_service.routers.normalization import normalization_router
from transcriptomics_data_service.routers.query import query_router
from . import __version__
from .config import get_config
from .constants import BENTO_SERVICE_KIND, SERVICE_TYPE
Expand Down Expand Up @@ -42,7 +42,7 @@ async def lifespan(_app: FastAPI):
lifespan=lifespan,
)

app.include_router(expression_router)
app.include_router(ingest_router)
app.include_router(experiment_router)
app.include_router(normalization_router)
app.include_router(query_router)
80 changes: 70 additions & 10 deletions transcriptomics_data_service/models.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,82 @@
from pydantic import BaseModel
from pydantic import BaseModel, Field, validator
from typing import List, Optional
from enum import Enum

__all__ = [
"ExperimentResult",
"GeneExpression",
"GeneExpressionData",
"PaginationMeta",
"GeneExpressionResponse",
"MethodEnum",
"QueryParameters",
]


class ExperimentResult(BaseModel):
experiment_result_id: str
assembly_id: str | None = None
assembly_name: str | None = None
experiment_result_id: str = Field(..., min_length=1, max_length=255)
assembly_id: Optional[str] = Field(None, max_length=255)
assembly_name: Optional[str] = Field(None, max_length=255)


class GeneExpression(BaseModel):
gene_code: str
sample_id: str
experiment_result_id: str
gene_code: str = Field(..., min_length=1, max_length=255)
sample_id: str = Field(..., min_length=1, max_length=255)
experiment_result_id: str = Field(..., min_length=1, max_length=255)
raw_count: int
tpm_count: float | None = None
tmm_count: float | None = None
getmm_count: float | None = None
tpm_count: Optional[float] = None
tmm_count: Optional[float] = None
getmm_count: Optional[float] = None


class GeneExpressionData(BaseModel):
gene_code: str = Field(..., min_length=1, max_length=255, description="Gene code")
noctillion marked this conversation as resolved.
Show resolved Hide resolved
sample_id: str = Field(..., min_length=1, max_length=255, description="Sample ID")
experiment_result_id: str = Field(..., min_length=1, max_length=255, description="Experiment result ID")
count: float = Field(..., description="Expression count")
method: str = Field(..., description="Method used to calculate the expression count")

noctillion marked this conversation as resolved.
Show resolved Hide resolved

class PaginationMeta(BaseModel):
total_records: int = Field(..., ge=0, description="Total number of records")
page: int = Field(..., ge=1, description="Current page number")
page_size: int = Field(..., ge=1, le=1000, description="Number of records per page")
total_pages: int = Field(..., ge=1, description="Total number of pages")


class GeneExpressionResponse(BaseModel):
expressions: List[GeneExpressionData]
pagination: PaginationMeta


class MethodEnum(str, Enum):
raw = "raw"
tpm = "tpm"
tmm = "tmm"
getmm = "getmm"


class QueryParameters(BaseModel):
genes: Optional[List[str]] = Field(None, description="List of gene codes to retrieve")
noctillion marked this conversation as resolved.
Show resolved Hide resolved
experiments: Optional[List[str]] = Field(None, description="List of experiment result IDs to retrieve data from")
sample_ids: Optional[List[str]] = Field(None, description="List of sample IDs to retrieve data from")
method: MethodEnum = Field(MethodEnum.raw, description="Data method to retrieve: 'raw', 'tpm', 'tmm', 'getmm'")
page: int = Field(
1,
ge=1,
description="Page number for pagination (must be greater than or equal to 1)",
)
page_size: int = Field(
100,
ge=1,
le=1000,
description="Number of records per page (between 1 and 1000)",
)

noctillion marked this conversation as resolved.
Show resolved Hide resolved
@validator("genes", "experiments", "sample_ids", each_item=True)
def validate_identifiers(cls, value):
if not (1 <= len(value) <= 255):
raise ValueError("Each identifier must be between 1 and 255 characters long.")
if not value.replace("_", "").isalnum():
raise ValueError("Identifiers must contain only alphanumeric characters and underscores.")
return value
88 changes: 88 additions & 0 deletions transcriptomics_data_service/routers/query.py
v-rocheleau marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from fastapi import APIRouter, HTTPException, status, Query

from transcriptomics_data_service.db import DatabaseDependency
from transcriptomics_data_service.logger import LoggerDependency
from transcriptomics_data_service.models import (
GeneExpressionData,
GeneExpressionResponse,
PaginationMeta,
MethodEnum,
QueryParameters,
)

query_router = APIRouter(prefix="/query")


async def get_expressions_handler(
params: QueryParameters,
db: DatabaseDependency,
logger: LoggerDependency,
):
"""
Handler for fetching and returning gene expression data.
"""
logger.info(f"Received query parameters: {params}")

expressions, total_records = await db.fetch_gene_expressions(
genes=params.genes,
experiments=params.experiments,
sample_ids=params.sample_ids,
method=params.method.value,
page=params.page,
page_size=params.page_size,
)

if not expressions:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No gene expression data found for the given parameters.",
)

response_data = []
method_field = f"{params.method.value}_count" if params.method != MethodEnum.raw else "raw_count"
for expr in expressions:
count = getattr(expr, method_field)
response_item = GeneExpressionData(
gene_code=expr.gene_code,
sample_id=expr.sample_id,
experiment_result_id=expr.experiment_result_id,
count=count,
method=method_field,
)
response_data.append(response_item)

total_pages = (total_records + params.page_size - 1) // params.page_size
pagination_meta = PaginationMeta(
total_records=total_records,
page=params.page,
page_size=params.page_size,
total_pages=total_pages,
)

return GeneExpressionResponse(expressions=response_data, pagination=pagination_meta)


@query_router.post(
"/expressions",
status_code=status.HTTP_200_OK,
response_model=GeneExpressionResponse,
)
async def get_expressions_post(
params: QueryParameters,
db: DatabaseDependency,
logger: LoggerDependency,
):
"""
Retrieve gene expression data via POST request.

Example JSON body:
{
"genes": ["gene1", "gene2"],
"experiments": ["exp1"],
"sample_ids": ["sample1"],
"method": "tmm",
"page": 1,
"page_size": 100
}
"""
return await get_expressions_handler(params, db, logger)
Loading