Skip to content

Commit

Permalink
Remove reduction axis when K=1 (#2737)
Browse files Browse the repository at this point in the history
When K=1, we will be mapping a reduction axis to broadcast axis in
`LinearOp/MatmulOp`. This is not allowed in nvFuser. This was first
noticed in Issue #2532 but exists generally.

PR #2534 is included in this PR for checking `LinearOp`.
  • Loading branch information
Priya2698 authored Aug 2, 2024
1 parent 3c8a3c4 commit 4a2987e
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 51 deletions.
1 change: 1 addition & 0 deletions csrc/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1861,6 +1861,7 @@ std::vector<at::Tensor> FusionExecutor::evaluateFusionOutputs(
for (const auto& out_val : fusion()->outputs()) {
auto out_tensor =
expr_eval.evaluate(out_val->as<TensorView>()).as<at::Tensor>();
expr_eval.bind(out_val, out_tensor);
outputs.emplace_back(out_tensor);
}
}
Expand Down
4 changes: 3 additions & 1 deletion csrc/logical_domain_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,16 @@ std::unordered_map<IterDomain*, IterDomain*> PairwiseLogicalDomainMap::map(
NVF_ERROR(false, "Producer did not match any LinearOp input.")
}

bool k_bcast = op->inA()->as<TensorView>()->axis(-1)->isBroadcast();
// LinearOp:
// inputs (0) = {*, in_features}
// weight (1) = {out_features, in_features} / {in_features}
// bias (2) = {out_features} / {}
// output = {*, out_features} / {*}

const std::vector<IterDomain*>& aligned_producer_ids =
ops::mapLinearOpIterDomains(producer_logical, input_position, out_size);
ops::mapLinearOpIterDomains(
producer_logical, input_position, out_size, k_bcast);
pairwiseMapAllIds(aligned_producer_ids, consumer_root);
return dom_map;
}
Expand Down
63 changes: 43 additions & 20 deletions csrc/ops/composite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,32 +64,43 @@ static TensorView* newForLinear(
auto input_domain = TensorDomain::noReductions(input->getLogicalDomain());
auto weight_domain = TensorDomain::noReductions(weight->getLogicalDomain());

// Output has a reduction axis rK if K is not bcast
NVF_CHECK(
input_domain.back()->isBroadcast() == weight_domain.back()->isBroadcast(),
"K should be broadcast in both inputs and weights, or neither.");
bool k_bcast = input_domain.back()->isBroadcast();
size_t red_dims = k_bcast ? 0 : 1;

// Linear: a = {*, in_features}, b = {out_features, in_features} /
// {in_features}.The linear output is {*, (out_features), rK}.
// The first out_size -2 dimensions are as the first input, followed by
// out_features (if present) and an additional reduction axis K.
auto ndims_out = input_domain.size() + weight_domain.size() - 1;
// {in_features}.The linear output is {*, (out_features), rK?}.
// Reduction K is present only when K is not bcast.
auto ndims_out =
(input_domain.size() - 1) + (weight_domain.size() - 1) + red_dims;

const std::vector<IterDomain*>& mapping_a =
ops::mapLinearOpIterDomains(input_domain, 0, ndims_out);
ops::mapLinearOpIterDomains(input_domain, 0, ndims_out, k_bcast);
const std::vector<IterDomain*>& mapping_b =
ops::mapLinearOpIterDomains(weight_domain, 1, ndims_out);
ops::mapLinearOpIterDomains(weight_domain, 1, ndims_out, k_bcast);
std::vector<IterDomain*> mapping_bias(ndims_out, nullptr);
if (bias != nullptr) {
auto bias_domain = TensorDomain::noReductions(bias->getLogicalDomain());
mapping_bias = ops::mapLinearOpIterDomains(bias_domain, 2, ndims_out);
mapping_bias =
ops::mapLinearOpIterDomains(bias_domain, 2, ndims_out, k_bcast);
}

std::vector<IterDomain*> out_domain(ndims_out, nullptr);

for (auto idx : c10::irange(ndims_out - 1)) {
for (auto idx : c10::irange(ndims_out - red_dims)) {
out_domain[idx] = ops::newOutputIterDomain(
{mapping_a.at(idx), mapping_b.at(idx), mapping_bias.at(idx)});
}
// Specify the iterdomain for K as reduction
out_domain[ndims_out - 1] = ops::newOutputIterDomain(
{mapping_a.back(), mapping_b.back()},
/*force_iter_type=*/IterType::Reduction);

if (!k_bcast) {
// Specify the iterdomain for K as reduction
out_domain[ndims_out - 1] = ops::newOutputIterDomain(
{mapping_a.back(), mapping_b.back()},
/*force_iter_type=*/IterType::Reduction);
}

TensorDomain* td = IrBuilder::create<TensorDomain>(
out_domain, TensorDomain::getContiguityFilledWith(out_domain, true));
Expand Down Expand Up @@ -335,14 +346,25 @@ static TensorView* newForMatmul(TensorView* tv_a, TensorView* tv_b) {
auto ndims_a = orig_domain_a.size();
auto ndims_b = orig_domain_b.size();

auto b_kpos = orig_domain_b.size() > 1 ? ndims_b - 2 : ndims_b - 1;
NVF_CHECK(
orig_domain_a.back()->isBroadcast() ==
orig_domain_b.at(b_kpos)->isBroadcast(),
"K should be broadcast in both A and B, or neither.");

// Output has a reduction axis rK if K is not bcast
bool k_bcast = orig_domain_a.back()->isBroadcast();
size_t red_dims = k_bcast ? 0 : 1;

// Matmul output size is same as the higher dimensional input size if both A/B
// > 1D, but with 1 additional IterType::Reduction axis rK.
auto ndims_out = std::max(ndims_a, ndims_b) + 1;
// > 1D, but with 1 additional IterType::Reduction axis rK if K is not
// broadcast.
auto ndims_out = std::max(ndims_a, ndims_b) + red_dims;
if (std::min(ndims_a, ndims_b) == 1) {
// If one of the inputs is 1D, the output size is the same as the higher
// dimensional input size, since we will include a Reduction axis for K in
// the output. For example: [iM, iK] x [iK] -> [iM, rK]
ndims_out = std::max(ndims_a, ndims_b);
ndims_out = std::max(ndims_a, ndims_b) - 1 + red_dims;
}

std::vector<IterDomain*> out_domain(ndims_out, nullptr);
Expand All @@ -352,14 +374,15 @@ static TensorView* newForMatmul(TensorView* tv_a, TensorView* tv_b) {
const std::vector<IterDomain*>& mapping_b =
ops::mapMatmulOpIterDomains(orig_domain_b, 1, ndims_out);

for (auto idx : c10::irange(ndims_out - 1)) {
for (auto idx : c10::irange(ndims_out - red_dims)) {
out_domain[idx] =
ops::newOutputIterDomain({mapping_a.at(idx), mapping_b.at(idx)});
}

out_domain[ndims_out - 1] = ops::newOutputIterDomain(
{mapping_a.back(), mapping_b.back()},
/*force_iter_type=*/IterType::Reduction);
if (!k_bcast) {
out_domain[ndims_out - 1] = ops::newOutputIterDomain(
{mapping_a.back(), mapping_b.back()},
/*force_iter_type=*/IterType::Reduction);
}

TensorDomain* td = IrBuilder::create<TensorDomain>(
out_domain, TensorDomain::getContiguityFilledWith(out_domain, true));
Expand Down
51 changes: 33 additions & 18 deletions csrc/ops/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,20 +188,26 @@ std::vector<IterDomain*> mapMatmulOpIterDomains(
std::vector<IterDomain*> mapping(out_size, nullptr);
auto inp_size = (int64_t)input_domain.size();

if (inp_size == 1) {
// Only reduction axis {K}
mapping[out_size - 1] = input_domain[0];
return mapping;
}

// Input A to matmul: {*, M, K}
// Input B to matmul: {*, K, N}
auto kpos = input_position == 0 ? inp_size - 1 : inp_size - 2;
auto kpos = inp_size - 1;
if (input_position == 1 && inp_size > 1) {
kpos = inp_size - 2;
}
bool k_bcast = input_domain.at(kpos)->isBroadcast();
int64_t red_dims = k_bcast ? 0 : 1;

// Last position is a reduction dimension mapping to K
mapping[out_size - 1] = input_domain.at(kpos);
// Last position is a reduction dimension mapping to K if K is not broadcast.
if (!k_bcast) {
mapping[out_size - 1] = input_domain.at(kpos);
;
}

for (auto out_idx = (int64_t)out_size - 2, inp_idx = inp_size - 1;
if (inp_size == 1) {
return mapping;
}

for (auto out_idx = (int64_t)out_size - 1 - red_dims, inp_idx = inp_size - 1;
inp_idx >= 0;
inp_idx--) {
if (inp_idx != kpos) {
Expand All @@ -211,7 +217,7 @@ std::vector<IterDomain*> mapMatmulOpIterDomains(
// Consider [iM, iK] x [iK]: [iM, rK]. Since out_size < inp_size,
// input A and output are not right-aligned. In this case, the output index
// pointer should not be moved when the reduction axis is encountered.
else if (inp_size <= (int64_t)out_size - 1) {
else if (inp_size <= (int64_t)out_size - red_dims) {
out_idx--;
}
}
Expand All @@ -222,7 +228,8 @@ std::vector<IterDomain*> mapMatmulOpIterDomains(
std::vector<IterDomain*> mapLinearOpIterDomains(
const std::vector<IterDomain*>& input_domain,
int64_t input_position,
size_t out_size) {
size_t out_size,
bool k_bcast) {
std::vector<IterDomain*> mapping(out_size, nullptr);
auto inp_size = input_domain.size();

Expand All @@ -231,29 +238,37 @@ std::vector<IterDomain*> mapLinearOpIterDomains(
"Input position must be 0, 1, or 2. Found ",
input_position);

auto red_dims = k_bcast ? 0 : 1;

// Input A: {*, M, K}
// Input B: {*, N, K} / {K}
// Bias: {N} / {}

// Map K if K is not bcast
if (input_position != 2 && !k_bcast) {
mapping[out_size - 1] = input_domain.back();
}

switch (input_position) {
case 0: {
// Linear output is same as input for all but the last dimension
// Linear output is same as input for inp_size - 1 dimensions.
// K is already mapped above if not broadcast.
for (auto inx : c10::irange(inp_size - 1)) {
mapping[inx] = input_domain[inx];
}
mapping[out_size - 1] = input_domain.back();
break;
}
case 1: {
for (auto inx : c10::irange(inp_size)) {
// Map N, K to the last two positions of the output.
mapping[out_size - 1 - inx] = input_domain[inp_size - 1 - inx];
// Map N / out_features if present
if (inp_size > 1) {
mapping[out_size - 1 - red_dims] = input_domain.front();
}
break;
}
case 2: {
if (inp_size > 0) {
// Bias is 1D tensor of shape {out_features}
mapping[out_size - 2] = input_domain[0];
mapping[out_size - 1 - red_dims] = input_domain.front();
}
break;
}
Expand Down
3 changes: 2 additions & 1 deletion csrc/ops/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ std::vector<IterDomain*> mapMatmulOpIterDomains(
std::vector<IterDomain*> mapLinearOpIterDomains(
const std::vector<IterDomain*>& input_domain,
int64_t input_position,
size_t out_size);
size_t out_size,
bool k_bcast);

// Takes a vector of aligned input iterdomains to create the output iterdomain.
// This is used if the input iterdomains are not trivially mapped to the output
Expand Down
5 changes: 3 additions & 2 deletions csrc/scheduler/mma_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1809,13 +1809,14 @@ DimRolesMap MatmulPattern::getDimRoles(IdModel& id_model) const {

} else if (output->definition()->isA<LinearOp>()) {
const std::vector<IterDomain*>& out_logical = output->getLogicalDomain();
bool k_bcast = A->getLogicalDomain().back()->isBroadcast();
return matmulOrLinearOpDimRoles(
permissive_graph,
out_logical,
ops::mapLinearOpIterDomains(
A->getLogicalDomain(), 0, out_logical.size()),
A->getLogicalDomain(), 0, out_logical.size(), k_bcast),
ops::mapLinearOpIterDomains(
B->getLogicalDomain(), 1, out_logical.size()));
B->getLogicalDomain(), 1, out_logical.size(), k_bcast));
}

// The code below handles MmaOp or mul-sum patterns
Expand Down
23 changes: 14 additions & 9 deletions tests/cpp/test_matmul_aten_evaluation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,22 @@ void checkMatmulOpIdMapping(
const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT);
vg.validateConsistency();

// If K is Broadcast then we will not have a reduction dim
// If K is Broadcast then we will not have a reduction dim.
bool k_bcast = A->axis(-1)->isBroadcast();
auto out_ndims = std::max(A->nDims(), B->nDims()) + 1;
int64_t red_dims = k_bcast ? 0 : 1;
auto out_ndims = std::max(A->nDims(), B->nDims()) + red_dims;
if (std::min(A->nDims(), B->nDims()) == 1) {
out_ndims = std::max(A->nDims(), B->nDims());
out_ndims = std::max(A->nDims(), B->nDims()) - 1 + red_dims;
}
ASSERT_EQ(output->nDims(), out_ndims);

if (A->nDims() > 1) {
int out_mpos = B->nDims() > 1 ? -3 : -2;
int out_mpos = B->nDims() > 1 ? -2 - red_dims : -1 - red_dims;
EXPECT_TRUE(checkMapped(vg, A->axis(-2), output->axis(out_mpos))); // M
}

if (B->nDims() > 1) {
EXPECT_TRUE(checkMapped(vg, B->axis(-1), output->axis(-2))); // N
EXPECT_TRUE(checkMapped(vg, B->axis(-1), output->axis(-1 - red_dims))); // N
}

if (!k_bcast) {
Expand All @@ -85,7 +86,8 @@ void checkMatmulOpIdMapping(
// Note that A and B can have different dimensions, so here we count
// backwards from the innermost batch dimension. Then we check that the axis
// exists (is not negative) and is not Broadcast before checking mapping.
int batch_ndims = output->nDims() - (B->nDims() > 1) - (A->nDims() > 1) - 1;
int batch_ndims =
output->nDims() - (B->nDims() > 1) - (A->nDims() > 1) - red_dims;
for (int64_t i : c10::irange(batch_ndims)) {
int64_t i_a = A->nDims() - 3 - i;
int64_t i_b = B->nDims() - 3 - i;
Expand Down Expand Up @@ -115,7 +117,9 @@ void checkLinearOpIdMapping(
// bias (optional): [out_features]/[]
// output = [*, (out_features), rK]

ASSERT_EQ(output->nDims(), input->nDims() + weight->nDims() - 1);
bool k_bcast = input->axis(-1)->isBroadcast();
int64_t red_dims = k_bcast ? 0 : 1;
ASSERT_EQ(output->nDims(), input->nDims() + weight->nDims() - 2 + red_dims);

// Check that the first input_size - 1 dims are mapped for input
for (auto i : c10::irange(input->nDims() - 1)) {
Expand All @@ -126,10 +130,11 @@ void checkLinearOpIdMapping(
// Check out_features dim is mapped in weight & bias if present.
if (weight->nDims() > 1) {
if (!weight->axis(0)->isBroadcast()) {
EXPECT_TRUE(checkMapped(vg, weight->axis(0), output->axis(-2)));
EXPECT_TRUE(
checkMapped(vg, weight->axis(0), output->axis(-1 - red_dims)));
}
if (bias != nullptr && bias->nDims() > 0 && !bias->axis(0)->isBroadcast()) {
EXPECT_TRUE(checkMapped(vg, bias->axis(0), output->axis(-2)));
EXPECT_TRUE(checkMapped(vg, bias->axis(0), output->axis(-1 - red_dims)));
}
}
// Check mapping for reduction axis in input and weight
Expand Down
33 changes: 33 additions & 0 deletions tests/python/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4458,6 +4458,39 @@ def fusion_func(fd: FusionDefinition):
for i in range(num_out):
self.assertEqual(nvf_out[i].data_ptr(), inputs[0].data_ptr())

# Tests broadcast reduction axis in matmul: Issue #2532.
def test_repro_issue2532(self):
def fusion_func(fd: FusionDefinition) -> None:
T0 = fd.define_tensor(
shape=[-1, -1, 1],
contiguity=[True, None, True],
dtype=DataType.Float,
is_cpu=False,
stride_order=[2, 0, 1],
)
T1 = fd.define_tensor(
shape=[-1, 1, -1],
contiguity=[True, None, True],
dtype=DataType.Float,
is_cpu=False,
stride_order=[2, 1, 0],
)
T2 = fd.ops.sum(T1, dims=[0, 1], keepdim=False, dtype=DataType.Null)
T3 = fd.ops.matmul(T0, T1)
T4 = fd.ops.sum(T3, dims=[0], keepdim=False, dtype=DataType.Null)
fd.add_output(T2)
fd.add_output(T4)

inputs = [
torch.randn((262400,), dtype=torch.float32, device="cuda:0").as_strided(
(1025, 256, 1), (256, 1, 256)
),
torch.randn((1049600,), dtype=torch.float32, device="cuda:0").as_strided(
(1025, 1, 1024), (1024, 1024, 1)
),
]
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)


if __name__ == "__main__":
run_tests()

0 comments on commit 4a2987e

Please sign in to comment.