Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lowering vectorized pad #3261

Merged
merged 51 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
8f9708f
relaxing check
jjsjann123 Sep 2, 2024
54826aa
allow cache on inputs for pad
jjsjann123 Sep 3, 2024
e54938c
Merge remote-tracking branch 'origin/main' into jjsjann123/resize_vec
jjsjann123 Sep 3, 2024
2bc3c7a
cpp example
jjsjann123 Sep 24, 2024
d04e8c3
Merge branch 'jjsjann123/pad_vec' into jjsjann123/resize_vec
jjsjann123 Sep 24, 2024
d0addc4
reverting earlier changes
jjsjann123 Sep 24, 2024
490fdbe
Revert "reverting earlier changes"
jjsjann123 Sep 24, 2024
51c3022
cherry-pick my revert
jjsjann123 Sep 24, 2024
1158ef0
Merge remote-tracking branch 'origin/main' into jjsjann123/resize_vec
jjsjann123 Oct 2, 2024
fdc6a9a
debug print
jjsjann123 Oct 3, 2024
9a6c03a
Merge remote-tracking branch 'origin/main' into jjsjann123/resize_vec
jjsjann123 Oct 4, 2024
a9d16ce
removing comments
jjsjann123 Oct 7, 2024
3401119
removing assert
jjsjann123 Oct 8, 2024
5d05284
Merge remote-tracking branch 'origin/main' into jjsjann123/resize_vec
jjsjann123 Oct 8, 2024
b6587ee
patching test
jjsjann123 Oct 10, 2024
28decac
Merge remote-tracking branch 'origin/main' into jjsjann123/resize_vec
jjsjann123 Oct 10, 2024
3e53feb
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Oct 20, 2024
ad61ecb
fixing test
jjsjann123 Oct 20, 2024
a8edc56
fixing
jjsjann123 Oct 20, 2024
9cdeb64
fixing test
jjsjann123 Oct 21, 2024
09a2aee
does this work to replace Ternary(where) with IfThenElse
jjsjann123 Oct 21, 2024
895d0bf
fixing build
jjsjann123 Oct 21, 2024
7a15e22
removing print
jjsjann123 Oct 22, 2024
a6e8fb1
restore lower to ternary:where; restore vectorization on tests
jjsjann123 Oct 22, 2024
fe0f263
testing water
jjsjann123 Oct 23, 2024
baa7b09
fixing syntax
jjsjann123 Oct 23, 2024
ca5ced1
now it's functional
jjsjann123 Oct 23, 2024
e0492d3
better formatting on printed code
jjsjann123 Oct 23, 2024
b528429
adding a tab
jjsjann123 Oct 23, 2024
a23e010
supporting local memory
jjsjann123 Oct 23, 2024
57b90d1
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Oct 23, 2024
7a976c7
clangformat
jjsjann123 Oct 23, 2024
f11d662
apparently there are ternary operations on scalars
jjsjann123 Oct 23, 2024
5a83fc6
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Oct 23, 2024
39f83f7
fixing
jjsjann123 Oct 23, 2024
07eafd1
fixing
jjsjann123 Oct 23, 2024
986b361
clangformat
jjsjann123 Oct 23, 2024
7409913
clangformat
jjsjann123 Oct 23, 2024
76cbcd8
clangformat again
jjsjann123 Oct 23, 2024
5f996fc
Merge branch 'main' into jjsjann123/resize_vec
jjsjann123 Oct 31, 2024
803a95b
Merge remote-tracking branch 'origin/main' into jjsjann123/resize_vec
jjsjann123 Nov 4, 2024
a67fb57
polish PR for review
jjsjann123 Nov 4, 2024
11cd4d1
Merge remote-tracking branch 'origin/jjsjann123/resize_vec' into jjsj…
jjsjann123 Nov 4, 2024
65aa77d
Merge remote-tracking branch 'origin/main' into jjsjann123/resize_vec
jjsjann123 Nov 4, 2024
1f75d7a
missed one arg
jjsjann123 Nov 4, 2024
4c92371
oops, fixing the generated code
jjsjann123 Nov 4, 2024
3ec2a6b
review comments
jjsjann123 Nov 4, 2024
d2864ab
fixing code
jjsjann123 Nov 5, 2024
1b4f2c1
I think this is fixed now
jjsjann123 Nov 5, 2024
0e4e61f
adding comments per review request
jjsjann123 Nov 5, 2024
4d4f747
another comment
jjsjann123 Nov 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 97 additions & 47 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,56 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
}
}

