Skip to content

Commit

Permalink
[Cleanup] Change Byte to Bool for training masks (dmlc#954)
Browse files Browse the repository at this point in the history
* Change Byte to Bool for training masks

* Check if module has Bool, otherwise use Byte
  • Loading branch information
Jacob Stevens authored and jermainewang committed Oct 29, 2019
1 parent 86cf154 commit 98c1448
Show file tree
Hide file tree
Showing 15 changed files with 121 additions and 46 deletions.
11 changes: 8 additions & 3 deletions examples/pytorch/appnp/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@ def main(args):
data = load_data(args)
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
Expand Down
11 changes: 8 additions & 3 deletions examples/pytorch/cluster_gcn/cluster_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,14 @@ def main(args):
labels = torch.LongTensor(data.labels)
else:
labels = torch.FloatTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask).type(torch.bool)
val_mask = torch.ByteTensor(data.val_mask).type(torch.bool)
test_mask = torch.ByteTensor(data.test_mask).type(torch.bool)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
Expand Down
11 changes: 8 additions & 3 deletions examples/pytorch/dgi/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,14 @@ def main(args):
data = load_data(args)
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
Expand Down
11 changes: 8 additions & 3 deletions examples/pytorch/gat/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,14 @@ def main(args):
data = load_data(args)
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
num_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
Expand Down
11 changes: 8 additions & 3 deletions examples/pytorch/gcn/gcn_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,14 @@ def main(args):
data = load_data(args)
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
Expand Down
11 changes: 8 additions & 3 deletions examples/pytorch/gcn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@ def main(args):
data = load_data(args)
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
Expand Down
11 changes: 8 additions & 3 deletions examples/pytorch/graphsage/graphsage.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,14 @@ def main(args):
data = load_data(args)
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
Expand Down
13 changes: 9 additions & 4 deletions examples/pytorch/model_zoo/citation_network/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,14 @@ def main(args):
data = load_data(args)
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
Expand Down Expand Up @@ -147,4 +152,4 @@ def main(args):
help="graph self-loop (default=False)")
args = parser.parse_args()
print(args)
main(args)
main(args)
11 changes: 8 additions & 3 deletions examples/pytorch/sampling/dis_sampling/gcn_cv_sc_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@ def main(args):

features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
Expand Down
11 changes: 8 additions & 3 deletions examples/pytorch/sampling/dis_sampling/gcn_ns_sc_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@ def main(args):

features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
Expand Down
11 changes: 8 additions & 3 deletions examples/pytorch/sampling/gcn_cv_sc.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,14 @@ def main(args):

features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
Expand Down
11 changes: 8 additions & 3 deletions examples/pytorch/sampling/gcn_ns_sc.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,14 @@ def main(args):

features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
Expand Down
11 changes: 8 additions & 3 deletions examples/pytorch/sgc/sgc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,14 @@ def main(args):
data = load_data(args)
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
Expand Down
11 changes: 8 additions & 3 deletions examples/pytorch/sgc/sgc_reddit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,14 @@ def main(args):
data = load_data(args)
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
Expand Down
11 changes: 8 additions & 3 deletions examples/pytorch/tagcn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,14 @@ def main(args):
data = load_data(args)
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
Expand Down

0 comments on commit 98c1448

Please sign in to comment.