diff --git a/src/helm/benchmark/adaptation/adapter_spec.py b/src/helm/benchmark/adaptation/adapter_spec.py index 2a15851a360..ccb85ae62e8 100644 --- a/src/helm/benchmark/adaptation/adapter_spec.py +++ b/src/helm/benchmark/adaptation/adapter_spec.py @@ -6,6 +6,7 @@ # Adaptation methods ADAPT_GENERATION: str = "generation" +ADAPT_CHAT: str = "chat" ADAPT_LANGUAGE_MODELING: str = "language_modeling" ADAPT_MULTIPLE_CHOICE_JOINT: str = "multiple_choice_joint" ADAPT_MULTIPLE_CHOICE_JOINT_CHAIN_OF_THOUGHT: str = "multiple_choice_joint_chain_of_thought" diff --git a/src/helm/benchmark/adaptation/adapters/adapter_factory.py b/src/helm/benchmark/adaptation/adapters/adapter_factory.py index b865f95b599..f5f0df89f00 100644 --- a/src/helm/benchmark/adaptation/adapters/adapter_factory.py +++ b/src/helm/benchmark/adaptation/adapters/adapter_factory.py @@ -1,5 +1,6 @@ from helm.benchmark.adaptation.adapter_spec import ( ADAPT_GENERATION, + ADAPT_CHAT, ADAPT_GENERATION_MULTIMODAL, ADAPT_LANGUAGE_MODELING, ADAPT_MULTIPLE_CHOICE_JOINT, @@ -13,6 +14,7 @@ from helm.benchmark.adaptation.adapters.adapter import Adapter from helm.benchmark.adaptation.adapters.binary_ranking_adapter import BinaryRankingAdapter from helm.benchmark.adaptation.adapters.generation_adapter import GenerationAdapter +from helm.benchmark.adaptation.adapters.chat_adapter import ChatAdapter from helm.benchmark.adaptation.adapters.language_modeling_adapter import LanguageModelingAdapter from helm.benchmark.adaptation.adapters.multimodal.generation_multimodal_adapter import GenerationMultimodalAdapter from helm.benchmark.adaptation.adapters.multimodal.multiple_choice_joint_multimodal_adapter import ( @@ -38,6 +40,8 @@ def get_adapter(adapter_spec: AdapterSpec, tokenizer_service: TokenizerService) if method == ADAPT_GENERATION: adapter = GenerationAdapter(adapter_spec, tokenizer_service) + elif method == ADAPT_CHAT: + adapter = ChatAdapter(adapter_spec, tokenizer_service) elif method == ADAPT_LANGUAGE_MODELING: adapter = LanguageModelingAdapter(adapter_spec, tokenizer_service) elif method == ADAPT_MULTIPLE_CHOICE_JOINT: diff --git a/src/helm/benchmark/adaptation/adapters/chat_adapter.py b/src/helm/benchmark/adaptation/adapters/chat_adapter.py new file mode 100644 index 00000000000..3c557a18afe --- /dev/null +++ b/src/helm/benchmark/adaptation/adapters/chat_adapter.py @@ -0,0 +1,52 @@ +from typing import List + +from helm.benchmark.adaptation.request_state import RequestState +from helm.benchmark.scenarios.scenario import Instance +from helm.common.request import Request +from helm.benchmark.adaptation.adapters.in_context_learning_adapter import InContextLearningAdapter + + +class ChatAdapter(InContextLearningAdapter): + """ + Each `Instance` in a `Scenario` has a history of the format: + + [ + {"role": "user", "content": }, + {"role": "assistant", "content": }, + {"role": "user", "content": }, + ... + ] + + """ + + def generate_requests( + self, eval_instance: Instance, train_trial_index: int, training_instances: List[Instance] + ) -> List[RequestState]: + assert eval_instance.extra_data + messages = [ + {"role": message["role"], "content": message["content"]} + for message in eval_instance.extra_data["conversation"] + ] + request = Request( + model=self.adapter_spec.model, + model_deployment=self.adapter_spec.model_deployment, + messages=messages, + num_completions=self.adapter_spec.num_outputs, + temperature=self.adapter_spec.temperature, + max_tokens=self.adapter_spec.max_tokens, + stop_sequences=self.adapter_spec.stop_sequences, + random=self.adapter_spec.random, + image_generation_parameters=self.adapter_spec.image_generation_parameters, + ) + request_state = RequestState( + instance=eval_instance, + reference_index=None, + request_mode=None, + train_trial_index=train_trial_index, + output_mapping=None, + request=request, + result=None, + num_train_instances=0, + prompt_truncated=False, + ) + return [request_state] diff --git a/src/helm/benchmark/annotation/wildbench/eval_template.pairwise.v2.md b/src/helm/benchmark/annotation/wildbench/eval_template.pairwise.v2.md new file mode 100644 index 00000000000..f00b6e82d9e --- /dev/null +++ b/src/helm/benchmark/annotation/wildbench/eval_template.pairwise.v2.md @@ -0,0 +1,75 @@ +# Instruction + +You are an expert evaluator. Your task is to evaluate the quality of the responses generated by two AI models. +We will provide you with the user query and a pair of AI-generated responses (Response A and Response B). +You should first read the user query and the conversation history carefully for analyzing the task, and then evaluate the quality of the responses based on and rules provided below. + +# Conversation between User and AI + +## History +<|begin_of_history|> + +{$history} + +<|end_of_history|> + +## Current User Query +<|begin_of_query|> + +{$user_query} + +<|end_of_query|> + +## Response A +<|begin_of_response_A|> + +{$candidate_A} + +<|end_of_response_A|> + +## Response B +<|begin_of_response_B|> + +{$candidate_B} + +<|end_of_response_B|> + +# Evaluation + +## Checklist + +<|begin_of_checklist|> + +{$checklist} + +<|end_of_checklist|> + +Please use this checklist to guide your evaluation, but do not limit your assessment to the checklist. + +## Rules + +You should compare the above two responses based on your analysis of the user queries and the conversation history. +You should first write down your analysis and the checklist that you used for the evaluation, and then provide your assessment according to the checklist. +There are five choices to give your final assessment: ["A++", "A+", "A=B", "B+", "B++"], which correspond to the following meanings: + +- `A++`: Response A is much better than Response B. +- `A+`: Response A is only slightly better than Response B. +- `A=B`: Response A and B are of the same quality. Please use this choice sparingly. +- `B+`: Response B is only slightly better than Response A. +- `B++`: Response B is much better than Response A. + + +## Output Format +First, please output your analysis for each model response, and then summarize your assessment to three aspects: "reason A=B", "reason A>B", and "reason B>A", and finally make your choice for the final assessment. + +Please provide your evaluation results in the following json format by filling in the placeholders in []: +``` +{ + "analysis of A": "[analysis of Response A]", + "analysis of B": "[analysis of Response B]", + "reason of A=B": "[where Response A and B perform equally well]", + "reason of A>B": "[where Response A is better than Response B]", + "reason of B>A": "[where Response B is better than Response A]", + "choice": "[A++ or A+ or A=B or B+ or B++]", +} +``` diff --git a/src/helm/benchmark/annotation/wildbench/eval_template.score.v2.md b/src/helm/benchmark/annotation/wildbench/eval_template.score.v2.md new file mode 100644 index 00000000000..8bbe07ace37 --- /dev/null +++ b/src/helm/benchmark/annotation/wildbench/eval_template.score.v2.md @@ -0,0 +1,66 @@ +# Instruction + +You are an expert evaluator. Your task is to evaluate the quality of the responses generated by AI models. +We will provide you with the user query and an AI-generated responses. +You should first read the user query and the conversation history carefully for analyzing the task, and then evaluate the quality of the responses based on and rules provided below. + +# Conversation between User and AI + +## History +<|begin_of_history|> + +{$history} + +<|end_of_history|> + +## Current User Query +<|begin_of_query|> + +{$user_query} + +<|end_of_query|> + +## AI Response +<|begin_of_response|> + +{$model_output} + +<|end_of_response|> + + +# Evaluation + +## Checklist + +<|begin_of_checklist|> + +{$checklist} + +<|end_of_checklist|> + +Please use this checklist to guide your evaluation, but do not limit your assessment to the checklist. + +## Rules + +You should compare the above response based on your analysis of the user queries and the conversation history. +You should first write down your analysis and the checklist that you used for the evaluation, and then provide your assessment according to the checklist. +The scores are in the range of 1~10, where 1 means the response is very poor and 10 means the response is perfect. +Here are more detailed criteria for the scores: + +- Score 1~2: The response is very poor and does not make sense at all. +- Score 3~4: The response is poor and does help user solve the problem in a meaningful way. +- Score 5~6: The response is fair but has some issues (e.g., factual errors, hallucinations, missing key information). +- Score 7~8: The response is good enough but could be improved in some ways. +- Score 9~10: The response is perfect and provides helpful information that can help user solve the problem. + +## Output Format +First, please output your analysis for the model response, and then summarize your assessment to two aspects: "strengths" and "weaknesses"; Finally, please write down your rating for the assessment. + +Please provide your evaluation results in the following json format by filling in the placeholders in []: +``` +{ + "strengths": "[analysis for the strengths of the response]", + "weaknesses": "[analysis for the weaknesses of the response]", + "score": "[1~10]" +} +``` \ No newline at end of file diff --git a/src/helm/benchmark/annotation/wildbench_annotator.py b/src/helm/benchmark/annotation/wildbench_annotator.py new file mode 100644 index 00000000000..20de62181c9 --- /dev/null +++ b/src/helm/benchmark/annotation/wildbench_annotator.py @@ -0,0 +1,63 @@ +import re +from typing import Any + +from helm.benchmark.adaptation.request_state import RequestState +from helm.benchmark.annotation.annotator import Annotator +from helm.clients.auto_client import AutoClient +from helm.common.request import Request + + +class WildBenchAnnotator(Annotator): + """The WildBench autograder.""" + + name = "wildbench" + + def __init__(self, auto_client: AutoClient): + self._auto_client = auto_client + with open("src/helm/benchmark/annotation/wildbench/eval_template.score.v2.md") as f: + self._score_template = f.read() + self._pattern = re.compile( + r'"strengths"\s*:\s*"(.*?)"\s*,\s*"weaknesses"\s*:\s*"(.*?)"\s*,\s*"score"\s*:\s*(".*?"|\d+)', re.DOTALL + ) + + def annotate(self, request_state: RequestState) -> Any: + assert request_state.result + assert len(request_state.result.completions) == 1 + assert request_state.instance.extra_data + model_output_text = request_state.result.completions[0].text + if not model_output_text.strip(): + # Following https://github.com/allenai/WildBench/blob/d6b8dcaf377d173d031980f97c16e1a82618c03d/src/eval.py + return {"prompt_text": "", "strengths": "N/A", "weaknesses": "The model output is empty.", "score": 1.0} + prompt_template = self._score_template + + annotator_prompt = ( + prompt_template.replace("{$history}", request_state.instance.extra_data["history"]) + .replace("{$user_query}", request_state.instance.extra_data["user_query"]) + .replace("{$model_output}", model_output_text) + .replace("{$checklist}", "\n".join(request_state.instance.extra_data["checklist"])) + ) + annotator_request = Request( + model="openai/gpt-4o-2024-05-13", + model_deployment="openai/gpt-4o-2024-05-13", + prompt=annotator_prompt, + temperature=0.0, + max_tokens=2000, + ) + annotator_response = self._auto_client.make_request(annotator_request) + if not annotator_response.success: + raise Exception(f"Annotation request failed: {annotator_response.error}") + assert len(annotator_response.completions) == 1 + annotator_response_text = annotator_response.completions[0].text + annotator_response_parts = self._pattern.search(annotator_response_text) + if not annotator_response_parts: + raise ValueError(f"Malformed annotator response: {annotator_response_text}") + + strengths = annotator_response_parts[1].strip() + weaknesses = annotator_response_parts[2].strip() + score_text = annotator_response_parts[3].strip().strip('"') + try: + score = float(score_text) + except ValueError: + raise ValueError(f"Malformed score '{score_text}' in annotator response: {annotator_response_text}") + + return {"prompt_text": annotator_prompt, "strengths": strengths, "weaknesses": weaknesses, "score": score} diff --git a/src/helm/benchmark/metrics/ifeval_metrics.py b/src/helm/benchmark/metrics/ifeval_metrics.py index d7c7c936c8f..9eb9479a56d 100644 --- a/src/helm/benchmark/metrics/ifeval_metrics.py +++ b/src/helm/benchmark/metrics/ifeval_metrics.py @@ -1,5 +1,6 @@ from typing import List +from helm.common.hierarchical_logger import hlog from helm.benchmark.adaptation.adapter_spec import AdapterSpec from helm.benchmark.adaptation.request_state import RequestState from helm.benchmark.metrics.metric import Metric @@ -40,7 +41,13 @@ def evaluate_generation( if args and "prompt" in args: instruction.build_description(prompt=prompt) - if response.strip() and instruction.check_following(response): + is_following = False + if response.strip(): + try: + is_following = instruction.check_following(response) + except Exception as e: + hlog(f"WARNING: Instruction following checking failed with error message {e}") + if is_following: is_following_list.append(1) else: is_following_list.append(0) diff --git a/src/helm/benchmark/metrics/wildbench_metrics.py b/src/helm/benchmark/metrics/wildbench_metrics.py new file mode 100644 index 00000000000..01b7a2abd00 --- /dev/null +++ b/src/helm/benchmark/metrics/wildbench_metrics.py @@ -0,0 +1,25 @@ +from typing import List + +from helm.benchmark.adaptation.adapter_spec import AdapterSpec +from helm.benchmark.adaptation.request_state import RequestState +from helm.benchmark.metrics.metric import Metric +from helm.benchmark.metrics.metric_name import MetricName +from helm.benchmark.metrics.metric_service import MetricService +from helm.benchmark.metrics.statistic import Stat + + +class WildBenchScoreMetric(Metric): + """Score metrics for WildBench.""" + + def evaluate_generation( + self, + adapter_spec: AdapterSpec, + request_state: RequestState, + metric_service: MetricService, + eval_cache_path: str, + ) -> List[Stat]: + assert request_state.annotations + score = request_state.annotations["wildbench"]["score"] + return [ + Stat(MetricName("wildbench_score")).add(score), + ] diff --git a/src/helm/benchmark/run_specs/lite_run_specs.py b/src/helm/benchmark/run_specs/lite_run_specs.py index c995aa3e36c..6dc83be0f67 100644 --- a/src/helm/benchmark/run_specs/lite_run_specs.py +++ b/src/helm/benchmark/run_specs/lite_run_specs.py @@ -4,6 +4,7 @@ from helm.benchmark.adaptation.adapter_spec import ( ADAPT_GENERATION, + ADAPT_CHAT, ADAPT_MULTIPLE_CHOICE_JOINT, ADAPT_MULTIPLE_CHOICE_JOINT_CHAIN_OF_THOUGHT, AdapterSpec, @@ -26,6 +27,7 @@ from helm.benchmark.runner import get_benchmark_output_path from helm.benchmark.scenarios.scenario import ScenarioSpec, get_scenario_cache_path from helm.benchmark.metrics.metric import MetricSpec +from helm.benchmark.annotation.annotator import AnnotatorSpec @run_spec_function("narrative_qa") @@ -449,3 +451,30 @@ def get_ifeval_spec() -> RunSpec: metric_specs=metric_specs, groups=["ifeval"], ) + + +@run_spec_function("wildbench") +def get_wildbench_spec(subset: str, use_model_outputs: str = "False") -> RunSpec: + + scenario_spec = ScenarioSpec( + class_name="helm.benchmark.scenarios.wildbench_scenario.WildBenchScenario", + args={ + "subset": subset, + "use_model_outputs": use_model_outputs == "True", + }, + ) + + adapter_spec = AdapterSpec( + method=ADAPT_CHAT, input_prefix="", output_prefix="", max_tokens=2000, num_outputs=1, temperature=0.0 + ) + annotator_specs = [AnnotatorSpec(class_name="helm.benchmark.annotation.wildbench_annotator.WildBenchAnnotator")] + metric_specs = [MetricSpec(class_name="helm.benchmark.metrics.wildbench_metrics.WildBenchScoreMetric")] + + return RunSpec( + name="wildbench", + scenario_spec=scenario_spec, + adapter_spec=adapter_spec, + annotators=annotator_specs, + metric_specs=metric_specs, + groups=["wildbench"], + ) diff --git a/src/helm/benchmark/scenarios/test_wildbench_scenario.py b/src/helm/benchmark/scenarios/test_wildbench_scenario.py new file mode 100644 index 00000000000..da71e26ca3e --- /dev/null +++ b/src/helm/benchmark/scenarios/test_wildbench_scenario.py @@ -0,0 +1,18 @@ +import pytest +from tempfile import TemporaryDirectory + +from helm.benchmark.scenarios.wildbench_scenario import WildBenchScenario +from helm.benchmark.scenarios.scenario import TEST_SPLIT + + +@pytest.mark.scenarios +def test_wildbench_scenario_get_instances(): + wildbench_scenario = WildBenchScenario(subset="v2") + with TemporaryDirectory() as tmpdir: + instances = wildbench_scenario.get_instances(tmpdir) + assert len(instances) == 1024 + assert instances[0].split == TEST_SPLIT + assert instances[0].extra_data + + assert instances[0].extra_data["user_query"].startswith("add 10 more balanced governments[aoc2]\n{\n\tGovernment") + assert len(instances[0].extra_data["user_query"]) == 17619 diff --git a/src/helm/benchmark/scenarios/wildbench_scenario.py b/src/helm/benchmark/scenarios/wildbench_scenario.py new file mode 100644 index 00000000000..100efa96617 --- /dev/null +++ b/src/helm/benchmark/scenarios/wildbench_scenario.py @@ -0,0 +1,96 @@ +import datasets +import os +from typing import List + +from helm.benchmark.scenarios.scenario import ( + Scenario, + Instance, + TEST_SPLIT, + Input, +) +from helm.common.general import ensure_directory_exists + + +SUBSETS = ["v2"] +REFERENCE_MODELS = ["gpt-4-turbo-2024-04-09", "claude-3-haiku-20240307", "Llama-2-70b-chat-hf"] + + +class WildBenchScenario(Scenario): + """WildBench: Benchmarking LLMs with Challenging Tasks from Real Users in the Wild + + WildBench is a benchmark for evaluating large language models (LLMs) on challenging tasks + that are more representative of real-world applications. The examples are collected from + real users by the AI2 WildChat project.""" + + name = "wildbench" + description = "Benchmarking LLMs with Challenging Tasks from Real Users in the Wild" + tags = ["instruction following"] + + def __init__(self, subset: str, use_model_outputs: bool = False): + super().__init__() + assert subset in SUBSETS, "Unknown subset: {}".format(subset) + self.subset = subset + self.use_model_outputs = use_model_outputs + + def get_instances(self, output_path: str) -> List[Instance]: + # Get WildBench from HuggingFace + cache_dir = os.path.join(output_path, "data") + ensure_directory_exists(cache_dir) + dataset = datasets.load_dataset( + "allenai/WildBench", + self.subset, + cache_dir=cache_dir, + split="test", + ) + assert isinstance(dataset, datasets.Dataset) + if self.use_model_outputs: + baseline_outputs = { + f"{model}": datasets.load_dataset( + "allenai/WildBench-V2-Model-Outputs", + model, + cache_dir=cache_dir, + split="train", + ) + for model in REFERENCE_MODELS + } + assert all(isinstance(baseline_output, datasets.Dataset) for baseline_output in baseline_outputs.values()) + + # Read all instances + instances: List[Instance] = [] + for idx, row in enumerate(dataset): + + conversation = row["conversation_input"] + + # Following https://github.com/allenai/WildBench/blob/d6b8dcaf377d173d031980f97c16e1a82618c03d/src/eval.py + history = [] + for round in row["conversation_input"][:-1]: + noun = "USER: " if round["role"] == "user" else "ASSISTANT: " + history.append(noun + round["content"]) + history_text = "\n\n".join(history) + user_query_text = row["conversation_input"][-1]["content"] + checklist = [f"- {checklist_item}" for checklist_item in row["checklist"]] + + input = Input( + text=history_text + + "\n\n" + + "USER: " + + user_query_text, # For frontend display only, not used for evaluation + ) + instance = Instance( + input=input, + references=[], + split=TEST_SPLIT, + extra_data={ + "conversation": conversation, + "baseline_outputs": { + model: baseline_outputs[model][idx]["output"][0] if self.use_model_outputs else None + for model in REFERENCE_MODELS + }, + "history": history_text, + "user_query": user_query_text, + "checklist": checklist, + }, + ) + instances.append(instance) + + return instances diff --git a/src/helm/benchmark/static/schema_lite_v2.yaml b/src/helm/benchmark/static/schema_lite_v2.yaml index 0d0026e3a30..0252a519f42 100644 --- a/src/helm/benchmark/static/schema_lite_v2.yaml +++ b/src/helm/benchmark/static/schema_lite_v2.yaml @@ -98,6 +98,11 @@ metrics: short_display_name: COT correct description: Allows to do evaluation using chain of thought for mmlu pro and gpqa. lower_is_better: false + - name: wildbench_score + display_name: WildBench Score + short_display_name: WB Score + description: Score of the AI output judged by GPT-4o. + lower_is_better: false ############################################################ perturbations: [] @@ -141,6 +146,7 @@ run_groups: - mmlu_pro - gpqa - ifeval + - wildbench - name: mmlu_pro display_name: MMLU-Pro @@ -159,6 +165,23 @@ run_groups: when: "?" language: English + - name: gpqa + display_name: GPQA + description: GPQA + metric_groups: + - accuracy + - efficiency + - general_information + environment: + main_name: chain_of_thought_correct + main_split: test + taxonomy: + task: "question answering" + what: "graduate-level questions in biology, physics, and chemistry" + who: "domain experts" + when: "2023" + language: English + - name: ifeval display_name: IFEval description: IFEval @@ -170,21 +193,25 @@ run_groups: main_name: ifeval_strict_accuracy main_split: test taxonomy: - task: "?" - - - name: gpqa - display_name: GPQA - description: GPQA + task: "instruction following" + what: "verifiable general domain instruction following" + who: "human annotators" + when: "2023" + language: English + + - name: wildbench + display_name: WildBench + description: WildBench metric_groups: - accuracy - efficiency - general_information environment: - main_name: chain_of_thought_correct # non-CoT + main_name: wildbench_score main_split: test taxonomy: - task: "?" - what: "?" - who: "?" - when: "?" + task: "instruction following" + what: "GPT-judged instruction following with instructions collected from real-user conversations" + who: "real-world users" + when: "2024" language: English diff --git a/src/helm/clients/mistral_client.py b/src/helm/clients/mistral_client.py index e5402406009..02542c5e8ef 100644 --- a/src/helm/clients/mistral_client.py +++ b/src/helm/clients/mistral_client.py @@ -21,7 +21,8 @@ class MistralAIRequest(TypedDict): model: str # The prompt can be either a string or a list of messages that can be multimodal - prompt: Union[str, List[Dict[str, str]]] + prompt: Optional[Union[str, List[Dict[str, str]]]] + messages: Optional[List[Dict[str, Any]]] max_tokens: int temperature: float top_p: float @@ -50,9 +51,13 @@ def __init__( self.mistral_model = mistral_model def _send_request(self, raw_request: MistralAIRequest) -> Dict[str, Any]: + if raw_request["messages"] is not None: + messages = raw_request["messages"] + else: + messages = [{"role": "user", "content": raw_request["prompt"]}] chat_response: Optional[ChatCompletionResponse] = self._client.chat.complete( model=raw_request["model"], - messages=[{"role": "user", "content": raw_request["prompt"]}], # type: ignore + messages=messages, # type: ignore temperature=raw_request["temperature"], max_tokens=raw_request["max_tokens"], top_p=raw_request["top_p"], @@ -114,15 +119,28 @@ def make_request(self, request: Request) -> RequestResult: # `num_completions` is not supported, so instead make `num_completions` separate requests. for completion_index in range(request.num_completions): try: - raw_request: MistralAIRequest = { - "model": self.mistral_model or request.model_engine, - "prompt": prompt, - "max_tokens": request.max_tokens, - "temperature": request.temperature, - "top_p": request.top_p, - "random_seed": self._get_random_seed(request, completion_index), - "stop": request.stop_sequences or None, - } + if request.messages: + raw_request: MistralAIRequest = { + "model": self.mistral_model or request.model_engine, + "prompt": None, + "messages": request.messages, + "max_tokens": request.max_tokens, + "temperature": request.temperature, + "top_p": request.top_p, + "random_seed": self._get_random_seed(request, completion_index), + "stop": request.stop_sequences or None, + } + else: + raw_request = { + "model": self.mistral_model or request.model_engine, + "prompt": prompt, + "messages": None, + "max_tokens": request.max_tokens, + "temperature": request.temperature, + "top_p": request.top_p, + "random_seed": self._get_random_seed(request, completion_index), + "stop": request.stop_sequences or None, + } def do_it() -> Dict[str, Any]: result: Dict[str, Any] = self._send_request(raw_request) diff --git a/src/helm/clients/vertexai_client.py b/src/helm/clients/vertexai_client.py index 09a159dd672..d7cc527da15 100644 --- a/src/helm/clients/vertexai_client.py +++ b/src/helm/clients/vertexai_client.py @@ -13,7 +13,14 @@ try: import vertexai from vertexai.language_models import TextGenerationModel, TextGenerationResponse # PaLM2 - from vertexai.preview.generative_models import GenerativeModel, GenerationResponse, Candidate, Part, Image # Gemini + from vertexai.preview.generative_models import ( + GenerativeModel, + GenerationResponse, + Candidate, + Content, + Part, + Image, + ) # Gemini from google.cloud.aiplatform_v1beta1.types import SafetySetting, HarmCategory except ModuleNotFoundError as e: handle_module_not_found_error(e, ["google"]) @@ -194,12 +201,20 @@ def get_model(model_name: str) -> GenerativeModel: def make_request(self, request: Request) -> RequestResult: """Make a request""" - contents: str = request.prompt + contents = [request.prompt] # For the multimodal case, build up the content with the media objects of `request.multimodal_prompt` if request.multimodal_prompt is not None: return self._make_multimodal_request(request) + if request.messages is not None: + contents = [] + role_mapping = {"user": "user", "assistant": "model"} + for msg in request.messages: + contents.append( + Content(role=role_mapping.get(msg["role"], "user"), parts=[Part.from_text(msg["content"])]) + ) + parameters = { "temperature": request.temperature, "max_output_tokens": request.max_tokens, @@ -264,7 +279,7 @@ def do_it() -> Dict[str, Any]: cache_key = self.make_cache_key_with_safety_settings_preset( { "model_name": model_name, - "prompt": request.prompt, + "prompt": request.messages or request.prompt, **parameters, }, request,