From cdaae0bf682d718b0eacdfd4f539e721afcc3588 Mon Sep 17 00:00:00 2001 From: tangleintel Date: Wed, 22 Jun 2022 10:16:34 +0800 Subject: [PATCH 01/22] support python dict in jit.trace --- torch/csrc/jit/python/pybind_utils.h | 10 ++++++ torch/csrc/jit/python/python_tracer.cpp | 46 +++++++++++++++++++++++++ torch/csrc/jit/python/python_tracer.h | 10 ++++++ torch/csrc/jit/python/script_init.cpp | 37 ++++++++++++++++++++ torch/jit/_trace.py | 38 ++++++++++++++------ 5 files changed, 130 insertions(+), 11 deletions(-) diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index 694d2b8ee4890..b20275420193d 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -565,6 +565,16 @@ inline Stack toTraceableStack(const py::tuple& inputs) { return info.toTupleRef().elements().vec(); } +inline Stack toTraceableStack(const py::dict& inputs) { + Stack res; + for(auto it = inputs.begin(); it != inputs.end(); it++) { + if(THPVariable_Check(it->second.ptr())) { + res.push_back(toIValue(it->second, tryToInferType(it->second).type())); + } + } + return res; +} + inline IValue createGenericList(py::handle obj, const TypePtr& elem_type) { auto elems = c10::impl::GenericList(elem_type); for (auto elem : obj) { diff --git a/torch/csrc/jit/python/python_tracer.cpp b/torch/csrc/jit/python/python_tracer.cpp index 494265e161849..375856990061f 100644 --- a/torch/csrc/jit/python/python_tracer.cpp +++ b/torch/csrc/jit/python/python_tracer.cpp @@ -73,6 +73,52 @@ SourceRange getPythonInterpreterSourceRange() { return SourceRange(source, 0, stack_trace_text.size()); } +std::pair, Stack> createGraphByTracing_dict( + const py::function& func, + const py::dict& inputs_dict, + Stack trace_inputs, + const py::function& var_name_lookup_fn, + bool strict, + bool force_outplace, + Module* self, + const std::vector& argument_names) { + C10_LOG_API_USAGE_ONCE("torch.tracer"); + + auto lookup_fn_adapter = + [var_name_lookup_fn](const Variable& var) -> std::string { + pybind11::gil_scoped_acquire ag; + return py::cast(var_name_lookup_fn(var)); + }; + + std::vector reordered_argument_names; + for (auto it = inputs_dict.begin(); it != inputs_dict.end(); it++) { + for (size_t i = 0; i < argument_names.size(); i++) { + if (py::cast(it->first) == argument_names[i]) { + reordered_argument_names.push_back(argument_names[i]); + break; + } + } + } + + auto outs = tracer::trace( + std::move(trace_inputs), + [&](Stack inputs) -> Stack { + auto out = func(**inputs_dict); + if (out.ptr() == Py_None) { + AT_ERROR( + "The traced function didn't return any values! Side-effects are not " + "captured in traces, so it would be a no-op."); + } + return {toTypeInferredIValue(out)}; + }, + lookup_fn_adapter, + strict, + force_outplace, + self, + reordered_argument_names); + return std::make_pair(std::get<0>(outs)->graph, std::get<1>(outs)); +} + std::pair, Stack> createGraphByTracing( const py::function& func, Stack trace_inputs, diff --git a/torch/csrc/jit/python/python_tracer.h b/torch/csrc/jit/python/python_tracer.h index 3f1fca20bfe00..aa6bfa037fdc4 100644 --- a/torch/csrc/jit/python/python_tracer.h +++ b/torch/csrc/jit/python/python_tracer.h @@ -24,6 +24,16 @@ Node* preRecordPythonTrace( at::ArrayRef inputs, std::vector scalar_args); +std::pair, Stack> createGraphByTracing_dict( + const py::function& func, + const py::dict& inputs_dict, + Stack inputs, + const py::function& var_name_lookup_fn, + bool strict, + bool force_outplace, + Module* self = nullptr, + const std::vector& argument_names = {}); + std::pair, Stack> createGraphByTracing( const py::function& func, Stack inputs, diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 110c2f4a70c79..c1a04d2a6a575 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -1218,6 +1218,43 @@ void initJitScriptBindings(PyObject* module) { py::arg("strict"), py::arg("force_outplace"), py::arg("argument_names") = std::vector()) + .def( + "_create_method_from_trace_with_dict", + [](Module& self, + const std::string& name, + const py::function& func, + const py::dict& input_dict, + const py::function& var_name_lookup_fn, + bool strict, + bool force_outplace, + const std::vector& argument_names) { + // prereq: Module's buffers and parameters are unique + // this was ensured in python before calling this function + auto typed_inputs = toTraceableStack(input_dict); + + std::shared_ptr graph = + std::get<0>(tracer::createGraphByTracing_dict( + func, + input_dict, + typed_inputs, + var_name_lookup_fn, + strict, + force_outplace, + &self, + argument_names)); + const auto method_name = QualifiedName(*self.type()->name(), name); + auto fn = self._ivalue()->compilation_unit()->create_function( + method_name, graph); + self.type()->addMethod(fn); + didFinishEmitModule(self); + }, + py::arg("name"), + py::arg("func"), + py::arg("input_dict"), + py::arg("var_name_lookup_fn"), + py::arg("strict"), + py::arg("force_outplace"), + py::arg("argument_names") = std::vector()) .def( "_get_forward_hooks", [](const Module& m) { diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index fe0091f63bb66..45587372a2f58 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -440,7 +440,11 @@ def wrap_retval(x): def run_mod_and_filter_tensor_outputs(mod, inputs, running_what): try: - outs = wrap_retval(mod(*_clone_inputs(inputs))) + if len(inputs) > 0: + if isinstance(inputs[0], dict): + outs = wrap_retval(mod(**inputs[0])) + else: + outs = wrap_retval(mod(*_clone_inputs(inputs))) outs = [out for out in outs if isinstance(out, torch.Tensor)] return outs except Exception as e: @@ -971,17 +975,29 @@ def register_submods(mod, prefix): func = getattr(mod, method_name) argument_names = get_callable_argument_names(func) - example_inputs = make_tuple(example_inputs) + if isinstance(example_inputs, dict): + module._c._create_method_from_trace_with_dict( + method_name, + func, + example_inputs, + var_lookup_fn, + strict, + _force_outplace, + argument_names, + ) + else: + example_inputs = make_tuple(example_inputs) + + module._c._create_method_from_trace( + method_name, + func, + example_inputs, + var_lookup_fn, + strict, + _force_outplace, + argument_names, + ) - module._c._create_method_from_trace( - method_name, - func, - example_inputs, - var_lookup_fn, - strict, - _force_outplace, - argument_names, - ) check_trace_method = module._c._get_method(method_name) # Check the trace against new traces created from user-specified inputs From bca2afd7e99e974b4d2ca2f353a26e238d137a29 Mon Sep 17 00:00:00 2001 From: tangleintel Date: Thu, 23 Jun 2022 01:13:30 +0800 Subject: [PATCH 02/22] Add comments for key parts --- torch/csrc/jit/python/pybind_utils.h | 1 + torch/csrc/jit/python/python_tracer.h | 2 +- torch/csrc/jit/python/script_init.cpp | 4 ++-- torch/jit/_trace.py | 8 ++++++-- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index b20275420193d..de4f25841ba0e 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -565,6 +565,7 @@ inline Stack toTraceableStack(const py::tuple& inputs) { return info.toTupleRef().elements().vec(); } +// Serialize the python dictionary into a traceable stack. inline Stack toTraceableStack(const py::dict& inputs) { Stack res; for(auto it = inputs.begin(); it != inputs.end(); it++) { diff --git a/torch/csrc/jit/python/python_tracer.h b/torch/csrc/jit/python/python_tracer.h index aa6bfa037fdc4..6ec9dc388c31a 100644 --- a/torch/csrc/jit/python/python_tracer.h +++ b/torch/csrc/jit/python/python_tracer.h @@ -24,7 +24,7 @@ Node* preRecordPythonTrace( at::ArrayRef inputs, std::vector scalar_args); -std::pair, Stack> createGraphByTracing_dict( +std::pair, Stack> createGraphByTracingWithDict( const py::function& func, const py::dict& inputs_dict, Stack inputs, diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index c1a04d2a6a575..fdb62082fdf2d 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -1183,7 +1183,7 @@ void initJitScriptBindings(PyObject* module) { m.register_attribute(name, type, toIValue(value, type)); }) .def( - "_create_method_from_trace", + "_create_method_from_trace_with_tuple", [](Module& self, const std::string& name, const py::function& func, @@ -1233,7 +1233,7 @@ void initJitScriptBindings(PyObject* module) { auto typed_inputs = toTraceableStack(input_dict); std::shared_ptr graph = - std::get<0>(tracer::createGraphByTracing_dict( + std::get<0>(tracer::createGraphByTracingWithDict( func, input_dict, typed_inputs, diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index 45587372a2f58..b21c6c965117f 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -976,6 +976,11 @@ def register_submods(mod, prefix): argument_names = get_callable_argument_names(func) if isinstance(example_inputs, dict): + # Raise exception when the user provided key names are not aligned with forward() method's arguments' name/ + for key in example_inputs: + if key not in argument_names: + valid_arguments = "[" + ','.join(argument_names) + "]" + raise NameError("'{}' is not in forward() method's arguments, valid arguments name are {}".format(key, valid_arguments)) module._c._create_method_from_trace_with_dict( method_name, func, @@ -987,8 +992,7 @@ def register_submods(mod, prefix): ) else: example_inputs = make_tuple(example_inputs) - - module._c._create_method_from_trace( + module._c._create_method_from_trace_with_tuple( method_name, func, example_inputs, From 937d85dd4705b0058560e4417d1d441a1361ef59 Mon Sep 17 00:00:00 2001 From: tangleintel Date: Thu, 23 Jun 2022 01:16:16 +0800 Subject: [PATCH 03/22] add missing file --- torch/csrc/jit/python/python_tracer.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/python/python_tracer.cpp b/torch/csrc/jit/python/python_tracer.cpp index 375856990061f..0ad38c7d19db0 100644 --- a/torch/csrc/jit/python/python_tracer.cpp +++ b/torch/csrc/jit/python/python_tracer.cpp @@ -73,7 +73,7 @@ SourceRange getPythonInterpreterSourceRange() { return SourceRange(source, 0, stack_trace_text.size()); } -std::pair, Stack> createGraphByTracing_dict( +std::pair, Stack> createGraphByTracingWithDict( const py::function& func, const py::dict& inputs_dict, Stack trace_inputs, @@ -90,6 +90,10 @@ std::pair, Stack> createGraphByTracing_dict( return py::cast(var_name_lookup_fn(var)); }; + // The argument_names parameter is parsed in python and its order + // is the same as the arguments' decalaration order in forward() method. + // These name shall be added to the graph as debug name and the order + // should align with the traceable stack we generated by the python dict. std::vector reordered_argument_names; for (auto it = inputs_dict.begin(); it != inputs_dict.end(); it++) { for (size_t i = 0; i < argument_names.size(); i++) { @@ -103,6 +107,7 @@ std::pair, Stack> createGraphByTracing_dict( auto outs = tracer::trace( std::move(trace_inputs), [&](Stack inputs) -> Stack { + // We just leave the inputs_dict as it was and pass it to forward method. auto out = func(**inputs_dict); if (out.ptr() == Py_None) { AT_ERROR( From 110cf0280ee3b1c2886d89e4cb02d48689154330 Mon Sep 17 00:00:00 2001 From: tangleintel Date: Mon, 27 Jun 2022 02:43:43 +0800 Subject: [PATCH 04/22] add UT for this feature --- test/test_jit.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/test/test_jit.py b/test/test_jit.py index a4f535921e558..5c985f4a1f0c3 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3027,6 +3027,23 @@ def forward(self, x): checker.check("def forward") checker.run(str(cm.exception)) + def test_dictionary_as_example_inputs_for_jit_trace(self): + class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + + def forward(self, key2=None, key3=None, key4=None, key5=None, key1=None, key6=None): + return key1 + key2 + key3 + + model = TestModule() + value1 = torch.ones(1) + value2 = torch.ones(1) + value3 = torch.ones(1) + example_input_dict = {'key1':value1, 'key2':value2, 'key3':value3} + traced_model = torch.jit.trace(model, example_input_dict, strict=False) + res = traced_model(**example_input_dict) + self.assertEqual(res, 3 * torch.ones(1)) + class TestScript(JitTestCase): From a878835d3e0b68a5ffd80c37a417e035064d432b Mon Sep 17 00:00:00 2001 From: tangleintel Date: Sat, 16 Jul 2022 02:10:18 +0800 Subject: [PATCH 05/22] Modify failed UT to obey my solution --- test/jit/test_tracer.py | 10 ++++++++-- torch/jit/_trace.py | 11 +++++------ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/test/jit/test_tracer.py b/test/jit/test_tracer.py index 50fdec94b9fc0..d68cdc8a8502e 100644 --- a/test/jit/test_tracer.py +++ b/test/jit/test_tracer.py @@ -2356,7 +2356,10 @@ def forward(self, input_map: Dict[str, List[Tensor]]) -> Tensor: input_map = {"1" : [torch.rand(2, 2), torch.rand(2, 2)], "3" : [torch.rand(2, 2), torch.rand(2, 2)]} model = testA() - traced_model = torch.jit.trace(model, input_map) + example_input = list() + example_input.append(input_map) + example_input = tuple(example_input) + traced_model = torch.jit.trace(model, example_inputs=example_input) new_input_map = {"1" : [torch.rand(2, 2), torch.randn(2, 2)], "3" : [torch.rand(2, 2), torch.rand(2, 2)]} self.assertEqual(model(new_input_map), traced_model(new_input_map)) @@ -2417,7 +2420,10 @@ def forward(self, input: Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tenso checks_dict = torch.jit.script(ChecksDict()) returns_dict = torch.jit.script(ReturnsDict()) eager_module = TestModule(checks_dict, returns_dict) - traced_module = torch.jit.trace(eager_module, input1) + example_input = list() + example_input.append(input1) + example_input = tuple(example_input) + traced_module = torch.jit.trace(eager_module, example_inputs=example_input) self.assertEqual(traced_module(input1), eager_module(input1)) self.assertEqual(traced_module(input2), eager_module(input2)) diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index b21c6c965117f..b9a7e82b2c121 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -338,7 +338,7 @@ def _check_trace( ) check_mod_func = check_mod._c._get_method(traced_func.name) inputs = inputs[traced_func.name] - if isinstance(inputs, (torch.Tensor, dict)): + if isinstance(inputs, (torch.Tensor)): inputs = (inputs,) else: check_mod = torch.jit.trace( @@ -440,11 +440,10 @@ def wrap_retval(x): def run_mod_and_filter_tensor_outputs(mod, inputs, running_what): try: - if len(inputs) > 0: - if isinstance(inputs[0], dict): - outs = wrap_retval(mod(**inputs[0])) - else: - outs = wrap_retval(mod(*_clone_inputs(inputs))) + if isinstance(inputs, dict): + outs = wrap_retval(mod(**inputs)) + else: + outs = wrap_retval(mod(*_clone_inputs(inputs))) outs = [out for out in outs if isinstance(out, torch.Tensor)] return outs except Exception as e: From 663102dbb39f460b8d73d3833e3b5308de629543 Mon Sep 17 00:00:00 2001 From: tangleintel Date: Mon, 18 Jul 2022 10:07:37 +0800 Subject: [PATCH 06/22] modify code format --- test/test_jit.py | 4 ++-- torch/csrc/jit/python/pybind_utils.h | 4 ++-- torch/csrc/jit/python/python_tracer.cpp | 3 ++- torch/jit/_trace.py | 3 ++- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 5c985f4a1f0c3..2bfe73b9e5f39 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3034,12 +3034,12 @@ def __init__(self): def forward(self, key2=None, key3=None, key4=None, key5=None, key1=None, key6=None): return key1 + key2 + key3 - + model = TestModule() value1 = torch.ones(1) value2 = torch.ones(1) value3 = torch.ones(1) - example_input_dict = {'key1':value1, 'key2':value2, 'key3':value3} + example_input_dict = {'key1': value1, 'key2': value2, 'key3': value3} traced_model = torch.jit.trace(model, example_input_dict, strict=False) res = traced_model(**example_input_dict) self.assertEqual(res, 3 * torch.ones(1)) diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index de4f25841ba0e..835c7d0dc709a 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -568,8 +568,8 @@ inline Stack toTraceableStack(const py::tuple& inputs) { // Serialize the python dictionary into a traceable stack. inline Stack toTraceableStack(const py::dict& inputs) { Stack res; - for(auto it = inputs.begin(); it != inputs.end(); it++) { - if(THPVariable_Check(it->second.ptr())) { + for (auto it = inputs.begin(); it != inputs.end(); it++) { + if (THPVariable_Check(it->second.ptr())) { res.push_back(toIValue(it->second, tryToInferType(it->second).type())); } } diff --git a/torch/csrc/jit/python/python_tracer.cpp b/torch/csrc/jit/python/python_tracer.cpp index 0ad38c7d19db0..49dcdc2482988 100644 --- a/torch/csrc/jit/python/python_tracer.cpp +++ b/torch/csrc/jit/python/python_tracer.cpp @@ -107,7 +107,8 @@ std::pair, Stack> createGraphByTracingWithDict( auto outs = tracer::trace( std::move(trace_inputs), [&](Stack inputs) -> Stack { - // We just leave the inputs_dict as it was and pass it to forward method. + // We just leave the inputs_dict as it was and pass it to forward + // method. auto out = func(**inputs_dict); if (out.ptr() == Py_None) { AT_ERROR( diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index b9a7e82b2c121..5f7bf74505f32 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -979,7 +979,8 @@ def register_submods(mod, prefix): for key in example_inputs: if key not in argument_names: valid_arguments = "[" + ','.join(argument_names) + "]" - raise NameError("'{}' is not in forward() method's arguments, valid arguments name are {}".format(key, valid_arguments)) + raise NameError("""'{}' is not in forward() method's arguments, + valid arguments name are {}""".format(key, valid_arguments)) module._c._create_method_from_trace_with_dict( method_name, func, From f13c6582f1ae519039eabae5e0b32b4eb0001881 Mon Sep 17 00:00:00 2001 From: tangleintel Date: Mon, 5 Sep 2022 23:49:07 +0800 Subject: [PATCH 07/22] Complete UT --- test/test_jit.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 2bfe73b9e5f39..c6ffe081bcb9a 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3028,21 +3028,34 @@ def forward(self, x): checker.run(str(cm.exception)) def test_dictionary_as_example_inputs_for_jit_trace(self): - class TestModule(torch.nn.Module): + class TestModule_v1(torch.nn.Module): def __init__(self): - super(TestModule, self).__init__() + super(TestModule_v1, self).__init__() def forward(self, key2=None, key3=None, key4=None, key5=None, key1=None, key6=None): return key1 + key2 + key3 - model = TestModule() + class TestModule_v2(torch.nn.Module): + def __init__(self): + super(TestModule_v2, self).__init__() + + def forward(self, x, y): + return x + y + + model_1 = TestModule_v1() + model_2 = TestModule_v2() value1 = torch.ones(1) value2 = torch.ones(1) value3 = torch.ones(1) example_input_dict = {'key1': value1, 'key2': value2, 'key3': value3} - traced_model = torch.jit.trace(model, example_input_dict, strict=False) - res = traced_model(**example_input_dict) - self.assertEqual(res, 3 * torch.ones(1)) + traced_model_1 = torch.jit.trace(model_1, example_input_dict, strict=False) + traced_model_2 = torch.jit.trace(model_2, {'x': torch.rand([2]), 'y': torch.rand([2])}) + res_1 = traced_model_1(**example_input_dict) + self.assertEqual(res_1, 3 * torch.ones(1)) + with self.assertRaisesRegex(RuntimeError, "forward\(\) is missing value for argument 'x'."): + res_2 = traced_model_2(**{'z': torch.rand([2]), 'y': torch.rand([2])}) + with self.assertRaisesRegex(RuntimeError, "forward\(\) is missing value for argument 'y'."): + res_2 = traced_model_2(**{'x': torch.rand([2]), 'z': torch.rand([2])}) class TestScript(JitTestCase): From 5860d782d3a4a7fd45b6b4eeede041151016a077 Mon Sep 17 00:00:00 2001 From: tangleintel Date: Tue, 4 Oct 2022 16:30:45 +0800 Subject: [PATCH 08/22] modify the python internal API _create_method_from_trace_with_tuple() to its origin name _create_method_from_trace() to maintain backward compatibility --- torch/csrc/jit/python/script_init.cpp | 2 +- torch/jit/_trace.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index fdb62082fdf2d..594ca8f78b1f9 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -1183,7 +1183,7 @@ void initJitScriptBindings(PyObject* module) { m.register_attribute(name, type, toIValue(value, type)); }) .def( - "_create_method_from_trace_with_tuple", + "_create_method_from_trace", [](Module& self, const std::string& name, const py::function& func, diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index 5f7bf74505f32..fb3fb53314e3f 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -992,7 +992,7 @@ def register_submods(mod, prefix): ) else: example_inputs = make_tuple(example_inputs) - module._c._create_method_from_trace_with_tuple( + module._c._create_method_from_trace( method_name, func, example_inputs, From eda9a2216a3dd1ccdf7c558c3a343901668a3d2e Mon Sep 17 00:00:00 2001 From: tangleintel Date: Wed, 5 Oct 2022 21:14:55 +0800 Subject: [PATCH 09/22] add an option to trace() and trace_module() to extend the meaning of python dict as example_input and maintain the backward compatibility at the same time --- test/jit/test_tracer.py | 5 +---- test/test_jit.py | 4 ++-- torch/jit/_trace.py | 17 ++++++++++++++--- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/test/jit/test_tracer.py b/test/jit/test_tracer.py index d68cdc8a8502e..369bf239dff62 100644 --- a/test/jit/test_tracer.py +++ b/test/jit/test_tracer.py @@ -2356,10 +2356,7 @@ def forward(self, input_map: Dict[str, List[Tensor]]) -> Tensor: input_map = {"1" : [torch.rand(2, 2), torch.rand(2, 2)], "3" : [torch.rand(2, 2), torch.rand(2, 2)]} model = testA() - example_input = list() - example_input.append(input_map) - example_input = tuple(example_input) - traced_model = torch.jit.trace(model, example_inputs=example_input) + traced_model = torch.jit.trace(model, input_map) new_input_map = {"1" : [torch.rand(2, 2), torch.randn(2, 2)], "3" : [torch.rand(2, 2), torch.rand(2, 2)]} self.assertEqual(model(new_input_map), traced_model(new_input_map)) diff --git a/test/test_jit.py b/test/test_jit.py index c6ffe081bcb9a..6c4c48b7679fd 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3048,8 +3048,8 @@ def forward(self, x, y): value2 = torch.ones(1) value3 = torch.ones(1) example_input_dict = {'key1': value1, 'key2': value2, 'key3': value3} - traced_model_1 = torch.jit.trace(model_1, example_input_dict, strict=False) - traced_model_2 = torch.jit.trace(model_2, {'x': torch.rand([2]), 'y': torch.rand([2])}) + traced_model_1 = torch.jit.trace(model_1, example_input_dict, unpack_input_dict=True, strict=False) + traced_model_2 = torch.jit.trace(model_2, {'x': torch.rand([2]), 'y': torch.rand([2])}, unpack_input_dict=True) res_1 = traced_model_1(**example_input_dict) self.assertEqual(res_1, 3 * torch.ones(1)) with self.assertRaisesRegex(RuntimeError, "forward\(\) is missing value for argument 'x'."): diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index fb3fb53314e3f..b16f66e5e9997 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -316,6 +316,7 @@ def _check_trace( force_outplace, is_trace_module, _module_class, + unpack_input_dict=False ): # Note: tracing is independent of optimizations, which consume the trace for inputs in check_inputs: @@ -335,10 +336,11 @@ def _check_trace( _force_outplace=force_outplace, _module_class=_module_class, _compilation_unit=torch._C.CompilationUnit(), + unpack_input_dict=unpack_input_dict ) check_mod_func = check_mod._c._get_method(traced_func.name) inputs = inputs[traced_func.name] - if isinstance(inputs, (torch.Tensor)): + if isinstance(inputs, (torch.Tensor)) or isinstance(inputs, dict) and not unpack_input_dict: inputs = (inputs,) else: check_mod = torch.jit.trace( @@ -348,6 +350,7 @@ def _check_trace( strict=strict, _force_outplace=force_outplace, _module_class=_module_class, + unpack_input_dict=unpack_input_dict ) check_mod_func = check_mod @@ -440,7 +443,7 @@ def wrap_retval(x): def run_mod_and_filter_tensor_outputs(mod, inputs, running_what): try: - if isinstance(inputs, dict): + if isinstance(inputs, dict) and unpack_input_dict: outs = wrap_retval(mod(**inputs)) else: outs = wrap_retval(mod(*_clone_inputs(inputs))) @@ -607,6 +610,7 @@ def trace( _force_outplace=False, _module_class=None, _compilation_unit=_python_cu, + unpack_input_dict=False, ): """ Trace a function and return an executable or :class:`ScriptFunction` @@ -769,6 +773,7 @@ def forward(self, x): strict, _force_outplace, _module_class, + unpack_input_dict=unpack_input_dict ) if ( @@ -786,6 +791,7 @@ def forward(self, x): strict, _force_outplace, _module_class, + unpack_input_dict=unpack_input_dict ) # Special case for common case of passing a single Tensor @@ -826,6 +832,7 @@ def forward(self, x): _force_outplace, False, _module_class, + unpack_input_dict=unpack_input_dict ) else: _check_trace( @@ -837,6 +844,7 @@ def forward(self, x): _force_outplace, False, _module_class, + unpack_input_dict=unpack_input_dict ) return traced @@ -856,6 +864,7 @@ def trace_module( _force_outplace=False, _module_class=None, _compilation_unit=_python_cu, + unpack_input_dict=False ): """ Trace a module and return an executable :class:`ScriptModule` that will be optimized @@ -974,7 +983,7 @@ def register_submods(mod, prefix): func = getattr(mod, method_name) argument_names = get_callable_argument_names(func) - if isinstance(example_inputs, dict): + if isinstance(example_inputs, dict) and unpack_input_dict: # Raise exception when the user provided key names are not aligned with forward() method's arguments' name/ for key in example_inputs: if key not in argument_names: @@ -1016,6 +1025,7 @@ def register_submods(mod, prefix): _force_outplace, True, _module_class, + unpack_input_dict=unpack_input_dict ) else: _check_trace( @@ -1027,6 +1037,7 @@ def register_submods(mod, prefix): _force_outplace, True, _module_class, + unpack_input_dict=unpack_input_dict ) finally: torch.jit._trace._trace_module_map = old_module_map From 0deda0039c2e795940c4cd67c2d8f571efadfd87 Mon Sep 17 00:00:00 2001 From: tangleintel Date: Wed, 5 Oct 2022 21:19:29 +0800 Subject: [PATCH 10/22] revert the UT --- test/jit/test_tracer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/jit/test_tracer.py b/test/jit/test_tracer.py index 369bf239dff62..50fdec94b9fc0 100644 --- a/test/jit/test_tracer.py +++ b/test/jit/test_tracer.py @@ -2417,10 +2417,7 @@ def forward(self, input: Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tenso checks_dict = torch.jit.script(ChecksDict()) returns_dict = torch.jit.script(ReturnsDict()) eager_module = TestModule(checks_dict, returns_dict) - example_input = list() - example_input.append(input1) - example_input = tuple(example_input) - traced_module = torch.jit.trace(eager_module, example_inputs=example_input) + traced_module = torch.jit.trace(eager_module, input1) self.assertEqual(traced_module(input1), eager_module(input1)) self.assertEqual(traced_module(input2), eager_module(input2)) From d37e25273762666180fd54679d06e9804020ad1c Mon Sep 17 00:00:00 2001 From: tangleintel Date: Fri, 7 Oct 2022 00:00:45 +0800 Subject: [PATCH 11/22] add warning msg and function doc of the adding arguments --- torch/jit/_trace.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index b16f66e5e9997..ee9f519e57048 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -666,12 +666,14 @@ def trace( tensors. When a module is passed `torch.jit.trace`, only the ``forward`` method is run and traced (see :func:`torch.jit.trace ` for details). - example_inputs (tuple or torch.Tensor): A tuple of example inputs that + example_inputs (tuple or torch.Tensor or dict): A tuple of example inputs that will be passed to the function while tracing. The resulting trace can be run with inputs of different types and shapes assuming the traced operations support those types and shapes. `example_inputs` may also be a single Tensor in which case it is automatically - wrapped in a tuple. + wrapped in a tuple. When example_inputs is a dict, if unpack_input_dict + is set to True, it reprensents a pack of keyword arguments and can be unpacked + by the traced function's arguments name. Keyword arguments: check_trace (``bool``, optional): Check if the same inputs run through @@ -697,6 +699,10 @@ def trace( and you are sure that the container you are using in your problem is a ``constant`` structure and does not get used as control flow (if, for) conditions. + unpack_input_dict (``bool``, optional): When we use a python dict as + example_inputs, it suggests wether it is a pack of keyword arguments(True) or + a single value(False). This is a workaround for the ambiguity of using dict + as example_input. The old behavior will be deprecated. Returns: If `func` is `nn.Module` or ``forward`` of `nn.Module`, `trace` returns @@ -1000,6 +1006,15 @@ def register_submods(mod, prefix): argument_names, ) else: + if isinstance(example_inputs, dict): + warnings.warn( + "Directly using python dict as a `single` value(with unpack_input_dict=False) to " + "example_inputs is deprecated and will be removed in upcoming PyTorch release(2.1)." + "Instead, users should wrap the dict with python tuple first, and then assign to " + "example_inputs for the deprecated behavior. In the future, passing a dict to example_inputs" + "will represent a pack of keyword parameters which will be unpacked according to the " + "traced function's parameter name." + ) example_inputs = make_tuple(example_inputs) module._c._create_method_from_trace( method_name, From b5539e478d224787c6c59e8aa15eecea1111c635 Mon Sep 17 00:00:00 2001 From: tangleintel Date: Fri, 7 Oct 2022 01:21:39 +0800 Subject: [PATCH 12/22] fix the lint error --- test/test_jit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 6c4c48b7679fd..99dee3ca64ca5 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3052,9 +3052,9 @@ def forward(self, x, y): traced_model_2 = torch.jit.trace(model_2, {'x': torch.rand([2]), 'y': torch.rand([2])}, unpack_input_dict=True) res_1 = traced_model_1(**example_input_dict) self.assertEqual(res_1, 3 * torch.ones(1)) - with self.assertRaisesRegex(RuntimeError, "forward\(\) is missing value for argument 'x'."): + with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'x'."): res_2 = traced_model_2(**{'z': torch.rand([2]), 'y': torch.rand([2])}) - with self.assertRaisesRegex(RuntimeError, "forward\(\) is missing value for argument 'y'."): + with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'y'."): res_2 = traced_model_2(**{'x': torch.rand([2]), 'z': torch.rand([2])}) From 34d585da810f54db4c2d402d479e0071c8706f88 Mon Sep 17 00:00:00 2001 From: tangleintel Date: Fri, 7 Oct 2022 23:36:05 +0800 Subject: [PATCH 13/22] didn't change the debug name's order, just compact when there is missing args --- torch/csrc/jit/python/python_tracer.cpp | 23 +++++++++++++++-------- torch/jit/_trace.py | 2 +- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/torch/csrc/jit/python/python_tracer.cpp b/torch/csrc/jit/python/python_tracer.cpp index 49dcdc2482988..e5e511333a625 100644 --- a/torch/csrc/jit/python/python_tracer.cpp +++ b/torch/csrc/jit/python/python_tracer.cpp @@ -94,18 +94,25 @@ std::pair, Stack> createGraphByTracingWithDict( // is the same as the arguments' decalaration order in forward() method. // These name shall be added to the graph as debug name and the order // should align with the traceable stack we generated by the python dict. - std::vector reordered_argument_names; - for (auto it = inputs_dict.begin(); it != inputs_dict.end(); it++) { - for (size_t i = 0; i < argument_names.size(); i++) { - if (py::cast(it->first) == argument_names[i]) { - reordered_argument_names.push_back(argument_names[i]); - break; + std::vector compact_argument_names; + Stack compact_trace_inputs; + for (auto i = 0; i < argument_names.size(); i++) { + if (inputs_dict.contains(argument_names[i])) { + compact_argument_names.push_back(argument_names[i]); + } + } + for (auto i = 0; i < compact_argument_names.size(); i++) { + for (auto it = inputs_dict.begin(); it != inputs_dict.end(); it++) { + if (py::cast(it->first) == compact_argument_names[i]) { + if (THPVariable_Check(it->second.ptr())) { + compact_trace_inputs.push_back(toIValue(it->second, tryToInferType(it->second).type())); + } } } } auto outs = tracer::trace( - std::move(trace_inputs), + std::move(compact_trace_inputs), [&](Stack inputs) -> Stack { // We just leave the inputs_dict as it was and pass it to forward // method. @@ -121,7 +128,7 @@ std::pair, Stack> createGraphByTracingWithDict( strict, force_outplace, self, - reordered_argument_names); + compact_argument_names); return std::make_pair(std::get<0>(outs)->graph, std::get<1>(outs)); } diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index ee9f519e57048..2178a94713eb8 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -1009,7 +1009,7 @@ def register_submods(mod, prefix): if isinstance(example_inputs, dict): warnings.warn( "Directly using python dict as a `single` value(with unpack_input_dict=False) to " - "example_inputs is deprecated and will be removed in upcoming PyTorch release(2.1)." + "example_inputs is deprecated and will be removed in upcoming PyTorch release(1.16)." "Instead, users should wrap the dict with python tuple first, and then assign to " "example_inputs for the deprecated behavior. In the future, passing a dict to example_inputs" "will represent a pack of keyword parameters which will be unpacked according to the " From 6f3753163ae39b4f3b0043f54e193b0d6d304691 Mon Sep 17 00:00:00 2001 From: tangleintel Date: Fri, 7 Oct 2022 23:56:59 +0800 Subject: [PATCH 14/22] lint error clang-format --- torch/csrc/jit/python/python_tracer.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/python/python_tracer.cpp b/torch/csrc/jit/python/python_tracer.cpp index e5e511333a625..0b40afd10ffc3 100644 --- a/torch/csrc/jit/python/python_tracer.cpp +++ b/torch/csrc/jit/python/python_tracer.cpp @@ -105,7 +105,8 @@ std::pair, Stack> createGraphByTracingWithDict( for (auto it = inputs_dict.begin(); it != inputs_dict.end(); it++) { if (py::cast(it->first) == compact_argument_names[i]) { if (THPVariable_Check(it->second.ptr())) { - compact_trace_inputs.push_back(toIValue(it->second, tryToInferType(it->second).type())); + compact_trace_inputs.push_back( + toIValue(it->second, tryToInferType(it->second).type())); } } } From bd8ea4c614815722ab6aa74b2f22e6872971f4f7 Mon Sep 17 00:00:00 2001 From: tangleintel Date: Sat, 8 Oct 2022 00:38:00 +0800 Subject: [PATCH 15/22] fix build error for some compilers on other platform --- torch/csrc/jit/python/python_tracer.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/csrc/jit/python/python_tracer.cpp b/torch/csrc/jit/python/python_tracer.cpp index 0b40afd10ffc3..d09069ac7d3c2 100644 --- a/torch/csrc/jit/python/python_tracer.cpp +++ b/torch/csrc/jit/python/python_tracer.cpp @@ -96,12 +96,12 @@ std::pair, Stack> createGraphByTracingWithDict( // should align with the traceable stack we generated by the python dict. std::vector compact_argument_names; Stack compact_trace_inputs; - for (auto i = 0; i < argument_names.size(); i++) { + for (std::vector::size_type i = 0; i < argument_names.size(); i++) { if (inputs_dict.contains(argument_names[i])) { compact_argument_names.push_back(argument_names[i]); } } - for (auto i = 0; i < compact_argument_names.size(); i++) { + for (std::vector::size_type i = 0; i < compact_argument_names.size(); i++) { for (auto it = inputs_dict.begin(); it != inputs_dict.end(); it++) { if (py::cast(it->first) == compact_argument_names[i]) { if (THPVariable_Check(it->second.ptr())) { From f8a47183d8ea696818a8092e361252d2fbbd101a Mon Sep 17 00:00:00 2001 From: tangleintel Date: Sat, 8 Oct 2022 00:47:56 +0800 Subject: [PATCH 16/22] clang-format --- torch/csrc/jit/python/python_tracer.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torch/csrc/jit/python/python_tracer.cpp b/torch/csrc/jit/python/python_tracer.cpp index d09069ac7d3c2..83570c85e9b4c 100644 --- a/torch/csrc/jit/python/python_tracer.cpp +++ b/torch/csrc/jit/python/python_tracer.cpp @@ -96,12 +96,15 @@ std::pair, Stack> createGraphByTracingWithDict( // should align with the traceable stack we generated by the python dict. std::vector compact_argument_names; Stack compact_trace_inputs; - for (std::vector::size_type i = 0; i < argument_names.size(); i++) { + for (std::vector::size_type i = 0; i < argument_names.size(); + i++) { if (inputs_dict.contains(argument_names[i])) { compact_argument_names.push_back(argument_names[i]); } } - for (std::vector::size_type i = 0; i < compact_argument_names.size(); i++) { + for (std::vector::size_type i = 0; + i < compact_argument_names.size(); + i++) { for (auto it = inputs_dict.begin(); it != inputs_dict.end(); it++) { if (py::cast(it->first) == compact_argument_names[i]) { if (THPVariable_Check(it->second.ptr())) { From 89497ff3956c35400aa20a76022307438eacb2a3 Mon Sep 17 00:00:00 2001 From: tangleintel Date: Tue, 11 Oct 2022 23:14:23 +0800 Subject: [PATCH 17/22] add argument example_kwarg_inputs to unpack dict --- test/test_jit.py | 4 +-- torch/jit/_trace.py | 62 +++++++++++++++++++++------------------------ 2 files changed, 31 insertions(+), 35 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 99dee3ca64ca5..615425ab6adb5 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3048,8 +3048,8 @@ def forward(self, x, y): value2 = torch.ones(1) value3 = torch.ones(1) example_input_dict = {'key1': value1, 'key2': value2, 'key3': value3} - traced_model_1 = torch.jit.trace(model_1, example_input_dict, unpack_input_dict=True, strict=False) - traced_model_2 = torch.jit.trace(model_2, {'x': torch.rand([2]), 'y': torch.rand([2])}, unpack_input_dict=True) + traced_model_1 = torch.jit.trace(model_1, example_kwarg_inputs=example_input_dict, strict=False) + traced_model_2 = torch.jit.trace(model_2, example_kwarg_inputs={'x': torch.rand([2]), 'y': torch.rand([2])}) res_1 = traced_model_1(**example_input_dict) self.assertEqual(res_1, 3 * torch.ones(1)) with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'x'."): diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index 2178a94713eb8..6657868723b98 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -316,7 +316,7 @@ def _check_trace( force_outplace, is_trace_module, _module_class, - unpack_input_dict=False + example_kwarg_inputs=None, ): # Note: tracing is independent of optimizations, which consume the trace for inputs in check_inputs: @@ -336,11 +336,11 @@ def _check_trace( _force_outplace=force_outplace, _module_class=_module_class, _compilation_unit=torch._C.CompilationUnit(), - unpack_input_dict=unpack_input_dict + example_kwarg_inputs=example_kwarg_inputs ) check_mod_func = check_mod._c._get_method(traced_func.name) inputs = inputs[traced_func.name] - if isinstance(inputs, (torch.Tensor)) or isinstance(inputs, dict) and not unpack_input_dict: + if isinstance(inputs, (torch.Tensor)) or isinstance(inputs, dict) and example_kwarg_inputs is None: inputs = (inputs,) else: check_mod = torch.jit.trace( @@ -350,7 +350,7 @@ def _check_trace( strict=strict, _force_outplace=force_outplace, _module_class=_module_class, - unpack_input_dict=unpack_input_dict + example_kwarg_inputs=example_kwarg_inputs, ) check_mod_func = check_mod @@ -443,7 +443,7 @@ def wrap_retval(x): def run_mod_and_filter_tensor_outputs(mod, inputs, running_what): try: - if isinstance(inputs, dict) and unpack_input_dict: + if isinstance(inputs, dict) and isinstance(example_kwarg_inputs, dict): outs = wrap_retval(mod(**inputs)) else: outs = wrap_retval(mod(*_clone_inputs(inputs))) @@ -601,7 +601,7 @@ def wrap_check_inputs(check_inputs): def trace( func, - example_inputs, + example_inputs=None, optimize=None, check_trace=True, check_inputs=None, @@ -610,7 +610,7 @@ def trace( _force_outplace=False, _module_class=None, _compilation_unit=_python_cu, - unpack_input_dict=False, + example_kwarg_inputs=None ): """ Trace a function and return an executable or :class:`ScriptFunction` @@ -666,16 +666,16 @@ def trace( tensors. When a module is passed `torch.jit.trace`, only the ``forward`` method is run and traced (see :func:`torch.jit.trace ` for details). - example_inputs (tuple or torch.Tensor or dict): A tuple of example inputs that + + Keyword arguments: + example_inputs (tuple or torch.Tensor or None): A tuple of example inputs that will be passed to the function while tracing. The resulting trace can be run with inputs of different types and shapes assuming the traced operations support those types and shapes. `example_inputs` may also be a single Tensor in which case it is automatically - wrapped in a tuple. When example_inputs is a dict, if unpack_input_dict - is set to True, it reprensents a pack of keyword arguments and can be unpacked - by the traced function's arguments name. + wrapped in a tuple. When the value is None, example_kwarg_inputs should + be specified. - Keyword arguments: check_trace (``bool``, optional): Check if the same inputs run through traced code produce the same outputs. Default: ``True``. You might want to disable this if, for example, your network contains non- @@ -699,10 +699,9 @@ def trace( and you are sure that the container you are using in your problem is a ``constant`` structure and does not get used as control flow (if, for) conditions. - unpack_input_dict (``bool``, optional): When we use a python dict as - example_inputs, it suggests wether it is a pack of keyword arguments(True) or - a single value(False). This is a workaround for the ambiguity of using dict - as example_input. The old behavior will be deprecated. + example_kwarg_inputs (dict): This parameter is a pack of keyword arguments + example inputs that will be passed to the function while tracing. The dict + will be unpacking by the arguments name of the traced function. Returns: If `func` is `nn.Module` or ``forward`` of `nn.Module`, `trace` returns @@ -768,6 +767,12 @@ def forward(self, x): ) return func + if example_inputs is None: + if isinstance(example_kwarg_inputs, dict): + example_inputs = example_kwarg_inputs + else: + raise RuntimeError("example_kwarg_inputs should be a dict") + if isinstance(func, torch.nn.Module): return trace_module( func, @@ -779,7 +784,7 @@ def forward(self, x): strict, _force_outplace, _module_class, - unpack_input_dict=unpack_input_dict + example_kwarg_inputs=example_kwarg_inputs, ) if ( @@ -797,7 +802,7 @@ def forward(self, x): strict, _force_outplace, _module_class, - unpack_input_dict=unpack_input_dict + example_kwarg_inputs=example_kwarg_inputs, ) # Special case for common case of passing a single Tensor @@ -838,7 +843,7 @@ def forward(self, x): _force_outplace, False, _module_class, - unpack_input_dict=unpack_input_dict + example_kwarg_inputs=example_kwarg_inputs, ) else: _check_trace( @@ -850,7 +855,7 @@ def forward(self, x): _force_outplace, False, _module_class, - unpack_input_dict=unpack_input_dict + example_kwarg_inputs=example_kwarg_inputs, ) return traced @@ -870,7 +875,7 @@ def trace_module( _force_outplace=False, _module_class=None, _compilation_unit=_python_cu, - unpack_input_dict=False + example_kwarg_inputs=None, ): """ Trace a module and return an executable :class:`ScriptModule` that will be optimized @@ -989,7 +994,7 @@ def register_submods(mod, prefix): func = getattr(mod, method_name) argument_names = get_callable_argument_names(func) - if isinstance(example_inputs, dict) and unpack_input_dict: + if isinstance(example_inputs, dict) and isinstance(example_kwarg_inputs, dict): # Raise exception when the user provided key names are not aligned with forward() method's arguments' name/ for key in example_inputs: if key not in argument_names: @@ -1006,15 +1011,6 @@ def register_submods(mod, prefix): argument_names, ) else: - if isinstance(example_inputs, dict): - warnings.warn( - "Directly using python dict as a `single` value(with unpack_input_dict=False) to " - "example_inputs is deprecated and will be removed in upcoming PyTorch release(1.16)." - "Instead, users should wrap the dict with python tuple first, and then assign to " - "example_inputs for the deprecated behavior. In the future, passing a dict to example_inputs" - "will represent a pack of keyword parameters which will be unpacked according to the " - "traced function's parameter name." - ) example_inputs = make_tuple(example_inputs) module._c._create_method_from_trace( method_name, @@ -1040,7 +1036,7 @@ def register_submods(mod, prefix): _force_outplace, True, _module_class, - unpack_input_dict=unpack_input_dict + example_kwarg_inputs=example_kwarg_inputs, ) else: _check_trace( @@ -1052,7 +1048,7 @@ def register_submods(mod, prefix): _force_outplace, True, _module_class, - unpack_input_dict=unpack_input_dict + example_kwarg_inputs=example_kwarg_inputs, ) finally: torch.jit._trace._trace_module_map = old_module_map From ead480a4960337a25afce4d21990007a7de10f22 Mon Sep 17 00:00:00 2001 From: tangleintel Date: Wed, 12 Oct 2022 13:48:44 +0800 Subject: [PATCH 18/22] refine the docstring --- torch/jit/_trace.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index 6657868723b98..ea19686ab913a 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -668,13 +668,14 @@ def trace( ` for details). Keyword arguments: - example_inputs (tuple or torch.Tensor or None): A tuple of example inputs that - will be passed to the function while tracing. The resulting trace - can be run with inputs of different types and shapes assuming the - traced operations support those types and shapes. `example_inputs` - may also be a single Tensor in which case it is automatically - wrapped in a tuple. When the value is None, example_kwarg_inputs should - be specified. + example_inputs (tuple or torch.Tensor or None, optional): A tuple of example + inputs that will be passed to the function while tracing. + Default: ``None``. Either this argument or ``example_kwarg_inputs`` + should be specified. The resulting trace can be run with inputs of + different types and shapes assuming the traced operations support those + types and shapes. `example_inputs` may also be a single Tensor in which + case it is automatically wrapped in a tuple. When the value is None, + ``example_kwarg_inputs`` should be specified. check_trace (``bool``, optional): Check if the same inputs run through traced code produce the same outputs. Default: ``True``. You might want @@ -699,9 +700,12 @@ def trace( and you are sure that the container you are using in your problem is a ``constant`` structure and does not get used as control flow (if, for) conditions. - example_kwarg_inputs (dict): This parameter is a pack of keyword arguments - example inputs that will be passed to the function while tracing. The dict - will be unpacking by the arguments name of the traced function. + example_kwarg_inputs (dict, optional): This parameter is a pack of keyword + arguments example inputs that will be passed to the function while tracing. + Default: ``None``. Either this argument or ``example_inputs`` should be + specified. The dict will be unpacking by the arguments name of the traced + function. If the keys of the dict don't not match with the traced + function'a arguments name, a runtime exception will be raised. Returns: If `func` is `nn.Module` or ``forward`` of `nn.Module`, `trace` returns From 54c3f66ff0e9bdd529a692e55eff1451a6dcc15e Mon Sep 17 00:00:00 2001 From: tangleintel Date: Fri, 14 Oct 2022 00:13:43 +0800 Subject: [PATCH 19/22] support this feature to pure python function in jit.trace() and modify the doc --- test/test_jit.py | 6 +++ torch/csrc/jit/python/script_init.cpp | 37 +++++++++++++++++ torch/jit/_trace.py | 57 +++++++++++++++++++-------- 3 files changed, 83 insertions(+), 17 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 615425ab6adb5..3244db7488d84 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3042,16 +3042,22 @@ def __init__(self): def forward(self, x, y): return x + y + def test_func(x, y): + return x + y model_1 = TestModule_v1() model_2 = TestModule_v2() value1 = torch.ones(1) value2 = torch.ones(1) value3 = torch.ones(1) example_input_dict = {'key1': value1, 'key2': value2, 'key3': value3} + example_input_dict_func = {'x': value1, 'y': value2} traced_model_1 = torch.jit.trace(model_1, example_kwarg_inputs=example_input_dict, strict=False) traced_model_2 = torch.jit.trace(model_2, example_kwarg_inputs={'x': torch.rand([2]), 'y': torch.rand([2])}) + traced_func = torch.jit.trace(test_func, example_kwarg_inputs=example_input_dict_func, strict=False) res_1 = traced_model_1(**example_input_dict) self.assertEqual(res_1, 3 * torch.ones(1)) + res_func = traced_func(**example_input_dict_func) + self.assertEqual(res_func, 2 * torch.ones(1)) with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'x'."): res_2 = traced_model_2(**{'z': torch.rand([2]), 'y': torch.rand([2])}) with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'y'."): diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 594ca8f78b1f9..ee9509588932c 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -1705,6 +1705,43 @@ void initJitScriptBindings(PyObject* module) { py::arg("force_outplace"), py::arg("argument_names") = std::vector()); + m.def( + "_create_function_from_trace_with_dict", + [](const std::string& qualname, + const py::function& func, + const py::dict& input_dict, + const py::function& var_name_lookup_fn, + bool strict, + bool force_outplace, + const std::vector& argument_names) { + auto typed_inputs = toTraceableStack(input_dict); + std::shared_ptr graph = + std::get<0>(tracer::createGraphByTracingWithDict( + func, + input_dict, + typed_inputs, + var_name_lookup_fn, + strict, + force_outplace, + /*self=*/nullptr, + argument_names)); + + auto cu = get_python_cu(); + auto name = c10::QualifiedName(qualname); + auto result = cu->create_function( + std::move(name), std::move(graph), /*shouldMangle=*/true); + StrongFunctionPtr ret(std::move(cu), result); + didFinishEmitFunction(ret); + return ret; + }, + py::arg("name"), + py::arg("func"), + py::arg("input_dict"), + py::arg("var_name_lookup_fn"), + py::arg("strict"), + py::arg("force_outplace"), + py::arg("argument_names") = std::vector()); + m.def( "_jit_script_class_compile", [](const std::string& qualifiedName, diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index ea19686ab913a..e8d20ff6706dc 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -705,7 +705,7 @@ def trace( Default: ``None``. Either this argument or ``example_inputs`` should be specified. The dict will be unpacking by the arguments name of the traced function. If the keys of the dict don't not match with the traced - function'a arguments name, a runtime exception will be raised. + function's arguments name, a runtime exception will be raised. Returns: If `func` is `nn.Module` or ``forward`` of `nn.Module`, `trace` returns @@ -771,13 +771,13 @@ def forward(self, x): ) return func - if example_inputs is None: - if isinstance(example_kwarg_inputs, dict): - example_inputs = example_kwarg_inputs - else: - raise RuntimeError("example_kwarg_inputs should be a dict") if isinstance(func, torch.nn.Module): + if example_inputs is None: + if isinstance(example_kwarg_inputs, dict): + example_inputs = example_kwarg_inputs + else: + raise RuntimeError("example_kwarg_inputs should be a dict") return trace_module( func, {"forward": example_inputs}, @@ -796,6 +796,11 @@ def forward(self, x): and isinstance(func.__self__, torch.nn.Module) and func.__name__ == "forward" ): + if example_inputs is None: + if isinstance(example_kwarg_inputs, dict): + example_inputs = example_kwarg_inputs + else: + raise RuntimeError("example_kwarg_inputs should be a dict") return trace_module( func.__self__, {"forward": example_inputs}, @@ -810,10 +815,10 @@ def forward(self, x): ) # Special case for common case of passing a single Tensor - if isinstance(example_inputs, (torch.Tensor, dict)): + if isinstance(example_inputs, (torch.Tensor, dict)) and example_kwarg_inputs is None: example_inputs = (example_inputs,) # done primarily so that weird iterables fail here and not pybind11 code - elif not isinstance(example_inputs, tuple): + elif example_kwarg_inputs is None and not isinstance(example_inputs, tuple): example_inputs = tuple(example_inputs) var_lookup_fn = _create_interpreter_name_lookup_fn(0) @@ -825,15 +830,27 @@ def forward(self, x): ) name = _qualified_name(func) - traced = torch._C._create_function_from_trace( - name, - func, - example_inputs, - var_lookup_fn, - strict, - _force_outplace, - get_callable_argument_names(func) - ) + if isinstance(example_kwarg_inputs, dict): + example_inputs = example_kwarg_inputs + traced = torch._C._create_function_from_trace_with_dict( + name, + func, + example_kwarg_inputs, + var_lookup_fn, + strict, + _force_outplace, + get_callable_argument_names(func) + ) + else: + traced = torch._C._create_function_from_trace( + name, + func, + example_inputs, + var_lookup_fn, + strict, + _force_outplace, + get_callable_argument_names(func) + ) # Check the trace against new traces created from user-specified inputs if check_trace: @@ -914,6 +931,12 @@ def trace_module( check_tolerance (float, optional): Floating-point comparison tolerance to use in the checker procedure. This can be used to relax the checker strictness in the event that results diverge numerically for a known reason, such as operator fusion. + example_kwarg_inputs (dict, optional): This parameter is a pack of keyword arguments example inputs + that will be passed to the function while tracing. Default: ``None``. + Either this argument or ``example_inputs`` in ``inputs`` should be specified. + The dict will be unpacking by the arguments name of the traced function. + If the keys of the dict don't not match with the traced module's forward + function's arguments name, a runtime exception will be raised. Returns: A :class:`ScriptModule` object with a single ``forward`` method containing the traced code. From 8dba25fc09df7a38be96686889cd93c2f0b1bb9a Mon Sep 17 00:00:00 2001 From: tangleintel Date: Fri, 14 Oct 2022 01:52:39 +0800 Subject: [PATCH 20/22] modify the interface of trace_module() and add UT for it & update doc & fix lint error --- test/test_jit.py | 3 ++ torch/_C/__init__.pyi.in | 9 ++++++ torch/jit/_trace.py | 61 ++++++++++++++++++++++------------------ 3 files changed, 46 insertions(+), 27 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 3244db7488d84..d341b6ce2a450 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3052,10 +3052,13 @@ def test_func(x, y): example_input_dict = {'key1': value1, 'key2': value2, 'key3': value3} example_input_dict_func = {'x': value1, 'y': value2} traced_model_1 = torch.jit.trace(model_1, example_kwarg_inputs=example_input_dict, strict=False) + traced_model_1_m = torch.jit.trace_module(model_1, {'forward': example_input_dict}, example_inputs_is_kwarg=True, strict=False) traced_model_2 = torch.jit.trace(model_2, example_kwarg_inputs={'x': torch.rand([2]), 'y': torch.rand([2])}) traced_func = torch.jit.trace(test_func, example_kwarg_inputs=example_input_dict_func, strict=False) res_1 = traced_model_1(**example_input_dict) + res_1_m = traced_model_1_m(**example_input_dict) self.assertEqual(res_1, 3 * torch.ones(1)) + self.assertEqual(res_1_m, 3 * torch.ones(1)) res_func = traced_func(**example_input_dict_func) self.assertEqual(res_func, 2 * torch.ones(1)) with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'x'."): diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 0e84fa864879c..0218c9fbdf4f8 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -332,6 +332,15 @@ def _create_function_from_trace( force_outplace: _bool, argument_names: List[str] ) -> Tuple[Graph, Stack]: ... +def _create_function_from_trace_with_dict( + qualname: str, + func: Callable[..., Any], + input_dict: Dict[str, Any], + var_lookup_fn: Callable[[Tensor], str], + strict: _bool, + force_outplace: _bool, + argument_names: List[str] +) -> Tuple[Graph, Stack]: ... def _jit_is_script_object(obj: Any) -> _bool: ... def _last_executed_optimized_graph() -> Graph: ... def parse_type_comment(comment: str) -> Decl: ... diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index e8d20ff6706dc..0935884ef2ec3 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -316,7 +316,7 @@ def _check_trace( force_outplace, is_trace_module, _module_class, - example_kwarg_inputs=None, + example_inputs_is_kwarg=False, ): # Note: tracing is independent of optimizations, which consume the trace for inputs in check_inputs: @@ -336,22 +336,33 @@ def _check_trace( _force_outplace=force_outplace, _module_class=_module_class, _compilation_unit=torch._C.CompilationUnit(), - example_kwarg_inputs=example_kwarg_inputs + example_inputs_is_kwarg=example_inputs_is_kwarg, ) check_mod_func = check_mod._c._get_method(traced_func.name) inputs = inputs[traced_func.name] - if isinstance(inputs, (torch.Tensor)) or isinstance(inputs, dict) and example_kwarg_inputs is None: + if isinstance(inputs, (torch.Tensor)) or isinstance(inputs, dict) and not example_inputs_is_kwarg: inputs = (inputs,) else: - check_mod = torch.jit.trace( - func, - _clone_inputs(inputs), - check_trace=False, - strict=strict, - _force_outplace=force_outplace, - _module_class=_module_class, - example_kwarg_inputs=example_kwarg_inputs, - ) + if example_inputs_is_kwarg: + check_mod = torch.jit.trace( + func, + check_trace=False, + strict=strict, + _force_outplace=force_outplace, + _module_class=_module_class, + example_kwarg_inputs=_clone_inputs(inputs), + ) + else: + check_mod = torch.jit.trace( + func, + _clone_inputs(inputs), + check_trace=False, + strict=strict, + _force_outplace=force_outplace, + _module_class=_module_class, + ) + + check_mod_func = check_mod def graph_diagnostic_info(): @@ -443,7 +454,7 @@ def wrap_retval(x): def run_mod_and_filter_tensor_outputs(mod, inputs, running_what): try: - if isinstance(inputs, dict) and isinstance(example_kwarg_inputs, dict): + if isinstance(inputs, dict) and example_inputs_is_kwarg: outs = wrap_retval(mod(**inputs)) else: outs = wrap_retval(mod(*_clone_inputs(inputs))) @@ -788,7 +799,7 @@ def forward(self, x): strict, _force_outplace, _module_class, - example_kwarg_inputs=example_kwarg_inputs, + example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), ) if ( @@ -811,7 +822,7 @@ def forward(self, x): strict, _force_outplace, _module_class, - example_kwarg_inputs=example_kwarg_inputs, + example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), ) # Special case for common case of passing a single Tensor @@ -864,7 +875,7 @@ def forward(self, x): _force_outplace, False, _module_class, - example_kwarg_inputs=example_kwarg_inputs, + example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), ) else: _check_trace( @@ -876,7 +887,7 @@ def forward(self, x): _force_outplace, False, _module_class, - example_kwarg_inputs=example_kwarg_inputs, + example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), ) return traced @@ -896,7 +907,7 @@ def trace_module( _force_outplace=False, _module_class=None, _compilation_unit=_python_cu, - example_kwarg_inputs=None, + example_inputs_is_kwarg=False, ): """ Trace a module and return an executable :class:`ScriptModule` that will be optimized @@ -931,12 +942,8 @@ def trace_module( check_tolerance (float, optional): Floating-point comparison tolerance to use in the checker procedure. This can be used to relax the checker strictness in the event that results diverge numerically for a known reason, such as operator fusion. - example_kwarg_inputs (dict, optional): This parameter is a pack of keyword arguments example inputs - that will be passed to the function while tracing. Default: ``None``. - Either this argument or ``example_inputs`` in ``inputs`` should be specified. - The dict will be unpacking by the arguments name of the traced function. - If the keys of the dict don't not match with the traced module's forward - function's arguments name, a runtime exception will be raised. + example_inputs_is_kwarg (``bool``, optional): This parameter indicate wether the example inputs is a pack + pack of keyword arguments. Default: ``False``. Returns: A :class:`ScriptModule` object with a single ``forward`` method containing the traced code. @@ -1021,7 +1028,7 @@ def register_submods(mod, prefix): func = getattr(mod, method_name) argument_names = get_callable_argument_names(func) - if isinstance(example_inputs, dict) and isinstance(example_kwarg_inputs, dict): + if isinstance(example_inputs, dict) and example_inputs_is_kwarg: # Raise exception when the user provided key names are not aligned with forward() method's arguments' name/ for key in example_inputs: if key not in argument_names: @@ -1063,7 +1070,7 @@ def register_submods(mod, prefix): _force_outplace, True, _module_class, - example_kwarg_inputs=example_kwarg_inputs, + example_inputs_is_kwarg=example_inputs_is_kwarg, ) else: _check_trace( @@ -1075,7 +1082,7 @@ def register_submods(mod, prefix): _force_outplace, True, _module_class, - example_kwarg_inputs=example_kwarg_inputs, + example_inputs_is_kwarg=example_inputs_is_kwarg, ) finally: torch.jit._trace._trace_module_map = old_module_map From 5beeb639772968ee3f75f50bf3af44a8fa7360c2 Mon Sep 17 00:00:00 2001 From: tangleintel Date: Fri, 14 Oct 2022 08:54:53 +0800 Subject: [PATCH 21/22] fix lint error --- test/test_jit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_jit.py b/test/test_jit.py index d341b6ce2a450..b1425a4ed71ca 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3052,7 +3052,8 @@ def test_func(x, y): example_input_dict = {'key1': value1, 'key2': value2, 'key3': value3} example_input_dict_func = {'x': value1, 'y': value2} traced_model_1 = torch.jit.trace(model_1, example_kwarg_inputs=example_input_dict, strict=False) - traced_model_1_m = torch.jit.trace_module(model_1, {'forward': example_input_dict}, example_inputs_is_kwarg=True, strict=False) + traced_model_1_m = torch.jit.trace_module( + model_1, {'forward': example_input_dict}, example_inputs_is_kwarg=True, strict=False) traced_model_2 = torch.jit.trace(model_2, example_kwarg_inputs={'x': torch.rand([2]), 'y': torch.rand([2])}) traced_func = torch.jit.trace(test_func, example_kwarg_inputs=example_input_dict_func, strict=False) res_1 = traced_model_1(**example_input_dict) From 073861cbcac80acff8eb729325054989c8c94faf Mon Sep 17 00:00:00 2001 From: tangleintel Date: Fri, 14 Oct 2022 10:28:11 +0800 Subject: [PATCH 22/22] doc format --- torch/jit/_trace.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index 0935884ef2ec3..b4352648df9c9 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -712,11 +712,11 @@ def trace( problem is a ``constant`` structure and does not get used as control flow (if, for) conditions. example_kwarg_inputs (dict, optional): This parameter is a pack of keyword - arguments example inputs that will be passed to the function while tracing. - Default: ``None``. Either this argument or ``example_inputs`` should be - specified. The dict will be unpacking by the arguments name of the traced - function. If the keys of the dict don't not match with the traced - function's arguments name, a runtime exception will be raised. + arguments of example inputs that will be passed to the function while + tracing. Default: ``None``. Either this argument or ``example_inputs`` + should be specified. The dict will be unpacking by the arguments name + of the traced function. If the keys of the dict don't not match with + the traced function's arguments name, a runtime exception will be raised. Returns: If `func` is `nn.Module` or ``forward`` of `nn.Module`, `trace` returns