diff --git a/.gitignore b/.gitignore index 6db56a4a..ea9c1977 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ results ./*.json client_configs/*.yaml old_results +results_evaluators # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/src/alpaca_eval/decoders/openai.py b/src/alpaca_eval/decoders/openai.py index df8a0f6c..c3975235 100644 --- a/src/alpaca_eval/decoders/openai.py +++ b/src/alpaca_eval/decoders/openai.py @@ -132,10 +132,10 @@ def openai_completions( prompt_batches = [prompts[batch_id * batch_size : (batch_id + 1) * batch_size] for batch_id in range(n_batches)] - if isinstance(max_tokens, int): - max_tokens = [max_tokens] * n_examples - - inputs = zip(prompt_batches, max_tokens) + try: + inputs = zip(prompt_batches, max_tokens) + except TypeError: + inputs = zip(prompt_batches, [max_tokens] * n_batches) kwargs = dict(model=model_name, **decoding_kwargs) kwargs_to_log = {k: v for k, v in kwargs.items() if "api_key" not in k} @@ -216,7 +216,11 @@ def _openai_completion_helper( client = all_clients[curr_client_idx] # copy shared_kwargs to avoid modifying it - kwargs.update(dict(max_tokens=max_tokens, top_p=top_p, temperature=temperature)) + to_update = dict() + for k in ["max_tokens", "top_p", "temperature"]: + if locals()[k] is not None: + to_update[k] = locals()[k] + kwargs.update(to_update) curr_kwargs = copy.deepcopy(kwargs) # ensure no infinite loop @@ -242,7 +246,7 @@ def _openai_completion_helper( # currently we only use function calls to get a JSON object => return raw text of json choices[i]["text"] = choice.message.function_call.arguments - if choice.message.tool_calls is not None: + if choice.message.tool_calls: # currently we only use function calls to get a JSON object => return raw text of json choices[i]["text"] = choice.message.tool_calls[0].function.arguments @@ -273,10 +277,11 @@ def _openai_completion_helper( return choices else: - if "rate limit" in str(e).lower(): + if "rate " in str(e).lower(): logging.warning(f"Hit request rate limit; retrying...") else: - logging.warning(f"Unknown error. \n It's likely a rate limit so we are retrying...") + logging.exception("Unknown error:") + raise e if len(all_clients) > 1: curr_client_idx = random.choice([idx for idx in client_idcs if idx != curr_client_idx]) client = all_clients[curr_client_idx]