Skip to content

Commit

Permalink
adding comments per review request
Browse files Browse the repository at this point in the history
  • Loading branch information
jjsjann123 committed Nov 5, 2024
1 parent 1b4f2c1 commit 0e4e61f
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1050,8 +1050,21 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
}

void handle(const TernaryOp* top) final {
// Get vectorization information
// Note: vectorized TernaryOp looks something like:
// ```
// predicate
// ? LoadGlobalToLocal(&dst[0], &in2[index])
// : arraySet(&dst[0], in3);
// ```
//
// Current limitation:
// 1. only TernaryOpType::Where is supported;
// 2. predicate needs to be a scalar;
// 3. output needs to be a TensorView;
// 4. one and only one of the inputs needs to be a TensorView. (This is
// coming from validation analysis.)
if (top->out()->isA<kir::TensorIndex>()) {
// Get vectorization information
auto out_tv = top->out()->as<kir::TensorIndex>()->view();
int64_t vector_word_size = ir_utils::getVectorizeSize(out_tv);
bool is_vector_op = vectorize_scope_ && vector_word_size != 1;
Expand All @@ -1066,10 +1079,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
NVF_CHECK(
top->getTernaryOpType() == TernaryOpType::Where,
"vectorization only works on TernaryOp::where");

indent() << gen(top->in1()) << "\n";
indent() << kTab << "? ";

auto vec_load = [&out_tv, &top, &vector_word_size, this](Val* in) {
if (in->isScalar()) {
if (out_tv->getMemoryType() == MemoryType::Local &&
Expand Down

0 comments on commit 0e4e61f

Please sign in to comment.