Skip to content

Commit

Permalink
[data] refine AsNodePredDataset and add tests for DGLCSVDataset (dmlc…
Browse files Browse the repository at this point in the history
…#3722)

* [data] refine AsNodePredDataset and add tests for DGLCSVDataset

* fix

* remove add_self_loop

* refine
  • Loading branch information
Rhett-Ying authored Feb 10, 2022
1 parent fcd8ed9 commit 45ac572
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
5 changes: 4 additions & 1 deletion python/dgl/data/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .dgl_dataset import DGLDataset
from . import utils
from .. import backend as F

__all__ = ['AsNodePredDataset']

Expand Down Expand Up @@ -68,13 +69,15 @@ def __init__(self,
self.g = dataset[0].clone()
self.split_ratio = split_ratio
self.target_ntype = target_ntype
self.num_classes = dataset.num_classes
self.num_classes = getattr(dataset, 'num_classes', None)
super().__init__(dataset.name + '-as-nodepred', **kwargs)

def process(self):
if 'label' not in self.g.nodes[self.target_ntype].data:
raise ValueError("Missing node labels. Make sure labels are stored "
"under name 'label'.")
if self.num_classes is None:
self.num_classes = len(F.unique(self.g.nodes[self.target_ntype].data['label']))
if self.verbose:
print('Generating train/val/test masks...')
utils.add_nodepred_split(self, self.split_ratio, self.target_ntype)
Expand Down
43 changes: 43 additions & 0 deletions tests/compute/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,48 @@ def test_as_nodepred2():
ds = data.AsNodePredDataset(data.AIFBDataset(), [0.1, 0.1, 0.8], 'Personen', verbose=True)
assert F.sum(F.astype(ds[0].nodes['Personen'].data['train_mask'], F.int32), 0) == int(ds[0].num_nodes('Personen') * 0.1)

@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_as_nodepred_csvdataset():
with tempfile.TemporaryDirectory() as test_dir:
# generate YAML/CSVs
meta_yaml_path = os.path.join(test_dir, "meta.yaml")
edges_csv_path = os.path.join(test_dir, "test_edges.csv")
nodes_csv_path = os.path.join(test_dir, "test_nodes.csv")
meta_yaml_data = {'version': '1.0.0', 'dataset_name': 'default_name',
'node_data': [{'file_name': os.path.basename(nodes_csv_path)
}],
'edge_data': [{'file_name': os.path.basename(edges_csv_path)
}],
}
with open(meta_yaml_path, 'w') as f:
yaml.dump(meta_yaml_data, f, sort_keys=False)
num_nodes = 100
num_edges = 500
num_dims = 3
num_classes = num_nodes
feat_ndata = np.random.rand(num_nodes, num_dims)
label_ndata = np.arange(num_classes)
df = pd.DataFrame({'node_id': np.arange(num_nodes),
'label': label_ndata,
'feat': [line.tolist() for line in feat_ndata],
})
df.to_csv(nodes_csv_path, index=False)
df = pd.DataFrame({'src_id': np.random.randint(num_nodes, size=num_edges),
'dst_id': np.random.randint(num_nodes, size=num_edges),
})
df.to_csv(edges_csv_path, index=False)

ds = data.DGLCSVDataset(test_dir, force_reload=True)
assert 'feat' in ds[0].ndata
assert 'label' in ds[0].ndata
assert 'train_mask' not in ds[0].ndata
assert not hasattr(ds[0], 'num_classes')
new_ds = data.AsNodePredDataset(ds, force_reload=True)
assert new_ds.num_classes == num_classes
assert 'feat' in new_ds[0].ndata
assert 'label' in new_ds[0].ndata
assert 'train_mask' in new_ds[0].ndata

if __name__ == '__main__':
test_minigc()
test_gin()
Expand All @@ -1079,3 +1121,4 @@ def test_as_nodepred2():
test_add_nodepred_split()
test_as_nodepred1()
test_as_nodepred2()
test_as_nodepred_csvdataset()

0 comments on commit 45ac572

Please sign in to comment.