Skip to content

Commit

Permalink
fix(pt): convert torch.__version__ to str when serializing (#4106)
Browse files Browse the repository at this point in the history
```py
>>> type(torch.__version__)
<class 'torch.torch_version.TorchVersion'>
```

This causes a YAML error:

```
  File "/home/jz748/anaconda3/lib/python3.10/site-packages/yaml/representer.py", line 231, in represent_undefined
    raise RepresenterError("cannot represent an object", data)
yaml.representer.RepresenterError: ('cannot represent an object', '2.3.1+cu121')
```

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **Bug Fixes**
- Improved the serialization of the PyTorch version by ensuring it is
represented as a string, enhancing data clarity and consistency.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Sep 6, 2024
1 parent f4139fa commit 866726e
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion deepmd/pt/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def serialize_from_file(model_file: str) -> dict:
model_dict = model.serialize()
data = {
"backend": "PyTorch",
"pt_version": torch.__version__,
"pt_version": str(torch.__version__),
"model": model_dict,
"model_def_script": model_def_script,
"@variables": {},
Expand Down

0 comments on commit 866726e

Please sign in to comment.