Skip to content

Commit

Permalink
Remove isProducerOf/isConsumerOf (#2867)
Browse files Browse the repository at this point in the history
Remove isProducerOf/isConsumerOf, replace the one instance. Original
definitions were backwards.
  • Loading branch information
csarofeen authored Aug 28, 2024
1 parent 2eeef46 commit 13fba3b
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 27 deletions.
17 changes: 0 additions & 17 deletions csrc/ir/base_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,23 +248,6 @@ std::optional<DataType> Val::getDataType() const {
return dtype_;
}

bool Val::isProducerOf(const Val* other) const {
NVF_ERROR(other != nullptr);
NVF_ERROR(container() == other->container());

if (definition() == nullptr) {
return false;
}
return std::any_of(
definition()->inputs().begin(),
definition()->inputs().end(),
[other](const Val* input) { return input == other; });
}

bool Val::isConsumerOf(const Val* other) const {
return other->isProducerOf(this);
}

// We don't register with the active fusion in Expr as this needs to be done
// after inputs and outputs are registered with the Expr
Expr::Expr(IrBuilderPasskey passkey) : Statement(passkey) {}
Expand Down
6 changes: 0 additions & 6 deletions csrc/ir/base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,12 +373,6 @@ class NVF_API Val : public Statement {
return is_fusion_output_;
}

//! Returns true when other is a producer of this
bool isProducerOf(const Val* other) const;

//! Returns true when other is a consumer of this
bool isConsumerOf(const Val* other) const;

bool sameType(const Statement* other) override {
return Statement::sameType(other) &&
getDataType() == other->as<Val>()->getDataType();
Expand Down
14 changes: 10 additions & 4 deletions csrc/logical_domain_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,19 @@ PairwiseLogicalDomainMap::PairwiseLogicalDomainMap(
NVF_ERROR(producer != nullptr);
NVF_ERROR(consumer != nullptr);
NVF_ERROR(producer->fusion() == consumer->fusion());
NVF_ERROR(consumer->definition() != nullptr);
auto producer_tvs_of_consumer = ir_utils::producerTvsOf(consumer);
// Make sure they are really a producer and its consumer
NVF_ERROR(
producer->isConsumerOf(consumer),
"Not a producer-consumer pair: ",
std::find(
producer_tvs_of_consumer.begin(),
producer_tvs_of_consumer.end(),
producer) != producer_tvs_of_consumer.end(),
"Expected ",
producer,
", ",
consumer);
" is a producer of ",
consumer,
" but it is not.");
}

namespace {
Expand Down

0 comments on commit 13fba3b

Please sign in to comment.