Skip to content

Commit

Permalink
[Fix] PT - convert BF16 tensor to float before calling .numpy() (#1342)
Browse files Browse the repository at this point in the history
  • Loading branch information
chunyuan-w authored Oct 12, 2023
1 parent 50d65d7 commit 56c8356
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 12 deletions.
4 changes: 2 additions & 2 deletions doctr/models/detection/differentiable_binarization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from doctr.file_utils import CLASS_NAME

from ...classification import mobilenet_v3_large
from ...utils import load_pretrained_params
from ...utils import _bf16_to_numpy_dtype, load_pretrained_params
from .base import DBPostProcessor, _DBNet

__all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large", "db_resnet50_rotation"]
Expand Down Expand Up @@ -203,7 +203,7 @@ def forward(
return out

if return_model_output or target is None or return_preds:
prob_map = torch.sigmoid(logits)
prob_map = _bf16_to_numpy_dtype(torch.sigmoid(logits))

if return_model_output:
out["out_map"] = prob_map
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/detection/linknet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from doctr.file_utils import CLASS_NAME
from doctr.models.classification import resnet18, resnet34, resnet50

from ...utils import load_pretrained_params
from ...utils import _bf16_to_numpy_dtype, load_pretrained_params
from .base import LinkNetPostProcessor, _LinkNet

__all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
Expand Down Expand Up @@ -175,7 +175,7 @@ def forward(
return out

if return_model_output or target is None or return_preds:
prob_map = torch.sigmoid(logits)
prob_map = _bf16_to_numpy_dtype(torch.sigmoid(logits))
if return_model_output:
out["out_map"] = prob_map

Expand Down
4 changes: 3 additions & 1 deletion doctr/models/recognition/master/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from doctr.models.classification import magc_resnet31
from doctr.models.modules.transformer import Decoder, PositionalEncoding

from ...utils.pytorch import load_pretrained_params
from ...utils.pytorch import _bf16_to_numpy_dtype, load_pretrained_params
from .base import _MASTER, _MASTERPostProcessor

__all__ = ["MASTER", "master"]
Expand Down Expand Up @@ -195,6 +195,8 @@ def forward(
else:
logits = self.decode(encoded)

logits = _bf16_to_numpy_dtype(logits)

if self.exportable:
out["logits"] = logits
return out
Expand Down
4 changes: 3 additions & 1 deletion doctr/models/recognition/parseq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward

from ...classification import vit_s
from ...utils.pytorch import load_pretrained_params
from ...utils.pytorch import _bf16_to_numpy_dtype, load_pretrained_params
from .base import _PARSeq, _PARSeqPostProcessor

__all__ = ["PARSeq", "parseq"]
Expand Down Expand Up @@ -362,6 +362,8 @@ def forward(
else:
logits = self.decode_autoregressive(features)

logits = _bf16_to_numpy_dtype(logits)

out: Dict[str, Any] = {}
if self.exportable:
out["logits"] = logits
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/sar/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from doctr.datasets import VOCABS

from ...classification import resnet31
from ...utils.pytorch import load_pretrained_params
from ...utils.pytorch import _bf16_to_numpy_dtype, load_pretrained_params
from ..core import RecognitionModel, RecognitionPostProcessor

__all__ = ["SAR", "sar_resnet31"]
Expand Down Expand Up @@ -249,7 +249,7 @@ def forward(
if self.training and target is None:
raise ValueError("Need to provide labels during training for teacher forcing")

decoded_features = self.decoder(features, encoded, gt=None if target is None else gt)
decoded_features = _bf16_to_numpy_dtype(self.decoder(features, encoded, gt=None if target is None else gt))

out: Dict[str, Any] = {}
if self.exportable:
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/vitstr/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from doctr.datasets import VOCABS

from ...classification import vit_b, vit_s
from ...utils.pytorch import load_pretrained_params
from ...utils.pytorch import _bf16_to_numpy_dtype, load_pretrained_params
from .base import _ViTSTR, _ViTSTRPostProcessor

__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
Expand Down Expand Up @@ -95,7 +95,7 @@ def forward(
B, N, E = features.size()
features = features.reshape(B * N, E)
logits = self.head(features).view(B, N, len(self.vocab) + 1) # (batch_size, max_length, vocab + 1)
decoded_features = logits[:, 1:] # remove cls_token
decoded_features = _bf16_to_numpy_dtype(logits[:, 1:]) # remove cls_token

out: Dict[str, Any] = {}
if self.exportable:
Expand Down
14 changes: 13 additions & 1 deletion doctr/models/utils/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@

from doctr.utils.data import download_from_url

__all__ = ["load_pretrained_params", "conv_sequence_pt", "set_device_and_dtype", "export_model_to_onnx", "_copy_tensor"]
__all__ = [
"load_pretrained_params",
"conv_sequence_pt",
"set_device_and_dtype",
"export_model_to_onnx",
"_copy_tensor",
"_bf16_to_numpy_dtype",
]


def _copy_tensor(x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -150,3 +157,8 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T
)
logging.info(f"Model exported to {model_name}.onnx")
return f"{model_name}.onnx"


def _bf16_to_numpy_dtype(x):
# bfloat16 is not supported in .numpy(): torch/csrc/utils/tensor_numpy.cpp:aten_to_numpy_dtype
return x.float() if x.dtype == torch.bfloat16 else x
14 changes: 13 additions & 1 deletion tests/pytorch/test_models_utils_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
import torch
from torch import nn

from doctr.models.utils import _copy_tensor, conv_sequence_pt, load_pretrained_params, set_device_and_dtype
from doctr.models.utils import (
_bf16_to_numpy_dtype,
_copy_tensor,
conv_sequence_pt,
load_pretrained_params,
set_device_and_dtype,
)


def test_copy_tensor():
Expand Down Expand Up @@ -52,3 +58,9 @@ def test_set_device_and_dtype():
model, batches = set_device_and_dtype(model, batches, device="cpu", dtype=torch.float16)
assert model[0].weight.dtype == torch.float16
assert batches[0].dtype == torch.float16


def test_bf16_to_numpy_dtype():
x = torch.randn([2, 2], dtype=torch.bfloat16)
converted_x = _bf16_to_numpy_dtype(x)
assert x.dtype == torch.bfloat16 and converted_x.dtype == torch.float32 and torch.equal(converted_x, x.float())

0 comments on commit 56c8356

Please sign in to comment.