Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Nov 12, 2024
1 parent 243bfe7 commit 2dec4a7
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions csrc/tensor_metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,31 +279,28 @@ inferAndValidateAllocationSizesAndStrides(
const at::Tensor& tensor,
TensorView* tv,
ExpressionEvaluator ee) {
const auto& logical = tv->getLogicalDomain();
const auto& alloc = tv->getMaybeAllocationDomain();
const auto& alloc_no_reductions = TensorDomain::noReductions(alloc);
const auto& logical_no_reductions =
TensorDomain::noReductions(tv->getLogicalDomain());

// active IDs and their shape and stride
std::unordered_map<IterDomain*, std::pair<int64_t, int64_t>> active_ids;
NVF_ERROR(static_cast<int64_t>(logical_no_reductions.size()) == tensor.dim());
for (const auto i : c10::irange(tensor.dim())) {
IterDomain* id = logical_no_reductions.at(i);
active_ids[id] = {tensor.size(i), tensor.stride(i)};
int64_t dim_index = 0;
for (IterDomain* id : TensorDomain::noReductions(logical)) {
active_ids[id] = {tensor.size(dim_index), tensor.stride(dim_index)};
dim_index++;
}
NVF_ERROR(dim_index == tensor.dim());

ForwardTraverseFromLogicalToAlloc(ee, active_ids)
.run(tv, logical_no_reductions, alloc_no_reductions);
BackwardTraverseFromLogicalToAlloc(ee, active_ids)
.run(tv, logical_no_reductions, alloc_no_reductions);
ForwardTraverseFromLogicalToAlloc(ee, active_ids).run(tv, logical, alloc);
BackwardTraverseFromLogicalToAlloc(ee, active_ids).run(tv, logical, alloc);

// Now active_ids should contain the final sizes and strides, unordered. We
// need to put them to the correct order.
std::vector<int64_t> sizes;
std::vector<int64_t> strides;
sizes.reserve(alloc_no_reductions.size());
strides.reserve(alloc_no_reductions.size());
for (IterDomain* id : alloc_no_reductions) {
sizes.reserve(alloc.size());
strides.reserve(alloc.size());
for (IterDomain* id : TensorDomain::noReductions(alloc)) {
if (id->isDeviceDim()) {
sizes.push_back(1);
} else {
Expand Down

0 comments on commit 2dec4a7

Please sign in to comment.