Skip to content

Commit

Permalink
Allow linear to take a >2D weight and a >1D bias.
Browse files Browse the repository at this point in the history
As long as the extra dimensions are DID-parallel.
  • Loading branch information
wujingyue committed Oct 2, 2024
1 parent 08d4482 commit 64ffd24
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 41 deletions.
36 changes: 31 additions & 5 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4303,14 +4303,40 @@ std::string LinearOp::toInlineString(int indent_size) const {
std::vector<PolymorphicValue> LinearOp::evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const {
const auto a = inputs.at(0).as<at::Tensor>();
const auto b = inputs.at(1).as<at::Tensor>();
const auto in = inputs.at(0).as<at::Tensor>();
auto weight = inputs.at(1).as<at::Tensor>();

auto squeeze_device_dims = [](at::Tensor& t,
int64_t num_device_dims) -> void {
// Record the initial shape for the error message.
std::vector<int64_t> shape = t.sizes().vec();
for ([[maybe_unused]] auto _ : c10::irange(num_device_dims)) {
NVF_CHECK(
t.size(0) == 1,
"When the weight is >2D, expect its preceding dimensions and "
"the bias's preceding dimensions to "
"be DID-parallel and therefore size-1: ",
shape);
t = t.squeeze(0);
}
};

auto num_device_dims = weight.dim() - 2;
squeeze_device_dims(weight, num_device_dims);

at::Tensor out;
if (has_bias()) {
const auto bias = inputs.at(2).as<at::Tensor>();
return {at::linear(a, b, bias)};
auto bias = inputs.at(2).as<at::Tensor>();
squeeze_device_dims(bias, num_device_dims);
out = at::linear(in, weight, bias);
} else {
out = at::linear(in, weight);
}

for ([[maybe_unused]] auto _ : c10::irange(num_device_dims)) {
out = out.unsqueeze(0);
}
return {at::linear(a, b)};
return {out};
}

SdpaFwdOp::SdpaFwdOp(
Expand Down
6 changes: 4 additions & 2 deletions csrc/ops/composite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,13 @@ static TensorView* newForLinear(
TensorView* linear(TensorView* input, TensorView* weight, TensorView* bias) {
auto input_ndims =
TensorDomain::noReductions(input->getLogicalDomain()).size();
NVF_CHECK(input_ndims > 0, "Input A must be atleast 1D.");
NVF_CHECK(input_ndims > 0, "Input A must be at least 1D.");

auto weight_ndims =
TensorDomain::noReductions(weight->getLogicalDomain()).size();
NVF_CHECK(weight_ndims == 2, "Input B must be a 2D tensor.");
NVF_CHECK(
weight_ndims >= 2,
"Input B must be at least 2D. The extra, preceding dimensions are expected to be parallelized on DIDs during scheduling.");

NVF_CHECK(
input->dtype() == weight->dtype(),
Expand Down
71 changes: 49 additions & 22 deletions csrc/ops/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,44 +232,71 @@ std::vector<IterDomain*> mapLinearOpIterDomains(
size_t out_size,
bool k_bcast) {
std::vector<IterDomain*> mapping(out_size, nullptr);
auto inp_size = input_domain.size();

NVF_ERROR(
input_position == 0 || input_position == 1 || input_position == 2,
"Input position must be 0, 1, or 2. Found ",
input_position);

auto red_dims = k_bcast ? 0 : 1;

// Input A: {*, M, K}
// Input B: {*, N, K} / {K}
// Bias: {N} / {}

// Map K if K is not bcast
if (input_position != 2 && !k_bcast) {
mapping[out_size - 1] = input_domain.back();
}

// Input: {*_i, K}
// Weight: {*_wb, N, K}
// Bias: {*_wb, N}
// Output: {*_wb, *_i, N, (rK)}. rK exists iff K is not a broadcast.
switch (input_position) {
case 0: {
// Linear output is same as input for inp_size - 1 dimensions.
// K is already mapped above if not broadcast.
for (auto inx : c10::irange(inp_size - 1)) {
mapping[inx] = input_domain[inx];
// Fill `mapping` from the back.
// Map K if K is not a broadcast.
auto in_index = input_domain.rbegin();
auto out_index = static_cast<int64_t>(out_size) - 1;
if (!k_bcast) {
mapping[out_index] = *in_index;
in_index++;
out_index--;
}

// Skip N.
out_index--;

// Map the rest, i.e., *_i.
while (in_index != input_domain.rend()) {
mapping[out_index] = *in_index;
in_index++;
out_index--;
}
break;
}
case 1: {
// Map N / out_features if present
if (inp_size > 1) {
mapping[out_size - 1 - red_dims] = input_domain.front();
auto in_r_index = static_cast<int64_t>(input_domain.size()) - 1;
// Map K if K is not broadcast.
auto out_index = static_cast<int64_t>(out_size) - 1;
if (!k_bcast) {
mapping[out_index] = input_domain[in_r_index];
in_r_index--;
out_index--;
}
// Fill `N`
mapping[out_index] = input_domain[in_r_index];

// Fill *_wb from the front.
out_index = 0;
for (auto in_index : c10::irange(in_r_index)) {
mapping[out_index] = input_domain[in_index];
out_index++;
}
break;
}
case 2: {
if (inp_size > 0) {
// Bias is 1D tensor of shape {out_features}
mapping[out_size - 1 - red_dims] = input_domain.front();
auto out_index = static_cast<int64_t>(out_size) - 1;
// If K is not a broadcast, skip K.
if (!k_bcast) {
out_index--;
}
// Fill `N`
mapping[out_index] = input_domain.back();

// Fill *_wb from the front.
for (auto index : c10::irange(input_domain.size() - 1)) {
mapping[index] = input_domain[index];
}
break;
}
Expand Down
14 changes: 6 additions & 8 deletions tests/python/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,15 @@ def fusion_func(
fd.add_output(t_out)

in_tensors = [inputs_mk_nk, inputs_mk_kn, inputs_km_nk, inputs_km_kn]
use_bias = [None, bias0d, bias1d]
for [inp, wt], use_bias in list(itertools.product(in_tensors, use_bias)):
with self.subTest(inp=inp, wt=wt, use_bias=use_bias):
input_tensors = (
(inp, wt, use_bias) if use_bias is not None else (inp, wt)
)
bias = [None, bias0d, bias1d]
for [inp, wt], bias in list(itertools.product(in_tensors, bias)):
with self.subTest(inp=inp, wt=wt, bias=bias):
input_tensors = (inp, wt, bias) if bias is not None else (inp, wt)
nvf_out, _ = self.exec_nvfuser(
partial(fusion_func, inp=inp, wt=wt, bias=use_bias),
partial(fusion_func, inp, wt, bias),
input_tensors,
)
eager_out = F.linear(input=inp, weight=wt, bias=use_bias)
eager_out = F.linear(inp, wt, bias)
fp16_nvf_out = nvf_out[0]
torch.testing.assert_close(fp16_nvf_out, eager_out, atol=1e-3, rtol=0)

Expand Down
14 changes: 10 additions & 4 deletions tests/python/test_multidevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def __init__(self, num_devices, batch, sequence, hidden):
def definition(self):
d, b, s, h = self._num_devices, self._batch, self._sequence, self._hidden
self.inp = self.define_tensor([b, s, h])
self.weight = self.define_tensor([d, h, h])
self.bias = self.define_tensor([d, h])
self.weight = self.define_tensor([d, h, h], contiguity=[True, True, True])
self.bias = self.define_tensor([d, h], contiguity=[True, True])
out = self.ops.linear(self.inp, self.weight, self.bias)
self.add_output(out)

Expand All @@ -101,5 +101,11 @@ def multidevice_schedule(self):
fn = Model(d, b, s, h)
out_tensors = fn.execute([inp_tensor, weight_tensor, bias_tensor])

unsharded_out_tensor = torch.linear(inp_tensor, unsharded_weight_tensor, unsharded_bias_tensor)
torch.testing.assert_close(out_tensors[0], unsharded_out_tensor[..., rank * h: (rank + 1) * h])
# [b, s, d*h]
unsharded_out_tensor = torch.nn.functional.linear(
inp_tensor, unsharded_weight_tensor, unsharded_bias_tensor
)
sharded_out_tensor = unsharded_out_tensor.view([b, s, d, h]).permute(2, 0, 1, 3)[
rank : rank + 1
]
torch.testing.assert_close(out_tensors[0], sharded_out_tensor)

0 comments on commit 64ffd24

Please sign in to comment.