Skip to content

Commit

Permalink
Merge branch 'main' into pm/simplify_bwd_baselines
Browse files Browse the repository at this point in the history
  • Loading branch information
Priya2698 authored Oct 31, 2024
2 parents fcb9363 + abdc3e1 commit b03d7c6
Show file tree
Hide file tree
Showing 33 changed files with 1,353 additions and 209 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/nvfuser-ci-trigger.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:

# This job only runs for pull request comments
if: |
startsWith(github.event.comment.body, '!build') &&
( startsWith(github.event.comment.body, '!build') || startsWith(github.event.comment.body, '!test') ) &&
(github.actor == 'xwang233' || github.actor == 'jjsjann123' || github.actor == 'chang-l' || github.actor == 'csarofeen' || github.actor == 'drzejan2' || github.actor == 'IvanYashchuk' || github.actor == 'jacobhinkle' || github.actor == 'kevinstephano' || github.actor == 'liqiangxl' || github.actor == 'mmigdal-nv' || github.actor == 'naoyam' || github.actor == 'ptrblck' || github.actor == 'rdspring1' || github.actor == 'samnordmann' || github.actor == 'zasdfgbnm' || github.actor == 'crcrpar' || github.actor == 'nWEIdia' || github.actor == 'Priya2698' || github.actor == 'wujingyue' || github.actor == 'tfogal' || github.actor == 'protonu' || github.actor == 'cowanmeg' || github.actor == 'nsarka')
steps:
- name: Check if comment is issued by authorized person
Expand Down
25 changes: 25 additions & 0 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

# A workflow to send CI-related helpful information to PRs
name: pull
on:
pull_request:

run-name: CI status hello ${{ github.event.pull_request.number }} - ${{ github.event.pull_request.head.sha }}
jobs:
status_hello:
name: send CI hello status
runs-on: ubuntu-latest
permissions:
statuses: write
steps:
- name: Set CI hello status
run: |
curl \
-X POST \
-H "Accept: application/vnd.github+json" \
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.event.pull_request.head.sha }} \
-d "{\"state\":\"success\",\"target_url\":\"https://github.com/NVIDIA/Fuser/wiki/Bot-Commands\",\"description\":\"Authorized users: comment !build or !test to trigger CI pipelines. See wiki.\",\"context\":\"CI notes\"}"
3 changes: 0 additions & 3 deletions benchmarks/python/test_adaptive_layernorm_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# SPDX-License-Identifier: BSD-3-Clause
import pytest
from nvfuser import FusionDefinition, DataType
from nvfuser.pytorch_utils import clear_cuda_cache
from .core import run_benchmark
import torch

Expand Down Expand Up @@ -73,8 +72,6 @@ def test_adaptive_layernorm_fwd_benchmark(
disable_validation: bool,
disable_benchmarking: bool,
):
clear_cuda_cache()

B = 1
T = 30 * 1024
D = 1024
Expand Down
83 changes: 83 additions & 0 deletions benchmarks/python/test_many_segments_host.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
import pytest
from nvfuser import FusionDefinition, DataType
from .core import run_benchmark
import torch


def many_matmul_fusion(fd: FusionDefinition) -> None:
x = fd.define_tensor(
shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False
)
y = fd.define_tensor(
shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False
)
a = fd.ops.add(x, y)
for _ in range(5):
a_transpose = fd.ops.permute(a, [1, 0])
matmul_out = fd.ops.matmul(a_transpose, y)
add_out = fd.ops.add(a_transpose, y)
a = fd.ops.add(matmul_out, add_out)
fd.add_output(a)


@pytest.mark.parametrize("host_bench_mode", ["compile", "steady", "dynamic"])
def test_many_segment_benchmark(
benchmark,
host_bench_mode: str,
disable_validation: bool,
disable_benchmarking: bool,
):
inputs = [torch.randn(16, 16, device="cuda", dtype=torch.float) for _ in range(2)]

# Generate multiple inputs to measure dynamic shape overhead.
if host_bench_mode == "dynamic":
input_sizes = [4, 8, 16, 32, 64, 128]
# Generate matrices of size x size dimensions
inputs = [
[
torch.randn(size, size, device="cuda", dtype=torch.float)
for _ in range(2)
]
for size in input_sizes
]

with FusionDefinition() as fd:
many_matmul_fusion(fd)

def validate(input):
x, y = input
eager_output = x + y
for _ in range(5):
eager_transpose = eager_output.t()
matmul_out = torch.matmul(eager_transpose, y)
add_out = eager_transpose + y
eager_output = matmul_out + add_out
fd.validate(input, [eager_output])

# Validate number of segments
_ = fd.execute(input, profile=True)
num_segments = fd.profile().segments
expected_segments = 12
assert (
num_segments == expected_segments
), f"Expected {expected_segments} fusion segments, got {num_segments}."

if not disable_validation:
if host_bench_mode == "dynamic":
# Run validate for all input sizes.
for input in inputs:
validate(input)
else:
validate(inputs)

