From fcb01279252b7af3bc385a9f2ede15771ec6321a Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 5 Mar 2023 07:51:22 -0500 Subject: [PATCH] [Unity][Fix] FX translating dtype (#14201) This PR fixes a bug of the current FX translator when dealing with dtype. Previously, the translator does not take the cases ```python dtype = x.getattr("dtype") ``` into consideration. In this case, the dtype will be a fx.Node object, while the translator assumes that the dtype is either a string or a torch native datatype (e.g., torch.float32). This PR fixes this by doing an environment table lookup before for all dtypes. --- python/tvm/relax/frontend/torch/fx_translator.py | 15 +++++++++------ tests/python/relax/test_frontend_from_fx.py | 14 ++++++++++---- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index a73bc9d0db8c..fa68b2eee3ea 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -18,7 +18,7 @@ # pylint: disable=invalid-name, inconsistent-return-statements, unidiomatic-typecheck # pylint: disable=import-outside-toplevel """PyTorch FX frontend of Relax.""" -from typing import Callable, Dict, List, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union from functools import reduce import tvm @@ -61,10 +61,13 @@ def _fetch_attr(model, target: str): return attr_itr @staticmethod - def _convert_data_type(input_type): + def _convert_data_type(input_type, env: Optional[Dict] = None): """converts the PyTorch scalar type input_type to a TVM dtype.""" import torch # type: ignore + if env is not None and input_type in env: + input_type = env[input_type] + input_type = input_type.lower() if isinstance(input_type, str) else input_type if input_type in ["float", "float32", "torch.float32", torch.float32]: return "float32" @@ -247,7 +250,7 @@ def _arange(self, node: fx.node.Node) -> relax.Var: start_end_step[2] = 1 if "dtype" in node.kwargs: - dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"])) + dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) elif any([isinstance(x, float) for x in start_end_step]): dtype = TorchFXImporter._convert_data_type(torch.get_default_dtype()) else: @@ -256,7 +259,7 @@ def _arange(self, node: fx.node.Node) -> relax.Var: return relax.const(np.arange(*start_end_step, dtype=dtype)) def _empty(self, node: fx.node.Node) -> relax.Var: - dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"])) + dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) return self.block_builder.emit(relax.op.zeros(node.args, dtype)) def _inplace_fill(self, node: fx.node.Node) -> relax.Var: @@ -334,7 +337,7 @@ def _half(self, node: fx.node.Node) -> relax.Var: def _type(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] - dtype = self._convert_data_type(node.args[1]) + dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) return self.block_builder.emit(relax.op.astype(x, dtype)) ########## Linear Algebra ########## @@ -565,7 +568,7 @@ def _batch_norm_2d(self, node: fx.node.Node) -> relax.Var: module = self.named_modules[node.target] weight = self.params[module.weight] bias = self.params[module.bias] - dtype = self._convert_data_type(str(module.running_mean.dtype)) + dtype = TorchFXImporter._convert_data_type(str(module.running_mean.dtype)) running_mean = relax.const(module.running_mean.cpu().detach().numpy(), dtype) running_var = relax.const(module.running_var.cpu().detach().numpy(), dtype) eps = module.eps diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index e28483dc2fab..4fd7cee81298 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -1683,7 +1683,7 @@ def forward(self, input): return torch.arange(0, 20, dtype=torch.int32) graph_model = fx.symbolic_trace(Arange()) - mod = from_fx(graph_model, [([10, 10], "float32")]) + mod = from_fx(graph_model, [([10, 10], "float32")]).mod assert len(mod["main"].body.blocks) == 1 assert len(mod["main"].body.blocks[0].bindings) == 1 assert isinstance(mod["main"].body.blocks[0].bindings[0].value, relax.Constant) @@ -1707,7 +1707,7 @@ def forward(self, input): return torch.empty((10, 10), dtype=torch.float32) graph_model = fx.symbolic_trace(Empty()) - mod = from_fx(graph_model, [([10, 10], "float32")]) + mod = from_fx(graph_model, [([10, 10], "float32")]).mod assert len(mod["main"].body.blocks) == 1 assert len(mod["main"].body.blocks[0].bindings) == 1 assert isinstance(mod["main"].body.blocks[0].bindings[0].value, relax.Constant) @@ -1734,7 +1734,7 @@ def forward(self, input): return torch.tensor(3) graph_model1 = fx.symbolic_trace(Empty1()) - mod1 = from_fx(graph_model1, [([10, 10], "float32")]) + mod1 = from_fx(graph_model1, [([10, 10], "float32")]).mod assert len(mod1["main"].body.blocks) == 1 assert len(mod1["main"].body.blocks[0].bindings) == 1 assert isinstance(mod1["main"].body.blocks[0].bindings[0].value, relax.Constant) @@ -1742,7 +1742,7 @@ def forward(self, input): assert mod1["main"].body.blocks[0].bindings[0].value.data.dtype == "float32" graph_model2 = fx.symbolic_trace(Empty2()) - mod2 = from_fx(graph_model2, [([10, 10], "float32")]) + mod2 = from_fx(graph_model2, [([10, 10], "float32")]).mod assert len(mod2["main"].body.blocks) == 1 assert len(mod2["main"].body.blocks[0].bindings) == 1 assert isinstance(mod2["main"].body.blocks[0].bindings[0].value, relax.Constant) @@ -1968,12 +1968,18 @@ class Type(Module): def forward(self, x): return x.type(torch.float32) + # type + class TypeFromAttr(Module): + def forward(self, x): + return x.type(x.getattr("dtype")) + # astype class AsType(Module): def forward(self, x): return x.astype(torch.float32) verify_model(Type(), input_info, {}, expected1) + verify_model(TypeFromAttr(), input_info, {}, expected1) verify_model(AsType(), input_info, {}, expected1)