void generateVectorizedLdSt(
Val* in,
Val* out,
CacheOp cache_op,
int64_t vector_word_size) {
auto out_tv = out->as<kir::TensorIndex>()->view();
auto in_tv = in->as<kir::TensorIndex>()->view();

bool localToGlobal = out_tv->getMemoryType() == MemoryType::Global &&
in_tv->getMemoryType() == MemoryType::Local;

bool globalToLocal = out_tv->getMemoryType() == MemoryType::Local &&
in_tv->getMemoryType() == MemoryType::Global;

bool globalToGlobal = out_tv->getMemoryType() == MemoryType::Global &&
in_tv->getMemoryType() == MemoryType::Global;

bool is_volatile_to = out_tv->getMemoryType() == MemoryType::Global &&
kernel_->summary().sync_map->needsRawSync(out_tv).hasBID();

bool is_volatile_from = in_tv->getMemoryType() == MemoryType::Global &&
kernel_->summary().sync_map->needsRawSync(in_tv).hasBID();

if (localToGlobal) {
indent() << "loadLocalToGlobal<" << out->dtype() << ", /*vec_size=*/"
<< vector_word_size << ", /*is_volatile=*/"
<< (is_volatile_to ? "true" : "false") << ">(";
code_ << " &" << gen(out) << ", &" << gen(in) << ")";
} else if (globalToLocal) {
indent() << "loadGlobalToLocal<" << out->dtype() << ", /*vec_size=*/"
<< vector_word_size << ", /*is_volatile=*/"
<< (is_volatile_from ? "true" : "false") << ", "
<< "CacheOp::" << cache_op << ">(&" << gen(out) << ", ";
code_ << " &" << gen(in) << ")";
} else if (globalToGlobal) {
indent() << "loadGlobalToGlobal<" << out->dtype() << ", /*vec_size=*/"
<< vector_word_size << ", /*is_volatile_to=*/"
<< (is_volatile_to ? "true" : "false")
<< ", /*is_volatile_from=*/"
<< (is_volatile_from ? "true" : "false") << ">(";
code_ << " &" << gen(out) << ", ";
code_ << " &" << gen(in) << ")";
} else {
indent() << "loadGeneric<" << out->dtype() << ", " << vector_word_size
<< ">(";
code_ << " &" << gen(out) << ", ";
code_ << " &" << gen(in) << ")";
}
}

// Cannot just use ConstIrVisitor::handle as it expects a vector of
// const Expr*, whereas most of the IR API returns a vector of
// non-const Expr*.
Expand Down Expand Up @@ -1001,6 +1051,50 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
}

void handle(const TernaryOp* top) final {
// Get vectorization information
jjsjann123 marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some comments about the expectation? IIUC, only in2 is allowed to be vectorized, but technically speaking, it should be possible to have vectorized loads in both in2 and in3, right? Not sure if it's worthwhile to allow that as well, although the required change seems minimal.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we can have in2 / in3 as TensorViews, I'm trying to add that since @zasdfgbnm mentioned about having a where test.

if (top->out()->isA<kir::TensorIndex>()) {
auto out_tv = top->out()->as<kir::TensorIndex>()->view();
int64_t vector_word_size = ir_utils::getVectorizeSize(out_tv);
bool is_vector_op = vectorize_scope_ && vector_word_size != 1;

if (is_vector_op) {
NVF_CHECK(!top->in2()->isScalar(), "input2 should be a tensor");
NVF_CHECK(top->in3()->isScalar(), "input3 should be a scalar");
NVF_CHECK(
!top->out()->isScalar(),
"scalar output in vectorization isn't supported");
NVF_CHECK(
top->getTernaryOpType() == TernaryOpType::Where,
"vectorization only works on TernaryOp::where");

indent() << gen(top->in1()) << "\n";
indent() << kTab << "? ";

// TODO: should we have the option to specify cache level?
generateVectorizedLdSt(
top->in2(), top->out(), CacheOp::AllLevels, vector_word_size);
code_ << "\n";

if (out_tv->getMemoryType() == MemoryType::Local &&
!out_tv->isCircularBuffered()) {
// Vectorized initialization, explicit type conversion is needed for
// complex numbers
indent() << kTab << ": " << genVariableName(out_tv) << ".set("
naoyam marked this conversation as resolved.
Show resolved Hide resolved
<< genCall(out_tv->dtype(), gen(top->in3())) << ");\n";
} else {
// Note: currently arraySet option is not vectorized, so it will
// rely on auto vectorization pass of cuda compiler.
indent() << kTab << ": "
<< "arraySet<" << out_tv->getDataType().value() << ", "
naoyam marked this conversation as resolved.
Show resolved Hide resolved
<< vector_word_size << ">(&" << gen(top->out()) << ", ("
<< out_tv->getDataType().value() << ")" << gen(top->in3())
<< ");\n";
}

return;
}
}

if (!print_inline_) {
indent() << gen(top->out());
if (!top->out()->isScalar()) {
Expand Down Expand Up @@ -1338,53 +1432,9 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
"Invalid input to unary op with tensor output, found: ",
ldst->in()->toString());

auto in_tv = ldst->in()->as<kir::TensorIndex>()->view();
bool localToGlobal = out_tv->getMemoryType() == MemoryType::Global &&
in_tv->getMemoryType() == MemoryType::Local;

bool globalToLocal = out_tv->getMemoryType() == MemoryType::Local &&
in_tv->getMemoryType() == MemoryType::Global;

bool globalToGlobal = out_tv->getMemoryType() == MemoryType::Global &&
in_tv->getMemoryType() == MemoryType::Global;

bool is_volatile_to = out_tv->getMemoryType() == MemoryType::Global &&
kernel_->summary().sync_map->needsRawSync(out_tv).hasBID();

bool is_volatile_from =
in_tv->getMemoryType() == MemoryType::Global &&
kernel_->summary().sync_map->needsRawSync(in_tv).hasBID();

if (localToGlobal) {
indent() << "loadLocalToGlobal<" << ldst->out()->dtype()
<< ", /*vec_size=*/" << vector_word_size
<< ", /*is_volatile=*/"
<< (is_volatile_to ? "true" : "false") << ">(";
code_ << " &" << gen(ldst->out()) << ", &" << gen(ldst->in())
<< ");\n";
} else if (globalToLocal) {
indent() << "loadGlobalToLocal<" << ldst->out()->dtype()
<< ", /*vec_size=*/" << vector_word_size
<< ", /*is_volatile=*/"
<< (is_volatile_from ? "true" : "false") << ", "
<< "CacheOp::" << ldst->cacheOp() << ">(&"
<< gen(ldst->out()) << ", ";
code_ << " &" << gen(ldst->in()) << ");\n";
} else if (globalToGlobal) {
indent() << "loadGlobalToGlobal<" << ldst->out()->dtype()
<< ", /*vec_size=*/" << vector_word_size
<< ", /*is_volatile_to=*/"
<< (is_volatile_to ? "true" : "false")
<< ", /*is_volatile_from=*/"
<< (is_volatile_from ? "true" : "false") << ">(";
code_ << " &" << gen(ldst->out()) << ", ";
code_ << " &" << gen(ldst->in()) << ");\n";
} else {
indent() << "loadGeneric<" << ldst->out()->dtype() << ", "
<< vector_word_size << ">(";
code_ << " &" << gen(ldst->out()) << ", ";
code_ << " &" << gen(ldst->in()) << ");\n";
}
generateVectorizedLdSt(
ldst->in(), ldst->out(), ldst->cacheOp(), vector_word_size);
code_ << ";\n";
}
return;
}
Expand Down
4 changes: 0 additions & 4 deletions csrc/device_lower/lower2device.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@