if not disable_benchmarking:
run_benchmark(
benchmark,
None,
inputs,
device=f"host:{host_bench_mode}",
fusion_fn=many_matmul_fusion,
)
23 changes: 18 additions & 5 deletions csrc/device_lower/pass/circular_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,17 @@ class CloneTmaCircularBufferLoopAndInsertSync
return wait_exprs;
}

// If there is already an if-then-else with electSync() predicate, use it.
// Otherwise, create a new one.
kir::IfThenElse* getElectSyncIfThenElse() {
if (elect_sync_if_then_else_ == nullptr) {
elect_sync_if_then_else_ = IrBuilder::create<kir::IfThenElse>(
IrBuilder::create<kir::Predicate>(PredicateType::ElectSync));
for_loop_stack_.back()->body().push_back(elect_sync_if_then_else_);
}
return elect_sync_if_then_else_;
}

// This function selects a single thread to launch tma load and mbarrier
// arrive_expected_tx operations. The remaining threads will simply arrive
// at the mbarrier.
Expand All @@ -719,16 +730,14 @@ class CloneTmaCircularBufferLoopAndInsertSync
NVF_ERROR(mbarrier_arrive_tx_ != nullptr);
NVF_ERROR(expr != nullptr);

// Create the if-then-else with electSync() predicate for the arrive expect
// transaction.
kir::IfThenElse* if_expr = IrBuilder::create<kir::IfThenElse>(
IrBuilder::create<kir::Predicate>(PredicateType::ElectSync));
// Use the if-then-else with electSync() predicate for the arrive expect
// and cpAsyncBulk operations.
kir::IfThenElse* if_expr = getElectSyncIfThenElse();

// A single thread issues arriveExpectTx with expected transactions and
// launches the TMA load.
if_expr->thenBody().push_back(mbarrier_arrive_tx_);
if_expr->thenBody().push_back(expr);
for_loop_stack_.back()->body().push_back(if_expr);

mbarrier_arrive_tx_ = nullptr;
}
Expand Down Expand Up @@ -841,6 +850,10 @@ class CloneTmaCircularBufferLoopAndInsertSync
// Mbarrier_ArriveExpectTx to add to cloned_top_level_loop
kir::MBarrierArriveExpectTx* mbarrier_arrive_tx_ = nullptr;

// ElectSync if-then-else for the cloned loop. We put all the circular buffer
// load TMA operations under this if-then-else.
kir::IfThenElse* elect_sync_if_then_else_ = nullptr;

// The circular buffered TVs for the loop being cloned
std::unordered_set<const TensorView*> circular_buffer_load_tvs_;
};
Expand Down
143 changes: 74 additions & 69 deletions csrc/expr_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,79 @@ void validateValWithConcreteValue(

} // namespace

void ExpressionEvaluator::bindTensorDomain(
const TensorView* tv,
const at::Tensor& t,
const bool evaluate_validate) {
auto logical_domain = TensorDomain::noReductions(tv->getLogicalDomain());
NVF_ERROR(
t.dim() == (int64_t)logical_domain.size(),
"Expected ",
getInputPosString(tv),
tv->toString(),
", to be bound to a tensor of rank ",
logical_domain.size(),
", but got a tensor of rank ",
t.dim());
for (auto i : c10::irange(t.dim())) {
auto id = logical_domain[i];
if (id->isBroadcast()) {
// DIDs are ignored for broadcast.
bind_(logical_domain[i]->extent(), 1, evaluate_validate);
if (id->hasExpandedExtent()) {
// Verify that t is also expanded
NVF_ERROR(
t.size(i) == 1 || t.stride(i) == 0,
"IterDomain ",
id->toString(),
" in ",
getInputPosString(tv),
"TensorView ",
tv->toString(),
" has expanded extent but input tensor has size ",
t.size(i),
" and stride ",
t.stride(i),
" in dimension ",
i);
bind_(
logical_domain[i]->expandedExtent(), t.size(i), evaluate_validate);
}
} else {
if (logical_domain[i]->isDeviceDim()) {
// Currently we have the restrictions:
// (1) Devices parallelized axis extent == DeviceMesh's extent
// (2) Device parallelized axis cannot be split or merged
// Therefore, the device parallelized extents will always be allocated
// with size 1, but the symbolic axis extent is binded with the extent
// of the DeviceMesh
NVF_CHECK(
1 == t.size(i),
"TensorView ",
tv->toString(),
getInputPosString(tv),
" IterDomain ",
id->toString(),
"is sharded and must have size 1, but input tensor has size ",
t.size(i));
NVF_CHECK(
tv->hasDeviceMesh(),
"TV ",
tv->toString(),
getInputPosString(tv),
" has an empty DeviceMesh with DID parallelization")
bind_(
logical_domain[i]->extent(),
static_cast<int64_t>(
tv->getDeviceMesh().size(logical_domain[i]->getParallelType())),
evaluate_validate);
} else {
bind_(logical_domain[i]->extent(), t.size(i), evaluate_validate);
}
}
}
}

