Skip to content
This repository has been archived by the owner on Dec 1, 2024. It is now read-only.

Commit

Permalink
Data wrangle benchmark (#95)
Browse files Browse the repository at this point in the history
* benchmarked 6.7B.

* some new scripts

* update some results of opt30b

* update some results.

* added some 175B results, one piece left.

* updated the missing results.

* Update README.md

* update readme
  • Loading branch information
BinhangYuan authored Mar 8, 2023
1 parent 74bdca7 commit 3834bb3
Show file tree
Hide file tree
Showing 14 changed files with 514 additions and 39 deletions.
13 changes: 13 additions & 0 deletions flexgen/apps/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@ python completion.py --model facebook/opt-30b --percent 100 0 100 0 100 0 --comp
python completion.py --model facebook/opt-66b --percent 50 10 100 0 100 0 --compress-weight
```

### Data Wrangling

Run the tests of data wrangling tasks in the [fm_data_tasks](https://github.com/HazyResearch/fm_data_tasks) repo from [HazyResearch](https://github.com/HazyResearch).
Check [more details](./data_wrangle/README.md).
```
cd data_wrangle
bash install
bash test_batch_query_all_opt6.7b.sh
bash test_batch_query_all_opt30b.sh
bash test_batch_query_all_opt175b.sh
```


### HELM benchmark
Run Massive Multitask Language Understanding (MMLU) scenario.
```
Expand Down
73 changes: 70 additions & 3 deletions flexgen/apps/data_wrangle/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# FlexGen for Data Wrangling Tasks.

Here we show how to use FlexGen for the data wrangling tasks. The implementation follows the [fm_data_tasks](https://github.com/HazyResearch/fm_data_tasks) repo from [HazyResearch](https://github.com/HazyResearch).
Here we show how to use FlexGen for the data wrangling tasks including entity match (EM), data imputation (DI) and error detection (ED). The implementation follows the [fm_data_tasks](https://github.com/HazyResearch/fm_data_tasks) repo from [HazyResearch](https://github.com/HazyResearch).

## Install

Expand All @@ -9,10 +9,77 @@ Here we show how to use FlexGen for the data wrangling tasks. The implementation

## Examples

- To check the outcome and verify the result of a data imputation task (e.g., Restaurant), run:
- To check the outcome and verify the result of a data imputation task (e.g., Restaurant on OPT-6.7B), run:

bash test_single_query_case.sh

- To test FlexGen Throughput of a data imputation task (e.g., Restaurant), run:
- To test the throughput of FlexGen for a data imputation task (e.g., Restaurant on OPT-6.7B), run:

bash test_batch_query_case.sh

- To run the complete tests of all tasks on OPT-6.7B:

bash test_batch_query_all_opt6.7b.sh

- To run the complete tests of all tasks on OPT-30B:

bash test_batch_query_all_opt30b.sh

- To run the complete tests of all tasks on OPT-175B:

bash test_batch_query_all_opt175b.sh



## Benchmark Results

- Notice that in this data wrangling tasks, such as entity match (EM), data imputation (DI) and error detection (ED), the input sequences length is **very long** (from 123 to 1274), but the output length is **very short** (e.g., 3, 5, or 10). Most of the inference time is spent on prefill phase, so here we report the throughput that includes both input and output tokens as our measurement.

- We run the experiments on the same setting as the HELM benchmark with a single T4 (16GB) GPU, 200GB of DRAM, and 1.5TB SSD connected by NVMe.

### OPT6.7B

| Task | Tested Samples | Input Length | Output Length | Time (s) |Input + Output Throughput (token/s)|
|------------------------|-------------------|---------------|---------------|----------|----------------------|
| EM: Fodors-Zagats | 189 | 744 | 3 | 109.556 | 1281.871 |
| EM: Beer | 91 | 592 | 3 | 42.087 | 1272.360 |
| EM: iTunes-Amazon | 109 | 529 | 3 | 59.467 | 966.178 |
| EM: Walmart-Amazon | 200 | 748 | 3 | 126.538 | 1186.992 |
| EM: Amazon-Google | 200 | 876 | 3 | 144.593 | 1215.828 |
| EM: DBLP-ACM | 200 | 1274 | 3 | 207.513 | 1230.767 |
| EM: DBLP-GoogleScholar | 200 | 1209 | 3 | 232.65 | 1097.78 |
| DI: Restaurant | 86 | 123 | 5 | 10.397 | 984.865 |
| DI: Buy | 65 | 488 | 10 | 43.077 | 739.876 |
| ED: Hospital | 200 | 200 | 3 | 30.137 | 1347.203 |


### OPT30B

| Task | Tested Samples | Input Length | Output Length | Time (s) |Input + Output Throughput (token/s)|
|------------------------|-------------------|---------------|---------------|----------|----------------------|
| EM: Fodors-Zagats | 189 | 744 | 3 | 541.550 | 248.287 |
| EM: Beer | 91 | 592 | 3 | 238.58 | 224.450 |
| EM: iTunes-Amazon | 109 | 529 | 3 | 267.639 | 198.775 |
| EM: Walmart-Amazon | 200 | 748 | 3 | 682.635 | 220.030 |
| EM: Amazon-Google | 200 | 876 | 3 | 799.514 | 219.884 |
| EM: DBLP-ACM | 200 | 1274 | 3 | 1119.272 | 228.184 |
| EM: DBLP-GoogleScholar | 200 | 1209 | 3 | 1271.534 | 190.636 |
| DI: Restaurant | 86 | 123 | 5 | 60.310 | 169.790 |
| DI: Buy | 65 | 488 | 10 | 185.882 | 160.747 |
| ED: Hospital | 200 | 200 | 3 | 158.329 | 256.429 |


### OPT175B

| Task | Tested Samples | Input Length | Output Length | Time (s) |Input + Output Throughput (token/s)|
|------------------------|----------------|---------------|---------------|----------|----------------------|
| EM: Fodors-Zagats | 189 | 744 | 3 |3928.310 | 34.228 |
| EM: Beer | 91 | 592 | 3 |1356.786 | 35.083 |
| EM: iTunes-Amazon | 109 | 529 | 3 |1569.062 | 33.906 |
| EM: Walmart-Amazon | 200 | 748 | 3 |4171.319 | 36.008 |
| EM: Amazon-Google | 200 | 876 | 3 |4893.572 | 35.925 |
| EM: DBLP-ACM | 200 | 1274 | 3 |7624.726 | 33.496 |
| EM: DBLP-GoogleScholar | 200 | 1209 | 3 |8275.828 | 29.290 |
| DI: Restaurant | 86 | 123 | 5 |648.762 | 16.968 |
| DI: Buy | 65 | 488 | 10 |2086.961 | 14.317 |
| ED: Hospital | 200 | 200 | 3 |1154.133 | 35.178 |
94 changes: 65 additions & 29 deletions flexgen/apps/data_wrangle/data_wrangle_run.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
# The source code in this file is partially adapted from
# https://github.com/HazyResearch/fm_data_tasks/blob/main/fm_data_tasks/utils/prompt_utils.py
# which is under Apache License Version 2.0.

"""Run inference."""
import argparse
from tqdm import tqdm
import json
import math
import logging
from pathlib import Path
import time
import numpy as np
from transformers import AutoTokenizer, AutoConfig
# from manifest import Manifest

import flexgen.apps.data_wrangle.utils.data_utils as data_utils
import flexgen.apps.data_wrangle.utils.prompt_utils as prompt_utils
from flexgen.apps.data_wrangle.utils import constants
Expand Down Expand Up @@ -174,7 +178,10 @@ def parse_args() -> argparse.Namespace:


def get_tokenizer(name):
tokenizer = AutoTokenizer.from_pretrained(name, padding_side="left")
if name == 'facebook/opt-175b':
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-30b', padding_side="left")
else:
tokenizer = AutoTokenizer.from_pretrained(name, padding_side="left")
tokenizer.add_bos_token = False
if 'galactica' in name:
config = AutoConfig.from_pretrained(name)
Expand Down Expand Up @@ -204,12 +211,11 @@ def single_query_test(args, task_instruction, test_data, task, pd_data_files, te
num_bits=4, group_size=64,
group_dim=2, symmetric=False))

print(f"Init weights begin.")
logger.info(f"Init weights begin.")
tic = time.time()
model = OptLM(args.model, env, args.path, policy)
print(f"Init weights end. Elapsed: {time.time() - tic:.2f} s", flush=True)
logger.info(f"Init weights end. Elapsed: {time.time() - tic:.2f} s")


if args.add_task_instruction:
prompt = lambda x: f"{task_instruction} {x}"
else:
Expand Down Expand Up @@ -246,14 +252,13 @@ def single_query_test(args, task_instruction, test_data, task, pd_data_files, te
gt = test_data["label_str"]
preds = []
idx = 0
# Run a few for printing -- they are cached
for _ in range(args.num_print):
logger.info(prompt(queries[idx]))
tic = time.time()
input_ids_tmp = tokenizer(prompt(queries[idx]), padding="max_length",
return_tensors="np",
max_length=args.pad_to_seq_len).input_ids
print(input_ids_tmp.shape)
logger.info(input_ids_tmp.shape)
output_ids_tmp = model.generate(input_ids_tmp,
do_sample=True,
temperature=args.temperature,
Expand Down Expand Up @@ -292,6 +297,7 @@ def single_query_test(args, task_instruction, test_data, task, pd_data_files, te
f"_{int(args.add_task_instruction)}inst"
f"_{int(args.class_balanced)}cb"
f"_{args.sample_method}"
f"_{args.model}"
f"_{args.num_print}run"
f"_{int(args.dry_run)}dry" / f"trial_{trial_num}.feather"
)
Expand Down Expand Up @@ -336,16 +342,17 @@ def batch_query_test(args, task_instruction, test_data, task, pd_data_files, tes
num_bits=4, group_size=64,
group_dim=2, symmetric=False))

print(f"Init weights begin.")
logger.info(f"Init weights begin.")
tic = time.time()
model = OptLM(args.model, env, args.path, policy)
print(f"Init weights end. Elapsed: {time.time() - tic:.2f} s", flush=True)
logger.info(f"Init weights end. Elapsed: {time.time() - tic:.2f} s.")

if args.add_task_instruction:
prompt = lambda x: f"{task_instruction} {x}"
else:
prompt = lambda x: f"{x}"
trial_metrics = {"prec": [], "rec": [], "f1": [], "acc": [], "throughput": []}
trial_metrics = {"prec": [], "rec": [], "f1": [], "acc": [], "total_time": [],
"output_throughput": [], "total_throughput": []}

saved_prefix = None

Expand Down Expand Up @@ -377,31 +384,54 @@ def batch_query_test(args, task_instruction, test_data, task, pd_data_files, tes
preds = []
idx = 0

# Run a few for printing -- they are cached
max_prompt_seq_length = 0
prompt_strs = []
for _ in range(args.num_run):
if idx == 0:
logger.info(f"This is a sample prompt: {prompt(queries[idx])}")
# if idx == 0:
# logger.info(f"This is a sample prompt: {prompt(queries[idx])}")
prompt_strs.append(prompt(queries[idx]))
idx += 1

current_prompt_tmp = tokenizer(prompt(queries[idx]), padding="max_length",
return_tensors="np", max_length=args.pad_to_seq_len).input_ids
# logger.info(f"Current prompt <{idx}> length: {current_prompt_tmp.shape[1]}")
max_prompt_seq_length = max(max_prompt_seq_length, current_prompt_tmp.shape[1])
idx += 1

logger.info(f"max_prompt_seq_length: {max_prompt_seq_length}")
tic = time.time()

input_ids_tmp = tokenizer(prompt_strs, padding="max_length",
input_ids = tokenizer(prompt_strs, padding="max_length",
return_tensors="np",
max_length=args.pad_to_seq_len).input_ids
output_ids_tmp = model.generate(input_ids_tmp,
do_sample=True,
temperature=args.temperature,
max_new_tokens=args.max_tokens,
stop=args.stop_token)
max_length=max_prompt_seq_length).input_ids
output_ids = []

flexgen_batch_size = args.gpu_batch_size*args.num_gpu_batches
num_batched_run = math.floor(args.num_run/flexgen_batch_size)
args.num_run = num_batched_run * flexgen_batch_size
input_ids = input_ids[0:args.num_run]

for i in tqdm(range(num_batched_run)):
input_ids_tmp = input_ids[i*flexgen_batch_size: (i+1)*flexgen_batch_size]
output_ids_tmp = model.generate(input_ids_tmp,
do_sample=True,
temperature=args.temperature,
max_new_tokens=args.max_tokens,
stop=args.stop_token)
output_ids.extend(output_ids_tmp)

toc = time.time()
input_strs = tokenizer.batch_decode(input_ids_tmp, skip_special_tokens=True)
output_strs = tokenizer.batch_decode(output_ids_tmp, skip_special_tokens=True)
preds = [ output_strs[i][len(input_strs[i]):] for i in range(len(input_strs))]
input_strs = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
output_strs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [output_strs[i][len(input_strs[i]):] for i in range(len(input_strs))]

throughput = args.num_run * args.max_tokens/(time.time() - tic)
print(f"Batch inference run end. Elapsed: { toc - tic:.2f} s, Throughput: {throughput:.2f} token/s")
total_time = time.time() - tic
total_prompt_tokens = args.num_run * max_prompt_seq_length
total_generate_tokens = args.num_run * args.max_tokens
output_throughput = total_generate_tokens/total_time
total_throughput = (total_prompt_tokens+total_generate_tokens)/total_time
logger.info(f"Batch inference run end. Elapsed: {total_time:.2f} s;")
logger.info(f"Output throughput: {output_throughput:.2f} token/s;")
logger.info(f"Total throughput: {total_throughput:.2f} token/s;")
# Save trial predictions
save_data = test_data.iloc[:args.num_run].copy(deep=True).reset_index()
gt = gt[:args.num_run]
Expand All @@ -412,13 +442,18 @@ def batch_query_test(args, task_instruction, test_data, task, pd_data_files, tes

logger.info(
f"Metrics Trial {trial_num}\n"
f"Prec: {prec:.3f} Recall: {rec:.3f} Acc: {acc:.3f} F1: {f1:.3f} FlexGen Throughput: {throughput:.3f}"
f"Prec: {prec:.3f} Recall: {rec:.3f} Acc: {acc:.3f} F1: {f1:.3f} \n"
f"<FlexGen> time: {total_time:.3f} \n"
f"<FlexGen> output throughput: {output_throughput:.3f} \n"
f"<FlexGen> total throughput: {total_throughput:.3f}"
)
trial_metrics["rec"].append(rec)
trial_metrics["prec"].append(prec)
trial_metrics["acc"].append(acc)
trial_metrics["f1"].append(f1)
trial_metrics["throughput"].append(throughput)
trial_metrics["total_time"].append(total_time)
trial_metrics["output_throughput"].append(output_throughput)
trial_metrics["total_throughput"].append(total_throughput)

output_file = (
Path(args.output_dir)
Expand All @@ -429,6 +464,7 @@ def batch_query_test(args, task_instruction, test_data, task, pd_data_files, tes
f"_{int(args.add_task_instruction)}inst"
f"_{int(args.class_balanced)}cb"
f"_{args.sample_method}"
f"_{args.model}"
f"_{args.num_run}run"
f"_{int(args.dry_run)}dry" / f"trial_{trial_num}.feather"
)
Expand Down
2 changes: 1 addition & 1 deletion flexgen/apps/data_wrangle/install.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pip install pandas==1.4.2
pip install sentence_transformers==2.2.0
pip install sentence-transformers==2.2.2
pip install rich==12.2.0
pip install pyarrow==7.0.0

Expand Down
88 changes: 88 additions & 0 deletions flexgen/apps/data_wrangle/test_batch_query_all_opt175b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
python3 ./data_wrangle_run.py\
--num_run 189 \
--num_trials 1 \
--nan_tok "" \
--do_test \
--sample_method manual \
--data_dir data/datasets/entity_matching/structured/Fodors-Zagats \
--batch_run --pad-to-seq-len 744 --model facebook/opt-175b --pin-weight 0 --cpu --percent 0 50 0 0 0 100 --gpu-batch-size 15 --num-gpu-batches 6

python3 ./data_wrangle_run.py\
--num_run 91 \
--num_trials 1 \
--nan_tok "" \
--do_test \
--sample_method manual \
--data_dir data/datasets/entity_matching/structured/Beer \
--batch_run --pad-to-seq-len 592 --model facebook/opt-175b --pin-weight 0 --cpu --percent 0 50 0 0 0 100 --gpu-batch-size 20 --num-gpu-batches 4

python3 ./data_wrangle_run.py\
--num_run 109 \
--num_trials 1 \
--nan_tok "" \
--do_test \
--sample_method manual \
--data_dir data/datasets/entity_matching/structured/iTunes-Amazon \
--batch_run --pad-to-seq-len 529 --model facebook/opt-175b --pin-weight 0 --cpu --percent 0 50 0 0 0 100 --gpu-batch-size 20 --num-gpu-batches 5

python3 ./data_wrangle_run.py\
--num_run 200 \
--num_trials 1 \
--nan_tok "" \
--do_test \
--sample_method manual \
--data_dir data/datasets/entity_matching/structured/Walmart-Amazon \
--batch_run --pad-to-seq-len 748 --model facebook/opt-175b --pin-weight 0 --cpu --percent 0 50 0 0 0 100 --gpu-batch-size 10 --num-gpu-batches 10

python3 ./data_wrangle_run.py\
--num_run 200 \
--num_trials 1 \
--nan_tok "" \
--do_test \
--sample_method manual \
--data_dir data/datasets/entity_matching/structured/Amazon-Google \
--batch_run --pad-to-seq-len 876 --model facebook/opt-175b --pin-weight 0 --cpu --percent 0 50 0 0 0 100 --gpu-batch-size 10 --num-gpu-batches 10

python3 ./data_wrangle_run.py\
--num_run 200 \
--num_trials 1 \
--nan_tok "" \
--do_test \
--sample_method manual \
--data_dir data/datasets/entity_matching/structured/DBLP-ACM \
--batch_run --pad-to-seq-len 1274 --model facebook/opt-175b --pin-weight 0 --cpu --percent 0 50 0 0 0 100 --gpu-batch-size 8 --num-gpu-batches 5

python3 ./data_wrangle_run.py\
--num_run 200 \
--num_trials 1 \
--nan_tok "" \
--do_test \
--sample_method manual \
--data_dir data/datasets/entity_matching/structured/DBLP-GoogleScholar \
--batch_run --pad-to-seq-len 1209 --model facebook/opt-175b --pin-weight 0 --cpu --percent 0 50 0 0 0 100 --gpu-batch-size 8 --num-gpu-batches 5

python3 ./data_wrangle_run.py\
--num_run 86 \
--num_trials 1 \
--max_tokens 5 \
--do_test \
--sample_method manual \
--data_dir data/datasets/data_imputation/Restaurant \
--batch_run --pad-to-seq-len 123 --model facebook/opt-175b --pin-weight 0 --cpu --percent 0 50 0 0 0 100 --gpu-batch-size 86 --num-gpu-batches 1

python3 ./data_wrangle_run.py\
--num_run 65 \
--num_trials 1 \
--max_tokens 10 \
--do_test \
--sample_method manual \
--data_dir data/datasets/data_imputation/Buy \
--batch_run --pad-to-seq-len 488 --model facebook/opt-175b --pin-weight 0 --cpu --percent 0 50 0 0 0 100 --gpu-batch-size 30 --num-gpu-batches 2

python3 ./data_wrangle_run.py\
--num_run 200 \
--num_trials 1 \
--do_test \
--sample_method manual \
--data_dir data/datasets/error_detection/Hospital \
--batch_run --pad-to-seq-len 200 --model facebook/opt-175b --pin-weight 0 --cpu --percent 0 50 0 0 0 100 --gpu-batch-size 50 --num-gpu-batches 4
Loading

0 comments on commit 3834bb3

Please sign in to comment.