Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add enable_options and disable_options to fd.execute #3270

Merged
merged 13 commits into from
Nov 13, 2024
106 changes: 68 additions & 38 deletions csrc/options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,53 +148,73 @@ std::unordered_map<DebugDumpOption, std::vector<std::string>> Options<
return parseEnvOptions("DUMP", available_options);
}

const std::unordered_map<std::string, EnableOption>& getEnableOptions() {
static const std::unordered_map<std::string, EnableOption> available_options =
{
{"fuse_matmul", EnableOption::FuseMatmul},
{"fuse_multiple_matmuls", EnableOption::FuseMultipleMatmuls},
{"id_model", EnableOption::IdModel},
{"kernel_db", EnableOption::KernelDb},
{"kernel_profile", EnableOption::KernelProfile},
{"memory_promotion", EnableOption::MemoryPromotion},
{"reuse_zeroed_memory", EnableOption::ReuseZeroedMemory},
{"static_fusion_count", EnableOption::StaticFusionCount},
{"warn_register_spill", EnableOption::WarnRegisterSpill},
{"io_to_lower_precision", EnableOption::IoToLowerPrecision},
{"kernel_debug", EnableOption::KernelDebug},
{"kernel_lineinfo", EnableOption::KernelLineInfo},
};
return available_options;
}

template <>
std::unordered_map<EnableOption, std::vector<std::string>> Options<
EnableOption>::getOptionsFromEnv() {
const std::unordered_map<std::string, EnableOption> available_options = {
{"fuse_matmul", EnableOption::FuseMatmul},
{"fuse_multiple_matmuls", EnableOption::FuseMultipleMatmuls},
{"id_model", EnableOption::IdModel},
{"kernel_db", EnableOption::KernelDb},
{"kernel_profile", EnableOption::KernelProfile},
{"memory_promotion", EnableOption::MemoryPromotion},
{"reuse_zeroed_memory", EnableOption::ReuseZeroedMemory},
{"static_fusion_count", EnableOption::StaticFusionCount},
{"warn_register_spill", EnableOption::WarnRegisterSpill},
{"io_to_lower_precision", EnableOption::IoToLowerPrecision},
{"kernel_debug", EnableOption::KernelDebug},
{"kernel_lineinfo", EnableOption::KernelLineInfo},
};

const auto& available_options = getEnableOptions();
return parseEnvOptions("ENABLE", available_options);
}

std::optional<EnableOption> stringToEnableOption(
const std::string& enable_option) {
const auto& opts = getEnableOptions();
auto it = opts.find(enable_option);
if (it != opts.end()) {
return it->second;
}
return std::nullopt;
}

const std::unordered_map<std::string, DisableOption>& getDisableOptions() {
static const std::unordered_map<std::string, DisableOption>
available_options = {
{"compile_to_sass", DisableOption::CompileToSass},
{"contig_indexing", DisableOption::ContigIndexing},
{"expr_simplify", DisableOption::ExprSimplify},
{"fallback", DisableOption::Fallback},
{"fma", DisableOption::Fma},
{"grouped_grid_welford_outer_opt",
DisableOption::GroupedGridWelfordOuterOpt},
{"index_hoist", DisableOption::IndexHoist},
{"magic_zero", DisableOption::MagicZero},
{"matmul_expr_eval", DisableOption::MatmulExprEval},
{"nvtx", DisableOption::Nvtx},
{"parallel_compile", DisableOption::ParallelCompile},
{"parallel_serde", DisableOption::ParallelSerde},
{"predicate_elimination", DisableOption::PredicateElimination},
{"python_inline_definitions", DisableOption::PythonInlineDefinitions},
{"kernel_reuse", DisableOption::KernelReuse},
{"var_name_remapping", DisableOption::VarNameRemapping},
{"welford_vectorization", DisableOption::WelfordVectorization},
{"reuse_mismatched_type_registers",
DisableOption::ReuseMismatchedTypeRegisters},
{"multidevice", DisableOption::Multidevice}};
return available_options;
}

template <>
std::unordered_map<DisableOption, std::vector<std::string>> Options<
DisableOption>::getOptionsFromEnv() {
const std::unordered_map<std::string, DisableOption> available_options = {
{"compile_to_sass", DisableOption::CompileToSass},
{"contig_indexing", DisableOption::ContigIndexing},
{"expr_simplify", DisableOption::ExprSimplify},
{"fallback", DisableOption::Fallback},
{"fma", DisableOption::Fma},
{"grouped_grid_welford_outer_opt",
DisableOption::GroupedGridWelfordOuterOpt},
{"index_hoist", DisableOption::IndexHoist},
{"magic_zero", DisableOption::MagicZero},
{"matmul_expr_eval", DisableOption::MatmulExprEval},
{"nvtx", DisableOption::Nvtx},
{"parallel_compile", DisableOption::ParallelCompile},
{"parallel_serde", DisableOption::ParallelSerde},
{"predicate_elimination", DisableOption::PredicateElimination},
{"python_inline_definitions", DisableOption::PythonInlineDefinitions},
{"kernel_reuse", DisableOption::KernelReuse},
{"var_name_remapping", DisableOption::VarNameRemapping},
{"welford_vectorization", DisableOption::WelfordVectorization},
{"reuse_mismatched_type_registers",
DisableOption::ReuseMismatchedTypeRegisters},
{"multidevice", DisableOption::Multidevice}};

const auto& available_options = getDisableOptions();
auto options = parseEnvOptions("DISABLE", available_options);

if (options.count(DisableOption::Fma)) {
Expand All @@ -205,6 +225,16 @@ std::unordered_map<DisableOption, std::vector<std::string>> Options<
return options;
}

std::optional<DisableOption> stringToDisableOption(
const std::string& disable_option) {
const auto& opts = getDisableOptions();
auto it = opts.find(disable_option);
if (it != opts.end()) {
return it->second;
}
return std::nullopt;
}

template <>
std::unordered_map<ProfilerOption, std::vector<std::string>> Options<
ProfilerOption>::getOptionsFromEnv() {
Expand Down
6 changes: 6 additions & 0 deletions csrc/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,9 @@ NVF_API std::unordered_map<EnableOption, std::vector<std::string>> Options<

using EnableOptions = Options<EnableOption>;

std::optional<EnableOption> stringToEnableOption(
const std::string& enable_option);

bool isOptionEnabled(EnableOption option);

const std::vector<std::string>& getEnableOptionArguments(EnableOption option);
Expand All @@ -268,6 +271,9 @@ NVF_API std::unordered_map<DisableOption, std::vector<std::string>> Options<

using DisableOptions = Options<DisableOption>;

std::optional<DisableOption> stringToDisableOption(
const std::string& disable_option);

NVF_API bool isOptionDisabled(DisableOption option);

const std::vector<std::string>& getDisableOptionArguments(DisableOption option);
Expand Down
19 changes: 18 additions & 1 deletion csrc/python_frontend/fusion_definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,9 @@ std::vector<at::Tensor> FusionDefinition::execute(
std::optional<int8_t> selected_device,
bool override_user_schedule,
bool capture_debug_output,
bool profile) const {
bool profile,
std::vector<std::string> _enable_options,
std::vector<std::string> _disable_options) const {
debug_output_ = std::nullopt;
std::stringstream debug_ss;
DebugStreamGuard dsg(capture_debug_output ? debug_ss : std::cout);
Expand All @@ -351,6 +353,21 @@ std::vector<at::Tensor> FusionDefinition::execute(
ProfilerOptionsGuard::getCurOptions().set(ProfilerOption::Enable);
}

EnableOptionsGuard enable_opt_guard;
for (const auto& _enable_option : _enable_options) {
std::optional<EnableOption> opt = stringToEnableOption(_enable_option);
NVF_CHECK(opt.has_value(), "Unrecognized enable_option: ", _enable_option);
EnableOptionsGuard::getCurOptions().set(opt.value());
}

DisableOptionsGuard disable_opt_guard;
for (const auto& _disable_option : _disable_options) {
std::optional<DisableOption> opt = stringToDisableOption(_disable_option);
NVF_CHECK(
opt.has_value(), "Unrecognized disable_option: ", _disable_option);
DisableOptionsGuard::getCurOptions().set(opt.value());
}

if (!override_user_schedule) {
auto device = getCommonDeviceCUDA(inputs, selected_device);
NVF_CHECK(
Expand Down
4 changes: 3 additions & 1 deletion csrc/python_frontend/fusion_definition.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ class NVF_API FusionDefinition : public FusionState {
std::optional<int8_t> device,
bool override_user_schedule,
bool capture_debug_output,
bool profile) const;
bool profile,
std::vector<std::string> _enable_options,
std::vector<std::string> _disable_options) const;
//! Return debugging output captured through exeuction with
//! capture_debug_output=true
std::optional<std::string> getDebugOutput() const {
Expand Down
10 changes: 8 additions & 2 deletions csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1016,7 +1016,9 @@ void initNvFuserPythonBindings(PyObject* module) {
std::optional<int64_t> device,
bool override_user_schedule,
bool capture_debug_output,
bool profile) {
bool profile,
std::vector<std::string> _enable_options,
std::vector<std::string> _disable_options) {
std::vector<c10::IValue> inputs;
for (py::handle obj : iter) {
// Allows for a Vector of Sizes to be inputed as a list/tuple
Expand All @@ -1041,14 +1043,18 @@ void initNvFuserPythonBindings(PyObject* module) {
int8_device,
override_user_schedule,
capture_debug_output,
profile);
profile,
_enable_options,
_disable_options);
},
py::arg("inputs"),
py::kw_only(),
py::arg("device") = py::none(),
py::arg("override_user_schedule") = false,
py::arg("capture_debug_output") = false,
py::arg("profile") = false,
py::arg("_enable_options") = py::none(),
py::arg("_disable_options") = py::none(),
py::return_value_policy::reference)
.def_static(
"_profile",
Expand Down
16 changes: 16 additions & 0 deletions nvfuser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import re
import sys
from typing import Callable, Optional, Union, List # noqa: F401
import warnings

import torch

Expand Down Expand Up @@ -77,6 +78,8 @@ def execute(
print_repro=False,
profile=False,
save_repro_inputs=False,
_enable_options: list[str] = [],
_disable_options: list[str] = [],
):
"""
Executes an nvFuser set of kernels for a given Fusion
Expand Down Expand Up @@ -119,6 +122,11 @@ def execute(
profile (bool): Captures a CUPTI based profile of a fusion.
save_repro_inputs (bool): Saves the inputs for last_repro_script() to
provide a provide a reproduction script.
_enable_options/_disable_options (list): NVFUSER_ENABLE/DISABLE options to use.
This is an alternative to environment variables.
Note: Currently, we do not cache/store these options in the FusionCache which makes it
plausible to reuse kernels when executing the same fusion definition with different sets of options.
Reset the FusionCache manually to avoid inadvertent kernel reuse when between different sets of options.

Returns:
List[Tensor]
Expand Down Expand Up @@ -176,15 +184,23 @@ def execute(
self.fake_inputs = [fake_mode.from_tensor(inp) for inp in inputs]

results = None

try:
if print_repro:
print(self.repro_script_for(inputs))
if len(_enable_options) or len(_disable_options):
warnings.warn(
"Reset the FusionCache manually to avoid reusing kernels when re-executing the fusion definition with different options."
)

results = self._execute(
inputs,
device=device,
override_user_schedule=override_user_schedule,
capture_debug_output=capture_debug_output,
profile=profile,
_enable_options=_enable_options,
_disable_options=_disable_options,
Comment on lines +191 to +203
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, moving forward with renaming the kwargs to _enable/_disable_options. I added a warning about the possibility of kernel reuse so that we are aware of this when utilizing this feature.

)
return results
except Exception as err:
Expand Down
41 changes: 41 additions & 0 deletions tests/python/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4709,3 +4709,44 @@ def fusion_func(fd: FusionDefinition) -> None:
fd.add_output(T223)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)

def test_enable_disable_options(self):
m = 24
n = 16
k = 8
inps = [
torch.randn(m, k, device="cuda", dtype=torch.float),
torch.randn(k, n, device="cuda", dtype=torch.float),
]

def fusion_func(fd: FusionDefinition, inps) -> None:
t0 = fd.from_pytorch(inps[0])
t1 = fd.from_pytorch(inps[1])
t2 = fd.ops.matmul(t0, t1)
fd.add_output(t2)

with FusionDefinition() as fd:
fusion_func(fd, inps=inps)

# By default, matmul will be be run through expr_eval scheduler.
# Through setting the enable and disable options as below,
# we can execute it through matmul scheduler. The above fusion will not
# be accepted by the matmul scheduler since the outputs are of type Float and raises a RuntimeError.
# Note: We use this error-based test since for compatible dtypes (float16/bfloat16),
# the matmul scheduler ran into a scheduling error on H100. This test might be more robust against
# changes in matmul scheduler in the interim.

self.assertRaisesRegex(
RuntimeError,
"Can not find a scheduler to schedule fusion segment",
self.exec_nvfuser,
partial(fusion_func, inps=inps),
inps,
_enable_options=["fuse_matmul"],
_disable_options=["matmul_expr_eval"],
skip_serde_check=True,
)

# Serializing error test cases corrupts the serialized binary causing subsequent tests to fail.
# Reset the fusion cache to avoid this.
FusionCache.reset()
9 changes: 8 additions & 1 deletion tests/python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,8 @@ def exec_nvfuser(
fusion_func,
inputs,
*,
_enable_options=[],
_disable_options=[],
new_fusion_expected=True,
device=None,
is_clonable=True,
Expand All @@ -432,7 +434,12 @@ def exec_nvfuser(
with FusionDefinition() as fd:
fusion_func(fd)
torch.manual_seed(0)
out = fd.execute(inputs, device=device)
out = fd.execute(
inputs,
device=device,
_enable_options=_enable_options,
_disable_options=_disable_options,
)

self.assertTrue(
check_captured_python_definition(out, fd, inputs_captured, device)
Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.2.22
0.2.23
Loading