Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow linear to take a >2D weight and a >1D bias. #3073

Merged
merged 7 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4326,14 +4326,42 @@ std::string LinearOp::toInlineString(int indent_size) const {
std::vector<PolymorphicValue> LinearOp::evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const {
const auto a = inputs.at(0).as<at::Tensor>();
const auto b = inputs.at(1).as<at::Tensor>();
const auto in = inputs.at(0).as<at::Tensor>();
wujingyue marked this conversation as resolved.
Show resolved Hide resolved
auto weight = inputs.at(1).as<at::Tensor>();

auto squeeze_device_dims = [](at::Tensor& t,
int64_t num_device_dims) -> void {
// Record the initial shape for the error message.
std::vector<int64_t> shape = t.sizes().vec();
for ([[maybe_unused]] auto _ : c10::irange(num_device_dims)) {
NVF_CHECK(
t.size(0) == 1,
"When the weight is >2D, expect its preceding dimensions and "
"the bias's preceding dimensions to "
"be DID-parallel and therefore size-1: ",
shape);
t = t.squeeze(0);
}
};

// The squeezes and unsqueezes are currently required to support a sharded
// linear layer. Remove them after #2563.
auto num_device_dims = weight.dim() - 2;
squeeze_device_dims(weight, num_device_dims);

at::Tensor out;
if (has_bias()) {
const auto bias = inputs.at(2).as<at::Tensor>();
return {at::linear(a, b, bias)};
auto bias = inputs.at(2).as<at::Tensor>();
squeeze_device_dims(bias, num_device_dims);
out = at::linear(in, weight, bias);
} else {
out = at::linear(in, weight);
}

for ([[maybe_unused]] auto _ : c10::irange(num_device_dims)) {
out = out.unsqueeze(0);
}
return {at::linear(a, b)};
return {out};
}

SdpaFwdOp::SdpaFwdOp(
Expand Down
51 changes: 33 additions & 18 deletions csrc/ops/composite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// clang-format on
#include <ATen/cuda/CUDAContext.h>
#include <ir/builder.h>
#include <ir/iostream.h>
#include <ops/all_ops.h>
#include <ops/utils.h>
#include <transform_view.h>
Expand Down Expand Up @@ -113,34 +114,48 @@ TensorView* newForLinear(
TensorView* linear(TensorView* input, TensorView* weight, TensorView* bias) {
auto input_ndims =
TensorDomain::noReductions(input->getLogicalDomain()).size();
NVF_CHECK(input_ndims > 0, "Input A must be atleast 1D.");

NVF_CHECK(input_ndims > 0, "Input A must be at least 1D.");

// `linear` previously supported 1D weight and 0D bias. The support was
// however removed by #3073 to support sharded linear layers, yet-another
// workaround of #2563. Otherwise, it would be unclear whether a 2D weight is
// one device dimension plus a non-device or two non-devices.
//
// If needed, we can still support 1D weight and 0D bias in Thunder by
// changing the thunder-to-nvFuser bridge to convert a 1D/0D linear to
// unsqueeze followed by a 2D/1D linear followed by a squeeze. It'll likely
// be the same speed because nvFuser treats squeezes and unsqueezes as meta
// ops and run them on the host.
auto weight_ndims =
TensorDomain::noReductions(weight->getLogicalDomain()).size();
NVF_CHECK(
weight_ndims == 1 || weight_ndims == 2,
"Input B must be a 1D / 2D tensor.");

// Note: This constraint is not documented but F.linear errors out if bias is
// given with 1D weights.
NVF_CHECK(
weight_ndims == 2 || bias == nullptr,
"Expected B to be a 2D matrix if bias is given, got 1D.")

weight_ndims >= 2,
"Input B must be at least 2D. The last two dimensions represent out "
"features and in features. The extra, preceding dimensions are expected "
"to be parallelized on DIDs during scheduling: ",
weight);
NVF_CHECK(
input->dtype() == weight->dtype(),
"Expected input and weight dtypes to have the same dtype, got: ",
input->dtype(),
" and ",
weight->dtype());

NVF_CHECK(
bias == nullptr || bias->dtype() == input->dtype(),
"Expected bias to have the same dtype as A and B, got: ",
bias->dtype(),
" and ",
input->dtype());
// For all other cases, create a new LinearOp
if (bias != nullptr) {
NVF_CHECK(
!TensorDomain::noReductions(bias->getLogicalDomain()).empty(),
"Input bias must be at least 1D. The last dimension represents out "
"features. The extra, preceding dimensions are expected to be "
"parallelized on DIDs during scheduling: ",
bias);
NVF_CHECK(
bias->dtype() == input->dtype(),
"Expected bias to have the same dtype as A and B, got: ",
bias->dtype(),
" and ",
input->dtype());
}

TensorView* out = newForLinear(input, weight, bias);
IrBuilder::create<LinearOp>(out, input, weight, bias);
return out;
Expand Down
86 changes: 50 additions & 36 deletions csrc/ops/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,49 +232,63 @@ std::vector<IterDomain*> mapLinearOpIterDomains(
size_t out_size,
bool k_bcast) {
std::vector<IterDomain*> mapping(out_size, nullptr);
auto inp_size = input_domain.size();

NVF_ERROR(
input_position == 0 || input_position == 1 || input_position == 2,
"Input position must be 0, 1, or 2. Found ",
input_position);

auto red_dims = k_bcast ? 0 : 1;
// Input: {*_i, K}
// Weight: {*_wb, N, K}
// Bias: {*_wb, N}
// Output: {*_wb, *_i, N, (rK)}. rK exists iff K is not a broadcast.
if (input_position == 0) {
// Fill `mapping` from the back.
auto in_r_index = static_cast<int64_t>(input_domain.size()) - 1;
auto out_index = static_cast<int64_t>(out_size) - 1;
// Map K if K is not a broadcast.
if (!k_bcast) {
mapping[out_index] = input_domain[in_r_index];
out_index--;
}
in_r_index--;

// Input A: {*, M, K}
// Input B: {*, N, K} / {K}
// Bias: {N} / {}
// Skip N because it's not in the input.
out_index--;

// Map K if K is not bcast
if (input_position != 2 && !k_bcast) {
mapping[out_size - 1] = input_domain.back();
}

switch (input_position) {
case 0: {
// Linear output is same as input for inp_size - 1 dimensions.
// K is already mapped above if not broadcast.
for (auto inx : c10::irange(inp_size - 1)) {
mapping[inx] = input_domain[inx];
}
break;
// Map the rest, i.e., *_i.
while (in_r_index >= 0) {
mapping[out_index] = input_domain[in_r_index];
in_r_index--;
out_index--;
}
case 1: {
// Map N / out_features if present
if (inp_size > 1) {
mapping[out_size - 1 - red_dims] = input_domain.front();
} else {
NVF_ERROR(
input_position == 1 || input_position == 2,
"Input position must be 0, 1, or 2. Found ",
input_position);

auto in_r_index = static_cast<int64_t>(input_domain.size()) - 1;
auto out_index = static_cast<int64_t>(out_size) - 1;
if (k_bcast) {
// If K is a broadcast, don't map K.
if (input_position == 1) {
// Skip K in the weight.
in_r_index--;
}
break;
}
case 2: {
if (inp_size > 0) {
// Bias is 1D tensor of shape {out_features}
mapping[out_size - 1 - red_dims] = input_domain.front();
} else {
// Otherwise, map K in the weight.
if (input_position == 1) {
mapping[out_index] = input_domain[in_r_index];
in_r_index--;
}
break;
out_index--;
}

// Fill `N`
mapping[out_index] = input_domain[in_r_index];

// Fill *_wb from the front.
out_index = 0;
for (auto in_index : c10::irange(in_r_index)) {
mapping[out_index] = input_domain[in_index];
out_index++;
}
default:
NVF_ERROR("Unexpected input type.");
}
return mapping;
}
Expand Down
10 changes: 5 additions & 5 deletions tests/cpp/test_matmul_aten_evaluation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ void checkLinearOpIdMapping(
vg.validateConsistency();

// input: [* , in_features]
// weight: [out_features, in_features] / [out_features]
// bias (optional): [out_features]/[]
// weight: [out_features, in_features]
// bias (optional): [out_features]
// output = [*, (out_features), rK]

bool k_bcast = input->axis(-1)->isBroadcast();
Expand Down Expand Up @@ -341,7 +341,7 @@ INSTANTIATE_TEST_SUITE_P(
Sizes({b, m, k}),
Sizes({1, k}),
Sizes({b, 1, k})),
testing::Values(Sizes({k}), Sizes({n, k}), Sizes({1, k})),
testing::Values(Sizes({n, k}), Sizes({1, k})),
testing::Values(std::nullopt)));

INSTANTIATE_TEST_SUITE_P(
Expand All @@ -355,7 +355,7 @@ INSTANTIATE_TEST_SUITE_P(
Sizes({1, k}),
Sizes({b, 1, k})),
testing::Values(Sizes({n, k})),
testing::Values(Sizes({}), Sizes({n}))));
testing::Values(Sizes({n}))));

INSTANTIATE_TEST_SUITE_P(
LinearReductionAxisIsOne,
Expand All @@ -368,6 +368,6 @@ INSTANTIATE_TEST_SUITE_P(
Sizes({1, 1}),
Sizes({b, 1, 1})),
testing::Values(Sizes({n, 1})),
testing::Values(Sizes({}), Sizes({n}))));
testing::Values(Sizes({n}))));

} // namespace nvfuser
6 changes: 0 additions & 6 deletions tests/cpp/test_translate_mma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -590,9 +590,7 @@ INSTANTIATE_TEST_SUITE_P(
std::make_tuple(2l, 2l, -1l, false, false, true),
std::make_tuple(2l, 2l, -1l, false, true, true),
std::make_tuple(1l, 2l, -1l, false, false, true),
std::make_tuple(2l, 1l, -1l, false, false, true),
std::make_tuple(2l, 2l, 1l, false, false, true),
std::make_tuple(1l, 1l, -1l, false, false, true),
std::make_tuple(3l, 2l, 1l, false, false, true),
std::make_tuple(4l, 2l, 1l, false, false, true),

Expand All @@ -603,15 +601,11 @@ INSTANTIATE_TEST_SUITE_P(
std::make_tuple(2l, 2l, -1l, true, true, true),
// We don't fuse 1D inputs
std::make_tuple(1l, 2l, -1l, true, false, true),
std::make_tuple(2l, 1l, -1l, true, false, true),
// Check that zero-dim output fusion is not claimed by NoOp scheduler
std::make_tuple(1l, 1l, -1l, true, false, true),
// Batch dims in input
// mixed length inputs via broadcasted batch dims
std::make_tuple(3l, 2l, -1l, true, false, false),
std::make_tuple(4l, 2l, -1l, true, false, false),
// Bias cases
std::make_tuple(2l, 2l, 0l, true, false, false),
std::make_tuple(2l, 2l, 1l, true, false, false),
std::make_tuple(3l, 2l, 1l, true, false, false),
std::make_tuple(4l, 2l, 1l, true, false, false)),
Expand Down
12 changes: 3 additions & 9 deletions tests/python/opinfo_input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1523,13 +1523,13 @@ def linear_input_generator(

# Cases without bias
shapes_input = ((K), (M, K), (B, M, K), (B, 1, M, K))
shapes_weight = ((K), (N, K), (1, K))
shapes_weight = ((N, K), (1, K))
for shape_input, shape_weight in itertools.product(shapes_input, shapes_weight):
yield SampleInput(make_arg(shape_input), make_arg(shape_weight))

# Cases with bias
shape_weight = (N, K)
shapes_bias = ((), (N,))
wujingyue marked this conversation as resolved.
Show resolved Hide resolved
shapes_bias = ((N,),)
for shape_input, shape_bias in itertools.product(shapes_input, shapes_bias):
yield SampleInput(
make_arg(shape_input), make_arg(shape_weight), make_arg(shape_bias)
Expand All @@ -1547,19 +1547,13 @@ def linear_error_generator(
N = 256
K = 32

bias_with_1dweight = (
((M, K), (K), (N)),
RuntimeError,
"Expected B to be a 2D matrix if bias is given, got 1D.",
)

mismatched_bias_extent = (
((M, K), (1, K), (N)),
RuntimeError,
f"The expanded size of the tensor (1) must match the existing size ({N}) at non-singleton dimension 1. Target sizes: [{M}, 1]. Tensor sizes: [{N}]",
)

error_cases = [bias_with_1dweight, mismatched_bias_extent]
error_cases = [mismatched_bias_extent]

for input_shapes, ex_type, ex_str in error_cases:
shape_input, shape_weight, shape_bias = input_shapes
Expand Down
19 changes: 8 additions & 11 deletions tests/python/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def test_linear(self):
m = 24
n = 16
k = 8
bias0d = torch.tensor(3.14, device="cuda", dtype=torch.float16)
bias1d = torch.randn(n, device="cuda", dtype=torch.float16)

inputs_mk_nk = [
Expand Down Expand Up @@ -92,17 +91,15 @@ def fusion_func(
fd.add_output(t_out)

in_tensors = [inputs_mk_nk, inputs_mk_kn, inputs_km_nk, inputs_km_kn]
use_bias = [None, bias0d, bias1d]
for [inp, wt], use_bias in list(itertools.product(in_tensors, use_bias)):
with self.subTest(inp=inp, wt=wt, use_bias=use_bias):
input_tensors = (
(inp, wt, use_bias) if use_bias is not None else (inp, wt)
)
bias = [None, bias1d]
for [inp, wt], bias in list(itertools.product(in_tensors, bias)):
with self.subTest(inp=inp, wt=wt, bias=bias):
input_tensors = (inp, wt, bias) if bias is not None else (inp, wt)
nvf_out, _ = self.exec_nvfuser(
partial(fusion_func, inp=inp, wt=wt, bias=use_bias),
partial(fusion_func, inp=inp, wt=wt, bias=bias),
input_tensors,
)
eager_out = F.linear(input=inp, weight=wt, bias=use_bias)
eager_out = F.linear(inp, wt, bias)
fp16_nvf_out = nvf_out[0]
torch.testing.assert_close(fp16_nvf_out, eager_out, atol=1e-3, rtol=0)

Expand Down Expand Up @@ -141,7 +138,7 @@ def fusion_func(fd: FusionDefinition):
fd.add_output(T2)
fd.add_output(T4)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
self.exec_nvfuser(fusion_func, inputs)

# Tests broadcast reduction axis in matmul: Issue #2532.
def test_repro_issue2532(self):
Expand Down Expand Up @@ -174,7 +171,7 @@ def fusion_func(fd: FusionDefinition) -> None:
(1025, 1, 1024), (1024, 1024, 1)
),
]
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
self.exec_nvfuser(fusion_func, inputs)

def test_linear_slice(self):
def fusion_func(fd: FusionDefinition) -> None:
Expand Down
Loading
Loading