Skip to content

Commit

Permalink
fix(llms) :Google Palm top_k value fix (#609)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tanmaypatil123 authored Oct 3, 2023
1 parent 42b8256 commit b31dda7
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
6 changes: 3 additions & 3 deletions pandasai/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ class BaseGoogle(LLM):

temperature: Optional[float] = 0
top_p: Optional[float] = 0.8
top_k: Optional[float] = 0.3
top_k: Optional[int] = 40
max_output_tokens: Optional[int] = 1000

def _valid_params(self):
Expand Down Expand Up @@ -409,8 +409,8 @@ def _validate(self):
if self.top_p is not None and not 0 <= self.top_p <= 1:
raise ValueError("top_p must be in the range [0.0, 1.0]")

if self.top_k is not None and not 0 <= self.top_k <= 1:
raise ValueError("top_k must be in the range [0.0, 1.0]")
if self.top_k is not None and not 0 <= self.top_k <= 100:
raise ValueError("top_k must be in the range [0.0, 100.0]")

if self.max_output_tokens is not None and self.max_output_tokens <= 0:
raise ValueError("max_output_tokens must be greater than zero")
Expand Down
1 change: 1 addition & 0 deletions test_file.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Hello, world!
12 changes: 6 additions & 6 deletions tests/llms/test_google_palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ def test_params_setting(self):
model="models/text-bison-001",
temperature=0.5,
top_p=1.0,
top_k=0.5,
top_k=50,
max_output_tokens=64,
)

assert llm.model == "models/text-bison-001"
assert llm.temperature == 0.5
assert llm.top_p == 1.0
assert llm.top_k == 0.5
assert llm.top_k == 50
assert llm.max_output_tokens == 64

def test_validations(self, prompt):
Expand All @@ -69,14 +69,14 @@ def test_validations(self, prompt):
GooglePalm(api_key="test", top_p=1.1).call(prompt, "World")

with pytest.raises(
ValueError, match=re.escape("top_k must be in the range [0.0, 1.0]")
ValueError, match=re.escape("top_k must be in the range [0.0, 100.0]")
):
GooglePalm(api_key="test", top_k=-1).call(prompt, "World")
GooglePalm(api_key="test", top_k=-100).call(prompt, "World")

with pytest.raises(
ValueError, match=re.escape("top_k must be in the range [0.0, 1.0]")
ValueError, match=re.escape("top_k must be in the range [0.0, 100.0]")
):
GooglePalm(api_key="test", top_k=1.1).call(prompt, "World")
GooglePalm(api_key="test", top_k=110).call(prompt, "World")

with pytest.raises(
ValueError, match=re.escape("max_output_tokens must be greater than zero")
Expand Down

0 comments on commit b31dda7

Please sign in to comment.