Skip to content

Commit

Permalink
[Unity][Fix] FX translating dtype (#14201)
Browse files Browse the repository at this point in the history
This PR fixes a bug of the current FX translator when dealing with
dtype.

Previously, the translator does not take the cases
```python
dtype = x.getattr("dtype")
```
into consideration. In this case, the dtype will be a fx.Node object,
while the translator assumes that the dtype is either a string or
a torch native datatype (e.g., torch.float32).

This PR fixes this by doing an environment table lookup before for all
dtypes.
  • Loading branch information
MasterJH5574 authored Mar 5, 2023
1 parent 85976ea commit fcb0127
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
15 changes: 9 additions & 6 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# pylint: disable=invalid-name, inconsistent-return-statements, unidiomatic-typecheck
# pylint: disable=import-outside-toplevel
"""PyTorch FX frontend of Relax."""
from typing import Callable, Dict, List, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union
from functools import reduce

import tvm
Expand Down Expand Up @@ -61,10 +61,13 @@ def _fetch_attr(model, target: str):
return attr_itr

@staticmethod
def _convert_data_type(input_type):
def _convert_data_type(input_type, env: Optional[Dict] = None):
"""converts the PyTorch scalar type input_type to a TVM dtype."""
import torch # type: ignore

if env is not None and input_type in env:
input_type = env[input_type]

input_type = input_type.lower() if isinstance(input_type, str) else input_type
if input_type in ["float", "float32", "torch.float32", torch.float32]:
return "float32"
Expand Down Expand Up @@ -247,7 +250,7 @@ def _arange(self, node: fx.node.Node) -> relax.Var:
start_end_step[2] = 1

if "dtype" in node.kwargs:
dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]))
dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env)
elif any([isinstance(x, float) for x in start_end_step]):
dtype = TorchFXImporter._convert_data_type(torch.get_default_dtype())
else:
Expand All @@ -256,7 +259,7 @@ def _arange(self, node: fx.node.Node) -> relax.Var:
return relax.const(np.arange(*start_end_step, dtype=dtype))

def _empty(self, node: fx.node.Node) -> relax.Var:
dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]))
dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env)
return self.block_builder.emit(relax.op.zeros(node.args, dtype))

def _inplace_fill(self, node: fx.node.Node) -> relax.Var:
Expand Down Expand Up @@ -334,7 +337,7 @@ def _half(self, node: fx.node.Node) -> relax.Var:

def _type(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
dtype = self._convert_data_type(node.args[1])
dtype = TorchFXImporter._convert_data_type(node.args[1], self.env)
return self.block_builder.emit(relax.op.astype(x, dtype))

########## Linear Algebra ##########
Expand Down Expand Up @@ -565,7 +568,7 @@ def _batch_norm_2d(self, node: fx.node.Node) -> relax.Var:
module = self.named_modules[node.target]
weight = self.params[module.weight]
bias = self.params[module.bias]
dtype = self._convert_data_type(str(module.running_mean.dtype))
dtype = TorchFXImporter._convert_data_type(str(module.running_mean.dtype))
running_mean = relax.const(module.running_mean.cpu().detach().numpy(), dtype)
running_var = relax.const(module.running_var.cpu().detach().numpy(), dtype)
eps = module.eps
Expand Down
14 changes: 10 additions & 4 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1683,7 +1683,7 @@ def forward(self, input):
return torch.arange(0, 20, dtype=torch.int32)

graph_model = fx.symbolic_trace(Arange())
mod = from_fx(graph_model, [([10, 10], "float32")])
mod = from_fx(graph_model, [([10, 10], "float32")]).mod
assert len(mod["main"].body.blocks) == 1
assert len(mod["main"].body.blocks[0].bindings) == 1
assert isinstance(mod["main"].body.blocks[0].bindings[0].value, relax.Constant)
Expand All @@ -1707,7 +1707,7 @@ def forward(self, input):
return torch.empty((10, 10), dtype=torch.float32)

graph_model = fx.symbolic_trace(Empty())
mod = from_fx(graph_model, [([10, 10], "float32")])
mod = from_fx(graph_model, [([10, 10], "float32")]).mod
assert len(mod["main"].body.blocks) == 1
assert len(mod["main"].body.blocks[0].bindings) == 1
assert isinstance(mod["main"].body.blocks[0].bindings[0].value, relax.Constant)
Expand All @@ -1734,15 +1734,15 @@ def forward(self, input):
return torch.tensor(3)

graph_model1 = fx.symbolic_trace(Empty1())
mod1 = from_fx(graph_model1, [([10, 10], "float32")])
mod1 = from_fx(graph_model1, [([10, 10], "float32")]).mod
assert len(mod1["main"].body.blocks) == 1
assert len(mod1["main"].body.blocks[0].bindings) == 1
assert isinstance(mod1["main"].body.blocks[0].bindings[0].value, relax.Constant)
assert mod1["main"].body.blocks[0].bindings[0].value.data.shape == ()
assert mod1["main"].body.blocks[0].bindings[0].value.data.dtype == "float32"

graph_model2 = fx.symbolic_trace(Empty2())
mod2 = from_fx(graph_model2, [([10, 10], "float32")])
mod2 = from_fx(graph_model2, [([10, 10], "float32")]).mod
assert len(mod2["main"].body.blocks) == 1
assert len(mod2["main"].body.blocks[0].bindings) == 1
assert isinstance(mod2["main"].body.blocks[0].bindings[0].value, relax.Constant)
Expand Down Expand Up @@ -1968,12 +1968,18 @@ class Type(Module):
def forward(self, x):
return x.type(torch.float32)

# type
class TypeFromAttr(Module):
def forward(self, x):
return x.type(x.getattr("dtype"))

# astype
class AsType(Module):
def forward(self, x):
return x.astype(torch.float32)

verify_model(Type(), input_info, {}, expected1)
verify_model(TypeFromAttr(), input_info, {}, expected1)
verify_model(AsType(), input_info, {}, expected1)


Expand Down

0 comments on commit fcb0127

Please sign in to comment.