Skip to content

Commit

Permalink
Validate Sparse Compressed tensor inputs (pytorch#79385)
Browse files Browse the repository at this point in the history
The validation includes regular tensor inputs, batched tensor inputs, as well as hybrid tensor inputs.
Pull Request resolved: pytorch#79385
Approved by: https://github.com/nikitaved, https://github.com/cpuhrsch
  • Loading branch information
pearu authored and pytorchmergebot committed Jun 27, 2022
1 parent 9db3c51 commit cde365a
Show file tree
Hide file tree
Showing 12 changed files with 4,157 additions and 2,616 deletions.
32 changes: 32 additions & 0 deletions aten/src/ATen/SparseCsrTensorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,38 @@ inline std::string plainIndicesName(Layout layout) {
[&] { return "row_indices"; });
}

inline std::string compressedDimName(Layout layout) {
switch (layout) {
case kSparseCsr:
return "row";
case kSparseCsc:
return "column";
case kSparseBsr:
return "row block";
case kSparseBsc:
return "column block";
default:
TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
return "";
}
}

inline std::string plainDimName(Layout layout) {
switch (layout) {
case kSparseCsr:
return "column";
case kSparseCsc:
return "row";
case kSparseBsr:
return "column block";
case kSparseBsc:
return "row block";
default:
TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
return "";
}
}

inline int rowDimension(Layout layout, IntArrayRef size) {
return size.size() - (isCompressedRow(layout) ? 2 : 1);
}
Expand Down
336 changes: 182 additions & 154 deletions aten/src/ATen/native/sparse/SparseCsrTensor.cpp

Large diffs are not rendered by default.

Loading

0 comments on commit cde365a

Please sign in to comment.