Skip to content

Commit

Permalink
Fixing indexing issue with almost-exact mapping (#3494)
Browse files Browse the repository at this point in the history
This is a follow-up to #3454. Specifically, as we now allow updating of
index mappings in `TensorIndexer::setIndex`, there can be a case like
below:

```
merge b0(1), i1(8) -> i2(8)
```

When propagating the index of `i2`, `i2_idx`, backward, the input IDs
would get `i2_idx / 8` and `i2_idx % 8`, respectively. However, if
`i2_idx` is not guaranteed to be less than 8 (for example, due to a
non-divisible split of `i2`), the broadcast `b0` id would potentially
get a non-zero index, which means that we would need to predicate `b0`
as well, i.e., `i2_idx / 8 < 1`, if it's part of the allocation domain.
However, this would not be predicated as we ignore broadcast IDs. The
new unit test would fail at the validation due to this predication
problem.

To fix the issue, we could also predicate broadcast allocation IDs.
Instead, this PR takes a simpler approach that just forwards a given
index to its almost-exactly mapped ID as is. In the above case, `b0` and
`i1` would get `0` and `i2_idx`, respectively.

Tested H100 manually.
  • Loading branch information
naoyam authored Nov 28, 2024
1 parent 50bac23 commit 889262e
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 10 deletions.
45 changes: 35 additions & 10 deletions csrc/id_model/id_model_index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,25 @@ void IdGraphIndexCompute::handle(Split* split) {
auto inner_extent = split->inner()->extent();

if (is_forward) {
auto in_idx = getIndex(split->in());
auto outer_idx = SimplifyingIrBuilder::divExpr(in_idx, inner_extent);
Val* inner_idx = SimplifyingIrBuilder::modExpr(in_idx, inner_extent);
setIndex(split->outer(), outer_idx);
setIndex(split->inner(), inner_idx);
// When propagating Split forward, if one of the outputs is mapped
// with the input (because of the almost-exact mapping), don't
// update the index and just set 0 as the index of the other
// output. This is necessary when the other output is a broadcast
// ID, which is ignored for predication. See
// IndexingTest.AlmostExactIndexingUpdate for a concrete example.
if (traversal_graph_.disjointValSets().strictAreMapped(
split->in(), split->inner())) {
setIndex(split->outer(), split->fusion()->zeroVal());
} else if (traversal_graph_.disjointValSets().strictAreMapped(
split->in(), split->outer())) {
setIndex(split->inner(), split->fusion()->zeroVal());
} else {
auto in_idx = getIndex(split->in());
auto outer_idx = SimplifyingIrBuilder::divExpr(in_idx, inner_extent);
Val* inner_idx = SimplifyingIrBuilder::modExpr(in_idx, inner_extent);
setIndex(split->outer(), outer_idx);
setIndex(split->inner(), inner_idx);
}
} else {
auto outer_idx = getIndex(split->outer());
auto inner_idx = getIndex(split->inner());
Expand All @@ -43,11 +57,22 @@ void IdGraphIndexCompute::handle(Merge* merge) {
SimplifyingIrBuilder::mulExpr(outer_idx, inner_ext), inner_idx);
setIndex(merge->out(), out_idx);
} else {
auto out_idx = getIndex(merge->out());
auto outer_idx = SimplifyingIrBuilder::divExpr(out_idx, inner_ext);
setIndex(merge->outer(), outer_idx);
Val* inner_idx = SimplifyingIrBuilder::modExpr(out_idx, inner_ext);
setIndex(merge->inner(), inner_idx);
// Similar to the forward propagation of Split, when propagating Merge
// backward, if one of the inputs is mapped with the output, don't update
// the index and just set 0 as the index of the other input.
if (traversal_graph_.disjointValSets().strictAreMapped(
merge->out(), merge->inner())) {
setIndex(merge->outer(), merge->fusion()->zeroVal());
} else if (traversal_graph_.disjointValSets().strictAreMapped(
merge->out(), merge->outer())) {
setIndex(merge->inner(), merge->fusion()->zeroVal());
} else {
auto out_idx = getIndex(merge->out());
auto outer_idx = SimplifyingIrBuilder::divExpr(out_idx, inner_ext);
setIndex(merge->outer(), outer_idx);
Val* inner_idx = SimplifyingIrBuilder::modExpr(out_idx, inner_ext);
setIndex(merge->inner(), inner_idx);
}
}
}

Expand Down
38 changes: 38 additions & 0 deletions tests/cpp/test_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5406,4 +5406,42 @@ TEST_F(IndexingTest, ResizeRotation) {
testValidate(&fusion, outputs, inputs, __LINE__, __FILE__);
}

TEST_F(IndexingTest, AlmostExactIndexingUpdate) {
EnableOptionsGuard enable_options_guard;
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});

Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeContigConcreteTensor({4, 8});
fusion.addInput(tv0);

auto tv1 = slice(
tv0,
{{IrBuilder::create<Val>(1L), IrBuilder::create<Val>(2L)},
{IrBuilder::create<Val>(0L), tv0->axis(1)->extent()}});

fusion.addOutput(tv1);

// [b0, i1]
tv1->split(-1, 5);
// [b0, i1/5, 5]
tv1->split(-1, 3);
// [b0, i1/5, 5/3, 3]
tv1->merge(0, -1);
// [b0*i1/5*3, 5/3]
tv1->split(0, 2);
// [b0*i1/5*3/2, 2, 5/3]

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({4, 8}, options);
std::vector<c10::IValue> inputs{t0};

KernelExecutor ke;
ke.compile(&fusion, inputs);
auto outputs = ke.run(inputs);

testValidate(&fusion, outputs, inputs, __LINE__, __FILE__);
}

} // namespace nvfuser

0 comments on commit 889262e

Please sign in to comment.