From 1c78f7d517a8d567115d9024024f690e6c9c18c6 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Mon, 4 Nov 2024 04:58:45 +0000 Subject: [PATCH] Support 2D break_point configurations Support Gelu-Bias, Silu-Mul, Bcast-Add, Mul Fusions --- .../python_scheduling/autotune_pointwise.py | 599 ++++++++++-------- 1 file changed, 330 insertions(+), 269 deletions(-) diff --git a/doc/dev/python_scheduling/autotune_pointwise.py b/doc/dev/python_scheduling/autotune_pointwise.py index 014ae8197a2..cbda494b3cb 100644 --- a/doc/dev/python_scheduling/autotune_pointwise.py +++ b/doc/dev/python_scheduling/autotune_pointwise.py @@ -1,15 +1,42 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # Owner(s): ["module: nvfuser"] import torch import itertools -import random -from nvfuser import FusionCache, FusionDefinition, SchedulerType +import math +from nvfuser import FusionDefinition, SchedulerType, DataType +from dataclasses import dataclass +from enum import Enum + +from autotune_utils import ( + ScriptConfiguration, + collect_data, + separate_data, + test_model_rmse, + test_model, +) + # ============================ Description ============================ +# This script defines four pointwise fusions: +# +# 1. GELU with Outer-Broadcast Bias Addition +# y = gelu(x + bias[broadcast, i], approximate='tanh') +# +# 2. SILU with Pointwise Multiplication +# z = silu(x) * y +# +# 3. Inner-Broadcast Addition +# y = x + y[i, broadcast] +# +# 4. Pointwise Multiplication +# z = x + y +# +# Script Sequence: +# # 1. Define a nvfuser fusion and its pytorch eager mode reference. # # 2. Profile the CUDA kernel performance by iterating over a set of input @@ -25,287 +52,321 @@ # - For a specific batch size, gather performance across a range of hidden # sizes. Calculate performance for best predicted and nvfuser # configurations. Plot a chart comparing performance using matplotlib. - +# # The selected performance metric is effective_bandwidth_gbs. The empirical # scheduler selects the configuration that has the highest predicted # effective_bandwidth_gbs. # ============================ Configurations ============================ -# Settings for input tensor generation -num_dimensions = 2 -outer_shapes = [512] -inner_shapes = [2**i for i in range(5, 15)] - -# For pointwise scheduler, we test the cartesian product of vectorization and -# unroll factors. -parameter_configurations = [ - vectorize_range := [1, 2, 4], - unroll_range := list(range(1, 10)), -] - -# We profile a range of input shapes with various configurations. -# This argument determines how much of the profiled data to keep as a test set. -test_data_percentage = 0.1 - -# The selected batch size for empirical and nvfuser comparison. -empirical_batch_size = 512 - -# The range of hidden sizes for empirical and nvfuser comparision. -empirical_hidden_sizes = list(range(256, 28672, 256)) - - -# A decorator to create a pointwise fusion given some input arguments. -def create_fusion_func(inputs): - def fusion_func(fd: FusionDefinition): - t0 = fd.from_pytorch(inputs[0]) - t1 = fd.from_pytorch(inputs[1]) - c0 = fd.define_scalar(3.0) - t2 = fd.ops.add(t0, t1) - t3 = fd.ops.mul(t2, c0) - fd.add_output(t3) - - return fusion_func - - -# The pytorch eager mode reference used to validating nvfuser kernel. -def eager_reference(inputs): - return (inputs[0] + inputs[1]) * 3 - - -# ============================ Function Definitions ============================ - - -# Apply scheduler with custom parameters using decorator -def custom_pointwise_scheduler(fd, config): - def inner_fn(): - # Check if compatible with pointwise scheduler - status, _ = fd.sched.can_schedule(SchedulerType.pointwise) - assert status - - schedule_params = fd.sched.compute_pointwise_heuristics() - - # Modify original parameters - if config is not None: - vectorization_factor, unroll_factor = config - schedule_params.vectorization_factor = vectorization_factor - schedule_params.unroll_factor_inner = unroll_factor - - # Schedule fusion - fd.sched.schedule() - - fd.schedule = inner_fn - return fd - -# Apply schedule decorator, run fusion, and profile performance -def run_profile(presched_fd, inputs, config=None): - scheduled_fd = custom_pointwise_scheduler(presched_fd, config) - nvf_outputs = scheduled_fd.execute(inputs, profile=True) - - # validate correctness - assert torch.allclose(nvf_outputs[0], eager_reference(inputs)) - - prof = scheduled_fd.profile() - bandwidth = prof.kernel_profiles[0].effective_bandwidth_gbs - time = prof.kernel_profiles[0].time_ms - return bandwidth, time - - -def argmax(map_config_to_perf): - best_perf = -1 - best_config = None - for config, perf in map_config_to_perf.items(): - if perf > best_perf: - best_perf = perf - best_config = config - return best_config - - -# Given a prediction model, input_shape, and set of parameter configurations, -# find the best parameters -def find_best_parameters(predictor, input_shape, parameter_configurations): - map_config_to_performance = { - config: predictor.predict([[*input_shape, *config]]) - for config in itertools.product(*parameter_configurations) - } - return argmax(map_config_to_performance) - - -# ============================ Run Experiments ================================ - -# Collect data for decision tree -parameters = [] -performance = [] - -for shape in itertools.product(outer_shapes, inner_shapes): - print(shape) - inputs = [ - torch.randn(*shape, device="cuda"), - torch.randn(*shape, device="cuda"), - ] - - with FusionDefinition() as presched_fd: - create_fusion_func(inputs)(presched_fd) - - # unroll and vectorization configurations - for config in itertools.product(vectorize_range, unroll_range): - perf_metric, _ = run_profile(presched_fd, inputs, config) - parameters.append((*shape, *config)) - performance.append(perf_metric) - -# ============================ Separate Data ================================== - -# Separate collected data into training and test sets -train_data = [] -test_data = [] -train_perf = [] -test_perf = [] -test_shapes = set() -all_test_config = {} # key: input_shape, value: (config, perf) - -for data, perf in zip(parameters, performance): - shape = data[:num_dimensions] - config = data[num_dimensions:] - - if shape in all_test_config: - all_test_config[shape][config] = perf - else: - all_test_config[shape] = {config: perf} - - if random.random() < test_data_percentage: - test_data.append(data) - test_perf.append(perf) - else: - test_shapes.add(shape) - train_data.append(data) - train_perf.append(perf) - -# key: input_shape, value: best_config -best_test_config = {shape: argmax(all_test_config[shape]) for shape in test_shapes} - -# ========================= Train Regression Models =========================== - -# Apply decision tree regressor -# Given input shapes and scheduler parameters, predict performance metric. -from sklearn import tree - -clf = tree.DecisionTreeRegressor() -clf = clf.fit(train_data, train_perf) -test_pred = clf.predict(test_data) - -print("===================== measure performance rmse ========================") - -# Estimate prediction error with RMSE -import numpy as np - -test_perf = np.array(test_perf) -print( - "Test prediction error (RMSE)", - np.sqrt(np.mean(np.power(test_perf - test_pred, 2))), -) -print("Test performance", test_perf) -print("Test prediction", test_pred) +class AutotunePointwise: + class FUSION(Enum): + GELU_BIAS = 1 + SILU_MUL = 2 + BCAST_ADD = 3 + MUL = 4 + + @dataclass(unsafe_hash=True) + class PointwiseConfiguration: + break_point: int + bdim: [int] + vectorize_factor: int + outer_unroll: int + inner_unroll: int + + def __init__(self, selected_fusion): + self.selected_fusion = selected_fusion + + def __repr__(self): + return f"pointwise_{self.selected_fusion.name}" + + # For pointwise scheduler, we test the cartesian product of vectorization and + # unroll factors. + def generate_scheduler_configurations(self, input_shape): + def _named_product(**items): + return itertools.starmap( + self.PointwiseConfiguration, itertools.product(*items.values()) + ) + + num_dimensions = len(input_shape) + warp_size = 32 + warp_group = warp_size * 4 + # limited to a maximum of 128 threads because of pointwise scheduler + max_threads_per_cta = 128 + threads_per_cta = list(range(warp_group, max_threads_per_cta + 1, warp_group)) + + scheduler_configs = [] + for bp in range(num_dimensions): + for num_threads in threads_per_cta: + if bp == 0: + # 1D scheduler configurations + bdim_shapes = [(num_threads, 1)] + outer_unroll_range = [1] + # unroll_factor is between [1, 9] + inner_unroll_range = range(1, 10) + else: + # 2D scheduler configurations + max_bdimy = num_threads // warp_size + log2_max_bdimy = int(math.log2(max_bdimy)) + bdimy_configs = [ + 2**log_bdimy for log_bdimy in range(1, log2_max_bdimy + 1) + ] + + bdim_shapes = [ + (max(warp_size, num_threads // bdimy), bdimy) + for bdimy in bdimy_configs + ] + # total_unroll_factor is between [1, 9] given that outer and + # inner unroll factors are between [1, 3]. + outer_unroll_range = range(1, 4) + inner_unroll_range = range(1, 4) + + scheduler_config = _named_product( + break_point=[bp], + bdim=bdim_shapes, + vectorize_factor=[1, 2, 4, 8], + outer_unroll=outer_unroll_range, + inner_unroll=inner_unroll_range, + ) + scheduler_configs.append(scheduler_config) + return itertools.chain(*scheduler_configs) + + def create_inputs(self, shape, tensor_datatype): + def outer_bcast(): + return [ + torch.randn(1, shape[-1], dtype=tensor_datatype, device="cuda"), + torch.randn(*shape, dtype=tensor_datatype, device="cuda"), + ] + + def inner_bcast(): + return [ + torch.randn(shape[0], 1, dtype=tensor_datatype, device="cuda"), + torch.randn(*shape, dtype=tensor_datatype, device="cuda"), + ] + + def full(): + return [ + torch.randn(*shape, dtype=tensor_datatype, device="cuda"), + torch.randn(*shape, dtype=tensor_datatype, device="cuda"), + ] + + if self.selected_fusion == self.FUSION.GELU_BIAS: + return outer_bcast() + elif self.selected_fusion in [self.FUSION.SILU_MUL, self.FUSION.MUL]: + return full() + elif self.selected_fusion == FUSION.BCAST_ADD: + return inner_bcast() + else: + assert False + + # A decorator to create a pointwise fusion given some input arguments. + def create_fusion_func(self, inputs): + def gelu_bias(fd: FusionDefinition): + T0 = fd.define_tensor( + shape=[1, -1], + contiguity=[None, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T1 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T6 = fd.ops.cast(T1, dtype=DataType.Float) + T7 = fd.ops.cast(T0, dtype=DataType.Float) + T8 = fd.ops.add(T6, T7) + T9 = fd.ops.mul(T8, T8) + T10 = fd.ops.mul(T9, T8) + S11 = fd.define_scalar(0.500000, dtype=DataType.Double) + T12 = fd.ops.mul(S11, T8) + S13 = fd.define_scalar(0.0447150, dtype=DataType.Double) + T14 = fd.ops.mul(S13, T10) + T15 = fd.ops.add(T8, T14) + S16 = fd.define_scalar(0.797885, dtype=DataType.Double) + T17 = fd.ops.mul(S16, T15) + T18 = fd.ops.tanh(T17) + S19 = fd.define_scalar(1.00000, dtype=DataType.Double) + T20 = fd.ops.add(S19, T18) + T21 = fd.ops.mul(T12, T20) + T22 = fd.ops.cast(T21, dtype=DataType.BFloat16) + fd.add_output(T22) + + def silu_mul(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T1 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T2 = fd.ops.cast(T0, dtype=DataType.Float) + T3 = fd.ops.neg(T2) + T4 = fd.ops.exp(T3) + S5 = fd.define_scalar(1.00000, dtype=DataType.Double) + T6 = fd.ops.add(S5, T4) + T7 = fd.ops.reciprocal(T6) + T8 = fd.ops.mul(T2, T7) + T9 = fd.ops.cast(T1, dtype=DataType.Float) + T10 = fd.ops.mul(T8, T9) + T11 = fd.ops.cast(T10, dtype=DataType.BFloat16) + fd.add_output(T11) + + def bcast_add(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[-1, 1], + contiguity=[True, None], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T1 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T2 = fd.ops.cast(T0, dtype=DataType.Float) + T3 = fd.ops.cast(T1, dtype=DataType.Float) + T4 = fd.ops.add(T2, T3) + T5 = fd.ops.cast(T4, dtype=DataType.BFloat16) + fd.add_output(T5) + + def mul(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T1 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T2 = fd.ops.cast(T0, dtype=DataType.Float) + T3 = fd.ops.cast(T1, dtype=DataType.Float) + T4 = fd.ops.mul(T2, T3) + T5 = fd.ops.cast(T4, dtype=DataType.BFloat16) + fd.add_output(T5) + + if self.selected_fusion == self.FUSION.GELU_BIAS: + return gelu_bias + elif self.selected_fusion == self.FUSION.SILU_MUL: + return silu_mul + elif self.selected_fusion == self.FUSION.BCAST_ADD: + return bcast_add + elif self.selected_fusion == self.FUSION.MUL: + return mul + else: + assert False + + # The pytorch eager mode reference used to validating nvfuser kernel. + def eager_reference(self, inputs): + def gelu_bias(inputs): + return torch.nn.functional.gelu( + inputs[0] + inputs[1].unsqueeze(0), approximate="tanh" + ) + + def silu_mul(inputs): + return torch.nn.functional.silu(inputs[0]) * inputs[1] + + def bcast_add(inputs): + return inputs[0] + inputs[1] + + def mul(inputs): + return inputs[0] * inputs[1] + + if self.selected_fusion == self.FUSION.GELU_BIAS: + return gelu_bias(inputs) + elif self.selected__fusion == self.FUSION.SILU_MUL: + return silu_mul(inputs) + elif self.selected_fusion == self.FUSION.BCAST_ADD: + return bcast_add(inputs) + elif self.selected_fusion == self.FUSION.MUL: + return mul(inputs) + else: + assert False + + # Apply scheduler with custom parameters using decorator + def custom_scheduler(self, fd, scheduler_config): + def inner_fn(): + # Check if compatible with pointwise scheduler + status, _ = fd.sched.can_schedule(SchedulerType.pointwise) + assert status + + schedule_params = fd.sched.compute_pointwise_heuristics() + + # Modify original parameters + if scheduler_config is not None: + schedule_params.break_point = scheduler_config.break_point + schedule_params.vectorization_factor = scheduler_config.vectorize_factor + schedule_params.unroll_factor_outer = scheduler_config.outer_unroll + schedule_params.unroll_factor_inner = scheduler_config.inner_unroll + schedule_params.lparams.bdimx = scheduler_config.bdim[0] + schedule_params.lparams.bdimy = scheduler_config.bdim[1] + + # Schedule fusion + fd.sched.schedule() + + fd.schedule = inner_fn + return fd + + +# Run sequence of steps to collect data, train and test model +def main(): + # ====================== Setup Script Configuration ======================= + script_config = ScriptConfiguration( + num_dimensions=2, + outer_shapes=[16384], + inner_shapes=[128, 1024, 4096, 16384], + tensor_datatype=torch.bfloat16, + test_data_percentage=0.1, + empirical_batch_size=16384, + empirical_hidden_sizes=list(range(256, 32768, 256)), + ) -print("======================= compare configurations =======================") -# Find best configuration for test_shapes -print( - "input shape, estimate_config:(vectorization, unroll), actual_config:(vectorization, unroll), correct" -) -correctness_count = 0 -mismatch_configs = [] -for shape in test_shapes: - estimate_config = find_best_parameters(clf, shape, parameter_configurations) - - match_config = estimate_config == best_test_config[shape] - if not match_config: - mismatch_configs.append((shape, estimate_config)) - - correctness_count += int(match_config) - print(f"{shape}, {estimate_config}, {best_test_config[shape]}, {match_config}") -print("% of predictions match nvfuser parameters", correctness_count / len(test_shapes)) -print(correctness_count, "out of", len(test_shapes)) - -print("======================= compare performance =========================") - -for shape, estimate_config in mismatch_configs: - inputs = [ - torch.randn(*shape, device="cuda"), - torch.randn(*shape, device="cuda"), - ] - - with FusionDefinition() as presched_fd: - create_fusion_func(inputs)(presched_fd) - - _, est_perf = run_profile(presched_fd, inputs, estimate_config) - _, nvf_perf = run_profile(presched_fd, inputs) - est_perf_faster = est_perf < nvf_perf - print( - f"{shape} \t estimate_perf:{est_perf:.5f} \t nvfuser_perf:{nvf_perf:.5f} \t is_estimated_config_faster:\t{est_perf_faster}" + autotune_config = AutotunePointwise( + selected_fusion=AutotunePointwise.FUSION.GELU_BIAS ) -print("=====================================================================") + # ============================ Run Experiments ============================ -# For a specific batch size, gather performance across a range of hidden sizes. -# Calculate performance for best predicted and nvfuser configurations. Plot a -# chart comparing performance using matplotlib. + parameters, performance = collect_data(script_config, autotune_config) -# NOTE: The matplotlib experiment plots the kernel runtime, which could be -# different than the selected performance metric. Currently, the performance -# metric is effective_bandwidth_gbs. + # ============================ Separate Data ============================== -import matplotlib.pyplot as plt -import numpy as np + train_data, test_data = separate_data(script_config, parameters, performance) -FusionCache.reset() -est_perfs = [] -for hidden_shape in empirical_hidden_sizes: - inputs = [ - torch.randn(empirical_batch_size, hidden_shape, device="cuda"), - torch.randn(empirical_batch_size, hidden_shape, device="cuda"), - ] - estimate_config = find_best_parameters( - clf, (empirical_batch_size, hidden_shape), parameter_configurations - ) + # ========================= Train Regression Models ======================= - with FusionDefinition() as presched_fd: - create_fusion_func(inputs)(presched_fd) + # Apply machine learning regressor + # Given input shapes and scheduler parameters, predict performance metric. + from sklearn import ensemble - _, est_time_ms = run_profile(presched_fd, inputs, estimate_config) - est_perfs.append(est_time_ms) - print( - f"{empirical_batch_size}, {hidden_shape}, {estimate_config}, {est_time_ms:.3f}" - ) + train_inputs, train_perf = train_data + clf = ensemble.RandomForestRegressor() + clf = clf.fit(train_inputs, train_perf) + + # ========================= Test Regression Models ======================== + test_model_rmse(clf, script_config, autotune_config, test_data) + test_model(clf, script_config, autotune_config) -FusionCache.reset() -nvf_perfs = [] -for hidden_shape in empirical_hidden_sizes: - inputs = [ - torch.randn(empirical_batch_size, hidden_shape, device="cuda"), - torch.randn(empirical_batch_size, hidden_shape, device="cuda"), - ] - - with FusionDefinition() as presched_fd: - create_fusion_func(inputs)(presched_fd) - - _, nvf_time_ms = run_profile(presched_fd, inputs) - nvf_perfs.append(nvf_time_ms) - print(f"{empirical_batch_size}, {hidden_shape}, {nvf_time_ms:.3f}") - -# Get mean speed-up from nvfuser to empirical configurations across all input shapes. -# Negative value mean empirical configurations are slower than nvfuser. -print("Mean speed-up", np.mean(np.array(nvf_perfs) - np.array(est_perfs))) - -np_hidden_size = np.array(empirical_hidden_sizes) -plt.plot(np_hidden_size, np.array(est_perfs)) -plt.plot(np_hidden_size, np.array(nvf_perfs)) - -plt.xlabel("Hidden Size") -plt.ylabel("Time(ms)") -plt.title( - f"Batch Size = {empirical_batch_size}, Compare Decision Tree Heuristic vs NvFuser" -) -plt.legend(["decision_tree", "nvfuser"], loc="lower right") -plt.savefig(f"pointwise_empirical_batchsize{empirical_batch_size}.png") -# ============================================================================= +if __name__ == "__main__": + main()