diff --git a/util.py b/util.py index 2ef4673..6dbdb6b 100644 --- a/util.py +++ b/util.py @@ -73,6 +73,7 @@ def predict_transform(prediction, inp_dim, anchors, num_classes, CUDA = True): if CUDA: x_offset = x_offset.cuda() y_offset = y_offset.cuda() + prediction = prediction.cuda() x_y_offset = torch.cat((x_offset, y_offset), 1).repeat(1,num_anchors).view(-1,2).unsqueeze(0)