diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index 20f74a2ad5..d8d70085e3 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -15,7 +15,7 @@ from tensorflow.keras.applications import ResNet50 from doctr.file_utils import CLASS_NAME -from doctr.models.utils import IntermediateLayerGetter, conv_sequence, load_pretrained_params +from doctr.models.utils import IntermediateLayerGetter, _bf16_to_numpy_dtype, conv_sequence, load_pretrained_params from doctr.utils.repr import NestedObject from ...classification import mobilenet_v3_large @@ -241,7 +241,7 @@ def call( return out if return_model_output or target is None or return_preds: - prob_map = tf.math.sigmoid(logits) + prob_map = _bf16_to_numpy_dtype(tf.math.sigmoid(logits)) if return_model_output: out["out_map"] = prob_map diff --git a/doctr/models/detection/linknet/tensorflow.py b/doctr/models/detection/linknet/tensorflow.py index 3ac436088e..0e49d93067 100644 --- a/doctr/models/detection/linknet/tensorflow.py +++ b/doctr/models/detection/linknet/tensorflow.py @@ -15,7 +15,7 @@ from doctr.file_utils import CLASS_NAME from doctr.models.classification import resnet18, resnet34, resnet50 -from doctr.models.utils import IntermediateLayerGetter, conv_sequence, load_pretrained_params +from doctr.models.utils import IntermediateLayerGetter, _bf16_to_numpy_dtype, conv_sequence, load_pretrained_params from doctr.utils.repr import NestedObject from .base import LinkNetPostProcessor, _LinkNet @@ -229,7 +229,8 @@ def call( return out if return_model_output or target is None or return_preds: - prob_map = tf.math.sigmoid(logits) + prob_map = _bf16_to_numpy_dtype(tf.math.sigmoid(logits)) + if return_model_output: out["out_map"] = prob_map diff --git a/doctr/models/modules/transformer/pytorch.py b/doctr/models/modules/transformer/pytorch.py index 63ad346c2f..190a12da63 100644 --- a/doctr/models/modules/transformer/pytorch.py +++ b/doctr/models/modules/transformer/pytorch.py @@ -37,7 +37,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: positional embeddings (batch, max_len, d_model) """ - x = x + self.pe[:, : x.size(1)] # type: ignore + x = x + self.pe[:, : x.size(1)] return self.dropout(x) diff --git a/doctr/models/recognition/crnn/pytorch.py b/doctr/models/recognition/crnn/pytorch.py index b1e50f1ad9..daf0e56e58 100644 --- a/doctr/models/recognition/crnn/pytorch.py +++ b/doctr/models/recognition/crnn/pytorch.py @@ -249,7 +249,7 @@ def _crnn( _cfg["input_shape"] = kwargs["input_shape"] # Build the model - model = CRNN(feat_extractor, cfg=_cfg, **kwargs) # type: ignore[arg-type] + model = CRNN(feat_extractor, cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: # The number of classes is not the same as the number of classes in the pretrained model => diff --git a/doctr/models/recognition/crnn/tensorflow.py b/doctr/models/recognition/crnn/tensorflow.py index 618a4c0e92..e00a3543ec 100644 --- a/doctr/models/recognition/crnn/tensorflow.py +++ b/doctr/models/recognition/crnn/tensorflow.py @@ -13,7 +13,7 @@ from doctr.datasets import VOCABS from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r -from ...utils.tensorflow import load_pretrained_params +from ...utils.tensorflow import _bf16_to_numpy_dtype, load_pretrained_params from ..core import RecognitionModel, RecognitionPostProcessor __all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"] @@ -199,7 +199,7 @@ def call( w, h, c = transposed_feat.get_shape().as_list()[1:] # B x W x H x C --> B x W x H * C features_seq = tf.reshape(transposed_feat, shape=(-1, w, h * c)) - logits = self.decoder(features_seq, **kwargs) + logits = _bf16_to_numpy_dtype(self.decoder(features_seq, **kwargs)) out: Dict[str, tf.Tensor] = {} if self.exportable: diff --git a/doctr/models/recognition/master/tensorflow.py b/doctr/models/recognition/master/tensorflow.py index 908c2e8b8f..bbae216f74 100644 --- a/doctr/models/recognition/master/tensorflow.py +++ b/doctr/models/recognition/master/tensorflow.py @@ -13,7 +13,7 @@ from doctr.models.classification import magc_resnet31 from doctr.models.modules.transformer import Decoder, PositionalEncoding -from ...utils.tensorflow import load_pretrained_params +from ...utils.tensorflow import _bf16_to_numpy_dtype, load_pretrained_params from .base import _MASTER, _MASTERPostProcessor __all__ = ["MASTER", "master"] @@ -183,6 +183,8 @@ def call( else: logits = self.decode(encoded, **kwargs) + logits = _bf16_to_numpy_dtype(logits) + if self.exportable: out["logits"] = logits return out diff --git a/doctr/models/recognition/parseq/tensorflow.py b/doctr/models/recognition/parseq/tensorflow.py index 21a35605f5..8ef77af4cd 100644 --- a/doctr/models/recognition/parseq/tensorflow.py +++ b/doctr/models/recognition/parseq/tensorflow.py @@ -16,7 +16,7 @@ from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward from ...classification import vit_s -from ...utils.tensorflow import load_pretrained_params +from ...utils.tensorflow import _bf16_to_numpy_dtype, load_pretrained_params from .base import _PARSeq, _PARSeqPostProcessor __all__ = ["PARSeq", "parseq"] @@ -390,6 +390,8 @@ def call( else: logits = self.decode_autoregressive(features, **kwargs) + logits = _bf16_to_numpy_dtype(logits) + out: Dict[str, tf.Tensor] = {} if self.exportable: out["logits"] = logits diff --git a/doctr/models/recognition/sar/tensorflow.py b/doctr/models/recognition/sar/tensorflow.py index 6a688c7bac..a91cb8be6b 100644 --- a/doctr/models/recognition/sar/tensorflow.py +++ b/doctr/models/recognition/sar/tensorflow.py @@ -13,7 +13,7 @@ from doctr.utils.repr import NestedObject from ...classification import resnet31 -from ...utils.tensorflow import load_pretrained_params +from ...utils.tensorflow import _bf16_to_numpy_dtype, load_pretrained_params from ..core import RecognitionModel, RecognitionPostProcessor __all__ = ["SAR", "sar_resnet31"] @@ -316,7 +316,9 @@ def call( if kwargs.get("training", False) and target is None: raise ValueError("Need to provide labels during training for teacher forcing") - decoded_features = self.decoder(features, encoded, gt=None if target is None else gt, **kwargs) + decoded_features = _bf16_to_numpy_dtype( + self.decoder(features, encoded, gt=None if target is None else gt, **kwargs) + ) out: Dict[str, tf.Tensor] = {} if self.exportable: diff --git a/doctr/models/recognition/vitstr/tensorflow.py b/doctr/models/recognition/vitstr/tensorflow.py index 70c7325b3f..84eeb1303e 100644 --- a/doctr/models/recognition/vitstr/tensorflow.py +++ b/doctr/models/recognition/vitstr/tensorflow.py @@ -12,7 +12,7 @@ from doctr.datasets import VOCABS from ...classification import vit_b, vit_s -from ...utils.tensorflow import load_pretrained_params +from ...utils.tensorflow import _bf16_to_numpy_dtype, load_pretrained_params from .base import _ViTSTR, _ViTSTRPostProcessor __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"] @@ -131,7 +131,7 @@ def call( logits = tf.reshape( self.head(features, **kwargs), (B, N, len(self.vocab) + 1) ) # (batch_size, max_length, vocab + 1) - decoded_features = logits[:, 1:] # remove cls_token + decoded_features = _bf16_to_numpy_dtype(logits[:, 1:]) # remove cls_token out: Dict[str, tf.Tensor] = {} if self.exportable: diff --git a/doctr/models/utils/pytorch.py b/doctr/models/utils/pytorch.py index 4e15fa628a..2179ba572b 100644 --- a/doctr/models/utils/pytorch.py +++ b/doctr/models/utils/pytorch.py @@ -25,6 +25,11 @@ def _copy_tensor(x: torch.Tensor) -> torch.Tensor: return x.clone().detach() +def _bf16_to_numpy_dtype(x: torch.Tensor) -> torch.Tensor: + # bfloat16 is not supported in .numpy(): torch/csrc/utils/tensor_numpy.cpp:aten_to_numpy_dtype + return x.float() if x.dtype == torch.bfloat16 else x + + def load_pretrained_params( model: nn.Module, url: Optional[str] = None, @@ -157,8 +162,3 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T ) logging.info(f"Model exported to {model_name}.onnx") return f"{model_name}.onnx" - - -def _bf16_to_numpy_dtype(x): - # bfloat16 is not supported in .numpy(): torch/csrc/utils/tensor_numpy.cpp:aten_to_numpy_dtype - return x.float() if x.dtype == torch.bfloat16 else x diff --git a/doctr/models/utils/tensorflow.py b/doctr/models/utils/tensorflow.py index 8490c09f11..98baf2d9d3 100644 --- a/doctr/models/utils/tensorflow.py +++ b/doctr/models/utils/tensorflow.py @@ -17,13 +17,25 @@ logging.getLogger("tensorflow").setLevel(logging.DEBUG) -__all__ = ["load_pretrained_params", "conv_sequence", "IntermediateLayerGetter", "export_model_to_onnx", "_copy_tensor"] +__all__ = [ + "load_pretrained_params", + "conv_sequence", + "IntermediateLayerGetter", + "export_model_to_onnx", + "_copy_tensor", + "_bf16_to_numpy_dtype", +] def _copy_tensor(x: tf.Tensor) -> tf.Tensor: return tf.identity(x) +def _bf16_to_numpy_dtype(x: tf.Tensor) -> tf.Tensor: + # Convert bfloat16 to float32 for numpy compatibility + return tf.cast(x, tf.float32) if x.dtype == tf.bfloat16 else x + + def load_pretrained_params( model: Model, url: Optional[str] = None, diff --git a/tests/tensorflow/test_models_utils_tf.py b/tests/tensorflow/test_models_utils_tf.py index 2e256cacb8..5db792a91e 100644 --- a/tests/tensorflow/test_models_utils_tf.py +++ b/tests/tensorflow/test_models_utils_tf.py @@ -5,7 +5,13 @@ from tensorflow.keras import Sequential, layers from tensorflow.keras.applications import ResNet50 -from doctr.models.utils import IntermediateLayerGetter, _copy_tensor, conv_sequence, load_pretrained_params +from doctr.models.utils import ( + IntermediateLayerGetter, + _bf16_to_numpy_dtype, + _copy_tensor, + conv_sequence, + load_pretrained_params, +) def test_copy_tensor(): @@ -14,6 +20,12 @@ def test_copy_tensor(): assert m.device == x.device and m.dtype == x.dtype and m.shape == x.shape and tf.reduce_all(tf.equal(m, x)) +def test_bf16_to_numpy_dtype(): + x = tf.random.uniform(shape=[8], minval=0, maxval=1, dtype=tf.bfloat16) + m = _bf16_to_numpy_dtype(x) + assert x.dtype == tf.bfloat16 and m.dtype == tf.float32 and tf.reduce_all(tf.equal(m, tf.cast(x, tf.float32))) + + def test_load_pretrained_params(tmpdir_factory): model = Sequential([layers.Dense(8, activation="relu", input_shape=(4,)), layers.Dense(4)]) # Retrieve this URL