Skip to content

Commit

Permalink
[GraphBolt][CUDA] Remove unnecessary check and synchronization (dmlc#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Dec 29, 2023
1 parent a2cb2ec commit 93a5834
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 27 deletions.
8 changes: 8 additions & 0 deletions python/dgl/graphbolt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,14 @@ class CSCFormatBase:
indptr: torch.Tensor = None
indices: torch.Tensor = None

def __init__(self, indptr: torch.Tensor, indices: torch.Tensor):
self.indptr = indptr
self.indices = indices
if not indptr.is_cuda:
assert self.indptr[-1] == len(
self.indices
), "The last element of indptr should be the same as the length of indices."

def __repr__(self) -> str:
return _csc_format_base_str(self)

Expand Down
9 changes: 0 additions & 9 deletions python/dgl/graphbolt/internal/sample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,6 @@ def unique_and_compact_csc_formats(
for etype, csc_format in csc_formats.items():
if device is None:
device = csc_format.indices.device
assert csc_format.indptr[-1] == len(
csc_format.indices
), "The last element of indptr should be the same as the length of indices."
src_type, _, dst_type = etype_str_to_tuple(etype)
assert len(unique_dst_nodes.get(dst_type, [])) + 1 == len(
csc_format.indptr
Expand Down Expand Up @@ -358,9 +355,6 @@ def compact_csc_format(
assert isinstance(
dst_nodes, torch.Tensor
), "Edge type not supported in homogeneous graph."
assert csc_formats.indptr[-1] == len(
csc_formats.indices
), "The last element of indptr should be the same as the length of indices."
assert len(dst_nodes) + 1 == len(
csc_formats.indptr
), "The seed nodes should correspond to indptr."
Expand All @@ -381,9 +375,6 @@ def compact_csc_format(
compacted_csc_formats = {}
original_row_ids = copy.deepcopy(dst_nodes)
for etype, csc_format in csc_formats.items():
assert csc_format.indptr[-1] == len(
csc_format.indices
), "The last element of indptr should be the same as the length of indices."
src_type, _, dst_type = etype_str_to_tuple(etype)
assert len(dst_nodes.get(dst_type, [])) + 1 == len(
csc_format.indptr
Expand Down
4 changes: 2 additions & 2 deletions python/dgl/graphbolt/minibatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def blocks(self):
v.indices,
torch.arange(
0,
v.indptr[-1],
len(v.indices),
device=v.indptr.device,
dtype=v.indptr.dtype,
),
Expand All @@ -227,7 +227,7 @@ def blocks(self):
sampled_csc.indices,
torch.arange(
0,
sampled_csc.indptr[-1],
len(sampled_csc.indices),
device=sampled_csc.indptr.device,
dtype=sampled_csc.indptr.dtype,
),
Expand Down
8 changes: 8 additions & 0 deletions tests/python/pytorch/graphbolt/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,11 @@ def test_csc_format_base_representation():
)"""
)
assert str(csc_format_base) == expected_result, print(csc_format_base)


def test_csc_format_base_incorrect_indptr():
indptr = torch.tensor([0, 2, 4, 6, 7, 11])
indices = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4])
with pytest.raises(AssertionError):
# The value of last element in indptr is not corresponding to indices.
csc_formats = gb.CSCFormatBase(indptr=indptr, indices=indices)
16 changes: 0 additions & 16 deletions tests/python/pytorch/graphbolt/test_graphbolt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,14 +350,6 @@ def test_unique_and_compact_incorrect_indptr():
with pytest.raises(AssertionError):
gb.unique_and_compact_csc_formats(csc_formats, seeds)

seeds = torch.tensor([1, 3, 5, 2, 6])
indptr = torch.tensor([0, 2, 4, 6, 7, 11])
indices = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4])
csc_formats = gb.CSCFormatBase(indptr=indptr, indices=indices)
# The value of last element in indptr is not corresponding to indices.
with pytest.raises(AssertionError):
gb.unique_and_compact_csc_formats(csc_formats, seeds)


def test_compact_csc_format_hetero():
dst_nodes = {
Expand Down Expand Up @@ -449,11 +441,3 @@ def test_compact_incorrect_indptr():
# The number of seeds is not corresponding to indptr.
with pytest.raises(AssertionError):
gb.compact_csc_format(csc_formats, seeds)

seeds = torch.tensor([1, 3, 5, 2, 6])
indptr = torch.tensor([0, 2, 4, 6, 7, 11])
indices = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4])
csc_formats = gb.CSCFormatBase(indptr=indptr, indices=indices)
# The value of last element in indptr is not corresponding to indices.
with pytest.raises(AssertionError):
gb.compact_csc_format(csc_formats, seeds)

0 comments on commit 93a5834

Please sign in to comment.