From fbc803df0c420db84429c51599f4fa4354b4493f Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Thu, 23 Mar 2023 05:40:19 +0000 Subject: [PATCH] Only warn once for TypedStorage deprecation (#97379) Fixes #97207 Pull Request resolved: https://github.com/pytorch/pytorch/pull/97379 Approved by: https://github.com/ezyang --- test/test_torch.py | 45 +++++++++++++++++++------ torch/storage.py | 37 ++++++++++++++++---- torch/testing/_internal/common_utils.py | 12 +++++++ 3 files changed, 77 insertions(+), 17 deletions(-) diff --git a/test/test_torch.py b/test/test_torch.py index 068c81a211993..34dfcaa12ffa0 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -37,7 +37,8 @@ skipCUDAMemoryLeakCheckIf, BytesIOContext, skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName, wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard, - skipIfNotRegistered, bytes_to_scalar, parametrize, skipIfMps, noncontiguous_like) + skipIfNotRegistered, bytes_to_scalar, parametrize, skipIfMps, noncontiguous_like, + AlwaysWarnTypedStorageRemoval) from multiprocessing.reduction import ForkingPickler from torch.testing._internal.common_device_type import ( expectedFailureMeta, @@ -6821,15 +6822,39 @@ def test_typed_storage_deprecation_warning(self): # Check that each of the TypedStorage function calls produce a warning # if warnings are reset between each for f in funcs: - with warnings.catch_warnings(record=True) as w: - warnings.resetwarnings() - f() - self.assertEqual(len(w), 1, msg=str([str(a) for a in w])) - warning = w[0].message - self.assertTrue(warning, DeprecationWarning) - self.assertTrue(re.search( - '^TypedStorage is deprecated', - str(warning))) + with AlwaysWarnTypedStorageRemoval(True): + with warnings.catch_warnings(record=True) as w: + warnings.resetwarnings() + f() + self.assertEqual(len(w), 1, msg=str([str(a) for a in w])) + warning = w[0].message + self.assertTrue(warning, DeprecationWarning) + self.assertTrue(re.search( + '^TypedStorage is deprecated', + str(warning))) + + # Test that only the first warning is raised by default + torch.storage._reset_warn_typed_storage_removal() + with warnings.catch_warnings(record=True) as w: + warnings.resetwarnings() + torch.FloatStorage() + torch.randn(10).storage() + self.assertEqual(len(w), 1, msg=str([str(a) for a in w])) + warning = w[0].message + self.assertTrue(re.search( + '^TypedStorage is deprecated', + str(warning))) + # Check the line of code from the warning's stack + with open(w[0].filename) as f: + code_line = f.readlines()[w[0].lineno - 1] + self.assertTrue(re.search(re.escape('torch.FloatStorage()'), code_line)) + + # Check that warnings are not emitted if it happened in the past + with warnings.catch_warnings(record=True) as w: + warnings.resetwarnings() + torch.FloatStorage() + torch.randn(10).storage() + self.assertEqual(len(w), 0, msg=str([str(a) for a in w])) def test_from_file(self): def assert_with_filename(filename): diff --git a/torch/storage.py b/torch/storage.py index e7fe36a4ae7cb..a332b2e59a6a7 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -362,14 +362,37 @@ def _isint(x): else: return isinstance(x, int) +_always_warn_typed_storage_removal = False + +def _get_always_warn_typed_storage_removal(): + return _always_warn_typed_storage_removal + +def _set_always_warn_typed_storage_removal(always_warn): + global _always_warn_typed_storage_removal + assert isinstance(always_warn, bool) + _always_warn_typed_storage_removal = always_warn + def _warn_typed_storage_removal(stacklevel=2): - message = ( - "TypedStorage is deprecated. It will be removed in the future and " - "UntypedStorage will be the only storage class. This should only matter " - "to you if you are using storages directly. To access UntypedStorage " - "directly, use tensor.untyped_storage() instead of tensor.storage()" - ) - warnings.warn(message, UserWarning, stacklevel=stacklevel + 1) + global _always_warn_typed_storage_removal + + def is_first_time(): + if not hasattr(_warn_typed_storage_removal, 'has_warned'): + return True + else: + return not _warn_typed_storage_removal.__dict__['has_warned'] + + if _get_always_warn_typed_storage_removal() or is_first_time(): + message = ( + "TypedStorage is deprecated. It will be removed in the future and " + "UntypedStorage will be the only storage class. This should only matter " + "to you if you are using storages directly. To access UntypedStorage " + "directly, use tensor.untyped_storage() instead of tensor.storage()" + ) + warnings.warn(message, UserWarning, stacklevel=stacklevel + 1) + _warn_typed_storage_removal.__dict__['has_warned'] = True + +def _reset_warn_typed_storage_removal(): + _warn_typed_storage_removal.__dict__['has_warned'] = False class TypedStorage: is_sparse = False diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index c4602d3afb92c..4e14f9ae6ef16 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1168,6 +1168,18 @@ def __exit__(self, exception_type, exception_value, traceback): self.deterministic_restore, warn_only=self.warn_only_restore) +class AlwaysWarnTypedStorageRemoval: + def __init__(self, always_warn): + assert isinstance(always_warn, bool) + self.always_warn = always_warn + + def __enter__(self): + self.always_warn_restore = torch.storage._get_always_warn_typed_storage_removal() + torch.storage._set_always_warn_typed_storage_removal(self.always_warn) + + def __exit__(self, exception_type, exception_value, traceback): + torch.storage._set_always_warn_typed_storage_removal(self.always_warn_restore) + # Context manager for setting cuda sync debug mode and reset it # to original value # we are not exposing it to the core because sync debug mode is