Skip to content

Commit

Permalink
lintrunner
Browse files Browse the repository at this point in the history
  • Loading branch information
Priya2698 committed Oct 29, 2024
1 parent 84d8662 commit d39ac44
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 49 deletions.
80 changes: 42 additions & 38 deletions csrc/options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,20 +148,21 @@ std::unordered_map<DebugDumpOption, std::vector<std::string>> Options<
}

const std::unordered_map<std::string, EnableOption>& getEnableOptions() {
static const std::unordered_map<std::string, EnableOption> available_options = {
{"fuse_matmul", EnableOption::FuseMatmul},
{"fuse_multiple_matmuls", EnableOption::FuseMultipleMatmuls},
{"id_model", EnableOption::IdModel},
{"kernel_db", EnableOption::KernelDb},
{"kernel_profile", EnableOption::KernelProfile},
{"memory_promotion", EnableOption::MemoryPromotion},
{"reuse_zeroed_memory", EnableOption::ReuseZeroedMemory},
{"static_fusion_count", EnableOption::StaticFusionCount},
{"warn_register_spill", EnableOption::WarnRegisterSpill},
{"io_to_lower_precision", EnableOption::IoToLowerPrecision},
{"kernel_debug", EnableOption::KernelDebug},
{"kernel_lineinfo", EnableOption::KernelLineInfo},
};
static const std::unordered_map<std::string, EnableOption> available_options =
{
{"fuse_matmul", EnableOption::FuseMatmul},
{"fuse_multiple_matmuls", EnableOption::FuseMultipleMatmuls},
{"id_model", EnableOption::IdModel},
{"kernel_db", EnableOption::KernelDb},
{"kernel_profile", EnableOption::KernelProfile},
{"memory_promotion", EnableOption::MemoryPromotion},
{"reuse_zeroed_memory", EnableOption::ReuseZeroedMemory},
{"static_fusion_count", EnableOption::StaticFusionCount},
{"warn_register_spill", EnableOption::WarnRegisterSpill},
{"io_to_lower_precision", EnableOption::IoToLowerPrecision},
{"kernel_debug", EnableOption::KernelDebug},
{"kernel_lineinfo", EnableOption::KernelLineInfo},
};
return available_options;
}

Expand All @@ -172,7 +173,8 @@ std::unordered_map<EnableOption, std::vector<std::string>> Options<
return parseEnvOptions("ENABLE", available_options);
}

