Skip to content

Commit

Permalink
work in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
martinreinhardt01 committed Dec 5, 2023
1 parent 14b9a42 commit 1750f91
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
3 changes: 3 additions & 0 deletions aleph_alpha_client/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,13 @@ class CompletionResponse:

model_version: str
completions: Sequence[CompletionResult]
num_tokens_prompt_total: int
optimized_prompt: Optional[Prompt] = None

@staticmethod
def from_json(json: Dict[str, Any]) -> "CompletionResponse":
optimized_prompt_json = json.get("optimized_prompt")
print(json)
return CompletionResponse(
model_version=json["model_version"],
completions=[
Expand All @@ -275,6 +277,7 @@ def from_json(json: Dict[str, Any]) -> "CompletionResponse":
optimized_prompt=Prompt.from_json(optimized_prompt_json)
if optimized_prompt_json
else None,
num_tokens_prompt_total=json["num_tokens_prompt_total"],
)

def to_json(self) -> Mapping[str, Any]:
Expand Down
15 changes: 15 additions & 0 deletions tests/test_complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,18 @@ def test_complete_with_echo(sync_client: Client, model_name: str, prompt_image:
assert len(completion_result.completion_tokens) > 0
assert completion_result.log_probs is not None
assert len(completion_result.log_probs) > 0

@pytest.mark.system_test
def test_num_tokes_prompt_total_with_best_of(sync_client: Client, model_name: str):
tokens = [49222, 2998] # Hello world
best_of = 2
request = CompletionRequest(
prompt= Prompt.from_tokens(tokens),
best_of = best_of,
maximum_tokens=1,
)

response = sync_client.complete(request, model=model_name)
completion_result = response.completions[0]
assert response.num_tokens_prompt_total == len(tokens) * best_of

0 comments on commit 1750f91

Please sign in to comment.