forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support unpacking python dictionary in torch.jit.trace() (pytorch#81623)
# Support unpacking python dictionary in **torch.jit.trace()** ## Problem statement & Motivation ### Problem 1(usability): Say, if you have a model and its forward method defined as follows: **`def forward(self, key1=value1, key2=value2, key3=value3)`** And you have a dataset and each data point in the dataset is a python dict as follows: **`data = {key1:value1, key3:value3, key2:value2}`** The problem is that if you want to trace the model using the dict data by the giving dataset, you need unpack the dictionary and reorder its value manually and make up a tuple as **`data_tuple = (value1, value2, value3)`** as the **`example_inputs`** parameter of **`torch.jit.trace()`**. This marshalling process is not user friendly. ### Problem 2 (feasibility): Say, if you have a model and its forward method defined as follows: **`def forward(self, key1=None, key2=None, key3=None)`** -> The default value is **None** And you have a dataset and each data point in the dataset is a python dict as follows: **`data = {key1:value1, key3:value3}`** -> Only **part of** the required value by forward was given, the rest use the default value. The problem is that if you want to trace the model using the dict data by the giving dataset, it's not feasible at all. Cause neither you can pass a tuple like **`T1 = (value1, value3)`** nor **`T2 = (value1, None, value3)`**. T1 will mismatch value3 with key2 and T2 include **None** type which will be blocked by tracer's type checking. (Of course you can pass **`T3 = (value1,)`** to make the trace function finish without exception, but the traced model you get probably is not what you expect cause the different input may result in different traced result.). These problems come from the HuggingFace's PT model, especially in text-classification tasks with datasets such as [MRPC,](https://paperswithcode.com/dataset/mrpc) [MNLI](https://paperswithcode.com/dataset/multinli) etc. ## Solution To address these two issues, we propose to support a new type, that is, python dict as example_inputs parameter for torch.jit.trace(). We can base on the runtime type information of the example_inputs object to determine if we fall back to the original tuple path or go into the new dictionary path. Both problem 1 and problem 2 can be solved by utilizing the "**`**`**" operator. ## Limitation & Mitigation 1. If we use dict as example_inputs to trace the model, then we have to pass a dictionary to the traced model too. (Cause probably we will change the order of debug name of the input parameter in torchscript IR, thus we can't assume the traced model's input parameters order are the same with the original model.). We need highlight this too in the document to mitigate this problem. For example: ``` # fetch a data from dataloader, and the data is a dictionary # and the example_inputs_dict is like: {key1:value1, key3:value3, key2:value2} # the forward() is like: def forward(self, key1=value1, key2=value2, key3=value3) example_inputs_dict = next(iter(dataloader)) jit_model = model.eval() # use the dictionary to trace the model jit_model = torch.jit.trace(jit_model, example_inputs_dict, strict=False) # Now the IR will be graph(%self : __torch__.module.___torch_mangle_n.Mymodule, %key1 : type1, %key3 : type3, %key2 : type2) jit_model = torch.jit.freeze(jit_model) # It's OK to use dict as the parameter for traced model jit_model(**example_inputs_dict) example_inputs_tuple = (value1, value3, value2) # It's wrong to rely on the original args order. jit_model(*example_inputs_tuple) ``` ## Note 1. This PR will make some UT introduced in [39601](pytorch#39601) fail, which I think should be classified as unpacking a tuple containing a single dictionary element in our solution. 4. I think there is ambiguity since currently we only specify passing a tuple or a single Tensor as our example_inputs parameter in **torch.jit.trace()**'s documentation, but it seems we can still passing a dictionary. Pull Request resolved: pytorch#81623 Approved by: https://github.com/davidberard98
- Loading branch information
1 parent
bdefa26
commit 7980ed9
Showing
7 changed files
with
321 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.