Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] PT - convert BF16 tensor to float before calling .numpy() #1342

Merged
merged 7 commits into from
Oct 12, 2023

Conversation

chunyuan-w
Copy link
Contributor

.numpy() in PyTorch only supports limited scalar types: aten_to_numpy_dtype.
When running BF16 with autocast, an error will be thrown here when calling .numpy(): TypeError: Got unsupported ScalarType BFloat16.
Convert BF16 tensor to float before calling .numpy() to fix this error.

@chunyuan-w chunyuan-w changed the title convert BF16 tensor to float before calling .numpy() [Fix] convert BF16 tensor to float before calling .numpy() Oct 10, 2023
@chunyuan-w
Copy link
Contributor Author

Comment on lines 211 to 215
def need_conversion_to_float(dtype):
# pytorch: torch/csrc/utils/tensor_numpy.cpp:aten_to_numpy_dtype
return dtype in [torch.bfloat16]

numpy_dtype_converter = lambda x: x.float() if need_conversion_to_float(x.dtype) else x
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

directly checking dtype in [torch.bfloat16] is simpler?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated as suggested.

Copy link
Contributor

@felixdittrich92 felixdittrich92 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @chunyuan-w @jgong5 👋

Thanks for the fix 👍

Some points:

We should add a function for the conversion in:
https://github.com/mindee/doctr/blob/main/doctr/models/utils/pytorch.py
and for TF in
https://github.com/mindee/doctr/blob/main/doctr/models/utils/tensorflow.py
because i expect we need this fix on multiple places:

for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())

out["preds"] = [dict(zip(self.class_names, preds)) for preds in self.postprocessor(prob_map.numpy())]

for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())

out["preds"] = [dict(zip(self.class_names, preds)) for preds in self.postprocessor(prob_map.numpy())]

Than a short test for the function in:
https://github.com/mindee/doctr/blob/main/tests/pytorch/test_models_utils_pt.py
and
https://github.com/mindee/doctr/blob/main/tests/tensorflow/test_models_utils_tf.py

Afterwards you can run
make style
make quality (sometimes it shows an typing issue in https://github.com/mindee/doctr/tree/main/doctr/models/artefacts which can be ignored)
make test-common
make test-torch
make test-tf

EDIT:

After double checking we need the conversion also for each recognition model (except CRNN)
e.g.:

out["preds"] = self.postprocessor(decoded_features)

And for the detection models i suggest to convert directly the prob_map if needed
e.g.:

@felixdittrich92 felixdittrich92 added this to the 0.7.1 milestone Oct 10, 2023
@felixdittrich92 felixdittrich92 added type: bug Something isn't working module: models Related to doctr.models ext: tests Related to tests folder framework: pytorch Related to PyTorch backend framework: tensorflow Related to TensorFlow backend topic: text detection Related to the task of text detection labels Oct 10, 2023
@felixdittrich92
Copy link
Contributor

@chunyuan-w see: #1344

In your PR you can do the same for PyTorch and we are fine to merge 🤗

@felixdittrich92 felixdittrich92 added topic: text recognition Related to the task of text recognition and removed framework: tensorflow Related to TensorFlow backend labels Oct 11, 2023
@chunyuan-w
Copy link
Contributor Author

@chunyuan-w see: #1344

In your PR you can do the same for PyTorch and we are fine to merge 🤗

Thanks for the reference. Let me further refine this PR following #1344.

@chunyuan-w chunyuan-w changed the title [Fix] convert BF16 tensor to float before calling .numpy() [Fix] PT - convert BF16 tensor to float before calling .numpy() Oct 12, 2023
Copy link
Contributor

@felixdittrich92 felixdittrich92 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix @chunyuan-w 👍

Could please add a short comment that it fixes the issue in torchbench ? :)
@odulcy-mindee mypy fix applied in #1344

@felixdittrich92 felixdittrich92 merged commit 56c8356 into mindee:main Oct 12, 2023
67 of 68 checks passed
@chunyuan-w
Copy link
Contributor Author

chunyuan-w commented Oct 12, 2023

Thanks for the fix @chunyuan-w 👍

Could please add a short comment that it fixes the issue in torchbench ? :) @odulcy-mindee mypy fix applied in #1344

Thanks for merging it!
I just submitted a draft PR to torchbench to update the doctr version in torchbench to include this fix:
pytorch/benchmark#1979

@felixdittrich92
Copy link
Contributor

Thanks for the update 👍

facebook-github-bot pushed a commit to pytorch/benchmark that referenced this pull request Oct 12, 2023
Summary:
Update the version of `doctr` to include the fix in mindee/doctr#1342 for BF16 mode.

Remove the change of `rapidfuzz==2.15.1` in `requirements.txt` (#1555) since the version has been set in the model repo in the updated version (mindee/doctr#1176).

Pull Request resolved: #1979

Reviewed By: aaronenyeshi

Differential Revision: D50242780

Pulled By: xuzhao9

fbshipit-source-id: d8ed9164d463a1217114408106b2c745431bd159
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ext: tests Related to tests folder framework: pytorch Related to PyTorch backend module: models Related to doctr.models topic: text detection Related to the task of text detection topic: text recognition Related to the task of text recognition type: bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants