Skip to content

Commit

Permalink
[traced-graph][sparse] fix restrictive assert for sparse add (pytorch…
Browse files Browse the repository at this point in the history
…#134037)

exporting sparse addition can be CPU/Meta this fixes the overly restrictive assert and adds an exporting test
Pull Request resolved: pytorch#134037
Approved by: https://github.com/ezyang
  • Loading branch information
aartbik authored and pytorchmergebot committed Aug 24, 2024
1 parent 90fb837 commit 050aa67
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
12 changes: 7 additions & 5 deletions aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -816,19 +816,19 @@ static void add_out_dense_sparse_compressed_cpu(
TORCH_INTERNAL_ASSERT(dense.layout() == kStrided);
TORCH_INTERNAL_ASSERT(
src.layout() == kSparseCsr || src.layout() == kSparseCsc);
TORCH_INTERNAL_ASSERT(dense.device() == kCPU);
TORCH_INTERNAL_ASSERT(dense.device() == kCPU || dense.device() == kMeta);

TORCH_CHECK(
out.is_contiguous(),
"out argument must be contiguous, but got: ",
out.suggest_memory_format());
TORCH_CHECK(
out.device() == kCPU,
"add: expected 'out' to be CPU tensor, but got tensor on device: ",
out.device() == dense.device(),
"add: expected 'out' to match dense tensor, but got tensor on device: ",
out.device());
TORCH_CHECK(
src.device() == kCPU,
"add: expected 'other' to be a CPU tensor, but got tensor on device: ",
src.device() == dense.device(),
"add: expected 'src' to match dense tensor, but got tensor on device: ",
src.device());

TORCH_CHECK(
Expand Down Expand Up @@ -864,6 +864,8 @@ static void add_out_dense_sparse_compressed_cpu(
return;
}

TORCH_INTERNAL_ASSERT(dense.device() == kCPU);

auto valuesBuffer = src_values.to(commonDtype).reshape({-1, src_values.size(-1)});
resultBuffer = resultBuffer.view({-1, out.size(-2), out.size(-1)});
Tensor src_compressed_indices;
Expand Down
33 changes: 33 additions & 0 deletions test/export/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ def forward(self, x):
return x.to_dense()


class AddNet(torch.nn.Module):
def forward(self, x, y):
return torch.add(x, y)


class SparseActivationCOO(torch.nn.Module):
def forward(self, x):
return [xi.to_sparse() for xi in x]
Expand Down Expand Up @@ -218,6 +223,34 @@ def test_todensenet(self, dtype, itype, layout):
else:
self.assertEqual(meta, None)

def test_add(self):
net = AddNet()
Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4)
A = torch.tensor(
[
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 2.0],
[0.0, 0.0, 1.0, 1.0],
[3.0, 0.0, 3.0, 0.0],
],
dtype=torch.float32,
)
S = A.to_sparse_csr()
result = net(S, Y)
# Build the traced graph.
prog = torch.export.export(net, (S, Y))
# Test args/add/output.
for i, node in enumerate(prog.graph.nodes):
meta = node.meta.get("val", None)
if i == 0:
self.assertEqualMeta(meta, S)
elif i == 1:
self.assertEqualMeta(meta, Y)
elif i == 2:
self.assertEqualMeta(meta, result)
else:
self.assertEqual(meta, None)

def test_activation_coo(self):
net = SparseActivationCOO()
x = [torch.randn(3, 3) for _ in range(3)]
Expand Down

0 comments on commit 050aa67

Please sign in to comment.