Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam committed Sep 5, 2024
1 parent f050ab5 commit db323c1
Show file tree
Hide file tree
Showing 7 changed files with 911 additions and 18 deletions.
3 changes: 0 additions & 3 deletions csrc/device_lower/analysis/sync_information.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -797,9 +797,6 @@ SyncMap::SyncMap(Fusion* fusion) {
raw_dims.set(producer_ptype);
} // end for ptypes

if (getenv("IGNORE_SYNC")) {
continue;
}
if (raw_dims.hasBID()) {
NVF_ERROR(
producer->getMemoryType() == MemoryType::Global,
Expand Down
3 changes: 3 additions & 0 deletions csrc/device_lower/validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,8 @@ void validateLookupTV(Fusion* fusion) {
}

void validateResize(Fusion* fusion) {
// No longer true
#if 0
auto fusion_vals = fusion->usedMathVals();
for (auto tv : ir_utils::filterByType<TensorView>(fusion_vals)) {
// Make sure resize is only used as part of root to logical transformations
Expand All @@ -1075,6 +1077,7 @@ void validateResize(Fusion* fusion) {
tv->toString(),
". Resize may only be used as part of root to logical transformations.");
}
#endif
}

void validateReductions(Fusion* fusion) {
Expand Down
5 changes: 5 additions & 0 deletions csrc/id_model/id_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <transform_iter.h>
#include <val_graph_visitor.h>

#include <fstream>
#include <memory>
#include <tuple>
#include <utility>
Expand Down Expand Up @@ -334,6 +335,10 @@ void IdModel::buildExactGraph() {
}

idGraph(IdMappingMode::EXACT).validateConsistency();

std::cerr << "Exact graph\n"
<< nvfuser::idGroupsString(idGraph(IdMappingMode::EXACT))
<< std::endl;
}

namespace {
Expand Down
7 changes: 7 additions & 0 deletions csrc/id_model/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,8 @@ std::vector<Val*> TensorIndexer::getIndexFor(
bool as_consumer,
const ValGroups& index_groups,
const std::vector<ForLoop*>& for_loops) const {
std::cerr << "getIndexFor: " << expr->toString()
<< "as consumer: " << as_consumer << std::endl;
auto info = computeIndex(expr, index_groups, for_loops);
const auto& replacement_map = getIndexReplacementMap(
expr, as_consumer, info.loop_domains, for_loops, info.index_map);
Expand Down Expand Up @@ -933,6 +935,10 @@ Val* TensorIndexer::getLinearIndex(

const auto alloc_info = getIndexingAllocationInfo(tv);

std::cerr << "getLinearIndex: " << tv->toString()
<< ", alloc: " << toDelimitedString(alloc_info.domains)
<< std::endl;

const auto [contig_indices, contig_strides] =
getContigIndexFor(expr, as_consumer, alloc_info, for_loops);

Expand Down Expand Up @@ -1099,6 +1105,7 @@ std::vector<PredicateInfo> TensorIndexer::getPredicates(
const Expr* expr,
const std::vector<ForLoop*>& for_loops,
ForLoop* unswitched_loop) const {
std::cerr << "getPredicates: " << tv->toString() << std::endl;
const auto& zero_val = tv->fusion()->zeroVal();

const std::vector<IterDomain*>& predicate_domains =
Expand Down
6 changes: 0 additions & 6 deletions csrc/id_model/indexing_traversal.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ class IndexingTraversal : public ValGraphBFS {
const ValGraph& graph,
const ValGroups& from_groups,
const ValGroups& to_groups) {
std::cerr << "getExprsBetween\n";
std::cerr << nvfuser::toString(from_groups) << " -> "
<< nvfuser::toString(to_groups) << "\n";
IndexingTraversal traversal(
expr,
graph,
Expand All @@ -49,9 +46,6 @@ class IndexingTraversal : public ValGraphBFS {
using ValGraphBFS::isVisited;

bool excludeFromTraversal(const NodeType& group) const override {
if (getenv("DISABLE_EXCLUDE")) {
return false;
}
if (const ExprGroup* eg = std::get_if<ExprGroup>(&group)) {
if ((*eg)->empty()) {
return false;
Expand Down
19 changes: 19 additions & 0 deletions csrc/val_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,8 @@ void ValGraph::mapVals(Val* val0, Val* val1) {
}
auto def0 = def_group_0->front();
auto def1 = def_group_1->front();
std::cerr << "mapThrough: " << def0->name() << ", " << def1->name()
<< std::endl;
maybeMapThroughExprs(def0, def1, false);
}
}
Expand All @@ -478,6 +480,23 @@ void ValGraph::maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward) {
// respectively, and vice versa.

if (!exprsMap(expr0, expr1, forward)) {
std::cerr << "exprs not mapped: " << expr0->name() << ", " << expr1->name()
<< std::endl;
if (expr0->name() == 10 && expr1->name() == 1) {
std::cerr << expr0->toString();
std::cerr << expr1->toString();
std::cerr << "same? " << expr0->sameOp(expr1) << std::endl;
for (const auto i : c10::irange(expr0->attributes().size())) {
if (!expr0->attribute(i)->sameAs(expr1->attribute(i))) {
std::cerr << "Different attribute at " << i << ": "
<< expr0->attribute(i)->toInlineString() << " ("
<< expr0->attribute(i)->as<Val>()->dtype() << "), "
<< expr1->attribute(i)->toInlineString() << " ("
<< expr1->attribute(i)->as<Val>()->dtype() << ")"
<< std::endl;
}
}
}
return;
}

Expand Down
Loading

0 comments on commit db323c1

Please sign in to comment.