diff --git a/README.md b/README.md index ef5fb9c..7a0341f 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,7 @@ ## News and Updates +- [13/03/2024] Add support for multi-modal models and datasets. - [05/01/2024] Add support for BigBench Hard, DROP, ARC datasets. - [16/12/2023] Add support for Gemini, Mistral, Mixtral, Baichuan, Yi models. - [15/12/2023] Add detailed instructions for users to add new modules (models, datasets, etc.) [examples/add_new_modules.md](examples/add_new_modules.md). @@ -161,7 +162,7 @@ import promptbench as pb We provide tutorials for: -1. **evaluate models on existing benchmarks:** please refer to the [examples/basic.ipynb](examples/basic.ipynb) for constructing your evaluation pipeline. +1. **evaluate models on existing benchmarks:** please refer to the [examples/basic.ipynb](examples/basic.ipynb) for constructing your evaluation pipeline. For a multi-modal evaluation pipeline, please refer to [examples/multimodal.ipynb](examples/multimodal.ipynb) 2. **test the effects of different prompting techniques:** 3. **examine the robustness for prompt attacks**, please refer to [examples/prompt_attack.ipynb](examples/prompt_attack.ipynb) to construct the attacks. 4. **use DyVal for evaluation:** please refer to [examples/dyval.ipynb](examples/dyval.ipynb) to construct DyVal datasets. @@ -185,6 +186,13 @@ PromptBench currently supports different datasets, models, prompt engineering me - Numersense - QASC - Last Letter Concatenate +- VQAv2 +- NoCaps +- MMMU +- MathVista +- AI2D +- ChartQA +- ScienceQA ### Models @@ -203,6 +211,18 @@ PromptBench currently supports different datasets, models, prompt engineering me - GPT-4 - Gemini Pro +### Models (Multi-Modal) + +- Open-source models: + - BLIP2 + - LLaVA + - Qwen-VL, Qwen-VL-Chat + - InternLM-XComposer2-VL +- Proprietary models + - GPT-4v + - GeminiProVision + - Qwen-VL-Max, Qwen-VL-Plus + ### Prompt Engineering - Chain-of-thought (COT) [1] @@ -239,10 +259,6 @@ PromptBench currently supports different datasets, models, prompt engineering me Please refer to our [benchmark website](https://llm-eval.github.io/) for benchmark results on Prompt Attacks, Prompt Engineering and Dynamic Evaluation DyVal. -## TODO - -- [ ] Add support for multi-modal models such as LlaVa and BLIP2. - ## Acknowledgements - [TextAttack](https://github.com/QData/TextAttack) diff --git a/docs/examples/multimodal.md b/docs/examples/multimodal.md new file mode 100644 index 0000000..c360cd8 --- /dev/null +++ b/docs/examples/multimodal.md @@ -0,0 +1,134 @@ +# Multi-Modal Models + +This example will walk you throught the basic usage of MULTI-MODAL models in PromptBench. We hope that you can get familiar with the APIs and use it in your own projects later. + +First, there is a unified import of `import promptbench as pb` that easily imports the package. + + +```python +import promptbench as pb +``` + + /anaconda/envs/promptbench_1/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html + from .autonotebook import tqdm as notebook_tqdm + + +## Load dataset + +First, PromptBench supports easy load of datasets. + + +```python +# print all supported datasets in promptbench +print('All supported datasets: ') +print(pb.SUPPORTED_DATASETS_VLM) + +# load a dataset, MMMMU, for instance. +# if the dataset is not available locally, it will be downloaded automatically. +dataset = pb.DatasetLoader.load_dataset("mmmu") + +# print the first 5 examples +dataset[:5] +``` + + All supported datasets: + ['vqav2', 'nocaps', 'science_qa', 'math_vista', 'ai2d', 'mmmu', 'chart_qa'] + + + + + + [{'images': [], + 'answer': 'B', + 'question': ' Baxter Company has a relevant range of production between 15,000 and 30,000 units. The following cost data represents average variable costs per unit for 25,000 units of production. If 30,000 units are produced, what are the per unit manufacturing overhead costs incurred?\nA: $6\nB: $7\nC: $8\nD: $9'}, + {'images': [], + 'answer': 'C', + 'question': 'Assume accounts have normal balances, solve for the one missing account balance: Dividends. Equipment was recently purchased, so there is neither depreciation expense nor accumulated depreciation. \nA: $194,815\nB: $182,815\nC: $12,000\nD: $9,000'}, + {'images': [], + 'answer': 'B', + 'question': 'Maxwell Software, Inc., has the following mutually exclusive projects.Suppose the company uses the NPV rule to rank these two projects. Which project should be chosen if the appropriate discount rate is 15 percent?\nA: Project A\nB: Project B'}, + {'images': [], + 'answer': 'D', + 'question': "Each situation below relates to an independent company's Owners' Equity. Calculate the missing values of company 2.\nA: $1,620\nB: $12,000\nC: $51,180\nD: $0"}, + {'images': [], + 'answer': 'B', + 'question': 'The following data show the units in beginning work in process inventory, the number of units started, the number of units transferred, and the percent completion of the ending work in process for conversion. Given that materials are added at the beginning of the process, what are the equivalent units for conversion costs for each quarter using the weighted-average method? Assume that the quarters are independent.\nA: 132,625\nB: 134,485\nC: 135,332\nD: 132,685'}] + + + +## Load models + +Then, you can easily load VLM models via promptbench. + + +```python +# print all supported models in promptbench +print('All supported models: ') +print(pb.SUPPORTED_MODELS_VLM) + +# load a model, llava-1.5-7b, for instance. +model = pb.VLMModel(model='llava-hf/llava-1.5-7b-hf', max_new_tokens=2048, temperature=0.0001, device='cuda') +``` + + All supported models: + ['Salesforce/blip2-opt-2.7b', 'Salesforce/blip2-opt-6.7b', 'Salesforce/blip2-flan-t5-xl', 'Salesforce/blip2-flan-t5-xxl', 'llava-hf/llava-1.5-7b-hf', 'llava-hf/llava-1.5-13b-hf', 'gemini-pro-vision', 'gpt-4-vision-preview', 'Qwen/Qwen-VL', 'Qwen/Qwen-VL-Chat', 'qwen-vl-plus', 'qwen-vl-max', 'internlm/internlm-xcomposer2-vl-7b'] + + + Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. + Loading checkpoint shards: 100%|██████████| 3/3 [00:04<00:00, 1.48s/it] + + +## Construct prompts + +Prompts are the key interaction interface to VLMs. You can easily construct a prompt by call the Prompt API. + + +```python +# Prompt API supports a list, so you can pass multiple prompts at once. +prompts = pb.Prompt([ + "You are a helpful assistant. Here is the question:{question}\nANSWER:", + "USER:{question}\nANSWER:", +]) +``` + +## Perform evaluation using prompts, datasets, and models + +Finally, you can perform standard evaluation using the loaded prompts, datasets, and labels. + + +```python +from tqdm import tqdm +for prompt in prompts: + preds = [] + labels = [] + for data in tqdm(dataset): + # process input + input_text = pb.InputProcess.basic_format(prompt, data) + input_images = data['images'] + label = data['answer'] + raw_pred = model(input_images, input_text) + # process output + pred = pb.OutputProcess.pattern_split(raw_pred, 'ANSWER:') + preds.append(pred) + labels.append(label) + + # evaluate + score = pb.Eval.compute_cls_accuracy(preds, labels) + print(f"{score:.3f}, {repr(prompt)}") +``` + + 0%| | 0/900 [00:00],\n", + " 'answer': 'B',\n", + " 'question': ' Baxter Company has a relevant range of production between 15,000 and 30,000 units. The following cost data represents average variable costs per unit for 25,000 units of production. If 30,000 units are produced, what are the per unit manufacturing overhead costs incurred?\\nA: $6\\nB: $7\\nC: $8\\nD: $9'},\n", + " {'images': [],\n", + " 'answer': 'C',\n", + " 'question': 'Assume accounts have normal balances, solve for the one missing account balance: Dividends. Equipment was recently purchased, so there is neither depreciation expense nor accumulated depreciation. \\nA: $194,815\\nB: $182,815\\nC: $12,000\\nD: $9,000'},\n", + " {'images': [],\n", + " 'answer': 'B',\n", + " 'question': 'Maxwell Software, Inc., has the following mutually exclusive projects.Suppose the company uses the NPV rule to rank these two projects. Which project should be chosen if the appropriate discount rate is 15 percent?\\nA: Project A\\nB: Project B'},\n", + " {'images': [],\n", + " 'answer': 'D',\n", + " 'question': \"Each situation below relates to an independent company's Owners' Equity. Calculate the missing values of company 2.\\nA: $1,620\\nB: $12,000\\nC: $51,180\\nD: $0\"},\n", + " {'images': [],\n", + " 'answer': 'B',\n", + " 'question': 'The following data show the units in beginning work in process inventory, the number of units started, the number of units transferred, and the percent completion of the ending work in process for conversion. Given that materials are added at the beginning of the process, what are the equivalent units for conversion costs for each quarter using the weighted-average method? Assume that the quarters are independent.\\nA: 132,625\\nB: 134,485\\nC: 135,332\\nD: 132,685'}]" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# print all supported datasets in promptbench\n", + "print('All supported datasets: ')\n", + "print(pb.SUPPORTED_DATASETS_VLM)\n", + "\n", + "# load a dataset, MMMMU, for instance.\n", + "# if the dataset is not available locally, it will be downloaded automatically.\n", + "dataset = pb.DatasetLoader.load_dataset(\"mmmu\")\n", + "\n", + "# print the first 5 examples\n", + "dataset[:5]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load models\n", + "\n", + "Then, you can easily load VLM models via promptbench." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "All supported models: \n", + "['Salesforce/blip2-opt-2.7b', 'Salesforce/blip2-opt-6.7b', 'Salesforce/blip2-flan-t5-xl', 'Salesforce/blip2-flan-t5-xxl', 'llava-hf/llava-1.5-7b-hf', 'llava-hf/llava-1.5-13b-hf', 'gemini-pro-vision', 'gpt-4-vision-preview', 'Qwen/Qwen-VL', 'Qwen/Qwen-VL-Chat', 'qwen-vl-plus', 'qwen-vl-max', 'internlm/internlm-xcomposer2-vl-7b']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", + "Loading checkpoint shards: 100%|██████████| 3/3 [00:04<00:00, 1.48s/it]\n" + ] + } + ], + "source": [ + "# print all supported models in promptbench\n", + "print('All supported models: ')\n", + "print(pb.SUPPORTED_MODELS_VLM)\n", + "\n", + "# load a model, llava-1.5-7b, for instance.\n", + "model = pb.VLMModel(model='llava-hf/llava-1.5-7b-hf', max_new_tokens=2048, temperature=0.0001, device='cuda')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Construct prompts\n", + "\n", + "Prompts are the key interaction interface to VLMs. You can easily construct a prompt by call the Prompt API." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Prompt API supports a list, so you can pass multiple prompts at once.\n", + "prompts = pb.Prompt([\n", + " \"You are a helpful assistant. Here is the question:{question}\\nANSWER:\",\n", + " \"USER:{question}\\nANSWER:\",\n", + "])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Perform evaluation using prompts, datasets, and models\n", + "\n", + "Finally, you can perform standard evaluation using the loaded prompts, datasets, and labels." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/900 [00:00 + } + """ + def __init__(self): + data = load_dataset("HuggingFaceM4/VQAv2", split="validation") + self.data = data + + def __getitem__(self, idx): + assert len(self.data) > 0, "Empty dataset. Please load data first." + return {"images": [self.data[idx]['image']], + "answers": self.data[idx]['answers'], + "question": self.data[idx]['question'],} + +class NoCaps(Dataset): + """ + NoCaps is a dataset class for the NoCaps dataset. + This dataset is loaded from huggingface datasets: nocaps (validation set). + + Reference: + https://huggingface.co/datasets/HuggingFaceM4/NoCaps + nocaps: novel object captioning at scale (https://arxiv.org/abs/1812.08658) + + Example data format: + { + 'image': , + 'image_coco_url': 'https://s3.amazonaws.com/nocaps/val/0013ea2087020901.jpg', + 'image_date_captured': '2018-11-06 11:04:33', + 'image_file_name': '0013ea2087020901.jpg', + 'image_height': 1024, + 'image_width': 732, + 'image_id': 0, + 'image_license': 0, + 'image_open_images_id': '0013ea2087020901', + 'annotations_ids': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + 'annotations_captions': ['A baby is standing in front of a house.', + 'A little girl in a white jacket and sandals.', + 'A young child stands in front of a house.', + 'A child is wearing a white shirt and standing on a side walk. ', + 'A little boy is standing in his diaper with a white shirt on.', + 'A child wearing a diaper and shoes stands on the sidewalk.', + 'A child is wearing a light-colored shirt during the daytime.', + 'A little kid standing on the pavement in a shirt. ', + 'Black and white photo of a little girl smiling.', + 'a cute baby is standing alone with white shirt'] + } + """ + def __init__(self): + data = load_dataset("HuggingFaceM4/NoCaps", split="validation") + self.data = data + + def __getitem__(self, idx): + assert len(self.data) > 0, "Empty dataset. Please load data first." + return {"images": [self.data[idx]['image']], + "answers": self.data[idx]['annotations_captions']} + +class MathVista(Dataset): + """ + MathVista is a dataset class for the MathVista dataset. + This dataset is loaded from huggingface datasets: math_vista (testmini set). + + Reference: + https://huggingface.co/datasets/AI4Math/MathVista + MathVista: Evaluating Mathematical Reasoning of Foundation Models in Visual Contexts (https://arxiv.org/abs/2310.02255) + + Example data format: + { + 'pid': '1', + 'question': "When a spring does work on an object, we cannot find the work by simply multiplying the spring force by the object's displacement. The reason is that there is no one value for the force-it changes. However, we can split the displacement up into an infinite number of tiny parts and then approximate the force in each as being constant. Integration sums the work done in all those parts. Here we use the generic result of the integration.\r\n\r\nIn Figure, a cumin canister of mass $m=0.40 \\mathrm{~kg}$ slides across a horizontal frictionless counter with speed $v=0.50 \\mathrm{~m} / \\mathrm{s}$. It then runs into and compresses a spring of spring constant $k=750 \\mathrm{~N} / \\mathrm{m}$. When the canister is momentarily stopped by the spring, by what distance $d$ is the spring compressed?", + 'image': 'images/1.jpg', + 'decoded_image': , + 'choices': None, + 'unit': None, + 'precision': 1.0, + 'answer': '1.2', + 'question_type': 'free_form', + 'answer_type': 'float', + 'metadata': {'category': 'math-targeted-vqa', + 'context': 'scientific figure', + 'grade': 'college', + 'img_height': 720, + 'img_width': 1514, + 'language': 'english', + 'skills': ['scientific reasoning'], + 'source': 'SciBench', + 'split': 'testmini', + 'task': 'textbook question answering'}, + 'query': "Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value, e.g., 1.2, 1.3, 1.4, at the end.\nQuestion: When a spring does work on an object, we cannot find the work by simply multiplying the spring force by the object's displacement. The reason is that there is no one value for the force-it changes. However, we can split the displacement up into an infinite number of tiny parts and then approximate the force in each as being constant. Integration sums the work done in all those parts. Here we use the generic result of the integration.\r\n\r\nIn Figure, a cumin canister of mass $m=0.40 \\mathrm{~kg}$ slides across a horizontal frictionless counter with speed $v=0.50 \\mathrm{~m} / \\mathrm{s}$. It then runs into and compresses a spring of spring constant $k=750 \\mathrm{~N} / \\mathrm{m}$. When the canister is momentarily stopped by the spring, by what distance $d$ is the spring compressed?" + } + + """ + def __init__(self): + data = load_dataset("AI4Math/MathVista", split="testmini") + self.data = data + + def __getitem__(self, idx): + assert len(self.data) > 0, "Empty dataset. Please load data first." + return {"images": [self.data[idx]['decoded_image']], + "answer": self.data[idx]['answer'], + "question": self.data[idx]['question'] + "\nANSWER TYPE: " + self.data[idx]['answer_type'],} + +class AI2D(Dataset): + """ + AI2D is a dataset class for the AI2D dataset. + This dataset is loaded from huggingface datasets: ai2d (test set). + + Reference: + https://huggingface.co/datasets/lmms-lab/ai2d + A Diagram Is Worth A Dozen Images (https://arxiv.org/abs/1603.07396) + + Example data format: + { + 'question': 'which of these define dairy item', + 'options': ['c', 'D', 'b', 'a'], + 'answer': '1', + 'image': + } + """ + def __init__(self): + data = load_dataset("lmms-lab/ai2d", split="test") + self.data = [] + + for d in data: + choices_dict = dict(enumerate(d['options'])) + choices = '' + for k, v in choices_dict.items(): + choices += f"\n{k}: {v}" + + self.data.append({ + "images": [d['image']], + "question": d['question'] + choices, + "answer": d['answer'] + }) + +class ChartQA(Dataset): + """ + ChartQA is a dataset class for the ChartQA dataset. + This dataset is loaded from huggingface datasets: chart_qa (test set). + + Reference: + https://huggingface.co/datasets/lmms-lab/ChartQA + ChartQA: A Benchmark for Question Answering about Charts with Visual and Logical Reasoning (https://arxiv.org/abs/2203.10244) + + Example data format: + { + 'type': 'human_test', + 'question': 'How many food item is shown in the bar graph?', + 'answer': '14', + 'image': + } + """ + def __init__(self): + data = load_dataset("lmms-lab/ChartQA", split="test") + self.data = data + + def __getitem__(self, idx): + assert len(self.data) > 0, "Empty dataset. Please load data first." + return {"images": [self.data[idx]['image']], + "answer": self.data[idx]['answer'], + "question": self.data[idx]['question'],} + +class ScienceQA(Dataset): + """ + ScienceQA is a dataset class for the ScienceQA dataset. + This dataset is loaded from huggingface datasets: science_qa (validation set). + + Reference: + https://huggingface.co/datasets/derek-thomas/ScienceQA + Learn to Explain: Multimodal Reasoning via Thought Chains for Science Question Answering (https://arxiv.org/abs/2209.09513) + + Example data format: + { + 'image': None, + 'question': 'Which figure of speech is used in this text?\nSing, O goddess, the anger of Achilles son of Peleus, that brought countless ills upon the Achaeans.\n—Homer, The Iliad', + 'choices': ['chiasmus', 'apostrophe'], + 'answer': 1, + 'hint': '', + 'task': 'closed choice', + 'grade': 'grade11', + 'subject': 'language science', + 'topic': 'figurative-language', + 'category': 'Literary devices', + 'skill': 'Classify the figure of speech: anaphora, antithesis, apostrophe, assonance, chiasmus, understatement', + 'lecture': 'Figures of speech are words or phrases that use language in a nonliteral or unusual way. They can make writing more expressive.\nAnaphora is the repetition of the same word or words at the beginning of several phrases or clauses.\nWe are united. We are powerful. We are winners.\nAntithesis involves contrasting opposing ideas within a parallel grammatical structure.\nI want to help, not to hurt.\nApostrophe is a direct address to an absent person or a nonhuman entity.\nOh, little bird, what makes you sing so beautifully?\nAssonance is the repetition of a vowel sound in a series of nearby words.\nTry to light the fire.\nChiasmus is an expression in which the second half parallels the first but reverses the order of words.\nNever let a fool kiss you or a kiss fool you.\nUnderstatement involves deliberately representing something as less serious or important than it really is.\nAs you know, it can get a little cold in the Antarctic.', + 'solution': 'The text uses apostrophe, a direct address to an absent person or a nonhuman entity.\nO goddess is a direct address to a goddess, a nonhuman entity.'} + } + """ + def __init__(self): + data = load_dataset("derek-thomas/ScienceQA", split="validation") + self.data = [] + + for d in data: + if d['image'] is not None: + + choices_dict = dict(enumerate(d['choices'])) + choices = '' + for k, v in choices_dict.items(): + choices += f"\n{k}: {v}" + + self.data.append({ + "images": [d['image']], + "question": d['question'] + choices, + "answer": d['answer'] + }) + +class MMMU(Dataset): + """ + MMMU is a dataset class for the MMMU dataset. + This dataset is loaded from huggingface datasets: mmlu (validation set). + + Reference: + https://huggingface.co/datasets/lmms-lab/MMMU + MMMU: A Massive Multi-discipline Multimodal Understanding and Reasoning Benchmark for Expert AGI (https://arxiv.org/abs/2311.16502) + + { + 'id': 'validation_Accounting_1', + 'question': ' Baxter Company has a relevant range of production between 15,000 and 30,000 units. The following cost data represents average variable costs per unit for 25,000 units of production. If 30,000 units are produced, what are the per unit manufacturing overhead costs incurred?', + 'options': "['$6', '$7', '$8', '$9']", + 'explanation': '', + 'image_1': , + 'image_2': None, + 'image_3': None, + 'image_4': None, + 'image_5': None, + 'image_6': None, + 'image_7': None, + 'img_type': "['Tables']", + 'answer': 'B', + 'topic_difficulty': 'Medium', + 'question_type': 'multiple-choice', + 'subfield': 'Managerial Accounting' + } + """ + def __init__(self): + data = load_dataset("lmms-lab/MMMU", split="validation") + self.data = [] + + for d in data: + + choices_dict = dict(enumerate(eval(d['options']))) + choices = '' + for k, v in choices_dict.items(): + choices += f"\n{chr(ord('A') + int(k))}: {v}" + question = d['question'] + choices + + images = [] + for i in range(1, 7): + if f'image {i}' in question: + if d[f'image_{i}'].mode == 'P': + d[f'image_{i}'] = d[f'image_{i}'].convert('RGBA') + images.append(d[f'image_{i}']) + + self.data.append({"images": images, + "answer": d['answer'], + "question": question}) \ No newline at end of file diff --git a/promptbench/metrics/cider/cider.py b/promptbench/metrics/cider/cider.py new file mode 100644 index 0000000..5fb4650 --- /dev/null +++ b/promptbench/metrics/cider/cider.py @@ -0,0 +1,57 @@ +""" +This is copied from https://github.com/tylin/coco-caption/tree/master/pycocoevalcap/cider. +""" +# Filename: cider.py +# +# Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric +# by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) +# +# Creation Date: Sun Feb 8 14:16:54 2015 +# +# Authors: Ramakrishna Vedantam and Tsung-Yi Lin + +from .cider_scorer import CiderScorer +import pdb + +class Cider: + """ + Main Class to compute the CIDEr metric + + """ + def __init__(self, test=None, refs=None, n=4, sigma=6.0): + # set cider to sum over 1 to 4-grams + self._n = n + # set the standard deviation parameter for gaussian penalty + self._sigma = sigma + + def compute_score(self, gts, res): + """ + Main function to compute CIDEr score + :param hypo_for_image (dict) : dictionary with key and value + ref_for_image (dict) : dictionary with key and value + :return: cider (float) : computed CIDEr score for the corpus + """ + + assert(gts.keys() == res.keys()) + imgIds = gts.keys() + + cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) + + for id in imgIds: + hypo = res[id] + ref = gts[id] + + # Sanity check. + assert(type(hypo) is list) + assert(len(hypo) == 1) + assert(type(ref) is list) + assert(len(ref) > 0) + + cider_scorer += (hypo[0], ref) + + (score, scores) = cider_scorer.compute_score() + + return score, scores + + def method(self): + return "CIDEr" \ No newline at end of file diff --git a/promptbench/metrics/cider/cider_scorer.py b/promptbench/metrics/cider/cider_scorer.py new file mode 100644 index 0000000..0aacb8e --- /dev/null +++ b/promptbench/metrics/cider/cider_scorer.py @@ -0,0 +1,196 @@ +""" +This is copied from https://github.com/tylin/coco-caption/tree/master/pycocoevalcap/cider. +""" +#!/usr/bin/env python +# Tsung-Yi Lin +# Ramakrishna Vedantam + +import copy +from collections import defaultdict +import numpy as np +import pdb +import math + +def precook(s, n=4, out=False): + """ + Takes a string as input and returns an object that can be given to + either cook_refs or cook_test. This is optional: cook_refs and cook_test + can take string arguments as well. + :param s: string : sentence to be converted into ngrams + :param n: int : number of ngrams for which representation is calculated + :return: term frequency vector for occuring ngrams + """ + words = s.split() + counts = defaultdict(int) + for k in range(1,n+1): + for i in range(len(words)-k+1): + ngram = tuple(words[i:i+k]) + counts[ngram] += 1 + return counts + +def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" + '''Takes a list of reference sentences for a single segment + and returns an object that encapsulates everything that BLEU + needs to know about them. + :param refs: list of string : reference sentences for some image + :param n: int : number of ngrams for which (ngram) representation is calculated + :return: result (list of dict) + ''' + return [precook(ref, n) for ref in refs] + +def cook_test(test, n=4): + '''Takes a test sentence and returns an object that + encapsulates everything that BLEU needs to know about it. + :param test: list of string : hypothesis sentence for some image + :param n: int : number of ngrams for which (ngram) representation is calculated + :return: result (dict) + ''' + return precook(test, n, True) + +class CiderScorer(object): + """CIDEr scorer. + """ + + def copy(self): + ''' copy the refs.''' + new = CiderScorer(n=self.n) + new.ctest = copy.copy(self.ctest) + new.crefs = copy.copy(self.crefs) + return new + + def __init__(self, test=None, refs=None, n=4, sigma=6.0): + ''' singular instance ''' + self.n = n + self.sigma = sigma + self.crefs = [] + self.ctest = [] + self.document_frequency = defaultdict(float) + self.cook_append(test, refs) + self.ref_len = None + + def cook_append(self, test, refs): + '''called by constructor and __iadd__ to avoid creating new instances.''' + + if refs is not None: + self.crefs.append(cook_refs(refs)) + if test is not None: + self.ctest.append(cook_test(test)) ## N.B.: -1 + else: + self.ctest.append(None) # lens of crefs and ctest have to match + + def size(self): + assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) + return len(self.crefs) + + def __iadd__(self, other): + '''add an instance (e.g., from another sentence).''' + + if type(other) is tuple: + ## avoid creating new CiderScorer instances + self.cook_append(other[0], other[1]) + else: + self.ctest.extend(other.ctest) + self.crefs.extend(other.crefs) + + return self + def compute_doc_freq(self): + ''' + Compute term frequency for reference data. + This will be used to compute idf (inverse document frequency later) + The term frequency is stored in the object + :return: None + ''' + for refs in self.crefs: + # refs, k ref captions of one image + for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): + self.document_frequency[ngram] += 1 + # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) + + def compute_cider(self): + def counts2vec(cnts): + """ + Function maps counts of ngram to vector of tfidf weights. + The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. + The n-th entry of array denotes length of n-grams. + :param cnts: + :return: vec (array of dict), norm (array of float), length (int) + """ + vec = [defaultdict(float) for _ in range(self.n)] + length = 0 + norm = [0.0 for _ in range(self.n)] + for (ngram, term_freq) in cnts.items(): + # give word count 1 if it doesn't appear in reference corpus + df = np.log(max(1.0, self.document_frequency[ngram])) + # ngram index + n = len(ngram)-1 + # tf (term_freq) * idf (precomputed idf) for n-grams + vec[n][ngram] = float(term_freq)*(self.ref_len - df) + # compute norm for the vector. the norm will be used for computing similarity + norm[n] += pow(vec[n][ngram], 2) + + if n == 1: + length += term_freq + norm = [np.sqrt(n) for n in norm] + return vec, norm, length + + def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): + ''' + Compute the cosine similarity of two vectors. + :param vec_hyp: array of dictionary for vector corresponding to hypothesis + :param vec_ref: array of dictionary for vector corresponding to reference + :param norm_hyp: array of float for vector corresponding to hypothesis + :param norm_ref: array of float for vector corresponding to reference + :param length_hyp: int containing length of hypothesis + :param length_ref: int containing length of reference + :return: array of score for each n-grams cosine similarity + ''' + delta = float(length_hyp - length_ref) + # measure consine similarity + val = np.array([0.0 for _ in range(self.n)]) + for n in range(self.n): + # ngram + for (ngram,count) in vec_hyp[n].items(): + # vrama91 : added clipping + val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] + + if (norm_hyp[n] != 0) and (norm_ref[n] != 0): + val[n] /= (norm_hyp[n]*norm_ref[n]) + + assert(not math.isnan(val[n])) + # vrama91: added a length based gaussian penalty + val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) + return val + + # compute log reference length + self.ref_len = np.log(float(len(self.crefs))) + if len(self.crefs) == 1: + self.ref_len = 1 + scores = [] + for test, refs in zip(self.ctest, self.crefs): + # compute vector for test captions + vec, norm, length = counts2vec(test) + # compute vector for ref captions + score = np.array([0.0 for _ in range(self.n)]) + for ref in refs: + vec_ref, norm_ref, length_ref = counts2vec(ref) + score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) + # change by vrama91 - mean of ngram scores, instead of sum + score_avg = np.mean(score) + # divide by number of references + score_avg /= len(refs) + # multiply score by 10 + score_avg *= 10.0 + # append score of an image to the score list + scores.append(score_avg) + return scores + + def compute_score(self, option=None, verbose=0): + # compute idf + self.compute_doc_freq() + # assert to check document frequency + assert(len(self.ctest) >= max(self.document_frequency.values())) + # compute cider score + score = self.compute_cider() + # debug + # print score + return np.mean(np.array(score)), np.array(score) \ No newline at end of file diff --git a/promptbench/metrics/eval.py b/promptbench/metrics/eval.py index 720a39c..f997538 100644 --- a/promptbench/metrics/eval.py +++ b/promptbench/metrics/eval.py @@ -133,3 +133,51 @@ def compute_math_accuracy(preds, gts): processed_gts.append(gt) return sum(a == b for a, b in zip(processed_preds, processed_gts)) / len(processed_gts) + + @staticmethod + def compute_vqa_accuracy(preds, gts): + """ + Computes vqa accuracy for the VQAv2 dataset. + + Parameters: + ----------- + preds : list + A list of predictions. + gts : list + A list of answers. + + Returns: + -------- + float + The vqa accuracy. + """ + from .vqa.eval_vqa import VQAEval + metric = VQAEval(n=3) + dict_gts = {i: {"answers": val} for i, val in enumerate(gts)} + dict_preds = {i: {"answer": val} for i, val in enumerate(preds)} + score = metric.evaluate(dict_gts, dict_preds, list(range(len(preds)))) + return score + + @staticmethod + def compute_cider(preds, gts): + """ + Computes the CIDEr score for image captioning tasks. + + Parameters: + ----------- + preds : list + A list of predictions. + gts : list + A list of ground truth captions. + + Returns: + -------- + float + The CIDEr score. + """ + from .cider.cider import Cider + metric = Cider() + dict_gts = {i: val for i, val in enumerate(gts)} + dict_preds = {i: [val] for i, val in enumerate(preds)} + score, _ = metric.compute_score(gts=dict_gts, res=dict_preds) + return score \ No newline at end of file diff --git a/promptbench/metrics/vqa/eval_vqa.py b/promptbench/metrics/vqa/eval_vqa.py new file mode 100644 index 0000000..51b7057 --- /dev/null +++ b/promptbench/metrics/vqa/eval_vqa.py @@ -0,0 +1,315 @@ +""" +This is copied from https://github.com/salesforce/LAVIS/blob/main/lavis/common/vqa_tools/vqa_eval.py. +""" + +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +# coding=utf-8 + +__author__ = "aagrawal" + +# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: +# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py). +import sys +import re + + +class VQAEval: + def __init__(self, vqa=None, vqaRes=None, n=2): + self.n = n + self.accuracy = {} + self.evalQA = {} + self.evalQuesType = {} + self.evalAnsType = {} + self.vqa = vqa + self.vqaRes = vqaRes + if vqa is not None: + self.params = {"question_id": vqa.getQuesIds()} + self.contractions = { + "aint": "ain't", + "arent": "aren't", + "cant": "can't", + "couldve": "could've", + "couldnt": "couldn't", + "couldn'tve": "couldn't've", + "couldnt've": "couldn't've", + "didnt": "didn't", + "doesnt": "doesn't", + "dont": "don't", + "hadnt": "hadn't", + "hadnt've": "hadn't've", + "hadn'tve": "hadn't've", + "hasnt": "hasn't", + "havent": "haven't", + "hed": "he'd", + "hed've": "he'd've", + "he'dve": "he'd've", + "hes": "he's", + "howd": "how'd", + "howll": "how'll", + "hows": "how's", + "Id've": "I'd've", + "I'dve": "I'd've", + "Im": "I'm", + "Ive": "I've", + "isnt": "isn't", + "itd": "it'd", + "itd've": "it'd've", + "it'dve": "it'd've", + "itll": "it'll", + "let's": "let's", + "maam": "ma'am", + "mightnt": "mightn't", + "mightnt've": "mightn't've", + "mightn'tve": "mightn't've", + "mightve": "might've", + "mustnt": "mustn't", + "mustve": "must've", + "neednt": "needn't", + "notve": "not've", + "oclock": "o'clock", + "oughtnt": "oughtn't", + "ow's'at": "'ow's'at", + "'ows'at": "'ow's'at", + "'ow'sat": "'ow's'at", + "shant": "shan't", + "shed've": "she'd've", + "she'dve": "she'd've", + "she's": "she's", + "shouldve": "should've", + "shouldnt": "shouldn't", + "shouldnt've": "shouldn't've", + "shouldn'tve": "shouldn't've", + "somebody'd": "somebodyd", + "somebodyd've": "somebody'd've", + "somebody'dve": "somebody'd've", + "somebodyll": "somebody'll", + "somebodys": "somebody's", + "someoned": "someone'd", + "someoned've": "someone'd've", + "someone'dve": "someone'd've", + "someonell": "someone'll", + "someones": "someone's", + "somethingd": "something'd", + "somethingd've": "something'd've", + "something'dve": "something'd've", + "somethingll": "something'll", + "thats": "that's", + "thered": "there'd", + "thered've": "there'd've", + "there'dve": "there'd've", + "therere": "there're", + "theres": "there's", + "theyd": "they'd", + "theyd've": "they'd've", + "they'dve": "they'd've", + "theyll": "they'll", + "theyre": "they're", + "theyve": "they've", + "twas": "'twas", + "wasnt": "wasn't", + "wed've": "we'd've", + "we'dve": "we'd've", + "weve": "we've", + "werent": "weren't", + "whatll": "what'll", + "whatre": "what're", + "whats": "what's", + "whatve": "what've", + "whens": "when's", + "whered": "where'd", + "wheres": "where's", + "whereve": "where've", + "whod": "who'd", + "whod've": "who'd've", + "who'dve": "who'd've", + "wholl": "who'll", + "whos": "who's", + "whove": "who've", + "whyll": "why'll", + "whyre": "why're", + "whys": "why's", + "wont": "won't", + "wouldve": "would've", + "wouldnt": "wouldn't", + "wouldnt've": "wouldn't've", + "wouldn'tve": "wouldn't've", + "yall": "y'all", + "yall'll": "y'all'll", + "y'allll": "y'all'll", + "yall'd've": "y'all'd've", + "y'alld've": "y'all'd've", + "y'all'dve": "y'all'd've", + "youd": "you'd", + "youd've": "you'd've", + "you'dve": "you'd've", + "youll": "you'll", + "youre": "you're", + "youve": "you've", + } + self.manualMap = { + "none": "0", + "zero": "0", + "one": "1", + "two": "2", + "three": "3", + "four": "4", + "five": "5", + "six": "6", + "seven": "7", + "eight": "8", + "nine": "9", + "ten": "10", + } + self.articles = ["a", "an", "the"] + + self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") + self.commaStrip = re.compile("(\d)(,)(\d)") + self.punct = [ + ";", + r"/", + "[", + "]", + '"', + "{", + "}", + "(", + ")", + "=", + "+", + "\\", + "_", + "-", + ">", + "<", + "@", + "`", + ",", + "?", + "!", + ] + + def evaluate(self, gts, res, quesIds=None): + if quesIds == None: + quesIds = [quesId for quesId in self.params["question_id"]] + + # ================================================= + # Compute accuracy + # ================================================= + accQA = [] + # print("computing accuracy") + step = 0 + for quesId in quesIds: + resAns = res[quesId]["answer"] + resAns = resAns.replace("\n", " ") + resAns = resAns.replace("\t", " ") + resAns = resAns.strip() + resAns = self.processPunctuation(resAns) + resAns = self.processDigitArticle(resAns) + gtAcc = [] + gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]] + if len(set(gtAnswers)) > 1: + for ansDic in gts[quesId]["answers"]: + ansDic["answer"] = self.processPunctuation(ansDic["answer"]) + for gtAnsDatum in gts[quesId]["answers"]: + otherGTAns = [ + item for item in gts[quesId]["answers"] if item != gtAnsDatum + ] + matchingAns = [item for item in otherGTAns if item["answer"] == resAns] + acc = min(1, float(len(matchingAns)) / 3) + gtAcc.append(acc) + avgGTAcc = float(sum(gtAcc)) / len(gtAcc) + accQA.append(avgGTAcc) + self.setEvalQA(quesId, avgGTAcc) + # if step % 100 == 0: + # self.updateProgress(step / float(len(quesIds))) + step = step + 1 + + self.setAccuracy(accQA) + # print("Done computing accuracy") + + return self.accuracy["overall"] + + def processPunctuation(self, inText): + outText = inText + for p in self.punct: + if (p + " " in inText or " " + p in inText) or ( + re.search(self.commaStrip, inText) != None + ): + outText = outText.replace(p, "") + else: + outText = outText.replace(p, " ") + outText = self.periodStrip.sub("", outText, re.UNICODE) + return outText + + def processDigitArticle(self, inText): + outText = [] + tempText = inText.lower().split() + for word in tempText: + word = self.manualMap.setdefault(word, word) + if word not in self.articles: + outText.append(word) + else: + pass + for wordId, word in enumerate(outText): + if word in self.contractions: + outText[wordId] = self.contractions[word] + outText = " ".join(outText) + return outText + + def setAccuracy(self, accQA, accQuesType=None, accAnsType=None): + self.accuracy["overall"] = round(float(sum(accQA)) / len(accQA), self.n) + if accQuesType is not None: + self.accuracy["perQuestionType"] = { + quesType: round( + 100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]), + self.n, + ) + for quesType in accQuesType + } + if accAnsType is not None: + self.accuracy["perAnswerType"] = { + ansType: round( + 100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n + ) + for ansType in accAnsType + } + + def setEvalQA(self, quesId, acc): + self.evalQA[quesId] = round(100 * acc, self.n) + + def setEvalQuesType(self, quesId, quesType, acc): + if quesType not in self.evalQuesType: + self.evalQuesType[quesType] = {} + self.evalQuesType[quesType][quesId] = round(100 * acc, self.n) + + def setEvalAnsType(self, quesId, ansType, acc): + if ansType not in self.evalAnsType: + self.evalAnsType[ansType] = {} + self.evalAnsType[ansType][quesId] = round(100 * acc, self.n) + + def updateProgress(self, progress): + barLength = 20 + status = "" + if isinstance(progress, int): + progress = float(progress) + if not isinstance(progress, float): + progress = 0 + status = "error: progress var must be float\r\n" + if progress < 0: + progress = 0 + status = "Halt...\r\n" + if progress >= 1: + progress = 1 + status = "Done...\r\n" + block = int(round(barLength * progress)) + text = "\rFinshed Percent: [{0}] {1}% {2}".format( + "#" * block + "-" * (barLength - block), int(progress * 100), status + ) + sys.stdout.write(text) + sys.stdout.flush() \ No newline at end of file diff --git a/promptbench/models/__init__.py b/promptbench/models/__init__.py index e55b033..406261f 100644 --- a/promptbench/models/__init__.py +++ b/promptbench/models/__init__.py @@ -19,8 +19,19 @@ BaichuanModel: ['baichuan-inc/Baichuan2-7B-Base', 'baichuan-inc/Baichuan2-13B-Base', 'baichuan-inc/Baichuan2-7B-Chat', 'baichuan-inc/Baichuan2-13B-Chat'], } +MODEL_LIST_VLM = { + BLIP2Model: ['Salesforce/blip2-opt-2.7b', 'Salesforce/blip2-opt-6.7b', + 'Salesforce/blip2-flan-t5-xl', 'Salesforce/blip2-flan-t5-xxl'], + LLaVAModel: ['llava-hf/llava-1.5-7b-hf', 'llava-hf/llava-1.5-13b-hf'], + GeminiVisionModel: ['gemini-pro-vision'], + OpenAIVisionModel: ['gpt-4-vision-preview'], + QwenVLModel: ['Qwen/Qwen-VL', 'Qwen/Qwen-VL-Chat', + 'qwen-vl-plus', 'qwen-vl-max'], + InternLMVisionModel: ['internlm/internlm-xcomposer2-vl-7b'], +} SUPPORTED_MODELS = [model for model_class in MODEL_LIST.keys() for model in MODEL_LIST[model_class]] +SUPPORTED_MODELS_VLM = [model for model_class in MODEL_LIST_VLM.keys() for model in MODEL_LIST_VLM[model_class]] class LLMModel(object): @@ -161,3 +172,142 @@ def _other_concat_prompts(self, prompt_list): def __call__(self, input_text, **kwargs): """Predicts the output based on the given input text using the loaded model.""" return self.model.predict(input_text, **kwargs) + + +class VLMModel(object): + """ + A class providing an interface for various vision language models. + + This class supports creating and interfacing with different vision language models, handling prompt engineering, and performing model inference. + + Parameters: + ----------- + model : str + The name of the model to be used. + max_new_tokens : int, optional + The maximum number of new tokens to be generated (default is 20). + temperature : float, optional + The temperature for text generation (default is 0). + device : str, optional + The device to be used for inference (default is "cuda"). + dtype : str, optional + The loaded data type of the language model (default is "auto"). + model_dir : str or None, optional + The directory containing the model files (default is None). + system_prompt : str or None, optional + The system prompt to be used (default is None). + api_key : str or None, optional + The API key for API-based models (GPT series, Gemini series and Qwen series), if required (default is None). + + Methods: + -------- + _create_model(max_new_tokens, temperature, device, dtype, model_dir, system_prompt, api_key) + Creates and returns the appropriate model instance. + convert_text_to_prompt(text, role) + Constructs a prompt based on the text and role. + concat_prompts(prompt_list) + Concatenates multiple prompts into a single prompt. + _gpt_concat_prompts(prompt_list) + Concatenates prompts for GPT models. + _other_concat_prompts(prompt_list) + Concatenates prompts for non-GPT models. + __call__(input_images, input_text, **kwargs) + Makes a prediction based on the input images and input text using the loaded model. + """ + + @staticmethod + def model_list(): + return SUPPORTED_MODELS_VLM + + def __init__(self, model: str, max_new_tokens: int=20, temperature: float=0.0, device: str="cuda", dtype: str="auto", system_prompt: str=None, api_key:str =None): + self.model_name = model + self.model = self._create_model(max_new_tokens, temperature, device, dtype, system_prompt, api_key) + + def _create_model(self, max_new_tokens, temperature, device, dtype, system_prompt, api_key): + """Creates and returns the appropriate model based on the model name.""" + + # Dictionary mapping of model names to their respective classes + model_mapping = {model: model_class for model_class in MODEL_LIST_VLM.keys() for model in MODEL_LIST_VLM[model_class]} + + # Get the model class based on the model name and instantiate it + model_class = model_mapping.get(self.model_name) + if model_class: + if model_class in [OpenAIVisionModel]: + return model_class(self.model_name, max_new_tokens, temperature, system_prompt, api_key) + elif model_class in [GeminiVisionModel]: + return model_class(self.model_name, max_new_tokens, temperature, api_key) + elif model_class in [QwenVLModel]: + return model_class(self.model_name, max_new_tokens, temperature, device, dtype, system_prompt, api_key) + else: + return model_class(self.model_name, max_new_tokens, temperature, device, dtype) + else: + raise ValueError("The model is not supported!") + + def convert_text_to_prompt(self, text, role): + """Constructs multi_turn conversation for complex methods in prompt engineering.""" + if self.model_name == ['gpt-4-vision-preview']: + return {'role': role, 'content': text} + else: + # return str(role) + ': ' + str(text) + '\n' + return str(text) + '\n' + + def concat_prompts(self, prompt_list): + """Concatenates the prompts into a single prompt.""" + if self.model_name == ['gpt-4-vision-preview']: + return self._gpt_concat_prompts(prompt_list) + else: + return self._other_concat_prompts(prompt_list) + + def _gpt_concat_prompts(self, prompt_list): + """ + Concatenate prompts from various inputs into a single list of dictionaries. + + The function accepts any number of keyword arguments, each of which can be either + a dictionary or a list of dictionaries. It concatenates all inputs into a single list. + + Returns: + A list of dictionaries containing all the prompts from the input arguments. + """ + # Initialize an empty list to hold all dictionaries + all_prompts = [] + + # Iterate over each keyword argument + for arg in prompt_list: + # Check if the argument is a dictionary, and if so, add it to the list + if isinstance(arg, dict): + all_prompts.append(arg) + # Check if the argument is a list of dictionaries + elif isinstance(arg, list) and all(isinstance(item, dict) for item in arg): + # Extend the list with the dictionaries from the current argument + all_prompts.extend(arg) + else: + raise ValueError("All arguments must be dictionaries or lists of dictionaries.") + + return all_prompts + + def _other_concat_prompts(self, prompt_list): + """ + Concatenate prompts from various inputs into a single strings. + + The function accepts any number of keyword arguments, each of which must be + a string. It concatenates all inputs into a single string. + + Returns: + A string containing all the prompts from the input arguments. + """ + # Initialize an empty string to hold all prompts + all_prompts = "" + + # Iterate over each keyword argument + for arg in prompt_list: + # Check if the argument is a string, and if so, add it to the list + if isinstance(arg, str): + all_prompts = all_prompts + '\n' + arg + else: + raise ValueError("All arguments must be strings.") + + return all_prompts + + def __call__(self, input_images, input_text, **kwargs): + """Predicts the output based on the given input text using the loaded model.""" + return self.model.predict(input_images, input_text, **kwargs) \ No newline at end of file diff --git a/promptbench/models/models.py b/promptbench/models/models.py index df48447..a9087e1 100644 --- a/promptbench/models/models.py +++ b/promptbench/models/models.py @@ -544,3 +544,459 @@ def predict(self, input_text, **kwargs): response = model.generate_content(input_text).text return response + + +class VLMBaseModel(ABC): + """ + Abstract base class for vision language model interfaces. + + This class provides a common interface for various vision language models and includes methods for prediction. + + Parameters: + ----------- + model : str + The name of the vision language model. + max_new_tokens : int + The maximum number of new tokens to be generated. + temperature : float + The temperature for text generation (default is 0). + device: str + The device to use for inference (default is 'auto'). + + Methods: + -------- + predict(input_images, input_text, **kwargs) + Generates a prediction based on the input images and text. + __call__(input_image, input_text, **kwargs) + Shortcut for predict method. + """ + def __init__(self, model_name, max_new_tokens, temperature, device='auto'): + self.model_name = model_name + self.max_new_tokens = max_new_tokens + self.temperature = temperature + self.device = device + self.placeholder = "" + + def predict(self, input_images, input_text, **kwargs): + if self.device == 'auto': + device = 'cuda' if torch.cuda.is_available() else 'cpu' + else: + device = self.device + + for i in range(len(input_images)): + input_text = self.placeholder + input_text + + if self.enable_multiple_images: + input_ids = self.processor(text=input_text, images=input_images, return_tensors="pt").to(device) + else: + input_ids = self.processor(text=input_text, images=input_images[0], return_tensors="pt").to(device) + + outputs = self.model.generate(**input_ids, + max_new_tokens=self.max_new_tokens, + temperature=self.temperature, + do_sample=True, + **kwargs) + + out = self.processor.batch_decode(outputs, skip_special_tokens=True)[0] + return out + + def __call__(self, input_images, input_text, **kwargs): + return self.predict(input_images, input_text, **kwargs) + +class BLIP2Model(VLMBaseModel): + """ + Vision Language model class for the BLIP2 model. + + Inherits from VLMBaseModel and sets up the BLIP2 vision language model for use. + + Parameters: + ----------- + model : str + The name of the BLIP2 model. + max_new_tokens : int + The maximum number of new tokens to be generated. + temperature : float, optional + The temperature for text generation (default is 0). + device: str + The device to use for inference (default is 'auto'). + dtype: str + The dtype to use for inference (default is 'auto'). + + Parameters of predict method: + ---------------- + input_images: list of PIL.Image + The input images. + input_text: str + The input text. + """ + def __init__(self, model_name, max_new_tokens, temperature, device, dtype): + super(BLIP2Model, self).__init__(model_name, max_new_tokens, temperature, device) + from transformers import Blip2Processor, Blip2ForConditionalGeneration + self.processor = Blip2Processor.from_pretrained(self.model_name, torch_dtype=dtype, device_map=device, use_fast=False) + self.model = Blip2ForConditionalGeneration.from_pretrained(self.model_name, torch_dtype=dtype, device_map=device) + self.enable_multiple_images = False + +class LLaVAModel(VLMBaseModel): + """ + Vision Language model class for the LLaVA model. + + Inherits from VLMBaseModel and sets up the LLaVA vision language model for use. + + Parameters: + ----------- + model : str + The name of the LLaVA model. + max_new_tokens : int + The maximum number of new tokens to be generated. + temperature : float + The temperature for text generation (default is 0). + device: str + The device to use for inference (default is 'auto'). + dtype: str + The dtype to use for inference (default is 'auto'). + + Parameters of predict method: + ---------------- + input_image: list of PIL.Image + The input images. + input_text: str + The input text. Using as the placeholder for the image. + """ + def __init__(self, model_name, max_new_tokens, temperature, device, dtype): + super(LLaVAModel, self).__init__(model_name, max_new_tokens, temperature, device) + from transformers import AutoProcessor, LlavaForConditionalGeneration + self.processor = AutoProcessor.from_pretrained(self.model_name, device_map=device, trust_remote_code=True) + self.model = LlavaForConditionalGeneration.from_pretrained(self.model_name, device_map=device) + self.enable_multiple_images = True + self.placeholder = "" # a specialized placeholder of LLaVA model + +class GeminiVisionModel(VLMBaseModel): + """ + Vision Language model class for interfacing with Google's Gemini models. + + Inherits from VLMBaseModel and sets up a model interface for Gemini models. + + Parameters: + ----------- + model : str + The name of the PaLM model. + max_new_tokens : int + The maximum number of new tokens to be generated. + temperature : float, optional + The temperature for text generation (default is 0). + gemini_key : str, optional + The Gemini API key (default is None). + + Parameters of predict method: + ---------------- + input_image: list of PIL.Image + The input images. + input_text: str + The input text. + """ + def __init__(self, model, max_new_tokens, temperature, gemini_key=None): + super(GeminiVisionModel, self).__init__(model, max_new_tokens, temperature) + self.gemini_key = gemini_key + self.enable_multiple_images = True + + def predict(self, input_images, input_text, **kwargs): + import google.generativeai as genai + + genai.configure(api_key=self.gemini_key) + + # Set up the model + generation_config = { + "temperature": self.temperature, + "top_p": 1, + "top_k": 1, + "max_output_tokens": self.max_new_tokens, + } + + safety_settings = [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_MEDIUM_AND_ABOVE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_MEDIUM_AND_ABOVE" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_MEDIUM_AND_ABOVE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_MEDIUM_AND_ABOVE" + } + ] + + model = genai.GenerativeModel(model_name="gemini-pro-vision", + generation_config=generation_config, + safety_settings=safety_settings) + + response = model.generate_content(input_images + [input_text]) + + try: + return response.text + except: + print('Error when generating the response using Gemini model') + return "" + +class OpenAIVisionModel(VLMBaseModel): + """ + Vision Language model class for interfacing with OpenAI's GPT models. + + Inherits from VLMBaseModel and sets up a model interface for OpenAI GPT models. + + Parameters: + ----------- + model : str + The name of the OpenAI model. + max_new_tokens : int + The maximum number of new tokens to be generated. + temperature : float + The temperature for text generation (default is 0). + system_prompt : str + The system prompt to be used (default is None). + openai_key : str + The OpenAI API key (default is None). + + Parameters of predict method: + ---------------- + input_image: list of str + The url / local path of the input images. + input_text: str + The input text. + """ + def __init__(self, model_name, max_new_tokens, temperature, system_prompt, openai_key): + super(OpenAIVisionModel, self).__init__(model_name, max_new_tokens, temperature) + self.openai_key = openai_key + self.system_prompt = system_prompt + self.enable_multiple_images = True + + def predict(self, input_images, input_text, **kwargs): + + if self.system_prompt is None: + system_messages = {'role': "system", 'content': "You are a helpful assistant."} + else: + system_messages = {'role': "system", 'content': self.system_prompt} + # extra parameterss + n = kwargs['n'] if 'n' in kwargs else 1 + temperature = kwargs['temperature'] if 'temperature' in kwargs else self.temperature + max_new_tokens = kwargs['max_new_tokens'] if 'max_new_tokens' in kwargs else self.max_new_tokens + + # for input image with url + if "http://" in input_images[0] or "https://" in input_images[0]: + + from openai import OpenAI + client = OpenAI(api_key=self.openai_key) + + messages = [{"role": "user", + "content": [ + {"type": "text", "text": input_text}, + ]}] + messages.insert(0, system_messages) + for input_image in input_images: + messages[1]['content'].append({ + "type": "image_url", + "image_url": {"url": input_image}, + }) + + response = client.chat.completions.create( + model=self.model_name, + messages=messages, + temperature=temperature, + max_tokens=max_new_tokens, + n=n, + ) + + if n > 1: + result = [choice.message.content for choice in response.choices] + else: + result = response.choices[0].message.content + + return result + + # for input image with local path + else: + import base64 + import requests + + api_key = self.openai_key + + def encode_image(image_path): + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode('utf-8') + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}" + } + + payload = { + "model": "gpt-4-vision-preview", + "messages": [ + system_messages, + { + "role": "user", + "content": [ + { + "type": "text", + "text": input_text + }, + ] + } + ], + "temperature": temperature, + "max_tokens": max_new_tokens, + "n": n, + } + base64_images = [] + for input_image in input_images: + base64_image = encode_image(input_image) # Getting the base64 string + base64_images.append(base64_image) + for base64_img in base64_images: + payload['messages'][1]['content'].append({ + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64, {base64_img}" + } + }) + + response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) + + if n > 1: + result = [choice['message']['content'] for choice in response.json()['choices']] + else: + result = response.json()['choices'][0]['message']['content'] + + return result + +class QwenVLModel(VLMBaseModel): + """ + Vision Language model class for the Qwen model. + + Inherits from VLMBaseModel and sets up the Qwen vision language model for use. + + Parameters: + ----------- + model : str + The name of the Qwen model. + max_new_tokens : int + The maximum number of new tokens to be generated. + temperature : float + The temperature for text generation (default is 0). + device: str + The device to use for inference (default is 'auto'). + dtype: str + The dtype to use for inference (default is 'auto'). + system_prompt : str + The system prompt to be used (default is None). + api_key : str + The api key for the Qwen model (default is None). + + Parameters of predict method: + ---------------- + input_image: list of str + The url / local path of the input images. + (Add "file://" prefix for local path when using 'qwen-vl-plus' and 'qwen-vl-max') + input_text: str + The input text. + """ + def __init__(self, model_name, max_new_tokens, temperature, device, dtype, system_prompt, api_key): + if model_name in ['qwen-vl-plus', 'qwen-vl-max']: + super(QwenVLModel, self).__init__(model_name, max_new_tokens, temperature) + assert api_key is not None, f"API key is required for {model_name}" + self.api_key = api_key + self.system_prompt = system_prompt + else: + super(QwenVLModel, self).__init__(model_name, max_new_tokens, temperature, device) + from transformers import AutoModelForCausalLM, AutoTokenizer + self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype, device_map=device, trust_remote_code=True).eval() + self.tokenizer = AutoTokenizer.from_pretrained(model_name, torch_dtype=dtype, device_map=device, trust_remote_code=True) + self.enable_multiple_images = True + + def predict(self, input_images, input_text, **kwargs): + + if self.model_name in ['qwen-vl-plus', 'qwen-vl-max']: + from http import HTTPStatus + import dashscope + dashscope.api_key = self.api_key + if self.system_prompt is None: + system_messages = { + 'role': 'system', + 'content': [{ + 'text': 'You are a helpful assistant.' + }] + } + else: + system_messages = { + 'role': 'system', + 'content': [{ + 'text': self.system_prompt + }] + } + messages = [{ + 'role': 'user', + 'content': [{'image': input_image} for input_image in input_images] + [{'text': input_text}] + }] + messages.insert(0, system_messages) + response = dashscope.MultiModalConversation.call(model=self.model_name, + messages=messages) + + if response.status_code == HTTPStatus.OK: + return response['output']['choices'][0]['message']['content'][0]['text'] + else: + print(response.code) # The error code. + print(response.message) # The error message. + return "" + + else: + query = self.tokenizer.from_list_format( + [{'image': input_image} for input_image in input_images] + [{'text': input_text}] + ) + response, _ = self.model.chat(self.tokenizer, query=query, history=None, + max_new_tokens=self.max_new_tokens, temperature=self.temperature) + return response + +class InternLMVisionModel(VLMBaseModel): + """ + Vision Language model class for interfacing with InternLM's vision language models. + + Inherits from VLMBaseModel and sets up a model interface for InternLM's vision language models. + + Parameters: + ----------- + model_name : str + The name of the InternLM model. + max_new_tokens : int + The maximum number of new tokens to be generated. + temperature : float, optional + The temperature for text generation (default is 0). + device: str + The device to use for inference (default is 'auto'). + dtype: str + The dtype to use for inference (default is 'auto'). + + Parameters of predict method: + ---------------- + input_image: list of str + The url / local path of the input images. + input_text: str + The input text. + """ + def __init__(self, model_name, max_new_tokens, temperature, device, dtype): + super(InternLMVisionModel, self).__init__(model_name, max_new_tokens, temperature, device) + from transformers import AutoModel, AutoTokenizer + self.model = AutoModel.from_pretrained(model_name, torch_dtype=dtype, device_map=device, trust_remote_code=True).eval() + self.tokenizer = AutoTokenizer.from_pretrained(model_name, torch_dtype=dtype, device_map=device, trust_remote_code=True) + self.enable_multiple_images = False + self.placeholder = "" # a specialized placeholder of InternLM model + + def predict(self, input_images, input_text, **kwargs): + input_text = self.placeholder + input_text + with torch.cuda.amp.autocast(): + response, _ = self.model.chat(self.tokenizer, query=input_text, image=input_images[0], history=[], do_sample=True, + max_new_tokens=self.max_new_tokens, temperature=self.temperature) + return response \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 1bcdc3b..28ca4e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,12 @@ sentencepiece==0.1.99 tokenizers==0.15.0 torch==2.1.1 tqdm==4.66.1 -transformers==4.36.2 \ No newline at end of file +transformers==4.36.2 +Pillow==10.2.0 +google-generativeai==0.4.0 +dashscope==1.14.1 +einops==0.7.0 +transformers_stream_generator==0.0.5 +torchvision==0.16.1 +matplotlib==3.8.3 +tiktoken==0.6.0 \ No newline at end of file