Skip to content

Commit

Permalink
Merge pull request #10 from sphinxbio/nls/refactor-input
Browse files Browse the repository at this point in the history
Nls/refactor input
  • Loading branch information
nlarusstone authored Oct 30, 2023
2 parents d6c2136 + 0432ea9 commit 1bbffac
Show file tree
Hide file tree
Showing 6 changed files with 363 additions and 220 deletions.
406 changes: 232 additions & 174 deletions poetry.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
[tool.poetry]
name = "platechain"
version = "0.0.2"
version = "0.0.4"
description = "A library of universal parsers for microplate data"
authors = ["Nicholas Larus-Stone <[email protected]>"]
readme = "README.md"

[tool.poetry.dependencies]
python = "^3.8"
langchain = "^0.0.279"
pandas = "^2.1.0"
python = "^3.8.1"
langchain = ">=0.0.313,<0.1"
pandas = "^2.0.1"
python-dotenv = "^1.0.0"
openai = "^0.28.0"
tabulate = "^0.9.0"
Expand Down
161 changes: 120 additions & 41 deletions src/platechain/chain.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations

import json

from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate
from langchain.schema.output_parser import StrOutputParser
import pandas as pd
from pydantic import BaseModel, Field
from platechain.constants import COLS_TO_WELLS_DICT, ROWS_TO_WELLS_DICT

from platechain.prompts import (
AI_REPONSE_DICT,
Expand All @@ -12,7 +16,7 @@
create_prompt,
)
from platechain.utils import (
get_plate_dimensions,
pluck_plate_from_df,
parse_llm_output,
tidy_rectangular_plate_data,
)
Expand All @@ -29,35 +33,115 @@
("human", "{input}"),
],
)


class ParsePlateRequest(BaseModel):
df: pd.DataFrame
num_plates: int | None
num_rows: int | None
num_cols: int | None

class Config:
# Needed to allow pandas dataframes as a type
arbitrary_types_allowed = True


def _load_df(request: ParsePlateRequest):
"""
Assumes the dataframe has a numeric index
"""
return request.df.to_csv(header=False)


def _load_prompt(request: ParsePlateRequest):
return create_prompt(
num_plates=request.num_plates,
num_rows=request.num_rows,
num_cols=request.num_cols,
)


def _get_col_range_str(request: ParsePlateRequest):
if request.num_cols:
return f"from 1 to {request.num_cols}"
else:
return ""


def _get_json_format(request: ParsePlateRequest):
"""
Defaults to a 96-well plate example if no num_rows or num_cols are provided
"""
num_rows = request.num_rows or 8
num_cols = request.num_cols or 12
row_start = 10
col_start = 1
return json.dumps(
[
{
"row_start": row_start,
"row_end": row_start + num_rows - 1,
"col_start": col_start,
"col_end": col_start + num_cols - 1,
"contents": "Entity ID",
}
]
)


def _get_user_example(request: ParsePlateRequest):
# Defaults to a 96-well plate example if no num_rows or num_cols are provided
if request.num_rows is None and request.num_cols is None:
return USER_EXAMPLE_DICT[96]

if request.num_rows is None:
assert (
request.num_cols in COLS_TO_WELLS_DICT.keys()
), f"If num_rows is not provided, num_cols must be a standard value: {COLS_TO_WELLS_DICT.keys()}" # noqa: E501
return USER_EXAMPLE_DICT[COLS_TO_WELLS_DICT[request.num_cols]]

if request.num_cols is None:
assert (
request.num_rows in ROWS_TO_WELLS_DICT.keys()
), f"If num_cols is not provided, num_rows must be a standard value: {ROWS_TO_WELLS_DICT.keys()}" # noqa: E501
return USER_EXAMPLE_DICT[ROWS_TO_WELLS_DICT[request.num_rows]]

assert (
request.num_cols * request.num_rows in USER_EXAMPLE_DICT.keys()
), f"Invalid plate size -- must be one of {USER_EXAMPLE_DICT.keys()}"
return USER_EXAMPLE_DICT[request.num_rows * request.num_cols]


def _get_ai_response(request: ParsePlateRequest):
if request.num_rows is None and request.num_cols is None:
return AI_REPONSE_DICT[96]

