Skip to content

Commit

Permalink
Support compressed sparse tensors with dense dimensions (pytorch#80565)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#80565
Approved by: https://github.com/cpuhrsch
  • Loading branch information
pearu authored and pytorchmergebot committed Jul 7, 2022
1 parent beb9867 commit d266256
Show file tree
Hide file tree
Showing 9 changed files with 27,220 additions and 194 deletions.
5,656 changes: 5,640 additions & 16 deletions test/expect/TestSparseCompressedCPU.test_print_SparseBSC_cpu.expect

Large diffs are not rendered by default.

5,848 changes: 5,832 additions & 16 deletions test/expect/TestSparseCompressedCPU.test_print_SparseBSR_cpu.expect

Large diffs are not rendered by default.

1,064 changes: 1,048 additions & 16 deletions test/expect/TestSparseCompressedCPU.test_print_SparseCSC_cpu.expect

Large diffs are not rendered by default.

1,064 changes: 1,048 additions & 16 deletions test/expect/TestSparseCompressedCPU.test_print_SparseCSR_cpu.expect

Large diffs are not rendered by default.

5,658 changes: 5,642 additions & 16 deletions test/expect/TestSparseCompressedCUDA.test_print_SparseBSC_cuda.expect

Large diffs are not rendered by default.

5,850 changes: 5,834 additions & 16 deletions test/expect/TestSparseCompressedCUDA.test_print_SparseBSR_cuda.expect

Large diffs are not rendered by default.

1,064 changes: 1,048 additions & 16 deletions test/expect/TestSparseCompressedCUDA.test_print_SparseCSC_cuda.expect

Large diffs are not rendered by default.

1,064 changes: 1,048 additions & 16 deletions test/expect/TestSparseCompressedCUDA.test_print_SparseCSR_cuda.expect

Large diffs are not rendered by default.

146 changes: 80 additions & 66 deletions test/test_sparse_csr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Owner(s): ["module: sparse"]

import copy
import torch
import random
import itertools
Expand Down Expand Up @@ -176,7 +177,7 @@ def genTensor(self, size, nnz, *, layout, device=None, dtype=torch.float, index_

def _generate_small_inputs_utils(self, layout, device=None, dtype=None):

def shape(shape, basedim=0, blocksize=(1, 1)):
def shape(shape, basedim=0, blocksize=(1, 1), dense_shape=()):
# Below, we define compressed and plain indices that
# correspond to row compressed tensors. In order to reuse
# the indices tensors for column compressed tensors, we
Expand All @@ -197,53 +198,62 @@ def shape(shape, basedim=0, blocksize=(1, 1)):
shape = shape[:basedim] + (shape[basedim] * blocksize[0], shape[basedim + 1] * blocksize[1]) + shape[basedim + 2:]
return shape

def values(lst, basedim=0, blocksize=(1, 1), device=device, dtype=dtype):
# Below, we define values for non-blocked tensors. To
# reuse these for blocked tensors, we replace all values
# in lst with a double-list that "shape" corresponds to
# blocksize.
if layout in {torch.sparse_bsr, torch.sparse_bsc}:

if layout is torch.sparse_bsc:
blocksize = tuple(reversed(blocksize))

if not lst:
return torch.tensor(lst, device=device, dtype=dtype).reshape(0, *blocksize)

def list_add(lst, value):
if isinstance(lst, list):
return [list_add(item, value) for item in lst]
return lst + value

def apply_block(value, bdim):
if isinstance(value, list) and bdim >= 0:
return [apply_block(item, bdim - 1) for item in value]
new_value = []
for i in range(blocksize[0]):
row = []
for j in range(blocksize[1]):
row.append(list_add(value, i + 10 * j))
new_value.append(row)
return new_value

lst = apply_block(lst, basedim)
def values(lst, basedim=0, blocksize=(1, 1), densesize=(), device=device, dtype=dtype):
# Below, we define values for non-blocked and non-hybrid
# tensors. To reuse these for blocked tensors, we replace
# all values in lst with a double-list that "shape"
# corresponds to blocksize.
# To support hybrid tensors, the values in lst are further
# replaced with a N-list where N==len(densesize) and the
# shape corresponds to densesize.

def list_add(lst, value):
# recursively add a value to lst items
if isinstance(lst, list):
return [list_add(item, value) for item in lst]
return lst + value

def stretch_values(value, bdim, values_item_shape):
# replace a value with a new value that extends the
# dimensionality of the value by
# len(values_item_shape) from right. The left
# dimensions up to bdim are considered as batch
# dimensions.
if not values_item_shape:
return value
if isinstance(value, list) and bdim >= 0:
return [stretch_values(item, bdim - 1, values_item_shape) for item in value]
new_value = functools.reduce(lambda x, dims: [copy.deepcopy(x) for _ in range(dims)],
reversed(values_item_shape), None)
for p in itertools.product(*map(list, map(range, values_item_shape))):
row = functools.reduce(lambda x, i: x.__getitem__(i), p[:-1], new_value)
row[p[-1]] = list_add(value, sum([i * 10 ** d for d, i in enumerate(p)]))
return new_value

if layout is torch.sparse_bsr:
values_item_shape = blocksize + densesize
elif layout is torch.sparse_bsc:
values_item_shape = tuple(reversed(blocksize)) + densesize
else:
values_item_shape = densesize

if not lst:
return torch.tensor(lst, device=device, dtype=dtype).reshape(0, *values_item_shape)

lst = stretch_values(lst, basedim, values_item_shape)

return torch.tensor(lst, device=device, dtype=dtype)

return shape, values

def _generate_small_inputs(self, layout, device, dtype, index_dtype,
enable_batched=True,
enable_hybrid=False):
enable_batched=True, enable_hybrid=True):
"""Generator of inputs to sparse compressed tensor factory functions.
The input is defined as a 4-tuple:
compressed_indices, plain_indices, values, expected_size_from_shape_inference
"""

# TODO: set enable_hybrid default to True when Sparse
# Compressed tensors support dense dimensions

shape, values = self._generate_small_inputs_utils(layout, device, dtype)

# a regular tensor
Expand Down Expand Up @@ -280,13 +290,13 @@ def _generate_small_inputs(self, layout, device, dtype, index_dtype,
# a tensor with one dense dimension
yield (torch.tensor([0, 2, 4], device=device, dtype=index_dtype),
torch.tensor([0, 1, 0, 2], device=device, dtype=index_dtype),
values([[1, 11], [2, 12], [3, 13], [4, 14]], 0, (3, 2)),
values([1, 2, 3, 4], 0, (3, 2), (2,)),
shape((2, 3, 2), 0, (3, 2)))

# a tensor with two dense dimensions
yield (torch.tensor([0, 2, 4], device=device, dtype=index_dtype),
torch.tensor([0, 1, 0, 2], device=device, dtype=index_dtype),
values([[[1, 11]], [[2, 12]], [[3, 13]], [[4, 14]]], 0, (2, 3)),
values([1, 2, 3, 4], 0, (2, 3), (4, 2)),
shape((2, 3, 4, 2), 0, (2, 3)))

if enable_batched and enable_hybrid:
Expand All @@ -297,13 +307,8 @@ def _generate_small_inputs(self, layout, device, dtype, index_dtype,
torch.tensor([[[0, 1, 0, 1], [0, 1, 2, 0], [0, 0, 1, 2]],
[[1, 0, 1, 2], [0, 2, 0, 1], [0, 1, 2, 1]]],
device=device, dtype=index_dtype),
values([[[[[1], [11]], [[2], [12]], [[3], [13]], [[4], [14]]],
[[[5], [15]], [[6], [16]], [[7], [17]], [[8], [18]]],
[[[9], [19]], [[10], [20]], [[11], [21]], [[12], [22]]]],
[[[[3], [13]], [[4], [14]], [[5], [15]], [[6], [16]]],
[[[7], [17]], [[8], [18]], [[9], [19]], [[10], [20]]],
[[[11], [21]], [[12], [22]], [[13], [23]], [[14], [24]]]]],
2, (3, 2)),
values([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]], 2, (3, 2), (2, 1)),
shape((2, 3, 2, 3, 2, 1), 2, (3, 2)))

@all_sparse_compressed_layouts()
Expand Down Expand Up @@ -418,28 +423,37 @@ def test_clone(self, layout, device, dtype):
def test_print(self, layout, device):
compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout]
printed = []
for index_dtype in [torch.int32, torch.int64]:
for dtype in [torch.float32, torch.float64]:
for compressed_indices, plain_indices, values, size in self._generate_small_inputs(
layout, device, dtype, index_dtype):
batch_shape = tuple(size[:-2])
block_shape = tuple(values.shape[-2:]) if layout in {torch.sparse_bsr, torch.sparse_bsc} else ()
blocksize0, blocksize1 = block_shape if layout in {torch.sparse_bsr, torch.sparse_bsc} else (1, 1)
printed.append("########## {}/{}/batch_shape={}/block_shape={} ##########".format(
dtype, index_dtype, batch_shape, block_shape))
x = torch.sparse_compressed_tensor(compressed_indices,
plain_indices,
values, dtype=dtype, layout=layout, device=device)
printed.append("# sparse tensor")
printed.append(str(x))
printed.append(f"# _{compressed_indices_mth.__name__}")
printed.append(str(compressed_indices_mth(x)))
printed.append(f"# _{plain_indices_mth.__name__}")
printed.append(str(plain_indices_mth(x)))
printed.append("# _values")
printed.append(str(x.values()))
for enable_hybrid in [False, True]:
for index_dtype in [torch.int32, torch.int64]:
for dtype in [torch.float32, torch.float64]:
for compressed_indices, plain_indices, values, size in self._generate_small_inputs(
layout, device, dtype, index_dtype, enable_hybrid=enable_hybrid):
block_ndim = 2 if layout in {torch.sparse_bsr, torch.sparse_bsc} else 0
base_ndim = 2
batch_ndim = compressed_indices.dim() - 1
dense_ndim = values.dim() - batch_ndim - block_ndim - 1
if enable_hybrid and dense_ndim == 0:
# non-hybrid cases are covered by the enable_hybrid==False loop
continue
batchsize = size[:batch_ndim]
basesize = size[batch_ndim:batch_ndim + base_ndim]
densesize = size[batch_ndim + base_ndim:]
assert len(densesize) == dense_ndim
printed.append("########## {}/{}/size={}+{}+{} ##########".format(
dtype, index_dtype, batchsize, basesize, densesize))
x = torch.sparse_compressed_tensor(compressed_indices,
plain_indices,
values, size, dtype=dtype, layout=layout, device=device)
printed.append("# sparse tensor")
printed.append(str(x))
printed.append(f"# _{compressed_indices_mth.__name__}")
printed.append(str(compressed_indices_mth(x)))
printed.append(f"# _{plain_indices_mth.__name__}")
printed.append(str(plain_indices_mth(x)))
printed.append("# _values")
printed.append(str(x.values()))
printed.append('')
printed.append('')
printed.append('')
orig_maxDiff = self.maxDiff
self.maxDiff = None
try:
Expand Down

0 comments on commit d266256

Please sign in to comment.