From 9045f08ac1af38c13100cfabdd17412f57dbffd2 Mon Sep 17 00:00:00 2001 From: jacobthebanana <50071502+jacobthebanana@users.noreply.github.com> Date: Tue, 28 May 2024 11:40:19 -0400 Subject: [PATCH] Add revised benchmarking logic and results (#9) * Revised estimation of batch count, directly retrieving from len(train_dataloader). Deleted unused timer_handle argument in Trainer. Revised handling of "max_seq_len" override in benchmarking. Added support for automatic switching between lora and full-rank sharding scheme in benchmarking. * Revised handling of unspecified max_seq_length. Added llama-3 to benchmark model_list. * Benchmarking: Revised benchmark script to ensure consistent per-device train batch size. * Benchmarking: replaced trainer.step with trainer.train_step to avoid eval overhead in benchmarking. Revised benchmark parsing logic; display optimal batch size for each context width value. * Benchmarking: Updated reference throughput based on updated logic. * Benchmarking: Updated reference throughput descriptions. --- docs/reference_throughput.md | 55 ++++++++++++------------ profiling/README.md | 2 +- profiling/benchmark.py | 60 ++++++++++++++++----------- profiling/configs/benchmark.yaml | 1 - profiling/configs/lora-benchmark.yaml | 1 - profiling/launch_benchmark.py | 24 ++++++----- profiling/launch_benchmark.sh | 2 +- profiling/parse_benchmark.py | 57 +++++++++++++++++++------ vectorlm/utils/model_utils.py | 17 ++++++-- 9 files changed, 138 insertions(+), 81 deletions(-) diff --git a/docs/reference_throughput.md b/docs/reference_throughput.md index 40b5a4f..5eb692b 100644 --- a/docs/reference_throughput.md +++ b/docs/reference_throughput.md @@ -1,33 +1,36 @@ # Reference Throughput We've benchmarked VectorLM on the Vaughan cluster for a number of model architectures across a variety of node configurations. -In experiments labelled as LoRA, we set hidden dimension to 8. During the testing, the NVIDIA driver version was 525.105.17, CUDA Runtime 12.1.105, and torch 2.2.2. +In experiments labelled as LoRA, we set hidden dimension to 8. Below are version numbers of the testing environment: -For consistency, we use a batch size of 8 and the maximum context length that the pre-trained LLM supports, capped at 65536. Note that especially for smaller models, it might be possible to further increase throughput by switching to a larger batch size. +```bash +$ pip3 freeze|grep -E "(torch|flash-attn|nvidia)" +flash-attn==2.5.8 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==8.9.2.26 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-ml-py==12.550.52 +nvidia-nccl-cu12==2.19.3 +nvidia-nvjitlink-cu12==12.3.101 +nvidia-nvtx-cu12==12.1.105 +torch==2.2.1 +``` -Entries that read NaN represent combinations where the node configuration does not have enough GPU memory for the training run to complete. An exception is gemma-2b, which currently does not support full-rank FSDP fine-tuning. +For each context width and hardware configuration, we experiment with a per-device batch size of 2, 4, and 8. In the table below, we report the batch size that maximizes training throughput. All values in the table represent the median training throughput in tokens/second across all training steps, aggregated across all GPU devices. -All values in the table below represent the median training throughput in tokens per second across all training steps, aggregated across all GPU devices. +| | Meta-Llama-3-8B (2048) | Meta-Llama-3-8B (4096) | Meta-Llama-3-8B (8192) | +| :----------------------------------- | :--------------------- | :--------------------- | :--------------------- | +| (full_rank) NVIDIA A100-SXM4-80GB x1 | 3550.48 (batch: 8) | 3461.64 (batch: 4) | 3204.21 (batch: 2) | +| (full_rank) NVIDIA A100-SXM4-80GB x2 | 6346.00 (batch: 8) | 6182.59 (batch: 4) | 5772.91 (batch: 2) | +| (full_rank) NVIDIA A100-SXM4-80GB x4 | 12688.44 (batch: 8) | 12249.74 (batch: 4) | 11463.46 (batch: 2) | +| (lora) NVIDIA A100-SXM4-80GB x1 | 4079.28 (batch: 8) | 3682.15 (batch: 4) | 3528.93 (batch: 2) | +| (lora) NVIDIA A100-SXM4-80GB x2 | 7182.97 (batch: 8) | 6955.58 (batch: 4) | 6452.96 (batch: 2) | +| (lora) NVIDIA A100-SXM4-80GB x4 | 14299.47 (batch: 8) | 13834.43 (batch: 4) | 12769.23 (batch: 2) | -| | Llama-2-13b-hf | Llama-2-7b-hf | Mistral-7B-v0.1 | Mixtral-8x7B-Instruct-v0.1 | gemma-2b | opt-350m | -| :----------------------------------- | -------------: | ------------: | --------------: | -------------------------: | -------: | -------: | -| (full_rank) NVIDIA A100-SXM4-80GB x1 | 424.726 | 570.818 | 528.747 | nan | nan | 780.045 | -| (full_rank) NVIDIA A100-SXM4-80GB x2 | 660.355 | 919.19 | 794.566 | 275.459 | nan | 1227.67 | -| (full_rank) NVIDIA A100-SXM4-80GB x4 | 1309.4 | 1744.39 | 1577.09 | 817.162 | nan | 2181.46 | -| (full_rank) NVIDIA A40 x1 | nan | 47.6435 | 107.503 | nan | nan | 666.881 | -| (full_rank) NVIDIA A40 x2 | nan | 313.074 | 322.624 | nan | nan | 854.672 | -| (full_rank) NVIDIA A40 x4 | 345.96 | 570.977 | 553.658 | nan | nan | 1765.49 | -| (full_rank) Tesla T4 x1 | nan | nan | nan | nan | nan | 475.51 | -| (full_rank) Tesla T4 x2 | nan | nan | nan | nan | nan | 768.008 | -| (full_rank) Tesla T4 x4 | nan | nan | nan | nan | nan | 1383.6 | -| (full_rank) Tesla T4 x8 | nan | nan | nan | nan | nan | 2414.68 | -| (lora) NVIDIA A100-SXM4-80GB x1 | 560.167 | 646.801 | 525.802 | nan | 851.678 | 859.379 | -| (lora) NVIDIA A100-SXM4-80GB x2 | 871.993 | 1157.17 | 1105.68 | 239.431 | 1724.57 | 1463.82 | -| (lora) NVIDIA A100-SXM4-80GB x4 | 1783.53 | 2091.03 | 2150.06 | 1309.74 | 2719.24 | 2381.01 | -| (lora) NVIDIA A40 x1 | 272.931 | 435.386 | 336.507 | nan | 983.256 | 652.611 | -| (lora) NVIDIA A40 x2 | 105.442 | 457.183 | 356.263 | nan | 725.723 | 1136.17 | -| (lora) NVIDIA A40 x4 | 543.22 | 715.416 | 642.642 | nan | 1302.62 | 1647.57 | -| (lora) Tesla T4 x1 | nan | nan | nan | nan | 148.272 | 571.471 | -| (lora) Tesla T4 x2 | nan | 101.126 | 102.859 | nan | 256.534 | 811.159 | -| (lora) Tesla T4 x4 | nan | 188.575 | 190.127 | nan | 495.755 | 1506.05 | -| (lora) Tesla T4 x8 | 196.709 | 372.375 | 351.361 | nan | 897.81 | 2945.86 | +We provide the tools for evaluating the throughput on different context windows and different hardware/model configuration. Refer to the profiling folder in this repository to get started. \ No newline at end of file diff --git a/profiling/README.md b/profiling/README.md index 47d87b7..7a75f38 100644 --- a/profiling/README.md +++ b/profiling/README.md @@ -13,7 +13,7 @@ $ python3 launch_benchmark.py # to accept and automatically invoke the commands. ``` -After the SLURM jobs complete, profiler output can be found under `data/benchmark`. Invoke the following the to generate a Markdown summary of the results: +After the SLURM jobs complete, profiler output can be found under `data/benchmark`. Invoke the following the to generate a Markdown summary of the results. If the benchmark results include multiple different batch sizes for each (model, context window, hardware) pair, the table would list the "optimal" batch size associated with the highest training throughput for this combination. ```bash $ python3 profiling/parse_benchmark.py --folder data/benchmark diff --git a/profiling/benchmark.py b/profiling/benchmark.py index 58906af..e68cf57 100644 --- a/profiling/benchmark.py +++ b/profiling/benchmark.py @@ -25,7 +25,6 @@ from vectorlm.utils.model_utils import ( get_lora_model_from_base_model, get_submodule_by_pattern, - hook_activation_checkpointing, load_model_and_tokenizer, shard_model, ) @@ -67,7 +66,7 @@ def parse_args() -> Namespace: default=1000, ) parser.add_argument("--max_length", type=int) - parser.add_argument("--training_batch_size", type=int) + parser.add_argument("--per_device_batch_size", type=int) return parser.parse_args() @@ -273,9 +272,26 @@ def load_datasets(self) -> None: setup(config.train_parameters.output_dir) - if args.training_batch_size is not None: - config.dataset.train_bs = args.training_batch_size - write_metrics("training_batch_size", args.training_batch_size) + training_args = config.train_parameters + + # set a seed + set_seed(training_args.seed) + + # set CUDA related dependencies + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + if args.per_device_batch_size is not None: + config.dataset.train_bs = args.per_device_batch_size + config.dataset.eval_bs = args.per_device_batch_size + + write_metrics("training_batch_size", config.dataset.train_bs) + write_metrics("eval_batch_size", config.dataset.eval_bs) + write_metrics( + "training_batch_size_global", + config.dataset.train_bs * world_size, + ) print(f"Writing metrics to {output_path}") write_metrics("model_name", args.model_name) @@ -291,16 +307,6 @@ def load_datasets(self) -> None: repeat=2, ) - training_args = config.train_parameters - - # set a seed - set_seed(training_args.seed) - - # set CUDA related dependencies - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - with track_time("dist_init"): print(f"Rank: {rank}, World size: {world_size}") if dist.is_initialized(): @@ -314,17 +320,18 @@ def load_datasets(self) -> None: # load model and tokenizer lora_peft_config = config.train_parameters.get("lora_peft_config") + is_lora_enabled = lora_peft_config is not None with track_time("model_load"): model, tokenizer = load_model_and_tokenizer( args.model_name, training_args.use_mp, get_is_flash_attention_supported(), - training_args.max_seq_len, + args.max_length, local_rank, training_args.low_cpu_mem_usage, ) - if lora_peft_config is not None: + if is_lora_enabled: print("Enabling LoRA Wrapper.") write_metrics("peft_method", "lora") model = get_lora_model_from_base_model(model, lora_peft_config) @@ -348,12 +355,9 @@ def load_datasets(self) -> None: training_args.sharding_strategy, local_rank, training_args.low_cpu_mem_usage, + is_lora_enabled=is_lora_enabled, ) - with track_time("set_activation_checkpointing"): - if training_args.use_activation_checkpointing: - hook_activation_checkpointing(model, decoder_layer_module) - # load dataset with track_time("dataset_load"): dataset = BenchmarkingDataset( @@ -364,6 +368,10 @@ def load_datasets(self) -> None: max_length=args.max_length, ) + print( + f"Sequence length: {dataset.max_length};" + f"Batch Size (per device): {config.dataset.train_bs}", + ) write_metrics("max_length", dataset.max_length) # instantiate trainer @@ -371,7 +379,6 @@ def load_datasets(self) -> None: config=training_args, enable_wandb_logging=config.enable_wandb_logging, original_dataset_length=dataset.original_length, - timer_handle=track_time, ) # load optimizer @@ -412,15 +419,18 @@ def load_datasets(self) -> None: trainer.model.train() train_dl_iterator = iter(dataset.train_dataloader) for _ in tqdm( - range(args.num_train_examples), + range(len(dataset.train_dataloader)), disable=rank != 0, file=sys.__stdout__, ): batch = next(train_dl_iterator) num_tokens = len(batch["input_ids"].flatten()) - with track_time("train_step", {"num_tokens": num_tokens}): - trainer.step(batch, epoch) + with track_time( + "train_step", + {"num_tokens": num_tokens * world_size}, + ): + trainer.train_step(batch, epoch) profile_handle.step() write_metrics( diff --git a/profiling/configs/benchmark.yaml b/profiling/configs/benchmark.yaml index 76c3b8d..30dee32 100644 --- a/profiling/configs/benchmark.yaml +++ b/profiling/configs/benchmark.yaml @@ -6,7 +6,6 @@ wandb_config: train_parameters: output_dir: /dev/shm/lora-benchmark - max_seq_len: 128 epochs: 1 seed: 11 diff --git a/profiling/configs/lora-benchmark.yaml b/profiling/configs/lora-benchmark.yaml index 4105404..b30ad45 100644 --- a/profiling/configs/lora-benchmark.yaml +++ b/profiling/configs/lora-benchmark.yaml @@ -6,7 +6,6 @@ wandb_config: train_parameters: output_dir: /dev/shm/lora-benchmark - max_seq_len: 128 epochs: 1 seed: 11 diff --git a/profiling/launch_benchmark.py b/profiling/launch_benchmark.py index e9509ed..8833d1a 100644 --- a/profiling/launch_benchmark.py +++ b/profiling/launch_benchmark.py @@ -22,12 +22,13 @@ model_list = [ "/model-weights/" + model_name for model_name in [ - "opt-350m", - "gemma-2b", - "Llama-2-7b-hf", - "Llama-2-13b-hf", - "Mistral-7B-v0.1", - "Mixtral-8x7B-Instruct-v0.1", + # "opt-350m", + # "gemma-2b", + # "Llama-2-7b-hf", + "Meta-Llama-3-8B", + # "Llama-2-13b-hf", + # "Mistral-7B-v0.1", + # "Mixtral-8x7B-Instruct-v0.1", ] ] @@ -37,19 +38,20 @@ ] # Set to (-1) to fall back to the max context length of the pre-trained model. -max_length_list = [1024, 2048, 4096, -1] -batch_size = [8, 16, 32, 64, 128] +max_length_list = [8192, 4096, 2048] +# Per-device batch size for training +per_device_batch_size = [2, 4, 8] slurm_flags_options = { "nodes": [1], "mem-per-gpu": ["16GB"], "ntasks-per-node": [1], "cpus-per-gpu": [3], - "gres": [f"gpu:{n}" for n in [1, 2, 4, 8]], + "gres": [f"gpu:{n}" for n in [4, 2, 1]], "partition": partitions, } -num_repeats = 2 +num_repeats = 1 slurm_flags_extra = {"time": "01:00:00", "qos": qos_selected} slurm_pos_args_options = [ @@ -57,7 +59,7 @@ config_list, model_list, max_length_list, - batch_size, + per_device_batch_size, ] timestamp = int(time.time()) diff --git a/profiling/launch_benchmark.sh b/profiling/launch_benchmark.sh index 66a9e8e..8f00445 100644 --- a/profiling/launch_benchmark.sh +++ b/profiling/launch_benchmark.sh @@ -28,7 +28,7 @@ profiling/benchmark.py \ --yaml_path $1 \ --model_name $2 \ --max_length $3 \ ---training_batch_size $4 +--per_device_batch_size $4 # clean up benchmarking artifacts as ops have requested rm -rf /dev/shm/lora-benchmark diff --git a/profiling/parse_benchmark.py b/profiling/parse_benchmark.py index 887140d..0efd916 100644 --- a/profiling/parse_benchmark.py +++ b/profiling/parse_benchmark.py @@ -15,8 +15,13 @@ Numbers = Union[int, float] NumericalTypes = Union[Numbers, np.ndarray] -Numerical = TypeVar("Numerical", bound=NumericalTypes) +V = TypeVar("V") Aggregator = TypeVar("Aggregator") +Numerical = TypeVar("Numerical", bound=NumericalTypes) + + +# Skip first N train steps (warmup, profiling, etc.) in throughput eval. +NUM_SKIPPED_STEPS = 80 @dataclass @@ -97,7 +102,7 @@ def _reduce_metric( def get_quantiles(values: list[Numbers]) -> np.ndarray: - """Given a list of numerical values, return (min, 25%, 50%, 75%, and max). + """Given a list of numerical values, return (min, 25%, 50%, 75%, 95%, max). Params ------ @@ -108,9 +113,14 @@ def get_quantiles(values: list[Numbers]) -> np.ndarray: np.ndarray. """ + percentiles = [0.25, 0.5, 0.75, 0.95] + + if len(values) == 0: + return [np.nan] * (1 + len(percentiles) + 1) + output_list = [ np.min(values), - *[np.percentile(values, q) for q in [0.25, 0.5, 0.75]], + *[np.percentile(values, q) for q in percentiles], np.max(values), ] @@ -137,8 +147,11 @@ def get_quantiles(values: list[Numbers]) -> np.ndarray: # Set of tuples the form (model_name, device) benchmarked_combinations: set[tuple[str, str]] = set() -aggregated_output: dict[tuple[str, str], RunningAverage] = defaultdict( - lambda: RunningAverage(), +# Map (model, device) pair to dict mapping (batch_size, seq_len) to aggregator. +aggregated_output: dict[tuple[str, str], dict[str, RunningAverage]] = ( + defaultdict( + lambda: defaultdict(lambda: RunningAverage()), + ) ) profiler_tables = defaultdict(dict) @@ -161,15 +174,15 @@ def get_quantiles(values: list[Numbers]) -> np.ndarray: benchmark_output[name] = new_value model_name = benchmark_output.get("model_name") + context_window = benchmark_output.get("max_length") if model_name is None: continue model_name = model_name.split("/")[-1] + model_name = f"{model_name} ({context_window})" source_filename = benchmark_output["_source"] peft_method = benchmark_output.get("peft_method") - if peft_method == "lora" and model_name == "gemma-2b": - print(source_filename) if peft_method is None: continue @@ -181,7 +194,7 @@ def get_quantiles(values: list[Numbers]) -> np.ndarray: world_size = device_info["world_size"] device_description = f"({peft_method}) {device_name} x{world_size}" - # Training throughput can be noisy. Report quantiles instead of avg, + # Training throughput can be noisy. Report median throughput, # and discard instances with only one training step logged. train_step = benchmark_output.get("train_step") if train_step is not None: @@ -189,9 +202,11 @@ def get_quantiles(values: list[Numbers]) -> np.ndarray: time_elapsed = np.asarray(train_step["time_elapsed"]) if num_tokens.flatten().shape[0] > 1: train_throughput = get_quantiles( - world_size * num_tokens / time_elapsed, + (num_tokens / time_elapsed)[NUM_SKIPPED_STEPS:], ) - aggregated_output[(model_name, device_description)].add( + aggregated_output[(model_name, device_description)][ + "batch: " + str(benchmark_output.get("training_batch_size")) + ].add( train_throughput[2], ) @@ -205,8 +220,26 @@ def get_quantiles(values: list[Numbers]) -> np.ndarray: aggregated_output_nested = defaultdict(dict) for combination in benchmarked_combinations: model_name, device_description = combination - throughput = aggregated_output[combination].get_average() - aggregated_output_nested[model_name][device_description] = throughput + # there might be more than one run for each batch size option + # average median throughput over all runs for each option. + # report batch size that achieves optimal (avg) throughput. + throughput: list[tuple[NumericalTypes, str]] = [ + (average, batch_size) + for (average, batch_size) in ( + (aggregation.get_average(), batch_size) + for batch_size, aggregation in aggregated_output[ + combination + ].items() + ) + if average is not None + ] + if len(throughput) == 0: + continue + + optimal_throughput, optimal_batch_size = sorted(throughput, reverse=True)[0] + aggregated_output_nested[model_name][device_description] = ( + f"{optimal_throughput:.2f} ({optimal_batch_size})" + ) throughput_table = ( diff --git a/vectorlm/utils/model_utils.py b/vectorlm/utils/model_utils.py index 5013769..e9ad8f0 100644 --- a/vectorlm/utils/model_utils.py +++ b/vectorlm/utils/model_utils.py @@ -76,7 +76,7 @@ def load_model_and_tokenizer( path: str, use_mp: bool, use_fa: bool, - max_seq_len: int, + max_seq_len: int | None, local_rank: int, low_cpu_mem_usage: bool, use_safetensors: bool = True, @@ -88,7 +88,9 @@ def load_model_and_tokenizer( path: The path where the model and tokenizer are stored. use_mp: Whether to use mixed-precision. use_fa: Whether to use Flash Attention 2. - max_seq_len: The maximum sequence length. + max_seq_len: Override the maximum sequence length of the tokenizer. + Set to None or a negative value to fall back to the + `max_position_embeddings` value from the pretrained model config. local_rank: The local rank of the current worker. low_cpu_mem_usage: Whether to only load model weights on main rank, and then scatter them to the other workers. @@ -129,7 +131,16 @@ def load_model_and_tokenizer( tokenizer = AutoTokenizer.from_pretrained(path) if not tokenizer.pad_token: tokenizer.pad_token_id = tokenizer.eos_token_id - tokenizer.model_max_length = max_seq_len + if (max_seq_len is not None) and (max_seq_len > 0): + tokenizer.model_max_length = max_seq_len + else: + if not hasattr(model.config, "max_position_embeddings"): + msg = ( + "A concrete max_seq_len value is required in your training yaml" + "as max_position_embeddings is not specified in model config." + ) + raise ValueError(msg) + tokenizer.model_max_length = model.config.max_position_embeddings # extend embeddings to a multiple so we use Tensor cores multiple = 64 if "A100" in torch.cuda.get_device_name() else 8