diff --git a/gconattn.py b/gconattn.py index 0c0bc3d..56f9e45 100644 --- a/gconattn.py +++ b/gconattn.py @@ -122,9 +122,6 @@ def train(args): concept_dim=args.cpt_out_dim, concept_in_dim=concept_dim, freeze_ent_emb=args.freeze_ent_emb, pretrained_concept_emb=cp_emb, hidden_dim=args.decoder_hidden_dim, dropout=args.dropoutm, encoder_config=lstm_config) - if args.freeze_ent_emb: - freeze_net(model.decoder.concept_emb) - try: model.to(device) except RuntimeError as e: diff --git a/kvmem.py b/kvmem.py index 546d807..f75d158 100644 --- a/kvmem.py +++ b/kvmem.py @@ -161,9 +161,6 @@ def train(args): decoder_emb_p=args.d_dropoute, decoder_hidden_p=args.d_dropoutr, decoder_mlp_p=args.d_dropoutm, encoder_config=lstm_config) - if args.freeze_ent_emb: - freeze_net(model.decoder.concept_emb) - model.encoder.to(device0) model.decoder.to(device1) except RuntimeError as e: diff --git a/rgcn.py b/rgcn.py index a6679a2..4f7e413 100644 --- a/rgcn.py +++ b/rgcn.py @@ -118,8 +118,6 @@ def train(args): num_attention_heads=args.att_head_num, fc_dim=args.fc_dim, num_fc_layers=args.fc_layer_num, p_gnn=args.dropoutg, p_fc=args.dropoutf, freeze_ent_emb=args.freeze_ent_emb, pretrained_concept_emb=cp_emb, diag_decompose=args.diag_decompose, ablation=args.ablation, encoder_config=lstm_config) - if args.freeze_ent_emb: - freeze_net(model.decoder.concept_emb) model.to(device) except RuntimeError as e: print(e)