Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can torchinfo support BEVFusion (https://github.com/mit-han-lab/bevfusion) ? #267

Open
dpan817 opened this issue Aug 2, 2023 · 2 comments

Comments

@dpan817
Copy link

dpan817 commented Aug 2, 2023

Has anyone tried torchinfo with BEVFusion? I tried it, but it reported that "TypeError: Model contains a layer with an unsupported input or output type: <mmdet3d.ops.spconv.structure.SparseConvTensor object at 0x7f3d9a48fee0>, type: <class 'mmdet3d.ops.spconv.structure.SparseConvTensor'>"

@TylerYep
Copy link
Owner

Can you post the full code used to reproduce this error?

@dpan817
Copy link
Author

dpan817 commented Sep 1, 2023

sorry for the later reply, as I worked on other issues in the past two weeks.

I debugged the code without torchinfo and get the parameters for the model forward, then compose the same parameters for summary() function call, but still failed.

the model forward parameters is :
data_parallel_module

then I compose the parameters in summary() in tools/test.py

    if not distributed:
        model = MMDataParallel(model, device_ids=[0])
        print(f"Model:\n{model}")

        img_tensor=dataset[0].get('img').data
        img_tensor=img_tensor.unsqueeze(0)
        points_list=[dataset[0].get('points').data]
        camera2ego_tensor=dataset[0].get('camera2ego').data
        camera2ego_tensor=camera2ego_tensor.unsqueeze(0)
        lidar2ego_tensor=dataset[0].get('lidar2ego').data
        lidar2ego_tensor=lidar2ego_tensor.unsqueeze(0)
        lidar2camera_tensor=dataset[0].get('lidar2camera').data
        lidar2camera_tensor=lidar2camera_tensor.unsqueeze(0)
        lidar2image_tensor=dataset[0].get('lidar2image').data
        lidar2image_tensor=lidar2image_tensor.unsqueeze(0)
        camera_intrinsics_tensor=dataset[0].get('camera_intrinsics').data
        camera_intrinsics_tensor=camera_intrinsics_tensor.unsqueeze(0)
        camera2lidar_tensor=dataset[0].get('camera2lidar').data
        camera2lidar_tensor=camera2lidar_tensor.unsqueeze(0)
        img_aug_matrix_tensor=dataset[0].get('img_aug_matrix').data
        img_aug_matrix_tensor=img_aug_matrix_tensor.unsqueeze(0)
        lidar_aug_matrix_tensor=dataset[0].get('lidar_aug_matrix').data
        lidar_aug_matrix_tensor=lidar_aug_matrix_tensor.unsqueeze(0)
        metas_list=[dataset[0].get('metas').data]
        gt_masks_bev_tensor=torch.zeros(1, 6, 200, 200)
        gt_bboxes_3d_list=[dataset[0].get('gt_bboxes_3d').data]
        gt_labels_3d_list=[torch.tensor(dataset[0].get('gt_labels_3d').data,device='cuda:0')]

        args_dict = {
            'return_loss': False,
            'rescale': True,
            'img': img_tensor,
            'points': points_list,
            'gt_bboxes_3d': gt_bboxes_3d_list,
            'gt_labels_3d': gt_labels_3d_list,
            'gt_masks_bev': gt_masks_bev_tensor,
            'camera_intrinscis': camera_intrinsics_tensor,
            'camera2ego': camera2ego_tensor,
            'lidar2ego': lidar2ego_tensor,
            'lidar2camera': lidar2camera_tensor,
            'camera2lidar': camera2lidar_tensor,
            'lidar2image': lidar2image_tensor,
            'img_aug_matrix': img_aug_matrix_tensor,
            'lidar_aug_matrix': lidar_aug_matrix_tensor,
            'metas': metas_list
        }
        input_dict = { }

        summary(model, input_data=[input_dict, args_dict])

and the error is:

Traceback (most recent call last):
  File "tools/test.py", line 288, in <module>
    main()
  File "tools/test.py", line 250, in main
    summary(model, input_data=[input_dict, args_dict])
  File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 220, in summary
    x, correct_input_size = process_input(
  File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 246, in process_input
    correct_input_size = get_input_data_sizes(input_data)
  File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 496, in get_input_data_sizes
    return traverse_input_data(
  File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 448, in traverse_input_data
    [traverse_input_data(d, action_fn, aggregate_fn) for d in data]
  File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 448, in <listcomp>
    [traverse_input_data(d, action_fn, aggregate_fn) for d in data]
  File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 435, in traverse_input_data
    {
  File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 436, in <dictcomp>
    k: traverse_input_data(v, action_fn, aggregate_fn)
  File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 448, in traverse_input_data
    [traverse_input_data(d, action_fn, aggregate_fn) for d in data]
  File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 448, in <listcomp>
    [traverse_input_data(d, action_fn, aggregate_fn) for d in data]
  File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 447, in traverse_input_data
    result = aggregate(
  File "/home/adlink/Downloads/Lidar_AI_Solution/CUDA-BEVFusion/bevfusion/mmdet3d/core/bbox/structures/base_box3d.py", line 46, in __init__
    assert tensor.dim() == 2 and tensor.size(-1) == box_dim, tensor.size()
AssertionError: torch.Size([9, 1])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants