Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

loading state_dict for BertModel #25

Open
AwsomeName opened this issue May 4, 2019 · 12 comments
Open

loading state_dict for BertModel #25

AwsomeName opened this issue May 4, 2019 · 12 comments

Comments

@AwsomeName
Copy link

您好,非常感谢您的代码:
我在调试的时候,下载了谷歌的chinese_base压缩包,解压后,用https://github.com/huggingface/pytorch-pretrained-BERT/tree/1de35b624b9d7998feb4d518e4f7e8e53abac4e1的方法转化成bin。或者是用https://github.com/NLPScott/bert-Chinese-classification-task/issues/13这里提供的chinese版本,都会遇到模型载入的错误。
RuntimeError: Error(s) in loading state_dict for BertModel:
Missing key(s) in state_dict: "embeddings.word_embeddings.weight",
可以发现是模型的名字对应错误,应该是名字有了调整,这里我解决不了,您能帮忙看看吗?

@ylf4910
Copy link

ylf4910 commented May 28, 2019

请问一下这个问题 这边解决了吗?我也遇到同样的错误,想请教一下呢

@zilinly
Copy link

zilinly commented May 28, 2019

@AwsomeName 同问,我这边也是这个问题

@rubby33
Copy link

rubby33 commented Jun 3, 2019

@NLPScott 大牛,能帮忙看下。自己转的,aws上的都会报错。谢谢

@xjmxmt
Copy link

xjmxmt commented Jun 12, 2019

自己手动改一下字典key值就好了

@dawn2034
Copy link

Missing key(s) in state_dict:
"embeddings.word_embeddings.weight",
"embeddings.position_embeddings.weight",
"embeddings.token_type_embeddings.weight",
"embeddings.LayerNorm.gamma",
"embeddings.LayerNorm.beta",......

Unexpected key(s) in state_dict:
"bert.embeddings.word_embeddings.weight",
"bert.embeddings.position_embeddings.weight",
"bert.embeddings.token_type_embeddings.weight",
"bert.embeddings.LayerNorm.weight",
"bert.embeddings.LayerNorm.bias"......

就是当前模型和state_dict字典key名字不同,可参考以下代码更改key值

        state_dict = torch.load(pretrained_model_file, map_location='cpu')
        old_keys = []
        new_keys = []
        for key in list(state_dict.keys()):
            new_key = None
            if 'gamma' in key:
                new_key = key.replace('gamma', 'weight')
            if 'beta' in key:
                new_key = key.replace('beta', 'bias')
            if 'bert.' in key:
                new_key = key.replace('bert.', '')
            if 'cls.' in key:
                state_dict.pop(key)
            if new_key:
                old_keys.append(key)
                new_keys.append(new_key)
        for old_key, new_key in zip(old_keys, new_keys):
            state_dict[new_key] = state_dict.pop(old_key)
        model.bert.load_state_dict(state_dict)

@ARDUJS
Copy link

ARDUJS commented Sep 22, 2019

#25 感谢大佬,正解,解决了

@ARDUJS
Copy link

ARDUJS commented Sep 22, 2019

Missing key(s) in state_dict:
"embeddings.word_embeddings.weight",
"embeddings.position_embeddings.weight",
"embeddings.token_type_embeddings.weight",
"embeddings.LayerNorm.gamma",
"embeddings.LayerNorm.beta",......

Unexpected key(s) in state_dict:
"bert.embeddings.word_embeddings.weight",
"bert.embeddings.position_embeddings.weight",
"bert.embeddings.token_type_embeddings.weight",
"bert.embeddings.LayerNorm.weight",
"bert.embeddings.LayerNorm.bias"......

就是当前模型和state_dict字典key名字不同,可参考以下代码更改key值

        state_dict = torch.load(pretrained_model_file, map_location='cpu')
        old_keys = []
        new_keys = []
        for key in list(state_dict.keys()):
            new_key = None
            if 'gamma' in key:
                new_key = key.replace('gamma', 'weight')
            if 'beta' in key:
                new_key = key.replace('beta', 'bias')
            if 'bert.' in key:
                new_key = key.replace('bert.', '')
            if 'cls.' in key:
                state_dict.pop(key)
            if new_key:
                old_keys.append(key)
                new_keys.append(new_key)
        for old_key, new_key in zip(old_keys, new_keys):
            state_dict[new_key] = state_dict.pop(old_key)
        model.bert.load_state_dict(state_dict)

