Skip to content

Commit

Permalink
Shard backprop. (#3232)
Browse files Browse the repository at this point in the history
This PR tries to parallelize inputs according to
https://arxiv.org/pdf/1909.08053. `propagate_shardings` is able to
propagate parallelization to intermediate tensors and outputs.

Fixes #2199.
  • Loading branch information
wujingyue authored Oct 22, 2024
1 parent c13bfb8 commit 96fe83d
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 79 deletions.
7 changes: 7 additions & 0 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4985,6 +4985,13 @@ std::vector<PolymorphicValue> SdpaBwdOp::evaluate(
// logsumexp, max_q/k Temporary handling of DID parallelization. See
// https://github.com/NVIDIA/Fuser/issues/2563
bool first_dim_is_did = this->key()->as<TensorView>()->axis(0)->isDeviceDim();
auto out_grad = inputs[0].as<at::Tensor>();
if (first_dim_is_did) {
NVF_CHECK(out_grad.dim() == 5, "Expected 5D but found ", out_grad.sizes());
} else {
NVF_CHECK(out_grad.dim() == 4, "Expected 4D but found ", out_grad.sizes());
}

std::vector<at::Tensor> bwd_inputs;
for (auto idx : c10::irange(6)) {
auto in_tensor = inputs.at(idx).as<at::Tensor>();
Expand Down
4 changes: 3 additions & 1 deletion csrc/python_frontend/fusion_record.h
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,9 @@ struct BroadcastInDimOpRecord : RecordFunctor {
"The broadcast dimensions should match the input dimensions: ",
arg_ndims,
" vs ",
broadcast_dims_.size());
broadcast_dims_.size(),
". arg = ",
arg->toString());

std::vector<bool> is_broadcast_dim(output_ndims_, true);
for (const auto idx : c10::irange(broadcast_dims_.size())) {
Expand Down
Loading

0 comments on commit 96fe83d

Please sign in to comment.