Skip to content

Commit

Permalink
[Feature] support non-numeric node_id/src_id/dst_id/graph_id and rena… (
Browse files Browse the repository at this point in the history
dmlc#3740)

* [Feature] support non-numeric node_id/src_id/dst_id/graph_id and rename CSVDataset

* change return value when iterate dataset

* refine data_parser

* force reload
  • Loading branch information
Rhett-Ying authored Feb 17, 2022
1 parent 42f8c8f commit 39121df
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 70 deletions.
2 changes: 1 addition & 1 deletion python/dgl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from .fraud import FraudDataset, FraudYelpDataset, FraudAmazonDataset
from .fakenews import FakeNewsDataset
from .csv_dataset import DGLCSVDataset
from .csv_dataset import CSVDataset
from .adapter import AsNodePredDataset, AsLinkPredDataset

def register_data_args(parser):
Expand Down
68 changes: 37 additions & 31 deletions python/dgl/data/csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..base import DGLError


class DGLCSVDataset(DGLDataset):
class CSVDataset(DGLDataset):
""" This class aims to parse data from CSV files, construct DGLGraph
and behaves as a DGLDataset.
Expand All @@ -17,22 +17,27 @@ class DGLCSVDataset(DGLDataset):
Whether to reload the dataset. Default: False
verbose: bool, optional
Whether to print out progress information. Default: True.
node_data_parser : dict[str, callable], optional
A dictionary used for node data parsing when loading from CSV files.
The key is node type which specifies the header in CSV file and the
value is a callable object which is used to parse corresponding
column data. Default: None. If None, a default data parser is applied
which load data directly and tries to convert list into array.
edge_data_parser : dict[(str, str, str), callable], optional
A dictionary used for edge data parsing when loading from CSV files.
The key is edge type which specifies the header in CSV file and the
value is a callable object which is used to parse corresponding
column data. Default: None. If None, a default data parser is applied
which load data directly and tries to convert list into array.
graph_data_parser : callable, optional
A callable object which is used to parse corresponding column graph
data. Default: None. If None, a default data parser is applied
which load data directly and tries to convert list into array.
ndata_parser : dict[str, callable] or callable, optional
Callable object which takes in the ``pandas.DataFrame`` object created from
CSV file, parses node data and returns a dictionary of parsed data. If given a
dictionary, the key is node type and the value is a callable object which is
used to parse data of corresponding node type. If given a single callable
object, such object is used to parse data of all node type data. Default: None.
If None, a default data parser is applied which load data directly and tries to
convert list into array.
edata_parser : dict[(str, str, str), callable], or callable, optional
Callable object which takes in the ``pandas.DataFrame`` object created from
CSV file, parses edge data and returns a dictionary of parsed data. If given a
dictionary, the key is edge type and the value is a callable object which is
used to parse data of corresponding edge type. If given a single callable
object, such object is used to parse data of all edge type data. Default: None.
If None, a default data parser is applied which load data directly and tries to
convert list into array.
gdata_parser : callable, optional
Callable object which takes in the ``pandas.DataFrame`` object created from
CSV file, parses graph data and returns a dictionary of parsed data. Default:
None. If None, a default data parser is applied which load data directly and
tries to convert list into array.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
Expand All @@ -50,19 +55,19 @@ class DGLCSVDataset(DGLDataset):
"""
META_YAML_NAME = 'meta.yaml'

def __init__(self, data_path, force_reload=False, verbose=True, node_data_parser=None,
edge_data_parser=None, graph_data_parser=None, transform=None):
def __init__(self, data_path, force_reload=False, verbose=True, ndata_parser=None,
edata_parser=None, gdata_parser=None, transform=None):
from .csv_dataset_base import load_yaml_with_sanity_check, DefaultDataParser
self.graphs = None
self.data = None
self.node_data_parser = {} if node_data_parser is None else node_data_parser
self.edge_data_parser = {} if edge_data_parser is None else edge_data_parser
self.graph_data_parser = graph_data_parser
self.ndata_parser = {} if ndata_parser is None else ndata_parser
self.edata_parser = {} if edata_parser is None else edata_parser
self.gdata_parser = gdata_parser
self.default_data_parser = DefaultDataParser()
meta_yaml_path = os.path.join(data_path, DGLCSVDataset.META_YAML_NAME)
meta_yaml_path = os.path.join(data_path, CSVDataset.META_YAML_NAME)
if not os.path.exists(meta_yaml_path):
raise DGLError(
"'{}' cannot be found under {}.".format(DGLCSVDataset.META_YAML_NAME, data_path))
"'{}' cannot be found under {}.".format(CSVDataset.META_YAML_NAME, data_path))
self.meta_yaml = load_yaml_with_sanity_check(meta_yaml_path)
ds_name = self.meta_yaml.dataset_name
super().__init__(ds_name, raw_dir=os.path.dirname(
Expand All @@ -80,8 +85,8 @@ def process(self):
if meta_node is None:
continue
ntype = meta_node.ntype
data_parser = self.node_data_parser.get(
ntype, self.default_data_parser)
data_parser = self.ndata_parser if callable(
self.ndata_parser) else self.ndata_parser.get(ntype, self.default_data_parser)
ndata = NodeData.load_from_csv(
meta_node, base_dir=base_dir, separator=meta_yaml.separator, data_parser=data_parser)
node_data.append(ndata)
Expand All @@ -90,15 +95,15 @@ def process(self):
if meta_edge is None:
continue
etype = tuple(meta_edge.etype)
data_parser = self.edge_data_parser.get(
etype, self.default_data_parser)
data_parser = self.edata_parser if callable(
self.edata_parser) else self.edata_parser.get(etype, self.default_data_parser)
edata = EdgeData.load_from_csv(
meta_edge, base_dir=base_dir, separator=meta_yaml.separator, data_parser=data_parser)
edge_data.append(edata)
graph_data = None
if meta_yaml.graph_data is not None:
meta_graph = meta_yaml.graph_data
data_parser = self.default_data_parser if self.graph_data_parser is None else self.graph_data_parser
data_parser = self.default_data_parser if self.gdata_parser is None else self.gdata_parser
graph_data = GraphData.load_from_csv(
meta_graph, base_dir=base_dir, separator=meta_yaml.separator, data_parser=data_parser)
# construct graphs
Expand Down Expand Up @@ -132,8 +137,9 @@ def __getitem__(self, i):
else:
g = self._transform(self.graphs[i])

if 'label' in self.data:
return g, self.data['label'][i]
if len(self.data) > 0:
data = {k: v[i] for (k, v) in self.data.items()}
return g, data
else:
return g

Expand Down
24 changes: 13 additions & 11 deletions python/dgl/data/csv_dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,13 @@ class NodeData(BaseData):
""" Class of node data which is used for DGLGraph construction. Internal use only. """

def __init__(self, node_id, data, type=None, graph_id=None):
self.id = np.array(node_id, dtype=np.int64)
self.id = np.array(node_id)
self.data = data
self.type = type if type is not None else '_V'
self.graph_id = np.array(graph_id, dtype=np.int) if graph_id is not None else np.full(
len(node_id), 0)
_validate_data_length({**{'id': self.id, 'graph_id': self.graph_id}, **self.data})
self.graph_id = np.array(
graph_id) if graph_id is not None else np.full(len(node_id), 0)
_validate_data_length(
{**{'id': self.id, 'graph_id': self.graph_id}, **self.data})

@staticmethod
def load_from_csv(meta: MetaNode, data_parser: Callable, base_dir=None, separator=','):
Expand Down Expand Up @@ -145,13 +146,14 @@ class EdgeData(BaseData):
""" Class of edge data which is used for DGLGraph construction. Internal use only. """

def __init__(self, src_id, dst_id, data, type=None, graph_id=None):
self.src = np.array(src_id, dtype=np.int64)
self.dst = np.array(dst_id, dtype=np.int64)
self.src = np.array(src_id)
self.dst = np.array(dst_id)
self.data = data
self.type = type if type is not None else ('_V', '_E', '_V')
self.graph_id = np.array(graph_id, dtype=np.int) if graph_id is not None else np.full(
len(src_id), 0)
_validate_data_length({**{'src': self.src, 'dst': self.dst, 'graph_id': self.graph_id}, **self.data})
self.graph_id = np.array(
graph_id) if graph_id is not None else np.full(len(src_id), 0)
_validate_data_length(
{**{'src': self.src, 'dst': self.dst, 'graph_id': self.graph_id}, **self.data})

@staticmethod
def load_from_csv(meta: MetaEdge, data_parser: Callable, base_dir=None, separator=','):
Expand Down Expand Up @@ -195,7 +197,7 @@ class GraphData(BaseData):
""" Class of graph data which is used for DGLGraph construction. Internal use only. """

def __init__(self, graph_id, data):
self.graph_id = np.array(graph_id, dtype=np.int64)
self.graph_id = np.array(graph_id)
self.data = data
_validate_data_length({**{'graph_id': self.graph_id}, **self.data})

Expand Down Expand Up @@ -269,7 +271,7 @@ def assign_data(type, src_data, dst_data):


class DefaultDataParser:
""" Default data parser for DGLCSVDataset. It
""" Default data parser for CSVDataset. It
1. ignores any columns which does not have a header.
2. tries to convert to list of numeric values(generated by
np.array().tolist()) if cell data is a str separated by ','.
Expand Down
Loading

0 comments on commit 39121df

Please sign in to comment.