Skip to content

Commit

Permalink
[GraphBolt][CUDA] Fix graph pinning and add tests (dmlc#6864)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Dec 29, 2023
1 parent 22a2513 commit 5b51e96
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 7 deletions.
2 changes: 1 addition & 1 deletion python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,7 @@ def pin_memory_(self):
"""Copy `FusedCSCSamplingGraph` to the pinned memory in-place."""

def _pin(x):
return x.pinned_memory() if hasattr(x, "pinned_memory") else x
return x.pin_memory() if hasattr(x, "pin_memory") else x

self._apply_to_members(_pin)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1511,11 +1511,7 @@ def test_from_dglgraph_heterogeneous():
}


@unittest.skipIf(
F._default_context_str == "cpu",
reason="`to` function needs GPU to test.",
)
def test_csc_sampling_graph_to_device():
def create_fused_csc_sampling_graph():
# Initialize data.
total_num_nodes = 10
total_num_edges = 9
Expand All @@ -1541,7 +1537,7 @@ def test_csc_sampling_graph_to_device():
}

# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(
return gb.fused_csc_sampling_graph(
indptr,
indices,
edge_attributes=edge_attributes,
Expand All @@ -1551,6 +1547,15 @@ def test_csc_sampling_graph_to_device():
edge_type_to_id=etypes,
)


@unittest.skipIf(
F._default_context_str == "cpu",
reason="`to` function needs GPU to test.",
)
def test_csc_sampling_graph_to_device():
# Construct FusedCSCSamplingGraph.
graph = create_fused_csc_sampling_graph()

# Copy to device.
graph = graph.to("cuda")

Expand All @@ -1564,6 +1569,27 @@ def test_csc_sampling_graph_to_device():
assert graph.edge_attributes[key].device.type == "cuda"


@unittest.skipIf(
F._default_context_str == "cpu",
reason="Tests for pinned memory are only meaningful on GPU.",
)
def test_csc_sampling_graph_to_pinned_memory():
# Construct FusedCSCSamplingGraph.
graph = create_fused_csc_sampling_graph()

# Copy to pinned_memory in-place.
graph.pin_memory_()

# Check.
assert graph.csc_indptr.is_pinned()
assert graph.indices.is_pinned()
assert graph.node_type_offset.is_pinned()
assert graph.type_per_edge.is_pinned()
assert graph.csc_indptr.is_pinned()
for key in graph.edge_attributes:
assert graph.edge_attributes[key].is_pinned()


@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
Expand Down

0 comments on commit 5b51e96

Please sign in to comment.