感谢大佬,正解,解决了

@754563116
Copy link

#25 感谢大佬,正解,解决了

你好,请问你的解决了吗?
我按照上面那个人的方法,又出现一个新的错误:state_dict[new_key] = state_dict.pop(old_key) kyError: 'cls.predictions.transform.LayerNorm.gamma'

@754563116
Copy link

按照你的方法出现了新的错误,我在你的基础上又改了改,代码如下:
state_dict = torch.load(path, map_location='cpu')
old_keys = []
new_keys = []
for key in list(state_dict.keys()):
# print(key)
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if 'bert.' in key:
new_key = key.replace('bert.', '')
if 'cls.' in key:
new_key = key.replace('cls.', '')
# state_dict.pop(key)
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
#print('old_key',old_key)
#print ('new_key',new_key)
state_dict[new_key] = state_dict.pop(old_key)
if args.init_checkpoint is not None:
model.bert.load_state_dict(state_dict,strict=False)

主要是strict=False,我找了很久,

@ARDUJS
Copy link

ARDUJS commented Oct 12, 2019 via email

@754563116
Copy link

我是这样处理的: state_dict = torch.load(args.init_checkpoint, map_location='cpu') new_keys = ["embeddings.word_embeddings.weight", "embeddings.position_embeddings.weight", "embeddings.token_type_embeddings.weight", "embeddings.LayerNorm.gamma", "embeddings.LayerNorm.beta", "encoder.layer.0.attention.self.query.weight", "encoder.layer.0.attention.self.query.bias", "encoder.layer.0.attention.self.key.weight", "encoder.layer.0.attention.self.key.bias", "encoder.layer.0.attention.self.value.weight", "encoder.layer.0.attention.self.value.bias", "encoder.layer.0.attention.output.dense.weight", "encoder.layer.0.attention.output.dense.bias", "encoder.layer.0.attention.output.LayerNorm.gamma", "encoder.layer.0.attention.output.LayerNorm.beta", "encoder.layer.0.intermediate.dense.weight", "encoder.layer.0.intermediate.dense.bias", "encoder.layer.0.output.dense.weight", "encoder.layer.0.output.dense.bias", "encoder.layer.0.output.LayerNorm.gamma", "encoder.layer.0.output.LayerNorm.beta", "encoder.layer.1.attention.self.query.weight", "encoder.layer.1.attention.self.query.bias", "encoder.layer.1.attention.self.key.weight", "encoder.layer.1.attention.self.key.bias", "encoder.layer.1.attention.self.value.weight", "encoder.layer.1.attention.self.value.bias", "encoder.layer.1.attention.output.dense.weight", "encoder.layer.1.attention.output.dense.bias", "encoder.layer.1.attention.output.LayerNorm.gamma", "encoder.layer.1.attention.output.LayerNorm.beta", "encoder.layer.1.intermediate.dense.weight", "encoder.layer.1.intermediate.dense.bias", "encoder.layer.1.output.dense.weight", "encoder.layer.1.output.dense.bias", "encoder.layer.1.output.LayerNorm.gamma", "encoder.layer.1.output.LayerNorm.beta", "encoder.layer.2.attention.self.query.weight", "encoder.layer.2.attention.self.query.bias", "encoder.layer.2.attention.self.key.weight", "encoder.layer.2.attention.self.key.bias", "encoder.layer.2.attention.self.value.weight", "encoder.layer.2.attention.self.value.bias", "encoder.layer.2.attention.output.dense.weight", "encoder.layer.2.attention.output.dense.bias", "encoder.layer.2.attention.output.LayerNorm.gamma", "encoder.layer.2.attention.output.LayerNorm.beta", "encoder.layer.2.intermediate.dense.weight", "encoder.layer.2.intermediate.dense.bias", "encoder.layer.2.output.dense.weight", "encoder.layer.2.output.dense.bias", "encoder.layer.2.output.LayerNorm.gamma", "encoder.layer.2.output.LayerNorm.beta", "encoder.layer.3.attention.self.query.weight", "encoder.layer.3.attention.self.query.bias", "encoder.layer.3.attention.self.key.weight", "encoder.layer.3.attention.self.key.bias", "encoder.layer.3.attention.self.value.weight", "encoder.layer.3.attention.self.value.bias", "encoder.layer.3.attention.output.dense.weight", "encoder.layer.3.attention.output.dense.bias", "encoder.layer.3.attention.output.LayerNorm.gamma", "encoder.layer.3.attention.output.LayerNorm.beta", "encoder.layer.3.intermediate.dense.weight", "encoder.layer.3.intermediate.dense.bias", "encoder.layer.3.output.dense.weight", "encoder.layer.3.output.dense.bias", "encoder.layer.3.output.LayerNorm.gamma", "encoder.layer.3.output.LayerNorm.beta", "encoder.layer.4.attention.self.query.weight", "encoder.layer.4.attention.self.query.bias", "encoder.layer.4.attention.self.key.weight", "encoder.layer.4.attention.self.key.bias", "encoder.layer.4.attention.self.value.weight", "encoder.layer.4.attention.self.value.bias", "encoder.layer.4.attention.output.dense.weight", "encoder.layer.4.attention.output.dense.bias", "encoder.layer.4.attention.output.LayerNorm.gamma", "encoder.layer.4.attention.output.LayerNorm.beta", "encoder.layer.4.intermediate.dense.weight", "encoder.layer.4.intermediate.dense.bias", "encoder.layer.4.output.dense.weight", "encoder.layer.4.output.dense.bias", "encoder.layer.4.output.LayerNorm.gamma", "encoder.layer.4.output.LayerNorm.beta", "encoder.layer.5.attention.self.query.weight", "encoder.layer.5.attention.self.query.bias", "encoder.layer.5.attention.self.key.weight", "encoder.layer.5.attention.self.key.bias", "encoder.layer.5.attention.self.value.weight", "encoder.layer.5.attention.self.value.bias", "encoder.layer.5.attention.output.dense.weight", "encoder.layer.5.attention.output.dense.bias", "encoder.layer.5.attention.output.LayerNorm.gamma", "encoder.layer.5.attention.output.LayerNorm.beta", "encoder.layer.5.intermediate.dense.weight", "encoder.layer.5.intermediate.dense.bias", "encoder.layer.5.output.dense.weight", "encoder.layer.5.output.dense.bias", "encoder.layer.5.output.LayerNorm.gamma", "encoder.layer.5.output.LayerNorm.beta", "encoder.layer.6.attention.self.query.weight", "encoder.layer.6.attention.self.query.bias", "encoder.layer.6.attention.self.key.weight", "encoder.layer.6.attention.self.key.bias", "encoder.layer.6.attention.self.value.weight", "encoder.layer.6.attention.self.value.bias", "encoder.layer.6.attention.output.dense.weight", "encoder.layer.6.attention.output.dense.bias", "encoder.layer.6.attention.output.LayerNorm.gamma", "encoder.layer.6.attention.output.LayerNorm.beta", "encoder.layer.6.intermediate.dense.weight", "encoder.layer.6.intermediate.dense.bias", "encoder.layer.6.output.dense.weight", "encoder.layer.6.output.dense.bias", "encoder.layer.6.output.LayerNorm.gamma", "encoder.layer.6.output.LayerNorm.beta", "encoder.layer.7.attention.self.query.weight", "encoder.layer.7.attention.self.query.bias", "encoder.layer.7.attention.self.key.weight", "encoder.layer.7.attention.self.key.bias", "encoder.layer.7.attention.self.value.weight", "encoder.layer.7.attention.self.value.bias", "encoder.layer.7.attention.output.dense.weight", "encoder.layer.7.attention.output.dense.bias", "encoder.layer.7.attention.output.LayerNorm.gamma", "encoder.layer.7.attention.output.LayerNorm.beta", "encoder.layer.7.intermediate.dense.weight", "encoder.layer.7.intermediate.dense.bias", "encoder.layer.7.output.dense.weight", "encoder.layer.7.output.dense.bias", "encoder.layer.7.output.LayerNorm.gamma", "encoder.layer.7.output.LayerNorm.beta", "encoder.layer.8.attention.self.query.weight", "encoder.layer.8.attention.self.query.bias", "encoder.layer.8.attention.self.key.weight", "encoder.layer.8.attention.self.key.bias", "encoder.layer.8.attention.self.value.weight", "encoder.layer.8.attention.self.value.bias", "encoder.layer.8.attention.output.dense.weight", "encoder.layer.8.attention.output.dense.bias", "encoder.layer.8.attention.output.LayerNorm.gamma", "encoder.layer.8.attention.output.LayerNorm.beta", "encoder.layer.8.intermediate.dense.weight", "encoder.layer.8.intermediate.dense.bias", "encoder.layer.8.output.dense.weight", "encoder.layer.8.output.dense.bias", "encoder.layer.8.output.LayerNorm.gamma", "encoder.layer.8.output.LayerNorm.beta", "encoder.layer.9.attention.self.query.weight", "encoder.layer.9.attention.self.query.bias", "encoder.layer.9.attention.self.key.weight", "encoder.layer.9.attention.self.key.bias", "encoder.layer.9.attention.self.value.weight", "encoder.layer.9.attention.self.value.bias", "encoder.layer.9.attention.output.dense.weight", "encoder.layer.9.attention.output.dense.bias", "encoder.layer.9.attention.output.LayerNorm.gamma", "encoder.layer.9.attention.output.LayerNorm.beta", "encoder.layer.9.intermediate.dense.weight", "encoder.layer.9.intermediate.dense.bias", "encoder.layer.9.output.dense.weight", "encoder.layer.9.output.dense.bias", "encoder.layer.9.output.LayerNorm.gamma", "encoder.layer.9.output.LayerNorm.beta", "encoder.layer.10.attention.self.query.weight", "encoder.layer.10.attention.self.query.bias", "encoder.layer.10.attention.self.key.weight", "encoder.layer.10.attention.self.key.bias", "encoder.layer.10.attention.self.value.weight", "encoder.layer.10.attention.self.value.bias", "encoder.layer.10.attention.output.dense.weight", "encoder.layer.10.attention.output.dense.bias", "encoder.layer.10.attention.output.LayerNorm.gamma", "encoder.layer.10.attention.output.LayerNorm.beta", "encoder.layer.10.intermediate.dense.weight", "encoder.layer.10.intermediate.dense.bias", "encoder.layer.10.output.dense.weight", "encoder.layer.10.output.dense.bias", "encoder.layer.10.output.LayerNorm.gamma", "encoder.layer.10.output.LayerNorm.beta", "encoder.layer.11.attention.self.query.weight", "encoder.layer.11.attention.self.query.bias", "encoder.layer.11.attention.self.key.weight", "encoder.layer.11.attention.self.key.bias", "encoder.layer.11.attention.self.value.weight", "encoder.layer.11.attention.self.value.bias", "encoder.layer.11.attention.output.dense.weight", "encoder.layer.11.attention.output.dense.bias", "encoder.layer.11.attention.output.LayerNorm.gamma", "encoder.layer.11.attention.output.LayerNorm.beta", "encoder.layer.11.intermediate.dense.weight", "encoder.layer.11.intermediate.dense.bias", "encoder.layer.11.output.dense.weight", "encoder.layer.11.output.dense.bias", "encoder.layer.11.output.LayerNorm.gamma", "encoder.layer.11.output.LayerNorm.beta", "pooler.dense.weight", "pooler.dense.bias"] old_keys = ["bert.embeddings.word_embeddings.weight", "bert.embeddings.position_embeddings.weight", "bert.embeddings.token_type_embeddings.weight", "bert.embeddings.LayerNorm.gamma", "bert.embeddings.LayerNorm.beta", "bert.encoder.layer.0.attention.self.query.weight", "bert.encoder.layer.0.attention.self.query.bias", "bert.encoder.layer.0.attention.self.key.weight", "bert.encoder.layer.0.attention.self.key.bias", "bert.encoder.layer.0.attention.self.value.weight", "bert.encoder.layer.0.attention.self.value.bias", "bert.encoder.layer.0.attention.output.dense.weight", "bert.encoder.layer.0.attention.output.dense.bias", "bert.encoder.layer.0.attention.output.LayerNorm.gamma", "bert.encoder.layer.0.attention.output.LayerNorm.beta", "bert.encoder.layer.0.intermediate.dense.weight", "bert.encoder.layer.0.intermediate.dense.bias", "bert.encoder.layer.0.output.dense.weight", "bert.encoder.layer.0.output.dense.bias", "bert.encoder.layer.0.output.LayerNorm.gamma", "bert.encoder.layer.0.output.LayerNorm.beta", "bert.encoder.layer.1.attention.self.query.weight", "bert.encoder.layer.1.attention.self.query.bias", "bert.encoder.layer.1.attention.self.key.weight", "bert.encoder.layer.1.attention.self.key.bias", "bert.encoder.layer.1.attention.self.value.weight", "bert.encoder.layer.1.attention.self.value.bias", "bert.encoder.layer.1.attention.output.dense.weight", "bert.encoder.layer.1.attention.output.dense.bias", "bert.encoder.layer.1.attention.output.LayerNorm.gamma", "bert.encoder.layer.1.attention.output.LayerNorm.beta", "bert.encoder.layer.1.intermediate.dense.weight", "bert.encoder.layer.1.intermediate.dense.bias", "bert.encoder.layer.1.output.dense.weight", "bert.encoder.layer.1.output.dense.bias", "bert.encoder.layer.1.output.LayerNorm.gamma", "bert.encoder.layer.1.output.LayerNorm.beta", "bert.encoder.layer.2.attention.self.query.weight", "bert.encoder.layer.2.attention.self.query.bias", "bert.encoder.layer.2.attention.self.key.weight", "bert.encoder.layer.2.attention.self.key.bias", "bert.encoder.layer.2.attention.self.value.weight", "bert.encoder.layer.2.attention.self.value.bias", "bert.encoder.layer.2.attention.output.dense.weight", "bert.encoder.layer.2.attention.output.dense.bias", "bert.encoder.layer.2.attention.output.LayerNorm.gamma", "bert.encoder.layer.2.attention.output.LayerNorm.beta", "bert.encoder.layer.2.intermediate.dense.weight", "bert.encoder.layer.2.intermediate.dense.bias", "bert.encoder.layer.2.output.dense.weight", "bert.encoder.layer.2.output.dense.bias", "bert.encoder.layer.2.output.LayerNorm.gamma", "bert.encoder.layer.2.output.LayerNorm.beta", "bert.encoder.layer.3.attention.self.query.weight", "bert.encoder.layer.3.attention.self.query.bias", "bert.encoder.layer.3.attention.self.key.weight", "bert.encoder.layer.3.attention.self.key.bias", "bert.encoder.layer.3.attention.self.value.weight", "bert.encoder.layer.3.attention.self.value.bias", "bert.encoder.layer.3.attention.output.dense.weight", "bert.encoder.layer.3.attention.output.dense.bias", "bert.encoder.layer.3.attention.output.LayerNorm.gamma", "bert.encoder.layer.3.attention.output.LayerNorm.beta", "bert.encoder.layer.3.intermediate.dense.weight", "bert.encoder.layer.3.intermediate.dense.bias", "bert.encoder.layer.3.output.dense.weight", "bert.encoder.layer.3.output.dense.bias", "bert.encoder.layer.3.output.LayerNorm.gamma", "bert.encoder.layer.3.output.LayerNorm.beta", "bert.encoder.layer.4.attention.self.query.weight", "bert.encoder.layer.4.attention.self.query.bias", "bert.encoder.layer.4.attention.self.key.weight", "bert.encoder.layer.4.attention.self.key.bias", "bert.encoder.layer.4.attention.self.value.weight", "bert.encoder.layer.4.attention.self.value.bias", "bert.encoder.layer.4.attention.output.dense.weight", "bert.encoder.layer.4.attention.output.dense.bias", "bert.encoder.layer.4.attention.output.LayerNorm.gamma", "bert.encoder.layer.4.attention.output.LayerNorm.beta", "bert.encoder.layer.4.intermediate.dense.weight", "bert.encoder.layer.4.intermediate.dense.bias", "bert.encoder.layer.4.output.dense.weight", "bert.encoder.layer.4.output.dense.bias", "bert.encoder.layer.4.output.LayerNorm.gamma", "bert.encoder.layer.4.output.LayerNorm.beta", "bert.encoder.layer.5.attention.self.query.weight", "bert.encoder.layer.5.attention.self.query.bias", "bert.encoder.layer.5.attention.self.key.weight", "bert.encoder.layer.5.attention.self.key.bias", "bert.encoder.layer.5.attention.self.value.weight", "bert.encoder.layer.5.attention.self.value.bias", "bert.encoder.layer.5.attention.output.dense.weight", "bert.encoder.layer.5.attention.output.dense.bias", "bert.encoder.layer.5.attention.output.LayerNorm.gamma", "bert.encoder.layer.5.attention.output.LayerNorm.beta", "bert.encoder.layer.5.intermediate.dense.weight", "bert.encoder.layer.5.intermediate.dense.bias", "bert.encoder.layer.5.output.dense.weight", "bert.encoder.layer.5.output.dense.bias", "bert.encoder.layer.5.output.LayerNorm.gamma", "bert.encoder.layer.5.output.LayerNorm.beta", "bert.encoder.layer.6.attention.self.query.weight", "bert.encoder.layer.6.attention.self.query.bias", "bert.encoder.layer.6.attention.self.key.weight", "bert.encoder.layer.6.attention.self.key.bias", "bert.encoder.layer.6.attention.self.value.weight", "bert.encoder.layer.6.attention.self.value.bias", "bert.encoder.layer.6.attention.output.dense.weight", "bert.encoder.layer.6.attention.output.dense.bias", "bert.encoder.layer.6.attention.output.LayerNorm.gamma", "bert.encoder.layer.6.attention.output.LayerNorm.beta", "bert.encoder.layer.6.intermediate.dense.weight", "bert.encoder.layer.6.intermediate.dense.bias", "bert.encoder.layer.6.output.dense.weight", "bert.encoder.layer.6.output.dense.bias", "bert.encoder.layer.6.output.LayerNorm.gamma", "bert.encoder.layer.6.output.LayerNorm.beta", "bert.encoder.layer.7.attention.self.query.weight", "bert.encoder.layer.7.attention.self.query.bias", "bert.encoder.layer.7.attention.self.key.weight", "bert.encoder.layer.7.attention.self.key.bias", "bert.encoder.layer.7.attention.self.value.weight", "bert.encoder.layer.7.attention.self.value.bias", "bert.encoder.layer.7.attention.output.dense.weight", "bert.encoder.layer.7.attention.output.dense.bias", "bert.encoder.layer.7.attention.output.LayerNorm.gamma", "bert.encoder.layer.7.attention.output.LayerNorm.beta", "bert.encoder.layer.7.intermediate.dense.weight", "bert.encoder.layer.7.intermediate.dense.bias", "bert.encoder.layer.7.output.dense.weight", "bert.encoder.layer.7.output.dense.bias", "bert.encoder.layer.7.output.LayerNorm.gamma", "bert.encoder.layer.7.output.LayerNorm.beta", "bert.encoder.layer.8.attention.self.query.weight", "bert.encoder.layer.8.attention.self.query.bias", "bert.encoder.layer.8.attention.self.key.weight", "bert.encoder.layer.8.attention.self.key.bias", "bert.encoder.layer.8.attention.self.value.weight", "bert.encoder.layer.8.attention.self.value.bias", "bert.encoder.layer.8.attention.output.dense.weight", "bert.encoder.layer.8.attention.output.dense.bias", "bert.encoder.layer.8.attention.output.LayerNorm.gamma", "bert.encoder.layer.8.attention.output.LayerNorm.beta", "bert.encoder.layer.8.intermediate.dense.weight", "bert.encoder.layer.8.intermediate.dense.bias", "bert.encoder.layer.8.output.dense.weight", "bert.encoder.layer.8.output.dense.bias", "bert.encoder.layer.8.output.LayerNorm.gamma", "bert.encoder.layer.8.output.LayerNorm.beta", "bert.encoder.layer.9.attention.self.query.weight", "bert.encoder.layer.9.attention.self.query.bias", "bert.encoder.layer.9.attention.self.key.weight", "bert.encoder.layer.9.attention.self.key.bias", "bert.encoder.layer.9.attention.self.value.weight", "bert.encoder.layer.9.attention.self.value.bias", "bert.encoder.layer.9.attention.output.dense.weight", "bert.encoder.layer.9.attention.output.dense.bias", "bert.encoder.layer.9.attention.output.LayerNorm.gamma", "bert.encoder.layer.9.attention.output.LayerNorm.beta", "bert.encoder.layer.9.intermediate.dense.weight", "bert.encoder.layer.9.intermediate.dense.bias", "bert.encoder.layer.9.output.dense.weight", "bert.encoder.layer.9.output.dense.bias", "bert.encoder.layer.9.output.LayerNorm.gamma", "bert.encoder.layer.9.output.LayerNorm.beta", "bert.encoder.layer.10.attention.self.query.weight", "bert.encoder.layer.10.attention.self.query.bias", "bert.encoder.layer.10.attention.self.key.weight", "bert.encoder.layer.10.attention.self.key.bias", "bert.encoder.layer.10.attention.self.value.weight", "bert.encoder.layer.10.attention.self.value.bias", "bert.encoder.layer.10.attention.output.dense.weight", "bert.encoder.layer.10.attention.output.dense.bias", "bert.encoder.layer.10.attention.output.LayerNorm.gamma", "bert.encoder.layer.10.attention.output.LayerNorm.beta", "bert.encoder.layer.10.intermediate.dense.weight", "bert.encoder.layer.10.intermediate.dense.bias", "bert.encoder.layer.10.output.dense.weight", "bert.encoder.layer.10.output.dense.bias", "bert.encoder.layer.10.output.LayerNorm.gamma", "bert.encoder.layer.10.output.LayerNorm.beta", "bert.encoder.layer.11.attention.self.query.weight", "bert.encoder.layer.11.attention.self.query.bias", "bert.encoder.layer.11.attention.self.key.weight", "bert.encoder.layer.11.attention.self.key.bias", "bert.encoder.layer.11.attention.self.value.weight", "bert.encoder.layer.11.attention.self.value.bias", "bert.encoder.layer.11.attention.output.dense.weight", "bert.encoder.layer.11.attention.output.dense.bias", "bert.encoder.layer.11.attention.output.LayerNorm.gamma", "bert.encoder.layer.11.attention.output.LayerNorm.beta", "bert.encoder.layer.11.intermediate.dense.weight", "bert.encoder.layer.11.intermediate.dense.bias", "bert.encoder.layer.11.output.dense.weight", "bert.encoder.layer.11.output.dense.bias", "bert.encoder.layer.11.output.LayerNorm.gamma", "bert.encoder.layer.11.output.LayerNorm.beta", "bert.pooler.dense.weight", "bert.pooler.dense.bias", "cls.predictions.bias", "cls.predictions.transform.dense.weight", "cls.predictions.transform.dense.bias", "cls.predictions.transform.LayerNorm.gamma", "cls.predictions.transform.LayerNorm.beta", "cls.predictions.decoder.weight", "cls.seq_relationship.weight", "cls.seq_relationship.bias"] for key in list(state_dict.keys()): # new_key = None # if 'gamma' in key: # new_key = key.replace('gamma', 'weight') # if 'beta' in key: # new_key = key.replace('beta', 'bias') # if 'bert.' in key: # new_key = key.replace('bert.', '') if 'cls.' in key: state_dict.pop(key) # if new_key: # old_keys.append(key) # new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) model.bert.load_state_dict(state_dict)

------------------ 原始邮件 ------------------ 发件人: "754563116"[email protected]; 发送时间: 2019年10月12日(星期六) 下午3:32 收件人: "NLPScott/bert-Chinese-classification-task"[email protected]; 抄送: "572915766"[email protected]; "Comment"[email protected]; 主题: Re: [NLPScott/bert-Chinese-classification-task] loading state_dictfor BertModel (#25) 按照你的方法出现了新的错误,我在你的基础上又改了改,代码如下: state_dict = torch.load(path, map_location='cpu') old_keys = [] new_keys = [] for key in list(state_dict.keys()): # print(key) new_key = None if 'gamma' in key: new_key = key.replace('gamma', 'weight') if 'beta' in key: new_key = key.replace('beta', 'bias') if 'bert.' in key: new_key = key.replace('bert.', '') if 'cls.' in key: new_key = key.replace('cls.', '') # state_dict.pop(key) if new_key: old_keys.append(key) new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): #print('old_key',old_key) #print ('new_key',new_key) state_dict[new_key] = state_dict.pop(old_key) if args.init_checkpoint is not None: model.bert.load_state_dict(state_dict,strict=False) 主要是strict=False,我找了很久, — You are receiving this because you commented. Reply to this email directly, view it on GitHub, or unsubscribe.

我像你这样处理,仍然会报错

@monster-fn
Copy link

按照你的方法出现了新的错误,我在你的基础上又改了改,代码如下:
state_dict = torch.load(path, map_location='cpu')
old_keys = []
new_keys = []
for key in list(state_dict.keys()):

print(key)

new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if 'bert.' in key:
new_key = key.replace('bert.', '')
if 'cls.' in key:
new_key = key.replace('cls.', '')

state_dict.pop(key)

if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
#print('old_key',old_key)
#print ('new_key',new_key)
state_dict[new_key] = state_dict.pop(old_key)
if args.init_checkpoint is not None:
model.bert.load_state_dict(state_dict,strict=False)

主要是strict=False,我找了很久,

我按这个又出现了新问题:
RuntimeError: CUDA error: invalid device ordinal

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

9 participants