Skip to content

Commit

Permalink
oops, fixing the generated code
Browse files Browse the repository at this point in the history
  • Loading branch information
jjsjann123 committed Nov 4, 2024
1 parent 1f75d7a commit 4c92371
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,26 +429,26 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
indent() << "loadLocalToGlobal<" << out->dtype() << ", /*vec_size=*/"
<< vector_word_size << ", /*is_volatile=*/"
<< (is_volatile_to ? "true" : "false") << ">(";
code_ << " &" << gen(out) << ", &" << gen(in) << ");\n";
code_ << " &" << gen(out) << ", &" << gen(in) << ")";
} else if (globalToLocal) {
indent() << "loadGlobalToLocal<" << out->dtype() << ", /*vec_size=*/"
<< vector_word_size << ", /*is_volatile=*/"
<< (is_volatile_from ? "true" : "false") << ", "
<< "CacheOp::" << cache_op << ">(&" << gen(out) << ", ";
code_ << " &" << gen(in) << ");\n";
code_ << " &" << gen(in) << ")";
} else if (globalToGlobal) {
indent() << "loadGlobalToGlobal<" << out->dtype() << ", /*vec_size=*/"
<< vector_word_size << ", /*is_volatile_to=*/"
<< (is_volatile_to ? "true" : "false")
<< ", /*is_volatile_from=*/"
<< (is_volatile_from ? "true" : "false") << ">(";
code_ << " &" << gen(out) << ", ";
code_ << " &" << gen(in) << ");\n";
code_ << " &" << gen(in) << ")";
} else {
indent() << "loadGeneric<" << out->dtype() << ", " << vector_word_size
<< ">(";
code_ << " &" << gen(out) << ", ";
code_ << " &" << gen(in) << ");\n";
code_ << " &" << gen(in) << ")";
}
}

Expand Down Expand Up @@ -1073,6 +1073,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
// TODO: should we have the option to specify cache level?
generateVectorizedLdSt(
top->in2(), top->out(), CacheOp::AllLevels, vector_word_size);
code_ << "\n";

if (out_tv->getMemoryType() == MemoryType::Local &&
!out_tv->isCircularBuffered()) {
Expand Down Expand Up @@ -1433,6 +1434,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {

generateVectorizedLdSt(
ldst->in(), ldst->out(), ldst->cacheOp(), vector_word_size);
code_ << ";\n";
}
return;
}
Expand Down

0 comments on commit 4c92371

Please sign in to comment.