Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle empty tensors during definition of cat #3313

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

jacobhinkle
Copy link
Collaborator

In #3292, a PadOp is created for a pad of two inputs, one of which has zero size in the cat dimension. This caused an error when we replaced the empty input with a FullOp output. That is addressed in the remove_empty pass by PR #3301. This PR aims to additionally simplify the fusion definition when we have concrete sizes. In particular, arguments to pad and cat are inspected for empty dimensions and if found, we avoid using PadOp or CatOp unless necessary. The behavior is demonstrated in the included CatOfEmpty test:

Cat of tensors where some non-cat dimensions are zero results in a FullOp of the proper size
Cat of tensors that are all empty in cat dim results in LoadStoreOp of the first input
Cat of tensors, some empty and some not, results in ignoring the empty inputs. In this case that leaves only a set op
Cat of tensors, some empty and some not, results in ignoring the empty inputs. In this case that leaves us with a CatOp with two inputs

@jacobhinkle jacobhinkle requested a review from naoyam October 30, 2024 16:14
@jacobhinkle jacobhinkle changed the title Empty cat at definition Handle empty tensors during definition of cat Oct 30, 2024
@jacobhinkle
Copy link
Collaborator Author

!build --diff

@@ -479,6 +481,21 @@ bool hasSimilarDtype(DataType base, DataType dt) {
NVF_THROW("Unrecognized base dtype.");
}

Val* zeroForDtype(DataType dtype) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use Fusion::zeroVal(DatatType)?

if (dim == cat_dim) {
shape[dim] = SimplifyingIrBuilder::addExpr(shape[dim], extent);
} else if (shape[dim] == nullptr) {
shape[dim] = extent;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what the correct behavior should be here. Can we just use the first extent for a non-cat dimension? Suppose the first cat input has size zero for that dimension, the output of the cat would also have zero for the dimension. Would that be expected?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case it should be caught by the inputs.empty() condition at line 678, since all of the inputs should be empty. Actually though, I just realized that assumes that if one of the inputs has constant size 0 in one dimension that all the other inputs will have constant size zero in the same dimension. If some are symbolic this definition should be proof that they're equal by exact mapping, but they won't be detected at line 660. I guess what I should do instead is fire this condition if any input has a zero non-cat dimension.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if one of the inputs has constant size 0 in one dimension that all the other inputs will have constant size zero in the same dimension.

Yeah, this was what I was thinking about. I'm not sure if that's actually allowed. Is it?

Copy link
Collaborator Author

@jacobhinkle jacobhinkle Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could have something like this I think:

tv0[ 2, 0 ]
tv1[ 3, i0 ]
tv2 = cat({tv0, tv1}, /*axis=*/0)

In this case we would normally exact map tv0->axis(1) with tv1->axis(1) so that i0 must be 0 or else we'll hit an error in ExpressionEvaluator::propagateBoundValuesThroughExactMaps(). But now since we'll just translate this to full({5, 0}) I don't think there's any such constraint, so maybe it would be legal to pass in a tv1 with shape (3,2) and we would not hit an error.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, makes sense.

@jacobhinkle
Copy link
Collaborator Author

The failure in jit_binary_tests_17_A100_1/3 is real. I have a fix but am waiting to push until the codediff is done.

Base automatically changed from remove_empty_deep_extent_eval to main October 30, 2024 20:49
@jacobhinkle
Copy link
Collaborator Author

!test --diff

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants