diff --git a/test/test_jit.py b/test/test_jit.py index a4f535921e558..b1425a4ed71ca 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3027,6 +3027,46 @@ def forward(self, x): checker.check("def forward") checker.run(str(cm.exception)) + def test_dictionary_as_example_inputs_for_jit_trace(self): + class TestModule_v1(torch.nn.Module): + def __init__(self): + super(TestModule_v1, self).__init__() + + def forward(self, key2=None, key3=None, key4=None, key5=None, key1=None, key6=None): + return key1 + key2 + key3 + + class TestModule_v2(torch.nn.Module): + def __init__(self): + super(TestModule_v2, self).__init__() + + def forward(self, x, y): + return x + y + + def test_func(x, y): + return x + y + model_1 = TestModule_v1() + model_2 = TestModule_v2() + value1 = torch.ones(1) + value2 = torch.ones(1) + value3 = torch.ones(1) + example_input_dict = {'key1': value1, 'key2': value2, 'key3': value3} + example_input_dict_func = {'x': value1, 'y': value2} + traced_model_1 = torch.jit.trace(model_1, example_kwarg_inputs=example_input_dict, strict=False) + traced_model_1_m = torch.jit.trace_module( + model_1, {'forward': example_input_dict}, example_inputs_is_kwarg=True, strict=False) + traced_model_2 = torch.jit.trace(model_2, example_kwarg_inputs={'x': torch.rand([2]), 'y': torch.rand([2])}) + traced_func = torch.jit.trace(test_func, example_kwarg_inputs=example_input_dict_func, strict=False) + res_1 = traced_model_1(**example_input_dict) + res_1_m = traced_model_1_m(**example_input_dict) + self.assertEqual(res_1, 3 * torch.ones(1)) + self.assertEqual(res_1_m, 3 * torch.ones(1)) + res_func = traced_func(**example_input_dict_func) + self.assertEqual(res_func, 2 * torch.ones(1)) + with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'x'."): + res_2 = traced_model_2(**{'z': torch.rand([2]), 'y': torch.rand([2])}) + with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'y'."): + res_2 = traced_model_2(**{'x': torch.rand([2]), 'z': torch.rand([2])}) + class TestScript(JitTestCase): diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 0e84fa864879c..0218c9fbdf4f8 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -332,6 +332,15 @@ def _create_function_from_trace( force_outplace: _bool, argument_names: List[str] ) -> Tuple[Graph, Stack]: ... +def _create_function_from_trace_with_dict( + qualname: str, + func: Callable[..., Any], + input_dict: Dict[str, Any], + var_lookup_fn: Callable[[Tensor], str], + strict: _bool, + force_outplace: _bool, + argument_names: List[str] +) -> Tuple[Graph, Stack]: ... def _jit_is_script_object(obj: Any) -> _bool: ... def _last_executed_optimized_graph() -> Graph: ... def parse_type_comment(comment: str) -> Decl: ... diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index 694d2b8ee4890..835c7d0dc709a 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -565,6 +565,17 @@ inline Stack toTraceableStack(const py::tuple& inputs) { return info.toTupleRef().elements().vec(); } +// Serialize the python dictionary into a traceable stack. +inline Stack toTraceableStack(const py::dict& inputs) { + Stack res; + for (auto it = inputs.begin(); it != inputs.end(); it++) { + if (THPVariable_Check(it->second.ptr())) { + res.push_back(toIValue(it->second, tryToInferType(it->second).type())); + } + } + return res; +} + inline IValue createGenericList(py::handle obj, const TypePtr& elem_type) { auto elems = c10::impl::GenericList(elem_type); for (auto elem : obj) { diff --git a/torch/csrc/jit/python/python_tracer.cpp b/torch/csrc/jit/python/python_tracer.cpp index 494265e161849..83570c85e9b4c 100644 --- a/torch/csrc/jit/python/python_tracer.cpp +++ b/torch/csrc/jit/python/python_tracer.cpp @@ -73,6 +73,69 @@ SourceRange getPythonInterpreterSourceRange() { return SourceRange(source, 0, stack_trace_text.size()); } +std::pair, Stack> createGraphByTracingWithDict( + const py::function& func, + const py::dict& inputs_dict, + Stack trace_inputs, + const py::function& var_name_lookup_fn, + bool strict, + bool force_outplace, + Module* self, + const std::vector& argument_names) { + C10_LOG_API_USAGE_ONCE("torch.tracer"); + + auto lookup_fn_adapter = + [var_name_lookup_fn](const Variable& var) -> std::string { + pybind11::gil_scoped_acquire ag; + return py::cast(var_name_lookup_fn(var)); + }; + + // The argument_names parameter is parsed in python and its order + // is the same as the arguments' decalaration order in forward() method. + // These name shall be added to the graph as debug name and the order + // should align with the traceable stack we generated by the python dict. + std::vector compact_argument_names; + Stack compact_trace_inputs; + for (std::vector::size_type i = 0; i < argument_names.size(); + i++) { + if (inputs_dict.contains(argument_names[i])) { + compact_argument_names.push_back(argument_names[i]); + } + } + for (std::vector::size_type i = 0; + i < compact_argument_names.size(); + i++) { + for (auto it = inputs_dict.begin(); it != inputs_dict.end(); it++) { + if (py::cast(it->first) == compact_argument_names[i]) { + if (THPVariable_Check(it->second.ptr())) { + compact_trace_inputs.push_back( + toIValue(it->second, tryToInferType(it->second).type())); + } + } + } + } + + auto outs = tracer::trace( + std::move(compact_trace_inputs), + [&](Stack inputs) -> Stack { + // We just leave the inputs_dict as it was and pass it to forward + // method. + auto out = func(**inputs_dict); + if (out.ptr() == Py_None) { + AT_ERROR( + "The traced function didn't return any values! Side-effects are not " + "captured in traces, so it would be a no-op."); + } + return {toTypeInferredIValue(out)}; + }, + lookup_fn_adapter, + strict, + force_outplace, + self, + compact_argument_names); + return std::make_pair(std::get<0>(outs)->graph, std::get<1>(outs)); +} + std::pair, Stack> createGraphByTracing( const py::function& func, Stack trace_inputs, diff --git a/torch/csrc/jit/python/python_tracer.h b/torch/csrc/jit/python/python_tracer.h index 3f1fca20bfe00..6ec9dc388c31a 100644 --- a/torch/csrc/jit/python/python_tracer.h +++ b/torch/csrc/jit/python/python_tracer.h @@ -24,6 +24,16 @@ Node* preRecordPythonTrace( at::ArrayRef inputs, std::vector scalar_args); +std::pair, Stack> createGraphByTracingWithDict( + const py::function& func, + const py::dict& inputs_dict, + Stack inputs, + const py::function& var_name_lookup_fn, + bool strict, + bool force_outplace, + Module* self = nullptr, + const std::vector& argument_names = {}); + std::pair, Stack> createGraphByTracing( const py::function& func, Stack inputs, diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 110c2f4a70c79..ee9509588932c 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -1218,6 +1218,43 @@ void initJitScriptBindings(PyObject* module) { py::arg("strict"), py::arg("force_outplace"), py::arg("argument_names") = std::vector()) + .def( + "_create_method_from_trace_with_dict", + [](Module& self, + const std::string& name, + const py::function& func, + const py::dict& input_dict, + const py::function& var_name_lookup_fn, + bool strict, + bool force_outplace, + const std::vector& argument_names) { + // prereq: Module's buffers and parameters are unique + // this was ensured in python before calling this function + auto typed_inputs = toTraceableStack(input_dict); + + std::shared_ptr graph = + std::get<0>(tracer::createGraphByTracingWithDict( + func, + input_dict, + typed_inputs, + var_name_lookup_fn, + strict, + force_outplace, + &self, + argument_names)); + const auto method_name = QualifiedName(*self.type()->name(), name); + auto fn = self._ivalue()->compilation_unit()->create_function( + method_name, graph); + self.type()->addMethod(fn); + didFinishEmitModule(self); + }, + py::arg("name"), + py::arg("func"), + py::arg("input_dict"), + py::arg("var_name_lookup_fn"), + py::arg("strict"), + py::arg("force_outplace"), + py::arg("argument_names") = std::vector()) .def( "_get_forward_hooks", [](const Module& m) { @@ -1668,6 +1705,43 @@ void initJitScriptBindings(PyObject* module) { py::arg("force_outplace"), py::arg("argument_names") = std::vector()); + m.def( + "_create_function_from_trace_with_dict", + [](const std::string& qualname, + const py::function& func, + const py::dict& input_dict, + const py::function& var_name_lookup_fn, + bool strict, + bool force_outplace, + const std::vector& argument_names) { + auto typed_inputs = toTraceableStack(input_dict); + std::shared_ptr graph = + std::get<0>(tracer::createGraphByTracingWithDict( + func, + input_dict, + typed_inputs, + var_name_lookup_fn, + strict, + force_outplace, + /*self=*/nullptr, + argument_names)); + + auto cu = get_python_cu(); + auto name = c10::QualifiedName(qualname); + auto result = cu->create_function( + std::move(name), std::move(graph), /*shouldMangle=*/true); + StrongFunctionPtr ret(std::move(cu), result); + didFinishEmitFunction(ret); + return ret; + }, + py::arg("name"), + py::arg("func"), + py::arg("input_dict"), + py::arg("var_name_lookup_fn"), + py::arg("strict"), + py::arg("force_outplace"), + py::arg("argument_names") = std::vector()); + m.def( "_jit_script_class_compile", [](const std::string& qualifiedName, diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index fe0091f63bb66..b4352648df9c9 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -316,6 +316,7 @@ def _check_trace( force_outplace, is_trace_module, _module_class, + example_inputs_is_kwarg=False, ): # Note: tracing is independent of optimizations, which consume the trace for inputs in check_inputs: @@ -335,20 +336,33 @@ def _check_trace( _force_outplace=force_outplace, _module_class=_module_class, _compilation_unit=torch._C.CompilationUnit(), + example_inputs_is_kwarg=example_inputs_is_kwarg, ) check_mod_func = check_mod._c._get_method(traced_func.name) inputs = inputs[traced_func.name] - if isinstance(inputs, (torch.Tensor, dict)): + if isinstance(inputs, (torch.Tensor)) or isinstance(inputs, dict) and not example_inputs_is_kwarg: inputs = (inputs,) else: - check_mod = torch.jit.trace( - func, - _clone_inputs(inputs), - check_trace=False, - strict=strict, - _force_outplace=force_outplace, - _module_class=_module_class, - ) + if example_inputs_is_kwarg: + check_mod = torch.jit.trace( + func, + check_trace=False, + strict=strict, + _force_outplace=force_outplace, + _module_class=_module_class, + example_kwarg_inputs=_clone_inputs(inputs), + ) + else: + check_mod = torch.jit.trace( + func, + _clone_inputs(inputs), + check_trace=False, + strict=strict, + _force_outplace=force_outplace, + _module_class=_module_class, + ) + + check_mod_func = check_mod def graph_diagnostic_info(): @@ -440,7 +454,10 @@ def wrap_retval(x): def run_mod_and_filter_tensor_outputs(mod, inputs, running_what): try: - outs = wrap_retval(mod(*_clone_inputs(inputs))) + if isinstance(inputs, dict) and example_inputs_is_kwarg: + outs = wrap_retval(mod(**inputs)) + else: + outs = wrap_retval(mod(*_clone_inputs(inputs))) outs = [out for out in outs if isinstance(out, torch.Tensor)] return outs except Exception as e: @@ -595,7 +612,7 @@ def wrap_check_inputs(check_inputs): def trace( func, - example_inputs, + example_inputs=None, optimize=None, check_trace=True, check_inputs=None, @@ -604,6 +621,7 @@ def trace( _force_outplace=False, _module_class=None, _compilation_unit=_python_cu, + example_kwarg_inputs=None ): """ Trace a function and return an executable or :class:`ScriptFunction` @@ -659,14 +677,17 @@ def trace( tensors. When a module is passed `torch.jit.trace`, only the ``forward`` method is run and traced (see :func:`torch.jit.trace ` for details). - example_inputs (tuple or torch.Tensor): A tuple of example inputs that - will be passed to the function while tracing. The resulting trace - can be run with inputs of different types and shapes assuming the - traced operations support those types and shapes. `example_inputs` - may also be a single Tensor in which case it is automatically - wrapped in a tuple. Keyword arguments: + example_inputs (tuple or torch.Tensor or None, optional): A tuple of example + inputs that will be passed to the function while tracing. + Default: ``None``. Either this argument or ``example_kwarg_inputs`` + should be specified. The resulting trace can be run with inputs of + different types and shapes assuming the traced operations support those + types and shapes. `example_inputs` may also be a single Tensor in which + case it is automatically wrapped in a tuple. When the value is None, + ``example_kwarg_inputs`` should be specified. + check_trace (``bool``, optional): Check if the same inputs run through traced code produce the same outputs. Default: ``True``. You might want to disable this if, for example, your network contains non- @@ -690,6 +711,12 @@ def trace( and you are sure that the container you are using in your problem is a ``constant`` structure and does not get used as control flow (if, for) conditions. + example_kwarg_inputs (dict, optional): This parameter is a pack of keyword + arguments of example inputs that will be passed to the function while + tracing. Default: ``None``. Either this argument or ``example_inputs`` + should be specified. The dict will be unpacking by the arguments name + of the traced function. If the keys of the dict don't not match with + the traced function's arguments name, a runtime exception will be raised. Returns: If `func` is `nn.Module` or ``forward`` of `nn.Module`, `trace` returns @@ -755,7 +782,13 @@ def forward(self, x): ) return func + if isinstance(func, torch.nn.Module): + if example_inputs is None: + if isinstance(example_kwarg_inputs, dict): + example_inputs = example_kwarg_inputs + else: + raise RuntimeError("example_kwarg_inputs should be a dict") return trace_module( func, {"forward": example_inputs}, @@ -766,6 +799,7 @@ def forward(self, x): strict, _force_outplace, _module_class, + example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), ) if ( @@ -773,6 +807,11 @@ def forward(self, x): and isinstance(func.__self__, torch.nn.Module) and func.__name__ == "forward" ): + if example_inputs is None: + if isinstance(example_kwarg_inputs, dict): + example_inputs = example_kwarg_inputs + else: + raise RuntimeError("example_kwarg_inputs should be a dict") return trace_module( func.__self__, {"forward": example_inputs}, @@ -783,13 +822,14 @@ def forward(self, x): strict, _force_outplace, _module_class, + example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), ) # Special case for common case of passing a single Tensor - if isinstance(example_inputs, (torch.Tensor, dict)): + if isinstance(example_inputs, (torch.Tensor, dict)) and example_kwarg_inputs is None: example_inputs = (example_inputs,) # done primarily so that weird iterables fail here and not pybind11 code - elif not isinstance(example_inputs, tuple): + elif example_kwarg_inputs is None and not isinstance(example_inputs, tuple): example_inputs = tuple(example_inputs) var_lookup_fn = _create_interpreter_name_lookup_fn(0) @@ -801,15 +841,27 @@ def forward(self, x): ) name = _qualified_name(func) - traced = torch._C._create_function_from_trace( - name, - func, - example_inputs, - var_lookup_fn, - strict, - _force_outplace, - get_callable_argument_names(func) - ) + if isinstance(example_kwarg_inputs, dict): + example_inputs = example_kwarg_inputs + traced = torch._C._create_function_from_trace_with_dict( + name, + func, + example_kwarg_inputs, + var_lookup_fn, + strict, + _force_outplace, + get_callable_argument_names(func) + ) + else: + traced = torch._C._create_function_from_trace( + name, + func, + example_inputs, + var_lookup_fn, + strict, + _force_outplace, + get_callable_argument_names(func) + ) # Check the trace against new traces created from user-specified inputs if check_trace: @@ -823,6 +875,7 @@ def forward(self, x): _force_outplace, False, _module_class, + example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), ) else: _check_trace( @@ -834,6 +887,7 @@ def forward(self, x): _force_outplace, False, _module_class, + example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), ) return traced @@ -853,6 +907,7 @@ def trace_module( _force_outplace=False, _module_class=None, _compilation_unit=_python_cu, + example_inputs_is_kwarg=False, ): """ Trace a module and return an executable :class:`ScriptModule` that will be optimized @@ -887,6 +942,8 @@ def trace_module( check_tolerance (float, optional): Floating-point comparison tolerance to use in the checker procedure. This can be used to relax the checker strictness in the event that results diverge numerically for a known reason, such as operator fusion. + example_inputs_is_kwarg (``bool``, optional): This parameter indicate wether the example inputs is a pack + pack of keyword arguments. Default: ``False``. Returns: A :class:`ScriptModule` object with a single ``forward`` method containing the traced code. @@ -971,17 +1028,34 @@ def register_submods(mod, prefix): func = getattr(mod, method_name) argument_names = get_callable_argument_names(func) - example_inputs = make_tuple(example_inputs) + if isinstance(example_inputs, dict) and example_inputs_is_kwarg: + # Raise exception when the user provided key names are not aligned with forward() method's arguments' name/ + for key in example_inputs: + if key not in argument_names: + valid_arguments = "[" + ','.join(argument_names) + "]" + raise NameError("""'{}' is not in forward() method's arguments, + valid arguments name are {}""".format(key, valid_arguments)) + module._c._create_method_from_trace_with_dict( + method_name, + func, + example_inputs, + var_lookup_fn, + strict, + _force_outplace, + argument_names, + ) + else: + example_inputs = make_tuple(example_inputs) + module._c._create_method_from_trace( + method_name, + func, + example_inputs, + var_lookup_fn, + strict, + _force_outplace, + argument_names, + ) - module._c._create_method_from_trace( - method_name, - func, - example_inputs, - var_lookup_fn, - strict, - _force_outplace, - argument_names, - ) check_trace_method = module._c._get_method(method_name) # Check the trace against new traces created from user-specified inputs @@ -996,6 +1070,7 @@ def register_submods(mod, prefix): _force_outplace, True, _module_class, + example_inputs_is_kwarg=example_inputs_is_kwarg, ) else: _check_trace( @@ -1007,6 +1082,7 @@ def register_submods(mod, prefix): _force_outplace, True, _module_class, + example_inputs_is_kwarg=example_inputs_is_kwarg, ) finally: torch.jit._trace._trace_module_map = old_module_map