-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pointwise scheduler fails to validate reference tv #3513
base: main
Are you sure you want to change the base?
Changes from 69 commits
f0ce0e3
70e31bf
5f09e36
9ad9edb
6540201
89dc741
ed56c75
f1e7e0a
9d174c9
bf425eb
aef13ac
a9ae516
d9e8dc0
7333806
f6ad363
cef0b83
a557a8b
65cb621
f88ebf7
ea89b69
0fc0dc1
07797a4
94e2ddf
3e2b43e
63284b6
e39ec58
fa4d8ab
66bc533
26054c3
3d2b926
bb659f8
45bb785
b744086
3a16c65
3b9c97f
54176a7
5b668d6
3112ebd
db4cabc
a60bdc6
55ddfc8
0972ca2
3e4cf86
325f5bb
8a92a15
880d73e
22a7561
2bc97bf
49767da
2cb3372
73c66f8
dbd5995
b7f628f
4ba5baa
1d021c7
742f7f3
145d902
19291c8
526e6b7
0ecc1f6
6abaa1d
e8a4ddd
6ada657
d797df8
25362cd
a129e72
b7f2efb
307569f
af315c7
7668d4e
d46323c
c70f160
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -142,6 +142,137 @@ bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv) | |
return in_concrete_ids.empty(); | ||
} | ||
|
||
// Note: ideally we would want to check that reference_tv contains all iter | ||
// domains in target_tv, so that transformation applied on reference_tv can be | ||
// propagated to target_tv. But we don't have an easy way to check that. Instead | ||
// of that, this function checks that all source iter domains involved in | ||
// transformation on target_tv is covered by reference_tv. Source iter domains | ||
// of TensorViews are IDs that doesn't have an definition and are producers of | ||
// any IDs on the logical domain of the given TensorView. | ||
// | ||
// ------ | ||
// | ||
// e.g 0. | ||
// T34 [i0, i1] | ||
// T185 [i0, b2, i1] = broadcast(T34) | ||
// T192 [i0, b3(ex), i1] = expand(T185) | ||
// T198 [i0, b3(ex)*i1] = reshape(T192) | ||
// output(T34) | ||
// output(T198) | ||
// | ||
// if we consider taking T34 as reference_tv. T198 is the target_tv. We can't | ||
// replay T34's transform of merging all the dimensions to T198, since b3(ex)*i1 | ||
// can't be reversed. The check in this function would give us T34 with source | ||
// i0, i1; where T198 would have source i0, b3, i1, where b3 isn't contained in | ||
// T34. Hence we'll reject this reference_tv. | ||
// | ||
// ------ | ||
// | ||
// e.g 1. | ||
// T0 [i0, i1] | ||
// T1 [i2, i0, i1] | ||
// T2 [i0*i1] = reshape(T0) | ||
// T3 [b3, i0, i1] = broadcast(T0) | ||
// T4 [i2, i0, i1] = add(T1, T3) | ||
// output(T2) | ||
// output(T4) | ||
// | ||
// the example above should be able to pick T4 as reference_tv. T2's source i0, | ||
// i1 are both contained by the source of T4, so this example could be scheduled | ||
// as a single fusion. | ||
bool DomainMap::areAllTargetIdsCoveredBy( | ||
TensorView* target_tv, | ||
TensorView* reference_tv) const { | ||
auto get_source_iter_domains = [this](TensorView* tv) { | ||
// traverse back to collect all disjoint set producer IDs for each ID in the | ||
// logical domain of tv. | ||
VectorOfUniqueEntries<std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>> | ||
all_producer_sets; | ||
std::for_each( | ||
tv->getLogicalDomain().begin(), | ||
tv->getLogicalDomain().end(), | ||
[&](IterDomain* tv_logical_id) { | ||
all_producer_sets.pushBack( | ||
ca_map_.disjointSetOf(tv_logical_id, IdMappingMode::EXACT)); | ||
}); | ||
all_producer_sets.pushBack( | ||
ca_map_.getAllDisjointSetProducers(all_producer_sets)); | ||
|
||
std::vector<IterDomain*> source_ids; | ||
// filtering all producer IDs with empty definition to get source iter | ||
// domains | ||
std::for_each( | ||
all_producer_sets.vector().begin(), | ||
all_producer_sets.vector().end(), | ||
[&source_ids, | ||
this](const std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>& | ||
producer_set_ptr) { | ||
IterDomain* producer_id = producer_set_ptr->front(); | ||
if (ca_map_.uniqueExactDefinitions(producer_id).empty()) { | ||
source_ids.push_back(producer_id); | ||
} | ||
}); | ||
return source_ids; | ||
}; | ||
|
||
// this contains all source iter domain that's covered by reference_tv, so | ||
// it's safe for target_tv to have them. | ||
std::unordered_set<IterDomain*> covered_source_ids; | ||
for (IterDomain* source_id_ref : get_source_iter_domains(reference_tv)) { | ||
covered_source_ids.insert(source_id_ref); | ||
} | ||
// It's safe to have unmapped broadcast IterDomain. There're quite a few tests | ||
// expecting pointwise scheduler to handle this pattern | ||
for (IterDomain* id_out : target_tv->getLogicalDomain()) { | ||
if (id_out->isBroadcast()) { | ||
NVF_ERROR( | ||
id_out->definition() == nullptr || | ||
id_out->definition()->isA<Resize>()); | ||
|
||
// Note that ideally we should also be able to handle merge/split on | ||
// broadcast IDs, so we should really move this skip inside the loop below | ||
// `get_source_iter_domains(target_tv)` and skip broadcast source IDs. | ||
// currently we have the issue that split/merge does not preserve expanded | ||
// broadcasts, see issue: https://github.com/NVIDIA/Fuser/issues/1126 | ||
covered_source_ids.insert(id_out); | ||
} | ||
} | ||
// Note: there's certain cases where it's safe to have dangling IDs, | ||
// e.g | ||
// T34 [i0, i1] | ||
// T185 [i0, b2, i1] = broadcast(T34) | ||
// T192 [i0, b3(ex), i1] = expand(T185) | ||
// It's safe to propagate T34 to T192, since b3(ex) is not involved in the | ||
// propagation. But this isn't generally safe. If the above example is changed | ||
// to e.g | ||
// T34 [i0, i1] | ||
// T185 [i0, b2, i1] = broadcast(T34) | ||
// T186 [i0, i4, i1] = ones({i0, i4, i1}) | ||
// T193 [i0, i4, i1] = add(T185, T186) | ||
// It's unsafe to propagate from T34 to T193, see issue | ||
// https://github.com/NVIDIA/Fuser/issues/3542 | ||
|
||
// Check all source iter domain involved in producing target_tv | ||
for (IterDomain* source_id_out : get_source_iter_domains(target_tv)) { | ||
// NOTE: we use concrete id instead. This allows us to link indirect | ||
// broadcast. So in the example below: T2[i0, i1] = T0[i0, b0] + T1[i0, i1] | ||
// T3[i0, i9] = pad(T0[i0, b0]) | ||
// We have i9 in T3 | ||
// -> source ID b0 | ||
// -> concrete map to i1 | ||
// So T3 is contained by T2. See test `PointwiseTest.DomainMapPad1` | ||
auto concrete_source_id_out = | ||
ca_map_.getConcreteMappedID(source_id_out, IdMappingMode::PERMISSIVE); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the change I made in order to avoid the regression in the added test. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like this did work with our tests. @naoyam Let me know what you think about this change. |
||
// if we find any source_id_out that's not contained, it's possible our | ||
// propagation would fail since transformation involving this iter domain | ||
// can't be resolved. | ||
if (!getMappedInputConcreteID(covered_source_ids, concrete_source_id_out)) { | ||
return false; | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
// Reference domains must exactly match with the input domains. See | ||
// also PR #661 | ||
IterDomain* DomainMap::getMappedInputConcreteID( | ||
|
@@ -233,7 +364,7 @@ IterDomain* DomainMap::anyMapped( | |
} | ||
|
||
// Determine if output TensorView is a valid reference tensor for this fusion. | ||
// The reference tensor must map to all the iterDomains in each input. | ||
// The reference tensor must map to all the iterDomains in each input and output | ||
bool DomainMap::isValidReference(TensorView* tv) const { | ||
for (auto input_tv : ir_utils::filterByType<TensorView>(fusion_->inputs())) { | ||
if (input_tv->uses().empty()) { | ||
|
@@ -245,6 +376,20 @@ bool DomainMap::isValidReference(TensorView* tv) const { | |
return false; | ||
} | ||
} | ||
// The check on outputs are optional, transpose scheduler might propose a | ||
// secondary reference that only applies to a subset of IO tensors. Ideally we | ||
// should have a more robust check and consider the IO groups instead of | ||
// blindly skip outputs. | ||
for (auto output_tv : | ||
ir_utils::filterByType<TensorView>(fusion_->outputs())) { | ||
// no need to check for self. | ||
if (output_tv == tv) { | ||
continue; | ||
} | ||
if (!areAllTargetIdsCoveredBy(output_tv, tv)) { | ||
return false; | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -170,8 +170,16 @@ class DomainMap : public pointwise_utils::DomainMap { | |
TensorView* result = nullptr; | ||
int64_t max_dims = -1; | ||
for (auto tv : group) { | ||
// since transpose scheduler have different set of reference, we skip IDs | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure why. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
What I was trying to say is that, this is what transpose scheduler is doing today, and I don't think that's the right thing to do neither. If you look at how the scheduler is propagating the parallelization strategy. I think we took a short-cut in the first place to re-use the TL;DR, if we apply the same check for all outputs, we are going to fail some CI tests that checks the transpose scheduler is taking the fusion. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure why transposing would create disconnected tensors. Can you show an example? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this example
Here reference2 : T1_g_float[iS2{1024}, iS3{128}] doesn't cover all the source IDs from output If we don't skip the check here, we'll end up not using pointwise scheduler and loose part of the vectorization.. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. It seems this would miss a case where a reference tensor indeed has an output like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. indeed. It's not safe to blindly skip it. We want the check to be consistent of how transformation is propagated. I have a comment in the transpose scheduler when we are passing I feel I needed to look at it further to figure out what is the right check for transpose scheduler and I'm not ready to commit to that in this PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. I thought you're saying this would be less than ideal but still acceptable, but looks like this is a real bug. I agree that fixing the issue doesn't need to be done in this PR, but I think we should take the safer option, i.e., don't skip the check even with the transpose scheduler. Failing to use a vectorized transpose scheduler is a bummer but is better than failing to run. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds fair, I'll go with that route and temporarily disable some CI tests. I'll make sure I open an issue to track the disabled tests and restore them when we fix the validation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Opened issue #3570. I'll have it in the code comment as well. |
||
// coverage check of the reference on outputs of the fusion. Note that | ||
// this is not ideal, we would want to instead have reference tensor | ||
// checked against all its target IO tensors. | ||
// TODO: open an issue for this one. transpose scheduler is not supposed | ||
// to reuse pointwise_utils::DomainMap::isValidRefrence. This function is | ||
// too restrictive and doesn't align well with the scheme of transpose | ||
// scheduler | ||
if (isValidReference(tv)) { | ||
int64_t dims = (int64_t)pointwise_utils::nRootDims(tv); | ||
int64_t dims = (int64_t)pointwise_utils::nLogicalDims(tv); | ||
if (dims > max_dims) { | ||
result = tv; | ||
max_dims = dims; | ||
|
@@ -990,12 +998,12 @@ std::unique_ptr<TransposeParams> getTransposeHeuristics( | |
<< "max_io_dtype_size: " << max_io_dtype_size << "\n" | ||
<< "group 1: " << ir_utils::toString(grouped_inputs_outputs[0]) | ||
<< "\n" | ||
<< "reference1: " << reference1 << "\n" | ||
<< "reference1: " << reference1->toString() << "\n" | ||
<< "inner_most_id1 position: " << inner_most_pos1_in_ref1 | ||
<< " (in reference 1)\n" | ||
<< "group 2: " << ir_utils::toString(grouped_inputs_outputs[1]) | ||
<< "\n" | ||
<< "reference2: " << reference2 << "\n" | ||
<< "reference2: " << reference2->toString() << "\n" | ||
<< "inner_most_id2 position: " << inner_most_pos2_in_ref1 | ||
<< " (in reference 1)" << std::endl; | ||
if (hasSmallTransposeDimensions(tparams)) { | ||
|
@@ -1045,11 +1053,11 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { | |
|
||
int64_t max_dims = 0; | ||
for (auto inp : input_tvs) { | ||
max_dims = std::max(pointwise_utils::nRootDims(inp), max_dims); | ||
max_dims = std::max(pointwise_utils::nLogicalDims(inp), max_dims); | ||
} | ||
|
||
for (auto out : output_tvs) { | ||
max_dims = std::max(pointwise_utils::nRootDims(out), max_dims); | ||
max_dims = std::max(pointwise_utils::nLogicalDims(out), max_dims); | ||
} | ||
|
||
// If everything is zero dim tensors, just return. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to cover only broadcast IDs that have no transformation. For example, if a broadcast ID is split, the two split output IDs should be safe to have in the other output tensors, but since the original broadcast ID is not included in
covered_source_ids
, I suspect the reference would be considered invalid.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TL;DR: good call, let me add a test to verify and change it.
Earlier I have this inside the loop below in line 246, I think that wouldn't have the problem.
I changed that to this version, because I was concerned that might be too loose a check.
I was mostly worried that if we later decide to add
expand
as a ID op, then we won't be able to identify expand ops. But I think I can add another check along withca_map_.uniqueExactDefinitions
for that. At least we have CI guarding that. I'll put a comment there so we'll remember how to fix it then.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Logging offline conversation. Due to issue #1126 , we don't see split/merge on broadcast here.
I'm keeping this code as-is. We added an assert that broadcast IDs on logical domain cannot have definition. 🤞 Let's see if that causes any CI issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion may not hold when a non-broadcast ID is sliced to a broadcast ID. See #3513 (comment)