diff --git a/csrc/ir/base_nodes.cpp b/csrc/ir/base_nodes.cpp index 16c2b61b397..5c22294b698 100644 --- a/csrc/ir/base_nodes.cpp +++ b/csrc/ir/base_nodes.cpp @@ -248,23 +248,6 @@ std::optional Val::getDataType() const { return dtype_; } -bool Val::isProducerOf(const Val* other) const { - NVF_ERROR(other != nullptr); - NVF_ERROR(container() == other->container()); - - if (definition() == nullptr) { - return false; - } - return std::any_of( - definition()->inputs().begin(), - definition()->inputs().end(), - [other](const Val* input) { return input == other; }); -} - -bool Val::isConsumerOf(const Val* other) const { - return other->isProducerOf(this); -} - // We don't register with the active fusion in Expr as this needs to be done // after inputs and outputs are registered with the Expr Expr::Expr(IrBuilderPasskey passkey) : Statement(passkey) {} diff --git a/csrc/ir/base_nodes.h b/csrc/ir/base_nodes.h index 024fe3cbf4b..efcb2e475b6 100644 --- a/csrc/ir/base_nodes.h +++ b/csrc/ir/base_nodes.h @@ -373,12 +373,6 @@ class NVF_API Val : public Statement { return is_fusion_output_; } - //! Returns true when other is a producer of this - bool isProducerOf(const Val* other) const; - - //! Returns true when other is a consumer of this - bool isConsumerOf(const Val* other) const; - bool sameType(const Statement* other) override { return Statement::sameType(other) && getDataType() == other->as()->getDataType(); diff --git a/csrc/logical_domain_map.cpp b/csrc/logical_domain_map.cpp index 6e67ff84b63..b337b78a3e1 100644 --- a/csrc/logical_domain_map.cpp +++ b/csrc/logical_domain_map.cpp @@ -57,13 +57,19 @@ PairwiseLogicalDomainMap::PairwiseLogicalDomainMap( NVF_ERROR(producer != nullptr); NVF_ERROR(consumer != nullptr); NVF_ERROR(producer->fusion() == consumer->fusion()); + NVF_ERROR(consumer->definition() != nullptr); + auto producer_tvs_of_consumer = ir_utils::producerTvsOf(consumer); // Make sure they are really a producer and its consumer NVF_ERROR( - producer->isConsumerOf(consumer), - "Not a producer-consumer pair: ", + std::find( + producer_tvs_of_consumer.begin(), + producer_tvs_of_consumer.end(), + producer) != producer_tvs_of_consumer.end(), + "Expected ", producer, - ", ", - consumer); + " is a producer of ", + consumer, + " but it is not."); } namespace {