diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc index d1afdd74ccdd..dd9fac9fdcd0 100644 --- a/src/relax/ir/binding_rewrite.cc +++ b/src/relax/ir/binding_rewrite.cc @@ -199,6 +199,8 @@ class RemoveUnusedVars : public ExprMutator { do { prev_size = unused.size(); + std::vector used; + used.reserve(users.size()); for (const auto& kv : users) { // var -> [users...] // var is unused iff @@ -207,17 +209,22 @@ class RemoveUnusedVars : public ExprMutator { if (kv.second.empty() && // kv.first is not used by fn outputs. fn_outputs.end() == std::find(fn_outputs.begin(), fn_outputs.end(), kv.first)) { unused.push_back(kv.first); + } else { + used.push_back(kv.first); } } for (size_t i = prev_size; i < unused.size(); ++i) { users.erase(unused[i]); // remove def site. - for (auto kv : users) { // remove use site. - auto it = std::find(kv.second.begin(), kv.second.end(), unused[i]); - if (it != kv.second.end()) { - kv.second.erase(it); - users.Set(kv.first, std::move(kv.second)); + for (const auto& used_var : used) { + ICHECK(users.count(used_var)); + Array var_users = users[used_var]; + // remove the unused var from the use site. + auto it = std::find(var_users.begin(), var_users.end(), unused[i]); + if (it != var_users.end()) { + var_users.erase(it); + users.Set(used_var, std::move(var_users)); } } }