std::optional<EnableOption> stringToEnableOption(const std::string& enable_option) {
std::optional<EnableOption> stringToEnableOption(
const std::string& enable_option) {
const auto& opts = getEnableOptions();
auto it = opts.find(enable_option);
if (it != opts.end()) {
Expand All @@ -182,28 +184,29 @@ std::optional<EnableOption> stringToEnableOption(const std::string& enable_optio
}

const std::unordered_map<std::string, DisableOption>& getDisableOptions() {
static const std::unordered_map<std::string, DisableOption> available_options = {
{"compile_to_sass", DisableOption::CompileToSass},
{"contig_indexing", DisableOption::ContigIndexing},
{"expr_simplify", DisableOption::ExprSimplify},
{"fallback", DisableOption::Fallback},
{"fma", DisableOption::Fma},
{"grouped_grid_welford_outer_opt",
DisableOption::GroupedGridWelfordOuterOpt},
{"index_hoist", DisableOption::IndexHoist},
{"magic_zero", DisableOption::MagicZero},
{"matmul_expr_eval", DisableOption::MatmulExprEval},
{"nvtx", DisableOption::Nvtx},
{"parallel_compile", DisableOption::ParallelCompile},
{"parallel_serde", DisableOption::ParallelSerde},
{"predicate_elimination", DisableOption::PredicateElimination},
{"python_inline_definitions", DisableOption::PythonInlineDefinitions},
{"kernel_reuse", DisableOption::KernelReuse},
{"var_name_remapping", DisableOption::VarNameRemapping},
{"welford_vectorization", DisableOption::WelfordVectorization},
{"reuse_mismatched_type_registers",
DisableOption::ReuseMismatchedTypeRegisters},
{"multidevice", DisableOption::Multidevice}};
static const std::unordered_map<std::string, DisableOption>
available_options = {
{"compile_to_sass", DisableOption::CompileToSass},
{"contig_indexing", DisableOption::ContigIndexing},
{"expr_simplify", DisableOption::ExprSimplify},
{"fallback", DisableOption::Fallback},
{"fma", DisableOption::Fma},
{"grouped_grid_welford_outer_opt",
DisableOption::GroupedGridWelfordOuterOpt},
{"index_hoist", DisableOption::IndexHoist},
{"magic_zero", DisableOption::MagicZero},
{"matmul_expr_eval", DisableOption::MatmulExprEval},
{"nvtx", DisableOption::Nvtx},
{"parallel_compile", DisableOption::ParallelCompile},
{"parallel_serde", DisableOption::ParallelSerde},
{"predicate_elimination", DisableOption::PredicateElimination},
{"python_inline_definitions", DisableOption::PythonInlineDefinitions},
{"kernel_reuse", DisableOption::KernelReuse},
{"var_name_remapping", DisableOption::VarNameRemapping},
{"welford_vectorization", DisableOption::WelfordVectorization},
{"reuse_mismatched_type_registers",
DisableOption::ReuseMismatchedTypeRegisters},
{"multidevice", DisableOption::Multidevice}};
return available_options;
}

Expand All @@ -221,7 +224,8 @@ std::unordered_map<DisableOption, std::vector<std::string>> Options<
return options;
}

std::optional<DisableOption> stringToDisableOption(const std::string& disable_option) {
std::optional<DisableOption> stringToDisableOption(
const std::string& disable_option) {
const auto& opts = getDisableOptions();
auto it = opts.find(disable_option);
if (it != opts.end()) {
Expand Down
6 changes: 4 additions & 2 deletions csrc/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ NVF_API std::unordered_map<EnableOption, std::vector<std::string>> Options<

using EnableOptions = Options<EnableOption>;

std::optional<EnableOption> stringToEnableOption(const std::string& enable_option);
std::optional<EnableOption> stringToEnableOption(
const std::string& enable_option);

bool isOptionEnabled(EnableOption option);

Expand All @@ -269,7 +270,8 @@ NVF_API std::unordered_map<DisableOption, std::vector<std::string>> Options<

using DisableOptions = Options<DisableOption>;

std::optional<DisableOption> stringToDisableOption(const std::string& disable_option);
std::optional<DisableOption> stringToDisableOption(
const std::string& disable_option);

NVF_API bool isOptionDisabled(DisableOption option);

Expand Down
2 changes: 1 addition & 1 deletion csrc/python_frontend/fusion_definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ std::vector<at::Tensor> FusionDefinition::execute(
}

DisableOptionsGuard disable_opt_guard;
for (auto disable_option: disable_options) {
for (auto disable_option : disable_options) {
std::optional<DisableOption> opt = stringToDisableOption(disable_option);
NVF_CHECK(opt.has_value(), "Unrecognized disable_option: ", disable_option);
DisableOptionsGuard::getCurOptions().set(opt.value());
Expand Down
4 changes: 2 additions & 2 deletions nvfuser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def execute(
self.fake_inputs = [fake_mode.from_tensor(inp) for inp in inputs]

results = None

try:
results = self._execute(
inputs,
Expand All @@ -188,7 +188,7 @@ def execute(
capture_debug_output=capture_debug_output,
profile=profile,
enable_options=enable_options,
disable_options=disable_options
disable_options=disable_options,
)
if print_repro:
print(self.repro_script_for(inputs))
Expand Down
17 changes: 11 additions & 6 deletions tests/python/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4600,7 +4600,7 @@ def fusion_func(fd: FusionDefinition) -> None:
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
for out in nvf_out:
self.assertTrue(out.allclose(x[:, 1:, 2:]))

def test_enable_disable_options(self):
m = 24
n = 16
Expand All @@ -4618,14 +4618,19 @@ def fusion_func(fd: FusionDefinition, inps) -> None:

with FusionDefinition() as fd:
fusion_func(fd, inps=inps)

nvf_out = fd.execute(inps, enable_options=["fuse_matmul"], disable_options=["matmul_expr_eval"], profile=True)

nvf_out = fd.execute(
inps,
enable_options=["fuse_matmul"],
disable_options=["matmul_expr_eval"],
profile=True,
)
prof = fd.profile()
self.assertEqual(len(prof.kernel_profiles), 1)

# By default, matmul will be be run through expr_eval scheduler.
# Through setting the enable and disable options as above,
# Through setting the enable and disable options as above,
# we can execute it through matmul scheduler.
self.assertEqual(prof.kernel_profiles[0].scheduler, 'matmul')
self.assertEqual(prof.kernel_profiles[0].scheduler, "matmul")
eager_out = torch.matmul(inps[0], inps[1])
self.assertEqual(eager_out, nvf_out[0])

0 comments on commit d39ac44

Please sign in to comment.