Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

can't torch.export.export the model and set batch size to dynamic #109

Open
rbavery opened this issue Dec 19, 2023 · 0 comments
Open

can't torch.export.export the model and set batch size to dynamic #109

rbavery opened this issue Dec 19, 2023 · 0 comments

Comments

@rbavery
Copy link

rbavery commented Dec 19, 2023

the model can be exported with out any dynamic shapes.

import torch
import torchvision
from torch.export import export, Dim
from segment_anything_fast import sam_model_fast_registry, SamPredictor

sam_checkpoint = "../segment-anything-fast/experiments/sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"

sam = sam_model_fast_registry[model_type](checkpoint=sam_checkpoint)
# sam.to(device=device)
predictor = SamPredictor(sam)
encoder = predictor.model.image_encoder

example_args = (torch.randn(2,3,1024, 1024, dtype=torch.bfloat16),)
# Create a dynamic batch size
batch = Dim("batch")
h = Dim("h")
w = Dim("w")
# # Specify that the batch and height and width dimensions are dynamic
# dynamic_shapes=(({0: Dim("batch"), 2: Dim("h"), 3: Dim("w")},),)
dynamic_shapes=()
exported_program = export(predictor.model.image_encoder, args=example_args, dynamic_shapes=dynamic_shapes)

but setting the batch size to be dynamic causes a strange error.

dynamic_shapes=(({0: Dim("batch")},),)
UserError: Expecting `args` to be a tuple of example positional inputs, got <class 'torch.Tensor'>

and setting h or w to dynamic triggers guard errors. does segment anything fast not support dynamic input shapes? this would be great so that inputs don't need to be resized.

the full thread where I'm trying to debug torch export is here. I tried various options for nesting the dynamic shape specs but can't find a combination that works for segment anything fast.

https://pytorch.slack.com/archives/C3PDTEV8E/p1702780972665469

Below is a minimal working example of dynamic shape export provided by Angela Yi in the slack channel

import torch
from torch.export import export, Dim
def g(x):
    return x + x

def f(*args):
    return g(*args)

example_args = (torch.randn(2,3,1024, 1024),)
dynamic_shapes=(({0: Dim("batch")},),)
export(f, example_args, dynamic_shapes=dynamic_shapes)

I would expect dynamic_shapes=(({0: Dim("batch")},),) to work for the segment anything fast encoder as well given that the encoder takes an input with the same dimensions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant