Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] TF - add bf16 numpy dtype conversion #1344

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from tensorflow.keras.applications import ResNet50

from doctr.file_utils import CLASS_NAME
from doctr.models.utils import IntermediateLayerGetter, conv_sequence, load_pretrained_params
from doctr.models.utils import IntermediateLayerGetter, _bf16_to_numpy_dtype, conv_sequence, load_pretrained_params
from doctr.utils.repr import NestedObject

from ...classification import mobilenet_v3_large
Expand Down Expand Up @@ -241,7 +241,7 @@ def call(
return out

if return_model_output or target is None or return_preds:
prob_map = tf.math.sigmoid(logits)
prob_map = _bf16_to_numpy_dtype(tf.math.sigmoid(logits))

if return_model_output:
out["out_map"] = prob_map
Expand Down
5 changes: 3 additions & 2 deletions doctr/models/detection/linknet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from doctr.file_utils import CLASS_NAME
from doctr.models.classification import resnet18, resnet34, resnet50
from doctr.models.utils import IntermediateLayerGetter, conv_sequence, load_pretrained_params
from doctr.models.utils import IntermediateLayerGetter, _bf16_to_numpy_dtype, conv_sequence, load_pretrained_params
from doctr.utils.repr import NestedObject

from .base import LinkNetPostProcessor, _LinkNet
Expand Down Expand Up @@ -229,7 +229,8 @@ def call(
return out

if return_model_output or target is None or return_preds:
prob_map = tf.math.sigmoid(logits)
prob_map = _bf16_to_numpy_dtype(tf.math.sigmoid(logits))

if return_model_output:
out["out_map"] = prob_map

Expand Down
2 changes: 1 addition & 1 deletion doctr/models/modules/transformer/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Returns:
positional embeddings (batch, max_len, d_model)
"""
x = x + self.pe[:, : x.size(1)] # type: ignore
x = x + self.pe[:, : x.size(1)]
return self.dropout(x)


Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/crnn/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def _crnn(
_cfg["input_shape"] = kwargs["input_shape"]

# Build the model
model = CRNN(feat_extractor, cfg=_cfg, **kwargs) # type: ignore[arg-type]
model = CRNN(feat_extractor, cfg=_cfg, **kwargs)
# Load pretrained parameters
if pretrained:
# The number of classes is not the same as the number of classes in the pretrained model =>
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/crnn/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from doctr.datasets import VOCABS

from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r
from ...utils.tensorflow import load_pretrained_params
from ...utils.tensorflow import _bf16_to_numpy_dtype, load_pretrained_params
from ..core import RecognitionModel, RecognitionPostProcessor

__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
Expand Down Expand Up @@ -199,7 +199,7 @@ def call(
w, h, c = transposed_feat.get_shape().as_list()[1:]
# B x W x H x C --> B x W x H * C
features_seq = tf.reshape(transposed_feat, shape=(-1, w, h * c))
logits = self.decoder(features_seq, **kwargs)
logits = _bf16_to_numpy_dtype(self.decoder(features_seq, **kwargs))

out: Dict[str, tf.Tensor] = {}
if self.exportable:
Expand Down
4 changes: 3 additions & 1 deletion doctr/models/recognition/master/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from doctr.models.classification import magc_resnet31
from doctr.models.modules.transformer import Decoder, PositionalEncoding

from ...utils.tensorflow import load_pretrained_params
from ...utils.tensorflow import _bf16_to_numpy_dtype, load_pretrained_params
from .base import _MASTER, _MASTERPostProcessor

__all__ = ["MASTER", "master"]
Expand Down Expand Up @@ -183,6 +183,8 @@ def call(
else:
logits = self.decode(encoded, **kwargs)

logits = _bf16_to_numpy_dtype(logits)

if self.exportable:
out["logits"] = logits
return out
Expand Down
4 changes: 3 additions & 1 deletion doctr/models/recognition/parseq/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward

from ...classification import vit_s
from ...utils.tensorflow import load_pretrained_params
from ...utils.tensorflow import _bf16_to_numpy_dtype, load_pretrained_params
from .base import _PARSeq, _PARSeqPostProcessor

__all__ = ["PARSeq", "parseq"]
Expand Down Expand Up @@ -390,6 +390,8 @@ def call(
else:
logits = self.decode_autoregressive(features, **kwargs)

logits = _bf16_to_numpy_dtype(logits)

out: Dict[str, tf.Tensor] = {}
if self.exportable:
out["logits"] = logits
Expand Down
6 changes: 4 additions & 2 deletions doctr/models/recognition/sar/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from doctr.utils.repr import NestedObject

from ...classification import resnet31
from ...utils.tensorflow import load_pretrained_params
from ...utils.tensorflow import _bf16_to_numpy_dtype, load_pretrained_params
from ..core import RecognitionModel, RecognitionPostProcessor

__all__ = ["SAR", "sar_resnet31"]
Expand Down Expand Up @@ -316,7 +316,9 @@ def call(
if kwargs.get("training", False) and target is None:
raise ValueError("Need to provide labels during training for teacher forcing")

decoded_features = self.decoder(features, encoded, gt=None if target is None else gt, **kwargs)
decoded_features = _bf16_to_numpy_dtype(
self.decoder(features, encoded, gt=None if target is None else gt, **kwargs)
)

out: Dict[str, tf.Tensor] = {}
if self.exportable:
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/vitstr/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from doctr.datasets import VOCABS

from ...classification import vit_b, vit_s
from ...utils.tensorflow import load_pretrained_params
from ...utils.tensorflow import _bf16_to_numpy_dtype, load_pretrained_params
from .base import _ViTSTR, _ViTSTRPostProcessor

__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
Expand Down Expand Up @@ -131,7 +131,7 @@ def call(
logits = tf.reshape(
self.head(features, **kwargs), (B, N, len(self.vocab) + 1)
) # (batch_size, max_length, vocab + 1)
decoded_features = logits[:, 1:] # remove cls_token
decoded_features = _bf16_to_numpy_dtype(logits[:, 1:]) # remove cls_token

out: Dict[str, tf.Tensor] = {}
if self.exportable:
Expand Down
10 changes: 5 additions & 5 deletions doctr/models/utils/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ def _copy_tensor(x: torch.Tensor) -> torch.Tensor:
return x.clone().detach()


def _bf16_to_numpy_dtype(x: torch.Tensor) -> torch.Tensor:
# bfloat16 is not supported in .numpy(): torch/csrc/utils/tensor_numpy.cpp:aten_to_numpy_dtype
return x.float() if x.dtype == torch.bfloat16 else x


def load_pretrained_params(
model: nn.Module,
url: Optional[str] = None,
Expand Down Expand Up @@ -157,8 +162,3 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T
)
logging.info(f"Model exported to {model_name}.onnx")
return f"{model_name}.onnx"


def _bf16_to_numpy_dtype(x):
# bfloat16 is not supported in .numpy(): torch/csrc/utils/tensor_numpy.cpp:aten_to_numpy_dtype
return x.float() if x.dtype == torch.bfloat16 else x
14 changes: 13 additions & 1 deletion doctr/models/utils/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,25 @@
logging.getLogger("tensorflow").setLevel(logging.DEBUG)


__all__ = ["load_pretrained_params", "conv_sequence", "IntermediateLayerGetter", "export_model_to_onnx", "_copy_tensor"]
__all__ = [
"load_pretrained_params",
"conv_sequence",
"IntermediateLayerGetter",
"export_model_to_onnx",
"_copy_tensor",
"_bf16_to_numpy_dtype",
]


def _copy_tensor(x: tf.Tensor) -> tf.Tensor:
return tf.identity(x)


def _bf16_to_numpy_dtype(x: tf.Tensor) -> tf.Tensor:
# Convert bfloat16 to float32 for numpy compatibility
return tf.cast(x, tf.float32) if x.dtype == tf.bfloat16 else x


def load_pretrained_params(
model: Model,
url: Optional[str] = None,
Expand Down
14 changes: 13 additions & 1 deletion tests/tensorflow/test_models_utils_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
from tensorflow.keras import Sequential, layers
from tensorflow.keras.applications import ResNet50

from doctr.models.utils import IntermediateLayerGetter, _copy_tensor, conv_sequence, load_pretrained_params
from doctr.models.utils import (
IntermediateLayerGetter,
_bf16_to_numpy_dtype,
_copy_tensor,
conv_sequence,
load_pretrained_params,
)


def test_copy_tensor():
Expand All @@ -14,6 +20,12 @@ def test_copy_tensor():
assert m.device == x.device and m.dtype == x.dtype and m.shape == x.shape and tf.reduce_all(tf.equal(m, x))


def test_bf16_to_numpy_dtype():
x = tf.random.uniform(shape=[8], minval=0, maxval=1, dtype=tf.bfloat16)
m = _bf16_to_numpy_dtype(x)
assert x.dtype == tf.bfloat16 and m.dtype == tf.float32 and tf.reduce_all(tf.equal(m, tf.cast(x, tf.float32)))


def test_load_pretrained_params(tmpdir_factory):
model = Sequential([layers.Dense(8, activation="relu", input_shape=(4,)), layers.Dense(4)])
# Retrieve this URL
Expand Down
Loading