Skip to content

Commit

Permalink
Only warn once for TypedStorage deprecation (pytorch#97379)
Browse files Browse the repository at this point in the history
Fixes pytorch#97207

Pull Request resolved: pytorch#97379
Approved by: https://github.com/ezyang
  • Loading branch information
kurtamohler authored and pytorchmergebot committed Mar 23, 2023
1 parent b507d7d commit fbc803d
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 17 deletions.
45 changes: 35 additions & 10 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
skipCUDAMemoryLeakCheckIf, BytesIOContext,
skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName,
wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard,
skipIfNotRegistered, bytes_to_scalar, parametrize, skipIfMps, noncontiguous_like)
skipIfNotRegistered, bytes_to_scalar, parametrize, skipIfMps, noncontiguous_like,
AlwaysWarnTypedStorageRemoval)
from multiprocessing.reduction import ForkingPickler
from torch.testing._internal.common_device_type import (
expectedFailureMeta,
Expand Down Expand Up @@ -6821,15 +6822,39 @@ def test_typed_storage_deprecation_warning(self):
# Check that each of the TypedStorage function calls produce a warning
# if warnings are reset between each
for f in funcs:
with warnings.catch_warnings(record=True) as w:
warnings.resetwarnings()
f()
self.assertEqual(len(w), 1, msg=str([str(a) for a in w]))
warning = w[0].message
self.assertTrue(warning, DeprecationWarning)
self.assertTrue(re.search(
'^TypedStorage is deprecated',
str(warning)))
with AlwaysWarnTypedStorageRemoval(True):
with warnings.catch_warnings(record=True) as w:
warnings.resetwarnings()
f()
self.assertEqual(len(w), 1, msg=str([str(a) for a in w]))
warning = w[0].message
self.assertTrue(warning, DeprecationWarning)
self.assertTrue(re.search(
'^TypedStorage is deprecated',
str(warning)))

# Test that only the first warning is raised by default
torch.storage._reset_warn_typed_storage_removal()
with warnings.catch_warnings(record=True) as w:
warnings.resetwarnings()
torch.FloatStorage()
torch.randn(10).storage()
self.assertEqual(len(w), 1, msg=str([str(a) for a in w]))
warning = w[0].message
self.assertTrue(re.search(
'^TypedStorage is deprecated',
str(warning)))
# Check the line of code from the warning's stack
with open(w[0].filename) as f:
code_line = f.readlines()[w[0].lineno - 1]
self.assertTrue(re.search(re.escape('torch.FloatStorage()'), code_line))

# Check that warnings are not emitted if it happened in the past
with warnings.catch_warnings(record=True) as w:
warnings.resetwarnings()
torch.FloatStorage()
torch.randn(10).storage()
self.assertEqual(len(w), 0, msg=str([str(a) for a in w]))

def test_from_file(self):
def assert_with_filename(filename):
Expand Down
37 changes: 30 additions & 7 deletions torch/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,14 +362,37 @@ def _isint(x):
else:
return isinstance(x, int)

_always_warn_typed_storage_removal = False

def _get_always_warn_typed_storage_removal():
return _always_warn_typed_storage_removal

def _set_always_warn_typed_storage_removal(always_warn):
global _always_warn_typed_storage_removal
assert isinstance(always_warn, bool)
_always_warn_typed_storage_removal = always_warn

def _warn_typed_storage_removal(stacklevel=2):
message = (
"TypedStorage is deprecated. It will be removed in the future and "
"UntypedStorage will be the only storage class. This should only matter "
"to you if you are using storages directly. To access UntypedStorage "
"directly, use tensor.untyped_storage() instead of tensor.storage()"
)
warnings.warn(message, UserWarning, stacklevel=stacklevel + 1)
global _always_warn_typed_storage_removal

def is_first_time():
if not hasattr(_warn_typed_storage_removal, 'has_warned'):
return True
else:
return not _warn_typed_storage_removal.__dict__['has_warned']

if _get_always_warn_typed_storage_removal() or is_first_time():
message = (
"TypedStorage is deprecated. It will be removed in the future and "
"UntypedStorage will be the only storage class. This should only matter "
"to you if you are using storages directly. To access UntypedStorage "
"directly, use tensor.untyped_storage() instead of tensor.storage()"
)
warnings.warn(message, UserWarning, stacklevel=stacklevel + 1)
_warn_typed_storage_removal.__dict__['has_warned'] = True

def _reset_warn_typed_storage_removal():
_warn_typed_storage_removal.__dict__['has_warned'] = False

class TypedStorage:
is_sparse = False
Expand Down
12 changes: 12 additions & 0 deletions torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,18 @@ def __exit__(self, exception_type, exception_value, traceback):
self.deterministic_restore,
warn_only=self.warn_only_restore)

class AlwaysWarnTypedStorageRemoval:
def __init__(self, always_warn):
assert isinstance(always_warn, bool)
self.always_warn = always_warn

def __enter__(self):
self.always_warn_restore = torch.storage._get_always_warn_typed_storage_removal()
torch.storage._set_always_warn_typed_storage_removal(self.always_warn)

def __exit__(self, exception_type, exception_value, traceback):
torch.storage._set_always_warn_typed_storage_removal(self.always_warn_restore)

# Context manager for setting cuda sync debug mode and reset it
# to original value
# we are not exposing it to the core because sync debug mode is
Expand Down

0 comments on commit fbc803d

Please sign in to comment.