From ab34c7aac8b207156d5101206af5cca2b6b1002d Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Sun, 24 Nov 2024 10:56:18 -0800 Subject: [PATCH 1/8] Create Autotuning utilities --- doc/dev/python_scheduling/autotune_utils.py | 298 ++++++++++++++++++++ 1 file changed, 298 insertions(+) create mode 100644 doc/dev/python_scheduling/autotune_utils.py diff --git a/doc/dev/python_scheduling/autotune_utils.py b/doc/dev/python_scheduling/autotune_utils.py new file mode 100644 index 00000000000..4017c6c87f8 --- /dev/null +++ b/doc/dev/python_scheduling/autotune_utils.py @@ -0,0 +1,298 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# Owner(s): ["module: nvfuser"] + +import torch +import itertools +from nvfuser import FusionCache, FusionDefinition +from dataclasses import dataclass, astuple + +# ================================ Description ================================ +# This file contains the utility function for autotuning scripts. +# ============================================================================= + + +@dataclass +class ScriptConfiguration: + # Settings for input tensor generation + # number of dimensions in the tensor argument + num_dimensions: int + + # the data type for the tensor argument + tensor_datatype: torch.dtype + + # During training, the cartesian product of outer_shapes and inner_shapes + # is used to define the shape of the input tensor arguments. + outer_shapes: [int] + inner_shapes: [int] + + # We profile a range of input shapes with various configurations. + # This argument determines how much of the profiled data to keep as a test + # set. + test_data_percentage: [float] + + # The selected batch size for empirical and nvfuser comparison. + empirical_batch_size: [int] + + # The range of hidden sizes for empirical and nvfuser comparision. + empirical_hidden_sizes: [int] + + +# Converted DataClass to a Tuple. It flattens nested tuples. The function is +# used for compatibility with machine learning model. +def flatten_configuration(scheduler_config): + new_scheduler_config = [] + for item in astuple(scheduler_config): + if type(item) is tuple: + new_scheduler_config.extend(item) + else: + new_scheduler_config.append(item) + return tuple(new_scheduler_config) + + +# Collect data for machine learning +def collect_data(script_config, autotune_config): + parameters = [] + performance = [] + + for shape in itertools.product( + script_config.outer_shapes, script_config.inner_shapes + ): + print(shape) + inputs = autotune_config.create_inputs(shape, script_config.tensor_datatype) + + with FusionDefinition() as presched_fd: + autotune_config.create_fusion_func(inputs)(presched_fd) + + # unroll and vectorization configurations + for parameter_config in autotune_config.generate_scheduler_configurations( + shape + ): + perf_metric, _ = run_profile( + autotune_config, presched_fd, inputs, parameter_config + ) + parameters.append((*shape, *flatten_configuration(parameter_config))) + performance.append(perf_metric) + return parameters, performance + + +# Separate collected data into training and test sets +def separate_data(script_config, parameters, performance): + import random + + train_inputs = [] + test_inputs = [] + train_perf = [] + test_perf = [] + test_shapes = set() + all_test_scheduler_config = {} # key: input_shape, value: (scheduler_config, perf) + + for data, perf in zip(parameters, performance): + shape = data[: script_config.num_dimensions] + scheduler_config = data[script_config.num_dimensions :] + + if shape in all_test_scheduler_config: + all_test_scheduler_config[shape][scheduler_config] = perf + else: + all_test_scheduler_config[shape] = {scheduler_config: perf} + + if ( + script_config.test_data_percentage > 0 + and random.random() < script_config.test_data_percentage + ): + test_shapes.add(shape) + test_inputs.append(data) + test_perf.append(perf) + else: + train_inputs.append(data) + train_perf.append(perf) + + # key: input_shape, value: best_scheduler_config + best_test_scheduler_config = { + shape: argmax(all_test_scheduler_config[shape]) for shape in test_shapes + } + + return (train_inputs, train_perf), ( + test_inputs, + test_perf, + test_shapes, + best_test_scheduler_config, + ) + + +# Apply schedule decorator, run fusion, and profile performance +def run_profile(autotune_config, presched_fd, inputs, scheduler_config=None): + scheduled_fd = autotune_config.custom_scheduler(presched_fd, scheduler_config) + nvf_outputs = scheduled_fd.execute(inputs, profile=True) + + # validate correctness + assert torch.allclose( + nvf_outputs[0], autotune_config.eager_reference(inputs), atol=1e-2, rtol=1e-2 + ) + + prof = scheduled_fd.profile() + bandwidth = prof.kernel_profiles[0].effective_bandwidth_gbs + time = prof.kernel_profiles[0].time_ms + return bandwidth, time + + +# Given a map from scheduler configuration to predicted performance, find the +# configuration with the maximum predicted performance +def argmax(map_scheduler_config_to_perf): + best_perf = -1 + best_scheduler_config = None + for scheduler_config, perf in map_scheduler_config_to_perf.items(): + if perf > best_perf: + best_perf = perf + best_scheduler_config = scheduler_config + return best_scheduler_config + + +# Given a prediction model, input_shape, and set of parameter configurations, +# find the best parameters +def find_best_parameters(clf, input_shape, scheduler_configurations): + map_scheduler_config_to_performance = { + scheduler_config: clf.predict( + [[*input_shape, *flatten_configuration(scheduler_config)]] + ) + for scheduler_config in scheduler_configurations + } + return argmax(map_scheduler_config_to_performance) + + +# Measure model performance with RMSE +def test_model_rmse(clf, script_config, autotune_config, test_data): + test_inputs, test_perf, test_shapes, best_test_scheduler_config = test_data + test_pred = clf.predict(test_inputs) + + # Estimate prediction error with RMSE + import numpy as np + + test_perf = np.array(test_perf) + print( + "Test prediction error (RMSE)", + np.sqrt(np.mean(np.power(test_perf - test_pred, 2))), + ) + print("Test performance", test_perf) + print("Test prediction", test_pred) + + print("======================= compare configurations =======================") + # Find best configuration for test_shapes + print("input shape, estimate_config, actual_config, correct") + correctness_count = 0 + mismatch_configs = [] + for shape in test_shapes: + estimate_config = find_best_parameters( + clf, shape, autotune_config.generate_scheduler_configurations(shape) + ) + + match_config = ( + flatten_configuration(estimate_config) == best_test_scheduler_config[shape] + ) + if not match_config: + mismatch_configs.append((shape, estimate_config)) + + correctness_count += int(match_config) + print( + f"{shape}, {estimate_config}, {best_test_scheduler_config[shape]}, {match_config}" + ) + print( + "% of predictions match nvfuser parameters", + correctness_count / len(test_shapes), + ) + print(correctness_count, "out of", len(test_shapes)) + + print("======================= compare performance =========================") + + for shape, estimate_config in mismatch_configs: + inputs = autotune_config.create_inputs(shape, script_config.tensor_datatype) + + with FusionDefinition() as presched_fd: + autotune_config.create_fusion_func(inputs)(presched_fd) + + _, est_perf = run_profile(autotune_config, presched_fd, inputs, estimate_config) + _, nvf_perf = run_profile(autotune_config, presched_fd, inputs) + est_perf_faster = est_perf < nvf_perf + print( + f"{shape} \t estimate_perf: {est_perf: .5f} \t nvfuser_perf: {nvf_perf: .5f} \t is_estimated_config_faster: {est_perf_faster}" + ) + print("=====================================================================") + + +# Given a machine learning model, compare the performance of its predicted configuration +# against nvfuser on a given fusion +def test_model(clf, script_config, autotune_config): + # For a specific batch size, gather performance across a range of hidden sizes. + # Calculate performance for best predicted and nvfuser configurations. Plot a + # chart comparing performance using matplotlib. + + # NOTE: The matplotlib experiment plots the kernel runtime, which could be + # different than the selected performance metric. Currently, the performance + # metric is effective_bandwidth_gbs. + + import matplotlib.pyplot as plt + import numpy as np + + FusionCache.reset() + est_perfs = [] + for hidden_shape in script_config.empirical_hidden_sizes: + inputs = autotune_config.create_inputs( + (script_config.empirical_batch_size, hidden_shape), + script_config.tensor_datatype, + ) + + estimate_config = find_best_parameters( + clf, + (script_config.empirical_batch_size, hidden_shape), + autotune_config.generate_scheduler_configurations( + (script_config.empirical_batch_size, hidden_shape) + ), + ) + + with FusionDefinition() as presched_fd: + autotune_config.create_fusion_func(inputs)(presched_fd) + + _, est_time_ms = run_profile( + autotune_config, presched_fd, inputs, estimate_config + ) + est_perfs.append(est_time_ms) + print( + f"{script_config.empirical_batch_size}, {hidden_shape}, {estimate_config}, {est_time_ms: .3f}" + ) + + FusionCache.reset() + nvf_perfs = [] + for hidden_shape in script_config.empirical_hidden_sizes: + inputs = autotune_config.create_inputs( + (script_config.empirical_batch_size, hidden_shape), + script_config.tensor_datatype, + ) + + with FusionDefinition() as presched_fd: + autotune_config.create_fusion_func(inputs)(presched_fd) + + _, nvf_time_ms = run_profile(autotune_config, presched_fd, inputs) + nvf_perfs.append(nvf_time_ms) + print( + f"{script_config.empirical_batch_size}, {hidden_shape}, {nvf_time_ms: .3f}" + ) + + # Get mean speed-up from nvfuser to empirical configurations across all input shapes. + # Negative value mean empirical configurations are slower than nvfuser. + print("Mean speed-up", np.mean(np.array(nvf_perfs) - np.array(est_perfs))) + + np_hidden_size = np.array(script_config.empirical_hidden_sizes) + plt.plot(np_hidden_size, np.array(est_perfs)) + plt.plot(np_hidden_size, np.array(nvf_perfs)) + + plt.xlabel("Hidden Size") + plt.ylabel("Time(ms)") + plt.title( + f"Batch Size = {script_config.empirical_batch_size}, Compare Machine Learning Heuristic vs NvFuser" + ) + plt.legend(["random_forest", "nvfuser"], loc="lower right") + plt.savefig( + f"{autotune_config}_empirical_batch_size_{script_config.empirical_batch_size}.png" + ) + plt.close("all") From 45b74f6a0e0763cc0a2f93c4b7c763cc5020b286 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Mon, 4 Nov 2024 04:58:45 +0000 Subject: [PATCH 2/8] Support 2D break_point configurations Support Gelu-Bias, Silu-Mul, Bcast-Add, Mul Fusions --- .../python_scheduling/autotune_pointwise.py | 599 ++++++++++-------- 1 file changed, 330 insertions(+), 269 deletions(-) diff --git a/doc/dev/python_scheduling/autotune_pointwise.py b/doc/dev/python_scheduling/autotune_pointwise.py index 014ae8197a2..cbda494b3cb 100644 --- a/doc/dev/python_scheduling/autotune_pointwise.py +++ b/doc/dev/python_scheduling/autotune_pointwise.py @@ -1,15 +1,42 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # Owner(s): ["module: nvfuser"] import torch import itertools -import random -from nvfuser import FusionCache, FusionDefinition, SchedulerType +import math +from nvfuser import FusionDefinition, SchedulerType, DataType +from dataclasses import dataclass +from enum import Enum + +from autotune_utils import ( + ScriptConfiguration, + collect_data, + separate_data, + test_model_rmse, + test_model, +) + # ============================ Description ============================ +# This script defines four pointwise fusions: +# +# 1. GELU with Outer-Broadcast Bias Addition +# y = gelu(x + bias[broadcast, i], approximate='tanh') +# +# 2. SILU with Pointwise Multiplication +# z = silu(x) * y +# +# 3. Inner-Broadcast Addition +# y = x + y[i, broadcast] +# +# 4. Pointwise Multiplication +# z = x + y +# +# Script Sequence: +# # 1. Define a nvfuser fusion and its pytorch eager mode reference. # # 2. Profile the CUDA kernel performance by iterating over a set of input @@ -25,287 +52,321 @@ # - For a specific batch size, gather performance across a range of hidden # sizes. Calculate performance for best predicted and nvfuser # configurations. Plot a chart comparing performance using matplotlib. - +# # The selected performance metric is effective_bandwidth_gbs. The empirical # scheduler selects the configuration that has the highest predicted # effective_bandwidth_gbs. # ============================ Configurations ============================ -# Settings for input tensor generation -num_dimensions = 2 -outer_shapes = [512] -inner_shapes = [2**i for i in range(5, 15)] - -# For pointwise scheduler, we test the cartesian product of vectorization and -# unroll factors. -parameter_configurations = [ - vectorize_range := [1, 2, 4], - unroll_range := list(range(1, 10)), -] - -# We profile a range of input shapes with various configurations. -# This argument determines how much of the profiled data to keep as a test set. -test_data_percentage = 0.1 - -# The selected batch size for empirical and nvfuser comparison. -empirical_batch_size = 512 - -# The range of hidden sizes for empirical and nvfuser comparision. -empirical_hidden_sizes = list(range(256, 28672, 256)) - - -# A decorator to create a pointwise fusion given some input arguments. -def create_fusion_func(inputs): - def fusion_func(fd: FusionDefinition): - t0 = fd.from_pytorch(inputs[0]) - t1 = fd.from_pytorch(inputs[1]) - c0 = fd.define_scalar(3.0) - t2 = fd.ops.add(t0, t1) - t3 = fd.ops.mul(t2, c0) - fd.add_output(t3) - - return fusion_func - - -# The pytorch eager mode reference used to validating nvfuser kernel. -def eager_reference(inputs): - return (inputs[0] + inputs[1]) * 3 - - -# ============================ Function Definitions ============================ - - -# Apply scheduler with custom parameters using decorator -def custom_pointwise_scheduler(fd, config): - def inner_fn(): - # Check if compatible with pointwise scheduler - status, _ = fd.sched.can_schedule(SchedulerType.pointwise) - assert status - - schedule_params = fd.sched.compute_pointwise_heuristics() - - # Modify original parameters - if config is not None: - vectorization_factor, unroll_factor = config - schedule_params.vectorization_factor = vectorization_factor - schedule_params.unroll_factor_inner = unroll_factor - - # Schedule fusion - fd.sched.schedule() - - fd.schedule = inner_fn - return fd - -# Apply schedule decorator, run fusion, and profile performance -def run_profile(presched_fd, inputs, config=None): - scheduled_fd = custom_pointwise_scheduler(presched_fd, config) - nvf_outputs = scheduled_fd.execute(inputs, profile=True) - - # validate correctness - assert torch.allclose(nvf_outputs[0], eager_reference(inputs)) - - prof = scheduled_fd.profile() - bandwidth = prof.kernel_profiles[0].effective_bandwidth_gbs - time = prof.kernel_profiles[0].time_ms - return bandwidth, time - - -def argmax(map_config_to_perf): - best_perf = -1 - best_config = None - for config, perf in map_config_to_perf.items(): - if perf > best_perf: - best_perf = perf - best_config = config - return best_config - - -# Given a prediction model, input_shape, and set of parameter configurations, -# find the best parameters -def find_best_parameters(predictor, input_shape, parameter_configurations): - map_config_to_performance = { - config: predictor.predict([[*input_shape, *config]]) - for config in itertools.product(*parameter_configurations) - } - return argmax(map_config_to_performance) - - -# ============================ Run Experiments ================================ - -# Collect data for decision tree -parameters = [] -performance = [] - -for shape in itertools.product(outer_shapes, inner_shapes): - print(shape) - inputs = [ - torch.randn(*shape, device="cuda"), - torch.randn(*shape, device="cuda"), - ] - - with FusionDefinition() as presched_fd: - create_fusion_func(inputs)(presched_fd) - - # unroll and vectorization configurations - for config in itertools.product(vectorize_range, unroll_range): - perf_metric, _ = run_profile(presched_fd, inputs, config) - parameters.append((*shape, *config)) - performance.append(perf_metric) - -# ============================ Separate Data ================================== - -# Separate collected data into training and test sets -train_data = [] -test_data = [] -train_perf = [] -test_perf = [] -test_shapes = set() -all_test_config = {} # key: input_shape, value: (config, perf) - -for data, perf in zip(parameters, performance): - shape = data[:num_dimensions] - config = data[num_dimensions:] - - if shape in all_test_config: - all_test_config[shape][config] = perf - else: - all_test_config[shape] = {config: perf} - - if random.random() < test_data_percentage: - test_data.append(data) - test_perf.append(perf) - else: - test_shapes.add(shape) - train_data.append(data) - train_perf.append(perf) - -# key: input_shape, value: best_config -best_test_config = {shape: argmax(all_test_config[shape]) for shape in test_shapes} - -# ========================= Train Regression Models =========================== - -# Apply decision tree regressor -# Given input shapes and scheduler parameters, predict performance metric. -from sklearn import tree - -clf = tree.DecisionTreeRegressor() -clf = clf.fit(train_data, train_perf) -test_pred = clf.predict(test_data) - -print("===================== measure performance rmse ========================") - -# Estimate prediction error with RMSE -import numpy as np - -test_perf = np.array(test_perf) -print( - "Test prediction error (RMSE)", - np.sqrt(np.mean(np.power(test_perf - test_pred, 2))), -) -print("Test performance", test_perf) -print("Test prediction", test_pred) +class AutotunePointwise: + class FUSION(Enum): + GELU_BIAS = 1 + SILU_MUL = 2 + BCAST_ADD = 3 + MUL = 4 + + @dataclass(unsafe_hash=True) + class PointwiseConfiguration: + break_point: int + bdim: [int] + vectorize_factor: int + outer_unroll: int + inner_unroll: int + + def __init__(self, selected_fusion): + self.selected_fusion = selected_fusion + + def __repr__(self): + return f"pointwise_{self.selected_fusion.name}" + + # For pointwise scheduler, we test the cartesian product of vectorization and + # unroll factors. + def generate_scheduler_configurations(self, input_shape): + def _named_product(**items): + return itertools.starmap( + self.PointwiseConfiguration, itertools.product(*items.values()) + ) + + num_dimensions = len(input_shape) + warp_size = 32 + warp_group = warp_size * 4 + # limited to a maximum of 128 threads because of pointwise scheduler + max_threads_per_cta = 128 + threads_per_cta = list(range(warp_group, max_threads_per_cta + 1, warp_group)) + + scheduler_configs = [] + for bp in range(num_dimensions): + for num_threads in threads_per_cta: + if bp == 0: + # 1D scheduler configurations + bdim_shapes = [(num_threads, 1)] + outer_unroll_range = [1] + # unroll_factor is between [1, 9] + inner_unroll_range = range(1, 10) + else: + # 2D scheduler configurations + max_bdimy = num_threads // warp_size + log2_max_bdimy = int(math.log2(max_bdimy)) + bdimy_configs = [ + 2**log_bdimy for log_bdimy in range(1, log2_max_bdimy + 1) + ] + + bdim_shapes = [ + (max(warp_size, num_threads // bdimy), bdimy) + for bdimy in bdimy_configs + ] + # total_unroll_factor is between [1, 9] given that outer and + # inner unroll factors are between [1, 3]. + outer_unroll_range = range(1, 4) + inner_unroll_range = range(1, 4) + + scheduler_config = _named_product( + break_point=[bp], + bdim=bdim_shapes, + vectorize_factor=[1, 2, 4, 8], + outer_unroll=outer_unroll_range, + inner_unroll=inner_unroll_range, + ) + scheduler_configs.append(scheduler_config) + return itertools.chain(*scheduler_configs) + + def create_inputs(self, shape, tensor_datatype): + def outer_bcast(): + return [ + torch.randn(1, shape[-1], dtype=tensor_datatype, device="cuda"), + torch.randn(*shape, dtype=tensor_datatype, device="cuda"), + ] + + def inner_bcast(): + return [ + torch.randn(shape[0], 1, dtype=tensor_datatype, device="cuda"), + torch.randn(*shape, dtype=tensor_datatype, device="cuda"), + ] + + def full(): + return [ + torch.randn(*shape, dtype=tensor_datatype, device="cuda"), + torch.randn(*shape, dtype=tensor_datatype, device="cuda"), + ] + + if self.selected_fusion == self.FUSION.GELU_BIAS: + return outer_bcast() + elif self.selected_fusion in [self.FUSION.SILU_MUL, self.FUSION.MUL]: + return full() + elif self.selected_fusion == FUSION.BCAST_ADD: + return inner_bcast() + else: + assert False + + # A decorator to create a pointwise fusion given some input arguments. + def create_fusion_func(self, inputs): + def gelu_bias(fd: FusionDefinition): + T0 = fd.define_tensor( + shape=[1, -1], + contiguity=[None, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T1 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T6 = fd.ops.cast(T1, dtype=DataType.Float) + T7 = fd.ops.cast(T0, dtype=DataType.Float) + T8 = fd.ops.add(T6, T7) + T9 = fd.ops.mul(T8, T8) + T10 = fd.ops.mul(T9, T8) + S11 = fd.define_scalar(0.500000, dtype=DataType.Double) + T12 = fd.ops.mul(S11, T8) + S13 = fd.define_scalar(0.0447150, dtype=DataType.Double) + T14 = fd.ops.mul(S13, T10) + T15 = fd.ops.add(T8, T14) + S16 = fd.define_scalar(0.797885, dtype=DataType.Double) + T17 = fd.ops.mul(S16, T15) + T18 = fd.ops.tanh(T17) + S19 = fd.define_scalar(1.00000, dtype=DataType.Double) + T20 = fd.ops.add(S19, T18) + T21 = fd.ops.mul(T12, T20) + T22 = fd.ops.cast(T21, dtype=DataType.BFloat16) + fd.add_output(T22) + + def silu_mul(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T1 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T2 = fd.ops.cast(T0, dtype=DataType.Float) + T3 = fd.ops.neg(T2) + T4 = fd.ops.exp(T3) + S5 = fd.define_scalar(1.00000, dtype=DataType.Double) + T6 = fd.ops.add(S5, T4) + T7 = fd.ops.reciprocal(T6) + T8 = fd.ops.mul(T2, T7) + T9 = fd.ops.cast(T1, dtype=DataType.Float) + T10 = fd.ops.mul(T8, T9) + T11 = fd.ops.cast(T10, dtype=DataType.BFloat16) + fd.add_output(T11) + + def bcast_add(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[-1, 1], + contiguity=[True, None], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T1 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T2 = fd.ops.cast(T0, dtype=DataType.Float) + T3 = fd.ops.cast(T1, dtype=DataType.Float) + T4 = fd.ops.add(T2, T3) + T5 = fd.ops.cast(T4, dtype=DataType.BFloat16) + fd.add_output(T5) + + def mul(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T1 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T2 = fd.ops.cast(T0, dtype=DataType.Float) + T3 = fd.ops.cast(T1, dtype=DataType.Float) + T4 = fd.ops.mul(T2, T3) + T5 = fd.ops.cast(T4, dtype=DataType.BFloat16) + fd.add_output(T5) + + if self.selected_fusion == self.FUSION.GELU_BIAS: + return gelu_bias + elif self.selected_fusion == self.FUSION.SILU_MUL: + return silu_mul + elif self.selected_fusion == self.FUSION.BCAST_ADD: + return bcast_add + elif self.selected_fusion == self.FUSION.MUL: + return mul + else: + assert False + + # The pytorch eager mode reference used to validating nvfuser kernel. + def eager_reference(self, inputs): + def gelu_bias(inputs): + return torch.nn.functional.gelu( + inputs[0] + inputs[1].unsqueeze(0), approximate="tanh" + ) + + def silu_mul(inputs): + return torch.nn.functional.silu(inputs[0]) * inputs[1] + + def bcast_add(inputs): + return inputs[0] + inputs[1] + + def mul(inputs): + return inputs[0] * inputs[1] + + if self.selected_fusion == self.FUSION.GELU_BIAS: + return gelu_bias(inputs) + elif self.selected__fusion == self.FUSION.SILU_MUL: + return silu_mul(inputs) + elif self.selected_fusion == self.FUSION.BCAST_ADD: + return bcast_add(inputs) + elif self.selected_fusion == self.FUSION.MUL: + return mul(inputs) + else: + assert False + + # Apply scheduler with custom parameters using decorator + def custom_scheduler(self, fd, scheduler_config): + def inner_fn(): + # Check if compatible with pointwise scheduler + status, _ = fd.sched.can_schedule(SchedulerType.pointwise) + assert status + + schedule_params = fd.sched.compute_pointwise_heuristics() + + # Modify original parameters + if scheduler_config is not None: + schedule_params.break_point = scheduler_config.break_point + schedule_params.vectorization_factor = scheduler_config.vectorize_factor + schedule_params.unroll_factor_outer = scheduler_config.outer_unroll + schedule_params.unroll_factor_inner = scheduler_config.inner_unroll + schedule_params.lparams.bdimx = scheduler_config.bdim[0] + schedule_params.lparams.bdimy = scheduler_config.bdim[1] + + # Schedule fusion + fd.sched.schedule() + + fd.schedule = inner_fn + return fd + + +# Run sequence of steps to collect data, train and test model +def main(): + # ====================== Setup Script Configuration ======================= + script_config = ScriptConfiguration( + num_dimensions=2, + outer_shapes=[16384], + inner_shapes=[128, 1024, 4096, 16384], + tensor_datatype=torch.bfloat16, + test_data_percentage=0.1, + empirical_batch_size=16384, + empirical_hidden_sizes=list(range(256, 32768, 256)), + ) -print("======================= compare configurations =======================") -# Find best configuration for test_shapes -print( - "input shape, estimate_config:(vectorization, unroll), actual_config:(vectorization, unroll), correct" -) -correctness_count = 0 -mismatch_configs = [] -for shape in test_shapes: - estimate_config = find_best_parameters(clf, shape, parameter_configurations) - - match_config = estimate_config == best_test_config[shape] - if not match_config: - mismatch_configs.append((shape, estimate_config)) - - correctness_count += int(match_config) - print(f"{shape}, {estimate_config}, {best_test_config[shape]}, {match_config}") -print("% of predictions match nvfuser parameters", correctness_count / len(test_shapes)) -print(correctness_count, "out of", len(test_shapes)) - -print("======================= compare performance =========================") - -for shape, estimate_config in mismatch_configs: - inputs = [ - torch.randn(*shape, device="cuda"), - torch.randn(*shape, device="cuda"), - ] - - with FusionDefinition() as presched_fd: - create_fusion_func(inputs)(presched_fd) - - _, est_perf = run_profile(presched_fd, inputs, estimate_config) - _, nvf_perf = run_profile(presched_fd, inputs) - est_perf_faster = est_perf < nvf_perf - print( - f"{shape} \t estimate_perf:{est_perf:.5f} \t nvfuser_perf:{nvf_perf:.5f} \t is_estimated_config_faster:\t{est_perf_faster}" + autotune_config = AutotunePointwise( + selected_fusion=AutotunePointwise.FUSION.GELU_BIAS ) -print("=====================================================================") + # ============================ Run Experiments ============================ -# For a specific batch size, gather performance across a range of hidden sizes. -# Calculate performance for best predicted and nvfuser configurations. Plot a -# chart comparing performance using matplotlib. + parameters, performance = collect_data(script_config, autotune_config) -# NOTE: The matplotlib experiment plots the kernel runtime, which could be -# different than the selected performance metric. Currently, the performance -# metric is effective_bandwidth_gbs. + # ============================ Separate Data ============================== -import matplotlib.pyplot as plt -import numpy as np + train_data, test_data = separate_data(script_config, parameters, performance) -FusionCache.reset() -est_perfs = [] -for hidden_shape in empirical_hidden_sizes: - inputs = [ - torch.randn(empirical_batch_size, hidden_shape, device="cuda"), - torch.randn(empirical_batch_size, hidden_shape, device="cuda"), - ] - estimate_config = find_best_parameters( - clf, (empirical_batch_size, hidden_shape), parameter_configurations - ) + # ========================= Train Regression Models ======================= - with FusionDefinition() as presched_fd: - create_fusion_func(inputs)(presched_fd) + # Apply machine learning regressor + # Given input shapes and scheduler parameters, predict performance metric. + from sklearn import ensemble - _, est_time_ms = run_profile(presched_fd, inputs, estimate_config) - est_perfs.append(est_time_ms) - print( - f"{empirical_batch_size}, {hidden_shape}, {estimate_config}, {est_time_ms:.3f}" - ) + train_inputs, train_perf = train_data + clf = ensemble.RandomForestRegressor() + clf = clf.fit(train_inputs, train_perf) + + # ========================= Test Regression Models ======================== + test_model_rmse(clf, script_config, autotune_config, test_data) + test_model(clf, script_config, autotune_config) -FusionCache.reset() -nvf_perfs = [] -for hidden_shape in empirical_hidden_sizes: - inputs = [ - torch.randn(empirical_batch_size, hidden_shape, device="cuda"), - torch.randn(empirical_batch_size, hidden_shape, device="cuda"), - ] - - with FusionDefinition() as presched_fd: - create_fusion_func(inputs)(presched_fd) - - _, nvf_time_ms = run_profile(presched_fd, inputs) - nvf_perfs.append(nvf_time_ms) - print(f"{empirical_batch_size}, {hidden_shape}, {nvf_time_ms:.3f}") - -# Get mean speed-up from nvfuser to empirical configurations across all input shapes. -# Negative value mean empirical configurations are slower than nvfuser. -print("Mean speed-up", np.mean(np.array(nvf_perfs) - np.array(est_perfs))) - -np_hidden_size = np.array(empirical_hidden_sizes) -plt.plot(np_hidden_size, np.array(est_perfs)) -plt.plot(np_hidden_size, np.array(nvf_perfs)) - -plt.xlabel("Hidden Size") -plt.ylabel("Time(ms)") -plt.title( - f"Batch Size = {empirical_batch_size}, Compare Decision Tree Heuristic vs NvFuser" -) -plt.legend(["decision_tree", "nvfuser"], loc="lower right") -plt.savefig(f"pointwise_empirical_batchsize{empirical_batch_size}.png") -# ============================================================================= +if __name__ == "__main__": + main() From 6fb68320e333bbe374e34f16ad251af2f453bac2 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Wed, 20 Nov 2024 10:51:22 -0800 Subject: [PATCH 3/8] Create 2d inner reduction autotuning script --- .../autotune_inner_reduction.py | 406 ++++++++++++++++++ 1 file changed, 406 insertions(+) create mode 100644 doc/dev/python_scheduling/autotune_inner_reduction.py diff --git a/doc/dev/python_scheduling/autotune_inner_reduction.py b/doc/dev/python_scheduling/autotune_inner_reduction.py new file mode 100644 index 00000000000..366118d34a7 --- /dev/null +++ b/doc/dev/python_scheduling/autotune_inner_reduction.py @@ -0,0 +1,406 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# Owner(s): ["module: nvfuser"] + +import torch +import itertools +import math +from nvfuser import FusionDefinition, SchedulerType, DataType, ParallelType +from enum import Enum +from dataclasses import dataclass + +from autotune_utils import ( + ScriptConfiguration, + collect_data, + separate_data, + test_model_rmse, + test_model, +) + + +# ============================ Description ============================ + +# This script defines four inner reduction fusions: +# +# 1. Inner Sum +# y = sum(x, dim=-1) +# +# 2. Add Sum +# z = sum(x1 + x2 + x3 + x4, dim=-1) +# +# 3. Tanh Sum +# y = sum(tanh(x), dim=-1) +# +# 4. Exp Sum +# z = sum(exp(x), dim=-1) +# +# Script Sequence: +# +# 1. Define a nvfuser fusion and its pytorch eager mode reference. +# +# 2. Profile the CUDA kernel performance by iterating over a set of input +# arguments and scheduler configurations. +# +# 3. Train a regression model to predict the desired performance metric given +# some input arguments and a scheduler configuration. +# +# 4. Measure the performance of the regression model. +# - Calculate RMSE of predicted and actual performance on test set. +# - Find the configuration with the best performance using regression model. +# Then, compare against the heuristic configuration selected by nvfuser. +# - For a specific batch size, gather performance across a range of hidden +# sizes. Calculate performance for best predicted and nvfuser +# configurations. Plot a chart comparing performance using matplotlib. +# +# The selected performance metric is effective_bandwidth_gbs. The empirical +# scheduler selects the configuration that has the highest predicted +# effective_bandwidth_gbs. + +# ============================ Configurations ============================ + +# gpu device properties are defined globally +assert torch.cuda.is_available() +gpu_properties = torch.cuda.get_device_properties(device=0) + + +# Returns the result of a/b rounded to the nearest integer in the direction of +# positive infinity. +def ceil_div(a, b): + return int(math.ceil(a / b)) + + +# Returns the result of a/b rounded to the nearest integer in the direction of +# negative infinity. +def floor_div(a, b): + return int(math.floor(a / b)) + + +class AutotuneInnerReduction: + class FUSION(Enum): + INNER_SUM = 1 + ADD_SUM = 2 + TANH_SUM = 3 + EXP_SUM = 4 + + @dataclass(unsafe_hash=True) + class InnerReductionConfiguration: + vectorize_factor: int = 1 + unroll_factor: int = 1 + godim: int = -1 + grdim: int = -1 + bdimx: int = -1 + bdimy: int = -1 + + def __init__(self, selected_fusion): + self.selected_fusion = selected_fusion + + def __repr__(self): + return f"inner_reduction_{self.selected_fusion.name}" + + def convert_to_inner_reduction_params(self, scheduler_config, reduction_params): + warp_size = 32 + max_number_of_threads_cta = 1024 + grid_x_limit = 2147483647 + grid_y_limit = 65535 + + reduction_params.schedule_3D = False + reduction_params.fastest_dim = True + reduction_params.cross_block_inner_reduction = True + reduction_params.block_dim_inner_reduction = ParallelType.block_x + reduction_params.cross_grid_inner_reduction = scheduler_config.grdim > 1 + reduction_params.multiple_reds_per_blk = scheduler_config.bdimy > 1 + reduction_params.pad_inner_reduction_to_warp = ( + scheduler_config.bdimx > warp_size + ) and ( + (scheduler_config.bdimx * scheduler_config.bdimy) + < max_number_of_threads_cta + ) + reduction_params.unroll_factor_inner_reduction = ( + scheduler_config.vectorize_factor + ) + reduction_params.vectorize_inner_reduction = ( + scheduler_config.vectorize_factor > 1 + ) + + if scheduler_config.bdimy > 1: + reduction_params.block_dim_iter_dom = ParallelType.block_y + + reduction_params.unroll_factor_iter_dom = scheduler_config.unroll_factor + + gdimx = -1 + gdimy = -1 + + if scheduler_config.grdim > 1: + reduction_params.grid_dim_inner_reduction = ParallelType.grid_x + reduction_params.grid_dim_iter_dom = ParallelType.grid_y + + reduction_params.split_grid_dim_iter_dom_inner = True + gdimx = min(scheduler_config.grdim, grid_x_limit) + gdimy = min(scheduler_config.godim, grid_y_limit) + if scheduler_config.godim > grid_y_limit: + reduction_params.split_grid_dim_iter_dom_outer = True + else: + reduction_params.grid_dim_iter_dom = ParallelType.grid_x + gdimx = min(scheduler_config.godim, grid_x_limit) + if scheduler_config.godim > grid_x_limit: + reduction_params.split_grid_dim_inner_reduction = True + + reduction_params.lparams.gdimx = gdimx + reduction_params.lparams.gdimy = gdimy + + # Reset CTA dimensions to avoid failing LaunchParams::assertValid + reduction_params.lparams.bdimx = -1 + reduction_params.lparams.bdimy = -1 + reduction_params.lparams.bdimz = -1 + + reduction_params.lparams.bdimx = scheduler_config.bdimx + reduction_params.lparams.bdimy = scheduler_config.bdimy + + # For reduction scheduler, we test the cartesian product of vectorization and + # unroll factors. + def generate_scheduler_configurations(self, input_shape): + threads_per_cta_options = [128, 256, 512, 1024] + vectorization_factor_options = [1, 2, 4, 8] + unroll_factor_options = list(range(1, 11)) + warp_size = 32 + + num_iterations, num_reductions = input_shape + + for threads_per_cta, vectorize_factor, unroll_factor in itertools.product( + threads_per_cta_options, vectorization_factor_options, unroll_factor_options + ): + scheduler_config = self.InnerReductionConfiguration( + vectorize_factor=vectorize_factor, unroll_factor=unroll_factor + ) + scheduler_config.bdimx = min( + threads_per_cta, + max( + warp_size, + ceil_div(num_reductions, scheduler_config.vectorize_factor), + ), + ) + scheduler_config.bdimy = min( + threads_per_cta, + max(1, floor_div(threads_per_cta, scheduler_config.bdimx)), + ) + scheduler_config.godim = ceil_div( + num_iterations, scheduler_config.bdimy * scheduler_config.unroll_factor + ) + + # number of reduction elements not handled by a CTA + remaining_reduction = ceil_div( + num_reductions, + (scheduler_config.bdimx * scheduler_config.vectorize_factor), + ) + + if unroll_factor == 1 and remaining_reduction > 1: + # all remaining reduction goes to grdim + scheduler_config.grdim = remaining_reduction + yield scheduler_config + + # grid stride across reduction iterDomain is 1 + scheduler_config.grdim = 1 + yield scheduler_config + + def create_inputs(self, shape, tensor_datatype): + def inner_fn(num_inputs): + return [ + torch.randn(*shape, dtype=tensor_datatype, device="cuda") + for _ in range(num_inputs) + ] + + if self.selected_fusion == self.FUSION.ADD_SUM: + return inner_fn(num_inputs=4) + elif self.selected_fusion in [ + self.FUSION.INNER_SUM, + self.FUSION.TANH_SUM, + self.FUSION.EXP_SUM, + ]: + return inner_fn(num_inputs=1) + else: + assert False + + # A decorator to create a reduction fusion given some input arguments. + def create_fusion_func(self, inputs): + def sum_fusion(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T1 = fd.ops.cast(T0, dtype=DataType.Float) + T2 = fd.ops.sum(T1, dims=[1], keepdim=False, dtype=DataType.Null) + T3 = fd.ops.cast(T2, dtype=DataType.BFloat16) + fd.add_output(T3) + + def add_sum(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T1 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T2 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T3 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T4 = fd.ops.cast(T0, dtype=DataType.Float) + T5 = fd.ops.cast(T1, dtype=DataType.Float) + T6 = fd.ops.add(T4, T5) + T7 = fd.ops.cast(T2, dtype=DataType.Float) + T8 = fd.ops.add(T6, T7) + T9 = fd.ops.cast(T3, dtype=DataType.Float) + T10 = fd.ops.add(T8, T9) + T11 = fd.ops.sum(T10, dims=[1], keepdim=False, dtype=DataType.Null) + T12 = fd.ops.cast(T11, dtype=DataType.BFloat16) + fd.add_output(T12) + + def tanh_sum(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T1 = fd.ops.cast(T0, dtype=DataType.Float) + T2 = fd.ops.tanh(T1) + T3 = fd.ops.sum(T2, dims=[1], keepdim=False, dtype=DataType.Null) + T4 = fd.ops.cast(T3, dtype=DataType.BFloat16) + fd.add_output(T4) + + def exp_sum(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T1 = fd.ops.cast(T0, dtype=DataType.Float) + T2 = fd.ops.exp(T1) + T3 = fd.ops.sum(T2, dims=[1], keepdim=False, dtype=DataType.Null) + T4 = fd.ops.cast(T3, dtype=DataType.BFloat16) + fd.add_output(T4) + + if self.selected_fusion == self.FUSION.INNER_SUM: + return sum_fusion + elif self.selected_fusion == self.FUSION.ADD_SUM: + return add_sum + elif self.selected_fusion == self.FUSION.TANH_SUM: + return tanh_sum + elif self.selected_fusion == self.FUSION.EXP_SUM: + return exp_sum + else: + assert False + + # The pytorch eager mode reference used to validating nvfuser kernel. + def eager_reference(self, inputs): + def sum_fusion(inputs): + return torch.sum(inputs[0], dim=-1) + + def add_sum(inputs): + return torch.sum(inputs[0] + inputs[1] + inputs[2] + inputs[3], dim=-1) + + def tanh_sum(inputs): + return torch.sum(torch.tanh(inputs[0]), dim=-1) + + def exp_sum(inputs): + return torch.sum(torch.exp(inputs[0]), dim=-1) + + if self.selected_fusion == self.FUSION.INNER_SUM: + return sum_fusion(inputs) + elif self.selected_fusion == self.FUSION.ADD_SUM: + return add_sum(inputs) + elif self.selected_fusion == self.FUSION.TANH_SUM: + return tanh_sum(inputs) + elif self.selected_fusion == self.FUSION.EXP_SUM: + return exp_sum(inputs) + else: + assert False + + # Apply scheduler with custom parameters using decorator + def custom_scheduler(self, fd, scheduler_config): + def inner_fn(): + # Check if compatible with reduction scheduler + status, _ = fd.sched.can_schedule(SchedulerType.reduction) + assert status + + reduction_params = fd.sched.compute_reduction_heuristics() + + # Modify original parameters + if scheduler_config is not None: + self.convert_to_inner_reduction_params( + scheduler_config, reduction_params + ) + + # Schedule fusion + fd.sched.schedule() + + fd.schedule = inner_fn + return fd + + +# Run sequence of steps to collect data, train and test model +def main(): + # ====================== Setup Script Configuration ======================= + script_config = ScriptConfiguration( + num_dimensions=2, + outer_shapes=[16384], + inner_shapes=[128, 1024, 4096, 16384], + tensor_datatype=torch.bfloat16, + test_data_percentage=0.1, + empirical_batch_size=16384, + empirical_hidden_sizes=list(range(256, 32768, 256)), + ) + + autotune_config = AutotuneInnerReduction( + selected_fusion=AutotuneInnerReduction.FUSION.INNER_SUM + ) + + # ============================ Run Experiments ============================ + + parameters, performance = collect_data(script_config, autotune_config) + + # ============================ Separate Data ============================== + + train_data, test_data = separate_data(script_config, parameters, performance) + + # ========================= Train Regression Models ======================= + + # Apply machine learning regressor + # Given input shapes and scheduler parameters, predict performance metric. + from sklearn import ensemble + + train_inputs, train_perf = train_data + clf = ensemble.RandomForestRegressor() + clf = clf.fit(train_inputs, train_perf) + + # ========================= Test Regression Models ======================== + test_model_rmse(clf, script_config, autotune_config, test_data) + test_model(clf, script_config, autotune_config) + + +if __name__ == "__main__": + main() From a265617b427aef85e186507354c4795e4499e02d Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Sun, 24 Nov 2024 11:48:57 -0800 Subject: [PATCH 4/8] refactor --- .../autotune_inner_reduction.py | 27 ++++++------------- doc/dev/python_scheduling/autotune_utils.py | 13 +++++++++ 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/doc/dev/python_scheduling/autotune_inner_reduction.py b/doc/dev/python_scheduling/autotune_inner_reduction.py index 366118d34a7..e714cd238ee 100644 --- a/doc/dev/python_scheduling/autotune_inner_reduction.py +++ b/doc/dev/python_scheduling/autotune_inner_reduction.py @@ -5,7 +5,6 @@ import torch import itertools -import math from nvfuser import FusionDefinition, SchedulerType, DataType, ParallelType from enum import Enum from dataclasses import dataclass @@ -16,10 +15,12 @@ separate_data, test_model_rmse, test_model, + ceil_div, + floor_div, ) -# ============================ Description ============================ +# ================================ Description ================================ # This script defines four inner reduction fusions: # @@ -57,23 +58,7 @@ # scheduler selects the configuration that has the highest predicted # effective_bandwidth_gbs. -# ============================ Configurations ============================ - -# gpu device properties are defined globally -assert torch.cuda.is_available() -gpu_properties = torch.cuda.get_device_properties(device=0) - - -# Returns the result of a/b rounded to the nearest integer in the direction of -# positive infinity. -def ceil_div(a, b): - return int(math.ceil(a / b)) - - -# Returns the result of a/b rounded to the nearest integer in the direction of -# negative infinity. -def floor_div(a, b): - return int(math.floor(a / b)) +# ============================================================================= class AutotuneInnerReduction: @@ -95,6 +80,10 @@ class InnerReductionConfiguration: def __init__(self, selected_fusion): self.selected_fusion = selected_fusion + # gpu device properties are defined globally + assert torch.cuda.is_available() + self.gpu_properties = torch.cuda.get_device_properties(device=0) + def __repr__(self): return f"inner_reduction_{self.selected_fusion.name}" diff --git a/doc/dev/python_scheduling/autotune_utils.py b/doc/dev/python_scheduling/autotune_utils.py index 4017c6c87f8..e699f84270e 100644 --- a/doc/dev/python_scheduling/autotune_utils.py +++ b/doc/dev/python_scheduling/autotune_utils.py @@ -4,6 +4,7 @@ # Owner(s): ["module: nvfuser"] import torch +import math import itertools from nvfuser import FusionCache, FusionDefinition from dataclasses import dataclass, astuple @@ -13,6 +14,18 @@ # ============================================================================= +# Returns the result of a/b rounded to the nearest integer in the direction of +# positive infinity. +def ceil_div(a, b): + return int(math.ceil(a / b)) + + +# Returns the result of a/b rounded to the nearest integer in the direction of +# negative infinity. +def floor_div(a, b): + return int(math.floor(a / b)) + + @dataclass class ScriptConfiguration: # Settings for input tensor generation From e7ffb29abc9e0040140c46bf8feae3569d4e7886 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Tue, 26 Nov 2024 17:09:45 -0800 Subject: [PATCH 5/8] comments --- doc/dev/python_scheduling/autotune_inner_reduction.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/doc/dev/python_scheduling/autotune_inner_reduction.py b/doc/dev/python_scheduling/autotune_inner_reduction.py index e714cd238ee..c43a20f5767 100644 --- a/doc/dev/python_scheduling/autotune_inner_reduction.py +++ b/doc/dev/python_scheduling/autotune_inner_reduction.py @@ -70,11 +70,21 @@ class FUSION(Enum): @dataclass(unsafe_hash=True) class InnerReductionConfiguration: + # The vectorization factor for inner reduction domain. vectorize_factor: int = 1 + # The unroll factor for the outer iteration domain. unroll_factor: int = 1 + # The grid size for the outer iteration domain. + # If grdim > 1, then godim corresponds with y axis of the grid. + # Otherwise, it is the x axis of the grid. godim: int = -1 + # The grid size for the inner reduction domain. It corresponds + # with x axis of the grid when it is >1. grdim: int = -1 + # The x axis of CTA. It corresponds with inner reduction domain. bdimx: int = -1 + # The y axis of CTA. It corresponds with outer reduction domain. + # If it is non-zero, then there are multiple reduction per CTA. bdimy: int = -1 def __init__(self, selected_fusion): From fdcf6a548ed43be9c6df4e853fdd4f0515a6ec08 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Tue, 26 Nov 2024 17:18:04 -0800 Subject: [PATCH 6/8] add reduction_unroll_factor --- .../autotune_inner_reduction.py | 36 ++++++++++++++----- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/doc/dev/python_scheduling/autotune_inner_reduction.py b/doc/dev/python_scheduling/autotune_inner_reduction.py index c43a20f5767..92430b6a0aa 100644 --- a/doc/dev/python_scheduling/autotune_inner_reduction.py +++ b/doc/dev/python_scheduling/autotune_inner_reduction.py @@ -72,8 +72,10 @@ class FUSION(Enum): class InnerReductionConfiguration: # The vectorization factor for inner reduction domain. vectorize_factor: int = 1 + # The unroll factor for the inner reduction domain. + reduction_unroll_factor: int = 1 # The unroll factor for the outer iteration domain. - unroll_factor: int = 1 + iteration_unroll_factor: int = 1 # The grid size for the outer iteration domain. # If grdim > 1, then godim corresponds with y axis of the grid. # Otherwise, it is the x axis of the grid. @@ -121,11 +123,16 @@ def convert_to_inner_reduction_params(self, scheduler_config, reduction_params): reduction_params.vectorize_inner_reduction = ( scheduler_config.vectorize_factor > 1 ) + reduction_params.unroll_factor_top_of_vectorization = ( + scheduler_config.reduction_unroll_factor + ) if scheduler_config.bdimy > 1: reduction_params.block_dim_iter_dom = ParallelType.block_y - reduction_params.unroll_factor_iter_dom = scheduler_config.unroll_factor + reduction_params.unroll_factor_iter_dom = ( + scheduler_config.iteration_unroll_factor + ) gdimx = -1 gdimy = -1 @@ -161,16 +168,27 @@ def convert_to_inner_reduction_params(self, scheduler_config, reduction_params): def generate_scheduler_configurations(self, input_shape): threads_per_cta_options = [128, 256, 512, 1024] vectorization_factor_options = [1, 2, 4, 8] - unroll_factor_options = list(range(1, 11)) + reduction_unroll_factor_options = list(range(1, 6)) + iteration_unroll_factor_options = list(range(1, 6)) warp_size = 32 num_iterations, num_reductions = input_shape - for threads_per_cta, vectorize_factor, unroll_factor in itertools.product( - threads_per_cta_options, vectorization_factor_options, unroll_factor_options + for ( + threads_per_cta, + vectorize_factor, + reduction_unroll_factor, + iteration_unroll_factor, + ) in itertools.product( + threads_per_cta_options, + vectorization_factor_options, + reduction_unroll_factor_options, + iteration_unroll_factor_options, ): scheduler_config = self.InnerReductionConfiguration( - vectorize_factor=vectorize_factor, unroll_factor=unroll_factor + vectorize_factor=vectorize_factor, + reduction_unroll_factor=reduction_unroll_factor, + iteration_unroll_factor=iteration_unroll_factor, ) scheduler_config.bdimx = min( threads_per_cta, @@ -184,16 +202,16 @@ def generate_scheduler_configurations(self, input_shape): max(1, floor_div(threads_per_cta, scheduler_config.bdimx)), ) scheduler_config.godim = ceil_div( - num_iterations, scheduler_config.bdimy * scheduler_config.unroll_factor + num_iterations, scheduler_config.bdimy * iteration_unroll_factor ) # number of reduction elements not handled by a CTA remaining_reduction = ceil_div( num_reductions, - (scheduler_config.bdimx * scheduler_config.vectorize_factor), + (scheduler_config.bdimx * vectorize_factor * reduction_unroll_factor), ) - if unroll_factor == 1 and remaining_reduction > 1: + if iteration_unroll_factor == 1 and remaining_reduction > 1: # all remaining reduction goes to grdim scheduler_config.grdim = remaining_reduction yield scheduler_config From fc8e57de42600456d80772b80e71e4d8c581b267 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Fri, 13 Dec 2024 09:25:44 -0800 Subject: [PATCH 7/8] update remaining_reduction calculation --- doc/dev/python_scheduling/autotune_inner_reduction.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/doc/dev/python_scheduling/autotune_inner_reduction.py b/doc/dev/python_scheduling/autotune_inner_reduction.py index 92430b6a0aa..e06de2566b3 100644 --- a/doc/dev/python_scheduling/autotune_inner_reduction.py +++ b/doc/dev/python_scheduling/autotune_inner_reduction.py @@ -207,8 +207,10 @@ def generate_scheduler_configurations(self, input_shape): # number of reduction elements not handled by a CTA remaining_reduction = ceil_div( - num_reductions, - (scheduler_config.bdimx * vectorize_factor * reduction_unroll_factor), + ceil_div( + ceil_div(num_reductions, vectorize_factor), scheduler_config.bdimx + ), + reduction_unroll_factor, ) if iteration_unroll_factor == 1 and remaining_reduction > 1: From 5ae857d393e8cc3b250166fc1634ab55624d45dc Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Fri, 13 Dec 2024 10:08:36 -0800 Subject: [PATCH 8/8] update grdim --- .../python_scheduling/autotune_inner_reduction.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/doc/dev/python_scheduling/autotune_inner_reduction.py b/doc/dev/python_scheduling/autotune_inner_reduction.py index e06de2566b3..260c0dab6ca 100644 --- a/doc/dev/python_scheduling/autotune_inner_reduction.py +++ b/doc/dev/python_scheduling/autotune_inner_reduction.py @@ -218,6 +218,19 @@ def generate_scheduler_configurations(self, input_shape): scheduler_config.grdim = remaining_reduction yield scheduler_config + # When iteration dim is small, there may be unused SMs. We need + # to shift work from block reduction to grid reduction to + # increase SM usage. + godim = scheduler_config.godim + grdim = 1 + while ( + godim * grdim * 2 <= self.gpu_properties.multi_processor_count + and (remaining_reduction / grdim) >= 2 + ): + grdim *= 2 + scheduler_config.grdim = grdim + yield scheduler_config + # grid stride across reduction iterDomain is 1 scheduler_config.grdim = 1 yield scheduler_config