Skip to content

Commit

Permalink
patch potential segfault (#3262)
Browse files Browse the repository at this point in the history
updating fusion inside a generator loop causes segfault. defer to update to after the loop.
  • Loading branch information
jjsjann123 authored Oct 26, 2024
1 parent ff75845 commit 5776198
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 5 deletions.
7 changes: 7 additions & 0 deletions csrc/ir/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,12 @@ class ValReplacementMutator : private OptOutMutator {
for (auto stmt : more_stmts) {
dispatchMutate(stmt);
}

for (const auto& [old_v, new_v] : replacement_map_) {
if (old_v->isFusionOutput()) {
fusion->replaceOutput(old_v, new_v);
}
}
}

private:
Expand Down Expand Up @@ -523,6 +529,7 @@ class ValReplacementMutator : private OptOutMutator {
void replaceValue(
Fusion* fusion,
const std::unordered_map<Val*, Val*>& replacement_map) {
// NOLINTNEXTLINE(bugprone-unused-raii)
ValReplacementMutator(fusion, replacement_map);
}

Expand Down
3 changes: 2 additions & 1 deletion csrc/ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ struct MatmulInputs {

namespace nvfuser::ir_utils {

// Replace values in fusion using ValReplacementMutator
// Replace values in fusion using ValReplacementMutator, it also updates fusion
// output according to the replacement_map
void replaceValue(
Fusion*,
const std::unordered_map<Val*, Val*>& replacement_map);
Expand Down
3 changes: 0 additions & 3 deletions csrc/preseg_passes/consecutive_cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,6 @@ void castOptimizationPass(Fusion* fusion) {
// if output dtype is identical to starting_anchor dtype, we can't keep
// the last cast op and will need to re-write all uses here
ir_utils::replaceValue(fusion, {{expr->output(0), starting_anchor}});
if (expr->output(0)->isFusionOutput()) {
fusion->replaceOutput(expr->output(0), starting_anchor);
}
} else {
replaceInputInCast(expr->output(0), starting_anchor);
}
Expand Down
9 changes: 8 additions & 1 deletion csrc/preseg_passes/move_pad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,16 +533,23 @@ void replaceCat(Fusion* fusion) {
std::vector<Expr*> exprs = fusion->exprs();

// sanitizing CatOp with series of binary add
std::unordered_map<Val*, Val*> replacement_map;
for (auto* cat : ir_utils::filterByType<CatOp>(exprs)) {
if (std::any_of(cat->inputs().begin(), cat->inputs().end(), [](Val* val) {
NVF_ERROR(
val->definition() != nullptr,
"CatOp shouldn't take fusion input as argument");
return !val->definition()->isA<PadOp>();
})) {
Val* res = replaceCatOpWithBinaryOp(cat->inputs());

// replace `CatOp` with the replay result.
ir_utils::replaceValInAllExprInputsAndFusionOutputs(cat->output(0), res);
replacement_map[cat->output(0)] = res;
}
}
// defer the update to after the for loop on a generator to avoid deleting
// nodes in the replacement
ir_utils::replaceValue(fusion, replacement_map);
}

} // namespace
Expand Down
143 changes: 143 additions & 0 deletions tests/python/test_optimization_passes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
import torch
from nvfuser import FusionDefinition, DataType


# this example hits a segfault in nvfuser::preseg_passes::MovePadPass::replaceCat, where the old optimization pass updates the fusion within the filterByType generator, causing a dynamic cast on a dangling pointer.
def test_litgpt_variants_gpt_neox_like():
def nvfuser_fusion_id4(fd: FusionDefinition) -> None:
T0 = fd.define_tensor(
shape=[128, 16],
contiguity=[True, True],
dtype=DataType.Float,
is_cpu=False,
stride_order=[1, 0],
)
T1 = fd.define_tensor(
shape=[128, 16],
contiguity=[True, True],
dtype=DataType.Float,
is_cpu=False,
stride_order=[1, 0],
)
T2 = fd.define_tensor(
shape=[5, 5, 192],
contiguity=[True, True, True],
dtype=DataType.Float,
is_cpu=False,
stride_order=[2, 1, 0],
)
T12 = fd.ops.slice(
T0, start_indices=[0, 0], end_indices=[5, 16], strides=[1, 1]
)
T22 = fd.ops.slice(
T1, start_indices=[0, 0], end_indices=[5, 16], strides=[1, 1]
)
T29 = fd.ops.reshape(T2, new_shape=[5, 5, 4, 3, 16])
T30 = fd.ops.permute(T29, dims=[0, 2, 3, 1, 4])
T49 = fd.ops.slice(
T30,
start_indices=[0, 0, 0, 0, 0],
end_indices=[5, 4, 1, 5, 16],
strides=[1, 1, 1, 1, 1],
)
T68 = fd.ops.slice(
T30,
start_indices=[0, 0, 1, 0, 0],
end_indices=[5, 4, 2, 5, 16],
strides=[1, 1, 1, 1, 1],
)
T87 = fd.ops.slice(
T30,
start_indices=[0, 0, 2, 0, 0],
end_indices=[5, 4, 3, 5, 16],
strides=[1, 1, 1, 1, 1],
)
T93 = fd.ops.reshape(T49, new_shape=[5, 4, 5, 16])
T99 = fd.ops.reshape(T68, new_shape=[5, 4, 5, 16])
T105 = fd.ops.reshape(T87, new_shape=[5, 4, 5, 16])
T121 = fd.ops.slice(
T93,
start_indices=[0, 0, 0, 0],
end_indices=[5, 4, 5, 8],
strides=[1, 1, 1, 1],
)
T137 = fd.ops.slice(
T93,
start_indices=[0, 0, 0, 8],
end_indices=[5, 4, 5, 16],
strides=[1, 1, 1, 1],
)
T138 = fd.ops.neg(T137)
T139 = fd.ops.cat([T138, T121], dim=-1)
S140 = fd.define_scalar(5, dtype=DataType.Int)
S141 = fd.define_scalar(4, dtype=DataType.Int)
S142 = fd.define_scalar(5, dtype=DataType.Int)
S143 = fd.define_scalar(16, dtype=DataType.Int)
T145 = fd.ops.broadcast_in_dim(
T12, shape=[S140, S141, S142, S143], broadcast_dims=[2, 3]
)
T146 = fd.ops.mul(T93, T145)
S147 = fd.define_scalar(5, dtype=DataType.Int)
S148 = fd.define_scalar(4, dtype=DataType.Int)
S149 = fd.define_scalar(5, dtype=DataType.Int)
S150 = fd.define_scalar(16, dtype=DataType.Int)
T152 = fd.ops.broadcast_in_dim(
T22, shape=[S147, S148, S149, S150], broadcast_dims=[2, 3]
)
T153 = fd.ops.mul(T139, T152)
T154 = fd.ops.add(T146, T153)
T170 = fd.ops.slice(
T99,
start_indices=[0, 0, 0, 0],
end_indices=[5, 4, 5, 8],
strides=[1, 1, 1, 1],
)
T186 = fd.ops.slice(
T99,
start_indices=[0, 0, 0, 8],
end_indices=[5, 4, 5, 16],
strides=[1, 1, 1, 1],
)
T187 = fd.ops.neg(T186)
T188 = fd.ops.cat([T187, T170], dim=-1)
T189 = fd.ops.mul(T99, T145)
T190 = fd.ops.mul(T188, T152)
T191 = fd.ops.add(T189, T190)
T207 = fd.ops.slice(
T93,
start_indices=[0, 0, 0, 0],
end_indices=[5, 4, 5, 0],
strides=[1, 1, 1, 1],
)
T208 = fd.ops.cat([T154, T207], dim=-1)
T224 = fd.ops.slice(
T99,
start_indices=[0, 0, 0, 0],
end_indices=[5, 4, 5, 0],
strides=[1, 1, 1, 1],
)
T225 = fd.ops.cat([T191, T224], dim=-1)
S226 = fd.define_scalar(0.500000, dtype=DataType.Double)
T227 = fd.ops.mul(T208, S226)
T228 = fd.ops.permute(T225, dims=[0, 1, 3, 2])
S229 = fd.define_scalar(0.500000, dtype=DataType.Double)
T230 = fd.ops.mul(T228, S229)
fd.add_output(T105)
fd.add_output(T145)
fd.add_output(T152)
fd.add_output(T227)
fd.add_output(T230)

with FusionDefinition() as fd:
nvfuser_fusion_id4(fd)

inputs = [
torch.testing.make_tensor((128, 16), dtype=torch.float32, device="cuda:0"),
torch.testing.make_tensor((128, 16), dtype=torch.float32, device="cuda:0"),
torch.testing.make_tensor((5, 5, 192), dtype=torch.float32, device="cuda:0"),
]
# TODO: I wish we have an easy way for validation like in cpp tests.
fd.execute(inputs)

0 comments on commit 5776198

Please sign in to comment.