Skip to content

Commit

Permalink
start implement fewshotregressor
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Maik Jablonka committed Oct 4, 2023
1 parent 7016362 commit 0c60155
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 2 deletions.
6 changes: 4 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ classifiers =
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3 :: Only
keywords =
llm
llm
chemistry
lift

fine-tuning
icl
gpt


[options]
Expand Down
101 changes: 101 additions & 0 deletions src/chemlift/icl/fewshotregressor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from loguru import logger
from numpy.typing import ArrayLike
from chemlift.icl.fewshotpredictor import FewShotPredictor
from typing import Union
from chemlift.icl.utils import LangChainChatModelWrapper
from langchain.llms import BaseLLM
from .fewshotpredictor import Strategy


class FewShotClassifier(FewShotPredictor):
"""A few-shot classifier using in-context learning."""

def __init__(
self,
llm: Union[BaseLLM, LangChainChatModelWrapper],
property_name: str,
n_support: int = 5,
strategy: Strategy = Strategy.RANDOM,
seed: int = 42,
prefix: str = "You are an expert chemist. ",
max_test: int = 5,
):
"""Initialize the few-shot predictor.
Args:
llm (Union[BaseLLM, LangChainChatModelWrapper]): The language model to use.
property_name (str): The property to predict.
n_support (int, optional): The number of examples to use as support set.
Defaults to 5.
strategy (Strategy, optional): The strategy to use to pick the support set.
Defaults to Strategy.RANDOM.
seed (int, optional): The random seed to use. Defaults to 42.
prefix (str, optional): The prefix to use for the prompt.
Defaults to "You are an expert chemist. ".
max_test (int, optional): The maximum number of examples to predict at once.
Defaults to 5.
Raises:
ValueError: If the strategy is unknown.
Examples:
>>> from chemlift.icl.fewshotpredictor import FewShotPredictor
>>> from langchain.llms import OpenAI
>>> llm = OpenAI(model_name="text-ada-001")
>>> predictor = FewShotPredictor(llm, property_name="melting point")
>>> predictor.fit(["water", "ethanol"], [0, 1])
>>> predictor.predict(["methanol"])
[0]
"""
self._support_set = None
self._llm = llm
self._n_support = n_support
self._strategy = strategy
self._seed = seed
self._property_name = property_name
self._allowed_values = None
self._materialclass = "molecules"
self._max_test = max_test
self._prefix = prefix

def _format_examples(self, examples, targets):
"""Format examples and targets into a string.
Per default, it is a multiline string with
- example: target
"""
return "\n".join([f"- {example}: {target}" for example, target in zip(examples, targets)])

def _extract(self, generations, expected_len):
generations = sum(
[
g[0].text.split(":")[-1].replace("Answer: ", "").strip().split(",")
for generation in generations
for g in generation.generations
],
[],
)
if len(generations) != expected_len:
logger.warning(f"Expected {expected_len} generations, got {len(generations)}")
return [None] * expected_len
original_length = len(generations)
if self.intify:
generations_ = []
for g in generations:
try:
generations_.append(float(g.strip()))
except Exception:
generations_.append(None)
generations = generations_
assert len(generations) == original_length
return generations

def predict(self, X: ArrayLike, generation_kwargs: dict = {}):
"""Predict the class of a list of examples.
Args:
X: A list of examples.
generation_kwargs: Keyword arguments to pass to the language model.
"""
generations = self._predict(X, generation_kwargs)
return self._extract(generations, expected_len=len(X))

0 comments on commit 0c60155

Please sign in to comment.