From 57761987703d6e924ca76e441f512b7679054f8b Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Sat, 26 Oct 2024 11:42:01 -0700 Subject: [PATCH] patch potential segfault (#3262) updating fusion inside a generator loop causes segfault. defer to update to after the loop. --- csrc/ir/utils.cpp | 7 ++ csrc/ir/utils.h | 3 +- csrc/preseg_passes/consecutive_cast.cpp | 3 - csrc/preseg_passes/move_pad.cpp | 9 +- tests/python/test_optimization_passes.py | 143 +++++++++++++++++++++++ 5 files changed, 160 insertions(+), 5 deletions(-) create mode 100644 tests/python/test_optimization_passes.py diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index eb160d2e2de..a9b0dfb3d45 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -475,6 +475,12 @@ class ValReplacementMutator : private OptOutMutator { for (auto stmt : more_stmts) { dispatchMutate(stmt); } + + for (const auto& [old_v, new_v] : replacement_map_) { + if (old_v->isFusionOutput()) { + fusion->replaceOutput(old_v, new_v); + } + } } private: @@ -523,6 +529,7 @@ class ValReplacementMutator : private OptOutMutator { void replaceValue( Fusion* fusion, const std::unordered_map& replacement_map) { + // NOLINTNEXTLINE(bugprone-unused-raii) ValReplacementMutator(fusion, replacement_map); } diff --git a/csrc/ir/utils.h b/csrc/ir/utils.h index 6ce59fb95c4..8c674e211f0 100644 --- a/csrc/ir/utils.h +++ b/csrc/ir/utils.h @@ -82,7 +82,8 @@ struct MatmulInputs { namespace nvfuser::ir_utils { -// Replace values in fusion using ValReplacementMutator +// Replace values in fusion using ValReplacementMutator, it also updates fusion +// output according to the replacement_map void replaceValue( Fusion*, const std::unordered_map& replacement_map); diff --git a/csrc/preseg_passes/consecutive_cast.cpp b/csrc/preseg_passes/consecutive_cast.cpp index e41dbe2d6bd..a0901d3deaf 100644 --- a/csrc/preseg_passes/consecutive_cast.cpp +++ b/csrc/preseg_passes/consecutive_cast.cpp @@ -170,9 +170,6 @@ void castOptimizationPass(Fusion* fusion) { // if output dtype is identical to starting_anchor dtype, we can't keep // the last cast op and will need to re-write all uses here ir_utils::replaceValue(fusion, {{expr->output(0), starting_anchor}}); - if (expr->output(0)->isFusionOutput()) { - fusion->replaceOutput(expr->output(0), starting_anchor); - } } else { replaceInputInCast(expr->output(0), starting_anchor); } diff --git a/csrc/preseg_passes/move_pad.cpp b/csrc/preseg_passes/move_pad.cpp index 0dee18d21a7..bb277e62566 100644 --- a/csrc/preseg_passes/move_pad.cpp +++ b/csrc/preseg_passes/move_pad.cpp @@ -533,16 +533,23 @@ void replaceCat(Fusion* fusion) { std::vector exprs = fusion->exprs(); // sanitizing CatOp with series of binary add + std::unordered_map replacement_map; for (auto* cat : ir_utils::filterByType(exprs)) { if (std::any_of(cat->inputs().begin(), cat->inputs().end(), [](Val* val) { + NVF_ERROR( + val->definition() != nullptr, + "CatOp shouldn't take fusion input as argument"); return !val->definition()->isA(); })) { Val* res = replaceCatOpWithBinaryOp(cat->inputs()); // replace `CatOp` with the replay result. - ir_utils::replaceValInAllExprInputsAndFusionOutputs(cat->output(0), res); + replacement_map[cat->output(0)] = res; } } + // defer the update to after the for loop on a generator to avoid deleting + // nodes in the replacement + ir_utils::replaceValue(fusion, replacement_map); } } // namespace diff --git a/tests/python/test_optimization_passes.py b/tests/python/test_optimization_passes.py new file mode 100644 index 00000000000..3fdc01a979f --- /dev/null +++ b/tests/python/test_optimization_passes.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +import torch +from nvfuser import FusionDefinition, DataType + + +# this example hits a segfault in nvfuser::preseg_passes::MovePadPass::replaceCat, where the old optimization pass updates the fusion within the filterByType generator, causing a dynamic cast on a dangling pointer. +def test_litgpt_variants_gpt_neox_like(): + def nvfuser_fusion_id4(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[128, 16], + contiguity=[True, True], + dtype=DataType.Float, + is_cpu=False, + stride_order=[1, 0], + ) + T1 = fd.define_tensor( + shape=[128, 16], + contiguity=[True, True], + dtype=DataType.Float, + is_cpu=False, + stride_order=[1, 0], + ) + T2 = fd.define_tensor( + shape=[5, 5, 192], + contiguity=[True, True, True], + dtype=DataType.Float, + is_cpu=False, + stride_order=[2, 1, 0], + ) + T12 = fd.ops.slice( + T0, start_indices=[0, 0], end_indices=[5, 16], strides=[1, 1] + ) + T22 = fd.ops.slice( + T1, start_indices=[0, 0], end_indices=[5, 16], strides=[1, 1] + ) + T29 = fd.ops.reshape(T2, new_shape=[5, 5, 4, 3, 16]) + T30 = fd.ops.permute(T29, dims=[0, 2, 3, 1, 4]) + T49 = fd.ops.slice( + T30, + start_indices=[0, 0, 0, 0, 0], + end_indices=[5, 4, 1, 5, 16], + strides=[1, 1, 1, 1, 1], + ) + T68 = fd.ops.slice( + T30, + start_indices=[0, 0, 1, 0, 0], + end_indices=[5, 4, 2, 5, 16], + strides=[1, 1, 1, 1, 1], + ) + T87 = fd.ops.slice( + T30, + start_indices=[0, 0, 2, 0, 0], + end_indices=[5, 4, 3, 5, 16], + strides=[1, 1, 1, 1, 1], + ) + T93 = fd.ops.reshape(T49, new_shape=[5, 4, 5, 16]) + T99 = fd.ops.reshape(T68, new_shape=[5, 4, 5, 16]) + T105 = fd.ops.reshape(T87, new_shape=[5, 4, 5, 16]) + T121 = fd.ops.slice( + T93, + start_indices=[0, 0, 0, 0], + end_indices=[5, 4, 5, 8], + strides=[1, 1, 1, 1], + ) + T137 = fd.ops.slice( + T93, + start_indices=[0, 0, 0, 8], + end_indices=[5, 4, 5, 16], + strides=[1, 1, 1, 1], + ) + T138 = fd.ops.neg(T137) + T139 = fd.ops.cat([T138, T121], dim=-1) + S140 = fd.define_scalar(5, dtype=DataType.Int) + S141 = fd.define_scalar(4, dtype=DataType.Int) + S142 = fd.define_scalar(5, dtype=DataType.Int) + S143 = fd.define_scalar(16, dtype=DataType.Int) + T145 = fd.ops.broadcast_in_dim( + T12, shape=[S140, S141, S142, S143], broadcast_dims=[2, 3] + ) + T146 = fd.ops.mul(T93, T145) + S147 = fd.define_scalar(5, dtype=DataType.Int) + S148 = fd.define_scalar(4, dtype=DataType.Int) + S149 = fd.define_scalar(5, dtype=DataType.Int) + S150 = fd.define_scalar(16, dtype=DataType.Int) + T152 = fd.ops.broadcast_in_dim( + T22, shape=[S147, S148, S149, S150], broadcast_dims=[2, 3] + ) + T153 = fd.ops.mul(T139, T152) + T154 = fd.ops.add(T146, T153) + T170 = fd.ops.slice( + T99, + start_indices=[0, 0, 0, 0], + end_indices=[5, 4, 5, 8], + strides=[1, 1, 1, 1], + ) + T186 = fd.ops.slice( + T99, + start_indices=[0, 0, 0, 8], + end_indices=[5, 4, 5, 16], + strides=[1, 1, 1, 1], + ) + T187 = fd.ops.neg(T186) + T188 = fd.ops.cat([T187, T170], dim=-1) + T189 = fd.ops.mul(T99, T145) + T190 = fd.ops.mul(T188, T152) + T191 = fd.ops.add(T189, T190) + T207 = fd.ops.slice( + T93, + start_indices=[0, 0, 0, 0], + end_indices=[5, 4, 5, 0], + strides=[1, 1, 1, 1], + ) + T208 = fd.ops.cat([T154, T207], dim=-1) + T224 = fd.ops.slice( + T99, + start_indices=[0, 0, 0, 0], + end_indices=[5, 4, 5, 0], + strides=[1, 1, 1, 1], + ) + T225 = fd.ops.cat([T191, T224], dim=-1) + S226 = fd.define_scalar(0.500000, dtype=DataType.Double) + T227 = fd.ops.mul(T208, S226) + T228 = fd.ops.permute(T225, dims=[0, 1, 3, 2]) + S229 = fd.define_scalar(0.500000, dtype=DataType.Double) + T230 = fd.ops.mul(T228, S229) + fd.add_output(T105) + fd.add_output(T145) + fd.add_output(T152) + fd.add_output(T227) + fd.add_output(T230) + + with FusionDefinition() as fd: + nvfuser_fusion_id4(fd) + + inputs = [ + torch.testing.make_tensor((128, 16), dtype=torch.float32, device="cuda:0"), + torch.testing.make_tensor((128, 16), dtype=torch.float32, device="cuda:0"), + torch.testing.make_tensor((5, 5, 192), dtype=torch.float32, device="cuda:0"), + ] + # TODO: I wish we have an easy way for validation like in cpp tests. + fd.execute(inputs)