Skip to content

Commit

Permalink
Support unpacking python dictionary in torch.jit.trace() (pytorch#81623)
Browse files Browse the repository at this point in the history
# Support unpacking python dictionary in **torch.jit.trace()**

## Problem statement & Motivation
### Problem 1(usability):
Say, if you have a model and its forward method defined as follows:
**`def forward(self, key1=value1, key2=value2, key3=value3)`**
And you have a dataset and each data point in the dataset is a python dict as follows:
**`data = {key1:value1, key3:value3, key2:value2}`**

The problem is that if you want to trace the model using the dict data by the giving dataset, you need unpack the dictionary and reorder its value manually and make up a tuple as **`data_tuple = (value1, value2, value3)`** as the **`example_inputs`** parameter of **`torch.jit.trace()`**. This marshalling process is not user friendly.

### Problem 2 (feasibility):
Say, if you have a model and its forward method defined as follows:
**`def forward(self, key1=None, key2=None, key3=None)`** -> The default value is **None**
And you have a dataset and each data point in the dataset is a python dict as follows:
**`data = {key1:value1, key3:value3}`** -> Only **part of** the required value by forward was given, the rest use the default value.

The problem is that if you want to trace the model using the dict data by the giving dataset, it's not feasible at all. Cause neither you can pass a tuple like **`T1 = (value1, value3)`**  nor **`T2 = (value1, None, value3)`**. T1 will mismatch value3 with key2 and T2 include **None** type which will be blocked by tracer's type checking. (Of course you can pass **`T3 = (value1,)`**  to make the trace function finish without exception, but the traced model you get probably is not what you expect cause the different input may result in different traced result.).

These problems come from the HuggingFace's PT model, especially in text-classification tasks with datasets such as [MRPC,](https://paperswithcode.com/dataset/mrpc)  [MNLI](https://paperswithcode.com/dataset/multinli) etc.

## Solution
To address these two issues, we propose to support a new type, that is, python dict as example_inputs parameter for torch.jit.trace(). We can base on the runtime type information of the example_inputs object to determine if we fall back to the original tuple path or go into the new dictionary path. Both problem 1 and  problem 2 can be solved by utilizing the "**`**`**"
operator.

## Limitation & Mitigation

1. If we use dict as example_inputs to trace the model, then we have to pass a dictionary to the traced model too. (Cause probably we will change the order of debug name of the input parameter in torchscript IR, thus we can't assume the traced model's input parameters order are the same with the original model.). We need highlight this too in the document to mitigate this problem.

    For example:
```
# fetch a data from dataloader, and the data is a dictionary
# and the example_inputs_dict is like: {key1:value1, key3:value3, key2:value2}
# the forward() is like: def forward(self, key1=value1, key2=value2, key3=value3)
example_inputs_dict = next(iter(dataloader))
jit_model = model.eval()
# use the dictionary to trace the model
jit_model = torch.jit.trace(jit_model, example_inputs_dict, strict=False)  # Now the IR will be graph(%self : __torch__.module.___torch_mangle_n.Mymodule, %key1 : type1, %key3 : type3, %key2 : type2)
jit_model = torch.jit.freeze(jit_model)

# It's OK to use dict as the parameter for traced model
jit_model(**example_inputs_dict)

example_inputs_tuple = (value1, value3, value2)
# It's wrong to rely on the original args order.
jit_model(*example_inputs_tuple)

```
## Note
1. This PR will make some UT introduced in [39601](pytorch#39601) fail, which I think should be classified as unpacking a tuple containing a single dictionary element in our solution.
4. I think there is ambiguity since currently we only specify passing a tuple or a single Tensor as our example_inputs parameter in **torch.jit.trace()**'s documentation, but it seems we can still passing a dictionary.

Pull Request resolved: pytorch#81623
Approved by: https://github.com/davidberard98
  • Loading branch information
tangleintel authored and pytorchmergebot committed Oct 15, 2022
1 parent bdefa26 commit 7980ed9
Show file tree
Hide file tree
Showing 7 changed files with 321 additions and 38 deletions.
40 changes: 40 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3027,6 +3027,46 @@ def forward(self, x):
checker.check("def forward")
checker.run(str(cm.exception))

def test_dictionary_as_example_inputs_for_jit_trace(self):
class TestModule_v1(torch.nn.Module):
def __init__(self):
super(TestModule_v1, self).__init__()

def forward(self, key2=None, key3=None, key4=None, key5=None, key1=None, key6=None):
return key1 + key2 + key3

class TestModule_v2(torch.nn.Module):
def __init__(self):
super(TestModule_v2, self).__init__()

def forward(self, x, y):
return x + y

def test_func(x, y):
return x + y
model_1 = TestModule_v1()
model_2 = TestModule_v2()
value1 = torch.ones(1)
value2 = torch.ones(1)
value3 = torch.ones(1)
example_input_dict = {'key1': value1, 'key2': value2, 'key3': value3}
example_input_dict_func = {'x': value1, 'y': value2}
traced_model_1 = torch.jit.trace(model_1, example_kwarg_inputs=example_input_dict, strict=False)
traced_model_1_m = torch.jit.trace_module(
model_1, {'forward': example_input_dict}, example_inputs_is_kwarg=True, strict=False)
traced_model_2 = torch.jit.trace(model_2, example_kwarg_inputs={'x': torch.rand([2]), 'y': torch.rand([2])})
traced_func = torch.jit.trace(test_func, example_kwarg_inputs=example_input_dict_func, strict=False)
res_1 = traced_model_1(**example_input_dict)
res_1_m = traced_model_1_m(**example_input_dict)
self.assertEqual(res_1, 3 * torch.ones(1))
self.assertEqual(res_1_m, 3 * torch.ones(1))
res_func = traced_func(**example_input_dict_func)
self.assertEqual(res_func, 2 * torch.ones(1))
with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'x'."):
res_2 = traced_model_2(**{'z': torch.rand([2]), 'y': torch.rand([2])})
with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'y'."):
res_2 = traced_model_2(**{'x': torch.rand([2]), 'z': torch.rand([2])})


class TestScript(JitTestCase):

Expand Down
9 changes: 9 additions & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,15 @@ def _create_function_from_trace(
force_outplace: _bool,
argument_names: List[str]
) -> Tuple[Graph, Stack]: ...
def _create_function_from_trace_with_dict(
qualname: str,
func: Callable[..., Any],
input_dict: Dict[str, Any],
var_lookup_fn: Callable[[Tensor], str],
strict: _bool,
force_outplace: _bool,
argument_names: List[str]
) -> Tuple[Graph, Stack]: ...
def _jit_is_script_object(obj: Any) -> _bool: ...
def _last_executed_optimized_graph() -> Graph: ...
def parse_type_comment(comment: str) -> Decl: ...
Expand Down
11 changes: 11 additions & 0 deletions torch/csrc/jit/python/pybind_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,17 @@ inline Stack toTraceableStack(const py::tuple& inputs) {
return info.toTupleRef().elements().vec();
}

// Serialize the python dictionary into a traceable stack.
inline Stack toTraceableStack(const py::dict& inputs) {
Stack res;
for (auto it = inputs.begin(); it != inputs.end(); it++) {
if (THPVariable_Check(it->second.ptr())) {
res.push_back(toIValue(it->second, tryToInferType(it->second).type()));
}
}
return res;
}

inline IValue createGenericList(py::handle obj, const TypePtr& elem_type) {
auto elems = c10::impl::GenericList(elem_type);
for (auto elem : obj) {
Expand Down
63 changes: 63 additions & 0 deletions torch/csrc/jit/python/python_tracer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,69 @@ SourceRange getPythonInterpreterSourceRange() {
return SourceRange(source, 0, stack_trace_text.size());
}

std::pair<std::shared_ptr<Graph>, Stack> createGraphByTracingWithDict(
const py::function& func,
const py::dict& inputs_dict,
Stack trace_inputs,
const py::function& var_name_lookup_fn,
bool strict,
bool force_outplace,
Module* self,
const std::vector<std::string>& argument_names) {
C10_LOG_API_USAGE_ONCE("torch.tracer");

auto lookup_fn_adapter =
[var_name_lookup_fn](const Variable& var) -> std::string {
pybind11::gil_scoped_acquire ag;
return py::cast<std::string>(var_name_lookup_fn(var));
};

// The argument_names parameter is parsed in python and its order
// is the same as the arguments' decalaration order in forward() method.
// These name shall be added to the graph as debug name and the order
// should align with the traceable stack we generated by the python dict.
std::vector<std::string> compact_argument_names;
Stack compact_trace_inputs;
for (std::vector<std::string>::size_type i = 0; i < argument_names.size();
i++) {
if (inputs_dict.contains(argument_names[i])) {
compact_argument_names.push_back(argument_names[i]);
}
}
for (std::vector<std::string>::size_type i = 0;
i < compact_argument_names.size();
i++) {
for (auto it = inputs_dict.begin(); it != inputs_dict.end(); it++) {
if (py::cast<std::string>(it->first) == compact_argument_names[i]) {
if (THPVariable_Check(it->second.ptr())) {
compact_trace_inputs.push_back(
toIValue(it->second, tryToInferType(it->second).type()));
}
}
}
}

auto outs = tracer::trace(
std::move(compact_trace_inputs),
[&](Stack inputs) -> Stack {
// We just leave the inputs_dict as it was and pass it to forward
// method.
auto out = func(**inputs_dict);
if (out.ptr() == Py_None) {
AT_ERROR(
"The traced function didn't return any values! Side-effects are not "
"captured in traces, so it would be a no-op.");
}
return {toTypeInferredIValue(out)};
},
lookup_fn_adapter,
strict,
force_outplace,
self,
compact_argument_names);
return std::make_pair(std::get<0>(outs)->graph, std::get<1>(outs));
}

std::pair<std::shared_ptr<Graph>, Stack> createGraphByTracing(
const py::function& func,
Stack trace_inputs,
Expand Down
10 changes: 10 additions & 0 deletions torch/csrc/jit/python/python_tracer.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ Node* preRecordPythonTrace(
at::ArrayRef<autograd::Variable> inputs,
std::vector<THPObjectPtr> scalar_args);

std::pair<std::shared_ptr<Graph>, Stack> createGraphByTracingWithDict(
const py::function& func,
const py::dict& inputs_dict,
Stack inputs,
const py::function& var_name_lookup_fn,
bool strict,
bool force_outplace,
Module* self = nullptr,
const std::vector<std::string>& argument_names = {});

std::pair<std::shared_ptr<Graph>, Stack> createGraphByTracing(
const py::function& func,
Stack inputs,
Expand Down
74 changes: 74 additions & 0 deletions torch/csrc/jit/python/script_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1218,6 +1218,43 @@ void initJitScriptBindings(PyObject* module) {
py::arg("strict"),
py::arg("force_outplace"),
py::arg("argument_names") = std::vector<std::string>())
.def(
"_create_method_from_trace_with_dict",
[](Module& self,
const std::string& name,
const py::function& func,
const py::dict& input_dict,
const py::function& var_name_lookup_fn,
bool strict,
bool force_outplace,
const std::vector<std::string>& argument_names) {
// prereq: Module's buffers and parameters are unique
// this was ensured in python before calling this function
auto typed_inputs = toTraceableStack(input_dict);

std::shared_ptr<Graph> graph =
std::get<0>(tracer::createGraphByTracingWithDict(
func,
input_dict,
typed_inputs,
var_name_lookup_fn,
strict,
force_outplace,
&self,
argument_names));
const auto method_name = QualifiedName(*self.type()->name(), name);
auto fn = self._ivalue()->compilation_unit()->create_function(
method_name, graph);
self.type()->addMethod(fn);
didFinishEmitModule(self);
},
py::arg("name"),
py::arg("func"),
py::arg("input_dict"),
py::arg("var_name_lookup_fn"),
py::arg("strict"),
py::arg("force_outplace"),
py::arg("argument_names") = std::vector<std::string>())
.def(
"_get_forward_hooks",
[](const Module& m) {
Expand Down Expand Up @@ -1668,6 +1705,43 @@ void initJitScriptBindings(PyObject* module) {
py::arg("force_outplace"),
py::arg("argument_names") = std::vector<std::string>());

m.def(
"_create_function_from_trace_with_dict",
[](const std::string& qualname,
const py::function& func,
const py::dict& input_dict,
const py::function& var_name_lookup_fn,
bool strict,
bool force_outplace,
const std::vector<std::string>& argument_names) {
auto typed_inputs = toTraceableStack(input_dict);
std::shared_ptr<Graph> graph =
std::get<0>(tracer::createGraphByTracingWithDict(
func,
input_dict,
typed_inputs,
var_name_lookup_fn,
strict,
force_outplace,
/*self=*/nullptr,
argument_names));

auto cu = get_python_cu();
auto name = c10::QualifiedName(qualname);
auto result = cu->create_function(
std::move(name), std::move(graph), /*shouldMangle=*/true);
StrongFunctionPtr ret(std::move(cu), result);
didFinishEmitFunction(ret);
return ret;
},
py::arg("name"),
py::arg("func"),
py::arg("input_dict"),
py::arg("var_name_lookup_fn"),
py::arg("strict"),
py::arg("force_outplace"),
py::arg("argument_names") = std::vector<std::string>());

m.def(
"_jit_script_class_compile",
[](const std::string& qualifiedName,
Expand Down
Loading

0 comments on commit 7980ed9

Please sign in to comment.