Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
SiriusNEO committed Feb 17, 2023
1 parent 35ebddc commit 9001753
Showing 1 changed file with 12 additions and 18 deletions.
30 changes: 12 additions & 18 deletions src/relax/transform/legalize_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,41 +44,35 @@ class LegalizeMutator : public ExprMutator {
static const Op& call_tir_op = Op::Get("relax.call_tir");
auto* op_node = visited_call->op.as<OpNode>();

if (op_node == nullptr) { // not a OpNode
// Not an OpNode
if (op_node == nullptr) {
return visited_call;
}

// Not all shape values are known
if (!std::all_of(visited_call->args.begin(), visited_call->args.end(),
[](Expr arg) { return KnowAllShapeValues(GetStructInfo(arg)); }) ||
!KnowAllShapeValues(GetStructInfo(visited_call))) { // Not all shape values are known
!KnowAllShapeValues(GetStructInfo(visited_call))) {
return visited_call;
}

auto op = GetRef<Op>(op_node);
FLegalize flegalize;
bool has_legalize = false;

// Priority: customize > default.
// Check if it has customize legalization registered.
if (cmap_.defined() && cmap_.value().count(op->name)) {
flegalize = cmap_.value()[op->name];
has_legalize = true;
return cmap_.value()[op->name](this->builder_, visited_call);
}
// Check if it has default legalization registered.
if (!legalize_map.count(op)) {
if (op != call_tir_op) {
LOG(WARNING) << "No legalization func for " << op->name << " is found.";
}
} else if (!has_legalize) {
flegalize = legalize_map[op];
has_legalize = true;
if (legalize_map.count(op)) {
return legalize_map[op](this->builder_, visited_call);
}

if (has_legalize) {
return flegalize(this->builder_, visited_call);
} else {
return visited_call; // No legalization.
// No legalization.
if (op != call_tir_op) {
LOG(WARNING) << "No legalization func for " << op->name << " is found.";
}
return visited_call;
}

IRModule Transform() {
Expand All @@ -94,7 +88,7 @@ class LegalizeMutator : public ExprMutator {
private:
IRModule mod_;
Optional<Map<String, PackedFunc>> cmap_;
};
}; // namespace relax

namespace transform {

Expand Down

0 comments on commit 9001753

Please sign in to comment.