You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
>>> get_total_memory_used(["abc", "def"])
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.9/dist-packages/torchinfo/torchinfo.py", line 503, in get_total_memory_used
result = traverse_input_data(
File "/usr/local/lib/python3.9/dist-packages/torchinfo/torchinfo.py", line 447, in traverse_input_data
result = aggregate(
TypeError: unsupported operand type(s) for +: 'int' and 'str'
>>>
action_fn is not applied to str so that sys.getsizeof fails to get size of strings.
The text was updated successfully, but these errors were encountered:
This happens when an input of a model is a list of strings, e.g., language models.
Possible dirty workaround
defget_total_memory_used(data: CORRECTED_INPUT_DATA_TYPE) ->int:
"""Calculates the total memory of all tensors stored in data."""result=traverse_input_data(
data,
action_fn=lambdadata: sys.getsizeof(
data.untyped_storage()
ifhasattr(data, "untyped_storage")
elsedata.storage()
),
aggregate_fn=(
# We don't need the dictionary keys in this case# if the data is not integer, assume the above action_fn is not applied for some reason
(
lambdadata: (
lambdad: sum(d.values())
ifisinstance(d, Mapping)
elsesys.getsizeof(d)
)
)
if (isinstance(data, Mapping) ornotisinstance(data, int))
elsesum
),
)
returncast(int, result)
torchinfo/torchinfo/torchinfo.py
Line 501 in 73ed568
action_fn
is not applied to str so that sys.getsizeof fails to get size of strings.The text was updated successfully, but these errors were encountered: