diff --git a/aleph_alpha_client/completion.py b/aleph_alpha_client/completion.py index 272f7c8..f15813b 100644 --- a/aleph_alpha_client/completion.py +++ b/aleph_alpha_client/completion.py @@ -262,11 +262,13 @@ class CompletionResponse: model_version: str completions: Sequence[CompletionResult] + num_tokens_prompt_total: int optimized_prompt: Optional[Prompt] = None @staticmethod def from_json(json: Dict[str, Any]) -> "CompletionResponse": optimized_prompt_json = json.get("optimized_prompt") + print(json) return CompletionResponse( model_version=json["model_version"], completions=[ @@ -275,6 +277,7 @@ def from_json(json: Dict[str, Any]) -> "CompletionResponse": optimized_prompt=Prompt.from_json(optimized_prompt_json) if optimized_prompt_json else None, + num_tokens_prompt_total=json["num_tokens_prompt_total"], ) def to_json(self) -> Mapping[str, Any]: diff --git a/tests/test_complete.py b/tests/test_complete.py index 863a9fe..a05b8c4 100644 --- a/tests/test_complete.py +++ b/tests/test_complete.py @@ -127,3 +127,18 @@ def test_complete_with_echo(sync_client: Client, model_name: str, prompt_image: assert len(completion_result.completion_tokens) > 0 assert completion_result.log_probs is not None assert len(completion_result.log_probs) > 0 + +@pytest.mark.system_test +def test_num_tokes_prompt_total_with_best_of(sync_client: Client, model_name: str): + tokens = [49222, 2998] # Hello world + best_of = 2 + request = CompletionRequest( + prompt= Prompt.from_tokens(tokens), + best_of = best_of, + maximum_tokens=1, + ) + + response = sync_client.complete(request, model=model_name) + completion_result = response.completions[0] + assert response.num_tokens_prompt_total == len(tokens) * best_of +