You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Traceback (most recent call last): File "pt2model.py", line 11, in <module> checkpoint = torch.jit.load(args.checkpoint_path) File "/share/apps/anaconda3/envs/mace1/lib/python3.8/site-packages/torch/jit/_serialization.py", line 158, in load cpp_module = torch._C.import_ir_module(cu, os.fspath(f), map_location, _extra_files, _restore_shapes) # type: ignore[call-arg] RuntimeError: PytorchStreamReader failed locating file constants.pkl: file not found
The command I use is python pt2model.py checkpoints/MACE-OFF23_medium_sol-retrain_run-112_epoch-454.pt MACE-OFF23_medium_sol-retrain.model test.model
which the first argument is a .pt file in checkfiles/ , the second is a .model file that generates when training is finished.
The following is my pt2model.py:
import torch
import argparse
parser = argparse.ArgumentParser(description="Load model from checkpoint and save to a new path.")
parser.add_argument('checkpoint_path', type=str, help="Path to the checkpoint file.")
parser.add_argument('model_path', type=str, help="Path to the model file.")
parser.add_argument('new_model_path', type=str, help="Path where the new model will be saved.")
args = parser.parse_args()
checkpoint = torch.jit.load(args.checkpoint_path)
model = torch.jit.load(args.model_path)
model.load_state_dict(checkpoint["model"])
torch.save(model, args.new_model_path)
print(f"Model saved to {args.new_model_path}")
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Following by the discussion https://github.com/ACEsuit/mace/discussions/564, I write a python script to do this and get an error:
Traceback (most recent call last): File "pt2model.py", line 11, in <module> checkpoint = torch.jit.load(args.checkpoint_path) File "/share/apps/anaconda3/envs/mace1/lib/python3.8/site-packages/torch/jit/_serialization.py", line 158, in load cpp_module = torch._C.import_ir_module(cu, os.fspath(f), map_location, _extra_files, _restore_shapes) # type: ignore[call-arg] RuntimeError: PytorchStreamReader failed locating file constants.pkl: file not found
The command I use is
python pt2model.py checkpoints/MACE-OFF23_medium_sol-retrain_run-112_epoch-454.pt MACE-OFF23_medium_sol-retrain.model test.model
which the first argument is a .pt file in checkfiles/ , the second is a .model file that generates when training is finished.
The following is my pt2model.py:
python 3.8.9
pytorch 2.3.1
Beta Was this translation helpful? Give feedback.
All reactions