-
Notifications
You must be signed in to change notification settings - Fork 114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How do you load a custom checkpoint? #53
Comments
Hello! |
I just went down this path if pain and suffering, I want to post this for anyone else who wants to use this AI to train data for their own purposes. It was a really painful experience that let me learn how detoxify works from a tensor level lol. trainer.fit(model, data_loader, valid_data_loader)
torch.save({"config": model.config, "state_dict": model.state_dict()},"model.pt") my original idea after seeing this post was just to put this save line after the trainfer.fit. The issue i have been running to is that you can't use model.state_dict() because everything in the state dictonary is prefixed with model. i.g model.bert.encoder.layer.8.output.LayerNorm.weight needs to be converted to bert.encoder.layer.8.output.LayerNorm.weight. after doing all of the translations of every element in the state dictonary i could sucessfully run the checkpoint method in detoxify. you need to add this to the bottom of train.py trainer.fit(model, data_loader, valid_data_loader)
statedict = {}
for param_tensor in model.state_dict():
if "model.bert." in param_tensor:
newname = param_tensor.replace("model.","")
statedict[newname] = model.state_dict()[param_tensor]
statedict["classifier.weight"] = model.state_dict()["model.classifier.weight"]
statedict["classifier.bias"] = model.state_dict()["model.classifier.bias"]
torch.save({"config": model.config, "state_dict": statedict},"model.pt") then you can just import your model using detoxify like below ai = Detoxify(checkpoint="model.pt") also for other models like Robert or albert you just need to replace the bert in the if statement above. |
Hello I want to train the network on my own samples but I'm finding it quite difficult.
Right now I edited Toxic_comment_classification_BERT.json to point to my own training and test csv. Then I have to edit train.py to manually save the model object inside ToxicClassifier at the end of the training.
Then I have load the file manually, instantiate the normal instance of detoxify, and then replace the internal model object with the saved version to get it to work.
If I try to load a checkpoint generated at "saved\Jigsaw_BERT\lightning_logs\version_x\checkpoints\epoch=3-step=76.ckpt" with detoxify or try to instantiate detoxify with the "checkpoint parameter" or with a file generated by torch.save(model), it always says
What's the proper way of saving the checkpoint so it has the config and state dict with it? Or is my workaround the best way to use custom training data?
The text was updated successfully, but these errors were encountered: