Skip to content

Commit

Permalink
Add enable_options and disable_options to fd.execute (#3270)
Browse files Browse the repository at this point in the history
This PR adds `enable_options` and `disable_options` to `fd.execute` to
allow setting them through the python frontend in lieu of the
environment variables.
This work will be used to allow enabling nvfuser matmul codegen from
within Thunder.

Inspired by @jacobhinkle's PR #1905!

Tracking Issue: #3022

---------

Co-authored-by: Ryan Spring <[email protected]>
  • Loading branch information
Priya2698 and rdspring1 authored Nov 13, 2024
1 parent 2fb5539 commit 7212038
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 44 deletions.
106 changes: 68 additions & 38 deletions csrc/options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,53 +148,73 @@ std::unordered_map<DebugDumpOption, std::vector<std::string>> Options<
return parseEnvOptions("DUMP", available_options);
}

const std::unordered_map<std::string, EnableOption>& getEnableOptions() {
static const std::unordered_map<std::string, EnableOption> available_options =
{
{"fuse_matmul", EnableOption::FuseMatmul},
{"fuse_multiple_matmuls", EnableOption::FuseMultipleMatmuls},
{"id_model", EnableOption::IdModel},
{"kernel_db", EnableOption::KernelDb},
{"kernel_profile", EnableOption::KernelProfile},
{"memory_promotion", EnableOption::MemoryPromotion},
{"reuse_zeroed_memory", EnableOption::ReuseZeroedMemory},
{"static_fusion_count", EnableOption::StaticFusionCount},
{"warn_register_spill", EnableOption::WarnRegisterSpill},
{"io_to_lower_precision", EnableOption::IoToLowerPrecision},
{"kernel_debug", EnableOption::KernelDebug},
{"kernel_lineinfo", EnableOption::KernelLineInfo},
};
return available_options;
}

template <>
std::unordered_map<EnableOption, std::vector<std::string>> Options<
EnableOption>::getOptionsFromEnv() {
const std::unordered_map<std::string, EnableOption> available_options = {
{"fuse_matmul", EnableOption::FuseMatmul},
{"fuse_multiple_matmuls", EnableOption::FuseMultipleMatmuls},
{"id_model", EnableOption::IdModel},
{"kernel_db", EnableOption::KernelDb},
{"kernel_profile", EnableOption::KernelProfile},
{"memory_promotion", EnableOption::MemoryPromotion},
{"reuse_zeroed_memory", EnableOption::ReuseZeroedMemory},
{"static_fusion_count", EnableOption::StaticFusionCount},
{"warn_register_spill", EnableOption::WarnRegisterSpill},
{"io_to_lower_precision", EnableOption::IoToLowerPrecision},
{"kernel_debug", EnableOption::KernelDebug},
{"kernel_lineinfo", EnableOption::KernelLineInfo},
};

const auto& available_options = getEnableOptions();
return parseEnvOptions("ENABLE", available_options);
}

std::optional<EnableOption> stringToEnableOption(
const std::string& enable_option) {
const auto& opts = getEnableOptions();
auto it = opts.find(enable_option);
if (it != opts.end()) {
return it->second;
}
return std::nullopt;
}

const std::unordered_map<std::string, DisableOption>& getDisableOptions() {
static const std::unordered_map<std::string, DisableOption>
available_options = {
{"compile_to_sass", DisableOption::CompileToSass},
{"contig_indexing", DisableOption::ContigIndexing},
{"expr_simplify", DisableOption::ExprSimplify},
{"fallback", DisableOption::Fallback},
{"fma", DisableOption::Fma},
{"grouped_grid_welford_outer_opt",
DisableOption::GroupedGridWelfordOuterOpt},
{"index_hoist", DisableOption::IndexHoist},
{"magic_zero", DisableOption::MagicZero},
{"matmul_expr_eval", DisableOption::MatmulExprEval},
{"nvtx", DisableOption::Nvtx},
{"parallel_compile", DisableOption::ParallelCompile},
{"parallel_serde", DisableOption::ParallelSerde},
{"predicate_elimination", DisableOption::PredicateElimination},
{"python_inline_definitions", DisableOption::PythonInlineDefinitions},
{"kernel_reuse", DisableOption::KernelReuse},
{"var_name_remapping", DisableOption::VarNameRemapping},
{"welford_vectorization", DisableOption::WelfordVectorization},
{"reuse_mismatched_type_registers",
DisableOption::ReuseMismatchedTypeRegisters},
{"multidevice", DisableOption::Multidevice}};
return available_options;
}

template <>
std::unordered_map<DisableOption, std::vector<std::string>> Options<
DisableOption>::getOptionsFromEnv() {
const std::unordered_map<std::string, DisableOption> available_options = {
{"compile_to_sass", DisableOption::CompileToSass},
{"contig_indexing", DisableOption::ContigIndexing},
{"expr_simplify", DisableOption::ExprSimplify},
{"fallback", DisableOption::Fallback},
{"fma", DisableOption::Fma},
{"grouped_grid_welford_outer_opt",
DisableOption::GroupedGridWelfordOuterOpt},
{"index_hoist", DisableOption::IndexHoist},
{"magic_zero", DisableOption::MagicZero},
{"matmul_expr_eval", DisableOption::MatmulExprEval},
{"nvtx", DisableOption::Nvtx},
{"parallel_compile", DisableOption::ParallelCompile},
{"parallel_serde", DisableOption::ParallelSerde},
{"predicate_elimination", DisableOption::PredicateElimination},
{"python_inline_definitions", DisableOption::PythonInlineDefinitions},
{"kernel_reuse", DisableOption::KernelReuse},
{"var_name_remapping", DisableOption::VarNameRemapping},
{"welford_vectorization", DisableOption::WelfordVectorization},
{"reuse_mismatched_type_registers",
DisableOption::ReuseMismatchedTypeRegisters},
{"multidevice", DisableOption::Multidevice}};

const auto& available_options = getDisableOptions();
auto options = parseEnvOptions("DISABLE", available_options);

if (options.count(DisableOption::Fma)) {
Expand All @@ -205,6 +225,16 @@ std::unordered_map<DisableOption, std::vector<std::string>> Options<
return options;
}

std::optional<DisableOption> stringToDisableOption(
const std::string& disable_option) {
const auto& opts = getDisableOptions();
auto it = opts.find(disable_option);
if (it != opts.end()) {
return it->second;
}
return std::nullopt;
}

template <>
std::unordered_map<ProfilerOption, std::vector<std::string>> Options<
ProfilerOption>::getOptionsFromEnv() {
Expand Down
6 changes: 6 additions & 0 deletions csrc/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,9 @@ NVF_API std::unordered_map<EnableOption, std::vector<std::string>> Options<

using EnableOptions = Options<EnableOption>;

std::optional<EnableOption> stringToEnableOption(
const std::string& enable_option);

bool isOptionEnabled(EnableOption option);

const std::vector<std::string>& getEnableOptionArguments(EnableOption option);
Expand All @@ -268,6 +271,9 @@ NVF_API std::unordered_map<DisableOption, std::vector<std::string>> Options<

using DisableOptions = Options<DisableOption>;

std::optional<DisableOption> stringToDisableOption(
const std::string& disable_option);

NVF_API bool isOptionDisabled(DisableOption option);

const std::vector<std::string>& getDisableOptionArguments(DisableOption option);
Expand Down
19 changes: 18 additions & 1 deletion csrc/python_frontend/fusion_definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,9 @@ std::vector<at::Tensor> FusionDefinition::execute(
std::optional<int8_t> selected_device,
bool override_user_schedule,
bool capture_debug_output,
bool profile) const {
bool profile,
std::vector<std::string> _enable_options,
std::vector<std::string> _disable_options) const {
debug_output_ = std::nullopt;
std::stringstream debug_ss;
DebugStreamGuard dsg(capture_debug_output ? debug_ss : std::cout);
Expand All @@ -351,6 +353,21 @@ std::vector<at::Tensor> FusionDefinition::execute(
ProfilerOptionsGuard::getCurOptions().set(ProfilerOption::Enable);
}

EnableOptionsGuard enable_opt_guard;
for (const auto& _enable_option : _enable_options) {
std::optional<EnableOption> opt = stringToEnableOption(_enable_option);
NVF_CHECK(opt.has_value(), "Unrecognized enable_option: ", _enable_option);
EnableOptionsGuard::getCurOptions().set(opt.value());
}

DisableOptionsGuard disable_opt_guard;
for (const auto& _disable_option : _disable_options) {
std::optional<DisableOption> opt = stringToDisableOption(_disable_option);
NVF_CHECK(
opt.has_value(), "Unrecognized disable_option: ", _disable_option);
DisableOptionsGuard::getCurOptions().set(opt.value());
}

if (!override_user_schedule) {
auto device = getCommonDeviceCUDA(inputs, selected_device);
NVF_CHECK(
Expand Down
4 changes: 3 additions & 1 deletion csrc/python_frontend/fusion_definition.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ class NVF_API FusionDefinition : public FusionState {
std::optional<int8_t> device,
bool override_user_schedule,
bool capture_debug_output,
bool profile) const;
bool profile,
std::vector<std::string> _enable_options,
std::vector<std::string> _disable_options) const;
//! Return debugging output captured through exeuction with
//! capture_debug_output=true
std::optional<std::string> getDebugOutput() const {
Expand Down
10 changes: 8 additions & 2 deletions csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1016,7 +1016,9 @@ void initNvFuserPythonBindings(PyObject* module) {
std::optional<int64_t> device,
bool override_user_schedule,
bool capture_debug_output,
bool profile) {
bool profile,
std::vector<std::string> _enable_options,
std::vector<std::string> _disable_options) {
std::vector<c10::IValue> inputs;
for (py::handle obj : iter) {
// Allows for a Vector of Sizes to be inputed as a list/tuple
Expand All @@ -1041,14 +1043,18 @@ void initNvFuserPythonBindings(PyObject* module) {
int8_device,
override_user_schedule,
capture_debug_output,
profile);
profile,
_enable_options,
_disable_options);
},
py::arg("inputs"),
py::kw_only(),
py::arg("device") = py::none(),
py::arg("override_user_schedule") = false,
py::arg("capture_debug_output") = false,
py::arg("profile") = false,
py::arg("_enable_options") = py::none(),
py::arg("_disable_options") = py::none(),
py::return_value_policy::reference)
.def_static(
"_profile",
Expand Down
16 changes: 16 additions & 0 deletions nvfuser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import re
import sys
from typing import Callable, Optional, Union, List # noqa: F401
import warnings

import torch

Expand Down Expand Up @@ -77,6 +78,8 @@ def execute(
print_repro=False,
profile=False,
save_repro_inputs=False,
_enable_options: list[str] = [],
_disable_options: list[str] = [],
):
"""
Executes an nvFuser set of kernels for a given Fusion
Expand Down Expand Up @@ -119,6 +122,11 @@ def execute(
profile (bool): Captures a CUPTI based profile of a fusion.
save_repro_inputs (bool): Saves the inputs for last_repro_script() to
provide a provide a reproduction script.
_enable_options/_disable_options (list): NVFUSER_ENABLE/DISABLE options to use.
This is an alternative to environment variables.
Note: Currently, we do not cache/store these options in the FusionCache which makes it
plausible to reuse kernels when executing the same fusion definition with different sets of options.
Reset the FusionCache manually to avoid inadvertent kernel reuse when between different sets of options.
Returns:
List[Tensor]
Expand Down Expand Up @@ -176,15 +184,23 @@ def execute(
self.fake_inputs = [fake_mode.from_tensor(inp) for inp in inputs]

results = None

try:
if print_repro:
print(self.repro_script_for(inputs))
if len(_enable_options) or len(_disable_options):
warnings.warn(
"Reset the FusionCache manually to avoid reusing kernels when re-executing the fusion definition with different options."
)

results = self._execute(
inputs,
device=device,
override_user_schedule=override_user_schedule,
capture_debug_output=capture_debug_output,
profile=profile,
_enable_options=_enable_options,
_disable_options=_disable_options,
)
return results
except Exception as err:
Expand Down
41 changes: 41 additions & 0 deletions tests/python/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4709,3 +4709,44 @@ def fusion_func(fd: FusionDefinition) -> None:
fd.add_output(T223)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)

def test_enable_disable_options(self):
m = 24
n = 16
k = 8
inps = [
torch.randn(m, k, device="cuda", dtype=torch.float),
torch.randn(k, n, device="cuda", dtype=torch.float),
]

def fusion_func(fd: FusionDefinition, inps) -> None:
t0 = fd.from_pytorch(inps[0])
t1 = fd.from_pytorch(inps[1])
t2 = fd.ops.matmul(t0, t1)
fd.add_output(t2)

with FusionDefinition() as fd:
fusion_func(fd, inps=inps)

# By default, matmul will be be run through expr_eval scheduler.
# Through setting the enable and disable options as below,
# we can execute it through matmul scheduler. The above fusion will not
# be accepted by the matmul scheduler since the outputs are of type Float and raises a RuntimeError.
# Note: We use this error-based test since for compatible dtypes (float16/bfloat16),
# the matmul scheduler ran into a scheduling error on H100. This test might be more robust against
# changes in matmul scheduler in the interim.

self.assertRaisesRegex(
RuntimeError,
"Can not find a scheduler to schedule fusion segment",
self.exec_nvfuser,
partial(fusion_func, inps=inps),
inps,
_enable_options=["fuse_matmul"],
_disable_options=["matmul_expr_eval"],
skip_serde_check=True,
)

# Serializing error test cases corrupts the serialized binary causing subsequent tests to fail.
# Reset the fusion cache to avoid this.
FusionCache.reset()
9 changes: 8 additions & 1 deletion tests/python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,8 @@ def exec_nvfuser(
fusion_func,
inputs,
*,
_enable_options=[],
_disable_options=[],
new_fusion_expected=True,
device=None,
is_clonable=True,
Expand All @@ -432,7 +434,12 @@ def exec_nvfuser(
with FusionDefinition() as fd:
fusion_func(fd)
torch.manual_seed(0)
out = fd.execute(inputs, device=device)
out = fd.execute(
inputs,
device=device,
_enable_options=_enable_options,
_disable_options=_disable_options,
)

self.assertTrue(
check_captured_python_definition(out, fd, inputs_captured, device)
Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.2.22
0.2.23

0 comments on commit 7212038

Please sign in to comment.