From b31dda7a9424b36756025eae4a3811d0bf2d9860 Mon Sep 17 00:00:00 2001 From: Tanmay patil <77950208+Tanmaypatil123@users.noreply.github.com> Date: Tue, 3 Oct 2023 21:04:56 +0530 Subject: [PATCH] fix(llms) :Google Palm top_k value fix (#609) --- pandasai/llm/base.py | 6 +++--- test_file.txt | 1 + tests/llms/test_google_palm.py | 12 ++++++------ 3 files changed, 10 insertions(+), 9 deletions(-) create mode 100644 test_file.txt diff --git a/pandasai/llm/base.py b/pandasai/llm/base.py index 5b2e72012..0c742c51f 100644 --- a/pandasai/llm/base.py +++ b/pandasai/llm/base.py @@ -375,7 +375,7 @@ class BaseGoogle(LLM): temperature: Optional[float] = 0 top_p: Optional[float] = 0.8 - top_k: Optional[float] = 0.3 + top_k: Optional[int] = 40 max_output_tokens: Optional[int] = 1000 def _valid_params(self): @@ -409,8 +409,8 @@ def _validate(self): if self.top_p is not None and not 0 <= self.top_p <= 1: raise ValueError("top_p must be in the range [0.0, 1.0]") - if self.top_k is not None and not 0 <= self.top_k <= 1: - raise ValueError("top_k must be in the range [0.0, 1.0]") + if self.top_k is not None and not 0 <= self.top_k <= 100: + raise ValueError("top_k must be in the range [0.0, 100.0]") if self.max_output_tokens is not None and self.max_output_tokens <= 0: raise ValueError("max_output_tokens must be greater than zero") diff --git a/test_file.txt b/test_file.txt new file mode 100644 index 000000000..5dd01c177 --- /dev/null +++ b/test_file.txt @@ -0,0 +1 @@ +Hello, world! \ No newline at end of file diff --git a/tests/llms/test_google_palm.py b/tests/llms/test_google_palm.py index 91162819b..9c96d7e82 100644 --- a/tests/llms/test_google_palm.py +++ b/tests/llms/test_google_palm.py @@ -37,14 +37,14 @@ def test_params_setting(self): model="models/text-bison-001", temperature=0.5, top_p=1.0, - top_k=0.5, + top_k=50, max_output_tokens=64, ) assert llm.model == "models/text-bison-001" assert llm.temperature == 0.5 assert llm.top_p == 1.0 - assert llm.top_k == 0.5 + assert llm.top_k == 50 assert llm.max_output_tokens == 64 def test_validations(self, prompt): @@ -69,14 +69,14 @@ def test_validations(self, prompt): GooglePalm(api_key="test", top_p=1.1).call(prompt, "World") with pytest.raises( - ValueError, match=re.escape("top_k must be in the range [0.0, 1.0]") + ValueError, match=re.escape("top_k must be in the range [0.0, 100.0]") ): - GooglePalm(api_key="test", top_k=-1).call(prompt, "World") + GooglePalm(api_key="test", top_k=-100).call(prompt, "World") with pytest.raises( - ValueError, match=re.escape("top_k must be in the range [0.0, 1.0]") + ValueError, match=re.escape("top_k must be in the range [0.0, 100.0]") ): - GooglePalm(api_key="test", top_k=1.1).call(prompt, "World") + GooglePalm(api_key="test", top_k=110).call(prompt, "World") with pytest.raises( ValueError, match=re.escape("max_output_tokens must be greater than zero")