Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pointwise scheduler fails to validate reference tv #3513

Open
wants to merge 72 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 69 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
f0ce0e3
will this work?
jjsjann123 Dec 2, 2024
70e31bf
errr
jjsjann123 Dec 2, 2024
5f09e36
missed a few renaming
jjsjann123 Dec 2, 2024
9ad9edb
WIP
jjsjann123 Dec 2, 2024
6540201
test added
jjsjann123 Dec 2, 2024
89dc741
Merge branch 'main' into pw_scheduler_reference_find_patch
jjsjann123 Dec 4, 2024
ed56c75
WIP
jjsjann123 Dec 4, 2024
f1e7e0a
WIP
jjsjann123 Dec 4, 2024
9d174c9
declaration
jjsjann123 Dec 4, 2024
bf425eb
WIP
jjsjann123 Dec 4, 2024
aef13ac
WIP
jjsjann123 Dec 4, 2024
a9ae516
refactor the traversal
jjsjann123 Dec 4, 2024
d9e8dc0
WIP
jjsjann123 Dec 4, 2024
7333806
scratch that, it's getting out of hand
jjsjann123 Dec 4, 2024
f6ad363
Revert "scratch that, it's getting out of hand"
jjsjann123 Dec 4, 2024
cef0b83
try focus on expanded dimensions
jjsjann123 Dec 4, 2024
a557a8b
wip
jjsjann123 Dec 5, 2024
65cb621
Merge branch 'main' into pw_scheduler_reference_find_patch
jjsjann123 Dec 5, 2024
f88ebf7
lintrunner
jjsjann123 Dec 5, 2024
ea89b69
comment added
jjsjann123 Dec 5, 2024
0fc0dc1
fixing
jjsjann123 Dec 5, 2024
07797a4
Merge branch 'main' into pw_scheduler_reference_find_patch
jjsjann123 Dec 5, 2024
94e2ddf
Apply suggestions from code review
jjsjann123 Dec 6, 2024
3e2b43e
Merge branch 'main' into pw_scheduler_reference_find_patch
jjsjann123 Dec 6, 2024
63284b6
reverting unintended changes
jjsjann123 Dec 6, 2024
e39ec58
adding unit tests
jjsjann123 Dec 6, 2024
fa4d8ab
WIP
jjsjann123 Dec 6, 2024
66bc533
unit test
jjsjann123 Dec 6, 2024
26054c3
WIP
jjsjann123 Dec 6, 2024
3d2b926
WIP, seems to found another issue here
jjsjann123 Dec 7, 2024
bb659f8
revert unsafe exception
jjsjann123 Dec 7, 2024
45bb785
moving tests to uniform
jjsjann123 Dec 7, 2024
b744086
Revert "moving tests to uniform"
jjsjann123 Dec 7, 2024
3a16c65
do not use random for validation
jjsjann123 Dec 7, 2024
3b9c97f
fixing tests
jjsjann123 Dec 7, 2024
54176a7
fixing tests and comments
jjsjann123 Dec 7, 2024
5b668d6
skip the check for transpose scheduler to ensure no performance regre…
jjsjann123 Dec 9, 2024
3112ebd
allowing unmatched broadcast dimension
jjsjann123 Dec 9, 2024
db4cabc
CLANGFORMAT
jjsjann123 Dec 9, 2024
a60bdc6
Merge branch 'main' into pw_scheduler_reference_find_patch
jjsjann123 Dec 9, 2024
55ddfc8
TYPO
jjsjann123 Dec 9, 2024
0972ca2
Merge remote-tracking branch 'origin/pw_scheduler_reference_find_patc…
jjsjann123 Dec 9, 2024
3e4cf86
Merge branch 'main' into pw_scheduler_reference_find_patch
jjsjann123 Dec 9, 2024
325f5bb
lifting the broadcast exception, in case we change how expand is mode…
jjsjann123 Dec 9, 2024
8a92a15
Merge remote-tracking branch 'origin/pw_scheduler_reference_find_patc…
jjsjann123 Dec 9, 2024
880d73e
Merge branch 'main' into pw_scheduler_reference_find_patch
jjsjann123 Dec 9, 2024
22a7561
fixing false negative tests
jjsjann123 Dec 9, 2024
2bc97bf
Merge branch 'main' into pw_scheduler_reference_find_patch
jjsjann123 Dec 10, 2024
49767da
WIP addressing review comments
jjsjann123 Dec 10, 2024
2cb3372
typo
jjsjann123 Dec 10, 2024
73c66f8
refactor the logic per review comments/discussions
jjsjann123 Dec 11, 2024
dbd5995
fixing signature
jjsjann123 Dec 11, 2024
b7f628f
Merge branch 'main' into pw_scheduler_reference_find_patch
jjsjann123 Dec 11, 2024
4ba5baa
updating tests, removing asserts
jjsjann123 Dec 11, 2024
1d021c7
Merge remote-tracking branch 'origin/pw_scheduler_reference_find_patc…
jjsjann123 Dec 11, 2024
742f7f3
removing checks that are not exposed by scheduler
jjsjann123 Dec 11, 2024
145d902
renaming things
jjsjann123 Dec 11, 2024
19291c8
err somehow I missed this one
jjsjann123 Dec 11, 2024
526e6b7
updating tests
jjsjann123 Dec 12, 2024
0ecc1f6
adding another test
jjsjann123 Dec 12, 2024
6abaa1d
test fixing
jjsjann123 Dec 12, 2024
e8a4ddd
fixing tests
jjsjann123 Dec 12, 2024
6ada657
CLANGFORMAT
jjsjann123 Dec 12, 2024
d797df8
removing python test since it's already covered in cpp test
jjsjann123 Dec 12, 2024
25362cd
oops, assert was placed in the wrong spot
jjsjann123 Dec 12, 2024
a129e72
CLANGFORMAT
jjsjann123 Dec 12, 2024
b7f2efb
adding naoya's example
jjsjann123 Dec 12, 2024
307569f
I was padding the wrong dimension here
jjsjann123 Dec 12, 2024
af315c7
made a small refactor to avoid regression
jjsjann123 Dec 12, 2024
7668d4e
Merge branch 'main' into pw_scheduler_reference_find_patch
jjsjann123 Dec 12, 2024
d46323c
committing something so I can trigger CI again
jjsjann123 Dec 12, 2024
c70f160
Merge remote-tracking branch 'origin/pw_scheduler_reference_find_patc…
jjsjann123 Dec 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions csrc/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class DomainMap : public pointwise_utils::DomainMap {
if (isValidReference(output_tv) &&
hasMinimumSize(output_tv, minimum_num_axes) &&
!output_tv->isFusionInput()) {
int64_t n_dims = pointwise_utils::nRootDims(output_tv);
int64_t n_dims = pointwise_utils::nLogicalDims(output_tv);
if (n_dims > max_dims) {
result = output_tv;
max_dims = n_dims;
Expand Down Expand Up @@ -529,11 +529,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) {

int64_t max_dims = 0;
for (auto inp : input_tvs) {
max_dims = std::max(pointwise_utils::nRootDims(inp), max_dims);
max_dims = std::max(pointwise_utils::nLogicalDims(inp), max_dims);
}

for (auto out : output_tvs) {
max_dims = std::max(pointwise_utils::nRootDims(out), max_dims);
max_dims = std::max(pointwise_utils::nLogicalDims(out), max_dims);
}

// If everything is zero dim tensors, just return.
Expand Down
147 changes: 146 additions & 1 deletion csrc/scheduler/pointwise_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,137 @@ bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv)
return in_concrete_ids.empty();
}

// Note: ideally we would want to check that reference_tv contains all iter
// domains in target_tv, so that transformation applied on reference_tv can be
// propagated to target_tv. But we don't have an easy way to check that. Instead
// of that, this function checks that all source iter domains involved in
// transformation on target_tv is covered by reference_tv. Source iter domains
// of TensorViews are IDs that doesn't have an definition and are producers of
// any IDs on the logical domain of the given TensorView.
//
// ------
//
// e.g 0.
// T34 [i0, i1]
// T185 [i0, b2, i1] = broadcast(T34)
// T192 [i0, b3(ex), i1] = expand(T185)
// T198 [i0, b3(ex)*i1] = reshape(T192)
// output(T34)
// output(T198)
//
// if we consider taking T34 as reference_tv. T198 is the target_tv. We can't
// replay T34's transform of merging all the dimensions to T198, since b3(ex)*i1
// can't be reversed. The check in this function would give us T34 with source
// i0, i1; where T198 would have source i0, b3, i1, where b3 isn't contained in
// T34. Hence we'll reject this reference_tv.
//
// ------
//
// e.g 1.
// T0 [i0, i1]
// T1 [i2, i0, i1]
// T2 [i0*i1] = reshape(T0)
// T3 [b3, i0, i1] = broadcast(T0)
// T4 [i2, i0, i1] = add(T1, T3)
// output(T2)
// output(T4)
//
// the example above should be able to pick T4 as reference_tv. T2's source i0,
// i1 are both contained by the source of T4, so this example could be scheduled
// as a single fusion.
bool DomainMap::areAllTargetIdsCoveredBy(
TensorView* target_tv,
TensorView* reference_tv) const {
auto get_source_iter_domains = [this](TensorView* tv) {
// traverse back to collect all disjoint set producer IDs for each ID in the
// logical domain of tv.
VectorOfUniqueEntries<std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>>
all_producer_sets;
std::for_each(
tv->getLogicalDomain().begin(),
tv->getLogicalDomain().end(),
[&](IterDomain* tv_logical_id) {
all_producer_sets.pushBack(
ca_map_.disjointSetOf(tv_logical_id, IdMappingMode::EXACT));
});
all_producer_sets.pushBack(
ca_map_.getAllDisjointSetProducers(all_producer_sets));

std::vector<IterDomain*> source_ids;
// filtering all producer IDs with empty definition to get source iter
// domains
std::for_each(
all_producer_sets.vector().begin(),
all_producer_sets.vector().end(),
[&source_ids,
this](const std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>&
producer_set_ptr) {
IterDomain* producer_id = producer_set_ptr->front();
if (ca_map_.uniqueExactDefinitions(producer_id).empty()) {
source_ids.push_back(producer_id);
}
});
return source_ids;
};

// this contains all source iter domain that's covered by reference_tv, so
// it's safe for target_tv to have them.
std::unordered_set<IterDomain*> covered_source_ids;
for (IterDomain* source_id_ref : get_source_iter_domains(reference_tv)) {
covered_source_ids.insert(source_id_ref);
}
// It's safe to have unmapped broadcast IterDomain. There're quite a few tests
// expecting pointwise scheduler to handle this pattern
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to cover only broadcast IDs that have no transformation. For example, if a broadcast ID is split, the two split output IDs should be safe to have in the other output tensors, but since the original broadcast ID is not included in covered_source_ids, I suspect the reference would be considered invalid.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TL;DR: good call, let me add a test to verify and change it.

Earlier I have this inside the loop below in line 246, I think that wouldn't have the problem.

for (IterDomain* id : get_source_producers(output_tv)) {
  if (id->isBroadcast()) {
    continue;
  }
...
}

I changed that to this version, because I was concerned that might be too loose a check.
I was mostly worried that if we later decide to add expand as a ID op, then we won't be able to identify expand ops. But I think I can add another check along with ca_map_.uniqueExactDefinitions for that. At least we have CI guarding that. I'll put a comment there so we'll remember how to fix it then.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logging offline conversation. Due to issue #1126 , we don't see split/merge on broadcast here.
I'm keeping this code as-is. We added an assert that broadcast IDs on logical domain cannot have definition. 🤞 Let's see if that causes any CI issue.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertion may not hold when a non-broadcast ID is sliced to a broadcast ID. See #3513 (comment)

for (IterDomain* id_out : target_tv->getLogicalDomain()) {
if (id_out->isBroadcast()) {
NVF_ERROR(
id_out->definition() == nullptr ||
id_out->definition()->isA<Resize>());

// Note that ideally we should also be able to handle merge/split on
// broadcast IDs, so we should really move this skip inside the loop below
// `get_source_iter_domains(target_tv)` and skip broadcast source IDs.
// currently we have the issue that split/merge does not preserve expanded
// broadcasts, see issue: https://github.com/NVIDIA/Fuser/issues/1126
covered_source_ids.insert(id_out);
}
}
// Note: there's certain cases where it's safe to have dangling IDs,
// e.g
// T34 [i0, i1]
// T185 [i0, b2, i1] = broadcast(T34)
// T192 [i0, b3(ex), i1] = expand(T185)
// It's safe to propagate T34 to T192, since b3(ex) is not involved in the
// propagation. But this isn't generally safe. If the above example is changed
// to e.g
// T34 [i0, i1]
// T185 [i0, b2, i1] = broadcast(T34)
// T186 [i0, i4, i1] = ones({i0, i4, i1})
// T193 [i0, i4, i1] = add(T185, T186)
// It's unsafe to propagate from T34 to T193, see issue
// https://github.com/NVIDIA/Fuser/issues/3542

// Check all source iter domain involved in producing target_tv
for (IterDomain* source_id_out : get_source_iter_domains(target_tv)) {
// NOTE: we use concrete id instead. This allows us to link indirect
// broadcast. So in the example below: T2[i0, i1] = T0[i0, b0] + T1[i0, i1]
// T3[i0, i9] = pad(T0[i0, b0])
// We have i9 in T3
// -> source ID b0
// -> concrete map to i1
// So T3 is contained by T2. See test `PointwiseTest.DomainMapPad1`
auto concrete_source_id_out =
ca_map_.getConcreteMappedID(source_id_out, IdMappingMode::PERMISSIVE);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the change I made in order to avoid the regression in the added test. PointwiseTest.DomainMapPad1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this did work with our tests.

@naoyam Let me know what you think about this change.

// if we find any source_id_out that's not contained, it's possible our
// propagation would fail since transformation involving this iter domain
// can't be resolved.
if (!getMappedInputConcreteID(covered_source_ids, concrete_source_id_out)) {
return false;
}
}
return true;
}

// Reference domains must exactly match with the input domains. See
// also PR #661
IterDomain* DomainMap::getMappedInputConcreteID(
Expand Down Expand Up @@ -233,7 +364,7 @@ IterDomain* DomainMap::anyMapped(
}

// Determine if output TensorView is a valid reference tensor for this fusion.
// The reference tensor must map to all the iterDomains in each input.
// The reference tensor must map to all the iterDomains in each input and output
bool DomainMap::isValidReference(TensorView* tv) const {
for (auto input_tv : ir_utils::filterByType<TensorView>(fusion_->inputs())) {
if (input_tv->uses().empty()) {
Expand All @@ -245,6 +376,20 @@ bool DomainMap::isValidReference(TensorView* tv) const {
return false;
}
}
// The check on outputs are optional, transpose scheduler might propose a
// secondary reference that only applies to a subset of IO tensors. Ideally we
// should have a more robust check and consider the IO groups instead of
// blindly skip outputs.
for (auto output_tv :
ir_utils::filterByType<TensorView>(fusion_->outputs())) {
// no need to check for self.
if (output_tv == tv) {
continue;
}
if (!areAllTargetIdsCoveredBy(output_tv, tv)) {
return false;
}
}
return true;
}

Expand Down
11 changes: 9 additions & 2 deletions csrc/scheduler/pointwise_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,21 @@ class DomainMap {
}

// Determine if a TensorView is a valid reference tensor for this fusion.
// The reference tensor must map to all the iterDomains in each input.
// The reference tensor must map to all the iterDomains in each input and
// output.
bool isValidReference(TensorView* tv) const;

protected:
// Determine if all IterDomains are mapped between input and the given tvs
bool areAllInputIdsMappedTo(TensorView* input_tv, TensorView* output_tv)
const;

// Determine if all source IterDomains in target_tv are contained by the
// reference_tv, this ensures transformations from reference_tv can be
// propagated to target_tv
bool areAllTargetIdsCoveredBy(TensorView* target_tv, TensorView* reference_tv)
const;

virtual IterDomain* getMappedInputConcreteID(
const std::unordered_set<IterDomain*>& in_concrete_ids,
IterDomain* out_id) const;
Expand Down Expand Up @@ -63,7 +70,7 @@ class DomainMap {

// Returns number of non-reduction/non-broadcas/non-device dims in logical
// domain
inline int64_t nRootDims(const TensorView* tv) {
inline int64_t nLogicalDims(const TensorView* tv) {
auto logical_dom = tv->getLogicalDomain();
int64_t tv_n_dims = 0;
for (auto dim : logical_dom) {
Expand Down
18 changes: 13 additions & 5 deletions csrc/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,16 @@ class DomainMap : public pointwise_utils::DomainMap {
TensorView* result = nullptr;
int64_t max_dims = -1;
for (auto tv : group) {
// since transpose scheduler have different set of reference, we skip IDs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why. The DomainMap of the transpose scheduler has findReferenceFor that tries to find a reference for each group, so there are multiple references. Each of the reference tensors is supposed to pass isValidReferece, so I think each reference should be mapped with all the input tensors. Why isn't it the case with fusion outputs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each of the reference tensors is supposed to pass isValidReferece, so I think each reference should be mapped with all the input tensors.

What I was trying to say is that, this is what transpose scheduler is doing today, and I don't think that's the right thing to do neither.

If you look at how the scheduler is propagating the parallelization strategy.
For group 2 and for group 1. The propagation is limited to other TVs in its own group. It's not supposed to be propagated to all inputs and outputs.

I think we took a short-cut in the first place to re-use the pointwise_utils::isValidReference when we check for legitimacy of each reference.

TL;DR, if we apply the same check for all outputs, we are going to fail some CI tests that checks the transpose scheduler is taking the fusion.
I think we need to have a follow up PR that adds a transpose_utils::isValidReference instead of re-use pointwise_utils::isValidReference

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why transposing would create disconnected tensors. Can you show an example?

Copy link
Collaborator Author

@jjsjann123 jjsjann123 Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this example

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id38(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[1024, 128], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T1 = fd.define_tensor(shape=[1024, 128], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[0, 1])

    T2 = fd.ops.broadcast_in_dim(T1, shape=[1, 1024, 128], broadcast_dims=[1, 2])

    T3 = fd.ops.broadcast_in_dim(T0, shape=[1, 1024, 128], broadcast_dims=[1, 2])
    T4 = fd.ops.broadcast_in_dim(T3, shape=[32, 1024, 128], broadcast_dims=[0, 1, 2])

    T5 = fd.ops.add(T4, T2)
    T6 = fd.ops.reshape(T5, new_shape=[32*1024, 128])
    fd.add_output(T6)

with FusionDefinition() as fd:
    nvfuser_fusion_id38(fd)

inputs = [
    torch.testing.make_tensor((1024, 128), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((1024*128), dtype=torch.float32, device='cuda:0').as_strided((1024, 128), (1, 1024)),
]
o = fd.execute(inputs)

def foo(inputs):
    T2 = inputs[1].unsqueeze(0)
    T3 = inputs[0].expand((32, 1024, 128))
    T4 = T2 + T3
    return T4.reshape(32*1024, 128)

o_ref = foo(inputs)

assert o[0].allclose(o_ref)
g{(transpose)
group id: 0
inputs:
  T0_g_float[iS0{1024}, iS1{128}] float
  T1_g_float[iS2{1024}, iS3{128}] float
outputs:
  T9_g_float[iS30{32768}rf, iS27{128}] float


T4_l_float[bS10{1}, iS11{1024}, iS12{128}]
   = broadcast( T0_g_float[iS0{1024}, iS1{128}] )
(2)
T7_l_float[bS19{1 ex 32}, iS20{1024}, iS21{128}] = expand( T4_l_float[bS10{1}, iS11{1024}, iS12{128}], {32, 1024, 128} )
(29)
T2_l_float[bS4{1}, iS5{1024}, iS6{128}]
   = broadcast( T1_g_float[iS2{1024}, iS3{128}] )
(0)
T8_l_float[bS22{1 ex 32}, iS23{1024}, iS24{128}]
   = T7_l_float[bS19{1 ex 32}, iS20{1024}, iS21{128}]
   + T2_l_float[bS4{1}, iS5{1024}, iS6{128}];
(27)
T9_g_float[iS30{32768}rf, iS27{128}] = view( T8_l_float[bS22{1 ex 32}, iS23{1024}, iS24{128}] )
(8)
}

} //Segmented_Fusion


===== Transpose Stats ========
inputs: T0_g_float[iS0{1024}, iS1{128}], T1_g_float[iS2{1024}, iS3{128}]
outputs: T9_g_float[iS30{32768}rf, iS27{128}]
shape: 32768 128
num_elems: 4194304
n_io_tensors: 3
max_io_dtype_size: 4
group 1: T9_g_float[iS30{32768}rf, iS27{128}], T0_g_float[iS0{1024}, iS1{128}]
reference1: T9_g_float[iS30{32768}rf, iS27{128}]
inner_most_id1 position: 1 (in reference 1)
group 2: T1_g_float[iS2{1024}, iS3{128}]
reference2: T1_g_float[iS2{1024}, iS3{128}]
inner_most_id2 position: 0 (in reference 1)

Here reference2 : T1_g_float[iS2{1024}, iS3{128}] doesn't cover all the source IDs from output T9, but transpose scheduler only propagate/parallelize reference2 within its own group, so it's still a valid reference.

If we don't skip the check here, we'll end up not using pointwise scheduler and loose part of the vectorization..

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. It seems this would miss a case where a reference tensor indeed has an output like T9 in the same group. If so, we should not just always skip the check even for the transpose scheduler, should we?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed. It's not safe to blindly skip it. We want the check to be consistent of how transformation is propagated.

I have a comment in the transpose scheduler when we are passing false to skip the check and how it's not correct.
Transpose scheduler is a bit strange, since it also does a global scheduling at Step 2, but only to a subset of IDs on reference1. So the right check to apply is a bit blurry.

I feel I needed to look at it further to figure out what is the right check for transpose scheduler and I'm not ready to commit to that in this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I thought you're saying this would be less than ideal but still acceptable, but looks like this is a real bug. I agree that fixing the issue doesn't need to be done in this PR, but I think we should take the safer option, i.e., don't skip the check even with the transpose scheduler. Failing to use a vectorized transpose scheduler is a bummer but is better than failing to run.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds fair, I'll go with that route and temporarily disable some CI tests. I'll make sure I open an issue to track the disabled tests and restore them when we fix the validation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened issue #3570. I'll have it in the code comment as well.

// coverage check of the reference on outputs of the fusion. Note that
// this is not ideal, we would want to instead have reference tensor
// checked against all its target IO tensors.
// TODO: open an issue for this one. transpose scheduler is not supposed
// to reuse pointwise_utils::DomainMap::isValidRefrence. This function is
// too restrictive and doesn't align well with the scheme of transpose
// scheduler
if (isValidReference(tv)) {
int64_t dims = (int64_t)pointwise_utils::nRootDims(tv);
int64_t dims = (int64_t)pointwise_utils::nLogicalDims(tv);
if (dims > max_dims) {
result = tv;
max_dims = dims;
Expand Down Expand Up @@ -990,12 +998,12 @@ std::unique_ptr<TransposeParams> getTransposeHeuristics(
<< "max_io_dtype_size: " << max_io_dtype_size << "\n"
<< "group 1: " << ir_utils::toString(grouped_inputs_outputs[0])
<< "\n"
<< "reference1: " << reference1 << "\n"
<< "reference1: " << reference1->toString() << "\n"
<< "inner_most_id1 position: " << inner_most_pos1_in_ref1
<< " (in reference 1)\n"
<< "group 2: " << ir_utils::toString(grouped_inputs_outputs[1])
<< "\n"
<< "reference2: " << reference2 << "\n"
<< "reference2: " << reference2->toString() << "\n"
<< "inner_most_id2 position: " << inner_most_pos2_in_ref1
<< " (in reference 1)" << std::endl;
if (hasSmallTransposeDimensions(tparams)) {
Expand Down Expand Up @@ -1045,11 +1053,11 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) {

int64_t max_dims = 0;
for (auto inp : input_tvs) {
max_dims = std::max(pointwise_utils::nRootDims(inp), max_dims);
max_dims = std::max(pointwise_utils::nLogicalDims(inp), max_dims);
}

for (auto out : output_tvs) {
max_dims = std::max(pointwise_utils::nRootDims(out), max_dims);
max_dims = std::max(pointwise_utils::nLogicalDims(out), max_dims);
}

// If everything is zero dim tensors, just return.
Expand Down
Loading
Loading