if request.num_rows is None:
assert (
request.num_cols in COLS_TO_WELLS_DICT.keys()
), f"If num_rows is not provided, num_cols must be a standard value: {COLS_TO_WELLS_DICT.keys()}" # noqa: E501
return AI_REPONSE_DICT[COLS_TO_WELLS_DICT[request.num_cols]]

if request.num_cols is None:
assert (
request.num_rows in ROWS_TO_WELLS_DICT.keys()
), f"If num_cols is not provided, num_rows must be a standard value: {ROWS_TO_WELLS_DICT.keys()}" # noqa: E501
return AI_REPONSE_DICT[ROWS_TO_WELLS_DICT[request.num_rows]]

assert (
request.num_cols * request.num_rows in USER_EXAMPLE_DICT.keys()
), f"Invalid plate size -- must be one of {AI_REPONSE_DICT.keys()}"
return AI_REPONSE_DICT[request.num_rows * request.num_cols]


chain = (
{
# Should add validation to ensure numeric indices
"input": lambda x: x["input"].to_csv(header=False),
"hint": lambda x: create_prompt(
num_plates=x.get("num_plates"),
num_rows=x.get("num_rows"),
num_cols=x.get("num_cols"),
),
"col_range_str": lambda x: f"from 1 to {x.get('num_cols')}"
if x.get("num_cols")
else "",
"json_format": lambda x: json.dumps(
[
{
"row_start": 12,
"row_end": 12 + x.get("num_rows", 8) - 1,
"col_start": 1,
"col_end": 1 + x.get("num_cols", 12) - 1,
"contents": "Entity ID",
}
]
),
"user_example": lambda x: USER_EXAMPLE_DICT[
x.get("num_rows", 8) * x.get("num_cols", 12)
],
"ai_response": lambda x: AI_REPONSE_DICT[
x.get("num_rows", 8) * x.get("num_cols", 12)
],
"input": _load_df,
"hint": _load_prompt,
"col_range_str": _get_col_range_str,
"json_format": _get_json_format,
"user_example": _get_user_example,
"ai_response": _get_ai_response,
}
| prompt
| llm
Expand All @@ -75,22 +159,17 @@ def parse_plates(
"""
df must have a numeric index
"""
# TODO: add validation around num_rows and num_cols
inp_dict = {
"input": df,
}
# Only add if not None so that `.get` can use the default value in our chain
if num_plates is not None:
inp_dict["num_plates"] = num_plates
if num_rows is not None:
inp_dict["num_rows"] = num_rows
if num_cols is not None:
inp_dict["num_cols"] = num_cols
result = chain.invoke(inp_dict)
req = ParsePlateRequest(
df=df,
num_plates=num_plates,
num_rows=num_rows,
num_cols=num_cols,
)
result = chain.invoke(req)

plates: list[pd.DataFrame] = []
for llm_response in result:
plate_data = get_plate_dimensions(df, llm_response)
plate_data = pluck_plate_from_df(df, llm_response)
plates.append(tidy_rectangular_plate_data(plate_data))
# Returns a list of "tidy" plates so that a downstream user can decide what to do with them
return plates
4 changes: 4 additions & 0 deletions src/platechain/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ def generate_row_letters(num_strings: int, max_length: int = 2):

# 1536 well plates have 32 rows, which we will assume is the max for now
ROW_LETTERS = list(generate_row_letters(32))

# We only support standard plate sizes for right now
ROWS_TO_WELLS_DICT = {4: 24, 8: 96, 16: 384, 32: 1536}
COLS_TO_WELLS_DICT = {6: 24, 12: 96, 24: 384, 48: 1536}
2 changes: 2 additions & 0 deletions src/platechain/prompts.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

FULL_PROMPT = """# Context
- Plate-based data is rectangular and could be situated anywhere within the dataset.
- The first item in every row is the row index
Expand Down
2 changes: 1 addition & 1 deletion src/platechain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def tidy_rectangular_plate_data(
return pd.DataFrame(new_rows)


def get_plate_dimensions(df: pd.DataFrame, plate_loc: LLMPlateResponse) -> pd.DataFrame:
def pluck_plate_from_df(df: pd.DataFrame, plate_loc: LLMPlateResponse) -> pd.DataFrame:
row_start, row_end = plate_loc.row_start, plate_loc.row_end + 1
col_start, col_end = plate_loc.col_start, plate_loc.col_end + 1
proposed_plate = df.iloc[
Expand Down

0 comments on commit 1bbffac

Please sign in to comment.