Skip to content

Commit

Permalink
Add missing __main__ in two unittests (pytorch#97302)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#97302
Approved by: https://github.com/zou3519
  • Loading branch information
ppwwyyxx authored and pytorchmergebot committed Mar 22, 2023
1 parent 28929b1 commit 726fc36
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
6 changes: 5 additions & 1 deletion test/test_comparison_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Owner(s): ["module: internals"]

import torch
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_utils import TestCase, run_tests

class TestComparisonUtils(TestCase):
def test_all_equal_no_assert(self):
Expand Down Expand Up @@ -30,3 +30,7 @@ def test_assert_sizes(self):

with self.assertRaises(RuntimeError):
torch._assert_tensor_metadata(t, [3], [1], torch.float)


if __name__ == '__main__':
run_tests()
6 changes: 5 additions & 1 deletion test/test_pruning_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from hypothesis import given
import numpy as np
import torch
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_utils import TestCase, run_tests
import torch.testing._internal.hypothesis_utils as hu
hu.assert_deadline_disabled()

Expand Down Expand Up @@ -76,3 +76,7 @@ def test_rowwise_prune_op_32bit_indices(self, embedding_rows, embedding_dims, we
)
def test_rowwise_prune_op_64bit_indices(self, embedding_rows, embedding_dims, weights_dtype):
self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int64, weights_dtype)


if __name__ == '__main__':
run_tests()

0 comments on commit 726fc36

Please sign in to comment.