Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get_total_memory_used fails to handle list of str #254

Open
minostauros opened this issue May 26, 2023 · 1 comment
Open

get_total_memory_used fails to handle list of str #254

minostauros opened this issue May 26, 2023 · 1 comment

Comments

@minostauros
Copy link

minostauros commented May 26, 2023

def get_total_memory_used(data: CORRECTED_INPUT_DATA_TYPE) -> int:

>>> get_total_memory_used(["abc", "def"])
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.9/dist-packages/torchinfo/torchinfo.py", line 503, in get_total_memory_used
    result = traverse_input_data(
  File "/usr/local/lib/python3.9/dist-packages/torchinfo/torchinfo.py", line 447, in traverse_input_data
    result = aggregate(
TypeError: unsupported operand type(s) for +: 'int' and 'str'
>>> 

action_fn is not applied to str so that sys.getsizeof fails to get size of strings.

@minostauros
Copy link
Author

minostauros commented May 26, 2023

This happens when an input of a model is a list of strings, e.g., language models.

Possible dirty workaround

def get_total_memory_used(data: CORRECTED_INPUT_DATA_TYPE) -> int:
    """Calculates the total memory of all tensors stored in data."""
    result = traverse_input_data(
        data,
        action_fn=lambda data: sys.getsizeof(
            data.untyped_storage()
            if hasattr(data, "untyped_storage")
            else data.storage()
        ),
        aggregate_fn=(
            # We don't need the dictionary keys in this case
            # if the data is not integer, assume the above action_fn is not applied for some reason
            (
                lambda data: (
                    lambda d: sum(d.values())
                    if isinstance(d, Mapping)
                    else sys.getsizeof(d)
                )
            )
            if (isinstance(data, Mapping) or not isinstance(data, int))
            else sum
        ),
    )
    return cast(int, result)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant