Skip to content

Commit

Permalink
implement fewshotregressor
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Maik Jablonka committed Oct 5, 2023
1 parent 0c60155 commit 373bda1
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions src/chemlift/icl/fewshotregressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from chemlift.icl.utils import LangChainChatModelWrapper
from langchain.llms import BaseLLM
from .fewshotpredictor import Strategy
import numpy as np


class FewShotClassifier(FewShotPredictor):
"""A few-shot classifier using in-context learning."""
class FewShotRegressor(FewShotPredictor):
"""A few-shot regressor using in-context learning."""

def __init__(
self,
Expand All @@ -19,6 +20,7 @@ def __init__(
seed: int = 42,
prefix: str = "You are an expert chemist. ",
max_test: int = 5,
num_digits: int = 3,
):
"""Initialize the few-shot predictor.
Expand All @@ -34,18 +36,20 @@ def __init__(
Defaults to "You are an expert chemist. ".
max_test (int, optional): The maximum number of examples to predict at once.
Defaults to 5.
num_digits (int, optional): The number of digits to round to.
Defaults to 3.
Raises:
ValueError: If the strategy is unknown.
Examples:
>>> from chemlift.icl.fewshotpredictor import FewShotPredictor
>>> from chemlift.icl.fewshotregressor import FewShotRegressor
>>> from langchain.llms import OpenAI
>>> llm = OpenAI(model_name="text-ada-001")
>>> predictor = FewShotPredictor(llm, property_name="melting point")
>>> predictor.fit(["water", "ethanol"], [0, 1])
>>> predictor = FewShotRegressor(llm, property_name="melting point")
>>> predictor.fit(["water", "ethanol"], [0.1, 1.4])
>>> predictor.predict(["methanol"])
[0]
[0.5]
"""
self._support_set = None
self._llm = llm
Expand All @@ -57,14 +61,20 @@ def __init__(
self._materialclass = "molecules"
self._max_test = max_test
self._prefix = prefix
self._num_digits = num_digits

def _format_examples(self, examples, targets):
"""Format examples and targets into a string.
Per default, it is a multiline string with
- example: target
"""
return "\n".join([f"- {example}: {target}" for example, target in zip(examples, targets)])
return "\n".join(
[
f"- {example}: {np.round(target, self._num_digits)}"
for example, target in zip(examples, targets)
]
)

def _extract(self, generations, expected_len):
generations = sum(
Expand Down

0 comments on commit 373bda1

Please sign in to comment.