Skip to content

Commit

Permalink
[Dataset] Add transform argument to built-in datasets (dmlc#3733)
Browse files Browse the repository at this point in the history
* Update

* Fix

* Update
  • Loading branch information
mufeili authored Feb 15, 2022
1 parent b3d3a2c commit 8b8fd2c
Show file tree
Hide file tree
Showing 22 changed files with 621 additions and 258 deletions.
14 changes: 11 additions & 3 deletions python/dgl/data/bitcoinotc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
verbose: bool
Whether to print out progress information.
Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
Expand Down Expand Up @@ -67,12 +71,13 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
_url = 'https://snap.stanford.edu/data/soc-sign-bitcoinotc.csv.gz'
_sha1_str = 'c14281f9e252de0bd0b5f1c6e2bae03123938641'

def __init__(self, raw_dir=None, force_reload=False, verbose=False):
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(BitcoinOTCDataset, self).__init__(name='bitcoinotc',
url=self._url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)

def download(self):
gz_file_path = os.path.join(self.raw_dir, self.name + '.csv.gz')
Expand Down Expand Up @@ -143,7 +148,10 @@ def __getitem__(self, item):
- ``edata['h']`` : edge weights
"""
return self.graphs[item]
if self._transform is None:
return self.graphs[item]
else:
return self._transform(self.graphs[item])

@property
def is_temporal(self):
Expand Down
130 changes: 90 additions & 40 deletions python/dgl/data/citation_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,23 @@ class CitationGraphDataset(DGLBuiltinDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
reverse_edge: bool
reverse_edge : bool
Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
"""
_urls = {
'cora_v2' : 'dataset/cora_v2.zip',
'citeseer' : 'dataset/citeseer.zip',
'pubmed' : 'dataset/pubmed.zip',
}

def __init__(self, name, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
def __init__(self, name, raw_dir=None, force_reload=False,
verbose=True, reverse_edge=True, transform=None):
assert name.lower() in ['cora', 'citeseer', 'pubmed']

# Previously we use the pre-processing in pygcn (https://github.com/tkipf/pygcn)
Expand All @@ -69,7 +74,8 @@ def __init__(self, name, raw_dir=None, force_reload=False, verbose=True, reverse
url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)

def process(self):
"""Loads input data from data directory and reorder graph for better locality
Expand Down Expand Up @@ -213,7 +219,10 @@ def load(self):

def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph"
return self._g
if self._transform is None:
return self._g
else:
return self._transform(self._g)

def __len__(self):
return 1
Expand Down Expand Up @@ -267,7 +276,7 @@ def features(self):
@property
def reverse_edge(self):
return self._reverse_edge


def _preprocess_features(features):
"""Row-normalize feature matrix and convert to tuple representation"""
Expand Down Expand Up @@ -356,10 +365,14 @@ class CoraGraphDataset(CitationGraphDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
reverse_edge: bool
reverse_edge : bool
Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
Expand Down Expand Up @@ -400,10 +413,12 @@ class CoraGraphDataset(CitationGraphDataset):
>>> label = g.ndata['label']
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
def __init__(self, raw_dir=None, force_reload=False, verbose=True,
reverse_edge=True, transform=None):
name = 'cora'

super(CoraGraphDataset, self).__init__(name, raw_dir, force_reload, verbose, reverse_edge)
super(CoraGraphDataset, self).__init__(name, raw_dir, force_reload,
verbose, reverse_edge, transform)

def __getitem__(self, idx):
r"""Gets the graph object
Expand Down Expand Up @@ -496,10 +511,14 @@ class CiteseerGraphDataset(CitationGraphDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
reverse_edge: bool
reverse_edge : bool
Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
Expand Down Expand Up @@ -543,10 +562,12 @@ class CiteseerGraphDataset(CitationGraphDataset):
>>> label = g.ndata['label']
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
def __init__(self, raw_dir=None, force_reload=False,
verbose=True, reverse_edge=True, transform=None):
name = 'citeseer'

super(CiteseerGraphDataset, self).__init__(name, raw_dir, force_reload, verbose, reverse_edge)
super(CiteseerGraphDataset, self).__init__(name, raw_dir, force_reload,
verbose, reverse_edge, transform)

def __getitem__(self, idx):
r"""Gets the graph object
Expand Down Expand Up @@ -639,10 +660,14 @@ class PubmedGraphDataset(CitationGraphDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
reverse_edge: bool
reverse_edge : bool
Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
Expand Down Expand Up @@ -683,10 +708,12 @@ class PubmedGraphDataset(CitationGraphDataset):
>>> label = g.ndata['label']
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
def __init__(self, raw_dir=None, force_reload=False, verbose=True,
reverse_edge=True, transform=None):
name = 'pubmed'

super(PubmedGraphDataset, self).__init__(name, raw_dir, force_reload, verbose, reverse_edge)
super(PubmedGraphDataset, self).__init__(name, raw_dir, force_reload,
verbose, reverse_edge, transform)

def __getitem__(self, idx):
r"""Gets the graph object
Expand Down Expand Up @@ -714,7 +741,7 @@ def __len__(self):
r"""The number of graphs in the dataset."""
return super(PubmedGraphDataset, self).__len__()

def load_cora(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
def load_cora(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True, transform=None):
"""Get CoraGraphDataset
Parameters
Expand All @@ -724,19 +751,24 @@ def load_cora(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True)
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
reverse_edge: bool
verbose : bool
Whether to print out progress information. Default: True.
reverse_edge : bool
Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Return
-------
CoraGraphDataset
"""
data = CoraGraphDataset(raw_dir, force_reload, verbose, reverse_edge)
data = CoraGraphDataset(raw_dir, force_reload, verbose, reverse_edge, transform)
return data

def load_citeseer(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
def load_citeseer(raw_dir=None, force_reload=False, verbose=True,
reverse_edge=True, transform=None):
"""Get CiteseerGraphDataset
Parameters
Expand All @@ -746,38 +778,47 @@ def load_citeseer(raw_dir=None, force_reload=False, verbose=True, reverse_edge=T
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
reverse_edge: bool
verbose : bool
Whether to print out progress information. Default: True.
reverse_edge : bool
Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Return
-------
CiteseerGraphDataset
"""
data = CiteseerGraphDataset(raw_dir, force_reload, verbose, reverse_edge)
data = CiteseerGraphDataset(raw_dir, force_reload, verbose, reverse_edge, transform)
return data

def load_pubmed(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
def load_pubmed(raw_dir=None, force_reload=False, verbose=True,
reverse_edge=True, transform=None):
"""Get PubmedGraphDataset
Parameters
-----------
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose : bool
Whether to print out progress information. Default: True.
reverse_edge: bool
reverse_edge : bool
Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Return
-------
PubmedGraphDataset
"""
data = PubmedGraphDataset(raw_dir, force_reload, verbose, reverse_edge)
data = PubmedGraphDataset(raw_dir, force_reload, verbose, reverse_edge, transform)
return data

class CoraBinary(DGLBuiltinDataset):
Expand All @@ -798,15 +839,20 @@ class CoraBinary(DGLBuiltinDataset):
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=True):
def __init__(self, raw_dir=None, force_reload=False, verbose=True, transform=None):
name = 'cora_binary'
url = _get_dgl_url('dataset/cora_binary.zip')
super(CoraBinary, self).__init__(name,
url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)

def process(self):
root = self.raw_path
Expand Down Expand Up @@ -894,7 +940,11 @@ def __getitem__(self, i):
(dgl.DGLGraph, scipy.sparse.coo_matrix, int)
The graph, scipy sparse coo_matrix and its label.
"""
return (self.graphs[i], self.pmpds[i], self.labels[i])
if self._transform is None:
g = self.graphs[i]
else:
g = self._transform(self.graphs[i])
return (g, self.pmpds[i], self.labels[i])

@property
def save_name(self):
Expand Down
18 changes: 14 additions & 4 deletions python/dgl/data/csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ class DGLCSVDataset(DGLDataset):
A callable object which is used to parse corresponding column graph
data. Default: None. If None, a default data parser is applied
which load data directly and tries to convert list into array.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
Expand All @@ -46,7 +50,8 @@ class DGLCSVDataset(DGLDataset):
"""
META_YAML_NAME = 'meta.yaml'

def __init__(self, data_path, force_reload=False, verbose=True, node_data_parser=None, edge_data_parser=None, graph_data_parser=None):
def __init__(self, data_path, force_reload=False, verbose=True, node_data_parser=None,
edge_data_parser=None, graph_data_parser=None, transform=None):
from .csv_dataset_base import load_yaml_with_sanity_check, DefaultDataParser
self.graphs = None
self.data = None
Expand All @@ -61,7 +66,7 @@ def __init__(self, data_path, force_reload=False, verbose=True, node_data_parser
self.meta_yaml = load_yaml_with_sanity_check(meta_yaml_path)
ds_name = self.meta_yaml.dataset_name
super().__init__(ds_name, raw_dir=os.path.dirname(
meta_yaml_path), force_reload=force_reload, verbose=verbose)
meta_yaml_path), force_reload=force_reload, verbose=verbose, transform=transform)


def process(self):
Expand Down Expand Up @@ -122,10 +127,15 @@ def load(self):
self.graphs, self.data = load_graphs(graph_path)

def __getitem__(self, i):
if self._transform is None:
g = self.graphs[i]
else:
g = self._transform(self.graphs[i])

if 'label' in self.data:
return self.graphs[i], self.data['label'][i]
return g, self.data['label'][i]
else:
return self.graphs[i]
return g

def __len__(self):
return len(self.graphs)
Loading

0 comments on commit 8b8fd2c

Please sign in to comment.