void ExpressionEvaluator::bind_(
const Val* value,
PolymorphicValue concrete_value,
Expand Down Expand Up @@ -162,75 +235,7 @@ void ExpressionEvaluator::bind_(
}
if (auto tv = dynamic_cast<const TensorView*>(value)) {
const auto& t = concrete_value.as<at::Tensor>();
auto logical_domain = TensorDomain::noReductions(tv->getLogicalDomain());
NVF_ERROR(
t.dim() == (int64_t)logical_domain.size(),
"Expected ",
getInputPosString(tv),
tv->toString(),
", to be bound to a tensor of rank ",
logical_domain.size(),
", but got a tensor of rank ",
t.dim());
for (auto i : c10::irange(t.dim())) {
auto id = logical_domain[i];
if (id->isBroadcast()) {
// DIDs are ignored for broadcast.
bind_(logical_domain[i]->extent(), 1, evaluate_validate);
if (id->hasExpandedExtent()) {
// Verify that t is also expanded
NVF_ERROR(
t.size(i) == 1 || t.stride(i) == 0,
"IterDomain ",
id->toString(),
" in ",
getInputPosString(tv),
"TensorView ",
tv->toString(),
" has expanded extent but input tensor has size ",
t.size(i),
" and stride ",
t.stride(i),
" in dimension ",
i);
bind_(
logical_domain[i]->expandedExtent(),
t.size(i),
evaluate_validate);
}
} else {
if (logical_domain[i]->isDeviceDim()) {
// Currently we have the restrictions:
// (1) Devices parallelized axis extent == DeviceMesh's extent
// (2) Device parallelized axis cannot be split or merged
// Therefore, the device parallelized extents will always be allocated
// with size 1, but the symbolic axis extent is binded with the extent
// of the DeviceMesh
NVF_CHECK(
1 == t.size(i),
"TensorView ",
tv->toString(),
getInputPosString(tv),
" IterDomain ",
id->toString(),
"is sharded and must have size 1, but input tensor has size ",
t.size(i));
NVF_CHECK(
tv->hasDeviceMesh(),
"TV ",
tv->toString(),
getInputPosString(tv),
" has an empty DeviceMesh with DID parallelization")
bind_(
logical_domain[i]->extent(),
static_cast<int64_t>(tv->getDeviceMesh().size(
logical_domain[i]->getParallelType())),
evaluate_validate);
} else {
bind_(logical_domain[i]->extent(), t.size(i), evaluate_validate);
}
}
}
bindTensorDomain(tv, t, evaluate_validate);
}
if (value->isA<NamedScalar>()) {
known_named_scalars_[value->as<NamedScalar>()->name()] =
Expand Down
18 changes: 12 additions & 6 deletions csrc/expr_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,6 @@ class PrecomputedValues;

//! Calculate Fusion IR expressions
class ExpressionEvaluator {
NVF_API void bind_(
const Val* value,
PolymorphicValue concrete_value,
bool evaluate_validate);
void bind_(const std::string& name, PolymorphicValue concrete_value);

public:
//! Bind a concrete value to an IR variable
//! If evaluate_validate is true, and value is evaluatable with the
Expand Down Expand Up @@ -98,6 +92,18 @@ class ExpressionEvaluator {
ExpressionEvaluator clone(IrCloner& ir_cloner) const;

private:
void bind_(
const Val* value,
PolymorphicValue concrete_value,
bool evaluate_validate);

void bind_(const std::string& name, PolymorphicValue concrete_value);

void bindTensorDomain(
const TensorView* tv,
const at::Tensor& t,
bool evaluate_validate);

const PolymorphicValue& getValue(
const Val* value,
const std::unordered_map<const Val*, PolymorphicValue>&
Expand Down
1 change: 1 addition & 0 deletions csrc/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2238,6 +2238,7 @@ kir::TensorIndex* Index::getConsumerIndex(
DataType as_type) {
Val* index = nullptr;
if (!ir_utils::hasRootToLoopLinearTransformations(consumer) ||
ir_utils::isCpAsyncBulkLoad(consumer->definition()) ||
(isIdModelOptionEnabled(IdModelEnableOption::ConsumerIndex) &&
GpuLower::current()->isTensorIndexerEnabled())) {
index = GpuLower::current()->tensorIndexer().getLinearIndex(
Expand Down
2 changes: 1 addition & 1 deletion csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2593,7 +2593,7 @@ IterDomain* IterDomain::merge(
} else {
expanded_extent = mul(outer->expandedExtent(), inner->extent());
}
} else if (outer->hasExpandedExtent() && inner->hasExpandedExtent()) {
} else if (!outer->hasExpandedExtent() && inner->hasExpandedExtent()) {
if (outer->isBroadcast()) {
expanded_extent = inner->expandedExtent();
} else {
Expand Down
Loading

0 comments on commit b03d7c6

Please sign in to comment.