Skip to content

Commit

Permalink
Merge branch 'main' into rename
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam committed Nov 7, 2024
2 parents 14993fa + 1836ed0 commit ba2a03b
Show file tree
Hide file tree
Showing 10 changed files with 138 additions and 49 deletions.
28 changes: 20 additions & 8 deletions csrc/device_lower/pass/replace_size.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,28 +59,38 @@ std::unordered_map<Val*, Val*> getSimplificationMap(Fusion* fusion) {
// 1. Constant ints. These might be non-immediate constants
// 2. Extents of input TVs.
// 3. Extents of non-input TVs.
// Within these three classes, we find the IterDomain with the smallest
// name().
// Within these three classes, we find the IterDomain with the
// smallest name(). For case 3, we also prefer the IterDomain with
// the simplest extent, which has the smallest number of defining
// expessions.
bool group_is_const = false;
IterDomain* rep = nullptr;
bool rep_is_input_id = false;
int64_t rep_num_defs = 0;
std::unordered_set<Val*> dynamic_scalars;
for (Val* v : *group) {
auto* id = dynamic_cast<IterDomain*>(v);
NVF_ERROR(
id != nullptr, "Expected only IterDomains in exact graph ValGroups");
bool is_input_id = fusion_input_ids.count(id) > 0;
if (rep == nullptr) {
rep = id;
rep_is_input_id = is_input_id;
continue;
}
Val* ext = id->extent();
bool ext_is_const = ext->isConstInt();
if (!ext_is_const) {
dynamic_scalars.insert(ext);
}

// Initializing rep with the first ID
if (rep == nullptr) {
rep = id;
rep_is_input_id = is_input_id;
group_is_const = ext_is_const;
// If neigher const nor input, record the number of exprs
if (!ext_is_const && !is_input_id) {
rep_num_defs = ir_utils::getOperationCount(id->extent());
}
continue;
}

if (ext_is_const) {
if (!group_is_const || id->name() < rep->name()) {
rep = id;
Expand All @@ -103,9 +113,11 @@ std::unordered_map<Val*, Val*> getSimplificationMap(Fusion* fusion) {
if (group_is_const || rep_is_input_id) {
continue;
}
if (id->name() < rep->name()) {
auto num_defs = ir_utils::getOperationCount(id->extent());
if (num_defs < rep_num_defs || id->name() < rep->name()) {
rep = id;
rep_is_input_id = is_input_id;
rep_num_defs = num_defs;
continue;
}
}
Expand Down
26 changes: 26 additions & 0 deletions csrc/ir/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,32 @@ bool isRecursivelyDefined(Val* val) {
return false;
}

int64_t getOperationCount(Val* val) {
int64_t num_ops = 0;

// Start with the given val and recursively count the number of ops
// by traversing inputs
std::deque<Val*> vals;
vals.push_back(val);

while (!vals.empty()) {
auto v = vals.front();
vals.pop_front();

auto def = v->definition();
if (def == nullptr) {
continue;
}
++num_ops;

for (auto inp : def->inputs()) {
vals.push_back(inp);
}
}

return num_ops;
}

} // namespace nvfuser::ir_utils

namespace nvfuser::MmaOpUtils {
Expand Down
4 changes: 4 additions & 0 deletions csrc/ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -733,4 +733,8 @@ bool isFunctional(const Val* v);
// such as the Kernel IR
bool isRecursivelyDefined(Val* val);

// Return the number of operations that are used to define val. One
// instance of Expr is counted as a single operation.
int64_t getOperationCount(Val* val);

} // namespace nvfuser::ir_utils
20 changes: 19 additions & 1 deletion csrc/multidevice/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
*/
// clang-format on
#include <ATen/cuda/CUDAContext.h>

#include <device_lower/utils.h>
#include <fusion_segmenter.h>
#include <host_ir/container.h>
#include <host_ir/host_ir.h>
#include <instrumentation.h>
#include <ir/builder.h>
#include <ir/utils.h>
#include <multidevice/device_mesh.h>
Expand Down Expand Up @@ -52,6 +54,22 @@ std::unique_ptr<Fusion> copyFusionAndChangeOutputs(
return fusion_copy;
}

// Used in distributed setting where we only want to allocate output space and
// receive output data from a different rank instead of computing them.
std::vector<at::Tensor> allocateOutputSpace(
const at::ArrayRef<c10::IValue>& inputs,
Fusion* fusion,
const c10::Device& device) {
FUSER_PERF_SCOPE("multidevice::executor::allocateOutputSpace");
auto fusion_inputs = KernelArgumentHolder::createKernelArgumentHolder(inputs);
auto expr_eval = executor_utils::bindInputs(fusion_inputs, fusion);

auto output_info =
getBufferInfos(expr_eval, PrimDataType::Int, fusion->outputs());

return allocateOutputs(fusion, output_info, device, expr_eval);
}

} // namespace

MultiDeviceExecutor::MultiDeviceExecutor(
Expand Down Expand Up @@ -186,7 +204,7 @@ std::vector<at::Tensor> MultiDeviceExecutor::runWithInput(
}

auto allocations =
allocOutputSpace(inputs, allocator_fusion_.get(), comm()->device());
allocateOutputSpace(inputs, allocator_fusion_.get(), comm()->device());
NVF_ERROR(vals_to_allocate_.size() == allocations.size());
for (auto i : c10::irange(allocations.size())) {
val_to_IValue[vals_to_allocate_.at(i)] = allocations.at(i);
Expand Down
2 changes: 1 addition & 1 deletion csrc/python_frontend/fusion_record.h
Original file line number Diff line number Diff line change
Expand Up @@ -1823,7 +1823,7 @@ struct SelectOpRecord : RecordFunctor {

void operator()(FusionState& fd) final {
auto arg1 = fd.getFusionState(args_.at(0).index)->template as<TensorView>();
auto arg3 = fd.getFusionState(args_.at(1).index)->template as<TensorView>();
auto arg3 = fd.getFusionState(args_.at(1).index);

Val* output = select(arg1, dim_, arg3);
fd.setFusionState(outputs_.at(0).index, output);
Expand Down
24 changes: 4 additions & 20 deletions csrc/runtime/allocations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,20 +366,6 @@ std::vector<at::Tensor> allocateOutputs(
return out_tensors;
}

std::vector<at::Tensor> allocOutputSpace(
const at::ArrayRef<c10::IValue>& inputs,
Fusion* fusion,
const c10::Device& device) {
FUSER_PERF_SCOPE("fusion_executor::allocations::allocOutputSpace");
auto fusion_inputs = KernelArgumentHolder::createKernelArgumentHolder(inputs);
auto expr_eval = executor_utils::bindInputs(fusion_inputs, fusion);

auto output_info =
getBufferInfos(expr_eval, PrimDataType::Int, fusion->outputs());

return allocateOutputs(fusion, output_info, device, expr_eval);
}

namespace {
GlobalBufferInfo getBufferInfo(
ExpressionEvaluator& expr_eval,
Expand Down Expand Up @@ -685,12 +671,11 @@ class BackwardTraverseFromAllocToLogical {
// Another example, if the logical domain is [I1*I2] and the allocation domain
// is [I1, I2], then we will allocate as [I1, I2] and do a tensor.view(I1*I2) to
// get a tensor whose semantics is [I1*I2] but memory is [I1,I2]
at::Tensor transformOutputFromAllocationToLogical(
at::Tensor transformFromAllocationToLogical(
at::Tensor tensor,
TensorView* tv,
ExpressionEvaluator& ee) {
FUSER_PERF_SCOPE(
"fusion_executor::allocations::transformOutputFromAllocationToLogical");
FUSER_PERF_SCOPE("allocations::transformFromAllocationToLogical");
// Ignore reductions because reductions does not exist in tensor's definition
auto logical = TensorDomain::noReductions(tv->getLogicalDomain());
auto alloc = TensorDomain::noReductions(tv->getMaybeAllocationDomain());
Expand Down Expand Up @@ -765,9 +750,8 @@ std::pair<std::vector<int64_t>, std::vector<int64_t>> inferShapeOfOutput(
at::empty_strided(size_stride.first, size_stride.second, options);
// TODO(jiej): we should refactor it here, there's no need to use
// meta_tensor at all, size + stride should be used directly in the
// `transformOutputFromAllocationToLogical`
meta_tensor =
transformOutputFromAllocationToLogical(meta_tensor, tv, expr_eval);
// `transformFromAllocationToLogical`
meta_tensor = transformFromAllocationToLogical(meta_tensor, tv, expr_eval);
return {meta_tensor.sizes().vec(), meta_tensor.strides().vec()};
}

Expand Down
8 changes: 0 additions & 8 deletions csrc/runtime/allocations.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,6 @@ NVF_API void setFillAllocationWithNan(bool value);

void fillTensorWithNan(at::Tensor& t);

//! Used in distributed setting where we only want to
//! allocate output space and receive output data from
//! a different rank instead of computing them.
std::vector<at::Tensor> allocOutputSpace(
const at::ArrayRef<c10::IValue>& inputs,
Fusion* fusion,
const c10::Device& device);

// Infer the sizes and strides of an output tensor
std::pair<std::vector<int64_t>, std::vector<int64_t>> inferShapeOfOutput(
TensorView* tv,
Expand Down
11 changes: 0 additions & 11 deletions csrc/runtime/fusion_executor_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@
#include <scheduler/registry.h>
#include <utils.h>

#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/runtime/graph_executor.h>

namespace nvfuser {

FusionExecutorCache::FusionExecutorCache(
Expand Down Expand Up @@ -81,15 +78,7 @@ std::vector<at::Tensor> FusionExecutorCache::runFusionWithInputs(
" failed");
}

int seq_id = 0;
// Record kernel input and output tensors so profiler can construct
// the data flow graph
RECORD_FUNCTION(
"run_fused_kernel",
std::vector<c10::IValue>(inputs.begin(), inputs.end()),
seq_id);
auto outputs = kernel_runtime->runWithInputs(args);
RECORD_OUTPUTS(outputs);

// Kernel time measurement is off by default
kernel_runtime->disableKernelTimeMeasurement();
Expand Down
6 changes: 6 additions & 0 deletions csrc/scheduler/cache_policy_refiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ bool isLoadGlobalToLocal(const Expr* expr) {
if (ldst->opType() != LoadStoreOpType::Set) {
return false;
}
// It should not be necessary to check the output since it should be
// always a TensorView as long as the input is a TensorView, but
// just in case.
if (!ldst->in()->isA<TensorView>() || !ldst->out()->isA<TensorView>()) {
return false;
}
if (ldst->in()->as<TensorView>()->getMemoryType() != MemoryType::Global) {
return false;
}
Expand Down
58 changes: 58 additions & 0 deletions tests/cpp/test_gpu3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8937,6 +8937,64 @@ TEST_F(NVFuserTest, AvoidReplacingWithDependentVal) {
"not allowed as it would result in a recursive definition")));
}

// Was also a repro of issue #3347
TEST_F(NVFuserTest, ReplaceSymbolicSizesPreferSimplerExtents) {
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr;
FusionGuard fg(fusion_ptr.get());

auto tv0 = makeSymbolicTensor(3);
fusion.addInput(tv0);
auto tv1 = makeSymbolicTensor(2);
fusion.addInput(tv1);
auto i0 = IrBuilder::create<Val>(DataType::Index);
fusion.addInput(i0);

auto tv2 = reshape(tv0, {i0});
auto tv3 = reshape(tv1, {i0});
auto tv4 = add(tv2, tv3);
fusion.addOutput(tv4);

ExpressionEvaluator expr_eval;

expr_eval.bind(tv0->axis(0)->extent(), 2L);
expr_eval.bind(tv0->axis(1)->extent(), 4L);
expr_eval.bind(tv0->axis(2)->extent(), 8L);
expr_eval.bind(tv1->axis(0)->extent(), 8L);
expr_eval.bind(tv1->axis(1)->extent(), 8L);
expr_eval.bind(i0, 64L);

auto initial_info = DynamicTransform::getInitialInfo(&fusion);
auto info = DynamicTransformConcretizationInfo(&initial_info, &expr_eval);

DynamicTransform::concretizeFusion(&fusion, &info);

replaceSymbolicSizes(&fusion);

// All expr output tensors should use the extent of the tv3 since it
// has only one merge, whereas tv2 has two merges
// All expr output tensors should use the same extent.
auto ref_ext = fusion.outputs().at(0)->as<TensorView>()->axis(0)->extent();

// ref_ext should look like getMetaData(T1).logical_size[0] *
// getMetaData(T1).logical_size[1]
auto ext_def = dynamic_cast<BinaryOp*>(ref_ext->definition());
ASSERT_NE(ext_def, nullptr);
ASSERT_EQ(ext_def->getBinaryOpType(), BinaryOpType::Mul);
auto lhs = ext_def->input(0);
auto rhs = ext_def->input(1);
ASSERT_NE(dynamic_cast<GetItem*>(lhs->definition()), nullptr);
ASSERT_NE(dynamic_cast<GetItem*>(rhs->definition()), nullptr);

for (auto expr : fusion.exprs()) {
auto tv_output = ir_utils::getTvOutput(expr);
ASSERT_EQ(tv_output->nDims(), 1);
auto ext = tv_output->axis(0)->extent();
EXPECT_EQ(ref_ext, ext) << "Reference: " << ref_ext->toString()
<< ", actual: " << ext->toString();
}
}

// Test file size should be up to 10K LoC. Create a new file for more tests.

} // namespace nvfuser

0 comments on commit ba2a03b

Please sign in to comment.