Skip to content

Commit

Permalink
Changed device assignments to be compatible with data-parallel. (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
uzaymacar authored Jul 1, 2021
1 parent cacaa2d commit 3a5e37a
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions modeling/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def forward(self, x1, x2):

if self.cfg.task == '2':
# Concat. TPs
x = torch.cat([x1, x2], dim=1).to(self.cfg.device)
x = torch.cat([x1, x2], dim=1).to(x1.device)
# x: (B, 2, SV, SV, SV)
else:
# Discard x2
Expand Down Expand Up @@ -581,9 +581,9 @@ def __init__(self, cfg):

def forward(self, x):
# Get position IDs and timepoint IDs
position_ids = torch.cat([torch.arange(self.cfg.num_patches)] * 2).unsqueeze(0).to(self.cfg.device)
position_ids = torch.cat([torch.arange(self.cfg.num_patches)] * 2).unsqueeze(0).to(x.device)
# position_ids: (1, 2 * S)
timepoint_ids = torch.tensor([0] * self.cfg.num_patches + [1] * self.cfg.num_patches).unsqueeze(0).to(self.cfg.device)
timepoint_ids = torch.tensor([0] * self.cfg.num_patches + [1] * self.cfg.num_patches).unsqueeze(0).to(x.device)
# timepoint_ids: (1, 2 * S)

x, context_out, context_features = self.embeddings(x, position_ids, timepoint_ids)
Expand Down Expand Up @@ -632,7 +632,7 @@ def forward(self, x1, x2):
# x2: (B, S, 1, P, P, P)

# Concat. TPs
x = torch.cat([x1, x2], dim=1).to(self.cfg.device)
x = torch.cat([x1, x2], dim=1).to(x1.device)
# x: (B, 2 * S, 1, P, P, P)

x, attention_weights, context_out, context_features = self.transformer(x)
Expand Down

0 comments on commit 3a5e37a

Please sign in to comment.