diff --git a/src/chemlift/icl/fewshotregressor.py b/src/chemlift/icl/fewshotregressor.py index 6d703a5..9d03eba 100644 --- a/src/chemlift/icl/fewshotregressor.py +++ b/src/chemlift/icl/fewshotregressor.py @@ -5,10 +5,11 @@ from chemlift.icl.utils import LangChainChatModelWrapper from langchain.llms import BaseLLM from .fewshotpredictor import Strategy +import numpy as np -class FewShotClassifier(FewShotPredictor): - """A few-shot classifier using in-context learning.""" +class FewShotRegressor(FewShotPredictor): + """A few-shot regressor using in-context learning.""" def __init__( self, @@ -19,6 +20,7 @@ def __init__( seed: int = 42, prefix: str = "You are an expert chemist. ", max_test: int = 5, + num_digits: int = 3, ): """Initialize the few-shot predictor. @@ -34,18 +36,20 @@ def __init__( Defaults to "You are an expert chemist. ". max_test (int, optional): The maximum number of examples to predict at once. Defaults to 5. + num_digits (int, optional): The number of digits to round to. + Defaults to 3. Raises: ValueError: If the strategy is unknown. Examples: - >>> from chemlift.icl.fewshotpredictor import FewShotPredictor + >>> from chemlift.icl.fewshotregressor import FewShotRegressor >>> from langchain.llms import OpenAI >>> llm = OpenAI(model_name="text-ada-001") - >>> predictor = FewShotPredictor(llm, property_name="melting point") - >>> predictor.fit(["water", "ethanol"], [0, 1]) + >>> predictor = FewShotRegressor(llm, property_name="melting point") + >>> predictor.fit(["water", "ethanol"], [0.1, 1.4]) >>> predictor.predict(["methanol"]) - [0] + [0.5] """ self._support_set = None self._llm = llm @@ -57,6 +61,7 @@ def __init__( self._materialclass = "molecules" self._max_test = max_test self._prefix = prefix + self._num_digits = num_digits def _format_examples(self, examples, targets): """Format examples and targets into a string. @@ -64,7 +69,12 @@ def _format_examples(self, examples, targets): Per default, it is a multiline string with - example: target """ - return "\n".join([f"- {example}: {target}" for example, target in zip(examples, targets)]) + return "\n".join( + [ + f"- {example}: {np.round(target, self._num_digits)}" + for example, target in zip(examples, targets) + ] + ) def _extract(self, generations, expected_len): generations = sum(