namespace nvfuser {

// TODO: we frequently use pairwise root mapping from consumers to producers.
// This information is implicitly in the computeAtMaps, but there's no isolated
// container for this information that we can reuse. Would be nice to generate
// such a structure and propagate it through lowering.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class GpuLower : public NonCopyable {
class KernelIrMapper;
Expand Down
3 changes: 2 additions & 1 deletion csrc/device_lower/pass/predicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator {
"Expecting predicated body to only have one vectorized expression.");
auto vec_expr = ite->thenBody()[0];
NVF_ERROR(
vec_expr->isA<UnaryOp>() || vec_expr->isA<LoadStoreOp>(),
vec_expr->isA<UnaryOp>() || vec_expr->isA<LoadStoreOp>() ||
vec_expr->isA<TernaryOp>(),
"Vectorize predicate exprs only supported on set operations.");
NVF_ERROR(
ir_utils::isTvOp(vec_expr),
Expand Down
1 change: 1 addition & 0 deletions csrc/device_lower/validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,7 @@ void validateAndCollectVectorizeInfo(Fusion* fusion) {
Expr* def = tv->definition();
NVF_ERROR(
def == nullptr || def->isA<LoadStoreOp>() || def->isA<SliceOp>() ||
def->isA<PadOp>() ||
(def->isA<ReductionOp>() &&
def->as<ReductionOp>()->serialGridReductionRequested()),
"Vectorized accesses cannot be inline with computation: ",
Expand Down
33 changes: 33 additions & 0 deletions tests/cpp/test_resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4041,4 +4041,37 @@ TEST_F(ResizeTest, SliceSliceConcatConcat) {
NVF_CHECK(ref.equal(cg_outputs[0]));
}

// manual scheduling that should have vectorized load on padded inputs.
TEST_F(ResizeTest, VectorizePadLowering) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have a test for vectorizing where without using pad?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call. almost forgot that we have where directly 🤕

auto fusion_ptr = std::make_unique<Fusion>();
auto& fusion = *fusion_ptr;
FusionGuard fg(fusion_ptr.get());

const std::vector<int64_t> shape({1024L * 1024L});

auto tv0 = makeContigConcreteTensor(shape);
fusion.addInput(tv0);

auto tv1 = pad(tv0, {IrBuilder::create<Val>(4L), IrBuilder::create<Val>(4L)});
fusion.addOutput(tv1);

tv1->split(0, 4);
tv1->split(0, 128);

tv1->axis(0)->parallelize(ParallelType::BIDx);
tv1->axis(1)->parallelize(ParallelType::TIDx);
tv1->axis(2)->parallelize(ParallelType::Vectorize);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn(shape, options);
std::vector<c10::IValue> aten_inputs({t0});

FusionExecutor fe;
fe.compileFusion(&fusion, aten_inputs);
auto cg_outputs = fe.runFusion(aten_inputs);

auto ref = at::pad(t0, {4, 4});
ASSERT_TRUE(ref.equal(cg_outputs[0]));
}

} // namespace nvfuser
Loading