From 7980ed95bd708d6e9baf64c95b1aa83df8891b59 Mon Sep 17 00:00:00 2001 From: tangleintel Date: Sat, 15 Oct 2022 05:33:07 +0000 Subject: [PATCH] Support unpacking python dictionary in torch.jit.trace() (#81623) # Support unpacking python dictionary in **torch.jit.trace()** ## Problem statement & Motivation ### Problem 1(usability): Say, if you have a model and its forward method defined as follows: **`def forward(self, key1=value1, key2=value2, key3=value3)`** And you have a dataset and each data point in the dataset is a python dict as follows: **`data = {key1:value1, key3:value3, key2:value2}`** The problem is that if you want to trace the model using the dict data by the giving dataset, you need unpack the dictionary and reorder its value manually and make up a tuple as **`data_tuple = (value1, value2, value3)`** as the **`example_inputs`** parameter of **`torch.jit.trace()`**. This marshalling process is not user friendly. ### Problem 2 (feasibility): Say, if you have a model and its forward method defined as follows: **`def forward(self, key1=None, key2=None, key3=None)`** -> The default value is **None** And you have a dataset and each data point in the dataset is a python dict as follows: **`data = {key1:value1, key3:value3}`** -> Only **part of** the required value by forward was given, the rest use the default value. The problem is that if you want to trace the model using the dict data by the giving dataset, it's not feasible at all. Cause neither you can pass a tuple like **`T1 = (value1, value3)`** nor **`T2 = (value1, None, value3)`**. T1 will mismatch value3 with key2 and T2 include **None** type which will be blocked by tracer's type checking. (Of course you can pass **`T3 = (value1,)`** to make the trace function finish without exception, but the traced model you get probably is not what you expect cause the different input may result in different traced result.). These problems come from the HuggingFace's PT model, especially in text-classification tasks with datasets such as [MRPC,](https://paperswithcode.com/dataset/mrpc) [MNLI](https://paperswithcode.com/dataset/multinli) etc. ## Solution To address these two issues, we propose to support a new type, that is, python dict as example_inputs parameter for torch.jit.trace(). We can base on the runtime type information of the example_inputs object to determine if we fall back to the original tuple path or go into the new dictionary path. Both problem 1 and problem 2 can be solved by utilizing the "**`**`**" operator. ## Limitation & Mitigation 1. If we use dict as example_inputs to trace the model, then we have to pass a dictionary to the traced model too. (Cause probably we will change the order of debug name of the input parameter in torchscript IR, thus we can't assume the traced model's input parameters order are the same with the original model.). We need highlight this too in the document to mitigate this problem. For example: ``` # fetch a data from dataloader, and the data is a dictionary # and the example_inputs_dict is like: {key1:value1, key3:value3, key2:value2} # the forward() is like: def forward(self, key1=value1, key2=value2, key3=value3) example_inputs_dict = next(iter(dataloader)) jit_model = model.eval() # use the dictionary to trace the model jit_model = torch.jit.trace(jit_model, example_inputs_dict, strict=False) # Now the IR will be graph(%self : __torch__.module.___torch_mangle_n.Mymodule, %key1 : type1, %key3 : type3, %key2 : type2) jit_model = torch.jit.freeze(jit_model) # It's OK to use dict as the parameter for traced model jit_model(**example_inputs_dict) example_inputs_tuple = (value1, value3, value2) # It's wrong to rely on the original args order. jit_model(*example_inputs_tuple) ``` ## Note 1. This PR will make some UT introduced in [39601](https://github.com/pytorch/pytorch/pull/39601) fail, which I think should be classified as unpacking a tuple containing a single dictionary element in our solution. 4. I think there is ambiguity since currently we only specify passing a tuple or a single Tensor as our example_inputs parameter in **torch.jit.trace()**'s documentation, but it seems we can still passing a dictionary. Pull Request resolved: https://github.com/pytorch/pytorch/pull/81623 Approved by: https://github.com/davidberard98 --- test/test_jit.py | 40 +++++++ torch/_C/__init__.pyi.in | 9 ++ torch/csrc/jit/python/pybind_utils.h | 11 ++ torch/csrc/jit/python/python_tracer.cpp | 63 ++++++++++ torch/csrc/jit/python/python_tracer.h | 10 ++ torch/csrc/jit/python/script_init.cpp | 74 ++++++++++++ torch/jit/_trace.py | 152 ++++++++++++++++++------ 7 files changed, 321 insertions(+), 38 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index a4f535921e558..b1425a4ed71ca 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3027,6 +3027,46 @@ def forward(self, x): checker.check("def forward") checker.run(str(cm.exception)) + def test_dictionary_as_example_inputs_for_jit_trace(self): + class TestModule_v1(torch.nn.Module): + def __init__(self): + super(TestModule_v1, self).__init__() + + def forward(self, key2=None, key3=None, key4=None, key5=None, key1=None, key6=None): + return key1 + key2 + key3 + + class TestModule_v2(torch.nn.Module): + def __init__(self): + super(TestModule_v2, self).__init__() + + def forward(self, x, y): + return x + y + + def test_func(x, y): + return x + y + model_1 = TestModule_v1() + model_2 = TestModule_v2() + value1 = torch.ones(1) + value2 = torch.ones(1) + value3 = torch.ones(1) + example_input_dict = {'key1': value1, 'key2': value2, 'key3': value3} + example_input_dict_func = {'x': value1, 'y': value2} + traced_model_1 = torch.jit.trace(model_1, example_kwarg_inputs=example_input_dict, strict=False) + traced_model_1_m = torch.jit.trace_module( + model_1, {'forward': example_input_dict}, example_inputs_is_kwarg=True, strict=False) + traced_model_2 = torch.jit.trace(model_2, example_kwarg_inputs={'x': torch.rand([2]), 'y': torch.rand([2])}) + traced_func = torch.jit.trace(test_func, example_kwarg_inputs=example_input_dict_func, strict=False) + res_1 = traced_model_1(**example_input_dict) + res_1_m = traced_model_1_m(**example_input_dict) + self.assertEqual(res_1, 3 * torch.ones(1)) + self.assertEqual(res_1_m, 3 * torch.ones(1)) + res_func = traced_func(**example_input_dict_func) + self.assertEqual(res_func, 2 * torch.ones(1)) + with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'x'."): + res_2 = traced_model_2(**{'z': torch.rand([2]), 'y': torch.rand([2])}) + with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'y'."): + res_2 = traced_model_2(**{'x': torch.rand([2]), 'z': torch.rand([2])}) + class TestScript(JitTestCase): diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 0e84fa864879c..0218c9fbdf4f8 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -332,6 +332,15 @@ def _create_function_from_trace( force_outplace: _bool, argument_names: List[str] ) -> Tuple[Graph, Stack]: ... +def _create_function_from_trace_with_dict( + qualname: str, + func: Callable[..., Any], + input_dict: Dict[str, Any], + var_lookup_fn: Callable[[Tensor], str], + strict: _bool, + force_outplace: _bool, + argument_names: List[str] +) -> Tuple[Graph, Stack]: ... def _jit_is_script_object(obj: Any) -> _bool: ... def _last_executed_optimized_graph() -> Graph: ... def parse_type_comment(comment: str) -> Decl: ... diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index 694d2b8ee4890..835c7d0dc709a 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -565,6 +565,17 @@ inline Stack toTraceableStack(const py::tuple& inputs) { return info.toTupleRef().elements().vec(); } +// Serialize the python dictionary into a traceable stack. +inline Stack toTraceableStack(const py::dict& inputs) { + Stack res; + for (auto it = inputs.begin(); it != inputs.end(); it++) { + if (THPVariable_Check(it->second.ptr())) { + res.push_back(toIValue(it->second, tryToInferType(it->second).type())); + } + } + return res; +} + inline IValue createGenericList(py::handle obj, const TypePtr& elem_type) { auto elems = c10::impl::GenericList(elem_type); for (auto elem : obj) { diff --git a/torch/csrc/jit/python/python_tracer.cpp b/torch/csrc/jit/python/python_tracer.cpp index 494265e161849..83570c85e9b4c 100644 --- a/torch/csrc/jit/python/python_tracer.cpp +++ b/torch/csrc/jit/python/python_tracer.cpp @@ -73,6 +73,69 @@ SourceRange getPythonInterpreterSourceRange() { return SourceRange(source, 0, stack_trace_text.size()); } +std::pair, Stack> createGraphByTracingWithDict( + const py::function& func, + const py::dict& inputs_dict, + Stack trace_inputs, + const py::function& var_name_lookup_fn, + bool strict, + bool force_outplace, + Module* self, + const std::vector& argument_names) { + C10_LOG_API_USAGE_ONCE("torch.tracer"); + + auto lookup_fn_adapter = + [var_name_lookup_fn](const Variable& var) -> std::string { + pybind11::gil_scoped_acquire ag; + return py::cast(var_name_lookup_fn(var)); + }; + + // The argument_names parameter is parsed in python and its order + // is the same as the arguments' decalaration order in forward() method. + // These name shall be added to the graph as debug name and the order + // should align with the traceable stack we generated by the python dict. + std::vector compact_argument_names; + Stack compact_trace_inputs; + for (std::vector::size_type i = 0; i < argument_names.size(); + i++) { + if (inputs_dict.contains(argument_names[i])) { + compact_argument_names.push_back(argument_names[i]); + } + } + for (std::vector::size_type i = 0; + i < compact_argument_names.size(); + i++) { + for (auto it = inputs_dict.begin(); it != inputs_dict.end(); it++) { + if (py::cast(it->first) == compact_argument_names[i]) { + if (THPVariable_Check(it->second.ptr())) { + compact_trace_inputs.push_back( + toIValue(it->second, tryToInferType(it->second).type())); + } + } + } + } + + auto outs = tracer::trace( + std::move(compact_trace_inputs), + [&](Stack inputs) -> Stack { + // We just leave the inputs_dict as it was and pass it to forward + // method. + auto out = func(**inputs_dict); + if (out.ptr() == Py_None) { + AT_ERROR( + "The traced function didn't return any values! Side-effects are not " + "captured in traces, so it would be a no-op."); + } + return {toTypeInferredIValue(out)}; + }, + lookup_fn_adapter, + strict, + force_outplace, + self, + compact_argument_names); + return std::make_pair(std::get<0>(outs)->graph, std::get<1>(outs)); +} + std::pair, Stack> createGraphByTracing( const py::function& func, Stack trace_inputs, diff --git a/torch/csrc/jit/python/python_tracer.h b/torch/csrc/jit/python/python_tracer.h index 3f1fca20bfe00..6ec9dc388c31a 100644 --- a/torch/csrc/jit/python/python_tracer.h +++ b/torch/csrc/jit/python/python_tracer.h @@ -24,6 +24,16 @@ Node* preRecordPythonTrace( at::ArrayRef inputs, std::vector scalar_args); +std::pair, Stack> createGraphByTracingWithDict( + const py::function& func, + const py::dict& inputs_dict, + Stack inputs, + const py::function& var_name_lookup_fn, + bool strict, + bool force_outplace, + Module* self = nullptr, + const std::vector& argument_names = {}); + std::pair, Stack> createGraphByTracing( const py::function& func, Stack inputs, diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 110c2f4a70c79..ee9509588932c 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -1218,6 +1218,43 @@ void initJitScriptBindings(PyObject* module) { py::arg("strict"), py::arg("force_outplace"), py::arg("argument_names") = std::vector()) + .def( + "_create_method_from_trace_with_dict", + [](Module& self, + const std::string& name, + const py::function& func, + const py::dict& input_dict, + const py::function& var_name_lookup_fn, + bool strict, + bool force_outplace, + const std::vector& argument_names) { + // prereq: Module's buffers and parameters are unique + // this was ensured in python before calling this function + auto typed_inputs = toTraceableStack(input_dict); + + std::shared_ptr graph = + std::get<0>(tracer::createGraphByTracingWithDict( + func, + input_dict, + typed_inputs, + var_name_lookup_fn, + strict, + force_outplace, + &self, + argument_names)); + const auto method_name = QualifiedName(*self.type()->name(), name); + auto fn = self._ivalue()->compilation_unit()->create_function( + method_name, graph); + self.type()->addMethod(fn); + didFinishEmitModule(self); + }, + py::arg("name"), + py::arg("func"), + py::arg("input_dict"), + py::arg("var_name_lookup_fn"), + py::arg("strict"), + py::arg("force_outplace"), + py::arg("argument_names") = std::vector()) .def( "_get_forward_hooks", [](const Module& m) { @@ -1668,6 +1705,43 @@ void initJitScriptBindings(PyObject* module) { py::arg("force_outplace"), py::arg("argument_names") = std::vector()); + m.def( + "_create_function_from_trace_with_dict", + [](const std::string& qualname, + const py::function& func, + const py::dict& input_dict, + const py::function& var_name_lookup_fn, + bool strict, + bool force_outplace, + const std::vector& argument_names) { + auto typed_inputs = toTraceableStack(input_dict); + std::shared_ptr graph = + std::get<0>(tracer::createGraphByTracingWithDict( + func, + input_dict, + typed_inputs, + var_name_lookup_fn, + strict, + force_outplace, + /*self=*/nullptr, + argument_names)); + + auto cu = get_python_cu(); + auto name = c10::QualifiedName(qualname); + auto result = cu->create_function( + std::move(name), std::move(graph), /*shouldMangle=*/true); + StrongFunctionPtr ret(std::move(cu), result); + didFinishEmitFunction(ret); + return ret; + }, + py::arg("name"), + py::arg("func"), + py::arg("input_dict"), + py::arg("var_name_lookup_fn"), + py::arg("strict"), + py::arg("force_outplace"), + py::arg("argument_names") = std::vector()); + m.def( "_jit_script_class_compile", [](const std::string& qualifiedName, diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index fe0091f63bb66..b4352648df9c9 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -316,6 +316,7 @@ def _check_trace( force_outplace, is_trace_module, _module_class, + example_inputs_is_kwarg=False, ): # Note: tracing is independent of optimizations, which consume the trace for inputs in check_inputs: @@ -335,20 +336,33 @@ def _check_trace( _force_outplace=force_outplace, _module_class=_module_class, _compilation_unit=torch._C.CompilationUnit(), + example_inputs_is_kwarg=example_inputs_is_kwarg, ) check_mod_func = check_mod._c._get_method(traced_func.name) inputs = inputs[traced_func.name] - if isinstance(inputs, (torch.Tensor, dict)): + if isinstance(inputs, (torch.Tensor)) or isinstance(inputs, dict) and not example_inputs_is_kwarg: inputs = (inputs,) else: - check_mod = torch.jit.trace( - func, - _clone_inputs(inputs), - check_trace=False, - strict=strict, - _force_outplace=force_outplace, - _module_class=_module_class, - ) + if example_inputs_is_kwarg: + check_mod = torch.jit.trace( + func, + check_trace=False, + strict=strict, + _force_outplace=force_outplace, + _module_class=_module_class, + example_kwarg_inputs=_clone_inputs(inputs), + ) + else: + check_mod = torch.jit.trace( + func, + _clone_inputs(inputs), + check_trace=False, + strict=strict, + _force_outplace=force_outplace, + _module_class=_module_class, + ) + + check_mod_func = check_mod def graph_diagnostic_info(): @@ -440,7 +454,10 @@ def wrap_retval(x): def run_mod_and_filter_tensor_outputs(mod, inputs, running_what): try: - outs = wrap_retval(mod(*_clone_inputs(inputs))) + if isinstance(inputs, dict) and example_inputs_is_kwarg: + outs = wrap_retval(mod(**inputs)) + else: + outs = wrap_retval(mod(*_clone_inputs(inputs))) outs = [out for out in outs if isinstance(out, torch.Tensor)] return outs except Exception as e: @@ -595,7 +612,7 @@ def wrap_check_inputs(check_inputs): def trace( func, - example_inputs, + example_inputs=None, optimize=None, check_trace=True, check_inputs=None, @@ -604,6 +621,7 @@ def trace( _force_outplace=False, _module_class=None, _compilation_unit=_python_cu, + example_kwarg_inputs=None ): """ Trace a function and return an executable or :class:`ScriptFunction` @@ -659,14 +677,17 @@ def trace( tensors. When a module is passed `torch.jit.trace`, only the ``forward`` method is run and traced (see :func:`torch.jit.trace ` for details). - example_inputs (tuple or torch.Tensor): A tuple of example inputs that - will be passed to the function while tracing. The resulting trace - can be run with inputs of different types and shapes assuming the - traced operations support those types and shapes. `example_inputs` - may also be a single Tensor in which case it is automatically - wrapped in a tuple. Keyword arguments: + example_inputs (tuple or torch.Tensor or None, optional): A tuple of example + inputs that will be passed to the function while tracing. + Default: ``None``. Either this argument or ``example_kwarg_inputs`` + should be specified. The resulting trace can be run with inputs of + different types and shapes assuming the traced operations support those + types and shapes. `example_inputs` may also be a single Tensor in which + case it is automatically wrapped in a tuple. When the value is None, + ``example_kwarg_inputs`` should be specified. + check_trace (``bool``, optional): Check if the same inputs run through traced code produce the same outputs. Default: ``True``. You might want to disable this if, for example, your network contains non- @@ -690,6 +711,12 @@ def trace( and you are sure that the container you are using in your problem is a ``constant`` structure and does not get used as control flow (if, for) conditions. + example_kwarg_inputs (dict, optional): This parameter is a pack of keyword + arguments of example inputs that will be passed to the function while + tracing. Default: ``None``. Either this argument or ``example_inputs`` + should be specified. The dict will be unpacking by the arguments name + of the traced function. If the keys of the dict don't not match with + the traced function's arguments name, a runtime exception will be raised. Returns: If `func` is `nn.Module` or ``forward`` of `nn.Module`, `trace` returns @@ -755,7 +782,13 @@ def forward(self, x): ) return func + if isinstance(func, torch.nn.Module): + if example_inputs is None: + if isinstance(example_kwarg_inputs, dict): + example_inputs = example_kwarg_inputs + else: + raise RuntimeError("example_kwarg_inputs should be a dict") return trace_module( func, {"forward": example_inputs}, @@ -766,6 +799,7 @@ def forward(self, x): strict, _force_outplace, _module_class, + example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), ) if ( @@ -773,6 +807,11 @@ def forward(self, x): and isinstance(func.__self__, torch.nn.Module) and func.__name__ == "forward" ): + if example_inputs is None: + if isinstance(example_kwarg_inputs, dict): + example_inputs = example_kwarg_inputs + else: + raise RuntimeError("example_kwarg_inputs should be a dict") return trace_module( func.__self__, {"forward": example_inputs}, @@ -783,13 +822,14 @@ def forward(self, x): strict, _force_outplace, _module_class, + example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), ) # Special case for common case of passing a single Tensor - if isinstance(example_inputs, (torch.Tensor, dict)): + if isinstance(example_inputs, (torch.Tensor, dict)) and example_kwarg_inputs is None: example_inputs = (example_inputs,) # done primarily so that weird iterables fail here and not pybind11 code - elif not isinstance(example_inputs, tuple): + elif example_kwarg_inputs is None and not isinstance(example_inputs, tuple): example_inputs = tuple(example_inputs) var_lookup_fn = _create_interpreter_name_lookup_fn(0) @@ -801,15 +841,27 @@ def forward(self, x): ) name = _qualified_name(func) - traced = torch._C._create_function_from_trace( - name, - func, - example_inputs, - var_lookup_fn, - strict, - _force_outplace, - get_callable_argument_names(func) - ) + if isinstance(example_kwarg_inputs, dict): + example_inputs = example_kwarg_inputs + traced = torch._C._create_function_from_trace_with_dict( + name, + func, + example_kwarg_inputs, + var_lookup_fn, + strict, + _force_outplace, + get_callable_argument_names(func) + ) + else: + traced = torch._C._create_function_from_trace( + name, + func, + example_inputs, + var_lookup_fn, + strict, + _force_outplace, + get_callable_argument_names(func) + ) # Check the trace against new traces created from user-specified inputs if check_trace: @@ -823,6 +875,7 @@ def forward(self, x): _force_outplace, False, _module_class, + example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), ) else: _check_trace( @@ -834,6 +887,7 @@ def forward(self, x): _force_outplace, False, _module_class, + example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), ) return traced @@ -853,6 +907,7 @@ def trace_module( _force_outplace=False, _module_class=None, _compilation_unit=_python_cu, + example_inputs_is_kwarg=False, ): """ Trace a module and return an executable :class:`ScriptModule` that will be optimized @@ -887,6 +942,8 @@ def trace_module( check_tolerance (float, optional): Floating-point comparison tolerance to use in the checker procedure. This can be used to relax the checker strictness in the event that results diverge numerically for a known reason, such as operator fusion. + example_inputs_is_kwarg (``bool``, optional): This parameter indicate wether the example inputs is a pack + pack of keyword arguments. Default: ``False``. Returns: A :class:`ScriptModule` object with a single ``forward`` method containing the traced code. @@ -971,17 +1028,34 @@ def register_submods(mod, prefix): func = getattr(mod, method_name) argument_names = get_callable_argument_names(func) - example_inputs = make_tuple(example_inputs) + if isinstance(example_inputs, dict) and example_inputs_is_kwarg: + # Raise exception when the user provided key names are not aligned with forward() method's arguments' name/ + for key in example_inputs: + if key not in argument_names: + valid_arguments = "[" + ','.join(argument_names) + "]" + raise NameError("""'{}' is not in forward() method's arguments, + valid arguments name are {}""".format(key, valid_arguments)) + module._c._create_method_from_trace_with_dict( + method_name, + func, + example_inputs, + var_lookup_fn, + strict, + _force_outplace, + argument_names, + ) + else: + example_inputs = make_tuple(example_inputs) + module._c._create_method_from_trace( + method_name, + func, + example_inputs, + var_lookup_fn, + strict, + _force_outplace, + argument_names, + ) - module._c._create_method_from_trace( - method_name, - func, - example_inputs, - var_lookup_fn, - strict, - _force_outplace, - argument_names, - ) check_trace_method = module._c._get_method(method_name) # Check the trace against new traces created from user-specified inputs @@ -996,6 +1070,7 @@ def register_submods(mod, prefix): _force_outplace, True, _module_class, + example_inputs_is_kwarg=example_inputs_is_kwarg, ) else: _check_trace( @@ -1007,6 +1082,7 @@ def register_submods(mod, prefix): _force_outplace, True, _module_class, + example_inputs_is_kwarg=example_inputs_is_kwarg, ) finally: torch.jit._trace._trace_module_map = old_module_map