diff --git a/factgenie/config/default_prompts.yml b/factgenie/config/default_prompts.yml index 5f936220..5f200917 100644 --- a/factgenie/config/default_prompts.yml +++ b/factgenie/config/default_prompts.yml @@ -21,7 +21,7 @@ llm_eval: | ``` Instructions for annotating the text: - Output the errors as a JSON list "annotations" in which each object contains fields "reason", "text", and "type". The value of "reason" is the reason for the annotation. The value of "text" is the literal value of the text inside the highlighted span, so that the span can later be identified using string matching. The value of "type" is an integer index of the error based on the following list: + Output the errors as a JSON list "annotations" in which each object contains fields "reason", "text", and "annotation_type". The value of "reason" is the reason for the annotation. The value of "text" is the literal value of the text inside the highlighted span, so that the span can later be identified using string matching. The value of "annotation_type" is an integer index of the error based on the following list: {error_list} diff --git a/factgenie/config/llm-eval/example-ollama-llama3-eval.yaml b/factgenie/config/llm-eval/example-ollama-llama3-eval.yaml index ab1afacb..3f44efda 100644 --- a/factgenie/config/llm-eval/example-ollama-llama3-eval.yaml +++ b/factgenie/config/llm-eval/example-ollama-llama3-eval.yaml @@ -1,5 +1,5 @@ type: ollama_metric -model: llama3 +model: llama3.1:8b # You can run ollama alson on other machine than factgenie # e.g. we run it on a machine tdll-3gpu3 and access it from any machine which is withing the same firewall # in that case we use api_url: http://tdll-3gpu3.ufal.hide.ms.mff.cuni.cz:11434/api/ @@ -33,7 +33,7 @@ prompt_template: | ``` {text} ``` - Output the errors as a JSON list "annotations" in which each object contains fields "reason", "text", and "type". The value of "text" is the text of the error. The value of "reason" is the reason for the error. The value of "type" is one of {0, 1, 2, 3} based on the following list: + Output the errors as a JSON list "annotations" in which each object contains fields "reason", "text", and "annotation_type". The value of "text" is the text of the error. The value of "reason" is the reason for the error. The value of "annotation_type" is one of {0, 1, 2, 3} based on the following list: - 0: Incorrect fact: The fact in the text contradicts the data. - 1: Not checkable: The fact in the text cannot be checked in the data. - 2: Misleading: The fact in the text is misleading in the given context. @@ -54,6 +54,6 @@ prompt_template: | Nokia 3310 is produced in Finland and features a 320x320 display. It is available in black color. The data seem to provide only partial information about the phone. ``` output: - ```{ "annotations": [{"reason": "The country where the phone is produced is not mentioned in the data.", "text": "produced in Finland", "type": 1}, {"reason": "The data mentions that the display has resolution 320x240px.", "text": "320x320", type: 0}, {"reason": "Misleadingly suggests that the phone is not available in other colors.", "text": "available in black color", type: 2}, {"reason": "The note is irrelevant for the phone description.", "text": "The data seem to provide only partial information about the phone.", type: 3}] } + ```{ "annotations": [{"reason": "The country where the phone is produced is not mentioned in the data.", "text": "produced in Finland", "annotation_type": 1}, {"reason": "The data mentions that the display has resolution 320x240px.", "text": "320x320", "annotation_type": 0}, {"reason": "Misleadingly suggests that the phone is not available in other colors.", "text": "available in black color", "annotation_type": 2}, {"reason": "The note is irrelevant for the phone description.", "text": "The data seem to provide only partial information about the phone.", "annotation_type": 3}] } ``` Note that some details may not be mentioned in the text: do not count omissions as errors. Also do not be too strict: some facts can be less specific than in the data (rounded values, shortened or abbreviated text, etc.), do not count these as errors. If there are no errors in the text, "annotations" will be an empty list. diff --git a/factgenie/config/llm-eval/example-openai-gpt3.5-eval.yaml b/factgenie/config/llm-eval/example-openai-gpt-4o-mini-eval.yaml similarity index 76% rename from factgenie/config/llm-eval/example-openai-gpt3.5-eval.yaml rename to factgenie/config/llm-eval/example-openai-gpt-4o-mini-eval.yaml index 74e3ab5a..7f81d254 100644 --- a/factgenie/config/llm-eval/example-openai-gpt3.5-eval.yaml +++ b/factgenie/config/llm-eval/example-openai-gpt-4o-mini-eval.yaml @@ -25,7 +25,7 @@ prompt_template: | ``` {text} ``` - Output the errors as a JSON list "annotations" in which each object contains fields "reason", "text", and "type". The value of "text" is the text of the error. The value of "reason" is the reason for the error. The value of "type" is one of {0, 1, 2, 3} based on the following list: + Output the errors as a JSON list "annotations" in which each object contains fields "reason", "text", and "annotation_type". The value of "text" is the text of the error. The value of "reason" is the reason for the error. The value of "annotation_type" is one of {0, 1, 2, 3} based on the following list: - 0: Incorrect fact: The fact in the text contradicts the data. - 1: Not checkable: The fact in the text cannot be checked in the data. - 2: Misleading: The fact in the text is misleading in the given context. @@ -46,6 +46,6 @@ prompt_template: | Nokia 3310 is produced in Finland and features a 320x320 display. It is available in black color. The data seem to provide only partial information about the phone. ``` output: - ```{ "annotations": [{"reason": "The country where the phone is produced is not mentioned in the data.", "text": "produced in Finland", "type": 1}, {"reason": "The data mentions that the display has resolution 320x240px.", "text": "320x320", type: 0}, {"reason": "Misleadingly suggests that the phone is not available in other colors.", "text": "available in black color", type: 2}, {"reason": "The note is irrelevant for the phone description.", "text": "The data seem to provide only partial information about the phone.", type: 3}] } + ```{ "annotations": [{"reason": "The country where the phone is produced is not mentioned in the data.", "text": "produced in Finland", "annotation_type": 1}, {"reason": "The data mentions that the display has resolution 320x240px.", "text": "320x320", "annotation_type": 0}, {"reason": "Misleadingly suggests that the phone is not available in other colors.", "text": "available in black color", "annotation_type": 2}, {"reason": "The note is irrelevant for the phone description.", "text": "The data seem to provide only partial information about the phone.", "annotation_type": 3}] } ``` Note that some details may not be mentioned in the text: do not count omissions as errors. Also do not be too strict: some facts can be less specific than in the data (rounded values, shortened or abbreviated text, etc.), do not count these as errors. If there are no errors in the text, "annotations" will be an empty list. diff --git a/factgenie/config/llm-eval/example-vllm-llama3-eval.yaml b/factgenie/config/llm-eval/example-vllm-llama3-eval.yaml new file mode 100644 index 00000000..b46703b6 --- /dev/null +++ b/factgenie/config/llm-eval/example-vllm-llama3-eval.yaml @@ -0,0 +1,59 @@ +type: vllm_metric +model: meta-llama/Meta-Llama-3-8B-Instruct +# You can run vllm also on other machine than factgenie +# e.g. we run it on a machine tdll-3gpu3 and access it from any machine which is withing the same firewall +# in that case we use api_url: http://tdll-3gpu3.ufal.hide.ms.mff.cuni.cz:8000/v1/ +# If you run vllm at the same machine as factgenie let's use just localhost. +api_url: http://localhost:8000/v1/ +model_args: + num_predict: 1024 + temperature: 0.0 + top_p: 1.0 + top_k: 0.0 + seed: 42 +annotation_span_categories: + - name: "Incorrect" + color: "#ffbcbc" + description: "The fact in the text contradicts the data." + - name: "Not checkable" + color: "#e9d2ff" + description: "The fact in the text cannot be checked given the data." + - name: "Misleading" + color: "#fff79f" + description: "The fact in the text is misleading in the given context." + - name: "Other" + color: "#bbbbbb" + description: "The text is problematic for another reason, e.g. grammatically or stylistically incorrect, irrelevant, or repetitive." +prompt_template: | + Given the data: + ``` + {data} + ``` + Annotate all the errors in the following text: + ``` + {text} + ``` + Output the errors as a JSON list "annotations" in which each object contains fields "reason", "text", and "annotation_type". The value of "text" is the text of the error. The value of "reason" is the reason for the error. The value of "annotation_type" is one of {0, 1, 2, 3} based on the following list: + - 0: Incorrect fact: The fact in the text contradicts the data. + - 1: Not checkable: The fact in the text cannot be checked in the data. + - 2: Misleading: The fact in the text is misleading in the given context. + - 3: Other: The text is problematic for another reason, e.g. grammatically or stylistically incorrect, irrelevant, or repetitive. + + The list should be sorted by the position of the error in the text. Make sure that the annotations are not overlapping. + + *Example:* + data: + ``` + Nokia 3310 + ----- + - **color**: black, blue, grey + - **display**: 320x240px + ``` + text (product description): + ``` + Nokia 3310 is produced in Finland and features a 320x320 display. It is available in black color. The data seem to provide only partial information about the phone. + ``` + output: + ```{ "annotations": [{"reason": "The country where the phone is produced is not mentioned in the data.", "text": "produced in Finland", "annotation_type": 1}, {"reason": "The data mentions that the display has resolution 320x240px.", "text": "320x320", "annotation_type": 0}, {"reason": "Misleadingly suggests that the phone is not available in other colors.", "text": "available in black color", "annotation_type": 2}, {"reason": "The note is irrelevant for the phone description.", "text": "The data seem to provide only partial information about the phone.", "annotation_type": 3}] } + ``` + Note that some details may not be mentioned in the text: do not count omissions as errors. Also do not be too strict: some facts can be less specific than in the data (rounded values, shortened or abbreviated text, etc.), do not count these as errors. If there are no errors in the text, "annotations" will be an empty list. diff --git a/factgenie/data/datasets_TEMPLATE.yml b/factgenie/data/datasets_TEMPLATE.yml index 5c8007a3..d3e90269 100644 --- a/factgenie/data/datasets_TEMPLATE.yml +++ b/factgenie/data/datasets_TEMPLATE.yml @@ -11,4 +11,4 @@ example-dataset-id: - list - of - dataset - - splits + - splits \ No newline at end of file diff --git a/factgenie/models.py b/factgenie/models.py index 0bf6036e..b6d8afab 100644 --- a/factgenie/models.py +++ b/factgenie/models.py @@ -1,22 +1,21 @@ #!/usr/bin/env python3 +from abc import abstractmethod import traceback +from typing import Optional from openai import OpenAI from textwrap import dedent import json -from pathlib import Path import os -import coloredlogs import logging -import time +from pydantic import BaseModel, Field, ValidationError import requests import copy from ast import literal_eval from factgenie.campaigns import CampaignMode -# logging.basicConfig(format="%(message)s", level=logging.INFO, datefmt="%H:%M:%S") logger = logging.getLogger(__name__) DIR_PATH = os.path.dirname(__file__) @@ -33,6 +32,7 @@ def model_classes(): CampaignMode.LLM_EVAL: { "openai_metric": OpenAIMetric, "ollama_metric": OllamaMetric, + "vllm_metric": VLLMMetric, }, CampaignMode.LLM_GEN: { "openai_gen": OpenAIGen, @@ -52,16 +52,25 @@ def from_config(config, mode): return classes[metric_type](config) +class SpanAnnotation(BaseModel): + text: str = Field(description="The text which is annotated.") + # Do not name it type since it is a reserved keyword in JSON schema + annotation_type: int = Field( + description="Index to the list of span annotation types defined for the annotation campaign." + ) + reason: str = Field(description="The reason for the annotation.") + + +class OutputAnnotations(BaseModel): + annotations: list[SpanAnnotation] = Field(description="The list of annotations.") + + class Model: def __init__(self, config): self.validate_config(config) self.config = config self.parse_model_args() - if "extra_args" in config: - # the key in the model output that contains the annotations - self.annotation_key = config["extra_args"].get("annotation_key", "annotations") - @property def new_connection_error_advice_docstring(self): return """Please check the LLM engine documentation. The call to the LLM API server failed.""" @@ -124,26 +133,33 @@ def get_optional_fields(self): "extra_args": dict, } - def postprocess_annotations(self, text, model_json): + def parse_annotations(self, text, annotations_json): + try: + annotations_obj = OutputAnnotations.parse_raw(annotations_json) + annotations = annotations_obj.annotations + except ValidationError as e: + logger.error(f"LLM response in not in the expected format: {e}\n\t{annotations_json=}") + annotation_list = [] current_pos = 0 - - if self.annotation_key not in model_json: - logger.error(f"Cannot find the key `{self.annotation_key}` in {model_json=}") - return annotation_list - - for annotation in model_json[self.annotation_key]: + for annotation in annotations: # find the `start` index of the error in the text - start_pos = text.lower().find(annotation["text"].lower(), current_pos) + start_pos = text.lower().find(annotation.text.lower(), current_pos) if start_pos == -1: logger.warning(f"Cannot find {annotation=} in text {text}, skipping") continue - annotation["start"] = start_pos - annotation_list.append(copy.deepcopy(annotation)) + annotation_d = annotation.dict() + # For backward compatibility let's use shorter "type" + # We do not use the name "type" in JSON schema for error types because it has much broader sense in the schema (e.g. string or integer) + annotation_d["type"] = annotation.annotation_type + del annotation_d["annotation_type"] + # logging where the annotion starts to disambiguate errors on the same string in different places + annotation_d["start"] = start_pos + annotation_list.append(annotation_d) - current_pos = start_pos + len(annotation["text"]) + current_pos = start_pos + len(annotation.text) # does not allow for overlapping annotations return annotation_list @@ -181,10 +197,34 @@ def annotate_example(self, data, text): raise NotImplementedError("Override this method in the subclass to call the LLM API") -class OpenAIMetric(LLMMetric): - def __init__(self, config): +class OpenAIClientMetric(LLMMetric): + def __init__(self, config, **kwargs): super().__init__(config) - self.client = OpenAI() + self.client = OpenAI(**kwargs) + + config_schema = config.get("extra_args", {}).get("schema", {}) + pydantic_schema = OutputAnnotations.model_json_schema() + if config_schema: + self._schema = config_schema + logger.warning( + f"We expect parsing according to \n{pydantic_schema=}\n but got anoter schema from config\n{config_schema=}" + "\nAdapt parsing accordingly!" + ) + else: + self._schema = pydantic_schema + + # Required for OpenAI API but make sense in general too + # TODO make it more pydantic / Python friendly + self._schema["additionalProperties"] = False + self._schema["$defs"]["Annotation"]["additionalProperties"] = False + + logger.warning(f"The schema is set to\n{self._schema}.\n\tCheck that your prompt is compatible!!! ") + # access the later used config keys early to log them once and test if they are present + logger.info(f"Using {config['model']=} with {config['system_msg']=}") + + @property + def schema(self): + return self._schema def get_required_fields(self): return { @@ -202,31 +242,79 @@ def get_optional_fields(self): "extra_args": dict, # TODO we receive it from the UI, but can be removed } + @abstractmethod + def _prepare_chat_completions_create_args(self): + raise NotImplementedError("Override this method in the subclass to prepare the arguments for the OpenAI API") + def annotate_example(self, data, text): try: prompt = self.prompt(data, text) logger.debug(f"Calling OpenAI API with prompt: {prompt}") + + model = self.config["model"] + response = self.client.chat.completions.create( - model=self.config["model"], - response_format={"type": "json_object"}, + model=model, messages=[ {"role": "system", "content": self.config["system_msg"]}, {"role": "user", "content": prompt}, ], - **self.config.get("model_args", {}), + **self._prepare_chat_completions_create_args(), ) annotation_str = response.choices[0].message.content - j = json.loads(annotation_str) - logger.info(j) + logger.info(annotation_str) - return {"prompt": prompt, "annotations": self.postprocess_annotations(text=text, model_json=j)} + return { + "prompt": prompt, + "annotations": self.parse_annotations(text=text, annotations_json=annotation_str), + } except Exception as e: traceback.print_exc() logger.error(e) raise e +class VLLMMetric(OpenAIClientMetric): + def __init__(self, config, **kwargs): + base_url = config["api_url"] # Mandatory for VLLM + api_key = config.get("api_key", None) # Optional authentication for VLLM + + super().__init__(config, base_url=base_url, api_key=api_key, **kwargs) + + def _prepare_chat_completions_create_args(self): + guided_json = self.schema + # # works well with vllm https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters + config_args = {"extra_body": {"guided_json": guided_json}} + return config_args + + +class OpenAIMetric(OpenAIClientMetric): + def _prepare_chat_completions_create_args(self): + model = self.config["model"] + + model_supported = any(model.startswith(prefix) for prefix in ["gpt-4o", "gpt-4o-mini"]) + if not model_supported: + logger.warning( + f"Model {model} does not support structured output. It is probablye there will be SOME OF PARSING ERRORS" + ) + response_format = {"type": "json_object"} + else: + # Details at https://platform.openai.com/docs/guides/structured-outputs?context=without_parse + json_schema = dict(name="OutputNLGAnnotations", strict=True, schema=self.schema) + response_format = { + "type": "json_schema", + "json_schema": json_schema, + } + + config_args = self.config.get("model_args", {}) + if "response_format" in config_args and config_args["response_format"] != response_format: + logger.warning(f"Not using the default {response_format=} but using {config_args['response_format']=}") + else: + config_args["response_format"] = response_format + return config_args + + class OllamaMetric(LLMMetric): def __init__(self, config): super().__init__(config) @@ -250,13 +338,21 @@ def postprocess_output(self, output): output = output.strip() j = json.loads(output) + ANNOTATION_STR = "annotations" + assert ( + ANNOTATION_STR in OutputAnnotations.model_json_schema()["properties"] + ), f"Has the {OutputAnnotations=} schema changed?" + + # Required for OllamaMetric. You may want to switch to VLLMMetric which uses constrained decoding. + # It is especially useful for weaker models which have problems decoding valid JSON on output. if self.config["model"].startswith("llama3"): # the model often tends to produce a nested list - annotations = j[self.annotation_key] + + annotations = j[ANNOTATION_STR] if isinstance(annotations, list) and len(annotations) >= 1 and isinstance(annotations[0], list): - j[self.annotation_key] = j[self.annotation_key][0] + j[ANNOTATION_STR] = j[ANNOTATION_STR][0] - return j + return json.dumps(j) def annotate_example(self, data, text): prompt = self.prompt(data=data, text=text) @@ -282,12 +378,11 @@ def annotate_example(self, data, text): return [] annotation_str = response_json["response"] - - j = self.postprocess_output(annotation_str) - logger.info(j) + annotation_postprocessed = self.postprocess_output(annotation_str) + logger.info(annotation_postprocessed) return { "prompt": prompt, - "annotations": self.postprocess_annotations(text=text, model_json=j), + "annotations": self.parse_annotations(text=text, annotations_json=annotation_postprocessed), } except (ConnectionError, requests.exceptions.ConnectionError) as e: # notifiy the user that the API is down