Skip to content

Commit

Permalink
fix scaling and dataloader len
Browse files Browse the repository at this point in the history
  • Loading branch information
chendiqian committed Dec 12, 2022
1 parent 146c361 commit 093316b
Show file tree
Hide file tree
Showing 8 changed files with 13 additions and 18 deletions.
4 changes: 4 additions & 0 deletions dataloaders/BaseLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def __getitem__(self, *args, **kwargs):
def __len__(self, *args, **kwargs):
raise NotImplementedError

@property
def loader_len(self):
return len(self)

def __collate__(self, *args, **kwargs):
raise NotImplementedError

Expand Down
1 change: 1 addition & 0 deletions dataloaders/GraphSAINTRWSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def __filename__(self):
def __len__(self):
return len(self.output_indices)

@property
def loader_len(self):
return ceil(len(self.output_indices) / self.batch_size)

Expand Down
1 change: 1 addition & 0 deletions dataloaders/IBMBRandLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __getitem__(self, idx):
def __len__(self):
return len(self.out_aux_pairs)

@property
def loader_len(self):
return ceil(len(self.out_aux_pairs) / self.batch_size)

Expand Down
1 change: 1 addition & 0 deletions dataloaders/LADIESSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __getitem__(self, idx):
def __len__(self):
return len(self.output_indices)

@property
def loader_len(self):
return ceil(len(self.output_indices) / self.batch_size)

Expand Down
1 change: 1 addition & 0 deletions dataloaders/NeighborSamplingLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __getitem__(self, idx):
def __len__(self):
return len(self.node_idx)

@property
def loader_len(self):
return ceil(len(self.node_idx) / self.batch_size)

Expand Down
1 change: 1 addition & 0 deletions dataloaders/ShaDowLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __getitem__(self, idx):
def __len__(self):
return len(self.out_aux_pairs)

@property
def loader_len(self):
return ceil(len(self.out_aux_pairs) / self.batch_size)

Expand Down
15 changes: 1 addition & 14 deletions train/prefetch_generators.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
import queue
import threading

from dataloaders.GraphSAINTRWSampler import SaintRWValSampler
from dataloaders.IBMBRandLoader import IBMBRandLoader
from dataloaders.ShaDowLoader import ShaDowLoader
from dataloaders.LADIESSampler import LADIESSampler
from dataloaders.NeighborSamplingLoader import NeighborSamplingLoader


class BaseGenerator(threading.Thread):
def __init__(self, max_prefetch=1, device='cuda'):
Expand Down Expand Up @@ -37,13 +31,6 @@ def __init__(self, dataloader, max_prefetch=1, device='cuda'):

def run(self):
for i, graph in enumerate(self.dataloader):
if isinstance(self.dataloader, (SaintRWValSampler,
ShaDowLoader,
IBMBRandLoader,
LADIESSampler,
NeighborSamplingLoader)):
stop_signal = i == self.dataloader.loader_len() - 1
else:
stop_signal = i == len(self.dataloader) - 1
stop_signal = i == self.dataloader.loader_len - 1
self.queue.put((graph.to(self.device, non_blocking=True), stop_signal))
self.queue.put(None)
7 changes: 3 additions & 4 deletions train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ def __init__(self,

def get_loss_scaling(self, len_loader: int):
micro_batch = int(min(self.micro_batch, len_loader))
num_batches = ceil(len_loader / self.batch_size)
loss_scaling_lst = [micro_batch] * (num_batches // micro_batch) + [num_batches % micro_batch]
loss_scaling_lst = [micro_batch] * (len_loader // micro_batch) + [len_loader % micro_batch]
return loss_scaling_lst, micro_batch

def train(self,
Expand Down Expand Up @@ -90,7 +89,7 @@ def train(self,

# train
model.train()
loss_scaling_lst, cur_micro_batch = self.get_loss_scaling(len(train_loader))
loss_scaling_lst, cur_micro_batch = self.get_loss_scaling(train_loader.loader_len)
loader, next_loader = next_loader, None

start_time = time.time()
Expand Down Expand Up @@ -385,7 +384,7 @@ def full_graph_inference(self,
adj = BaseLoader.normalize_adjmat(adj, normalization='sym')

outputs = model.chunked_pass(MyGraph(x=graph.x, adj=adj, idx=torch.from_numpy(mask)),
self.num_batches // self.batch_size).detach().numpy()
self.num_batches // self.batch_size).detach().numpy() # an estimate of #chunks

for cat in ['val', 'test']:
nodes = val_nodes if cat == 'val' else test_nodes
Expand Down

0 comments on commit 093316b

Please sign in to comment.