Skip to content

Commit

Permalink
Minor fix to DGL Enter (dmlc#3753)
Browse files Browse the repository at this point in the history
* [Fix] Convert float64 to float32 when creating tensor

* fix

Co-authored-by: RhettYing <[email protected]>
Co-authored-by: Rhett Ying <[email protected]>
  • Loading branch information
3 people authored Feb 18, 2022
1 parent 5558ce2 commit f247d29
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions enter/dglenter/utils/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ class NodeBase(DGLBaseModel):

DataFactory.register(
"csv",
import_code="from dgl.data import DGLCSVDataset",
import_code="from dgl.data import CSVDataset",
extra_args={"data_path": "./"},
class_name="DGLCSVDataset({})",
class_name="CSVDataset({})",
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"])

DataFactory.register(
Expand Down
8 changes: 4 additions & 4 deletions python/dgl/data/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self,
self.split_ratio = split_ratio
self.target_ntype = target_ntype
super().__init__(self.dataset.name + '-as-nodepred',
hash_key=(split_ratio, target_ntype), **kwargs)
hash_key=(split_ratio, target_ntype, dataset.name, 'nodepred'), **kwargs)

def process(self):
is_ogb = hasattr(self.dataset, 'get_idx_split')
Expand Down Expand Up @@ -211,7 +211,7 @@ class AsLinkPredDataset(DGLDataset):
Dataset("cora_v2", num_graphs=1, save_path=...)
>>> new_ds = dgl.data.AsNodePredDataset(ds, [0.8, 0.1, 0.1])
>>> print(new_ds)
Dataset("cora_v2-as-edgepred", num_graphs=1, save_path=/home/ubuntu/.dgl/cora_v2-as-edgepred)
Dataset("cora_v2-as-linkpred", num_graphs=1, save_path=/home/ubuntu/.dgl/cora_v2-as-linkpred)
>>> print(hasattr(new_ds, "get_test_edges"))
True
"""
Expand All @@ -226,8 +226,8 @@ def __init__(self,
self.dataset = dataset
self.split_ratio = split_ratio
self.neg_ratio = neg_ratio
super().__init__(dataset.name + '-as-edgepred',
hash_key=(neg_ratio, split_ratio), **kwargs)
super().__init__(dataset.name + '-as-linkpred',
hash_key=(neg_ratio, split_ratio, dataset.name, 'linkpred'), **kwargs)

def process(self):
if self.split_ratio is None:
Expand Down

0 comments on commit f247d29

Please sign in to comment.