diff --git a/aten/src/ATen/SparseCsrTensorUtils.h b/aten/src/ATen/SparseCsrTensorUtils.h index dfc7ff8813045..628b8ce61240a 100644 --- a/aten/src/ATen/SparseCsrTensorUtils.h +++ b/aten/src/ATen/SparseCsrTensorUtils.h @@ -143,12 +143,12 @@ inline int columnDimension(Layout layout, IntArrayRef size) { return size.size() - (isCompressedColumn(layout) ? 2 : 1); } -inline int compressedDimension(Layout layout, IntArrayRef size) { - return size.size() - (isCompressedRow(layout) ? 2 : 1); +inline int compressedDimension(Layout layout, IntArrayRef size, size_t dense_ndim=0) { + return size.size() - dense_ndim - (isCompressedRow(layout) ? 2 : 1); } -inline int plainDimension(Layout layout, IntArrayRef size) { - return size.size() - (isCompressedRow(layout) ? 1 : 2); +inline int plainDimension(Layout layout, IntArrayRef size, size_t dense_ndim=0) { + return size.size() - dense_ndim - (isCompressedRow(layout) ? 1 : 2); } } // namespace sparse_csr diff --git a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp index 62d600dc0926d..77979f55647de 100644 --- a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp @@ -103,38 +103,52 @@ void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_ind "number of dimensions of ", compressed_indices_name, " and ", plain_indices_name, " must be the same but got ", compressed_indices.dim(), " and ", plain_indices.dim(), ", respectively"); - AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS( + int block_ndim = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS( layout, "validate_sparse_compressed_tensor_args", [&] { TORCH_CHECK( - compressed_indices.dim() == values.dim(), - "number of dimensions of indices and values must be the same but got ", - compressed_indices.dim(), " and ", values.dim(), ", respectively"); + compressed_indices.dim() <= values.dim(), + "number of dimensions of indices (=", compressed_indices.dim(), + ") must be equal or less than the number of dimensions of values (=", values.dim(), ")"); + return 0; }, [&] { TORCH_CHECK( - compressed_indices.dim() + 2 == values.dim(), - "number of dimensions of indices must be two less than the number of dimensions of the values but got ", - compressed_indices.dim(), " + 2 not equal to ", values.dim()); + compressed_indices.dim() + 2 <= values.dim(), + "number of dimensions of indices (=", compressed_indices.dim(), + ") plus two must be equal or less than the number of dimensions of values (=", values.dim(), ")"); + return 2; }); + int dense_ndim = values.dim() - compressed_indices.dim() - block_ndim; + TORCH_CHECK(dense_ndim == 0, "non-zero dense dimensions (=", dense_ndim, ") is not supported for ", layout, " layout"); - TORCH_CHECK( - static_cast(compressed_indices.dim()) == size.size() - 1, - "number of dimensions of indices must be one less than the number of dimensions of the provided size but got ", - compressed_indices.dim(), " not equal to ", size.size(), " - 1"); + int batch_ndim = size.size() - 2 - dense_ndim; + TORCH_INTERNAL_ASSERT(block_ndim >= 0 && dense_ndim >=0 && batch_ndim >= 0); - int block_ndim = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout, "validate_sparse_compressed_tensor_args", [&]{ return 0; }, [&]{ return 2; }); - IntArrayRef block_size = values.sizes().slice(values.dim() - block_ndim, block_ndim); - int64_t numel_per_block = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout, "validate_sparse_compressed_tensor_args", - [&]() -> int64_t { return 1; }, [&]() -> int64_t { return block_size[0] * block_size[1]; }); - int compressed_dim = compressedDimension(layout, size); - int plain_dim = plainDimension(layout, size); + TORCH_CHECK( + static_cast(compressed_indices.dim()) == size.size() - 1 - dense_ndim, + "number of dimensions of indices must be one less than the number of dimensions of the provided size", + " (minus the number of dense dimensions) but got ", + compressed_indices.dim(), " not equal to ", size.size(), " - 1 - ", dense_ndim); + + // For CSR/CSC formats, we define blocksize=(1, 1) so that checking + // the sparse compressed tensor invariants can be unified with the + // BSR/BSC invariants. + DimVector blocksize{ + (block_ndim == 2 ? std::max(1, values.sizes()[values.dim() - dense_ndim - 2]) : 1), + (block_ndim == 2 ? std::max(1, values.sizes()[values.dim() - dense_ndim - 1]) : 1), + }; + TORCH_INTERNAL_ASSERT(blocksize.size() == 2 && blocksize[0] > 0 && blocksize[1] > 0); + + int64_t numel_per_block = blocksize[0] * blocksize[1]; + int compressed_dim = compressedDimension(layout, size, dense_ndim); + int plain_dim = plainDimension(layout, size, dense_ndim); // All batch sizes must be the same - auto batch_size = size.slice(0, size.size() - 2); - auto compressed_indices_batch_size = compressed_indices.sizes().slice(0, compressed_indices.dim() - 1); - auto plain_indices_batch_size = plain_indices.sizes().slice(0, plain_indices.dim() - 1); - auto values_batch_size = values.sizes().slice(0, values.dim() - 1 - block_ndim); + DimVector batch_size = DimVector(size.slice(0, batch_ndim)); + DimVector compressed_indices_batch_size = DimVector(compressed_indices.sizes().slice(0, compressed_indices.dim() - 1)); + DimVector plain_indices_batch_size = DimVector(plain_indices.sizes().slice(0, plain_indices.dim() - 1)); + DimVector values_batch_size = DimVector(values.sizes().slice(0, values.dim() - 1 - block_ndim - dense_ndim)); TORCH_CHECK( batch_size == compressed_indices_batch_size && batch_size == plain_indices_batch_size && @@ -143,34 +157,56 @@ void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_ind compressed_indices_batch_size,", ", plain_indices_batch_size, "), and values (", values_batch_size,") must be the same."); + // A tensor constitutes of full blocks + for (int i=0; i= 1` - TORCH_CHECK( - compressed_indices.size(-1) == (size[compressed_dim] + 1), - compressed_indices_name, ".size(-1) must be equal to size[-", (size.size() - compressed_dim), "] + 1 (that is ", - size[compressed_dim] + 1, "), but got: ", compressed_indices.size(-1)); + if (block_ndim == 2) { + TORCH_CHECK( + compressed_indices.size(-1) == (size[compressed_dim] / blocksize[compressed_dim - batch_ndim] + 1), + compressed_indices_name, ".size(-1) must be equal to size[-", (size.size() - compressed_dim), + "]/blocksize[", compressed_dim - batch_ndim, "] + 1 (that is ", + size[compressed_dim] / blocksize[compressed_dim - batch_ndim] + 1, "), but got: ", compressed_indices.size(-1)); + TORCH_CHECK( + plain_indices.numel() * numel_per_block == values.numel(), + "number of ", plain_indices_name, " elements must be the same as the number of blocks in values, but got ", + plain_indices_name, ".numel() * numel_per_block: ", plain_indices.numel() * numel_per_block, + ", values.numel(): ", values.numel(),", numel_per_block: ", numel_per_block); + } else { + TORCH_CHECK( + compressed_indices.size(-1) == (size[compressed_dim] + 1), + compressed_indices_name, ".size(-1) must be equal to size[-", (size.size() - compressed_dim), + "] + 1 (that is ", + size[compressed_dim] + 1, "), but got: ", compressed_indices.size(-1)); + TORCH_CHECK( + plain_indices.numel() == values.numel(), + "number of ", plain_indices_name, " elements must be the same number of elements, but got ", + plain_indices_name, ".numel(): ", plain_indices.numel(), + ", values.numel(): ", values.numel()); + } - AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout, "validate_sparse_compressed_tensor_args", - [&] { - TORCH_CHECK( - plain_indices.numel() == values.numel(), - plain_indices_name, " and values must have the same number of elements, but got ", plain_indices_name, ".numel(): ", - plain_indices.numel(), ", values.numel(): ", values.numel()); - }, - [&] { - TORCH_CHECK( - plain_indices.numel() * numel_per_block == values.numel(), - "number of ", plain_indices_name, " elements must be the same as the number of blocks in values, but got ", - plain_indices_name, ".numel() * numel_per_block: ", plain_indices.numel() * numel_per_block, - ", values.numel(): ", values.numel(),", numel_per_block: ", numel_per_block); - }); + // Type Invariants + auto compressed_indices_type = compressed_indices.scalar_type(); + auto plain_indices_type = plain_indices.scalar_type(); + TORCH_CHECK( + compressed_indices_type == plain_indices_type, + "both ", compressed_indices_name, " and ", plain_indices_name, " should have the same type, bot got ", + compressed_indices_type, " and ", plain_indices_type, ", respectively"); + TORCH_CHECK( + compressed_indices_type == kInt || compressed_indices_type == kLong, + compressed_indices_name, " and ", plain_indices_name, " must be an int32 or int64 type, but got: ", + compressed_indices_type); // Indices invariants - AT_DISPATCH_INDEX_TYPES(compressed_indices.scalar_type(), "validate_sparse_compressed_tensor_args", + AT_DISPATCH_INDEX_TYPES(compressed_indices_type, "validate_sparse_compressed_tensor_args", [&] { Tensor compressed_indices_cpu = compressed_indices.to(kCPU); auto compressed_indices_data_ptr = compressed_indices_cpu.data_ptr(); auto batch_stride = compressed_indices_cpu.dim() >= 2 ? compressed_indices_cpu.stride(-2) : 0; - auto compressed_dims = size[compressedDimension(layout, size)]; + auto compressed_dims = (block_ndim == 0 ? size[compressed_dim] : size[compressed_dim] / blocksize[compressed_dim - batch_ndim]); for (const auto batch_id : c10::irange(batchCount(compressed_indices_cpu))) { TORCH_CHECK( compressed_indices_data_ptr[batch_id*batch_stride] == 0, @@ -184,7 +220,8 @@ void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_ind TORCH_CHECK( compressed_indices_data_ptr[batch_id*batch_stride + i - 1] <= compressed_indices_data_ptr[batch_id*batch_stride + i], "(Batch element ", batch_id, ") ", - "at position i = ", i, ", the condition ", compressed_indices_name, "[i - 1] <= ", compressed_indices_name, "[i] fails"); + "at position i = ", i, ", the condition ", compressed_indices_name, "[i - 1] <= ", compressed_indices_name, "[i] fails, got ", + compressed_indices_data_ptr[batch_id*batch_stride + i - 1], " <= ", compressed_indices_data_ptr[batch_id*batch_stride + i]); } } if (plain_indices.numel() > 0) { @@ -193,18 +230,6 @@ void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_ind } }); - // Type Invariants - auto compressed_indices_type = compressed_indices.scalar_type(); - auto plain_indices_type = plain_indices.scalar_type(); - TORCH_CHECK( - compressed_indices_type == plain_indices_type, - "both ", compressed_indices_name, " and ", plain_indices_name, " should have the same type, bot got ", - compressed_indices_type, " and ", plain_indices_type, ", respectively"); - TORCH_CHECK( - compressed_indices_type == kInt || compressed_indices_type == kLong, - compressed_indices_name, " and ", plain_indices_name, " must be an int32 or int64 type, but got: ", - compressed_indices_type); - // Device Invariants TORCH_CHECK( plain_indices.get_device() == compressed_indices.get_device(), @@ -335,6 +360,12 @@ DimVector _estimate_sparse_compressed_tensor_size( const Tensor& plain_indices, const Tensor& values, Layout layout) { + int block_ndim = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout, "estimate_sparse_compressed_tensor_size", [&] { return 0; }, [&] { return 2; }); + int dense_ndim = values.dim() - compressed_indices.dim() - block_ndim; + DimVector blocksize{ + (block_ndim == 2 ? std::max(1, values.sizes()[values.dim() - dense_ndim - 2]) : 1), + (block_ndim == 2 ? std::max(1, values.sizes()[values.dim() - dense_ndim - 1]) : 1), + }; DimVector size = DimVector(IntArrayRef(plain_indices.sizes().data(), plain_indices.dim() - 1)); int64_t compressed_dim = (plain_indices.size(-1) > 0 ? compressed_indices.size(-1) - 1 : 0); int64_t plain_dim = AT_DISPATCH_INTEGRAL_TYPES(plain_indices.scalar_type(), "estimate_sparse_compressed_tensor_size", @@ -347,13 +378,16 @@ DimVector _estimate_sparse_compressed_tensor_size( }); AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(layout, "estimate_sparse_compressed_tensor_size", [&]{ - size.push_back(compressed_dim); - size.push_back(plain_dim); + size.push_back(compressed_dim * blocksize[0]); + size.push_back(plain_dim * blocksize[1]); }, [&]{ - size.push_back(plain_dim); - size.push_back(compressed_dim); + size.push_back(plain_dim * blocksize[0]); + size.push_back(compressed_dim * blocksize[1]); }); + for (int i=0; i 0 for d in range(len(size))) or nnz == 0, 'invalid arguments' assert len(size) >= sparse_dim - if block_size: - assert len(block_size) == 2 + if blocksize: + assert len(blocksize) == 2, (size, blocksize) + assert size[-2] % blocksize[0] == 0, (size, blocksize) + assert size[-1] % blocksize[1] == 0, (size, blocksize) + blocksize0, blocksize1 = blocksize + else: + blocksize0 = blocksize1 = 1 def random_sparse_compressed(n_compressed_dims, n_plain_dims, nnz): compressed_indices = self._make_crow_indices(n_compressed_dims, n_plain_dims, nnz, device=device, dtype=index_dtype) @@ -2064,20 +2069,21 @@ def random_sparse_compressed(n_compressed_dims, n_plain_dims, nnz): torch.randperm(n_plain_dims, dtype=index_dtype, device=device)[:count]) low = -1 if dtype != torch.uint8 else 0 high = 1 if dtype != torch.uint8 else 2 - values = make_tensor((nnz,) + block_size, device=device, dtype=dtype, low=low, high=high) + values = make_tensor((nnz,) + blocksize, device=device, dtype=dtype, low=low, high=high) return values, compressed_indices, plain_indices batch_shape = size[:-2] n_batch = reduce(mul, batch_shape, 1) if layout in {torch.sparse_csr, torch.sparse_bsr}: - n_compressed_dims, n_plain_dims = size[-2], size[-1] + n_compressed_dims, n_plain_dims = size[-2] // blocksize0, size[-1] // blocksize1 else: - n_compressed_dims, n_plain_dims = size[-1], size[-2] - sparse_tensors = [random_sparse_compressed(n_compressed_dims, n_plain_dims, nnz) for _ in range(n_batch)] + n_compressed_dims, n_plain_dims = size[-1] // blocksize1, size[-2] // blocksize0 + blocknnz = nnz // (blocksize0 * blocksize1) + sparse_tensors = [random_sparse_compressed(n_compressed_dims, n_plain_dims, blocknnz) for _ in range(n_batch)] sparse_tensors_it = map(list, zip(*sparse_tensors)) - values = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, nnz, *block_size) + values = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, blocknnz, *blocksize) compressed_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1) plain_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1) @@ -2086,21 +2092,21 @@ def random_sparse_compressed(n_compressed_dims, n_plain_dims, nnz): def genSparseCSRTensor(self, size, nnz, *, device, dtype, index_dtype): return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_csr, device=device, - dtype=dtype, index_dtype=index_dtype, block_size=()) + dtype=dtype, index_dtype=index_dtype, blocksize=()) def genSparseCSCTensor(self, size, nnz, *, device, dtype, index_dtype): return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_csc, device=device, - dtype=dtype, index_dtype=index_dtype, block_size=()) + dtype=dtype, index_dtype=index_dtype, blocksize=()) - def genSparseBSRTensor(self, size, block_size, nnz, *, device, dtype, index_dtype): - assert len(block_size) == 2 + def genSparseBSRTensor(self, size, blocksize, nnz, *, device, dtype, index_dtype): + assert len(blocksize) == 2 return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_bsr, device=device, - dtype=dtype, index_dtype=index_dtype, block_size=block_size) + dtype=dtype, index_dtype=index_dtype, blocksize=blocksize) - def genSparseBSCTensor(self, size, block_size, nnz, *, device, dtype, index_dtype): - assert len(block_size) == 2 + def genSparseBSCTensor(self, size, blocksize, nnz, *, device, dtype, index_dtype): + assert len(blocksize) == 2 return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_bsc, device=device, - dtype=dtype, index_dtype=index_dtype, block_size=block_size) + dtype=dtype, index_dtype=index_dtype, blocksize=blocksize) def genSparseTensor(self, size, sparse_dim, nnz, is_uncoalesced, device, dtype): # Assert not given impossible combination, where the sparse dims have