diff --git a/modeling/models.py b/modeling/models.py index 47d43ee..93961fb 100644 --- a/modeling/models.py +++ b/modeling/models.py @@ -364,7 +364,7 @@ def forward(self, x1, x2): if self.cfg.task == '2': # Concat. TPs - x = torch.cat([x1, x2], dim=1).to(self.cfg.device) + x = torch.cat([x1, x2], dim=1).to(x1.device) # x: (B, 2, SV, SV, SV) else: # Discard x2 @@ -581,9 +581,9 @@ def __init__(self, cfg): def forward(self, x): # Get position IDs and timepoint IDs - position_ids = torch.cat([torch.arange(self.cfg.num_patches)] * 2).unsqueeze(0).to(self.cfg.device) + position_ids = torch.cat([torch.arange(self.cfg.num_patches)] * 2).unsqueeze(0).to(x.device) # position_ids: (1, 2 * S) - timepoint_ids = torch.tensor([0] * self.cfg.num_patches + [1] * self.cfg.num_patches).unsqueeze(0).to(self.cfg.device) + timepoint_ids = torch.tensor([0] * self.cfg.num_patches + [1] * self.cfg.num_patches).unsqueeze(0).to(x.device) # timepoint_ids: (1, 2 * S) x, context_out, context_features = self.embeddings(x, position_ids, timepoint_ids) @@ -632,7 +632,7 @@ def forward(self, x1, x2): # x2: (B, S, 1, P, P, P) # Concat. TPs - x = torch.cat([x1, x2], dim=1).to(self.cfg.device) + x = torch.cat([x1, x2], dim=1).to(x1.device) # x: (B, 2 * S, 1, P, P, P) x, attention_weights, context_out, context_features = self.transformer(x)