Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot define an MmaOp where batch dimension is Broadcast #2273

Closed
jacobhinkle opened this issue May 20, 2024 · 3 comments · Fixed by #3322
Closed

Cannot define an MmaOp where batch dimension is Broadcast #2273

jacobhinkle opened this issue May 20, 2024 · 3 comments · Fixed by #3322

Comments

@jacobhinkle
Copy link
Collaborator

jacobhinkle commented May 20, 2024

The following test fails when trying to create an MmaOp

// Single batch dimension which is broadcast
TEST_F(GPUTTensorCoreTest, FusionAmpereBroadcastBatchMatmul_CUDA) {
  auto layout = MmaLayout::TN;

  Fusion fusion;
  FusionGuard fg(&fusion);

  auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout);

  auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half);
  auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half);

  fusion.addInput(tv0);
  fusion.addInput(tv1);

  tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
  tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
  auto tv2 = fusedMultiplySum(
      broadcast(tv0, {true, false, false, false}),
      broadcast(tv1, {true, false, false, false}),
      {-1});
/*
C++ exception with description "details.bcasts.empty() INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/ir/utils.cpp":1
268, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. MmaOp output: has broadcast domains.                                                                                                                 
Exception raised from operator() at /opt/pytorch/nvfuser/csrc/ir/utils.cpp:1268 (most recent call first):
*/

  fusion.addOutput(tv2);
}

This caused the failure of

TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Fail2) {
which is why that test currently checks that we cannot translate that case. However, I think that case should be covered and we should instead fix the MmaOp ctor to not balk at such cases.

@jacobhinkle
Copy link
Collaborator Author

Relevant comment from the PR introducing this check: https://github.com/NVIDIA/Fuser/pull/131/files#r1164511926. It seems we do plan to support this broadcast batch dims, but getMmaOpDetails and getInputLayout don't currently support it. This might change with #2272 since that uses IdModel and allocation domain to determine layout instead of pattern matching.

@kevinstephano
Copy link
Collaborator

Is this still an issue?

@jacobhinkle
Copy link
Collaborator Author

This does still fail. Now it fails in getInputLayout with Unsupported input layout. However, thanks for bringing this to my attention. I just commented out this and the test succeeds, and we can schedule the fusion.

Fuser/csrc/ir/utils.cpp

Lines 1410 to 1415 in 621e146

details.input_layout = getInputLayout(
in_a_details,
in_b_details,
details.m_axes,
details.n_axes,
details.k_axes);

details.input_layout is not used, so we can remove it. I will do that in a new PR that will close this issue.

jacobhinkle added a commit that referenced this issue Oct 31, 2024
There is no reason for us to check the Mma layout anymore when defining
an MmaOp, since that is all handled in the scheduler now. I also added a
test where a new batch dimension is broadcasted before defining the
MmaOp.

Fixes #2273.
jacobhinkle added a commit that referenced this issue Oct 31, 2024
There is no reason for us to check the Mma layout anymore when defining
an MmaOp, since that is all handled in the scheduler now. I also added a
test where a new batch dimension is broadcasted before defining the
MmaOp.

Fixes #2273.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants